Skip to content
20 changes: 20 additions & 0 deletions sqlx-core/src/any/connection/backend.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,26 @@ pub trait AnyConnectionBackend: std::any::Any + Debug + Send + 'static {

fn start_rollback(&mut self);

/// Returns the current transaction depth.
///
/// Transaction depth indicates the level of nested transactions:
/// - Level 0: No active transaction.
/// - Level 1: A transaction is active.
/// - Level 2 or higher: A transaction is active and one or more SAVEPOINTs have been created within it.
fn get_transaction_depth(&self) -> usize {
unimplemented!("get_transaction_depth() is not implemented for this backend. This is a provided method to avoid a breaking change, but it will become a required method in version 0.9 and later.");
}

/// Checks if the connection is currently in a transaction.
///
/// This method returns `true` if the current transaction depth is greater than 0,
/// indicating that a transaction is active. It returns `false` if the transaction depth is 0,
/// meaning no transaction is active.
#[inline]
fn is_in_transaction(&self) -> bool {
self.get_transaction_depth() != 0
}

/// The number of statements currently cached in the connection.
fn cached_statements_size(&self) -> usize {
0
Expand Down
4 changes: 4 additions & 0 deletions sqlx-core/src/any/connection/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,10 @@ impl Connection for AnyConnection {
Transaction::begin(self)
}

fn get_transaction_depth(&self) -> usize {
self.backend.get_transaction_depth()
}

fn cached_statements_size(&self) -> usize {
self.backend.cached_statements_size()
}
Expand Down
5 changes: 5 additions & 0 deletions sqlx-core/src/any/transaction.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use futures_util::future::BoxFuture;

use crate::any::{Any, AnyConnection};
use crate::database::Database;
use crate::error::Error;
use crate::transaction::TransactionManager;

Expand All @@ -24,4 +25,8 @@ impl TransactionManager for AnyTransactionManager {
fn start_rollback(conn: &mut AnyConnection) {
conn.backend.start_rollback()
}

fn get_transaction_depth(conn: &<Self::Database as Database>::Connection) -> usize {
conn.backend.get_transaction_depth()
}
}
23 changes: 22 additions & 1 deletion sqlx-core/src/connection.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use crate::database::{Database, HasStatementCache};
use crate::error::Error;

use crate::transaction::Transaction;
use crate::transaction::{Transaction, TransactionManager};
use futures_core::future::BoxFuture;
use log::LevelFilter;
use std::fmt::Debug;
Expand Down Expand Up @@ -49,6 +49,27 @@ pub trait Connection: Send {
where
Self: Sized;

/// Returns the current transaction depth.
///
/// Transaction depth indicates the level of nested transactions:
/// - Level 0: No active transaction.
/// - Level 1: A transaction is active.
/// - Level 2 or higher: A transaction is active and one or more SAVEPOINTs have been created within it.
fn get_transaction_depth(&self) -> usize {
// Fallback implementation to avoid breaking changes
<Self::Database as Database>::TransactionManager::get_transaction_depth(self)
}

/// Checks if the connection is currently in a transaction.
///
/// This method returns `true` if the current transaction depth is greater than 0,
/// indicating that a transaction is active. It returns `false` if the transaction depth is 0,
/// meaning no transaction is active.
#[inline]
fn is_in_transaction(&self) -> bool {
self.get_transaction_depth() != 0
}

/// Execute the function inside a transaction.
///
/// If the function returns an error, the transaction will be rolled back. If it does not
Expand Down
8 changes: 8 additions & 0 deletions sqlx-core/src/transaction.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,14 @@ pub trait TransactionManager {

/// Starts to abort the active transaction or restore from the most recent snapshot.
fn start_rollback(conn: &mut <Self::Database as Database>::Connection);

/// Returns the current transaction depth.
///
/// Transaction depth indicates the level of nested transactions:
/// - Level 0: No active transaction.
/// - Level 1: A transaction is active.
/// - Level 2 or higher: A transaction is active and one or more SAVEPOINTs have been created within it.
fn get_transaction_depth(conn: &<Self::Database as Database>::Connection) -> usize;
}

/// An in-progress database transaction or savepoint.
Expand Down
4 changes: 4 additions & 0 deletions sqlx-mysql/src/any.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,10 @@ impl AnyConnectionBackend for MySqlConnection {
MySqlTransactionManager::start_rollback(self)
}

fn get_transaction_depth(&self) -> usize {
MySqlTransactionManager::get_transaction_depth(self)
}

fn shrink_buffers(&mut self) {
Connection::shrink_buffers(self);
}
Expand Down
4 changes: 4 additions & 0 deletions sqlx-mysql/src/connection/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,10 @@ impl Connection for MySqlConnection {
Transaction::begin(self)
}

fn get_transaction_depth(&self) -> usize {
self.inner.transaction_depth
}

fn shrink_buffers(&mut self) {
self.inner.stream.shrink_buffers();
}
Expand Down
4 changes: 4 additions & 0 deletions sqlx-mysql/src/transaction.rs
Original file line number Diff line number Diff line change
Expand Up @@ -64,4 +64,8 @@ impl TransactionManager for MySqlTransactionManager {
conn.inner.transaction_depth = depth - 1;
}
}

fn get_transaction_depth(conn: &MySqlConnection) -> usize {
conn.inner.transaction_depth
}
}
4 changes: 4 additions & 0 deletions sqlx-postgres/src/any.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,10 @@ impl AnyConnectionBackend for PgConnection {
PgTransactionManager::start_rollback(self)
}

