From 0ef8393738c20b2ae2fc69512c3d6844c9865b11 Mon Sep 17 00:00:00 2001 From: mpyw Date: Sun, 11 Aug 2024 18:13:49 +0900 Subject: [PATCH 01/10] feat: Implement `get_transaction_depth` for drivers --- sqlx-core/src/any/connection/backend.rs | 9 +++++++++ sqlx-core/src/any/connection/mod.rs | 8 +++++++- sqlx-core/src/connection.rs | 20 ++++++++++++++++++++ sqlx-mysql/src/any.rs | 7 ++++++- sqlx-mysql/src/connection/mod.rs | 6 ++++++ sqlx-postgres/src/any.rs | 7 ++++++- sqlx-postgres/src/connection/mod.rs | 6 ++++++ sqlx-sqlite/src/any.rs | 7 ++++++- sqlx-sqlite/src/connection/mod.rs | 6 ++++++ src/lib.rs | 4 +++- 10 files changed, 75 insertions(+), 5 deletions(-) diff --git a/sqlx-core/src/any/connection/backend.rs b/sqlx-core/src/any/connection/backend.rs index b30cbe83f3..a2393832dc 100644 --- a/sqlx-core/src/any/connection/backend.rs +++ b/sqlx-core/src/any/connection/backend.rs @@ -1,5 +1,6 @@ use crate::any::{Any, AnyArguments, AnyQueryResult, AnyRow, AnyStatement, AnyTypeInfo}; use crate::describe::Describe; +use crate::Error; use either::Either; use futures_core::future::BoxFuture; use futures_core::stream::BoxStream; @@ -34,6 +35,14 @@ pub trait AnyConnectionBackend: std::any::Any + Debug + Send + 'static { fn start_rollback(&mut self); + /// Returns the current transaction depth asynchronously. + /// + /// Transaction depth indicates the level of nested transactions: + /// - Level 0: No active transaction. + /// - Level 1: A transaction is active. + /// - Level 2 or higher: A transaction is active and one or more SAVEPOINTs have been created within it. + fn get_transaction_depth(&mut self) -> BoxFuture<'_, Result>; + /// The number of statements currently cached in the connection. fn cached_statements_size(&self) -> usize { 0 diff --git a/sqlx-core/src/any/connection/mod.rs b/sqlx-core/src/any/connection/mod.rs index b6f795848a..c2274d9885 100644 --- a/sqlx-core/src/any/connection/mod.rs +++ b/sqlx-core/src/any/connection/mod.rs @@ -1,7 +1,7 @@ use futures_core::future::BoxFuture; use crate::any::{Any, AnyConnectOptions}; -use crate::connection::{ConnectOptions, Connection}; +use crate::connection::{AsyncTransactionDepth, ConnectOptions, Connection}; use crate::error::Error; use crate::database::Database; @@ -112,3 +112,9 @@ impl Connection for AnyConnection { self.backend.should_flush() } } + +impl AsyncTransactionDepth for AnyConnection { + fn get_transaction_depth(&mut self) -> BoxFuture<'_, Result> { + self.backend.get_transaction_depth() + } +} diff --git a/sqlx-core/src/connection.rs b/sqlx-core/src/connection.rs index ce2aa6c629..155f29d268 100644 --- a/sqlx-core/src/connection.rs +++ b/sqlx-core/src/connection.rs @@ -236,3 +236,23 @@ pub trait ConnectOptions: 'static + Send + Sync + FromStr + Debug + .log_slow_statements(LevelFilter::Off, Duration::default()) } } + +pub trait TransactionDepth { + /// Returns the current transaction depth synchronously. + /// + /// Transaction depth indicates the level of nested transactions: + /// - Level 0: No active transaction. + /// - Level 1: A transaction is active. + /// - Level 2 or higher: A transaction is active and one or more SAVEPOINTs have been created within it. + fn get_transaction_depth(&self) -> usize; +} + +pub trait AsyncTransactionDepth { + /// Returns the current transaction depth asynchronously. + /// + /// Transaction depth indicates the level of nested transactions: + /// - Level 0: No active transaction. + /// - Level 1: A transaction is active. + /// - Level 2 or higher: A transaction is active and one or more SAVEPOINTs have been created within it. + fn get_transaction_depth(&mut self) -> BoxFuture<'_, Result>; +} diff --git a/sqlx-mysql/src/any.rs b/sqlx-mysql/src/any.rs index fa8d34f8db..c688ad0d48 100644 --- a/sqlx-mysql/src/any.rs +++ b/sqlx-mysql/src/any.rs @@ -11,11 +11,12 @@ use sqlx_core::any::{ Any, AnyArguments, AnyColumn, AnyConnectOptions, AnyConnectionBackend, AnyQueryResult, AnyRow, AnyStatement, AnyTypeInfo, AnyTypeInfoKind, }; -use sqlx_core::connection::Connection; +use sqlx_core::connection::{Connection, TransactionDepth}; use sqlx_core::database::Database; use sqlx_core::describe::Describe; use sqlx_core::executor::Executor; use sqlx_core::transaction::TransactionManager; +use sqlx_core::Error; use std::future; sqlx_core::declare_driver_with_optional_migrate!(DRIVER = MySql); @@ -53,6 +54,10 @@ impl AnyConnectionBackend for MySqlConnection { MySqlTransactionManager::start_rollback(self) } + fn get_transaction_depth(&mut self) -> BoxFuture<'_, Result> { + Box::pin(async { Ok(TransactionDepth::get_transaction_depth(self)) }) + } + fn shrink_buffers(&mut self) { Connection::shrink_buffers(self); } diff --git a/sqlx-mysql/src/connection/mod.rs b/sqlx-mysql/src/connection/mod.rs index c4978a7701..b2a9a94b42 100644 --- a/sqlx-mysql/src/connection/mod.rs +++ b/sqlx-mysql/src/connection/mod.rs @@ -118,3 +118,9 @@ impl Connection for MySqlConnection { self.inner.stream.shrink_buffers(); } } + +impl TransactionDepth for MySqlConnection { + fn get_transaction_depth(&self) -> usize { + self.inner.transaction_depth + } +} diff --git a/sqlx-postgres/src/any.rs b/sqlx-postgres/src/any.rs index 7eae4bcb73..d0e0459721 100644 --- a/sqlx-postgres/src/any.rs +++ b/sqlx-postgres/src/any.rs @@ -10,12 +10,13 @@ use std::future; pub use sqlx_core::any::*; use crate::type_info::PgType; -use sqlx_core::connection::Connection; +use sqlx_core::connection::{Connection, TransactionDepth}; use sqlx_core::database::Database; use sqlx_core::describe::Describe; use sqlx_core::executor::Executor; use sqlx_core::ext::ustr::UStr; use sqlx_core::transaction::TransactionManager; +use sqlx_core::Error; sqlx_core::declare_driver_with_optional_migrate!(DRIVER = Postgres); @@ -52,6 +53,10 @@ impl AnyConnectionBackend for PgConnection { PgTransactionManager::start_rollback(self) } + fn get_transaction_depth(&mut self) -> BoxFuture<'_, Result> { + Box::pin(async { Ok(TransactionDepth::get_transaction_depth(self)) }) + } + fn shrink_buffers(&mut self) { Connection::shrink_buffers(self); } diff --git a/sqlx-postgres/src/connection/mod.rs b/sqlx-postgres/src/connection/mod.rs index 1c7a468240..8e3143fec7 100644 --- a/sqlx-postgres/src/connection/mod.rs +++ b/sqlx-postgres/src/connection/mod.rs @@ -224,3 +224,9 @@ impl AsMut for PgConnection { self } } + +impl TransactionDepth for PgConnection { + fn get_transaction_depth(&self) -> usize { + self.transaction_depth + } +} diff --git a/sqlx-sqlite/src/any.rs b/sqlx-sqlite/src/any.rs index 01600d9931..c14a4eaa29 100644 --- a/sqlx-sqlite/src/any.rs +++ b/sqlx-sqlite/src/any.rs @@ -12,11 +12,12 @@ use sqlx_core::any::{ }; use crate::type_info::DataType; -use sqlx_core::connection::{ConnectOptions, Connection}; +use sqlx_core::connection::{AsyncTransactionDepth, ConnectOptions, Connection}; use sqlx_core::database::Database; use sqlx_core::describe::Describe; use sqlx_core::executor::Executor; use sqlx_core::transaction::TransactionManager; +use sqlx_core::Error; sqlx_core::declare_driver_with_optional_migrate!(DRIVER = Sqlite); @@ -53,6 +54,10 @@ impl AnyConnectionBackend for SqliteConnection { SqliteTransactionManager::start_rollback(self) } + fn get_transaction_depth(&mut self) -> BoxFuture<'_, Result> { + AsyncTransactionDepth::get_transaction_depth(self) + } + fn shrink_buffers(&mut self) { // NO-OP. } diff --git a/sqlx-sqlite/src/connection/mod.rs b/sqlx-sqlite/src/connection/mod.rs index 3588b94f82..a58428e14a 100644 --- a/sqlx-sqlite/src/connection/mod.rs +++ b/sqlx-sqlite/src/connection/mod.rs @@ -426,3 +426,9 @@ impl Statements { self.temp = None; } } + +impl AsyncTransactionDepth for SqliteConnection { + fn get_transaction_depth(&mut self) -> BoxFuture<'_, Result> { + Box::pin(async { Ok(self.lock_handle().await?.guard.transaction_depth) }) + } +} diff --git a/src/lib.rs b/src/lib.rs index d675fa11c3..7a0a7c5fa0 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -5,7 +5,9 @@ pub use sqlx_core::acquire::Acquire; pub use sqlx_core::arguments::{Arguments, IntoArguments}; pub use sqlx_core::column::Column; pub use sqlx_core::column::ColumnIndex; -pub use sqlx_core::connection::{ConnectOptions, Connection}; +pub use sqlx_core::connection::{ + AsyncTransactionDepth, ConnectOptions, Connection, TransactionDepth, +}; pub use sqlx_core::database::{self, Database}; pub use sqlx_core::describe::Describe; pub use sqlx_core::executor::{Execute, Executor}; From 4f9d3f9579f5c384957bf6ec652d90176a5219d7 Mon Sep 17 00:00:00 2001 From: mpyw Date: Sun, 11 Aug 2024 18:21:16 +0900 Subject: [PATCH 02/10] test: Verify `get_transaction_depth()` on postgres --- tests/postgres/postgres.rs | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/tests/postgres/postgres.rs b/tests/postgres/postgres.rs index 7edb5a7a8c..cdc518c666 100644 --- a/tests/postgres/postgres.rs +++ b/tests/postgres/postgres.rs @@ -6,6 +6,7 @@ use sqlx::postgres::{ PgPoolOptions, PgRow, PgSeverity, Postgres, }; use sqlx::{Column, Connection, Executor, Row, Statement, TypeInfo}; +use sqlx_core::connection::TransactionDepth; use sqlx_core::{bytes::Bytes, error::BoxDynError}; use sqlx_test::{new, pool, setup_if_needed}; use std::env; @@ -515,6 +516,7 @@ async fn it_can_work_with_transactions() -> anyhow::Result<()> { #[sqlx_macros::test] async fn it_can_work_with_nested_transactions() -> anyhow::Result<()> { let mut conn = new::().await?; + assert_eq!(conn.get_transaction_depth(), 0); conn.execute("CREATE TABLE IF NOT EXISTS _sqlx_users_2523 (id INTEGER PRIMARY KEY)") .await?; @@ -523,6 +525,7 @@ async fn it_can_work_with_nested_transactions() -> anyhow::Result<()> { // begin let mut tx = conn.begin().await?; // transaction + assert_eq!(conn.get_transaction_depth(), 1); // insert a user sqlx::query("INSERT INTO _sqlx_users_2523 (id) VALUES ($1)") @@ -532,6 +535,7 @@ async fn it_can_work_with_nested_transactions() -> anyhow::Result<()> { // begin once more let mut tx2 = tx.begin().await?; // savepoint + assert_eq!(conn.get_transaction_depth(), 2); // insert another user sqlx::query("INSERT INTO _sqlx_users_2523 (id) VALUES ($1)") @@ -541,6 +545,7 @@ async fn it_can_work_with_nested_transactions() -> anyhow::Result<()> { // never mind, rollback tx2.rollback().await?; // roll that one back + assert_eq!(conn.get_transaction_depth(), 1); // did we really? let (count,): (i64,) = sqlx::query_as("SELECT COUNT(*) FROM _sqlx_users_2523") @@ -551,6 +556,7 @@ async fn it_can_work_with_nested_transactions() -> anyhow::Result<()> { // actually, commit tx.commit().await?; + assert_eq!(conn.get_transaction_depth(), 0); // did we really? let (count,): (i64,) = sqlx::query_as("SELECT COUNT(*) FROM _sqlx_users_2523") From 8f450196e0e9a8d1bb171cc805452a84d91cc0ce Mon Sep 17 00:00:00 2001 From: mpyw Date: Mon, 12 Aug 2024 14:59:50 +0900 Subject: [PATCH 03/10] Refactor: `TransactionManager` delegation without BC SQLite implementation is currently WIP --- sqlx-core/src/any/connection/backend.rs | 5 ++- sqlx-core/src/any/connection/mod.rs | 12 +++---- sqlx-core/src/any/transaction.rs | 5 +++ sqlx-core/src/connection.rs | 43 +++++++++++++------------ sqlx-core/src/transaction.rs | 8 +++++ sqlx-mysql/src/any.rs | 7 ++-- sqlx-mysql/src/connection/mod.rs | 10 +++--- sqlx-mysql/src/transaction.rs | 4 +++ sqlx-postgres/src/any.rs | 7 ++-- sqlx-postgres/src/connection/mod.rs | 10 +++--- sqlx-postgres/src/transaction.rs | 5 +++ sqlx-sqlite/src/any.rs | 7 ++-- sqlx-sqlite/src/connection/mod.rs | 10 +++--- sqlx-sqlite/src/transaction.rs | 9 ++++-- src/lib.rs | 4 +-- tests/postgres/postgres.rs | 1 - 16 files changed, 80 insertions(+), 67 deletions(-) diff --git a/sqlx-core/src/any/connection/backend.rs b/sqlx-core/src/any/connection/backend.rs index a2393832dc..e4e752c07b 100644 --- a/sqlx-core/src/any/connection/backend.rs +++ b/sqlx-core/src/any/connection/backend.rs @@ -1,6 +1,5 @@ use crate::any::{Any, AnyArguments, AnyQueryResult, AnyRow, AnyStatement, AnyTypeInfo}; use crate::describe::Describe; -use crate::Error; use either::Either; use futures_core::future::BoxFuture; use futures_core::stream::BoxStream; @@ -35,13 +34,13 @@ pub trait AnyConnectionBackend: std::any::Any + Debug + Send + 'static { fn start_rollback(&mut self); - /// Returns the current transaction depth asynchronously. + /// Returns the current transaction depth. /// /// Transaction depth indicates the level of nested transactions: /// - Level 0: No active transaction. /// - Level 1: A transaction is active. /// - Level 2 or higher: A transaction is active and one or more SAVEPOINTs have been created within it. - fn get_transaction_depth(&mut self) -> BoxFuture<'_, Result>; + fn get_transaction_depth(&self) -> usize; /// The number of statements currently cached in the connection. fn cached_statements_size(&self) -> usize { diff --git a/sqlx-core/src/any/connection/mod.rs b/sqlx-core/src/any/connection/mod.rs index c2274d9885..ba06d865f1 100644 --- a/sqlx-core/src/any/connection/mod.rs +++ b/sqlx-core/src/any/connection/mod.rs @@ -1,7 +1,7 @@ use futures_core::future::BoxFuture; use crate::any::{Any, AnyConnectOptions}; -use crate::connection::{AsyncTransactionDepth, ConnectOptions, Connection}; +use crate::connection::{ConnectOptions, Connection}; use crate::error::Error; use crate::database::Database; @@ -90,6 +90,10 @@ impl Connection for AnyConnection { Transaction::begin(self) } + fn get_transaction_depth(&self) -> usize { + self.backend.get_transaction_depth() + } + fn cached_statements_size(&self) -> usize { self.backend.cached_statements_size() } @@ -112,9 +116,3 @@ impl Connection for AnyConnection { self.backend.should_flush() } } - -impl AsyncTransactionDepth for AnyConnection { - fn get_transaction_depth(&mut self) -> BoxFuture<'_, Result> { - self.backend.get_transaction_depth() - } -} diff --git a/sqlx-core/src/any/transaction.rs b/sqlx-core/src/any/transaction.rs index fce4175626..22937a0bd1 100644 --- a/sqlx-core/src/any/transaction.rs +++ b/sqlx-core/src/any/transaction.rs @@ -1,6 +1,7 @@ use futures_util::future::BoxFuture; use crate::any::{Any, AnyConnection}; +use crate::database::Database; use crate::error::Error; use crate::transaction::TransactionManager; @@ -24,4 +25,8 @@ impl TransactionManager for AnyTransactionManager { fn start_rollback(conn: &mut AnyConnection) { conn.backend.start_rollback() } + + fn get_transaction_depth(conn: &::Connection) -> usize { + conn.backend.get_transaction_depth() + } } diff --git a/sqlx-core/src/connection.rs b/sqlx-core/src/connection.rs index 155f29d268..9faa71effe 100644 --- a/sqlx-core/src/connection.rs +++ b/sqlx-core/src/connection.rs @@ -1,7 +1,7 @@ use crate::database::{Database, HasStatementCache}; use crate::error::Error; -use crate::transaction::Transaction; +use crate::transaction::{Transaction, TransactionManager}; use futures_core::future::BoxFuture; use log::LevelFilter; use std::fmt::Debug; @@ -49,6 +49,27 @@ pub trait Connection: Send { where Self: Sized; + /// Returns the current transaction depth synchronously. + /// + /// Transaction depth indicates the level of nested transactions: + /// - Level 0: No active transaction. + /// - Level 1: A transaction is active. + /// - Level 2 or higher: A transaction is active and one or more SAVEPOINTs have been created within it. + fn get_transaction_depth(&self) -> usize { + // Fallback implementation to avoid breaking changes + ::TransactionManager::get_transaction_depth(self) + } + + /// Checks if the connection is currently in a transaction. + /// + /// This method returns `true` if the current transaction depth is greater than 0, + /// indicating that a transaction is active. It returns `false` if the transaction depth is 0, + /// meaning no transaction is active. + #[inline] + fn is_in_transaction(&self) -> bool { + self.get_transaction_depth() != 0 + } + /// Execute the function inside a transaction. /// /// If the function returns an error, the transaction will be rolled back. If it does not @@ -236,23 +257,3 @@ pub trait ConnectOptions: 'static + Send + Sync + FromStr + Debug + .log_slow_statements(LevelFilter::Off, Duration::default()) } } - -pub trait TransactionDepth { - /// Returns the current transaction depth synchronously. - /// - /// Transaction depth indicates the level of nested transactions: - /// - Level 0: No active transaction. - /// - Level 1: A transaction is active. - /// - Level 2 or higher: A transaction is active and one or more SAVEPOINTs have been created within it. - fn get_transaction_depth(&self) -> usize; -} - -pub trait AsyncTransactionDepth { - /// Returns the current transaction depth asynchronously. - /// - /// Transaction depth indicates the level of nested transactions: - /// - Level 0: No active transaction. - /// - Level 1: A transaction is active. - /// - Level 2 or higher: A transaction is active and one or more SAVEPOINTs have been created within it. - fn get_transaction_depth(&mut self) -> BoxFuture<'_, Result>; -} diff --git a/sqlx-core/src/transaction.rs b/sqlx-core/src/transaction.rs index 9cd38aab3a..48b1463f2a 100644 --- a/sqlx-core/src/transaction.rs +++ b/sqlx-core/src/transaction.rs @@ -32,6 +32,14 @@ pub trait TransactionManager { /// Starts to abort the active transaction or restore from the most recent snapshot. fn start_rollback(conn: &mut ::Connection); + + /// Returns the current transaction depth. + /// + /// Transaction depth indicates the level of nested transactions: + /// - Level 0: No active transaction. + /// - Level 1: A transaction is active. + /// - Level 2 or higher: A transaction is active and one or more SAVEPOINTs have been created within it. + fn get_transaction_depth(conn: &::Connection) -> usize; } /// An in-progress database transaction or savepoint. diff --git a/sqlx-mysql/src/any.rs b/sqlx-mysql/src/any.rs index c688ad0d48..709fdafcea 100644 --- a/sqlx-mysql/src/any.rs +++ b/sqlx-mysql/src/any.rs @@ -11,12 +11,11 @@ use sqlx_core::any::{ Any, AnyArguments, AnyColumn, AnyConnectOptions, AnyConnectionBackend, AnyQueryResult, AnyRow, AnyStatement, AnyTypeInfo, AnyTypeInfoKind, }; -use sqlx_core::connection::{Connection, TransactionDepth}; +use sqlx_core::connection::Connection; use sqlx_core::database::Database; use sqlx_core::describe::Describe; use sqlx_core::executor::Executor; use sqlx_core::transaction::TransactionManager; -use sqlx_core::Error; use std::future; sqlx_core::declare_driver_with_optional_migrate!(DRIVER = MySql); @@ -54,8 +53,8 @@ impl AnyConnectionBackend for MySqlConnection { MySqlTransactionManager::start_rollback(self) } - fn get_transaction_depth(&mut self) -> BoxFuture<'_, Result> { - Box::pin(async { Ok(TransactionDepth::get_transaction_depth(self)) }) + fn get_transaction_depth(&self) -> usize { + MySqlTransactionManager::get_transaction_depth(self) } fn shrink_buffers(&mut self) { diff --git a/sqlx-mysql/src/connection/mod.rs b/sqlx-mysql/src/connection/mod.rs index b2a9a94b42..af50188ff6 100644 --- a/sqlx-mysql/src/connection/mod.rs +++ b/sqlx-mysql/src/connection/mod.rs @@ -114,13 +114,11 @@ impl Connection for MySqlConnection { Transaction::begin(self) } - fn shrink_buffers(&mut self) { - self.inner.stream.shrink_buffers(); - } -} - -impl TransactionDepth for MySqlConnection { fn get_transaction_depth(&self) -> usize { self.inner.transaction_depth } + + fn shrink_buffers(&mut self) { + self.inner.stream.shrink_buffers(); + } } diff --git a/sqlx-mysql/src/transaction.rs b/sqlx-mysql/src/transaction.rs index 99d6526392..5a4c751b2c 100644 --- a/sqlx-mysql/src/transaction.rs +++ b/sqlx-mysql/src/transaction.rs @@ -64,4 +64,8 @@ impl TransactionManager for MySqlTransactionManager { conn.inner.transaction_depth = depth - 1; } } + + fn get_transaction_depth(conn: &MySqlConnection) -> usize { + conn.inner.transaction_depth + } } diff --git a/sqlx-postgres/src/any.rs b/sqlx-postgres/src/any.rs index d0e0459721..bcfef8dce1 100644 --- a/sqlx-postgres/src/any.rs +++ b/sqlx-postgres/src/any.rs @@ -10,13 +10,12 @@ use std::future; pub use sqlx_core::any::*; use crate::type_info::PgType; -use sqlx_core::connection::{Connection, TransactionDepth}; +use sqlx_core::connection::Connection; use sqlx_core::database::Database; use sqlx_core::describe::Describe; use sqlx_core::executor::Executor; use sqlx_core::ext::ustr::UStr; use sqlx_core::transaction::TransactionManager; -use sqlx_core::Error; sqlx_core::declare_driver_with_optional_migrate!(DRIVER = Postgres); @@ -53,8 +52,8 @@ impl AnyConnectionBackend for PgConnection { PgTransactionManager::start_rollback(self) } - fn get_transaction_depth(&mut self) -> BoxFuture<'_, Result> { - Box::pin(async { Ok(TransactionDepth::get_transaction_depth(self)) }) + fn get_transaction_depth(&mut self) -> usize { + PgTransactionManager::get_transaction_depth(self) } fn shrink_buffers(&mut self) { diff --git a/sqlx-postgres/src/connection/mod.rs b/sqlx-postgres/src/connection/mod.rs index 8e3143fec7..dcfc6206a5 100644 --- a/sqlx-postgres/src/connection/mod.rs +++ b/sqlx-postgres/src/connection/mod.rs @@ -171,6 +171,10 @@ impl Connection for PgConnection { Transaction::begin(self) } + fn get_transaction_depth(&self) -> usize { + self.transaction_depth + } + fn cached_statements_size(&self) -> usize { self.cache_statement.len() } @@ -224,9 +228,3 @@ impl AsMut for PgConnection { self } } - -impl TransactionDepth for PgConnection { - fn get_transaction_depth(&self) -> usize { - self.transaction_depth - } -} diff --git a/sqlx-postgres/src/transaction.rs b/sqlx-postgres/src/transaction.rs index 02028624e1..40a2de2e6f 100644 --- a/sqlx-postgres/src/transaction.rs +++ b/sqlx-postgres/src/transaction.rs @@ -1,4 +1,5 @@ use futures_core::future::BoxFuture; +use sqlx_core::database::Database; use crate::error::Error; use crate::executor::Executor; @@ -59,6 +60,10 @@ impl TransactionManager for PgTransactionManager { conn.transaction_depth -= 1; } } + + fn get_transaction_depth(conn: &::Connection) -> usize { + conn.transaction_depth + } } struct Rollback<'c> { diff --git a/sqlx-sqlite/src/any.rs b/sqlx-sqlite/src/any.rs index c14a4eaa29..29fb6bd7a6 100644 --- a/sqlx-sqlite/src/any.rs +++ b/sqlx-sqlite/src/any.rs @@ -12,12 +12,11 @@ use sqlx_core::any::{ }; use crate::type_info::DataType; -use sqlx_core::connection::{AsyncTransactionDepth, ConnectOptions, Connection}; +use sqlx_core::connection::{ConnectOptions, Connection}; use sqlx_core::database::Database; use sqlx_core::describe::Describe; use sqlx_core::executor::Executor; use sqlx_core::transaction::TransactionManager; -use sqlx_core::Error; sqlx_core::declare_driver_with_optional_migrate!(DRIVER = Sqlite); @@ -54,8 +53,8 @@ impl AnyConnectionBackend for SqliteConnection { SqliteTransactionManager::start_rollback(self) } - fn get_transaction_depth(&mut self) -> BoxFuture<'_, Result> { - AsyncTransactionDepth::get_transaction_depth(self) + fn get_transaction_depth(&self) -> usize { + SqliteTransactionManager::get_transaction_depth(self) } fn shrink_buffers(&mut self) { diff --git a/sqlx-sqlite/src/connection/mod.rs b/sqlx-sqlite/src/connection/mod.rs index a58428e14a..f12f49edbd 100644 --- a/sqlx-sqlite/src/connection/mod.rs +++ b/sqlx-sqlite/src/connection/mod.rs @@ -210,6 +210,10 @@ impl Connection for SqliteConnection { Transaction::begin(self) } + fn get_transaction_depth(&self) -> usize { + todo!() + } + fn cached_statements_size(&self) -> usize { self.worker .shared @@ -426,9 +430,3 @@ impl Statements { self.temp = None; } } - -impl AsyncTransactionDepth for SqliteConnection { - fn get_transaction_depth(&mut self) -> BoxFuture<'_, Result> { - Box::pin(async { Ok(self.lock_handle().await?.guard.transaction_depth) }) - } -} diff --git a/sqlx-sqlite/src/transaction.rs b/sqlx-sqlite/src/transaction.rs index 24eaca51b1..dafc33359d 100644 --- a/sqlx-sqlite/src/transaction.rs +++ b/sqlx-sqlite/src/transaction.rs @@ -1,9 +1,10 @@ use futures_core::future::BoxFuture; - -use crate::{Sqlite, SqliteConnection}; +use sqlx_core::database::Database; use sqlx_core::error::Error; use sqlx_core::transaction::TransactionManager; +use crate::{Sqlite, SqliteConnection}; + /// Implementation of [`TransactionManager`] for SQLite. pub struct SqliteTransactionManager; @@ -25,4 +26,8 @@ impl TransactionManager for SqliteTransactionManager { fn start_rollback(conn: &mut SqliteConnection) { conn.worker.start_rollback().ok(); } + + fn get_transaction_depth(_conn: &::Connection) -> usize { + todo!() + } } diff --git a/src/lib.rs b/src/lib.rs index 7a0a7c5fa0..d675fa11c3 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -5,9 +5,7 @@ pub use sqlx_core::acquire::Acquire; pub use sqlx_core::arguments::{Arguments, IntoArguments}; pub use sqlx_core::column::Column; pub use sqlx_core::column::ColumnIndex; -pub use sqlx_core::connection::{ - AsyncTransactionDepth, ConnectOptions, Connection, TransactionDepth, -}; +pub use sqlx_core::connection::{ConnectOptions, Connection}; pub use sqlx_core::database::{self, Database}; pub use sqlx_core::describe::Describe; pub use sqlx_core::executor::{Execute, Executor}; diff --git a/tests/postgres/postgres.rs b/tests/postgres/postgres.rs index cdc518c666..8e5d0d227e 100644 --- a/tests/postgres/postgres.rs +++ b/tests/postgres/postgres.rs @@ -6,7 +6,6 @@ use sqlx::postgres::{ PgPoolOptions, PgRow, PgSeverity, Postgres, }; use sqlx::{Column, Connection, Executor, Row, Statement, TypeInfo}; -use sqlx_core::connection::TransactionDepth; use sqlx_core::{bytes::Bytes, error::BoxDynError}; use sqlx_test::{new, pool, setup_if_needed}; use std::env; From 9409e5e9d32cde7928493ac61864f14cb5e31760 Mon Sep 17 00:00:00 2001 From: mpyw Date: Mon, 12 Aug 2024 21:09:52 +0900 Subject: [PATCH 04/10] Fix: Avoid breaking changes on `AnyConnectionBackend` --- sqlx-core/src/any/connection/backend.rs | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/sqlx-core/src/any/connection/backend.rs b/sqlx-core/src/any/connection/backend.rs index e4e752c07b..5d8b2709fd 100644 --- a/sqlx-core/src/any/connection/backend.rs +++ b/sqlx-core/src/any/connection/backend.rs @@ -40,7 +40,9 @@ pub trait AnyConnectionBackend: std::any::Any + Debug + Send + 'static { /// - Level 0: No active transaction. /// - Level 1: A transaction is active. /// - Level 2 or higher: A transaction is active and one or more SAVEPOINTs have been created within it. - fn get_transaction_depth(&self) -> usize; + fn get_transaction_depth(&self) -> usize { + unimplemented!("get_transaction_depth() is not implemented for this backend. This is a provided method to avoid a breaking change, but it will become a required method in version 0.9 and later."); + } /// The number of statements currently cached in the connection. fn cached_statements_size(&self) -> usize { From 10d0aea44ba5d72de6c179bfbfb868e9294355b8 Mon Sep 17 00:00:00 2001 From: mpyw Date: Mon, 12 Aug 2024 21:34:12 +0900 Subject: [PATCH 05/10] Refactor: Remove verbose `SqliteConnection` typing --- sqlx-sqlite/src/transaction.rs | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/sqlx-sqlite/src/transaction.rs b/sqlx-sqlite/src/transaction.rs index dafc33359d..cbde09ae68 100644 --- a/sqlx-sqlite/src/transaction.rs +++ b/sqlx-sqlite/src/transaction.rs @@ -1,5 +1,4 @@ use futures_core::future::BoxFuture; -use sqlx_core::database::Database; use sqlx_core::error::Error; use sqlx_core::transaction::TransactionManager; @@ -27,7 +26,7 @@ impl TransactionManager for SqliteTransactionManager { conn.worker.start_rollback().ok(); } - fn get_transaction_depth(_conn: &::Connection) -> usize { + fn get_transaction_depth(_conn: &SqliteConnection) -> usize { todo!() } } From a66787d36d62876b55475ef2326d17bade817aed Mon Sep 17 00:00:00 2001 From: mpyw Date: Mon, 12 Aug 2024 22:39:14 +0900 Subject: [PATCH 06/10] Feat: Implementation for SQLite I have included `AtomicUsize` in `WorkerSharedState`. Ideally, it is not desirable to execute `load` and `fetch_add` in two separate steps, but we decided to allow it here since there is only one thread writing. To prevent writing from other threads, the field itself was made private, and a getter method was provided with `pub(crate)`. --- sqlx-sqlite/src/connection/establish.rs | 1 - sqlx-sqlite/src/connection/mod.rs | 5 +---- sqlx-sqlite/src/connection/worker.rs | 22 +++++++++++++++------- sqlx-sqlite/src/transaction.rs | 4 ++-- 4 files changed, 18 insertions(+), 14 deletions(-) diff --git a/sqlx-sqlite/src/connection/establish.rs b/sqlx-sqlite/src/connection/establish.rs index 6438b6b7f4..698e64fce7 100644 --- a/sqlx-sqlite/src/connection/establish.rs +++ b/sqlx-sqlite/src/connection/establish.rs @@ -292,7 +292,6 @@ impl EstablishParams { Ok(ConnectionState { handle, statements: Statements::new(self.statement_cache_capacity), - transaction_depth: 0, log_settings: self.log_settings.clone(), progress_handler_callback: None, update_hook_callback: None, diff --git a/sqlx-sqlite/src/connection/mod.rs b/sqlx-sqlite/src/connection/mod.rs index f12f49edbd..0234d5d914 100644 --- a/sqlx-sqlite/src/connection/mod.rs +++ b/sqlx-sqlite/src/connection/mod.rs @@ -94,9 +94,6 @@ unsafe impl Send for UpdateHookHandler {} pub(crate) struct ConnectionState { pub(crate) handle: ConnectionHandle, - // transaction status - pub(crate) transaction_depth: usize, - pub(crate) statements: Statements, log_settings: LogSettings, @@ -211,7 +208,7 @@ impl Connection for SqliteConnection { } fn get_transaction_depth(&self) -> usize { - todo!() + self.worker.shared.get_transaction_depth() } fn cached_statements_size(&self) -> usize { diff --git a/sqlx-sqlite/src/connection/worker.rs b/sqlx-sqlite/src/connection/worker.rs index a01de2419c..c0b6973b73 100644 --- a/sqlx-sqlite/src/connection/worker.rs +++ b/sqlx-sqlite/src/connection/worker.rs @@ -34,10 +34,17 @@ pub(crate) struct ConnectionWorker { } pub(crate) struct WorkerSharedState { + transaction_depth: AtomicUsize, pub(crate) cached_statements_size: AtomicUsize, pub(crate) conn: Mutex, } +impl WorkerSharedState { + pub(crate) fn get_transaction_depth(&self) -> usize { + self.transaction_depth.load(Ordering::Acquire) + } +} + enum Command { Prepare { query: Box, @@ -93,6 +100,7 @@ impl ConnectionWorker { }; let shared = Arc::new(WorkerSharedState { + transaction_depth: AtomicUsize::new(0), cached_statements_size: AtomicUsize::new(0), // note: must be fair because in `Command::UnlockDb` we unlock the mutex // and then immediately try to relock it; an unfair mutex would immediately @@ -181,12 +189,12 @@ impl ConnectionWorker { update_cached_statements_size(&conn, &shared.cached_statements_size); } Command::Begin { tx } => { - let depth = conn.transaction_depth; + let depth = shared.transaction_depth.load(Ordering::Acquire); let res = conn.handle .exec(begin_ansi_transaction_sql(depth)) .map(|_| { - conn.transaction_depth += 1; + shared.transaction_depth.fetch_add(1, Ordering::Release); }); let res_ok = res.is_ok(); @@ -199,7 +207,7 @@ impl ConnectionWorker { .handle .exec(rollback_ansi_transaction_sql(depth + 1)) .map(|_| { - conn.transaction_depth -= 1; + shared.transaction_depth.fetch_sub(1, Ordering::Release); }) { // The rollback failed. To prevent leaving the connection @@ -211,13 +219,13 @@ impl ConnectionWorker { } } Command::Commit { tx } => { - let depth = conn.transaction_depth; + let depth = shared.transaction_depth.load(Ordering::Acquire); let res = if depth > 0 { conn.handle .exec(commit_ansi_transaction_sql(depth)) .map(|_| { - conn.transaction_depth -= 1; + shared.transaction_depth.fetch_sub(1, Ordering::Release); }) } else { Ok(()) @@ -237,13 +245,13 @@ impl ConnectionWorker { continue; } - let depth = conn.transaction_depth; + let depth = shared.transaction_depth.load(Ordering::Acquire); let res = if depth > 0 { conn.handle .exec(rollback_ansi_transaction_sql(depth)) .map(|_| { - conn.transaction_depth -= 1; + shared.transaction_depth.fetch_sub(1, Ordering::Release); }) } else { Ok(()) diff --git a/sqlx-sqlite/src/transaction.rs b/sqlx-sqlite/src/transaction.rs index cbde09ae68..6bce5b30cf 100644 --- a/sqlx-sqlite/src/transaction.rs +++ b/sqlx-sqlite/src/transaction.rs @@ -26,7 +26,7 @@ impl TransactionManager for SqliteTransactionManager { conn.worker.start_rollback().ok(); } - fn get_transaction_depth(_conn: &SqliteConnection) -> usize { - todo!() + fn get_transaction_depth(conn: &SqliteConnection) -> usize { + conn.worker.shared.get_transaction_depth() } } From 54c73b63cd41ec12cf4a34f71e2d96491f465017 Mon Sep 17 00:00:00 2001 From: mpyw Date: Mon, 12 Aug 2024 22:47:21 +0900 Subject: [PATCH 07/10] Refactor: Same approach for `cached_statements_size` ref: a66787d36d62876b55475ef2326d17bade817aed --- sqlx-sqlite/src/connection/mod.rs | 5 +---- sqlx-sqlite/src/connection/worker.rs | 6 +++++- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/sqlx-sqlite/src/connection/mod.rs b/sqlx-sqlite/src/connection/mod.rs index 0234d5d914..0950b31db1 100644 --- a/sqlx-sqlite/src/connection/mod.rs +++ b/sqlx-sqlite/src/connection/mod.rs @@ -212,10 +212,7 @@ impl Connection for SqliteConnection { } fn cached_statements_size(&self) -> usize { - self.worker - .shared - .cached_statements_size - .load(std::sync::atomic::Ordering::Acquire) + self.worker.shared.get_cached_statements_size() } fn clear_cached_statements(&mut self) -> BoxFuture<'_, Result<(), Error>> { diff --git a/sqlx-sqlite/src/connection/worker.rs b/sqlx-sqlite/src/connection/worker.rs index c0b6973b73..6dda814a0a 100644 --- a/sqlx-sqlite/src/connection/worker.rs +++ b/sqlx-sqlite/src/connection/worker.rs @@ -35,7 +35,7 @@ pub(crate) struct ConnectionWorker { pub(crate) struct WorkerSharedState { transaction_depth: AtomicUsize, - pub(crate) cached_statements_size: AtomicUsize, + cached_statements_size: AtomicUsize, pub(crate) conn: Mutex, } @@ -43,6 +43,10 @@ impl WorkerSharedState { pub(crate) fn get_transaction_depth(&self) -> usize { self.transaction_depth.load(Ordering::Acquire) } + + pub(crate) fn get_cached_statements_size(&self) -> usize { + self.cached_statements_size.load(Ordering::Acquire) + } } enum Command { From b780ed7788f717334aeea90b91bc944aa78fb015 Mon Sep 17 00:00:00 2001 From: mpyw Date: Tue, 13 Aug 2024 11:16:51 +0900 Subject: [PATCH 08/10] Fix: Add missing `is_in_transaction` for backend --- sqlx-core/src/any/connection/backend.rs | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/sqlx-core/src/any/connection/backend.rs b/sqlx-core/src/any/connection/backend.rs index 5d8b2709fd..6ba5ad92b2 100644 --- a/sqlx-core/src/any/connection/backend.rs +++ b/sqlx-core/src/any/connection/backend.rs @@ -44,6 +44,16 @@ pub trait AnyConnectionBackend: std::any::Any + Debug + Send + 'static { unimplemented!("get_transaction_depth() is not implemented for this backend. This is a provided method to avoid a breaking change, but it will become a required method in version 0.9 and later."); } + /// Checks if the connection is currently in a transaction. + /// + /// This method returns `true` if the current transaction depth is greater than 0, + /// indicating that a transaction is active. It returns `false` if the transaction depth is 0, + /// meaning no transaction is active. + #[inline] + fn is_in_transaction(&self) -> bool { + self.get_transaction_depth() != 0 + } + /// The number of statements currently cached in the connection. fn cached_statements_size(&self) -> usize { 0 From c866c5b8017948a00b7fe6cf9330ee7526af14de Mon Sep 17 00:00:00 2001 From: mpyw Date: Tue, 13 Aug 2024 17:01:44 +0900 Subject: [PATCH 09/10] Doc: Remove verbose "synchronously" word --- sqlx-core/src/connection.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sqlx-core/src/connection.rs b/sqlx-core/src/connection.rs index 9faa71effe..1f86f6c1a9 100644 --- a/sqlx-core/src/connection.rs +++ b/sqlx-core/src/connection.rs @@ -49,7 +49,7 @@ pub trait Connection: Send { where Self: Sized; - /// Returns the current transaction depth synchronously. + /// Returns the current transaction depth. /// /// Transaction depth indicates the level of nested transactions: /// - Level 0: No active transaction. From 85ad53990788fd0a4cd132953b6a7256ec362473 Mon Sep 17 00:00:00 2001 From: mpyw Date: Tue, 13 Aug 2024 21:49:50 +0900 Subject: [PATCH 10/10] Fix: Remove useless `mut` qualifier --- sqlx-postgres/src/any.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sqlx-postgres/src/any.rs b/sqlx-postgres/src/any.rs index bcfef8dce1..bc74f58da2 100644 --- a/sqlx-postgres/src/any.rs +++ b/sqlx-postgres/src/any.rs @@ -52,7 +52,7 @@ impl AnyConnectionBackend for PgConnection { PgTransactionManager::start_rollback(self) } - fn get_transaction_depth(&mut self) -> usize { + fn get_transaction_depth(&self) -> usize { PgTransactionManager::get_transaction_depth(self) }