diff --git a/Cargo.toml b/Cargo.toml index 45e433fff..aafe818f2 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -17,6 +17,7 @@ rust-version = "1.65.0" foldhash = { version = "0.2.0", default-features = false, optional = true } # For external trait impls +paralight = { version = "0.0.10", optional = true } rayon = { version = "1.9.0", optional = true } serde_core = { version = "1.0.221", default-features = false, optional = true } @@ -85,5 +86,5 @@ default-hasher = ["dep:foldhash"] inline-more = [] [package.metadata.docs.rs] -features = ["nightly", "rayon", "serde", "raw-entry"] +features = ["nightly", "paralight", "rayon", "serde", "raw-entry"] rustdoc-args = ["--generate-link-to-definition"] diff --git a/src/external_trait_impls/mod.rs b/src/external_trait_impls/mod.rs index ef497836c..cfea0ffef 100644 --- a/src/external_trait_impls/mod.rs +++ b/src/external_trait_impls/mod.rs @@ -1,3 +1,5 @@ +#[cfg(feature = "paralight")] +mod paralight; #[cfg(feature = "rayon")] pub(crate) mod rayon; #[cfg(feature = "serde")] diff --git a/src/external_trait_impls/paralight.rs b/src/external_trait_impls/paralight.rs new file mode 100644 index 000000000..926090621 --- /dev/null +++ b/src/external_trait_impls/paralight.rs @@ -0,0 +1,742 @@ +use crate::raw::{Allocator, RawTable}; +use crate::{HashMap, HashSet}; +use paralight::iter::{ + IntoParallelRefMutSource, IntoParallelRefSource, IntoParallelSource, ParallelSource, + SourceCleanup, SourceDescriptor, +}; + +// HashSet.par_iter() +impl<'data, T: Sync + 'data, S: 'data, A: Allocator + Sync + 'data> IntoParallelRefSource<'data> + for HashSet +{ + type Item = &'data T; + type Source = HashSetRefParallelSource<'data, T, S, A>; + + fn par_iter(&'data self) -> Self::Source { + HashSetRefParallelSource { hash_set: self } + } +} + +#[must_use = "iterator adaptors are lazy"] +pub struct HashSetRefParallelSource<'data, T, S, A: Allocator> { + hash_set: &'data HashSet, +} + +impl<'data, T: Sync, S, A: Allocator + Sync> ParallelSource + for HashSetRefParallelSource<'data, T, S, A> +{ + type Item = &'data T; + + fn descriptor(self) -> impl SourceDescriptor + Sync { + HashSetRefSourceDescriptor { + table: &self.hash_set.map.table, + } + } +} + +struct HashSetRefSourceDescriptor<'data, T: Sync, A: Allocator> { + table: &'data RawTable<(T, ()), A>, +} + +impl SourceCleanup for HashSetRefSourceDescriptor<'_, T, A> { + const NEEDS_CLEANUP: bool = false; + + fn len(&self) -> usize { + self.table.buckets() + } + + unsafe fn cleanup_item_range(&self, _range: core::ops::Range) { + // Nothing to cleanup + } +} + +impl<'data, T: Sync, A: Allocator> SourceDescriptor for HashSetRefSourceDescriptor<'data, T, A> { + type Item = &'data T; + + unsafe fn fetch_item(&self, index: usize) -> Option { + debug_assert!(index < self.len()); + // SAFETY: The passed index is less than the number of buckets. This is + // ensured by the safety preconditions of `fetch_item()`, given that + // `len()` returned the number of buckets, and is further confirmed by + // the debug assertion. + let full = unsafe { self.table.is_bucket_full(index) }; + if full { + // SAFETY: + // - The table is already allocated. + // - The index is in bounds (see previous safety comment). + // - The table contains elements of type (T, ()). + let bucket = unsafe { self.table.bucket(index) }; + // SAFETY: The bucket is full, so it's safe to derive a const + // reference from it. + let (t, ()) = unsafe { bucket.as_ref() }; + Some(t) + } else { + None + } + } +} + +// HashMap.par_iter() +impl<'data, K: Sync + 'data, V: Sync + 'data, S: 'data, A: Allocator + Sync + 'data> + IntoParallelRefSource<'data> for HashMap +{ + type Item = &'data (K, V); + type Source = HashMapRefParallelSource<'data, K, V, S, A>; + + fn par_iter(&'data self) -> Self::Source { + HashMapRefParallelSource { hash_map: self } + } +} + +#[must_use = "iterator adaptors are lazy"] +pub struct HashMapRefParallelSource<'data, K, V, S, A: Allocator> { + hash_map: &'data HashMap, +} + +impl<'data, K: Sync, V: Sync, S, A: Allocator + Sync> ParallelSource + for HashMapRefParallelSource<'data, K, V, S, A> +{ + type Item = &'data (K, V); + + fn descriptor(self) -> impl SourceDescriptor + Sync { + HashMapRefSourceDescriptor { + table: &self.hash_map.table, + } + } +} + +struct HashMapRefSourceDescriptor<'data, K: Sync, V: Sync, A: Allocator> { + table: &'data RawTable<(K, V), A>, +} + +impl SourceCleanup for HashMapRefSourceDescriptor<'_, K, V, A> { + const NEEDS_CLEANUP: bool = false; + + fn len(&self) -> usize { + self.table.buckets() + } + + unsafe fn cleanup_item_range(&self, _range: core::ops::Range) { + // Nothing to cleanup + } +} + +impl<'data, K: Sync, V: Sync, A: Allocator> SourceDescriptor + for HashMapRefSourceDescriptor<'data, K, V, A> +{ + type Item = &'data (K, V); + + unsafe fn fetch_item(&self, index: usize) -> Option { + debug_assert!(index < self.len()); + // SAFETY: The passed index is less than the number of buckets. This is + // ensured by the safety preconditions of `fetch_item()`, given that + // `len()` returned the number of buckets, and is further confirmed by + // the debug assertion. + let full = unsafe { self.table.is_bucket_full(index) }; + if full { + // SAFETY: + // - The table is already allocated. + // - The index is in bounds (see previous safety comment). + // - The table contains elements of type (K, V). + let bucket = unsafe { self.table.bucket(index) }; + // SAFETY: The bucket is full, so it's safe to derive a const + // reference from it. + unsafe { Some(bucket.as_ref()) } + } else { + None + } + } +} + +// HashMap.par_iter_mut() +impl<'data, K: Sync + 'data, V: Send + 'data, S: 'data, A: Allocator + Sync + 'data> + IntoParallelRefMutSource<'data> for HashMap +{ + type Item = (&'data K, &'data mut V); + type Source = HashMapRefMutParallelSource<'data, K, V, S, A>; + + fn par_iter_mut(&'data mut self) -> Self::Source { + HashMapRefMutParallelSource { hash_map: self } + } +} + +#[must_use = "iterator adaptors are lazy"] +pub struct HashMapRefMutParallelSource<'data, K, V, S, A: Allocator> { + hash_map: &'data mut HashMap, +} + +impl<'data, K: Sync, V: Send, S, A: Allocator + Sync> ParallelSource + for HashMapRefMutParallelSource<'data, K, V, S, A> +{ + type Item = (&'data K, &'data mut V); + + fn descriptor(self) -> impl SourceDescriptor + Sync { + HashMapRefMutSourceDescriptor { + table: raw_table_wrapper::HashMapRefMut { + inner: &self.hash_map.table, + }, + } + } +} + +struct HashMapRefMutSourceDescriptor<'data, K: Sync, V: Send, A: Allocator> { + table: raw_table_wrapper::HashMapRefMut<'data, K, V, A>, +} + +impl SourceCleanup for HashMapRefMutSourceDescriptor<'_, K, V, A> { + const NEEDS_CLEANUP: bool = false; + + fn len(&self) -> usize { + self.table.inner.buckets() + } + + unsafe fn cleanup_item_range(&self, _range: core::ops::Range) { + // Nothing to cleanup + } +} + +impl<'data, K: Sync, V: Send, A: Allocator> SourceDescriptor + for HashMapRefMutSourceDescriptor<'data, K, V, A> +{ + type Item = (&'data K, &'data mut V); + + unsafe fn fetch_item(&self, index: usize) -> Option { + debug_assert!(index < self.len()); + // SAFETY: The passed index is less than the number of buckets. This is + // ensured by the safety preconditions of `fetch_item()`, given that + // `len()` returned the number of buckets, and is further confirmed by + // the debug assertion. + let full = unsafe { self.table.inner.is_bucket_full(index) }; + if full { + // SAFETY: + // - The table is already allocated. + // - The index is in bounds (see previous safety comment). + // - The table contains elements of type (K, V). + let bucket = unsafe { self.table.inner.bucket(index) }; + // SAFETY: + // - The bucket is full, i.e. points to a valid value. + // - While the resulting reference is valid, the memory it points to + // isn't accessed through any other pointer. Indeed, the + // `SourceDescriptor` contract ensures that no other call to + // `fetch_item()` will be made at this index while the iterator is + // active. Furthermore, `HashMapRefMutParallelSource` holds a + // mutable reference to the hash map with the 'data lifetime, + // ensuring that no other part of the program accesses the hash + // map while the returned reference exists. + let (key, value) = unsafe { bucket.as_mut() }; + Some((key, value)) + } else { + None + } + } +} + +// HashSet.into_par_iter() +impl IntoParallelSource for HashSet { + type Item = T; + type Source = HashSetParallelSource; + + fn into_par_iter(self) -> Self::Source { + HashSetParallelSource { hash_set: self } + } +} + +#[must_use = "iterator adaptors are lazy"] +pub struct HashSetParallelSource { + hash_set: HashSet, +} + +impl ParallelSource for HashSetParallelSource { + type Item = T; + + fn descriptor(self) -> impl SourceDescriptor + Sync { + HashSetSourceDescriptor { + table: raw_table_wrapper::HashSet { + inner: self.hash_set.map.table, + }, + } + } +} + +struct HashSetSourceDescriptor { + table: raw_table_wrapper::HashSet, +} + +impl SourceCleanup for HashSetSourceDescriptor { + const NEEDS_CLEANUP: bool = core::mem::needs_drop::(); + + fn len(&self) -> usize { + self.table.inner.buckets() + } + + unsafe fn cleanup_item_range(&self, range: core::ops::Range) { + if Self::NEEDS_CLEANUP { + debug_assert!(range.start <= range.end); + debug_assert!(range.start <= self.len()); + debug_assert!(range.end <= self.len()); + for index in range { + // SAFETY: The passed index is less than the number of buckets. This is + // ensured by the safety preconditions of `cleanup_item_range()`, given + // that `len()` returned the number of buckets, and is further confirmed + // by the debug assertions. + let full = unsafe { self.table.inner.is_bucket_full(index) }; + if full { + // SAFETY: + // - The table is already allocated. + // - The index is in bounds (see previous safety comment). + // - The table contains elements of type (T, ()). + let bucket = unsafe { self.table.inner.bucket(index) }; + // SAFETY: + // - The bucket points to an aligned value of type (T, ()). + // - The value is initialized, as the bucket is full. + // - No other part of the program reads it, as the `SourceCleanup` + // and `SourceDescriptor` contracts ensure that no other call to + // `fetch_item()` nor `cleanup_item_range()` is made for this + // index; and even though the bucket isn't marked as empty here, + // the Drop implementation clears the table without dropping. + let (t, ()) = unsafe { bucket.read() }; + drop(t); + } + } + } + } +} + +impl SourceDescriptor for HashSetSourceDescriptor { + type Item = T; + + unsafe fn fetch_item(&self, index: usize) -> Option { + debug_assert!(index < self.len()); + // SAFETY: The passed index is less than the number of buckets. This is + // ensured by the safety preconditions of `fetch_item()`, given that + // `len()` returned the number of buckets, and is further confirmed by + // the debug assertion. + let full = unsafe { self.table.inner.is_bucket_full(index) }; + if full { + // SAFETY: + // - The table is already allocated. + // - The index is in bounds (see previous safety comment). + // - The table contains elements of type (T, ()). + let bucket = unsafe { self.table.inner.bucket(index) }; + // SAFETY: + // - The bucket points to an aligned value of type (T, ()). + // - The value is initialized, as the bucket is full. + // - No other part of the program reads it, as the `SourceCleanup` + // and `SourceDescriptor` contracts ensure that no other call to + // `fetch_item()` nor `cleanup_item_range()` is made for this + // index; and even though the bucket isn't marked as empty here, + // the Drop implementation clears the table without dropping. + let (t, ()) = unsafe { bucket.read() }; + Some(t) + } else { + None + } + } +} + +impl Drop for HashSetSourceDescriptor { + fn drop(&mut self) { + // Paralight already dropped each missing* bucket via calls to cleanup_item_range(), so we + // can simply mark all buckets as cleared and let the RawTable destructor do the rest. + // + // *Some buckets may be missing because the iterator exited early (e.g. an item was found + // via the find_any() adaptor) or unexpectedly due to a panic (e.g. in the closure passed + // to the for_each() adaptor). + // + // TODO: Optimize this to simply deallocate without touching the control bytes. + self.table.inner.clear_no_drop(); + } +} + +// HashMap.into_par_iter() +impl IntoParallelSource for HashMap { + type Item = (K, V); + type Source = HashMapParallelSource; + + fn into_par_iter(self) -> Self::Source { + HashMapParallelSource { hash_map: self } + } +} + +#[must_use = "iterator adaptors are lazy"] +pub struct HashMapParallelSource { + hash_map: HashMap, +} + +impl ParallelSource + for HashMapParallelSource +{ + type Item = (K, V); + + fn descriptor(self) -> impl SourceDescriptor + Sync { + HashMapSourceDescriptor { + table: raw_table_wrapper::HashMap { + inner: self.hash_map.table, + }, + } + } +} + +struct HashMapSourceDescriptor { + table: raw_table_wrapper::HashMap, +} + +impl SourceCleanup for HashMapSourceDescriptor { + const NEEDS_CLEANUP: bool = core::mem::needs_drop::<(K, V)>(); + + fn len(&self) -> usize { + self.table.inner.buckets() + } + + unsafe fn cleanup_item_range(&self, range: core::ops::Range) { + if Self::NEEDS_CLEANUP { + debug_assert!(range.start <= range.end); + debug_assert!(range.start <= self.len()); + debug_assert!(range.end <= self.len()); + for index in range { + // SAFETY: The passed index is less than the number of buckets. This is + // ensured by the safety preconditions of `cleanup_item_range()`, given + // that `len()` returned the number of buckets, and is further confirmed + // by the debug assertions. + let full = unsafe { self.table.inner.is_bucket_full(index) }; + if full { + // SAFETY: + // - The table is already allocated. + // - The index is in bounds (see previous safety comment). + // - The table contains elements of type (K, V). + let bucket = unsafe { self.table.inner.bucket(index) }; + // SAFETY: + // - The bucket points to an aligned value of type (K, V). + // - The value is initialized, as the bucket is full. + // - No other part of the program reads it, as the `SourceCleanup` + // and `SourceDescriptor` contracts ensure that no other call to + // `fetch_item()` nor `cleanup_item_range()` is made for this + // index; and even though the bucket isn't marked as empty here, + // the Drop implementation clears the table without dropping. + let key_value = unsafe { bucket.read() }; + drop(key_value); + } + } + } + } +} + +impl SourceDescriptor for HashMapSourceDescriptor { + type Item = (K, V); + + unsafe fn fetch_item(&self, index: usize) -> Option { + debug_assert!(index < self.len()); + // SAFETY: The passed index is less than the number of buckets. This is + // ensured by the safety preconditions of `fetch_item()`, given that + // `len()` returned the number of buckets, and is further confirmed by + // the debug assertion. + let full = unsafe { self.table.inner.is_bucket_full(index) }; + if full { + // SAFETY: + // - The table is already allocated. + // - The index is in bounds (see previous safety comment). + // - The table contains elements of type (K, V). + let bucket = unsafe { self.table.inner.bucket(index) }; + // SAFETY: + // - The bucket points to an aligned value of type (K, V). + // - The value is initialized, as the bucket is full. + // - No other part of the program reads it, as the `SourceCleanup` + // and `SourceDescriptor` contracts ensure that no other call to + // `fetch_item()` nor `cleanup_item_range()` is made for this + // index; and even though the bucket isn't marked as empty here, + // the Drop implementation clears the table without dropping. + unsafe { Some(bucket.read()) } + } else { + None + } + } +} + +impl Drop for HashMapSourceDescriptor { + fn drop(&mut self) { + // Paralight already dropped each missing* bucket via calls to cleanup_item_range(), so we + // can simply mark all buckets as cleared and let the RawTable destructor do the rest. + // + // *Some buckets may be missing because the iterator exited early (e.g. an item was found + // via the find_any() adaptor) or unexpectedly due to a panic (e.g. in the closure passed + // to the for_each() adaptor). + // + // TODO: Optimize this to simply deallocate without touching the control bytes. + self.table.inner.clear_no_drop(); + } +} + +mod raw_table_wrapper { + use crate::raw::{Allocator, RawTable}; + + pub(super) struct HashSet { + pub(super) inner: RawTable<(T, ()), A>, + } + + // TODO: Does the Allocator need to be Sync too? + unsafe impl Sync for HashSet {} + + pub(super) struct HashMap { + pub(super) inner: RawTable<(K, V), A>, + } + + // TODO: Does the Allocator need to be Sync too? + unsafe impl Sync for HashMap {} + + pub(super) struct HashMapRefMut<'data, K, V, A: Allocator> { + pub(super) inner: &'data RawTable<(K, V), A>, + } + + // TODO: Does the Allocator need to be Sync too? + unsafe impl<'data, K: Sync, V: Send, A: Allocator + Sync> Sync for HashMapRefMut<'data, K, V, A> {} +} + +#[cfg(test)] +mod test { + use super::*; + use alloc::boxed::Box; + use core::cell::Cell; + use core::ops::Deref; + use paralight::iter::{ParallelIteratorExt, ParallelSourceExt}; + use paralight::threads::{CpuPinningPolicy, RangeStrategy, ThreadCount, ThreadPoolBuilder}; + use std::hash::{Hash, Hasher}; + + // A cell that implements Hash. + #[derive(PartialEq, Eq)] + struct HashCell(Cell); + + impl HashCell { + fn new(t: T) -> Self { + Self(Cell::new(t)) + } + + fn get(&self) -> T { + self.0.get() + } + } + + impl Hash for HashCell { + fn hash(&self, state: &mut H) { + self.0.get().hash(state) + } + } + + #[test] + fn test_set_par_iter() { + let mut thread_pool = ThreadPoolBuilder { + num_threads: ThreadCount::AvailableParallelism, + range_strategy: RangeStrategy::WorkStealing, + cpu_pinning: CpuPinningPolicy::No, + } + .build(); + + let mut set = HashSet::new(); + for i in 1..=42 { + set.insert(Box::new(i)); + } + + let sum = set + .par_iter() + .with_thread_pool(&mut thread_pool) + .map(|x| x.deref()) + .sum::(); + assert_eq!(sum, 21 * 43); + } + + #[test] + fn test_set_into_par_iter() { + let mut thread_pool = ThreadPoolBuilder { + num_threads: ThreadCount::AvailableParallelism, + range_strategy: RangeStrategy::WorkStealing, + cpu_pinning: CpuPinningPolicy::No, + } + .build(); + + let mut set = HashSet::new(); + for i in 1..=42 { + set.insert(Box::new(i)); + } + + let sum = set + .into_par_iter() + .with_thread_pool(&mut thread_pool) + .map(|x| *x) + .sum::(); + assert_eq!(sum, 21 * 43); + } + + #[test] + fn test_set_into_par_iter_send() { + let mut thread_pool = ThreadPoolBuilder { + num_threads: ThreadCount::AvailableParallelism, + range_strategy: RangeStrategy::WorkStealing, + cpu_pinning: CpuPinningPolicy::No, + } + .build(); + + let mut set = HashSet::new(); + for i in 1..=42 { + set.insert(HashCell::new(i)); + } + + let sum = set + .into_par_iter() + .with_thread_pool(&mut thread_pool) + .map(|x| x.get()) + .sum::(); + assert_eq!(sum, 21 * 43); + } + + #[test] + fn test_set_into_par_iter_find() { + let mut thread_pool = ThreadPoolBuilder { + num_threads: ThreadCount::AvailableParallelism, + range_strategy: RangeStrategy::WorkStealing, + cpu_pinning: CpuPinningPolicy::No, + } + .build(); + + let mut set = HashSet::new(); + for i in 1..=42 { + set.insert(Box::new(i)); + } + + // The search will exit once an even number is found, this test checks + // (with Miri) that no memory leak happens as a result. + let any_even = set + .into_par_iter() + .with_thread_pool(&mut thread_pool) + .find_any(|x| **x % 2 == 0); + assert!(any_even.is_some()); + assert_eq!(*any_even.unwrap() % 2, 0); + } + + #[test] + fn test_map_par_iter() { + let mut thread_pool = ThreadPoolBuilder { + num_threads: ThreadCount::AvailableParallelism, + range_strategy: RangeStrategy::WorkStealing, + cpu_pinning: CpuPinningPolicy::No, + } + .build(); + + let mut map = HashMap::new(); + for i in 1..=42 { + map.insert(Box::new(i), Box::new(i * i)); + } + + map.par_iter() + .with_thread_pool(&mut thread_pool) + .for_each(|(k, v)| assert_eq!(**k * **k, **v)); + } + + #[test] + fn test_map_par_iter_mut() { + let mut thread_pool = ThreadPoolBuilder { + num_threads: ThreadCount::AvailableParallelism, + range_strategy: RangeStrategy::WorkStealing, + cpu_pinning: CpuPinningPolicy::No, + } + .build(); + + let mut map = HashMap::new(); + for i in 1..=42 { + map.insert(Box::new(i), Box::new(i)); + } + + map.par_iter_mut() + .with_thread_pool(&mut thread_pool) + .for_each(|(k, v)| **v *= **k); + + for (k, v) in map.iter() { + assert_eq!(**k * **k, **v); + } + } + + #[test] + fn test_map_par_iter_mut_send_sync() { + let mut thread_pool = ThreadPoolBuilder { + num_threads: ThreadCount::AvailableParallelism, + range_strategy: RangeStrategy::WorkStealing, + cpu_pinning: CpuPinningPolicy::No, + } + .build(); + + let mut map = HashMap::new(); + for i in 1..=42 { + map.insert(Box::new(i), Cell::new(i)); + } + + map.par_iter_mut() + .with_thread_pool(&mut thread_pool) + .for_each(|(k, v)| *v.get_mut() *= **k); + + for (k, v) in map.iter() { + assert_eq!(**k * **k, v.get()); + } + } + + #[test] + fn test_map_into_par_iter() { + let mut thread_pool = ThreadPoolBuilder { + num_threads: ThreadCount::AvailableParallelism, + range_strategy: RangeStrategy::WorkStealing, + cpu_pinning: CpuPinningPolicy::No, + } + .build(); + + let mut map = HashMap::new(); + for i in 1..=42 { + map.insert(Box::new(i), Box::new(i * i)); + } + + map.into_par_iter() + .with_thread_pool(&mut thread_pool) + .for_each(|(k, v)| assert_eq!(*k * *k, *v)); + } + + #[test] + fn test_map_into_par_iter_send() { + let mut thread_pool = ThreadPoolBuilder { + num_threads: ThreadCount::AvailableParallelism, + range_strategy: RangeStrategy::WorkStealing, + cpu_pinning: CpuPinningPolicy::No, + } + .build(); + + let mut map = HashMap::new(); + for i in 1..=42 { + map.insert(HashCell::new(i), Cell::new(i * i)); + } + + map.into_par_iter() + .with_thread_pool(&mut thread_pool) + .for_each(|(k, v)| assert_eq!(k.get() * k.get(), v.get())); + } + + #[test] + fn test_map_into_par_iter_find() { + let mut thread_pool = ThreadPoolBuilder { + num_threads: ThreadCount::AvailableParallelism, + range_strategy: RangeStrategy::WorkStealing, + cpu_pinning: CpuPinningPolicy::No, + } + .build(); + + let mut map = HashMap::new(); + for i in 1..=42 { + map.insert(Box::new(i), Box::new(i * i)); + } + + // The search will exit once an match is found, this test checks (with + // Miri) that no memory leak happens as a result. + let needle = map + .into_par_iter() + .with_thread_pool(&mut thread_pool) + .find_any(|(k, v)| **k % 2 == 0 && **v % 3 == 0); + assert!(needle.is_some()); + let (k, v) = needle.unwrap(); + assert_eq!(*k % 2, 0); + assert_eq!(*v % 3, 0); + } +}