diff --git a/sqlx-core/src/any/connection/backend.rs b/sqlx-core/src/any/connection/backend.rs index b30cbe83f3..6ba5ad92b2 100644 --- a/sqlx-core/src/any/connection/backend.rs +++ b/sqlx-core/src/any/connection/backend.rs @@ -34,6 +34,26 @@ pub trait AnyConnectionBackend: std::any::Any + Debug + Send + 'static { fn start_rollback(&mut self); + /// Returns the current transaction depth. + /// + /// Transaction depth indicates the level of nested transactions: + /// - Level 0: No active transaction. + /// - Level 1: A transaction is active. + /// - Level 2 or higher: A transaction is active and one or more SAVEPOINTs have been created within it. + fn get_transaction_depth(&self) -> usize { + unimplemented!("get_transaction_depth() is not implemented for this backend. This is a provided method to avoid a breaking change, but it will become a required method in version 0.9 and later."); + } + + /// Checks if the connection is currently in a transaction. + /// + /// This method returns `true` if the current transaction depth is greater than 0, + /// indicating that a transaction is active. It returns `false` if the transaction depth is 0, + /// meaning no transaction is active. + #[inline] + fn is_in_transaction(&self) -> bool { + self.get_transaction_depth() != 0 + } + /// The number of statements currently cached in the connection. fn cached_statements_size(&self) -> usize { 0 diff --git a/sqlx-core/src/any/connection/mod.rs b/sqlx-core/src/any/connection/mod.rs index b6f795848a..ba06d865f1 100644 --- a/sqlx-core/src/any/connection/mod.rs +++ b/sqlx-core/src/any/connection/mod.rs @@ -90,6 +90,10 @@ impl Connection for AnyConnection { Transaction::begin(self) } + fn get_transaction_depth(&self) -> usize { + self.backend.get_transaction_depth() + } + fn cached_statements_size(&self) -> usize { self.backend.cached_statements_size() } diff --git a/sqlx-core/src/any/transaction.rs b/sqlx-core/src/any/transaction.rs index fce4175626..22937a0bd1 100644 --- a/sqlx-core/src/any/transaction.rs +++ b/sqlx-core/src/any/transaction.rs @@ -1,6 +1,7 @@ use futures_util::future::BoxFuture; use crate::any::{Any, AnyConnection}; +use crate::database::Database; use crate::error::Error; use crate::transaction::TransactionManager; @@ -24,4 +25,8 @@ impl TransactionManager for AnyTransactionManager { fn start_rollback(conn: &mut AnyConnection) { conn.backend.start_rollback() } + + fn get_transaction_depth(conn: &::Connection) -> usize { + conn.backend.get_transaction_depth() + } } diff --git a/sqlx-core/src/connection.rs b/sqlx-core/src/connection.rs index ce2aa6c629..1f86f6c1a9 100644 --- a/sqlx-core/src/connection.rs +++ b/sqlx-core/src/connection.rs @@ -1,7 +1,7 @@ use crate::database::{Database, HasStatementCache}; use crate::error::Error; -use crate::transaction::Transaction; +use crate::transaction::{Transaction, TransactionManager}; use futures_core::future::BoxFuture; use log::LevelFilter; use std::fmt::Debug; @@ -49,6 +49,27 @@ pub trait Connection: Send { where Self: Sized; + /// Returns the current transaction depth. + /// + /// Transaction depth indicates the level of nested transactions: + /// - Level 0: No active transaction. + /// - Level 1: A transaction is active. + /// - Level 2 or higher: A transaction is active and one or more SAVEPOINTs have been created within it. + fn get_transaction_depth(&self) -> usize { + // Fallback implementation to avoid breaking changes + ::TransactionManager::get_transaction_depth(self) + } + + /// Checks if the connection is currently in a transaction. + /// + /// This method returns `true` if the current transaction depth is greater than 0, + /// indicating that a transaction is active. It returns `false` if the transaction depth is 0, + /// meaning no transaction is active. + #[inline] + fn is_in_transaction(&self) -> bool { + self.get_transaction_depth() != 0 + } + /// Execute the function inside a transaction. /// /// If the function returns an error, the transaction will be rolled back. If it does not diff --git a/sqlx-core/src/transaction.rs b/sqlx-core/src/transaction.rs index 9cd38aab3a..48b1463f2a 100644 --- a/sqlx-core/src/transaction.rs +++ b/sqlx-core/src/transaction.rs @@ -32,6 +32,14 @@ pub trait TransactionManager { /// Starts to abort the active transaction or restore from the most recent snapshot. fn start_rollback(conn: &mut ::Connection); + + /// Returns the current transaction depth. + /// + /// Transaction depth indicates the level of nested transactions: + /// - Level 0: No active transaction. + /// - Level 1: A transaction is active. + /// - Level 2 or higher: A transaction is active and one or more SAVEPOINTs have been created within it. + fn get_transaction_depth(conn: &::Connection) -> usize; } /// An in-progress database transaction or savepoint. diff --git a/sqlx-mysql/src/any.rs b/sqlx-mysql/src/any.rs index fa8d34f8db..709fdafcea 100644 --- a/sqlx-mysql/src/any.rs +++ b/sqlx-mysql/src/any.rs @@ -53,6 +53,10 @@ impl AnyConnectionBackend for MySqlConnection { MySqlTransactionManager::start_rollback(self) } + fn get_transaction_depth(&self) -> usize { + MySqlTransactionManager::get_transaction_depth(self) + } + fn shrink_buffers(&mut self) { Connection::shrink_buffers(self); } diff --git a/sqlx-mysql/src/connection/mod.rs b/sqlx-mysql/src/connection/mod.rs index c4978a7701..af50188ff6 100644 --- a/sqlx-mysql/src/connection/mod.rs +++ b/sqlx-mysql/src/connection/mod.rs @@ -114,6 +114,10 @@ impl Connection for MySqlConnection { Transaction::begin(self) } + fn get_transaction_depth(&self) -> usize { + self.inner.transaction_depth + } + fn shrink_buffers(&mut self) { self.inner.stream.shrink_buffers(); } diff --git a/sqlx-mysql/src/transaction.rs b/sqlx-mysql/src/transaction.rs index 99d6526392..5a4c751b2c 100644 --- a/sqlx-mysql/src/transaction.rs +++ b/sqlx-mysql/src/transaction.rs @@ -64,4 +64,8 @@ impl TransactionManager for MySqlTransactionManager { conn.inner.transaction_depth = depth - 1; } } + + fn get_transaction_depth(conn: &MySqlConnection) -> usize { + conn.inner.transaction_depth + } } diff --git a/sqlx-postgres/src/any.rs b/sqlx-postgres/src/any.rs index 7eae4bcb73..bc74f58da2 100644 --- a/sqlx-postgres/src/any.rs +++ b/sqlx-postgres/src/any.rs @@ -52,6 +52,10 @@ impl AnyConnectionBackend for PgConnection { PgTransactionManager::start_rollback(self) } + fn get_transaction_depth(&self) -> usize { + PgTransactionManager::get_transaction_depth(self) + } + fn shrink_buffers(&mut self) { Connection::shrink_buffers(self); } diff --git a/sqlx-postgres/src/connection/mod.rs b/sqlx-postgres/src/connection/mod.rs index 1c7a468240..dcfc6206a5 100644 --- a/sqlx-postgres/src/connection/mod.rs +++ b/sqlx-postgres/src/connection/mod.rs @@ -171,6 +171,10 @@ impl Connection for PgConnection { Transaction::begin(self) } + fn get_transaction_depth(&self) -> usize { + self.transaction_depth + } + fn cached_statements_size(&self) -> usize { self.cache_statement.len() } diff --git a/sqlx-postgres/src/transaction.rs b/sqlx-postgres/src/transaction.rs index 02028624e1..40a2de2e6f 100644 --- a/sqlx-postgres/src/transaction.rs +++ b/sqlx-postgres/src/transaction.rs @@ -1,4 +1,5 @@ use futures_core::future::BoxFuture; +use sqlx_core::database::Database; use crate::error::Error; use crate::executor::Executor; @@ -59,6 +60,10 @@ impl TransactionManager for PgTransactionManager { conn.transaction_depth -= 1; } } + + fn get_transaction_depth(conn: &::Connection) -> usize { + conn.transaction_depth + } } struct Rollback<'c> { diff --git a/sqlx-sqlite/src/any.rs b/sqlx-sqlite/src/any.rs index 01600d9931..29fb6bd7a6 100644 --- a/sqlx-sqlite/src/any.rs +++ b/sqlx-sqlite/src/any.rs @@ -53,6 +53,10 @@ impl AnyConnectionBackend for SqliteConnection { SqliteTransactionManager::start_rollback(self) } + fn get_transaction_depth(&self) -> usize { + SqliteTransactionManager::get_transaction_depth(self) + } + fn shrink_buffers(&mut self) { // NO-OP. } diff --git a/sqlx-sqlite/src/connection/establish.rs b/sqlx-sqlite/src/connection/establish.rs index 6438b6b7f4..698e64fce7 100644 --- a/sqlx-sqlite/src/connection/establish.rs +++ b/sqlx-sqlite/src/connection/establish.rs @@ -292,7 +292,6 @@ impl EstablishParams { Ok(ConnectionState { handle, statements: Statements::new(self.statement_cache_capacity), - transaction_depth: 0, log_settings: self.log_settings.clone(), progress_handler_callback: None, update_hook_callback: None, diff --git a/sqlx-sqlite/src/connection/mod.rs b/sqlx-sqlite/src/connection/mod.rs index 3588b94f82..0950b31db1 100644 --- a/sqlx-sqlite/src/connection/mod.rs +++ b/sqlx-sqlite/src/connection/mod.rs @@ -94,9 +94,6 @@ unsafe impl Send for UpdateHookHandler {} pub(crate) struct ConnectionState { pub(crate) handle: ConnectionHandle, - // transaction status - pub(crate) transaction_depth: usize, - pub(crate) statements: Statements, log_settings: LogSettings, @@ -210,11 +207,12 @@ impl Connection for SqliteConnection { Transaction::begin(self) } + fn get_transaction_depth(&self) -> usize { + self.worker.shared.get_transaction_depth() + } + fn cached_statements_size(&self) -> usize { - self.worker - .shared - .cached_statements_size - .load(std::sync::atomic::Ordering::Acquire) + self.worker.shared.get_cached_statements_size() } fn clear_cached_statements(&mut self) -> BoxFuture<'_, Result<(), Error>> { diff --git a/sqlx-sqlite/src/connection/worker.rs b/sqlx-sqlite/src/connection/worker.rs index a01de2419c..6dda814a0a 100644 --- a/sqlx-sqlite/src/connection/worker.rs +++ b/sqlx-sqlite/src/connection/worker.rs @@ -34,10 +34,21 @@ pub(crate) struct ConnectionWorker { } pub(crate) struct WorkerSharedState { - pub(crate) cached_statements_size: AtomicUsize, + transaction_depth: AtomicUsize, + cached_statements_size: AtomicUsize, pub(crate) conn: Mutex, } +impl WorkerSharedState { + pub(crate) fn get_transaction_depth(&self) -> usize { + self.transaction_depth.load(Ordering::Acquire) + } + + pub(crate) fn get_cached_statements_size(&self) -> usize { + self.cached_statements_size.load(Ordering::Acquire) + } +} + enum Command { Prepare { query: Box, @@ -93,6 +104,7 @@ impl ConnectionWorker { }; let shared = Arc::new(WorkerSharedState { + transaction_depth: AtomicUsize::new(0), cached_statements_size: AtomicUsize::new(0), // note: must be fair because in `Command::UnlockDb` we unlock the mutex // and then immediately try to relock it; an unfair mutex would immediately @@ -181,12 +193,12 @@ impl ConnectionWorker { update_cached_statements_size(&conn, &shared.cached_statements_size); } Command::Begin { tx } => { - let depth = conn.transaction_depth; + let depth = shared.transaction_depth.load(Ordering::Acquire); let res = conn.handle .exec(begin_ansi_transaction_sql(depth)) .map(|_| { - conn.transaction_depth += 1; + shared.transaction_depth.fetch_add(1, Ordering::Release); }); let res_ok = res.is_ok(); @@ -199,7 +211,7 @@ impl ConnectionWorker { .handle .exec(rollback_ansi_transaction_sql(depth + 1)) .map(|_| { - conn.transaction_depth -= 1; + shared.transaction_depth.fetch_sub(1, Ordering::Release); }) { // The rollback failed. To prevent leaving the connection @@ -211,13 +223,13 @@ impl ConnectionWorker { } } Command::Commit { tx } => { - let depth = conn.transaction_depth; + let depth = shared.transaction_depth.load(Ordering::Acquire); let res = if depth > 0 { conn.handle .exec(commit_ansi_transaction_sql(depth)) .map(|_| { - conn.transaction_depth -= 1; + shared.transaction_depth.fetch_sub(1, Ordering::Release); }) } else { Ok(()) @@ -237,13 +249,13 @@ impl ConnectionWorker { continue; } - let depth = conn.transaction_depth; + let depth = shared.transaction_depth.load(Ordering::Acquire); let res = if depth > 0 { conn.handle .exec(rollback_ansi_transaction_sql(depth)) .map(|_| { - conn.transaction_depth -= 1; + shared.transaction_depth.fetch_sub(1, Ordering::Release); }) } else { Ok(()) diff --git a/sqlx-sqlite/src/transaction.rs b/sqlx-sqlite/src/transaction.rs index 24eaca51b1..6bce5b30cf 100644 --- a/sqlx-sqlite/src/transaction.rs +++ b/sqlx-sqlite/src/transaction.rs @@ -1,9 +1,9 @@ use futures_core::future::BoxFuture; - -use crate::{Sqlite, SqliteConnection}; use sqlx_core::error::Error; use sqlx_core::transaction::TransactionManager; +use crate::{Sqlite, SqliteConnection}; + /// Implementation of [`TransactionManager`] for SQLite. pub struct SqliteTransactionManager; @@ -25,4 +25,8 @@ impl TransactionManager for SqliteTransactionManager { fn start_rollback(conn: &mut SqliteConnection) { conn.worker.start_rollback().ok(); } + + fn get_transaction_depth(conn: &SqliteConnection) -> usize { + conn.worker.shared.get_transaction_depth() + } } diff --git a/tests/postgres/postgres.rs b/tests/postgres/postgres.rs index 7edb5a7a8c..8e5d0d227e 100644 --- a/tests/postgres/postgres.rs +++ b/tests/postgres/postgres.rs @@ -515,6 +515,7 @@ async fn it_can_work_with_transactions() -> anyhow::Result<()> { #[sqlx_macros::test] async fn it_can_work_with_nested_transactions() -> anyhow::Result<()> { let mut conn = new::().await?; + assert_eq!(conn.get_transaction_depth(), 0); conn.execute("CREATE TABLE IF NOT EXISTS _sqlx_users_2523 (id INTEGER PRIMARY KEY)") .await?; @@ -523,6 +524,7 @@ async fn it_can_work_with_nested_transactions() -> anyhow::Result<()> { // begin let mut tx = conn.begin().await?; // transaction + assert_eq!(conn.get_transaction_depth(), 1); // insert a user sqlx::query("INSERT INTO _sqlx_users_2523 (id) VALUES ($1)") @@ -532,6 +534,7 @@ async fn it_can_work_with_nested_transactions() -> anyhow::Result<()> { // begin once more let mut tx2 = tx.begin().await?; // savepoint + assert_eq!(conn.get_transaction_depth(), 2); // insert another user sqlx::query("INSERT INTO _sqlx_users_2523 (id) VALUES ($1)") @@ -541,6 +544,7 @@ async fn it_can_work_with_nested_transactions() -> anyhow::Result<()> { // never mind, rollback tx2.rollback().await?; // roll that one back + assert_eq!(conn.get_transaction_depth(), 1); // did we really? let (count,): (i64,) = sqlx::query_as("SELECT COUNT(*) FROM _sqlx_users_2523") @@ -551,6 +555,7 @@ async fn it_can_work_with_nested_transactions() -> anyhow::Result<()> { // actually, commit tx.commit().await?; + assert_eq!(conn.get_transaction_depth(), 0); // did we really? let (count,): (i64,) = sqlx::query_as("SELECT COUNT(*) FROM _sqlx_users_2523")