From 18adf69456b68f60673ac892074f60447b362250 Mon Sep 17 00:00:00 2001 From: Nicholas Gates Date: Thu, 29 Jan 2026 15:35:54 -0500 Subject: [PATCH 01/21] Extension Scalar Signed-off-by: Nicholas Gates --- Cargo.lock | 1 + vortex-dtype/src/extension/mod.rs | 3 +- vortex-scalar/Cargo.toml | 1 + vortex-scalar/src/datetime/date.rs | 66 ++++ vortex-scalar/src/datetime/mod.rs | 35 ++ vortex-scalar/src/datetime/time.rs | 68 ++++ vortex-scalar/src/datetime/timestamp.rs | 115 ++++++ vortex-scalar/src/extension.rs | 460 ---------------------- vortex-scalar/src/extension/mod.rs | 488 ++++++++++++++++++++++++ vortex-scalar/src/extension/vtable.rs | 41 ++ vortex-scalar/src/lib.rs | 6 +- vortex-scalar/src/scalar.rs | 35 +- 12 files changed, 852 insertions(+), 467 deletions(-) create mode 100644 vortex-scalar/src/datetime/date.rs create mode 100644 vortex-scalar/src/datetime/mod.rs create mode 100644 vortex-scalar/src/datetime/time.rs create mode 100644 vortex-scalar/src/datetime/timestamp.rs delete mode 100644 vortex-scalar/src/extension.rs create mode 100644 vortex-scalar/src/extension/mod.rs create mode 100644 vortex-scalar/src/extension/vtable.rs diff --git a/Cargo.lock b/Cargo.lock index 952184756aa..863543638f4 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -10960,6 +10960,7 @@ dependencies = [ "arrow-array 57.2.0", "bytes", "itertools 0.14.0", + "jiff", "num-traits", "paste", "prost 0.14.3", diff --git a/vortex-dtype/src/extension/mod.rs b/vortex-dtype/src/extension/mod.rs index d287120aa04..79de5de72a3 100644 --- a/vortex-dtype/src/extension/mod.rs +++ b/vortex-dtype/src/extension/mod.rs @@ -29,7 +29,7 @@ use crate::Nullability; pub type ExtID = ArcRef; /// An extension data type. -#[derive(Clone)] +#[derive(Debug, Clone)] pub struct ExtDType(Arc>); // Convenience impls for zero-sized VTables @@ -262,6 +262,7 @@ trait ExtDTypeImpl: 'static + Send + Sync + private::Sealed { fn with_nullability(&self, nullability: Nullability) -> ExtDTypeRef; } +#[derive(Debug)] struct ExtDTypeAdapter { vtable: V, metadata: V::Metadata, diff --git a/vortex-scalar/Cargo.toml b/vortex-scalar/Cargo.toml index d69c03f8679..dd52ccbc0ec 100644 --- a/vortex-scalar/Cargo.toml +++ b/vortex-scalar/Cargo.toml @@ -21,6 +21,7 @@ arbitrary = { workspace = true, optional = true } arrow-array = { workspace = true } bytes = { workspace = true } itertools = { workspace = true } +jiff = { workspace = true } num-traits = { workspace = true } paste = { workspace = true } prost = { workspace = true } diff --git a/vortex-scalar/src/datetime/date.rs b/vortex-scalar/src/datetime/date.rs new file mode 100644 index 00000000000..2e4b9278734 --- /dev/null +++ b/vortex-scalar/src/datetime/date.rs @@ -0,0 +1,66 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +use jiff::Span; +use vortex_dtype::DType; +use vortex_dtype::ExtDType; +use vortex_dtype::Nullability; +use vortex_dtype::PType; +use vortex_dtype::datetime::Date; +use vortex_dtype::datetime::TimeUnit; +use vortex_error::VortexExpect; +use vortex_error::VortexResult; +use vortex_error::vortex_err; + +use crate::Scalar; +use crate::ScalarValue; +use crate::datetime::SpanExt; +use crate::extension::ExtScalarVTable; + +impl ExtScalarVTable for Date { + type Value = jiff::civil::Date; + + fn zero(&self, _metadata: &Self::Metadata) -> Self::Value { + jiff::civil::Date::new(1970, 1, 1).vortex_expect("failed to create epoch date") + } + + fn unpack(&self, dtype: &ExtDType, storage: &ScalarValue) -> VortexResult { + let v = storage + .as_pvalue()? + .vortex_expect("storage is non-null") + .cast::(); + let span = Span::from_unit_length(v, *dtype.metadata()); + let epoch = jiff::civil::Date::new(1970, 1, 1)?; + Ok(epoch.checked_add(span)?) + } + + fn pack( + &self, + metadata: &Self::Metadata, + value: Self::Value, + nullability: Nullability, + ) -> VortexResult { + let epoch = jiff::civil::Date::new(1970, 1, 1)?; + let span = value - epoch; + let length = span.get_unit_length(*metadata); + + match metadata { + TimeUnit::Milliseconds => Ok(Scalar::primitive(length, nullability)), + TimeUnit::Days => { + let length = + i32::try_from(length).map_err(|_| vortex_err!("date does not fit in i32"))?; + Ok(Scalar::primitive(length, nullability)) + } + _ => unreachable!("Date only supports Milliseconds and Days time units"), + } + } + + fn pack_null(&self, metadata: &Self::Metadata) -> VortexResult { + let ptype = match metadata { + TimeUnit::Milliseconds => PType::I64, + TimeUnit::Days => PType::I32, + _ => unreachable!("Date only supports Milliseconds and Days time units"), + }; + Ok(Scalar::null(DType::Primitive(ptype, Nullability::Nullable))) + } +} diff --git a/vortex-scalar/src/datetime/mod.rs b/vortex-scalar/src/datetime/mod.rs new file mode 100644 index 00000000000..9d811a4c3ed --- /dev/null +++ b/vortex-scalar/src/datetime/mod.rs @@ -0,0 +1,35 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +use vortex_dtype::datetime::TimeUnit; + +pub mod date; +pub mod time; +pub mod timestamp; + +trait SpanExt { + fn get_unit_length(&self, time_unit: TimeUnit) -> i64; + fn from_unit_length(length: i64, time_unit: TimeUnit) -> Self; +} + +impl SpanExt for jiff::Span { + fn get_unit_length(&self, time_unit: TimeUnit) -> i64 { + match time_unit { + TimeUnit::Nanoseconds => self.get_nanoseconds(), + TimeUnit::Microseconds => self.get_microseconds(), + TimeUnit::Milliseconds => self.get_milliseconds(), + TimeUnit::Seconds => self.get_seconds(), + TimeUnit::Days => self.get_days() as _, + } + } + + fn from_unit_length(length: i64, time_unit: TimeUnit) -> Self { + match time_unit { + TimeUnit::Nanoseconds => jiff::Span::new().nanoseconds(length), + TimeUnit::Microseconds => jiff::Span::new().microseconds(length), + TimeUnit::Milliseconds => jiff::Span::new().milliseconds(length), + TimeUnit::Seconds => jiff::Span::new().seconds(length), + TimeUnit::Days => jiff::Span::new().days(length), + } + } +} diff --git a/vortex-scalar/src/datetime/time.rs b/vortex-scalar/src/datetime/time.rs new file mode 100644 index 00000000000..751b111fcc7 --- /dev/null +++ b/vortex-scalar/src/datetime/time.rs @@ -0,0 +1,68 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +use jiff::Span; +use vortex_dtype::DType; +use vortex_dtype::ExtDType; +use vortex_dtype::Nullability; +use vortex_dtype::PType; +use vortex_dtype::datetime::Time; +use vortex_dtype::datetime::TimeUnit; +use vortex_error::VortexExpect; +use vortex_error::VortexResult; +use vortex_error::vortex_err; + +use crate::Scalar; +use crate::ScalarValue; +use crate::datetime::SpanExt; +use crate::extension::ExtScalarVTable; + +impl ExtScalarVTable for Time { + type Value = jiff::civil::Time; + + fn zero(&self, _metadata: &Self::Metadata) -> Self::Value { + jiff::civil::Time::MIN + } + + fn unpack(&self, dtype: &ExtDType, storage: &ScalarValue) -> VortexResult { + let v = storage + .as_pvalue()? + .vortex_expect("storage is non-null") + .cast::(); + let span = Span::from_unit_length(v, *dtype.metadata()); + let epoch = jiff::civil::Time::MIN; + Ok(epoch.checked_add(span)?) + } + + fn pack( + &self, + metadata: &Self::Metadata, + value: Self::Value, + nullability: Nullability, + ) -> VortexResult { + let epoch = jiff::civil::Time::MIN; + let span = value - epoch; + let length = span.get_unit_length(*metadata); + + Ok(match metadata { + TimeUnit::Nanoseconds | TimeUnit::Microseconds => { + Scalar::primitive(length, nullability) + } + TimeUnit::Milliseconds | TimeUnit::Seconds => { + let length = + i32::try_from(length).map_err(|_| vortex_err!("time does not fit in i32"))?; + Scalar::primitive(length, nullability) + } + TimeUnit::Days => unreachable!("TimeUnit::Days is not supported for Time types"), + }) + } + + fn pack_null(&self, metadata: &Self::Metadata) -> VortexResult { + let ptype = match metadata { + TimeUnit::Nanoseconds | TimeUnit::Microseconds => PType::I64, + TimeUnit::Milliseconds | TimeUnit::Seconds => PType::I32, + TimeUnit::Days => unreachable!("TimeUnit::Days is not supported for Time types"), + }; + Ok(Scalar::null(DType::Primitive(ptype, Nullability::Nullable))) + } +} diff --git a/vortex-scalar/src/datetime/timestamp.rs b/vortex-scalar/src/datetime/timestamp.rs new file mode 100644 index 00000000000..42a34477dfc --- /dev/null +++ b/vortex-scalar/src/datetime/timestamp.rs @@ -0,0 +1,115 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +use std::fmt::Display; +use std::fmt::Formatter; + +use jiff::Span; +use vortex_dtype::DType; +use vortex_dtype::ExtDType; +use vortex_dtype::Nullability; +use vortex_dtype::PType; +use vortex_dtype::datetime::Timestamp; +use vortex_error::VortexExpect; +use vortex_error::VortexResult; + +use crate::Scalar; +use crate::ScalarValue; +use crate::datetime::SpanExt; +use crate::extension::ExtScalarVTable; + +#[derive(Clone, Debug, Hash)] +pub enum TimestampValue { + Zoned(jiff::Zoned), + Unzoned(jiff::civil::DateTime), +} + +impl PartialEq for TimestampValue { + fn eq(&self, other: &Self) -> bool { + match (self, other) { + (TimestampValue::Zoned(a), TimestampValue::Zoned(b)) => a == b, + (TimestampValue::Unzoned(a), TimestampValue::Unzoned(b)) => a == b, + _ => false, + } + } +} + +impl Display for TimestampValue { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + match self { + TimestampValue::Zoned(z) => write!(f, "{}", z), + TimestampValue::Unzoned(dt) => write!(f, "{}", dt), + } + } +} + +impl ExtScalarVTable for Timestamp { + type Value = TimestampValue; + + fn zero(&self, metadata: &Self::Metadata) -> Self::Value { + match &metadata.tz { + None => { + let epoch = jiff::civil::DateTime::new(1970, 1, 1, 0, 0, 0, 0) + .vortex_expect("failed to create epoch datetime"); + TimestampValue::Unzoned(epoch) + } + Some(tz) => { + let epoch = jiff::Timestamp::UNIX_EPOCH; + TimestampValue::Zoned( + epoch + .in_tz(tz.as_ref()) + .vortex_expect("failed to create zoned epoch"), + ) + } + } + } + + fn unpack(&self, dtype: &ExtDType, storage: &ScalarValue) -> VortexResult { + let v = storage + .as_pvalue()? + .vortex_expect("storage is non-null") + .cast::(); + + Ok(match &dtype.metadata().tz { + None => { + let epoch = jiff::civil::DateTime::new(1970, 1, 1, 0, 0, 0, 0)?; + let span = Span::from_unit_length(v, dtype.metadata().unit); + TimestampValue::Unzoned(epoch.checked_add(span)?) + } + Some(tz) => { + let epoch = jiff::Timestamp::UNIX_EPOCH; + let span = Span::from_unit_length(v, dtype.metadata().unit); + TimestampValue::Zoned(epoch.checked_add(span)?.in_tz(tz.as_ref())?) + } + }) + } + + fn pack( + &self, + metadata: &Self::Metadata, + value: Self::Value, + nullability: Nullability, + ) -> VortexResult { + match value { + TimestampValue::Zoned(zoned) => { + let epoch = jiff::Timestamp::UNIX_EPOCH; + let span = zoned.timestamp() - epoch; + let length = span.get_unit_length(metadata.unit); + Ok(Scalar::primitive(length, nullability)) + } + TimestampValue::Unzoned(datetime) => { + let epoch = jiff::civil::DateTime::new(1970, 1, 1, 0, 0, 0, 0)?; + let span = datetime - epoch; + let length = span.get_unit_length(metadata.unit); + Ok(Scalar::primitive(length, nullability)) + } + } + } + + fn pack_null(&self, _metadata: &Self::Metadata) -> VortexResult { + Ok(Scalar::null(DType::Primitive( + PType::I64, + Nullability::Nullable, + ))) + } +} diff --git a/vortex-scalar/src/extension.rs b/vortex-scalar/src/extension.rs deleted file mode 100644 index d64ef9f22c4..00000000000 --- a/vortex-scalar/src/extension.rs +++ /dev/null @@ -1,460 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: Copyright the Vortex contributors - -use std::fmt::Display; -use std::fmt::Formatter; -use std::hash::Hash; - -use vortex_dtype::DType; -use vortex_dtype::ExtDType; -use vortex_dtype::datetime::AnyTemporal; -use vortex_dtype::extension::ExtDTypeRef; -use vortex_dtype::extension::ExtDTypeVTable; -use vortex_error::VortexError; -use vortex_error::VortexExpect; -use vortex_error::VortexResult; -use vortex_error::vortex_bail; - -use crate::Scalar; -use crate::ScalarValue; - -/// A scalar value representing an extension type. -/// -/// Extension types allow wrapping a storage type with custom semantics. -#[derive(Debug, Clone)] -pub struct ExtScalar<'a> { - ext_dtype: &'a ExtDTypeRef, - value: &'a ScalarValue, -} - -impl Display for ExtScalar<'_> { - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - // Specialized handling for date/time/timestamp builtin extension types. - if let Some(temporal) = self.ext_dtype.metadata_opt::() { - let maybe_timestamp = self - .storage() - .as_primitive() - .as_::() - .map(|maybe_timestamp| temporal.to_jiff(maybe_timestamp)) - .transpose() - .map_err(|_| std::fmt::Error)?; - - match maybe_timestamp { - None => write!(f, "null"), - Some(v) => write!(f, "{v}"), - } - } else { - write!(f, "{}({})", self.ext_dtype().id(), self.storage()) - } - } -} - -impl PartialEq for ExtScalar<'_> { - fn eq(&self, other: &Self) -> bool { - self.ext_dtype.eq_ignore_nullability(other.ext_dtype) && self.storage() == other.storage() - } -} - -impl Eq for ExtScalar<'_> {} - -// Ord is not implemented since it's undefined for different Extension DTypes -impl PartialOrd for ExtScalar<'_> { - fn partial_cmp(&self, other: &Self) -> Option { - if !self.ext_dtype.eq_ignore_nullability(other.ext_dtype) { - return None; - } - self.storage().partial_cmp(&other.storage()) - } -} - -impl Hash for ExtScalar<'_> { - fn hash(&self, state: &mut H) { - self.ext_dtype.hash(state); - self.storage().hash(state); - } -} - -impl<'a> ExtScalar<'a> { - /// Creates a new extension scalar from a data type and scalar value. - /// - /// # Errors - /// - /// Returns an error if the data type is not an extension type. - pub fn try_new(dtype: &'a DType, value: &'a ScalarValue) -> VortexResult { - let DType::Extension(ext_dtype) = dtype else { - vortex_bail!("Expected extension scalar, found {}", dtype) - }; - - Ok(Self { ext_dtype, value }) - } - - /// Returns the storage scalar of the extension scalar. - pub fn storage(&self) -> Scalar { - Scalar::new(self.ext_dtype.storage_dtype().clone(), self.value.clone()) - } - - /// Returns the extension data type. - pub fn ext_dtype(&self) -> &'a ExtDTypeRef { - self.ext_dtype - } - - pub(crate) fn cast(&self, dtype: &DType) -> VortexResult { - if self.value.is_null() && !dtype.is_nullable() { - vortex_bail!( - "cannot cast extension dtype with id {} and storage type {} to {}", - self.ext_dtype.id(), - self.ext_dtype.storage_dtype(), - dtype - ); - } - - if self.ext_dtype.storage_dtype().eq_ignore_nullability(dtype) { - // Casting from an extension type to the underlying storage type is OK. - return Ok(Scalar::new(dtype.clone(), self.value.clone())); - } - - if let DType::Extension(ext_dtype) = dtype - && self.ext_dtype.eq_ignore_nullability(ext_dtype) - { - return Ok(Scalar::new(dtype.clone(), self.value.clone())); - } - - vortex_bail!( - "cannot cast extension dtype with id {} and storage type {} to {}", - self.ext_dtype.id(), - self.ext_dtype.storage_dtype(), - dtype - ); - } -} - -impl<'a> TryFrom<&'a Scalar> for ExtScalar<'a> { - type Error = VortexError; - - fn try_from(scalar: &'a Scalar) -> Result { - ExtScalar::try_new(scalar.dtype(), scalar.value()) - } -} - -impl Scalar { - /// Creates a new extension scalar wrapping the given storage value. - pub fn extension(options: V::Metadata, value: Scalar) -> Self { - let ext_dtype = ExtDType::::try_new(options, value.dtype().clone()) - .vortex_expect("Failed to create extension dtype"); - Self::new(DType::Extension(ext_dtype.erased()), value.value().clone()) - } - - /// Creates a new extension scalar wrapping the given storage value. - pub fn extension_ref(ext_dtype: ExtDTypeRef, value: Scalar) -> Self { - assert_eq!(ext_dtype.storage_dtype(), value.dtype()); - Self::new(DType::Extension(ext_dtype), value.value().clone()) - } -} - -#[cfg(test)] -mod tests { - use vortex_dtype::DType; - use vortex_dtype::ExtDType; - use vortex_dtype::ExtID; - use vortex_dtype::Nullability; - use vortex_dtype::PType; - use vortex_dtype::extension::EmptyMetadata; - use vortex_dtype::extension::ExtDTypeVTable; - use vortex_error::VortexResult; - - use crate::ExtScalar; - use crate::InnerScalarValue; - use crate::Scalar; - use crate::ScalarValue; - - #[derive(Debug, Clone, Default)] - struct TestExt; - impl ExtDTypeVTable for TestExt { - type Metadata = EmptyMetadata; - - fn id(&self) -> ExtID { - ExtID::new_ref("test_ext") - } - - fn validate(&self, _options: &Self::Metadata, _storage_dtype: &DType) -> VortexResult<()> { - Ok(()) - } - } - - impl TestExt { - fn new_non_nullable() -> ExtDType { - ExtDType::try_new( - EmptyMetadata, - DType::Primitive(PType::I32, Nullability::NonNullable), - ) - .unwrap() - } - } - - #[test] - fn test_ext_scalar_equality() { - let scalar1 = Scalar::extension::( - EmptyMetadata, - Scalar::primitive(42i32, Nullability::NonNullable), - ); - let scalar2 = Scalar::extension::( - EmptyMetadata, - Scalar::primitive(42i32, Nullability::NonNullable), - ); - let scalar3 = Scalar::extension::( - EmptyMetadata, - Scalar::primitive(43i32, Nullability::NonNullable), - ); - - let ext1 = ExtScalar::try_from(&scalar1).unwrap(); - let ext2 = ExtScalar::try_from(&scalar2).unwrap(); - let ext3 = ExtScalar::try_from(&scalar3).unwrap(); - - assert_eq!(ext1, ext2); - assert_ne!(ext1, ext3); - } - - #[test] - fn test_ext_scalar_partial_ord() { - let scalar1 = Scalar::extension::( - EmptyMetadata, - Scalar::primitive(10i32, Nullability::NonNullable), - ); - let scalar2 = Scalar::extension::( - EmptyMetadata, - Scalar::primitive(20i32, Nullability::NonNullable), - ); - - let ext1 = ExtScalar::try_from(&scalar1).unwrap(); - let ext2 = ExtScalar::try_from(&scalar2).unwrap(); - - assert!(ext1 < ext2); - assert!(ext2 > ext1); - } - - #[test] - fn test_ext_scalar_partial_ord_different_types() { - #[derive(Clone, Debug, Default)] - struct TestExt2; - impl ExtDTypeVTable for TestExt2 { - type Metadata = EmptyMetadata; - - fn id(&self) -> ExtID { - ExtID::new_ref("test_ext_2") - } - - fn validate( - &self, - _options: &Self::Metadata, - _storage_dtype: &DType, - ) -> VortexResult<()> { - Ok(()) - } - } - - let scalar1 = Scalar::extension::( - EmptyMetadata, - Scalar::primitive(10i32, Nullability::NonNullable), - ); - let scalar2 = Scalar::extension::( - EmptyMetadata, - Scalar::primitive(20i32, Nullability::NonNullable), - ); - - let ext1 = ExtScalar::try_from(&scalar1).unwrap(); - let ext2 = ExtScalar::try_from(&scalar2).unwrap(); - - // Different extension types should not be comparable - assert_eq!(ext1.partial_cmp(&ext2), None); - } - - #[test] - fn test_ext_scalar_hash() { - use vortex_utils::aliases::hash_set::HashSet; - - let scalar1 = Scalar::extension::( - EmptyMetadata, - Scalar::primitive(42i32, Nullability::NonNullable), - ); - let scalar2 = Scalar::extension::( - EmptyMetadata, - Scalar::primitive(42i32, Nullability::NonNullable), - ); - - let mut set = HashSet::new(); - set.insert(scalar2); - set.insert(scalar1); - - // Same value should hash the same - assert_eq!(set.len(), 1); - - // Different value should hash differently - let scalar3 = Scalar::extension::( - EmptyMetadata, - Scalar::primitive(43i32, Nullability::NonNullable), - ); - set.insert(scalar3); - assert_eq!(set.len(), 2); - } - - #[test] - fn test_ext_scalar_storage() { - let storage_scalar = Scalar::primitive(42i32, Nullability::NonNullable); - let ext_scalar = Scalar::extension::(EmptyMetadata, storage_scalar.clone()); - - let ext = ExtScalar::try_from(&ext_scalar).unwrap(); - assert_eq!(ext.storage(), storage_scalar); - } - - #[test] - fn test_ext_scalar_ext_dtype() { - let ext_dtype = TestExt::new_non_nullable(); - let scalar = Scalar::extension::( - EmptyMetadata.clone(), - Scalar::primitive(42i32, Nullability::NonNullable), - ); - - let ext = ExtScalar::try_from(&scalar).unwrap(); - assert_eq!(ext.ext_dtype().id(), ext_dtype.id()); - assert_eq!(ext.ext_dtype(), &ext_dtype.erased()); - } - - #[test] - fn test_ext_scalar_cast_to_storage() { - let scalar = Scalar::extension::( - EmptyMetadata, - Scalar::primitive(42i32, Nullability::NonNullable), - ); - - let ext = ExtScalar::try_from(&scalar).unwrap(); - - // Cast to storage type - let casted = ext - .cast(&DType::Primitive(PType::I32, Nullability::NonNullable)) - .unwrap(); - assert_eq!( - casted.dtype(), - &DType::Primitive(PType::I32, Nullability::NonNullable) - ); - assert_eq!(casted.as_primitive().typed_value::(), Some(42)); - - // Cast to nullable storage type - let casted_nullable = ext - .cast(&DType::Primitive(PType::I32, Nullability::Nullable)) - .unwrap(); - assert_eq!( - casted_nullable.dtype(), - &DType::Primitive(PType::I32, Nullability::Nullable) - ); - assert_eq!( - casted_nullable.as_primitive().typed_value::(), - Some(42) - ); - } - - #[test] - fn test_ext_scalar_cast_to_self() { - let ext_dtype = TestExt::new_non_nullable(); - - let scalar = Scalar::extension::( - EmptyMetadata, - Scalar::primitive(42i32, Nullability::NonNullable), - ); - - let ext = ExtScalar::try_from(&scalar).unwrap(); - let ext_dtype = ext_dtype.erased(); - - // Cast to same extension type - let casted = ext.cast(&DType::Extension(ext_dtype.clone())).unwrap(); - assert_eq!(casted.dtype(), &DType::Extension(ext_dtype.clone())); - - // Cast to nullable version of same extension type - let nullable_ext = DType::Extension(ext_dtype).as_nullable(); - let casted_nullable = ext.cast(&nullable_ext).unwrap(); - assert_eq!(casted_nullable.dtype(), &nullable_ext); - } - - #[test] - fn test_ext_scalar_cast_incompatible() { - let scalar = Scalar::extension::( - EmptyMetadata, - Scalar::primitive(42i32, Nullability::NonNullable), - ); - - let ext = ExtScalar::try_from(&scalar).unwrap(); - - // Cast to incompatible type should fail - let result = ext.cast(&DType::Utf8(Nullability::NonNullable)); - assert!(result.is_err()); - } - - #[test] - fn test_ext_scalar_cast_null_to_non_nullable() { - let scalar = Scalar::extension::( - EmptyMetadata, - Scalar::null(DType::Primitive(PType::I32, Nullability::Nullable)), - ); - - let ext = ExtScalar::try_from(&scalar).unwrap(); - - // Cast null to non-nullable should fail - let result = ext.cast(&DType::Primitive(PType::I32, Nullability::NonNullable)); - assert!(result.is_err()); - } - - #[test] - fn test_ext_scalar_try_new_non_extension() { - let dtype = DType::Primitive(PType::I32, Nullability::NonNullable); - let value = ScalarValue(InnerScalarValue::Primitive(crate::PValue::I32(42))); - - let result = ExtScalar::try_new(&dtype, &value); - assert!(result.is_err()); - } - - #[test] - fn test_ext_scalar_with_metadata() { - #[derive(Clone, Debug, Default)] - struct TestExtMetadata; - impl ExtDTypeVTable for TestExtMetadata { - type Metadata = usize; - - fn id(&self) -> ExtID { - ExtID::new_ref("test_ext_metadata") - } - - fn validate( - &self, - _options: &Self::Metadata, - _storage_dtype: &DType, - ) -> VortexResult<()> { - Ok(()) - } - } - - let scalar = Scalar::extension::( - 1234, - Scalar::primitive(42i32, Nullability::NonNullable), - ); - - let ext = ExtScalar::try_from(&scalar).unwrap(); - assert_eq!(ext.ext_dtype().metadata::(), &1234); - } - - #[test] - fn test_ext_scalar_equality_ignores_nullability() { - let scalar1 = Scalar::extension::( - EmptyMetadata, - Scalar::primitive(42i32, Nullability::NonNullable), - ); - let scalar2 = Scalar::extension::( - EmptyMetadata, - Scalar::primitive(42i32, Nullability::Nullable), - ); - - let ext1 = ExtScalar::try_from(&scalar1).unwrap(); - let ext2 = ExtScalar::try_from(&scalar2).unwrap(); - - // Equality should ignore nullability differences - assert_eq!(ext1, ext2); - } -} diff --git a/vortex-scalar/src/extension/mod.rs b/vortex-scalar/src/extension/mod.rs new file mode 100644 index 00000000000..02929c58263 --- /dev/null +++ b/vortex-scalar/src/extension/mod.rs @@ -0,0 +1,488 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +mod vtable; + +use std::any::type_name; + +use vortex_dtype::DType; +use vortex_dtype::ExtDType; +use vortex_dtype::ExtDTypeRef; +use vortex_dtype::Nullability; +use vortex_error::VortexExpect; +use vortex_error::VortexResult; +use vortex_error::vortex_bail; +use vortex_error::vortex_ensure; +use vortex_error::vortex_err; +use vortex_error::vortex_panic; +pub use vtable::*; + +use crate::Scalar; +use crate::ScalarValue; + +/// A typed extension scalar. +#[derive(Debug, Clone)] +pub struct ExtScalar { + ext_dtype: ExtDType, + value: Option, +} + +impl ExtScalar { + /// Creates a new extension scalar from a data type and scalar value. + /// + /// # Errors + /// + /// Returns an error if the data type is not an extension type. + pub fn try_from_scalar(dtype: &ExtDTypeRef, value: &ScalarValue) -> VortexResult { + let ext_dtype = dtype.clone().try_downcast::().map_err(|_| { + vortex_err!( + "Expected extension dtype of type {}, got {}", + type_name::(), + dtype.id() + ) + })?; + let vtable = V::default(); + + if value.is_null() { + vortex_ensure!( + ext_dtype.storage_dtype().is_nullable(), + "Cannot create non-nullable extension scalar of type {} with null value", + ext_dtype.id() + ); + return Ok(Self { + ext_dtype, + value: None, + }); + } + + let value = vtable.unpack(&ext_dtype, value)?; + Ok(Self { + ext_dtype, + value: Some(value), + }) + } +} + +impl ExtScalar { + /// Get a reference to the extension DType. + pub fn ext_dtype(&self) -> &ExtDType { + &self.ext_dtype + } + + /// Get a reference to the scalar value. + pub fn value(&self) -> Option<&V::Value> { + self.value.as_ref() + } + + pub(crate) fn cast(&self, dtype: &DType) -> VortexResult { + if self.value.is_none() && !dtype.is_nullable() { + vortex_bail!( + "cannot cast extension dtype with id {} and storage type {} to {}", + self.ext_dtype.id(), + self.ext_dtype.storage_dtype(), + dtype + ); + } + + if self.ext_dtype.storage_dtype().eq_ignore_nullability(dtype) { + // Casting from an extension type to the underlying storage type is OK. + return Ok(Scalar::new(dtype.clone(), self.value.clone())); + } + + if let DType::Extension(ext_dtype) = dtype + && self.ext_dtype.eq_ignore_nullability(ext_dtype) + { + return Ok(Scalar::new(dtype.clone(), self.value.clone())); + } + + vortex_bail!( + "cannot cast extension dtype with id {} and storage type {} to {}", + self.ext_dtype.id(), + self.ext_dtype.storage_dtype(), + dtype + ); + } +} + +/// A type-erased extension scalar. +#[derive(Debug, Clone)] +pub struct ExtScalarRef<'a> { + ext_dtype: &'a ExtDTypeRef, + value: &'a ScalarValue, +} + +impl ExtScalarRef<'_> { + pub(crate) fn cast(&self, dtype: &DType) -> VortexResult { + if self.value.is_null() && !dtype.is_nullable() { + vortex_bail!( + "cannot cast extension dtype with id {} and storage type {} to {}", + self.ext_dtype.id(), + self.ext_dtype.storage_dtype(), + dtype + ); + } + + if self.ext_dtype.storage_dtype().eq_ignore_nullability(dtype) { + // Casting from an extension type to the underlying storage type is OK. + return Ok(Scalar::new(dtype.clone(), self.value.clone())); + } + + if let DType::Extension(ext_dtype) = dtype + && self.ext_dtype.eq_ignore_nullability(ext_dtype) + { + return Ok(Scalar::new(dtype.clone(), self.value.clone())); + } + + vortex_bail!( + "cannot cast extension dtype with id {} and storage type {} to {}", + self.ext_dtype.id(), + self.ext_dtype.storage_dtype(), + dtype + ); + } +} + +impl Scalar { + /// Creates a new extension scalar wrapping the given storage value. + pub fn extension( + metadata: V::Metadata, + value: Option, + nullability: Nullability, + ) -> VortexResult { + if value.is_none() && nullability == Nullability::NonNullable { + vortex_bail!( + "Cannot create non-nullable extension scalar of type {} with null value", + type_name::(), + ); + } + + let vtable = V::default(); + + let storage = match value { + None => vtable.pack_null(&metadata), + Some(value) => vtable.pack(&metadata, value, nullability), + }?; + + let ext_dtype = ExtDType::::try_new(metadata, storage.dtype().clone()) + .vortex_expect("Failed to create extension dtype"); + + Ok(Self::new( + DType::Extension(ext_dtype.erased()), + storage.into_value(), + )) + } + + /// Creates a new extension scalar wrapping the given storage value. + pub fn extension_ref(ext_dtype: ExtDTypeRef, value: Scalar) -> Self { + assert_eq!(ext_dtype.storage_dtype(), value.dtype()); + Self::new(DType::Extension(ext_dtype), value.value().clone()) + } +} + +// #[cfg(test)] +// mod tests { +// use vortex_dtype::DType; +// use vortex_dtype::ExtDType; +// use vortex_dtype::ExtID; +// use vortex_dtype::Nullability; +// use vortex_dtype::PType; +// use vortex_dtype::extension::EmptyMetadata; +// use vortex_dtype::extension::ExtDTypeVTable; +// use vortex_error::VortexResult; +// +// use crate::ExtScalar; +// use crate::InnerScalarValue; +// use crate::Scalar; +// use crate::ScalarValue; +// +// #[derive(Debug, Clone, Default)] +// struct TestExt; +// impl ExtDTypeVTable for TestExt { +// type Metadata = EmptyMetadata; +// +// fn id(&self) -> ExtID { +// ExtID::new_ref("test_ext") +// } +// +// fn validate(&self, _options: &Self::Metadata, _storage_dtype: &DType) -> VortexResult<()> { +// Ok(()) +// } +// } +// +// impl TestExt { +// fn new_non_nullable() -> ExtDType { +// ExtDType::try_new( +// EmptyMetadata, +// DType::Primitive(PType::I32, Nullability::NonNullable), +// ) +// .unwrap() +// } +// } +// +// #[test] +// fn test_ext_scalar_equality() { +// let scalar1 = Scalar::extension::( +// EmptyMetadata, +// Scalar::primitive(42i32, Nullability::NonNullable), +// ); +// let scalar2 = Scalar::extension::( +// EmptyMetadata, +// Scalar::primitive(42i32, Nullability::NonNullable), +// ); +// let scalar3 = Scalar::extension::( +// EmptyMetadata, +// Scalar::primitive(43i32, Nullability::NonNullable), +// ); +// +// let ext1 = ExtScalar::try_from(&scalar1).unwrap(); +// let ext2 = ExtScalar::try_from(&scalar2).unwrap(); +// let ext3 = ExtScalar::try_from(&scalar3).unwrap(); +// +// assert_eq!(ext1, ext2); +// assert_ne!(ext1, ext3); +// } +// +// #[test] +// fn test_ext_scalar_partial_ord() { +// let scalar1 = Scalar::extension::( +// EmptyMetadata, +// Scalar::primitive(10i32, Nullability::NonNullable), +// ); +// let scalar2 = Scalar::extension::( +// EmptyMetadata, +// Scalar::primitive(20i32, Nullability::NonNullable), +// ); +// +// let ext1 = ExtScalar::try_from(&scalar1).unwrap(); +// let ext2 = ExtScalar::try_from(&scalar2).unwrap(); +// +// assert!(ext1 < ext2); +// assert!(ext2 > ext1); +// } +// +// #[test] +// fn test_ext_scalar_partial_ord_different_types() { +// #[derive(Clone, Debug, Default)] +// struct TestExt2; +// impl ExtDTypeVTable for TestExt2 { +// type Metadata = EmptyMetadata; +// +// fn id(&self) -> ExtID { +// ExtID::new_ref("test_ext_2") +// } +// +// fn validate( +// &self, +// _options: &Self::Metadata, +// _storage_dtype: &DType, +// ) -> VortexResult<()> { +// Ok(()) +// } +// } +// +// let scalar1 = Scalar::extension::( +// EmptyMetadata, +// Scalar::primitive(10i32, Nullability::NonNullable), +// ); +// let scalar2 = Scalar::extension::( +// EmptyMetadata, +// Scalar::primitive(20i32, Nullability::NonNullable), +// ); +// +// let ext1 = ExtScalar::try_from(&scalar1).unwrap(); +// let ext2 = ExtScalar::try_from(&scalar2).unwrap(); +// +// // Different extension types should not be comparable +// assert_eq!(ext1.partial_cmp(&ext2), None); +// } +// +// #[test] +// fn test_ext_scalar_hash() { +// use vortex_utils::aliases::hash_set::HashSet; +// +// let scalar1 = Scalar::extension::( +// EmptyMetadata, +// Scalar::primitive(42i32, Nullability::NonNullable), +// ); +// let scalar2 = Scalar::extension::( +// EmptyMetadata, +// Scalar::primitive(42i32, Nullability::NonNullable), +// ); +// +// let mut set = HashSet::new(); +// set.insert(scalar2); +// set.insert(scalar1); +// +// // Same value should hash the same +// assert_eq!(set.len(), 1); +// +// // Different value should hash differently +// let scalar3 = Scalar::extension::( +// EmptyMetadata, +// Scalar::primitive(43i32, Nullability::NonNullable), +// ); +// set.insert(scalar3); +// assert_eq!(set.len(), 2); +// } +// +// #[test] +// fn test_ext_scalar_storage() { +// let storage_scalar = Scalar::primitive(42i32, Nullability::NonNullable); +// let ext_scalar = Scalar::extension::(EmptyMetadata, storage_scalar.clone()); +// +// let ext = ExtScalar::try_from(&ext_scalar).unwrap(); +// assert_eq!(ext.storage(), storage_scalar); +// } +// +// #[test] +// fn test_ext_scalar_ext_dtype() { +// let ext_dtype = TestExt::new_non_nullable(); +// let scalar = Scalar::extension::( +// EmptyMetadata.clone(), +// Scalar::primitive(42i32, Nullability::NonNullable), +// ); +// +// let ext = ExtScalar::try_from(&scalar).unwrap(); +// assert_eq!(ext.ext_dtype().id(), ext_dtype.id()); +// assert_eq!(ext.ext_dtype(), &ext_dtype.erased()); +// } +// +// #[test] +// fn test_ext_scalar_cast_to_storage() { +// let scalar = Scalar::extension::( +// EmptyMetadata, +// Scalar::primitive(42i32, Nullability::NonNullable), +// ); +// +// let ext = ExtScalar::try_from(&scalar).unwrap(); +// +// // Cast to storage type +// let casted = ext +// .cast(&DType::Primitive(PType::I32, Nullability::NonNullable)) +// .unwrap(); +// assert_eq!( +// casted.dtype(), +// &DType::Primitive(PType::I32, Nullability::NonNullable) +// ); +// assert_eq!(casted.as_primitive().typed_value::(), Some(42)); +// +// // Cast to nullable storage type +// let casted_nullable = ext +// .cast(&DType::Primitive(PType::I32, Nullability::Nullable)) +// .unwrap(); +// assert_eq!( +// casted_nullable.dtype(), +// &DType::Primitive(PType::I32, Nullability::Nullable) +// ); +// assert_eq!( +// casted_nullable.as_primitive().typed_value::(), +// Some(42) +// ); +// } +// +// #[test] +// fn test_ext_scalar_cast_to_self() { +// let ext_dtype = TestExt::new_non_nullable(); +// +// let scalar = Scalar::extension::( +// EmptyMetadata, +// Scalar::primitive(42i32, Nullability::NonNullable), +// ); +// +// let ext = ExtScalar::try_from(&scalar).unwrap(); +// let ext_dtype = ext_dtype.erased(); +// +// // Cast to same extension type +// let casted = ext.cast(&DType::Extension(ext_dtype.clone())).unwrap(); +// assert_eq!(casted.dtype(), &DType::Extension(ext_dtype.clone())); +// +// // Cast to nullable version of same extension type +// let nullable_ext = DType::Extension(ext_dtype).as_nullable(); +// let casted_nullable = ext.cast(&nullable_ext).unwrap(); +// assert_eq!(casted_nullable.dtype(), &nullable_ext); +// } +// +// #[test] +// fn test_ext_scalar_cast_incompatible() { +// let scalar = Scalar::extension::( +// EmptyMetadata, +// Scalar::primitive(42i32, Nullability::NonNullable), +// ); +// +// let ext = ExtScalar::try_from(&scalar).unwrap(); +// +// // Cast to incompatible type should fail +// let result = ext.cast(&DType::Utf8(Nullability::NonNullable)); +// assert!(result.is_err()); +// } +// +// #[test] +// fn test_ext_scalar_cast_null_to_non_nullable() { +// let scalar = Scalar::extension::( +// EmptyMetadata, +// Scalar::null(DType::Primitive(PType::I32, Nullability::Nullable)), +// ); +// +// let ext = ExtScalar::try_from(&scalar).unwrap(); +// +// // Cast null to non-nullable should fail +// let result = ext.cast(&DType::Primitive(PType::I32, Nullability::NonNullable)); +// assert!(result.is_err()); +// } +// +// #[test] +// fn test_ext_scalar_try_new_non_extension() { +// let dtype = DType::Primitive(PType::I32, Nullability::NonNullable); +// let value = ScalarValue(InnerScalarValue::Primitive(crate::PValue::I32(42))); +// +// let result = ExtScalar::try_new(&dtype, &value); +// assert!(result.is_err()); +// } +// +// #[test] +// fn test_ext_scalar_with_metadata() { +// #[derive(Clone, Debug, Default)] +// struct TestExtMetadata; +// impl ExtDTypeVTable for TestExtMetadata { +// type Metadata = usize; +// +// fn id(&self) -> ExtID { +// ExtID::new_ref("test_ext_metadata") +// } +// +// fn validate( +// &self, +// _options: &Self::Metadata, +// _storage_dtype: &DType, +// ) -> VortexResult<()> { +// Ok(()) +// } +// } +// +// let scalar = Scalar::extension::( +// 1234, +// Scalar::primitive(42i32, Nullability::NonNullable), +// ); +// +// let ext = ExtScalar::try_from(&scalar).unwrap(); +// assert_eq!(ext.ext_dtype().metadata::(), &1234); +// } +// +// #[test] +// fn test_ext_scalar_equality_ignores_nullability() { +// let scalar1 = Scalar::extension::( +// EmptyMetadata, +// Scalar::primitive(42i32, Nullability::NonNullable), +// ); +// let scalar2 = Scalar::extension::( +// EmptyMetadata, +// Scalar::primitive(42i32, Nullability::Nullable), +// ); +// +// let ext1 = ExtScalar::try_from(&scalar1).unwrap(); +// let ext2 = ExtScalar::try_from(&scalar2).unwrap(); +// +// // Equality should ignore nullability differences +// assert_eq!(ext1, ext2); +// } +// } diff --git a/vortex-scalar/src/extension/vtable.rs b/vortex-scalar/src/extension/vtable.rs new file mode 100644 index 00000000000..18ebb26ed90 --- /dev/null +++ b/vortex-scalar/src/extension/vtable.rs @@ -0,0 +1,41 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +use std::fmt::Debug; +use std::fmt::Display; +use std::hash::Hash; + +use vortex_dtype::ExtDType; +use vortex_dtype::Nullability; +use vortex_dtype::extension::ExtDTypeVTable; +use vortex_error::VortexResult; + +use crate::Scalar; +use crate::ScalarValue; + +/// API for defining the scalar behavior of an extension DType. +pub trait ExtScalarVTable: ExtDTypeVTable { + /// The native value type for this extension scalar. + /// The `Default` trait should return a value representing `zero`. + // TODO(ngates): require total ordering? + type Value: 'static + Send + Sync + Clone + Debug + Display + PartialEq + Hash; + + /// Return the `zero` value for this extension scalar. + fn zero(&self, metadata: &Self::Metadata) -> Self::Value; + + /// Unpack the native value from the given scalar. + /// + /// Note that the storage scalar is guaranteed to be non-null. + fn unpack(&self, dtype: &ExtDType, storage: &ScalarValue) -> VortexResult; + + /// Pack the native value into the storage scalar. + fn pack( + &self, + metadata: &Self::Metadata, + value: Self::Value, + nullability: Nullability, + ) -> VortexResult; + + /// Pack a null value into the storage scalar. + fn pack_null(&self, metadata: &Self::Metadata) -> VortexResult; +} diff --git a/vortex-scalar/src/lib.rs b/vortex-scalar/src/lib.rs index 2cbb62be58a..ca50450bc60 100644 --- a/vortex-scalar/src/lib.rs +++ b/vortex-scalar/src/lib.rs @@ -16,7 +16,7 @@ mod binary; mod bool; mod decimal; mod display; -mod extension; +pub mod extension; mod list; mod null; mod primitive; @@ -30,7 +30,8 @@ mod utf8; pub use binary::*; pub use bool::*; pub use decimal::*; -pub use extension::*; +pub use extension::ExtScalar; +pub use extension::ExtScalarRef; pub use list::*; pub use primitive::*; pub use pvalue::*; @@ -39,5 +40,6 @@ pub use scalar_value::*; pub use struct_::*; pub use utf8::*; +mod datetime; #[cfg(test)] mod tests; diff --git a/vortex-scalar/src/scalar.rs b/vortex-scalar/src/scalar.rs index ec9d0ca4580..89d02530997 100644 --- a/vortex-scalar/src/scalar.rs +++ b/vortex-scalar/src/scalar.rs @@ -18,6 +18,7 @@ use vortex_error::vortex_bail; use vortex_error::vortex_err; use super::*; +use crate::extension::ExtScalarVTable; /// A single logical item, composed of both a [`ScalarValue`] and a logical [`DType`]. /// @@ -41,6 +42,8 @@ pub struct Scalar { impl Scalar { /// Creates a new scalar with the given data type and value. + /// + // TODO(ngates): make this unsafe. There's no guarantee that the value matches the dtype. pub fn new(dtype: DType, value: ScalarValue) -> Self { if !dtype.is_nullable() { assert!( @@ -455,13 +458,37 @@ impl Scalar { /// # Panics /// /// Panics if the scalar is not an extension type. - pub fn as_extension(&self) -> ExtScalar<'_> { - ExtScalar::try_from(self).vortex_expect("Failed to convert scalar to extension") + pub fn as_extension(&self) -> ExtScalar { + self.as_extension_opt::() + .vortex_expect("Failed to convert scalar to extension") + } + + /// Returns a view of the scalar as an extension scalar if it has an extension type. + pub fn as_extension_opt(&self) -> Option> { + let DType::Extension(ext_dtype) = &self.dtype else { + return None; + }; + if !ext_dtype.is::() { + return None; + } + Some( + ExtScalar::try_from_scalar(ext_dtype, &self.value) + .vortex_expect("Failed to convert scalar to extension"), + ) + } + + /// Returns a view of the scalar as an extension scalar. + /// + /// # Panics + /// + /// Panics if the scalar is not an extension type. + pub fn as_extension_ref(&self) -> ExtScalarRef<'_> { + ExtScalarRef::try_from(self).vortex_expect("Failed to convert scalar to extension") } /// Returns a view of the scalar as an extension scalar if it has an extension type. - pub fn as_extension_opt(&self) -> Option> { - matches!(self.dtype, DType::Extension(..)).then(|| self.as_extension()) + pub fn as_extension_ref_opt(&self) -> Option> { + matches!(self.dtype, DType::Extension(..)).then(|| self.as_extension_ref()) } } From 65ce7a4520d33dca700ae1f5b39ecd20c5f32f72 Mon Sep 17 00:00:00 2001 From: Nicholas Gates Date: Thu, 29 Jan 2026 16:09:41 -0500 Subject: [PATCH 02/21] Extension Scalar Signed-off-by: Nicholas Gates --- vortex-dtype/src/datetime/date.rs | 2 +- vortex-dtype/src/datetime/time.rs | 2 +- vortex-dtype/src/datetime/timestamp.rs | 2 +- vortex-dtype/src/extension/mod.rs | 20 +- vortex-dtype/src/extension/vtable.rs | 2 +- vortex-scalar/src/datetime/mod.rs | 6 +- vortex-scalar/src/extension/mod.rs | 277 ++++++++++++++++--------- vortex-scalar/src/extension/vtable.rs | 8 + vortex-scalar/src/lib.rs | 1 + vortex-scalar/src/scalar.rs | 10 +- vortex-scalar/src/tests/mod.rs | 23 ++ 11 files changed, 235 insertions(+), 118 deletions(-) diff --git a/vortex-dtype/src/datetime/date.rs b/vortex-dtype/src/datetime/date.rs index 66738a2cbb4..5204b8904c5 100644 --- a/vortex-dtype/src/datetime/date.rs +++ b/vortex-dtype/src/datetime/date.rs @@ -15,7 +15,7 @@ use crate::extension::ExtDTypeVTable; use crate::extension::ExtID; /// Date DType. -#[derive(Clone, Debug, Default)] +#[derive(Clone, Debug, Default, PartialEq, Eq, Hash)] pub struct Date; impl Date { diff --git a/vortex-dtype/src/datetime/time.rs b/vortex-dtype/src/datetime/time.rs index f1f45014058..563e97cb47c 100644 --- a/vortex-dtype/src/datetime/time.rs +++ b/vortex-dtype/src/datetime/time.rs @@ -15,7 +15,7 @@ use crate::extension::ExtDTypeVTable; use crate::extension::ExtID; /// Time DType. -#[derive(Clone, Debug, Default)] +#[derive(Clone, Debug, Default, PartialEq, Eq, Hash)] pub struct Time; impl Time { diff --git a/vortex-dtype/src/datetime/timestamp.rs b/vortex-dtype/src/datetime/timestamp.rs index 56bbdcdd094..4d7a7f6ecf0 100644 --- a/vortex-dtype/src/datetime/timestamp.rs +++ b/vortex-dtype/src/datetime/timestamp.rs @@ -22,7 +22,7 @@ use crate::extension::ExtDTypeVTable; use crate::extension::ExtID; /// Timestamp DType. -#[derive(Clone, Debug, Default)] +#[derive(Clone, Debug, Default, PartialEq, Eq, Hash)] pub struct Timestamp; impl Timestamp { diff --git a/vortex-dtype/src/extension/mod.rs b/vortex-dtype/src/extension/mod.rs index 79de5de72a3..5155ba832ca 100644 --- a/vortex-dtype/src/extension/mod.rs +++ b/vortex-dtype/src/extension/mod.rs @@ -29,7 +29,7 @@ use crate::Nullability; pub type ExtID = ArcRef; /// An extension data type. -#[derive(Debug, Clone)] +#[derive(Debug, Clone, PartialEq, Eq, Hash)] pub struct ExtDType(Arc>); // Convenience impls for zero-sized VTables @@ -60,6 +60,11 @@ impl ExtDType { self.0.id() } + /// Returns the vtable of the extension type. + pub fn vtable(&self) -> &V { + &self.0.vtable + } + /// Returns the metadata of the extension type. pub fn metadata(&self) -> &V::Metadata { &self.0.metadata @@ -211,7 +216,7 @@ impl ExtDTypeRef { /// Wrapper for type-erased extension dtype metadata. pub struct ExtDTypeMetadata<'a> { - pub(super) ext_dtype: &'a ExtDTypeRef, + ext_dtype: &'a ExtDTypeRef, } impl ExtDTypeMetadata<'_> { @@ -249,7 +254,7 @@ impl Hash for ExtDTypeMetadata<'_> { } /// An object-safe trait encapsulating the behavior for extension DTypes. -trait ExtDTypeImpl: 'static + Send + Sync + private::Sealed { +trait ExtDTypeImpl: 'static + Send + Sync { fn as_any(&self) -> &dyn Any; fn id(&self) -> ExtID; fn storage_dtype(&self) -> &DType; @@ -262,7 +267,7 @@ trait ExtDTypeImpl: 'static + Send + Sync + private::Sealed { fn with_nullability(&self, nullability: Nullability) -> ExtDTypeRef; } -#[derive(Debug)] +#[derive(Debug, Hash, PartialEq, Eq)] struct ExtDTypeAdapter { vtable: V, metadata: V::Metadata, @@ -315,10 +320,3 @@ impl ExtDTypeImpl for ExtDTypeAdapter { .vortex_expect("Extension DType {} incorrect fails validation with the same storage type but different nullability").erased() } } - -mod private { - use super::ExtDTypeAdapter; - - pub trait Sealed {} - impl Sealed for ExtDTypeAdapter {} -} diff --git a/vortex-dtype/src/extension/vtable.rs b/vortex-dtype/src/extension/vtable.rs index b650c187c5e..0d0e494227e 100644 --- a/vortex-dtype/src/extension/vtable.rs +++ b/vortex-dtype/src/extension/vtable.rs @@ -14,7 +14,7 @@ use crate::ExtID; use crate::extension::ExtDTypeRef; /// The public API for defining new extension DTypes. -pub trait ExtDTypeVTable: 'static + Sized + Send + Sync + Clone + Debug { +pub trait ExtDTypeVTable: 'static + Sized + Send + Sync + Clone + Debug + Eq + Hash { /// Associated type containing the deserialized metadata for this extension type type Metadata: 'static + Send + Sync + Clone + Debug + Display + Eq + Hash; diff --git a/vortex-scalar/src/datetime/mod.rs b/vortex-scalar/src/datetime/mod.rs index 9d811a4c3ed..17882061a9c 100644 --- a/vortex-scalar/src/datetime/mod.rs +++ b/vortex-scalar/src/datetime/mod.rs @@ -3,9 +3,9 @@ use vortex_dtype::datetime::TimeUnit; -pub mod date; -pub mod time; -pub mod timestamp; +// pub mod date; +// pub mod time; +// pub mod timestamp; trait SpanExt { fn get_unit_length(&self, time_unit: TimeUnit) -> i64; diff --git a/vortex-scalar/src/extension/mod.rs b/vortex-scalar/src/extension/mod.rs index 02929c58263..24c2cbaff6d 100644 --- a/vortex-scalar/src/extension/mod.rs +++ b/vortex-scalar/src/extension/mod.rs @@ -3,142 +3,227 @@ mod vtable; +use std::any::Any; use std::any::type_name; +use std::fmt::Debug; +use std::fmt::Display; +use std::fmt::Formatter; +use std::hash::Hash; +use std::hash::Hasher; +use std::mem::discriminant; +use std::sync::Arc; use vortex_dtype::DType; use vortex_dtype::ExtDType; use vortex_dtype::ExtDTypeRef; +use vortex_dtype::ExtID; use vortex_dtype::Nullability; use vortex_error::VortexExpect; use vortex_error::VortexResult; use vortex_error::vortex_bail; -use vortex_error::vortex_ensure; -use vortex_error::vortex_err; -use vortex_error::vortex_panic; pub use vtable::*; use crate::Scalar; -use crate::ScalarValue; /// A typed extension scalar. #[derive(Debug, Clone)] -pub struct ExtScalar { - ext_dtype: ExtDType, - value: Option, -} +pub struct ExtScalar(Arc>); impl ExtScalar { - /// Creates a new extension scalar from a data type and scalar value. - /// - /// # Errors - /// - /// Returns an error if the data type is not an extension type. - pub fn try_from_scalar(dtype: &ExtDTypeRef, value: &ScalarValue) -> VortexResult { - let ext_dtype = dtype.clone().try_downcast::().map_err(|_| { - vortex_err!( - "Expected extension dtype of type {}, got {}", - type_name::(), - dtype.id() - ) - })?; - let vtable = V::default(); - - if value.is_null() { - vortex_ensure!( - ext_dtype.storage_dtype().is_nullable(), - "Cannot create non-nullable extension scalar of type {} with null value", - ext_dtype.id() - ); - return Ok(Self { - ext_dtype, - value: None, - }); - } - - let value = vtable.unpack(&ext_dtype, value)?; - Ok(Self { - ext_dtype, - value: Some(value), - }) + /// Creates a new extension scalar from a scalar value. + pub fn try_new( + metadata: V::Metadata, + value: Option, + nullability: Nullability, + ) -> VortexResult { + Self::try_with_vtable(V::default(), metadata, value, nullability) } } impl ExtScalar { - /// Get a reference to the extension DType. - pub fn ext_dtype(&self) -> &ExtDType { - &self.ext_dtype + /// Creates a new extension scalar from a vtable, metadata, and scalar value. + pub fn try_with_vtable( + vtable: V, + metadata: V::Metadata, + value: Option, + nullability: Nullability, + ) -> VortexResult { + let storage_dtype = vtable.storage_dtype(&metadata, nullability)?; + let dtype = ExtDType::::try_with_vtable(vtable, metadata, storage_dtype)?; + Ok(Self(Arc::new(ExtScalarAdapter:: { dtype, value }))) } - /// Get a reference to the scalar value. + /// Returns the identifier of the extension scalar. + pub fn id(&self) -> ExtID { + self.0.dtype.id() + } + + /// Returns the vtable of this extension scalar. + pub fn vtable(&self) -> &V { + self.0.dtype.vtable() + } + + /// Returns the value of this extension scalar. pub fn value(&self) -> Option<&V::Value> { - self.value.as_ref() + self.0.value.as_ref() } - pub(crate) fn cast(&self, dtype: &DType) -> VortexResult { - if self.value.is_none() && !dtype.is_nullable() { - vortex_bail!( - "cannot cast extension dtype with id {} and storage type {} to {}", - self.ext_dtype.id(), - self.ext_dtype.storage_dtype(), - dtype - ); - } + /// Erase the concrete type information, returning a type-erased extension scalar. + pub fn erased(self) -> ExtScalarRef { + ExtScalarRef(self.0) + } - if self.ext_dtype.storage_dtype().eq_ignore_nullability(dtype) { - // Casting from an extension type to the underlying storage type is OK. - return Ok(Scalar::new(dtype.clone(), self.value.clone())); - } + // pub(crate) fn cast(&self, dtype: &DType) -> VortexResult { + // if self.value.is_none() && !dtype.is_nullable() { + // vortex_bail!( + // "cannot cast extension dtype with id {} and storage type {} to {}", + // self.ext_dtype.id(), + // self.ext_dtype.storage_dtype(), + // dtype + // ); + // } + // + // if self.ext_dtype.storage_dtype().eq_ignore_nullability(dtype) { + // // Casting from an extension type to the underlying storage type is OK. + // return Ok(Scalar::new(dtype.clone(), self.value.clone())); + // } + // + // if let DType::Extension(ext_dtype) = dtype + // && self.ext_dtype.eq_ignore_nullability(ext_dtype) + // { + // return Ok(Scalar::new(dtype.clone(), self.value.clone())); + // } + // + // vortex_bail!( + // "cannot cast extension dtype with id {} and storage type {} to {}", + // self.ext_dtype.id(), + // self.ext_dtype.storage_dtype(), + // dtype + // ); + // } +} - if let DType::Extension(ext_dtype) = dtype - && self.ext_dtype.eq_ignore_nullability(ext_dtype) - { - return Ok(Scalar::new(dtype.clone(), self.value.clone())); - } +/// A type-erased extension scalar. +#[derive(Clone)] +pub struct ExtScalarRef(Arc); - vortex_bail!( - "cannot cast extension dtype with id {} and storage type {} to {}", - self.ext_dtype.id(), - self.ext_dtype.storage_dtype(), - dtype - ); +impl Display for ExtScalarRef { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + self.0.value_display(f) } } -/// A type-erased extension scalar. -#[derive(Debug, Clone)] -pub struct ExtScalarRef<'a> { - ext_dtype: &'a ExtDTypeRef, - value: &'a ScalarValue, +impl ExtScalarRef { + /// Returns the identifier of the extension scalar. + pub fn id(&self) -> ExtID { + self.0.id() + } + + /// Returns the type-erased dtype of this extension scalar. + pub fn dtype_erased(&self) -> ExtDTypeRef { + self.0.dtype_erased() + } + + /// Returns the type-erased value of this extension scalar. + pub fn value_erased(&self) -> ExtScalarValue<'_> { + ExtScalarValue { scalar: self } + } } -impl ExtScalarRef<'_> { - pub(crate) fn cast(&self, dtype: &DType) -> VortexResult { - if self.value.is_null() && !dtype.is_nullable() { - vortex_bail!( - "cannot cast extension dtype with id {} and storage type {} to {}", - self.ext_dtype.id(), - self.ext_dtype.storage_dtype(), - dtype - ); +/// A type-erased reference to an extension scalar value. +pub struct ExtScalarValue<'a> { + scalar: &'a ExtScalarRef, +} + +impl Display for ExtScalarValue<'_> { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + self.scalar.0.value_display(f) + } +} + +impl Debug for ExtScalarValue<'_> { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + self.scalar.0.value_debug(f) + } +} + +impl PartialEq for ExtScalarValue<'_> { + fn eq(&self, other: &Self) -> bool { + if self.scalar.dtype_erased() != other.scalar.dtype_erased() { + return false; } + self.scalar.0.value_eq(other.scalar.0.value_any()) + } +} +impl Eq for ExtScalarValue<'_> {} + +trait ExtScalarImpl: 'static + Send + Sync { + fn as_any(&self) -> &dyn Any; + fn id(&self) -> ExtID; + fn dtype_erased(&self) -> ExtDTypeRef; + fn value_any(&self) -> Option<&dyn Any>; + fn value_debug(&self, f: &mut Formatter<'_>) -> std::fmt::Result; + fn value_display(&self, f: &mut Formatter<'_>) -> std::fmt::Result; + fn value_eq(&self, other: Option<&dyn Any>) -> bool; + fn value_hash(&self, state: &mut dyn Hasher); +} - if self.ext_dtype.storage_dtype().eq_ignore_nullability(dtype) { - // Casting from an extension type to the underlying storage type is OK. - return Ok(Scalar::new(dtype.clone(), self.value.clone())); +#[derive(Debug)] +struct ExtScalarAdapter { + dtype: ExtDType, + value: Option, +} + +impl ExtScalarImpl for ExtScalarAdapter { + fn as_any(&self) -> &dyn Any { + self + } + + fn id(&self) -> ExtID { + self.dtype.id() + } + + fn dtype_erased(&self) -> ExtDTypeRef { + self.dtype.clone().erased() + } + + fn value_any(&self) -> Option<&dyn Any> { + self.value.as_ref().map(|v| v as &dyn Any) + } + + fn value_debug(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + match &self.value { + None => return write!(f, "null"), + Some(value) => ::fmt(value, f), + } + } + + fn value_display(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + match &self.value { + None => write!(f, "null"), + Some(value) => ::fmt(value, f), } + } - if let DType::Extension(ext_dtype) = dtype - && self.ext_dtype.eq_ignore_nullability(ext_dtype) - { - return Ok(Scalar::new(dtype.clone(), self.value.clone())); + fn value_eq(&self, other: Option<&dyn Any>) -> bool { + match (&self.value, other) { + (None, None) => true, + (Some(_), None) | (None, Some(_)) => false, + (Some(value), Some(other)) => { + let Some(other) = other.downcast_ref::() else { + return false; + }; + ::eq(value, other) + } } + } - vortex_bail!( - "cannot cast extension dtype with id {} and storage type {} to {}", - self.ext_dtype.id(), - self.ext_dtype.storage_dtype(), - dtype - ); + fn value_hash(&self, mut state: &mut dyn Hasher) { + self.dtype.hash(&mut state); + discriminant(&self.value).hash(&mut state); + if let Some(value) = self.value.as_ref() { + ::hash(value, &mut state); + } } } diff --git a/vortex-scalar/src/extension/vtable.rs b/vortex-scalar/src/extension/vtable.rs index 18ebb26ed90..757191e007c 100644 --- a/vortex-scalar/src/extension/vtable.rs +++ b/vortex-scalar/src/extension/vtable.rs @@ -5,6 +5,7 @@ use std::fmt::Debug; use std::fmt::Display; use std::hash::Hash; +use vortex_dtype::DType; use vortex_dtype::ExtDType; use vortex_dtype::Nullability; use vortex_dtype::extension::ExtDTypeVTable; @@ -23,6 +24,13 @@ pub trait ExtScalarVTable: ExtDTypeVTable { /// Return the `zero` value for this extension scalar. fn zero(&self, metadata: &Self::Metadata) -> Self::Value; + /// Returns the storage dtype for this extension scalar. + fn storage_dtype( + &self, + metadata: &Self::Metadata, + nullability: Nullability, + ) -> VortexResult; + /// Unpack the native value from the given scalar. /// /// Note that the storage scalar is guaranteed to be non-null. diff --git a/vortex-scalar/src/lib.rs b/vortex-scalar/src/lib.rs index ca50450bc60..b776d77c14f 100644 --- a/vortex-scalar/src/lib.rs +++ b/vortex-scalar/src/lib.rs @@ -41,5 +41,6 @@ pub use struct_::*; pub use utf8::*; mod datetime; + #[cfg(test)] mod tests; diff --git a/vortex-scalar/src/scalar.rs b/vortex-scalar/src/scalar.rs index 89d02530997..d2509f2121d 100644 --- a/vortex-scalar/src/scalar.rs +++ b/vortex-scalar/src/scalar.rs @@ -482,13 +482,15 @@ impl Scalar { /// # Panics /// /// Panics if the scalar is not an extension type. - pub fn as_extension_ref(&self) -> ExtScalarRef<'_> { - ExtScalarRef::try_from(self).vortex_expect("Failed to convert scalar to extension") + pub fn as_extension_ref(&self) -> ExtScalarRef { + // ExtScalarRef::try_from(self).vortex_expect("Failed to convert scalar to extension") + todo!() } /// Returns a view of the scalar as an extension scalar if it has an extension type. - pub fn as_extension_ref_opt(&self) -> Option> { - matches!(self.dtype, DType::Extension(..)).then(|| self.as_extension_ref()) + pub fn as_extension_ref_opt(&self) -> Option { + todo!() + // matches!(self.dtype, DType::Extension(..)).then(|| self.as_extension_ref()) } } diff --git a/vortex-scalar/src/tests/mod.rs b/vortex-scalar/src/tests/mod.rs index bdae7b9d4b1..3a3f9bd844c 100644 --- a/vortex-scalar/src/tests/mod.rs +++ b/vortex-scalar/src/tests/mod.rs @@ -12,8 +12,31 @@ mod round_trip; use std::sync::LazyLock; +use vortex_dtype::DType; +use vortex_dtype::ExtID; +use vortex_dtype::extension::EmptyMetadata; +use vortex_dtype::extension::ExtDTypeVTable; use vortex_dtype::session::DTypeSession; +use vortex_error::VortexResult; +use vortex_error::vortex_ensure; use vortex_session::VortexSession; pub(crate) static SESSION: LazyLock = LazyLock::new(|| VortexSession::empty().with::()); + +/// We define a dummy extension type here for testing purposes. +#[derive(Debug, Clone, Default)] +struct Even; + +impl ExtDTypeVTable for Even { + type Metadata = EmptyMetadata; + + fn id(&self) -> ExtID { + ExtID::new_ref("test.even") + } + + fn validate(&self, _options: &Self::Metadata, storage_dtype: &DType) -> VortexResult<()> { + vortex_ensure!(storage_dtype.is_primitive()); + Ok(()) + } +} From 5a57ba7a1dd9e57f1685544a3bf7aa8268a47999 Mon Sep 17 00:00:00 2001 From: Nicholas Gates Date: Fri, 30 Jan 2026 12:20:47 -0500 Subject: [PATCH 03/21] Scan API Signed-off-by: Nicholas Gates --- vortex-scalar/src/arrow/mod.rs | 2 +- vortex-scalar/src/arrow/tests.rs | 142 ++++++---------------- vortex-scalar/src/datetime/date.rs | 24 ++-- vortex-scalar/src/datetime/mod.rs | 37 +++++- vortex-scalar/src/datetime/time.rs | 24 ++-- vortex-scalar/src/datetime/timestamp.rs | 60 ++++------ vortex-scalar/src/display.rs | 2 +- vortex-scalar/src/extension/matcher.rs | 35 ++++++ vortex-scalar/src/extension/mod.rs | 150 +++++++++++++++++++++--- vortex-scalar/src/extension/vtable.rs | 18 +-- vortex-scalar/src/lib.rs | 1 + vortex-scalar/src/scalar.rs | 28 +++-- vortex-scalar/src/tests/casting.rs | 12 +- vortex-scalar/src/tests/mod.rs | 50 +++++++- vortex-scalar/src/tests/primitives.rs | 6 +- vortex-scalar/src/v2.rs | 146 +++++++++++++++++++++++ 16 files changed, 514 insertions(+), 223 deletions(-) create mode 100644 vortex-scalar/src/extension/matcher.rs create mode 100644 vortex-scalar/src/v2.rs diff --git a/vortex-scalar/src/arrow/mod.rs b/vortex-scalar/src/arrow/mod.rs index 50213d920cf..6891474a059 100644 --- a/vortex-scalar/src/arrow/mod.rs +++ b/vortex-scalar/src/arrow/mod.rs @@ -128,7 +128,7 @@ impl TryFrom<&Scalar> for Arc { vortex_bail!("Cannot convert extension scalar {} to Arrow", ext.id()) }; - let storage_scalar = value.as_extension().storage(); + let storage_scalar = value.as_extension_storage(); let primitive = storage_scalar .as_primitive_opt() .ok_or_else(|| vortex_err!("Expected primitive scalar"))?; diff --git a/vortex-scalar/src/arrow/tests.rs b/vortex-scalar/src/arrow/tests.rs index e4eb451ac42..6d090b328c0 100644 --- a/vortex-scalar/src/arrow/tests.rs +++ b/vortex-scalar/src/arrow/tests.rs @@ -13,11 +13,11 @@ use vortex_dtype::datetime::Time; use vortex_dtype::datetime::TimeUnit; use vortex_dtype::datetime::Timestamp; use vortex_dtype::datetime::TimestampOptions; -use vortex_dtype::extension::ExtDTypeVTable; -use vortex_error::VortexResult; -use vortex_error::vortex_bail; +use vortex_dtype::extension::EmptyMetadata; use crate::Scalar; +use crate::datetime::timestamp::TimestampValue; +use crate::tests::Even; #[test] fn test_null_scalar_to_arrow() { @@ -260,123 +260,70 @@ fn test_list_scalar_to_arrow_todo() { #[test] #[should_panic(expected = "Cannot convert extension scalar")] fn test_non_temporal_extension_to_arrow_todo() { - use vortex_dtype::ExtID; - - #[derive(Debug, Clone, Default)] - struct SomeExt; - impl ExtDTypeVTable for SomeExt { - type Metadata = String; - - fn id(&self) -> ExtID { - ExtID::new_ref("some_ext") - } - - fn serialize(&self, _options: &Self::Metadata) -> VortexResult> { - vortex_bail!("not implemented") - } - - fn deserialize(&self, _data: &[u8]) -> VortexResult { - vortex_bail!("not implemented") - } - - fn validate(&self, _options: &Self::Metadata, _storage_dtype: &DType) -> VortexResult<()> { - Ok(()) - } - } - - let scalar = Scalar::extension::( - "".into(), - Scalar::primitive(42i32, Nullability::NonNullable), - ); - + let scalar = + Scalar::extension::(EmptyMetadata, Some(32), Nullability::NonNullable).unwrap(); Arc::::try_from(&scalar).unwrap(); } #[rstest] -#[case(TimeUnit::Nanoseconds, PType::I64, 123456789i64)] -#[case(TimeUnit::Microseconds, PType::I64, 123456789i64)] -#[case(TimeUnit::Milliseconds, PType::I32, 123456i64)] -#[case(TimeUnit::Seconds, PType::I32, 1234i64)] -fn test_temporal_time_to_arrow( - #[case] time_unit: TimeUnit, - #[case] ptype: PType, - #[case] value: i64, -) { - let scalar = Scalar::extension::