Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions sqlx-sqlite/src/connection/establish.rs
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,7 @@ impl EstablishParams {
preupdate_hook_callback: None,
commit_hook_callback: None,
rollback_hook_callback: None,
wal_hook_callback: None,
})
}

Expand Down
67 changes: 66 additions & 1 deletion sqlx-sqlite/src/connection/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ use std::ptr::NonNull;
use futures_intrusive::sync::MutexGuard;
use libsqlite3_sys::{
sqlite3, sqlite3_commit_hook, sqlite3_progress_handler, sqlite3_rollback_hook,
sqlite3_update_hook, SQLITE_DELETE, SQLITE_INSERT, SQLITE_UPDATE,
sqlite3_update_hook, sqlite3_wal_hook, SQLITE_DELETE, SQLITE_INSERT, SQLITE_OK, SQLITE_UPDATE,
};
#[cfg(feature = "preupdate-hook")]
pub use preupdate_hook::*;
Expand Down Expand Up @@ -96,6 +96,11 @@ pub struct UpdateHookResult<'a> {
pub rowid: i64,
}

pub struct WalHookResult<'a> {
pub database: &'a str,
pub page_count: i32,
}

pub(crate) struct UpdateHookHandler(NonNull<dyn FnMut(UpdateHookResult) + Send + 'static>);
unsafe impl Send for UpdateHookHandler {}

Expand All @@ -105,6 +110,9 @@ unsafe impl Send for CommitHookHandler {}
pub(crate) struct RollbackHookHandler(NonNull<dyn FnMut() + Send + 'static>);
unsafe impl Send for RollbackHookHandler {}

pub(crate) struct WalHookHandler(NonNull<dyn FnMut(WalHookResult) + Send + 'static>);
unsafe impl Send for WalHookHandler {}

pub(crate) struct ConnectionState {
pub(crate) handle: ConnectionHandle,

Expand All @@ -123,6 +131,8 @@ pub(crate) struct ConnectionState {
commit_hook_callback: Option<CommitHookHandler>,

rollback_hook_callback: Option<RollbackHookHandler>,

wal_hook_callback: Option<WalHookHandler>,
}

impl ConnectionState {
Expand Down Expand Up @@ -172,6 +182,15 @@ impl ConnectionState {
}
}
}

pub(crate) fn remove_wal_hook(&mut self) {
if let Some(mut handler) = self.wal_hook_callback.take() {
unsafe {
sqlite3_wal_hook(self.handle.as_ptr(), None, ptr::null_mut());
let _ = { Box::from_raw(handler.0.as_mut()) };
}
}
}
}

pub(crate) struct Statements {
Expand Down Expand Up @@ -353,6 +372,28 @@ where
}
}

extern "C" fn wal_hook<F>(
callback: *mut c_void,
_db: *mut sqlite3,
database: *const c_char,
page_count: c_int,
) -> c_int
where
F: FnMut(WalHookResult) + Send + 'static,
{
unsafe {
let _ = catch_unwind(|| {
let callback: *mut F = callback.cast::<F>();
let database = CStr::from_ptr(database).to_str().unwrap_or_default();
(*callback)(WalHookResult {
database,
page_count,
})
});
}
SQLITE_OK
}

impl LockedSqliteHandle<'_> {
/// Returns the underlying sqlite3* connection handle.
///
Expand Down Expand Up @@ -520,6 +561,26 @@ impl LockedSqliteHandle<'_> {
}
}

/// Sets a WAL hook that is invoked whenever a commit occurs in WAL mode. Only a single WAL hook may be
/// defined at one time per database connection; setting a new WAL hook overrides the old one.
///
/// Note that sqlite3_wal_autocheckpoint() and the wal_autocheckpoint pragma overwrite the WAL hook.
pub fn set_wal_hook<F>(&mut self, callback: F)
where
F: FnMut(WalHookResult) + Send + 'static,
{
unsafe {
let callback_boxed = Box::new(callback);
// SAFETY: `Box::into_raw()` always returns a non-null pointer.
let callback = NonNull::new_unchecked(Box::into_raw(callback_boxed));
let handler = callback.as_ptr() as *mut _;
self.guard.remove_wal_hook();
self.guard.wal_hook_callback = Some(WalHookHandler(callback));

sqlite3_wal_hook(self.as_raw_handle().as_mut(), Some(wal_hook::<F>), handler);
}
}

/// Removes the progress handler on a database connection. The method does nothing if no handler was set.
pub fn remove_progress_handler(&mut self) {
self.guard.remove_progress_handler();
Expand All @@ -542,6 +603,10 @@ impl LockedSqliteHandle<'_> {
self.guard.remove_rollback_hook();
}

pub fn remove_wal_hook(&mut self) {
self.guard.remove_wal_hook();
}

pub fn last_error(&mut self) -> Option<SqliteError> {
self.guard.handle.last_error()
}
Expand Down
56 changes: 56 additions & 0 deletions tests/sqlite/sqlite.rs
Original file line number Diff line number Diff line change
Expand Up @@ -978,6 +978,62 @@ async fn test_multiple_set_rollback_hook_calls_drop_old_handler() -> anyhow::Res
Ok(())
}

#[sqlx_macros::test]
async fn test_query_with_wal_hook() -> anyhow::Result<()> {
let mut conn = new::<Sqlite>().await?;
conn.execute("PRAGMA journal_mode = WAL").await?;

// Using this string as a canary to ensure the callback doesn't get called with the wrong data pointer.
let state = "test".to_string();
static CALLED: AtomicBool = AtomicBool::new(false);
conn.lock_handle().await?.set_wal_hook(move |_| {
assert_eq!(state, "test");
CALLED.store(true, Ordering::Relaxed);
});

let mut tx = conn.begin().await?;
sqlx::query("INSERT INTO tweet ( id, text ) VALUES (5, 'Hello, World' )")
.execute(&mut *tx)
.await?;
assert!(!CALLED.load(Ordering::Relaxed));
tx.commit().await?;
assert!(CALLED.load(Ordering::Relaxed));
Ok(())
}

#[sqlx_macros::test]
async fn test_multiple_set_wal_hook_calls_drop_old_handler() -> anyhow::Result<()> {
let ref_counted_object = Arc::new(0);
assert_eq!(1, Arc::strong_count(&ref_counted_object));

{
let mut conn = new::<Sqlite>().await?;

let o = ref_counted_object.clone();
conn.lock_handle().await?.set_wal_hook(move |_| {
println!("{o:?}");
});
assert_eq!(2, Arc::strong_count(&ref_counted_object));

let o = ref_counted_object.clone();
conn.lock_handle().await?.set_wal_hook(move |_| {
println!("{o:?}");
});
assert_eq!(2, Arc::strong_count(&ref_counted_object));

let o = ref_counted_object.clone();
conn.lock_handle().await?.set_wal_hook(move |_| {
println!("{o:?}");
});
assert_eq!(2, Arc::strong_count(&ref_counted_object));

conn.lock_handle().await?.remove_wal_hook();
}

assert_eq!(1, Arc::strong_count(&ref_counted_object));
Ok(())
}

#[sqlx_macros::test]
async fn issue_3150() {
// Same bounds as `tokio::spawn()`
Expand Down
Loading