Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pgdog-postgres-types/src/array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use crate::{DataType, Datum};
#[derive(Debug, Clone)]
pub struct Array {
elements: Vec<Option<Datum>>,
element_oid: i32,
pub(crate) element_oid: i32,
dim: Dimension,
}

Expand Down
125 changes: 101 additions & 24 deletions pgdog-postgres-types/src/datum.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use std::ops::Add;
use std::{fmt, mem};

use bytes::Bytes;
use pgdog_vector::{Float, Vector};
Expand Down Expand Up @@ -172,27 +172,6 @@ impl ToDataRowColumn for Datum {
}
}

impl Add for Datum {
type Output = Datum;

fn add(self, rhs: Self) -> Self::Output {
use Datum::*;

match (self, rhs) {
(Bigint(a), Bigint(b)) => Bigint(a + b),
(Integer(a), Integer(b)) => Integer(a + b),
(SmallInt(a), SmallInt(b)) => SmallInt(a + b),
(Interval(a), Interval(b)) => Interval(a + b),
(Numeric(a), Numeric(b)) => Numeric(a + b),
(Float(a), Float(b)) => Float(crate::Float(a.0 + b.0)),
(Double(a), Double(b)) => Double(crate::Double(a.0 + b.0)),
(Datum::Null, b) => b,
(a, Datum::Null) => a,
_ => Datum::Null, // Might be good to raise an error.
}
}
}

impl Datum {
pub fn new(
bytes: &[u8],
Expand Down Expand Up @@ -254,6 +233,63 @@ impl Datum {
Datum::Unknown(bytes) => Ok(bytes.clone()),
}
}

fn data_type(&self) -> DataType {
match self {
Datum::Bigint(..) => DataType::Bigint,
Datum::Integer(..) => DataType::Integer,
Datum::Uuid(..) => DataType::Uuid,
Datum::Text(..) => DataType::Text,
Datum::Boolean(..) => DataType::Bool,
Datum::Float(..) => DataType::Real,
Datum::Double(..) => DataType::DoublePrecision,
Datum::Numeric(..) => DataType::Numeric,
Datum::Timestamp(..) => DataType::Timestamp,
Datum::TimestampTz(..) => DataType::TimestampTz,
Datum::SmallInt(..) => DataType::SmallInt,
Datum::Interval(..) => DataType::Interval,
Datum::Vector(..) => DataType::Vector,
Datum::Oid(..) => DataType::Oid,
Datum::Array(a) => DataType::Array(a.element_oid),
Datum::Null => DataType::Other(0),
Datum::Unknown(..) => DataType::Other(0),
}
}

/// Adds rhs to self. Returns an error if self + rhs are not the same type,
/// or if self is a type that cannot be added.
///
/// The behavior of this method diverges from postgres when handling NULL.
/// When calculating x + NULL, we will return x, while postgres will return
/// NULL
pub fn checked_add_assign(&mut self, rhs: Self) -> Result<(), Error> {
use Datum::*;

match (self, rhs) {
(Bigint(a), Bigint(b)) => *a += b,
(Integer(a), Integer(b)) => *a += b,
(SmallInt(a), SmallInt(b)) => *a += b,
(Interval(a), Interval(b)) => *a += b,
(Numeric(a), Numeric(b)) => *a += b,
(Float(a), Float(b)) => a.0 += b.0,
(Double(a), Double(b)) => a.0 += b.0,
// FIXME(sage): We should probably mimic PG in this general method,
// and expect the caller to filter nulls if that's what they want
(a @ Datum::Null, b) => *a = b,
(_, Datum::Null) => {}
(a, b) if mem::discriminant(a) != mem::discriminant(&b) => {
return Err(Error::IncompatibleTypes(a.data_type(), b.data_type()));
}
(a, _) => {
return Err(Error::InvalidOperation {
op: "add",
ty: a.data_type(),
});
}
}

Ok(())
}
}

/// PostgreSQL data types.
Expand Down Expand Up @@ -317,17 +353,43 @@ impl DataType {
}
}

impl fmt::Display for DataType {
fn fmt(&self, f: &mut fmt::Formatter) -> Result<(), fmt::Error> {
use DataType::*;
match self {
Bigint => write!(f, "bigint"),
Integer => write!(f, "integer"),
Text => write!(f, "text"),
Interval => write!(f, "interval"),
Timestamp => write!(f, "timestamp"),
TimestampTz => write!(f, "timestamptz"),
Real => write!(f, "real"),
DoublePrecision => write!(f, "double precision"),
Bool => write!(f, "boolean"),
SmallInt => write!(f, "smallint"),
TinyInt => write!(f, "tinyint"),
Numeric => write!(f, "numeric"),
Other(i) => write!(f, "unknown type {i}"),
Uuid => write!(f, "uuid"),
Oid => write!(f, "oid"),
Vector => write!(f, "vector"),
Array(i) => write!(f, "{}[]", Self::from_oid(*i)),
}
}
}

