diff --git a/pgdog-postgres-types/src/array.rs b/pgdog-postgres-types/src/array.rs index 6c50d5573..f056f0184 100644 --- a/pgdog-postgres-types/src/array.rs +++ b/pgdog-postgres-types/src/array.rs @@ -8,7 +8,7 @@ use crate::{DataType, Datum}; #[derive(Debug, Clone)] pub struct Array { elements: Vec>, - element_oid: i32, + pub(crate) element_oid: i32, dim: Dimension, } diff --git a/pgdog-postgres-types/src/datum.rs b/pgdog-postgres-types/src/datum.rs index 2304596a8..9e04e75fa 100644 --- a/pgdog-postgres-types/src/datum.rs +++ b/pgdog-postgres-types/src/datum.rs @@ -1,4 +1,4 @@ -use std::ops::Add; +use std::{fmt, mem}; use bytes::Bytes; use pgdog_vector::{Float, Vector}; @@ -172,27 +172,6 @@ impl ToDataRowColumn for Datum { } } -impl Add for Datum { - type Output = Datum; - - fn add(self, rhs: Self) -> Self::Output { - use Datum::*; - - match (self, rhs) { - (Bigint(a), Bigint(b)) => Bigint(a + b), - (Integer(a), Integer(b)) => Integer(a + b), - (SmallInt(a), SmallInt(b)) => SmallInt(a + b), - (Interval(a), Interval(b)) => Interval(a + b), - (Numeric(a), Numeric(b)) => Numeric(a + b), - (Float(a), Float(b)) => Float(crate::Float(a.0 + b.0)), - (Double(a), Double(b)) => Double(crate::Double(a.0 + b.0)), - (Datum::Null, b) => b, - (a, Datum::Null) => a, - _ => Datum::Null, // Might be good to raise an error. - } - } -} - impl Datum { pub fn new( bytes: &[u8], @@ -254,6 +233,63 @@ impl Datum { Datum::Unknown(bytes) => Ok(bytes.clone()), } } + + fn data_type(&self) -> DataType { + match self { + Datum::Bigint(..) => DataType::Bigint, + Datum::Integer(..) => DataType::Integer, + Datum::Uuid(..) => DataType::Uuid, + Datum::Text(..) => DataType::Text, + Datum::Boolean(..) => DataType::Bool, + Datum::Float(..) => DataType::Real, + Datum::Double(..) => DataType::DoublePrecision, + Datum::Numeric(..) => DataType::Numeric, + Datum::Timestamp(..) => DataType::Timestamp, + Datum::TimestampTz(..) => DataType::TimestampTz, + Datum::SmallInt(..) => DataType::SmallInt, + Datum::Interval(..) => DataType::Interval, + Datum::Vector(..) => DataType::Vector, + Datum::Oid(..) => DataType::Oid, + Datum::Array(a) => DataType::Array(a.element_oid), + Datum::Null => DataType::Other(0), + Datum::Unknown(..) => DataType::Other(0), + } + } + + /// Adds rhs to self. Returns an error if self + rhs are not the same type, + /// or if self is a type that cannot be added. + /// + /// The behavior of this method diverges from postgres when handling NULL. + /// When calculating x + NULL, we will return x, while postgres will return + /// NULL + pub fn checked_add_assign(&mut self, rhs: Self) -> Result<(), Error> { + use Datum::*; + + match (self, rhs) { + (Bigint(a), Bigint(b)) => *a += b, + (Integer(a), Integer(b)) => *a += b, + (SmallInt(a), SmallInt(b)) => *a += b, + (Interval(a), Interval(b)) => *a += b, + (Numeric(a), Numeric(b)) => *a += b, + (Float(a), Float(b)) => a.0 += b.0, + (Double(a), Double(b)) => a.0 += b.0, + // FIXME(sage): We should probably mimic PG in this general method, + // and expect the caller to filter nulls if that's what they want + (a @ Datum::Null, b) => *a = b, + (_, Datum::Null) => {} + (a, b) if mem::discriminant(a) != mem::discriminant(&b) => { + return Err(Error::IncompatibleTypes(a.data_type(), b.data_type())); + } + (a, _) => { + return Err(Error::InvalidOperation { + op: "add", + ty: a.data_type(), + }); + } + } + + Ok(()) + } } /// PostgreSQL data types. @@ -317,17 +353,43 @@ impl DataType { } } +impl fmt::Display for DataType { + fn fmt(&self, f: &mut fmt::Formatter) -> Result<(), fmt::Error> { + use DataType::*; + match self { + Bigint => write!(f, "bigint"), + Integer => write!(f, "integer"), + Text => write!(f, "text"), + Interval => write!(f, "interval"), + Timestamp => write!(f, "timestamp"), + TimestampTz => write!(f, "timestamptz"), + Real => write!(f, "real"), + DoublePrecision => write!(f, "double precision"), + Bool => write!(f, "boolean"), + SmallInt => write!(f, "smallint"), + TinyInt => write!(f, "tinyint"), + Numeric => write!(f, "numeric"), + Other(i) => write!(f, "unknown type {i}"), + Uuid => write!(f, "uuid"), + Oid => write!(f, "oid"), + Vector => write!(f, "vector"), + Array(i) => write!(f, "{}[]", Self::from_oid(*i)), + } + } +} + #[cfg(test)] mod tests { use super::*; use bytes::{BufMut, BytesMut}; + use std::assert_matches; #[test] fn test_multidimensional_text_array_falls_back_to_unknown() { let input = b"{{1,2},{3,4}}"; let datum = Datum::new(input, DataType::Array(23), Format::Text, false).unwrap(); - assert!(matches!(datum, Datum::Unknown(_))); + assert_matches!(datum, Datum::Unknown(_)); assert_eq!( datum.encode(Format::Text).unwrap(), Bytes::from_static(input) @@ -353,7 +415,22 @@ mod tests { let input = buf.freeze(); let datum = Datum::new(&input, DataType::Array(23), Format::Binary, false).unwrap(); - assert!(matches!(datum, Datum::Unknown(_))); + assert_matches!(datum, Datum::Unknown(_)); assert_eq!(datum.encode(Format::Binary).unwrap(), input); } + + #[test] + fn test_adding_types_which_cannot_be_added() { + let mut datum = Datum::Text("hello".to_owned()); + // operator does not exist: text + text + let result = datum.checked_add_assign(Datum::Text("goodbye".to_owned())); + assert_matches!(result, Err(Error::InvalidOperation { .. })); + } + + #[test] + fn test_adding_incompatible_types() { + let mut datum = Datum::Integer(1); + let result = datum.checked_add_assign(Datum::Text("1".to_owned())); + assert_matches!(result, Err(Error::IncompatibleTypes(..))); + } } diff --git a/pgdog-postgres-types/src/error.rs b/pgdog-postgres-types/src/error.rs index b8378467f..274ce1a1c 100644 --- a/pgdog-postgres-types/src/error.rs +++ b/pgdog-postgres-types/src/error.rs @@ -1,5 +1,6 @@ //! Network errors. +use crate::datum::DataType; use std::array::TryFromSliceError; use thiserror::Error; @@ -47,4 +48,10 @@ pub enum Error { #[error("lsn decode error")] LsnDecode, + + #[error("expected {0}, got {1}")] + IncompatibleTypes(DataType, DataType), + + #[error("invalid operation {op} for {ty}")] + InvalidOperation { op: &'static str, ty: DataType }, } diff --git a/pgdog-postgres-types/src/interval.rs b/pgdog-postgres-types/src/interval.rs index c4b3e5506..652e6e881 100644 --- a/pgdog-postgres-types/src/interval.rs +++ b/pgdog-postgres-types/src/interval.rs @@ -1,11 +1,11 @@ -use std::{num::ParseIntError, ops::Add}; +use std::{num::ParseIntError, ops::Add, ops::AddAssign}; use crate::Data; use super::*; use bytes::{Buf, BufMut, Bytes, BytesMut}; -#[derive(Eq, PartialEq, Ord, PartialOrd, Default, Debug, Clone, Hash)] +#[derive(Eq, PartialEq, Ord, PartialOrd, Default, Debug, Clone, Copy, Hash)] pub struct Interval { years: i64, months: i32, @@ -33,6 +33,12 @@ impl Add for Interval { } } +impl AddAssign for Interval { + fn add_assign(&mut self, rhs: Interval) { + *self = *self + rhs + } +} + impl ToDataRowColumn for Interval { fn to_data_row_column(&self) -> Data { self.encode(Format::Text).unwrap().into() diff --git a/pgdog-postgres-types/src/numeric.rs b/pgdog-postgres-types/src/numeric.rs index 5395f0346..19a192168 100644 --- a/pgdog-postgres-types/src/numeric.rs +++ b/pgdog-postgres-types/src/numeric.rs @@ -1,4 +1,10 @@ -use std::{cmp::Ordering, fmt::Display, hash::Hash, ops::Add, str::FromStr}; +use std::{ + cmp::Ordering, + fmt::Display, + hash::Hash, + ops::{Add, AddAssign}, + str::FromStr, +}; use bytes::{Buf, BufMut, Bytes, BytesMut}; use postgres_types::{FromSql, ToSql, Type}; @@ -95,6 +101,12 @@ impl Add for Numeric { } } +impl AddAssign for Numeric { + fn add_assign(&mut self, rhs: Self) { + *self = *self + rhs; + } +} + impl FromDataType for Numeric { fn decode(bytes: &[u8], encoding: Format) -> Result { match encoding { diff --git a/pgdog/src/backend/pool/connection/aggregate.rs b/pgdog/src/backend/pool/connection/aggregate.rs index 1e070104f..adb086762 100644 --- a/pgdog/src/backend/pool/connection/aggregate.rs +++ b/pgdog/src/backend/pool/connection/aggregate.rs @@ -102,7 +102,7 @@ impl<'a> Accumulator<'a> { match self.target.function() { AggregateFunction::Count => { if !self.datum.is_null() { - self.datum = self.datum.clone() + column.value; + self.datum.checked_add_assign(column.value)?; } else { self.datum = column.value; } @@ -127,7 +127,7 @@ impl<'a> Accumulator<'a> { } AggregateFunction::Sum => { if !self.datum.is_null() { - self.datum = self.datum.clone() + column.value; + self.datum.checked_add_assign(column.value)?; } else { self.datum = column.value; } @@ -150,8 +150,8 @@ impl<'a> Accumulator<'a> { } if let Some(weighted) = multiply_for_average(&column.value, &count.value) { - state.weighted_sum = state.weighted_sum.clone() + weighted; - state.total_count = state.total_count.clone() + count.value.clone(); + state.weighted_sum.checked_add_assign(weighted)?; + state.total_count.checked_add_assign(count.value.clone())?; } else { state.supported = false; return Ok(false); @@ -300,7 +300,7 @@ impl VarianceState { return Ok(true); } - self.total_count = self.total_count.clone() + count.value.clone(); + self.total_count.checked_add_assign(count.value.clone())?; let Some(sum_column) = self.sum_column else { self.supported = false; @@ -311,7 +311,7 @@ impl VarianceState { self.supported = false; return Ok(false); } - self.total_sum = self.total_sum.clone() + sum.value.clone(); + self.total_sum.checked_add_assign(sum.value.clone())?; let Some(sumsq_column) = self.sumsq_column else { self.supported = false; @@ -322,7 +322,7 @@ impl VarianceState { self.supported = false; return Ok(false); } - self.total_sumsq = self.total_sumsq.clone() + sumsq.value.clone(); + self.total_sumsq.checked_add_assign(sumsq.value.clone())?; Ok(true) }