fn get_transaction_depth(&self) -> usize {
PgTransactionManager::get_transaction_depth(self)
}

fn shrink_buffers(&mut self) {
Connection::shrink_buffers(self);
}
Expand Down
4 changes: 4 additions & 0 deletions sqlx-postgres/src/connection/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,10 @@ impl Connection for PgConnection {
Transaction::begin(self)
}

fn get_transaction_depth(&self) -> usize {
self.transaction_depth
}

fn cached_statements_size(&self) -> usize {
self.cache_statement.len()
}
Expand Down
5 changes: 5 additions & 0 deletions sqlx-postgres/src/transaction.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use futures_core::future::BoxFuture;
use sqlx_core::database::Database;

use crate::error::Error;
use crate::executor::Executor;
Expand Down Expand Up @@ -59,6 +60,10 @@ impl TransactionManager for PgTransactionManager {
conn.transaction_depth -= 1;
}
}

fn get_transaction_depth(conn: &<Self::Database as Database>::Connection) -> usize {
conn.transaction_depth
}
}

struct Rollback<'c> {
Expand Down
4 changes: 4 additions & 0 deletions sqlx-sqlite/src/any.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,10 @@ impl AnyConnectionBackend for SqliteConnection {
SqliteTransactionManager::start_rollback(self)
}

fn get_transaction_depth(&self) -> usize {
SqliteTransactionManager::get_transaction_depth(self)
}

