From aa0358fe072d4cb4ccc2d12738025b81bb47d36c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20L=C3=B3pez?= Date: Tue, 2 Dec 2025 20:12:59 +0100 Subject: [PATCH 1/2] Use higher half of AtomicUsize to track mutable borrows Use higher half of AtomicUsize to track mutable borrows, instead of just the highest bit. This is in preparation to implement map_split() for AtomicRef and AtomicRefMut, the latter of which will require having more than one mutable borrow. Unfortunately, this means that dropping an AtomicBorrowRefMut cannot unconditionally set the borrow count to zero anymore, as in the future there may be more than one mutable borrow. We need to drop the reference count by one, and also clear the lower bits for failed attempts to acquire an immutable borrow, which can only be done in a single step with a CAS loop. --- src/lib.rs | 44 +++++++++++++++++++++++++++++++++----------- 1 file changed, 33 insertions(+), 11 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index 1f50069..c7059fb 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -194,8 +194,15 @@ impl AtomicRefCell { // Core synchronization logic. Keep this section small and easy to audit. // -const HIGH_BIT: usize = !(::core::usize::MAX >> 1); -const MAX_FAILED_BORROWS: usize = HIGH_BIT + (HIGH_BIT >> 1); +/// Amount of bits used to keep track of readers and writers, respectively. +const REFCOUNT_BITS: usize = (usize::BITS / 2) as usize; +/// Bitmask for readers (lowest half of `usize`) +const LOW_MASK: usize = ::core::usize::MAX >> REFCOUNT_BITS; +/// Bitmask for writers (highest half of `usize`) +const HIGH_MASK: usize = LOW_MASK << REFCOUNT_BITS; +/// Maximum allowed value in the lowest bits during a mutable borrow. This acts +/// as a threshold to prevent overflowing the reader refcount. +const MAX_FAILED_BORROWS: usize = LOW_MASK >> 1; struct AtomicBorrowRef<'b> { borrow: &'b AtomicUsize, @@ -204,9 +211,11 @@ struct AtomicBorrowRef<'b> { impl<'b> AtomicBorrowRef<'b> { #[inline] fn try_new(borrow: &'b AtomicUsize) -> Result { - let new = borrow.fetch_add(1, atomic::Ordering::Acquire) + 1; - if new & HIGH_BIT != 0 { - // If the new count has the high bit set, that almost certainly + let new = borrow + .fetch_add(1, atomic::Ordering::Acquire) + .wrapping_add(1); + if new & HIGH_MASK != 0 { + // If the new count has the high bits set, that almost certainly // means there's an pre-existing mutable borrow. In that case, // we simply leave the increment as a benign side-effect and // return `Err`. Once the mutable borrow is released, the @@ -225,7 +234,7 @@ impl<'b> AtomicBorrowRef<'b> { #[cold] #[inline(never)] fn check_overflow(borrow: &'b AtomicUsize, new: usize) { - if new == HIGH_BIT { + if new & LOW_MASK == 0 { // We overflowed into the reserved upper half of the refcount // space. Before panicking, decrement the refcount to leave things // in a consistent immutable-borrow state. @@ -234,7 +243,7 @@ impl<'b> AtomicBorrowRef<'b> { // in a tight loop. borrow.fetch_sub(1, atomic::Ordering::Release); panic!("too many immutable borrows"); - } else if new >= MAX_FAILED_BORROWS { + } else if new & LOW_MASK >= MAX_FAILED_BORROWS { // During the mutable borrow, an absurd number of threads have // attempted to increment the refcount with immutable borrows. // To avoid hypothetically wrapping the refcount, we abort the @@ -275,7 +284,7 @@ impl<'b> Drop for AtomicBorrowRef<'b> { // thread hits the hypothetical overflow case, since we might observe // the refcount before it fixes it up (and panics). But that never will // never happen in a real program, and this is a debug_assert! anyway. - debug_assert!(old & HIGH_BIT == 0); + debug_assert!(old & HIGH_MASK == 0); } } @@ -286,7 +295,20 @@ struct AtomicBorrowRefMut<'b> { impl<'b> Drop for AtomicBorrowRefMut<'b> { #[inline] fn drop(&mut self) { - self.borrow.store(0, atomic::Ordering::Release); + // Drop the mutable borrow reference count by one, and clear the lower + // bits from failed immutable borrows. This must be done in a single + // step: if we fetch_sub() first an then clear the lower bits, an + // immutable borrow could happen in the middle, and our clear would + // erase that borrow. If we clear the lower bits first, more failed + // immutable borrows could happen before we fetch_sub(). So, use a CAS + // loop. + // + // The closure always returns Some(), so the result is always Ok() + let _ = + self.borrow + .fetch_update(atomic::Ordering::Release, atomic::Ordering::Relaxed, |b| { + Some((b - (1 << REFCOUNT_BITS)) & HIGH_MASK) + }); } } @@ -297,7 +319,7 @@ impl<'b> AtomicBorrowRefMut<'b> { // on illegal mutable borrows. let old = match borrow.compare_exchange( 0, - HIGH_BIT, + 1 << REFCOUNT_BITS, atomic::Ordering::Acquire, atomic::Ordering::Relaxed, ) { @@ -307,7 +329,7 @@ impl<'b> AtomicBorrowRefMut<'b> { if old == 0 { Ok(AtomicBorrowRefMut { borrow }) - } else if old & HIGH_BIT == 0 { + } else if old & HIGH_MASK == 0 { Err("already immutably borrowed") } else { Err("already mutably borrowed") From 215a1c5aaddaf8c7331c18883bec797524ad9d2b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20L=C3=B3pez?= Date: Tue, 2 Dec 2025 22:55:05 +0100 Subject: [PATCH 2/2] Implement AtomicRef{,Mut}::map_split() Add a new method to split an AtomicRef / AtomicRefMut into two, by borrowing from distinct parts of the backing data, just like Ref{,Mut}::map_split() in the standard library. Unfortunately, the mutable borrow increment during the split must be peformed with a CAS loop, as we cannot afford to overflow the borrow count without potentially triggering undefined behavior. Add as well some basic tests to verify the functionality of the new methods. --- src/lib.rs | 138 +++++++++++++++++++++++++++++++++++++++++++++++++ tests/basic.rs | 71 +++++++++++++++++++++++++ 2 files changed, 209 insertions(+) diff --git a/src/lib.rs b/src/lib.rs index c7059fb..c205a6a 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -335,6 +335,46 @@ impl<'b> AtomicBorrowRefMut<'b> { Err("already mutably borrowed") } } + + /// Attempts to create a new `AtomicBorrowRefMut` by incrementing the + /// mutable borrow count by 1. + /// + /// # Errors + /// + /// Returns an error if the mutable borrow count would overflow. + fn try_clone(&self) -> Result, &'static str> { + // Increase the mutable borrow count. To avoid overflowing, do a CAS + // loop and a checked_add(). + // + // This is the only good way to do this - we cannot blindly fetch_add() + // and check later. Overflowing the atomic means that the writer count + // wraps to zero, allowing immutable borrows to happen. Consider the + // following scenario: + // + // let cell = AtomicRefCell::new(Foo::new()); + // let a = cell.borrow_mut(); + // let (b, c) = AtomicRefMut::map_split(a, ..); + // send_to_thread(b) + // .. + // let (d, e) = AtomicRefMut::map_split(c, ..); + // + // If the overflow happens during the last map_split(), there will be a + // window where `b` (a mutable borrow into a field of `Foo`) is alive in + // some other thread, while other threads are able to acquire immutable + // borrows into `cell`, causing UB. + // + // On the upside, this can never panic. + match self.borrow.fetch_update( + atomic::Ordering::Acquire, + atomic::Ordering::Relaxed, + |old| old.checked_add(1 << REFCOUNT_BITS), + ) { + Ok(_) => Ok(Self { + borrow: self.borrow, + }), + Err(_) => Err("mutable borrow count would overflow"), + } + } } unsafe impl Send for AtomicRefCell {} @@ -452,6 +492,52 @@ impl<'b, T: ?Sized> AtomicRef<'b, T> { borrow: orig.borrow, }) } + + /// Splits an `AtomicRef` into two `AtomicRef`s for different components of + /// the borrowed data. + /// + /// The underlying `AtomicRefCell` will remain borrowed until both returned + /// `AtomicRef`s go out of scope. + /// # Errors + /// + /// This function may fail if the operation would overflow the immutable + /// reference count. In this case, it will return the original `AtomicRef`. + #[inline] + pub fn try_map_split( + orig: AtomicRef<'b, T>, + f: F, + ) -> Result<(AtomicRef<'b, U>, AtomicRef<'b, V>), AtomicRef<'b, T>> + where + F: FnOnce(&T) -> (&U, &V), + { + let Ok(borrow) = AtomicBorrowRef::try_new(orig.borrow.borrow) else { + return Err(orig); + }; + let (a, b) = f(&*orig); + Ok(( + AtomicRef { + value: NonNull::from(a), + borrow, + }, + AtomicRef { + value: NonNull::from(b), + borrow: orig.borrow, + }, + )) + } + + /// Like [`try_map_split()`](Self::try_map_split), but instead panics + /// immediately on an error. + #[inline] + pub fn map_split(orig: AtomicRef<'b, T>, f: F) -> (AtomicRef<'b, U>, AtomicRef<'b, V>) + where + F: FnOnce(&T) -> (&U, &V), + { + if let Ok(ret) = Self::try_map_split(orig, f) { + return ret; + }; + panic!("immutable borrow count overflow"); + } } impl<'b, T: ?Sized> AtomicRefMut<'b, T> { @@ -484,6 +570,58 @@ impl<'b, T: ?Sized> AtomicRefMut<'b, T> { marker: PhantomData, }) } + + /// Splits an `AtomicRefMut` into two `AtomicRefMut`s for different + /// components of the borrowed data. + /// + /// The underlying `AtomicRefCell` will remain mutably borrowed until both + /// returned `AtomicRefMut`s go out of scope. + /// + /// # Errors + /// + /// This function may fail if the operation would overflow the mutable + /// reference count. In this case, it will return the original `AtomicRefMut`. + #[inline] + pub fn try_map_split( + mut orig: AtomicRefMut<'b, T>, + f: F, + ) -> Result<(AtomicRefMut<'b, U>, AtomicRefMut<'b, V>), AtomicRefMut<'b, T>> + where + F: FnOnce(&mut T) -> (&mut U, &mut V), + { + let Ok(borrow) = orig.borrow.try_clone() else { + return Err(orig); + }; + let (a, b) = f(&mut *orig); + Ok(( + AtomicRefMut { + value: NonNull::from(a), + borrow, + marker: PhantomData, + }, + AtomicRefMut { + value: NonNull::from(b), + borrow: orig.borrow, + marker: PhantomData, + }, + )) + } + + /// Like [`try_map_split()`](Self::try_map_split), but instead panics + /// immediately on an error. + #[inline] + pub fn map_split( + orig: AtomicRefMut<'b, T>, + f: F, + ) -> (AtomicRefMut<'b, U>, AtomicRefMut<'b, V>) + where + F: FnOnce(&mut T) -> (&mut U, &mut V), + { + if let Ok(ret) = Self::try_map_split(orig, f) { + return ret; + }; + panic!("mutable borrow count overflow"); + } } /// A wrapper type for a mutably borrowed value from an `AtomicRefCell`. diff --git a/tests/basic.rs b/tests/basic.rs index 49eaa77..6e529c7 100644 --- a/tests/basic.rs +++ b/tests/basic.rs @@ -87,6 +87,77 @@ fn try_interleaved() { } } +#[derive(Default, Clone, Copy, Debug)] +struct Baz { + a: u32, + b: u32, +} + +#[test] +fn map_split() { + let a = AtomicRefCell::new(Baz::default()); + { + let _ = a.try_borrow_mut().unwrap(); + } + let read = a.try_borrow().unwrap(); + let (first, second) = AtomicRef::map_split(read, |baz| (&baz.a, &baz.b)); + + // No writers allowed until both readers go away + let _ = a.try_borrow_mut().unwrap_err(); + drop(first); + let _ = a.try_borrow_mut().unwrap_err(); + drop(second); + + let _write = a.try_borrow_mut().unwrap(); +} + +#[test] +#[should_panic(expected = "already immutably borrowed")] +fn map_split_panic() { + let a = AtomicRefCell::new(Baz::default()); + let read = a.try_borrow().unwrap(); + let (first, second) = AtomicRef::map_split(read, |baz| (&baz.a, &baz.b)); + drop(first); + // This should panic even if one of the two immutable references was dropped + let _ = a.borrow_mut(); +} + +#[test] +fn map_split_mut() { + let a = AtomicRefCell::new(Baz::default()); + { + let _ = a.try_borrow().unwrap(); + } + let write = a.try_borrow_mut().unwrap(); + let (first, second) = AtomicRefMut::map_split(write, |baz| (&mut baz.a, &mut baz.b)); + + // No readers or writers allowed until both writers go away + let _ = a.try_borrow().unwrap_err(); + let _ = a.try_borrow_mut().unwrap_err(); + drop(first); + let _ = a.try_borrow().unwrap_err(); + let _ = a.try_borrow_mut().unwrap_err(); + drop(second); + + { + let _ = a.try_borrow().unwrap(); + } + { + let _ = a.try_borrow_mut().unwrap(); + } +} + +#[test] +#[should_panic(expected = "already mutably borrowed")] +fn map_split_mut_panic() { + let a = AtomicRefCell::new(Baz::default()); + let write = a.try_borrow_mut().unwrap(); + let (first, second) = AtomicRefMut::map_split(write, |baz| (&mut baz.a, &mut baz.b)); + drop(first); + // This should panic even if one of the two mutable references was dropped + let _ = a.borrow_mut(); +} + // For Miri to catch issues when calling a function. // // See how this scenerio affects std::cell::RefCell implementation: