From 0ef8393738c20b2ae2fc69512c3d6844c9865b11 Mon Sep 17 00:00:00 2001 From: mpyw Date: Sun, 11 Aug 2024 18:13:49 +0900 Subject: [PATCH 01/27] feat: Implement `get_transaction_depth` for drivers --- sqlx-core/src/any/connection/backend.rs | 9 +++++++++ sqlx-core/src/any/connection/mod.rs | 8 +++++++- sqlx-core/src/connection.rs | 20 ++++++++++++++++++++ sqlx-mysql/src/any.rs | 7 ++++++- sqlx-mysql/src/connection/mod.rs | 6 ++++++ sqlx-postgres/src/any.rs | 7 ++++++- sqlx-postgres/src/connection/mod.rs | 6 ++++++ sqlx-sqlite/src/any.rs | 7 ++++++- sqlx-sqlite/src/connection/mod.rs | 6 ++++++ src/lib.rs | 4 +++- 10 files changed, 75 insertions(+), 5 deletions(-) diff --git a/sqlx-core/src/any/connection/backend.rs b/sqlx-core/src/any/connection/backend.rs index b30cbe83f3..a2393832dc 100644 --- a/sqlx-core/src/any/connection/backend.rs +++ b/sqlx-core/src/any/connection/backend.rs @@ -1,5 +1,6 @@ use crate::any::{Any, AnyArguments, AnyQueryResult, AnyRow, AnyStatement, AnyTypeInfo}; use crate::describe::Describe; +use crate::Error; use either::Either; use futures_core::future::BoxFuture; use futures_core::stream::BoxStream; @@ -34,6 +35,14 @@ pub trait AnyConnectionBackend: std::any::Any + Debug + Send + 'static { fn start_rollback(&mut self); + /// Returns the current transaction depth asynchronously. + /// + /// Transaction depth indicates the level of nested transactions: + /// - Level 0: No active transaction. + /// - Level 1: A transaction is active. + /// - Level 2 or higher: A transaction is active and one or more SAVEPOINTs have been created within it. + fn get_transaction_depth(&mut self) -> BoxFuture<'_, Result>; + /// The number of statements currently cached in the connection. fn cached_statements_size(&self) -> usize { 0 diff --git a/sqlx-core/src/any/connection/mod.rs b/sqlx-core/src/any/connection/mod.rs index b6f795848a..c2274d9885 100644 --- a/sqlx-core/src/any/connection/mod.rs +++ b/sqlx-core/src/any/connection/mod.rs @@ -1,7 +1,7 @@ use futures_core::future::BoxFuture; use crate::any::{Any, AnyConnectOptions}; -use crate::connection::{ConnectOptions, Connection}; +use crate::connection::{AsyncTransactionDepth, ConnectOptions, Connection}; use crate::error::Error; use crate::database::Database; @@ -112,3 +112,9 @@ impl Connection for AnyConnection { self.backend.should_flush() } } + +impl AsyncTransactionDepth for AnyConnection { + fn get_transaction_depth(&mut self) -> BoxFuture<'_, Result> { + self.backend.get_transaction_depth() + } +} diff --git a/sqlx-core/src/connection.rs b/sqlx-core/src/connection.rs index ce2aa6c629..155f29d268 100644 --- a/sqlx-core/src/connection.rs +++ b/sqlx-core/src/connection.rs @@ -236,3 +236,23 @@ pub trait ConnectOptions: 'static + Send + Sync + FromStr + Debug + .log_slow_statements(LevelFilter::Off, Duration::default()) } } + +pub trait TransactionDepth { + /// Returns the current transaction depth synchronously. + /// + /// Transaction depth indicates the level of nested transactions: + /// - Level 0: No active transaction. + /// - Level 1: A transaction is active. + /// - Level 2 or higher: A transaction is active and one or more SAVEPOINTs have been created within it. + fn get_transaction_depth(&self) -> usize; +} + +pub trait AsyncTransactionDepth { + /// Returns the current transaction depth asynchronously. + /// + /// Transaction depth indicates the level of nested transactions: + /// - Level 0: No active transaction. + /// - Level 1: A transaction is active. + /// - Level 2 or higher: A transaction is active and one or more SAVEPOINTs have been created within it. + fn get_transaction_depth(&mut self) -> BoxFuture<'_, Result>; +} diff --git a/sqlx-mysql/src/any.rs b/sqlx-mysql/src/any.rs index fa8d34f8db..c688ad0d48 100644 --- a/sqlx-mysql/src/any.rs +++ b/sqlx-mysql/src/any.rs @@ -11,11 +11,12 @@ use sqlx_core::any::{ Any, AnyArguments, AnyColumn, AnyConnectOptions, AnyConnectionBackend, AnyQueryResult, AnyRow, AnyStatement, AnyTypeInfo, AnyTypeInfoKind, }; -use sqlx_core::connection::Connection; +use sqlx_core::connection::{Connection, TransactionDepth}; use sqlx_core::database::Database; use sqlx_core::describe::Describe; use sqlx_core::executor::Executor; use sqlx_core::transaction::TransactionManager; +use sqlx_core::Error; use std::future; sqlx_core::declare_driver_with_optional_migrate!(DRIVER = MySql); @@ -53,6 +54,10 @@ impl AnyConnectionBackend for MySqlConnection { MySqlTransactionManager::start_rollback(self) } + fn get_transaction_depth(&mut self) -> BoxFuture<'_, Result> { + Box::pin(async { Ok(TransactionDepth::get_transaction_depth(self)) }) + } + fn shrink_buffers(&mut self) { Connection::shrink_buffers(self); } diff --git a/sqlx-mysql/src/connection/mod.rs b/sqlx-mysql/src/connection/mod.rs index c4978a7701..b2a9a94b42 100644 --- a/sqlx-mysql/src/connection/mod.rs +++ b/sqlx-mysql/src/connection/mod.rs @@ -118,3 +118,9 @@ impl Connection for MySqlConnection { self.inner.stream.shrink_buffers(); } } + +impl TransactionDepth for MySqlConnection { + fn get_transaction_depth(&self) -> usize { + self.inner.transaction_depth + } +} diff --git a/sqlx-postgres/src/any.rs b/sqlx-postgres/src/any.rs index 7eae4bcb73..d0e0459721 100644 --- a/sqlx-postgres/src/any.rs +++ b/sqlx-postgres/src/any.rs @@ -10,12 +10,13 @@ use std::future; pub use sqlx_core::any::*; use crate::type_info::PgType; -use sqlx_core::connection::Connection; +use sqlx_core::connection::{Connection, TransactionDepth}; use sqlx_core::database::Database; use sqlx_core::describe::Describe; use sqlx_core::executor::Executor; use sqlx_core::ext::ustr::UStr; use sqlx_core::transaction::TransactionManager; +use sqlx_core::Error; sqlx_core::declare_driver_with_optional_migrate!(DRIVER = Postgres); @@ -52,6 +53,10 @@ impl AnyConnectionBackend for PgConnection { PgTransactionManager::start_rollback(self) } + fn get_transaction_depth(&mut self) -> BoxFuture<'_, Result> { + Box::pin(async { Ok(TransactionDepth::get_transaction_depth(self)) }) + } + fn shrink_buffers(&mut self) { Connection::shrink_buffers(self); } diff --git a/sqlx-postgres/src/connection/mod.rs b/sqlx-postgres/src/connection/mod.rs index 1c7a468240..8e3143fec7 100644 --- a/sqlx-postgres/src/connection/mod.rs +++ b/sqlx-postgres/src/connection/mod.rs @@ -224,3 +224,9 @@ impl AsMut for PgConnection { self } } + +impl TransactionDepth for PgConnection { + fn get_transaction_depth(&self) -> usize { + self.transaction_depth + } +} diff --git a/sqlx-sqlite/src/any.rs b/sqlx-sqlite/src/any.rs index 01600d9931..c14a4eaa29 100644 --- a/sqlx-sqlite/src/any.rs +++ b/sqlx-sqlite/src/any.rs @@ -12,11 +12,12 @@ use sqlx_core::any::{ }; use crate::type_info::DataType; -use sqlx_core::connection::{ConnectOptions, Connection}; +use sqlx_core::connection::{AsyncTransactionDepth, ConnectOptions, Connection}; use sqlx_core::database::Database; use sqlx_core::describe::Describe; use sqlx_core::executor::Executor; use sqlx_core::transaction::TransactionManager; +use sqlx_core::Error; sqlx_core::declare_driver_with_optional_migrate!(DRIVER = Sqlite); @@ -53,6 +54,10 @@ impl AnyConnectionBackend for SqliteConnection { SqliteTransactionManager::start_rollback(self) } + fn get_transaction_depth(&mut self) -> BoxFuture<'_, Result> { + AsyncTransactionDepth::get_transaction_depth(self) + } + fn shrink_buffers(&mut self) { // NO-OP. } diff --git a/sqlx-sqlite/src/connection/mod.rs b/sqlx-sqlite/src/connection/mod.rs index 3588b94f82..a58428e14a 100644 --- a/sqlx-sqlite/src/connection/mod.rs +++ b/sqlx-sqlite/src/connection/mod.rs @@ -426,3 +426,9 @@ impl Statements { self.temp = None; } } + +impl AsyncTransactionDepth for SqliteConnection { + fn get_transaction_depth(&mut self) -> BoxFuture<'_, Result> { + Box::pin(async { Ok(self.lock_handle().await?.guard.transaction_depth) }) + } +} diff --git a/src/lib.rs b/src/lib.rs index d675fa11c3..7a0a7c5fa0 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -5,7 +5,9 @@ pub use sqlx_core::acquire::Acquire; pub use sqlx_core::arguments::{Arguments, IntoArguments}; pub use sqlx_core::column::Column; pub use sqlx_core::column::ColumnIndex; -pub use sqlx_core::connection::{ConnectOptions, Connection}; +pub use sqlx_core::connection::{ + AsyncTransactionDepth, ConnectOptions, Connection, TransactionDepth, +}; pub use sqlx_core::database::{self, Database}; pub use sqlx_core::describe::Describe; pub use sqlx_core::executor::{Execute, Executor}; From 4f9d3f9579f5c384957bf6ec652d90176a5219d7 Mon Sep 17 00:00:00 2001 From: mpyw Date: Sun, 11 Aug 2024 18:21:16 +0900 Subject: [PATCH 02/27] test: Verify `get_transaction_depth()` on postgres --- tests/postgres/postgres.rs | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/tests/postgres/postgres.rs b/tests/postgres/postgres.rs index 7edb5a7a8c..cdc518c666 100644 --- a/tests/postgres/postgres.rs +++ b/tests/postgres/postgres.rs @@ -6,6 +6,7 @@ use sqlx::postgres::{ PgPoolOptions, PgRow, PgSeverity, Postgres, }; use sqlx::{Column, Connection, Executor, Row, Statement, TypeInfo}; +use sqlx_core::connection::TransactionDepth; use sqlx_core::{bytes::Bytes, error::BoxDynError}; use sqlx_test::{new, pool, setup_if_needed}; use std::env; @@ -515,6 +516,7 @@ async fn it_can_work_with_transactions() -> anyhow::Result<()> { #[sqlx_macros::test] async fn it_can_work_with_nested_transactions() -> anyhow::Result<()> { let mut conn = new::().await?; + assert_eq!(conn.get_transaction_depth(), 0); conn.execute("CREATE TABLE IF NOT EXISTS _sqlx_users_2523 (id INTEGER PRIMARY KEY)") .await?; @@ -523,6 +525,7 @@ async fn it_can_work_with_nested_transactions() -> anyhow::Result<()> { // begin let mut tx = conn.begin().await?; // transaction + assert_eq!(conn.get_transaction_depth(), 1); // insert a user sqlx::query("INSERT INTO _sqlx_users_2523 (id) VALUES ($1)") @@ -532,6 +535,7 @@ async fn it_can_work_with_nested_transactions() -> anyhow::Result<()> { // begin once more let mut tx2 = tx.begin().await?; // savepoint + assert_eq!(conn.get_transaction_depth(), 2); // insert another user sqlx::query("INSERT INTO _sqlx_users_2523 (id) VALUES ($1)") @@ -541,6 +545,7 @@ async fn it_can_work_with_nested_transactions() -> anyhow::Result<()> { // never mind, rollback tx2.rollback().await?; // roll that one back + assert_eq!(conn.get_transaction_depth(), 1); // did we really? let (count,): (i64,) = sqlx::query_as("SELECT COUNT(*) FROM _sqlx_users_2523") @@ -551,6 +556,7 @@ async fn it_can_work_with_nested_transactions() -> anyhow::Result<()> { // actually, commit tx.commit().await?; + assert_eq!(conn.get_transaction_depth(), 0); // did we really? let (count,): (i64,) = sqlx::query_as("SELECT COUNT(*) FROM _sqlx_users_2523") From 8f450196e0e9a8d1bb171cc805452a84d91cc0ce Mon Sep 17 00:00:00 2001 From: mpyw Date: Mon, 12 Aug 2024 14:59:50 +0900 Subject: [PATCH 03/27] Refactor: `TransactionManager` delegation without BC SQLite implementation is currently WIP --- sqlx-core/src/any/connection/backend.rs | 5 ++- sqlx-core/src/any/connection/mod.rs | 12 +++---- sqlx-core/src/any/transaction.rs | 5 +++ sqlx-core/src/connection.rs | 43 +++++++++++++------------ sqlx-core/src/transaction.rs | 8 +++++ sqlx-mysql/src/any.rs | 7 ++-- sqlx-mysql/src/connection/mod.rs | 10 +++--- sqlx-mysql/src/transaction.rs | 4 +++ sqlx-postgres/src/any.rs | 7 ++-- sqlx-postgres/src/connection/mod.rs | 10 +++--- sqlx-postgres/src/transaction.rs | 5 +++ sqlx-sqlite/src/any.rs | 7 ++-- sqlx-sqlite/src/connection/mod.rs | 10 +++--- sqlx-sqlite/src/transaction.rs | 9 ++++-- src/lib.rs | 4 +-- tests/postgres/postgres.rs | 1 - 16 files changed, 80 insertions(+), 67 deletions(-) diff --git a/sqlx-core/src/any/connection/backend.rs b/sqlx-core/src/any/connection/backend.rs index a2393832dc..e4e752c07b 100644 --- a/sqlx-core/src/any/connection/backend.rs +++ b/sqlx-core/src/any/connection/backend.rs @@ -1,6 +1,5 @@ use crate::any::{Any, AnyArguments, AnyQueryResult, AnyRow, AnyStatement, AnyTypeInfo}; use crate::describe::Describe; -use crate::Error; use either::Either; use futures_core::future::BoxFuture; use futures_core::stream::BoxStream; @@ -35,13 +34,13 @@ pub trait AnyConnectionBackend: std::any::Any + Debug + Send + 'static { fn start_rollback(&mut self); - /// Returns the current transaction depth asynchronously. + /// Returns the current transaction depth. /// /// Transaction depth indicates the level of nested transactions: /// - Level 0: No active transaction. /// - Level 1: A transaction is active. /// - Level 2 or higher: A transaction is active and one or more SAVEPOINTs have been created within it. - fn get_transaction_depth(&mut self) -> BoxFuture<'_, Result>; + fn get_transaction_depth(&self) -> usize; /// The number of statements currently cached in the connection. fn cached_statements_size(&self) -> usize { diff --git a/sqlx-core/src/any/connection/mod.rs b/sqlx-core/src/any/connection/mod.rs index c2274d9885..ba06d865f1 100644 --- a/sqlx-core/src/any/connection/mod.rs +++ b/sqlx-core/src/any/connection/mod.rs @@ -1,7 +1,7 @@ use futures_core::future::BoxFuture; use crate::any::{Any, AnyConnectOptions}; -use crate::connection::{AsyncTransactionDepth, ConnectOptions, Connection}; +use crate::connection::{ConnectOptions, Connection}; use crate::error::Error; use crate::database::Database; @@ -90,6 +90,10 @@ impl Connection for AnyConnection { Transaction::begin(self) } + fn get_transaction_depth(&self) -> usize { + self.backend.get_transaction_depth() + } + fn cached_statements_size(&self) -> usize { self.backend.cached_statements_size() } @@ -112,9 +116,3 @@ impl Connection for AnyConnection { self.backend.should_flush() } } - -impl AsyncTransactionDepth for AnyConnection { - fn get_transaction_depth(&mut self) -> BoxFuture<'_, Result> { - self.backend.get_transaction_depth() - } -} diff --git a/sqlx-core/src/any/transaction.rs b/sqlx-core/src/any/transaction.rs index fce4175626..22937a0bd1 100644 --- a/sqlx-core/src/any/transaction.rs +++ b/sqlx-core/src/any/transaction.rs @@ -1,6 +1,7 @@ use futures_util::future::BoxFuture; use crate::any::{Any, AnyConnection}; +use crate::database::Database; use crate::error::Error; use crate::transaction::TransactionManager; @@ -24,4 +25,8 @@ impl TransactionManager for AnyTransactionManager { fn start_rollback(conn: &mut AnyConnection) { conn.backend.start_rollback() } + + fn get_transaction_depth(conn: &::Connection) -> usize { + conn.backend.get_transaction_depth() + } } diff --git a/sqlx-core/src/connection.rs b/sqlx-core/src/connection.rs index 155f29d268..9faa71effe 100644 --- a/sqlx-core/src/connection.rs +++ b/sqlx-core/src/connection.rs @@ -1,7 +1,7 @@ use crate::database::{Database, HasStatementCache}; use crate::error::Error; -use crate::transaction::Transaction; +use crate::transaction::{Transaction, TransactionManager}; use futures_core::future::BoxFuture; use log::LevelFilter; use std::fmt::Debug; @@ -49,6 +49,27 @@ pub trait Connection: Send { where Self: Sized; + /// Returns the current transaction depth synchronously. + /// + /// Transaction depth indicates the level of nested transactions: + /// - Level 0: No active transaction. + /// - Level 1: A transaction is active. + /// - Level 2 or higher: A transaction is active and one or more SAVEPOINTs have been created within it. + fn get_transaction_depth(&self) -> usize { + // Fallback implementation to avoid breaking changes + ::TransactionManager::get_transaction_depth(self) + } + + /// Checks if the connection is currently in a transaction. + /// + /// This method returns `true` if the current transaction depth is greater than 0, + /// indicating that a transaction is active. It returns `false` if the transaction depth is 0, + /// meaning no transaction is active. + #[inline] + fn is_in_transaction(&self) -> bool { + self.get_transaction_depth() != 0 + } + /// Execute the function inside a transaction. /// /// If the function returns an error, the transaction will be rolled back. If it does not @@ -236,23 +257,3 @@ pub trait ConnectOptions: 'static + Send + Sync + FromStr + Debug + .log_slow_statements(LevelFilter::Off, Duration::default()) } } - -pub trait TransactionDepth { - /// Returns the current transaction depth synchronously. - /// - /// Transaction depth indicates the level of nested transactions: - /// - Level 0: No active transaction. - /// - Level 1: A transaction is active. - /// - Level 2 or higher: A transaction is active and one or more SAVEPOINTs have been created within it. - fn get_transaction_depth(&self) -> usize; -} - -pub trait AsyncTransactionDepth { - /// Returns the current transaction depth asynchronously. - /// - /// Transaction depth indicates the level of nested transactions: - /// - Level 0: No active transaction. - /// - Level 1: A transaction is active. - /// - Level 2 or higher: A transaction is active and one or more SAVEPOINTs have been created within it. - fn get_transaction_depth(&mut self) -> BoxFuture<'_, Result>; -} diff --git a/sqlx-core/src/transaction.rs b/sqlx-core/src/transaction.rs index 9cd38aab3a..48b1463f2a 100644 --- a/sqlx-core/src/transaction.rs +++ b/sqlx-core/src/transaction.rs @@ -32,6 +32,14 @@ pub trait TransactionManager { /// Starts to abort the active transaction or restore from the most recent snapshot. fn start_rollback(conn: &mut ::Connection); + + /// Returns the current transaction depth. + /// + /// Transaction depth indicates the level of nested transactions: + /// - Level 0: No active transaction. + /// - Level 1: A transaction is active. + /// - Level 2 or higher: A transaction is active and one or more SAVEPOINTs have been created within it. + fn get_transaction_depth(conn: &::Connection) -> usize; } /// An in-progress database transaction or savepoint. diff --git a/sqlx-mysql/src/any.rs b/sqlx-mysql/src/any.rs index c688ad0d48..709fdafcea 100644 --- a/sqlx-mysql/src/any.rs +++ b/sqlx-mysql/src/any.rs @@ -11,12 +11,11 @@ use sqlx_core::any::{ Any, AnyArguments, AnyColumn, AnyConnectOptions, AnyConnectionBackend, AnyQueryResult, AnyRow, AnyStatement, AnyTypeInfo, AnyTypeInfoKind, }; -use sqlx_core::connection::{Connection, TransactionDepth}; +use sqlx_core::connection::Connection; use sqlx_core::database::Database; use sqlx_core::describe::Describe; use sqlx_core::executor::Executor; use sqlx_core::transaction::TransactionManager; -use sqlx_core::Error; use std::future; sqlx_core::declare_driver_with_optional_migrate!(DRIVER = MySql); @@ -54,8 +53,8 @@ impl AnyConnectionBackend for MySqlConnection { MySqlTransactionManager::start_rollback(self) } - fn get_transaction_depth(&mut self) -> BoxFuture<'_, Result> { - Box::pin(async { Ok(TransactionDepth::get_transaction_depth(self)) }) + fn get_transaction_depth(&self) -> usize { + MySqlTransactionManager::get_transaction_depth(self) } fn shrink_buffers(&mut self) { diff --git a/sqlx-mysql/src/connection/mod.rs b/sqlx-mysql/src/connection/mod.rs index b2a9a94b42..af50188ff6 100644 --- a/sqlx-mysql/src/connection/mod.rs +++ b/sqlx-mysql/src/connection/mod.rs @@ -114,13 +114,11 @@ impl Connection for MySqlConnection { Transaction::begin(self) } - fn shrink_buffers(&mut self) { - self.inner.stream.shrink_buffers(); - } -} - -impl TransactionDepth for MySqlConnection { fn get_transaction_depth(&self) -> usize { self.inner.transaction_depth } + + fn shrink_buffers(&mut self) { + self.inner.stream.shrink_buffers(); + } } diff --git a/sqlx-mysql/src/transaction.rs b/sqlx-mysql/src/transaction.rs index 99d6526392..5a4c751b2c 100644 --- a/sqlx-mysql/src/transaction.rs +++ b/sqlx-mysql/src/transaction.rs @@ -64,4 +64,8 @@ impl TransactionManager for MySqlTransactionManager { conn.inner.transaction_depth = depth - 1; } } + + fn get_transaction_depth(conn: &MySqlConnection) -> usize { + conn.inner.transaction_depth + } } diff --git a/sqlx-postgres/src/any.rs b/sqlx-postgres/src/any.rs index d0e0459721..bcfef8dce1 100644 --- a/sqlx-postgres/src/any.rs +++ b/sqlx-postgres/src/any.rs @@ -10,13 +10,12 @@ use std::future; pub use sqlx_core::any::*; use crate::type_info::PgType; -use sqlx_core::connection::{Connection, TransactionDepth}; +use sqlx_core::connection::Connection; use sqlx_core::database::Database; use sqlx_core::describe::Describe; use sqlx_core::executor::Executor; use sqlx_core::ext::ustr::UStr; use sqlx_core::transaction::TransactionManager; -use sqlx_core::Error; sqlx_core::declare_driver_with_optional_migrate!(DRIVER = Postgres); @@ -53,8 +52,8 @@ impl AnyConnectionBackend for PgConnection { PgTransactionManager::start_rollback(self) } - fn get_transaction_depth(&mut self) -> BoxFuture<'_, Result> { - Box::pin(async { Ok(TransactionDepth::get_transaction_depth(self)) }) + fn get_transaction_depth(&mut self) -> usize { + PgTransactionManager::get_transaction_depth(self) } fn shrink_buffers(&mut self) { diff --git a/sqlx-postgres/src/connection/mod.rs b/sqlx-postgres/src/connection/mod.rs index 8e3143fec7..dcfc6206a5 100644 --- a/sqlx-postgres/src/connection/mod.rs +++ b/sqlx-postgres/src/connection/mod.rs @@ -171,6 +171,10 @@ impl Connection for PgConnection { Transaction::begin(self) } + fn get_transaction_depth(&self) -> usize { + self.transaction_depth + } + fn cached_statements_size(&self) -> usize { self.cache_statement.len() } @@ -224,9 +228,3 @@ impl AsMut for PgConnection { self } } - -impl TransactionDepth for PgConnection { - fn get_transaction_depth(&self) -> usize { - self.transaction_depth - } -} diff --git a/sqlx-postgres/src/transaction.rs b/sqlx-postgres/src/transaction.rs index 02028624e1..40a2de2e6f 100644 --- a/sqlx-postgres/src/transaction.rs +++ b/sqlx-postgres/src/transaction.rs @@ -1,4 +1,5 @@ use futures_core::future::BoxFuture; +use sqlx_core::database::Database; use crate::error::Error; use crate::executor::Executor; @@ -59,6 +60,10 @@ impl TransactionManager for PgTransactionManager { conn.transaction_depth -= 1; } } + + fn get_transaction_depth(conn: &::Connection) -> usize { + conn.transaction_depth + } } struct Rollback<'c> { diff --git a/sqlx-sqlite/src/any.rs b/sqlx-sqlite/src/any.rs index c14a4eaa29..29fb6bd7a6 100644 --- a/sqlx-sqlite/src/any.rs +++ b/sqlx-sqlite/src/any.rs @@ -12,12 +12,11 @@ use sqlx_core::any::{ }; use crate::type_info::DataType; -use sqlx_core::connection::{AsyncTransactionDepth, ConnectOptions, Connection}; +use sqlx_core::connection::{ConnectOptions, Connection}; use sqlx_core::database::Database; use sqlx_core::describe::Describe; use sqlx_core::executor::Executor; use sqlx_core::transaction::TransactionManager; -use sqlx_core::Error; sqlx_core::declare_driver_with_optional_migrate!(DRIVER = Sqlite); @@ -54,8 +53,8 @@ impl AnyConnectionBackend for SqliteConnection { SqliteTransactionManager::start_rollback(self) } - fn get_transaction_depth(&mut self) -> BoxFuture<'_, Result> { - AsyncTransactionDepth::get_transaction_depth(self) + fn get_transaction_depth(&self) -> usize { + SqliteTransactionManager::get_transaction_depth(self) } fn shrink_buffers(&mut self) { diff --git a/sqlx-sqlite/src/connection/mod.rs b/sqlx-sqlite/src/connection/mod.rs index a58428e14a..f12f49edbd 100644 --- a/sqlx-sqlite/src/connection/mod.rs +++ b/sqlx-sqlite/src/connection/mod.rs @@ -210,6 +210,10 @@ impl Connection for SqliteConnection { Transaction::begin(self) } + fn get_transaction_depth(&self) -> usize { + todo!() + } + fn cached_statements_size(&self) -> usize { self.worker .shared @@ -426,9 +430,3 @@ impl Statements { self.temp = None; } } - -impl AsyncTransactionDepth for SqliteConnection { - fn get_transaction_depth(&mut self) -> BoxFuture<'_, Result> { - Box::pin(async { Ok(self.lock_handle().await?.guard.transaction_depth) }) - } -} diff --git a/sqlx-sqlite/src/transaction.rs b/sqlx-sqlite/src/transaction.rs index 24eaca51b1..dafc33359d 100644 --- a/sqlx-sqlite/src/transaction.rs +++ b/sqlx-sqlite/src/transaction.rs @@ -1,9 +1,10 @@ use futures_core::future::BoxFuture; - -use crate::{Sqlite, SqliteConnection}; +use sqlx_core::database::Database; use sqlx_core::error::Error; use sqlx_core::transaction::TransactionManager; +use crate::{Sqlite, SqliteConnection}; + /// Implementation of [`TransactionManager`] for SQLite. pub struct SqliteTransactionManager; @@ -25,4 +26,8 @@ impl TransactionManager for SqliteTransactionManager { fn start_rollback(conn: &mut SqliteConnection) { conn.worker.start_rollback().ok(); } + + fn get_transaction_depth(_conn: &::Connection) -> usize { + todo!() + } } diff --git a/src/lib.rs b/src/lib.rs index 7a0a7c5fa0..d675fa11c3 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -5,9 +5,7 @@ pub use sqlx_core::acquire::Acquire; pub use sqlx_core::arguments::{Arguments, IntoArguments}; pub use sqlx_core::column::Column; pub use sqlx_core::column::ColumnIndex; -pub use sqlx_core::connection::{ - AsyncTransactionDepth, ConnectOptions, Connection, TransactionDepth, -}; +pub use sqlx_core::connection::{ConnectOptions, Connection}; pub use sqlx_core::database::{self, Database}; pub use sqlx_core::describe::Describe; pub use sqlx_core::executor::{Execute, Executor}; diff --git a/tests/postgres/postgres.rs b/tests/postgres/postgres.rs index cdc518c666..8e5d0d227e 100644 --- a/tests/postgres/postgres.rs +++ b/tests/postgres/postgres.rs @@ -6,7 +6,6 @@ use sqlx::postgres::{ PgPoolOptions, PgRow, PgSeverity, Postgres, }; use sqlx::{Column, Connection, Executor, Row, Statement, TypeInfo}; -use sqlx_core::connection::TransactionDepth; use sqlx_core::{bytes::Bytes, error::BoxDynError}; use sqlx_test::{new, pool, setup_if_needed}; use std::env; From 9409e5e9d32cde7928493ac61864f14cb5e31760 Mon Sep 17 00:00:00 2001 From: mpyw Date: Mon, 12 Aug 2024 21:09:52 +0900 Subject: [PATCH 04/27] Fix: Avoid breaking changes on `AnyConnectionBackend` --- sqlx-core/src/any/connection/backend.rs | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/sqlx-core/src/any/connection/backend.rs b/sqlx-core/src/any/connection/backend.rs index e4e752c07b..5d8b2709fd 100644 --- a/sqlx-core/src/any/connection/backend.rs +++ b/sqlx-core/src/any/connection/backend.rs @@ -40,7 +40,9 @@ pub trait AnyConnectionBackend: std::any::Any + Debug + Send + 'static { /// - Level 0: No active transaction. /// - Level 1: A transaction is active. /// - Level 2 or higher: A transaction is active and one or more SAVEPOINTs have been created within it. - fn get_transaction_depth(&self) -> usize; + fn get_transaction_depth(&self) -> usize { + unimplemented!("get_transaction_depth() is not implemented for this backend. This is a provided method to avoid a breaking change, but it will become a required method in version 0.9 and later."); + } /// The number of statements currently cached in the connection. fn cached_statements_size(&self) -> usize { From 10d0aea44ba5d72de6c179bfbfb868e9294355b8 Mon Sep 17 00:00:00 2001 From: mpyw Date: Mon, 12 Aug 2024 21:34:12 +0900 Subject: [PATCH 05/27] Refactor: Remove verbose `SqliteConnection` typing --- sqlx-sqlite/src/transaction.rs | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/sqlx-sqlite/src/transaction.rs b/sqlx-sqlite/src/transaction.rs index dafc33359d..cbde09ae68 100644 --- a/sqlx-sqlite/src/transaction.rs +++ b/sqlx-sqlite/src/transaction.rs @@ -1,5 +1,4 @@ use futures_core::future::BoxFuture; -use sqlx_core::database::Database; use sqlx_core::error::Error; use sqlx_core::transaction::TransactionManager; @@ -27,7 +26,7 @@ impl TransactionManager for SqliteTransactionManager { conn.worker.start_rollback().ok(); } - fn get_transaction_depth(_conn: &::Connection) -> usize { + fn get_transaction_depth(_conn: &SqliteConnection) -> usize { todo!() } } From a66787d36d62876b55475ef2326d17bade817aed Mon Sep 17 00:00:00 2001 From: mpyw Date: Mon, 12 Aug 2024 22:39:14 +0900 Subject: [PATCH 06/27] Feat: Implementation for SQLite I have included `AtomicUsize` in `WorkerSharedState`. Ideally, it is not desirable to execute `load` and `fetch_add` in two separate steps, but we decided to allow it here since there is only one thread writing. To prevent writing from other threads, the field itself was made private, and a getter method was provided with `pub(crate)`. --- sqlx-sqlite/src/connection/establish.rs | 1 - sqlx-sqlite/src/connection/mod.rs | 5 +---- sqlx-sqlite/src/connection/worker.rs | 22 +++++++++++++++------- sqlx-sqlite/src/transaction.rs | 4 ++-- 4 files changed, 18 insertions(+), 14 deletions(-) diff --git a/sqlx-sqlite/src/connection/establish.rs b/sqlx-sqlite/src/connection/establish.rs index 6438b6b7f4..698e64fce7 100644 --- a/sqlx-sqlite/src/connection/establish.rs +++ b/sqlx-sqlite/src/connection/establish.rs @@ -292,7 +292,6 @@ impl EstablishParams { Ok(ConnectionState { handle, statements: Statements::new(self.statement_cache_capacity), - transaction_depth: 0, log_settings: self.log_settings.clone(), progress_handler_callback: None, update_hook_callback: None, diff --git a/sqlx-sqlite/src/connection/mod.rs b/sqlx-sqlite/src/connection/mod.rs index f12f49edbd..0234d5d914 100644 --- a/sqlx-sqlite/src/connection/mod.rs +++ b/sqlx-sqlite/src/connection/mod.rs @@ -94,9 +94,6 @@ unsafe impl Send for UpdateHookHandler {} pub(crate) struct ConnectionState { pub(crate) handle: ConnectionHandle, - // transaction status - pub(crate) transaction_depth: usize, - pub(crate) statements: Statements, log_settings: LogSettings, @@ -211,7 +208,7 @@ impl Connection for SqliteConnection { } fn get_transaction_depth(&self) -> usize { - todo!() + self.worker.shared.get_transaction_depth() } fn cached_statements_size(&self) -> usize { diff --git a/sqlx-sqlite/src/connection/worker.rs b/sqlx-sqlite/src/connection/worker.rs index a01de2419c..c0b6973b73 100644 --- a/sqlx-sqlite/src/connection/worker.rs +++ b/sqlx-sqlite/src/connection/worker.rs @@ -34,10 +34,17 @@ pub(crate) struct ConnectionWorker { } pub(crate) struct WorkerSharedState { + transaction_depth: AtomicUsize, pub(crate) cached_statements_size: AtomicUsize, pub(crate) conn: Mutex, } +impl WorkerSharedState { + pub(crate) fn get_transaction_depth(&self) -> usize { + self.transaction_depth.load(Ordering::Acquire) + } +} + enum Command { Prepare { query: Box, @@ -93,6 +100,7 @@ impl ConnectionWorker { }; let shared = Arc::new(WorkerSharedState { + transaction_depth: AtomicUsize::new(0), cached_statements_size: AtomicUsize::new(0), // note: must be fair because in `Command::UnlockDb` we unlock the mutex // and then immediately try to relock it; an unfair mutex would immediately @@ -181,12 +189,12 @@ impl ConnectionWorker { update_cached_statements_size(&conn, &shared.cached_statements_size); } Command::Begin { tx } => { - let depth = conn.transaction_depth; + let depth = shared.transaction_depth.load(Ordering::Acquire); let res = conn.handle .exec(begin_ansi_transaction_sql(depth)) .map(|_| { - conn.transaction_depth += 1; + shared.transaction_depth.fetch_add(1, Ordering::Release); }); let res_ok = res.is_ok(); @@ -199,7 +207,7 @@ impl ConnectionWorker { .handle .exec(rollback_ansi_transaction_sql(depth + 1)) .map(|_| { - conn.transaction_depth -= 1; + shared.transaction_depth.fetch_sub(1, Ordering::Release); }) { // The rollback failed. To prevent leaving the connection @@ -211,13 +219,13 @@ impl ConnectionWorker { } } Command::Commit { tx } => { - let depth = conn.transaction_depth; + let depth = shared.transaction_depth.load(Ordering::Acquire); let res = if depth > 0 { conn.handle .exec(commit_ansi_transaction_sql(depth)) .map(|_| { - conn.transaction_depth -= 1; + shared.transaction_depth.fetch_sub(1, Ordering::Release); }) } else { Ok(()) @@ -237,13 +245,13 @@ impl ConnectionWorker { continue; } - let depth = conn.transaction_depth; + let depth = shared.transaction_depth.load(Ordering::Acquire); let res = if depth > 0 { conn.handle .exec(rollback_ansi_transaction_sql(depth)) .map(|_| { - conn.transaction_depth -= 1; + shared.transaction_depth.fetch_sub(1, Ordering::Release); }) } else { Ok(()) diff --git a/sqlx-sqlite/src/transaction.rs b/sqlx-sqlite/src/transaction.rs index cbde09ae68..6bce5b30cf 100644 --- a/sqlx-sqlite/src/transaction.rs +++ b/sqlx-sqlite/src/transaction.rs @@ -26,7 +26,7 @@ impl TransactionManager for SqliteTransactionManager { conn.worker.start_rollback().ok(); } - fn get_transaction_depth(_conn: &SqliteConnection) -> usize { - todo!() + fn get_transaction_depth(conn: &SqliteConnection) -> usize { + conn.worker.shared.get_transaction_depth() } } From 54c73b63cd41ec12cf4a34f71e2d96491f465017 Mon Sep 17 00:00:00 2001 From: mpyw Date: Mon, 12 Aug 2024 22:47:21 +0900 Subject: [PATCH 07/27] Refactor: Same approach for `cached_statements_size` ref: a66787d36d62876b55475ef2326d17bade817aed --- sqlx-sqlite/src/connection/mod.rs | 5 +---- sqlx-sqlite/src/connection/worker.rs | 6 +++++- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/sqlx-sqlite/src/connection/mod.rs b/sqlx-sqlite/src/connection/mod.rs index 0234d5d914..0950b31db1 100644 --- a/sqlx-sqlite/src/connection/mod.rs +++ b/sqlx-sqlite/src/connection/mod.rs @@ -212,10 +212,7 @@ impl Connection for SqliteConnection { } fn cached_statements_size(&self) -> usize { - self.worker - .shared - .cached_statements_size - .load(std::sync::atomic::Ordering::Acquire) + self.worker.shared.get_cached_statements_size() } fn clear_cached_statements(&mut self) -> BoxFuture<'_, Result<(), Error>> { diff --git a/sqlx-sqlite/src/connection/worker.rs b/sqlx-sqlite/src/connection/worker.rs index c0b6973b73..6dda814a0a 100644 --- a/sqlx-sqlite/src/connection/worker.rs +++ b/sqlx-sqlite/src/connection/worker.rs @@ -35,7 +35,7 @@ pub(crate) struct ConnectionWorker { pub(crate) struct WorkerSharedState { transaction_depth: AtomicUsize, - pub(crate) cached_statements_size: AtomicUsize, + cached_statements_size: AtomicUsize, pub(crate) conn: Mutex, } @@ -43,6 +43,10 @@ impl WorkerSharedState { pub(crate) fn get_transaction_depth(&self) -> usize { self.transaction_depth.load(Ordering::Acquire) } + + pub(crate) fn get_cached_statements_size(&self) -> usize { + self.cached_statements_size.load(Ordering::Acquire) + } } enum Command { From b780ed7788f717334aeea90b91bc944aa78fb015 Mon Sep 17 00:00:00 2001 From: mpyw Date: Tue, 13 Aug 2024 11:16:51 +0900 Subject: [PATCH 08/27] Fix: Add missing `is_in_transaction` for backend --- sqlx-core/src/any/connection/backend.rs | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/sqlx-core/src/any/connection/backend.rs b/sqlx-core/src/any/connection/backend.rs index 5d8b2709fd..6ba5ad92b2 100644 --- a/sqlx-core/src/any/connection/backend.rs +++ b/sqlx-core/src/any/connection/backend.rs @@ -44,6 +44,16 @@ pub trait AnyConnectionBackend: std::any::Any + Debug + Send + 'static { unimplemented!("get_transaction_depth() is not implemented for this backend. This is a provided method to avoid a breaking change, but it will become a required method in version 0.9 and later."); } + /// Checks if the connection is currently in a transaction. + /// + /// This method returns `true` if the current transaction depth is greater than 0, + /// indicating that a transaction is active. It returns `false` if the transaction depth is 0, + /// meaning no transaction is active. + #[inline] + fn is_in_transaction(&self) -> bool { + self.get_transaction_depth() != 0 + } + /// The number of statements currently cached in the connection. fn cached_statements_size(&self) -> usize { 0 From c866c5b8017948a00b7fe6cf9330ee7526af14de Mon Sep 17 00:00:00 2001 From: mpyw Date: Tue, 13 Aug 2024 17:01:44 +0900 Subject: [PATCH 09/27] Doc: Remove verbose "synchronously" word --- sqlx-core/src/connection.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sqlx-core/src/connection.rs b/sqlx-core/src/connection.rs index 9faa71effe..1f86f6c1a9 100644 --- a/sqlx-core/src/connection.rs +++ b/sqlx-core/src/connection.rs @@ -49,7 +49,7 @@ pub trait Connection: Send { where Self: Sized; - /// Returns the current transaction depth synchronously. + /// Returns the current transaction depth. /// /// Transaction depth indicates the level of nested transactions: /// - Level 0: No active transaction. From 85ad53990788fd0a4cd132953b6a7256ec362473 Mon Sep 17 00:00:00 2001 From: mpyw Date: Tue, 13 Aug 2024 21:49:50 +0900 Subject: [PATCH 10/27] Fix: Remove useless `mut` qualifier --- sqlx-postgres/src/any.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sqlx-postgres/src/any.rs b/sqlx-postgres/src/any.rs index bcfef8dce1..bc74f58da2 100644 --- a/sqlx-postgres/src/any.rs +++ b/sqlx-postgres/src/any.rs @@ -52,7 +52,7 @@ impl AnyConnectionBackend for PgConnection { PgTransactionManager::start_rollback(self) } - fn get_transaction_depth(&mut self) -> usize { + fn get_transaction_depth(&self) -> usize { PgTransactionManager::get_transaction_depth(self) } From cfb1f19650bf78808f3d8e05684638c5545a859c Mon Sep 17 00:00:00 2001 From: Duncan Fairbanks Date: Wed, 27 Nov 2024 14:35:54 -0800 Subject: [PATCH 11/27] feat: add Connection::begin_with This patch completes the plumbing of an optional statement from these methods to `TransactionManager::begin` without any validation of the provided statement. There is a new `Error::InvalidSavePoint` which is triggered by any attempt to call `Connection::begin_with` when we are already inside of a transaction. --- sqlx-core/src/acquire.rs | 4 ++-- sqlx-core/src/any/connection/backend.rs | 9 ++++++++- sqlx-core/src/any/connection/mod.rs | 13 ++++++++++++- sqlx-core/src/any/transaction.rs | 8 ++++++-- sqlx-core/src/connection.rs | 11 +++++++++++ sqlx-core/src/error.rs | 3 +++ sqlx-core/src/pool/connection.rs | 2 +- sqlx-core/src/pool/mod.rs | 8 ++++++-- sqlx-core/src/transaction.rs | 18 ++++++++++++----- sqlx-mysql/src/any.rs | 8 ++++++-- sqlx-mysql/src/connection/mod.rs | 13 ++++++++++++- sqlx-mysql/src/transaction.rs | 16 +++++++++++---- sqlx-postgres/src/any.rs | 8 ++++++-- sqlx-postgres/src/connection/mod.rs | 13 ++++++++++++- sqlx-postgres/src/transaction.rs | 15 +++++++++++--- sqlx-sqlite/src/any.rs | 9 +++++++-- sqlx-sqlite/src/connection/mod.rs | 13 ++++++++++++- sqlx-sqlite/src/connection/worker.rs | 26 +++++++++++++++++++++---- sqlx-sqlite/src/transaction.rs | 8 ++++++-- 19 files changed, 169 insertions(+), 36 deletions(-) diff --git a/sqlx-core/src/acquire.rs b/sqlx-core/src/acquire.rs index c9d7fb215c..59bac9fa59 100644 --- a/sqlx-core/src/acquire.rs +++ b/sqlx-core/src/acquire.rs @@ -93,7 +93,7 @@ impl<'a, DB: Database> Acquire<'a> for &'_ Pool { let conn = self.acquire(); Box::pin(async move { - Transaction::begin(MaybePoolConnection::PoolConnection(conn.await?)).await + Transaction::begin(MaybePoolConnection::PoolConnection(conn.await?), None).await }) } } @@ -121,7 +121,7 @@ macro_rules! impl_acquire { 'c, Result<$crate::transaction::Transaction<'c, $DB>, $crate::error::Error>, > { - $crate::transaction::Transaction::begin(self) + $crate::transaction::Transaction::begin(self, None) } } }; diff --git a/sqlx-core/src/any/connection/backend.rs b/sqlx-core/src/any/connection/backend.rs index b30cbe83f3..2fe9ed7656 100644 --- a/sqlx-core/src/any/connection/backend.rs +++ b/sqlx-core/src/any/connection/backend.rs @@ -3,6 +3,7 @@ use crate::describe::Describe; use either::Either; use futures_core::future::BoxFuture; use futures_core::stream::BoxStream; +use std::borrow::Cow; use std::fmt::Debug; pub trait AnyConnectionBackend: std::any::Any + Debug + Send + 'static { @@ -26,7 +27,13 @@ pub trait AnyConnectionBackend: std::any::Any + Debug + Send + 'static { fn ping(&mut self) -> BoxFuture<'_, crate::Result<()>>; /// Begin a new transaction or establish a savepoint within the active transaction. - fn begin(&mut self) -> BoxFuture<'_, crate::Result<()>>; + /// + /// If this is a new transaction, `statement` may be used instead of the + /// default "BEGIN" statement. + /// + /// If we are already inside a transaction and `statement.is_some()`, then + /// `Error::InvalidSavePoint` is returned without running any statements. + fn begin(&mut self, statement: Option>) -> BoxFuture<'_, crate::Result<()>>; fn commit(&mut self) -> BoxFuture<'_, crate::Result<()>>; diff --git a/sqlx-core/src/any/connection/mod.rs b/sqlx-core/src/any/connection/mod.rs index b6f795848a..8cf8fc510c 100644 --- a/sqlx-core/src/any/connection/mod.rs +++ b/sqlx-core/src/any/connection/mod.rs @@ -1,4 +1,5 @@ use futures_core::future::BoxFuture; +use std::borrow::Cow; use crate::any::{Any, AnyConnectOptions}; use crate::connection::{ConnectOptions, Connection}; @@ -87,7 +88,17 @@ impl Connection for AnyConnection { where Self: Sized, { - Transaction::begin(self) + Transaction::begin(self, None) + } + + fn begin_with( + &mut self, + statement: impl Into>, + ) -> BoxFuture<'_, Result, Error>> + where + Self: Sized, + { + Transaction::begin(self, Some(statement.into())) } fn cached_statements_size(&self) -> usize { diff --git a/sqlx-core/src/any/transaction.rs b/sqlx-core/src/any/transaction.rs index fce4175626..4972268499 100644 --- a/sqlx-core/src/any/transaction.rs +++ b/sqlx-core/src/any/transaction.rs @@ -1,4 +1,5 @@ use futures_util::future::BoxFuture; +use std::borrow::Cow; use crate::any::{Any, AnyConnection}; use crate::error::Error; @@ -9,8 +10,11 @@ pub struct AnyTransactionManager; impl TransactionManager for AnyTransactionManager { type Database = Any; - fn begin(conn: &mut AnyConnection) -> BoxFuture<'_, Result<(), Error>> { - conn.backend.begin() + fn begin<'conn>( + conn: &'conn mut AnyConnection, + statement: Option>, + ) -> BoxFuture<'conn, Result<(), Error>> { + conn.backend.begin(statement) } fn commit(conn: &mut AnyConnection) -> BoxFuture<'_, Result<(), Error>> { diff --git a/sqlx-core/src/connection.rs b/sqlx-core/src/connection.rs index ce2aa6c629..de0a05799d 100644 --- a/sqlx-core/src/connection.rs +++ b/sqlx-core/src/connection.rs @@ -4,6 +4,7 @@ use crate::error::Error; use crate::transaction::Transaction; use futures_core::future::BoxFuture; use log::LevelFilter; +use std::borrow::Cow; use std::fmt::Debug; use std::str::FromStr; use std::time::Duration; @@ -49,6 +50,16 @@ pub trait Connection: Send { where Self: Sized; + /// Begin a new transaction with a custom statement. + /// + /// Returns a [`Transaction`] for controlling and tracking the new transaction. + fn begin_with( + &mut self, + statement: impl Into>, + ) -> BoxFuture<'_, Result, Error>> + where + Self: Sized; + /// Execute the function inside a transaction. /// /// If the function returns an error, the transaction will be rolled back. If it does not diff --git a/sqlx-core/src/error.rs b/sqlx-core/src/error.rs index 17774addd2..8b454575e9 100644 --- a/sqlx-core/src/error.rs +++ b/sqlx-core/src/error.rs @@ -111,6 +111,9 @@ pub enum Error { #[cfg(feature = "migrate")] #[error("{0}")] Migrate(#[source] Box), + + #[error("attempted to call begin_with at non-zero transaction depth")] + InvalidSavePointStatement, } impl StdError for Box {} diff --git a/sqlx-core/src/pool/connection.rs b/sqlx-core/src/pool/connection.rs index bf3a6d4b1c..c029fec6eb 100644 --- a/sqlx-core/src/pool/connection.rs +++ b/sqlx-core/src/pool/connection.rs @@ -191,7 +191,7 @@ impl<'c, DB: Database> crate::acquire::Acquire<'c> for &'c mut PoolConnection futures_core::future::BoxFuture<'c, Result, Error>> { - crate::transaction::Transaction::begin(&mut **self) + crate::transaction::Transaction::begin(&mut **self, None) } } diff --git a/sqlx-core/src/pool/mod.rs b/sqlx-core/src/pool/mod.rs index e998618413..438eebf6c1 100644 --- a/sqlx-core/src/pool/mod.rs +++ b/sqlx-core/src/pool/mod.rs @@ -367,13 +367,17 @@ impl Pool { /// Retrieves a connection and immediately begins a new transaction. pub async fn begin(&self) -> Result, Error> { - Transaction::begin(MaybePoolConnection::PoolConnection(self.acquire().await?)).await + Transaction::begin( + MaybePoolConnection::PoolConnection(self.acquire().await?), + None, + ) + .await } /// Attempts to retrieve a connection and immediately begins a new transaction if successful. pub async fn try_begin(&self) -> Result>, Error> { match self.try_acquire() { - Some(conn) => Transaction::begin(MaybePoolConnection::PoolConnection(conn)) + Some(conn) => Transaction::begin(MaybePoolConnection::PoolConnection(conn), None) .await .map(Some), diff --git a/sqlx-core/src/transaction.rs b/sqlx-core/src/transaction.rs index 9cd38aab3a..d9459c53d4 100644 --- a/sqlx-core/src/transaction.rs +++ b/sqlx-core/src/transaction.rs @@ -16,9 +16,16 @@ pub trait TransactionManager { type Database: Database; /// Begin a new transaction or establish a savepoint within the active transaction. - fn begin( - conn: &mut ::Connection, - ) -> BoxFuture<'_, Result<(), Error>>; + /// + /// If this is a new transaction, `statement` may be used instead of the + /// default "BEGIN" statement. + /// + /// If we are already inside a transaction and `statement.is_some()`, then + /// `Error::InvalidSavePoint` is returned without running any statements. + fn begin<'conn>( + conn: &'conn mut ::Connection, + statement: Option>, + ) -> BoxFuture<'conn, Result<(), Error>>; /// Commit the active transaction or release the most recent savepoint. fn commit( @@ -83,11 +90,12 @@ where #[doc(hidden)] pub fn begin( conn: impl Into>, + statement: Option>, ) -> BoxFuture<'c, Result> { let mut conn = conn.into(); Box::pin(async move { - DB::TransactionManager::begin(&mut conn).await?; + DB::TransactionManager::begin(&mut conn, statement).await?; Ok(Self { connection: conn, @@ -237,7 +245,7 @@ impl<'c, 't, DB: Database> crate::acquire::Acquire<'t> for &'t mut Transaction<' #[inline] fn begin(self) -> BoxFuture<'t, Result, Error>> { - Transaction::begin(&mut **self) + Transaction::begin(&mut **self, None) } } diff --git a/sqlx-mysql/src/any.rs b/sqlx-mysql/src/any.rs index 0466bfc0a4..96190f0bd2 100644 --- a/sqlx-mysql/src/any.rs +++ b/sqlx-mysql/src/any.rs @@ -16,6 +16,7 @@ use sqlx_core::database::Database; use sqlx_core::describe::Describe; use sqlx_core::executor::Executor; use sqlx_core::transaction::TransactionManager; +use std::borrow::Cow; use std::future; sqlx_core::declare_driver_with_optional_migrate!(DRIVER = MySql); @@ -37,8 +38,11 @@ impl AnyConnectionBackend for MySqlConnection { Connection::ping(self) } - fn begin(&mut self) -> BoxFuture<'_, sqlx_core::Result<()>> { - MySqlTransactionManager::begin(self) + fn begin( + &mut self, + statement: Option>, + ) -> BoxFuture<'_, sqlx_core::Result<()>> { + MySqlTransactionManager::begin(self, statement) } fn commit(&mut self) -> BoxFuture<'_, sqlx_core::Result<()>> { diff --git a/sqlx-mysql/src/connection/mod.rs b/sqlx-mysql/src/connection/mod.rs index c4978a7701..e2c671046d 100644 --- a/sqlx-mysql/src/connection/mod.rs +++ b/sqlx-mysql/src/connection/mod.rs @@ -1,3 +1,4 @@ +use std::borrow::Cow; use std::fmt::{self, Debug, Formatter}; use futures_core::future::BoxFuture; @@ -111,7 +112,17 @@ impl Connection for MySqlConnection { where Self: Sized, { - Transaction::begin(self) + Transaction::begin(self, None) + } + + fn begin_with( + &mut self, + statement: impl Into>, + ) -> BoxFuture<'_, Result, Error>> + where + Self: Sized, + { + Transaction::begin(self, Some(statement.into())) } fn shrink_buffers(&mut self) { diff --git a/sqlx-mysql/src/transaction.rs b/sqlx-mysql/src/transaction.rs index d8538cc2b3..f287c4a80b 100644 --- a/sqlx-mysql/src/transaction.rs +++ b/sqlx-mysql/src/transaction.rs @@ -1,3 +1,5 @@ +use std::borrow::Cow; + use futures_core::future::BoxFuture; use crate::connection::Waiting; @@ -14,12 +16,18 @@ pub struct MySqlTransactionManager; impl TransactionManager for MySqlTransactionManager { type Database = MySql; - fn begin(conn: &mut MySqlConnection) -> BoxFuture<'_, Result<(), Error>> { + fn begin<'conn>( + conn: &'conn mut MySqlConnection, + statement: Option>, + ) -> BoxFuture<'conn, Result<(), Error>> { Box::pin(async move { let depth = conn.inner.transaction_depth; - - conn.execute(&*begin_ansi_transaction_sql(depth)).await?; - conn.inner.transaction_depth = depth + 1; + if statement.is_some() && depth > 0 { + return Err(Error::InvalidSavePointStatement); + } + let statement = statement.unwrap_or_else(|| begin_ansi_transaction_sql(depth)); + conn.execute(&*statement).await?; + conn.inner.transaction_depth += 1; Ok(()) }) diff --git a/sqlx-postgres/src/any.rs b/sqlx-postgres/src/any.rs index efa9a044bc..d189301c13 100644 --- a/sqlx-postgres/src/any.rs +++ b/sqlx-postgres/src/any.rs @@ -5,6 +5,7 @@ use crate::{ use futures_core::future::BoxFuture; use futures_core::stream::BoxStream; use futures_util::{stream, StreamExt, TryFutureExt, TryStreamExt}; +use std::borrow::Cow; use std::future; use sqlx_core::any::{ @@ -39,8 +40,11 @@ impl AnyConnectionBackend for PgConnection { Connection::ping(self) } - fn begin(&mut self) -> BoxFuture<'_, sqlx_core::Result<()>> { - PgTransactionManager::begin(self) + fn begin( + &mut self, + statement: Option>, + ) -> BoxFuture<'_, sqlx_core::Result<()>> { + PgTransactionManager::begin(self, statement) } fn commit(&mut self) -> BoxFuture<'_, sqlx_core::Result<()>> { diff --git a/sqlx-postgres/src/connection/mod.rs b/sqlx-postgres/src/connection/mod.rs index c139f8e53d..04b9a4c9e2 100644 --- a/sqlx-postgres/src/connection/mod.rs +++ b/sqlx-postgres/src/connection/mod.rs @@ -1,3 +1,4 @@ +use std::borrow::Cow; use std::fmt::{self, Debug, Formatter}; use std::sync::Arc; @@ -179,7 +180,17 @@ impl Connection for PgConnection { where Self: Sized, { - Transaction::begin(self) + Transaction::begin(self, None) + } + + fn begin_with( + &mut self, + statement: impl Into>, + ) -> BoxFuture<'_, Result, Error>> + where + Self: Sized, + { + Transaction::begin(self, Some(statement.into())) } fn cached_statements_size(&self) -> usize { diff --git a/sqlx-postgres/src/transaction.rs b/sqlx-postgres/src/transaction.rs index e7c78488eb..767d83c52e 100644 --- a/sqlx-postgres/src/transaction.rs +++ b/sqlx-postgres/src/transaction.rs @@ -1,4 +1,5 @@ use futures_core::future::BoxFuture; +use std::borrow::Cow; use crate::error::Error; use crate::executor::Executor; @@ -13,11 +14,19 @@ pub struct PgTransactionManager; impl TransactionManager for PgTransactionManager { type Database = Postgres; - fn begin(conn: &mut PgConnection) -> BoxFuture<'_, Result<(), Error>> { + fn begin<'conn>( + conn: &'conn mut PgConnection, + statement: Option>, + ) -> BoxFuture<'conn, Result<(), Error>> { Box::pin(async move { + let depth = conn.inner.transaction_depth; + if statement.is_some() && depth > 0 { + return Err(Error::InvalidSavePointStatement); + } + let statement = statement.unwrap_or_else(|| begin_ansi_transaction_sql(depth)); + let rollback = Rollback::new(conn); - let query = begin_ansi_transaction_sql(rollback.conn.inner.transaction_depth); - rollback.conn.queue_simple_query(&query)?; + rollback.conn.queue_simple_query(&statement)?; rollback.conn.inner.transaction_depth += 1; rollback.conn.wait_until_ready().await?; rollback.defuse(); diff --git a/sqlx-sqlite/src/any.rs b/sqlx-sqlite/src/any.rs index 01600d9931..2c74c01494 100644 --- a/sqlx-sqlite/src/any.rs +++ b/sqlx-sqlite/src/any.rs @@ -1,3 +1,5 @@ +use std::borrow::Cow; + use crate::{ Either, Sqlite, SqliteArgumentValue, SqliteArguments, SqliteColumn, SqliteConnectOptions, SqliteConnection, SqliteQueryResult, SqliteRow, SqliteTransactionManager, SqliteTypeInfo, @@ -37,8 +39,11 @@ impl AnyConnectionBackend for SqliteConnection { Connection::ping(self) } - fn begin(&mut self) -> BoxFuture<'_, sqlx_core::Result<()>> { - SqliteTransactionManager::begin(self) + fn begin( + &mut self, + statement: Option>, + ) -> BoxFuture<'_, sqlx_core::Result<()>> { + SqliteTransactionManager::begin(self, statement) } fn commit(&mut self) -> BoxFuture<'_, sqlx_core::Result<()>> { diff --git a/sqlx-sqlite/src/connection/mod.rs b/sqlx-sqlite/src/connection/mod.rs index c1f9d46da8..ba24976fb4 100644 --- a/sqlx-sqlite/src/connection/mod.rs +++ b/sqlx-sqlite/src/connection/mod.rs @@ -1,3 +1,4 @@ +use std::borrow::Cow; use std::cmp::Ordering; use std::ffi::CStr; use std::fmt::Write; @@ -252,7 +253,17 @@ impl Connection for SqliteConnection { where Self: Sized, { - Transaction::begin(self) + Transaction::begin(self, None) + } + + fn begin_with( + &mut self, + statement: impl Into>, + ) -> BoxFuture<'_, Result, Error>> + where + Self: Sized, + { + Transaction::begin(self, Some(statement.into())) } fn cached_statements_size(&self) -> usize { diff --git a/sqlx-sqlite/src/connection/worker.rs b/sqlx-sqlite/src/connection/worker.rs index c1c67636f1..26a90314c2 100644 --- a/sqlx-sqlite/src/connection/worker.rs +++ b/sqlx-sqlite/src/connection/worker.rs @@ -56,6 +56,7 @@ enum Command { }, Begin { tx: rendezvous_oneshot::Sender>, + statement: Option>, }, Commit { tx: rendezvous_oneshot::Sender>, @@ -182,11 +183,25 @@ impl ConnectionWorker { update_cached_statements_size(&conn, &shared.cached_statements_size); } - Command::Begin { tx } => { + Command::Begin { tx, statement } => { let depth = conn.transaction_depth; + + let statement = if depth == 0 { + statement.unwrap_or_else(|| begin_ansi_transaction_sql(depth)) + } else { + if statement.is_some() { + if tx.blocking_send(Err(Error::InvalidSavePointStatement)).is_err() { + break; + } + continue; + } + + begin_ansi_transaction_sql(depth) + }; + let res = conn.handle - .exec(begin_ansi_transaction_sql(depth)) + .exec(statement) .map(|_| { conn.transaction_depth += 1; }); @@ -333,8 +348,11 @@ impl ConnectionWorker { Ok(rx) } - pub(crate) async fn begin(&mut self) -> Result<(), Error> { - self.oneshot_cmd_with_ack(|tx| Command::Begin { tx }) + pub(crate) async fn begin( + &mut self, + statement: Option>, + ) -> Result<(), Error> { + self.oneshot_cmd_with_ack(|tx| Command::Begin { tx, statement }) .await? } diff --git a/sqlx-sqlite/src/transaction.rs b/sqlx-sqlite/src/transaction.rs index 24eaca51b1..d7c40d4956 100644 --- a/sqlx-sqlite/src/transaction.rs +++ b/sqlx-sqlite/src/transaction.rs @@ -1,4 +1,5 @@ use futures_core::future::BoxFuture; +use std::borrow::Cow; use crate::{Sqlite, SqliteConnection}; use sqlx_core::error::Error; @@ -10,8 +11,11 @@ pub struct SqliteTransactionManager; impl TransactionManager for SqliteTransactionManager { type Database = Sqlite; - fn begin(conn: &mut SqliteConnection) -> BoxFuture<'_, Result<(), Error>> { - Box::pin(conn.worker.begin()) + fn begin<'conn>( + conn: &'conn mut SqliteConnection, + statement: Option>, + ) -> BoxFuture<'conn, Result<(), Error>> { + Box::pin(conn.worker.begin(statement)) } fn commit(conn: &mut SqliteConnection) -> BoxFuture<'_, Result<(), Error>> { From 7d802ccd0461c91ad05f189f4414869a6d3b82f6 Mon Sep 17 00:00:00 2001 From: Duncan Fairbanks Date: Wed, 27 Nov 2024 17:21:53 -0800 Subject: [PATCH 12/27] feat: add Pool::begin_with and Pool::try_begin_with --- sqlx-core/src/pool/mod.rs | 31 +++++++++++++++++++++++++++++++ 1 file changed, 31 insertions(+) diff --git a/sqlx-core/src/pool/mod.rs b/sqlx-core/src/pool/mod.rs index 438eebf6c1..b759bacdda 100644 --- a/sqlx-core/src/pool/mod.rs +++ b/sqlx-core/src/pool/mod.rs @@ -54,6 +54,7 @@ //! [`Pool::acquire`] or //! [`Pool::begin`]. +use std::borrow::Cow; use std::fmt; use std::future::Future; use std::pin::Pin; @@ -385,6 +386,36 @@ impl Pool { } } + /// Retrieves a connection and immediately begins a new transaction using `statement`. + pub async fn begin_with( + &self, + statement: impl Into>, + ) -> Result, Error> { + Transaction::begin( + MaybePoolConnection::PoolConnection(self.acquire().await?), + Some(statement.into()), + ) + .await + } + + /// Attempts to retrieve a connection and, if successful, immediately begins a new + /// transaction using `statement`. + pub async fn try_begin_with( + &self, + statement: impl Into>, + ) -> Result>, Error> { + match self.try_acquire() { + Some(conn) => Transaction::begin( + MaybePoolConnection::PoolConnection(conn), + Some(statement.into()), + ) + .await + .map(Some), + + None => Ok(None), + } + } + /// Shut down the connection pool, immediately waking all tasks waiting for a connection. /// /// Upon calling this method, any currently waiting or subsequent calls to [`Pool::acquire`] and From c188073e1ff8f64201f5da9c5d4aecc63e3d1237 Mon Sep 17 00:00:00 2001 From: Duncan Fairbanks Date: Wed, 27 Nov 2024 18:01:57 -0800 Subject: [PATCH 13/27] feat: add Error::BeginFailed and validate that custom "begin" statements are successful --- sqlx-core/src/error.rs | 3 +++ sqlx-mysql/src/connection/establish.rs | 1 + sqlx-mysql/src/connection/mod.rs | 10 ++++++++++ sqlx-mysql/src/protocol/response/status.rs | 2 +- sqlx-mysql/src/transaction.rs | 3 +++ sqlx-postgres/src/connection/mod.rs | 7 +++++++ sqlx-postgres/src/transaction.rs | 5 ++++- sqlx-sqlite/src/connection/mod.rs | 9 +++++++-- sqlx-sqlite/src/transaction.rs | 13 ++++++++++++- 9 files changed, 48 insertions(+), 5 deletions(-) diff --git a/sqlx-core/src/error.rs b/sqlx-core/src/error.rs index 8b454575e9..150d643180 100644 --- a/sqlx-core/src/error.rs +++ b/sqlx-core/src/error.rs @@ -114,6 +114,9 @@ pub enum Error { #[error("attempted to call begin_with at non-zero transaction depth")] InvalidSavePointStatement, + + #[error("got unexpected connection status after attempting to begin transaction")] + BeginFailed, } impl StdError for Box {} diff --git a/sqlx-mysql/src/connection/establish.rs b/sqlx-mysql/src/connection/establish.rs index 0623a0556c..85a9d84f96 100644 --- a/sqlx-mysql/src/connection/establish.rs +++ b/sqlx-mysql/src/connection/establish.rs @@ -27,6 +27,7 @@ impl MySqlConnection { inner: Box::new(MySqlConnectionInner { stream, transaction_depth: 0, + status_flags: Default::default(), cache_statement: StatementCache::new(options.statement_cache_capacity), log_settings: options.log_settings.clone(), }), diff --git a/sqlx-mysql/src/connection/mod.rs b/sqlx-mysql/src/connection/mod.rs index e2c671046d..0a2f5fb839 100644 --- a/sqlx-mysql/src/connection/mod.rs +++ b/sqlx-mysql/src/connection/mod.rs @@ -8,6 +8,7 @@ pub(crate) use stream::{MySqlStream, Waiting}; use crate::common::StatementCache; use crate::error::Error; +use crate::protocol::response::Status; use crate::protocol::statement::StmtClose; use crate::protocol::text::{Ping, Quit}; use crate::statement::MySqlStatementMetadata; @@ -35,6 +36,7 @@ pub(crate) struct MySqlConnectionInner { // transaction status pub(crate) transaction_depth: usize, + status_flags: Status, // cache by query string to the statement id and metadata cache_statement: StatementCache<(u32, MySqlStatementMetadata)>, @@ -42,6 +44,14 @@ pub(crate) struct MySqlConnectionInner { log_settings: LogSettings, } +impl MySqlConnection { + pub(crate) fn in_transaction(&self) -> bool { + self.inner + .status_flags + .intersects(Status::SERVER_STATUS_IN_TRANS) + } +} + impl Debug for MySqlConnection { fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { f.debug_struct("MySqlConnection").finish() diff --git a/sqlx-mysql/src/protocol/response/status.rs b/sqlx-mysql/src/protocol/response/status.rs index bf5013deed..4a8bb0375a 100644 --- a/sqlx-mysql/src/protocol/response/status.rs +++ b/sqlx-mysql/src/protocol/response/status.rs @@ -1,7 +1,7 @@ // https://dev.mysql.com/doc/dev/mysql-server/8.0.12/mysql__com_8h.html#a1d854e841086925be1883e4d7b4e8cad // https://mariadb.com/kb/en/library/mariadb-connectorc-types-and-definitions/#server-status bitflags::bitflags! { - #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] + #[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Hash)] pub struct Status: u16 { // Is raised when a multi-statement transaction has been started, either explicitly, // by means of BEGIN or COMMIT AND CHAIN, or implicitly, by the first diff --git a/sqlx-mysql/src/transaction.rs b/sqlx-mysql/src/transaction.rs index f287c4a80b..953735bf9a 100644 --- a/sqlx-mysql/src/transaction.rs +++ b/sqlx-mysql/src/transaction.rs @@ -27,6 +27,9 @@ impl TransactionManager for MySqlTransactionManager { } let statement = statement.unwrap_or_else(|| begin_ansi_transaction_sql(depth)); conn.execute(&*statement).await?; + if !conn.in_transaction() { + return Err(Error::BeginFailed); + } conn.inner.transaction_depth += 1; Ok(()) diff --git a/sqlx-postgres/src/connection/mod.rs b/sqlx-postgres/src/connection/mod.rs index 04b9a4c9e2..96e3e2fe12 100644 --- a/sqlx-postgres/src/connection/mod.rs +++ b/sqlx-postgres/src/connection/mod.rs @@ -128,6 +128,13 @@ impl PgConnection { Ok(()) } + + pub(crate) fn in_transaction(&self) -> bool { + match self.inner.transaction_status { + TransactionStatus::Transaction => true, + TransactionStatus::Error | TransactionStatus::Idle => false, + } + } } impl Debug for PgConnection { diff --git a/sqlx-postgres/src/transaction.rs b/sqlx-postgres/src/transaction.rs index 767d83c52e..ec01129d6f 100644 --- a/sqlx-postgres/src/transaction.rs +++ b/sqlx-postgres/src/transaction.rs @@ -27,8 +27,11 @@ impl TransactionManager for PgTransactionManager { let rollback = Rollback::new(conn); rollback.conn.queue_simple_query(&statement)?; - rollback.conn.inner.transaction_depth += 1; rollback.conn.wait_until_ready().await?; + if !rollback.conn.in_transaction() { + return Err(Error::BeginFailed); + } + rollback.conn.inner.transaction_depth += 1; rollback.defuse(); Ok(()) diff --git a/sqlx-sqlite/src/connection/mod.rs b/sqlx-sqlite/src/connection/mod.rs index ba24976fb4..6b15fbca5d 100644 --- a/sqlx-sqlite/src/connection/mod.rs +++ b/sqlx-sqlite/src/connection/mod.rs @@ -12,8 +12,8 @@ use futures_core::future::BoxFuture; use futures_intrusive::sync::MutexGuard; use futures_util::future; use libsqlite3_sys::{ - sqlite3, sqlite3_commit_hook, sqlite3_progress_handler, sqlite3_rollback_hook, - sqlite3_update_hook, SQLITE_DELETE, SQLITE_INSERT, SQLITE_UPDATE, + sqlite3, sqlite3_commit_hook, sqlite3_get_autocommit, sqlite3_progress_handler, + sqlite3_rollback_hook, sqlite3_update_hook, SQLITE_DELETE, SQLITE_INSERT, SQLITE_UPDATE, }; #[cfg(feature = "preupdate-hook")] pub use preupdate_hook::*; @@ -553,6 +553,11 @@ impl LockedSqliteHandle<'_> { pub fn remove_rollback_hook(&mut self) { self.guard.remove_rollback_hook(); } + + pub(crate) fn in_transaction(&mut self) -> bool { + let ret = unsafe { sqlite3_get_autocommit(self.as_raw_handle().as_ptr()) }; + ret == 0 + } } impl Drop for ConnectionState { diff --git a/sqlx-sqlite/src/transaction.rs b/sqlx-sqlite/src/transaction.rs index d7c40d4956..d217cffd61 100644 --- a/sqlx-sqlite/src/transaction.rs +++ b/sqlx-sqlite/src/transaction.rs @@ -15,7 +15,18 @@ impl TransactionManager for SqliteTransactionManager { conn: &'conn mut SqliteConnection, statement: Option>, ) -> BoxFuture<'conn, Result<(), Error>> { - Box::pin(conn.worker.begin(statement)) + Box::pin(async { + let is_custom_statement = statement.is_some(); + conn.worker.begin(statement).await?; + if is_custom_statement { + // Check that custom statement actually put the connection into a transaction. + let mut handle = conn.lock_handle().await?; + if !handle.in_transaction() { + return Err(Error::BeginFailed); + } + } + Ok(()) + }) } fn commit(conn: &mut SqliteConnection) -> BoxFuture<'_, Result<(), Error>> { From 36a2fab38a4d3ab138a474d2988488586e6df0ca Mon Sep 17 00:00:00 2001 From: Duncan Fairbanks Date: Wed, 27 Nov 2024 18:34:41 -0800 Subject: [PATCH 14/27] chore: add tests of Error::BeginFailed --- tests/mysql/error.rs | 14 +++++++++++++- tests/postgres/error.rs | 14 +++++++++++++- tests/sqlite/error.rs | 14 +++++++++++++- 3 files changed, 39 insertions(+), 3 deletions(-) diff --git a/tests/mysql/error.rs b/tests/mysql/error.rs index 7c84266c32..090cbe1980 100644 --- a/tests/mysql/error.rs +++ b/tests/mysql/error.rs @@ -1,4 +1,4 @@ -use sqlx::{error::ErrorKind, mysql::MySql, Connection}; +use sqlx::{error::ErrorKind, mysql::MySql, Connection, Error}; use sqlx_test::new; #[sqlx_macros::test] @@ -74,3 +74,15 @@ async fn it_fails_with_check_violation() -> anyhow::Result<()> { Ok(()) } + +#[sqlx_macros::test] +async fn it_fails_with_begin_failed() -> anyhow::Result<()> { + let mut conn = new::().await?; + let res = conn.begin_with("SELECT * FROM tweet").await; + + let err = res.unwrap_err(); + + assert!(matches!(err, Error::BeginFailed), "{err:?}"); + + Ok(()) +} diff --git a/tests/postgres/error.rs b/tests/postgres/error.rs index d6f78140da..5e52155f33 100644 --- a/tests/postgres/error.rs +++ b/tests/postgres/error.rs @@ -1,4 +1,4 @@ -use sqlx::{error::ErrorKind, postgres::Postgres, Connection}; +use sqlx::{error::ErrorKind, postgres::Postgres, Connection, Error}; use sqlx_test::new; #[sqlx_macros::test] @@ -74,3 +74,15 @@ async fn it_fails_with_check_violation() -> anyhow::Result<()> { Ok(()) } + +#[sqlx_macros::test] +async fn it_fails_with_begin_failed() -> anyhow::Result<()> { + let mut conn = new::().await?; + let res = conn.begin_with("SELECT * FROM tweet").await; + + let err = res.unwrap_err(); + + assert!(matches!(err, Error::BeginFailed), "{err:?}"); + + Ok(()) +} diff --git a/tests/sqlite/error.rs b/tests/sqlite/error.rs index 1f6b797e69..2227a14d3b 100644 --- a/tests/sqlite/error.rs +++ b/tests/sqlite/error.rs @@ -1,4 +1,4 @@ -use sqlx::{error::ErrorKind, sqlite::Sqlite, Connection, Executor}; +use sqlx::{error::ErrorKind, sqlite::Sqlite, Connection, Error, Executor}; use sqlx_test::new; #[sqlx_macros::test] @@ -70,3 +70,15 @@ async fn it_fails_with_check_violation() -> anyhow::Result<()> { Ok(()) } + +#[sqlx_macros::test] +async fn it_fails_with_begin_failed() -> anyhow::Result<()> { + let mut conn = new::().await?; + let res = conn.begin_with("SELECT * FROM tweet").await; + + let err = res.unwrap_err(); + + assert!(matches!(err, Error::BeginFailed), "{err:?}"); + + Ok(()) +} From 1d9afb32f817c028742e08588bcd7ab96c39b33f Mon Sep 17 00:00:00 2001 From: Duncan Fairbanks Date: Wed, 27 Nov 2024 18:37:25 -0800 Subject: [PATCH 15/27] chore: add tests of Error::InvalidSavePointStatement --- tests/mysql/error.rs | 14 ++++++++++++++ tests/postgres/error.rs | 14 ++++++++++++++ tests/sqlite/error.rs | 14 ++++++++++++++ 3 files changed, 42 insertions(+) diff --git a/tests/mysql/error.rs b/tests/mysql/error.rs index 090cbe1980..3ee1024fc8 100644 --- a/tests/mysql/error.rs +++ b/tests/mysql/error.rs @@ -86,3 +86,17 @@ async fn it_fails_with_begin_failed() -> anyhow::Result<()> { Ok(()) } + +#[sqlx_macros::test] +async fn it_fails_with_invalid_save_point_statement() -> anyhow::Result<()> { + let mut conn = new::().await?; + let mut txn = conn.begin().await?; + let txn_conn = sqlx::Acquire::acquire(&mut txn).await?; + let res = txn_conn.begin_with("BEGIN").await; + + let err = res.unwrap_err(); + + assert!(matches!(err, Error::InvalidSavePointStatement), "{err}"); + + Ok(()) +} diff --git a/tests/postgres/error.rs b/tests/postgres/error.rs index 5e52155f33..32bf814770 100644 --- a/tests/postgres/error.rs +++ b/tests/postgres/error.rs @@ -86,3 +86,17 @@ async fn it_fails_with_begin_failed() -> anyhow::Result<()> { Ok(()) } + +#[sqlx_macros::test] +async fn it_fails_with_invalid_save_point_statement() -> anyhow::Result<()> { + let mut conn = new::().await?; + let mut txn = conn.begin().await?; + let txn_conn = sqlx::Acquire::acquire(&mut txn).await?; + let res = txn_conn.begin_with("BEGIN").await; + + let err = res.unwrap_err(); + + assert!(matches!(err, Error::InvalidSavePointStatement), "{err}"); + + Ok(()) +} diff --git a/tests/sqlite/error.rs b/tests/sqlite/error.rs index 2227a14d3b..8729842b70 100644 --- a/tests/sqlite/error.rs +++ b/tests/sqlite/error.rs @@ -82,3 +82,17 @@ async fn it_fails_with_begin_failed() -> anyhow::Result<()> { Ok(()) } + +#[sqlx_macros::test] +async fn it_fails_with_invalid_save_point_statement() -> anyhow::Result<()> { + let mut conn = new::().await?; + let mut txn = conn.begin().await?; + let txn_conn = sqlx::Acquire::acquire(&mut txn).await?; + let res = txn_conn.begin_with("BEGIN").await; + + let err = res.unwrap_err(); + + assert!(matches!(err, Error::InvalidSavePointStatement), "{err}"); + + Ok(()) +} From 6750b62b668d0b79aa48a9be6d0037d35ed8b48c Mon Sep 17 00:00:00 2001 From: Duncan Fairbanks Date: Wed, 27 Nov 2024 18:50:02 -0800 Subject: [PATCH 16/27] chore: test begin_with works for all SQLite "BEGIN" statements --- sqlx-sqlite/src/connection/mod.rs | 23 +++++++++++++++++++++ sqlx-sqlite/src/lib.rs | 4 +++- tests/sqlite/sqlite.rs | 33 +++++++++++++++++++++++++++++++ 3 files changed, 59 insertions(+), 1 deletion(-) diff --git a/sqlx-sqlite/src/connection/mod.rs b/sqlx-sqlite/src/connection/mod.rs index 6b15fbca5d..75d15ef2b6 100644 --- a/sqlx-sqlite/src/connection/mod.rs +++ b/sqlx-sqlite/src/connection/mod.rs @@ -558,6 +558,29 @@ impl LockedSqliteHandle<'_> { let ret = unsafe { sqlite3_get_autocommit(self.as_raw_handle().as_ptr()) }; ret == 0 } + + /// Calls `sqlite3_txn_state` on this handle. + pub fn transaction_state(&mut self) -> Result { + use libsqlite3_sys::{ + sqlite3_txn_state, SQLITE_TXN_NONE, SQLITE_TXN_READ, SQLITE_TXN_WRITE, + }; + + let state = + match unsafe { sqlite3_txn_state(self.as_raw_handle().as_ptr(), std::ptr::null()) } { + SQLITE_TXN_NONE => SqliteTransactionState::None, + SQLITE_TXN_READ => SqliteTransactionState::Read, + SQLITE_TXN_WRITE => SqliteTransactionState::Write, + _ => return Err(Error::Protocol("Invalid transaction state".into())), + }; + Ok(state) + } +} + +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub enum SqliteTransactionState { + None, + Read, + Write, } impl Drop for ConnectionState { diff --git a/sqlx-sqlite/src/lib.rs b/sqlx-sqlite/src/lib.rs index f1a45c3d34..18b71c4d4b 100644 --- a/sqlx-sqlite/src/lib.rs +++ b/sqlx-sqlite/src/lib.rs @@ -48,7 +48,9 @@ pub use arguments::{SqliteArgumentValue, SqliteArguments}; pub use column::SqliteColumn; #[cfg(feature = "preupdate-hook")] pub use connection::PreupdateHookResult; -pub use connection::{LockedSqliteHandle, SqliteConnection, SqliteOperation, UpdateHookResult}; +pub use connection::{ + LockedSqliteHandle, SqliteConnection, SqliteOperation, SqliteTransactionState, UpdateHookResult, +}; pub use database::Sqlite; pub use error::SqliteError; pub use options::{ diff --git a/tests/sqlite/sqlite.rs b/tests/sqlite/sqlite.rs index d78e1151a9..21cf7b9dcf 100644 --- a/tests/sqlite/sqlite.rs +++ b/tests/sqlite/sqlite.rs @@ -1172,3 +1172,36 @@ async fn test_multiple_set_preupdate_hook_calls_drop_old_handler() -> anyhow::Re assert_eq!(1, Arc::strong_count(&ref_counted_object)); Ok(()) } + +#[sqlx_macros::test] +async fn it_can_use_transaction_options() -> anyhow::Result<()> { + use sqlx_sqlite::SqliteTransactionState; + + async fn check_txn_state( + conn: &mut SqliteConnection, + expected: SqliteTransactionState, + ) -> Result<(), sqlx::Error> { + let state = conn.lock_handle().await?.transaction_state()?; + assert_eq!(state, expected); + Ok(()) + } + + let mut conn = new::().await?; + + check_txn_state(&mut conn, SqliteTransactionState::None).await?; + + let mut tx = conn.begin_with("BEGIN DEFERRED").await?; + check_txn_state(&mut *tx, SqliteTransactionState::None).await?; + drop(tx); + + let mut tx = conn.begin_with("BEGIN IMMEDIATE").await?; + check_txn_state(&mut *tx, SqliteTransactionState::Write).await?; + drop(tx); + + // Note: may result in database locked errors if tests are run in parallel + let mut tx = conn.begin_with("BEGIN EXCLUSIVE").await?; + check_txn_state(&mut *tx, SqliteTransactionState::Write).await?; + drop(tx); + + Ok(()) +} From bed0190d705ec611aa92fd7ae55965639e83617a Mon Sep 17 00:00:00 2001 From: Duncan Fairbanks Date: Mon, 2 Dec 2024 19:03:24 -0800 Subject: [PATCH 17/27] chore: improve comment on Connection::begin_with --- sqlx-core/src/connection.rs | 3 +++ 1 file changed, 3 insertions(+) diff --git a/sqlx-core/src/connection.rs b/sqlx-core/src/connection.rs index de0a05799d..dd9c974fdc 100644 --- a/sqlx-core/src/connection.rs +++ b/sqlx-core/src/connection.rs @@ -53,6 +53,9 @@ pub trait Connection: Send { /// Begin a new transaction with a custom statement. /// /// Returns a [`Transaction`] for controlling and tracking the new transaction. + /// + /// Returns an error if the connection is already in a transaction or if + /// `statement` does not put the connection into a transaction. fn begin_with( &mut self, statement: impl Into>, From 479ab06af2f4723391c4e84c36e98f53c63253af Mon Sep 17 00:00:00 2001 From: Duncan Fairbanks Date: Mon, 2 Dec 2024 19:04:33 -0800 Subject: [PATCH 18/27] feat: add default impl of `Connection::begin_with` This makes the new method a non-breaking change. --- sqlx-core/src/connection.rs | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/sqlx-core/src/connection.rs b/sqlx-core/src/connection.rs index dd9c974fdc..ba226bc814 100644 --- a/sqlx-core/src/connection.rs +++ b/sqlx-core/src/connection.rs @@ -61,7 +61,10 @@ pub trait Connection: Send { statement: impl Into>, ) -> BoxFuture<'_, Result, Error>> where - Self: Sized; + Self: Sized, + { + Transaction::begin(self, Some(statement.into())) + } /// Execute the function inside a transaction. /// From 36e028988341b13e6da2041a1952e4f5766f1e02 Mon Sep 17 00:00:00 2001 From: Duncan Fairbanks Date: Mon, 2 Dec 2024 18:40:17 -0800 Subject: [PATCH 19/27] refactor: combine if statement + unwrap_or_else into one match --- sqlx-mysql/src/transaction.rs | 11 +++++++---- sqlx-postgres/src/transaction.rs | 11 +++++++---- sqlx-sqlite/src/connection/worker.rs | 15 ++++++++------- 3 files changed, 22 insertions(+), 15 deletions(-) diff --git a/sqlx-mysql/src/transaction.rs b/sqlx-mysql/src/transaction.rs index 953735bf9a..11f56c0cb9 100644 --- a/sqlx-mysql/src/transaction.rs +++ b/sqlx-mysql/src/transaction.rs @@ -22,10 +22,13 @@ impl TransactionManager for MySqlTransactionManager { ) -> BoxFuture<'conn, Result<(), Error>> { Box::pin(async move { let depth = conn.inner.transaction_depth; - if statement.is_some() && depth > 0 { - return Err(Error::InvalidSavePointStatement); - } - let statement = statement.unwrap_or_else(|| begin_ansi_transaction_sql(depth)); + let statement = match statement { + // custom `BEGIN` statements are not allowed if we're already in a transaction + // (we need to issue a `SAVEPOINT` instead) + Some(_) if depth > 0 => return Err(Error::InvalidSavePointStatement), + Some(statement) => statement, + None => begin_ansi_transaction_sql(depth), + }; conn.execute(&*statement).await?; if !conn.in_transaction() { return Err(Error::BeginFailed); diff --git a/sqlx-postgres/src/transaction.rs b/sqlx-postgres/src/transaction.rs index ec01129d6f..f70961cc19 100644 --- a/sqlx-postgres/src/transaction.rs +++ b/sqlx-postgres/src/transaction.rs @@ -20,10 +20,13 @@ impl TransactionManager for PgTransactionManager { ) -> BoxFuture<'conn, Result<(), Error>> { Box::pin(async move { let depth = conn.inner.transaction_depth; - if statement.is_some() && depth > 0 { - return Err(Error::InvalidSavePointStatement); - } - let statement = statement.unwrap_or_else(|| begin_ansi_transaction_sql(depth)); + let statement = match statement { + // custom `BEGIN` statements are not allowed if we're already in + // a transaction (we need to issue a `SAVEPOINT` instead) + Some(_) if depth > 0 => return Err(Error::InvalidSavePointStatement), + Some(statement) => statement, + None => begin_ansi_transaction_sql(depth), + }; let rollback = Rollback::new(conn); rollback.conn.queue_simple_query(&statement)?; diff --git a/sqlx-sqlite/src/connection/worker.rs b/sqlx-sqlite/src/connection/worker.rs index 26a90314c2..3f6a13adef 100644 --- a/sqlx-sqlite/src/connection/worker.rs +++ b/sqlx-sqlite/src/connection/worker.rs @@ -186,17 +186,18 @@ impl ConnectionWorker { Command::Begin { tx, statement } => { let depth = conn.transaction_depth; - let statement = if depth == 0 { - statement.unwrap_or_else(|| begin_ansi_transaction_sql(depth)) - } else { - if statement.is_some() { + let statement = match statement { + // custom `BEGIN` statements are not allowed if + // we're already in a transaction (we need to + // issue a `SAVEPOINT` instead) + Some(_) if depth > 0 => { if tx.blocking_send(Err(Error::InvalidSavePointStatement)).is_err() { break; } continue; - } - - begin_ansi_transaction_sql(depth) + }, + Some(statement) => statement, + None => begin_ansi_transaction_sql(depth), }; let res = From 471e32fa083b9dece2c9da0febf676d1b9787172 Mon Sep 17 00:00:00 2001 From: Duncan Fairbanks Date: Mon, 2 Dec 2024 18:59:42 -0800 Subject: [PATCH 20/27] feat: use in-memory SQLite DB to avoid conflicts across tests run in parallel --- tests/sqlite/sqlite.rs | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/tests/sqlite/sqlite.rs b/tests/sqlite/sqlite.rs index 21cf7b9dcf..bbf4d2737a 100644 --- a/tests/sqlite/sqlite.rs +++ b/tests/sqlite/sqlite.rs @@ -1186,21 +1186,24 @@ async fn it_can_use_transaction_options() -> anyhow::Result<()> { Ok(()) } - let mut conn = new::().await?; + let mut conn = SqliteConnectOptions::new() + .in_memory(true) + .connect() + .await + .unwrap(); check_txn_state(&mut conn, SqliteTransactionState::None).await?; let mut tx = conn.begin_with("BEGIN DEFERRED").await?; - check_txn_state(&mut *tx, SqliteTransactionState::None).await?; + check_txn_state(&mut tx, SqliteTransactionState::None).await?; drop(tx); let mut tx = conn.begin_with("BEGIN IMMEDIATE").await?; - check_txn_state(&mut *tx, SqliteTransactionState::Write).await?; + check_txn_state(&mut tx, SqliteTransactionState::Write).await?; drop(tx); - // Note: may result in database locked errors if tests are run in parallel let mut tx = conn.begin_with("BEGIN EXCLUSIVE").await?; - check_txn_state(&mut *tx, SqliteTransactionState::Write).await?; + check_txn_state(&mut tx, SqliteTransactionState::Write).await?; drop(tx); Ok(()) From ba57644e3c84e370f76fffe96362098194b90935 Mon Sep 17 00:00:00 2001 From: Duncan Fairbanks Date: Mon, 2 Dec 2024 18:55:29 -0800 Subject: [PATCH 21/27] feedback: remove public wrapper for sqlite3_txn_state Move the wrapper directly into the test that uses it instead. --- Cargo.toml | 1 + sqlx-sqlite/src/connection/mod.rs | 23 ------------------ sqlx-sqlite/src/lib.rs | 4 +--- tests/sqlite/sqlite.rs | 39 +++++++++++++++++++++---------- 4 files changed, 29 insertions(+), 38 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index f93ed3dded..5f57f98537 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -187,6 +187,7 @@ rand_xoshiro = "0.6.0" hex = "0.4.3" tempfile = "3.10.1" criterion = { version = "0.5.1", features = ["async_tokio"] } +libsqlite3-sys = { version = "0.30.1" } # If this is an unconditional dev-dependency then Cargo will *always* try to build `libsqlite3-sys`, # even when SQLite isn't the intended test target, and fail if the build environment is not set up for compiling C code. diff --git a/sqlx-sqlite/src/connection/mod.rs b/sqlx-sqlite/src/connection/mod.rs index 75d15ef2b6..6b15fbca5d 100644 --- a/sqlx-sqlite/src/connection/mod.rs +++ b/sqlx-sqlite/src/connection/mod.rs @@ -558,29 +558,6 @@ impl LockedSqliteHandle<'_> { let ret = unsafe { sqlite3_get_autocommit(self.as_raw_handle().as_ptr()) }; ret == 0 } - - /// Calls `sqlite3_txn_state` on this handle. - pub fn transaction_state(&mut self) -> Result { - use libsqlite3_sys::{ - sqlite3_txn_state, SQLITE_TXN_NONE, SQLITE_TXN_READ, SQLITE_TXN_WRITE, - }; - - let state = - match unsafe { sqlite3_txn_state(self.as_raw_handle().as_ptr(), std::ptr::null()) } { - SQLITE_TXN_NONE => SqliteTransactionState::None, - SQLITE_TXN_READ => SqliteTransactionState::Read, - SQLITE_TXN_WRITE => SqliteTransactionState::Write, - _ => return Err(Error::Protocol("Invalid transaction state".into())), - }; - Ok(state) - } -} - -#[derive(Clone, Copy, Debug, PartialEq, Eq)] -pub enum SqliteTransactionState { - None, - Read, - Write, } impl Drop for ConnectionState { diff --git a/sqlx-sqlite/src/lib.rs b/sqlx-sqlite/src/lib.rs index 18b71c4d4b..f1a45c3d34 100644 --- a/sqlx-sqlite/src/lib.rs +++ b/sqlx-sqlite/src/lib.rs @@ -48,9 +48,7 @@ pub use arguments::{SqliteArgumentValue, SqliteArguments}; pub use column::SqliteColumn; #[cfg(feature = "preupdate-hook")] pub use connection::PreupdateHookResult; -pub use connection::{ - LockedSqliteHandle, SqliteConnection, SqliteOperation, SqliteTransactionState, UpdateHookResult, -}; +pub use connection::{LockedSqliteHandle, SqliteConnection, SqliteOperation, UpdateHookResult}; pub use database::Sqlite; pub use error::SqliteError; pub use options::{ diff --git a/tests/sqlite/sqlite.rs b/tests/sqlite/sqlite.rs index bbf4d2737a..ee9651eeb4 100644 --- a/tests/sqlite/sqlite.rs +++ b/tests/sqlite/sqlite.rs @@ -8,6 +8,7 @@ use sqlx::{ SqliteConnection, SqlitePool, Statement, TypeInfo, }; use sqlx::{Value, ValueRef}; +use sqlx_sqlite::LockedSqliteHandle; use sqlx_test::new; use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::Arc; @@ -1175,15 +1176,9 @@ async fn test_multiple_set_preupdate_hook_calls_drop_old_handler() -> anyhow::Re #[sqlx_macros::test] async fn it_can_use_transaction_options() -> anyhow::Result<()> { - use sqlx_sqlite::SqliteTransactionState; - - async fn check_txn_state( - conn: &mut SqliteConnection, - expected: SqliteTransactionState, - ) -> Result<(), sqlx::Error> { - let state = conn.lock_handle().await?.transaction_state()?; + async fn check_txn_state(conn: &mut SqliteConnection, expected: SqliteTransactionState) { + let state = transaction_state(&mut conn.lock_handle().await.unwrap()); assert_eq!(state, expected); - Ok(()) } let mut conn = SqliteConnectOptions::new() @@ -1192,19 +1187,39 @@ async fn it_can_use_transaction_options() -> anyhow::Result<()> { .await .unwrap(); - check_txn_state(&mut conn, SqliteTransactionState::None).await?; + check_txn_state(&mut conn, SqliteTransactionState::None).await; let mut tx = conn.begin_with("BEGIN DEFERRED").await?; - check_txn_state(&mut tx, SqliteTransactionState::None).await?; + check_txn_state(&mut tx, SqliteTransactionState::None).await; drop(tx); let mut tx = conn.begin_with("BEGIN IMMEDIATE").await?; - check_txn_state(&mut tx, SqliteTransactionState::Write).await?; + check_txn_state(&mut tx, SqliteTransactionState::Write).await; drop(tx); let mut tx = conn.begin_with("BEGIN EXCLUSIVE").await?; - check_txn_state(&mut tx, SqliteTransactionState::Write).await?; + check_txn_state(&mut tx, SqliteTransactionState::Write).await; drop(tx); Ok(()) } + +fn transaction_state(handle: &mut LockedSqliteHandle) -> SqliteTransactionState { + use libsqlite3_sys::{sqlite3_txn_state, SQLITE_TXN_NONE, SQLITE_TXN_READ, SQLITE_TXN_WRITE}; + + let unchecked_state = + unsafe { sqlite3_txn_state(handle.as_raw_handle().as_ptr(), std::ptr::null()) }; + match unchecked_state { + SQLITE_TXN_NONE => SqliteTransactionState::None, + SQLITE_TXN_READ => SqliteTransactionState::Read, + SQLITE_TXN_WRITE => SqliteTransactionState::Write, + _ => panic!("unknown txn state: {unchecked_state}"), + } +} + +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +enum SqliteTransactionState { + None, + Read, + Write, +} From 07310dd3576fb9194fe89a3953123eb0a89b5481 Mon Sep 17 00:00:00 2001 From: Duncan Fairbanks Date: Tue, 3 Dec 2024 14:05:13 -0800 Subject: [PATCH 22/27] fix: cache Status on MySqlConnection --- sqlx-mysql/src/connection/executor.rs | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/sqlx-mysql/src/connection/executor.rs b/sqlx-mysql/src/connection/executor.rs index 07c7979b08..169dee76b7 100644 --- a/sqlx-mysql/src/connection/executor.rs +++ b/sqlx-mysql/src/connection/executor.rs @@ -166,6 +166,8 @@ impl MySqlConnection { // this indicates either a successful query with no rows at all or a failed query let ok = packet.ok()?; + self.inner.status_flags = ok.status; + let rows_affected = ok.affected_rows; logger.increase_rows_affected(rows_affected); let done = MySqlQueryResult { @@ -208,6 +210,8 @@ impl MySqlConnection { if packet[0] == 0xfe && packet.len() < 9 { let eof = packet.eof(self.inner.stream.capabilities)?; + self.inner.status_flags = eof.status; + r#yield!(Either::Left(MySqlQueryResult { rows_affected: 0, last_insert_id: 0, From 2b5b59432f5e6896ce5d4889cb08f78d6a37fc8f Mon Sep 17 00:00:00 2001 From: Austin Bonander Date: Sun, 2 Mar 2025 16:39:55 -0800 Subject: [PATCH 23/27] fix: compilation errors --- sqlx-postgres/src/connection/mod.rs | 2 +- sqlx-postgres/src/transaction.rs | 4 ++-- sqlx-sqlite/src/transaction.rs | 1 - 3 files changed, 3 insertions(+), 4 deletions(-) diff --git a/sqlx-postgres/src/connection/mod.rs b/sqlx-postgres/src/connection/mod.rs index 587dcd0f1d..cbf686af99 100644 --- a/sqlx-postgres/src/connection/mod.rs +++ b/sqlx-postgres/src/connection/mod.rs @@ -201,7 +201,7 @@ impl Connection for PgConnection { } fn get_transaction_depth(&self) -> usize { - self.transaction_depth + self.inner.transaction_depth } fn cached_statements_size(&self) -> usize { diff --git a/sqlx-postgres/src/transaction.rs b/sqlx-postgres/src/transaction.rs index 0306d94078..ef02dd3df1 100644 --- a/sqlx-postgres/src/transaction.rs +++ b/sqlx-postgres/src/transaction.rs @@ -61,7 +61,7 @@ impl TransactionManager for PgTransactionManager { conn.execute(&*rollback_ansi_transaction_sql( conn.inner.transaction_depth, )) - .await?; + .await?; conn.inner.transaction_depth -= 1; } @@ -80,7 +80,7 @@ impl TransactionManager for PgTransactionManager { } fn get_transaction_depth(conn: &::Connection) -> usize { - conn.transaction_depth + conn.inner.transaction_depth } } diff --git a/sqlx-sqlite/src/transaction.rs b/sqlx-sqlite/src/transaction.rs index ce788b2706..55a80ab9f3 100644 --- a/sqlx-sqlite/src/transaction.rs +++ b/sqlx-sqlite/src/transaction.rs @@ -1,7 +1,6 @@ use futures_core::future::BoxFuture; use std::borrow::Cow; -use crate::{Sqlite, SqliteConnection}; use sqlx_core::error::Error; use sqlx_core::transaction::TransactionManager; From 3c08ede28784cc7bd74d423622a4aad10ab1e7ab Mon Sep 17 00:00:00 2001 From: Austin Bonander Date: Sun, 2 Mar 2025 16:44:48 -0800 Subject: [PATCH 24/27] fix: format --- sqlx-postgres/src/transaction.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sqlx-postgres/src/transaction.rs b/sqlx-postgres/src/transaction.rs index ef02dd3df1..23352a8dcf 100644 --- a/sqlx-postgres/src/transaction.rs +++ b/sqlx-postgres/src/transaction.rs @@ -61,7 +61,7 @@ impl TransactionManager for PgTransactionManager { conn.execute(&*rollback_ansi_transaction_sql( conn.inner.transaction_depth, )) - .await?; + .await?; conn.inner.transaction_depth -= 1; } From d035d7974d94929ecc2d94c637a2120b0bacdf1a Mon Sep 17 00:00:00 2001 From: Austin Bonander Date: Mon, 3 Mar 2025 16:01:52 -0800 Subject: [PATCH 25/27] fix: postgres test --- tests/postgres/postgres.rs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/postgres/postgres.rs b/tests/postgres/postgres.rs index 44aa3701ac..384038409b 100644 --- a/tests/postgres/postgres.rs +++ b/tests/postgres/postgres.rs @@ -524,7 +524,7 @@ async fn it_can_work_with_nested_transactions() -> anyhow::Result<()> { // begin let mut tx = conn.begin().await?; // transaction - assert_eq!(conn.get_transaction_depth(), 1); + assert_eq!(tx.get_transaction_depth(), 1); // insert a user sqlx::query("INSERT INTO _sqlx_users_2523 (id) VALUES ($1)") @@ -534,7 +534,7 @@ async fn it_can_work_with_nested_transactions() -> anyhow::Result<()> { // begin once more let mut tx2 = tx.begin().await?; // savepoint - assert_eq!(conn.get_transaction_depth(), 2); + assert_eq!(tx2.get_transaction_depth(), 2); // insert another user sqlx::query("INSERT INTO _sqlx_users_2523 (id) VALUES ($1)") @@ -544,7 +544,7 @@ async fn it_can_work_with_nested_transactions() -> anyhow::Result<()> { // never mind, rollback tx2.rollback().await?; // roll that one back - assert_eq!(conn.get_transaction_depth(), 1); + assert_eq!(tx.get_transaction_depth(), 1); // did we really? let (count,): (i64,) = sqlx::query_as("SELECT COUNT(*) FROM _sqlx_users_2523") From dbf7b12471c6b2fbed73a3ca316b0ec7a632f417 Mon Sep 17 00:00:00 2001 From: Austin Bonander Date: Tue, 4 Mar 2025 18:45:47 -0800 Subject: [PATCH 26/27] refactor: delete `Connection::get_transaction_depth` --- sqlx-core/src/any/connection/mod.rs | 4 ---- sqlx-core/src/connection.rs | 22 ++++++---------------- sqlx-mysql/src/connection/mod.rs | 4 ---- sqlx-postgres/src/connection/mod.rs | 4 ---- sqlx-sqlite/src/connection/mod.rs | 4 ---- tests/postgres/postgres.rs | 10 +++++----- 6 files changed, 11 insertions(+), 37 deletions(-) diff --git a/sqlx-core/src/any/connection/mod.rs b/sqlx-core/src/any/connection/mod.rs index b906c67c32..8cf8fc510c 100644 --- a/sqlx-core/src/any/connection/mod.rs +++ b/sqlx-core/src/any/connection/mod.rs @@ -101,10 +101,6 @@ impl Connection for AnyConnection { Transaction::begin(self, Some(statement.into())) } - fn get_transaction_depth(&self) -> usize { - self.backend.get_transaction_depth() - } - fn cached_statements_size(&self) -> usize { self.backend.cached_statements_size() } diff --git a/sqlx-core/src/connection.rs b/sqlx-core/src/connection.rs index cc1784dcc4..74e8cd3e8b 100644 --- a/sqlx-core/src/connection.rs +++ b/sqlx-core/src/connection.rs @@ -66,25 +66,15 @@ pub trait Connection: Send { Transaction::begin(self, Some(statement.into())) } - /// Returns the current transaction depth. + /// Returns `true` if the connection is currently in a transaction. /// - /// Transaction depth indicates the level of nested transactions: - /// - Level 0: No active transaction. - /// - Level 1: A transaction is active. - /// - Level 2 or higher: A transaction is active and one or more SAVEPOINTs have been created within it. - fn get_transaction_depth(&self) -> usize { - // Fallback implementation to avoid breaking changes - ::TransactionManager::get_transaction_depth(self) - } - - /// Checks if the connection is currently in a transaction. - /// - /// This method returns `true` if the current transaction depth is greater than 0, - /// indicating that a transaction is active. It returns `false` if the transaction depth is 0, - /// meaning no transaction is active. + /// # Note: Automatic Rollbacks May Not Be Counted + /// Certain database errors (such as a serializable isolation failure) + /// can cause automatic rollbacks of a transaction + /// which may not be indicated in the return value of this method. #[inline] fn is_in_transaction(&self) -> bool { - self.get_transaction_depth() != 0 + ::TransactionManager::get_transaction_depth(self) != 0 } /// Execute the function inside a transaction. diff --git a/sqlx-mysql/src/connection/mod.rs b/sqlx-mysql/src/connection/mod.rs index e3e101e260..0a2f5fb839 100644 --- a/sqlx-mysql/src/connection/mod.rs +++ b/sqlx-mysql/src/connection/mod.rs @@ -135,10 +135,6 @@ impl Connection for MySqlConnection { Transaction::begin(self, Some(statement.into())) } - fn get_transaction_depth(&self) -> usize { - self.inner.transaction_depth - } - fn shrink_buffers(&mut self) { self.inner.stream.shrink_buffers(); } diff --git a/sqlx-postgres/src/connection/mod.rs b/sqlx-postgres/src/connection/mod.rs index cbf686af99..96e3e2fe12 100644 --- a/sqlx-postgres/src/connection/mod.rs +++ b/sqlx-postgres/src/connection/mod.rs @@ -200,10 +200,6 @@ impl Connection for PgConnection { Transaction::begin(self, Some(statement.into())) } - fn get_transaction_depth(&self) -> usize { - self.inner.transaction_depth - } - fn cached_statements_size(&self) -> usize { self.inner.cache_statement.len() } diff --git a/sqlx-sqlite/src/connection/mod.rs b/sqlx-sqlite/src/connection/mod.rs index c01e2b0210..b94ad91c4d 100644 --- a/sqlx-sqlite/src/connection/mod.rs +++ b/sqlx-sqlite/src/connection/mod.rs @@ -264,10 +264,6 @@ impl Connection for SqliteConnection { Transaction::begin(self, Some(statement.into())) } - fn get_transaction_depth(&self) -> usize { - self.worker.shared.get_transaction_depth() - } - fn cached_statements_size(&self) -> usize { self.worker.shared.get_cached_statements_size() } diff --git a/tests/postgres/postgres.rs b/tests/postgres/postgres.rs index 384038409b..9ef0d0961b 100644 --- a/tests/postgres/postgres.rs +++ b/tests/postgres/postgres.rs @@ -515,7 +515,7 @@ async fn it_can_work_with_transactions() -> anyhow::Result<()> { #[sqlx_macros::test] async fn it_can_work_with_nested_transactions() -> anyhow::Result<()> { let mut conn = new::().await?; - assert_eq!(conn.get_transaction_depth(), 0); + assert!(!conn.is_in_transaction()); conn.execute("CREATE TABLE IF NOT EXISTS _sqlx_users_2523 (id INTEGER PRIMARY KEY)") .await?; @@ -524,7 +524,7 @@ async fn it_can_work_with_nested_transactions() -> anyhow::Result<()> { // begin let mut tx = conn.begin().await?; // transaction - assert_eq!(tx.get_transaction_depth(), 1); + assert!(conn.is_in_transaction()); // insert a user sqlx::query("INSERT INTO _sqlx_users_2523 (id) VALUES ($1)") @@ -534,7 +534,7 @@ async fn it_can_work_with_nested_transactions() -> anyhow::Result<()> { // begin once more let mut tx2 = tx.begin().await?; // savepoint - assert_eq!(tx2.get_transaction_depth(), 2); + assert!(conn.is_in_transaction()); // insert another user sqlx::query("INSERT INTO _sqlx_users_2523 (id) VALUES ($1)") @@ -544,7 +544,7 @@ async fn it_can_work_with_nested_transactions() -> anyhow::Result<()> { // never mind, rollback tx2.rollback().await?; // roll that one back - assert_eq!(tx.get_transaction_depth(), 1); + assert!(conn.is_in_transaction()); // did we really? let (count,): (i64,) = sqlx::query_as("SELECT COUNT(*) FROM _sqlx_users_2523") @@ -555,7 +555,7 @@ async fn it_can_work_with_nested_transactions() -> anyhow::Result<()> { // actually, commit tx.commit().await?; - assert_eq!(conn.get_transaction_depth(), 0); + assert!(!conn.is_in_transaction()); // did we really? let (count,): (i64,) = sqlx::query_as("SELECT COUNT(*) FROM _sqlx_users_2523") From d32f22f8ebde7ee3e610d1d5bcf02629e26ed9d6 Mon Sep 17 00:00:00 2001 From: Austin Bonander Date: Mon, 10 Mar 2025 14:03:44 -0700 Subject: [PATCH 27/27] fix: tests --- tests/postgres/postgres.rs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/postgres/postgres.rs b/tests/postgres/postgres.rs index 9ef0d0961b..fc7108bf4f 100644 --- a/tests/postgres/postgres.rs +++ b/tests/postgres/postgres.rs @@ -524,7 +524,7 @@ async fn it_can_work_with_nested_transactions() -> anyhow::Result<()> { // begin let mut tx = conn.begin().await?; // transaction - assert!(conn.is_in_transaction()); + assert!(tx.is_in_transaction()); // insert a user sqlx::query("INSERT INTO _sqlx_users_2523 (id) VALUES ($1)") @@ -534,7 +534,7 @@ async fn it_can_work_with_nested_transactions() -> anyhow::Result<()> { // begin once more let mut tx2 = tx.begin().await?; // savepoint - assert!(conn.is_in_transaction()); + assert!(tx2.is_in_transaction()); // insert another user sqlx::query("INSERT INTO _sqlx_users_2523 (id) VALUES ($1)") @@ -544,7 +544,7 @@ async fn it_can_work_with_nested_transactions() -> anyhow::Result<()> { // never mind, rollback tx2.rollback().await?; // roll that one back - assert!(conn.is_in_transaction()); + assert!(tx.is_in_transaction()); // did we really? let (count,): (i64,) = sqlx::query_as("SELECT COUNT(*) FROM _sqlx_users_2523")