fn shrink_buffers(&mut self) {
// NO-OP.
}
Expand Down
1 change: 0 additions & 1 deletion sqlx-sqlite/src/connection/establish.rs
Original file line number Diff line number Diff line change
Expand Up @@ -292,7 +292,6 @@ impl EstablishParams {
Ok(ConnectionState {
handle,
statements: Statements::new(self.statement_cache_capacity),
transaction_depth: 0,
log_settings: self.log_settings.clone(),
progress_handler_callback: None,
update_hook_callback: None,
Expand Down
12 changes: 5 additions & 7 deletions sqlx-sqlite/src/connection/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -94,9 +94,6 @@ unsafe impl Send for UpdateHookHandler {}
pub(crate) struct ConnectionState {
pub(crate) handle: ConnectionHandle,

// transaction status
pub(crate) transaction_depth: usize,

pub(crate) statements: Statements,

log_settings: LogSettings,
Expand Down Expand Up @@ -210,11 +207,12 @@ impl Connection for SqliteConnection {
Transaction::begin(self)
}

fn get_transaction_depth(&self) -> usize {
self.worker.shared.get_transaction_depth()
}

fn cached_statements_size(&self) -> usize {
self.worker
.shared
.cached_statements_size
.load(std::sync::atomic::Ordering::Acquire)
self.worker.shared.get_cached_statements_size()
}

fn clear_cached_statements(&mut self) -> BoxFuture<'_, Result<(), Error>> {
Expand Down
28 changes: 20 additions & 8 deletions sqlx-sqlite/src/connection/worker.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,21 @@ pub(crate) struct ConnectionWorker {
}

pub(crate) struct WorkerSharedState {
pub(crate) cached_statements_size: AtomicUsize,
transaction_depth: AtomicUsize,
cached_statements_size: AtomicUsize,
pub(crate) conn: Mutex<ConnectionState>,
}

impl WorkerSharedState {
pub(crate) fn get_transaction_depth(&self) -> usize {
self.transaction_depth.load(Ordering::Acquire)
}

pub(crate) fn get_cached_statements_size(&self) -> usize {
self.cached_statements_size.load(Ordering::Acquire)
}
}

enum Command {
Prepare {
query: Box<str>,
Expand Down Expand Up @@ -93,6 +104,7 @@ impl ConnectionWorker {
};

let shared = Arc::new(WorkerSharedState {
transaction_depth: AtomicUsize::new(0),
cached_statements_size: AtomicUsize::new(0),
// note: must be fair because in `Command::UnlockDb` we unlock the mutex
// and then immediately try to relock it; an unfair mutex would immediately
Expand Down Expand Up @@ -181,12 +193,12 @@ impl ConnectionWorker {
update_cached_statements_size(&conn, &shared.cached_statements_size);
}
Command::Begin { tx } => {
let depth = conn.transaction_depth;
let depth = shared.transaction_depth.load(Ordering::Acquire);
let res =
conn.handle
.exec(begin_ansi_transaction_sql(depth))
.map(|_| {
conn.transaction_depth += 1;
shared.transaction_depth.fetch_add(1, Ordering::Release);
});
let res_ok = res.is_ok();

Expand All @@ -199,7 +211,7 @@ impl ConnectionWorker {
.handle
.exec(rollback_ansi_transaction_sql(depth + 1))
.map(|_| {
conn.transaction_depth -= 1;
shared.transaction_depth.fetch_sub(1, Ordering::Release);
})
{
// The rollback failed. To prevent leaving the connection
Expand All @@ -211,13 +223,13 @@ impl ConnectionWorker {
}
}
Command::Commit { tx } => {
let depth = conn.transaction_depth;
let depth = shared.transaction_depth.load(Ordering::Acquire);

let res = if depth > 0 {
conn.handle
.exec(commit_ansi_transaction_sql(depth))
.map(|_| {
conn.transaction_depth -= 1;
shared.transaction_depth.fetch_sub(1, Ordering::Release);
})
} else {
Ok(())
Expand All @@ -237,13 +249,13 @@ impl ConnectionWorker {
continue;
}

let depth = conn.transaction_depth;
let depth = shared.transaction_depth.load(Ordering::Acquire);

let res = if depth > 0 {
conn.handle
.exec(rollback_ansi_transaction_sql(depth))
.map(|_| {
conn.transaction_depth -= 1;
shared.transaction_depth.fetch_sub(1, Ordering::Release);
})
} else {
Ok(())
Expand Down
8 changes: 6 additions & 2 deletions sqlx-sqlite/src/transaction.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
use futures_core::future::BoxFuture;

use crate::{Sqlite, SqliteConnection};
use sqlx_core::error::Error;
use sqlx_core::transaction::TransactionManager;

use crate::{Sqlite, SqliteConnection};

/// Implementation of [`TransactionManager`] for SQLite.
pub struct SqliteTransactionManager;

Expand All @@ -25,4 +25,8 @@ impl TransactionManager for SqliteTransactionManager {
fn start_rollback(conn: &mut SqliteConnection) {
conn.worker.start_rollback().ok();
}

fn get_transaction_depth(conn: &SqliteConnection) -> usize {
conn.worker.shared.get_transaction_depth()
}
}
5 changes: 5 additions & 0 deletions tests/postgres/postgres.rs
Original file line number Diff line number Diff line change
Expand Up @@ -515,6 +515,7 @@ async fn it_can_work_with_transactions() -> anyhow::Result<()> {
#[sqlx_macros::test]
async fn it_can_work_with_nested_transactions() -> anyhow::Result<()> {
let mut conn = new::<Postgres>().await?;
assert_eq!(conn.get_transaction_depth(), 0);

conn.execute("CREATE TABLE IF NOT EXISTS _sqlx_users_2523 (id INTEGER PRIMARY KEY)")
.await?;
Expand All @@ -523,6 +524,7 @@ async fn it_can_work_with_nested_transactions() -> anyhow::Result<()> {

// begin
let mut tx = conn.begin().await?; // transaction
assert_eq!(conn.get_transaction_depth(), 1);

// insert a user
sqlx::query("INSERT INTO _sqlx_users_2523 (id) VALUES ($1)")
Expand All @@ -532,6 +534,7 @@ async fn it_can_work_with_nested_transactions() -> anyhow::Result<()> {

// begin once more
let mut tx2 = tx.begin().await?; // savepoint
assert_eq!(conn.get_transaction_depth(), 2);

// insert another user
sqlx::query("INSERT INTO _sqlx_users_2523 (id) VALUES ($1)")
Expand All @@ -541,6 +544,7 @@ async fn it_can_work_with_nested_transactions() -> anyhow::Result<()> {

// never mind, rollback
tx2.rollback().await?; // roll that one back
assert_eq!(conn.get_transaction_depth(), 1);

// did we really?
let (count,): (i64,) = sqlx::query_as("SELECT COUNT(*) FROM _sqlx_users_2523")
Expand All @@ -551,6 +555,7 @@ async fn it_can_work_with_nested_transactions() -> anyhow::Result<()> {

// actually, commit
tx.commit().await?;
assert_eq!(conn.get_transaction_depth(), 0);

// did we really?
let (count,): (i64,) = sqlx::query_as("SELECT COUNT(*) FROM _sqlx_users_2523")
Expand Down