From a5b3a331966b72d4bb80d6c374e544bcd0ff0490 Mon Sep 17 00:00:00 2001 From: Florian Uekermann Date: Tue, 2 Sep 2025 09:09:46 +0200 Subject: [PATCH] feat(sqlite): add WAL hook support --- sqlx-sqlite/src/connection/establish.rs | 1 + sqlx-sqlite/src/connection/mod.rs | 67 ++++++++++++++++++++++++- tests/sqlite/sqlite.rs | 56 +++++++++++++++++++++ 3 files changed, 123 insertions(+), 1 deletion(-) diff --git a/sqlx-sqlite/src/connection/establish.rs b/sqlx-sqlite/src/connection/establish.rs index d811275409..88d8e2ac48 100644 --- a/sqlx-sqlite/src/connection/establish.rs +++ b/sqlx-sqlite/src/connection/establish.rs @@ -188,6 +188,7 @@ impl EstablishParams { preupdate_hook_callback: None, commit_hook_callback: None, rollback_hook_callback: None, + wal_hook_callback: None, }) } diff --git a/sqlx-sqlite/src/connection/mod.rs b/sqlx-sqlite/src/connection/mod.rs index 218c747143..3593317d75 100644 --- a/sqlx-sqlite/src/connection/mod.rs +++ b/sqlx-sqlite/src/connection/mod.rs @@ -11,7 +11,7 @@ use std::ptr::NonNull; use futures_intrusive::sync::MutexGuard; use libsqlite3_sys::{ sqlite3, sqlite3_commit_hook, sqlite3_progress_handler, sqlite3_rollback_hook, - sqlite3_update_hook, SQLITE_DELETE, SQLITE_INSERT, SQLITE_UPDATE, + sqlite3_update_hook, sqlite3_wal_hook, SQLITE_DELETE, SQLITE_INSERT, SQLITE_OK, SQLITE_UPDATE, }; #[cfg(feature = "preupdate-hook")] pub use preupdate_hook::*; @@ -96,6 +96,11 @@ pub struct UpdateHookResult<'a> { pub rowid: i64, } +pub struct WalHookResult<'a> { + pub database: &'a str, + pub page_count: i32, +} + pub(crate) struct UpdateHookHandler(NonNull); unsafe impl Send for UpdateHookHandler {} @@ -105,6 +110,9 @@ unsafe impl Send for CommitHookHandler {} pub(crate) struct RollbackHookHandler(NonNull); unsafe impl Send for RollbackHookHandler {} +pub(crate) struct WalHookHandler(NonNull); +unsafe impl Send for WalHookHandler {} + pub(crate) struct ConnectionState { pub(crate) handle: ConnectionHandle, @@ -123,6 +131,8 @@ pub(crate) struct ConnectionState { commit_hook_callback: Option, rollback_hook_callback: Option, + + wal_hook_callback: Option, } impl ConnectionState { @@ -172,6 +182,15 @@ impl ConnectionState { } } } + + pub(crate) fn remove_wal_hook(&mut self) { + if let Some(mut handler) = self.wal_hook_callback.take() { + unsafe { + sqlite3_wal_hook(self.handle.as_ptr(), None, ptr::null_mut()); + let _ = { Box::from_raw(handler.0.as_mut()) }; + } + } + } } pub(crate) struct Statements { @@ -353,6 +372,28 @@ where } } +extern "C" fn wal_hook( + callback: *mut c_void, + _db: *mut sqlite3, + database: *const c_char, + page_count: c_int, +) -> c_int +where + F: FnMut(WalHookResult) + Send + 'static, +{ + unsafe { + let _ = catch_unwind(|| { + let callback: *mut F = callback.cast::(); + let database = CStr::from_ptr(database).to_str().unwrap_or_default(); + (*callback)(WalHookResult { + database, + page_count, + }) + }); + } + SQLITE_OK +} + impl LockedSqliteHandle<'_> { /// Returns the underlying sqlite3* connection handle. /// @@ -520,6 +561,26 @@ impl LockedSqliteHandle<'_> { } } + /// Sets a WAL hook that is invoked whenever a commit occurs in WAL mode. Only a single WAL hook may be + /// defined at one time per database connection; setting a new WAL hook overrides the old one. + /// + /// Note that sqlite3_wal_autocheckpoint() and the wal_autocheckpoint pragma overwrite the WAL hook. + pub fn set_wal_hook(&mut self, callback: F) + where + F: FnMut(WalHookResult) + Send + 'static, + { + unsafe { + let callback_boxed = Box::new(callback); + // SAFETY: `Box::into_raw()` always returns a non-null pointer. + let callback = NonNull::new_unchecked(Box::into_raw(callback_boxed)); + let handler = callback.as_ptr() as *mut _; + self.guard.remove_wal_hook(); + self.guard.wal_hook_callback = Some(WalHookHandler(callback)); + + sqlite3_wal_hook(self.as_raw_handle().as_mut(), Some(wal_hook::), handler); + } + } + /// Removes the progress handler on a database connection. The method does nothing if no handler was set. pub fn remove_progress_handler(&mut self) { self.guard.remove_progress_handler(); @@ -542,6 +603,10 @@ impl LockedSqliteHandle<'_> { self.guard.remove_rollback_hook(); } + pub fn remove_wal_hook(&mut self) { + self.guard.remove_wal_hook(); + } + pub fn last_error(&mut self) -> Option { self.guard.handle.last_error() } diff --git a/tests/sqlite/sqlite.rs b/tests/sqlite/sqlite.rs index d8f8ee492c..1b7b09abb8 100644 --- a/tests/sqlite/sqlite.rs +++ b/tests/sqlite/sqlite.rs @@ -978,6 +978,62 @@ async fn test_multiple_set_rollback_hook_calls_drop_old_handler() -> anyhow::Res Ok(()) } +#[sqlx_macros::test] +async fn test_query_with_wal_hook() -> anyhow::Result<()> { + let mut conn = new::().await?; + conn.execute("PRAGMA journal_mode = WAL").await?; + + // Using this string as a canary to ensure the callback doesn't get called with the wrong data pointer. + let state = "test".to_string(); + static CALLED: AtomicBool = AtomicBool::new(false); + conn.lock_handle().await?.set_wal_hook(move |_| { + assert_eq!(state, "test"); + CALLED.store(true, Ordering::Relaxed); + }); + + let mut tx = conn.begin().await?; + sqlx::query("INSERT INTO tweet ( id, text ) VALUES (5, 'Hello, World' )") + .execute(&mut *tx) + .await?; + assert!(!CALLED.load(Ordering::Relaxed)); + tx.commit().await?; + assert!(CALLED.load(Ordering::Relaxed)); + Ok(()) +} + +#[sqlx_macros::test] +async fn test_multiple_set_wal_hook_calls_drop_old_handler() -> anyhow::Result<()> { + let ref_counted_object = Arc::new(0); + assert_eq!(1, Arc::strong_count(&ref_counted_object)); + + { + let mut conn = new::().await?; + + let o = ref_counted_object.clone(); + conn.lock_handle().await?.set_wal_hook(move |_| { + println!("{o:?}"); + }); + assert_eq!(2, Arc::strong_count(&ref_counted_object)); + + let o = ref_counted_object.clone(); + conn.lock_handle().await?.set_wal_hook(move |_| { + println!("{o:?}"); + }); + assert_eq!(2, Arc::strong_count(&ref_counted_object)); + + let o = ref_counted_object.clone(); + conn.lock_handle().await?.set_wal_hook(move |_| { + println!("{o:?}"); + }); + assert_eq!(2, Arc::strong_count(&ref_counted_object)); + + conn.lock_handle().await?.remove_wal_hook(); + } + + assert_eq!(1, Arc::strong_count(&ref_counted_object)); + Ok(()) +} + #[sqlx_macros::test] async fn issue_3150() { // Same bounds as `tokio::spawn()`