diff --git a/src/lib.rs b/src/lib.rs index 1f50069..c205a6a 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -194,8 +194,15 @@ impl AtomicRefCell { // Core synchronization logic. Keep this section small and easy to audit. // -const HIGH_BIT: usize = !(::core::usize::MAX >> 1); -const MAX_FAILED_BORROWS: usize = HIGH_BIT + (HIGH_BIT >> 1); +/// Amount of bits used to keep track of readers and writers, respectively. +const REFCOUNT_BITS: usize = (usize::BITS / 2) as usize; +/// Bitmask for readers (lowest half of `usize`) +const LOW_MASK: usize = ::core::usize::MAX >> REFCOUNT_BITS; +/// Bitmask for writers (highest half of `usize`) +const HIGH_MASK: usize = LOW_MASK << REFCOUNT_BITS; +/// Maximum allowed value in the lowest bits during a mutable borrow. This acts +/// as a threshold to prevent overflowing the reader refcount. +const MAX_FAILED_BORROWS: usize = LOW_MASK >> 1; struct AtomicBorrowRef<'b> { borrow: &'b AtomicUsize, @@ -204,9 +211,11 @@ struct AtomicBorrowRef<'b> { impl<'b> AtomicBorrowRef<'b> { #[inline] fn try_new(borrow: &'b AtomicUsize) -> Result { - let new = borrow.fetch_add(1, atomic::Ordering::Acquire) + 1; - if new & HIGH_BIT != 0 { - // If the new count has the high bit set, that almost certainly + let new = borrow + .fetch_add(1, atomic::Ordering::Acquire) + .wrapping_add(1); + if new & HIGH_MASK != 0 { + // If the new count has the high bits set, that almost certainly // means there's an pre-existing mutable borrow. In that case, // we simply leave the increment as a benign side-effect and // return `Err`. Once the mutable borrow is released, the @@ -225,7 +234,7 @@ impl<'b> AtomicBorrowRef<'b> { #[cold] #[inline(never)] fn check_overflow(borrow: &'b AtomicUsize, new: usize) { - if new == HIGH_BIT { + if new & LOW_MASK == 0 { // We overflowed into the reserved upper half of the refcount // space. Before panicking, decrement the refcount to leave things // in a consistent immutable-borrow state. @@ -234,7 +243,7 @@ impl<'b> AtomicBorrowRef<'b> { // in a tight loop. borrow.fetch_sub(1, atomic::Ordering::Release); panic!("too many immutable borrows"); - } else if new >= MAX_FAILED_BORROWS { + } else if new & LOW_MASK >= MAX_FAILED_BORROWS { // During the mutable borrow, an absurd number of threads have // attempted to increment the refcount with immutable borrows. // To avoid hypothetically wrapping the refcount, we abort the @@ -275,7 +284,7 @@ impl<'b> Drop for AtomicBorrowRef<'b> { // thread hits the hypothetical overflow case, since we might observe // the refcount before it fixes it up (and panics). But that never will // never happen in a real program, and this is a debug_assert! anyway. - debug_assert!(old & HIGH_BIT == 0); + debug_assert!(old & HIGH_MASK == 0); } } @@ -286,7 +295,20 @@ struct AtomicBorrowRefMut<'b> { impl<'b> Drop for AtomicBorrowRefMut<'b> { #[inline] fn drop(&mut self) { - self.borrow.store(0, atomic::Ordering::Release); + // Drop the mutable borrow reference count by one, and clear the lower + // bits from failed immutable borrows. This must be done in a single + // step: if we fetch_sub() first an then clear the lower bits, an + // immutable borrow could happen in the middle, and our clear would + // erase that borrow. If we clear the lower bits first, more failed + // immutable borrows could happen before we fetch_sub(). So, use a CAS + // loop. + // + // The closure always returns Some(), so the result is always Ok() + let _ = + self.borrow + .fetch_update(atomic::Ordering::Release, atomic::Ordering::Relaxed, |b| { + Some((b - (1 << REFCOUNT_BITS)) & HIGH_MASK) + }); } } @@ -297,7 +319,7 @@ impl<'b> AtomicBorrowRefMut<'b> { // on illegal mutable borrows. let old = match borrow.compare_exchange( 0, - HIGH_BIT, + 1 << REFCOUNT_BITS, atomic::Ordering::Acquire, atomic::Ordering::Relaxed, ) { @@ -307,12 +329,52 @@ impl<'b> AtomicBorrowRefMut<'b> { if old == 0 { Ok(AtomicBorrowRefMut { borrow }) - } else if old & HIGH_BIT == 0 { + } else if old & HIGH_MASK == 0 { Err("already immutably borrowed") } else { Err("already mutably borrowed") } } + + /// Attempts to create a new `AtomicBorrowRefMut` by incrementing the + /// mutable borrow count by 1. + /// + /// # Errors + /// + /// Returns an error if the mutable borrow count would overflow. + fn try_clone(&self) -> Result, &'static str> { + // Increase the mutable borrow count. To avoid overflowing, do a CAS + // loop and a checked_add(). + // + // This is the only good way to do this - we cannot blindly fetch_add() + // and check later. Overflowing the atomic means that the writer count + // wraps to zero, allowing immutable borrows to happen. Consider the + // following scenario: + // + // let cell = AtomicRefCell::new(Foo::new()); + // let a = cell.borrow_mut(); + // let (b, c) = AtomicRefMut::map_split(a, ..); + // send_to_thread(b) + // .. + // let (d, e) = AtomicRefMut::map_split(c, ..); + // + // If the overflow happens during the last map_split(), there will be a + // window where `b` (a mutable borrow into a field of `Foo`) is alive in + // some other thread, while other threads are able to acquire immutable + // borrows into `cell`, causing UB. + // + // On the upside, this can never panic. + match self.borrow.fetch_update( + atomic::Ordering::Acquire, + atomic::Ordering::Relaxed, + |old| old.checked_add(1 << REFCOUNT_BITS), + ) { + Ok(_) => Ok(Self { + borrow: self.borrow, + }), + Err(_) => Err("mutable borrow count would overflow"), + } + } } unsafe impl Send for AtomicRefCell {} @@ -430,6 +492,52 @@ impl<'b, T: ?Sized> AtomicRef<'b, T> { borrow: orig.borrow, }) } + + /// Splits an `AtomicRef` into two `AtomicRef`s for different components of + /// the borrowed data. + /// + /// The underlying `AtomicRefCell` will remain borrowed until both returned + /// `AtomicRef`s go out of scope. + /// # Errors + /// + /// This function may fail if the operation would overflow the immutable + /// reference count. In this case, it will return the original `AtomicRef`. + #[inline] + pub fn try_map_split( + orig: AtomicRef<'b, T>, + f: F, + ) -> Result<(AtomicRef<'b, U>, AtomicRef<'b, V>), AtomicRef<'b, T>> + where + F: FnOnce(&T) -> (&U, &V), + { + let Ok(borrow) = AtomicBorrowRef::try_new(orig.borrow.borrow) else { + return Err(orig); + }; + let (a, b) = f(&*orig); + Ok(( + AtomicRef { + value: NonNull::from(a), + borrow, + }, + AtomicRef { + value: NonNull::from(b), + borrow: orig.borrow, + }, + )) + } + + /// Like [`try_map_split()`](Self::try_map_split), but instead panics + /// immediately on an error. + #[inline] + pub fn map_split(orig: AtomicRef<'b, T>, f: F) -> (AtomicRef<'b, U>, AtomicRef<'b, V>) + where + F: FnOnce(&T) -> (&U, &V), + { + if let Ok(ret) = Self::try_map_split(orig, f) { + return ret; + }; + panic!("immutable borrow count overflow"); + } } impl<'b, T: ?Sized> AtomicRefMut<'b, T> { @@ -462,6 +570,58 @@ impl<'b, T: ?Sized> AtomicRefMut<'b, T> { marker: PhantomData, }) } + + /// Splits an `AtomicRefMut` into two `AtomicRefMut`s for different + /// components of the borrowed data. + /// + /// The underlying `AtomicRefCell` will remain mutably borrowed until both + /// returned `AtomicRefMut`s go out of scope. + /// + /// # Errors + /// + /// This function may fail if the operation would overflow the mutable + /// reference count. In this case, it will return the original `AtomicRefMut`. + #[inline] + pub fn try_map_split( + mut orig: AtomicRefMut<'b, T>, + f: F, + ) -> Result<(AtomicRefMut<'b, U>, AtomicRefMut<'b, V>), AtomicRefMut<'b, T>> + where + F: FnOnce(&mut T) -> (&mut U, &mut V), + { + let Ok(borrow) = orig.borrow.try_clone() else { + return Err(orig); + }; + let (a, b) = f(&mut *orig); + Ok(( + AtomicRefMut { + value: NonNull::from(a), + borrow, + marker: PhantomData, + }, + AtomicRefMut { + value: NonNull::from(b), + borrow: orig.borrow, + marker: PhantomData, + }, + )) + } + + /// Like [`try_map_split()`](Self::try_map_split), but instead panics + /// immediately on an error. + #[inline] + pub fn map_split( + orig: AtomicRefMut<'b, T>, + f: F, + ) -> (AtomicRefMut<'b, U>, AtomicRefMut<'b, V>) + where + F: FnOnce(&mut T) -> (&mut U, &mut V), + { + if let Ok(ret) = Self::try_map_split(orig, f) { + return ret; + }; + panic!("mutable borrow count overflow"); + } } /// A wrapper type for a mutably borrowed value from an `AtomicRefCell`. diff --git a/tests/basic.rs b/tests/basic.rs index 49eaa77..6e529c7 100644 --- a/tests/basic.rs +++ b/tests/basic.rs @@ -87,6 +87,77 @@ fn try_interleaved() { } } +#[derive(Default, Clone, Copy, Debug)] +struct Baz { + a: u32, + b: u32, +} + +#[test] +fn map_split() { + let a = AtomicRefCell::new(Baz::default()); + { + let _ = a.try_borrow_mut().unwrap(); + } + let read = a.try_borrow().unwrap(); + let (first, second) = AtomicRef::map_split(read, |baz| (&baz.a, &baz.b)); + + // No writers allowed until both readers go away + let _ = a.try_borrow_mut().unwrap_err(); + drop(first); + let _ = a.try_borrow_mut().unwrap_err(); + drop(second); + + let _write = a.try_borrow_mut().unwrap(); +} + +#[test] +#[should_panic(expected = "already immutably borrowed")] +fn map_split_panic() { + let a = AtomicRefCell::new(Baz::default()); + let read = a.try_borrow().unwrap(); + let (first, second) = AtomicRef::map_split(read, |baz| (&baz.a, &baz.b)); + drop(first); + // This should panic even if one of the two immutable references was dropped + let _ = a.borrow_mut(); +} + +#[test] +fn map_split_mut() { + let a = AtomicRefCell::new(Baz::default()); + { + let _ = a.try_borrow().unwrap(); + } + let write = a.try_borrow_mut().unwrap(); + let (first, second) = AtomicRefMut::map_split(write, |baz| (&mut baz.a, &mut baz.b)); + + // No readers or writers allowed until both writers go away + let _ = a.try_borrow().unwrap_err(); + let _ = a.try_borrow_mut().unwrap_err(); + drop(first); + let _ = a.try_borrow().unwrap_err(); + let _ = a.try_borrow_mut().unwrap_err(); + drop(second); + + { + let _ = a.try_borrow().unwrap(); + } + { + let _ = a.try_borrow_mut().unwrap(); + } +} + +#[test] +#[should_panic(expected = "already mutably borrowed")] +fn map_split_mut_panic() { + let a = AtomicRefCell::new(Baz::default()); + let write = a.try_borrow_mut().unwrap(); + let (first, second) = AtomicRefMut::map_split(write, |baz| (&mut baz.a, &mut baz.b)); + drop(first); + // This should panic even if one of the two mutable references was dropped + let _ = a.borrow_mut(); +} + // For Miri to catch issues when calling a function. // // See how this scenerio affects std::cell::RefCell implementation: