From 7f1fdd958325f477528d50d923bc8a6138d01477 Mon Sep 17 00:00:00 2001 From: Austin Schey Date: Tue, 3 Dec 2024 20:27:31 -0800 Subject: [PATCH 1/4] feat: add preupdate hook --- .github/workflows/sqlx.yml | 6 +- Cargo.lock | 12 +- Cargo.toml | 3 +- README.md | 4 + sqlx-sqlite/Cargo.toml | 2 + sqlx-sqlite/src/connection/establish.rs | 2 + sqlx-sqlite/src/connection/mod.rs | 221 +++++++++++++++++++++++ sqlx-sqlite/src/lib.rs | 4 + tests/sqlite/sqlite.rs | 230 +++++++++++++++++++++++- 9 files changed, 471 insertions(+), 13 deletions(-) diff --git a/.github/workflows/sqlx.yml b/.github/workflows/sqlx.yml index 5a7673f26b..7d3845570c 100644 --- a/.github/workflows/sqlx.yml +++ b/.github/workflows/sqlx.yml @@ -39,7 +39,7 @@ jobs: - run: > cargo clippy --no-default-features - --features all-databases,_unstable-all-types,runtime-${{ matrix.runtime }},tls-${{ matrix.tls }},macros + --features all-databases,_unstable-all-types,sqlite-preupdate-hook,runtime-${{ matrix.runtime }},tls-${{ matrix.tls }},macros -- -D warnings # Run beta for new warnings but don't break the build. @@ -47,7 +47,7 @@ jobs: - run: > cargo +beta clippy --no-default-features - --features all-databases,_unstable-all-types,runtime-${{ matrix.runtime }},tls-${{ matrix.tls }},macros + --features all-databases,_unstable-all-types,sqlite-preupdate-hook,runtime-${{ matrix.runtime }},tls-${{ matrix.tls }},macros --target-dir target/beta/ check-minimal-versions: @@ -140,7 +140,7 @@ jobs: - run: > cargo test --no-default-features - --features any,macros,${{ matrix.linking }},_unstable-all-types,runtime-${{ matrix.runtime }} + --features any,macros,${{ matrix.linking }},${{ matrix.linking == 'sqlite' && 'sqlite-preupdate-hook,' || ''}}_unstable-all-types,runtime-${{ matrix.runtime }} -- --test-threads=1 env: diff --git a/Cargo.lock b/Cargo.lock index 2da47afa54..36cb94422e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1177,7 +1177,7 @@ checksum = "97369cbbc041bc366949bc74d34658d6cda5621039731c6310521892a3a20ae0" dependencies = [ "proc-macro2", "quote", - "syn 2.0.52", + "syn 2.0.87", ] [[package]] @@ -1914,7 +1914,7 @@ checksum = "1ec89e9337638ecdc08744df490b221a7399bf8d164eb52a665454e60e075ad6" dependencies = [ "proc-macro2", "quote", - "syn 2.0.52", + "syn 2.0.87", ] [[package]] @@ -3986,7 +3986,7 @@ checksum = "c8af7666ab7b6390ab78131fb5b0fce11d6b7a6951602017c35fa82800708971" dependencies = [ "proc-macro2", "quote", - "syn 2.0.52", + "syn 2.0.87", ] [[package]] @@ -4815,7 +4815,7 @@ checksum = "28cc31741b18cb6f1d5ff12f5b7523e3d6eb0852bbbad19d73905511d9849b95" dependencies = [ "proc-macro2", "quote", - "syn 2.0.52", + "syn 2.0.87", "synstructure", ] @@ -4856,7 +4856,7 @@ checksum = "0ea7b4a3637ea8669cedf0f1fd5c286a17f3de97b8dd5a70a6c167a1730e63a5" dependencies = [ "proc-macro2", "quote", - "syn 2.0.52", + "syn 2.0.87", "synstructure", ] @@ -4899,5 +4899,5 @@ checksum = "6eafa6dfb17584ea3e2bd6e76e0cc15ad7af12b09abdd1ca55961bed9b1063c6" dependencies = [ "proc-macro2", "quote", - "syn 2.0.52", + "syn 2.0.87", ] diff --git a/Cargo.toml b/Cargo.toml index 49aefd7a93..fdd7983f0d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -50,7 +50,7 @@ authors.workspace = true repository.workspace = true [package.metadata.docs.rs] -features = ["all-databases", "_unstable-all-types"] +features = ["all-databases", "_unstable-all-types", "sqlite-preupdate-hook"] rustdoc-args = ["--cfg", "docsrs"] [features] @@ -108,6 +108,7 @@ postgres = ["sqlx-postgres", "sqlx-macros?/postgres"] mysql = ["sqlx-mysql", "sqlx-macros?/mysql"] sqlite = ["_sqlite", "sqlx-sqlite/bundled", "sqlx-macros?/sqlite"] sqlite-unbundled = ["_sqlite", "sqlx-sqlite/unbundled", "sqlx-macros?/sqlite-unbundled"] +sqlite-preupdate-hook = ["sqlx-sqlite/preupdate-hook"] # types json = ["sqlx-macros?/json", "sqlx-mysql?/json", "sqlx-postgres?/json", "sqlx-sqlite?/json"] diff --git a/README.md b/README.md index 4d4a2338ec..15d68bbb42 100644 --- a/README.md +++ b/README.md @@ -196,6 +196,10 @@ be removed in the future. * May result in link errors if the SQLite version is too old. Version `3.20.0` or newer is recommended. * Can increase build time due to the use of bindgen. +- `sqlite-preupdate-hook`: enables SQLite's [preupdate hook](https://sqlite.org/c3ref/preupdate_count.html) API. + * Exposed as a separate feature because it's generally not enabled by default. + * Using this feature with `sqlite-unbundled` may cause linker failures if the system SQLite version does not support it. + - `any`: Add support for the `Any` database driver, which can proxy to a database driver at runtime. - `derive`: Add support for the derive family macros, those are `FromRow`, `Type`, `Encode`, `Decode`. diff --git a/sqlx-sqlite/Cargo.toml b/sqlx-sqlite/Cargo.toml index 391bf4523c..f9375d68f6 100644 --- a/sqlx-sqlite/Cargo.toml +++ b/sqlx-sqlite/Cargo.toml @@ -23,6 +23,8 @@ uuid = ["dep:uuid", "sqlx-core/uuid"] regexp = ["dep:regex"] +preupdate-hook = ["libsqlite3-sys/preupdate_hook"] + bundled = ["libsqlite3-sys/bundled"] unbundled = ["libsqlite3-sys/buildtime_bindgen"] diff --git a/sqlx-sqlite/src/connection/establish.rs b/sqlx-sqlite/src/connection/establish.rs index 40f9b4c302..5b8aa01b62 100644 --- a/sqlx-sqlite/src/connection/establish.rs +++ b/sqlx-sqlite/src/connection/establish.rs @@ -296,6 +296,8 @@ impl EstablishParams { log_settings: self.log_settings.clone(), progress_handler_callback: None, update_hook_callback: None, + #[cfg(feature = "preupdate-hook")] + preupdate_hook_callback: None, commit_hook_callback: None, rollback_hook_callback: None, }) diff --git a/sqlx-sqlite/src/connection/mod.rs b/sqlx-sqlite/src/connection/mod.rs index a579b8a605..8c1554aa26 100644 --- a/sqlx-sqlite/src/connection/mod.rs +++ b/sqlx-sqlite/src/connection/mod.rs @@ -14,6 +14,8 @@ use libsqlite3_sys::{ sqlite3, sqlite3_commit_hook, sqlite3_progress_handler, sqlite3_rollback_hook, sqlite3_update_hook, SQLITE_DELETE, SQLITE_INSERT, SQLITE_UPDATE, }; +#[cfg(feature = "preupdate-hook")] +pub use preupdate_hook::*; pub(crate) use handle::ConnectionHandle; use sqlx_core::common::StatementCache; @@ -88,6 +90,7 @@ pub struct UpdateHookResult<'a> { pub table: &'a str, pub rowid: i64, } + pub(crate) struct UpdateHookHandler(NonNull); unsafe impl Send for UpdateHookHandler {} @@ -112,6 +115,8 @@ pub(crate) struct ConnectionState { progress_handler_callback: Option, update_hook_callback: Option, + #[cfg(feature = "preupdate-hook")] + preupdate_hook_callback: Option, commit_hook_callback: Option, @@ -544,3 +549,219 @@ impl Statements { self.temp = None; } } + +#[cfg(feature = "preupdate-hook")] +mod preupdate_hook { + use super::ConnectionState; + use super::LockedSqliteHandle; + use super::SqliteOperation; + use crate::type_info::DataType; + use crate::{SqliteError, SqliteTypeInfo, SqliteValue}; + use libsqlite3_sys::{ + sqlite3, sqlite3_preupdate_count, sqlite3_preupdate_depth, sqlite3_preupdate_hook, + sqlite3_preupdate_new, sqlite3_preupdate_old, sqlite3_value, sqlite3_value_type, SQLITE_OK, + }; + use sqlx_core::error::Error; + use std::ffi::CStr; + use std::fmt::Debug; + use std::os::raw::{c_char, c_int, c_void}; + use std::panic::catch_unwind; + use std::ptr; + use std::ptr::NonNull; + + pub struct PreupdateHookResult<'a> { + pub operation: SqliteOperation, + pub database: &'a str, + pub table: &'a str, + pub case: PreupdateCase, + } + + pub(crate) struct PreupdateHookHandler( + NonNull, + ); + unsafe impl Send for PreupdateHookHandler {} + + /// The possible cases for when a PreUpdate Hook gets triggered. Allows access to the relevant + /// functions for each case through the contained values. + pub enum PreupdateCase { + /// Pre-update hook was triggered by an insert. + Insert(PreupdateNewValueAccessor), + /// Pre-update hook was triggered by a delete. + Delete(PreupdateOldValueAccessor), + /// Pre-update hook was triggered by an update. + Update { + old_value_accessor: PreupdateOldValueAccessor, + new_value_accessor: PreupdateNewValueAccessor, + }, + /// This variant is not normally produced by SQLite. You may encounter it + /// if you're using a different version than what's supported by this library. + Unknown, + } + + /// An accessor for the old values of the row being deleted/updated during the preupdate callback. + #[derive(Debug)] + pub struct PreupdateOldValueAccessor { + db: *mut sqlite3, + old_row_id: i64, + } + + impl PreupdateOldValueAccessor { + /// Gets the amount of columns in the row being deleted/updated. + pub fn get_column_count(&self) -> i32 { + unsafe { sqlite3_preupdate_count(self.db) } + } + + /// Gets the depth of the query that triggered the preupdate hook. + /// Returns 0 if the preupdate callback was invoked as a result of + /// a direct insert, update, or delete operation; + /// 1 for inserts, updates, or deletes invoked by top-level triggers; + /// 2 for changes resulting from triggers called by top-level triggers; and so forth. + pub fn get_query_depth(&self) -> i32 { + unsafe { sqlite3_preupdate_depth(self.db) } + } + + /// Gets the row id of the row being updated/deleted. + pub fn get_old_row_id(&self) -> i64 { + self.old_row_id + } + + /// Gets the value of the row being updated/deleted at the specified index. + pub fn get_old_column_value(&self, i: i32) -> Result { + let mut p_value: *mut sqlite3_value = ptr::null_mut(); + unsafe { + let ret = sqlite3_preupdate_old(self.db, i, &mut p_value); + if ret != SQLITE_OK { + return Err(Error::Database(Box::new(SqliteError::new(self.db)))); + } + let data_type = DataType::from_code(sqlite3_value_type(p_value)); + Ok(SqliteValue::new(p_value, SqliteTypeInfo(data_type))) + } + } + } + + /// An accessor for the new values of the row being inserted/updated during the preupdate callback. + #[derive(Debug)] + pub struct PreupdateNewValueAccessor { + db: *mut sqlite3, + new_row_id: i64, + } + + impl PreupdateNewValueAccessor { + /// Gets the amount of columns in the row being inserted/updated. + pub fn get_column_count(&self) -> i32 { + unsafe { sqlite3_preupdate_count(self.db) } + } + + /// Gets the depth of the query that triggered the preupdate hook. + /// Returns 0 if the preupdate callback was invoked as a result of + /// a direct insert, update, or delete operation; + /// 1 for inserts, updates, or deletes invoked by top-level triggers; + /// 2 for changes resulting from triggers called by top-level triggers; and so forth. + pub fn get_query_depth(&self) -> i32 { + unsafe { sqlite3_preupdate_depth(self.db) } + } + + /// Gets the row id of the row being inserted/updated. + pub fn get_new_row_id(&self) -> i64 { + self.new_row_id + } + + /// Gets the value of the row being updated/deleted at the specified index. + pub fn get_new_column_value(&self, i: i32) -> Result { + let mut p_value: *mut sqlite3_value = ptr::null_mut(); + unsafe { + let ret = sqlite3_preupdate_new(self.db, i, &mut p_value); + if ret != SQLITE_OK { + return Err(Error::Database(Box::new(SqliteError::new(self.db)))); + } + let data_type = DataType::from_code(sqlite3_value_type(p_value)); + Ok(SqliteValue::new(p_value, SqliteTypeInfo(data_type))) + } + } + } + + impl ConnectionState { + pub(crate) fn remove_preupdate_hook(&mut self) { + if let Some(mut handler) = self.preupdate_hook_callback.take() { + unsafe { + sqlite3_preupdate_hook(self.handle.as_ptr(), None, ptr::null_mut()); + let _ = { Box::from_raw(handler.0.as_mut()) }; + } + } + } + } + + impl LockedSqliteHandle<'_> { + /// Registers a hook that is invoked prior to each `INSERT`, `UPDATE`, and `DELETE` operation on a database table. + /// At most one preupdate hook may be registered at a time on a single database connection. + /// + /// The preupdate hook only fires for changes to real database tables; + /// it is not invoked for changes to virtual tables or to system tables like sqlite_sequence or sqlite_stat1. + /// + /// See https://sqlite.org/c3ref/preupdate_count.html + pub fn set_preupdate_hook(&mut self, callback: F) + where + F: FnMut(PreupdateHookResult) + Send + 'static, + { + unsafe { + let callback_boxed = Box::new(callback); + // SAFETY: `Box::into_raw()` always returns a non-null pointer. + let callback = NonNull::new_unchecked(Box::into_raw(callback_boxed)); + let handler = callback.as_ptr() as *mut _; + self.guard.remove_preupdate_hook(); + self.guard.preupdate_hook_callback = Some(PreupdateHookHandler(callback)); + + sqlite3_preupdate_hook( + self.as_raw_handle().as_mut(), + Some(preupdate_hook::), + handler, + ); + } + } + + pub fn remove_preupdate_hook(&mut self) { + self.guard.remove_preupdate_hook(); + } + } + + extern "C" fn preupdate_hook( + callback: *mut c_void, + db: *mut sqlite3, + op_code: c_int, + database: *const c_char, + table: *const c_char, + old_row_id: i64, + new_row_id: i64, + ) where + F: FnMut(PreupdateHookResult), + { + unsafe { + let _ = catch_unwind(|| { + let callback: *mut F = callback.cast::(); + let operation: SqliteOperation = op_code.into(); + let database = CStr::from_ptr(database).to_str().unwrap_or_default(); + let table = CStr::from_ptr(table).to_str().unwrap_or_default(); + + let preupdate_case = match operation { + SqliteOperation::Insert => { + PreupdateCase::Insert(PreupdateNewValueAccessor { db, new_row_id }) + } + SqliteOperation::Delete => { + PreupdateCase::Delete(PreupdateOldValueAccessor { db, old_row_id }) + } + SqliteOperation::Update => PreupdateCase::Update { + old_value_accessor: PreupdateOldValueAccessor { db, old_row_id }, + new_value_accessor: PreupdateNewValueAccessor { db, new_row_id }, + }, + SqliteOperation::Unknown(_) => PreupdateCase::Unknown, + }; + (*callback)(PreupdateHookResult { + operation, + database, + table, + case: preupdate_case, + }) + }); + } + } +} diff --git a/sqlx-sqlite/src/lib.rs b/sqlx-sqlite/src/lib.rs index f8f5534879..6792e65af7 100644 --- a/sqlx-sqlite/src/lib.rs +++ b/sqlx-sqlite/src/lib.rs @@ -47,6 +47,10 @@ use std::sync::atomic::AtomicBool; pub use arguments::{SqliteArgumentValue, SqliteArguments}; pub use column::SqliteColumn; pub use connection::{LockedSqliteHandle, SqliteConnection, SqliteOperation, UpdateHookResult}; +#[cfg(feature = "preupdate-hook")] +pub use connection::{ + PreupdateCase, PreupdateHookResult, PreupdateNewValueAccessor, PreupdateOldValueAccessor, +}; pub use database::Sqlite; pub use error::SqliteError; pub use options::{ diff --git a/tests/sqlite/sqlite.rs b/tests/sqlite/sqlite.rs index b733ccbb4c..ffcd562140 100644 --- a/tests/sqlite/sqlite.rs +++ b/tests/sqlite/sqlite.rs @@ -2,11 +2,14 @@ use futures::TryStreamExt; use rand::{Rng, SeedableRng}; use rand_xoshiro::Xoshiro256PlusPlus; use sqlx::sqlite::{SqliteConnectOptions, SqliteOperation, SqlitePoolOptions}; +use sqlx::Decode; +use sqlx::Value; use sqlx::{ query, sqlite::Sqlite, sqlite::SqliteRow, Column, ConnectOptions, Connection, Executor, Row, SqliteConnection, SqlitePool, Statement, TypeInfo, }; use sqlx_test::new; +use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::Arc; #[sqlx_macros::test] @@ -798,7 +801,7 @@ async fn test_multiple_set_progress_handler_calls_drop_old_handler() -> anyhow:: #[sqlx_macros::test] async fn test_query_with_update_hook() -> anyhow::Result<()> { let mut conn = new::().await?; - + static CALLED: AtomicBool = AtomicBool::new(false); // Using this string as a canary to ensure the callback doesn't get called with the wrong data pointer. let state = format!("test"); conn.lock_handle().await?.set_update_hook(move |result| { @@ -807,11 +810,13 @@ async fn test_query_with_update_hook() -> anyhow::Result<()> { assert_eq!(result.database, "main"); assert_eq!(result.table, "tweet"); assert_eq!(result.rowid, 2); + CALLED.store(true, Ordering::Relaxed); }); let _ = sqlx::query("INSERT INTO tweet ( id, text ) VALUES ( 3, 'Hello, World' )") .execute(&mut conn) .await?; + assert!(CALLED.load(Ordering::Relaxed)); Ok(()) } @@ -852,10 +857,11 @@ async fn test_multiple_set_update_hook_calls_drop_old_handler() -> anyhow::Resul #[sqlx_macros::test] async fn test_query_with_commit_hook() -> anyhow::Result<()> { let mut conn = new::().await?; - + static CALLED: AtomicBool = AtomicBool::new(false); // Using this string as a canary to ensure the callback doesn't get called with the wrong data pointer. let state = format!("test"); conn.lock_handle().await?.set_commit_hook(move || { + CALLED.store(true, Ordering::Relaxed); assert_eq!(state, "test"); false }); @@ -870,7 +876,7 @@ async fn test_query_with_commit_hook() -> anyhow::Result<()> { } _ => panic!("expected an error"), } - + assert!(CALLED.load(Ordering::Relaxed)); Ok(()) } @@ -916,8 +922,10 @@ async fn test_query_with_rollback_hook() -> anyhow::Result<()> { // Using this string as a canary to ensure the callback doesn't get called with the wrong data pointer. let state = format!("test"); + static CALLED: AtomicBool = AtomicBool::new(false); conn.lock_handle().await?.set_rollback_hook(move || { assert_eq!(state, "test"); + CALLED.store(true, Ordering::Relaxed); }); let mut tx = conn.begin().await?; @@ -925,6 +933,7 @@ async fn test_query_with_rollback_hook() -> anyhow::Result<()> { .execute(&mut *tx) .await?; tx.rollback().await?; + assert!(CALLED.load(Ordering::Relaxed)); Ok(()) } @@ -960,3 +969,218 @@ async fn test_multiple_set_rollback_hook_calls_drop_old_handler() -> anyhow::Res assert_eq!(1, Arc::strong_count(&ref_counted_object)); Ok(()) } + +#[cfg(feature = "sqlite-preupdate-hook")] +#[sqlx_macros::test] +async fn test_query_with_preupdate_hook_insert() -> anyhow::Result<()> { + let mut conn = new::().await?; + static CALLED: AtomicBool = AtomicBool::new(false); + // Using this string as a canary to ensure the callback doesn't get called with the wrong data pointer. + let state = format!("test"); + conn.lock_handle().await?.set_preupdate_hook(move |result| { + assert_eq!(state, "test"); + assert_eq!(result.operation, SqliteOperation::Insert); + assert_eq!(result.database, "main"); + assert_eq!(result.table, "tweet"); + + if let sqlx_sqlite::PreupdateCase::Insert(accessor) = result.case { + assert_eq!(4, accessor.get_column_count()); + assert_eq!(2, accessor.get_new_row_id()); + assert_eq!(0, accessor.get_query_depth()); + assert_eq!( + 4, + >::decode( + accessor.get_new_column_value(0).unwrap().as_ref(), + ) + .unwrap() + ); + assert_eq!( + "Hello, World", + >::decode( + accessor.get_new_column_value(1).unwrap().as_ref(), + ) + .unwrap() + ); + // out of bounds access should return an error + assert!(accessor.get_new_column_value(4).is_err()); + } else { + panic!("wrong preupdate case"); + } + CALLED.store(true, Ordering::Relaxed); + }); + + let _ = sqlx::query("INSERT INTO tweet ( id, text ) VALUES ( 4, 'Hello, World' )") + .execute(&mut conn) + .await?; + + assert!(CALLED.load(Ordering::Relaxed)); + conn.lock_handle().await?.remove_preupdate_hook(); + let _ = sqlx::query("DELETE FROM tweet where id = 4") + .execute(&mut conn) + .await?; + Ok(()) +} + +#[cfg(feature = "sqlite-preupdate-hook")] +#[sqlx_macros::test] +async fn test_query_with_preupdate_hook_delete() -> anyhow::Result<()> { + let mut conn = new::().await?; + let _ = sqlx::query("INSERT INTO tweet ( id, text ) VALUES ( 5, 'Hello, World' )") + .execute(&mut conn) + .await?; + static CALLED: AtomicBool = AtomicBool::new(false); + // Using this string as a canary to ensure the callback doesn't get called with the wrong data pointer. + let state = format!("test"); + conn.lock_handle().await?.set_preupdate_hook(move |result| { + assert_eq!(state, "test"); + assert_eq!(result.operation, SqliteOperation::Delete); + assert_eq!(result.database, "main"); + assert_eq!(result.table, "tweet"); + + if let sqlx_sqlite::PreupdateCase::Delete(accessor) = result.case { + assert_eq!(4, accessor.get_column_count()); + assert_eq!(2, accessor.get_old_row_id()); + assert_eq!(0, accessor.get_query_depth()); + assert_eq!( + 5, + >::decode( + accessor.get_old_column_value(0).unwrap().as_ref(), + ) + .unwrap() + ); + assert_eq!( + "Hello, World", + >::decode( + accessor.get_old_column_value(1).unwrap().as_ref(), + ) + .unwrap() + ); + // out of bounds access should return an error + assert!(accessor.get_old_column_value(4).is_err()); + } else { + panic!("wrong preupdate case"); + } + CALLED.store(true, Ordering::Relaxed); + }); + + let _ = sqlx::query("DELETE FROM tweet WHERE id = 5") + .execute(&mut conn) + .await?; + assert!(CALLED.load(Ordering::Relaxed)); + Ok(()) +} + +#[cfg(feature = "sqlite-preupdate-hook")] +#[sqlx_macros::test] +async fn test_query_with_preupdate_hook_update() -> anyhow::Result<()> { + let mut conn = new::().await?; + let _ = sqlx::query("INSERT INTO tweet ( id, text ) VALUES ( 6, 'Hello, World' )") + .execute(&mut conn) + .await?; + static CALLED: AtomicBool = AtomicBool::new(false); + // Using this string as a canary to ensure the callback doesn't get called with the wrong data pointer. + let state = format!("test"); + conn.lock_handle().await?.set_preupdate_hook(move |result| { + assert_eq!(state, "test"); + assert_eq!(result.operation, SqliteOperation::Update); + assert_eq!(result.database, "main"); + assert_eq!(result.table, "tweet"); + + if let sqlx_sqlite::PreupdateCase::Update { + old_value_accessor, + new_value_accessor, + } = result.case + { + assert_eq!(4, old_value_accessor.get_column_count()); + assert_eq!(4, new_value_accessor.get_column_count()); + + assert_eq!(2, old_value_accessor.get_old_row_id()); + assert_eq!(2, new_value_accessor.get_new_row_id()); + + assert_eq!(0, old_value_accessor.get_query_depth()); + assert_eq!(0, new_value_accessor.get_query_depth()); + + assert_eq!( + 6, + >::decode( + old_value_accessor.get_old_column_value(0).unwrap().as_ref(), + ) + .unwrap() + ); + assert_eq!( + 6, + >::decode( + new_value_accessor.get_new_column_value(0).unwrap().as_ref(), + ) + .unwrap() + ); + + assert_eq!( + "Hello, World", + >::decode( + old_value_accessor.get_old_column_value(1).unwrap().as_ref(), + ) + .unwrap() + ); + assert_eq!( + "Hello, World2", + >::decode( + new_value_accessor.get_new_column_value(1).unwrap().as_ref(), + ) + .unwrap() + ); + + // out of bounds access should return an error + assert!(old_value_accessor.get_old_column_value(4).is_err()); + assert!(new_value_accessor.get_new_column_value(4).is_err()); + } else { + panic!("wrong preupdate case"); + } + CALLED.store(true, Ordering::Relaxed); + }); + + let _ = sqlx::query("UPDATE tweet SET text = 'Hello, World2' WHERE id = 6") + .execute(&mut conn) + .await?; + + assert!(CALLED.load(Ordering::Relaxed)); + conn.lock_handle().await?.remove_preupdate_hook(); + let _ = sqlx::query("DELETE FROM tweet where id = 6") + .execute(&mut conn) + .await?; + Ok(()) +} + +#[cfg(feature = "sqlite-preupdate-hook")] +#[sqlx_macros::test] +async fn test_multiple_set_preupdate_hook_calls_drop_old_handler() -> anyhow::Result<()> { + let ref_counted_object = Arc::new(0); + assert_eq!(1, Arc::strong_count(&ref_counted_object)); + + { + let mut conn = new::().await?; + + let o = ref_counted_object.clone(); + conn.lock_handle().await?.set_preupdate_hook(move |_| { + println!("{o:?}"); + }); + assert_eq!(2, Arc::strong_count(&ref_counted_object)); + + let o = ref_counted_object.clone(); + conn.lock_handle().await?.set_preupdate_hook(move |_| { + println!("{o:?}"); + }); + assert_eq!(2, Arc::strong_count(&ref_counted_object)); + + let o = ref_counted_object.clone(); + conn.lock_handle().await?.set_preupdate_hook(move |_| { + println!("{o:?}"); + }); + assert_eq!(2, Arc::strong_count(&ref_counted_object)); + + conn.lock_handle().await?.remove_preupdate_hook(); + } + + assert_eq!(1, Arc::strong_count(&ref_counted_object)); + Ok(()) +} From 46df6f8e4b5c974c3a2cd130389ec1c88fec399e Mon Sep 17 00:00:00 2001 From: Austin Schey Date: Thu, 12 Dec 2024 22:16:29 -0800 Subject: [PATCH 2/4] address some PR comments --- Cargo.lock | 1 + sqlx-sqlite/Cargo.toml | 1 + sqlx-sqlite/src/connection/mod.rs | 261 ++++--------------- sqlx-sqlite/src/connection/preupdate_hook.rs | 156 +++++++++++ sqlx-sqlite/src/lib.rs | 6 +- src/lib.rs | 8 + tests/sqlite/sqlite.rs | 145 +++++------ 7 files changed, 276 insertions(+), 302 deletions(-) create mode 100644 sqlx-sqlite/src/connection/preupdate_hook.rs diff --git a/Cargo.lock b/Cargo.lock index 36cb94422e..a06106ac4b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3843,6 +3843,7 @@ dependencies = [ "serde_urlencoded", "sqlx", "sqlx-core", + "thiserror 2.0.0", "time", "tracing", "url", diff --git a/sqlx-sqlite/Cargo.toml b/sqlx-sqlite/Cargo.toml index f9375d68f6..5ad57546e7 100644 --- a/sqlx-sqlite/Cargo.toml +++ b/sqlx-sqlite/Cargo.toml @@ -50,6 +50,7 @@ atoi = "2.0" log = "0.4.18" tracing = { version = "0.1.37", features = ["log"] } +thiserror = "2.0.0" serde = { version = "1.0.145", features = ["derive"], optional = true } regex = { version = "1.5.5", optional = true } diff --git a/sqlx-sqlite/src/connection/mod.rs b/sqlx-sqlite/src/connection/mod.rs index 8c1554aa26..c1f9d46da8 100644 --- a/sqlx-sqlite/src/connection/mod.rs +++ b/sqlx-sqlite/src/connection/mod.rs @@ -38,6 +38,8 @@ mod executor; mod explain; mod handle; pub(crate) mod intmap; +#[cfg(feature = "preupdate-hook")] +mod preupdate_hook; mod worker; @@ -143,6 +145,16 @@ impl ConnectionState { } } + #[cfg(feature = "preupdate-hook")] + pub(crate) fn remove_preupdate_hook(&mut self) { + if let Some(mut handler) = self.preupdate_hook_callback.take() { + unsafe { + libsqlite3_sys::sqlite3_preupdate_hook(self.handle.as_ptr(), None, ptr::null_mut()); + let _ = { Box::from_raw(handler.0.as_mut()) }; + } + } + } + pub(crate) fn remove_commit_hook(&mut self) { if let Some(mut handler) = self.commit_hook_callback.take() { unsafe { @@ -426,6 +438,34 @@ impl LockedSqliteHandle<'_> { } } + /// Registers a hook that is invoked prior to each `INSERT`, `UPDATE`, and `DELETE` operation on a database table. + /// At most one preupdate hook may be registered at a time on a single database connection. + /// + /// The preupdate hook only fires for changes to real database tables; + /// it is not invoked for changes to virtual tables or to system tables like sqlite_sequence or sqlite_stat1. + /// + /// See https://sqlite.org/c3ref/preupdate_count.html + #[cfg(feature = "preupdate-hook")] + pub fn set_preupdate_hook(&mut self, callback: F) + where + F: FnMut(PreupdateHookResult) + Send + 'static, + { + unsafe { + let callback_boxed = Box::new(callback); + // SAFETY: `Box::into_raw()` always returns a non-null pointer. + let callback = NonNull::new_unchecked(Box::into_raw(callback_boxed)); + let handler = callback.as_ptr() as *mut _; + self.guard.remove_preupdate_hook(); + self.guard.preupdate_hook_callback = Some(PreupdateHookHandler(callback)); + + libsqlite3_sys::sqlite3_preupdate_hook( + self.as_raw_handle().as_mut(), + Some(preupdate_hook::), + handler, + ); + } + } + /// Sets a commit hook that is invoked whenever a transaction is committed. If the commit hook callback /// returns `false`, then the operation is turned into a ROLLBACK. /// @@ -490,6 +530,11 @@ impl LockedSqliteHandle<'_> { self.guard.remove_update_hook(); } + #[cfg(feature = "preupdate-hook")] + pub fn remove_preupdate_hook(&mut self) { + self.guard.remove_preupdate_hook(); + } + pub fn remove_commit_hook(&mut self) { self.guard.remove_commit_hook(); } @@ -549,219 +594,3 @@ impl Statements { self.temp = None; } } - -#[cfg(feature = "preupdate-hook")] -mod preupdate_hook { - use super::ConnectionState; - use super::LockedSqliteHandle; - use super::SqliteOperation; - use crate::type_info::DataType; - use crate::{SqliteError, SqliteTypeInfo, SqliteValue}; - use libsqlite3_sys::{ - sqlite3, sqlite3_preupdate_count, sqlite3_preupdate_depth, sqlite3_preupdate_hook, - sqlite3_preupdate_new, sqlite3_preupdate_old, sqlite3_value, sqlite3_value_type, SQLITE_OK, - }; - use sqlx_core::error::Error; - use std::ffi::CStr; - use std::fmt::Debug; - use std::os::raw::{c_char, c_int, c_void}; - use std::panic::catch_unwind; - use std::ptr; - use std::ptr::NonNull; - - pub struct PreupdateHookResult<'a> { - pub operation: SqliteOperation, - pub database: &'a str, - pub table: &'a str, - pub case: PreupdateCase, - } - - pub(crate) struct PreupdateHookHandler( - NonNull, - ); - unsafe impl Send for PreupdateHookHandler {} - - /// The possible cases for when a PreUpdate Hook gets triggered. Allows access to the relevant - /// functions for each case through the contained values. - pub enum PreupdateCase { - /// Pre-update hook was triggered by an insert. - Insert(PreupdateNewValueAccessor), - /// Pre-update hook was triggered by a delete. - Delete(PreupdateOldValueAccessor), - /// Pre-update hook was triggered by an update. - Update { - old_value_accessor: PreupdateOldValueAccessor, - new_value_accessor: PreupdateNewValueAccessor, - }, - /// This variant is not normally produced by SQLite. You may encounter it - /// if you're using a different version than what's supported by this library. - Unknown, - } - - /// An accessor for the old values of the row being deleted/updated during the preupdate callback. - #[derive(Debug)] - pub struct PreupdateOldValueAccessor { - db: *mut sqlite3, - old_row_id: i64, - } - - impl PreupdateOldValueAccessor { - /// Gets the amount of columns in the row being deleted/updated. - pub fn get_column_count(&self) -> i32 { - unsafe { sqlite3_preupdate_count(self.db) } - } - - /// Gets the depth of the query that triggered the preupdate hook. - /// Returns 0 if the preupdate callback was invoked as a result of - /// a direct insert, update, or delete operation; - /// 1 for inserts, updates, or deletes invoked by top-level triggers; - /// 2 for changes resulting from triggers called by top-level triggers; and so forth. - pub fn get_query_depth(&self) -> i32 { - unsafe { sqlite3_preupdate_depth(self.db) } - } - - /// Gets the row id of the row being updated/deleted. - pub fn get_old_row_id(&self) -> i64 { - self.old_row_id - } - - /// Gets the value of the row being updated/deleted at the specified index. - pub fn get_old_column_value(&self, i: i32) -> Result { - let mut p_value: *mut sqlite3_value = ptr::null_mut(); - unsafe { - let ret = sqlite3_preupdate_old(self.db, i, &mut p_value); - if ret != SQLITE_OK { - return Err(Error::Database(Box::new(SqliteError::new(self.db)))); - } - let data_type = DataType::from_code(sqlite3_value_type(p_value)); - Ok(SqliteValue::new(p_value, SqliteTypeInfo(data_type))) - } - } - } - - /// An accessor for the new values of the row being inserted/updated during the preupdate callback. - #[derive(Debug)] - pub struct PreupdateNewValueAccessor { - db: *mut sqlite3, - new_row_id: i64, - } - - impl PreupdateNewValueAccessor { - /// Gets the amount of columns in the row being inserted/updated. - pub fn get_column_count(&self) -> i32 { - unsafe { sqlite3_preupdate_count(self.db) } - } - - /// Gets the depth of the query that triggered the preupdate hook. - /// Returns 0 if the preupdate callback was invoked as a result of - /// a direct insert, update, or delete operation; - /// 1 for inserts, updates, or deletes invoked by top-level triggers; - /// 2 for changes resulting from triggers called by top-level triggers; and so forth. - pub fn get_query_depth(&self) -> i32 { - unsafe { sqlite3_preupdate_depth(self.db) } - } - - /// Gets the row id of the row being inserted/updated. - pub fn get_new_row_id(&self) -> i64 { - self.new_row_id - } - - /// Gets the value of the row being updated/deleted at the specified index. - pub fn get_new_column_value(&self, i: i32) -> Result { - let mut p_value: *mut sqlite3_value = ptr::null_mut(); - unsafe { - let ret = sqlite3_preupdate_new(self.db, i, &mut p_value); - if ret != SQLITE_OK { - return Err(Error::Database(Box::new(SqliteError::new(self.db)))); - } - let data_type = DataType::from_code(sqlite3_value_type(p_value)); - Ok(SqliteValue::new(p_value, SqliteTypeInfo(data_type))) - } - } - } - - impl ConnectionState { - pub(crate) fn remove_preupdate_hook(&mut self) { - if let Some(mut handler) = self.preupdate_hook_callback.take() { - unsafe { - sqlite3_preupdate_hook(self.handle.as_ptr(), None, ptr::null_mut()); - let _ = { Box::from_raw(handler.0.as_mut()) }; - } - } - } - } - - impl LockedSqliteHandle<'_> { - /// Registers a hook that is invoked prior to each `INSERT`, `UPDATE`, and `DELETE` operation on a database table. - /// At most one preupdate hook may be registered at a time on a single database connection. - /// - /// The preupdate hook only fires for changes to real database tables; - /// it is not invoked for changes to virtual tables or to system tables like sqlite_sequence or sqlite_stat1. - /// - /// See https://sqlite.org/c3ref/preupdate_count.html - pub fn set_preupdate_hook(&mut self, callback: F) - where - F: FnMut(PreupdateHookResult) + Send + 'static, - { - unsafe { - let callback_boxed = Box::new(callback); - // SAFETY: `Box::into_raw()` always returns a non-null pointer. - let callback = NonNull::new_unchecked(Box::into_raw(callback_boxed)); - let handler = callback.as_ptr() as *mut _; - self.guard.remove_preupdate_hook(); - self.guard.preupdate_hook_callback = Some(PreupdateHookHandler(callback)); - - sqlite3_preupdate_hook( - self.as_raw_handle().as_mut(), - Some(preupdate_hook::), - handler, - ); - } - } - - pub fn remove_preupdate_hook(&mut self) { - self.guard.remove_preupdate_hook(); - } - } - - extern "C" fn preupdate_hook( - callback: *mut c_void, - db: *mut sqlite3, - op_code: c_int, - database: *const c_char, - table: *const c_char, - old_row_id: i64, - new_row_id: i64, - ) where - F: FnMut(PreupdateHookResult), - { - unsafe { - let _ = catch_unwind(|| { - let callback: *mut F = callback.cast::(); - let operation: SqliteOperation = op_code.into(); - let database = CStr::from_ptr(database).to_str().unwrap_or_default(); - let table = CStr::from_ptr(table).to_str().unwrap_or_default(); - - let preupdate_case = match operation { - SqliteOperation::Insert => { - PreupdateCase::Insert(PreupdateNewValueAccessor { db, new_row_id }) - } - SqliteOperation::Delete => { - PreupdateCase::Delete(PreupdateOldValueAccessor { db, old_row_id }) - } - SqliteOperation::Update => PreupdateCase::Update { - old_value_accessor: PreupdateOldValueAccessor { db, old_row_id }, - new_value_accessor: PreupdateNewValueAccessor { db, new_row_id }, - }, - SqliteOperation::Unknown(_) => PreupdateCase::Unknown, - }; - (*callback)(PreupdateHookResult { - operation, - database, - table, - case: preupdate_case, - }) - }); - } - } -} diff --git a/sqlx-sqlite/src/connection/preupdate_hook.rs b/sqlx-sqlite/src/connection/preupdate_hook.rs new file mode 100644 index 0000000000..8df40e16b9 --- /dev/null +++ b/sqlx-sqlite/src/connection/preupdate_hook.rs @@ -0,0 +1,156 @@ +use super::SqliteOperation; +use crate::type_info::DataType; +use crate::{SqliteError, SqliteTypeInfo, SqliteValue}; + +use libsqlite3_sys::{ + sqlite3, sqlite3_preupdate_count, sqlite3_preupdate_depth, sqlite3_preupdate_new, + sqlite3_preupdate_old, sqlite3_value, sqlite3_value_type, SQLITE_OK, +}; +use std::ffi::CStr; +use std::os::raw::{c_char, c_int, c_void}; +use std::panic::catch_unwind; +use std::ptr; +use std::ptr::NonNull; + +#[derive(Debug, thiserror::Error)] +pub enum PreupdateError { + /// Error returned from the database. + #[error("error returned from database: {0}")] + Database(#[source] SqliteError), + /// Index is not within the valid column range + #[error("{0} is not within the valid column range")] + ColumnIndexOutOfBounds(i32), + /// Column value accessor was invoked from an invalid operation + #[error("column value accessor was invoked from an invalid operation")] + InvalidOperation, +} + +pub(crate) struct PreupdateHookHandler( + pub(super) NonNull, +); +unsafe impl Send for PreupdateHookHandler {} + +#[derive(Debug)] +pub struct PreupdateHookResult<'a> { + pub operation: SqliteOperation, + pub database: &'a str, + pub table: &'a str, + // The database pointer should not be usable after the preupdate hook. + // The lifetime on this struct needs to ensure it cannot outlive the callback. + db: *mut sqlite3, + old_row_id: i64, + new_row_id: i64, +} + +impl<'a> PreupdateHookResult<'a> { + /// Gets the amount of columns in the row being inserted, deleted, or updated. + pub fn get_column_count(&self) -> i32 { + unsafe { sqlite3_preupdate_count(self.db) } + } + + /// Gets the depth of the query that triggered the preupdate hook. + /// Returns 0 if the preupdate callback was invoked as a result of + /// a direct insert, update, or delete operation; + /// 1 for inserts, updates, or deletes invoked by top-level triggers; + /// 2 for changes resulting from triggers called by top-level triggers; and so forth. + pub fn get_query_depth(&self) -> i32 { + unsafe { sqlite3_preupdate_depth(self.db) } + } + + /// Gets the row id of the row being updated/deleted. + /// Returns an error if called from an insert operation. + pub fn get_old_row_id(&self) -> Result { + if self.operation == SqliteOperation::Insert { + return Err(PreupdateError::InvalidOperation); + } + Ok(self.old_row_id) + } + + /// Gets the row id of the row being inserted/updated. + /// Returns an error if called from a delete operation. + pub fn get_new_row_id(&self) -> Result { + if self.operation == SqliteOperation::Delete { + return Err(PreupdateError::InvalidOperation); + } + Ok(self.new_row_id) + } + + /// Gets the value of the row being updated/deleted at the specified index. + /// Returns an error if called from an insert operation or the index is out of bounds. + pub fn get_old_column_value(&self, i: i32) -> Result { + if self.operation == SqliteOperation::Insert { + return Err(PreupdateError::InvalidOperation); + } + self.validate_column_index(i)?; + + let mut p_value: *mut sqlite3_value = ptr::null_mut(); + unsafe { + let ret = sqlite3_preupdate_old(self.db, i, &mut p_value); + self.get_value(ret, p_value) + } + } + + /// Gets the value of the row being inserted/updated at the specified index. + /// Returns an error if called from a delete operation or the index is out of bounds. + pub fn get_new_column_value(&self, i: i32) -> Result { + if self.operation == SqliteOperation::Delete { + return Err(PreupdateError::InvalidOperation); + } + self.validate_column_index(i)?; + + let mut p_value: *mut sqlite3_value = ptr::null_mut(); + unsafe { + let ret = sqlite3_preupdate_new(self.db, i, &mut p_value); + self.get_value(ret, p_value) + } + } + + fn validate_column_index(&self, i: i32) -> Result<(), PreupdateError> { + if i < 0 || i >= self.get_column_count() { + return Err(PreupdateError::ColumnIndexOutOfBounds(i)); + } + Ok(()) + } + + unsafe fn get_value( + &self, + ret: i32, + p_value: *mut sqlite3_value, + ) -> Result { + if ret != SQLITE_OK { + return Err(PreupdateError::Database(SqliteError::new(self.db))); + } + let data_type = DataType::from_code(sqlite3_value_type(p_value)); + Ok(SqliteValue::new(p_value, SqliteTypeInfo(data_type))) + } +} + +pub(super) extern "C" fn preupdate_hook( + callback: *mut c_void, + db: *mut sqlite3, + op_code: c_int, + database: *const c_char, + table: *const c_char, + old_row_id: i64, + new_row_id: i64, +) where + F: FnMut(PreupdateHookResult) + Send + 'static, +{ + unsafe { + let _ = catch_unwind(|| { + let callback: *mut F = callback.cast::(); + let operation: SqliteOperation = op_code.into(); + let database = CStr::from_ptr(database).to_str().unwrap_or_default(); + let table = CStr::from_ptr(table).to_str().unwrap_or_default(); + + (*callback)(PreupdateHookResult { + operation, + database, + table, + old_row_id, + new_row_id, + db, + }) + }); + } +} diff --git a/sqlx-sqlite/src/lib.rs b/sqlx-sqlite/src/lib.rs index 6792e65af7..474bdeee52 100644 --- a/sqlx-sqlite/src/lib.rs +++ b/sqlx-sqlite/src/lib.rs @@ -46,11 +46,9 @@ use std::sync::atomic::AtomicBool; pub use arguments::{SqliteArgumentValue, SqliteArguments}; pub use column::SqliteColumn; -pub use connection::{LockedSqliteHandle, SqliteConnection, SqliteOperation, UpdateHookResult}; #[cfg(feature = "preupdate-hook")] -pub use connection::{ - PreupdateCase, PreupdateHookResult, PreupdateNewValueAccessor, PreupdateOldValueAccessor, -}; +pub use connection::PreupdateHookResult; +pub use connection::{LockedSqliteHandle, SqliteConnection, SqliteOperation, UpdateHookResult}; pub use database::Sqlite; pub use error::SqliteError; pub use options::{ diff --git a/src/lib.rs b/src/lib.rs index a9e90c071d..b84ed74e6b 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,6 +1,14 @@ #![cfg_attr(docsrs, feature(doc_cfg))] #![doc = include_str!("lib.md")] +#[cfg(all( + feature = "sqlite-preupdate-hook", + not(any(feature = "sqlite", feature = "sqlite-unbundled")) +))] +compile_error!( + "sqlite-preupdate-hook requires either 'sqlite' or 'sqlite-unbundled' to be enabled" +); + pub use sqlx_core::acquire::Acquire; pub use sqlx_core::arguments::{Arguments, IntoArguments}; pub use sqlx_core::column::Column; diff --git a/tests/sqlite/sqlite.rs b/tests/sqlite/sqlite.rs index ffcd562140..043520740b 100644 --- a/tests/sqlite/sqlite.rs +++ b/tests/sqlite/sqlite.rs @@ -977,36 +977,36 @@ async fn test_query_with_preupdate_hook_insert() -> anyhow::Result<()> { static CALLED: AtomicBool = AtomicBool::new(false); // Using this string as a canary to ensure the callback doesn't get called with the wrong data pointer. let state = format!("test"); - conn.lock_handle().await?.set_preupdate_hook(move |result| { - assert_eq!(state, "test"); - assert_eq!(result.operation, SqliteOperation::Insert); - assert_eq!(result.database, "main"); - assert_eq!(result.table, "tweet"); - - if let sqlx_sqlite::PreupdateCase::Insert(accessor) = result.case { - assert_eq!(4, accessor.get_column_count()); - assert_eq!(2, accessor.get_new_row_id()); - assert_eq!(0, accessor.get_query_depth()); + conn.lock_handle().await?.set_preupdate_hook({ + move |result| { + assert_eq!(state, "test"); + assert_eq!(result.operation, SqliteOperation::Insert); + assert_eq!(result.database, "main"); + assert_eq!(result.table, "tweet"); + + assert_eq!(4, result.get_column_count()); + assert_eq!(2, result.get_new_row_id().unwrap()); + assert_eq!(0, result.get_query_depth()); assert_eq!( - 4, - >::decode( - accessor.get_new_column_value(0).unwrap().as_ref(), - ) - .unwrap() + 4, + >::decode(result.get_new_column_value(0).unwrap().as_ref(),) + .unwrap() ); assert_eq!( "Hello, World", >::decode( - accessor.get_new_column_value(1).unwrap().as_ref(), + result.get_new_column_value(1).unwrap().as_ref(), ) .unwrap() ); // out of bounds access should return an error - assert!(accessor.get_new_column_value(4).is_err()); - } else { - panic!("wrong preupdate case"); + assert!(result.get_new_column_value(4).is_err()); + // old values aren't available for inserts + assert!(result.get_old_column_value(0).is_err()); + assert!(result.get_old_row_id().is_err()); + + CALLED.store(true, Ordering::Relaxed); } - CALLED.store(true, Ordering::Relaxed); }); let _ = sqlx::query("INSERT INTO tweet ( id, text ) VALUES ( 4, 'Hello, World' )") @@ -1037,29 +1037,25 @@ async fn test_query_with_preupdate_hook_delete() -> anyhow::Result<()> { assert_eq!(result.database, "main"); assert_eq!(result.table, "tweet"); - if let sqlx_sqlite::PreupdateCase::Delete(accessor) = result.case { - assert_eq!(4, accessor.get_column_count()); - assert_eq!(2, accessor.get_old_row_id()); - assert_eq!(0, accessor.get_query_depth()); - assert_eq!( - 5, - >::decode( - accessor.get_old_column_value(0).unwrap().as_ref(), - ) + assert_eq!(4, result.get_column_count()); + assert_eq!(2, result.get_old_row_id().unwrap()); + assert_eq!(0, result.get_query_depth()); + assert_eq!( + 5, + >::decode(result.get_old_column_value(0).unwrap().as_ref(),) .unwrap() - ); - assert_eq!( - "Hello, World", - >::decode( - accessor.get_old_column_value(1).unwrap().as_ref(), - ) + ); + assert_eq!( + "Hello, World", + >::decode(result.get_old_column_value(1).unwrap().as_ref(),) .unwrap() - ); - // out of bounds access should return an error - assert!(accessor.get_old_column_value(4).is_err()); - } else { - panic!("wrong preupdate case"); - } + ); + // out of bounds access should return an error + assert!(result.get_old_column_value(4).is_err()); + // new values aren't available for deletes + assert!(result.get_new_column_value(0).is_err()); + assert!(result.get_new_row_id().is_err()); + CALLED.store(true, Ordering::Relaxed); }); @@ -1086,56 +1082,41 @@ async fn test_query_with_preupdate_hook_update() -> anyhow::Result<()> { assert_eq!(result.database, "main"); assert_eq!(result.table, "tweet"); - if let sqlx_sqlite::PreupdateCase::Update { - old_value_accessor, - new_value_accessor, - } = result.case - { - assert_eq!(4, old_value_accessor.get_column_count()); - assert_eq!(4, new_value_accessor.get_column_count()); + assert_eq!(4, result.get_column_count()); + assert_eq!(4, result.get_column_count()); - assert_eq!(2, old_value_accessor.get_old_row_id()); - assert_eq!(2, new_value_accessor.get_new_row_id()); + assert_eq!(2, result.get_old_row_id().unwrap()); + assert_eq!(2, result.get_new_row_id().unwrap()); - assert_eq!(0, old_value_accessor.get_query_depth()); - assert_eq!(0, new_value_accessor.get_query_depth()); + assert_eq!(0, result.get_query_depth()); + assert_eq!(0, result.get_query_depth()); - assert_eq!( - 6, - >::decode( - old_value_accessor.get_old_column_value(0).unwrap().as_ref(), - ) + assert_eq!( + 6, + >::decode(result.get_old_column_value(0).unwrap().as_ref(),) .unwrap() - ); - assert_eq!( - 6, - >::decode( - new_value_accessor.get_new_column_value(0).unwrap().as_ref(), - ) + ); + assert_eq!( + 6, + >::decode(result.get_new_column_value(0).unwrap().as_ref(),) .unwrap() - ); + ); - assert_eq!( - "Hello, World", - >::decode( - old_value_accessor.get_old_column_value(1).unwrap().as_ref(), - ) + assert_eq!( + "Hello, World", + >::decode(result.get_old_column_value(1).unwrap().as_ref(),) .unwrap() - ); - assert_eq!( - "Hello, World2", - >::decode( - new_value_accessor.get_new_column_value(1).unwrap().as_ref(), - ) + ); + assert_eq!( + "Hello, World2", + >::decode(result.get_new_column_value(1).unwrap().as_ref(),) .unwrap() - ); + ); + + // out of bounds access should return an error + assert!(result.get_old_column_value(4).is_err()); + assert!(result.get_new_column_value(4).is_err()); - // out of bounds access should return an error - assert!(old_value_accessor.get_old_column_value(4).is_err()); - assert!(new_value_accessor.get_new_column_value(4).is_err()); - } else { - panic!("wrong preupdate case"); - } CALLED.store(true, Ordering::Relaxed); }); From 8da5e6b8a74e2b8ab0d36e66b7f67ee4476636e1 Mon Sep 17 00:00:00 2001 From: Austin Schey Date: Fri, 13 Dec 2024 20:30:15 -0800 Subject: [PATCH 3/4] add SqliteValueRef variant that takes a borrowed sqlite value pointer --- sqlx-sqlite/src/connection/preupdate_hook.rs | 11 +- sqlx-sqlite/src/value.rs | 139 +++++++++++++------ tests/sqlite/sqlite.rs | 99 +++++++------ 3 files changed, 154 insertions(+), 95 deletions(-) diff --git a/sqlx-sqlite/src/connection/preupdate_hook.rs b/sqlx-sqlite/src/connection/preupdate_hook.rs index 8df40e16b9..fcc0fe0bc3 100644 --- a/sqlx-sqlite/src/connection/preupdate_hook.rs +++ b/sqlx-sqlite/src/connection/preupdate_hook.rs @@ -1,6 +1,6 @@ use super::SqliteOperation; use crate::type_info::DataType; -use crate::{SqliteError, SqliteTypeInfo, SqliteValue}; +use crate::{SqliteError, SqliteTypeInfo, SqliteValueRef}; use libsqlite3_sys::{ sqlite3, sqlite3_preupdate_count, sqlite3_preupdate_depth, sqlite3_preupdate_new, @@ -77,7 +77,7 @@ impl<'a> PreupdateHookResult<'a> { /// Gets the value of the row being updated/deleted at the specified index. /// Returns an error if called from an insert operation or the index is out of bounds. - pub fn get_old_column_value(&self, i: i32) -> Result { + pub fn get_old_column_value(&self, i: i32) -> Result, PreupdateError> { if self.operation == SqliteOperation::Insert { return Err(PreupdateError::InvalidOperation); } @@ -92,7 +92,7 @@ impl<'a> PreupdateHookResult<'a> { /// Gets the value of the row being inserted/updated at the specified index. /// Returns an error if called from a delete operation or the index is out of bounds. - pub fn get_new_column_value(&self, i: i32) -> Result { + pub fn get_new_column_value(&self, i: i32) -> Result, PreupdateError> { if self.operation == SqliteOperation::Delete { return Err(PreupdateError::InvalidOperation); } @@ -116,12 +116,13 @@ impl<'a> PreupdateHookResult<'a> { &self, ret: i32, p_value: *mut sqlite3_value, - ) -> Result { + ) -> Result, PreupdateError> { if ret != SQLITE_OK { return Err(PreupdateError::Database(SqliteError::new(self.db))); } let data_type = DataType::from_code(sqlite3_value_type(p_value)); - Ok(SqliteValue::new(p_value, SqliteTypeInfo(data_type))) + // SAFETY: SQLite will free the sqlite3_value when the callback returns + Ok(SqliteValueRef::borrowed(p_value, SqliteTypeInfo(data_type))) } } diff --git a/sqlx-sqlite/src/value.rs b/sqlx-sqlite/src/value.rs index 967b3f7476..469c4e70d5 100644 --- a/sqlx-sqlite/src/value.rs +++ b/sqlx-sqlite/src/value.rs @@ -1,4 +1,5 @@ use std::borrow::Cow; +use std::marker::PhantomData; use std::ptr::NonNull; use std::slice::from_raw_parts; use std::str::from_utf8; @@ -17,6 +18,7 @@ use crate::{Sqlite, SqliteTypeInfo}; enum SqliteValueData<'r> { Value(&'r SqliteValue), + BorrowedHandle(ValueHandle<'r>), } pub struct SqliteValueRef<'r>(SqliteValueData<'r>); @@ -26,31 +28,44 @@ impl<'r> SqliteValueRef<'r> { Self(SqliteValueData::Value(value)) } + // SAFETY: The supplied sqlite3_value must not be null and SQLite must free it. It will not be freed on drop. + // The lifetime on this struct should tie it to whatever scope it's valid for before SQLite frees it. + #[allow(unused)] + pub(crate) unsafe fn borrowed(value: *mut sqlite3_value, type_info: SqliteTypeInfo) -> Self { + debug_assert!(!value.is_null()); + let handle = ValueHandle::new_borrowed(NonNull::new_unchecked(value), type_info); + Self(SqliteValueData::BorrowedHandle(handle)) + } + // NOTE: `int()` is deliberately omitted because it will silently truncate a wider value, // which is likely to cause bugs: // https://github.com/launchbadge/sqlx/issues/3179 // (Similar bug in Postgres): https://github.com/launchbadge/sqlx/issues/3161 pub(super) fn int64(&self) -> i64 { - match self.0 { - SqliteValueData::Value(v) => v.int64(), + match &self.0 { + SqliteValueData::Value(v) => v.0.int64(), + SqliteValueData::BorrowedHandle(v) => v.int64(), } } pub(super) fn double(&self) -> f64 { - match self.0 { - SqliteValueData::Value(v) => v.double(), + match &self.0 { + SqliteValueData::Value(v) => v.0.double(), + SqliteValueData::BorrowedHandle(v) => v.double(), } } pub(super) fn blob(&self) -> &'r [u8] { - match self.0 { - SqliteValueData::Value(v) => v.blob(), + match &self.0 { + SqliteValueData::Value(v) => v.0.blob(), + SqliteValueData::BorrowedHandle(v) => v.blob(), } } pub(super) fn text(&self) -> Result<&'r str, BoxDynError> { - match self.0 { - SqliteValueData::Value(v) => v.text(), + match &self.0 { + SqliteValueData::Value(v) => v.0.text(), + SqliteValueData::BorrowedHandle(v) => v.text(), } } } @@ -59,50 +74,66 @@ impl<'r> ValueRef<'r> for SqliteValueRef<'r> { type Database = Sqlite; fn to_owned(&self) -> SqliteValue { - match self.0 { - SqliteValueData::Value(v) => v.clone(), + match &self.0 { + SqliteValueData::Value(v) => (*v).clone(), + SqliteValueData::BorrowedHandle(v) => unsafe { + SqliteValue::new(v.value.as_ptr(), v.type_info.clone()) + }, } } fn type_info(&self) -> Cow<'_, SqliteTypeInfo> { - match self.0 { + match &self.0 { SqliteValueData::Value(v) => v.type_info(), + SqliteValueData::BorrowedHandle(v) => v.type_info(), } } fn is_null(&self) -> bool { - match self.0 { + match &self.0 { SqliteValueData::Value(v) => v.is_null(), + SqliteValueData::BorrowedHandle(v) => v.is_null(), } } } #[derive(Clone)] -pub struct SqliteValue { - pub(crate) handle: Arc, - pub(crate) type_info: SqliteTypeInfo, -} +pub struct SqliteValue(Arc>); -pub(crate) struct ValueHandle(NonNull); +pub(crate) struct ValueHandle<'a> { + value: NonNull, + type_info: SqliteTypeInfo, + free_on_drop: bool, + _sqlite_value_lifetime: PhantomData<&'a ()>, +} // SAFE: only protected value objects are stored in SqliteValue -unsafe impl Send for ValueHandle {} -unsafe impl Sync for ValueHandle {} +unsafe impl<'a> Send for ValueHandle<'a> {} +unsafe impl<'a> Sync for ValueHandle<'a> {} -impl SqliteValue { - pub(crate) unsafe fn new(value: *mut sqlite3_value, type_info: SqliteTypeInfo) -> Self { - debug_assert!(!value.is_null()); +impl ValueHandle<'static> { + fn new_owned(value: NonNull, type_info: SqliteTypeInfo) -> Self { + Self { + value, + type_info, + free_on_drop: true, + _sqlite_value_lifetime: PhantomData, + } + } +} +impl<'a> ValueHandle<'a> { + fn new_borrowed(value: NonNull, type_info: SqliteTypeInfo) -> Self { Self { + value, type_info, - handle: Arc::new(ValueHandle(NonNull::new_unchecked(sqlite3_value_dup( - value, - )))), + free_on_drop: false, + _sqlite_value_lifetime: PhantomData, } } fn type_info_opt(&self) -> Option { - let dt = DataType::from_code(unsafe { sqlite3_value_type(self.handle.0.as_ptr()) }); + let dt = DataType::from_code(unsafe { sqlite3_value_type(self.value.as_ptr()) }); if let DataType::Null = dt { None @@ -112,15 +143,15 @@ impl SqliteValue { } fn int64(&self) -> i64 { - unsafe { sqlite3_value_int64(self.handle.0.as_ptr()) } + unsafe { sqlite3_value_int64(self.value.as_ptr()) } } fn double(&self) -> f64 { - unsafe { sqlite3_value_double(self.handle.0.as_ptr()) } + unsafe { sqlite3_value_double(self.value.as_ptr()) } } - fn blob(&self) -> &[u8] { - let len = unsafe { sqlite3_value_bytes(self.handle.0.as_ptr()) }; + fn blob<'b>(&self) -> &'b [u8] { + let len = unsafe { sqlite3_value_bytes(self.value.as_ptr()) }; // This likely means UB in SQLite itself or our usage of it; // signed integer overflow is UB in the C standard. @@ -133,23 +164,15 @@ impl SqliteValue { return &[]; } - let ptr = unsafe { sqlite3_value_blob(self.handle.0.as_ptr()) } as *const u8; + let ptr = unsafe { sqlite3_value_blob(self.value.as_ptr()) } as *const u8; debug_assert!(!ptr.is_null()); unsafe { from_raw_parts(ptr, len) } } - fn text(&self) -> Result<&str, BoxDynError> { + fn text<'b>(&self) -> Result<&'b str, BoxDynError> { Ok(from_utf8(self.blob())?) } -} - -impl Value for SqliteValue { - type Database = Sqlite; - - fn as_ref(&self) -> SqliteValueRef<'_> { - SqliteValueRef::value(self) - } fn type_info(&self) -> Cow<'_, SqliteTypeInfo> { self.type_info_opt() @@ -158,18 +181,46 @@ impl Value for SqliteValue { } fn is_null(&self) -> bool { - unsafe { sqlite3_value_type(self.handle.0.as_ptr()) == SQLITE_NULL } + unsafe { sqlite3_value_type(self.value.as_ptr()) == SQLITE_NULL } } } -impl Drop for ValueHandle { +impl<'a> Drop for ValueHandle<'a> { fn drop(&mut self) { - unsafe { - sqlite3_value_free(self.0.as_ptr()); + if self.free_on_drop { + unsafe { + sqlite3_value_free(self.value.as_ptr()); + } } } } +impl SqliteValue { + // SAFETY: The sqlite3_value must be non-null and SQLite must not free it. It will be freed on drop. + pub(crate) unsafe fn new(value: *mut sqlite3_value, type_info: SqliteTypeInfo) -> Self { + debug_assert!(!value.is_null()); + let handle = + ValueHandle::new_owned(NonNull::new_unchecked(sqlite3_value_dup(value)), type_info); + Self(Arc::new(handle)) + } +} + +impl Value for SqliteValue { + type Database = Sqlite; + + fn as_ref(&self) -> SqliteValueRef<'_> { + SqliteValueRef::value(self) + } + + fn type_info(&self) -> Cow<'_, SqliteTypeInfo> { + self.0.type_info() + } + + fn is_null(&self) -> bool { + self.0.is_null() + } +} + // #[cfg(feature = "any")] // impl<'r> From> for crate::any::AnyValueRef<'r> { // #[inline] diff --git a/tests/sqlite/sqlite.rs b/tests/sqlite/sqlite.rs index 043520740b..d78e1151a9 100644 --- a/tests/sqlite/sqlite.rs +++ b/tests/sqlite/sqlite.rs @@ -3,11 +3,11 @@ use rand::{Rng, SeedableRng}; use rand_xoshiro::Xoshiro256PlusPlus; use sqlx::sqlite::{SqliteConnectOptions, SqliteOperation, SqlitePoolOptions}; use sqlx::Decode; -use sqlx::Value; use sqlx::{ query, sqlite::Sqlite, sqlite::SqliteRow, Column, ConnectOptions, Connection, Executor, Row, SqliteConnection, SqlitePool, Statement, TypeInfo, }; +use sqlx::{Value, ValueRef}; use sqlx_test::new; use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::Arc; @@ -989,15 +989,12 @@ async fn test_query_with_preupdate_hook_insert() -> anyhow::Result<()> { assert_eq!(0, result.get_query_depth()); assert_eq!( 4, - >::decode(result.get_new_column_value(0).unwrap().as_ref(),) - .unwrap() + >::decode(result.get_new_column_value(0).unwrap()).unwrap() ); assert_eq!( "Hello, World", - >::decode( - result.get_new_column_value(1).unwrap().as_ref(), - ) - .unwrap() + >::decode(result.get_new_column_value(1).unwrap()) + .unwrap() ); // out of bounds access should return an error assert!(result.get_new_column_value(4).is_err()); @@ -1042,13 +1039,11 @@ async fn test_query_with_preupdate_hook_delete() -> anyhow::Result<()> { assert_eq!(0, result.get_query_depth()); assert_eq!( 5, - >::decode(result.get_old_column_value(0).unwrap().as_ref(),) - .unwrap() + >::decode(result.get_old_column_value(0).unwrap()).unwrap() ); assert_eq!( "Hello, World", - >::decode(result.get_old_column_value(1).unwrap().as_ref(),) - .unwrap() + >::decode(result.get_old_column_value(1).unwrap()).unwrap() ); // out of bounds access should return an error assert!(result.get_old_column_value(4).is_err()); @@ -1074,50 +1069,54 @@ async fn test_query_with_preupdate_hook_update() -> anyhow::Result<()> { .execute(&mut conn) .await?; static CALLED: AtomicBool = AtomicBool::new(false); + let sqlite_value_stored: Arc>> = Default::default(); // Using this string as a canary to ensure the callback doesn't get called with the wrong data pointer. let state = format!("test"); - conn.lock_handle().await?.set_preupdate_hook(move |result| { - assert_eq!(state, "test"); - assert_eq!(result.operation, SqliteOperation::Update); - assert_eq!(result.database, "main"); - assert_eq!(result.table, "tweet"); + conn.lock_handle().await?.set_preupdate_hook({ + let sqlite_value_stored = sqlite_value_stored.clone(); + move |result| { + assert_eq!(state, "test"); + assert_eq!(result.operation, SqliteOperation::Update); + assert_eq!(result.database, "main"); + assert_eq!(result.table, "tweet"); - assert_eq!(4, result.get_column_count()); - assert_eq!(4, result.get_column_count()); + assert_eq!(4, result.get_column_count()); + assert_eq!(4, result.get_column_count()); - assert_eq!(2, result.get_old_row_id().unwrap()); - assert_eq!(2, result.get_new_row_id().unwrap()); + assert_eq!(2, result.get_old_row_id().unwrap()); + assert_eq!(2, result.get_new_row_id().unwrap()); - assert_eq!(0, result.get_query_depth()); - assert_eq!(0, result.get_query_depth()); + assert_eq!(0, result.get_query_depth()); + assert_eq!(0, result.get_query_depth()); - assert_eq!( - 6, - >::decode(result.get_old_column_value(0).unwrap().as_ref(),) - .unwrap() - ); - assert_eq!( - 6, - >::decode(result.get_new_column_value(0).unwrap().as_ref(),) - .unwrap() - ); + assert_eq!( + 6, + >::decode(result.get_old_column_value(0).unwrap()).unwrap() + ); + assert_eq!( + 6, + >::decode(result.get_new_column_value(0).unwrap()).unwrap() + ); - assert_eq!( - "Hello, World", - >::decode(result.get_old_column_value(1).unwrap().as_ref(),) - .unwrap() - ); - assert_eq!( - "Hello, World2", - >::decode(result.get_new_column_value(1).unwrap().as_ref(),) - .unwrap() - ); + assert_eq!( + "Hello, World", + >::decode(result.get_old_column_value(1).unwrap()) + .unwrap() + ); + assert_eq!( + "Hello, World2", + >::decode(result.get_new_column_value(1).unwrap()) + .unwrap() + ); + *sqlite_value_stored.lock().unwrap() = + Some(result.get_old_column_value(0).unwrap().to_owned()); - // out of bounds access should return an error - assert!(result.get_old_column_value(4).is_err()); - assert!(result.get_new_column_value(4).is_err()); + // out of bounds access should return an error + assert!(result.get_old_column_value(4).is_err()); + assert!(result.get_new_column_value(4).is_err()); - CALLED.store(true, Ordering::Relaxed); + CALLED.store(true, Ordering::Relaxed); + } }); let _ = sqlx::query("UPDATE tweet SET text = 'Hello, World2' WHERE id = 6") @@ -1129,6 +1128,14 @@ async fn test_query_with_preupdate_hook_update() -> anyhow::Result<()> { let _ = sqlx::query("DELETE FROM tweet where id = 6") .execute(&mut conn) .await?; + // Ensure that taking an owned SqliteValue maintains a valid reference after the callback returns + assert_eq!( + 6, + >::decode( + sqlite_value_stored.lock().unwrap().take().unwrap().as_ref() + ) + .unwrap() + ); Ok(()) } From 3d80e3f329d9d89df419a8a676c61844c69d9305 Mon Sep 17 00:00:00 2001 From: Austin Schey Date: Fri, 27 Dec 2024 19:35:06 -0800 Subject: [PATCH 4/4] add PhantomData for additional lifetime check --- sqlx-sqlite/src/connection/preupdate_hook.rs | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/sqlx-sqlite/src/connection/preupdate_hook.rs b/sqlx-sqlite/src/connection/preupdate_hook.rs index fcc0fe0bc3..edcb078124 100644 --- a/sqlx-sqlite/src/connection/preupdate_hook.rs +++ b/sqlx-sqlite/src/connection/preupdate_hook.rs @@ -7,6 +7,7 @@ use libsqlite3_sys::{ sqlite3_preupdate_old, sqlite3_value, sqlite3_value_type, SQLITE_OK, }; use std::ffi::CStr; +use std::marker::PhantomData; use std::os::raw::{c_char, c_int, c_void}; use std::panic::catch_unwind; use std::ptr; @@ -35,9 +36,10 @@ pub struct PreupdateHookResult<'a> { pub operation: SqliteOperation, pub database: &'a str, pub table: &'a str, + db: *mut sqlite3, // The database pointer should not be usable after the preupdate hook. // The lifetime on this struct needs to ensure it cannot outlive the callback. - db: *mut sqlite3, + _db_lifetime: PhantomData<&'a ()>, old_row_id: i64, new_row_id: i64, } @@ -151,6 +153,7 @@ pub(super) extern "C" fn preupdate_hook( old_row_id, new_row_id, db, + _db_lifetime: PhantomData, }) }); }