#[cfg(test)]
mod tests {
use super::*;
use bytes::{BufMut, BytesMut};
use std::assert_matches;

#[test]
fn test_multidimensional_text_array_falls_back_to_unknown() {
let input = b"{{1,2},{3,4}}";
let datum = Datum::new(input, DataType::Array(23), Format::Text, false).unwrap();

assert!(matches!(datum, Datum::Unknown(_)));
assert_matches!(datum, Datum::Unknown(_));
assert_eq!(
datum.encode(Format::Text).unwrap(),
Bytes::from_static(input)
Expand All @@ -353,7 +415,22 @@ mod tests {
let input = buf.freeze();
let datum = Datum::new(&input, DataType::Array(23), Format::Binary, false).unwrap();

assert!(matches!(datum, Datum::Unknown(_)));
assert_matches!(datum, Datum::Unknown(_));
assert_eq!(datum.encode(Format::Binary).unwrap(), input);
}

#[test]
fn test_adding_types_which_cannot_be_added() {
let mut datum = Datum::Text("hello".to_owned());
// operator does not exist: text + text
let result = datum.checked_add_assign(Datum::Text("goodbye".to_owned()));
assert_matches!(result, Err(Error::InvalidOperation { .. }));
}

#[test]
fn test_adding_incompatible_types() {
let mut datum = Datum::Integer(1);
let result = datum.checked_add_assign(Datum::Text("1".to_owned()));
assert_matches!(result, Err(Error::IncompatibleTypes(..)));
}
}
7 changes: 7 additions & 0 deletions pgdog-postgres-types/src/error.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
//! Network errors.

use crate::datum::DataType;
use std::array::TryFromSliceError;

use thiserror::Error;
Expand Down Expand Up @@ -47,4 +48,10 @@ pub enum Error {

#[error("lsn decode error")]
LsnDecode,

#[error("expected {0}, got {1}")]
IncompatibleTypes(DataType, DataType),

#[error("invalid operation {op} for {ty}")]
InvalidOperation { op: &'static str, ty: DataType },
}
10 changes: 8 additions & 2 deletions pgdog-postgres-types/src/interval.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
use std::{num::ParseIntError, ops::Add};
use std::{num::ParseIntError, ops::Add, ops::AddAssign};

use crate::Data;

use super::*;
use bytes::{Buf, BufMut, Bytes, BytesMut};

#[derive(Eq, PartialEq, Ord, PartialOrd, Default, Debug, Clone, Hash)]
#[derive(Eq, PartialEq, Ord, PartialOrd, Default, Debug, Clone, Copy, Hash)]
pub struct Interval {
years: i64,
months: i32,
Expand Down Expand Up @@ -33,6 +33,12 @@ impl Add for Interval {
}
}

impl AddAssign for Interval {
fn add_assign(&mut self, rhs: Interval) {
*self = *self + rhs
}
}

impl ToDataRowColumn for Interval {
fn to_data_row_column(&self) -> Data {
self.encode(Format::Text).unwrap().into()
Expand Down
14 changes: 13 additions & 1 deletion pgdog-postgres-types/src/numeric.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,10 @@
use std::{cmp::Ordering, fmt::Display, hash::Hash, ops::Add, str::FromStr};
use std::{
cmp::Ordering,
fmt::Display,
hash::Hash,
ops::{Add, AddAssign},
str::FromStr,
};

use bytes::{Buf, BufMut, Bytes, BytesMut};
use postgres_types::{FromSql, ToSql, Type};
Expand Down Expand Up @@ -95,6 +101,12 @@ impl Add for Numeric {
}
}

impl AddAssign for Numeric {
fn add_assign(&mut self, rhs: Self) {
*self = *self + rhs;
}
}

impl FromDataType for Numeric {
fn decode(bytes: &[u8], encoding: Format) -> Result<Self, Error> {
match encoding {
Expand Down
14 changes: 7 additions & 7 deletions pgdog/src/backend/pool/connection/aggregate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ impl<'a> Accumulator<'a> {
match self.target.function() {
AggregateFunction::Count => {
if !self.datum.is_null() {
self.datum = self.datum.clone() + column.value;
self.datum.checked_add_assign(column.value)?;
} else {
self.datum = column.value;
}
Expand All @@ -127,7 +127,7 @@ impl<'a> Accumulator<'a> {
}
AggregateFunction::Sum => {
if !self.datum.is_null() {
self.datum = self.datum.clone() + column.value;
self.datum.checked_add_assign(column.value)?;
} else {
self.datum = column.value;
}
Expand All @@ -150,8 +150,8 @@ impl<'a> Accumulator<'a> {
}

if let Some(weighted) = multiply_for_average(&column.value, &count.value) {
state.weighted_sum = state.weighted_sum.clone() + weighted;
state.total_count = state.total_count.clone() + count.value.clone();
state.weighted_sum.checked_add_assign(weighted)?;
state.total_count.checked_add_assign(count.value.clone())?;
} else {
state.supported = false;
return Ok(false);
Expand Down Expand Up @@ -300,7 +300,7 @@ impl VarianceState {
return Ok(true);
}

self.total_count = self.total_count.clone() + count.value.clone();
self.total_count.checked_add_assign(count.value.clone())?;

let Some(sum_column) = self.sum_column else {
self.supported = false;
Expand All @@ -311,7 +311,7 @@ impl VarianceState {
self.supported = false;
return Ok(false);
}
self.total_sum = self.total_sum.clone() + sum.value.clone();
self.total_sum.checked_add_assign(sum.value.clone())?;

let Some(sumsq_column) = self.sumsq_column else {
self.supported = false;
Expand All @@ -322,7 +322,7 @@ impl VarianceState {
self.supported = false;
return Ok(false);
}
self.total_sumsq = self.total_sumsq.clone() + sumsq.value.clone();
self.total_sumsq.checked_add_assign(sumsq.value.clone())?;

Ok(true)
}
Expand Down
Loading