From 84235335f17c1497b44de8d2309f645f1fda1385 Mon Sep 17 00:00:00 2001 From: Baris Palaska Date: Wed, 22 Apr 2026 14:22:13 +0100 Subject: [PATCH 01/21] arrow dtype extension conversion Signed-off-by: Baris Palaska --- Cargo.lock | 2 + Cargo.toml | 1 + vortex-array/Cargo.toml | 1 + vortex-array/public-api.lock | 36 ++ vortex-array/src/dtype/arrow.rs | 468 ++++++++++++++++++--- vortex-array/src/extension/mod.rs | 2 +- vortex-array/src/extension/tests/mod.rs | 2 +- vortex-tensor/Cargo.toml | 1 + vortex-tensor/src/lib.rs | 2 + vortex-tensor/src/tests/arrow_roundtrip.rs | 164 ++++++++ 10 files changed, 609 insertions(+), 70 deletions(-) create mode 100644 vortex-tensor/src/tests/arrow_roundtrip.rs diff --git a/Cargo.lock b/Cargo.lock index 945d35e3e08..0ad391e531f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -10255,6 +10255,7 @@ dependencies = [ "arrow-select 58.1.0", "arrow-string 58.1.0", "async-lock", + "base64", "bytes", "cfg-if", "codspeed-divan-compat", @@ -11092,6 +11093,7 @@ dependencies = [ name = "vortex-tensor" version = "0.1.0" dependencies = [ + "arrow-schema 58.1.0", "codspeed-divan-compat", "half", "itertools 0.14.0", diff --git a/Cargo.toml b/Cargo.toml index 12e0acb882d..561f696d6c9 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -106,6 +106,7 @@ async-lock = "3.4" async-stream = "0.3.6" async-trait = "0.1.89" base16ct = "1.0.0" +base64 = "0.22" bigdecimal = "0.4.8" bindgen = "0.72.0" bit-vec = "0.9.0" diff --git a/vortex-array/Cargo.toml b/vortex-array/Cargo.toml index f9adbeb99db..19bb5aef43f 100644 --- a/vortex-array/Cargo.toml +++ b/vortex-array/Cargo.toml @@ -32,6 +32,7 @@ arrow-schema = { workspace = true } arrow-select = { workspace = true } arrow-string = { workspace = true } async-lock = { workspace = true } +base64 = { workspace = true } bytes = { workspace = true } cfg-if = { workspace = true } cudarc = { workspace = true, optional = true } diff --git a/vortex-array/public-api.lock b/vortex-array/public-api.lock index 8ee3c97f17b..554e6509131 100644 --- a/vortex-array/public-api.lock +++ b/vortex-array/public-api.lock @@ -8492,6 +8492,26 @@ impl vortex_array::dtype::arrow::FromArrowType Self +pub trait vortex_array::dtype::arrow::FromArrowWithSession: core::marker::Sized + +pub fn vortex_array::dtype::arrow::FromArrowWithSession::from_arrow_with_session(value: T, session: &vortex_session::VortexSession) -> Self + +impl vortex_array::dtype::arrow::FromArrowWithSession<&alloc::sync::Arc> for vortex_array::dtype::DType + +pub fn vortex_array::dtype::DType::from_arrow_with_session(value: &arrow_schema::schema::SchemaRef, session: &vortex_session::VortexSession) -> Self + +impl vortex_array::dtype::arrow::FromArrowWithSession<&arrow_schema::field::Field> for vortex_array::dtype::DType + +pub fn vortex_array::dtype::DType::from_arrow_with_session(field: &arrow_schema::field::Field, session: &vortex_session::VortexSession) -> Self + +impl vortex_array::dtype::arrow::FromArrowWithSession<&arrow_schema::fields::Fields> for vortex_array::dtype::StructFields + +pub fn vortex_array::dtype::StructFields::from_arrow_with_session(value: &arrow_schema::fields::Fields, session: &vortex_session::VortexSession) -> Self + +impl vortex_array::dtype::arrow::FromArrowWithSession<&arrow_schema::schema::Schema> for vortex_array::dtype::DType + +pub fn vortex_array::dtype::DType::from_arrow_with_session(value: &arrow_schema::schema::Schema, session: &vortex_session::VortexSession) -> Self + pub trait vortex_array::dtype::arrow::TryFromArrowType: core::marker::Sized pub fn vortex_array::dtype::arrow::TryFromArrowType::try_from_arrow(value: T) -> vortex_error::VortexResult @@ -9038,6 +9058,18 @@ impl vortex_array::dtype::arrow::FromArrowType Self +impl vortex_array::dtype::arrow::FromArrowWithSession<&alloc::sync::Arc> for vortex_array::dtype::DType + +pub fn vortex_array::dtype::DType::from_arrow_with_session(value: &arrow_schema::schema::SchemaRef, session: &vortex_session::VortexSession) -> Self + +impl vortex_array::dtype::arrow::FromArrowWithSession<&arrow_schema::field::Field> for vortex_array::dtype::DType + +pub fn vortex_array::dtype::DType::from_arrow_with_session(field: &arrow_schema::field::Field, session: &vortex_session::VortexSession) -> Self + +impl vortex_array::dtype::arrow::FromArrowWithSession<&arrow_schema::schema::Schema> for vortex_array::dtype::DType + +pub fn vortex_array::dtype::DType::from_arrow_with_session(value: &arrow_schema::schema::Schema, session: &vortex_session::VortexSession) -> Self + impl vortex_flatbuffers::FlatBufferRoot for vortex_array::dtype::DType impl vortex_flatbuffers::WriteFlatBuffer for vortex_array::dtype::DType @@ -9992,6 +10024,10 @@ impl vortex_array::dtype::arrow::FromArrowType<&arrow_schema::fields::Fields> fo pub fn vortex_array::dtype::StructFields::from_arrow(value: &arrow_schema::fields::Fields) -> Self +impl vortex_array::dtype::arrow::FromArrowWithSession<&arrow_schema::fields::Fields> for vortex_array::dtype::StructFields + +pub fn vortex_array::dtype::StructFields::from_arrow_with_session(value: &arrow_schema::fields::Fields, session: &vortex_session::VortexSession) -> Self + impl core::iter::traits::collect::FromIterator<(T, V)> for vortex_array::dtype::StructFields where T: core::convert::Into, V: core::convert::Into pub fn vortex_array::dtype::StructFields::from_iter>(iter: I) -> Self diff --git a/vortex-array/src/dtype/arrow.rs b/vortex-array/src/dtype/arrow.rs index 17af749cfc0..96d19e20f19 100644 --- a/vortex-array/src/dtype/arrow.rs +++ b/vortex-array/src/dtype/arrow.rs @@ -23,19 +23,29 @@ use arrow_schema::Schema; use arrow_schema::SchemaBuilder; use arrow_schema::SchemaRef; use arrow_schema::TimeUnit as ArrowTimeUnit; +use arrow_schema::extension::EXTENSION_TYPE_METADATA_KEY; +use arrow_schema::extension::EXTENSION_TYPE_NAME_KEY; +use base64::Engine; +use base64::prelude::BASE64_STANDARD; use vortex_error::VortexError; use vortex_error::VortexExpect; use vortex_error::VortexResult; use vortex_error::vortex_bail; use vortex_error::vortex_err; use vortex_error::vortex_panic; +use vortex_session::VortexSession; +use crate::LEGACY_SESSION; use crate::dtype::DType; use crate::dtype::DecimalDType; use crate::dtype::FieldName; use crate::dtype::Nullability; use crate::dtype::PType; use crate::dtype::StructFields; +use crate::dtype::extension::ExtDTypeRef; +use crate::dtype::extension::ExtId; +use crate::dtype::session::DTypeSession; +use crate::dtype::session::DTypeSessionExt; use crate::extension::datetime::AnyTemporal; use crate::extension::datetime::Date; use crate::extension::datetime::TemporalMetadata; @@ -43,6 +53,8 @@ use crate::extension::datetime::Time; use crate::extension::datetime::TimeUnit; use crate::extension::datetime::Timestamp; +const ARROW_EXT_NAME_VARIANT: &str = "arrow.parquet.variant"; + /// Trait for converting Arrow types to Vortex types. pub trait FromArrowType: Sized { /// Convert the Arrow type to a Vortex type. @@ -55,6 +67,15 @@ pub trait TryFromArrowType: Sized { fn try_from_arrow(value: T) -> VortexResult; } +/// Conversion from Arrow types to Vortex types using a session's extension dtype registry to +/// resolve `ARROW:extension:name` metadata into [`DType::Extension`] values. +/// +/// Unregistered or malformed extension metadata falls back to the storage dtype. +pub trait FromArrowWithSession: Sized { + /// Convert the Arrow type to a Vortex type. + fn from_arrow_with_session(value: T, session: &VortexSession) -> Self; +} + impl TryFromArrowType<&DataType> for PType { fn try_from_arrow(value: &DataType) -> VortexResult { match value { @@ -126,25 +147,44 @@ impl TryFrom for ArrowTimeUnit { impl FromArrowType for DType { fn from_arrow(value: SchemaRef) -> Self { - Self::from_arrow(value.as_ref()) + Self::from_arrow_with_session(value.as_ref(), &LEGACY_SESSION) } } impl FromArrowType<&Schema> for DType { fn from_arrow(value: &Schema) -> Self { + Self::from_arrow_with_session(value, &LEGACY_SESSION) + } +} + +impl FromArrowType<&Fields> for StructFields { + fn from_arrow(value: &Fields) -> Self { + Self::from_arrow_with_session(value, &LEGACY_SESSION) + } +} + +impl FromArrowWithSession<&SchemaRef> for DType { + fn from_arrow_with_session(value: &SchemaRef, session: &VortexSession) -> Self { + Self::from_arrow_with_session(value.as_ref(), session) + } +} + +impl FromArrowWithSession<&Schema> for DType { + fn from_arrow_with_session(value: &Schema, session: &VortexSession) -> Self { Self::Struct( - StructFields::from_arrow(value.fields()), + StructFields::from_arrow_with_session(value.fields(), session), Nullability::NonNullable, // Must match From for Array ) } } -impl FromArrowType<&Fields> for StructFields { - fn from_arrow(value: &Fields) -> Self { +impl FromArrowWithSession<&Fields> for StructFields { + fn from_arrow_with_session(value: &Fields, session: &VortexSession) -> Self { + let dtypes = session.dtypes(); StructFields::from_iter(value.into_iter().map(|f| { ( FieldName::from(f.name().as_str()), - DType::from_arrow(f.as_ref()), + dtype_from_field(f.as_ref(), &dtypes), ) })) } @@ -210,15 +250,108 @@ impl FromArrowType<(&DataType, Nullability)> for DType { impl FromArrowType<&Field> for DType { fn from_arrow(field: &Field) -> Self { - if field - .metadata() - .get("ARROW:extension:name") - .map(|s| s.as_str()) - == Some("arrow.parquet.variant") - { - return DType::Variant(field.is_nullable().into()); + Self::from_arrow_with_session(field, &LEGACY_SESSION) + } +} + +impl FromArrowWithSession<&Field> for DType { + fn from_arrow_with_session(field: &Field, session: &VortexSession) -> Self { + dtype_from_field(field, &session.dtypes()) + } +} + +/// Convert an Arrow Field to a [`DType`] using a pre-borrowed [`DTypeSession`] for extension +/// lookup. Used by the `&Fields` and `&Field` impls so the session handle is acquired once per +/// schema rather than once per field. +fn dtype_from_field(field: &Field, dtypes: &DTypeSession) -> DType { + let ext_name = field.extension_type_name(); + + // Variant maps to its own DType variant, not DType::Extension. + if ext_name.is_some_and(|s| s == ARROW_EXT_NAME_VARIANT) { + return DType::Variant(field.is_nullable().into()); + } + + let storage_dtype = storage_dtype_from_field(field, dtypes); + + let Some(ext_name) = ext_name else { + return storage_dtype; + }; + + let ext_id = ExtId::new(ext_name); + let Some(plugin) = dtypes.registry().find(&ext_id) else { + tracing::warn!( + "Arrow field {:?} extension id {:?} not registered; using storage dtype", + field.name(), + ext_name, + ); + return storage_dtype; + }; + + let metadata_bytes = match decode_extension_metadata(field) { + Ok(bytes) => bytes, + Err(e) => { + tracing::warn!( + "Arrow field {:?} extension id {:?} has malformed metadata ({}); \ + using storage dtype", + field.name(), + ext_name, + e, + ); + return storage_dtype; + } + }; + + match plugin.deserialize(&metadata_bytes, storage_dtype.clone()) { + Ok(ext_ref) => DType::Extension(ext_ref), + Err(e) => { + tracing::warn!( + "Arrow field {:?} extension id {:?} failed to deserialize ({}); \ + using storage dtype", + field.name(), + ext_name, + e, + ); + storage_dtype + } + } +} + +/// Decodes base64-encoded extension metadata. Missing / empty values yield an empty vector. +fn decode_extension_metadata(field: &Field) -> VortexResult> { + match field.extension_type_metadata() { + None | Some("") => Ok(Vec::new()), + Some(s) => BASE64_STANDARD + .decode(s) + .map_err(|e| vortex_err!("failed to base64-decode {EXTENSION_TYPE_METADATA_KEY}: {e}")), + } +} + +/// Recursively build the storage [`DType`] for an Arrow Field, threading `dtypes` through +/// nested child fields so nested extensions are also resolved. +fn storage_dtype_from_field(field: &Field, dtypes: &DTypeSession) -> DType { + let nullability: Nullability = field.is_nullable().into(); + match field.data_type() { + DataType::Struct(f) => DType::Struct( + StructFields::from_iter(f.into_iter().map(|child| { + ( + FieldName::from(child.name().as_str()), + dtype_from_field(child.as_ref(), dtypes), + ) + })), + nullability, + ), + DataType::List(e) + | DataType::LargeList(e) + | DataType::ListView(e) + | DataType::LargeListView(e) => { + DType::List(Arc::new(dtype_from_field(e.as_ref(), dtypes)), nullability) } - Self::from_arrow((field.data_type(), field.is_nullable().into())) + DataType::FixedSizeList(e, size) => DType::FixedSizeList( + Arc::new(dtype_from_field(e.as_ref(), dtypes)), + *size as u32, + nullability, + ), + other => DType::from_arrow((other, nullability)), } } @@ -235,23 +368,7 @@ impl DType { let mut builder = SchemaBuilder::with_capacity(struct_dtype.names().len()); for (field_name, field_dtype) in struct_dtype.names().iter().zip(struct_dtype.fields()) { - let field = if field_dtype.is_variant() { - let storage = DataType::Struct(variant_storage_fields_minimal()); - Field::new(field_name.as_ref(), storage, field_dtype.is_nullable()).with_metadata( - [( - "ARROW:extension:name".to_owned(), - "arrow.parquet.variant".to_owned(), - )] - .into(), - ) - } else { - Field::new( - field_name.as_ref(), - field_dtype.to_arrow_dtype()?, - field_dtype.is_nullable(), - ) - }; - builder.push(field); + builder.push(field_from_dtype(field_name.as_ref(), &field_dtype)?); } Ok(builder.finish()) @@ -296,26 +413,25 @@ impl DType { // There are four kinds of lists: List (32-bit offsets), Large List (64-bit), List View // (32-bit), Large List View (64-bit). We cannot both guarantee zero-copy and commit to an // Arrow dtype because we do not how large our offsets are. - DType::List(elem_dtype, _) => DataType::List(FieldRef::new(Field::new_list_field( - elem_dtype.to_arrow_dtype()?, - elem_dtype.nullability().into(), - ))), + DType::List(elem_dtype, _) => DataType::List(FieldRef::new(field_from_dtype( + Field::LIST_FIELD_DEFAULT_NAME, + elem_dtype, + )?)), DType::FixedSizeList(elem_dtype, size, _) => DataType::FixedSizeList( - FieldRef::new(Field::new_list_field( - elem_dtype.to_arrow_dtype()?, - elem_dtype.nullability().into(), - )), + FieldRef::new(field_from_dtype( + Field::LIST_FIELD_DEFAULT_NAME, + elem_dtype, + )?), *size as i32, ), DType::Struct(struct_dtype, _) => { let mut fields = Vec::with_capacity(struct_dtype.names().len()); for (field_name, field_dt) in struct_dtype.names().iter().zip(struct_dtype.fields()) { - fields.push(FieldRef::from(Field::new( + fields.push(FieldRef::from(field_from_dtype( field_name.as_ref(), - field_dt.to_arrow_dtype()?, - field_dt.is_nullable(), - ))); + &field_dt, + )?)); } DataType::Struct(Fields::from(fields)) @@ -324,37 +440,90 @@ impl DType { "DType::Variant requires Arrow Field metadata; use to_arrow_schema or a Field helper" ), DType::Extension(ext_dtype) => { - // Try and match against the known extension DTypes. - if let Some(temporal) = ext_dtype.metadata_opt::() { - return Ok(match temporal { - TemporalMetadata::Timestamp(unit, tz) => { - DataType::Timestamp(ArrowTimeUnit::try_from(*unit)?, tz.clone()) - } - TemporalMetadata::Date(unit) => match unit { - TimeUnit::Days => DataType::Date32, - TimeUnit::Milliseconds => DataType::Date64, - TimeUnit::Nanoseconds | TimeUnit::Microseconds | TimeUnit::Seconds => { - vortex_panic!(InvalidArgument: "Invalid TimeUnit {} for {}", unit, ext_dtype.id()) - } - }, - TemporalMetadata::Time(unit) => match unit { - TimeUnit::Seconds => DataType::Time32(ArrowTimeUnit::Second), - TimeUnit::Milliseconds => DataType::Time32(ArrowTimeUnit::Millisecond), - TimeUnit::Microseconds => DataType::Time64(ArrowTimeUnit::Microsecond), - TimeUnit::Nanoseconds => DataType::Time64(ArrowTimeUnit::Nanosecond), - TimeUnit::Days => { - vortex_panic!(InvalidArgument: "Invalid TimeUnit {} for {}", unit, ext_dtype.id()) - } - }, - }); - }; - - vortex_bail!("Unsupported extension type \"{}\"", ext_dtype.id()) + if let Some(native) = native_arrow_dtype_for_extension(ext_dtype) { + return Ok(native); + } + // Extension identity lives on the Field (see `field_from_dtype`), not on + // DataType, so here we only encode the storage type. + ext_dtype.storage_dtype().to_arrow_dtype()? } }) } } +/// Build an Arrow [`Field`], attaching `ARROW:extension:name` and, when present, +/// `ARROW:extension:metadata` for extensions and Variant that have no native Arrow mapping. +fn field_from_dtype(name: &str, dtype: &DType) -> VortexResult { + if dtype.is_variant() { + let storage = DataType::Struct(variant_storage_fields_minimal()); + return Ok( + Field::new(name, storage, dtype.is_nullable()).with_metadata( + [( + EXTENSION_TYPE_NAME_KEY.to_owned(), + ARROW_EXT_NAME_VARIANT.to_owned(), + )] + .into(), + ), + ); + } + + if let DType::Extension(ext) = dtype { + // Native Arrow mapping carries the semantics in DataType; emitting extension metadata + // on top would break consumers that only understand native Arrow types. + if let Some(native) = native_arrow_dtype_for_extension(ext) { + return Ok(Field::new(name, native, dtype.is_nullable())); + } + + let storage_arrow = ext.storage_dtype().to_arrow_dtype()?; + let mut metadata = vec![( + EXTENSION_TYPE_NAME_KEY.to_owned(), + ext.id().as_str().to_owned(), + )]; + let ext_meta_bytes = ext.serialize_metadata()?; + if !ext_meta_bytes.is_empty() { + metadata.push(( + EXTENSION_TYPE_METADATA_KEY.to_owned(), + BASE64_STANDARD.encode(&ext_meta_bytes), + )); + } + return Ok(Field::new(name, storage_arrow, dtype.is_nullable()) + .with_metadata(metadata.into_iter().collect())); + } + + Ok(Field::new( + name, + dtype.to_arrow_dtype()?, + dtype.is_nullable(), + )) +} + +/// Returns the native Arrow [`DataType`] for extensions Arrow models directly (e.g. temporal). +/// `None` means the extension should round-trip via storage + Field metadata. +fn native_arrow_dtype_for_extension(ext_dtype: &ExtDTypeRef) -> Option { + let temporal = ext_dtype.metadata_opt::()?; + Some(match temporal { + TemporalMetadata::Timestamp(unit, tz) => { + DataType::Timestamp(ArrowTimeUnit::try_from(*unit).ok()?, tz.clone()) + } + TemporalMetadata::Date(unit) => match unit { + TimeUnit::Days => DataType::Date32, + TimeUnit::Milliseconds => DataType::Date64, + TimeUnit::Nanoseconds | TimeUnit::Microseconds | TimeUnit::Seconds => { + vortex_panic!(InvalidArgument: "Invalid TimeUnit {} for {}", unit, ext_dtype.id()) + } + }, + TemporalMetadata::Time(unit) => match unit { + TimeUnit::Seconds => DataType::Time32(ArrowTimeUnit::Second), + TimeUnit::Milliseconds => DataType::Time32(ArrowTimeUnit::Millisecond), + TimeUnit::Microseconds => DataType::Time64(ArrowTimeUnit::Microsecond), + TimeUnit::Nanoseconds => DataType::Time64(ArrowTimeUnit::Nanosecond), + TimeUnit::Days => { + vortex_panic!(InvalidArgument: "Invalid TimeUnit {} for {}", unit, ext_dtype.id()) + } + }, + }) +} + fn variant_storage_fields_minimal() -> Fields { Fields::from(vec![ Field::new("metadata", DataType::Binary, false), @@ -561,4 +730,167 @@ mod test { assert_eq!(original_dtype, roundtripped_dtype); } + + mod extension_roundtrip { + use vortex_session::VortexSession; + + use super::*; + use crate::dtype::extension::ExtDType; + use crate::dtype::session::DTypeSession; + use crate::dtype::session::DTypeSessionExt; + use crate::extension::tests::divisible_int::DivisibleInt; + use crate::extension::tests::divisible_int::Divisor; + + fn session_with_divisible_int() -> VortexSession { + let session = VortexSession::empty().with::(); + session.dtypes().register(DivisibleInt); + session + } + + fn divisible_ext(divisor: u64) -> DType { + let ext = ExtDType::::try_new( + Divisor(divisor), + DType::Primitive(PType::U64, Nullability::NonNullable), + ) + .unwrap(); + DType::Extension(ext.erased()) + } + + #[test] + fn forward_emits_name_and_base64_metadata() { + let dtype = DType::struct_([("div", divisible_ext(7))], Nullability::NonNullable); + + let schema = dtype.to_arrow_schema().unwrap(); + let field = schema.field(0); + + assert_eq!(field.data_type(), &DataType::UInt64); + assert_eq!( + field + .metadata() + .get(EXTENSION_TYPE_NAME_KEY) + .map(String::as_str), + Some("test.divisible_int"), + ); + + let meta_b64 = field.metadata().get(EXTENSION_TYPE_METADATA_KEY).unwrap(); + let decoded = BASE64_STANDARD.decode(meta_b64).unwrap(); + assert_eq!(decoded, 7u64.to_le_bytes()); + } + + #[test] + fn reverse_with_session_recovers_extension() { + let original = DType::struct_([("div", divisible_ext(42))], Nullability::NonNullable); + + let schema = original.to_arrow_schema().unwrap(); + let session = session_with_divisible_int(); + let recovered = DType::from_arrow_with_session(&schema, &session); + + assert_eq!(recovered, original); + } + + #[test] + fn reverse_without_registration_falls_back_to_storage() { + let original = DType::struct_([("div", divisible_ext(13))], Nullability::NonNullable); + + let schema = original.to_arrow_schema().unwrap(); + // DivisibleInt is not in the default DTypeSession. + let session = VortexSession::empty().with::(); + let recovered = DType::from_arrow_with_session(&schema, &session); + + let expected = DType::struct_( + [( + "div", + DType::Primitive(PType::U64, Nullability::NonNullable), + )], + Nullability::NonNullable, + ); + assert_eq!(recovered, expected); + } + + #[test] + fn nested_struct_roundtrip() { + let inner = DType::struct_([("div", divisible_ext(3))], Nullability::Nullable); + let original = DType::struct_([("inner", inner)], Nullability::NonNullable); + + let schema = original.to_arrow_schema().unwrap(); + let session = session_with_divisible_int(); + let recovered = DType::from_arrow_with_session(&schema, &session); + + assert_eq!(recovered, original); + } + + #[test] + fn list_element_roundtrip() { + let list_dtype = DType::List(Arc::new(divisible_ext(5)), Nullability::Nullable); + let original = DType::struct_([("xs", list_dtype)], Nullability::NonNullable); + + let schema = original.to_arrow_schema().unwrap(); + let session = session_with_divisible_int(); + let recovered = DType::from_arrow_with_session(&schema, &session); + + assert_eq!(recovered, original); + } + + #[test] + fn temporal_native_path_emits_no_extension_metadata() { + let ts = Timestamp::new_with_tz(TimeUnit::Microseconds, None, Nullability::Nullable); + let original = DType::struct_( + [("t", DType::Extension(ts.erased()))], + Nullability::NonNullable, + ); + + let schema = original.to_arrow_schema().unwrap(); + let field = schema.field(0); + + assert!(matches!( + field.data_type(), + DataType::Timestamp(ArrowTimeUnit::Microsecond, None) + )); + assert!(field.metadata().get(EXTENSION_TYPE_NAME_KEY).is_none()); + + let recovered = DType::from_arrow(&schema); + assert_eq!(recovered, original); + } + + #[test] + fn variant_still_roundtrips() { + let original = DType::struct_( + [("v", DType::Variant(Nullability::NonNullable))], + Nullability::NonNullable, + ); + let schema = original.to_arrow_schema().unwrap(); + let recovered = DType::from_arrow(&schema); + assert_eq!(recovered, original); + } + + #[test] + fn malformed_metadata_falls_back_to_storage() { + let field = Field::new("div", DataType::UInt64, false).with_metadata( + [ + ( + EXTENSION_TYPE_NAME_KEY.to_owned(), + "test.divisible_int".to_owned(), + ), + ( + EXTENSION_TYPE_METADATA_KEY.to_owned(), + "not_base64!!!".to_owned(), + ), + ] + .into(), + ); + let schema = Schema::new(Fields::from(vec![field])); + + let session = session_with_divisible_int(); + let recovered = DType::from_arrow_with_session(&schema, &session); + + let expected = DType::struct_( + [( + "div", + DType::Primitive(PType::U64, Nullability::NonNullable), + )], + Nullability::NonNullable, + ); + assert_eq!(recovered, expected); + } + } } diff --git a/vortex-array/src/extension/mod.rs b/vortex-array/src/extension/mod.rs index 9f81e7fb310..077af4a8337 100644 --- a/vortex-array/src/extension/mod.rs +++ b/vortex-array/src/extension/mod.rs @@ -9,7 +9,7 @@ pub mod datetime; pub mod uuid; #[cfg(test)] -mod tests; +pub(crate) mod tests; /// An empty metadata struct for extension dtypes that do not require any metadata. #[derive(Debug, Clone, PartialEq, Eq, Hash)] diff --git a/vortex-array/src/extension/tests/mod.rs b/vortex-array/src/extension/tests/mod.rs index 31df677e61d..f4ab560fbf8 100644 --- a/vortex-array/src/extension/tests/mod.rs +++ b/vortex-array/src/extension/tests/mod.rs @@ -3,4 +3,4 @@ //! Test extension types for exercising the [`ExtVTable`] contract. -mod divisible_int; +pub(crate) mod divisible_int; diff --git a/vortex-tensor/Cargo.toml b/vortex-tensor/Cargo.toml index 2f92ce5a107..8c405ce5d70 100644 --- a/vortex-tensor/Cargo.toml +++ b/vortex-tensor/Cargo.toml @@ -31,6 +31,7 @@ num-traits = { workspace = true } prost = { workspace = true } [dev-dependencies] +arrow-schema = { workspace = true } divan = { workspace = true } mimalloc = { workspace = true } rand = { workspace = true } diff --git a/vortex-tensor/src/lib.rs b/vortex-tensor/src/lib.rs index c748bdd9f43..a47ab88b061 100644 --- a/vortex-tensor/src/lib.rs +++ b/vortex-tensor/src/lib.rs @@ -80,4 +80,6 @@ mod tests { crate::initialize(&session); session }); + + mod arrow_roundtrip; } diff --git a/vortex-tensor/src/tests/arrow_roundtrip.rs b/vortex-tensor/src/tests/arrow_roundtrip.rs new file mode 100644 index 00000000000..2b3da425971 --- /dev/null +++ b/vortex-tensor/src/tests/arrow_roundtrip.rs @@ -0,0 +1,164 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +//! Arrow ↔ DType round-trip tests for tensor extension types. + +use std::sync::Arc; + +use arrow_schema::DataType; +use arrow_schema::TimeUnit as ArrowTimeUnit; +use arrow_schema::extension::EXTENSION_TYPE_METADATA_KEY; +use arrow_schema::extension::EXTENSION_TYPE_NAME_KEY; +use vortex_array::dtype::DType; +use vortex_array::dtype::Nullability; +use vortex_array::dtype::PType; +use vortex_array::dtype::arrow::FromArrowWithSession; +use vortex_array::dtype::extension::ExtDType; +use vortex_array::extension::datetime::TimeUnit; +use vortex_array::extension::datetime::Timestamp; + +use crate::tests::SESSION; +use crate::types::fixed_shape::FixedShapeTensor; +use crate::types::fixed_shape::FixedShapeTensorMetadata; +use crate::types::vector::Vector; + +const VECTOR_EXT_NAME: &str = "vortex.tensor.vector"; +const FIXED_SHAPE_EXT_NAME: &str = "vortex.fixed_shape_tensor"; + +fn vector_dtype(len: u32) -> DType { + let storage = DType::FixedSizeList( + Arc::new(DType::Primitive(PType::F32, Nullability::NonNullable)), + len, + Nullability::NonNullable, + ); + let ext = ExtDType::::try_new(vortex_array::extension::EmptyMetadata, storage).unwrap(); + DType::Extension(ext.erased()) +} + +fn fixed_shape_dtype(metadata: FixedShapeTensorMetadata, element_count: u32) -> DType { + let storage = DType::FixedSizeList( + Arc::new(DType::Primitive(PType::F32, Nullability::NonNullable)), + element_count, + Nullability::NonNullable, + ); + let ext = ExtDType::::try_new(metadata, storage).unwrap(); + DType::Extension(ext.erased()) +} + +#[test] +fn vector_forward_carries_extension_name() { + let original = DType::struct_([("embedding", vector_dtype(4))], Nullability::NonNullable); + + let schema = original.to_arrow_schema().unwrap(); + let field = schema.field(0); + + assert_eq!( + field + .metadata() + .get(EXTENSION_TYPE_NAME_KEY) + .map(String::as_str), + Some(VECTOR_EXT_NAME), + ); + // EmptyMetadata: no metadata key emitted. + assert!(field.metadata().get(EXTENSION_TYPE_METADATA_KEY).is_none()); + + let DataType::FixedSizeList(element, size) = field.data_type() else { + panic!("expected FixedSizeList, got {:?}", field.data_type()); + }; + assert_eq!(*size, 4); + assert_eq!(element.data_type(), &DataType::Float32); +} + +#[test] +fn vector_roundtrip_with_session() { + let original = DType::struct_([("embedding", vector_dtype(128))], Nullability::NonNullable); + + let schema = original.to_arrow_schema().unwrap(); + let recovered = DType::from_arrow_with_session(&schema, &SESSION); + + assert_eq!(recovered, original); +} + +#[test] +fn vector_without_registration_falls_back_to_fsl() { + let original = DType::struct_([("embedding", vector_dtype(16))], Nullability::NonNullable); + + let empty_session = vortex_session::VortexSession::empty(); + let schema = original.to_arrow_schema().unwrap(); + let recovered = DType::from_arrow_with_session(&schema, &empty_session); + + let expected = DType::struct_( + [( + "embedding", + DType::FixedSizeList( + Arc::new(DType::Primitive(PType::F32, Nullability::NonNullable)), + 16, + Nullability::NonNullable, + ), + )], + Nullability::NonNullable, + ); + assert_eq!(recovered, expected); +} + +#[test] +fn fixed_shape_tensor_metadata_roundtrip() { + let metadata = FixedShapeTensorMetadata::new(vec![2, 3, 4]) + .with_dim_names(vec!["x".into(), "y".into(), "z".into()]) + .unwrap() + .with_permutation(vec![2, 0, 1]) + .unwrap(); + + let original = DType::struct_( + [("tensor", fixed_shape_dtype(metadata, 24))], + Nullability::NonNullable, + ); + + let schema = original.to_arrow_schema().unwrap(); + let field = schema.field(0); + + assert_eq!( + field + .metadata() + .get(EXTENSION_TYPE_NAME_KEY) + .map(String::as_str), + Some(FIXED_SHAPE_EXT_NAME), + ); + assert!(field.metadata().get(EXTENSION_TYPE_METADATA_KEY).is_some()); + + let recovered = DType::from_arrow_with_session(&schema, &SESSION); + assert_eq!(recovered, original); +} + +#[test] +fn tensor_inside_nested_struct_roundtrips() { + let inner = DType::struct_([("embedding", vector_dtype(8))], Nullability::Nullable); + let original = DType::struct_( + [("inner", inner), ("id", DType::Utf8(Nullability::Nullable))], + Nullability::NonNullable, + ); + + let schema = original.to_arrow_schema().unwrap(); + let recovered = DType::from_arrow_with_session(&schema, &SESSION); + + assert_eq!(recovered, original); +} + +#[test] +fn temporal_extension_still_uses_native_arrow() { + let ts = Timestamp::new_with_tz(TimeUnit::Microseconds, None, Nullability::Nullable); + let original = DType::struct_( + [("ts", DType::Extension(ts.erased()))], + Nullability::NonNullable, + ); + + let schema = original.to_arrow_schema().unwrap(); + let field = schema.field(0); + + assert!(matches!( + field.data_type(), + DataType::Timestamp(ArrowTimeUnit::Microsecond, None) + )); + assert!(field.metadata().get(EXTENSION_TYPE_NAME_KEY).is_none()); + assert!(field.metadata().get(EXTENSION_TYPE_METADATA_KEY).is_none()); +} From 47efbdbd4ba3d4ec6f222b130d6253b36c121d92 Mon Sep 17 00:00:00 2001 From: Baris Palaska Date: Wed, 22 Apr 2026 16:51:41 +0100 Subject: [PATCH 02/21] use arrow canonical extension name Signed-off-by: Baris Palaska --- Cargo.lock | 2 + vortex-array/src/dtype/arrow.rs | 69 ++++++++-- vortex-tensor/Cargo.toml | 3 + vortex-tensor/src/tests/arrow_roundtrip.rs | 11 +- .../src/types/fixed_shape/canonical.rs | 126 ++++++++++++++++++ vortex-tensor/src/types/fixed_shape/mod.rs | 2 +- vortex-tensor/src/types/fixed_shape/proto.rs | 90 ------------- vortex-tensor/src/types/fixed_shape/vtable.rs | 6 +- 8 files changed, 201 insertions(+), 108 deletions(-) create mode 100644 vortex-tensor/src/types/fixed_shape/canonical.rs delete mode 100644 vortex-tensor/src/types/fixed_shape/proto.rs diff --git a/Cargo.lock b/Cargo.lock index 0ad391e531f..10743c6fe81 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -11103,6 +11103,8 @@ dependencies = [ "rand 0.10.1", "rand_distr 0.6.0", "rstest", + "serde", + "serde_json", "vortex-array", "vortex-btrblocks", "vortex-buffer", diff --git a/vortex-array/src/dtype/arrow.rs b/vortex-array/src/dtype/arrow.rs index 96d19e20f19..d78004f19c9 100644 --- a/vortex-array/src/dtype/arrow.rs +++ b/vortex-array/src/dtype/arrow.rs @@ -55,6 +55,26 @@ use crate::extension::datetime::Timestamp; const ARROW_EXT_NAME_VARIANT: &str = "arrow.parquet.variant"; +/// `(vortex_id, arrow_canonical_name)` pairs — single source of truth for bijection between +/// Vortex-internal extension ids and Arrow canonical extension names. Canonical extensions +/// serialize metadata as raw UTF-8 (typically JSON) rather than base64-wrapped bytes. +const CANONICAL_ALIASES: &[(&str, &str)] = + &[("vortex.fixed_shape_tensor", "arrow.fixed_shape_tensor")]; + +fn vortex_id_to_arrow_canonical(vortex_id: &str) -> Option<&'static str> { + CANONICAL_ALIASES + .iter() + .find(|(v, _)| *v == vortex_id) + .map(|(_, a)| *a) +} + +fn arrow_canonical_to_vortex_id(arrow_name: &str) -> Option<&'static str> { + CANONICAL_ALIASES + .iter() + .find(|(_, a)| *a == arrow_name) + .map(|(v, _)| *v) +} + /// Trait for converting Arrow types to Vortex types. pub trait FromArrowType: Sized { /// Convert the Arrow type to a Vortex type. @@ -277,7 +297,10 @@ fn dtype_from_field(field: &Field, dtypes: &DTypeSession) -> DType { return storage_dtype; }; - let ext_id = ExtId::new(ext_name); + let canonical_alias = arrow_canonical_to_vortex_id(ext_name); + let is_canonical = canonical_alias.is_some(); + let ext_id = ExtId::new(canonical_alias.unwrap_or(ext_name)); + let Some(plugin) = dtypes.registry().find(&ext_id) else { tracing::warn!( "Arrow field {:?} extension id {:?} not registered; using storage dtype", @@ -287,7 +310,7 @@ fn dtype_from_field(field: &Field, dtypes: &DTypeSession) -> DType { return storage_dtype; }; - let metadata_bytes = match decode_extension_metadata(field) { + let metadata_bytes = match decode_extension_metadata(field, is_canonical) { Ok(bytes) => bytes, Err(e) => { tracing::warn!( @@ -316,10 +339,15 @@ fn dtype_from_field(field: &Field, dtypes: &DTypeSession) -> DType { } } -/// Decodes base64-encoded extension metadata. Missing / empty values yield an empty vector. -fn decode_extension_metadata(field: &Field) -> VortexResult> { +/// Decode extension metadata bytes from a Field. +/// +/// Canonical Arrow extensions store UTF-8 bytes directly (e.g. JSON). Non-canonical extensions +/// store base64-encoded bytes so that arbitrary binary plugin output survives a String-typed +/// metadata channel. +fn decode_extension_metadata(field: &Field, is_canonical: bool) -> VortexResult> { match field.extension_type_metadata() { None | Some("") => Ok(Vec::new()), + Some(s) if is_canonical => Ok(s.as_bytes().to_vec()), Some(s) => BASE64_STANDARD .decode(s) .map_err(|e| vortex_err!("failed to base64-decode {EXTENSION_TYPE_METADATA_KEY}: {e}")), @@ -475,16 +503,25 @@ fn field_from_dtype(name: &str, dtype: &DType) -> VortexResult { } let storage_arrow = ext.storage_dtype().to_arrow_dtype()?; - let mut metadata = vec![( - EXTENSION_TYPE_NAME_KEY.to_owned(), - ext.id().as_str().to_owned(), - )]; let ext_meta_bytes = ext.serialize_metadata()?; - if !ext_meta_bytes.is_empty() { - metadata.push(( - EXTENSION_TYPE_METADATA_KEY.to_owned(), + let (ext_name, meta_str) = match vortex_id_to_arrow_canonical(ext.id().as_str()) { + Some(canonical) => { + // Canonical Arrow extensions specify a UTF-8 metadata format (typically JSON), + // read as-is by arrow-rs / pyarrow. The plugin owns producing those bytes. + let s = String::from_utf8(ext_meta_bytes).map_err(|e| { + vortex_err!("canonical extension {canonical} metadata must be valid UTF-8: {e}") + })?; + (canonical.to_owned(), s) + } + None => ( + ext.id().as_str().to_owned(), BASE64_STANDARD.encode(&ext_meta_bytes), - )); + ), + }; + + let mut metadata = vec![(EXTENSION_TYPE_NAME_KEY.to_owned(), ext_name)]; + if !meta_str.is_empty() { + metadata.push((EXTENSION_TYPE_METADATA_KEY.to_owned(), meta_str)); } return Ok(Field::new(name, storage_arrow, dtype.is_nullable()) .with_metadata(metadata.into_iter().collect())); @@ -689,6 +726,14 @@ mod test { schema_null.to_arrow_schema().unwrap(); } + #[test] + fn canonical_aliases_bijection() { + for (vortex_id, arrow_name) in CANONICAL_ALIASES { + assert_eq!(vortex_id_to_arrow_canonical(vortex_id), Some(*arrow_name)); + assert_eq!(arrow_canonical_to_vortex_id(arrow_name), Some(*vortex_id)); + } + } + #[test] fn test_unicode_field_names_roundtrip() { // Regression test for https://github.com/vortex-data/vortex/issues/5979. diff --git a/vortex-tensor/Cargo.toml b/vortex-tensor/Cargo.toml index 8c405ce5d70..71d673d539a 100644 --- a/vortex-tensor/Cargo.toml +++ b/vortex-tensor/Cargo.toml @@ -29,6 +29,8 @@ half = { workspace = true } itertools = { workspace = true } num-traits = { workspace = true } prost = { workspace = true } +serde = { workspace = true, features = ["derive"] } +serde_json = { workspace = true } [dev-dependencies] arrow-schema = { workspace = true } @@ -37,4 +39,5 @@ mimalloc = { workspace = true } rand = { workspace = true } rand_distr = { workspace = true } rstest = { workspace = true } +serde_json = { workspace = true } vortex-btrblocks = { path = "../vortex-btrblocks" } diff --git a/vortex-tensor/src/tests/arrow_roundtrip.rs b/vortex-tensor/src/tests/arrow_roundtrip.rs index 2b3da425971..2d8a64466f8 100644 --- a/vortex-tensor/src/tests/arrow_roundtrip.rs +++ b/vortex-tensor/src/tests/arrow_roundtrip.rs @@ -23,7 +23,7 @@ use crate::types::fixed_shape::FixedShapeTensorMetadata; use crate::types::vector::Vector; const VECTOR_EXT_NAME: &str = "vortex.tensor.vector"; -const FIXED_SHAPE_EXT_NAME: &str = "vortex.fixed_shape_tensor"; +const FIXED_SHAPE_EXT_NAME: &str = "arrow.fixed_shape_tensor"; fn vector_dtype(len: u32) -> DType { let storage = DType::FixedSizeList( @@ -124,7 +124,14 @@ fn fixed_shape_tensor_metadata_roundtrip() { .map(String::as_str), Some(FIXED_SHAPE_EXT_NAME), ); - assert!(field.metadata().get(EXTENSION_TYPE_METADATA_KEY).is_some()); + + // Canonical extensions put raw JSON on the wire — pyarrow / arrow-rs read it directly + // without base64. Parse it back to confirm the on-wire format. + let meta_str = field.metadata().get(EXTENSION_TYPE_METADATA_KEY).unwrap(); + let parsed: serde_json::Value = serde_json::from_str(meta_str).unwrap(); + assert_eq!(parsed["shape"], serde_json::json!([2, 3, 4])); + assert_eq!(parsed["dim_names"], serde_json::json!(["x", "y", "z"])); + assert_eq!(parsed["permutation"], serde_json::json!([2, 0, 1])); let recovered = DType::from_arrow_with_session(&schema, &SESSION); assert_eq!(recovered, original); diff --git a/vortex-tensor/src/types/fixed_shape/canonical.rs b/vortex-tensor/src/types/fixed_shape/canonical.rs new file mode 100644 index 00000000000..e1a2d6ec300 --- /dev/null +++ b/vortex-tensor/src/types/fixed_shape/canonical.rs @@ -0,0 +1,126 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +//! Arrow canonical [`arrow.fixed_shape_tensor`] metadata serialization. +//! +//! The wire format is a UTF-8 JSON object placed in `ARROW:extension:metadata`, matching the +//! Arrow specification and pyarrow / arrow-rs interop expectations. +//! +//! We roll our own serde rather than delegating to `arrow_schema::extension::FixedShapeTensor` +//! because arrow-rs 58 serializes the field as `"permutations"` (plural) while the Arrow +//! specification and pyarrow use `"permutation"` (singular). pyarrow silently ignores the +//! misspelled key. +//! +//! [`arrow.fixed_shape_tensor`]: https://arrow.apache.org/docs/format/CanonicalExtensions.html#fixed-shape-tensor + +use serde::Deserialize; +use serde::Serialize; +use vortex_error::VortexResult; +use vortex_error::vortex_err; + +use crate::types::fixed_shape::FixedShapeTensorMetadata; + +#[derive(Serialize)] +struct WireRef<'a> { + shape: &'a [usize], + #[serde(skip_serializing_if = "Option::is_none")] + dim_names: Option<&'a [String]>, + #[serde(skip_serializing_if = "Option::is_none")] + permutation: Option<&'a [usize]>, +} + +#[derive(Deserialize)] +struct Wire { + shape: Vec, + #[serde(default)] + dim_names: Option>, + #[serde(default)] + permutation: Option>, +} + +/// Serialize [`FixedShapeTensorMetadata`] to the Arrow canonical JSON representation. +pub(crate) fn serialize(metadata: &FixedShapeTensorMetadata) -> VortexResult> { + let wire = WireRef { + shape: metadata.logical_shape(), + dim_names: metadata.dim_names(), + permutation: metadata.permutation(), + }; + serde_json::to_vec(&wire) + .map_err(|e| vortex_err!("fixed_shape_tensor canonical serialize: {e}")) +} + +/// Deserialize [`FixedShapeTensorMetadata`] from Arrow canonical JSON bytes. +pub(crate) fn deserialize(bytes: &[u8]) -> VortexResult { + let wire: Wire = serde_json::from_slice(bytes) + .map_err(|e| vortex_err!("fixed_shape_tensor canonical deserialize: {e}"))?; + + let mut m = FixedShapeTensorMetadata::new(wire.shape); + if let Some(names) = wire.dim_names { + m = m.with_dim_names(names)?; + } + if let Some(perm) = wire.permutation { + m = m.with_permutation(perm)?; + } + Ok(m) +} + +#[cfg(test)] +mod tests { + use rstest::rstest; + + use super::*; + + #[rstest] + #[case::scalar_0d(FixedShapeTensorMetadata::new(vec![]))] + #[case::vector_1d(FixedShapeTensorMetadata::new(vec![5]))] + #[case::shape_only(FixedShapeTensorMetadata::new(vec![2, 3, 4]))] + #[case::with_dim_names( + FixedShapeTensorMetadata::new(vec![3, 4]) + .with_dim_names(vec!["rows".into(), "cols".into()]) + .unwrap() + )] + #[case::with_permutation( + FixedShapeTensorMetadata::new(vec![2, 3, 4]) + .with_permutation(vec![2, 0, 1]) + .unwrap() + )] + #[case::all_fields( + FixedShapeTensorMetadata::new(vec![2, 3, 4]) + .with_dim_names(vec!["x".into(), "y".into(), "z".into()]).unwrap() + .with_permutation(vec![1, 2, 0]).unwrap() + )] + fn roundtrip(#[case] metadata: FixedShapeTensorMetadata) -> VortexResult<()> { + let bytes = serialize(&metadata)?; + let decoded = deserialize(&bytes)?; + assert_eq!(decoded, metadata); + Ok(()) + } + + #[test] + fn wire_format_matches_arrow_spec() -> VortexResult<()> { + let metadata = FixedShapeTensorMetadata::new(vec![2, 3, 4]) + .with_dim_names(vec!["x".into(), "y".into(), "z".into()])? + .with_permutation(vec![1, 2, 0])?; + + let bytes = serialize(&metadata)?; + let v: serde_json::Value = + serde_json::from_slice(&bytes).map_err(|e| vortex_err!("parse wire: {e}"))?; + + assert_eq!(v["shape"], serde_json::json!([2, 3, 4])); + assert_eq!(v["dim_names"], serde_json::json!(["x", "y", "z"])); + // Arrow spec uses singular "permutation"; guard against regressions to arrow-rs's plural. + assert_eq!(v["permutation"], serde_json::json!([1, 2, 0])); + assert!(v.get("permutations").is_none()); + Ok(()) + } + + #[test] + fn omits_optional_fields_when_unset() -> VortexResult<()> { + let bytes = serialize(&FixedShapeTensorMetadata::new(vec![5]))?; + let v: serde_json::Value = + serde_json::from_slice(&bytes).map_err(|e| vortex_err!("parse wire: {e}"))?; + assert!(v.get("dim_names").is_none()); + assert!(v.get("permutation").is_none()); + Ok(()) + } +} diff --git a/vortex-tensor/src/types/fixed_shape/mod.rs b/vortex-tensor/src/types/fixed_shape/mod.rs index 48f991517ec..602a6ecd637 100644 --- a/vortex-tensor/src/types/fixed_shape/mod.rs +++ b/vortex-tensor/src/types/fixed_shape/mod.rs @@ -14,5 +14,5 @@ pub use matcher::FixedShapeTensorMatcherMetadata; mod metadata; pub use metadata::FixedShapeTensorMetadata; -mod proto; +mod canonical; mod vtable; diff --git a/vortex-tensor/src/types/fixed_shape/proto.rs b/vortex-tensor/src/types/fixed_shape/proto.rs deleted file mode 100644 index 89b3db4289d..00000000000 --- a/vortex-tensor/src/types/fixed_shape/proto.rs +++ /dev/null @@ -1,90 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: Copyright the Vortex contributors - -//! Protobuf serialization for [`FixedShapeTensorMetadata`]. - -use prost::Message; -use vortex_error::VortexExpect; -use vortex_error::VortexResult; -use vortex_error::vortex_err; - -use crate::types::fixed_shape::FixedShapeTensorMetadata; - -/// Protobuf representation of [`FixedShapeTensorMetadata`]. -/// -/// Protobuf does not distinguish between an absent repeated field and an empty one (both will -/// deserialize as an empty `Vec`). This is fine because the semantic meaning is unambiguous: -/// -/// - `logical_shape` empty: 0-dimensional (scalar) tensor. -/// - `dim_names` empty: no dimension names (`None`). -/// - `permutation` empty: no permutation, i.e., identity layout (`None`). -#[derive(Clone, PartialEq, Message)] -struct FixedShapeTensorMetadataProto { - /// The size of each logical dimension. Empty for a 0-dimensional scalar tensor. - #[prost(uint32, repeated, tag = "1")] - logical_shape: Vec, - - /// Optional human-readable names for each logical dimension. When present, must have the - /// same length as `logical_shape`. Empty means no names are set. - #[prost(string, repeated, tag = "2")] - dim_names: Vec, - - /// Optional dimension permutation mapping logical to physical indices. When present, must - /// be a permutation of `[0, 1, ..., N-1]`. Empty means identity (row-major) layout. - #[prost(uint32, repeated, tag = "3")] - permutation: Vec, -} - -/// Serializes [`FixedShapeTensorMetadata`] to protobuf bytes. -pub(crate) fn serialize(metadata: &FixedShapeTensorMetadata) -> Vec { - let logical_shape = metadata - .logical_shape() - .iter() - .map(|&d| u32::try_from(d).vortex_expect("dimension size exceeds u32")) - .collect(); - - let dim_names = metadata.dim_names().map(|n| n.to_vec()).unwrap_or_default(); - - let permutation = metadata - .permutation() - .map(|p| { - p.iter() - .map(|&i| u32::try_from(i).vortex_expect("permutation index exceeds u32")) - .collect() - }) - .unwrap_or_default(); - - let proto = FixedShapeTensorMetadataProto { - logical_shape, - dim_names, - permutation, - }; - proto.encode_to_vec() -} - -/// Deserializes [`FixedShapeTensorMetadata`] from protobuf bytes. -/// -/// For 0-dimensional tensors, all three repeated fields are empty, which correctly produces a -/// metadata with an empty shape and no names or permutation. -pub(crate) fn deserialize(bytes: &[u8]) -> VortexResult { - let proto = FixedShapeTensorMetadataProto::decode(bytes).map_err(|e| vortex_err!("{e}"))?; - - let logical_shape = proto - .logical_shape - .into_iter() - .map(|d| d as usize) - .collect(); - let mut m = FixedShapeTensorMetadata::new(logical_shape); - - // Note that this is fine for 0 dimensions since if we do not have any dimensions, we cannot - // have any names or permutations. - if !proto.dim_names.is_empty() { - m = m.with_dim_names(proto.dim_names)?; - } - if !proto.permutation.is_empty() { - let permutation = proto.permutation.into_iter().map(|i| i as usize).collect(); - m = m.with_permutation(permutation)?; - } - - Ok(m) -} diff --git a/vortex-tensor/src/types/fixed_shape/vtable.rs b/vortex-tensor/src/types/fixed_shape/vtable.rs index 7d69b9a7c44..89e1aa3c719 100644 --- a/vortex-tensor/src/types/fixed_shape/vtable.rs +++ b/vortex-tensor/src/types/fixed_shape/vtable.rs @@ -13,7 +13,7 @@ use vortex_error::vortex_ensure_eq; use crate::types::fixed_shape::FixedShapeTensor; use crate::types::fixed_shape::FixedShapeTensorMetadata; -use crate::types::fixed_shape::proto; +use crate::types::fixed_shape::canonical; impl ExtVTable for FixedShapeTensor { type Metadata = FixedShapeTensorMetadata; @@ -26,11 +26,11 @@ impl ExtVTable for FixedShapeTensor { } fn serialize_metadata(&self, metadata: &Self::Metadata) -> VortexResult> { - Ok(proto::serialize(metadata)) + canonical::serialize(metadata) } fn deserialize_metadata(&self, metadata: &[u8]) -> VortexResult { - proto::deserialize(metadata) + canonical::deserialize(metadata) } fn validate_dtype(ext_dtype: &ExtDType) -> VortexResult<()> { From ad80b531030092f8efcf1b694ed2dd61a45822db Mon Sep 17 00:00:00 2001 From: Baris Palaska Date: Thu, 23 Apr 2026 21:51:48 +0100 Subject: [PATCH 03/21] externalize canonical aliases Signed-off-by: Baris Palaska --- vortex-array/public-api.lock | 58 ++--- vortex-array/src/dtype/arrow.rs | 239 ++++++++---------- vortex-array/src/dtype/session.rs | 39 ++- vortex-tensor/src/lib.rs | 8 +- vortex-tensor/src/tests/arrow_roundtrip.rs | 4 +- vortex-tensor/src/types/fixed_shape/mod.rs | 3 + vortex-tensor/src/types/fixed_shape/vtable.rs | 3 +- 7 files changed, 186 insertions(+), 168 deletions(-) diff --git a/vortex-array/public-api.lock b/vortex-array/public-api.lock index 554e6509131..f2573ce40b6 100644 --- a/vortex-array/public-api.lock +++ b/vortex-array/public-api.lock @@ -8472,45 +8472,37 @@ pub trait vortex_array::dtype::arrow::FromArrowType: core::marker::Sized pub fn vortex_array::dtype::arrow::FromArrowType::from_arrow(value: T) -> Self +pub fn vortex_array::dtype::arrow::FromArrowType::from_arrow_with_session(value: T, _session: &vortex_session::VortexSession) -> Self + impl vortex_array::dtype::arrow::FromArrowType<&arrow_schema::field::Field> for vortex_array::dtype::DType pub fn vortex_array::dtype::DType::from_arrow(field: &arrow_schema::field::Field) -> Self +pub fn vortex_array::dtype::DType::from_arrow_with_session(field: &arrow_schema::field::Field, session: &vortex_session::VortexSession) -> Self + impl vortex_array::dtype::arrow::FromArrowType<&arrow_schema::fields::Fields> for vortex_array::dtype::StructFields pub fn vortex_array::dtype::StructFields::from_arrow(value: &arrow_schema::fields::Fields) -> Self +pub fn vortex_array::dtype::StructFields::from_arrow_with_session(value: &arrow_schema::fields::Fields, session: &vortex_session::VortexSession) -> Self + impl vortex_array::dtype::arrow::FromArrowType<&arrow_schema::schema::Schema> for vortex_array::dtype::DType pub fn vortex_array::dtype::DType::from_arrow(value: &arrow_schema::schema::Schema) -> Self +pub fn vortex_array::dtype::DType::from_arrow_with_session(value: &arrow_schema::schema::Schema, session: &vortex_session::VortexSession) -> Self + impl vortex_array::dtype::arrow::FromArrowType<(&arrow_schema::datatype::DataType, vortex_array::dtype::Nullability)> for vortex_array::dtype::DType pub fn vortex_array::dtype::DType::from_arrow((data_type, nullability): (&arrow_schema::datatype::DataType, vortex_array::dtype::Nullability)) -> Self +pub fn vortex_array::dtype::DType::from_arrow_with_session(value: T, _session: &vortex_session::VortexSession) -> Self + impl vortex_array::dtype::arrow::FromArrowType> for vortex_array::dtype::DType pub fn vortex_array::dtype::DType::from_arrow(value: arrow_schema::schema::SchemaRef) -> Self -pub trait vortex_array::dtype::arrow::FromArrowWithSession: core::marker::Sized - -pub fn vortex_array::dtype::arrow::FromArrowWithSession::from_arrow_with_session(value: T, session: &vortex_session::VortexSession) -> Self - -impl vortex_array::dtype::arrow::FromArrowWithSession<&alloc::sync::Arc> for vortex_array::dtype::DType - -pub fn vortex_array::dtype::DType::from_arrow_with_session(value: &arrow_schema::schema::SchemaRef, session: &vortex_session::VortexSession) -> Self - -impl vortex_array::dtype::arrow::FromArrowWithSession<&arrow_schema::field::Field> for vortex_array::dtype::DType - -pub fn vortex_array::dtype::DType::from_arrow_with_session(field: &arrow_schema::field::Field, session: &vortex_session::VortexSession) -> Self - -impl vortex_array::dtype::arrow::FromArrowWithSession<&arrow_schema::fields::Fields> for vortex_array::dtype::StructFields - -pub fn vortex_array::dtype::StructFields::from_arrow_with_session(value: &arrow_schema::fields::Fields, session: &vortex_session::VortexSession) -> Self - -impl vortex_array::dtype::arrow::FromArrowWithSession<&arrow_schema::schema::Schema> for vortex_array::dtype::DType - -pub fn vortex_array::dtype::DType::from_arrow_with_session(value: &arrow_schema::schema::Schema, session: &vortex_session::VortexSession) -> Self +pub fn vortex_array::dtype::DType::from_arrow_with_session(value: arrow_schema::schema::SchemaRef, session: &vortex_session::VortexSession) -> Self pub trait vortex_array::dtype::arrow::TryFromArrowType: core::marker::Sized @@ -8814,10 +8806,16 @@ pub struct vortex_array::dtype::session::DTypeSession impl vortex_array::dtype::session::DTypeSession +pub fn vortex_array::dtype::session::DTypeSession::arrow_canonical_for(&self, vortex_id: &vortex_array::dtype::extension::ExtId) -> core::option::Option<&'static str> + pub fn vortex_array::dtype::session::DTypeSession::register(&self, vtable: V) +pub fn vortex_array::dtype::session::DTypeSession::register_arrow_canonical(&self, vortex_id: vortex_array::dtype::extension::ExtId, arrow_name: &'static str) + pub fn vortex_array::dtype::session::DTypeSession::registry(&self) -> &vortex_array::dtype::session::ExtDTypeRegistry +pub fn vortex_array::dtype::session::DTypeSession::vortex_id_for_arrow_canonical(&self, arrow_name: &str) -> core::option::Option + impl core::default::Default for vortex_array::dtype::session::DTypeSession pub fn vortex_array::dtype::session::DTypeSession::default() -> Self @@ -8984,6 +8982,8 @@ pub fn vortex_array::dtype::DType::to_arrow_dtype(&self) -> vortex_error::Vortex pub fn vortex_array::dtype::DType::to_arrow_schema(&self) -> vortex_error::VortexResult +pub fn vortex_array::dtype::DType::to_arrow_schema_with_session(&self, session: &vortex_session::VortexSession) -> vortex_error::VortexResult + impl core::clone::Clone for vortex_array::dtype::DType pub fn vortex_array::dtype::DType::clone(&self) -> vortex_array::dtype::DType @@ -9046,29 +9046,25 @@ impl vortex_array::dtype::arrow::FromArrowType<&arrow_schema::field::Field> for pub fn vortex_array::dtype::DType::from_arrow(field: &arrow_schema::field::Field) -> Self +pub fn vortex_array::dtype::DType::from_arrow_with_session(field: &arrow_schema::field::Field, session: &vortex_session::VortexSession) -> Self + impl vortex_array::dtype::arrow::FromArrowType<&arrow_schema::schema::Schema> for vortex_array::dtype::DType pub fn vortex_array::dtype::DType::from_arrow(value: &arrow_schema::schema::Schema) -> Self +pub fn vortex_array::dtype::DType::from_arrow_with_session(value: &arrow_schema::schema::Schema, session: &vortex_session::VortexSession) -> Self + impl vortex_array::dtype::arrow::FromArrowType<(&arrow_schema::datatype::DataType, vortex_array::dtype::Nullability)> for vortex_array::dtype::DType pub fn vortex_array::dtype::DType::from_arrow((data_type, nullability): (&arrow_schema::datatype::DataType, vortex_array::dtype::Nullability)) -> Self +pub fn vortex_array::dtype::DType::from_arrow_with_session(value: T, _session: &vortex_session::VortexSession) -> Self + impl vortex_array::dtype::arrow::FromArrowType> for vortex_array::dtype::DType pub fn vortex_array::dtype::DType::from_arrow(value: arrow_schema::schema::SchemaRef) -> Self -impl vortex_array::dtype::arrow::FromArrowWithSession<&alloc::sync::Arc> for vortex_array::dtype::DType - -pub fn vortex_array::dtype::DType::from_arrow_with_session(value: &arrow_schema::schema::SchemaRef, session: &vortex_session::VortexSession) -> Self - -impl vortex_array::dtype::arrow::FromArrowWithSession<&arrow_schema::field::Field> for vortex_array::dtype::DType - -pub fn vortex_array::dtype::DType::from_arrow_with_session(field: &arrow_schema::field::Field, session: &vortex_session::VortexSession) -> Self - -impl vortex_array::dtype::arrow::FromArrowWithSession<&arrow_schema::schema::Schema> for vortex_array::dtype::DType - -pub fn vortex_array::dtype::DType::from_arrow_with_session(value: &arrow_schema::schema::Schema, session: &vortex_session::VortexSession) -> Self +pub fn vortex_array::dtype::DType::from_arrow_with_session(value: arrow_schema::schema::SchemaRef, session: &vortex_session::VortexSession) -> Self impl vortex_flatbuffers::FlatBufferRoot for vortex_array::dtype::DType @@ -10024,8 +10020,6 @@ impl vortex_array::dtype::arrow::FromArrowType<&arrow_schema::fields::Fields> fo pub fn vortex_array::dtype::StructFields::from_arrow(value: &arrow_schema::fields::Fields) -> Self -impl vortex_array::dtype::arrow::FromArrowWithSession<&arrow_schema::fields::Fields> for vortex_array::dtype::StructFields - pub fn vortex_array::dtype::StructFields::from_arrow_with_session(value: &arrow_schema::fields::Fields, session: &vortex_session::VortexSession) -> Self impl core::iter::traits::collect::FromIterator<(T, V)> for vortex_array::dtype::StructFields where T: core::convert::Into, V: core::convert::Into diff --git a/vortex-array/src/dtype/arrow.rs b/vortex-array/src/dtype/arrow.rs index d78004f19c9..cf15e3aebcc 100644 --- a/vortex-array/src/dtype/arrow.rs +++ b/vortex-array/src/dtype/arrow.rs @@ -55,30 +55,17 @@ use crate::extension::datetime::Timestamp; const ARROW_EXT_NAME_VARIANT: &str = "arrow.parquet.variant"; -/// `(vortex_id, arrow_canonical_name)` pairs — single source of truth for bijection between -/// Vortex-internal extension ids and Arrow canonical extension names. Canonical extensions -/// serialize metadata as raw UTF-8 (typically JSON) rather than base64-wrapped bytes. -const CANONICAL_ALIASES: &[(&str, &str)] = - &[("vortex.fixed_shape_tensor", "arrow.fixed_shape_tensor")]; - -fn vortex_id_to_arrow_canonical(vortex_id: &str) -> Option<&'static str> { - CANONICAL_ALIASES - .iter() - .find(|(v, _)| *v == vortex_id) - .map(|(_, a)| *a) -} - -fn arrow_canonical_to_vortex_id(arrow_name: &str) -> Option<&'static str> { - CANONICAL_ALIASES - .iter() - .find(|(_, a)| *a == arrow_name) - .map(|(v, _)| *v) -} - /// Trait for converting Arrow types to Vortex types. pub trait FromArrowType: Sized { /// Convert the Arrow type to a Vortex type. fn from_arrow(value: T) -> Self; + + /// Convert the Arrow type to a Vortex type, consulting `session` for extension lookup. + /// + /// Unregistered or malformed extension metadata falls back to the storage dtype. + fn from_arrow_with_session(value: T, _session: &VortexSession) -> Self { + Self::from_arrow(value) + } } /// Trait for converting Vortex types to Arrow types. @@ -87,15 +74,6 @@ pub trait TryFromArrowType: Sized { fn try_from_arrow(value: T) -> VortexResult; } -/// Conversion from Arrow types to Vortex types using a session's extension dtype registry to -/// resolve `ARROW:extension:name` metadata into [`DType::Extension`] values. -/// -/// Unregistered or malformed extension metadata falls back to the storage dtype. -pub trait FromArrowWithSession: Sized { - /// Convert the Arrow type to a Vortex type. - fn from_arrow_with_session(value: T, session: &VortexSession) -> Self; -} - impl TryFromArrowType<&DataType> for PType { fn try_from_arrow(value: &DataType) -> VortexResult { match value { @@ -167,29 +145,19 @@ impl TryFrom for ArrowTimeUnit { impl FromArrowType for DType { fn from_arrow(value: SchemaRef) -> Self { - Self::from_arrow_with_session(value.as_ref(), &LEGACY_SESSION) - } -} - -impl FromArrowType<&Schema> for DType { - fn from_arrow(value: &Schema) -> Self { Self::from_arrow_with_session(value, &LEGACY_SESSION) } -} -impl FromArrowType<&Fields> for StructFields { - fn from_arrow(value: &Fields) -> Self { - Self::from_arrow_with_session(value, &LEGACY_SESSION) + fn from_arrow_with_session(value: SchemaRef, session: &VortexSession) -> Self { + Self::from_arrow_with_session(value.as_ref(), session) } } -impl FromArrowWithSession<&SchemaRef> for DType { - fn from_arrow_with_session(value: &SchemaRef, session: &VortexSession) -> Self { - Self::from_arrow_with_session(value.as_ref(), session) +impl FromArrowType<&Schema> for DType { + fn from_arrow(value: &Schema) -> Self { + Self::from_arrow_with_session(value, &LEGACY_SESSION) } -} -impl FromArrowWithSession<&Schema> for DType { fn from_arrow_with_session(value: &Schema, session: &VortexSession) -> Self { Self::Struct( StructFields::from_arrow_with_session(value.fields(), session), @@ -198,7 +166,11 @@ impl FromArrowWithSession<&Schema> for DType { } } -impl FromArrowWithSession<&Fields> for StructFields { +impl FromArrowType<&Fields> for StructFields { + fn from_arrow(value: &Fields) -> Self { + Self::from_arrow_with_session(value, &LEGACY_SESSION) + } + fn from_arrow_with_session(value: &Fields, session: &VortexSession) -> Self { let dtypes = session.dtypes(); StructFields::from_iter(value.into_iter().map(|f| { @@ -272,9 +244,7 @@ impl FromArrowType<&Field> for DType { fn from_arrow(field: &Field) -> Self { Self::from_arrow_with_session(field, &LEGACY_SESSION) } -} -impl FromArrowWithSession<&Field> for DType { fn from_arrow_with_session(field: &Field, session: &VortexSession) -> Self { dtype_from_field(field, &session.dtypes()) } @@ -297,9 +267,9 @@ fn dtype_from_field(field: &Field, dtypes: &DTypeSession) -> DType { return storage_dtype; }; - let canonical_alias = arrow_canonical_to_vortex_id(ext_name); + let canonical_alias = dtypes.vortex_id_for_arrow_canonical(ext_name); let is_canonical = canonical_alias.is_some(); - let ext_id = ExtId::new(canonical_alias.unwrap_or(ext_name)); + let ext_id = canonical_alias.unwrap_or_else(|| ExtId::new(ext_name)); let Some(plugin) = dtypes.registry().find(&ext_id) else { tracing::warn!( @@ -386,6 +356,12 @@ fn storage_dtype_from_field(field: &Field, dtypes: &DTypeSession) -> DType { impl DType { /// Convert a Vortex [`DType`] into an Arrow [`Schema`]. pub fn to_arrow_schema(&self) -> VortexResult { + self.to_arrow_schema_with_session(&LEGACY_SESSION) + } + + /// Convert a Vortex [`DType`] into an Arrow [`Schema`], consulting `session` for Arrow + /// canonical extension aliases registered via [`DTypeSession::register_arrow_canonical`]. + pub fn to_arrow_schema_with_session(&self, session: &VortexSession) -> VortexResult { let DType::Struct(struct_dtype, nullable) = self else { vortex_bail!("only DType::Struct can be converted to arrow schema"); }; @@ -394,9 +370,14 @@ impl DType { vortex_bail!("top-level struct in Schema must be NonNullable"); } + let dtypes = session.dtypes(); let mut builder = SchemaBuilder::with_capacity(struct_dtype.names().len()); for (field_name, field_dtype) in struct_dtype.names().iter().zip(struct_dtype.fields()) { - builder.push(field_from_dtype(field_name.as_ref(), &field_dtype)?); + builder.push(field_from_dtype( + field_name.as_ref(), + &field_dtype, + &dtypes, + )?); } Ok(builder.finish()) @@ -404,84 +385,90 @@ impl DType { /// Returns the Arrow [`DataType`] that best corresponds to this Vortex [`DType`]. pub fn to_arrow_dtype(&self) -> VortexResult { - Ok(match self { - DType::Null => DataType::Null, - DType::Bool(_) => DataType::Boolean, - DType::Primitive(ptype, _) => match ptype { - PType::U8 => DataType::UInt8, - PType::U16 => DataType::UInt16, - PType::U32 => DataType::UInt32, - PType::U64 => DataType::UInt64, - PType::I8 => DataType::Int8, - PType::I16 => DataType::Int16, - PType::I32 => DataType::Int32, - PType::I64 => DataType::Int64, - PType::F16 => DataType::Float16, - PType::F32 => DataType::Float32, - PType::F64 => DataType::Float64, - }, - DType::Decimal(dt, _) => { - let precision = dt.precision(); - let scale = dt.scale(); - - match precision { - // This code is commented out until DataFusion improves its support for smaller decimals. - // // DECIMAL32_MAX_PRECISION - // 0..=9 => DataType::Decimal32(precision, scale), - // // DECIMAL64_MAX_PRECISION - // 10..=18 => DataType::Decimal64(precision, scale), - // DECIMAL128_MAX_PRECISION - 0..=38 => DataType::Decimal128(precision, scale), - // DECIMAL256_MAX_PRECISION - 39.. => DataType::Decimal256(precision, scale), - } + to_arrow_dtype_with_dtypes(self, &LEGACY_SESSION.dtypes()) + } +} + +fn to_arrow_dtype_with_dtypes(dtype: &DType, dtypes: &DTypeSession) -> VortexResult { + Ok(match dtype { + DType::Null => DataType::Null, + DType::Bool(_) => DataType::Boolean, + DType::Primitive(ptype, _) => match ptype { + PType::U8 => DataType::UInt8, + PType::U16 => DataType::UInt16, + PType::U32 => DataType::UInt32, + PType::U64 => DataType::UInt64, + PType::I8 => DataType::Int8, + PType::I16 => DataType::Int16, + PType::I32 => DataType::Int32, + PType::I64 => DataType::Int64, + PType::F16 => DataType::Float16, + PType::F32 => DataType::Float32, + PType::F64 => DataType::Float64, + }, + DType::Decimal(dt, _) => { + let precision = dt.precision(); + let scale = dt.scale(); + + match precision { + // This code is commented out until DataFusion improves its support for smaller decimals. + // // DECIMAL32_MAX_PRECISION + // 0..=9 => DataType::Decimal32(precision, scale), + // // DECIMAL64_MAX_PRECISION + // 10..=18 => DataType::Decimal64(precision, scale), + // DECIMAL128_MAX_PRECISION + 0..=38 => DataType::Decimal128(precision, scale), + // DECIMAL256_MAX_PRECISION + 39.. => DataType::Decimal256(precision, scale), } - DType::Utf8(_) => DataType::Utf8View, - DType::Binary(_) => DataType::BinaryView, - // There are four kinds of lists: List (32-bit offsets), Large List (64-bit), List View - // (32-bit), Large List View (64-bit). We cannot both guarantee zero-copy and commit to an - // Arrow dtype because we do not how large our offsets are. - DType::List(elem_dtype, _) => DataType::List(FieldRef::new(field_from_dtype( + } + DType::Utf8(_) => DataType::Utf8View, + DType::Binary(_) => DataType::BinaryView, + // There are four kinds of lists: List (32-bit offsets), Large List (64-bit), List View + // (32-bit), Large List View (64-bit). We cannot both guarantee zero-copy and commit to an + // Arrow dtype because we do not how large our offsets are. + DType::List(elem_dtype, _) => DataType::List(FieldRef::new(field_from_dtype( + Field::LIST_FIELD_DEFAULT_NAME, + elem_dtype, + dtypes, + )?)), + DType::FixedSizeList(elem_dtype, size, _) => DataType::FixedSizeList( + FieldRef::new(field_from_dtype( Field::LIST_FIELD_DEFAULT_NAME, elem_dtype, - )?)), - DType::FixedSizeList(elem_dtype, size, _) => DataType::FixedSizeList( - FieldRef::new(field_from_dtype( - Field::LIST_FIELD_DEFAULT_NAME, - elem_dtype, - )?), - *size as i32, - ), - DType::Struct(struct_dtype, _) => { - let mut fields = Vec::with_capacity(struct_dtype.names().len()); - for (field_name, field_dt) in struct_dtype.names().iter().zip(struct_dtype.fields()) - { - fields.push(FieldRef::from(field_from_dtype( - field_name.as_ref(), - &field_dt, - )?)); - } - - DataType::Struct(Fields::from(fields)) + dtypes, + )?), + *size as i32, + ), + DType::Struct(struct_dtype, _) => { + let mut fields = Vec::with_capacity(struct_dtype.names().len()); + for (field_name, field_dt) in struct_dtype.names().iter().zip(struct_dtype.fields()) { + fields.push(FieldRef::from(field_from_dtype( + field_name.as_ref(), + &field_dt, + dtypes, + )?)); } - DType::Variant(_) => vortex_bail!( - "DType::Variant requires Arrow Field metadata; use to_arrow_schema or a Field helper" - ), - DType::Extension(ext_dtype) => { - if let Some(native) = native_arrow_dtype_for_extension(ext_dtype) { - return Ok(native); - } - // Extension identity lives on the Field (see `field_from_dtype`), not on - // DataType, so here we only encode the storage type. - ext_dtype.storage_dtype().to_arrow_dtype()? + + DataType::Struct(Fields::from(fields)) + } + DType::Variant(_) => vortex_bail!( + "DType::Variant requires Arrow Field metadata; use to_arrow_schema or a Field helper" + ), + DType::Extension(ext_dtype) => { + if let Some(native) = native_arrow_dtype_for_extension(ext_dtype) { + return Ok(native); } - }) - } + // Extension identity lives on the Field (see `field_from_dtype`), not on + // DataType, so here we only encode the storage type. + to_arrow_dtype_with_dtypes(ext_dtype.storage_dtype(), dtypes)? + } + }) } /// Build an Arrow [`Field`], attaching `ARROW:extension:name` and, when present, /// `ARROW:extension:metadata` for extensions and Variant that have no native Arrow mapping. -fn field_from_dtype(name: &str, dtype: &DType) -> VortexResult { +fn field_from_dtype(name: &str, dtype: &DType, dtypes: &DTypeSession) -> VortexResult { if dtype.is_variant() { let storage = DataType::Struct(variant_storage_fields_minimal()); return Ok( @@ -502,9 +489,9 @@ fn field_from_dtype(name: &str, dtype: &DType) -> VortexResult { return Ok(Field::new(name, native, dtype.is_nullable())); } - let storage_arrow = ext.storage_dtype().to_arrow_dtype()?; + let storage_arrow = to_arrow_dtype_with_dtypes(ext.storage_dtype(), dtypes)?; let ext_meta_bytes = ext.serialize_metadata()?; - let (ext_name, meta_str) = match vortex_id_to_arrow_canonical(ext.id().as_str()) { + let (ext_name, meta_str) = match dtypes.arrow_canonical_for(&ext.id()) { Some(canonical) => { // Canonical Arrow extensions specify a UTF-8 metadata format (typically JSON), // read as-is by arrow-rs / pyarrow. The plugin owns producing those bytes. @@ -529,7 +516,7 @@ fn field_from_dtype(name: &str, dtype: &DType) -> VortexResult { Ok(Field::new( name, - dtype.to_arrow_dtype()?, + to_arrow_dtype_with_dtypes(dtype, dtypes)?, dtype.is_nullable(), )) } @@ -726,14 +713,6 @@ mod test { schema_null.to_arrow_schema().unwrap(); } - #[test] - fn canonical_aliases_bijection() { - for (vortex_id, arrow_name) in CANONICAL_ALIASES { - assert_eq!(vortex_id_to_arrow_canonical(vortex_id), Some(*arrow_name)); - assert_eq!(arrow_canonical_to_vortex_id(arrow_name), Some(*vortex_id)); - } - } - #[test] fn test_unicode_field_names_roundtrip() { // Regression test for https://github.com/vortex-data/vortex/issues/5979. diff --git a/vortex-array/src/dtype/session.rs b/vortex-array/src/dtype/session.rs index 669aab68165..0b46414284c 100644 --- a/vortex-array/src/dtype/session.rs +++ b/vortex-array/src/dtype/session.rs @@ -5,11 +5,13 @@ use std::sync::Arc; +use parking_lot::RwLock; use vortex_session::Ref; use vortex_session::SessionExt; use vortex_session::registry::Registry; use crate::dtype::extension::ExtDTypePluginRef; +use crate::dtype::extension::ExtId; use crate::dtype::extension::ExtVTable; use crate::extension::datetime::Date; use crate::extension::datetime::Time; @@ -22,15 +24,21 @@ pub type ExtDTypeRegistry = Registry; #[derive(Debug)] pub struct DTypeSession { registry: ExtDTypeRegistry, + arrow_canonical: RwLock, +} + +#[derive(Debug, Default)] +struct ArrowCanonicalAliases { + entries: Vec<(ExtId, &'static str)>, } impl Default for DTypeSession { fn default() -> Self { let this = Self { registry: Registry::default(), + arrow_canonical: RwLock::default(), }; - // Register built-in temporal extension dtypes this.register(Date); this.register(Time); this.register(Timestamp); @@ -50,6 +58,35 @@ impl DTypeSession { pub fn registry(&self) -> &ExtDTypeRegistry { &self.registry } + + /// Register an Arrow canonical extension name as an alias for a Vortex extension id. + /// Aliased extensions emit the canonical name on `ARROW:extension:name` and serialize + /// metadata as raw UTF-8 instead of base64-wrapped bytes. + pub fn register_arrow_canonical(&self, vortex_id: ExtId, arrow_name: &'static str) { + let mut aliases = self.arrow_canonical.write(); + aliases.entries.retain(|(v, _)| *v != vortex_id); + aliases.entries.push((vortex_id, arrow_name)); + } + + /// Returns the Arrow canonical extension name aliased to the given Vortex id, if any. + pub fn arrow_canonical_for(&self, vortex_id: &ExtId) -> Option<&'static str> { + self.arrow_canonical + .read() + .entries + .iter() + .find(|(v, _)| v == vortex_id) + .map(|(_, a)| *a) + } + + /// Returns the Vortex extension id aliased to the given Arrow canonical name, if any. + pub fn vortex_id_for_arrow_canonical(&self, arrow_name: &str) -> Option { + self.arrow_canonical + .read() + .entries + .iter() + .find(|(_, a)| *a == arrow_name) + .map(|(v, _)| *v) + } } /// Extension trait for accessing the DType session. diff --git a/vortex-tensor/src/lib.rs b/vortex-tensor/src/lib.rs index a47ab88b061..5196ae9e331 100644 --- a/vortex-tensor/src/lib.rs +++ b/vortex-tensor/src/lib.rs @@ -6,6 +6,7 @@ //! similarity. use vortex_array::arrays::scalar_fn::plugin::ScalarFnArrayPlugin; +use vortex_array::dtype::extension::ExtId; use vortex_array::dtype::session::DTypeSessionExt; use vortex_array::scalar_fn::session::ScalarFnSessionExt; use vortex_array::session::ArraySessionExt; @@ -42,8 +43,11 @@ pub const SCALAR_FN_ARRAY_TENSOR_PLUGIN_ENV: &str = "VX_SCALAR_FN_ARRAY_TENSOR_P /// Initialize the Vortex tensor library with a Vortex session. pub fn initialize(session: &VortexSession) { - session.dtypes().register(Vector); - session.dtypes().register(FixedShapeTensor); + let dtypes = session.dtypes(); + dtypes.register(Vector); + dtypes.register(FixedShapeTensor); + dtypes.register_arrow_canonical(ExtId::new(fixed_shape::ID), "arrow.fixed_shape_tensor"); + drop(dtypes); let session_fns = session.scalar_fns(); diff --git a/vortex-tensor/src/tests/arrow_roundtrip.rs b/vortex-tensor/src/tests/arrow_roundtrip.rs index 2d8a64466f8..361767528b0 100644 --- a/vortex-tensor/src/tests/arrow_roundtrip.rs +++ b/vortex-tensor/src/tests/arrow_roundtrip.rs @@ -12,7 +12,7 @@ use arrow_schema::extension::EXTENSION_TYPE_NAME_KEY; use vortex_array::dtype::DType; use vortex_array::dtype::Nullability; use vortex_array::dtype::PType; -use vortex_array::dtype::arrow::FromArrowWithSession; +use vortex_array::dtype::arrow::FromArrowType; use vortex_array::dtype::extension::ExtDType; use vortex_array::extension::datetime::TimeUnit; use vortex_array::extension::datetime::Timestamp; @@ -114,7 +114,7 @@ fn fixed_shape_tensor_metadata_roundtrip() { Nullability::NonNullable, ); - let schema = original.to_arrow_schema().unwrap(); + let schema = original.to_arrow_schema_with_session(&SESSION).unwrap(); let field = schema.field(0); assert_eq!( diff --git a/vortex-tensor/src/types/fixed_shape/mod.rs b/vortex-tensor/src/types/fixed_shape/mod.rs index 602a6ecd637..1571c2588c3 100644 --- a/vortex-tensor/src/types/fixed_shape/mod.rs +++ b/vortex-tensor/src/types/fixed_shape/mod.rs @@ -3,6 +3,9 @@ //! Fixed-shape Tensor extension type. +/// Vortex extension id for [`FixedShapeTensor`]. +pub(crate) const ID: &str = "vortex.tensor.fixed_shape_tensor"; + /// The VTable for the Tensor extension type. #[derive(Clone, Debug, Default, PartialEq, Eq, Hash)] pub struct FixedShapeTensor; diff --git a/vortex-tensor/src/types/fixed_shape/vtable.rs b/vortex-tensor/src/types/fixed_shape/vtable.rs index 89e1aa3c719..362c3c65945 100644 --- a/vortex-tensor/src/types/fixed_shape/vtable.rs +++ b/vortex-tensor/src/types/fixed_shape/vtable.rs @@ -13,6 +13,7 @@ use vortex_error::vortex_ensure_eq; use crate::types::fixed_shape::FixedShapeTensor; use crate::types::fixed_shape::FixedShapeTensorMetadata; +use crate::types::fixed_shape::ID; use crate::types::fixed_shape::canonical; impl ExtVTable for FixedShapeTensor { @@ -22,7 +23,7 @@ impl ExtVTable for FixedShapeTensor { type NativeValue<'a> = &'a ScalarValue; fn id(&self) -> ExtId { - ExtId::new("vortex.fixed_shape_tensor") + ExtId::new(ID) } fn serialize_metadata(&self, metadata: &Self::Metadata) -> VortexResult> { From c4d1079bae16cc795490b10275a003dce55b31d9 Mon Sep 17 00:00:00 2001 From: Baris Palaska Date: Fri, 24 Apr 2026 12:32:51 +0100 Subject: [PATCH 04/21] reuse ids Signed-off-by: Baris Palaska --- vortex-array/src/dtype/arrow.rs | 15 +++++---------- vortex-tensor/src/lib.rs | 2 +- vortex-tensor/src/tests/arrow_roundtrip.rs | 14 ++++++-------- vortex-tensor/src/types/fixed_shape/canonical.rs | 11 +++-------- vortex-tensor/src/types/fixed_shape/mod.rs | 3 +++ vortex-tensor/src/types/vector/mod.rs | 3 +++ vortex-tensor/src/types/vector/vtable.rs | 3 ++- 7 files changed, 23 insertions(+), 28 deletions(-) diff --git a/vortex-array/src/dtype/arrow.rs b/vortex-array/src/dtype/arrow.rs index cf15e3aebcc..85b3f462f89 100644 --- a/vortex-array/src/dtype/arrow.rs +++ b/vortex-array/src/dtype/arrow.rs @@ -250,9 +250,8 @@ impl FromArrowType<&Field> for DType { } } -/// Convert an Arrow Field to a [`DType`] using a pre-borrowed [`DTypeSession`] for extension -/// lookup. Used by the `&Fields` and `&Field` impls so the session handle is acquired once per -/// schema rather than once per field. +/// Convert an Arrow Field to a [`DType`] with `dtypes` already borrowed from the session, +/// so the handle is acquired once per schema rather than once per field. fn dtype_from_field(field: &Field, dtypes: &DTypeSession) -> DType { let ext_name = field.extension_type_name(); @@ -309,11 +308,8 @@ fn dtype_from_field(field: &Field, dtypes: &DTypeSession) -> DType { } } -/// Decode extension metadata bytes from a Field. -/// -/// Canonical Arrow extensions store UTF-8 bytes directly (e.g. JSON). Non-canonical extensions -/// store base64-encoded bytes so that arbitrary binary plugin output survives a String-typed -/// metadata channel. +/// Canonical extensions store UTF-8 bytes directly; non-canonical extensions base64-encode so +/// arbitrary binary plugin output survives the String-typed metadata channel. fn decode_extension_metadata(field: &Field, is_canonical: bool) -> VortexResult> { match field.extension_type_metadata() { None | Some("") => Ok(Vec::new()), @@ -493,8 +489,7 @@ fn field_from_dtype(name: &str, dtype: &DType, dtypes: &DTypeSession) -> VortexR let ext_meta_bytes = ext.serialize_metadata()?; let (ext_name, meta_str) = match dtypes.arrow_canonical_for(&ext.id()) { Some(canonical) => { - // Canonical Arrow extensions specify a UTF-8 metadata format (typically JSON), - // read as-is by arrow-rs / pyarrow. The plugin owns producing those bytes. + // Canonical wire: raw UTF-8 (typically JSON), read as-is by arrow-rs / pyarrow. let s = String::from_utf8(ext_meta_bytes).map_err(|e| { vortex_err!("canonical extension {canonical} metadata must be valid UTF-8: {e}") })?; diff --git a/vortex-tensor/src/lib.rs b/vortex-tensor/src/lib.rs index 5196ae9e331..91e999bea79 100644 --- a/vortex-tensor/src/lib.rs +++ b/vortex-tensor/src/lib.rs @@ -46,7 +46,7 @@ pub fn initialize(session: &VortexSession) { let dtypes = session.dtypes(); dtypes.register(Vector); dtypes.register(FixedShapeTensor); - dtypes.register_arrow_canonical(ExtId::new(fixed_shape::ID), "arrow.fixed_shape_tensor"); + dtypes.register_arrow_canonical(ExtId::new(fixed_shape::ID), fixed_shape::ARROW_EXT_NAME); drop(dtypes); let session_fns = session.scalar_fns(); diff --git a/vortex-tensor/src/tests/arrow_roundtrip.rs b/vortex-tensor/src/tests/arrow_roundtrip.rs index 361767528b0..903c8b9f2cc 100644 --- a/vortex-tensor/src/tests/arrow_roundtrip.rs +++ b/vortex-tensor/src/tests/arrow_roundtrip.rs @@ -18,13 +18,12 @@ use vortex_array::extension::datetime::TimeUnit; use vortex_array::extension::datetime::Timestamp; use crate::tests::SESSION; +use crate::types::fixed_shape; use crate::types::fixed_shape::FixedShapeTensor; use crate::types::fixed_shape::FixedShapeTensorMetadata; +use crate::types::vector; use crate::types::vector::Vector; -const VECTOR_EXT_NAME: &str = "vortex.tensor.vector"; -const FIXED_SHAPE_EXT_NAME: &str = "arrow.fixed_shape_tensor"; - fn vector_dtype(len: u32) -> DType { let storage = DType::FixedSizeList( Arc::new(DType::Primitive(PType::F32, Nullability::NonNullable)), @@ -57,9 +56,9 @@ fn vector_forward_carries_extension_name() { .metadata() .get(EXTENSION_TYPE_NAME_KEY) .map(String::as_str), - Some(VECTOR_EXT_NAME), + Some(vector::ID), ); - // EmptyMetadata: no metadata key emitted. + // EmptyMetadata → no metadata key. assert!(field.metadata().get(EXTENSION_TYPE_METADATA_KEY).is_none()); let DataType::FixedSizeList(element, size) = field.data_type() else { @@ -122,11 +121,10 @@ fn fixed_shape_tensor_metadata_roundtrip() { .metadata() .get(EXTENSION_TYPE_NAME_KEY) .map(String::as_str), - Some(FIXED_SHAPE_EXT_NAME), + Some(fixed_shape::ARROW_EXT_NAME), ); - // Canonical extensions put raw JSON on the wire — pyarrow / arrow-rs read it directly - // without base64. Parse it back to confirm the on-wire format. + // Canonical wire: raw JSON, not base64. let meta_str = field.metadata().get(EXTENSION_TYPE_METADATA_KEY).unwrap(); let parsed: serde_json::Value = serde_json::from_str(meta_str).unwrap(); assert_eq!(parsed["shape"], serde_json::json!([2, 3, 4])); diff --git a/vortex-tensor/src/types/fixed_shape/canonical.rs b/vortex-tensor/src/types/fixed_shape/canonical.rs index e1a2d6ec300..35da49a6edb 100644 --- a/vortex-tensor/src/types/fixed_shape/canonical.rs +++ b/vortex-tensor/src/types/fixed_shape/canonical.rs @@ -1,15 +1,10 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: Copyright the Vortex contributors -//! Arrow canonical [`arrow.fixed_shape_tensor`] metadata serialization. +//! Arrow canonical [`arrow.fixed_shape_tensor`] JSON metadata serialization. //! -//! The wire format is a UTF-8 JSON object placed in `ARROW:extension:metadata`, matching the -//! Arrow specification and pyarrow / arrow-rs interop expectations. -//! -//! We roll our own serde rather than delegating to `arrow_schema::extension::FixedShapeTensor` -//! because arrow-rs 58 serializes the field as `"permutations"` (plural) while the Arrow -//! specification and pyarrow use `"permutation"` (singular). pyarrow silently ignores the -//! misspelled key. +//! Hand-rolled rather than reusing `arrow_schema::extension::FixedShapeTensor` because arrow-rs +//! 58 emits `"permutations"` (plural) while the spec and pyarrow use `"permutation"`. //! //! [`arrow.fixed_shape_tensor`]: https://arrow.apache.org/docs/format/CanonicalExtensions.html#fixed-shape-tensor diff --git a/vortex-tensor/src/types/fixed_shape/mod.rs b/vortex-tensor/src/types/fixed_shape/mod.rs index 1571c2588c3..b9e6c1d4219 100644 --- a/vortex-tensor/src/types/fixed_shape/mod.rs +++ b/vortex-tensor/src/types/fixed_shape/mod.rs @@ -6,6 +6,9 @@ /// Vortex extension id for [`FixedShapeTensor`]. pub(crate) const ID: &str = "vortex.tensor.fixed_shape_tensor"; +/// Arrow canonical extension name [`ID`] aliases to. +pub(crate) const ARROW_EXT_NAME: &str = "arrow.fixed_shape_tensor"; + /// The VTable for the Tensor extension type. #[derive(Clone, Debug, Default, PartialEq, Eq, Hash)] pub struct FixedShapeTensor; diff --git a/vortex-tensor/src/types/vector/mod.rs b/vortex-tensor/src/types/vector/mod.rs index d077a183713..e81ee4b0674 100644 --- a/vortex-tensor/src/types/vector/mod.rs +++ b/vortex-tensor/src/types/vector/mod.rs @@ -3,6 +3,9 @@ //! Vector extension type for fixed-length float vectors (e.g., embeddings). +/// Vortex extension id for [`Vector`]. +pub(crate) const ID: &str = "vortex.tensor.vector"; + use vortex_array::ArrayRef; use vortex_array::IntoArray; use vortex_array::arrays::ConstantArray; diff --git a/vortex-tensor/src/types/vector/vtable.rs b/vortex-tensor/src/types/vector/vtable.rs index d870e7a8e0d..709ac4d939b 100644 --- a/vortex-tensor/src/types/vector/vtable.rs +++ b/vortex-tensor/src/types/vector/vtable.rs @@ -11,6 +11,7 @@ use vortex_error::VortexResult; use vortex_error::vortex_bail; use vortex_error::vortex_ensure; +use crate::types::vector::ID; use crate::types::vector::Vector; impl ExtVTable for Vector { @@ -20,7 +21,7 @@ impl ExtVTable for Vector { type NativeValue<'a> = &'a ScalarValue; fn id(&self) -> ExtId { - ExtId::new("vortex.tensor.vector") + ExtId::new(ID) } fn serialize_metadata(&self, _metadata: &Self::Metadata) -> VortexResult> { From 62d70e22d95421fc6fd5ffa675d97dc15ec298e2 Mon Sep 17 00:00:00 2001 From: Baris Palaska Date: Mon, 27 Apr 2026 09:33:25 +0100 Subject: [PATCH 05/21] cached id Signed-off-by: Baris Palaska --- vortex-tensor/src/lib.rs | 3 +-- vortex-tensor/src/tests/arrow_roundtrip.rs | 2 +- vortex-tensor/src/types/fixed_shape/mod.rs | 6 ++---- vortex-tensor/src/types/fixed_shape/vtable.rs | 7 +++++-- vortex-tensor/src/types/vector/mod.rs | 5 ++--- vortex-tensor/src/types/vector/vtable.rs | 7 +++++-- 6 files changed, 16 insertions(+), 14 deletions(-) diff --git a/vortex-tensor/src/lib.rs b/vortex-tensor/src/lib.rs index 91e999bea79..4196b32b7cf 100644 --- a/vortex-tensor/src/lib.rs +++ b/vortex-tensor/src/lib.rs @@ -6,7 +6,6 @@ //! similarity. use vortex_array::arrays::scalar_fn::plugin::ScalarFnArrayPlugin; -use vortex_array::dtype::extension::ExtId; use vortex_array::dtype::session::DTypeSessionExt; use vortex_array::scalar_fn::session::ScalarFnSessionExt; use vortex_array::session::ArraySessionExt; @@ -46,7 +45,7 @@ pub fn initialize(session: &VortexSession) { let dtypes = session.dtypes(); dtypes.register(Vector); dtypes.register(FixedShapeTensor); - dtypes.register_arrow_canonical(ExtId::new(fixed_shape::ID), fixed_shape::ARROW_EXT_NAME); + dtypes.register_arrow_canonical(*fixed_shape::ID, fixed_shape::ARROW_EXT_NAME); drop(dtypes); let session_fns = session.scalar_fns(); diff --git a/vortex-tensor/src/tests/arrow_roundtrip.rs b/vortex-tensor/src/tests/arrow_roundtrip.rs index 903c8b9f2cc..313eb343aea 100644 --- a/vortex-tensor/src/tests/arrow_roundtrip.rs +++ b/vortex-tensor/src/tests/arrow_roundtrip.rs @@ -56,7 +56,7 @@ fn vector_forward_carries_extension_name() { .metadata() .get(EXTENSION_TYPE_NAME_KEY) .map(String::as_str), - Some(vector::ID), + Some(vector::ID.as_str()), ); // EmptyMetadata → no metadata key. assert!(field.metadata().get(EXTENSION_TYPE_METADATA_KEY).is_none()); diff --git a/vortex-tensor/src/types/fixed_shape/mod.rs b/vortex-tensor/src/types/fixed_shape/mod.rs index b9e6c1d4219..565dcb66a47 100644 --- a/vortex-tensor/src/types/fixed_shape/mod.rs +++ b/vortex-tensor/src/types/fixed_shape/mod.rs @@ -3,10 +3,7 @@ //! Fixed-shape Tensor extension type. -/// Vortex extension id for [`FixedShapeTensor`]. -pub(crate) const ID: &str = "vortex.tensor.fixed_shape_tensor"; - -/// Arrow canonical extension name [`ID`] aliases to. +/// Arrow canonical extension name aliased to [`ID`]. pub(crate) const ARROW_EXT_NAME: &str = "arrow.fixed_shape_tensor"; /// The VTable for the Tensor extension type. @@ -22,3 +19,4 @@ pub use metadata::FixedShapeTensorMetadata; mod canonical; mod vtable; +pub(crate) use vtable::ID; diff --git a/vortex-tensor/src/types/fixed_shape/vtable.rs b/vortex-tensor/src/types/fixed_shape/vtable.rs index 362c3c65945..5377a64df0b 100644 --- a/vortex-tensor/src/types/fixed_shape/vtable.rs +++ b/vortex-tensor/src/types/fixed_shape/vtable.rs @@ -10,12 +10,15 @@ use vortex_error::VortexResult; use vortex_error::vortex_bail; use vortex_error::vortex_ensure; use vortex_error::vortex_ensure_eq; +use vortex_session::registry::CachedId; use crate::types::fixed_shape::FixedShapeTensor; use crate::types::fixed_shape::FixedShapeTensorMetadata; -use crate::types::fixed_shape::ID; use crate::types::fixed_shape::canonical; +/// Vortex extension id for [`FixedShapeTensor`]. +pub(crate) static ID: CachedId = CachedId::new("vortex.tensor.fixed_shape_tensor"); + impl ExtVTable for FixedShapeTensor { type Metadata = FixedShapeTensorMetadata; @@ -23,7 +26,7 @@ impl ExtVTable for FixedShapeTensor { type NativeValue<'a> = &'a ScalarValue; fn id(&self) -> ExtId { - ExtId::new(ID) + *ID } fn serialize_metadata(&self, metadata: &Self::Metadata) -> VortexResult> { diff --git a/vortex-tensor/src/types/vector/mod.rs b/vortex-tensor/src/types/vector/mod.rs index e81ee4b0674..35bde7f16c5 100644 --- a/vortex-tensor/src/types/vector/mod.rs +++ b/vortex-tensor/src/types/vector/mod.rs @@ -3,9 +3,6 @@ //! Vector extension type for fixed-length float vectors (e.g., embeddings). -/// Vortex extension id for [`Vector`]. -pub(crate) const ID: &str = "vortex.tensor.vector"; - use vortex_array::ArrayRef; use vortex_array::IntoArray; use vortex_array::arrays::ConstantArray; @@ -60,3 +57,5 @@ pub use matcher::AnyVector; pub use matcher::VectorMatcherMetadata; mod vtable; +#[cfg(test)] +pub(crate) use vtable::ID; diff --git a/vortex-tensor/src/types/vector/vtable.rs b/vortex-tensor/src/types/vector/vtable.rs index 709ac4d939b..a6418fc3d47 100644 --- a/vortex-tensor/src/types/vector/vtable.rs +++ b/vortex-tensor/src/types/vector/vtable.rs @@ -10,10 +10,13 @@ use vortex_array::scalar::ScalarValue; use vortex_error::VortexResult; use vortex_error::vortex_bail; use vortex_error::vortex_ensure; +use vortex_session::registry::CachedId; -use crate::types::vector::ID; use crate::types::vector::Vector; +/// Vortex extension id for [`Vector`]. +pub(crate) static ID: CachedId = CachedId::new("vortex.tensor.vector"); + impl ExtVTable for Vector { type Metadata = EmptyMetadata; @@ -21,7 +24,7 @@ impl ExtVTable for Vector { type NativeValue<'a> = &'a ScalarValue; fn id(&self) -> ExtId { - ExtId::new(ID) + *ID } fn serialize_metadata(&self, _metadata: &Self::Metadata) -> VortexResult> { From c3c455dfba12957c73bd199e78c476507abb2210 Mon Sep 17 00:00:00 2001 From: Baris Palaska Date: Mon, 27 Apr 2026 11:29:50 +0100 Subject: [PATCH 06/21] arcswap Signed-off-by: Baris Palaska --- Cargo.lock | 1 + vortex-array/Cargo.toml | 1 + vortex-array/public-api.lock | 2 +- vortex-array/src/dtype/arrow.rs | 2 +- vortex-array/src/dtype/session.rs | 115 +++++++++++++++++++++++------- 5 files changed, 93 insertions(+), 28 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 10743c6fe81..dbf6a1f970a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -10244,6 +10244,7 @@ name = "vortex-array" version = "0.1.0" dependencies = [ "arbitrary", + "arc-swap", "arcref", "arrow-arith 58.1.0", "arrow-array 58.1.0", diff --git a/vortex-array/Cargo.toml b/vortex-array/Cargo.toml index 19bb5aef43f..a56f4317333 100644 --- a/vortex-array/Cargo.toml +++ b/vortex-array/Cargo.toml @@ -21,6 +21,7 @@ workspace = true [dependencies] arbitrary = { workspace = true, optional = true } +arc-swap = { workspace = true } arcref = { workspace = true } arrow-arith = { workspace = true } arrow-array = { workspace = true, features = ["ffi"] } diff --git a/vortex-array/public-api.lock b/vortex-array/public-api.lock index f2573ce40b6..7d9385b4307 100644 --- a/vortex-array/public-api.lock +++ b/vortex-array/public-api.lock @@ -8806,7 +8806,7 @@ pub struct vortex_array::dtype::session::DTypeSession impl vortex_array::dtype::session::DTypeSession -pub fn vortex_array::dtype::session::DTypeSession::arrow_canonical_for(&self, vortex_id: &vortex_array::dtype::extension::ExtId) -> core::option::Option<&'static str> +pub fn vortex_array::dtype::session::DTypeSession::arrow_canonical_for(&self, vortex_id: &vortex_array::dtype::extension::ExtId) -> core::option::Option pub fn vortex_array::dtype::session::DTypeSession::register(&self, vtable: V) diff --git a/vortex-array/src/dtype/arrow.rs b/vortex-array/src/dtype/arrow.rs index 85b3f462f89..7cbb11887e9 100644 --- a/vortex-array/src/dtype/arrow.rs +++ b/vortex-array/src/dtype/arrow.rs @@ -493,7 +493,7 @@ fn field_from_dtype(name: &str, dtype: &DType, dtypes: &DTypeSession) -> VortexR let s = String::from_utf8(ext_meta_bytes).map_err(|e| { vortex_err!("canonical extension {canonical} metadata must be valid UTF-8: {e}") })?; - (canonical.to_owned(), s) + (canonical.as_str().to_owned(), s) } None => ( ext.id().as_str().to_owned(), diff --git a/vortex-array/src/dtype/session.rs b/vortex-array/src/dtype/session.rs index 0b46414284c..306350e3c86 100644 --- a/vortex-array/src/dtype/session.rs +++ b/vortex-array/src/dtype/session.rs @@ -5,10 +5,11 @@ use std::sync::Arc; -use parking_lot::RwLock; +use arc_swap::ArcSwap; use vortex_session::Ref; use vortex_session::SessionExt; use vortex_session::registry::Registry; +use vortex_utils::aliases::hash_map::HashMap; use crate::dtype::extension::ExtDTypePluginRef; use crate::dtype::extension::ExtId; @@ -20,23 +21,21 @@ use crate::extension::datetime::Timestamp; /// Registry for extension dtypes. pub type ExtDTypeRegistry = Registry; +/// Bidirectional alias map between Vortex extension ids and Arrow canonical names. +type ArrowCanonicalMap = HashMap; + /// Session for managing extension dtypes. #[derive(Debug)] pub struct DTypeSession { registry: ExtDTypeRegistry, - arrow_canonical: RwLock, -} - -#[derive(Debug, Default)] -struct ArrowCanonicalAliases { - entries: Vec<(ExtId, &'static str)>, + arrow_canonical: ArcSwap, } impl Default for DTypeSession { fn default() -> Self { let this = Self { registry: Registry::default(), - arrow_canonical: RwLock::default(), + arrow_canonical: ArcSwap::new(Arc::new(ArrowCanonicalMap::default())), }; this.register(Date); @@ -59,33 +58,34 @@ impl DTypeSession { &self.registry } - /// Register an Arrow canonical extension name as an alias for a Vortex extension id. - /// Aliased extensions emit the canonical name on `ARROW:extension:name` and serialize - /// metadata as raw UTF-8 instead of base64-wrapped bytes. + /// Alias an Arrow canonical extension name to a Vortex extension id. Aliased extensions + /// emit the canonical name on `ARROW:extension:name` and serialize metadata as raw UTF-8 + /// instead of base64-wrapped bytes. Re-registering evicts the previous mapping. pub fn register_arrow_canonical(&self, vortex_id: ExtId, arrow_name: &'static str) { - let mut aliases = self.arrow_canonical.write(); - aliases.entries.retain(|(v, _)| *v != vortex_id); - aliases.entries.push((vortex_id, arrow_name)); + let arrow_id = ExtId::new(arrow_name); + self.arrow_canonical.rcu(|prev| { + let mut next = (**prev).clone(); + if let Some(stale) = next.insert(vortex_id, arrow_id) { + next.remove(&stale); + } + if let Some(stale) = next.insert(arrow_id, vortex_id) { + next.remove(&stale); + } + Arc::new(next) + }); } /// Returns the Arrow canonical extension name aliased to the given Vortex id, if any. - pub fn arrow_canonical_for(&self, vortex_id: &ExtId) -> Option<&'static str> { - self.arrow_canonical - .read() - .entries - .iter() - .find(|(v, _)| v == vortex_id) - .map(|(_, a)| *a) + pub fn arrow_canonical_for(&self, vortex_id: &ExtId) -> Option { + self.arrow_canonical.load().get(vortex_id).copied() } /// Returns the Vortex extension id aliased to the given Arrow canonical name, if any. pub fn vortex_id_for_arrow_canonical(&self, arrow_name: &str) -> Option { self.arrow_canonical - .read() - .entries - .iter() - .find(|(_, a)| *a == arrow_name) - .map(|(v, _)| *v) + .load() + .get(&ExtId::new(arrow_name)) + .copied() } } @@ -100,3 +100,66 @@ impl DTypeSessionExt for S { self.get::() } } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn arrow_canonical_re_registration_is_clean() { + let session = DTypeSession::default(); + let v = ExtId::new("vortex.test"); + + session.register_arrow_canonical(v, "arrow.foo"); + assert_eq!( + session.arrow_canonical_for(&v), + Some(ExtId::new("arrow.foo")) + ); + assert_eq!(session.vortex_id_for_arrow_canonical("arrow.foo"), Some(v)); + + session.register_arrow_canonical(v, "arrow.bar"); + assert_eq!( + session.arrow_canonical_for(&v), + Some(ExtId::new("arrow.bar")) + ); + assert_eq!(session.vortex_id_for_arrow_canonical("arrow.bar"), Some(v)); + assert_eq!(session.vortex_id_for_arrow_canonical("arrow.foo"), None); + } + + /// `(a → b, b → a)` then `register(a, c)` should leave `(a → c, c → a)` only. + #[test] + fn rebind_vortex_id_to_new_arrow_name() { + let session = DTypeSession::default(); + let a = ExtId::new("vortex.a"); + let b = ExtId::new("arrow.b"); + let c = ExtId::new("arrow.c"); + + session.register_arrow_canonical(a, "arrow.b"); + assert_eq!(session.arrow_canonical_for(&a), Some(b)); + assert_eq!(session.vortex_id_for_arrow_canonical("arrow.b"), Some(a)); + + session.register_arrow_canonical(a, "arrow.c"); + + assert_eq!(session.arrow_canonical_for(&a), Some(c)); + assert_eq!(session.vortex_id_for_arrow_canonical("arrow.c"), Some(a)); + assert_eq!(session.vortex_id_for_arrow_canonical("arrow.b"), None); + } + + /// `(a → b, b → a)` then `register(c, b)` should leave `(c → b, b → c)` only. + #[test] + fn steal_arrow_name_from_another_vortex_id() { + let session = DTypeSession::default(); + let a = ExtId::new("vortex.a"); + let b = ExtId::new("arrow.b"); + let c = ExtId::new("vortex.c"); + + session.register_arrow_canonical(a, "arrow.b"); + assert_eq!(session.arrow_canonical_for(&a), Some(b)); + + session.register_arrow_canonical(c, "arrow.b"); + + assert_eq!(session.arrow_canonical_for(&c), Some(b)); + assert_eq!(session.vortex_id_for_arrow_canonical("arrow.b"), Some(c)); + assert_eq!(session.arrow_canonical_for(&a), None); + } +} From a91c8619a29f10a55d826bd5a325984f110aa8ad Mon Sep 17 00:00:00 2001 From: Baris Palaska Date: Mon, 27 Apr 2026 11:44:31 +0100 Subject: [PATCH 07/21] clippy Signed-off-by: Baris Palaska --- vortex-array/src/dtype/session.rs | 40 +++++++++++++++---------------- 1 file changed, 20 insertions(+), 20 deletions(-) diff --git a/vortex-array/src/dtype/session.rs b/vortex-array/src/dtype/session.rs index 306350e3c86..75246a0c59a 100644 --- a/vortex-array/src/dtype/session.rs +++ b/vortex-array/src/dtype/session.rs @@ -126,40 +126,40 @@ mod tests { assert_eq!(session.vortex_id_for_arrow_canonical("arrow.foo"), None); } - /// `(a → b, b → a)` then `register(a, c)` should leave `(a → c, c → a)` only. + /// `(vid → old, old → vid)` then `register(vid, new)` should leave `(vid → new, new → vid)`. #[test] fn rebind_vortex_id_to_new_arrow_name() { let session = DTypeSession::default(); - let a = ExtId::new("vortex.a"); - let b = ExtId::new("arrow.b"); - let c = ExtId::new("arrow.c"); + let vid = ExtId::new("vortex.a"); + let old = ExtId::new("arrow.b"); + let new = ExtId::new("arrow.c"); - session.register_arrow_canonical(a, "arrow.b"); - assert_eq!(session.arrow_canonical_for(&a), Some(b)); - assert_eq!(session.vortex_id_for_arrow_canonical("arrow.b"), Some(a)); + session.register_arrow_canonical(vid, "arrow.b"); + assert_eq!(session.arrow_canonical_for(&vid), Some(old)); + assert_eq!(session.vortex_id_for_arrow_canonical("arrow.b"), Some(vid)); - session.register_arrow_canonical(a, "arrow.c"); + session.register_arrow_canonical(vid, "arrow.c"); - assert_eq!(session.arrow_canonical_for(&a), Some(c)); - assert_eq!(session.vortex_id_for_arrow_canonical("arrow.c"), Some(a)); + assert_eq!(session.arrow_canonical_for(&vid), Some(new)); + assert_eq!(session.vortex_id_for_arrow_canonical("arrow.c"), Some(vid)); assert_eq!(session.vortex_id_for_arrow_canonical("arrow.b"), None); } - /// `(a → b, b → a)` then `register(c, b)` should leave `(c → b, b → c)` only. + /// `(old → name, name → old)` then `register(new, name)` should leave `(new → name, name → new)`. #[test] fn steal_arrow_name_from_another_vortex_id() { let session = DTypeSession::default(); - let a = ExtId::new("vortex.a"); - let b = ExtId::new("arrow.b"); - let c = ExtId::new("vortex.c"); + let old = ExtId::new("vortex.a"); + let name = ExtId::new("arrow.b"); + let new = ExtId::new("vortex.c"); - session.register_arrow_canonical(a, "arrow.b"); - assert_eq!(session.arrow_canonical_for(&a), Some(b)); + session.register_arrow_canonical(old, "arrow.b"); + assert_eq!(session.arrow_canonical_for(&old), Some(name)); - session.register_arrow_canonical(c, "arrow.b"); + session.register_arrow_canonical(new, "arrow.b"); - assert_eq!(session.arrow_canonical_for(&c), Some(b)); - assert_eq!(session.vortex_id_for_arrow_canonical("arrow.b"), Some(c)); - assert_eq!(session.arrow_canonical_for(&a), None); + assert_eq!(session.arrow_canonical_for(&new), Some(name)); + assert_eq!(session.vortex_id_for_arrow_canonical("arrow.b"), Some(new)); + assert_eq!(session.arrow_canonical_for(&old), None); } } From 343139a98fc71cd1037e0afec4de871aee55091b Mon Sep 17 00:00:00 2001 From: Baris Palaska Date: Mon, 27 Apr 2026 12:25:49 +0100 Subject: [PATCH 08/21] refactor tests, dont copy metadata bytes Signed-off-by: Baris Palaska --- vortex-array/src/dtype/arrow.rs | 10 +- vortex-tensor/src/types/fixed_shape/vtable.rs | 109 +++++++++--------- vortex-tensor/src/types/vector/vtable.rs | 90 ++++++--------- 3 files changed, 96 insertions(+), 113 deletions(-) diff --git a/vortex-array/src/dtype/arrow.rs b/vortex-array/src/dtype/arrow.rs index 7cbb11887e9..362a1caef8e 100644 --- a/vortex-array/src/dtype/arrow.rs +++ b/vortex-array/src/dtype/arrow.rs @@ -13,6 +13,7 @@ //! For this reason, it's recommended to do as much computation as possible within Vortex, and then //! materialize an Arrow ArrayRef at the very end of the processing chain. +use std::borrow::Cow; use std::sync::Arc; use arrow_schema::DataType; @@ -293,7 +294,7 @@ fn dtype_from_field(field: &Field, dtypes: &DTypeSession) -> DType { } }; - match plugin.deserialize(&metadata_bytes, storage_dtype.clone()) { + match plugin.deserialize(metadata_bytes.as_ref(), storage_dtype.clone()) { Ok(ext_ref) => DType::Extension(ext_ref), Err(e) => { tracing::warn!( @@ -310,12 +311,13 @@ fn dtype_from_field(field: &Field, dtypes: &DTypeSession) -> DType { /// Canonical extensions store UTF-8 bytes directly; non-canonical extensions base64-encode so /// arbitrary binary plugin output survives the String-typed metadata channel. -fn decode_extension_metadata(field: &Field, is_canonical: bool) -> VortexResult> { +fn decode_extension_metadata(field: &Field, is_canonical: bool) -> VortexResult> { match field.extension_type_metadata() { - None | Some("") => Ok(Vec::new()), - Some(s) if is_canonical => Ok(s.as_bytes().to_vec()), + None | Some("") => Ok(Cow::Borrowed(&[])), + Some(s) if is_canonical => Ok(Cow::Borrowed(s.as_bytes())), Some(s) => BASE64_STANDARD .decode(s) + .map(Cow::Owned) .map_err(|e| vortex_err!("failed to base64-decode {EXTENSION_TYPE_METADATA_KEY}: {e}")), } } diff --git a/vortex-tensor/src/types/fixed_shape/vtable.rs b/vortex-tensor/src/types/fixed_shape/vtable.rs index ac07da2c872..85d70068bed 100644 --- a/vortex-tensor/src/types/fixed_shape/vtable.rs +++ b/vortex-tensor/src/types/fixed_shape/vtable.rs @@ -145,69 +145,64 @@ mod tests { assert_roundtrip(&metadata?) } - /// Constructs a `FixedShapeTensor` ext dtype wrapped in `DType::Extension`. - fn tensor_dtype( - metadata: FixedShapeTensorMetadata, - element: PType, - list_size: u32, - ) -> VortexResult { + fn tensor_dtype(metadata: FixedShapeTensorMetadata, element: PType, list_size: u32) -> DType { let storage = DType::FixedSizeList( Arc::new(DType::Primitive(element, Nullability::NonNullable)), list_size, Nullability::NonNullable, ); - Ok(DType::Extension( - ExtDType::::try_new(metadata, storage)?.erased(), - )) + DType::Extension( + ExtDType::::try_new(metadata, storage) + .unwrap() + .erased(), + ) } - #[test] - fn tensor_widens_element_when_metadata_matches() -> VortexResult<()> { - let metadata = FixedShapeTensorMetadata::new(vec![2, 3]); - let lhs = tensor_dtype(metadata.clone(), PType::F32, 6)?; - let rhs = tensor_dtype(metadata.clone(), PType::F64, 6)?; - let expected = tensor_dtype(metadata, PType::F64, 6)?; - assert_eq!(lhs.least_supertype(&rhs), Some(expected)); - Ok(()) - } - - #[test] - fn tensor_different_shape_returns_none() -> VortexResult<()> { - let lhs = tensor_dtype(FixedShapeTensorMetadata::new(vec![2, 3]), PType::F32, 6)?; - let rhs = tensor_dtype(FixedShapeTensorMetadata::new(vec![3, 2]), PType::F32, 6)?; - assert_eq!(lhs.least_supertype(&rhs), None); - Ok(()) - } - - #[test] - fn tensor_different_permutation_returns_none() -> VortexResult<()> { - let lhs_metadata = - FixedShapeTensorMetadata::new(vec![2, 3]).with_permutation(vec![0, 1])?; - let rhs_metadata = - FixedShapeTensorMetadata::new(vec![2, 3]).with_permutation(vec![1, 0])?; - let lhs = tensor_dtype(lhs_metadata, PType::F32, 6)?; - let rhs = tensor_dtype(rhs_metadata, PType::F32, 6)?; - assert_eq!(lhs.least_supertype(&rhs), None); - Ok(()) - } - - #[test] - fn tensor_different_dim_names_returns_none() -> VortexResult<()> { - let lhs_metadata = FixedShapeTensorMetadata::new(vec![2, 3]) - .with_dim_names(vec!["x".into(), "y".into()])?; - let rhs_metadata = FixedShapeTensorMetadata::new(vec![2, 3]) - .with_dim_names(vec!["rows".into(), "cols".into()])?; - let lhs = tensor_dtype(lhs_metadata, PType::F32, 6)?; - let rhs = tensor_dtype(rhs_metadata, PType::F32, 6)?; - assert_eq!(lhs.least_supertype(&rhs), None); - Ok(()) - } - - #[test] - fn tensor_vs_non_extension_returns_none() -> VortexResult<()> { - let lhs = tensor_dtype(FixedShapeTensorMetadata::new(vec![2, 3]), PType::F32, 6)?; - let rhs = DType::Primitive(PType::F32, Nullability::NonNullable); - assert_eq!(lhs.least_supertype(&rhs), None); - Ok(()) + #[rstest] + #[case::widens_element_when_metadata_matches( + tensor_dtype(FixedShapeTensorMetadata::new(vec![2, 3]), PType::F32, 6), + tensor_dtype(FixedShapeTensorMetadata::new(vec![2, 3]), PType::F64, 6), + Some(tensor_dtype(FixedShapeTensorMetadata::new(vec![2, 3]), PType::F64, 6)), + )] + #[case::different_shape_returns_none( + tensor_dtype(FixedShapeTensorMetadata::new(vec![2, 3]), PType::F32, 6), + tensor_dtype(FixedShapeTensorMetadata::new(vec![3, 2]), PType::F32, 6), + None, + )] + #[case::different_permutation_returns_none( + tensor_dtype( + FixedShapeTensorMetadata::new(vec![2, 3]).with_permutation(vec![0, 1]).unwrap(), + PType::F32, 6, + ), + tensor_dtype( + FixedShapeTensorMetadata::new(vec![2, 3]).with_permutation(vec![1, 0]).unwrap(), + PType::F32, 6, + ), + None, + )] + #[case::different_dim_names_returns_none( + tensor_dtype( + FixedShapeTensorMetadata::new(vec![2, 3]) + .with_dim_names(vec!["x".into(), "y".into()]).unwrap(), + PType::F32, 6, + ), + tensor_dtype( + FixedShapeTensorMetadata::new(vec![2, 3]) + .with_dim_names(vec!["rows".into(), "cols".into()]).unwrap(), + PType::F32, 6, + ), + None, + )] + #[case::vs_non_extension_returns_none( + tensor_dtype(FixedShapeTensorMetadata::new(vec![2, 3]), PType::F32, 6), + DType::Primitive(PType::F32, Nullability::NonNullable), + None, + )] + fn tensor_least_supertype( + #[case] lhs: DType, + #[case] rhs: DType, + #[case] expected: Option, + ) { + assert_eq!(lhs.least_supertype(&rhs), expected); } } diff --git a/vortex-tensor/src/types/vector/vtable.rs b/vortex-tensor/src/types/vector/vtable.rs index f42ae1caeb5..59c4f7494c3 100644 --- a/vortex-tensor/src/types/vector/vtable.rs +++ b/vortex-tensor/src/types/vector/vtable.rs @@ -142,60 +142,46 @@ mod tests { Ok(()) } - /// Constructs a `Vector` ext dtype wrapped in `DType::Extension`. - fn vector_dtype(ptype: PType, dims: u32) -> VortexResult { - vector_dtype_with_outer(ptype, dims, Nullability::NonNullable) - } - - /// Constructs a `Vector` ext dtype with the given outer `Nullability`, wrapped in - /// `DType::Extension`. - fn vector_dtype_with_outer(ptype: PType, dims: u32, outer: Nullability) -> VortexResult { + fn vector_dtype(ptype: PType, dims: u32, outer: Nullability) -> DType { let storage = vector_storage_dtype(ptype, dims, outer); - Ok(DType::Extension( - ExtDType::::try_new(EmptyMetadata, storage)?.erased(), - )) - } - - #[test] - fn vector_widens_float_precision() -> VortexResult<()> { - let lhs = vector_dtype(PType::F32, 768)?; - let rhs = vector_dtype(PType::F64, 768)?; - let expected = vector_dtype(PType::F64, 768)?; - assert_eq!(lhs.least_supertype(&rhs), Some(expected)); - Ok(()) - } - - #[test] - fn vector_dim_mismatch_returns_none() -> VortexResult<()> { - let lhs = vector_dtype(PType::F32, 768)?; - let rhs = vector_dtype(PType::F32, 1024)?; - assert_eq!(lhs.least_supertype(&rhs), None); - Ok(()) - } - - #[test] - fn vector_vs_non_extension_returns_none() -> VortexResult<()> { - let lhs = vector_dtype(PType::F32, 768)?; - let rhs = DType::Primitive(PType::F32, Nullability::NonNullable); - assert_eq!(lhs.least_supertype(&rhs), None); - Ok(()) - } - - #[test] - fn vector_unions_outer_nullability_with_float_widening() -> VortexResult<()> { - let lhs = vector_dtype_with_outer(PType::F32, 4, Nullability::NonNullable)?; - let rhs = vector_dtype_with_outer(PType::F64, 4, Nullability::Nullable)?; - let expected = vector_dtype_with_outer(PType::F64, 4, Nullability::Nullable)?; - assert_eq!(lhs.least_supertype(&rhs), Some(expected)); - Ok(()) + DType::Extension( + ExtDType::::try_new(EmptyMetadata, storage) + .unwrap() + .erased(), + ) } - #[test] - fn vector_same_ptype_unions_outer_nullability() -> VortexResult<()> { - let lhs = vector_dtype_with_outer(PType::F32, 4, Nullability::NonNullable)?; - let rhs = vector_dtype_with_outer(PType::F32, 4, Nullability::Nullable)?; - let expected = vector_dtype_with_outer(PType::F32, 4, Nullability::Nullable)?; - assert_eq!(lhs.least_supertype(&rhs), Some(expected)); - Ok(()) + #[rstest] + #[case::widens_float_precision( + vector_dtype(PType::F32, 768, Nullability::NonNullable), + vector_dtype(PType::F64, 768, Nullability::NonNullable), + Some(vector_dtype(PType::F64, 768, Nullability::NonNullable)) + )] + #[case::dim_mismatch_returns_none( + vector_dtype(PType::F32, 768, Nullability::NonNullable), + vector_dtype(PType::F32, 1024, Nullability::NonNullable), + None + )] + #[case::vs_non_extension_returns_none( + vector_dtype(PType::F32, 768, Nullability::NonNullable), + DType::Primitive(PType::F32, Nullability::NonNullable), + None + )] + #[case::unions_outer_nullability_with_float_widening( + vector_dtype(PType::F32, 4, Nullability::NonNullable), + vector_dtype(PType::F64, 4, Nullability::Nullable), + Some(vector_dtype(PType::F64, 4, Nullability::Nullable)) + )] + #[case::same_ptype_unions_outer_nullability( + vector_dtype(PType::F32, 4, Nullability::NonNullable), + vector_dtype(PType::F32, 4, Nullability::Nullable), + Some(vector_dtype(PType::F32, 4, Nullability::Nullable)) + )] + fn vector_least_supertype( + #[case] lhs: DType, + #[case] rhs: DType, + #[case] expected: Option, + ) { + assert_eq!(lhs.least_supertype(&rhs), expected); } } From fd6654be9de71543f78a15b0b4616f38a4309d28 Mon Sep 17 00:00:00 2001 From: Baris Palaska Date: Mon, 27 Apr 2026 12:41:48 +0100 Subject: [PATCH 09/21] rm unnecessary stuff Signed-off-by: Baris Palaska --- vortex-tensor/Cargo.toml | 1 - vortex-tensor/src/lib.rs | 1 - 2 files changed, 2 deletions(-) diff --git a/vortex-tensor/Cargo.toml b/vortex-tensor/Cargo.toml index 71d673d539a..d3de27f5fb1 100644 --- a/vortex-tensor/Cargo.toml +++ b/vortex-tensor/Cargo.toml @@ -39,5 +39,4 @@ mimalloc = { workspace = true } rand = { workspace = true } rand_distr = { workspace = true } rstest = { workspace = true } -serde_json = { workspace = true } vortex-btrblocks = { path = "../vortex-btrblocks" } diff --git a/vortex-tensor/src/lib.rs b/vortex-tensor/src/lib.rs index 2e5c647a2dc..cdf77536a6f 100644 --- a/vortex-tensor/src/lib.rs +++ b/vortex-tensor/src/lib.rs @@ -51,7 +51,6 @@ pub fn initialize(session: &VortexSession) { dtypes.register(Vector); dtypes.register(FixedShapeTensor); dtypes.register_arrow_canonical(*fixed_shape::ID, fixed_shape::ARROW_EXT_NAME); - drop(dtypes); let session_fns = session.scalar_fns(); From 80f52f1a0bcb187114d4035810a20fb8ae540e84 Mon Sep 17 00:00:00 2001 From: Baris Palaska Date: Mon, 27 Apr 2026 12:44:51 +0100 Subject: [PATCH 10/21] better name Signed-off-by: Baris Palaska --- vortex-array/src/dtype/arrow.rs | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/vortex-array/src/dtype/arrow.rs b/vortex-array/src/dtype/arrow.rs index 362a1caef8e..a403701804f 100644 --- a/vortex-array/src/dtype/arrow.rs +++ b/vortex-array/src/dtype/arrow.rs @@ -383,11 +383,11 @@ impl DType { /// Returns the Arrow [`DataType`] that best corresponds to this Vortex [`DType`]. pub fn to_arrow_dtype(&self) -> VortexResult { - to_arrow_dtype_with_dtypes(self, &LEGACY_SESSION.dtypes()) + arrow_dtype_from_dtype(self, &LEGACY_SESSION.dtypes()) } } -fn to_arrow_dtype_with_dtypes(dtype: &DType, dtypes: &DTypeSession) -> VortexResult { +fn arrow_dtype_from_dtype(dtype: &DType, dtypes: &DTypeSession) -> VortexResult { Ok(match dtype { DType::Null => DataType::Null, DType::Bool(_) => DataType::Boolean, @@ -459,7 +459,7 @@ fn to_arrow_dtype_with_dtypes(dtype: &DType, dtypes: &DTypeSession) -> VortexRes } // Extension identity lives on the Field (see `field_from_dtype`), not on // DataType, so here we only encode the storage type. - to_arrow_dtype_with_dtypes(ext_dtype.storage_dtype(), dtypes)? + arrow_dtype_from_dtype(ext_dtype.storage_dtype(), dtypes)? } }) } @@ -487,7 +487,7 @@ fn field_from_dtype(name: &str, dtype: &DType, dtypes: &DTypeSession) -> VortexR return Ok(Field::new(name, native, dtype.is_nullable())); } - let storage_arrow = to_arrow_dtype_with_dtypes(ext.storage_dtype(), dtypes)?; + let storage_arrow = arrow_dtype_from_dtype(ext.storage_dtype(), dtypes)?; let ext_meta_bytes = ext.serialize_metadata()?; let (ext_name, meta_str) = match dtypes.arrow_canonical_for(&ext.id()) { Some(canonical) => { @@ -513,7 +513,7 @@ fn field_from_dtype(name: &str, dtype: &DType, dtypes: &DTypeSession) -> VortexR Ok(Field::new( name, - to_arrow_dtype_with_dtypes(dtype, dtypes)?, + arrow_dtype_from_dtype(dtype, dtypes)?, dtype.is_nullable(), )) } From 784d2834c8b4c3bc9c4bd3453d408a80b9363c8b Mon Sep 17 00:00:00 2001 From: Baris Palaska Date: Mon, 27 Apr 2026 15:55:02 +0100 Subject: [PATCH 11/21] newtype, static ID Signed-off-by: Baris Palaska --- vortex-array/src/dtype/session.rs | 58 +++++++++++++------ vortex-tensor/src/lib.rs | 7 ++- vortex-tensor/src/tests/arrow_roundtrip.rs | 7 +-- vortex-tensor/src/types/fixed_shape/mod.rs | 11 ++-- vortex-tensor/src/types/fixed_shape/vtable.rs | 2 +- vortex-tensor/src/types/vector/mod.rs | 2 - vortex-tensor/src/types/vector/vtable.rs | 2 +- 7 files changed, 57 insertions(+), 32 deletions(-) diff --git a/vortex-array/src/dtype/session.rs b/vortex-array/src/dtype/session.rs index 75246a0c59a..a297957984d 100644 --- a/vortex-array/src/dtype/session.rs +++ b/vortex-array/src/dtype/session.rs @@ -21,21 +21,53 @@ use crate::extension::datetime::Timestamp; /// Registry for extension dtypes. pub type ExtDTypeRegistry = Registry; -/// Bidirectional alias map between Vortex extension ids and Arrow canonical names. -type ArrowCanonicalMap = HashMap; +/// Bidirectional alias map between Vortex extension ids and Arrow canonical extension names. +/// +/// Aliased extensions emit the canonical name on `ARROW:extension:name` and serialize metadata +/// as raw UTF-8 instead of base64-wrapped bytes. Lookups are lock-free; updates clone-and-swap. +#[derive(Debug, Default)] +struct ArrowCanonicalAliases(ArcSwap>); + +impl ArrowCanonicalAliases { + /// Alias `vortex_id` to the Arrow canonical `arrow_name`. Re-registering evicts the previous + /// mapping for either side, so the bidirectional invariant holds after every call. + fn register(&self, vortex_id: ExtId, arrow_name: &'static str) { + let arrow_id = ExtId::new(arrow_name); + self.0.rcu(|prev| { + let mut next = (**prev).clone(); + if let Some(stale) = next.insert(vortex_id, arrow_id) { + next.remove(&stale); + } + if let Some(stale) = next.insert(arrow_id, vortex_id) { + next.remove(&stale); + } + Arc::new(next) + }); + } + + /// Returns the Arrow canonical extension name aliased to `vortex_id`, if any. + fn arrow_canonical_for(&self, vortex_id: &ExtId) -> Option { + self.0.load().get(vortex_id).copied() + } + + /// Returns the Vortex extension id aliased to `arrow_name`, if any. + fn vortex_id_for_arrow_canonical(&self, arrow_name: &str) -> Option { + self.0.load().get(&ExtId::new(arrow_name)).copied() + } +} /// Session for managing extension dtypes. #[derive(Debug)] pub struct DTypeSession { registry: ExtDTypeRegistry, - arrow_canonical: ArcSwap, + arrow_canonical: ArrowCanonicalAliases, } impl Default for DTypeSession { fn default() -> Self { let this = Self { registry: Registry::default(), - arrow_canonical: ArcSwap::new(Arc::new(ArrowCanonicalMap::default())), + arrow_canonical: ArrowCanonicalAliases::default(), }; this.register(Date); @@ -62,30 +94,18 @@ impl DTypeSession { /// emit the canonical name on `ARROW:extension:name` and serialize metadata as raw UTF-8 /// instead of base64-wrapped bytes. Re-registering evicts the previous mapping. pub fn register_arrow_canonical(&self, vortex_id: ExtId, arrow_name: &'static str) { - let arrow_id = ExtId::new(arrow_name); - self.arrow_canonical.rcu(|prev| { - let mut next = (**prev).clone(); - if let Some(stale) = next.insert(vortex_id, arrow_id) { - next.remove(&stale); - } - if let Some(stale) = next.insert(arrow_id, vortex_id) { - next.remove(&stale); - } - Arc::new(next) - }); + self.arrow_canonical.register(vortex_id, arrow_name); } /// Returns the Arrow canonical extension name aliased to the given Vortex id, if any. pub fn arrow_canonical_for(&self, vortex_id: &ExtId) -> Option { - self.arrow_canonical.load().get(vortex_id).copied() + self.arrow_canonical.arrow_canonical_for(vortex_id) } /// Returns the Vortex extension id aliased to the given Arrow canonical name, if any. pub fn vortex_id_for_arrow_canonical(&self, arrow_name: &str) -> Option { self.arrow_canonical - .load() - .get(&ExtId::new(arrow_name)) - .copied() + .vortex_id_for_arrow_canonical(arrow_name) } } diff --git a/vortex-tensor/src/lib.rs b/vortex-tensor/src/lib.rs index cdf77536a6f..11eea6e9a59 100644 --- a/vortex-tensor/src/lib.rs +++ b/vortex-tensor/src/lib.rs @@ -11,6 +11,7 @@ )] use vortex_array::arrays::scalar_fn::plugin::ScalarFnArrayPlugin; +use vortex_array::dtype::extension::ExtVTable; use vortex_array::dtype::session::DTypeSessionExt; use vortex_array::scalar_fn::session::ScalarFnSessionExt; use vortex_array::session::ArraySessionExt; @@ -50,7 +51,11 @@ pub fn initialize(session: &VortexSession) { let dtypes = session.dtypes(); dtypes.register(Vector); dtypes.register(FixedShapeTensor); - dtypes.register_arrow_canonical(*fixed_shape::ID, fixed_shape::ARROW_EXT_NAME); + dtypes.register_arrow_canonical(FixedShapeTensor.id(), FixedShapeTensor::ARROW_EXT_NAME); + // Drop the dashmap shard ref before acquiring `scalar_fns` — `or_insert_with` inside + // `session.get` takes a write lock, which would deadlock if `dtypes` were still alive + // and the two TypeIds happened to hash to the same shard. + drop(dtypes); let session_fns = session.scalar_fns(); diff --git a/vortex-tensor/src/tests/arrow_roundtrip.rs b/vortex-tensor/src/tests/arrow_roundtrip.rs index 313eb343aea..3975d8e2157 100644 --- a/vortex-tensor/src/tests/arrow_roundtrip.rs +++ b/vortex-tensor/src/tests/arrow_roundtrip.rs @@ -14,14 +14,13 @@ use vortex_array::dtype::Nullability; use vortex_array::dtype::PType; use vortex_array::dtype::arrow::FromArrowType; use vortex_array::dtype::extension::ExtDType; +use vortex_array::dtype::extension::ExtVTable; use vortex_array::extension::datetime::TimeUnit; use vortex_array::extension::datetime::Timestamp; use crate::tests::SESSION; -use crate::types::fixed_shape; use crate::types::fixed_shape::FixedShapeTensor; use crate::types::fixed_shape::FixedShapeTensorMetadata; -use crate::types::vector; use crate::types::vector::Vector; fn vector_dtype(len: u32) -> DType { @@ -56,7 +55,7 @@ fn vector_forward_carries_extension_name() { .metadata() .get(EXTENSION_TYPE_NAME_KEY) .map(String::as_str), - Some(vector::ID.as_str()), + Some(Vector.id().as_str()), ); // EmptyMetadata → no metadata key. assert!(field.metadata().get(EXTENSION_TYPE_METADATA_KEY).is_none()); @@ -121,7 +120,7 @@ fn fixed_shape_tensor_metadata_roundtrip() { .metadata() .get(EXTENSION_TYPE_NAME_KEY) .map(String::as_str), - Some(fixed_shape::ARROW_EXT_NAME), + Some(FixedShapeTensor::ARROW_EXT_NAME), ); // Canonical wire: raw JSON, not base64. diff --git a/vortex-tensor/src/types/fixed_shape/mod.rs b/vortex-tensor/src/types/fixed_shape/mod.rs index 565dcb66a47..423c1db700a 100644 --- a/vortex-tensor/src/types/fixed_shape/mod.rs +++ b/vortex-tensor/src/types/fixed_shape/mod.rs @@ -3,13 +3,17 @@ //! Fixed-shape Tensor extension type. -/// Arrow canonical extension name aliased to [`ID`]. -pub(crate) const ARROW_EXT_NAME: &str = "arrow.fixed_shape_tensor"; - /// The VTable for the Tensor extension type. #[derive(Clone, Debug, Default, PartialEq, Eq, Hash)] pub struct FixedShapeTensor; +impl FixedShapeTensor { + /// Arrow canonical extension name aliased to this type's [`ExtVTable::id`]. + /// + /// [`ExtVTable::id`]: vortex_array::dtype::extension::ExtVTable::id + pub(crate) const ARROW_EXT_NAME: &'static str = "arrow.fixed_shape_tensor"; +} + mod matcher; pub use matcher::AnyFixedShapeTensor; pub use matcher::FixedShapeTensorMatcherMetadata; @@ -19,4 +23,3 @@ pub use metadata::FixedShapeTensorMetadata; mod canonical; mod vtable; -pub(crate) use vtable::ID; diff --git a/vortex-tensor/src/types/fixed_shape/vtable.rs b/vortex-tensor/src/types/fixed_shape/vtable.rs index 85d70068bed..cc789ea39d7 100644 --- a/vortex-tensor/src/types/fixed_shape/vtable.rs +++ b/vortex-tensor/src/types/fixed_shape/vtable.rs @@ -17,7 +17,7 @@ use crate::types::fixed_shape::FixedShapeTensorMetadata; use crate::types::fixed_shape::canonical; /// Vortex extension id for [`FixedShapeTensor`]. -pub(crate) static ID: CachedId = CachedId::new("vortex.tensor.fixed_shape_tensor"); +static ID: CachedId = CachedId::new("vortex.tensor.fixed_shape_tensor"); impl ExtVTable for FixedShapeTensor { type Metadata = FixedShapeTensorMetadata; diff --git a/vortex-tensor/src/types/vector/mod.rs b/vortex-tensor/src/types/vector/mod.rs index 8f2d16d9a9b..3763220fc82 100644 --- a/vortex-tensor/src/types/vector/mod.rs +++ b/vortex-tensor/src/types/vector/mod.rs @@ -81,5 +81,3 @@ pub use matcher::AnyVector; pub use matcher::VectorMatcherMetadata; mod vtable; -#[cfg(test)] -pub(crate) use vtable::ID; diff --git a/vortex-tensor/src/types/vector/vtable.rs b/vortex-tensor/src/types/vector/vtable.rs index 59c4f7494c3..83d807a2cd3 100644 --- a/vortex-tensor/src/types/vector/vtable.rs +++ b/vortex-tensor/src/types/vector/vtable.rs @@ -14,7 +14,7 @@ use crate::types::vector::Vector; use crate::types::vector::validate_vector_storage_dtype; /// Vortex extension id for [`Vector`]. -pub(crate) static ID: CachedId = CachedId::new("vortex.tensor.vector"); +static ID: CachedId = CachedId::new("vortex.tensor.vector"); impl ExtVTable for Vector { type Metadata = EmptyMetadata; From 3b67a4dcfb036c5ff5aad714d93e80b3c0fe72a6 Mon Sep 17 00:00:00 2001 From: Baris Palaska Date: Mon, 27 Apr 2026 17:43:13 +0100 Subject: [PATCH 12/21] keep on-disk metadata as proto, convert to json only at the arrow boundary Signed-off-by: Baris Palaska --- vortex-array/src/dtype/arrow.rs | 45 +++-- vortex-array/src/dtype/session.rs | 180 +++++++++++++----- vortex-tensor/src/lib.rs | 14 +- .../src/types/fixed_shape/canonical.rs | 55 ++++-- vortex-tensor/src/types/fixed_shape/mod.rs | 5 +- vortex-tensor/src/types/fixed_shape/proto.rs | 70 +++++++ vortex-tensor/src/types/fixed_shape/vtable.rs | 6 +- 7 files changed, 285 insertions(+), 90 deletions(-) create mode 100644 vortex-tensor/src/types/fixed_shape/proto.rs diff --git a/vortex-array/src/dtype/arrow.rs b/vortex-array/src/dtype/arrow.rs index a403701804f..61f3a6de046 100644 --- a/vortex-array/src/dtype/arrow.rs +++ b/vortex-array/src/dtype/arrow.rs @@ -13,7 +13,6 @@ //! For this reason, it's recommended to do as much computation as possible within Vortex, and then //! materialize an Arrow ArrayRef at the very end of the processing chain. -use std::borrow::Cow; use std::sync::Arc; use arrow_schema::DataType; @@ -45,6 +44,7 @@ use crate::dtype::PType; use crate::dtype::StructFields; use crate::dtype::extension::ExtDTypeRef; use crate::dtype::extension::ExtId; +use crate::dtype::session::ArrowCanonicalCodec; use crate::dtype::session::DTypeSession; use crate::dtype::session::DTypeSessionExt; use crate::extension::datetime::AnyTemporal; @@ -268,8 +268,10 @@ fn dtype_from_field(field: &Field, dtypes: &DTypeSession) -> DType { }; let canonical_alias = dtypes.vortex_id_for_arrow_canonical(ext_name); - let is_canonical = canonical_alias.is_some(); - let ext_id = canonical_alias.unwrap_or_else(|| ExtId::new(ext_name)); + let (ext_id, codec) = match canonical_alias { + Some((vortex_id, codec)) => (vortex_id, Some(codec)), + None => (ExtId::new(ext_name), None), + }; let Some(plugin) = dtypes.registry().find(&ext_id) else { tracing::warn!( @@ -280,7 +282,7 @@ fn dtype_from_field(field: &Field, dtypes: &DTypeSession) -> DType { return storage_dtype; }; - let metadata_bytes = match decode_extension_metadata(field, is_canonical) { + let metadata_bytes = match decode_extension_metadata(field, codec) { Ok(bytes) => bytes, Err(e) => { tracing::warn!( @@ -294,7 +296,7 @@ fn dtype_from_field(field: &Field, dtypes: &DTypeSession) -> DType { } }; - match plugin.deserialize(metadata_bytes.as_ref(), storage_dtype.clone()) { + match plugin.deserialize(&metadata_bytes, storage_dtype.clone()) { Ok(ext_ref) => DType::Extension(ext_ref), Err(e) => { tracing::warn!( @@ -309,16 +311,20 @@ fn dtype_from_field(field: &Field, dtypes: &DTypeSession) -> DType { } } -/// Canonical extensions store UTF-8 bytes directly; non-canonical extensions base64-encode so -/// arbitrary binary plugin output survives the String-typed metadata channel. -fn decode_extension_metadata(field: &Field, is_canonical: bool) -> VortexResult> { +/// Non-canonical extensions base64-encode arbitrary binary metadata to survive Arrow's +/// String-typed metadata channel; canonical extensions go through the registered codec. +fn decode_extension_metadata( + field: &Field, + codec: Option, +) -> VortexResult> { match field.extension_type_metadata() { - None | Some("") => Ok(Cow::Borrowed(&[])), - Some(s) if is_canonical => Ok(Cow::Borrowed(s.as_bytes())), - Some(s) => BASE64_STANDARD - .decode(s) - .map(Cow::Owned) - .map_err(|e| vortex_err!("failed to base64-decode {EXTENSION_TYPE_METADATA_KEY}: {e}")), + None | Some("") => Ok(Vec::new()), + Some(s) => match codec { + Some(codec) => (codec.from_json)(s), + None => BASE64_STANDARD.decode(s).map_err(|e| { + vortex_err!("failed to base64-decode {EXTENSION_TYPE_METADATA_KEY}: {e}") + }), + }, } } @@ -490,13 +496,10 @@ fn field_from_dtype(name: &str, dtype: &DType, dtypes: &DTypeSession) -> VortexR let storage_arrow = arrow_dtype_from_dtype(ext.storage_dtype(), dtypes)?; let ext_meta_bytes = ext.serialize_metadata()?; let (ext_name, meta_str) = match dtypes.arrow_canonical_for(&ext.id()) { - Some(canonical) => { - // Canonical wire: raw UTF-8 (typically JSON), read as-is by arrow-rs / pyarrow. - let s = String::from_utf8(ext_meta_bytes).map_err(|e| { - vortex_err!("canonical extension {canonical} metadata must be valid UTF-8: {e}") - })?; - (canonical.as_str().to_owned(), s) - } + Some((canonical, codec)) => ( + canonical.as_str().to_owned(), + (codec.to_json)(&ext_meta_bytes)?, + ), None => ( ext.id().as_str().to_owned(), BASE64_STANDARD.encode(&ext_meta_bytes), diff --git a/vortex-array/src/dtype/session.rs b/vortex-array/src/dtype/session.rs index a297957984d..1a78b370339 100644 --- a/vortex-array/src/dtype/session.rs +++ b/vortex-array/src/dtype/session.rs @@ -6,6 +6,7 @@ use std::sync::Arc; use arc_swap::ArcSwap; +use vortex_error::VortexResult; use vortex_session::Ref; use vortex_session::SessionExt; use vortex_session::registry::Registry; @@ -21,37 +22,56 @@ use crate::extension::datetime::Timestamp; /// Registry for extension dtypes. pub type ExtDTypeRegistry = Registry; -/// Bidirectional alias map between Vortex extension ids and Arrow canonical extension names. +/// Converters between an extension's on-disk metadata bytes and the Arrow canonical JSON wire. /// -/// Aliased extensions emit the canonical name on `ARROW:extension:name` and serialize metadata -/// as raw UTF-8 instead of base64-wrapped bytes. Lookups are lock-free; updates clone-and-swap. +/// Bundled with the alias at registration time so [`ExtVTable`] stays Arrow-unaware. +#[derive(Copy, Clone, Debug)] +pub struct ArrowCanonicalCodec { + pub to_json: fn(&[u8]) -> VortexResult, + pub from_json: fn(&str) -> VortexResult>, +} + +#[derive(Copy, Clone, Debug)] +struct AliasEntry { + /// Forward entries point at the Arrow canonical id; reverse entries point at the Vortex id. + partner: ExtId, + /// Same codec value in both directions of a registration; eviction relies on this. + codec: ArrowCanonicalCodec, +} + #[derive(Debug, Default)] -struct ArrowCanonicalAliases(ArcSwap>); +struct ArrowCanonicalAliases(ArcSwap>); impl ArrowCanonicalAliases { - /// Alias `vortex_id` to the Arrow canonical `arrow_name`. Re-registering evicts the previous - /// mapping for either side, so the bidirectional invariant holds after every call. - fn register(&self, vortex_id: ExtId, arrow_name: &'static str) { + /// Re-registering evicts the previous mapping for either side so the bidirectional invariant + /// holds after every call. + fn register(&self, vortex_id: ExtId, arrow_name: &'static str, codec: ArrowCanonicalCodec) { let arrow_id = ExtId::new(arrow_name); + let forward = AliasEntry { + partner: arrow_id, + codec, + }; + let reverse = AliasEntry { + partner: vortex_id, + codec, + }; self.0.rcu(|prev| { let mut next = (**prev).clone(); - if let Some(stale) = next.insert(vortex_id, arrow_id) { - next.remove(&stale); + if let Some(stale) = next.insert(vortex_id, forward) { + next.remove(&stale.partner); } - if let Some(stale) = next.insert(arrow_id, vortex_id) { - next.remove(&stale); + if let Some(stale) = next.insert(arrow_id, reverse) { + next.remove(&stale.partner); } Arc::new(next) }); } - /// Returns the Arrow canonical extension name aliased to `vortex_id`, if any. - fn arrow_canonical_for(&self, vortex_id: &ExtId) -> Option { + fn arrow_canonical_for(&self, vortex_id: &ExtId) -> Option { self.0.load().get(vortex_id).copied() } - /// Returns the Vortex extension id aliased to `arrow_name`, if any. - fn vortex_id_for_arrow_canonical(&self, arrow_name: &str) -> Option { + fn vortex_id_for_arrow_canonical(&self, arrow_name: &str) -> Option { self.0.load().get(&ExtId::new(arrow_name)).copied() } } @@ -90,22 +110,32 @@ impl DTypeSession { &self.registry } - /// Alias an Arrow canonical extension name to a Vortex extension id. Aliased extensions - /// emit the canonical name on `ARROW:extension:name` and serialize metadata as raw UTF-8 - /// instead of base64-wrapped bytes. Re-registering evicts the previous mapping. - pub fn register_arrow_canonical(&self, vortex_id: ExtId, arrow_name: &'static str) { - self.arrow_canonical.register(vortex_id, arrow_name); + /// Alias `arrow_name` to `vortex_id` with the codec used at the Arrow boundary. + /// Re-registering evicts the previous mapping for either side. + pub fn register_arrow_canonical( + &self, + vortex_id: ExtId, + arrow_name: &'static str, + codec: ArrowCanonicalCodec, + ) { + self.arrow_canonical.register(vortex_id, arrow_name, codec); } - /// Returns the Arrow canonical extension name aliased to the given Vortex id, if any. - pub fn arrow_canonical_for(&self, vortex_id: &ExtId) -> Option { - self.arrow_canonical.arrow_canonical_for(vortex_id) + /// Returns the Arrow canonical name and codec aliased to `vortex_id`, if any. + pub fn arrow_canonical_for(&self, vortex_id: &ExtId) -> Option<(ExtId, ArrowCanonicalCodec)> { + self.arrow_canonical + .arrow_canonical_for(vortex_id) + .map(|e| (e.partner, e.codec)) } - /// Returns the Vortex extension id aliased to the given Arrow canonical name, if any. - pub fn vortex_id_for_arrow_canonical(&self, arrow_name: &str) -> Option { + /// Returns the Vortex id and codec aliased to `arrow_name`, if any. + pub fn vortex_id_for_arrow_canonical( + &self, + arrow_name: &str, + ) -> Option<(ExtId, ArrowCanonicalCodec)> { self.arrow_canonical .vortex_id_for_arrow_canonical(arrow_name) + .map(|e| (e.partner, e.codec)) } } @@ -123,27 +153,46 @@ impl DTypeSessionExt for S { #[cfg(test)] mod tests { + use vortex_error::vortex_err; + use super::*; + const TEST_CODEC: ArrowCanonicalCodec = ArrowCanonicalCodec { + to_json: |bytes| { + String::from_utf8(bytes.to_vec()).map_err(|e| vortex_err!("non-utf8 test bytes: {e}")) + }, + from_json: |s| Ok(s.as_bytes().to_vec()), + }; + #[test] fn arrow_canonical_re_registration_is_clean() { let session = DTypeSession::default(); let v = ExtId::new("vortex.test"); - session.register_arrow_canonical(v, "arrow.foo"); + session.register_arrow_canonical(v, "arrow.foo", TEST_CODEC); assert_eq!( - session.arrow_canonical_for(&v), + session.arrow_canonical_for(&v).map(|(id, _)| id), Some(ExtId::new("arrow.foo")) ); - assert_eq!(session.vortex_id_for_arrow_canonical("arrow.foo"), Some(v)); + assert_eq!( + session + .vortex_id_for_arrow_canonical("arrow.foo") + .map(|(id, _)| id), + Some(v) + ); - session.register_arrow_canonical(v, "arrow.bar"); + session.register_arrow_canonical(v, "arrow.bar", TEST_CODEC); assert_eq!( - session.arrow_canonical_for(&v), + session.arrow_canonical_for(&v).map(|(id, _)| id), Some(ExtId::new("arrow.bar")) ); - assert_eq!(session.vortex_id_for_arrow_canonical("arrow.bar"), Some(v)); - assert_eq!(session.vortex_id_for_arrow_canonical("arrow.foo"), None); + assert_eq!( + session + .vortex_id_for_arrow_canonical("arrow.bar") + .map(|(id, _)| id), + Some(v) + ); + assert!(session.vortex_id_for_arrow_canonical("arrow.foo").is_none()); } /// `(vid → old, old → vid)` then `register(vid, new)` should leave `(vid → new, new → vid)`. @@ -154,15 +203,31 @@ mod tests { let old = ExtId::new("arrow.b"); let new = ExtId::new("arrow.c"); - session.register_arrow_canonical(vid, "arrow.b"); - assert_eq!(session.arrow_canonical_for(&vid), Some(old)); - assert_eq!(session.vortex_id_for_arrow_canonical("arrow.b"), Some(vid)); + session.register_arrow_canonical(vid, "arrow.b", TEST_CODEC); + assert_eq!( + session.arrow_canonical_for(&vid).map(|(id, _)| id), + Some(old) + ); + assert_eq!( + session + .vortex_id_for_arrow_canonical("arrow.b") + .map(|(id, _)| id), + Some(vid) + ); - session.register_arrow_canonical(vid, "arrow.c"); + session.register_arrow_canonical(vid, "arrow.c", TEST_CODEC); - assert_eq!(session.arrow_canonical_for(&vid), Some(new)); - assert_eq!(session.vortex_id_for_arrow_canonical("arrow.c"), Some(vid)); - assert_eq!(session.vortex_id_for_arrow_canonical("arrow.b"), None); + assert_eq!( + session.arrow_canonical_for(&vid).map(|(id, _)| id), + Some(new) + ); + assert_eq!( + session + .vortex_id_for_arrow_canonical("arrow.c") + .map(|(id, _)| id), + Some(vid) + ); + assert!(session.vortex_id_for_arrow_canonical("arrow.b").is_none()); } /// `(old → name, name → old)` then `register(new, name)` should leave `(new → name, name → new)`. @@ -173,13 +238,38 @@ mod tests { let name = ExtId::new("arrow.b"); let new = ExtId::new("vortex.c"); - session.register_arrow_canonical(old, "arrow.b"); - assert_eq!(session.arrow_canonical_for(&old), Some(name)); + session.register_arrow_canonical(old, "arrow.b", TEST_CODEC); + assert_eq!( + session.arrow_canonical_for(&old).map(|(id, _)| id), + Some(name) + ); + + session.register_arrow_canonical(new, "arrow.b", TEST_CODEC); + + assert_eq!( + session.arrow_canonical_for(&new).map(|(id, _)| id), + Some(name) + ); + assert_eq!( + session + .vortex_id_for_arrow_canonical("arrow.b") + .map(|(id, _)| id), + Some(new) + ); + assert!(session.arrow_canonical_for(&old).is_none()); + } + + #[test] + fn codec_round_trips_through_lookup() { + let session = DTypeSession::default(); + let vid = ExtId::new("vortex.x"); - session.register_arrow_canonical(new, "arrow.b"); + session.register_arrow_canonical(vid, "arrow.x", TEST_CODEC); - assert_eq!(session.arrow_canonical_for(&new), Some(name)); - assert_eq!(session.vortex_id_for_arrow_canonical("arrow.b"), Some(new)); - assert_eq!(session.arrow_canonical_for(&old), None); + let (_, codec) = session.arrow_canonical_for(&vid).unwrap(); + let json = (codec.to_json)(b"hello").unwrap(); + assert_eq!(json, "hello"); + let bytes = (codec.from_json)(&json).unwrap(); + assert_eq!(bytes, b"hello"); } } diff --git a/vortex-tensor/src/lib.rs b/vortex-tensor/src/lib.rs index 11eea6e9a59..41aa729433f 100644 --- a/vortex-tensor/src/lib.rs +++ b/vortex-tensor/src/lib.rs @@ -12,6 +12,7 @@ use vortex_array::arrays::scalar_fn::plugin::ScalarFnArrayPlugin; use vortex_array::dtype::extension::ExtVTable; +use vortex_array::dtype::session::ArrowCanonicalCodec; use vortex_array::dtype::session::DTypeSessionExt; use vortex_array::scalar_fn::session::ScalarFnSessionExt; use vortex_array::session::ArraySessionExt; @@ -51,10 +52,15 @@ pub fn initialize(session: &VortexSession) { let dtypes = session.dtypes(); dtypes.register(Vector); dtypes.register(FixedShapeTensor); - dtypes.register_arrow_canonical(FixedShapeTensor.id(), FixedShapeTensor::ARROW_EXT_NAME); - // Drop the dashmap shard ref before acquiring `scalar_fns` — `or_insert_with` inside - // `session.get` takes a write lock, which would deadlock if `dtypes` were still alive - // and the two TypeIds happened to hash to the same shard. + dtypes.register_arrow_canonical( + FixedShapeTensor.id(), + FixedShapeTensor::ARROW_EXT_NAME, + ArrowCanonicalCodec { + to_json: fixed_shape::proto_to_json, + from_json: fixed_shape::json_to_proto, + }, + ); + // Release the shard read before `scalar_fns` may take a write on the same shard. drop(dtypes); let session_fns = session.scalar_fns(); diff --git a/vortex-tensor/src/types/fixed_shape/canonical.rs b/vortex-tensor/src/types/fixed_shape/canonical.rs index 35da49a6edb..e5ff5389f72 100644 --- a/vortex-tensor/src/types/fixed_shape/canonical.rs +++ b/vortex-tensor/src/types/fixed_shape/canonical.rs @@ -1,7 +1,7 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: Copyright the Vortex contributors -//! Arrow canonical [`arrow.fixed_shape_tensor`] JSON metadata serialization. +//! Arrow canonical [`arrow.fixed_shape_tensor`] JSON wire ⇄ on-disk proto adapters. //! //! Hand-rolled rather than reusing `arrow_schema::extension::FixedShapeTensor` because arrow-rs //! 58 emits `"permutations"` (plural) while the spec and pyarrow use `"permutation"`. @@ -14,6 +14,7 @@ use vortex_error::VortexResult; use vortex_error::vortex_err; use crate::types::fixed_shape::FixedShapeTensorMetadata; +use crate::types::fixed_shape::proto; #[derive(Serialize)] struct WireRef<'a> { @@ -33,20 +34,18 @@ struct Wire { permutation: Option>, } -/// Serialize [`FixedShapeTensorMetadata`] to the Arrow canonical JSON representation. -pub(crate) fn serialize(metadata: &FixedShapeTensorMetadata) -> VortexResult> { +fn metadata_to_json(metadata: &FixedShapeTensorMetadata) -> VortexResult { let wire = WireRef { shape: metadata.logical_shape(), dim_names: metadata.dim_names(), permutation: metadata.permutation(), }; - serde_json::to_vec(&wire) + serde_json::to_string(&wire) .map_err(|e| vortex_err!("fixed_shape_tensor canonical serialize: {e}")) } -/// Deserialize [`FixedShapeTensorMetadata`] from Arrow canonical JSON bytes. -pub(crate) fn deserialize(bytes: &[u8]) -> VortexResult { - let wire: Wire = serde_json::from_slice(bytes) +fn metadata_from_json(json: &str) -> VortexResult { + let wire: Wire = serde_json::from_str(json) .map_err(|e| vortex_err!("fixed_shape_tensor canonical deserialize: {e}"))?; let mut m = FixedShapeTensorMetadata::new(wire.shape); @@ -59,6 +58,16 @@ pub(crate) fn deserialize(bytes: &[u8]) -> VortexResult VortexResult { + let metadata = proto::deserialize(proto_bytes)?; + metadata_to_json(&metadata) +} + +pub(crate) fn json_to_proto(json: &str) -> VortexResult> { + let metadata = metadata_from_json(json)?; + Ok(proto::serialize(&metadata)) +} + #[cfg(test)] mod tests { use rstest::rstest; @@ -84,22 +93,40 @@ mod tests { .with_dim_names(vec!["x".into(), "y".into(), "z".into()]).unwrap() .with_permutation(vec![1, 2, 0]).unwrap() )] - fn roundtrip(#[case] metadata: FixedShapeTensorMetadata) -> VortexResult<()> { - let bytes = serialize(&metadata)?; - let decoded = deserialize(&bytes)?; + fn json_roundtrip(#[case] metadata: FixedShapeTensorMetadata) -> VortexResult<()> { + let json = metadata_to_json(&metadata)?; + let decoded = metadata_from_json(&json)?; assert_eq!(decoded, metadata); Ok(()) } + #[rstest] + #[case::shape_only(FixedShapeTensorMetadata::new(vec![2, 3, 4]))] + #[case::all_fields( + FixedShapeTensorMetadata::new(vec![2, 3, 4]) + .with_dim_names(vec!["x".into(), "y".into(), "z".into()]).unwrap() + .with_permutation(vec![1, 2, 0]).unwrap() + )] + fn proto_to_json_to_proto_roundtrip( + #[case] metadata: FixedShapeTensorMetadata, + ) -> VortexResult<()> { + let proto_bytes = proto::serialize(&metadata); + let json = proto_to_json(&proto_bytes)?; + let proto_again = json_to_proto(&json)?; + let metadata_again = proto::deserialize(&proto_again)?; + assert_eq!(metadata_again, metadata); + Ok(()) + } + #[test] fn wire_format_matches_arrow_spec() -> VortexResult<()> { let metadata = FixedShapeTensorMetadata::new(vec![2, 3, 4]) .with_dim_names(vec!["x".into(), "y".into(), "z".into()])? .with_permutation(vec![1, 2, 0])?; - let bytes = serialize(&metadata)?; + let json = metadata_to_json(&metadata)?; let v: serde_json::Value = - serde_json::from_slice(&bytes).map_err(|e| vortex_err!("parse wire: {e}"))?; + serde_json::from_str(&json).map_err(|e| vortex_err!("parse wire: {e}"))?; assert_eq!(v["shape"], serde_json::json!([2, 3, 4])); assert_eq!(v["dim_names"], serde_json::json!(["x", "y", "z"])); @@ -111,9 +138,9 @@ mod tests { #[test] fn omits_optional_fields_when_unset() -> VortexResult<()> { - let bytes = serialize(&FixedShapeTensorMetadata::new(vec![5]))?; + let json = metadata_to_json(&FixedShapeTensorMetadata::new(vec![5]))?; let v: serde_json::Value = - serde_json::from_slice(&bytes).map_err(|e| vortex_err!("parse wire: {e}"))?; + serde_json::from_str(&json).map_err(|e| vortex_err!("parse wire: {e}"))?; assert!(v.get("dim_names").is_none()); assert!(v.get("permutation").is_none()); Ok(()) diff --git a/vortex-tensor/src/types/fixed_shape/mod.rs b/vortex-tensor/src/types/fixed_shape/mod.rs index 423c1db700a..d34c3b85b8e 100644 --- a/vortex-tensor/src/types/fixed_shape/mod.rs +++ b/vortex-tensor/src/types/fixed_shape/mod.rs @@ -8,9 +8,6 @@ pub struct FixedShapeTensor; impl FixedShapeTensor { - /// Arrow canonical extension name aliased to this type's [`ExtVTable::id`]. - /// - /// [`ExtVTable::id`]: vortex_array::dtype::extension::ExtVTable::id pub(crate) const ARROW_EXT_NAME: &'static str = "arrow.fixed_shape_tensor"; } @@ -22,4 +19,6 @@ mod metadata; pub use metadata::FixedShapeTensorMetadata; mod canonical; +mod proto; mod vtable; +pub(crate) use canonical::{json_to_proto, proto_to_json}; diff --git a/vortex-tensor/src/types/fixed_shape/proto.rs b/vortex-tensor/src/types/fixed_shape/proto.rs new file mode 100644 index 00000000000..dcb157fadbc --- /dev/null +++ b/vortex-tensor/src/types/fixed_shape/proto.rs @@ -0,0 +1,70 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +//! On-disk protobuf serialization for [`FixedShapeTensorMetadata`]. +//! +//! The Arrow JSON wire is a separate concern; see [`super::canonical`] for the proto↔JSON +//! adapters invoked at the Arrow boundary. + +use prost::Message; +use vortex_error::VortexExpect; +use vortex_error::VortexResult; +use vortex_error::vortex_err; + +use crate::types::fixed_shape::FixedShapeTensorMetadata; + +/// Empty repeated fields collapse with absent ones in proto, which matches our semantics: +/// empty `logical_shape` is a scalar; empty `dim_names`/`permutation` mean `None`. +#[derive(Clone, PartialEq, Message)] +struct FixedShapeTensorMetadataProto { + #[prost(uint32, repeated, tag = "1")] + logical_shape: Vec, + #[prost(string, repeated, tag = "2")] + dim_names: Vec, + #[prost(uint32, repeated, tag = "3")] + permutation: Vec, +} + +pub(crate) fn serialize(metadata: &FixedShapeTensorMetadata) -> Vec { + let logical_shape = metadata + .logical_shape() + .iter() + .map(|&d| u32::try_from(d).vortex_expect("dimension size exceeds u32")) + .collect(); + + let dim_names = metadata.dim_names().map(|n| n.to_vec()).unwrap_or_default(); + + let permutation = metadata + .permutation() + .map(|p| { + p.iter() + .map(|&i| u32::try_from(i).vortex_expect("permutation index exceeds u32")) + .collect() + }) + .unwrap_or_default(); + + let proto = FixedShapeTensorMetadataProto { + logical_shape, + dim_names, + permutation, + }; + proto.encode_to_vec() +} + +pub(crate) fn deserialize(bytes: &[u8]) -> VortexResult { + let proto = FixedShapeTensorMetadataProto::decode(bytes).map_err(|e| vortex_err!("{e}"))?; + let logical_shape = proto + .logical_shape + .into_iter() + .map(|d| d as usize) + .collect(); + let mut m = FixedShapeTensorMetadata::new(logical_shape); + if !proto.dim_names.is_empty() { + m = m.with_dim_names(proto.dim_names)?; + } + if !proto.permutation.is_empty() { + let permutation = proto.permutation.into_iter().map(|i| i as usize).collect(); + m = m.with_permutation(permutation)?; + } + Ok(m) +} diff --git a/vortex-tensor/src/types/fixed_shape/vtable.rs b/vortex-tensor/src/types/fixed_shape/vtable.rs index cc789ea39d7..97eadbb55fb 100644 --- a/vortex-tensor/src/types/fixed_shape/vtable.rs +++ b/vortex-tensor/src/types/fixed_shape/vtable.rs @@ -14,7 +14,7 @@ use vortex_session::registry::CachedId; use crate::types::fixed_shape::FixedShapeTensor; use crate::types::fixed_shape::FixedShapeTensorMetadata; -use crate::types::fixed_shape::canonical; +use crate::types::fixed_shape::proto; /// Vortex extension id for [`FixedShapeTensor`]. static ID: CachedId = CachedId::new("vortex.tensor.fixed_shape_tensor"); @@ -30,11 +30,11 @@ impl ExtVTable for FixedShapeTensor { } fn serialize_metadata(&self, metadata: &Self::Metadata) -> VortexResult> { - canonical::serialize(metadata) + Ok(proto::serialize(metadata)) } fn deserialize_metadata(&self, metadata: &[u8]) -> VortexResult { - canonical::deserialize(metadata) + proto::deserialize(metadata) } fn least_supertype(ext_dtype: &ExtDType, other: &DType) -> Option { From 1ddb32b2b3d6309be124413757477222cf982252 Mon Sep 17 00:00:00 2001 From: Baris Palaska Date: Mon, 27 Apr 2026 17:46:09 +0100 Subject: [PATCH 13/21] revert comments Signed-off-by: Baris Palaska --- vortex-tensor/src/types/fixed_shape/proto.rs | 32 ++++++++++++++++---- 1 file changed, 26 insertions(+), 6 deletions(-) diff --git a/vortex-tensor/src/types/fixed_shape/proto.rs b/vortex-tensor/src/types/fixed_shape/proto.rs index dcb157fadbc..89b3db4289d 100644 --- a/vortex-tensor/src/types/fixed_shape/proto.rs +++ b/vortex-tensor/src/types/fixed_shape/proto.rs @@ -1,10 +1,7 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: Copyright the Vortex contributors -//! On-disk protobuf serialization for [`FixedShapeTensorMetadata`]. -//! -//! The Arrow JSON wire is a separate concern; see [`super::canonical`] for the proto↔JSON -//! adapters invoked at the Arrow boundary. +//! Protobuf serialization for [`FixedShapeTensorMetadata`]. use prost::Message; use vortex_error::VortexExpect; @@ -13,18 +10,32 @@ use vortex_error::vortex_err; use crate::types::fixed_shape::FixedShapeTensorMetadata; -/// Empty repeated fields collapse with absent ones in proto, which matches our semantics: -/// empty `logical_shape` is a scalar; empty `dim_names`/`permutation` mean `None`. +/// Protobuf representation of [`FixedShapeTensorMetadata`]. +/// +/// Protobuf does not distinguish between an absent repeated field and an empty one (both will +/// deserialize as an empty `Vec`). This is fine because the semantic meaning is unambiguous: +/// +/// - `logical_shape` empty: 0-dimensional (scalar) tensor. +/// - `dim_names` empty: no dimension names (`None`). +/// - `permutation` empty: no permutation, i.e., identity layout (`None`). #[derive(Clone, PartialEq, Message)] struct FixedShapeTensorMetadataProto { + /// The size of each logical dimension. Empty for a 0-dimensional scalar tensor. #[prost(uint32, repeated, tag = "1")] logical_shape: Vec, + + /// Optional human-readable names for each logical dimension. When present, must have the + /// same length as `logical_shape`. Empty means no names are set. #[prost(string, repeated, tag = "2")] dim_names: Vec, + + /// Optional dimension permutation mapping logical to physical indices. When present, must + /// be a permutation of `[0, 1, ..., N-1]`. Empty means identity (row-major) layout. #[prost(uint32, repeated, tag = "3")] permutation: Vec, } +/// Serializes [`FixedShapeTensorMetadata`] to protobuf bytes. pub(crate) fn serialize(metadata: &FixedShapeTensorMetadata) -> Vec { let logical_shape = metadata .logical_shape() @@ -51,14 +62,22 @@ pub(crate) fn serialize(metadata: &FixedShapeTensorMetadata) -> Vec { proto.encode_to_vec() } +/// Deserializes [`FixedShapeTensorMetadata`] from protobuf bytes. +/// +/// For 0-dimensional tensors, all three repeated fields are empty, which correctly produces a +/// metadata with an empty shape and no names or permutation. pub(crate) fn deserialize(bytes: &[u8]) -> VortexResult { let proto = FixedShapeTensorMetadataProto::decode(bytes).map_err(|e| vortex_err!("{e}"))?; + let logical_shape = proto .logical_shape .into_iter() .map(|d| d as usize) .collect(); let mut m = FixedShapeTensorMetadata::new(logical_shape); + + // Note that this is fine for 0 dimensions since if we do not have any dimensions, we cannot + // have any names or permutations. if !proto.dim_names.is_empty() { m = m.with_dim_names(proto.dim_names)?; } @@ -66,5 +85,6 @@ pub(crate) fn deserialize(bytes: &[u8]) -> VortexResult Date: Mon, 27 Apr 2026 17:55:38 +0100 Subject: [PATCH 14/21] public api, fmt Signed-off-by: Baris Palaska --- vortex-array/public-api.lock | 22 +++++++++++++++++++--- vortex-tensor/src/types/fixed_shape/mod.rs | 3 ++- 2 files changed, 21 insertions(+), 4 deletions(-) diff --git a/vortex-array/public-api.lock b/vortex-array/public-api.lock index 9abe79c9ee7..e50e30b03b5 100644 --- a/vortex-array/public-api.lock +++ b/vortex-array/public-api.lock @@ -8816,19 +8816,35 @@ pub mod vortex_array::dtype::serde pub mod vortex_array::dtype::session +pub struct vortex_array::dtype::session::ArrowCanonicalCodec + +pub vortex_array::dtype::session::ArrowCanonicalCodec::from_json: fn(&str) -> vortex_error::VortexResult> + +pub vortex_array::dtype::session::ArrowCanonicalCodec::to_json: fn(&[u8]) -> vortex_error::VortexResult + +impl core::clone::Clone for vortex_array::dtype::session::ArrowCanonicalCodec + +pub fn vortex_array::dtype::session::ArrowCanonicalCodec::clone(&self) -> vortex_array::dtype::session::ArrowCanonicalCodec + +impl core::fmt::Debug for vortex_array::dtype::session::ArrowCanonicalCodec + +pub fn vortex_array::dtype::session::ArrowCanonicalCodec::fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result + +impl core::marker::Copy for vortex_array::dtype::session::ArrowCanonicalCodec + pub struct vortex_array::dtype::session::DTypeSession impl vortex_array::dtype::session::DTypeSession -pub fn vortex_array::dtype::session::DTypeSession::arrow_canonical_for(&self, vortex_id: &vortex_array::dtype::extension::ExtId) -> core::option::Option +pub fn vortex_array::dtype::session::DTypeSession::arrow_canonical_for(&self, vortex_id: &vortex_array::dtype::extension::ExtId) -> core::option::Option<(vortex_array::dtype::extension::ExtId, vortex_array::dtype::session::ArrowCanonicalCodec)> pub fn vortex_array::dtype::session::DTypeSession::register(&self, vtable: V) -pub fn vortex_array::dtype::session::DTypeSession::register_arrow_canonical(&self, vortex_id: vortex_array::dtype::extension::ExtId, arrow_name: &'static str) +pub fn vortex_array::dtype::session::DTypeSession::register_arrow_canonical(&self, vortex_id: vortex_array::dtype::extension::ExtId, arrow_name: &'static str, codec: vortex_array::dtype::session::ArrowCanonicalCodec) pub fn vortex_array::dtype::session::DTypeSession::registry(&self) -> &vortex_array::dtype::session::ExtDTypeRegistry -pub fn vortex_array::dtype::session::DTypeSession::vortex_id_for_arrow_canonical(&self, arrow_name: &str) -> core::option::Option +pub fn vortex_array::dtype::session::DTypeSession::vortex_id_for_arrow_canonical(&self, arrow_name: &str) -> core::option::Option<(vortex_array::dtype::extension::ExtId, vortex_array::dtype::session::ArrowCanonicalCodec)> impl core::default::Default for vortex_array::dtype::session::DTypeSession diff --git a/vortex-tensor/src/types/fixed_shape/mod.rs b/vortex-tensor/src/types/fixed_shape/mod.rs index d34c3b85b8e..147fbdf1966 100644 --- a/vortex-tensor/src/types/fixed_shape/mod.rs +++ b/vortex-tensor/src/types/fixed_shape/mod.rs @@ -21,4 +21,5 @@ pub use metadata::FixedShapeTensorMetadata; mod canonical; mod proto; mod vtable; -pub(crate) use canonical::{json_to_proto, proto_to_json}; +pub(crate) use canonical::json_to_proto; +pub(crate) use canonical::proto_to_json; From 2978adda6c010b04a6b3a2e5845664b432c4d108 Mon Sep 17 00:00:00 2001 From: Baris Palaska Date: Mon, 27 Apr 2026 18:08:49 +0100 Subject: [PATCH 15/21] refactor registry Signed-off-by: Baris Palaska --- vortex-array/public-api.lock | 6 +- vortex-array/src/dtype/arrow.rs | 8 +- vortex-array/src/dtype/session.rs | 184 ++++++++++++------------------ vortex-tensor/src/lib.rs | 3 +- 4 files changed, 80 insertions(+), 121 deletions(-) diff --git a/vortex-array/public-api.lock b/vortex-array/public-api.lock index 3a667ec2973..7219b06a6f9 100644 --- a/vortex-array/public-api.lock +++ b/vortex-array/public-api.lock @@ -8836,15 +8836,15 @@ pub struct vortex_array::dtype::session::DTypeSession impl vortex_array::dtype::session::DTypeSession -pub fn vortex_array::dtype::session::DTypeSession::arrow_canonical_for(&self, vortex_id: &vortex_array::dtype::extension::ExtId) -> core::option::Option<(vortex_array::dtype::extension::ExtId, vortex_array::dtype::session::ArrowCanonicalCodec)> +pub fn vortex_array::dtype::session::DTypeSession::arrow_alias_for(&self, vortex_id: &vortex_array::dtype::extension::ExtId) -> core::option::Option<(vortex_array::dtype::extension::ExtId, vortex_array::dtype::session::ArrowCanonicalCodec)> pub fn vortex_array::dtype::session::DTypeSession::register(&self, vtable: V) -pub fn vortex_array::dtype::session::DTypeSession::register_arrow_canonical(&self, vortex_id: vortex_array::dtype::extension::ExtId, arrow_name: &'static str, codec: vortex_array::dtype::session::ArrowCanonicalCodec) +pub fn vortex_array::dtype::session::DTypeSession::register_arrow_canonical(&self, vortex_id: vortex_array::dtype::extension::ExtId, arrow_id: vortex_array::dtype::extension::ExtId, codec: vortex_array::dtype::session::ArrowCanonicalCodec) pub fn vortex_array::dtype::session::DTypeSession::registry(&self) -> &vortex_array::dtype::session::ExtDTypeRegistry -pub fn vortex_array::dtype::session::DTypeSession::vortex_id_for_arrow_canonical(&self, arrow_name: &str) -> core::option::Option<(vortex_array::dtype::extension::ExtId, vortex_array::dtype::session::ArrowCanonicalCodec)> +pub fn vortex_array::dtype::session::DTypeSession::vortex_alias_for(&self, arrow_id: &vortex_array::dtype::extension::ExtId) -> core::option::Option<(vortex_array::dtype::extension::ExtId, vortex_array::dtype::session::ArrowCanonicalCodec)> impl core::default::Default for vortex_array::dtype::session::DTypeSession diff --git a/vortex-array/src/dtype/arrow.rs b/vortex-array/src/dtype/arrow.rs index 61f3a6de046..125220b275f 100644 --- a/vortex-array/src/dtype/arrow.rs +++ b/vortex-array/src/dtype/arrow.rs @@ -267,10 +267,10 @@ fn dtype_from_field(field: &Field, dtypes: &DTypeSession) -> DType { return storage_dtype; }; - let canonical_alias = dtypes.vortex_id_for_arrow_canonical(ext_name); - let (ext_id, codec) = match canonical_alias { + let arrow_id = ExtId::new(ext_name); + let (ext_id, codec) = match dtypes.vortex_alias_for(&arrow_id) { Some((vortex_id, codec)) => (vortex_id, Some(codec)), - None => (ExtId::new(ext_name), None), + None => (arrow_id, None), }; let Some(plugin) = dtypes.registry().find(&ext_id) else { @@ -495,7 +495,7 @@ fn field_from_dtype(name: &str, dtype: &DType, dtypes: &DTypeSession) -> VortexR let storage_arrow = arrow_dtype_from_dtype(ext.storage_dtype(), dtypes)?; let ext_meta_bytes = ext.serialize_metadata()?; - let (ext_name, meta_str) = match dtypes.arrow_canonical_for(&ext.id()) { + let (ext_name, meta_str) = match dtypes.arrow_alias_for(&ext.id()) { Some((canonical, codec)) => ( canonical.as_str().to_owned(), (codec.to_json)(&ext_meta_bytes)?, diff --git a/vortex-array/src/dtype/session.rs b/vortex-array/src/dtype/session.rs index 1a78b370339..2ce4e8671c2 100644 --- a/vortex-array/src/dtype/session.rs +++ b/vortex-array/src/dtype/session.rs @@ -31,48 +31,53 @@ pub struct ArrowCanonicalCodec { pub from_json: fn(&str) -> VortexResult>, } -#[derive(Copy, Clone, Debug)] -struct AliasEntry { - /// Forward entries point at the Arrow canonical id; reverse entries point at the Vortex id. - partner: ExtId, - /// Same codec value in both directions of a registration; eviction relies on this. - codec: ArrowCanonicalCodec, +/// Forward map is the canonical source: each Vortex extension owns its codec and points at the +/// Arrow canonical name it serializes as. Reverse map is a lookup index for the read path, +/// taking an Arrow name back to the Vortex id whose codec applies. +#[derive(Default, Clone)] +struct AliasState { + forward: HashMap, + reverse: HashMap, } #[derive(Debug, Default)] -struct ArrowCanonicalAliases(ArcSwap>); +struct ArrowCanonicalAliases(ArcSwap); impl ArrowCanonicalAliases { - /// Re-registering evicts the previous mapping for either side so the bidirectional invariant - /// holds after every call. - fn register(&self, vortex_id: ExtId, arrow_name: &'static str, codec: ArrowCanonicalCodec) { - let arrow_id = ExtId::new(arrow_name); - let forward = AliasEntry { - partner: arrow_id, - codec, - }; - let reverse = AliasEntry { - partner: vortex_id, - codec, - }; + /// Re-registering evicts any prior alias touching either id so both directions agree. + fn register(&self, vortex_id: ExtId, arrow_id: ExtId, codec: ArrowCanonicalCodec) { self.0.rcu(|prev| { let mut next = (**prev).clone(); - if let Some(stale) = next.insert(vortex_id, forward) { - next.remove(&stale.partner); + if let Some((stale_arrow, _)) = next.forward.remove(&vortex_id) { + next.reverse.remove(&stale_arrow); } - if let Some(stale) = next.insert(arrow_id, reverse) { - next.remove(&stale.partner); + if let Some(stale_vortex) = next.reverse.remove(&arrow_id) { + next.forward.remove(&stale_vortex); } + next.forward.insert(vortex_id, (arrow_id, codec)); + next.reverse.insert(arrow_id, vortex_id); Arc::new(next) }); } - fn arrow_canonical_for(&self, vortex_id: &ExtId) -> Option { - self.0.load().get(vortex_id).copied() + fn arrow_alias_for(&self, vortex_id: &ExtId) -> Option<(ExtId, ArrowCanonicalCodec)> { + self.0.load().forward.get(vortex_id).copied() } - fn vortex_id_for_arrow_canonical(&self, arrow_name: &str) -> Option { - self.0.load().get(&ExtId::new(arrow_name)).copied() + fn vortex_alias_for(&self, arrow_id: &ExtId) -> Option<(ExtId, ArrowCanonicalCodec)> { + let state = self.0.load(); + let vortex_id = *state.reverse.get(arrow_id)?; + let (_, codec) = *state.forward.get(&vortex_id)?; + Some((vortex_id, codec)) + } +} + +impl std::fmt::Debug for AliasState { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("AliasState") + .field("forward", &self.forward) + .field("reverse", &self.reverse) + .finish() } } @@ -110,32 +115,25 @@ impl DTypeSession { &self.registry } - /// Alias `arrow_name` to `vortex_id` with the codec used at the Arrow boundary. + /// Alias `arrow_id` to `vortex_id` with the codec used at the Arrow boundary. /// Re-registering evicts the previous mapping for either side. pub fn register_arrow_canonical( &self, vortex_id: ExtId, - arrow_name: &'static str, + arrow_id: ExtId, codec: ArrowCanonicalCodec, ) { - self.arrow_canonical.register(vortex_id, arrow_name, codec); + self.arrow_canonical.register(vortex_id, arrow_id, codec); } - /// Returns the Arrow canonical name and codec aliased to `vortex_id`, if any. - pub fn arrow_canonical_for(&self, vortex_id: &ExtId) -> Option<(ExtId, ArrowCanonicalCodec)> { - self.arrow_canonical - .arrow_canonical_for(vortex_id) - .map(|e| (e.partner, e.codec)) + /// Returns the Arrow canonical id and codec aliased to `vortex_id`, if any. + pub fn arrow_alias_for(&self, vortex_id: &ExtId) -> Option<(ExtId, ArrowCanonicalCodec)> { + self.arrow_canonical.arrow_alias_for(vortex_id) } - /// Returns the Vortex id and codec aliased to `arrow_name`, if any. - pub fn vortex_id_for_arrow_canonical( - &self, - arrow_name: &str, - ) -> Option<(ExtId, ArrowCanonicalCodec)> { - self.arrow_canonical - .vortex_id_for_arrow_canonical(arrow_name) - .map(|e| (e.partner, e.codec)) + /// Returns the Vortex id and codec aliased to `arrow_id`, if any. + pub fn vortex_alias_for(&self, arrow_id: &ExtId) -> Option<(ExtId, ArrowCanonicalCodec)> { + self.arrow_canonical.vortex_alias_for(arrow_id) } } @@ -168,31 +166,17 @@ mod tests { fn arrow_canonical_re_registration_is_clean() { let session = DTypeSession::default(); let v = ExtId::new("vortex.test"); + let foo = ExtId::new("arrow.foo"); + let bar = ExtId::new("arrow.bar"); + + session.register_arrow_canonical(v, foo, TEST_CODEC); + assert_eq!(session.arrow_alias_for(&v).map(|(id, _)| id), Some(foo)); + assert_eq!(session.vortex_alias_for(&foo).map(|(id, _)| id), Some(v)); - session.register_arrow_canonical(v, "arrow.foo", TEST_CODEC); - assert_eq!( - session.arrow_canonical_for(&v).map(|(id, _)| id), - Some(ExtId::new("arrow.foo")) - ); - assert_eq!( - session - .vortex_id_for_arrow_canonical("arrow.foo") - .map(|(id, _)| id), - Some(v) - ); - - session.register_arrow_canonical(v, "arrow.bar", TEST_CODEC); - assert_eq!( - session.arrow_canonical_for(&v).map(|(id, _)| id), - Some(ExtId::new("arrow.bar")) - ); - assert_eq!( - session - .vortex_id_for_arrow_canonical("arrow.bar") - .map(|(id, _)| id), - Some(v) - ); - assert!(session.vortex_id_for_arrow_canonical("arrow.foo").is_none()); + session.register_arrow_canonical(v, bar, TEST_CODEC); + assert_eq!(session.arrow_alias_for(&v).map(|(id, _)| id), Some(bar)); + assert_eq!(session.vortex_alias_for(&bar).map(|(id, _)| id), Some(v)); + assert!(session.vortex_alias_for(&foo).is_none()); } /// `(vid → old, old → vid)` then `register(vid, new)` should leave `(vid → new, new → vid)`. @@ -203,31 +187,15 @@ mod tests { let old = ExtId::new("arrow.b"); let new = ExtId::new("arrow.c"); - session.register_arrow_canonical(vid, "arrow.b", TEST_CODEC); - assert_eq!( - session.arrow_canonical_for(&vid).map(|(id, _)| id), - Some(old) - ); - assert_eq!( - session - .vortex_id_for_arrow_canonical("arrow.b") - .map(|(id, _)| id), - Some(vid) - ); - - session.register_arrow_canonical(vid, "arrow.c", TEST_CODEC); - - assert_eq!( - session.arrow_canonical_for(&vid).map(|(id, _)| id), - Some(new) - ); - assert_eq!( - session - .vortex_id_for_arrow_canonical("arrow.c") - .map(|(id, _)| id), - Some(vid) - ); - assert!(session.vortex_id_for_arrow_canonical("arrow.b").is_none()); + session.register_arrow_canonical(vid, old, TEST_CODEC); + assert_eq!(session.arrow_alias_for(&vid).map(|(id, _)| id), Some(old)); + assert_eq!(session.vortex_alias_for(&old).map(|(id, _)| id), Some(vid)); + + session.register_arrow_canonical(vid, new, TEST_CODEC); + + assert_eq!(session.arrow_alias_for(&vid).map(|(id, _)| id), Some(new)); + assert_eq!(session.vortex_alias_for(&new).map(|(id, _)| id), Some(vid)); + assert!(session.vortex_alias_for(&old).is_none()); } /// `(old → name, name → old)` then `register(new, name)` should leave `(new → name, name → new)`. @@ -238,35 +206,25 @@ mod tests { let name = ExtId::new("arrow.b"); let new = ExtId::new("vortex.c"); - session.register_arrow_canonical(old, "arrow.b", TEST_CODEC); - assert_eq!( - session.arrow_canonical_for(&old).map(|(id, _)| id), - Some(name) - ); - - session.register_arrow_canonical(new, "arrow.b", TEST_CODEC); - - assert_eq!( - session.arrow_canonical_for(&new).map(|(id, _)| id), - Some(name) - ); - assert_eq!( - session - .vortex_id_for_arrow_canonical("arrow.b") - .map(|(id, _)| id), - Some(new) - ); - assert!(session.arrow_canonical_for(&old).is_none()); + session.register_arrow_canonical(old, name, TEST_CODEC); + assert_eq!(session.arrow_alias_for(&old).map(|(id, _)| id), Some(name)); + + session.register_arrow_canonical(new, name, TEST_CODEC); + + assert_eq!(session.arrow_alias_for(&new).map(|(id, _)| id), Some(name)); + assert_eq!(session.vortex_alias_for(&name).map(|(id, _)| id), Some(new)); + assert!(session.arrow_alias_for(&old).is_none()); } #[test] fn codec_round_trips_through_lookup() { let session = DTypeSession::default(); let vid = ExtId::new("vortex.x"); + let aid = ExtId::new("arrow.x"); - session.register_arrow_canonical(vid, "arrow.x", TEST_CODEC); + session.register_arrow_canonical(vid, aid, TEST_CODEC); - let (_, codec) = session.arrow_canonical_for(&vid).unwrap(); + let (_, codec) = session.arrow_alias_for(&vid).unwrap(); let json = (codec.to_json)(b"hello").unwrap(); assert_eq!(json, "hello"); let bytes = (codec.from_json)(&json).unwrap(); diff --git a/vortex-tensor/src/lib.rs b/vortex-tensor/src/lib.rs index 41aa729433f..5e46bb153d1 100644 --- a/vortex-tensor/src/lib.rs +++ b/vortex-tensor/src/lib.rs @@ -11,6 +11,7 @@ )] use vortex_array::arrays::scalar_fn::plugin::ScalarFnArrayPlugin; +use vortex_array::dtype::extension::ExtId; use vortex_array::dtype::extension::ExtVTable; use vortex_array::dtype::session::ArrowCanonicalCodec; use vortex_array::dtype::session::DTypeSessionExt; @@ -54,7 +55,7 @@ pub fn initialize(session: &VortexSession) { dtypes.register(FixedShapeTensor); dtypes.register_arrow_canonical( FixedShapeTensor.id(), - FixedShapeTensor::ARROW_EXT_NAME, + ExtId::new(FixedShapeTensor::ARROW_EXT_NAME), ArrowCanonicalCodec { to_json: fixed_shape::proto_to_json, from_json: fixed_shape::json_to_proto, From b61239c1dbca7a5a286c041314ed221166152136 Mon Sep 17 00:00:00 2001 From: Baris Palaska Date: Tue, 28 Apr 2026 11:33:46 +0100 Subject: [PATCH 16/21] nit Signed-off-by: Baris Palaska --- vortex-tensor/src/lib.rs | 3 +-- vortex-tensor/src/tests/arrow_roundtrip.rs | 2 +- vortex-tensor/src/types/fixed_shape/mod.rs | 8 +++++++- 3 files changed, 9 insertions(+), 4 deletions(-) diff --git a/vortex-tensor/src/lib.rs b/vortex-tensor/src/lib.rs index 5e46bb153d1..28ee8795825 100644 --- a/vortex-tensor/src/lib.rs +++ b/vortex-tensor/src/lib.rs @@ -11,7 +11,6 @@ )] use vortex_array::arrays::scalar_fn::plugin::ScalarFnArrayPlugin; -use vortex_array::dtype::extension::ExtId; use vortex_array::dtype::extension::ExtVTable; use vortex_array::dtype::session::ArrowCanonicalCodec; use vortex_array::dtype::session::DTypeSessionExt; @@ -55,7 +54,7 @@ pub fn initialize(session: &VortexSession) { dtypes.register(FixedShapeTensor); dtypes.register_arrow_canonical( FixedShapeTensor.id(), - ExtId::new(FixedShapeTensor::ARROW_EXT_NAME), + FixedShapeTensor::arrow_ext_id(), ArrowCanonicalCodec { to_json: fixed_shape::proto_to_json, from_json: fixed_shape::json_to_proto, diff --git a/vortex-tensor/src/tests/arrow_roundtrip.rs b/vortex-tensor/src/tests/arrow_roundtrip.rs index 3975d8e2157..37baf8449c5 100644 --- a/vortex-tensor/src/tests/arrow_roundtrip.rs +++ b/vortex-tensor/src/tests/arrow_roundtrip.rs @@ -120,7 +120,7 @@ fn fixed_shape_tensor_metadata_roundtrip() { .metadata() .get(EXTENSION_TYPE_NAME_KEY) .map(String::as_str), - Some(FixedShapeTensor::ARROW_EXT_NAME), + Some(FixedShapeTensor::arrow_ext_id().as_str()), ); // Canonical wire: raw JSON, not base64. diff --git a/vortex-tensor/src/types/fixed_shape/mod.rs b/vortex-tensor/src/types/fixed_shape/mod.rs index 147fbdf1966..94d91e74095 100644 --- a/vortex-tensor/src/types/fixed_shape/mod.rs +++ b/vortex-tensor/src/types/fixed_shape/mod.rs @@ -3,12 +3,18 @@ //! Fixed-shape Tensor extension type. +use vortex_array::dtype::extension::ExtId; +use vortex_session::registry::CachedId; + /// The VTable for the Tensor extension type. #[derive(Clone, Debug, Default, PartialEq, Eq, Hash)] pub struct FixedShapeTensor; impl FixedShapeTensor { - pub(crate) const ARROW_EXT_NAME: &'static str = "arrow.fixed_shape_tensor"; + pub(crate) fn arrow_ext_id() -> ExtId { + static ID: CachedId = CachedId::new("arrow.fixed_shape_tensor"); + *ID + } } mod matcher; From 2b2a9a3105331821edfb09878fd91970960bb39e Mon Sep 17 00:00:00 2001 From: Baris Palaska Date: Tue, 28 Apr 2026 17:02:23 +0100 Subject: [PATCH 17/21] unwrap extension to storage in execute_arrow Signed-off-by: Baris Palaska --- vortex-array/src/arrow/executor/mod.rs | 42 ++++++++++++++++++++++ vortex-tensor/src/tests/arrow_roundtrip.rs | 29 +++++++++++++++ 2 files changed, 71 insertions(+) diff --git a/vortex-array/src/arrow/executor/mod.rs b/vortex-array/src/arrow/executor/mod.rs index 890e7f8a46a..5cb8531637b 100644 --- a/vortex-array/src/arrow/executor/mod.rs +++ b/vortex-array/src/arrow/executor/mod.rs @@ -30,8 +30,10 @@ use vortex_error::vortex_bail; use vortex_error::vortex_ensure; use crate::ArrayRef; +use crate::arrays::ExtensionArray; use crate::arrays::List; use crate::arrays::VarBin; +use crate::arrays::extension::ExtensionArrayExt; use crate::arrays::list::ListArrayExt; use crate::arrays::varbin::VarBinArrayExt; use crate::arrow::executor::bool::to_arrow_bool; @@ -87,6 +89,12 @@ impl ArrowArrayExecutor for ArrayRef { data_type: Option<&DataType>, ctx: &mut ExecutionCtx, ) -> VortexResult { + // Extension identity lives on Field metadata; dispatch on the storage array. + if matches!(self.dtype(), DType::Extension(_)) { + let ext = self.execute::(ctx)?; + return ext.storage_array().clone().execute_arrow(data_type, ctx); + } + let len = self.len(); // Resolve the DataType if it is a leaf type @@ -186,6 +194,40 @@ impl ArrowArrayExecutor for ArrayRef { } } +#[cfg(test)] +mod tests { + use arrow_array::cast::AsArray; + use arrow_array::types::UInt64Type; + use arrow_schema::DataType; + + use super::*; + use crate::LEGACY_SESSION; + use crate::VortexSessionExecute; + use crate::array::IntoArray; + use crate::arrays::ExtensionArray; + use crate::arrays::PrimitiveArray; + use crate::extension::tests::divisible_int::DivisibleInt; + use crate::extension::tests::divisible_int::Divisor; + + #[test] + fn execute_arrow_unwraps_extension_to_storage() { + let storage = PrimitiveArray::from_iter(0u64..6).into_array(); + let ext = ExtensionArray::try_new_from_vtable(DivisibleInt, Divisor(1), storage) + .unwrap() + .into_array(); + + let arrow = ext + .execute_arrow( + Some(&DataType::UInt64), + &mut LEGACY_SESSION.create_execution_ctx(), + ) + .unwrap(); + + let primitives = arrow.as_primitive::(); + assert_eq!(primitives.values(), &[0, 1, 2, 3, 4, 5]); + } +} + /// Determine the preferred (cheapest) Arrow type for an array. /// /// For most arrays, this returns the canonical Arrow type from `dtype.to_arrow_dtype()`. diff --git a/vortex-tensor/src/tests/arrow_roundtrip.rs b/vortex-tensor/src/tests/arrow_roundtrip.rs index 37baf8449c5..b3a165d43c8 100644 --- a/vortex-tensor/src/tests/arrow_roundtrip.rs +++ b/vortex-tensor/src/tests/arrow_roundtrip.rs @@ -9,6 +9,7 @@ use arrow_schema::DataType; use arrow_schema::TimeUnit as ArrowTimeUnit; use arrow_schema::extension::EXTENSION_TYPE_METADATA_KEY; use arrow_schema::extension::EXTENSION_TYPE_NAME_KEY; +use vortex_array::arrays::StructArray; use vortex_array::dtype::DType; use vortex_array::dtype::Nullability; use vortex_array::dtype::PType; @@ -148,6 +149,34 @@ fn tensor_inside_nested_struct_roundtrips() { assert_eq!(recovered, original); } +#[test] +fn vector_record_batch_round_trip_carries_field_metadata() { + let vector_array = Vector::constant_array(&[1.0f32, 2.0, 3.0, 4.0], 2).unwrap(); + let struct_array = StructArray::from_fields(&[("embedding", vector_array)]).unwrap(); + + let dtype = DType::struct_([("embedding", vector_dtype(4))], Nullability::NonNullable); + let schema = dtype.to_arrow_schema_with_session(&SESSION).unwrap(); + let rb = struct_array.into_record_batch_with_schema(&schema).unwrap(); + + let column = rb.column(0); + let DataType::FixedSizeList(_, size) = column.data_type() else { + panic!( + "expected storage FixedSizeList, got {:?}", + column.data_type() + ); + }; + assert_eq!(*size, 4); + + assert_eq!( + rb.schema() + .field(0) + .metadata() + .get(EXTENSION_TYPE_NAME_KEY) + .map(String::as_str), + Some(Vector.id().as_str()), + ); +} + #[test] fn temporal_extension_still_uses_native_arrow() { let ts = Timestamp::new_with_tz(TimeUnit::Microseconds, None, Nullability::Nullable); From 2155f2085524daaf4fe9bd14cfc61d8387bbde82 Mon Sep 17 00:00:00 2001 From: Baris Palaska Date: Tue, 28 Apr 2026 17:02:51 +0100 Subject: [PATCH 18/21] clippy Signed-off-by: Baris Palaska --- vortex-array/src/arrow/executor/mod.rs | 68 +++++++++++++------------- 1 file changed, 34 insertions(+), 34 deletions(-) diff --git a/vortex-array/src/arrow/executor/mod.rs b/vortex-array/src/arrow/executor/mod.rs index 5cb8531637b..66e7afd7543 100644 --- a/vortex-array/src/arrow/executor/mod.rs +++ b/vortex-array/src/arrow/executor/mod.rs @@ -194,40 +194,6 @@ impl ArrowArrayExecutor for ArrayRef { } } -#[cfg(test)] -mod tests { - use arrow_array::cast::AsArray; - use arrow_array::types::UInt64Type; - use arrow_schema::DataType; - - use super::*; - use crate::LEGACY_SESSION; - use crate::VortexSessionExecute; - use crate::array::IntoArray; - use crate::arrays::ExtensionArray; - use crate::arrays::PrimitiveArray; - use crate::extension::tests::divisible_int::DivisibleInt; - use crate::extension::tests::divisible_int::Divisor; - - #[test] - fn execute_arrow_unwraps_extension_to_storage() { - let storage = PrimitiveArray::from_iter(0u64..6).into_array(); - let ext = ExtensionArray::try_new_from_vtable(DivisibleInt, Divisor(1), storage) - .unwrap() - .into_array(); - - let arrow = ext - .execute_arrow( - Some(&DataType::UInt64), - &mut LEGACY_SESSION.create_execution_ctx(), - ) - .unwrap(); - - let primitives = arrow.as_primitive::(); - assert_eq!(primitives.values(), &[0, 1, 2, 3, 4, 5]); - } -} - /// Determine the preferred (cheapest) Arrow type for an array. /// /// For most arrays, this returns the canonical Arrow type from `dtype.to_arrow_dtype()`. @@ -270,3 +236,37 @@ fn preferred_arrow_type(array: &ArrayRef) -> VortexResult { // Everything else: use canonical dtype conversion array.dtype().to_arrow_dtype() } + +#[cfg(test)] +mod tests { + use arrow_array::cast::AsArray; + use arrow_array::types::UInt64Type; + use arrow_schema::DataType; + + use super::*; + use crate::LEGACY_SESSION; + use crate::VortexSessionExecute; + use crate::array::IntoArray; + use crate::arrays::ExtensionArray; + use crate::arrays::PrimitiveArray; + use crate::extension::tests::divisible_int::DivisibleInt; + use crate::extension::tests::divisible_int::Divisor; + + #[test] + fn execute_arrow_unwraps_extension_to_storage() { + let storage = PrimitiveArray::from_iter(0u64..6).into_array(); + let ext = ExtensionArray::try_new_from_vtable(DivisibleInt, Divisor(1), storage) + .unwrap() + .into_array(); + + let arrow = ext + .execute_arrow( + Some(&DataType::UInt64), + &mut LEGACY_SESSION.create_execution_ctx(), + ) + .unwrap(); + + let primitives = arrow.as_primitive::(); + assert_eq!(primitives.values(), &[0, 1, 2, 3, 4, 5]); + } +} From 990a4b803b57e2abdb64ccd9b6a4c3b5b9938bd2 Mon Sep 17 00:00:00 2001 From: Baris Palaska Date: Tue, 28 Apr 2026 19:10:06 +0100 Subject: [PATCH 19/21] recover extension identity in FromArrowArray Signed-off-by: Baris Palaska --- vortex-array/public-api.lock | 152 ++++++++++++++++++++- vortex-array/src/arrow/convert.rs | 147 ++++++++++++++++---- vortex-array/src/arrow/mod.rs | 18 ++- vortex-array/src/dtype/arrow.rs | 51 ++++--- vortex-tensor/src/tests/arrow_roundtrip.rs | 60 +++++++- 5 files changed, 379 insertions(+), 49 deletions(-) diff --git a/vortex-array/public-api.lock b/vortex-array/public-api.lock index 49a771bbbd1..3cd8c34d1cf 100644 --- a/vortex-array/public-api.lock +++ b/vortex-array/public-api.lock @@ -7038,158 +7038,234 @@ pub fn vortex_array::ArrayRef::execute_record_batch(self, schema: &arrow_schema: pub fn vortex_array::ArrayRef::execute_record_batches(self, schema: &arrow_schema::schema::Schema, ctx: &mut vortex_array::ExecutionCtx) -> vortex_error::VortexResult> -pub trait vortex_array::arrow::FromArrowArray +pub trait vortex_array::arrow::FromArrowArray: core::marker::Sized -pub fn vortex_array::arrow::FromArrowArray::from_arrow(array: A, nullable: bool) -> vortex_error::VortexResult where Self: core::marker::Sized +pub fn vortex_array::arrow::FromArrowArray::from_arrow(array: A, nullable: bool) -> vortex_error::VortexResult + +pub fn vortex_array::arrow::FromArrowArray::from_arrow_with_session(array: A, nullable: bool, _session: &vortex_session::VortexSession) -> vortex_error::VortexResult impl vortex_array::arrow::FromArrowArray<&arrow_array::array::boolean_array::BooleanArray> for vortex_array::ArrayRef pub fn vortex_array::ArrayRef::from_arrow(value: &arrow_array::array::boolean_array::BooleanArray, nullable: bool) -> vortex_error::VortexResult +pub fn vortex_array::ArrayRef::from_arrow_with_session(array: A, nullable: bool, _session: &vortex_session::VortexSession) -> vortex_error::VortexResult + impl vortex_array::arrow::FromArrowArray<&arrow_array::array::fixed_size_list_array::FixedSizeListArray> for vortex_array::ArrayRef pub fn vortex_array::ArrayRef::from_arrow(array: &arrow_array::array::fixed_size_list_array::FixedSizeListArray, nullable: bool) -> vortex_error::VortexResult +pub fn vortex_array::ArrayRef::from_arrow_with_session(array: &arrow_array::array::fixed_size_list_array::FixedSizeListArray, nullable: bool, session: &vortex_session::VortexSession) -> vortex_error::VortexResult + impl vortex_array::arrow::FromArrowArray<&arrow_array::array::null_array::NullArray> for vortex_array::ArrayRef pub fn vortex_array::ArrayRef::from_arrow(value: &arrow_array::array::null_array::NullArray, nullable: bool) -> vortex_error::VortexResult +pub fn vortex_array::ArrayRef::from_arrow_with_session(array: A, nullable: bool, _session: &vortex_session::VortexSession) -> vortex_error::VortexResult + impl vortex_array::arrow::FromArrowArray<&arrow_array::array::primitive_array::PrimitiveArray> for vortex_array::ArrayRef pub fn vortex_array::ArrayRef::from_arrow(value: &arrow_array::array::primitive_array::PrimitiveArray, nullable: bool) -> vortex_error::VortexResult +pub fn vortex_array::ArrayRef::from_arrow_with_session(array: A, nullable: bool, _session: &vortex_session::VortexSession) -> vortex_error::VortexResult + impl vortex_array::arrow::FromArrowArray<&arrow_array::array::primitive_array::PrimitiveArray> for vortex_array::ArrayRef pub fn vortex_array::ArrayRef::from_arrow(value: &arrow_array::array::primitive_array::PrimitiveArray, nullable: bool) -> vortex_error::VortexResult +pub fn vortex_array::ArrayRef::from_arrow_with_session(array: A, nullable: bool, _session: &vortex_session::VortexSession) -> vortex_error::VortexResult + impl vortex_array::arrow::FromArrowArray<&arrow_array::array::primitive_array::PrimitiveArray> for vortex_array::ArrayRef pub fn vortex_array::ArrayRef::from_arrow(array: &arrow_array::array::primitive_array::PrimitiveArray, nullable: bool) -> vortex_error::VortexResult +pub fn vortex_array::ArrayRef::from_arrow_with_session(array: A, nullable: bool, _session: &vortex_session::VortexSession) -> vortex_error::VortexResult + impl vortex_array::arrow::FromArrowArray<&arrow_array::array::primitive_array::PrimitiveArray> for vortex_array::ArrayRef pub fn vortex_array::ArrayRef::from_arrow(array: &arrow_array::array::primitive_array::PrimitiveArray, nullable: bool) -> vortex_error::VortexResult +pub fn vortex_array::ArrayRef::from_arrow_with_session(array: A, nullable: bool, _session: &vortex_session::VortexSession) -> vortex_error::VortexResult + impl vortex_array::arrow::FromArrowArray<&arrow_array::array::primitive_array::PrimitiveArray> for vortex_array::ArrayRef pub fn vortex_array::ArrayRef::from_arrow(array: &arrow_array::array::primitive_array::PrimitiveArray, nullable: bool) -> vortex_error::VortexResult +pub fn vortex_array::ArrayRef::from_arrow_with_session(array: A, nullable: bool, _session: &vortex_session::VortexSession) -> vortex_error::VortexResult + impl vortex_array::arrow::FromArrowArray<&arrow_array::array::primitive_array::PrimitiveArray> for vortex_array::ArrayRef pub fn vortex_array::ArrayRef::from_arrow(array: &arrow_array::array::primitive_array::PrimitiveArray, nullable: bool) -> vortex_error::VortexResult +pub fn vortex_array::ArrayRef::from_arrow_with_session(array: A, nullable: bool, _session: &vortex_session::VortexSession) -> vortex_error::VortexResult + impl vortex_array::arrow::FromArrowArray<&arrow_array::array::primitive_array::PrimitiveArray> for vortex_array::ArrayRef pub fn vortex_array::ArrayRef::from_arrow(value: &arrow_array::array::primitive_array::PrimitiveArray, nullable: bool) -> vortex_error::VortexResult +pub fn vortex_array::ArrayRef::from_arrow_with_session(array: A, nullable: bool, _session: &vortex_session::VortexSession) -> vortex_error::VortexResult + impl vortex_array::arrow::FromArrowArray<&arrow_array::array::primitive_array::PrimitiveArray> for vortex_array::ArrayRef pub fn vortex_array::ArrayRef::from_arrow(value: &arrow_array::array::primitive_array::PrimitiveArray, nullable: bool) -> vortex_error::VortexResult +pub fn vortex_array::ArrayRef::from_arrow_with_session(array: A, nullable: bool, _session: &vortex_session::VortexSession) -> vortex_error::VortexResult + impl vortex_array::arrow::FromArrowArray<&arrow_array::array::primitive_array::PrimitiveArray> for vortex_array::ArrayRef pub fn vortex_array::ArrayRef::from_arrow(value: &arrow_array::array::primitive_array::PrimitiveArray, nullable: bool) -> vortex_error::VortexResult +pub fn vortex_array::ArrayRef::from_arrow_with_session(array: A, nullable: bool, _session: &vortex_session::VortexSession) -> vortex_error::VortexResult + impl vortex_array::arrow::FromArrowArray<&arrow_array::array::primitive_array::PrimitiveArray> for vortex_array::ArrayRef pub fn vortex_array::ArrayRef::from_arrow(value: &arrow_array::array::primitive_array::PrimitiveArray, nullable: bool) -> vortex_error::VortexResult +pub fn vortex_array::ArrayRef::from_arrow_with_session(array: A, nullable: bool, _session: &vortex_session::VortexSession) -> vortex_error::VortexResult + impl vortex_array::arrow::FromArrowArray<&arrow_array::array::primitive_array::PrimitiveArray> for vortex_array::ArrayRef pub fn vortex_array::ArrayRef::from_arrow(value: &arrow_array::array::primitive_array::PrimitiveArray, nullable: bool) -> vortex_error::VortexResult +pub fn vortex_array::ArrayRef::from_arrow_with_session(array: A, nullable: bool, _session: &vortex_session::VortexSession) -> vortex_error::VortexResult + impl vortex_array::arrow::FromArrowArray<&arrow_array::array::primitive_array::PrimitiveArray> for vortex_array::ArrayRef pub fn vortex_array::ArrayRef::from_arrow(value: &arrow_array::array::primitive_array::PrimitiveArray, nullable: bool) -> vortex_error::VortexResult +pub fn vortex_array::ArrayRef::from_arrow_with_session(array: A, nullable: bool, _session: &vortex_session::VortexSession) -> vortex_error::VortexResult + impl vortex_array::arrow::FromArrowArray<&arrow_array::array::primitive_array::PrimitiveArray> for vortex_array::ArrayRef pub fn vortex_array::ArrayRef::from_arrow(value: &arrow_array::array::primitive_array::PrimitiveArray, nullable: bool) -> vortex_error::VortexResult +pub fn vortex_array::ArrayRef::from_arrow_with_session(array: A, nullable: bool, _session: &vortex_session::VortexSession) -> vortex_error::VortexResult + impl vortex_array::arrow::FromArrowArray<&arrow_array::array::primitive_array::PrimitiveArray> for vortex_array::ArrayRef pub fn vortex_array::ArrayRef::from_arrow(value: &arrow_array::array::primitive_array::PrimitiveArray, nullable: bool) -> vortex_error::VortexResult +pub fn vortex_array::ArrayRef::from_arrow_with_session(array: A, nullable: bool, _session: &vortex_session::VortexSession) -> vortex_error::VortexResult + impl vortex_array::arrow::FromArrowArray<&arrow_array::array::primitive_array::PrimitiveArray> for vortex_array::ArrayRef pub fn vortex_array::ArrayRef::from_arrow(value: &arrow_array::array::primitive_array::PrimitiveArray, nullable: bool) -> vortex_error::VortexResult +pub fn vortex_array::ArrayRef::from_arrow_with_session(array: A, nullable: bool, _session: &vortex_session::VortexSession) -> vortex_error::VortexResult + impl vortex_array::arrow::FromArrowArray<&arrow_array::array::primitive_array::PrimitiveArray> for vortex_array::ArrayRef pub fn vortex_array::ArrayRef::from_arrow(value: &arrow_array::array::primitive_array::PrimitiveArray, nullable: bool) -> vortex_error::VortexResult +pub fn vortex_array::ArrayRef::from_arrow_with_session(array: A, nullable: bool, _session: &vortex_session::VortexSession) -> vortex_error::VortexResult + impl vortex_array::arrow::FromArrowArray<&arrow_array::array::primitive_array::PrimitiveArray> for vortex_array::ArrayRef pub fn vortex_array::ArrayRef::from_arrow(value: &arrow_array::array::primitive_array::PrimitiveArray, nullable: bool) -> vortex_error::VortexResult +pub fn vortex_array::ArrayRef::from_arrow_with_session(array: A, nullable: bool, _session: &vortex_session::VortexSession) -> vortex_error::VortexResult + impl vortex_array::arrow::FromArrowArray<&arrow_array::array::primitive_array::PrimitiveArray> for vortex_array::ArrayRef pub fn vortex_array::ArrayRef::from_arrow(value: &arrow_array::array::primitive_array::PrimitiveArray, nullable: bool) -> vortex_error::VortexResult +pub fn vortex_array::ArrayRef::from_arrow_with_session(array: A, nullable: bool, _session: &vortex_session::VortexSession) -> vortex_error::VortexResult + impl vortex_array::arrow::FromArrowArray<&arrow_array::array::primitive_array::PrimitiveArray> for vortex_array::ArrayRef pub fn vortex_array::ArrayRef::from_arrow(value: &arrow_array::array::primitive_array::PrimitiveArray, nullable: bool) -> vortex_error::VortexResult +pub fn vortex_array::ArrayRef::from_arrow_with_session(array: A, nullable: bool, _session: &vortex_session::VortexSession) -> vortex_error::VortexResult + impl vortex_array::arrow::FromArrowArray<&arrow_array::array::primitive_array::PrimitiveArray> for vortex_array::ArrayRef pub fn vortex_array::ArrayRef::from_arrow(value: &arrow_array::array::primitive_array::PrimitiveArray, nullable: bool) -> vortex_error::VortexResult +pub fn vortex_array::ArrayRef::from_arrow_with_session(array: A, nullable: bool, _session: &vortex_session::VortexSession) -> vortex_error::VortexResult + impl vortex_array::arrow::FromArrowArray<&arrow_array::array::primitive_array::PrimitiveArray> for vortex_array::ArrayRef pub fn vortex_array::ArrayRef::from_arrow(value: &arrow_array::array::primitive_array::PrimitiveArray, nullable: bool) -> vortex_error::VortexResult +pub fn vortex_array::ArrayRef::from_arrow_with_session(array: A, nullable: bool, _session: &vortex_session::VortexSession) -> vortex_error::VortexResult + impl vortex_array::arrow::FromArrowArray<&arrow_array::array::primitive_array::PrimitiveArray> for vortex_array::ArrayRef pub fn vortex_array::ArrayRef::from_arrow(value: &arrow_array::array::primitive_array::PrimitiveArray, nullable: bool) -> vortex_error::VortexResult +pub fn vortex_array::ArrayRef::from_arrow_with_session(array: A, nullable: bool, _session: &vortex_session::VortexSession) -> vortex_error::VortexResult + impl vortex_array::arrow::FromArrowArray<&arrow_array::array::primitive_array::PrimitiveArray> for vortex_array::ArrayRef pub fn vortex_array::ArrayRef::from_arrow(value: &arrow_array::array::primitive_array::PrimitiveArray, nullable: bool) -> vortex_error::VortexResult +pub fn vortex_array::ArrayRef::from_arrow_with_session(array: A, nullable: bool, _session: &vortex_session::VortexSession) -> vortex_error::VortexResult + impl vortex_array::arrow::FromArrowArray<&arrow_array::array::primitive_array::PrimitiveArray> for vortex_array::ArrayRef pub fn vortex_array::ArrayRef::from_arrow(value: &arrow_array::array::primitive_array::PrimitiveArray, nullable: bool) -> vortex_error::VortexResult +pub fn vortex_array::ArrayRef::from_arrow_with_session(array: A, nullable: bool, _session: &vortex_session::VortexSession) -> vortex_error::VortexResult + impl vortex_array::arrow::FromArrowArray<&arrow_array::array::primitive_array::PrimitiveArray> for vortex_array::ArrayRef pub fn vortex_array::ArrayRef::from_arrow(value: &arrow_array::array::primitive_array::PrimitiveArray, nullable: bool) -> vortex_error::VortexResult +pub fn vortex_array::ArrayRef::from_arrow_with_session(array: A, nullable: bool, _session: &vortex_session::VortexSession) -> vortex_error::VortexResult + impl vortex_array::arrow::FromArrowArray<&arrow_array::array::struct_array::StructArray> for vortex_array::ArrayRef pub fn vortex_array::ArrayRef::from_arrow(value: &arrow_array::array::struct_array::StructArray, nullable: bool) -> vortex_error::VortexResult +pub fn vortex_array::ArrayRef::from_arrow_with_session(value: &arrow_array::array::struct_array::StructArray, nullable: bool, session: &vortex_session::VortexSession) -> vortex_error::VortexResult + impl vortex_array::arrow::FromArrowArray<&arrow_array::record_batch::RecordBatch> for vortex_array::ArrayRef pub fn vortex_array::ArrayRef::from_arrow(array: &arrow_array::record_batch::RecordBatch, nullable: bool) -> vortex_error::VortexResult +pub fn vortex_array::ArrayRef::from_arrow_with_session(array: &arrow_array::record_batch::RecordBatch, nullable: bool, session: &vortex_session::VortexSession) -> vortex_error::VortexResult + impl vortex_array::arrow::FromArrowArray<&dyn arrow_array::array::Array> for vortex_array::ArrayRef pub fn vortex_array::ArrayRef::from_arrow(array: &dyn arrow_array::array::Array, nullable: bool) -> vortex_error::VortexResult +pub fn vortex_array::ArrayRef::from_arrow_with_session(array: &dyn arrow_array::array::Array, nullable: bool, session: &vortex_session::VortexSession) -> vortex_error::VortexResult + impl vortex_array::arrow::FromArrowArray for vortex_array::ArrayRef pub fn vortex_array::ArrayRef::from_arrow(array: arrow_array::record_batch::RecordBatch, nullable: bool) -> vortex_error::VortexResult +pub fn vortex_array::ArrayRef::from_arrow_with_session(array: arrow_array::record_batch::RecordBatch, nullable: bool, session: &vortex_session::VortexSession) -> vortex_error::VortexResult + impl vortex_array::arrow::FromArrowArray<&arrow_array::array::dictionary_array::DictionaryArray> for vortex_array::arrays::dict::DictArray pub fn vortex_array::arrays::dict::DictArray::from_arrow(array: &arrow_array::array::dictionary_array::DictionaryArray, nullable: bool) -> vortex_error::VortexResult +pub fn vortex_array::arrays::dict::DictArray::from_arrow_with_session(array: A, nullable: bool, _session: &vortex_session::VortexSession) -> vortex_error::VortexResult + impl vortex_array::arrow::FromArrowArray<&arrow_array::array::list_view_array::GenericListViewArray> for vortex_array::ArrayRef pub fn vortex_array::ArrayRef::from_arrow(array: &arrow_array::array::list_view_array::GenericListViewArray, nullable: bool) -> vortex_error::VortexResult +pub fn vortex_array::ArrayRef::from_arrow_with_session(array: A, nullable: bool, _session: &vortex_session::VortexSession) -> vortex_error::VortexResult + impl vortex_array::arrow::FromArrowArray<&arrow_array::array::list_array::GenericListArray> for vortex_array::ArrayRef pub fn vortex_array::ArrayRef::from_arrow(value: &arrow_array::array::list_array::GenericListArray, nullable: bool) -> vortex_error::VortexResult +pub fn vortex_array::ArrayRef::from_arrow_with_session(value: &arrow_array::array::list_array::GenericListArray, nullable: bool, session: &vortex_session::VortexSession) -> vortex_error::VortexResult + impl vortex_array::arrow::FromArrowArray<&arrow_array::array::byte_array::GenericByteArray> for vortex_array::ArrayRef where ::Offset: vortex_array::dtype::IntegerPType pub fn vortex_array::ArrayRef::from_arrow(value: &arrow_array::array::byte_array::GenericByteArray, nullable: bool) -> vortex_error::VortexResult +pub fn vortex_array::ArrayRef::from_arrow_with_session(array: A, nullable: bool, _session: &vortex_session::VortexSession) -> vortex_error::VortexResult + impl vortex_array::arrow::FromArrowArray<&arrow_array::array::byte_view_array::GenericByteViewArray> for vortex_array::ArrayRef pub fn vortex_array::ArrayRef::from_arrow(value: &arrow_array::array::byte_view_array::GenericByteViewArray, nullable: bool) -> vortex_error::VortexResult +pub fn vortex_array::ArrayRef::from_arrow_with_session(array: A, nullable: bool, _session: &vortex_session::VortexSession) -> vortex_error::VortexResult + pub trait vortex_array::arrow::IntoArrowArray pub fn vortex_array::arrow::IntoArrowArray::into_arrow(self, data_type: &arrow_schema::datatype::DataType) -> vortex_error::VortexResult @@ -22594,130 +22670,194 @@ impl vortex_array::arrow::FromArrowArray<&arrow_array::array::boolean_array::Boo pub fn vortex_array::ArrayRef::from_arrow(value: &arrow_array::array::boolean_array::BooleanArray, nullable: bool) -> vortex_error::VortexResult +pub fn vortex_array::ArrayRef::from_arrow_with_session(array: A, nullable: bool, _session: &vortex_session::VortexSession) -> vortex_error::VortexResult + impl vortex_array::arrow::FromArrowArray<&arrow_array::array::fixed_size_list_array::FixedSizeListArray> for vortex_array::ArrayRef pub fn vortex_array::ArrayRef::from_arrow(array: &arrow_array::array::fixed_size_list_array::FixedSizeListArray, nullable: bool) -> vortex_error::VortexResult +pub fn vortex_array::ArrayRef::from_arrow_with_session(array: &arrow_array::array::fixed_size_list_array::FixedSizeListArray, nullable: bool, session: &vortex_session::VortexSession) -> vortex_error::VortexResult + impl vortex_array::arrow::FromArrowArray<&arrow_array::array::null_array::NullArray> for vortex_array::ArrayRef pub fn vortex_array::ArrayRef::from_arrow(value: &arrow_array::array::null_array::NullArray, nullable: bool) -> vortex_error::VortexResult +pub fn vortex_array::ArrayRef::from_arrow_with_session(array: A, nullable: bool, _session: &vortex_session::VortexSession) -> vortex_error::VortexResult + impl vortex_array::arrow::FromArrowArray<&arrow_array::array::primitive_array::PrimitiveArray> for vortex_array::ArrayRef pub fn vortex_array::ArrayRef::from_arrow(value: &arrow_array::array::primitive_array::PrimitiveArray, nullable: bool) -> vortex_error::VortexResult +pub fn vortex_array::ArrayRef::from_arrow_with_session(array: A, nullable: bool, _session: &vortex_session::VortexSession) -> vortex_error::VortexResult + impl vortex_array::arrow::FromArrowArray<&arrow_array::array::primitive_array::PrimitiveArray> for vortex_array::ArrayRef pub fn vortex_array::ArrayRef::from_arrow(value: &arrow_array::array::primitive_array::PrimitiveArray, nullable: bool) -> vortex_error::VortexResult +pub fn vortex_array::ArrayRef::from_arrow_with_session(array: A, nullable: bool, _session: &vortex_session::VortexSession) -> vortex_error::VortexResult + impl vortex_array::arrow::FromArrowArray<&arrow_array::array::primitive_array::PrimitiveArray> for vortex_array::ArrayRef pub fn vortex_array::ArrayRef::from_arrow(array: &arrow_array::array::primitive_array::PrimitiveArray, nullable: bool) -> vortex_error::VortexResult +pub fn vortex_array::ArrayRef::from_arrow_with_session(array: A, nullable: bool, _session: &vortex_session::VortexSession) -> vortex_error::VortexResult + impl vortex_array::arrow::FromArrowArray<&arrow_array::array::primitive_array::PrimitiveArray> for vortex_array::ArrayRef pub fn vortex_array::ArrayRef::from_arrow(array: &arrow_array::array::primitive_array::PrimitiveArray, nullable: bool) -> vortex_error::VortexResult +pub fn vortex_array::ArrayRef::from_arrow_with_session(array: A, nullable: bool, _session: &vortex_session::VortexSession) -> vortex_error::VortexResult + impl vortex_array::arrow::FromArrowArray<&arrow_array::array::primitive_array::PrimitiveArray> for vortex_array::ArrayRef pub fn vortex_array::ArrayRef::from_arrow(array: &arrow_array::array::primitive_array::PrimitiveArray, nullable: bool) -> vortex_error::VortexResult +pub fn vortex_array::ArrayRef::from_arrow_with_session(array: A, nullable: bool, _session: &vortex_session::VortexSession) -> vortex_error::VortexResult + impl vortex_array::arrow::FromArrowArray<&arrow_array::array::primitive_array::PrimitiveArray> for vortex_array::ArrayRef pub fn vortex_array::ArrayRef::from_arrow(array: &arrow_array::array::primitive_array::PrimitiveArray, nullable: bool) -> vortex_error::VortexResult +pub fn vortex_array::ArrayRef::from_arrow_with_session(array: A, nullable: bool, _session: &vortex_session::VortexSession) -> vortex_error::VortexResult + impl vortex_array::arrow::FromArrowArray<&arrow_array::array::primitive_array::PrimitiveArray> for vortex_array::ArrayRef pub fn vortex_array::ArrayRef::from_arrow(value: &arrow_array::array::primitive_array::PrimitiveArray, nullable: bool) -> vortex_error::VortexResult +pub fn vortex_array::ArrayRef::from_arrow_with_session(array: A, nullable: bool, _session: &vortex_session::VortexSession) -> vortex_error::VortexResult + impl vortex_array::arrow::FromArrowArray<&arrow_array::array::primitive_array::PrimitiveArray> for vortex_array::ArrayRef pub fn vortex_array::ArrayRef::from_arrow(value: &arrow_array::array::primitive_array::PrimitiveArray, nullable: bool) -> vortex_error::VortexResult +pub fn vortex_array::ArrayRef::from_arrow_with_session(array: A, nullable: bool, _session: &vortex_session::VortexSession) -> vortex_error::VortexResult + impl vortex_array::arrow::FromArrowArray<&arrow_array::array::primitive_array::PrimitiveArray> for vortex_array::ArrayRef pub fn vortex_array::ArrayRef::from_arrow(value: &arrow_array::array::primitive_array::PrimitiveArray, nullable: bool) -> vortex_error::VortexResult +pub fn vortex_array::ArrayRef::from_arrow_with_session(array: A, nullable: bool, _session: &vortex_session::VortexSession) -> vortex_error::VortexResult + impl vortex_array::arrow::FromArrowArray<&arrow_array::array::primitive_array::PrimitiveArray> for vortex_array::ArrayRef pub fn vortex_array::ArrayRef::from_arrow(value: &arrow_array::array::primitive_array::PrimitiveArray, nullable: bool) -> vortex_error::VortexResult +pub fn vortex_array::ArrayRef::from_arrow_with_session(array: A, nullable: bool, _session: &vortex_session::VortexSession) -> vortex_error::VortexResult + impl vortex_array::arrow::FromArrowArray<&arrow_array::array::primitive_array::PrimitiveArray> for vortex_array::ArrayRef pub fn vortex_array::ArrayRef::from_arrow(value: &arrow_array::array::primitive_array::PrimitiveArray, nullable: bool) -> vortex_error::VortexResult +pub fn vortex_array::ArrayRef::from_arrow_with_session(array: A, nullable: bool, _session: &vortex_session::VortexSession) -> vortex_error::VortexResult + impl vortex_array::arrow::FromArrowArray<&arrow_array::array::primitive_array::PrimitiveArray> for vortex_array::ArrayRef pub fn vortex_array::ArrayRef::from_arrow(value: &arrow_array::array::primitive_array::PrimitiveArray, nullable: bool) -> vortex_error::VortexResult +pub fn vortex_array::ArrayRef::from_arrow_with_session(array: A, nullable: bool, _session: &vortex_session::VortexSession) -> vortex_error::VortexResult + impl vortex_array::arrow::FromArrowArray<&arrow_array::array::primitive_array::PrimitiveArray> for vortex_array::ArrayRef pub fn vortex_array::ArrayRef::from_arrow(value: &arrow_array::array::primitive_array::PrimitiveArray, nullable: bool) -> vortex_error::VortexResult +pub fn vortex_array::ArrayRef::from_arrow_with_session(array: A, nullable: bool, _session: &vortex_session::VortexSession) -> vortex_error::VortexResult + impl vortex_array::arrow::FromArrowArray<&arrow_array::array::primitive_array::PrimitiveArray> for vortex_array::ArrayRef pub fn vortex_array::ArrayRef::from_arrow(value: &arrow_array::array::primitive_array::PrimitiveArray, nullable: bool) -> vortex_error::VortexResult +pub fn vortex_array::ArrayRef::from_arrow_with_session(array: A, nullable: bool, _session: &vortex_session::VortexSession) -> vortex_error::VortexResult + impl vortex_array::arrow::FromArrowArray<&arrow_array::array::primitive_array::PrimitiveArray> for vortex_array::ArrayRef pub fn vortex_array::ArrayRef::from_arrow(value: &arrow_array::array::primitive_array::PrimitiveArray, nullable: bool) -> vortex_error::VortexResult +pub fn vortex_array::ArrayRef::from_arrow_with_session(array: A, nullable: bool, _session: &vortex_session::VortexSession) -> vortex_error::VortexResult + impl vortex_array::arrow::FromArrowArray<&arrow_array::array::primitive_array::PrimitiveArray> for vortex_array::ArrayRef pub fn vortex_array::ArrayRef::from_arrow(value: &arrow_array::array::primitive_array::PrimitiveArray, nullable: bool) -> vortex_error::VortexResult +pub fn vortex_array::ArrayRef::from_arrow_with_session(array: A, nullable: bool, _session: &vortex_session::VortexSession) -> vortex_error::VortexResult + impl vortex_array::arrow::FromArrowArray<&arrow_array::array::primitive_array::PrimitiveArray> for vortex_array::ArrayRef pub fn vortex_array::ArrayRef::from_arrow(value: &arrow_array::array::primitive_array::PrimitiveArray, nullable: bool) -> vortex_error::VortexResult +pub fn vortex_array::ArrayRef::from_arrow_with_session(array: A, nullable: bool, _session: &vortex_session::VortexSession) -> vortex_error::VortexResult + impl vortex_array::arrow::FromArrowArray<&arrow_array::array::primitive_array::PrimitiveArray> for vortex_array::ArrayRef pub fn vortex_array::ArrayRef::from_arrow(value: &arrow_array::array::primitive_array::PrimitiveArray, nullable: bool) -> vortex_error::VortexResult +pub fn vortex_array::ArrayRef::from_arrow_with_session(array: A, nullable: bool, _session: &vortex_session::VortexSession) -> vortex_error::VortexResult + impl vortex_array::arrow::FromArrowArray<&arrow_array::array::primitive_array::PrimitiveArray> for vortex_array::ArrayRef pub fn vortex_array::ArrayRef::from_arrow(value: &arrow_array::array::primitive_array::PrimitiveArray, nullable: bool) -> vortex_error::VortexResult +pub fn vortex_array::ArrayRef::from_arrow_with_session(array: A, nullable: bool, _session: &vortex_session::VortexSession) -> vortex_error::VortexResult + impl vortex_array::arrow::FromArrowArray<&arrow_array::array::primitive_array::PrimitiveArray> for vortex_array::ArrayRef pub fn vortex_array::ArrayRef::from_arrow(value: &arrow_array::array::primitive_array::PrimitiveArray, nullable: bool) -> vortex_error::VortexResult +pub fn vortex_array::ArrayRef::from_arrow_with_session(array: A, nullable: bool, _session: &vortex_session::VortexSession) -> vortex_error::VortexResult + impl vortex_array::arrow::FromArrowArray<&arrow_array::array::primitive_array::PrimitiveArray> for vortex_array::ArrayRef pub fn vortex_array::ArrayRef::from_arrow(value: &arrow_array::array::primitive_array::PrimitiveArray, nullable: bool) -> vortex_error::VortexResult +pub fn vortex_array::ArrayRef::from_arrow_with_session(array: A, nullable: bool, _session: &vortex_session::VortexSession) -> vortex_error::VortexResult + impl vortex_array::arrow::FromArrowArray<&arrow_array::array::primitive_array::PrimitiveArray> for vortex_array::ArrayRef pub fn vortex_array::ArrayRef::from_arrow(value: &arrow_array::array::primitive_array::PrimitiveArray, nullable: bool) -> vortex_error::VortexResult +pub fn vortex_array::ArrayRef::from_arrow_with_session(array: A, nullable: bool, _session: &vortex_session::VortexSession) -> vortex_error::VortexResult + impl vortex_array::arrow::FromArrowArray<&arrow_array::array::primitive_array::PrimitiveArray> for vortex_array::ArrayRef pub fn vortex_array::ArrayRef::from_arrow(value: &arrow_array::array::primitive_array::PrimitiveArray, nullable: bool) -> vortex_error::VortexResult +pub fn vortex_array::ArrayRef::from_arrow_with_session(array: A, nullable: bool, _session: &vortex_session::VortexSession) -> vortex_error::VortexResult + impl vortex_array::arrow::FromArrowArray<&arrow_array::array::primitive_array::PrimitiveArray> for vortex_array::ArrayRef pub fn vortex_array::ArrayRef::from_arrow(value: &arrow_array::array::primitive_array::PrimitiveArray, nullable: bool) -> vortex_error::VortexResult +pub fn vortex_array::ArrayRef::from_arrow_with_session(array: A, nullable: bool, _session: &vortex_session::VortexSession) -> vortex_error::VortexResult + impl vortex_array::arrow::FromArrowArray<&arrow_array::array::primitive_array::PrimitiveArray> for vortex_array::ArrayRef pub fn vortex_array::ArrayRef::from_arrow(value: &arrow_array::array::primitive_array::PrimitiveArray, nullable: bool) -> vortex_error::VortexResult +pub fn vortex_array::ArrayRef::from_arrow_with_session(array: A, nullable: bool, _session: &vortex_session::VortexSession) -> vortex_error::VortexResult + impl vortex_array::arrow::FromArrowArray<&arrow_array::array::struct_array::StructArray> for vortex_array::ArrayRef pub fn vortex_array::ArrayRef::from_arrow(value: &arrow_array::array::struct_array::StructArray, nullable: bool) -> vortex_error::VortexResult +pub fn vortex_array::ArrayRef::from_arrow_with_session(value: &arrow_array::array::struct_array::StructArray, nullable: bool, session: &vortex_session::VortexSession) -> vortex_error::VortexResult + impl vortex_array::arrow::FromArrowArray<&arrow_array::record_batch::RecordBatch> for vortex_array::ArrayRef pub fn vortex_array::ArrayRef::from_arrow(array: &arrow_array::record_batch::RecordBatch, nullable: bool) -> vortex_error::VortexResult +pub fn vortex_array::ArrayRef::from_arrow_with_session(array: &arrow_array::record_batch::RecordBatch, nullable: bool, session: &vortex_session::VortexSession) -> vortex_error::VortexResult + impl vortex_array::arrow::FromArrowArray<&dyn arrow_array::array::Array> for vortex_array::ArrayRef pub fn vortex_array::ArrayRef::from_arrow(array: &dyn arrow_array::array::Array, nullable: bool) -> vortex_error::VortexResult +pub fn vortex_array::ArrayRef::from_arrow_with_session(array: &dyn arrow_array::array::Array, nullable: bool, session: &vortex_session::VortexSession) -> vortex_error::VortexResult + impl vortex_array::arrow::FromArrowArray for vortex_array::ArrayRef pub fn vortex_array::ArrayRef::from_arrow(array: arrow_array::record_batch::RecordBatch, nullable: bool) -> vortex_error::VortexResult +pub fn vortex_array::ArrayRef::from_arrow_with_session(array: arrow_array::record_batch::RecordBatch, nullable: bool, session: &vortex_session::VortexSession) -> vortex_error::VortexResult + impl vortex_array::arrow::IntoArrowArray for vortex_array::ArrayRef pub fn vortex_array::ArrayRef::into_arrow(self, data_type: &arrow_schema::datatype::DataType) -> vortex_error::VortexResult @@ -22786,18 +22926,26 @@ impl, nullable: bool) -> vortex_error::VortexResult +pub fn vortex_array::ArrayRef::from_arrow_with_session(array: A, nullable: bool, _session: &vortex_session::VortexSession) -> vortex_error::VortexResult + impl vortex_array::arrow::FromArrowArray<&arrow_array::array::list_array::GenericListArray> for vortex_array::ArrayRef pub fn vortex_array::ArrayRef::from_arrow(value: &arrow_array::array::list_array::GenericListArray, nullable: bool) -> vortex_error::VortexResult +pub fn vortex_array::ArrayRef::from_arrow_with_session(value: &arrow_array::array::list_array::GenericListArray, nullable: bool, session: &vortex_session::VortexSession) -> vortex_error::VortexResult + impl vortex_array::arrow::FromArrowArray<&arrow_array::array::byte_array::GenericByteArray> for vortex_array::ArrayRef where ::Offset: vortex_array::dtype::IntegerPType pub fn vortex_array::ArrayRef::from_arrow(value: &arrow_array::array::byte_array::GenericByteArray, nullable: bool) -> vortex_error::VortexResult +pub fn vortex_array::ArrayRef::from_arrow_with_session(array: A, nullable: bool, _session: &vortex_session::VortexSession) -> vortex_error::VortexResult + impl vortex_array::arrow::FromArrowArray<&arrow_array::array::byte_view_array::GenericByteViewArray> for vortex_array::ArrayRef pub fn vortex_array::ArrayRef::from_arrow(value: &arrow_array::array::byte_view_array::GenericByteViewArray, nullable: bool) -> vortex_error::VortexResult +pub fn vortex_array::ArrayRef::from_arrow_with_session(array: A, nullable: bool, _session: &vortex_session::VortexSession) -> vortex_error::VortexResult + impl core::convert::AsRef for vortex_array::Array pub fn vortex_array::Array::as_ref(&self) -> &vortex_array::ArrayRef diff --git a/vortex-array/src/arrow/convert.rs b/vortex-array/src/arrow/convert.rs index 6eafa6033c5..b9cd26c7b80 100644 --- a/vortex-array/src/arrow/convert.rs +++ b/vortex-array/src/arrow/convert.rs @@ -56,6 +56,7 @@ use arrow_buffer::ScalarBuffer; use arrow_buffer::buffer::NullBuffer; use arrow_buffer::buffer::OffsetBuffer; use arrow_schema::DataType; +use arrow_schema::Field; use arrow_schema::TimeUnit as ArrowTimeUnit; use itertools::Itertools; use vortex_buffer::Alignment; @@ -66,12 +67,15 @@ use vortex_error::VortexExpect as _; use vortex_error::VortexResult; use vortex_error::vortex_bail; use vortex_error::vortex_panic; +use vortex_session::VortexSession; use crate::ArrayRef; use crate::IntoArray; +use crate::LEGACY_SESSION; use crate::arrays::BoolArray; use crate::arrays::DecimalArray; use crate::arrays::DictArray; +use crate::arrays::ExtensionArray; use crate::arrays::FixedSizeListArray; use crate::arrays::ListArray; use crate::arrays::ListViewArray; @@ -87,7 +91,9 @@ use crate::dtype::DecimalDType; use crate::dtype::IntegerPType; use crate::dtype::NativePType; use crate::dtype::PType; +use crate::dtype::arrow::resolve_extension_dtype; use crate::dtype::i256; +use crate::dtype::session::DTypeSessionExt; use crate::extension::datetime::TimeUnit; use crate::validity::Validity; @@ -380,23 +386,33 @@ fn remove_nulls(data: arrow_data::ArrayData) -> arrow_data::ArrayData { impl FromArrowArray<&ArrowStructArray> for ArrayRef { fn from_arrow(value: &ArrowStructArray, nullable: bool) -> VortexResult { + Self::from_arrow_with_session(value, nullable, &LEGACY_SESSION) + } + + fn from_arrow_with_session( + value: &ArrowStructArray, + nullable: bool, + session: &VortexSession, + ) -> VortexResult { + let columns = value + .columns() + .iter() + .zip(value.fields()) + .map(|(c, field)| { + // Arrow pushes down nulls, even into non-nullable fields. So we strip them + // out here because Vortex is a little more strict. + let storage = if c.null_count() > 0 && !field.is_nullable() { + let stripped = make_array(remove_nulls(c.into_data())); + Self::from_arrow_with_session(stripped.as_ref(), false, session)? + } else { + Self::from_arrow_with_session(c.as_ref(), field.is_nullable(), session)? + }; + wrap_extension_if_field_has_metadata(storage, field.as_ref(), session) + }) + .collect::>>()?; Ok(StructArray::try_new( value.column_names().iter().copied().collect(), - value - .columns() - .iter() - .zip(value.fields()) - .map(|(c, field)| { - // Arrow pushes down nulls, even into non-nullable fields. So we strip them - // out here because Vortex is a little more strict. - if c.null_count() > 0 && !field.is_nullable() { - let stripped = make_array(remove_nulls(c.into_data())); - Self::from_arrow(stripped.as_ref(), false) - } else { - Self::from_arrow(c.as_ref(), field.is_nullable()) - } - }) - .collect::>>()?, + columns, value.len(), nulls(value.nulls(), nullable), )? @@ -406,14 +422,30 @@ impl FromArrowArray<&ArrowStructArray> for ArrayRef { impl FromArrowArray<&GenericListArray> for ArrayRef { fn from_arrow(value: &GenericListArray, nullable: bool) -> VortexResult { - // Extract the validity of the underlying element array. - let elements_are_nullable = match value.data_type() { - DataType::List(field) => field.is_nullable(), - DataType::LargeList(field) => field.is_nullable(), + Self::from_arrow_with_session(value, nullable, &LEGACY_SESSION) + } + + fn from_arrow_with_session( + value: &GenericListArray, + nullable: bool, + session: &VortexSession, + ) -> VortexResult { + let elements_field = match value.data_type() { + DataType::List(field) => field.clone(), + DataType::LargeList(field) => field.clone(), dt => vortex_panic!("Invalid data type for ListArray: {dt}"), }; - let elements = Self::from_arrow(value.values().as_ref(), elements_are_nullable)?; + let elements_storage = Self::from_arrow_with_session( + value.values().as_ref(), + elements_field.is_nullable(), + session, + )?; + let elements = wrap_extension_if_field_has_metadata( + elements_storage, + elements_field.as_ref(), + session, + )?; // `offsets` are always non-nullable. let offsets = value.offsets().clone().into_array(); @@ -445,12 +477,25 @@ impl FromArrowArray<&GenericListViewArray> impl FromArrowArray<&ArrowFixedSizeListArray> for ArrayRef { fn from_arrow(array: &ArrowFixedSizeListArray, nullable: bool) -> VortexResult { + Self::from_arrow_with_session(array, nullable, &LEGACY_SESSION) + } + + fn from_arrow_with_session( + array: &ArrowFixedSizeListArray, + nullable: bool, + session: &VortexSession, + ) -> VortexResult { let DataType::FixedSizeList(field, list_size) = array.data_type() else { vortex_panic!("Invalid data type for ListArray: {}", array.data_type()); }; + let elements_storage = + Self::from_arrow_with_session(array.values().as_ref(), field.is_nullable(), session)?; + let elements = + wrap_extension_if_field_has_metadata(elements_storage, field.as_ref(), session)?; + Ok(FixedSizeListArray::try_new( - Self::from_arrow(array.values().as_ref(), field.is_nullable())?, + elements, *list_size as u32, nulls(array.nulls(), nullable), array.len(), @@ -494,6 +539,30 @@ fn nulls(nulls: Option<&NullBuffer>, nullable: bool) -> Validity { } impl FromArrowArray<&dyn ArrowArray> for ArrayRef { + fn from_arrow_with_session( + array: &dyn ArrowArray, + nullable: bool, + session: &VortexSession, + ) -> VortexResult { + match array.data_type() { + DataType::Struct(_) => { + Self::from_arrow_with_session(array.as_struct(), nullable, session) + } + DataType::List(_) => { + Self::from_arrow_with_session(array.as_list::(), nullable, session) + } + DataType::LargeList(_) => { + Self::from_arrow_with_session(array.as_list::(), nullable, session) + } + DataType::FixedSizeList(..) => { + Self::from_arrow_with_session(array.as_fixed_size_list(), nullable, session) + } + // Other arrays don't carry child Fields, so session-aware dispatch is identical to + // the legacy path; fall through to `from_arrow`. + _ => Self::from_arrow(array, nullable), + } + } + fn from_arrow(array: &dyn ArrowArray, nullable: bool) -> VortexResult { match array.data_type() { DataType::Boolean => Self::from_arrow(array.as_boolean(), nullable), @@ -617,13 +686,45 @@ impl FromArrowArray<&dyn ArrowArray> for ArrayRef { impl FromArrowArray for ArrayRef { fn from_arrow(array: RecordBatch, nullable: bool) -> VortexResult { - ArrayRef::from_arrow(&arrow_array::StructArray::from(array), nullable) + Self::from_arrow_with_session(array, nullable, &LEGACY_SESSION) + } + + fn from_arrow_with_session( + array: RecordBatch, + nullable: bool, + session: &VortexSession, + ) -> VortexResult { + Self::from_arrow_with_session(&arrow_array::StructArray::from(array), nullable, session) } } impl FromArrowArray<&RecordBatch> for ArrayRef { fn from_arrow(array: &RecordBatch, nullable: bool) -> VortexResult { - Self::from_arrow(array.clone(), nullable) + Self::from_arrow_with_session(array, nullable, &LEGACY_SESSION) + } + + fn from_arrow_with_session( + array: &RecordBatch, + nullable: bool, + session: &VortexSession, + ) -> VortexResult { + Self::from_arrow_with_session(array.clone(), nullable, session) + } +} + +/// Inverse of `field_from_dtype` (in `dtype/arrow.rs`): if `field` carries +/// `ARROW:extension:name` metadata for a registered extension, rewrap `storage` as an +/// `ExtensionArray`; otherwise fall through to `storage`. Diagnostic warnings live in +/// [`resolve_extension_dtype`]. +fn wrap_extension_if_field_has_metadata( + storage: ArrayRef, + field: &Field, + session: &VortexSession, +) -> VortexResult { + let dtypes = session.dtypes(); + match resolve_extension_dtype(field, &dtypes, storage.dtype()) { + Some(ext_dtype) => Ok(ExtensionArray::try_new(ext_dtype, storage)?.into_array()), + None => Ok(storage), } } diff --git a/vortex-array/src/arrow/mod.rs b/vortex-array/src/arrow/mod.rs index efc83aa6af6..52e32bcf3d7 100644 --- a/vortex-array/src/arrow/mod.rs +++ b/vortex-array/src/arrow/mod.rs @@ -6,6 +6,7 @@ use arrow_array::ArrayRef as ArrowArrayRef; use arrow_schema::DataType; use vortex_error::VortexResult; +use vortex_session::VortexSession; mod convert; mod datum; @@ -24,10 +25,19 @@ use crate::ArrayRef; use crate::LEGACY_SESSION; use crate::VortexSessionExecute; -pub trait FromArrowArray { - fn from_arrow(array: A, nullable: bool) -> VortexResult - where - Self: Sized; +pub trait FromArrowArray: Sized { + fn from_arrow(array: A, nullable: bool) -> VortexResult; + + /// Same conversion, with session for resolving `ARROW:extension:name` field metadata to + /// registered extension dtypes. The default ignores the session — override on impls that + /// see Arrow `Field`s (RecordBatch, Struct, List, FSL). + fn from_arrow_with_session( + array: A, + nullable: bool, + _session: &VortexSession, + ) -> VortexResult { + Self::from_arrow(array, nullable) + } } #[deprecated(note = "Use `execute_arrow(None, ctx)` or `execute_arrow(Some(dt), ctx)` instead")] diff --git a/vortex-array/src/dtype/arrow.rs b/vortex-array/src/dtype/arrow.rs index 125220b275f..f55954ff639 100644 --- a/vortex-array/src/dtype/arrow.rs +++ b/vortex-array/src/dtype/arrow.rs @@ -254,18 +254,36 @@ impl FromArrowType<&Field> for DType { /// Convert an Arrow Field to a [`DType`] with `dtypes` already borrowed from the session, /// so the handle is acquired once per schema rather than once per field. fn dtype_from_field(field: &Field, dtypes: &DTypeSession) -> DType { - let ext_name = field.extension_type_name(); - - // Variant maps to its own DType variant, not DType::Extension. - if ext_name.is_some_and(|s| s == ARROW_EXT_NAME_VARIANT) { + if field + .extension_type_name() + .is_some_and(|s| s == ARROW_EXT_NAME_VARIANT) + { return DType::Variant(field.is_nullable().into()); } let storage_dtype = storage_dtype_from_field(field, dtypes); + match resolve_extension_dtype(field, dtypes, &storage_dtype) { + Some(ext_ref) => DType::Extension(ext_ref), + None => storage_dtype, + } +} - let Some(ext_name) = ext_name else { - return storage_dtype; - }; +/// Resolve the [`ExtDTypeRef`] for an Arrow Field whose `ARROW:extension:name` metadata names +/// a registered Vortex extension. Returns `None` for unregistered extensions, malformed +/// metadata, or fields with no extension name; `tracing::warn!` reports the anomaly so callers +/// can simply fall back to the storage representation. +/// +/// Used on both the dtype side ([`dtype_from_field`]) and the array side +/// (`wrap_extension_if_field_has_metadata`); only the final wrap differs. +pub(crate) fn resolve_extension_dtype( + field: &Field, + dtypes: &DTypeSession, + storage_dtype: &DType, +) -> Option { + let ext_name = field.extension_type_name()?; + if ext_name == ARROW_EXT_NAME_VARIANT { + return None; + } let arrow_id = ExtId::new(ext_name); let (ext_id, codec) = match dtypes.vortex_alias_for(&arrow_id) { @@ -275,38 +293,33 @@ fn dtype_from_field(field: &Field, dtypes: &DTypeSession) -> DType { let Some(plugin) = dtypes.registry().find(&ext_id) else { tracing::warn!( - "Arrow field {:?} extension id {:?} not registered; using storage dtype", + "Arrow field {:?} extension id {ext_name:?} not registered; using storage dtype", field.name(), - ext_name, ); - return storage_dtype; + return None; }; let metadata_bytes = match decode_extension_metadata(field, codec) { Ok(bytes) => bytes, Err(e) => { tracing::warn!( - "Arrow field {:?} extension id {:?} has malformed metadata ({}); \ + "Arrow field {:?} extension id {ext_name:?} has malformed metadata ({e}); \ using storage dtype", field.name(), - ext_name, - e, ); - return storage_dtype; + return None; } }; match plugin.deserialize(&metadata_bytes, storage_dtype.clone()) { - Ok(ext_ref) => DType::Extension(ext_ref), + Ok(ext_ref) => Some(ext_ref), Err(e) => { tracing::warn!( - "Arrow field {:?} extension id {:?} failed to deserialize ({}); \ + "Arrow field {:?} extension id {ext_name:?} failed to deserialize ({e}); \ using storage dtype", field.name(), - ext_name, - e, ); - storage_dtype + None } } } diff --git a/vortex-tensor/src/tests/arrow_roundtrip.rs b/vortex-tensor/src/tests/arrow_roundtrip.rs index b3a165d43c8..423cb97b294 100644 --- a/vortex-tensor/src/tests/arrow_roundtrip.rs +++ b/vortex-tensor/src/tests/arrow_roundtrip.rs @@ -9,15 +9,23 @@ use arrow_schema::DataType; use arrow_schema::TimeUnit as ArrowTimeUnit; use arrow_schema::extension::EXTENSION_TYPE_METADATA_KEY; use arrow_schema::extension::EXTENSION_TYPE_NAME_KEY; +use vortex_array::ArrayRef; +use vortex_array::IntoArray; +use vortex_array::arrays::ExtensionArray; +use vortex_array::arrays::FixedSizeListArray; +use vortex_array::arrays::PrimitiveArray; use vortex_array::arrays::StructArray; +use vortex_array::arrow::FromArrowArray; use vortex_array::dtype::DType; use vortex_array::dtype::Nullability; use vortex_array::dtype::PType; use vortex_array::dtype::arrow::FromArrowType; use vortex_array::dtype::extension::ExtDType; use vortex_array::dtype::extension::ExtVTable; +use vortex_array::extension::EmptyMetadata; use vortex_array::extension::datetime::TimeUnit; use vortex_array::extension::datetime::Timestamp; +use vortex_array::validity::Validity; use crate::tests::SESSION; use crate::types::fixed_shape::FixedShapeTensor; @@ -30,7 +38,7 @@ fn vector_dtype(len: u32) -> DType { len, Nullability::NonNullable, ); - let ext = ExtDType::::try_new(vortex_array::extension::EmptyMetadata, storage).unwrap(); + let ext = ExtDType::::try_new(EmptyMetadata, storage).unwrap(); DType::Extension(ext.erased()) } @@ -195,3 +203,53 @@ fn temporal_extension_still_uses_native_arrow() { assert!(field.metadata().get(EXTENSION_TYPE_NAME_KEY).is_none()); assert!(field.metadata().get(EXTENSION_TYPE_METADATA_KEY).is_none()); } + +/// Build a storage FSL with `num_rows` rows, each of `elements_per_row` elements. +fn fsl_f32_storage(elements_per_row: u32, num_rows: usize) -> ArrayRef { + let total = elements_per_row as usize * num_rows; + let elements = PrimitiveArray::from_iter((0..total).map(|i| i as f32)); + FixedSizeListArray::try_new( + elements.into_array(), + elements_per_row, + Validity::NonNullable, + num_rows, + ) + .unwrap() + .into_array() +} + +#[test] +fn vector_record_batch_round_trip() { + let vector_array = + ExtensionArray::try_new_from_vtable(Vector, EmptyMetadata, fsl_f32_storage(4, 2)) + .unwrap() + .into_array(); + let original = StructArray::from_fields(&[("embedding", vector_array)]).unwrap(); + + let dtype = DType::struct_([("embedding", vector_dtype(4))], Nullability::NonNullable); + let schema = dtype.to_arrow_schema_with_session(&SESSION).unwrap(); + let rb = original.into_record_batch_with_schema(&schema).unwrap(); + + let recovered = ArrayRef::from_arrow_with_session(rb, false, &SESSION).unwrap(); + assert_eq!(recovered.dtype(), &dtype); +} + +#[test] +fn fixed_shape_tensor_record_batch_round_trip() { + let metadata = FixedShapeTensorMetadata::new(vec![2, 2]) + .with_dim_names(vec!["row".into(), "col".into()]) + .unwrap(); + let tensor_dtype = fixed_shape_dtype(metadata.clone(), 4); + let dtype = DType::struct_([("tensor", tensor_dtype.clone())], Nullability::NonNullable); + let schema = dtype.to_arrow_schema_with_session(&SESSION).unwrap(); + + let tensor_array = + ExtensionArray::try_new_from_vtable(FixedShapeTensor, metadata, fsl_f32_storage(4, 3)) + .unwrap() + .into_array(); + let original = StructArray::from_fields(&[("tensor", tensor_array)]).unwrap(); + let rb = original.into_record_batch_with_schema(&schema).unwrap(); + + let recovered = ArrayRef::from_arrow_with_session(rb, false, &SESSION).unwrap(); + assert_eq!(recovered.dtype(), &dtype); +} From fdaf5d7f39703962928e550207e65d5356dd1cb4 Mon Sep 17 00:00:00 2001 From: Baris Palaska Date: Tue, 28 Apr 2026 19:29:42 +0100 Subject: [PATCH 20/21] clippy Signed-off-by: Baris Palaska --- vortex-array/src/arrow/convert.rs | 13 +++++-------- vortex-tensor/src/tests/arrow_roundtrip.rs | 2 +- 2 files changed, 6 insertions(+), 9 deletions(-) diff --git a/vortex-array/src/arrow/convert.rs b/vortex-array/src/arrow/convert.rs index b9cd26c7b80..556df3a725a 100644 --- a/vortex-array/src/arrow/convert.rs +++ b/vortex-array/src/arrow/convert.rs @@ -430,9 +430,9 @@ impl FromArrowArray<&GenericListArray> for nullable: bool, session: &VortexSession, ) -> VortexResult { - let elements_field = match value.data_type() { - DataType::List(field) => field.clone(), - DataType::LargeList(field) => field.clone(), + let elements_field: &Field = match value.data_type() { + DataType::List(field) => field.as_ref(), + DataType::LargeList(field) => field.as_ref(), dt => vortex_panic!("Invalid data type for ListArray: {dt}"), }; @@ -441,11 +441,8 @@ impl FromArrowArray<&GenericListArray> for elements_field.is_nullable(), session, )?; - let elements = wrap_extension_if_field_has_metadata( - elements_storage, - elements_field.as_ref(), - session, - )?; + let elements = + wrap_extension_if_field_has_metadata(elements_storage, elements_field, session)?; // `offsets` are always non-nullable. let offsets = value.offsets().clone().into_array(); diff --git a/vortex-tensor/src/tests/arrow_roundtrip.rs b/vortex-tensor/src/tests/arrow_roundtrip.rs index 423cb97b294..baa17a480ca 100644 --- a/vortex-tensor/src/tests/arrow_roundtrip.rs +++ b/vortex-tensor/src/tests/arrow_roundtrip.rs @@ -240,7 +240,7 @@ fn fixed_shape_tensor_record_batch_round_trip() { .with_dim_names(vec!["row".into(), "col".into()]) .unwrap(); let tensor_dtype = fixed_shape_dtype(metadata.clone(), 4); - let dtype = DType::struct_([("tensor", tensor_dtype.clone())], Nullability::NonNullable); + let dtype = DType::struct_([("tensor", tensor_dtype)], Nullability::NonNullable); let schema = dtype.to_arrow_schema_with_session(&SESSION).unwrap(); let tensor_array = From dfdbf10389c0f7006773cab055eb2b5a09d7f67e Mon Sep 17 00:00:00 2001 From: Baris Palaska Date: Wed, 29 Apr 2026 12:23:01 +0100 Subject: [PATCH 21/21] simpler, no registry Signed-off-by: Baris Palaska --- vortex-array/public-api.lock | 78 ++++++-- vortex-array/src/arrow/record_batch.rs | 14 +- vortex-array/src/dtype/arrow.rs | 16 +- vortex-array/src/dtype/extension/plugin.rs | 8 + vortex-array/src/dtype/extension/vtable.rs | 33 ++++ vortex-array/src/dtype/session.rs | 179 ++---------------- vortex-tensor/public-api.lock | 2 + vortex-tensor/src/lib.rs | 10 - vortex-tensor/src/tests/arrow_roundtrip.rs | 12 +- vortex-tensor/src/types/fixed_shape/mod.rs | 2 - vortex-tensor/src/types/fixed_shape/vtable.rs | 14 ++ 11 files changed, 161 insertions(+), 207 deletions(-) diff --git a/vortex-array/public-api.lock b/vortex-array/public-api.lock index 90667c25f37..65024fb8fbc 100644 --- a/vortex-array/public-api.lock +++ b/vortex-array/public-api.lock @@ -8866,6 +8866,38 @@ pub fn vortex_array::dtype::PType::try_from_arrow(value: &arrow_schema::datatype pub mod vortex_array::dtype::extension +pub struct vortex_array::dtype::extension::ArrowCanonicalAlias + +pub vortex_array::dtype::extension::ArrowCanonicalAlias::arrow_id: vortex_array::dtype::extension::ExtId + +pub vortex_array::dtype::extension::ArrowCanonicalAlias::codec: vortex_array::dtype::extension::ArrowCanonicalCodec + +impl core::clone::Clone for vortex_array::dtype::extension::ArrowCanonicalAlias + +pub fn vortex_array::dtype::extension::ArrowCanonicalAlias::clone(&self) -> vortex_array::dtype::extension::ArrowCanonicalAlias + +impl core::fmt::Debug for vortex_array::dtype::extension::ArrowCanonicalAlias + +pub fn vortex_array::dtype::extension::ArrowCanonicalAlias::fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result + +impl core::marker::Copy for vortex_array::dtype::extension::ArrowCanonicalAlias + +pub struct vortex_array::dtype::extension::ArrowCanonicalCodec + +pub vortex_array::dtype::extension::ArrowCanonicalCodec::from_json: fn(&str) -> vortex_error::VortexResult> + +pub vortex_array::dtype::extension::ArrowCanonicalCodec::to_json: fn(&[u8]) -> vortex_error::VortexResult + +impl core::clone::Clone for vortex_array::dtype::extension::ArrowCanonicalCodec + +pub fn vortex_array::dtype::extension::ArrowCanonicalCodec::clone(&self) -> vortex_array::dtype::extension::ArrowCanonicalCodec + +impl core::fmt::Debug for vortex_array::dtype::extension::ArrowCanonicalCodec + +pub fn vortex_array::dtype::extension::ArrowCanonicalCodec::fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result + +impl core::marker::Copy for vortex_array::dtype::extension::ArrowCanonicalCodec + pub struct vortex_array::dtype::extension::ExtDType impl vortex_array::dtype::extension::ExtDType @@ -8980,12 +9012,16 @@ pub fn vortex_array::dtype::extension::ExtDTypeRef::hash( pub trait vortex_array::dtype::extension::ExtDTypePlugin: 'static + core::marker::Send + core::marker::Sync + core::fmt::Debug +pub fn vortex_array::dtype::extension::ExtDTypePlugin::arrow_canonical(&self) -> core::option::Option + pub fn vortex_array::dtype::extension::ExtDTypePlugin::deserialize(&self, data: &[u8], storage_dtype: vortex_array::dtype::DType) -> vortex_error::VortexResult pub fn vortex_array::dtype::extension::ExtDTypePlugin::id(&self) -> vortex_array::dtype::extension::ExtId impl vortex_array::dtype::extension::ExtDTypePlugin for V +pub fn V::arrow_canonical(&self) -> core::option::Option + pub fn V::deserialize(&self, data: &[u8], storage_dtype: vortex_array::dtype::DType) -> core::result::Result pub fn V::id(&self) -> vortex_session::registry::Id @@ -8996,6 +9032,8 @@ pub type vortex_array::dtype::extension::ExtVTable::Metadata: 'static + core::ma pub type vortex_array::dtype::extension::ExtVTable::NativeValue<'a>: core::fmt::Display +pub fn vortex_array::dtype::extension::ExtVTable::arrow_canonical(&self) -> core::option::Option + pub fn vortex_array::dtype::extension::ExtVTable::can_coerce_from(ext_dtype: &vortex_array::dtype::extension::ExtDType, other: &vortex_array::dtype::DType) -> bool pub fn vortex_array::dtype::extension::ExtVTable::can_coerce_to(ext_dtype: &vortex_array::dtype::extension::ExtDType, other: &vortex_array::dtype::DType) -> bool @@ -9020,6 +9058,8 @@ pub type vortex_array::extension::datetime::Date::Metadata = vortex_array::exten pub type vortex_array::extension::datetime::Date::NativeValue<'a> = vortex_array::extension::datetime::DateValue +pub fn vortex_array::extension::datetime::Date::arrow_canonical(&self) -> core::option::Option + pub fn vortex_array::extension::datetime::Date::can_coerce_from(ext_dtype: &vortex_array::dtype::extension::ExtDType, other: &vortex_array::dtype::DType) -> bool pub fn vortex_array::extension::datetime::Date::can_coerce_to(ext_dtype: &vortex_array::dtype::extension::ExtDType, other: &vortex_array::dtype::DType) -> bool @@ -9044,6 +9084,8 @@ pub type vortex_array::extension::datetime::Time::Metadata = vortex_array::exten pub type vortex_array::extension::datetime::Time::NativeValue<'a> = vortex_array::extension::datetime::TimeValue +pub fn vortex_array::extension::datetime::Time::arrow_canonical(&self) -> core::option::Option + pub fn vortex_array::extension::datetime::Time::can_coerce_from(ext_dtype: &vortex_array::dtype::extension::ExtDType, other: &vortex_array::dtype::DType) -> bool pub fn vortex_array::extension::datetime::Time::can_coerce_to(ext_dtype: &vortex_array::dtype::extension::ExtDType, other: &vortex_array::dtype::DType) -> bool @@ -9068,6 +9110,8 @@ pub type vortex_array::extension::datetime::Timestamp::Metadata = vortex_array:: pub type vortex_array::extension::datetime::Timestamp::NativeValue<'a> = vortex_array::extension::datetime::TimestampValue<'a> +pub fn vortex_array::extension::datetime::Timestamp::arrow_canonical(&self) -> core::option::Option + pub fn vortex_array::extension::datetime::Timestamp::can_coerce_from(ext_dtype: &vortex_array::dtype::extension::ExtDType, other: &vortex_array::dtype::DType) -> bool pub fn vortex_array::extension::datetime::Timestamp::can_coerce_to(ext_dtype: &vortex_array::dtype::extension::ExtDType, other: &vortex_array::dtype::DType) -> bool @@ -9092,6 +9136,8 @@ pub type vortex_array::extension::uuid::Uuid::Metadata = vortex_array::extension pub type vortex_array::extension::uuid::Uuid::NativeValue<'a> = uuid::Uuid +pub fn vortex_array::extension::uuid::Uuid::arrow_canonical(&self) -> core::option::Option + pub fn vortex_array::extension::uuid::Uuid::can_coerce_from(ext_dtype: &vortex_array::dtype::extension::ExtDType, other: &vortex_array::dtype::DType) -> bool pub fn vortex_array::extension::uuid::Uuid::can_coerce_to(ext_dtype: &vortex_array::dtype::extension::ExtDType, other: &vortex_array::dtype::DType) -> bool @@ -9150,35 +9196,17 @@ pub mod vortex_array::dtype::serde pub mod vortex_array::dtype::session -pub struct vortex_array::dtype::session::ArrowCanonicalCodec - -pub vortex_array::dtype::session::ArrowCanonicalCodec::from_json: fn(&str) -> vortex_error::VortexResult> - -pub vortex_array::dtype::session::ArrowCanonicalCodec::to_json: fn(&[u8]) -> vortex_error::VortexResult - -impl core::clone::Clone for vortex_array::dtype::session::ArrowCanonicalCodec - -pub fn vortex_array::dtype::session::ArrowCanonicalCodec::clone(&self) -> vortex_array::dtype::session::ArrowCanonicalCodec - -impl core::fmt::Debug for vortex_array::dtype::session::ArrowCanonicalCodec - -pub fn vortex_array::dtype::session::ArrowCanonicalCodec::fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result - -impl core::marker::Copy for vortex_array::dtype::session::ArrowCanonicalCodec - pub struct vortex_array::dtype::session::DTypeSession impl vortex_array::dtype::session::DTypeSession -pub fn vortex_array::dtype::session::DTypeSession::arrow_alias_for(&self, vortex_id: &vortex_array::dtype::extension::ExtId) -> core::option::Option<(vortex_array::dtype::extension::ExtId, vortex_array::dtype::session::ArrowCanonicalCodec)> +pub fn vortex_array::dtype::session::DTypeSession::arrow_alias_for(&self, vortex_id: &vortex_array::dtype::extension::ExtId) -> core::option::Option pub fn vortex_array::dtype::session::DTypeSession::register(&self, vtable: V) -pub fn vortex_array::dtype::session::DTypeSession::register_arrow_canonical(&self, vortex_id: vortex_array::dtype::extension::ExtId, arrow_id: vortex_array::dtype::extension::ExtId, codec: vortex_array::dtype::session::ArrowCanonicalCodec) - pub fn vortex_array::dtype::session::DTypeSession::registry(&self) -> &vortex_array::dtype::session::ExtDTypeRegistry -pub fn vortex_array::dtype::session::DTypeSession::vortex_alias_for(&self, arrow_id: &vortex_array::dtype::extension::ExtId) -> core::option::Option<(vortex_array::dtype::extension::ExtId, vortex_array::dtype::session::ArrowCanonicalCodec)> +pub fn vortex_array::dtype::session::DTypeSession::vortex_alias_for(&self, arrow_id: &vortex_array::dtype::extension::ExtId) -> core::option::Option<(vortex_array::dtype::extension::ExtId, vortex_array::dtype::extension::ArrowCanonicalAlias)> impl core::default::Default for vortex_array::dtype::session::DTypeSession @@ -12994,6 +13022,8 @@ pub type vortex_array::extension::datetime::Date::Metadata = vortex_array::exten pub type vortex_array::extension::datetime::Date::NativeValue<'a> = vortex_array::extension::datetime::DateValue +pub fn vortex_array::extension::datetime::Date::arrow_canonical(&self) -> core::option::Option + pub fn vortex_array::extension::datetime::Date::can_coerce_from(ext_dtype: &vortex_array::dtype::extension::ExtDType, other: &vortex_array::dtype::DType) -> bool pub fn vortex_array::extension::datetime::Date::can_coerce_to(ext_dtype: &vortex_array::dtype::extension::ExtDType, other: &vortex_array::dtype::DType) -> bool @@ -13050,6 +13080,8 @@ pub type vortex_array::extension::datetime::Time::Metadata = vortex_array::exten pub type vortex_array::extension::datetime::Time::NativeValue<'a> = vortex_array::extension::datetime::TimeValue +pub fn vortex_array::extension::datetime::Time::arrow_canonical(&self) -> core::option::Option + pub fn vortex_array::extension::datetime::Time::can_coerce_from(ext_dtype: &vortex_array::dtype::extension::ExtDType, other: &vortex_array::dtype::DType) -> bool pub fn vortex_array::extension::datetime::Time::can_coerce_to(ext_dtype: &vortex_array::dtype::extension::ExtDType, other: &vortex_array::dtype::DType) -> bool @@ -13108,6 +13140,8 @@ pub type vortex_array::extension::datetime::Timestamp::Metadata = vortex_array:: pub type vortex_array::extension::datetime::Timestamp::NativeValue<'a> = vortex_array::extension::datetime::TimestampValue<'a> +pub fn vortex_array::extension::datetime::Timestamp::arrow_canonical(&self) -> core::option::Option + pub fn vortex_array::extension::datetime::Timestamp::can_coerce_from(ext_dtype: &vortex_array::dtype::extension::ExtDType, other: &vortex_array::dtype::DType) -> bool pub fn vortex_array::extension::datetime::Timestamp::can_coerce_to(ext_dtype: &vortex_array::dtype::extension::ExtDType, other: &vortex_array::dtype::DType) -> bool @@ -13190,6 +13224,8 @@ pub type vortex_array::extension::uuid::Uuid::Metadata = vortex_array::extension pub type vortex_array::extension::uuid::Uuid::NativeValue<'a> = uuid::Uuid +pub fn vortex_array::extension::uuid::Uuid::arrow_canonical(&self) -> core::option::Option + pub fn vortex_array::extension::uuid::Uuid::can_coerce_from(ext_dtype: &vortex_array::dtype::extension::ExtDType, other: &vortex_array::dtype::DType) -> bool pub fn vortex_array::extension::uuid::Uuid::can_coerce_to(ext_dtype: &vortex_array::dtype::extension::ExtDType, other: &vortex_array::dtype::DType) -> bool @@ -22200,6 +22236,8 @@ impl vortex_array::Array pub fn vortex_array::Array::into_record_batch_with_schema(self, schema: impl core::convert::AsRef) -> vortex_error::VortexResult +pub fn vortex_array::Array::into_record_batch_with_schema_with_session(self, schema: impl core::convert::AsRef, session: &vortex_session::VortexSession) -> vortex_error::VortexResult + impl vortex_array::Array pub fn vortex_array::Array::remove_column_owned(&self, name: impl core::convert::Into) -> core::option::Option<(Self, vortex_array::ArrayRef)> diff --git a/vortex-array/src/arrow/record_batch.rs b/vortex-array/src/arrow/record_batch.rs index b57c307aed0..b2f7a5e89ba 100644 --- a/vortex-array/src/arrow/record_batch.rs +++ b/vortex-array/src/arrow/record_batch.rs @@ -6,6 +6,7 @@ use arrow_array::cast::AsArray; use arrow_schema::DataType; use arrow_schema::Schema; use vortex_error::VortexResult; +use vortex_session::VortexSession; use crate::LEGACY_SESSION; use crate::VortexSessionExecute; @@ -17,11 +18,22 @@ impl StructArray { pub fn into_record_batch_with_schema( self, schema: impl AsRef, + ) -> VortexResult { + self.into_record_batch_with_schema_with_session(schema, &LEGACY_SESSION) + } + + /// Same as [`Self::into_record_batch_with_schema`], but routes execution through `session` + /// so canonical Arrow extension aliases declared on registered vtables apply uniformly to + /// both schema construction and array conversion. + pub fn into_record_batch_with_schema_with_session( + self, + schema: impl AsRef, + session: &VortexSession, ) -> VortexResult { let data_type = DataType::Struct(schema.as_ref().fields.clone()); let array_ref = self .into_array() - .execute_arrow(Some(&data_type), &mut LEGACY_SESSION.create_execution_ctx())?; + .execute_arrow(Some(&data_type), &mut session.create_execution_ctx())?; Ok(RecordBatch::from(array_ref.as_struct())) } } diff --git a/vortex-array/src/dtype/arrow.rs b/vortex-array/src/dtype/arrow.rs index f55954ff639..4972e075673 100644 --- a/vortex-array/src/dtype/arrow.rs +++ b/vortex-array/src/dtype/arrow.rs @@ -42,9 +42,9 @@ use crate::dtype::FieldName; use crate::dtype::Nullability; use crate::dtype::PType; use crate::dtype::StructFields; +use crate::dtype::extension::ArrowCanonicalCodec; use crate::dtype::extension::ExtDTypeRef; use crate::dtype::extension::ExtId; -use crate::dtype::session::ArrowCanonicalCodec; use crate::dtype::session::DTypeSession; use crate::dtype::session::DTypeSessionExt; use crate::extension::datetime::AnyTemporal; @@ -287,7 +287,7 @@ pub(crate) fn resolve_extension_dtype( let arrow_id = ExtId::new(ext_name); let (ext_id, codec) = match dtypes.vortex_alias_for(&arrow_id) { - Some((vortex_id, codec)) => (vortex_id, Some(codec)), + Some((vortex_id, alias)) => (vortex_id, Some(alias.codec)), None => (arrow_id, None), }; @@ -376,8 +376,10 @@ impl DType { self.to_arrow_schema_with_session(&LEGACY_SESSION) } - /// Convert a Vortex [`DType`] into an Arrow [`Schema`], consulting `session` for Arrow - /// canonical extension aliases registered via [`DTypeSession::register_arrow_canonical`]. + /// Convert a Vortex [`DType`] into an Arrow [`Schema`], consulting `session` for canonical + /// Arrow extension aliases declared by registered vtables via [`ExtVTable::arrow_canonical`]. + /// + /// [`ExtVTable::arrow_canonical`]: crate::dtype::extension::ExtVTable::arrow_canonical pub fn to_arrow_schema_with_session(&self, session: &VortexSession) -> VortexResult { let DType::Struct(struct_dtype, nullable) = self else { vortex_bail!("only DType::Struct can be converted to arrow schema"); @@ -509,9 +511,9 @@ fn field_from_dtype(name: &str, dtype: &DType, dtypes: &DTypeSession) -> VortexR let storage_arrow = arrow_dtype_from_dtype(ext.storage_dtype(), dtypes)?; let ext_meta_bytes = ext.serialize_metadata()?; let (ext_name, meta_str) = match dtypes.arrow_alias_for(&ext.id()) { - Some((canonical, codec)) => ( - canonical.as_str().to_owned(), - (codec.to_json)(&ext_meta_bytes)?, + Some(alias) => ( + alias.arrow_id.as_str().to_owned(), + (alias.codec.to_json)(&ext_meta_bytes)?, ), None => ( ext.id().as_str().to_owned(), diff --git a/vortex-array/src/dtype/extension/plugin.rs b/vortex-array/src/dtype/extension/plugin.rs index 7d242d85274..8eff8591af6 100644 --- a/vortex-array/src/dtype/extension/plugin.rs +++ b/vortex-array/src/dtype/extension/plugin.rs @@ -7,6 +7,7 @@ use std::sync::Arc; use vortex_error::VortexResult; use crate::dtype::DType; +use crate::dtype::extension::ArrowCanonicalAlias; use crate::dtype::extension::ExtDType; use crate::dtype::extension::ExtDTypeRef; use crate::dtype::extension::ExtId; @@ -26,6 +27,9 @@ pub trait ExtDTypePlugin: 'static + Send + Sync + Debug { /// Returns the ID for this extension type. fn id(&self) -> ExtId; + /// See [`ExtVTable::arrow_canonical`]. + fn arrow_canonical(&self) -> Option; + /// Deserialize an extension type from serialized metadata. fn deserialize(&self, data: &[u8], storage_dtype: DType) -> VortexResult; } @@ -35,6 +39,10 @@ impl ExtDTypePlugin for V { ExtVTable::id(self) } + fn arrow_canonical(&self) -> Option { + ExtVTable::arrow_canonical(self) + } + fn deserialize(&self, data: &[u8], storage_dtype: DType) -> VortexResult { let metadata = ExtVTable::deserialize_metadata(self, data)?; Ok(ExtDType::try_with_vtable(self.clone(), metadata, storage_dtype)?.erased()) diff --git a/vortex-array/src/dtype/extension/vtable.rs b/vortex-array/src/dtype/extension/vtable.rs index d4a00fbdec4..a34356bcc6d 100644 --- a/vortex-array/src/dtype/extension/vtable.rs +++ b/vortex-array/src/dtype/extension/vtable.rs @@ -12,6 +12,30 @@ use crate::dtype::extension::ExtDType; use crate::dtype::extension::ExtId; use crate::scalar::ScalarValue; +/// Converters between an extension's on-disk metadata bytes and the canonical Arrow JSON wire. +/// +/// Each Vortex extension that maps to a canonical Arrow extension owns the codec used at the +/// Arrow boundary so [`ExtVTable`] stays Arrow-unaware in the storage path. +#[derive(Copy, Clone, Debug)] +pub struct ArrowCanonicalCodec { + /// Convert raw extension metadata bytes into the JSON string Arrow consumers expect. + pub to_json: fn(&[u8]) -> VortexResult, + /// Parse the JSON string Arrow consumers produce back into raw extension metadata bytes. + pub from_json: fn(&str) -> VortexResult>, +} + +/// Identifies the canonical Arrow extension this Vortex extension serializes as. +/// +/// Returned by [`ExtVTable::arrow_canonical`]. The `arrow_id` is the name written into +/// `ARROW:extension:name`; the `codec` round-trips metadata bytes through Arrow's JSON wire. +#[derive(Copy, Clone, Debug)] +pub struct ArrowCanonicalAlias { + /// The canonical Arrow extension id (e.g. `arrow.fixed_shape_tensor`). + pub arrow_id: ExtId, + /// Converters between Vortex on-disk metadata bytes and Arrow's JSON wire. + pub codec: ArrowCanonicalCodec, +} + /// The public API for defining new extension types. /// /// This is the non-object-safe trait that plugin authors implement to define a new extension type. @@ -28,6 +52,15 @@ pub trait ExtVTable: 'static + Sized + Send + Sync + Clone + Debug + Eq + Hash { /// Returns the ID for this extension type. fn id(&self) -> ExtId; + /// Optional canonical Arrow extension this type serializes as at the Arrow boundary. + /// + /// Override to map this Vortex extension to a registered canonical Arrow extension + /// (e.g. `arrow.fixed_shape_tensor`). The default `None` means the type round-trips + /// through base64-encoded metadata under its own [`ExtId`]. + fn arrow_canonical(&self) -> Option { + None + } + // Methods related to the extension `DType`. /// Serialize the metadata into a byte vector. diff --git a/vortex-array/src/dtype/session.rs b/vortex-array/src/dtype/session.rs index 0f0db8236c0..7e1d3fcafe8 100644 --- a/vortex-array/src/dtype/session.rs +++ b/vortex-array/src/dtype/session.rs @@ -6,14 +6,12 @@ use std::any::Any; use std::sync::Arc; -use arc_swap::ArcSwap; -use vortex_error::VortexResult; use vortex_session::Ref; use vortex_session::SessionExt; use vortex_session::SessionVar; use vortex_session::registry::Registry; -use vortex_utils::aliases::hash_map::HashMap; +use crate::dtype::extension::ArrowCanonicalAlias; use crate::dtype::extension::ExtDTypePluginRef; use crate::dtype::extension::ExtId; use crate::dtype::extension::ExtVTable; @@ -24,77 +22,16 @@ use crate::extension::datetime::Timestamp; /// Registry for extension dtypes. pub type ExtDTypeRegistry = Registry; -/// Converters between an extension's on-disk metadata bytes and the Arrow canonical JSON wire. -/// -/// Bundled with the alias at registration time so [`ExtVTable`] stays Arrow-unaware. -#[derive(Copy, Clone, Debug)] -pub struct ArrowCanonicalCodec { - pub to_json: fn(&[u8]) -> VortexResult, - pub from_json: fn(&str) -> VortexResult>, -} - -/// Forward map is the canonical source: each Vortex extension owns its codec and points at the -/// Arrow canonical name it serializes as. Reverse map is a lookup index for the read path, -/// taking an Arrow name back to the Vortex id whose codec applies. -#[derive(Default, Clone)] -struct AliasState { - forward: HashMap, - reverse: HashMap, -} - -#[derive(Debug, Default)] -struct ArrowCanonicalAliases(ArcSwap); - -impl ArrowCanonicalAliases { - /// Re-registering evicts any prior alias touching either id so both directions agree. - fn register(&self, vortex_id: ExtId, arrow_id: ExtId, codec: ArrowCanonicalCodec) { - self.0.rcu(|prev| { - let mut next = (**prev).clone(); - if let Some((stale_arrow, _)) = next.forward.remove(&vortex_id) { - next.reverse.remove(&stale_arrow); - } - if let Some(stale_vortex) = next.reverse.remove(&arrow_id) { - next.forward.remove(&stale_vortex); - } - next.forward.insert(vortex_id, (arrow_id, codec)); - next.reverse.insert(arrow_id, vortex_id); - Arc::new(next) - }); - } - - fn arrow_alias_for(&self, vortex_id: &ExtId) -> Option<(ExtId, ArrowCanonicalCodec)> { - self.0.load().forward.get(vortex_id).copied() - } - - fn vortex_alias_for(&self, arrow_id: &ExtId) -> Option<(ExtId, ArrowCanonicalCodec)> { - let state = self.0.load(); - let vortex_id = *state.reverse.get(arrow_id)?; - let (_, codec) = *state.forward.get(&vortex_id)?; - Some((vortex_id, codec)) - } -} - -impl std::fmt::Debug for AliasState { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.debug_struct("AliasState") - .field("forward", &self.forward) - .field("reverse", &self.reverse) - .finish() - } -} - /// Session for managing extension dtypes. #[derive(Debug)] pub struct DTypeSession { registry: ExtDTypeRegistry, - arrow_canonical: ArrowCanonicalAliases, } impl Default for DTypeSession { fn default() -> Self { let this = Self { registry: Registry::default(), - arrow_canonical: ArrowCanonicalAliases::default(), }; this.register(Date); @@ -117,6 +54,9 @@ impl SessionVar for DTypeSession { impl DTypeSession { /// Register an extension DType with the Vortex session. + /// + /// The vtable's [`ExtVTable::arrow_canonical`] is consulted lazily on lookup, so the alias + /// has a single source of truth (the vtable itself) and no per-session bookkeeping. pub fn register(&self, vtable: V) { self.registry .register(vtable.id(), Arc::new(vtable) as ExtDTypePluginRef); @@ -127,25 +67,19 @@ impl DTypeSession { &self.registry } - /// Alias `arrow_id` to `vortex_id` with the codec used at the Arrow boundary. - /// Re-registering evicts the previous mapping for either side. - pub fn register_arrow_canonical( - &self, - vortex_id: ExtId, - arrow_id: ExtId, - codec: ArrowCanonicalCodec, - ) { - self.arrow_canonical.register(vortex_id, arrow_id, codec); - } - - /// Returns the Arrow canonical id and codec aliased to `vortex_id`, if any. - pub fn arrow_alias_for(&self, vortex_id: &ExtId) -> Option<(ExtId, ArrowCanonicalCodec)> { - self.arrow_canonical.arrow_alias_for(vortex_id) + /// Returns the canonical Arrow alias declared by `vortex_id`'s vtable, if any. + pub fn arrow_alias_for(&self, vortex_id: &ExtId) -> Option { + self.registry.find(vortex_id)?.arrow_canonical() } - /// Returns the Vortex id and codec aliased to `arrow_id`, if any. - pub fn vortex_alias_for(&self, arrow_id: &ExtId) -> Option<(ExtId, ArrowCanonicalCodec)> { - self.arrow_canonical.vortex_alias_for(arrow_id) + /// Returns the Vortex id and alias for canonical Arrow extension `arrow_id`, if a vtable + /// declaring that alias is registered. Linear scan over registered vtables — cheap given + /// the small number of extensions in practice. + pub fn vortex_alias_for(&self, arrow_id: &ExtId) -> Option<(ExtId, ArrowCanonicalAlias)> { + self.registry.items().find_map(|plugin| { + let alias = plugin.arrow_canonical()?; + (alias.arrow_id == *arrow_id).then(|| (plugin.id(), alias)) + }) } } @@ -160,86 +94,3 @@ impl DTypeSessionExt for S { self.get::() } } - -#[cfg(test)] -mod tests { - use vortex_error::vortex_err; - - use super::*; - - const TEST_CODEC: ArrowCanonicalCodec = ArrowCanonicalCodec { - to_json: |bytes| { - String::from_utf8(bytes.to_vec()).map_err(|e| vortex_err!("non-utf8 test bytes: {e}")) - }, - from_json: |s| Ok(s.as_bytes().to_vec()), - }; - - #[test] - fn arrow_canonical_re_registration_is_clean() { - let session = DTypeSession::default(); - let v = ExtId::new("vortex.test"); - let foo = ExtId::new("arrow.foo"); - let bar = ExtId::new("arrow.bar"); - - session.register_arrow_canonical(v, foo, TEST_CODEC); - assert_eq!(session.arrow_alias_for(&v).map(|(id, _)| id), Some(foo)); - assert_eq!(session.vortex_alias_for(&foo).map(|(id, _)| id), Some(v)); - - session.register_arrow_canonical(v, bar, TEST_CODEC); - assert_eq!(session.arrow_alias_for(&v).map(|(id, _)| id), Some(bar)); - assert_eq!(session.vortex_alias_for(&bar).map(|(id, _)| id), Some(v)); - assert!(session.vortex_alias_for(&foo).is_none()); - } - - /// `(vid → old, old → vid)` then `register(vid, new)` should leave `(vid → new, new → vid)`. - #[test] - fn rebind_vortex_id_to_new_arrow_name() { - let session = DTypeSession::default(); - let vid = ExtId::new("vortex.a"); - let old = ExtId::new("arrow.b"); - let new = ExtId::new("arrow.c"); - - session.register_arrow_canonical(vid, old, TEST_CODEC); - assert_eq!(session.arrow_alias_for(&vid).map(|(id, _)| id), Some(old)); - assert_eq!(session.vortex_alias_for(&old).map(|(id, _)| id), Some(vid)); - - session.register_arrow_canonical(vid, new, TEST_CODEC); - - assert_eq!(session.arrow_alias_for(&vid).map(|(id, _)| id), Some(new)); - assert_eq!(session.vortex_alias_for(&new).map(|(id, _)| id), Some(vid)); - assert!(session.vortex_alias_for(&old).is_none()); - } - - /// `(old → name, name → old)` then `register(new, name)` should leave `(new → name, name → new)`. - #[test] - fn steal_arrow_name_from_another_vortex_id() { - let session = DTypeSession::default(); - let old = ExtId::new("vortex.a"); - let name = ExtId::new("arrow.b"); - let new = ExtId::new("vortex.c"); - - session.register_arrow_canonical(old, name, TEST_CODEC); - assert_eq!(session.arrow_alias_for(&old).map(|(id, _)| id), Some(name)); - - session.register_arrow_canonical(new, name, TEST_CODEC); - - assert_eq!(session.arrow_alias_for(&new).map(|(id, _)| id), Some(name)); - assert_eq!(session.vortex_alias_for(&name).map(|(id, _)| id), Some(new)); - assert!(session.arrow_alias_for(&old).is_none()); - } - - #[test] - fn codec_round_trips_through_lookup() { - let session = DTypeSession::default(); - let vid = ExtId::new("vortex.x"); - let aid = ExtId::new("arrow.x"); - - session.register_arrow_canonical(vid, aid, TEST_CODEC); - - let (_, codec) = session.arrow_alias_for(&vid).unwrap(); - let json = (codec.to_json)(b"hello").unwrap(); - assert_eq!(json, "hello"); - let bytes = (codec.from_json)(&json).unwrap(); - assert_eq!(bytes, b"hello"); - } -} diff --git a/vortex-tensor/public-api.lock b/vortex-tensor/public-api.lock index 7300bb399e6..79857dcdb3b 100644 --- a/vortex-tensor/public-api.lock +++ b/vortex-tensor/public-api.lock @@ -126,6 +126,8 @@ pub type vortex_tensor::fixed_shape::FixedShapeTensor::Metadata = vortex_tensor: pub type vortex_tensor::fixed_shape::FixedShapeTensor::NativeValue<'a> = &'a vortex_array::scalar::scalar_value::ScalarValue +pub fn vortex_tensor::fixed_shape::FixedShapeTensor::arrow_canonical(&self) -> core::option::Option + pub fn vortex_tensor::fixed_shape::FixedShapeTensor::deserialize_metadata(&self, metadata: &[u8]) -> vortex_error::VortexResult pub fn vortex_tensor::fixed_shape::FixedShapeTensor::id(&self) -> vortex_array::dtype::extension::ExtId diff --git a/vortex-tensor/src/lib.rs b/vortex-tensor/src/lib.rs index 28ee8795825..963e130a7b4 100644 --- a/vortex-tensor/src/lib.rs +++ b/vortex-tensor/src/lib.rs @@ -11,8 +11,6 @@ )] use vortex_array::arrays::scalar_fn::plugin::ScalarFnArrayPlugin; -use vortex_array::dtype::extension::ExtVTable; -use vortex_array::dtype::session::ArrowCanonicalCodec; use vortex_array::dtype::session::DTypeSessionExt; use vortex_array::scalar_fn::session::ScalarFnSessionExt; use vortex_array::session::ArraySessionExt; @@ -52,14 +50,6 @@ pub fn initialize(session: &VortexSession) { let dtypes = session.dtypes(); dtypes.register(Vector); dtypes.register(FixedShapeTensor); - dtypes.register_arrow_canonical( - FixedShapeTensor.id(), - FixedShapeTensor::arrow_ext_id(), - ArrowCanonicalCodec { - to_json: fixed_shape::proto_to_json, - from_json: fixed_shape::json_to_proto, - }, - ); // Release the shard read before `scalar_fns` may take a write on the same shard. drop(dtypes); diff --git a/vortex-tensor/src/tests/arrow_roundtrip.rs b/vortex-tensor/src/tests/arrow_roundtrip.rs index baa17a480ca..068f299c6cf 100644 --- a/vortex-tensor/src/tests/arrow_roundtrip.rs +++ b/vortex-tensor/src/tests/arrow_roundtrip.rs @@ -164,7 +164,9 @@ fn vector_record_batch_round_trip_carries_field_metadata() { let dtype = DType::struct_([("embedding", vector_dtype(4))], Nullability::NonNullable); let schema = dtype.to_arrow_schema_with_session(&SESSION).unwrap(); - let rb = struct_array.into_record_batch_with_schema(&schema).unwrap(); + let rb = struct_array + .into_record_batch_with_schema_with_session(&schema, &SESSION) + .unwrap(); let column = rb.column(0); let DataType::FixedSizeList(_, size) = column.data_type() else { @@ -228,7 +230,9 @@ fn vector_record_batch_round_trip() { let dtype = DType::struct_([("embedding", vector_dtype(4))], Nullability::NonNullable); let schema = dtype.to_arrow_schema_with_session(&SESSION).unwrap(); - let rb = original.into_record_batch_with_schema(&schema).unwrap(); + let rb = original + .into_record_batch_with_schema_with_session(&schema, &SESSION) + .unwrap(); let recovered = ArrayRef::from_arrow_with_session(rb, false, &SESSION).unwrap(); assert_eq!(recovered.dtype(), &dtype); @@ -248,7 +252,9 @@ fn fixed_shape_tensor_record_batch_round_trip() { .unwrap() .into_array(); let original = StructArray::from_fields(&[("tensor", tensor_array)]).unwrap(); - let rb = original.into_record_batch_with_schema(&schema).unwrap(); + let rb = original + .into_record_batch_with_schema_with_session(&schema, &SESSION) + .unwrap(); let recovered = ArrayRef::from_arrow_with_session(rb, false, &SESSION).unwrap(); assert_eq!(recovered.dtype(), &dtype); diff --git a/vortex-tensor/src/types/fixed_shape/mod.rs b/vortex-tensor/src/types/fixed_shape/mod.rs index 94d91e74095..64285228986 100644 --- a/vortex-tensor/src/types/fixed_shape/mod.rs +++ b/vortex-tensor/src/types/fixed_shape/mod.rs @@ -27,5 +27,3 @@ pub use metadata::FixedShapeTensorMetadata; mod canonical; mod proto; mod vtable; -pub(crate) use canonical::json_to_proto; -pub(crate) use canonical::proto_to_json; diff --git a/vortex-tensor/src/types/fixed_shape/vtable.rs b/vortex-tensor/src/types/fixed_shape/vtable.rs index 97eadbb55fb..2648fc559b3 100644 --- a/vortex-tensor/src/types/fixed_shape/vtable.rs +++ b/vortex-tensor/src/types/fixed_shape/vtable.rs @@ -2,6 +2,8 @@ // SPDX-FileCopyrightText: Copyright the Vortex contributors use vortex_array::dtype::DType; +use vortex_array::dtype::extension::ArrowCanonicalAlias; +use vortex_array::dtype::extension::ArrowCanonicalCodec; use vortex_array::dtype::extension::ExtDType; use vortex_array::dtype::extension::ExtId; use vortex_array::dtype::extension::ExtVTable; @@ -14,6 +16,8 @@ use vortex_session::registry::CachedId; use crate::types::fixed_shape::FixedShapeTensor; use crate::types::fixed_shape::FixedShapeTensorMetadata; +use crate::types::fixed_shape::canonical::json_to_proto; +use crate::types::fixed_shape::canonical::proto_to_json; use crate::types::fixed_shape::proto; /// Vortex extension id for [`FixedShapeTensor`]. @@ -29,6 +33,16 @@ impl ExtVTable for FixedShapeTensor { *ID } + fn arrow_canonical(&self) -> Option { + Some(ArrowCanonicalAlias { + arrow_id: Self::arrow_ext_id(), + codec: ArrowCanonicalCodec { + to_json: proto_to_json, + from_json: json_to_proto, + }, + }) + } + fn serialize_metadata(&self, metadata: &Self::Metadata) -> VortexResult> { Ok(proto::serialize(metadata)) }