diff --git a/src/compute/src/render/top_k.rs b/src/compute/src/render/top_k.rs index 301524da937c8..925de47e72aec 100644 --- a/src/compute/src/render/top_k.rs +++ b/src/compute/src/render/top_k.rs @@ -11,7 +11,9 @@ //! //! Consult [TopKPlan] documentation for details. +use std::cell::RefCell; use std::collections::HashMap; +use std::rc::Rc; use differential_dataflow::hashable::Hashable; use differential_dataflow::lattice::Lattice; @@ -28,7 +30,7 @@ use timely::dataflow::Scope; use mz_compute_client::plan::top_k::{ BasicTopKPlan, MonotonicTop1Plan, MonotonicTopKPlan, TopKPlan, }; -use mz_repr::{Diff, Row}; +use mz_repr::{DatumVec, Diff, Row}; use crate::render::context::CollectionBundle; use crate::render::context::Context; @@ -367,6 +369,11 @@ where { let mut aggregates = HashMap::new(); let mut vector = Vec::new(); + let shared = Rc::new(RefCell::new(monoids::Top1MonoidShared { + order_key, + left: DatumVec::new(), + right: DatumVec::new(), + })); collection .inner .unary_notify( @@ -380,9 +387,9 @@ where .entry(time.time().clone()) .or_insert_with(HashMap::new); for ((grp_row, row), record_time, diff) in vector.drain(..) { - let monoid = monoids::Top1Monoid { + let monoid = monoids::Top1MonoidLocal { row, - order_key: order_key.clone(), + shared: Rc::clone(&shared), }; let topk = agg_time.entry((grp_row, record_time)).or_insert_with( move || { @@ -401,7 +408,11 @@ where let mut session = output.session(&time); for ((grp_row, record_time), topk) in aggs { session.give_iterator(topk.into_iter().map(|(monoid, diff)| { - ((grp_row.clone(), monoid.row), record_time.clone(), diff) + ( + (grp_row.clone(), monoid.into_row()), + record_time.clone(), + diff, + ) })) } } @@ -508,13 +519,16 @@ pub mod topk_agg { /// Monoids for in-place compaction of monotonic streams. pub mod monoids { + use std::cell::RefCell; use std::cmp::Ordering; + use std::hash::{Hash, Hasher}; + use std::rc::Rc; use differential_dataflow::difference::Semigroup; use serde::{Deserialize, Serialize}; use mz_expr::ColumnOrder; - use mz_repr::Row; + use mz_repr::{DatumVec, Row}; /// A monoid containing a row and an ordering. #[derive(Eq, PartialEq, Debug, Clone, Serialize, Deserialize, Hash)] @@ -553,4 +567,75 @@ pub mod monoids { false } } + + /// A shared portion of a thread-local top-1 monoid implementation. + #[derive(Debug)] + pub struct Top1MonoidShared { + pub order_key: Vec, + pub left: DatumVec, + pub right: DatumVec, + } + + /// A monoid containing a row and a shared pointer to a shared structure. + /// Only suitable for thread-local aggregations. + #[derive(Debug, Clone)] + pub struct Top1MonoidLocal { + pub row: Row, + pub shared: Rc>, + } + + impl Top1MonoidLocal { + pub fn into_row(self) -> Row { + self.row + } + } + + impl PartialEq for Top1MonoidLocal { + fn eq(&self, other: &Self) -> bool { + self.row.eq(&other.row) + } + } + + impl Eq for Top1MonoidLocal {} + + impl Hash for Top1MonoidLocal { + fn hash(&self, state: &mut H) { + self.row.hash(state); + } + } + + impl Ord for Top1MonoidLocal { + fn cmp(&self, other: &Self) -> Ordering { + debug_assert!(Rc::ptr_eq(&self.shared, &other.shared)); + let Top1MonoidShared { + left, + right, + order_key, + } = &mut *self.shared.borrow_mut(); + + let left = left.borrow_with(&self.row); + let right = right.borrow_with(&other.row); + mz_expr::compare_columns(order_key, &left, &right, || left.cmp(&right)) + } + } + + impl PartialOrd for Top1MonoidLocal { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } + } + + impl Semigroup for Top1MonoidLocal { + fn plus_equals(&mut self, rhs: &Self) { + let cmp = (*self).cmp(rhs); + // NB: Reminder that TopK returns the _minimum_ K items. + if cmp == Ordering::Greater { + self.clone_from(rhs); + } + } + + fn is_zero(&self) -> bool { + false + } + } }