From 539503f56a6c763dcd0109c39edd6c665373d0ee Mon Sep 17 00:00:00 2001 From: Connor Tsui Date: Tue, 28 Apr 2026 15:11:40 -0400 Subject: [PATCH 1/6] add `NormalizedVector` extension type Signed-off-by: Connor Tsui --- vortex-tensor/public-api.lock | 68 +- vortex-tensor/src/encodings/l2_denorm.rs | 74 +- .../src/encodings/turboquant/compress.rs | 228 +++--- vortex-tensor/src/encodings/turboquant/mod.rs | 2 +- .../src/encodings/turboquant/tests/mod.rs | 7 +- .../encodings/turboquant/tests/roundtrip.rs | 16 +- .../encodings/turboquant/tests/structural.rs | 2 +- vortex-tensor/src/lib.rs | 3 + vortex-tensor/src/matcher.rs | 21 +- .../src/scalar_fns/cosine_similarity.rs | 151 ++-- vortex-tensor/src/scalar_fns/inner_product.rs | 217 +++-- vortex-tensor/src/scalar_fns/l2_denorm.rs | 746 +++++++++++------- vortex-tensor/src/scalar_fns/l2_norm.rs | 148 +++- .../src/scalar_fns/sorf_transform/mod.rs | 29 +- .../src/scalar_fns/sorf_transform/tests.rs | 142 ++++ .../src/scalar_fns/sorf_transform/vtable.rs | 30 +- vortex-tensor/src/types/mod.rs | 1 + .../src/types/normalized_vector/matcher.rs | 124 +++ .../src/types/normalized_vector/mod.rs | 192 +++++ .../src/types/normalized_vector/vtable.rs | 129 +++ vortex-tensor/src/types/vector/matcher.rs | 5 + vortex-tensor/src/utils.rs | 50 +- vortex/benches/single_encoding_throughput.rs | 16 +- 23 files changed, 1812 insertions(+), 589 deletions(-) create mode 100644 vortex-tensor/src/types/normalized_vector/matcher.rs create mode 100644 vortex-tensor/src/types/normalized_vector/mod.rs create mode 100644 vortex-tensor/src/types/normalized_vector/vtable.rs diff --git a/vortex-tensor/public-api.lock b/vortex-tensor/public-api.lock index 7300bb399e6..5a30355caac 100644 --- a/vortex-tensor/public-api.lock +++ b/vortex-tensor/public-api.lock @@ -82,7 +82,7 @@ pub fn vortex_tensor::encodings::turboquant::tq_validate_vector_dtype(dtype: &vo pub fn vortex_tensor::encodings::turboquant::turboquant_encode(input: vortex_array::array::erased::ArrayRef, config: &vortex_tensor::encodings::turboquant::TurboQuantConfig, ctx: &mut vortex_array::executor::ExecutionCtx) -> vortex_error::VortexResult -pub unsafe fn vortex_tensor::encodings::turboquant::turboquant_encode_unchecked(ext: vortex_array::array::view::ArrayView<'_, vortex_array::arrays::extension::vtable::Extension>, config: &vortex_tensor::encodings::turboquant::TurboQuantConfig, ctx: &mut vortex_array::executor::ExecutionCtx) -> vortex_error::VortexResult +pub fn vortex_tensor::encodings::turboquant::turboquant_encode_normalized(ext: vortex_array::array::view::ArrayView<'_, vortex_array::arrays::extension::vtable::Extension>, config: &vortex_tensor::encodings::turboquant::TurboQuantConfig, ctx: &mut vortex_array::executor::ExecutionCtx) -> vortex_error::VortexResult pub mod vortex_tensor::fixed_shape @@ -218,12 +218,16 @@ pub enum vortex_tensor::matcher::TensorMatch<'a> pub vortex_tensor::matcher::TensorMatch::FixedShapeTensor(vortex_tensor::fixed_shape::FixedShapeTensorMatcherMetadata<'a>) +pub vortex_tensor::matcher::TensorMatch::NormalizedVector(vortex_tensor::vector::VectorMatcherMetadata) + pub vortex_tensor::matcher::TensorMatch::Vector(vortex_tensor::vector::VectorMatcherMetadata) impl vortex_tensor::matcher::TensorMatch<'_> pub fn vortex_tensor::matcher::TensorMatch<'_>::element_ptype(self) -> vortex_array::dtype::ptype::PType +pub fn vortex_tensor::matcher::TensorMatch<'_>::is_normalized(self) -> bool + pub fn vortex_tensor::matcher::TensorMatch<'_>::list_size(self) -> u32 impl<'a> core::clone::Clone for vortex_tensor::matcher::TensorMatch<'a> @@ -252,6 +256,66 @@ pub type vortex_tensor::matcher::AnyTensor::Match<'a> = vortex_tensor::matcher:: pub fn vortex_tensor::matcher::AnyTensor::try_match<'a>(ext_dtype: &'a vortex_array::dtype::extension::erased::ExtDTypeRef) -> core::option::Option +pub mod vortex_tensor::normalized_vector + +pub struct vortex_tensor::normalized_vector::AnyNormalizedVector + +impl vortex_array::dtype::extension::matcher::Matcher for vortex_tensor::normalized_vector::AnyNormalizedVector + +pub type vortex_tensor::normalized_vector::AnyNormalizedVector::Match<'a> = vortex_tensor::vector::VectorMatcherMetadata + +pub fn vortex_tensor::normalized_vector::AnyNormalizedVector::try_match<'a>(ext_dtype: &'a vortex_array::dtype::extension::erased::ExtDTypeRef) -> core::option::Option + +pub struct vortex_tensor::normalized_vector::NormalizedVector + +impl vortex_tensor::normalized_vector::NormalizedVector + +pub unsafe fn vortex_tensor::normalized_vector::NormalizedVector::new_unchecked(fsl: vortex_array::array::erased::ArrayRef) -> vortex_error::VortexResult + +pub fn vortex_tensor::normalized_vector::NormalizedVector::try_new(fsl: vortex_array::array::erased::ArrayRef, ctx: &mut vortex_array::executor::ExecutionCtx) -> vortex_error::VortexResult + +pub unsafe fn vortex_tensor::normalized_vector::NormalizedVector::wrap_vector_unchecked(vector: vortex_array::array::erased::ArrayRef) -> vortex_error::VortexResult + +impl core::clone::Clone for vortex_tensor::normalized_vector::NormalizedVector + +pub fn vortex_tensor::normalized_vector::NormalizedVector::clone(&self) -> vortex_tensor::normalized_vector::NormalizedVector + +impl core::cmp::Eq for vortex_tensor::normalized_vector::NormalizedVector + +impl core::cmp::PartialEq for vortex_tensor::normalized_vector::NormalizedVector + +pub fn vortex_tensor::normalized_vector::NormalizedVector::eq(&self, other: &vortex_tensor::normalized_vector::NormalizedVector) -> bool + +impl core::default::Default for vortex_tensor::normalized_vector::NormalizedVector + +pub fn vortex_tensor::normalized_vector::NormalizedVector::default() -> vortex_tensor::normalized_vector::NormalizedVector + +impl core::fmt::Debug for vortex_tensor::normalized_vector::NormalizedVector + +pub fn vortex_tensor::normalized_vector::NormalizedVector::fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result + +impl core::hash::Hash for vortex_tensor::normalized_vector::NormalizedVector + +pub fn vortex_tensor::normalized_vector::NormalizedVector::hash<__H: core::hash::Hasher>(&self, state: &mut __H) + +impl core::marker::StructuralPartialEq for vortex_tensor::normalized_vector::NormalizedVector + +impl vortex_array::dtype::extension::vtable::ExtVTable for vortex_tensor::normalized_vector::NormalizedVector + +pub type vortex_tensor::normalized_vector::NormalizedVector::Metadata = vortex_array::extension::EmptyMetadata + +pub type vortex_tensor::normalized_vector::NormalizedVector::NativeValue<'a> = &'a vortex_array::scalar::scalar_value::ScalarValue + +pub fn vortex_tensor::normalized_vector::NormalizedVector::deserialize_metadata(&self, _metadata: &[u8]) -> vortex_error::VortexResult + +pub fn vortex_tensor::normalized_vector::NormalizedVector::id(&self) -> vortex_array::dtype::extension::ExtId + +pub fn vortex_tensor::normalized_vector::NormalizedVector::serialize_metadata(&self, _metadata: &Self::Metadata) -> vortex_error::VortexResult> + +pub fn vortex_tensor::normalized_vector::NormalizedVector::unpack_native<'a>(_ext_dtype: &'a vortex_array::dtype::extension::typed::ExtDType, storage_value: &'a vortex_array::scalar::scalar_value::ScalarValue) -> vortex_error::VortexResult + +pub fn vortex_tensor::normalized_vector::NormalizedVector::validate_dtype(ext_dtype: &vortex_array::dtype::extension::typed::ExtDType) -> vortex_error::VortexResult<()> + pub mod vortex_tensor::scalar_fns pub mod vortex_tensor::scalar_fns::cosine_similarity @@ -384,8 +448,6 @@ pub fn vortex_tensor::scalar_fns::l2_denorm::L2Denorm::validity(&self, _options: pub fn vortex_tensor::scalar_fns::l2_denorm::normalize_as_l2_denorm(input: vortex_array::array::erased::ArrayRef, ctx: &mut vortex_array::executor::ExecutionCtx) -> vortex_error::VortexResult -pub fn vortex_tensor::scalar_fns::l2_denorm::validate_l2_normalized_rows_against_norms(normalized: &vortex_array::array::erased::ArrayRef, norms: core::option::Option<&vortex_array::array::erased::ArrayRef>, ctx: &mut vortex_array::executor::ExecutionCtx) -> vortex_error::VortexResult<()> - pub mod vortex_tensor::scalar_fns::l2_norm pub struct vortex_tensor::scalar_fns::l2_norm::L2Norm diff --git a/vortex-tensor/src/encodings/l2_denorm.rs b/vortex-tensor/src/encodings/l2_denorm.rs index 6cb4fcb0626..68b8c1b31d8 100644 --- a/vortex-tensor/src/encodings/l2_denorm.rs +++ b/vortex-tensor/src/encodings/l2_denorm.rs @@ -14,8 +14,8 @@ use vortex_compressor::scheme::Scheme; use vortex_compressor::stats::ArrayAndStats; use vortex_error::VortexResult; -use crate::matcher::AnyTensor; use crate::scalar_fns::l2_denorm::normalize_as_l2_denorm; +use crate::types::vector::AnyVector; #[derive(Debug)] pub struct L2DenormScheme; @@ -26,10 +26,14 @@ impl Scheme for L2DenormScheme { } fn matches(&self, canonical: &Canonical) -> bool { - matches!( - canonical, - Canonical::Extension(ext) if ext.ext_dtype().is::() - ) + let Canonical::Extension(ext) = canonical else { + return false; + }; + + // `AnyVector` is the strict matcher for plain `Vector` only, so a `NormalizedVector` + // input is naturally excluded here (it would already carry an authoritative unit-norm + // representation and does not need re-normalization). + ext.ext_dtype().is::() } fn expected_compression_ratio( @@ -38,6 +42,7 @@ impl Scheme for L2DenormScheme { _compress_ctx: CompressorContext, _exec_ctx: &mut ExecutionCtx, ) -> CompressionEstimate { + // We almost always want to pre-normalize our data if the vector is not already normalized. CompressionEstimate::Verdict(EstimateVerdict::AlwaysUse) } @@ -52,3 +57,62 @@ impl Scheme for L2DenormScheme { Ok(l2_denorm.into_array()) } } + +#[cfg(test)] +mod tests { + use std::sync::Arc; + + use vortex_array::Canonical; + use vortex_array::IntoArray; + use vortex_array::arrays::ExtensionArray; + use vortex_array::arrays::FixedSizeListArray; + use vortex_array::arrays::PrimitiveArray; + use vortex_array::dtype::DType; + use vortex_array::dtype::Nullability; + use vortex_array::dtype::PType; + use vortex_array::dtype::extension::ExtDType; + use vortex_array::extension::EmptyMetadata; + use vortex_array::validity::Validity; + use vortex_compressor::scheme::Scheme; + use vortex_error::VortexResult; + + use super::L2DenormScheme; + use crate::types::fixed_shape::FixedShapeTensor; + use crate::types::fixed_shape::FixedShapeTensorMetadata; + use crate::types::vector::Vector; + + fn fsl_storage(elements: &[f32], list_size: u32) -> VortexResult { + let len = elements.len() / list_size as usize; + let elements = PrimitiveArray::from_iter(elements.iter().copied()).into_array(); + FixedSizeListArray::try_new(elements, list_size, Validity::NonNullable, len) + } + + #[test] + fn matches_vector() -> VortexResult<()> { + let fsl = fsl_storage(&[1.0, 0.0], 2)?; + let ext_dtype = ExtDType::::try_new(EmptyMetadata, fsl.dtype().clone())?.erased(); + let canonical = Canonical::Extension(ExtensionArray::new(ext_dtype, fsl.into_array())); + + assert!(L2DenormScheme.matches(&canonical)); + Ok(()) + } + + #[test] + fn rejects_fixed_shape_tensor() -> VortexResult<()> { + let fsl = fsl_storage(&[1.0, 0.0, 0.0, 1.0], 4)?; + let storage_dtype = DType::FixedSizeList( + Arc::new(DType::Primitive(PType::F32, Nullability::NonNullable)), + 4, + Nullability::NonNullable, + ); + let ext_dtype = ExtDType::::try_new( + FixedShapeTensorMetadata::new(vec![2, 2]), + storage_dtype, + )? + .erased(); + let canonical = Canonical::Extension(ExtensionArray::new(ext_dtype, fsl.into_array())); + + assert!(!L2DenormScheme.matches(&canonical)); + Ok(()) + } +} diff --git a/vortex-tensor/src/encodings/turboquant/compress.rs b/vortex-tensor/src/encodings/turboquant/compress.rs index ca32faa6ec9..0a1b4dded13 100644 --- a/vortex-tensor/src/encodings/turboquant/compress.rs +++ b/vortex-tensor/src/encodings/turboquant/compress.rs @@ -15,6 +15,7 @@ use vortex_array::ArrayView; use vortex_array::ExecutionCtx; use vortex_array::IntoArray; use vortex_array::arrays::Extension; +use vortex_array::arrays::ExtensionArray; use vortex_array::arrays::FixedSizeListArray; use vortex_array::arrays::PrimitiveArray; use vortex_array::arrays::dict::DictArray; @@ -34,13 +35,13 @@ use crate::encodings::turboquant::MIN_DIMENSION; use crate::encodings::turboquant::centroids::compute_centroid_boundaries; use crate::encodings::turboquant::centroids::compute_or_get_centroids; use crate::encodings::turboquant::centroids::find_nearest_centroid; +use crate::normalized_vector::AnyNormalizedVector; use crate::scalar_fns::l2_denorm::L2Denorm; use crate::scalar_fns::l2_denorm::normalize_as_l2_denorm; use crate::scalar_fns::sorf_transform::SorfMatrix; use crate::scalar_fns::sorf_transform::SorfOptions; use crate::scalar_fns::sorf_transform::SorfTransform; -use crate::types::vector::AnyVector; -use crate::types::vector::Vector; +use crate::types::normalized_vector::NormalizedVector; use crate::utils::cast_to_f32; /// Configuration for TurboQuant encoding. @@ -66,7 +67,7 @@ impl Default for TurboQuantConfig { /// Apply the full TurboQuant compression pipeline to a [`Vector`](crate::vector::Vector) /// extension array: normalize the rows via [`normalize_as_l2_denorm`], quantize the normalized -/// child via [`turboquant_encode_unchecked`], and reattach the stored norms as the outer +/// child via [`turboquant_encode_normalized`], and reattach the stored norms as the outer /// [`L2Denorm`] wrapper. /// /// The returned array has the canonical TurboQuant shape: @@ -80,8 +81,8 @@ impl Default for TurboQuantConfig { /// /// # Errors /// -/// Returns an error if `input` is not a tensor-like extension array, if normalization fails, or -/// if [`turboquant_encode_unchecked`] rejects the input shape. +/// Returns an error if `input` is not a tensor-like extension array, if normalization fails, or if +/// [`turboquant_encode_normalized`] rejects the input shape. pub fn turboquant_encode( input: ArrayRef, config: &TurboQuantConfig, @@ -89,6 +90,8 @@ pub fn turboquant_encode( ) -> VortexResult { // We must normalize the array before we can encode it with TurboQuant. let l2_denorm = normalize_as_l2_denorm(input, ctx)?; + + // This is guaranteed to be a `NormalizedVector` extension type. let normalized = l2_denorm.child_at(0).clone(); let norms = l2_denorm.child_at(1).clone(); let num_rows = l2_denorm.len(); @@ -97,55 +100,56 @@ pub fn turboquant_encode( .as_opt::() .vortex_expect("normalize_as_l2_denorm always produces an Extension array child"); - // SAFETY: `normalize_as_l2_denorm` guarantees every row is unit-norm (or zero for null rows). - let tq = unsafe { turboquant_encode_unchecked(normalized_ext, config, ctx) }?; + let tq = turboquant_encode_normalized(normalized_ext, config, ctx)?; // SAFETY: TurboQuant is a lossy approximation of the normalized child, so we intentionally - // bypass the strict normalized-row validation when reattaching the stored norms. + // bypass the strict normalized-row and zero-row validation when reattaching the stored norms. Ok(unsafe { L2Denorm::new_array_unchecked(tq, norms, num_rows) }?.into_array()) } -/// Encode a non-nullable, L2-normalized [`Vector`](crate::vector::Vector) extension array into a -/// `ScalarFnArray(SorfTransform, [FSL(Dict(codes, centroids))])`, without validating the unit-norm -/// precondition. -/// -/// # Safety -/// -/// The caller must ensure: -/// -/// - The input dtype is non-nullable. -/// - Every row is L2-normalized (unit norm) or is a zero vector. +/// Encode a non-nullable [`NormalizedVector`](crate::normalized_vector::NormalizedVector) +/// extension array into +/// a `ScalarFnArray(SorfTransform, [FSL(Dict(codes, centroids))])`, without validating the +/// unit-norm precondition. /// /// Passing non-unit-norm vectors will not cause memory unsafety, but will produce silently /// incorrect quantization results. -pub unsafe fn turboquant_encode_unchecked( +pub fn turboquant_encode_normalized( ext: ArrayView, config: &TurboQuantConfig, ctx: &mut ExecutionCtx, ) -> VortexResult { let ext_dtype = ext.dtype().clone(); - let storage = ext.storage_array(); - let fsl = storage.clone().execute::(ctx)?; + + let vector_metadata = ext_dtype.as_extension().metadata::(); + let element_ptype = vector_metadata.element_ptype(); + let dimensions = vector_metadata.dimensions(); + + // `NormalizedVector` storage is `Extension(Vector(FSL))`; drill past the inner `Vector` to + // reach the underlying `FixedSizeList`. + let inner_vector: ExtensionArray = ext.storage_array().clone().execute(ctx)?; + let fsl: FixedSizeListArray = inner_vector.storage_array().clone().execute(ctx)?; vortex_ensure!( config.bit_width >= 1 && config.bit_width <= MAX_BIT_WIDTH, "bit_width must be 1-{MAX_BIT_WIDTH}, got {}", config.bit_width ); - let dimension = fsl.list_size(); vortex_ensure!( - dimension >= MIN_DIMENSION, - "TurboQuant requires dimension >= {MIN_DIMENSION}, got {dimension}", + dimensions >= MIN_DIMENSION, + "TurboQuant requires dimension >= {MIN_DIMENSION}, got {dimensions}", ); - let vector_metadata = ext_dtype.as_extension().metadata::(); - let element_ptype = vector_metadata.element_ptype(); - - let seed = config.seed; let num_rows = fsl.len(); + let sorf_options = SorfOptions { + seed: config.seed, + num_rounds: config.num_rounds, + dimensions, + element_ptype, + }; if fsl.is_empty() { - let padded_dim = dimension.next_power_of_two(); + let padded_dim = dimensions.next_power_of_two(); let empty_codes = PrimitiveArray::empty::(Nullability::NonNullable); let empty_centroids = PrimitiveArray::empty::(Nullability::NonNullable); let empty_dict = @@ -156,77 +160,128 @@ pub unsafe fn turboquant_encode_unchecked( Validity::NonNullable, 0, )?; - let empty_padded_vector = Vector::try_new_vector_array(empty_fsl.into_array())?; - - let sorf_options = SorfOptions { - seed, - num_rounds: config.num_rounds, - dimensions: dimension, - element_ptype, - }; + // SAFETY: An empty FSL contains no rows, so the unit-norm-or-zero invariant holds + // vacuously. + let empty_padded_vector = + unsafe { NormalizedVector::new_unchecked(empty_fsl.into_array()) }?; + return Ok( SorfTransform::try_new_array(&sorf_options, empty_padded_vector, 0)?.into_array(), ); } - let core = turboquant_quantize_core(&fsl, seed, config.bit_width, config.num_rounds, ctx)?; - let quantized_fsl = - build_quantized_fsl(num_rows, core.all_indices, core.centroids, core.padded_dim)?; - let padded_vector = Vector::try_new_vector_array(quantized_fsl)?; + let quantized_fsl = turboquant_quantize_fsl(&fsl, config.bit_width, &sorf_options, ctx)?; - let sorf_options = SorfOptions { - seed, - num_rounds: config.num_rounds, - dimensions: dimension, - element_ptype, - }; - Ok(SorfTransform::try_new_array(&sorf_options, padded_vector, num_rows)?.into_array()) -} + // NB: The quantized rows are approximately unit-norm by construction; downstream callers + // (notably the enclosing `L2Denorm` wrapper) treat the stored-norm + NormalizedVector claim as + // authoritative rather than decode-verified. + + // SAFETY: TurboQuant is a lossy approximation of the already-unit-norm input. + let padded_vector = unsafe { NormalizedVector::new_unchecked(quantized_fsl) }?; -/// Shared intermediate results from the quantization loop. -struct QuantizationResult { - centroids: Buffer, - all_indices: Buffer, - padded_dim: usize, + Ok(SorfTransform::try_new_array(&sorf_options, padded_vector, num_rows)?.into_array()) } -/// Core quantization: rotate and quantize already-normalized rows. +/// Rotate and quantize already-normalized rows into a dict-encoded `FixedSizeList`. +/// +/// The input `fsl` must contain non-nullable, unit-norm vectors of float values (already +/// L2-normalized). Null vectors are not supported and must be zeroed out before reaching this +/// function. The rotation and centroid lookup happen in f32. +/// +/// The returned array is `FSL(DictArray(codes, centroids), padded_dim)`. The `FixedSizeList` has +/// Dict-encoded elements, where each row of `padded_dim` u8 codes indexes into the centroid +/// codebook. +/// +/// This allows the FSL (via the Dict-encodede elements) to be independently sliced, taken, or +/// executed (dequantized) without knowledge of the rotation. /// -/// The input `fsl` must contain non-nullable, unit-norm vectors (already L2-normalized). Null -/// vectors are not supported and must be zeroed out before reaching this function. The rotation -/// and centroid lookup happen in f32. -fn turboquant_quantize_core( +/// Internally, this function will: +/// +/// 1. Builds a [`SorfMatrix`] structured rotation from the seed/rounds in `sorf_options`. +/// 2. For each row, zero-pads to the next power of 2, applies the rotation, and maps each rotated +/// coordinate to its nearest centroid index via binary search on precomputed boundaries. +/// 3. Packs the per-row centroid indices and the shared centroid codebook into a `DictArray`-backed +/// `FixedSizeListArray`. +fn turboquant_quantize_fsl( fsl: &FixedSizeListArray, - seed: u64, bit_width: u8, - num_rounds: u8, + sorf_options: &SorfOptions, ctx: &mut ExecutionCtx, -) -> VortexResult { - let dimension = fsl.list_size() as usize; +) -> VortexResult { + let dimensions = fsl.list_size() as usize; let num_rows = fsl.len(); - let rotation = SorfMatrix::try_new(seed, dimension, num_rounds as usize)?; + vortex_ensure!(!fsl.dtype().is_nullable()); + + let rotation = SorfMatrix::try_new( + sorf_options.seed, + dimensions, + sorf_options.num_rounds as usize, + )?; let padded_dim = rotation.padded_dim(); let padded_dim_u32 = u32::try_from(padded_dim).vortex_expect("padded_dim stays representable as u32"); + // Compute the centroids for the given (dimension, bit_width) combination (or retrieve it from a + // previous computation) + let centroids = compute_or_get_centroids(padded_dim_u32, bit_width)?; + + // Extract out the elements of the FSL and cast to f32. In the f64 case, we intentionally lose + // information here because we are already going to be quantizing to a smaller set of centroids, + // so we are fine with this loss. let elements_prim: PrimitiveArray = fsl.elements().clone().execute(ctx)?; let f32_elements = cast_to_f32(elements_prim)?; - let centroids = compute_or_get_centroids(padded_dim_u32, bit_width)?; - let boundaries = compute_centroid_boundaries(¢roids); + // Take the float values and quantize by finding the closest centroid in the codebook to each + // and recording the index of that centroid. + let all_indices = rotate_and_quantize( + f32_elements.as_slice(), + num_rows, + dimensions, + &rotation, + ¢roids, + ); + + // Build the Dict-encoded FSL from the centroid indices and codebook. Everything is non-null + // since our input in non-null. + let codes = PrimitiveArray::new::(all_indices, Validity::NonNullable); + let values = PrimitiveArray::new::(centroids, Validity::NonNullable); + let dict = DictArray::try_new(codes.into_array(), values.into_array())?; + + Ok(FixedSizeListArray::try_new( + dict.into_array(), + padded_dim_u32, + Validity::NonNullable, + num_rows, + )? + .into_array()) +} + +/// Rotate each row via the structured rotation and quantize every rotated coordinate to its nearest +/// centroid index via binary search on precomputed boundaries. +/// +/// Returns a flat [`Buffer`] of length `num_rows * padded_dim` containing the per-coordinate +/// centroid indices. +fn rotate_and_quantize( + f32_slice: &[f32], + num_rows: usize, + dimensions: usize, + rotation: &SorfMatrix, + centroids: &[f32], +) -> Buffer { + let padded_dim = rotation.padded_dim(); + let boundaries = compute_centroid_boundaries(centroids); let mut all_indices = BufferMut::::with_capacity(num_rows * padded_dim); let mut padded = vec![0.0f32; padded_dim]; let mut rotated = vec![0.0f32; padded_dim]; - let f32_slice = f32_elements.as_slice(); for row in 0..num_rows { - let x = &f32_slice[row * dimension..(row + 1) * dimension]; + let x = &f32_slice[row * dimensions..][..dimensions]; // Zero-pad to the next power of 2. - padded[..dimension].copy_from_slice(x); - padded[dimension..].fill(0.0); + padded[..dimensions].copy_from_slice(x); + padded[dimensions..].fill(0.0); rotation.rotate(&padded, &mut rotated); @@ -235,36 +290,5 @@ fn turboquant_quantize_core( } } - Ok(QuantizationResult { - centroids, - all_indices: all_indices.freeze(), - padded_dim, - }) -} - -/// Build a quantized representation: `FSL(DictArray(codes, centroids), padded_dim)`. -/// -/// This is a Dict-encoded FixedSizeList where each row of `padded_dim` u8 codes indexes into the -/// centroid codebook. The Dict can be independently sliced, taken, or executed (dequantized) -/// without knowledge of the rotation. -fn build_quantized_fsl( - num_rows: usize, - all_indices: Buffer, - centroids: Buffer, - padded_dim: usize, -) -> VortexResult { - let codes = PrimitiveArray::new::(all_indices, Validity::NonNullable); - let centroids_array = PrimitiveArray::new::(centroids, Validity::NonNullable); - - let dict = DictArray::try_new(codes.into_array(), centroids_array.into_array())?; - - let padded_dim_u32 = - u32::try_from(padded_dim).vortex_expect("padded_dim stays representable as u32"); - Ok(FixedSizeListArray::try_new( - dict.into_array(), - padded_dim_u32, - Validity::NonNullable, - num_rows, - )? - .into_array()) + all_indices.freeze() } diff --git a/vortex-tensor/src/encodings/turboquant/mod.rs b/vortex-tensor/src/encodings/turboquant/mod.rs index 50cef7b721e..bc404a8843e 100644 --- a/vortex-tensor/src/encodings/turboquant/mod.rs +++ b/vortex-tensor/src/encodings/turboquant/mod.rs @@ -134,7 +134,7 @@ pub(crate) mod compress; mod scheme; pub use compress::TurboQuantConfig; pub use compress::turboquant_encode; -pub use compress::turboquant_encode_unchecked; +pub use compress::turboquant_encode_normalized; pub use scheme::TurboQuantScheme; /// Minimum vector dimension for TurboQuant encoding. diff --git a/vortex-tensor/src/encodings/turboquant/tests/mod.rs b/vortex-tensor/src/encodings/turboquant/tests/mod.rs index ec4182dcc3d..4a21c40f0ba 100644 --- a/vortex-tensor/src/encodings/turboquant/tests/mod.rs +++ b/vortex-tensor/src/encodings/turboquant/tests/mod.rs @@ -91,9 +91,10 @@ fn unwrap_codes_centroids_norms( .child_at(0) .clone(); - // Vector wrapping FSL(Dict(codes, centroids)) - let padded_vector: ExtensionArray = padded_vector_child.execute(ctx)?; - let fsl: FixedSizeListArray = padded_vector.storage_array().clone().execute(ctx)?; + // NormalizedVector wrapping Vector wrapping FSL(Dict(codes, centroids)). + let normalized_vector: ExtensionArray = padded_vector_child.execute(ctx)?; + let inner_vector: ExtensionArray = normalized_vector.storage_array().clone().execute(ctx)?; + let fsl: FixedSizeListArray = inner_vector.storage_array().clone().execute(ctx)?; let dict = fsl .elements() .as_opt::() diff --git a/vortex-tensor/src/encodings/turboquant/tests/roundtrip.rs b/vortex-tensor/src/encodings/turboquant/tests/roundtrip.rs index d82be3cf714..3a0d5a061ee 100644 --- a/vortex-tensor/src/encodings/turboquant/tests/roundtrip.rs +++ b/vortex-tensor/src/encodings/turboquant/tests/roundtrip.rs @@ -11,10 +11,9 @@ use vortex_array::arrays::PrimitiveArray; use vortex_array::validity::Validity; use vortex_buffer::BufferMut; use vortex_error::VortexResult; -use vortex_error::vortex_err; use super::*; -use crate::encodings::turboquant::turboquant_encode_unchecked; +use crate::encodings::turboquant::turboquant_encode_normalized; use crate::scalar_fns::l2_denorm::normalize_as_l2_denorm; #[rstest] @@ -185,7 +184,7 @@ fn rejects_invalid_bit_width(#[case] bit_width: u8) { let normalized_ext = normalized .as_opt::() .expect("normalized child should be Extension"); - assert!(unsafe { turboquant_encode_unchecked(normalized_ext, &config, &mut ctx) }.is_err()); + assert!(turboquant_encode_normalized(normalized_ext, &config, &mut ctx).is_err()); } #[test] @@ -196,7 +195,8 @@ fn all_zero_vectors_roundtrip() -> VortexResult<()> { let elements = PrimitiveArray::new::(buf.freeze(), Validity::NonNullable); let fsl = FixedSizeListArray::try_new( elements.into_array(), - dim.try_into().map_err(|e| vortex_err!("{e}"))?, + dim.try_into() + .expect("somehow got dimension greater than u32::MAX"), Validity::NonNullable, num_rows, )?; @@ -245,7 +245,7 @@ fn f64_input_encodes_successfully() -> VortexResult<()> { let num_rows = 10; let dim = 128; let mut rng = StdRng::seed_from_u64(99); - let normal = Normal::new(0.0f64, 1.0).map_err(|e| vortex_err!("{e}"))?; + let normal = Normal::new(0.0f64, 1.0).unwrap(); let mut buf = BufferMut::::with_capacity(num_rows * dim); for _ in 0..(num_rows * dim) { @@ -254,7 +254,7 @@ fn f64_input_encodes_successfully() -> VortexResult<()> { let elements = PrimitiveArray::new::(buf.freeze(), Validity::NonNullable); let fsl = FixedSizeListArray::try_new( elements.into_array(), - dim.try_into().map_err(|e| vortex_err!("{e}"))?, + dim.try_into().unwrap(), Validity::NonNullable, num_rows, )?; @@ -278,7 +278,7 @@ fn f16_input_encodes_successfully() -> VortexResult<()> { let num_rows = 10; let dim = 128; let mut rng = StdRng::seed_from_u64(99); - let normal = Normal::new(0.0f32, 1.0).map_err(|e| vortex_err!("{e}"))?; + let normal = Normal::new(0.0f32, 1.0).unwrap(); let mut buf = BufferMut::::with_capacity(num_rows * dim); for _ in 0..(num_rows * dim) { @@ -287,7 +287,7 @@ fn f16_input_encodes_successfully() -> VortexResult<()> { let elements = PrimitiveArray::new::(buf.freeze(), Validity::NonNullable); let fsl = FixedSizeListArray::try_new( elements.into_array(), - dim.try_into().map_err(|e| vortex_err!("{e}"))?, + dim.try_into().unwrap(), Validity::NonNullable, num_rows, )?; diff --git a/vortex-tensor/src/encodings/turboquant/tests/structural.rs b/vortex-tensor/src/encodings/turboquant/tests/structural.rs index 3913cf3d8fe..cf6cc2c3fb7 100644 --- a/vortex-tensor/src/encodings/turboquant/tests/structural.rs +++ b/vortex-tensor/src/encodings/turboquant/tests/structural.rs @@ -13,7 +13,7 @@ use vortex_error::VortexResult; use super::*; -/// Verify that the centroids stored in the DictArray match what `compute_or_get_centroids()` computes. +/// Verify that the centroids stored in the DictArray match what `get_centroids()` computes. #[test] fn stored_centroids_match_computed() -> VortexResult<()> { let fsl = make_fsl(10, 128, 42); diff --git a/vortex-tensor/src/lib.rs b/vortex-tensor/src/lib.rs index 7beadc02e93..72fd137ecaf 100644 --- a/vortex-tensor/src/lib.rs +++ b/vortex-tensor/src/lib.rs @@ -22,6 +22,7 @@ use crate::scalar_fns::l2_denorm::L2Denorm; use crate::scalar_fns::l2_norm::L2Norm; use crate::scalar_fns::sorf_transform::SorfTransform; use crate::types::fixed_shape::FixedShapeTensor; +use crate::types::normalized_vector::NormalizedVector; use crate::types::vector::Vector; pub mod matcher; @@ -30,6 +31,7 @@ pub mod scalar_fns; mod types; pub use types::fixed_shape; +pub use types::normalized_vector; pub use types::vector; pub mod encodings; @@ -48,6 +50,7 @@ pub const SCALAR_FN_ARRAY_TENSOR_PLUGIN_ENV: &str = "VX_SCALAR_FN_ARRAY_TENSOR_P /// Initialize the Vortex tensor library with a Vortex session. pub fn initialize(session: &VortexSession) { session.dtypes().register(Vector); + session.dtypes().register(NormalizedVector); session.dtypes().register(FixedShapeTensor); let session_fns = session.scalar_fns(); diff --git a/vortex-tensor/src/matcher.rs b/vortex-tensor/src/matcher.rs index 4566dcb3a38..65fff74b728 100644 --- a/vortex-tensor/src/matcher.rs +++ b/vortex-tensor/src/matcher.rs @@ -9,6 +9,7 @@ use vortex_array::dtype::extension::Matcher; use crate::types::fixed_shape::AnyFixedShapeTensor; use crate::types::fixed_shape::FixedShapeTensorMatcherMetadata; +use crate::types::normalized_vector::AnyNormalizedVector; use crate::types::vector::AnyVector; use crate::types::vector::VectorMatcherMetadata; @@ -18,6 +19,7 @@ use crate::types::vector::VectorMatcherMetadata; /// /// - `FixedShapeTensor` /// - `Vector` +/// - `NormalizedVector` pub struct AnyTensor; /// The matched variant of a tensor-like extension type. @@ -30,6 +32,10 @@ pub enum TensorMatch<'a> { /// /// Note that we store an owned type here wrapping (copyable) data from the dtype. Vector(VectorMatcherMetadata), + + /// A [`NormalizedVector`](crate::normalized_vector::NormalizedVector) extension over + /// [`Vector`](crate::vector::Vector) storage. + NormalizedVector(VectorMatcherMetadata), } impl TensorMatch<'_> { @@ -37,7 +43,7 @@ impl TensorMatch<'_> { pub fn element_ptype(self) -> PType { match self { Self::FixedShapeTensor(metadata) => metadata.element_ptype(), - Self::Vector(metadata) => metadata.element_ptype(), + Self::Vector(metadata) | Self::NormalizedVector(metadata) => metadata.element_ptype(), } } @@ -45,9 +51,15 @@ impl TensorMatch<'_> { pub fn list_size(self) -> u32 { match self { Self::FixedShapeTensor(metadata) => metadata.flat_list_size(), - Self::Vector(metadata) => metadata.dimensions(), + Self::Vector(metadata) | Self::NormalizedVector(metadata) => metadata.dimensions(), } } + + /// Returns `true` when the dtype is a + /// [`NormalizedVector`](crate::normalized_vector::NormalizedVector). + pub fn is_normalized(self) -> bool { + matches!(self, Self::NormalizedVector(_)) + } } impl Matcher for AnyTensor { @@ -58,11 +70,14 @@ impl Matcher for AnyTensor { return Some(TensorMatch::FixedShapeTensor(metadata)); } - // Special logic for vectors to get convenience metadata (instead of `EmptyMetadata`). if let Some(metadata) = ext_dtype.metadata_opt::() { return Some(TensorMatch::Vector(metadata)); } + if let Some(metadata) = ext_dtype.metadata_opt::() { + return Some(TensorMatch::NormalizedVector(metadata)); + } + None } } diff --git a/vortex-tensor/src/scalar_fns/cosine_similarity.rs b/vortex-tensor/src/scalar_fns/cosine_similarity.rs index 7819e0b46f0..88203fe37f2 100644 --- a/vortex-tensor/src/scalar_fns/cosine_similarity.rs +++ b/vortex-tensor/src/scalar_fns/cosine_similarity.rs @@ -34,11 +34,10 @@ use vortex_error::VortexResult; use vortex_session::VortexSession; use crate::scalar_fns::inner_product::InnerProduct; -use crate::scalar_fns::l2_denorm::DenormOrientation; +use crate::scalar_fns::l2_denorm::NormalForm; use crate::scalar_fns::l2_denorm::try_build_constant_l2_denorm; use crate::scalar_fns::l2_norm::L2Norm; use crate::utils::BinaryTensorOpMetadata; -use crate::utils::extract_l2_denorm_children; use crate::utils::validate_binary_tensor_float_inputs; /// Cosine similarity between two columns. @@ -141,15 +140,21 @@ impl ScalarFnVTable for CosineSimilarity { rhs_ref = sfn.into_array(); } - // Take any L2Denorm-wrapped fast path that applies. - match DenormOrientation::classify(&lhs_ref, &rhs_ref) { - DenormOrientation::Both { lhs, rhs } => { - return self.execute_both_denorm(lhs, rhs, len); + // Classify each operand by its normal form. When both operands carry a known unit-norm + // representation, cosine similarity collapses to the dot product of the unit vectors. + let lhs_form = NormalForm::classify(&lhs_ref); + let rhs_form = NormalForm::classify(&rhs_ref); + match (lhs_form.normalized_array(), rhs_form.normalized_array()) { + (Some(unit_lhs), Some(unit_rhs)) => { + return self.execute_both_unit(unit_lhs, unit_rhs, &lhs_ref, &rhs_ref, len); } - DenormOrientation::One { denorm, plain } => { - return self.execute_one_denorm(denorm, plain, len, ctx); + (Some(unit_lhs), None) => { + return self.execute_one_unit(unit_lhs, &rhs_ref, &lhs_ref, len, ctx); } - DenormOrientation::Neither => {} + (None, Some(unit_rhs)) => { + return self.execute_one_unit(unit_rhs, &lhs_ref, &rhs_ref, len, ctx); + } + (None, None) => {} } // Compute combined validity. @@ -237,22 +242,20 @@ impl ScalarFnArrayVTable for CosineSimilarity { } impl CosineSimilarity { - /// Both sides are `L2Denorm`: treat the normalized children as authoritative, so - /// `cosine_similarity = dot(n_l, n_r)`. - fn execute_both_denorm( + /// Both sides carry a known unit-norm representation: cosine similarity collapses to the + /// dot product of the unit children. + fn execute_both_unit( &self, + unit_lhs: &ArrayRef, + unit_rhs: &ArrayRef, lhs_ref: &ArrayRef, rhs_ref: &ArrayRef, len: usize, ) -> VortexResult { let validity = lhs_ref.validity()?.and(rhs_ref.validity()?)?; - let (normalized_l, _) = extract_l2_denorm_children(lhs_ref); - let (normalized_r, _) = extract_l2_denorm_children(rhs_ref); - - // `L2Denorm` makes the normalized children authoritative, so their dot product is the - // cosine similarity even for lossy storage wrappers. - let dot = InnerProduct::try_new_array(normalized_l, normalized_r, len)?.into_array(); + let dot = + InnerProduct::try_new_array(unit_lhs.clone(), unit_rhs.clone(), len)?.into_array(); if !matches!(validity, Validity::NonNullable) { // Masking always changes the nullability to nullable. @@ -262,22 +265,21 @@ impl CosineSimilarity { } } - /// One side is `L2Denorm`: treat the normalized child as authoritative, so - /// `cosine_similarity = dot(n, b) / ||b||`. - /// - /// The caller must pass the denorm array as `denorm_ref` and the plain array as `plain_ref`. - fn execute_one_denorm( + /// Exactly one side carries a unit-norm representation: cosine similarity reduces to + /// `dot(unit, other) / ||other||`. The norms of the unit side are implicitly `1.0` (naked + /// `NormalizedVector`) or stored separately (the outer `L2Denorm` wrapper, which is not + /// needed here since cosine ignores magnitude). + fn execute_one_unit( &self, - denorm_ref: &ArrayRef, + unit: &ArrayRef, plain_ref: &ArrayRef, + unit_ref: &ArrayRef, len: usize, ctx: &mut ExecutionCtx, ) -> VortexResult { - let validity = denorm_ref.validity()?.and(plain_ref.validity()?)?; + let validity = unit_ref.validity()?.and(plain_ref.validity()?)?; - let (normalized, _) = extract_l2_denorm_children(denorm_ref); - - let dot_arr = InnerProduct::try_new_array(normalized, plain_ref.clone(), len)?; + let dot_arr = InnerProduct::try_new_array(unit.clone(), plain_ref.clone(), len)?; let dot: PrimitiveArray = dot_arr.into_array().execute(ctx)?; let norm_arr = L2Norm::try_new_array(plain_ref.clone(), len)?; @@ -326,6 +328,7 @@ mod tests { use crate::utils::test_helpers::assert_close; use crate::utils::test_helpers::constant_tensor_array; use crate::utils::test_helpers::l2_denorm_array; + use crate::utils::test_helpers::normalized_vector_array; use crate::utils::test_helpers::tensor_array; use crate::utils::test_helpers::vector_array; @@ -514,13 +517,25 @@ mod tests { Ok(()) } + /// Naked [`NormalizedVector`](crate::normalized_vector::NormalizedVector) operands take the + /// fast path: cosine similarity collapses to the dot product without computing norms. + #[test] + fn naked_normalized_vector_cosine() -> VortexResult<()> { + let mut ctx = SESSION.create_execution_ctx(); + let lhs = normalized_vector_array(2, &[0.6, 0.8, 1.0, 0.0], &mut ctx)?; + let rhs = normalized_vector_array(2, &[0.6, 0.8, 0.0, 1.0], &mut ctx)?; + // Row 0: identical -> 1.0, Row 1: orthogonal -> 0.0. + assert_close(&eval_cosine_similarity(lhs, rhs, 2)?, &[1.0, 0.0]); + Ok(()) + } + #[test] fn both_denorm_self_similarity() -> VortexResult<()> { // [3.0, 4.0] has norm 5.0, normalized [0.6, 0.8]. // [1.0, 0.0] has norm 1.0, normalized [1.0, 0.0]. let mut ctx = SESSION.create_execution_ctx(); - let lhs = l2_denorm_array(&[2], &[0.6, 0.8, 1.0, 0.0], &[5.0, 1.0], &mut ctx)?; - let rhs = l2_denorm_array(&[2], &[0.6, 0.8, 1.0, 0.0], &[5.0, 1.0], &mut ctx)?; + let lhs = l2_denorm_array(2, &[0.6, 0.8, 1.0, 0.0], &[5.0, 1.0], &mut ctx)?; + let rhs = l2_denorm_array(2, &[0.6, 0.8, 1.0, 0.0], &[5.0, 1.0], &mut ctx)?; // Self-similarity should always be 1.0. assert_close(&eval_cosine_similarity(lhs, rhs, 2)?, &[1.0, 1.0]); @@ -532,8 +547,8 @@ mod tests { // [3.0, 0.0] normalized [1.0, 0.0], norm 3.0. // [0.0, 4.0] normalized [0.0, 1.0], norm 4.0. let mut ctx = SESSION.create_execution_ctx(); - let lhs = l2_denorm_array(&[2], &[1.0, 0.0], &[3.0], &mut ctx)?; - let rhs = l2_denorm_array(&[2], &[0.0, 1.0], &[4.0], &mut ctx)?; + let lhs = l2_denorm_array(2, &[1.0, 0.0], &[3.0], &mut ctx)?; + let rhs = l2_denorm_array(2, &[0.0, 1.0], &[4.0], &mut ctx)?; assert_close(&eval_cosine_similarity(lhs, rhs, 1)?, &[0.0]); Ok(()) @@ -543,8 +558,8 @@ mod tests { fn both_denorm_zero_norm() -> VortexResult<()> { // Zero-norm row: normalized is [0.0, 0.0], norm is 0.0. let mut ctx = SESSION.create_execution_ctx(); - let lhs = l2_denorm_array(&[2], &[0.6, 0.8, 0.0, 0.0], &[5.0, 0.0], &mut ctx)?; - let rhs = l2_denorm_array(&[2], &[0.6, 0.8, 1.0, 0.0], &[5.0, 1.0], &mut ctx)?; + let lhs = l2_denorm_array(2, &[0.6, 0.8, 0.0, 0.0], &[5.0, 0.0], &mut ctx)?; + let rhs = l2_denorm_array(2, &[0.6, 0.8, 1.0, 0.0], &[5.0, 1.0], &mut ctx)?; // Row 0: dot([0.6, 0.8], [0.6, 0.8]) = 1.0, row 1: dot([0,0], [1,0]) = 0.0. assert_close(&eval_cosine_similarity(lhs, rhs, 2)?, &[1.0, 0.0]); @@ -557,8 +572,8 @@ mod tests { // RHS is plain [3.0, 4.0]. // cosine_similarity([3.0, 4.0], [3.0, 4.0]) = 1.0. let mut ctx = SESSION.create_execution_ctx(); - let lhs = l2_denorm_array(&[2], &[0.6, 0.8], &[5.0], &mut ctx)?; - let rhs = tensor_array(&[2], &[3.0, 4.0])?; + let lhs = l2_denorm_array(2, &[0.6, 0.8], &[5.0], &mut ctx)?; + let rhs = vector_array(2, &[3.0, 4.0])?; assert_close(&eval_cosine_similarity(lhs, rhs, 1)?, &[1.0]); Ok(()) @@ -569,8 +584,8 @@ mod tests { // LHS is plain [1.0, 0.0], RHS is L2Denorm([0.6, 0.8], 5.0) representing [3.0, 4.0]. // cosine_similarity([1.0, 0.0], [3.0, 4.0]) = 3.0 / (1.0 * 5.0) = 0.6. let mut ctx = SESSION.create_execution_ctx(); - let lhs = tensor_array(&[2], &[1.0, 0.0])?; - let rhs = l2_denorm_array(&[2], &[0.6, 0.8], &[5.0], &mut ctx)?; + let lhs = vector_array(2, &[1.0, 0.0])?; + let rhs = l2_denorm_array(2, &[0.6, 0.8], &[5.0], &mut ctx)?; assert_close(&eval_cosine_similarity(lhs, rhs, 1)?, &[0.6]); Ok(()) @@ -580,9 +595,9 @@ mod tests { fn both_denorm_null_norms() -> VortexResult<()> { // Row 0: valid, row 1: null (via nullable norms on rhs). let mut ctx = SESSION.create_execution_ctx(); - let lhs = l2_denorm_array(&[2], &[0.6, 0.8, 1.0, 0.0], &[5.0, 1.0], &mut ctx)?; + let lhs = l2_denorm_array(2, &[0.6, 0.8, 1.0, 0.0], &[5.0, 1.0], &mut ctx)?; - let normalized_r = tensor_array(&[2], &[0.6, 0.8, 1.0, 0.0])?; + let normalized_r = normalized_vector_array(2, &[0.6, 0.8, 1.0, 0.0], &mut ctx)?; let norms_r = PrimitiveArray::from_option_iter([Some(5.0f64), None]).into_array(); let rhs = L2Denorm::try_new_array(normalized_r, norms_r, 2, &mut ctx)?.into_array(); @@ -698,9 +713,45 @@ mod tests { Ok(()) } + #[test] + fn serde_round_trip_mixed_vector_and_normalized_vector() -> VortexResult<()> { + let mut ctx = SESSION.create_execution_ctx(); + let lhs = normalized_vector_array(2, &[0.6, 0.8, 1.0, 0.0], &mut ctx)?; + let rhs = vector_array(2, &[3.0, 4.0, 0.0, 1.0])?; + let original = CosineSimilarity::try_new_array(lhs.clone(), rhs.clone(), 2)?.into_array(); + + let plugin = ScalarFnArrayPlugin::new(CosineSimilarity); + let metadata = plugin + .serialize(&original, &SESSION)? + .expect("CosineSimilarity serialize must produce metadata"); + + let children = vec![lhs, rhs]; + let recovered = plugin.deserialize( + original.dtype(), + original.len(), + &metadata, + &[], + &children, + &SESSION, + )?; + + assert_eq!(recovered.dtype(), original.dtype()); + assert_eq!(recovered.len(), original.len()); + assert_eq!(recovered.encoding_id(), original.encoding_id()); + Ok(()) + } + #[rstest] - #[case::vector(cosine_vector_lhs(), cosine_vector_rhs(), 2)] - #[case::fixed_shape_tensor(cosine_tensor_lhs(), cosine_tensor_rhs(), 2)] + #[case::vector( + vector_array(3, &[1.0, 0.0, 0.0, 3.0, 4.0, 0.0]).unwrap(), + vector_array(3, &[0.0, 1.0, 0.0, 3.0, 4.0, 0.0]).unwrap(), + 2, + )] + #[case::fixed_shape_tensor( + tensor_array(&[2], &[1.0, 0.0, 3.0, 4.0]).unwrap(), + tensor_array(&[2], &[0.0, 1.0, 3.0, 4.0]).unwrap(), + 2, + )] fn serde_round_trip( #[case] lhs: ArrayRef, #[case] rhs: ArrayRef, @@ -728,20 +779,4 @@ mod tests { assert_eq!(recovered.encoding_id(), original.encoding_id()); Ok(()) } - - fn cosine_vector_lhs() -> ArrayRef { - vector_array(3, &[1.0, 0.0, 0.0, 3.0, 4.0, 0.0]).expect("valid vector array") - } - - fn cosine_vector_rhs() -> ArrayRef { - vector_array(3, &[0.0, 1.0, 0.0, 3.0, 4.0, 0.0]).expect("valid vector array") - } - - fn cosine_tensor_lhs() -> ArrayRef { - tensor_array(&[2], &[1.0, 0.0, 3.0, 4.0]).expect("valid tensor array") - } - - fn cosine_tensor_rhs() -> ArrayRef { - tensor_array(&[2], &[0.0, 1.0, 3.0, 4.0]).expect("valid tensor array") - } } diff --git a/vortex-tensor/src/scalar_fns/inner_product.rs b/vortex-tensor/src/scalar_fns/inner_product.rs index d60938dfbd8..05e52666c05 100644 --- a/vortex-tensor/src/scalar_fns/inner_product.rs +++ b/vortex-tensor/src/scalar_fns/inner_product.rs @@ -46,14 +46,13 @@ use vortex_error::VortexResult; use vortex_session::VortexSession; use crate::matcher::AnyTensor; -use crate::scalar_fns::l2_denorm::DenormOrientation; +use crate::scalar_fns::l2_denorm::NormalForm; use crate::scalar_fns::sorf_transform::SorfMatrix; use crate::scalar_fns::sorf_transform::SorfTransform; use crate::types::vector::Vector; use crate::utils::BinaryTensorOpMetadata; use crate::utils::extract_constant_flat_row; use crate::utils::extract_flat_elements; -use crate::utils::extract_l2_denorm_children; use crate::utils::validate_binary_tensor_float_inputs; /// Inner product (dot product) between two columns. @@ -141,15 +140,16 @@ impl ScalarFnVTable for InnerProduct { let rhs_ref = args.get(1)?; let len = args.row_count(); - // Take any L2Denorm-wrapped fast path that applies. - match DenormOrientation::classify(&lhs_ref, &rhs_ref) { - DenormOrientation::Both { lhs, rhs } => { - return self.execute_both_denorm(lhs, rhs, len, ctx); - } - DenormOrientation::One { denorm, plain } => { - return self.execute_one_denorm(denorm, plain, len, ctx); - } - DenormOrientation::Neither => {} + // Take the unit-norm fast path only when at least one operand wraps stored norms (the + // `Denormalized` form). For naked `NormalizedVector` operands the fall-through dot + // product already computes the right thing (and short-circuiting here would recurse + // back into `InnerProduct`). + let lhs_form = NormalForm::classify(&lhs_ref); + let rhs_form = NormalForm::classify(&rhs_ref); + if matches!(lhs_form, NormalForm::Denormalized { .. }) + || matches!(rhs_form, NormalForm::Denormalized { .. }) + { + return self.execute_unit_form(&lhs_form, &rhs_form, &lhs_ref, &rhs_ref, len, ctx); } // Reduction case 1: `InnerProduct(SorfTransform(x), const)` rewrites to @@ -169,9 +169,14 @@ impl ScalarFnVTable for InnerProduct { // Compute combined validity. let validity = lhs_ref.validity()?.and(rhs_ref.validity()?)?; + // Drill past any `NormalizedVector` wrapper so we always work with the underlying + // `Vector` extension array. + let lhs_inner = crate::types::normalized_vector::inner_vector_array(&lhs_ref, ctx)?; + let rhs_inner = crate::types::normalized_vector::inner_vector_array(&rhs_ref, ctx)?; + // Canonicalize so we can perform the math directly. - let lhs: ExtensionArray = lhs_ref.execute(ctx)?; - let rhs: ExtensionArray = rhs_ref.execute(ctx)?; + let lhs: ExtensionArray = lhs_inner.execute(ctx)?; + let rhs: ExtensionArray = rhs_inner.execute(ctx)?; // We validated that both inputs have the same type. let ext = lhs.dtype().as_extension(); @@ -247,9 +252,14 @@ impl ScalarFnArrayVTable for InnerProduct { } impl InnerProduct { - /// Both sides are `L2Denorm`: `inner_product = s_l * s_r * dot(n_l, n_r)`. - fn execute_both_denorm( + /// Inner product over operands that may carry a unit-norm representation: + /// `inner_product = scale_l * scale_r * dot(unit_l, unit_r)`, where `scale = 1` for naked + /// `Normalized` operands, `scale = stored_norms` for `Denormalized` operands, and the + /// `unit_*` operands are the input itself for `Plain` operands. + fn execute_unit_form( &self, + lhs_form: &NormalForm<'_>, + rhs_form: &NormalForm<'_>, lhs_ref: &ArrayRef, rhs_ref: &ArrayRef, len: usize, @@ -257,50 +267,42 @@ impl InnerProduct { ) -> VortexResult { let validity = lhs_ref.validity()?.and(rhs_ref.validity()?)?; - let (normalized_l, norms_l) = extract_l2_denorm_children(lhs_ref); - let (normalized_r, norms_r) = extract_l2_denorm_children(rhs_ref); - - let norms_l: PrimitiveArray = norms_l.execute(ctx)?; - let norms_r: PrimitiveArray = norms_r.execute(ctx)?; - - let dot: PrimitiveArray = InnerProduct::try_new_array(normalized_l, normalized_r, len)? + // For each operand, take its unit-norm representation if it has one; fall back to the + // operand itself (the `Plain` case feeds the regular dot path with no scaling). + let unit_lhs = lhs_form + .normalized_array() + .cloned() + .unwrap_or_else(|| lhs_ref.clone()); + let unit_rhs = rhs_form + .normalized_array() + .cloned() + .unwrap_or_else(|| rhs_ref.clone()); + + let dot: PrimitiveArray = InnerProduct::try_new_array(unit_lhs, unit_rhs, len)? .into_array() .execute(ctx)?; - match_each_float_ptype!(dot.ptype(), |T| { - let dots = dot.as_slice::(); - let nl = norms_l.as_slice::(); - let nr = norms_r.as_slice::(); - let buffer: Buffer = (0..len).map(|i| nl[i] * nr[i] * dots[i]).collect(); - - // SAFETY: The buffer length equals `len`, which matches the source validity length. - Ok(unsafe { PrimitiveArray::new_unchecked(buffer, validity) }.into_array()) - }) - } - - /// One side is `L2Denorm`: `inner_product = s * dot(n, other)`. - /// - /// The caller must pass the denorm array as `denorm_ref` and the plain array as `plain_ref`. - fn execute_one_denorm( - &self, - denorm_ref: &ArrayRef, - plain_ref: &ArrayRef, - len: usize, - ctx: &mut ExecutionCtx, - ) -> VortexResult { - let validity = denorm_ref.validity()?.and(plain_ref.validity()?)?; - - let (normalized, norms) = extract_l2_denorm_children(denorm_ref); - let denorm_norms: PrimitiveArray = norms.execute(ctx)?; - - let dot: PrimitiveArray = InnerProduct::try_new_array(normalized, plain_ref.clone(), len)? - .into_array() - .execute(ctx)?; + let lhs_scale = norms_for_scaling(lhs_form, ctx)?; + let rhs_scale = norms_for_scaling(rhs_form, ctx)?; match_each_float_ptype!(dot.ptype(), |T| { let dots = dot.as_slice::(); - let ns = denorm_norms.as_slice::(); - let buffer: Buffer = (0..len).map(|i| ns[i] * dots[i]).collect(); + let buffer: Buffer = match (lhs_scale.as_ref(), rhs_scale.as_ref()) { + (Some(nl), Some(nr)) => { + let nl = nl.as_slice::(); + let nr = nr.as_slice::(); + (0..len).map(|i| nl[i] * nr[i] * dots[i]).collect() + } + (Some(nl), None) => { + let nl = nl.as_slice::(); + (0..len).map(|i| nl[i] * dots[i]).collect() + } + (None, Some(nr)) => { + let nr = nr.as_slice::(); + (0..len).map(|i| nr[i] * dots[i]).collect() + } + (None, None) => dots.iter().copied().collect(), + }; // SAFETY: The buffer length equals `len`, which matches the source validity length. Ok(unsafe { PrimitiveArray::new_unchecked(buffer, validity) }.into_array()) @@ -498,6 +500,21 @@ impl InnerProduct { } } +/// Materialize the per-row scaling factor for an operand classified by [`NormalForm`]. +/// +/// - `Plain`: no scaling needed (the operand itself enters the dot product). +/// - `Normalized`: implicit scaling of `1.0`, returned as `None` so the caller skips the multiply. +/// - `Denormalized`: returns the materialized stored norms. +fn norms_for_scaling( + form: &NormalForm<'_>, + ctx: &mut ExecutionCtx, +) -> VortexResult> { + match form { + NormalForm::Plain | NormalForm::Normalized { .. } => Ok(None), + NormalForm::Denormalized { norms, .. } => Ok(Some(norms.clone().execute(ctx)?)), + } +} + /// Return the storage constant for a canonical tensor-like constant query. fn constant_tensor_storage(array: &ArrayRef) -> Option { let constant = array.as_opt::()?; @@ -581,6 +598,7 @@ mod tests { use crate::tests::SESSION; use crate::utils::test_helpers::assert_close; use crate::utils::test_helpers::l2_denorm_array; + use crate::utils::test_helpers::normalized_vector_array; use crate::utils::test_helpers::tensor_array; use crate::utils::test_helpers::vector_array; @@ -701,8 +719,8 @@ mod tests { // RHS: [1.0, 0.0] = L2Denorm([1.0, 0.0], 1.0). // dot([3.0, 4.0], [1.0, 0.0]) = 3.0. let mut ctx = SESSION.create_execution_ctx(); - let lhs = l2_denorm_array(&[2], &[0.6, 0.8], &[5.0], &mut ctx)?; - let rhs = l2_denorm_array(&[2], &[1.0, 0.0], &[1.0], &mut ctx)?; + let lhs = l2_denorm_array(2, &[0.6, 0.8], &[5.0], &mut ctx)?; + let rhs = l2_denorm_array(2, &[1.0, 0.0], &[1.0], &mut ctx)?; // Expected: 5.0 * 1.0 * dot([0.6, 0.8], [1.0, 0.0]) = 5.0 * 0.6 = 3.0. assert_close(&eval_inner_product(lhs, rhs, 1)?, &[3.0]); @@ -714,8 +732,8 @@ mod tests { // Row 0: [3.0, 4.0] dot [3.0, 4.0] = 25.0. // Row 1: [1.0, 0.0] dot [0.0, 1.0] = 0.0. let mut ctx = SESSION.create_execution_ctx(); - let lhs = l2_denorm_array(&[2], &[0.6, 0.8, 1.0, 0.0], &[5.0, 1.0], &mut ctx)?; - let rhs = l2_denorm_array(&[2], &[0.6, 0.8, 0.0, 1.0], &[5.0, 1.0], &mut ctx)?; + let lhs = l2_denorm_array(2, &[0.6, 0.8, 1.0, 0.0], &[5.0, 1.0], &mut ctx)?; + let rhs = l2_denorm_array(2, &[0.6, 0.8, 0.0, 1.0], &[5.0, 1.0], &mut ctx)?; assert_close(&eval_inner_product(lhs, rhs, 2)?, &[25.0, 0.0]); Ok(()) @@ -727,8 +745,8 @@ mod tests { // RHS: plain [1.0, 2.0]. // dot([3.0, 4.0], [1.0, 2.0]) = 3.0 + 8.0 = 11.0. let mut ctx = SESSION.create_execution_ctx(); - let lhs = l2_denorm_array(&[2], &[0.6, 0.8], &[5.0], &mut ctx)?; - let rhs = tensor_array(&[2], &[1.0, 2.0])?; + let lhs = l2_denorm_array(2, &[0.6, 0.8], &[5.0], &mut ctx)?; + let rhs = vector_array(2, &[1.0, 2.0])?; assert_close(&eval_inner_product(lhs, rhs, 1)?, &[11.0]); Ok(()) @@ -740,8 +758,8 @@ mod tests { // RHS: L2Denorm([0.6, 0.8], 5.0) representing [3.0, 4.0]. // dot([1.0, 2.0], [3.0, 4.0]) = 3.0 + 8.0 = 11.0. let mut ctx = SESSION.create_execution_ctx(); - let lhs = tensor_array(&[2], &[1.0, 2.0])?; - let rhs = l2_denorm_array(&[2], &[0.6, 0.8], &[5.0], &mut ctx)?; + let lhs = vector_array(2, &[1.0, 2.0])?; + let rhs = l2_denorm_array(2, &[0.6, 0.8], &[5.0], &mut ctx)?; assert_close(&eval_inner_product(lhs, rhs, 1)?, &[11.0]); Ok(()) @@ -750,12 +768,16 @@ mod tests { #[test] fn both_denorm_null_norms() -> VortexResult<()> { // Row 0: valid, row 1: null (via nullable norms on lhs). - let normalized_l = tensor_array(&[2], &[0.6, 0.8, 1.0, 0.0])?; + let normalized_l = normalized_vector_array( + 2, + &[0.6, 0.8, 1.0, 0.0], + &mut SESSION.create_execution_ctx(), + )?; let norms_l = PrimitiveArray::from_option_iter([Some(5.0f64), None]).into_array(); let mut ctx = SESSION.create_execution_ctx(); let lhs = L2Denorm::try_new_array(normalized_l, norms_l, 2, &mut ctx)?.into_array(); - let rhs = l2_denorm_array(&[2], &[0.6, 0.8, 1.0, 0.0], &[5.0, 1.0], &mut ctx)?; + let rhs = l2_denorm_array(2, &[0.6, 0.8, 1.0, 0.0], &[5.0, 1.0], &mut ctx)?; let scalar_fn = InnerProduct::new().erased(); let result = ScalarFnArray::try_new(scalar_fn, vec![lhs, rhs], 2)?; @@ -768,9 +790,58 @@ mod tests { Ok(()) } + /// Naked [`NormalizedVector`](crate::normalized_vector::NormalizedVector) operands fall + /// through to the regular dot path (no extra scaling). The result is just `dot(lhs, rhs)`. + #[test] + fn naked_normalized_vector_dot() -> VortexResult<()> { + let mut ctx = SESSION.create_execution_ctx(); + let lhs = normalized_vector_array(2, &[0.6, 0.8, 1.0, 0.0], &mut ctx)?; + let rhs = normalized_vector_array(2, &[0.6, 0.8, 0.0, 1.0], &mut ctx)?; + + // Row 0: dot([0.6,0.8],[0.6,0.8]) = 1.0, Row 1: dot([1.0,0.0],[0.0,1.0]) = 0.0. + assert_close(&eval_inner_product(lhs, rhs, 2)?, &[1.0, 0.0]); + Ok(()) + } + + #[test] + fn serde_round_trip_mixed_vector_and_normalized_vector() -> VortexResult<()> { + let mut ctx = SESSION.create_execution_ctx(); + let lhs = normalized_vector_array(2, &[0.6, 0.8, 1.0, 0.0], &mut ctx)?; + let rhs = vector_array(2, &[3.0, 4.0, 0.0, 1.0])?; + let original = InnerProduct::try_new_array(lhs.clone(), rhs.clone(), 2)?.into_array(); + + let plugin = ScalarFnArrayPlugin::new(InnerProduct); + let metadata = plugin + .serialize(&original, &SESSION)? + .expect("InnerProduct serialize must produce metadata"); + + let children = vec![lhs, rhs]; + let recovered = plugin.deserialize( + original.dtype(), + original.len(), + &metadata, + &[], + &children, + &SESSION, + )?; + + assert_eq!(recovered.dtype(), original.dtype()); + assert_eq!(recovered.len(), original.len()); + assert_eq!(recovered.encoding_id(), original.encoding_id()); + Ok(()) + } + #[rstest] - #[case::vector(inner_product_vector_lhs(), inner_product_vector_rhs(), 2)] - #[case::fixed_shape_tensor(inner_product_tensor_lhs(), inner_product_tensor_rhs(), 2)] + #[case::vector( + vector_array(3, &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap(), + vector_array(3, &[7.0, 8.0, 9.0, 10.0, 11.0, 12.0]).unwrap(), + 2, + )] + #[case::fixed_shape_tensor( + tensor_array(&[2], &[1.0, 2.0, 3.0, 4.0]).unwrap(), + tensor_array(&[2], &[5.0, 6.0, 7.0, 8.0]).unwrap(), + 2, + )] fn serde_round_trip( #[case] lhs: ArrayRef, #[case] rhs: ArrayRef, @@ -799,22 +870,6 @@ mod tests { Ok(()) } - fn inner_product_vector_lhs() -> ArrayRef { - vector_array(3, &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).expect("valid vector array") - } - - fn inner_product_vector_rhs() -> ArrayRef { - vector_array(3, &[7.0, 8.0, 9.0, 10.0, 11.0, 12.0]).expect("valid vector array") - } - - fn inner_product_tensor_lhs() -> ArrayRef { - tensor_array(&[2], &[1.0, 2.0, 3.0, 4.0]).expect("valid tensor array") - } - - fn inner_product_tensor_rhs() -> ArrayRef { - tensor_array(&[2], &[5.0, 6.0, 7.0, 8.0]).expect("valid tensor array") - } - // ---- Tests for the `SorfTransform + constant` and `Dict + constant` fast paths ---- #[allow( diff --git a/vortex-tensor/src/scalar_fns/l2_denorm.rs b/vortex-tensor/src/scalar_fns/l2_denorm.rs index 1bdd81833d9..0230e36d8f2 100644 --- a/vortex-tensor/src/scalar_fns/l2_denorm.rs +++ b/vortex-tensor/src/scalar_fns/l2_denorm.rs @@ -31,9 +31,11 @@ use vortex_array::builtins::ArrayBuiltins; use vortex_array::dtype::DType; use vortex_array::dtype::NativePType; use vortex_array::dtype::Nullability; +use vortex_array::dtype::extension::ExtDType; use vortex_array::dtype::proto::dtype as pb; use vortex_array::expr::Expression; use vortex_array::expr::and; +use vortex_array::extension::EmptyMetadata; use vortex_array::match_each_float_ptype; use vortex_array::scalar::Scalar; use vortex_array::scalar::ScalarValue; @@ -59,32 +61,35 @@ use vortex_session::VortexSession; use crate::matcher::AnyTensor; use crate::scalar_fns::l2_norm::L2Norm; +use crate::types::normalized_vector::NormalizedVector; +use crate::types::normalized_vector::inner_vector_array; +use crate::types::normalized_vector::vector_fsl_storage_dtype; +use crate::types::vector::AnyVector; +use crate::types::vector::Vector; use crate::utils::extract_constant_flat_row; use crate::utils::extract_flat_elements; +use crate::utils::extract_l2_denorm_children; use crate::utils::unit_norm_tolerance; -use crate::utils::validate_tensor_float_input; -/// Re-applies authoritative L2 norms to a normalized tensor column. +/// Re-applies authoritative L2 norms to a normalized vector column. /// -/// Computes `normalized * norm` on each row over the flat backing buffer of each tensor-like type. +/// Computes `normalized * norm` on each row over the flat backing buffer of the vector-shaped +/// child. /// -/// The normalized input must be a tensor-like extension array with a float element type and each -/// non-null row is semantically required to already be L2-normalized. +/// The first child must be vector-shaped and semantically suitable for L2 denormalization. Exact +/// callers should use [`try_new_array`](Self::try_new_array), which verifies that plain +/// [`Vector`] inputs are row-wise unit-norm (or zero). Lossy encodings may use +/// [`new_array_unchecked`](Self::new_array_unchecked) when the decoded child is only an +/// approximation but the stored norms are still authoritative. /// -/// The norms input must be a primitive float column with the same element type as the normalized -/// tensor elements. -/// -/// [`L2Denorm`] is the norm-splitting wrapper used throughout the tensor crate. Callers that build -/// it through [`try_new_array`](Self::try_new_array) get an exact unit-norm invariant on the -/// `normalized` child. -/// -/// Advanced callers can also use [`new_array_unchecked`](Self::new_array_unchecked) to attach -/// authoritative stored norms to a lossy approximation of that child, such as quantized normalized -/// vectors. +/// The norms input must be a primitive float column with the same element type as the +/// normalized vector elements. /// /// Downstream readthrough rules intentionally treat the stored norms and normalized child as the /// encoding contract, even when that differs slightly from recomputing over fully decoded /// coordinates. +/// +/// [`NormalizedVector`]: crate::normalized_vector::NormalizedVector #[derive(Clone)] pub struct L2Denorm; @@ -99,45 +104,64 @@ impl L2Denorm { /// Constructs a validated [`ScalarFnArray`] that lazily re-applies `norms` to `normalized`. /// - /// This is the correct constructor for [`L2Denorm`] arrays. In addition to the structural - /// checks performed by [`ScalarFnArray::try_new`], it validates that every valid row of the - /// `normalized` child has L2 norm `1.0` (or `0.0` for zero rows), within the tolerance implied - /// by the child element precision. It also validates that stored norms are non-negative, and - /// that any row with stored norm `0.0` has an all-zero normalized row. + /// In addition to the structural checks performed by [`ScalarFnArray::try_new`], this + /// constructor verifies that plain [`Vector`] children are row-wise unit-norm (or zero), that + /// stored norms are non-negative, and that any row with stored norm `0.0` has an all-zero + /// normalized row. + /// + /// Plain [`Vector`] children are promoted to [`NormalizedVector`] after validation so that + /// downstream execution paths can rely on the unit-norm marker. /// /// # Errors /// /// Returns an error if the [`ScalarFnArray`] cannot be constructed (e.g. due to dtype - /// mismatches) or if the `normalized` child is not row-wise L2-normalized. + /// mismatches), if a stored norm is negative, or if a zero-norm row is paired with a + /// non-zero normalized row. pub fn try_new_array( normalized: ArrayRef, norms: ArrayRef, len: usize, ctx: &mut ExecutionCtx, ) -> VortexResult { - validate_l2_normalized_rows_against_norms(&normalized, Some(&norms), ctx)?; + validate_norms_against_normalized(&normalized, &norms, ctx)?; + + // Promote plain `Vector` children to `NormalizedVector`. The unit-norm invariant is + // verified by `validate_norms_against_normalized`, so the `wrap_vector_unchecked` wrap is + // safe by construction. + let normalized = if normalized + .dtype() + .as_extension_opt() + .is_some_and(|ext| ext.is::()) + { + normalized + } else { + // SAFETY: row-wise unit-norm (or zero) was just verified for the plain `Vector` input + // above. Wrap the `Vector` extension array as a `NormalizedVector` without unpacking + // to FSL storage. + unsafe { NormalizedVector::wrap_vector_unchecked(normalized) }? + }; - // SAFETY: We just validated that it is normalized. + // SAFETY: The validation above established the exact L2Denorm invariants. unsafe { Self::new_array_unchecked(normalized, norms, len) } } - /// Constructs an [`L2Denorm`] array without validating that the `normalized` child is actually - /// row-wise L2-normalized. + /// Constructs an [`L2Denorm`] array without validating the normalized-child invariant. /// - /// This escape hatch is intended for advanced callers that already established, or - /// intentionally relax, the normalized-child invariant. Structural validation still runs via - /// [`ScalarFnArray::try_new`]. + /// Structural validation still runs via [`ScalarFnArray::try_new`]. Use this when the + /// normalized child is a lossy approximation whose rows may not be exactly unit-norm or may not + /// preserve exact zero-ness. /// /// # Safety /// - /// The caller must ensure the `normalized` child is semantically suitable for L2 - /// denormalization. For exact wrappers, that means every valid row is unit-norm or zero. + /// The caller must ensure the first child is semantically suitable for L2 denormalization. + /// For exact wrappers, every valid row must be unit-norm or zero and stored norms must be + /// non-negative. Lossy encodings may deliberately relax the exact row invariant while still + /// treating the stored norms as authoritative. /// - /// Lossy encodings may deliberately relax that invariant while still treating the stored norms - /// as authoritative. + /// # Errors /// - /// Violating the intended contract will not cause memory unsafety, but may produce incorrect - /// results. + /// Returns an error if the [`ScalarFnArray`] cannot be constructed (e.g. due to dtype + /// mismatches). pub unsafe fn new_array_unchecked( normalized: ArrayRef, norms: ArrayRef, @@ -183,8 +207,19 @@ impl ScalarFnVTable for L2Denorm { let normalized = &arg_dtypes[0]; let norms = &arg_dtypes[1]; - let tensor_match = validate_tensor_float_input(normalized)?; - let element_ptype = tensor_match.element_ptype(); + let ext = normalized.as_extension_opt().ok_or_else(|| { + vortex_err!( + "L2Denorm normalized child must be a Vector or NormalizedVector, got \ + {normalized}", + ) + })?; + let normalized_metadata = ext.metadata_opt::().ok_or_else(|| { + vortex_err!( + "L2Denorm normalized child must be a Vector or NormalizedVector, got \ + {normalized}", + ) + })?; + let element_ptype = normalized_metadata.element_ptype(); let DType::Primitive(norms_ptype, _) = norms else { vortex_bail!("L2Denorm norms must be a primitive float array, got {norms}"); @@ -196,7 +231,17 @@ impl ScalarFnVTable for L2Denorm { got {norms_ptype}", ); - Ok(normalized.union_nullability(norms.nullability())) + // The denormalized output has the same FSL storage shape as the normalized child but is + // no longer guaranteed unit-norm, so it surfaces as a plain `Vector` extension type. + let fsl_dtype = vector_fsl_storage_dtype(ext).ok_or_else(|| { + vortex_err!( + "L2Denorm normalized child must be a Vector or NormalizedVector, got \ + {normalized}", + ) + })?; + let plain_vector = + DType::Extension(ExtDType::::try_new(EmptyMetadata, fsl_dtype)?.erased()); + Ok(plain_vector.union_nullability(norms.nullability())) } fn execute( @@ -207,9 +252,19 @@ impl ScalarFnVTable for L2Denorm { ) -> VortexResult { let normalized_ref = args.get(0)?; let norms_ref = args.get(1)?; - let output_dtype = normalized_ref - .dtype() - .union_nullability(norms_ref.dtype().nullability()); + // Output is a plain `Vector` (not `NormalizedVector`) because denormalized values are no + // longer guaranteed unit-norm. Drill through any `NormalizedVector` wrapper to get the + // underlying FSL. + let fsl_dtype = vector_fsl_storage_dtype(normalized_ref.dtype().as_extension()) + .ok_or_else(|| { + vortex_err!( + "L2Denorm normalized child must be a Vector or NormalizedVector, got {}", + normalized_ref.dtype(), + ) + })?; + let output_dtype = + DType::Extension(ExtDType::::try_new(EmptyMetadata, fsl_dtype)?.erased()) + .union_nullability(norms_ref.dtype().nullability()); let validity = normalized_ref.validity()?.and(norms_ref.validity()?)?; if let Some(const_norms) = norms_ref.as_opt::() { @@ -232,7 +287,10 @@ impl ScalarFnVTable for L2Denorm { } } - let normalized: ExtensionArray = normalized_ref.execute(ctx)?; + // Drill past any `NormalizedVector` wrapper so we always work with the underlying + // `Vector` extension array. + let vector_ref = inner_vector_array(&normalized_ref, ctx)?; + let normalized: ExtensionArray = vector_ref.execute(ctx)?; let norms: PrimitiveArray = norms_ref.execute(ctx)?; let row_count = args.row_count(); @@ -366,7 +424,7 @@ fn execute_l2_denorm_constant_norms( .and_then(|ext| ext.metadata_opt::()) .ok_or_else(|| { vortex_err!( - "L2Denorm normalized child must be a tensor-like extension, got {}", + "L2Denorm normalized child must be a Vector or NormalizedVector, got {}", normalized_ref.dtype(), ) })?; @@ -375,13 +433,25 @@ fn execute_l2_denorm_constant_norms( norm_scalar.dtype().as_ptype(), tensor_match.list_size() as usize, ); + + // Drill past any outer `NormalizedVector` wrapper so we always work with the inner plain + // `Vector` extension array (and its `FixedSizeList` storage). + let vector_ref = inner_vector_array(&normalized_ref, ctx)?; + if err.abs() < tolerance { - return Ok(normalized_ref); + // The output dtype is the sibling plain `Vector`. Rewrap the vector storage so the + // executed array's dtype matches `output_dtype`. + let normalized: ExtensionArray = vector_ref.execute(ctx)?; + return Ok(ExtensionArray::try_new( + output_dtype.as_extension().clone(), + normalized.storage_array().clone(), + )? + .into_array()); } // Even if the norms are not all 1, if they are all the same then we can multiply // the entire elements array by the same number. - let normalized: ExtensionArray = normalized_ref.execute(ctx)?; + let normalized: ExtensionArray = vector_ref.execute(ctx)?; let storage_fsl: FixedSizeListArray = normalized.storage_array().clone().execute(ctx)?; // Replace the elements array with an array that multiplies it by the constant @@ -408,35 +478,70 @@ fn execute_l2_denorm_constant_norms( } /// Builds an unexecuted [`L2Denorm`] expression by normalizing `input` and reattaching the exact -/// norms as the norms child. +/// norms as the `norms` child. /// /// The returned array is a lazy `L2Denorm(normalized, norms)` scalar function array. /// /// # Normalized child /// -/// The normalized child is always **non-nullable** with [`Validity::NonNullable`]. Every non-null -/// row with a positive L2 norm is divided by its norm to produce a unit-norm vector. +/// For plain [`Vector`] (and [`FixedShapeTensor`]) input, every non-null row with a positive L2 +/// norm is divided by its norm to produce a unit-norm vector. The normalized child is forced +/// **non-nullable** with [`Validity::NonNullable`] so optimized kernels over normalized vectors +/// only have to reason about unit-norm vs. zero rows, not nulls. Rows that are null in the +/// original input are **zeroed out** in the normalized output to avoid leaking undefined +/// physical storage values into downstream encodings (like TurboQuant). /// -/// Rows that are null in the original input are **zeroed out** in the normalized output. This is -/// necessary because null rows may have undefined (garbage) physical storage values, and we do not -/// want to let those propagate into downstream encodings (like TurboQuant). +/// For [`NormalizedVector`] input, the function takes a fast path that returns the input +/// unchanged as the normalized child and asks [`L2Norm`] for the per-row norms. The fast path +/// preserves the input's outer nullability rather than rewriting null rows to zero, since the +/// caller has already committed to a [`NormalizedVector`] shape and we do not want to +/// re-allocate storage just to coerce nullability. /// /// # Nullability /// -/// Nullability is tracked entirely by the norms child. Null input rows produce null norms via -/// [`L2Norm`]'s validity propagation. When the [`L2Denorm`] wrapper is executed, its validity is -/// `and(normalized_validity, norms_validity)`, which correctly identifies originally-null rows -/// since the normalized child is all-valid and the norms child carries the original nulls. +/// Nullability is tracked entirely by the `norms` child. Null input rows produce null `norms` via +/// [`L2Norm`]'s validity propagation. When the [`L2Denorm`] wrapper is executed, the output +/// validity is `and(normalized_validity, norms_validity)`, which correctly identifies +/// originally-null rows. +/// +/// Because this helper computes exact `norms` and (on the slow path) divides by them, the +/// returned `normalized` child satisfies the unit-norm invariant required by [`L2Denorm`]. /// -/// Because this helper computes exact norms first and then divides by those norms, the returned -/// `normalized` child satisfies the strict unit-norm invariant required by [`L2Denorm`]. +/// [`FixedShapeTensor`]: crate::fixed_shape::FixedShapeTensor +/// [`NormalizedVector`]: crate::normalized_vector::NormalizedVector pub fn normalize_as_l2_denorm( input: ArrayRef, ctx: &mut ExecutionCtx, ) -> VortexResult { let row_count = input.len(); - let tensor_match = validate_tensor_float_input(input.dtype())?; - let tensor_flat_size = tensor_match.list_size() as usize; + let tensor_metadata = input + .dtype() + .as_extension_opt() + .and_then(|ext| ext.metadata_opt::()) + .ok_or_else(|| { + vortex_err!( + "normalize_as_l2_denorm requires a tensor-like extension input, got {}", + input.dtype(), + ) + })?; + let tensor_flat_size = tensor_metadata.list_size() as usize; + + // Fast path: input is already a `NormalizedVector`. The slow path below would compute exact + // norms and divide every row by its norm, but for a `NormalizedVector` the divisor is always + // 1.0 (or 0.0 for zero rows). Skip the divide entirely and reattach `L2Norm`'s + // short-circuited per-row 0.0 / 1.0 norms. Crucially, this preserves the invariant required + // by [`L2Denorm::try_new_array`] that a zero-norm row is paired with an all-zero normalized + // row, because [`L2Norm`]'s `NormalizedVector` short-circuit emits 0.0 exactly when the row + // is all zero. + if tensor_metadata.is_normalized() { + let norms_sfn = L2Norm::try_new_array(input.clone(), row_count)?; + let norms_array: ArrayRef = norms_sfn.into_array().execute(ctx)?; + + // SAFETY: `input` is a `NormalizedVector`, so every valid row is unit-norm or zero by + // type. `norms_array` was produced by `L2Norm`, so stored norms are non-negative and a + // zero norm is always paired with an all-zero row. + return unsafe { L2Denorm::new_array_unchecked(input, norms_array, row_count) }; + } // Constant fast path: if the input is a constant-backed extension, normalize the single // stored row once and return an `L2Denorm` whose children are both `ConstantArray`s. @@ -451,11 +556,10 @@ pub fn normalize_as_l2_denorm( let norms_validity = primitive_norms.validity()?; let input: ExtensionArray = input.execute(ctx)?; - let normalized_dtype = input.dtype().as_nonnullable(); let flat = extract_flat_elements(input.storage_array(), tensor_flat_size, ctx)?; // Normalize all of the vectors. - let normalized = match_each_float_ptype!(flat.ptype(), |T| { + let normalized_storage = match_each_float_ptype!(flat.ptype(), |T| { let norm_values = primitive_norms.as_slice::(); let total_elements = row_count * tensor_flat_size; @@ -478,24 +582,19 @@ pub fn normalize_as_l2_denorm( } // Since L2Denorm's validity is the `and` of its child validities, we can make the - // normalized array non-nullable. - build_tensor_array( - normalized_dtype, - tensor_flat_size, - row_count, - Validity::NonNullable, - elements.freeze(), - ) + // normalized child non-nullable. + build_normalized_storage(tensor_flat_size, row_count, elements.freeze()) })?; // SAFETY: // - `norms_array` was produced by `L2Norm(input)`, so every stored norm is non-negative and // null rows already carry null validity through that child. // - For every valid row, we either emit all zeros when the norm is zero or divide every - // element by the exact stored norm, so the normalized child is unit-norm (or zero) by + // element by the exact stored norm, so the normalized storage is unit-norm (or zero) by // construction. - // - Null rows are zeroed out above to avoid propagating arbitrary physical storage values into - // downstream lossy encodings. + // - Null rows are zeroed out above to avoid propagating arbitrary physical storage values + // into downstream lossy encodings. + let normalized = unsafe { NormalizedVector::new_unchecked(normalized_storage) }?; unsafe { L2Denorm::new_array_unchecked(normalized, norms_array, row_count) } } @@ -526,16 +625,18 @@ pub(crate) fn try_build_constant_l2_denorm( return Ok(None); } - // The caller is expected to have already validated that `input` is an `AnyTensor` - // extension dtype. - let tensor_match = input + // Only promote vector-family inputs: wrapping FST rows as `NormalizedVector` would be a + // family change, so `FixedShapeTensor` constants fall back to the generic fast path with + // per-row division. + let Some(vector_metadata) = input .dtype() - .as_extension() - .metadata_opt::() - .vortex_expect("caller validated input has AnyTensor metadata"); - let list_size = tensor_match.list_size() as usize; + .as_extension_opt() + .and_then(|ext| ext.metadata_opt::()) + else { + return Ok(None); + }; + let list_size = vector_metadata.dimensions() as usize; let original_nullability = input.dtype().nullability(); - let ext_dtype = input.dtype().as_extension().clone(); let storage_fsl_nullability = storage.dtype().nullability(); // Materialize just the single stored row; this does not expand the constant to the full @@ -551,8 +652,8 @@ pub(crate) fn try_build_constant_l2_denorm( } let norm_t: T = sum_sq.sqrt(); - // Zero-norm rows must be stored as all-zeros so [`L2Denorm`]'s unit-norm-or-zero - // invariant holds. This mirrors the per-row logic in `normalize_as_l2_denorm`. + // Zero-norm rows must be stored as all-zeros so the `NormalizedVector` invariant holds. + // This mirrors the per-row logic in `normalize_as_l2_denorm`. let element_dtype = DType::Primitive(T::PTYPE, Nullability::NonNullable); let children: Vec = if norm_t == T::zero() { (0..list_size) @@ -564,23 +665,23 @@ pub(crate) fn try_build_constant_l2_denorm( .collect() }; - // The rebuilt FSL scalar preserves the original storage FSL's nullability so the - // resulting `ExtensionArray::new` call accepts the same extension dtype. let fsl_scalar = Scalar::fixed_size_list(element_dtype, children, storage_fsl_nullability); let norms_scalar = Scalar::primitive(norm_t, original_nullability); (fsl_scalar, norms_scalar) }); let normalized_storage = ConstantArray::new(normalized_fsl_scalar, len).into_array(); - let normalized_ext = ExtensionArray::new(ext_dtype, normalized_storage).into_array(); + // SAFETY: The single stored row is either `v / ||v||` (unit norm within floating-point + // tolerance) or all zeros when `||v|| == 0`. This is the invariant required by + // `NormalizedVector::new_unchecked`. + let normalized = unsafe { NormalizedVector::new_unchecked(normalized_storage) }?; let norms_array = ConstantArray::new(norms_scalar, len).into_array(); - // SAFETY: Each row of `normalized_ext` is either `v / ||v||` (unit norm within floating - // point tolerance) or all zeros when `||v|| == 0`. Stored norms are non-negative by - // construction (`sqrt`). These are exactly the invariants required by - // [`L2Denorm::new_array_unchecked`]. - let wrapped = unsafe { L2Denorm::new_array_unchecked(normalized_ext, norms_array, len)? }; - Ok(Some(wrapped)) + // SAFETY: The single stored row is exactly normalized above (or all zeros), and the norm was + // computed with `sqrt`, so it is non-negative. + Ok(Some(unsafe { + L2Denorm::new_array_unchecked(normalized, norms_array, len)? + })) } /// Rebuilds a tensor-like extension array from flat primitive elements. @@ -596,97 +697,136 @@ fn build_tensor_array( validity: Validity, elements: Buffer, ) -> VortexResult { + let storage = build_fsl_storage(tensor_flat_size, row_count, validity, elements)?.into_array(); + Ok(ExtensionArray::new(dtype.as_extension().clone(), storage).into_array()) +} + +/// Build a non-nullable [`FixedSizeListArray`] suitable for wrapping as a +/// [`NormalizedVector`] storage. +fn build_normalized_storage( + tensor_flat_size: usize, + row_count: usize, + elements: Buffer, +) -> VortexResult { + Ok( + build_fsl_storage(tensor_flat_size, row_count, Validity::NonNullable, elements)? + .into_array(), + ) +} + +/// Build a [`FixedSizeListArray`] from a flat element buffer and a per-row validity. +fn build_fsl_storage( + tensor_flat_size: usize, + row_count: usize, + validity: Validity, + elements: Buffer, +) -> VortexResult { let list_size = u32::try_from(tensor_flat_size).vortex_expect("tensor flat size must fit into `u32`"); - // SAFETY: Validity has no length (because tensor elements are always non-nullable). let elements = unsafe { PrimitiveArray::new_unchecked(elements, Validity::NonNullable) }; - - let storage = - FixedSizeListArray::try_new(elements.into_array(), list_size, validity, row_count)?; - - Ok(ExtensionArray::new(dtype.as_extension().clone(), storage.into_array()).into_array()) + FixedSizeListArray::try_new(elements.into_array(), list_size, validity, row_count) } -/// Validates that `normalized` and (when supplied) the matching `norms` jointly satisfy the -/// [`L2Denorm`] invariants: +// TODO(connor): Need better logic here to check against `NormalizedVector` vs `Vector`. +/// Cross-check that `normalized` and `norms` agree on per-row zero-ness, and that stored norms +/// are non-negative. Unit-norm enforcement on the rows lives on the +/// [`NormalizedVector`](crate::normalized_vector::NormalizedVector) dtype itself. /// -/// - Every valid row of `normalized` has L2 norm `1.0` or `0.0` (within element-precision -/// tolerance). -/// - When `norms` is supplied, every stored norm is non-negative and any row whose stored norm is -/// `0.0` is exactly the zero vector in `normalized`. -pub fn validate_l2_normalized_rows_against_norms( +/// We match against [`AnyTensor`] for symmetry with the rest of the tensor pipeline, but +/// downstream construction in [`L2Denorm::return_dtype`] only succeeds for `Vector` and +/// `NormalizedVector` storage (see [`vector_fsl_storage_dtype`]). A `FixedShapeTensor` operand +/// will pass this validator and then be rejected later, which is why the user-visible error +/// message names only the two supported shapes. +fn validate_norms_against_normalized( normalized: &ArrayRef, - norms: Option<&ArrayRef>, + norms: &ArrayRef, ctx: &mut ExecutionCtx, ) -> VortexResult<()> { + let tensor_match = normalized + .dtype() + .as_extension_opt() + .and_then(|ext| ext.metadata_opt::()) + .ok_or_else(|| { + vortex_err!( + "L2Denorm normalized child must be a Vector or NormalizedVector, got {}", + normalized.dtype(), + ) + })?; let row_count = normalized.len(); - if row_count == 0 { - return Ok(()); - } - - let tensor_match = validate_tensor_float_input(normalized.dtype())?; let element_ptype = tensor_match.element_ptype(); + let tolerance = unit_norm_tolerance(element_ptype, tensor_match.list_size() as usize); let tensor_flat_size = tensor_match.list_size() as usize; - let tolerance = unit_norm_tolerance(element_ptype, tensor_flat_size); + let skip_unit_norm_check = tensor_match.is_normalized(); - if let Some(norms) = norms { - vortex_ensure_eq!( - norms.dtype().as_ptype(), - element_ptype, - "L2Denorm norms ptype must match normalized element ptype" + vortex_ensure_eq!( + norms.len(), + row_count, + "L2Denorm normalized and norms children must have the same length" + ); + + let DType::Primitive(norms_ptype, _) = norms.dtype() else { + vortex_bail!( + "L2Denorm norms must be a primitive float array, got {}", + norms.dtype() ); - } + }; + vortex_ensure_eq!( + *norms_ptype, + element_ptype, + "L2Denorm norms ptype must match normalized element ptype" + ); - let normalized: ExtensionArray = normalized.clone().execute(ctx)?; - let normalized_validity = normalized.as_ref().validity()?; + if row_count == 0 { + return Ok(()); + } - let flat = extract_flat_elements(normalized.storage_array(), tensor_flat_size, ctx)?; - let norms = norms - .map(|norms| norms.clone().execute::(ctx)) - .transpose()?; + // Drill past any outer `NormalizedVector` wrapper so we always iterate the FSL of the + // inner plain `Vector`. + let vector_ref = inner_vector_array(normalized, ctx)?; + let vector_ext: ExtensionArray = vector_ref.execute(ctx)?; + let normalized_validity = vector_ext.as_ref().validity()?; - let combined_validity = match &norms { - Some(norms) => normalized_validity.and(norms.validity()?)?, - None => normalized_validity, - }; + let flat = extract_flat_elements(vector_ext.storage_array(), tensor_flat_size, ctx)?; + let norms_prim: PrimitiveArray = norms.clone().execute(ctx)?; + let combined_validity = normalized_validity.and(norms_prim.validity()?)?; match_each_float_ptype!(element_ptype, |T| { - let stored_norms = norms.as_ref().map(|norms| norms.as_slice::()); + let stored_norms = norms_prim.as_slice::(); for i in 0..row_count { if !combined_validity.is_valid(i)? { continue; } + let stored_norm_f64 = ToPrimitive::to_f64(&stored_norms[i]).unwrap_or(f64::NAN); + vortex_ensure!( + stored_norm_f64 >= 0.0, + "L2Denorm norms must be non-negative, but row {i} has {stored_norm_f64:.6}", + ); + let (row_norm_sq, is_zero_row) = flat.row::(i) .iter() - .fold((0.0f64, true), |(sum_sq, is_zero), x| { + .fold((0.0f64, true), |(sum_sq, all_zero), x| { let value = ToPrimitive::to_f64(x).unwrap_or(f64::NAN); - (sum_sq + value * value, is_zero && value.abs() <= tolerance) + (sum_sq + value * value, all_zero && value.abs() <= tolerance) }); - let row_norm = row_norm_sq.sqrt(); - - vortex_ensure!( - row_norm == 0.0 || (row_norm - 1.0).abs() <= tolerance, - "L2Denorm normalized child must have L2 norm 1.0 or 0.0, but row {i} has \ - {row_norm:.6}", - ); - if let Some(stored_norms) = stored_norms { - let stored_norm_f64 = ToPrimitive::to_f64(&stored_norms[i]).unwrap_or(f64::NAN); + if !skip_unit_norm_check { + let row_norm = row_norm_sq.sqrt(); vortex_ensure!( - stored_norm_f64 >= 0.0, - "L2Denorm norms must be non-negative, but row {i} has {stored_norm_f64:.6}", + row_norm == 0.0 || (row_norm - 1.0).abs() <= tolerance, + "L2Denorm normalized child row {i} has L2 norm {row_norm:.6}, \ + expected 1.0 or 0.0", ); + } - if stored_norm_f64 == 0.0 { - vortex_ensure!( - is_zero_row, - "L2Denorm normalized child must be all zeros when norms row {i} is 0.0", - ); - } + if stored_norm_f64 == 0.0 { + vortex_ensure!( + is_zero_row, + "L2Denorm normalized child must be all zeros when norms row {i} is 0.0", + ); } } }); @@ -694,47 +834,59 @@ pub fn validate_l2_normalized_rows_against_norms( Ok(()) } -/// Classification of a binary operand pair by which side (if any) is wrapped in [`L2Denorm`]. +/// Per-operand classification of a tensor argument by whether it carries an authoritative unit-norm +/// representation, and whether stored norms accompany it. /// -/// Symmetric binary tensor operators (e.g. [`CosineSimilarity`], [`InnerProduct`]) have identical -/// fast paths for "only the lhs is denormalized" and "only the rhs is denormalized", and a separate -/// fast path for "both are denormalized". Rather than hand-rolling the commutative swap at every -/// call site, callers classify their operands with [`Self::classify`] and pattern-match on the -/// returned variant. +/// Symmetric binary tensor operators ([`CosineSimilarity`], [`InnerProduct`]) and unary ones +/// ([`L2Norm`]) take a fast path whenever an operand carries a unit-norm representation. Callers +/// classify each operand individually via [`Self::classify`] and pattern-match on the resulting +/// variant. /// /// [`CosineSimilarity`]: crate::scalar_fns::cosine_similarity::CosineSimilarity /// [`InnerProduct`]: crate::scalar_fns::inner_product::InnerProduct -pub(crate) enum DenormOrientation<'a> { - /// Both operands are [`ExactScalarFn`] arrays. - Both { - lhs: &'a ArrayRef, - rhs: &'a ArrayRef, - }, - /// Exactly one operand is an [`ExactScalarFn`]; the other is plain. - One { - denorm: &'a ArrayRef, - plain: &'a ArrayRef, +pub(crate) enum NormalForm<'a> { + /// A plain `Vector`. + Plain, + + /// An already-normalized `NormalizedVector`, which has implicit norms of `1.0`. + Normalized { array: &'a ArrayRef }, + + /// Decomposed `L2Denorm(normalized: NormalizedVector, norms)`. + /// + /// Note that `normalized` is _always_ non-null, and the validity is stored in `norms`. + Denormalized { + normalized: ArrayRef, + norms: ArrayRef, }, - /// Neither operand is an [`ExactScalarFn`]. - Neither, } -impl<'a> DenormOrientation<'a> { - /// Classify `(lhs, rhs)` by which side (if any) is wrapped in [`L2Denorm`]. - pub(crate) fn classify(lhs: &'a ArrayRef, rhs: &'a ArrayRef) -> Self { - let lhs_denorm = lhs.is::>(); - let rhs_denorm = rhs.is::>(); - match (lhs_denorm, rhs_denorm) { - (true, true) => Self::Both { lhs, rhs }, - (true, false) => Self::One { - denorm: lhs, - plain: rhs, - }, - (false, true) => Self::One { - denorm: rhs, - plain: lhs, - }, - (false, false) => Self::Neither, +impl<'a> NormalForm<'a> { + /// Classify `array` by its tensor extension type and (if present) any wrapping `L2Denorm`. + pub(crate) fn classify(array: &'a ArrayRef) -> Self { + if array.is::>() { + let (normalized, norms) = extract_l2_denorm_children(array); + return Self::Denormalized { normalized, norms }; + } + + let is_normalized = array + .dtype() + .as_extension_opt() + .is_some_and(|ext| ext.is::()); + + if is_normalized { + Self::Normalized { array } + } else { + Self::Plain + } + } + + /// Returns the unit-norm "shape" of the operand suitable for inner-product fast paths, if + /// one exists. For [`Self::Plain`] this returns `None`. + pub(crate) fn normalized_array(&self) -> Option<&ArrayRef> { + match self { + Self::Plain => None, + Self::Normalized { array } => Some(array), + Self::Denormalized { normalized, .. } => Some(normalized), } } } @@ -769,17 +921,25 @@ mod tests { use crate::scalar_fns::l2_denorm::L2Denorm; use crate::scalar_fns::l2_denorm::normalize_as_l2_denorm; - use crate::scalar_fns::l2_denorm::validate_l2_normalized_rows_against_norms; use crate::tests::SESSION; + use crate::types::normalized_vector::NormalizedVector; use crate::types::vector::Vector; use crate::utils::test_helpers::assert_close; - use crate::utils::test_helpers::constant_tensor_array; - use crate::utils::test_helpers::tensor_array; + use crate::utils::test_helpers::normalized_vector_array; use crate::utils::test_helpers::vector_array; - /// Evaluates L2 denorm on a tensor/vector array and returns the executed array. - fn eval_l2_denorm(normalized: ArrayRef, norms: ArrayRef, len: usize) -> VortexResult { + /// Evaluates L2 denorm on a [`Vector`] (rewrapped as a [`NormalizedVector`]) and the matching + /// norms, returning the executed array. Convenience wrapper for tests that already hold a + /// pre-normalized [`Vector`] (possibly wrapped in another encoding such as `MaskedArray`). + fn eval_l2_denorm( + vector_input: ArrayRef, + norms: ArrayRef, + len: usize, + ) -> VortexResult { let mut ctx = SESSION.create_execution_ctx(); + let canonical: ExtensionArray = vector_input.execute(&mut ctx)?; + let storage = canonical.storage_array().clone(); + let normalized = NormalizedVector::try_new(storage, &mut ctx)?; let result = L2Denorm::try_new_array(normalized, norms, len, &mut ctx)?; result.into_array().execute(&mut ctx) } @@ -827,17 +987,6 @@ mod tests { Ok(()) } - #[test] - fn l2_denorm_fixed_shape_tensors() -> VortexResult<()> { - let lhs = tensor_array(&[2, 2], &[0.5, 0.5, 0.5, 0.5, 1.0, 0.0, 0.0, 0.0])?; - let rhs = PrimitiveArray::from_iter([4.0f64, 2.0]).into_array(); - let actual = eval_l2_denorm(lhs, rhs, 2)?; - let expected = tensor_array(&[2, 2], &[2.0, 2.0, 2.0, 2.0, 2.0, 0.0, 0.0, 0.0])?; - - assert_tensor_arrays_eq(actual, expected)?; - Ok(()) - } - #[test] fn l2_denorm_null_propagation() -> VortexResult<()> { let lhs = vector_array(2, &[0.6, 0.8, 1.0, 0.0, 0.0, 0.0])?; @@ -878,10 +1027,21 @@ mod tests { } #[test] - fn l2_denorm_rejects_integer_tensor_lhs() -> VortexResult<()> { - let lhs = tensor_array(&[2], &[1i32, 2, 3, 4])?; + fn l2_denorm_accepts_plain_unit_vector_lhs() -> VortexResult<()> { + let lhs = vector_array(2, &[1.0, 0.0, 0.0, 1.0])?; let rhs = PrimitiveArray::from_iter([1.0f64, 1.0]).into_array(); + let mut ctx = SESSION.create_execution_ctx(); + let result = L2Denorm::try_new_array(lhs, rhs, 2, &mut ctx); + assert!(result.is_ok()); + Ok(()) + } + + #[test] + fn l2_denorm_rejects_unnormalized_plain_vector_lhs() -> VortexResult<()> { + let lhs = vector_array(2, &[3.0, 4.0, 0.0, 1.0])?; + let rhs = PrimitiveArray::from_iter([5.0f64, 1.0]).into_array(); + let mut ctx = SESSION.create_execution_ctx(); let result = L2Denorm::try_new_array(lhs, rhs, 2, &mut ctx); assert!(result.is_err()); @@ -890,49 +1050,61 @@ mod tests { #[test] fn l2_denorm_rejects_mismatched_rhs_ptype() -> VortexResult<()> { - let lhs = vector_array(2, &[1.0, 0.0, 0.0, 1.0])?; + let mut ctx = SESSION.create_execution_ctx(); + let lhs = normalized_vector_array(2, &[1.0, 0.0, 0.0, 1.0], &mut ctx)?; let rhs = PrimitiveArray::from_iter([1.0f32, 1.0]).into_array(); - let mut ctx = SESSION.create_execution_ctx(); let result = L2Denorm::try_new_array(lhs, rhs, 2, &mut ctx); assert!(result.is_err()); Ok(()) } #[test] - fn validate_l2_normalized_rows_accepts_normalized_f16_input() -> VortexResult<()> { - let input = vector_array(2, &[3.0f32, 4.0, 0.0, 0.0].map(half::f16::from_f32))?; + fn l2_denorm_rejects_non_primitive_rhs_without_panic() -> VortexResult<()> { let mut ctx = SESSION.create_execution_ctx(); - let roundtrip = normalize_as_l2_denorm(input, &mut ctx)?; - validate_l2_normalized_rows_against_norms(&roundtrip.child_at(0).clone(), None, &mut ctx)?; + let lhs = normalized_vector_array(2, &[1.0, 0.0, 0.0, 1.0], &mut ctx)?; + let rhs = vector_array(2, &[1.0f64, 0.0, 0.0, 1.0])?; + + let result = L2Denorm::try_new_array(lhs, rhs, 2, &mut ctx); + assert!(result.is_err()); Ok(()) } #[test] - fn validate_l2_normalized_rows_rejects_unnormalized_input() -> VortexResult<()> { - let input = vector_array(2, &[3.0, 4.0, 1.0, 0.0])?; + fn l2_denorm_rejects_length_mismatch_without_panic() -> VortexResult<()> { let mut ctx = SESSION.create_execution_ctx(); - let result = validate_l2_normalized_rows_against_norms(&input, None, &mut ctx); + let lhs = normalized_vector_array(2, &[1.0, 0.0, 0.0, 1.0], &mut ctx)?; + let rhs = PrimitiveArray::from_iter([1.0f64]).into_array(); + + let result = L2Denorm::try_new_array(lhs, rhs, 2, &mut ctx); assert!(result.is_err()); Ok(()) } #[test] - fn l2_denorm_try_new_array_rejects_unnormalized_child() -> VortexResult<()> { - let normalized = vector_array(2, &[3.0, 4.0, 1.0, 0.0])?; - let norms = PrimitiveArray::from_iter([5.0f64, 1.0]).into_array(); + fn normalized_vector_try_new_accepts_normalized_f16_input() -> VortexResult<()> { let mut ctx = SESSION.create_execution_ctx(); + let elements = [3.0f32, 4.0, 0.0, 0.0].map(half::f16::from_f32); + let roundtrip = normalize_as_l2_denorm(vector_array(2, &elements)?, &mut ctx)?; + // The first child is already a `NormalizedVector` by construction. + let normalized = roundtrip.child_at(0).clone(); + assert!(normalized.dtype().as_extension().is::(),); + Ok(()) + } - let result = L2Denorm::try_new_array(normalized, norms, 2, &mut ctx); + #[test] + fn normalized_vector_try_new_rejects_unnormalized_input() -> VortexResult<()> { + let mut ctx = SESSION.create_execution_ctx(); + let result = normalized_vector_array(2, &[3.0, 4.0, 1.0, 0.0], &mut ctx); assert!(result.is_err()); Ok(()) } #[test] fn l2_denorm_try_new_array_rejects_nonzero_row_with_zero_norm() -> VortexResult<()> { - let normalized = vector_array(2, &[1.0, 0.0, 0.0, 0.0])?; - let norms = PrimitiveArray::from_iter([0.0f64, 0.0]).into_array(); let mut ctx = SESSION.create_execution_ctx(); + let normalized = normalized_vector_array(2, &[1.0, 0.0, 0.0, 0.0], &mut ctx)?; + let norms = PrimitiveArray::from_iter([0.0f64, 0.0]).into_array(); let result = L2Denorm::try_new_array(normalized, norms, 2, &mut ctx); assert!(result.is_err()); @@ -941,9 +1113,9 @@ mod tests { #[test] fn l2_denorm_try_new_array_rejects_negative_norms() -> VortexResult<()> { - let normalized = vector_array(2, &[1.0, 0.0, 0.0, 1.0])?; - let norms = PrimitiveArray::from_iter([1.0f64, -1.0]).into_array(); let mut ctx = SESSION.create_execution_ctx(); + let normalized = normalized_vector_array(2, &[1.0, 0.0, 0.0, 1.0], &mut ctx)?; + let norms = PrimitiveArray::from_iter([1.0f64, -1.0]).into_array(); let result = L2Denorm::try_new_array(normalized, norms, 2, &mut ctx); assert!(result.is_err()); @@ -951,10 +1123,14 @@ mod tests { } #[test] - fn l2_denorm_new_array_unchecked_accepts_unnormalized_child() -> VortexResult<()> { - let normalized = vector_array(2, &[3.0, 4.0, 1.0, 0.0])?; - let norms = PrimitiveArray::from_iter([5.0f64, 1.0]).into_array(); + fn l2_denorm_new_array_unchecked_skips_zero_row_cross_check() -> VortexResult<()> { + // `L2Denorm::new_array_unchecked` accepts a NormalizedVector + norms without the zero-norm + // cross-check; useful for lossy encodings (e.g. TurboQuant). + let mut ctx = SESSION.create_execution_ctx(); + let normalized = normalized_vector_array(2, &[1.0, 0.0, 0.0, 1.0], &mut ctx)?; + let norms = PrimitiveArray::from_iter([0.0f64, 1.0]).into_array(); + // SAFETY: This test intentionally exercises the lossy escape hatch. let result = unsafe { L2Denorm::new_array_unchecked(normalized, norms, 2) }; assert!(result.is_ok()); Ok(()) @@ -971,28 +1147,6 @@ mod tests { Ok(()) } - #[test] - fn normalize_as_l2_denorm_roundtrips_fixed_shape_tensors() -> VortexResult<()> { - let input = tensor_array(&[2, 2], &[1.0, 2.0, 3.0, 4.0, 0.0, 0.0, 0.0, 0.0])?; - let mut ctx = SESSION.create_execution_ctx(); - let roundtrip = normalize_as_l2_denorm(input.clone(), &mut ctx)?; - let actual = roundtrip.into_array().execute(&mut ctx)?; - - assert_tensor_arrays_eq(actual, input)?; - Ok(()) - } - - #[test] - fn normalize_as_l2_denorm_supports_constant_tensors() -> VortexResult<()> { - let input = constant_tensor_array(&[2], &[3.0, 4.0], 3)?; - let mut ctx = SESSION.create_execution_ctx(); - let roundtrip = normalize_as_l2_denorm(input.clone(), &mut ctx)?; - let actual = roundtrip.into_array().execute(&mut ctx)?; - - assert_tensor_arrays_eq(actual, input)?; - Ok(()) - } - #[test] fn normalize_as_l2_denorm_supports_constant_vectors() -> VortexResult<()> { let input = Vector::constant_array(&[3.0, 4.0], 2)?; @@ -1013,16 +1167,18 @@ mod tests { let mut ctx = SESSION.create_execution_ctx(); let roundtrip = normalize_as_l2_denorm(input, &mut ctx)?; - // The normalized child must be an extension array whose storage is still constant. + // The normalized child is a `NormalizedVector(Vector(Constant))`. Drill past both + // extension layers and confirm the innermost FSL storage is still constant-backed. let normalized = roundtrip.child_at(0).clone(); let normalized_ext = normalized .as_opt::() .expect("normalized child should be an Extension array"); + let inner_vector = normalized_ext + .storage_array() + .as_opt::() + .expect("NormalizedVector storage should be a Vector extension array"); assert!( - normalized_ext - .storage_array() - .as_opt::() - .is_some(), + inner_vector.storage_array().as_opt::().is_some(), "normalized storage should stay constant after the fast path" ); @@ -1047,8 +1203,11 @@ mod tests { let input = vector_array(2, &[0.0, 0.0, 3.0, 4.0])?; let mut ctx = SESSION.create_execution_ctx(); let roundtrip = normalize_as_l2_denorm(input.clone(), &mut ctx)?; + // Normalized child is a `NormalizedVector` wrapping a `Vector` wrapping the FSL; drill + // past the outer `NormalizedVector` to reach the underlying `Vector`. let normalized: ExtensionArray = roundtrip.child_at(0).clone().execute(&mut ctx)?; - let storage: FixedSizeListArray = normalized.storage_array().clone().execute(&mut ctx)?; + let vector: ExtensionArray = normalized.storage_array().clone().execute(&mut ctx)?; + let storage: FixedSizeListArray = vector.storage_array().clone().execute(&mut ctx)?; let elements: PrimitiveArray = storage.elements().clone().execute(&mut ctx)?; let actual = roundtrip.into_array().execute(&mut ctx)?; @@ -1057,6 +1216,56 @@ mod tests { Ok(()) } + /// `NormalizedVector` input takes the fast path: re-applying norms must reconstruct the + /// original element values bit-for-bit (since the divisor in the slow path would be 1.0 + /// for unit rows and 0.0 for zero rows). + #[test] + fn normalize_as_l2_denorm_normalized_vector_round_trips_values() -> VortexResult<()> { + let mut ctx = SESSION.create_execution_ctx(); + let elements = [0.6, 0.8, 0.0, 0.0, 1.0, 0.0]; + let input = normalized_vector_array(2, &elements, &mut ctx)?; + + let roundtrip = normalize_as_l2_denorm(input, &mut ctx)?; + let executed: ExtensionArray = roundtrip.into_array().execute(&mut ctx)?; + let storage: FixedSizeListArray = executed.storage_array().clone().execute(&mut ctx)?; + let executed_elements: PrimitiveArray = storage.elements().clone().execute(&mut ctx)?; + + assert_close(executed_elements.as_slice::(), &elements); + Ok(()) + } + + /// The `NormalizedVector` fast path borrows `L2Norm`'s short-circuit, which emits `1.0` for + /// unit rows and `0.0` for zero rows. Tag the zero row with norm `0.0` here (not `1.0`) so a + /// downstream `L2Norm` over the resulting `L2Denorm` continues to read the right value. + #[test] + fn normalize_as_l2_denorm_normalized_vector_emits_unit_or_zero_norms() -> VortexResult<()> { + let mut ctx = SESSION.create_execution_ctx(); + let input = normalized_vector_array(2, &[0.6, 0.8, 0.0, 0.0, 1.0, 0.0], &mut ctx)?; + + let l2_denorm = normalize_as_l2_denorm(input, &mut ctx)?; + let norms: PrimitiveArray = l2_denorm.child_at(1).clone().execute(&mut ctx)?; + + assert_close(norms.as_slice::(), &[1.0, 0.0, 1.0]); + Ok(()) + } + + /// The `NormalizedVector` fast path returns the input unchanged as the `normalized` child + /// rather than re-allocating storage to satisfy the slow path's "always non-nullable" + /// invariant. Verify that the child dtype is still a `NormalizedVector` extension after the + /// fast path. + #[test] + fn normalize_as_l2_denorm_normalized_vector_preserves_normalized_child_dtype() + -> VortexResult<()> { + let mut ctx = SESSION.create_execution_ctx(); + let input = normalized_vector_array(2, &[0.6, 0.8, 1.0, 0.0], &mut ctx)?; + + let l2_denorm = normalize_as_l2_denorm(input, &mut ctx)?; + let normalized = l2_denorm.child_at(0).clone(); + + assert!(normalized.dtype().as_extension().is::()); + Ok(()) + } + /// Builds a non-nullable constant f64 norms array of length `len`. fn constant_f64_norms(value: f64, len: usize) -> ArrayRef { ConstantArray::new(Scalar::primitive(value, Nullability::NonNullable), len).into_array() @@ -1099,16 +1308,33 @@ mod tests { Ok(()) } + /// Regression: a non-nullable [`NormalizedVector`] child paired with a nullable-dtype + /// constant norms array (whose value happens to be non-null `1.0`) used to panic in the + /// constant-unit fast path because the extension's declared storage nullability no longer + /// matched the storage array's own nullability. The fix is on the [`ExtensionArray`] side, + /// where storage-dtype matching will ignore outer nullability. That relaxation is not yet on + /// this branch, so the test is ignored until the `ExtensionArray::try_new` change lands. #[test] - fn l2_denorm_constant_nonunit_norms_scales_fixed_shape_tensors() -> VortexResult<()> { - // The same constant-scaling fast path must also cover multi-dimensional fixed-shape - // tensors, where the backing elements buffer spans more than one slot per row. - let normalized = tensor_array(&[2, 2], &[0.5, 0.5, 0.5, 0.5, 1.0, 0.0, 0.0, 0.0])?; - let norms = constant_f64_norms(4.0, 2); + #[ignore = "depends on ExtensionArray::try_new ignoring outer storage nullability"] + fn l2_denorm_constant_unit_norms_nullable_scalar_nonnullable_normalized() -> VortexResult<()> { + let mut ctx = SESSION.create_execution_ctx(); + let normalized = normalized_vector_array(3, &[1.0, 0.0, 0.0, 0.0, 1.0, 0.0], &mut ctx)?; + let nullable_unit_norms = + ConstantArray::new(Scalar::primitive(1.0f64, Nullability::Nullable), 2).into_array(); - let actual = eval_l2_denorm(normalized, norms, 2)?; - let expected = tensor_array(&[2, 2], &[2.0, 2.0, 2.0, 2.0, 4.0, 0.0, 0.0, 0.0])?; - assert_tensor_arrays_eq(actual, expected)?; + let result = L2Denorm::try_new_array(normalized, nullable_unit_norms, 2, &mut ctx)?; + let actual: ArrayRef = result.into_array().execute(&mut ctx)?; + + // The output surfaces as a plain `Vector` whose outer nullability is the union of the + // two children (nullable here, since the norms child was nullable). + assert!(actual.dtype().as_extension().is::()); + assert!(actual.dtype().is_nullable()); + + // The element values round-trip: scaling unit vectors by `1.0` is a no-op. + let ext: ExtensionArray = actual.execute(&mut ctx)?; + let storage: FixedSizeListArray = ext.storage_array().clone().execute(&mut ctx)?; + let elements: PrimitiveArray = storage.elements().clone().execute(&mut ctx)?; + assert_close(elements.as_slice::(), &[1.0, 0.0, 0.0, 0.0, 1.0, 0.0]); Ok(()) } @@ -1117,8 +1343,7 @@ mod tests { /// inherits the input's nullability, giving us two different per-child nullabilities to /// round-trip. #[rstest] - #[case::vector(l2_denorm_vector_input())] - #[case::fixed_shape_tensor(l2_denorm_tensor_input())] + #[case::vector(vector_array(3, &[3.0, 4.0, 0.0, 0.0, 0.0, 0.0]).unwrap())] fn serde_round_trip(#[case] input: ArrayRef) -> VortexResult<()> { let mut ctx = SESSION.create_execution_ctx(); let original = normalize_as_l2_denorm(input, &mut ctx)?.into_array(); @@ -1145,13 +1370,4 @@ mod tests { assert_eq!(recovered.encoding_id(), original.encoding_id()); Ok(()) } - - fn l2_denorm_vector_input() -> ArrayRef { - vector_array(3, &[3.0, 4.0, 0.0, 0.0, 0.0, 0.0]).expect("valid vector array") - } - - fn l2_denorm_tensor_input() -> ArrayRef { - tensor_array(&[2, 2], &[1.0, 2.0, 3.0, 4.0, 0.0, 0.0, 0.0, 0.0]) - .expect("valid tensor array") - } } diff --git a/vortex-tensor/src/scalar_fns/l2_norm.rs b/vortex-tensor/src/scalar_fns/l2_norm.rs index 5d741eef55e..67b196eba4e 100644 --- a/vortex-tensor/src/scalar_fns/l2_norm.rs +++ b/vortex-tensor/src/scalar_fns/l2_norm.rs @@ -6,6 +6,8 @@ use std::fmt::Formatter; use num_traits::Float; +use num_traits::One; +use num_traits::Zero; use prost::Message; use vortex_array::ArrayRef; use vortex_array::ExecutionCtx; @@ -17,7 +19,6 @@ use vortex_array::arrays::PrimitiveArray; use vortex_array::arrays::ScalarFn as ScalarFnArrayEncoding; use vortex_array::arrays::ScalarFnArray; use vortex_array::arrays::extension::ExtensionArrayExt; -use vortex_array::arrays::scalar_fn::ExactScalarFn; use vortex_array::arrays::scalar_fn::ScalarFnArrayExt; use vortex_array::arrays::scalar_fn::ScalarFnArrayView; use vortex_array::arrays::scalar_fn::plugin::ScalarFnArrayParts; @@ -25,6 +26,7 @@ use vortex_array::arrays::scalar_fn::plugin::ScalarFnArrayVTable; use vortex_array::dtype::DType; use vortex_array::dtype::NativePType; use vortex_array::dtype::Nullability; +use vortex_array::dtype::PType; use vortex_array::dtype::proto::dtype as pb; use vortex_array::expr::Expression; use vortex_array::match_each_float_ptype; @@ -40,14 +42,12 @@ use vortex_array::serde::ArrayChildren; use vortex_buffer::Buffer; use vortex_error::VortexExpect; use vortex_error::VortexResult; -use vortex_error::vortex_ensure_eq; use vortex_error::vortex_err; use vortex_session::VortexSession; use crate::matcher::AnyTensor; -use crate::scalar_fns::l2_denorm::L2Denorm; +use crate::scalar_fns::l2_denorm::NormalForm; use crate::utils::extract_flat_elements; -use crate::utils::extract_l2_denorm_children; use crate::utils::validate_tensor_float_input; /// L2 norm (Euclidean norm) of a tensor or vector column. @@ -57,10 +57,11 @@ use crate::utils::validate_tensor_float_input; /// The input must be a tensor-like extension array with a float element type. The output is a float /// column of the same float type. /// -/// When the input is wrapped in [`L2Denorm`], this operator treats the stored norms as -/// authoritative. For lossy encodings such as TurboQuant, that means `L2Norm` may intentionally -/// read the stored norms instead of re-deriving them from fully decoded coordinates. That behavior -/// is part of the lossy storage contract, not a separate lossy-compute mode. +/// When the input is wrapped in [`L2Denorm`](crate::scalar_fns::l2_denorm::L2Denorm), this operator +/// treats the stored norms as authoritative. For lossy encodings such as TurboQuant, that means +/// `L2Norm` may intentionally read the stored norms instead of re-deriving them from fully decoded +/// coordinates. That behavior is part of the lossy storage contract, not a separate lossy-compute +/// mode. #[derive(Clone)] pub struct L2Norm; @@ -115,6 +116,7 @@ impl ScalarFnVTable for L2Norm { let tensor_match = validate_tensor_float_input(input_dtype)?; let ptype = tensor_match.element_ptype(); + // Inherit the nullability from the vectors themselves. let nullability = Nullability::from(input_dtype.is_nullable()); Ok(DType::Primitive(ptype, nullability)) } @@ -137,13 +139,23 @@ impl ScalarFnVTable for L2Norm { let norm_dtype = DType::Primitive(element_ptype, ext.nullability()); - // L2Norm(L2Denorm(normalized, norms)) is defined to read back the authoritative stored - // norms. Exact callers of lossy encodings like TurboQuant opt into that storage semantics - // instead of forcing a decode-and-recompute path here. - if input_ref.is::>() { - let (_, norms) = extract_l2_denorm_children(&input_ref); - vortex_ensure_eq!(norms.dtype(), &norm_dtype); - return Ok(norms); + // Short-circuit when the input carries a unit-norm representation already. + match NormalForm::classify(&input_ref) { + NormalForm::Denormalized { norms, .. } => { + return Ok(norms); + } + NormalForm::Normalized { .. } => { + // A naked `NormalizedVector` row is either unit norm or the zero vector by type. + // We still have to distinguish those two cases and preserve row validity. + return execute_normalized_vector_norms( + &input_ref, + element_ptype, + tensor_flat_size, + row_count, + ctx, + ); + } + NormalForm::Plain => {} } // Optimize for the constant array case. @@ -172,6 +184,9 @@ impl ScalarFnVTable for L2Norm { return Ok(norms); } + // Drill past any `NormalizedVector` wrapper so we always work with the underlying + // `Vector` extension array. + let input_ref = crate::types::normalized_vector::inner_vector_array(&input_ref, ctx)?; let input: ExtensionArray = input_ref.execute(ctx)?; let validity = input.as_ref().validity()?; @@ -261,6 +276,38 @@ fn l2_norm_row(v: &[T]) -> T { sum_sq.sqrt() } +/// Computes L2 norms for a [`NormalizedVector`](crate::normalized_vector::NormalizedVector) +/// without taking square roots: valid rows are either all-zero (`0.0`) or unit-norm (`1.0`). +fn execute_normalized_vector_norms( + input_ref: &ArrayRef, + element_ptype: PType, + tensor_flat_size: usize, + row_count: usize, + ctx: &mut ExecutionCtx, +) -> VortexResult { + // `NormalizedVector` storage is `Extension(Vector(FSL))`; drill to the inner `Vector` to + // reach the underlying FSL. + let vector_ref = crate::types::normalized_vector::inner_vector_array(input_ref, ctx)?; + let input: ExtensionArray = vector_ref.execute(ctx)?; + let validity = input.as_ref().validity()?; + let flat = extract_flat_elements(input.storage_array(), tensor_flat_size, ctx)?; + + match_each_float_ptype!(element_ptype, |T| { + let buffer: Buffer = (0..row_count) + .map(|i| { + if flat.row::(i).iter().all(|&x| x == T::zero()) { + T::zero() + } else { + T::one() + } + }) + .collect(); + + // SAFETY: The buffer length equals `row_count`, which matches the source validity length. + Ok(unsafe { PrimitiveArray::new_unchecked(buffer, validity) }.into_array()) + }) +} + #[cfg(test)] mod tests { @@ -289,6 +336,7 @@ mod tests { use crate::types::vector::Vector; use crate::utils::test_helpers::assert_close; use crate::utils::test_helpers::literal_vector_array; + use crate::utils::test_helpers::normalized_vector_array; use crate::utils::test_helpers::tensor_array; use crate::utils::test_helpers::vector_array; @@ -418,8 +466,8 @@ mod tests { } #[rstest] - #[case::fixed_shape_tensor(l2_norm_tensor_child(), 2)] - #[case::vector(l2_norm_vector_child(), 2)] + #[case::fixed_shape_tensor(tensor_array(&[3], &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap(), 2)] + #[case::vector(vector_array(3, &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap(), 2)] fn serde_round_trip(#[case] child: ArrayRef, #[case] len: usize) -> VortexResult<()> { let original = L2Norm::try_new_array(child.clone(), len)?.into_array(); @@ -444,11 +492,69 @@ mod tests { Ok(()) } - fn l2_norm_tensor_child() -> ArrayRef { - tensor_array(&[3], &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).expect("valid tensor array") + /// A naked [`NormalizedVector`](crate::normalized_vector::NormalizedVector) input must + /// short-circuit to `1.0` for unit rows and `0.0` for zero rows without taking square roots. + #[test] + fn naked_normalized_vector_returns_unit_norms() -> VortexResult<()> { + let mut ctx = SESSION.create_execution_ctx(); + let input = normalized_vector_array(2, &[1.0, 0.0, 0.6, 0.8, 0.0, 0.0], &mut ctx)?; + assert_close(&eval_l2_norm(input, 3)?, &[1.0, 1.0, 0.0]); + Ok(()) + } + + #[test] + fn naked_normalized_vector_preserves_nulls() -> VortexResult<()> { + let mut ctx = SESSION.create_execution_ctx(); + let input = normalized_vector_array(2, &[1.0, 0.0, 0.0, 0.0], &mut ctx)?; + let input = MaskedArray::try_new(input, Validity::from_iter([true, false]))?.into_array(); + + let result = L2Norm::try_new_array(input, 2)?; + let prim: PrimitiveArray = result.into_array().execute(&mut ctx)?; + + assert!(prim.is_valid(0, &mut ctx)?); + assert!(!prim.is_valid(1, &mut ctx)?); + assert_close(&[prim.as_slice::()[0]], &[1.0]); + Ok(()) } - fn l2_norm_vector_child() -> ArrayRef { - vector_array(3, &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).expect("valid vector array") + /// `L2Norm(L2Denorm(normalized, norms))` must return the stored norms verbatim — that is the + /// `NormalForm::Denormalized` short-circuit's whole purpose. We use a deliberately oddball + /// norm value (`7.0`) that no row could plausibly produce from a unit-norm child, so a + /// regression that fell through to the recompute path would round-trip a different number. + #[test] + fn denormalized_input_returns_stored_norms() -> VortexResult<()> { + let mut ctx = SESSION.create_execution_ctx(); + let l2_denorm = crate::utils::test_helpers::l2_denorm_array( + 2, + &[1.0, 0.0, 0.6, 0.8], + &[7.0, 5.0], + &mut ctx, + )?; + + let result = L2Norm::try_new_array(l2_denorm, 2)?; + let prim: PrimitiveArray = result.into_array().execute(&mut ctx)?; + + assert_close(prim.as_slice::(), &[7.0, 5.0]); + Ok(()) + } + + /// The `Denormalized` short-circuit must propagate null rows in the stored norms child, + /// since validity on a `L2Denorm` lives entirely in its norms. + #[test] + fn denormalized_input_preserves_norm_nulls() -> VortexResult<()> { + let mut ctx = SESSION.create_execution_ctx(); + let normalized = normalized_vector_array(2, &[0.6, 0.8, 1.0, 0.0], &mut ctx)?; + let norms = PrimitiveArray::from_option_iter([Some(5.0f64), None]).into_array(); + let l2_denorm = + crate::scalar_fns::l2_denorm::L2Denorm::try_new_array(normalized, norms, 2, &mut ctx)? + .into_array(); + + let result = L2Norm::try_new_array(l2_denorm, 2)?; + let prim: PrimitiveArray = result.into_array().execute(&mut ctx)?; + + assert!(prim.is_valid(0, &mut ctx)?); + assert!(!prim.is_valid(1, &mut ctx)?); + assert_close(&[prim.as_slice::()[0]], &[5.0]); + Ok(()) } } diff --git a/vortex-tensor/src/scalar_fns/sorf_transform/mod.rs b/vortex-tensor/src/scalar_fns/sorf_transform/mod.rs index 26d38e87a1e..a6862465eba 100644 --- a/vortex-tensor/src/scalar_fns/sorf_transform/mod.rs +++ b/vortex-tensor/src/scalar_fns/sorf_transform/mod.rs @@ -8,10 +8,10 @@ //! Walsh-Hadamard transform to achieve O(d log d) matrix-vector products instead of the O(d^2) cost //! of a dense orthogonal matrix. //! -//! This module wraps a [`Vector`] extension array whose dimension is the padded SORF dimension -//! (e.g. a `Vector` wrapping `FSL(Dict(codes, centroids))`) and applies the inverse SORF transform -//! at execution time, producing a [`Vector`] extension array with the original (pre-padding) -//! dimensionality. +//! This module wraps a [`Vector`] or [`NormalizedVector`] extension array whose dimension is the +//! padded SORF dimension (e.g. a `Vector` wrapping `FSL(Dict(codes, centroids))`) and applies the +//! inverse SORF transform at execution time, producing a plain [`Vector`] extension array with the +//! original (pre-padding) dimensionality. //! //! The transform parameters are stored as a deterministic seed in [`SorfOptions`], so the //! [`SorfMatrix`] is reconstructed cheaply at decode time. Sign diagonals are defined by Vortex's @@ -19,9 +19,9 @@ //! //! # Input element type: `f32` only (TODO(connor): for now...) //! -//! The child [`Vector`] **must** have `f32` storage elements. This is a hard constraint that is -//! enforced by `SorfTransform`'s `return_dtype` check. Callers with `f16` or `f64` source data need -//! to cast to `f32` before wrapping in a [`Vector`] and handing it to SorfTransform. +//! The child vector extension **must** have `f32` storage elements. This is a hard constraint that +//! is enforced by `SorfTransform`'s `return_dtype` check. Callers with `f16` or `f64` source data +//! need to cast to `f32` before wrapping in a vector extension and handing it to SorfTransform. //! //! The reason for this constraint is that TurboQuant (the only production caller today) stores its //! dictionary centroids as `f32`, and the SORF transform itself operates internally in `f32`. @@ -40,6 +40,7 @@ //! //! [sorf-paper]: https://proceedings.neurips.cc/paper_files/paper/2016/file/53adaf494dc89ef7196d73636eb2451b-Paper.pdf //! [`Vector`]: crate::vector::Vector +//! [`NormalizedVector`]: crate::normalized_vector::NormalizedVector use std::fmt; use std::fmt::Formatter; @@ -59,10 +60,10 @@ mod vtable; /// Inverse SORF orthogonal transform scalar function. /// -/// Takes a [`Vector`](crate::vector::Vector) extension child at the padded dimension with `f32` -/// storage, applies the inverse structured Walsh-Hadamard orthogonal transform, truncates to the -/// original (pre-padding) dimension, casts element-wise to [`SorfOptions::element_ptype`], and -/// wraps the result in a new [`Vector`](crate::vector::Vector) extension array. +/// Takes a vector extension child at the padded dimension with `f32` storage, applies the inverse +/// structured Walsh-Hadamard orthogonal transform, truncates to the original (pre-padding) +/// dimension, casts element-wise to [`SorfOptions::element_ptype`], and wraps the result in a new +/// plain [`Vector`](crate::vector::Vector) extension array. /// /// See the [module-level docs](crate::scalar_fns::sorf_transform) for the rationale behind the /// `f32`-only input constraint. @@ -96,16 +97,18 @@ impl SorfTransform { /// Constructs a validated [`ScalarFnArray`] that lazily applies the inverse SORF transform. /// - /// The `child` must be a [`Vector`] extension array (or an array that executes to one) with: + /// The `child` must be a [`Vector`] or [`NormalizedVector`] extension array (or an array that + /// executes to one) with: /// /// - dimension equal to `padded_dim` (i.e. `options.dimension.next_power_of_two()`), and /// - `f32` storage elements. This is a hard requirement today; see the /// [module-level docs](crate::scalar_fns::sorf_transform) for the rationale. /// - /// The output [`Vector`] has dimension `options.dimension` and element type + /// The output plain [`Vector`] has dimension `options.dimension` and element type /// `options.element_ptype`. /// /// [`Vector`]: crate::vector::Vector + /// [`NormalizedVector`]: crate::normalized_vector::NormalizedVector pub fn try_new_array( options: &SorfOptions, child: ArrayRef, diff --git a/vortex-tensor/src/scalar_fns/sorf_transform/tests.rs b/vortex-tensor/src/scalar_fns/sorf_transform/tests.rs index 46abc66db71..0934657dc7c 100644 --- a/vortex-tensor/src/scalar_fns/sorf_transform/tests.rs +++ b/vortex-tensor/src/scalar_fns/sorf_transform/tests.rs @@ -7,6 +7,7 @@ use std::sync::Arc; +use prost::Message; use vortex_array::ArrayPlugin; use vortex_array::ArrayRef; use vortex_array::IntoArray; @@ -14,9 +15,11 @@ use vortex_array::VortexSessionExecute; use vortex_array::arrays::ExtensionArray; use vortex_array::arrays::FixedSizeListArray; use vortex_array::arrays::PrimitiveArray; +use vortex_array::arrays::ScalarFn as ScalarFnArrayEncoding; use vortex_array::arrays::dict::DictArray; use vortex_array::arrays::extension::ExtensionArrayExt; use vortex_array::arrays::fixed_size_list::FixedSizeListArrayExt; +use vortex_array::arrays::scalar_fn::ScalarFnArrayExt; use vortex_array::arrays::scalar_fn::plugin::ScalarFnArrayPlugin; use vortex_array::dtype::DType; use vortex_array::dtype::Nullability; @@ -36,6 +39,7 @@ use crate::encodings::turboquant::centroids::compute_centroid_boundaries; use crate::encodings::turboquant::centroids::compute_or_get_centroids; use crate::encodings::turboquant::centroids::find_nearest_centroid; use crate::tests::SESSION; +use crate::types::normalized_vector::NormalizedVector; use crate::types::vector::Vector; /// Build a unit-normalized input vector array and forward-transform + quantize it, returning @@ -361,6 +365,24 @@ fn rejects_non_vector_extension_child_at_construction() { assert!(err.to_string().contains("Vector extension")); } +#[test] +fn accepts_normalized_vector_child_and_returns_plain_vector() -> VortexResult<()> { + let options = default_options(128, 42); + let mut values = vec![0.0f32; 128]; + values[0] = 1.0; + let elements = PrimitiveArray::from_iter(values).into_array(); + let fsl = FixedSizeListArray::try_new(elements, 128, Validity::NonNullable, 1)?; + let mut ctx = SESSION.create_execution_ctx(); + let child = NormalizedVector::try_new(fsl.into_array(), &mut ctx)?; + + let sorf = SorfTransform::try_new_array(&options, child, 1)?.into_array(); + assert!(sorf.dtype().as_extension().is::()); + + let result: ExtensionArray = sorf.execute(&mut ctx)?; + assert!(result.dtype().as_extension().is::()); + Ok(()) +} + #[test] fn rejects_wrong_padded_dimension_at_construction() { // Options say dimension=128 so padded_dim should be 128. Pass a Vector<256> instead. @@ -455,6 +477,32 @@ fn trivial_padded_vector(padded_dim: u32, num_rows: usize, validity: Validity) - ExtensionArray::new(ext_dtype, fsl.into_array()).into_array() } +fn trivial_padded_normalized_vector( + padded_dim: u32, + num_rows: usize, + validity: Validity, +) -> VortexResult { + let elements = PrimitiveArray::new( + Buffer::::zeroed(num_rows * padded_dim as usize), + Validity::NonNullable, + ); + let fsl = FixedSizeListArray::try_new(elements.into_array(), padded_dim, validity, num_rows)?; + let mut ctx = SESSION.create_execution_ctx(); + NormalizedVector::try_new(fsl.into_array(), &mut ctx) +} + +#[derive(Clone, prost::Message)] +struct LegacySorfTransformMetadata { + #[prost(uint64, tag = "1")] + seed: u64, + #[prost(uint32, tag = "2")] + num_rounds: u32, + #[prost(uint32, tag = "3")] + dimension: u32, + #[prost(enumeration = "PType", tag = "4")] + element_ptype: i32, +} + #[rstest::rstest] // Non-power-of-two dimension to exercise `padded_dim = dim.next_power_of_two()`. #[case::power_of_two_dim(128, Validity::NonNullable)] @@ -491,5 +539,99 @@ fn serde_round_trip(#[case] dimensions: u32, #[case] validity: Validity) -> Vort assert_eq!(recovered.dtype(), original.dtype()); assert_eq!(recovered.len(), original.len()); assert_eq!(recovered.encoding_id(), original.encoding_id()); + let recovered_scalar_fn = recovered.as_::(); + assert!( + recovered_scalar_fn + .child_at(0) + .dtype() + .as_extension() + .is::() + ); + Ok(()) +} + +#[test] +fn serde_round_trip_preserves_normalized_vector_child_dtype() -> VortexResult<()> { + let dimension = 128; + let num_rows = 4; + let options = default_options(dimension, 42); + let child = trivial_padded_normalized_vector( + dimension.next_power_of_two(), + num_rows, + Validity::NonNullable, + )?; + let original = SorfTransform::try_new_array(&options, child.clone(), num_rows)?.into_array(); + + let plugin = ScalarFnArrayPlugin::new(SorfTransform); + let metadata = plugin + .serialize(&original, &SESSION)? + .expect("SorfTransform serialize must produce metadata"); + + let children = vec![child]; + let recovered = plugin.deserialize( + original.dtype(), + original.len(), + &metadata, + &[], + &children, + &SESSION, + )?; + + assert_eq!(recovered.dtype(), original.dtype()); + assert_eq!(recovered.len(), original.len()); + assert_eq!(recovered.encoding_id(), original.encoding_id()); + let recovered_scalar_fn = recovered.as_::(); + assert!( + recovered_scalar_fn + .child_at(0) + .dtype() + .as_extension() + .is::() + ); + Ok(()) +} + +#[test] +fn serde_legacy_metadata_derives_plain_vector_child_dtype() -> VortexResult<()> { + let dimension = 128; + let num_rows = 4; + let options = default_options(dimension, 42); + let child = trivial_padded_vector( + dimension.next_power_of_two(), + num_rows, + Validity::NonNullable, + ); + let original = SorfTransform::try_new_array(&options, child.clone(), num_rows)?.into_array(); + + let legacy_metadata = LegacySorfTransformMetadata { + seed: options.seed, + num_rounds: u32::from(options.num_rounds), + dimension: options.dimensions, + element_ptype: options.element_ptype as i32, + } + .encode_to_vec(); + + let plugin = ScalarFnArrayPlugin::new(SorfTransform); + let children = vec![child]; + let recovered = plugin.deserialize( + original.dtype(), + original.len(), + &legacy_metadata, + &[], + &children, + &SESSION, + )?; + + assert_eq!(recovered.dtype(), original.dtype()); + assert_eq!(recovered.len(), original.len()); + assert_eq!(recovered.encoding_id(), original.encoding_id()); + let recovered_scalar_fn = recovered.as_::(); + assert!( + recovered_scalar_fn + .child_at(0) + .dtype() + .as_extension() + .is::() + ); Ok(()) } diff --git a/vortex-tensor/src/scalar_fns/sorf_transform/vtable.rs b/vortex-tensor/src/scalar_fns/sorf_transform/vtable.rs index 827f8e6a796..ebe633ed7f6 100644 --- a/vortex-tensor/src/scalar_fns/sorf_transform/vtable.rs +++ b/vortex-tensor/src/scalar_fns/sorf_transform/vtable.rs @@ -50,7 +50,7 @@ use super::SorfOptions; use super::SorfTransform; use super::rotation::SorfMatrix; use super::validate_sorf_options; -use crate::types::vector::AnyVector; +use crate::matcher::AnyTensor; use crate::types::vector::Vector; impl ScalarFnVTable for SorfTransform { @@ -88,14 +88,17 @@ impl ScalarFnVTable for SorfTransform { let child_dtype = &arg_dtypes[0]; let vector_metadata = child_dtype .as_extension_opt() - .and_then(|ext| ext.metadata_opt::()) + .and_then(|ext| ext.metadata_opt::()) .ok_or_else(|| { - vortex_err!("SorfTransform child must be a Vector extension, got {child_dtype}") + vortex_err!( + "SorfTransform child must be a Vector or NormalizedVector extension, got \ + {child_dtype}" + ) })?; let expected_padded = options.dimensions.next_power_of_two(); vortex_ensure_eq!( - vector_metadata.dimensions(), + vector_metadata.list_size(), expected_padded, "SorfTransform child Vector must have dimension {expected_padded} (next power of two \ for dimension {})", @@ -120,6 +123,8 @@ impl ScalarFnVTable for SorfTransform { child_dtype.nullability(), ); + // The inverse SORF transform does not preserve unit norm on the output, even when the + // child is a [`NormalizedVector`]. Surface the output as a plain [`Vector`]. let _ = vector_metadata; let ext_dtype = ExtDType::::try_new(EmptyMetadata, storage_dtype)?.erased(); @@ -151,10 +156,12 @@ impl ScalarFnVTable for SorfTransform { }); } - // Execute the child to get the Vector extension wrapping an FSL of f32 coordinates. The - // `return_dtype` check guarantees the child is a `Vector`, so the - // materialized FSL elements are always f32. - let child_ext: ExtensionArray = args.get(0)?.execute(ctx)?; + // Execute the child to get either a `Vector` extension or a `NormalizedVector` + // wrapping a `Vector` over an FSL of f32 coordinates. The `return_dtype` check guarantees + // the shape is `Vector` at the FSL level, so drill past any + // `NormalizedVector` wrapper before unpacking. + let child_ref = crate::types::normalized_vector::inner_vector_array(&args.get(0)?, ctx)?; + let child_ext: ExtensionArray = child_ref.execute(ctx)?; let child_validity = child_ext.as_ref().validity()?; let child_fsl: FixedSizeListArray = child_ext.storage_array().clone().execute(ctx)?; let padded_dim = @@ -198,9 +205,10 @@ impl ScalarFnVTable for SorfTransform { /// Metadata for a serialized [`SorfTransform`] array. /// -/// Stores the full [`SorfOptions`] inline along with the child [`DType`]. Older metadata omitted -/// this field; deserialization derives the legacy plain-`Vector` child dtype from the parent dtype -/// in that case. +/// Stores the full [`SorfOptions`] inline along with the child [`DType`]. The child dtype records +/// whether the input was a plain [`Vector`] or [`NormalizedVector`](crate::normalized_vector::NormalizedVector). +/// Older metadata omitted this field; deserialization derives the legacy plain-`Vector` child dtype +/// from the parent dtype in that case. #[derive(Clone, prost::Message)] pub(super) struct SorfTransformMetadata { #[prost(uint64, tag = "1")] diff --git a/vortex-tensor/src/types/mod.rs b/vortex-tensor/src/types/mod.rs index 3ecd2826743..97aa932f9d6 100644 --- a/vortex-tensor/src/types/mod.rs +++ b/vortex-tensor/src/types/mod.rs @@ -4,4 +4,5 @@ //! Internal homes for tensor extension types. pub mod fixed_shape; +pub mod normalized_vector; pub mod vector; diff --git a/vortex-tensor/src/types/normalized_vector/matcher.rs b/vortex-tensor/src/types/normalized_vector/matcher.rs new file mode 100644 index 00000000000..483fa45c12b --- /dev/null +++ b/vortex-tensor/src/types/normalized_vector/matcher.rs @@ -0,0 +1,124 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +use vortex_array::dtype::DType; +use vortex_array::dtype::extension::ExtDTypeRef; +use vortex_array::dtype::extension::Matcher; +use vortex_error::VortexExpect; +use vortex_error::vortex_panic; + +use crate::types::normalized_vector::NormalizedVector; +use crate::types::vector::Vector; +use crate::types::vector::VectorMatcherMetadata; + +/// Matcher that accepts only the [`NormalizedVector`] extension type. +/// +/// Use this when a consumer requires the unit-norm guarantee. Callers that accept any +/// vector-shaped extension should use [`AnyTensor`](crate::matcher::AnyTensor). +pub struct AnyNormalizedVector; + +impl Matcher for AnyNormalizedVector { + type Match<'a> = VectorMatcherMetadata; + + fn try_match<'a>(ext_dtype: &'a ExtDTypeRef) -> Option> { + if !ext_dtype.is::() { + return None; + } + + // `NormalizedVector` stores a `Vector(FixedSizeList)`. Drill into the inner + // `Vector` to recover the dimension and element dtype. + let DType::Extension(inner_ext) = ext_dtype.storage_dtype() else { + vortex_panic!( + "`NormalizedVector` storage must be `DType::Extension(Vector)`, \ + got {}", + ext_dtype.storage_dtype(), + ) + }; + if !inner_ext.is::() { + vortex_panic!( + "`NormalizedVector` inner extension must be `Vector`, got {}", + inner_ext.id(), + ) + } + let DType::FixedSizeList(element_dtype, list_size, _) = inner_ext.storage_dtype() else { + vortex_panic!( + "inner `Vector` storage must be `FixedSizeList`, got {}", + inner_ext.storage_dtype(), + ) + }; + assert!(element_dtype.is_float(), "element dtype must be float"); + assert!( + !element_dtype.is_nullable(), + "element dtype must be non-nullable" + ); + + let metadata = VectorMatcherMetadata::try_new(element_dtype.as_ptype(), *list_size) + .vortex_expect("`NormalizedVector` inner Vector did not have float elements"); + + Some(metadata) + } +} + +#[cfg(test)] +mod tests { + use std::sync::Arc; + + use vortex_array::dtype::DType; + use vortex_array::dtype::Nullability; + use vortex_array::dtype::PType; + use vortex_array::dtype::extension::ExtDType; + use vortex_array::extension::EmptyMetadata; + use vortex_error::VortexResult; + + use super::*; + use crate::types::vector::AnyVector; + use crate::types::vector::Vector; + + fn fsl_storage(element_ptype: PType, dimensions: u32) -> DType { + DType::FixedSizeList( + Arc::new(DType::Primitive(element_ptype, Nullability::NonNullable)), + dimensions, + Nullability::NonNullable, + ) + } + + fn nv_storage(element_ptype: PType, dimensions: u32) -> VortexResult { + let vector = + ExtDType::::try_new(EmptyMetadata, fsl_storage(element_ptype, dimensions))? + .erased(); + Ok(DType::Extension(vector)) + } + + #[test] + fn matches_normalized_vector_dtype() -> VortexResult<()> { + let ext_dtype = + ExtDType::::try_new(EmptyMetadata, nv_storage(PType::F32, 128)?)? + .erased(); + + let metadata = ext_dtype.metadata::(); + assert_eq!(metadata.element_ptype(), PType::F32); + assert_eq!(metadata.dimensions(), 128); + Ok(()) + } + + #[test] + fn rejects_plain_vector() -> VortexResult<()> { + let ext_dtype = + ExtDType::::try_new(EmptyMetadata, fsl_storage(PType::F32, 128))?.erased(); + + assert!(ext_dtype.metadata_opt::().is_none()); + Ok(()) + } + + #[test] + fn any_vector_does_not_match_normalized_vector() -> VortexResult<()> { + let ext_dtype = + ExtDType::::try_new(EmptyMetadata, nv_storage(PType::F32, 128)?)? + .erased(); + + // `AnyVector` is strict: it only matches plain `Vector`. Use `AnyTensor` to accept + // both `Vector` and `NormalizedVector`. + assert!(ext_dtype.metadata_opt::().is_none()); + Ok(()) + } +} diff --git a/vortex-tensor/src/types/normalized_vector/mod.rs b/vortex-tensor/src/types/normalized_vector/mod.rs new file mode 100644 index 00000000000..605f6670c27 --- /dev/null +++ b/vortex-tensor/src/types/normalized_vector/mod.rs @@ -0,0 +1,192 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +//! Normalized vector extension type over [`Vector`](crate::vector::Vector) storage whose +//! rows are guaranteed (or asserted, for lossy encodings) to have unit L2 norm. + +use num_traits::ToPrimitive; +use vortex_array::ArrayRef; +use vortex_array::ExecutionCtx; +use vortex_array::IntoArray; +use vortex_array::arrays::ExtensionArray; +use vortex_array::arrays::extension::ExtensionArrayExt; +use vortex_array::extension::EmptyMetadata; +use vortex_array::match_each_float_ptype; +use vortex_error::VortexResult; +use vortex_error::vortex_ensure; + +use crate::types::vector::AnyVector; +use crate::types::vector::Vector; +use crate::utils::extract_flat_elements; +use crate::utils::unit_norm_tolerance; + +/// Extension type over [`Vector`](crate::vector::Vector) storage that asserts every valid row is +/// L2-normalized (unit-norm) or the zero vector. +/// +/// The storage dtype is `DType::Extension(Vector(FixedSizeList))`, i.e. a +/// [`Vector`](crate::vector::Vector) extension array. Downstream operators such as +/// [`L2Denorm`](crate::scalar_fns::l2_denorm::L2Denorm), +/// [`L2Norm`](crate::scalar_fns::l2_norm::L2Norm), +/// [`InnerProduct`](crate::scalar_fns::inner_product::InnerProduct), and +/// [`CosineSimilarity`](crate::scalar_fns::cosine_similarity::CosineSimilarity) short-circuit +/// arithmetic when they see this type. +#[derive(Clone, Debug, Default, PartialEq, Eq, Hash)] +pub struct NormalizedVector; + +impl NormalizedVector { + /// Wraps a [`FixedSizeList`](vortex_array::arrays::FixedSizeListArray) of float elements + /// as a [`NormalizedVector`] extension array, wrapping the FSL in a + /// [`Vector`](crate::vector::Vector) first. + /// + /// Every valid row is checked to be unit-norm or the zero vector before returning. + /// + /// # Errors + /// + /// Returns an error if `fsl` is not a `FixedSizeList` of non-nullable float elements, or if + /// any valid row's L2 norm is not `1.0` (or `0.0`) within the tolerance implied by the + /// element precision. + pub fn try_new(fsl: ArrayRef, ctx: &mut ExecutionCtx) -> VortexResult { + let vector = Vector::try_new_vector_array(fsl)?; + // Validate before wrapping so we iterate the inner `Vector` storage directly. The + // `ExtensionArray::try_new_from_vtable` call below runs `validate_dtype` (which only + // checks the storage dtype shape), but the unit-norm check is a bulk row operation we + // run explicitly here. + validate_unit_norm_rows(&vector, ctx)?; + Ok( + ExtensionArray::try_new_from_vtable(NormalizedVector, EmptyMetadata, vector)? + .into_array(), + ) + } + + /// Wraps `fsl` as a [`NormalizedVector`] extension array **without** validating that rows + /// are unit-norm. The FSL is still wrapped in a [`Vector`](crate::vector::Vector) first. + /// + /// # Safety + /// + /// Every valid row must be unit-norm or the zero vector. Lossy approximations (e.g. + /// TurboQuant) deliberately relax this, but still treat the claim as authoritative + /// downstream. Violating this does not cause memory unsafety but will produce silently + /// incorrect results. + /// + /// # Errors + /// + /// Returns an error if `fsl` is not a `FixedSizeList` of non-nullable float elements. + pub unsafe fn new_unchecked(fsl: ArrayRef) -> VortexResult { + let vector = Vector::try_new_vector_array(fsl)?; + Ok( + ExtensionArray::try_new_from_vtable(NormalizedVector, EmptyMetadata, vector)? + .into_array(), + ) + } + + /// Wraps an already-constructed [`Vector`](crate::vector::Vector) extension array as a + /// [`NormalizedVector`] **without** validating that rows are unit-norm. + /// + /// # Safety + /// + /// Every valid row of `vector` must be unit-norm or the zero vector. + /// + /// # Errors + /// + /// Returns an error if `vector.dtype()` is not a `Vector` extension dtype. + pub unsafe fn wrap_vector_unchecked(vector: ArrayRef) -> VortexResult { + Ok( + ExtensionArray::try_new_from_vtable(NormalizedVector, EmptyMetadata, vector)? + .into_array(), + ) + } +} + +/// Validates that every valid row of a [`Vector`](crate::vector::Vector) extension array has L2 +/// norm `1.0` or `0.0` within the element-precision tolerance. +/// +/// The input is expected to be a `Vector` extension array (not a raw `FixedSizeList`), matching +/// the storage of a `NormalizedVector`. +pub(crate) fn validate_unit_norm_rows( + vector_array: &ArrayRef, + ctx: &mut ExecutionCtx, +) -> VortexResult<()> { + let row_count = vector_array.len(); + if row_count == 0 { + return Ok(()); + } + + let vector_metadata = vector_array.dtype().as_extension().metadata::(); + let element_ptype = vector_metadata.element_ptype(); + let dim = vector_metadata.dimensions() as usize; + let tolerance = unit_norm_tolerance(element_ptype, dim); + + let ext: ExtensionArray = vector_array.clone().execute(ctx)?; + let validity = ext.as_ref().validity()?; + let flat = extract_flat_elements(ext.storage_array(), dim, ctx)?; + + match_each_float_ptype!(element_ptype, |T| { + for i in 0..row_count { + if !validity.is_valid(i)? { + continue; + } + + let row_norm_sq = flat.row::(i).iter().fold(0.0f64, |sum_sq, x| { + let value = ToPrimitive::to_f64(x).unwrap_or(f64::NAN); + sum_sq + value * value + }); + let row_norm = row_norm_sq.sqrt(); + + vortex_ensure!( + row_norm == 0.0 || (row_norm - 1.0).abs() <= tolerance, + "NormalizedVector row {i} has L2 norm {row_norm:.6}, expected 1.0 or 0.0", + ); + } + }); + + Ok(()) +} + +/// Returns the underlying `FixedSizeList` storage dtype for a vector-shaped extension dtype. +/// +/// For a plain [`Vector`], this is the direct storage dtype. For a [`NormalizedVector`] +/// it drills through one extra extension layer. +pub(crate) fn vector_fsl_storage_dtype( + ext: &vortex_array::dtype::extension::ExtDTypeRef, +) -> Option { + use vortex_array::dtype::DType; + if ext.is::() { + Some(ext.storage_dtype().clone()) + } else if ext.is::() { + let DType::Extension(inner) = ext.storage_dtype() else { + return None; + }; + if !inner.is::() { + return None; + } + Some(inner.storage_dtype().clone()) + } else { + None + } +} + +/// Returns the underlying `Vector` extension array inside a vector-shaped extension array. +/// +/// For a [`NormalizedVector`] array, this executes the outer extension and returns its +/// `Vector` storage child. For a plain [`Vector`] array, it returns the array itself (after +/// canonicalizing to an `ExtensionArray`). +pub(crate) fn inner_vector_array( + array: &ArrayRef, + ctx: &mut ExecutionCtx, +) -> VortexResult { + let is_normalized = array + .dtype() + .as_extension_opt() + .is_some_and(|ext| ext.is::()); + if is_normalized { + let ext: ExtensionArray = array.clone().execute(ctx)?; + Ok(ext.storage_array().clone()) + } else { + Ok(array.clone()) + } +} + +mod matcher; +mod vtable; + +pub use matcher::AnyNormalizedVector; diff --git a/vortex-tensor/src/types/normalized_vector/vtable.rs b/vortex-tensor/src/types/normalized_vector/vtable.rs new file mode 100644 index 00000000000..f50a0ac98fd --- /dev/null +++ b/vortex-tensor/src/types/normalized_vector/vtable.rs @@ -0,0 +1,129 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +use vortex_array::dtype::DType; +use vortex_array::dtype::extension::ExtDType; +use vortex_array::dtype::extension::ExtId; +use vortex_array::dtype::extension::ExtVTable; +use vortex_array::extension::EmptyMetadata; +use vortex_array::scalar::ScalarValue; +use vortex_error::VortexResult; +use vortex_error::vortex_bail; +use vortex_error::vortex_ensure; + +use crate::types::normalized_vector::NormalizedVector; +use crate::types::vector::Vector; + +impl ExtVTable for NormalizedVector { + type Metadata = EmptyMetadata; + type NativeValue<'a> = &'a ScalarValue; + + fn id(&self) -> ExtId { + ExtId::new("vortex.tensor.normalized_vector") + } + + fn validate_dtype(ext_dtype: &ExtDType) -> VortexResult<()> { + // Storage must be an extension-wrapped `Vector`. The inner `Vector` vtable's + // `validate_dtype` already ran when the inner `ExtDType` was constructed, so we + // only need to confirm the storage is in fact a `Vector` extension. + let DType::Extension(inner) = ext_dtype.storage_dtype() else { + vortex_bail!( + "`NormalizedVector` storage must be an extension type, got {}", + ext_dtype.storage_dtype(), + ); + }; + vortex_ensure!( + inner.is::(), + "`NormalizedVector` storage must be a `Vector` extension, got {}", + inner.id(), + ); + Ok(()) + } + + fn unpack_native<'a>( + _ext_dtype: &'a ExtDType, + storage_value: &'a ScalarValue, + ) -> VortexResult> { + // Per-scalar validation is a no-op: unit-norm is enforced in bulk by + // `validate_unit_norm_rows` at array construction, matching how `L2Denorm` + // validates up front rather than on each scalar access. + Ok(storage_value) + } + + fn serialize_metadata(&self, _metadata: &Self::Metadata) -> VortexResult> { + Ok(Vec::new()) + } + + fn deserialize_metadata(&self, _metadata: &[u8]) -> VortexResult { + Ok(EmptyMetadata) + } +} + +#[cfg(test)] +mod tests { + use std::sync::Arc; + + use rstest::rstest; + use vortex_array::dtype::DType; + use vortex_array::dtype::Nullability; + use vortex_array::dtype::PType; + use vortex_array::dtype::extension::ExtDType; + use vortex_array::dtype::extension::ExtVTable; + use vortex_array::extension::EmptyMetadata; + use vortex_error::VortexResult; + + use crate::types::normalized_vector::NormalizedVector; + use crate::types::vector::Vector; + + /// The NormalizedVector storage dtype is `DType::Extension(Vector(FSL))`. + fn nv_storage_dtype(ptype: PType, size: u32, nullability: Nullability) -> VortexResult { + let fsl = DType::FixedSizeList( + Arc::new(DType::Primitive(ptype, Nullability::NonNullable)), + size, + nullability, + ); + let vector = ExtDType::::try_new(EmptyMetadata, fsl)?.erased(); + Ok(DType::Extension(vector)) + } + + #[rstest] + #[case::f16(PType::F16)] + #[case::f32(PType::F32)] + #[case::f64(PType::F64)] + fn validate_accepts_float_types(#[case] ptype: PType) -> VortexResult<()> { + let storage = nv_storage_dtype(ptype, 64, Nullability::NonNullable)?; + ExtDType::::try_new(EmptyMetadata, storage)?; + Ok(()) + } + + #[rstest] + #[case::nullable(Nullability::Nullable)] + #[case::non_nullable(Nullability::NonNullable)] + fn validate_accepts_any_outer_nullability( + #[case] nullability: Nullability, + ) -> VortexResult<()> { + let storage = nv_storage_dtype(PType::F32, 64, nullability)?; + ExtDType::::try_new(EmptyMetadata, storage)?; + Ok(()) + } + + #[test] + fn validate_rejects_non_extension_storage() { + let storage = DType::FixedSizeList( + Arc::new(DType::Primitive(PType::F32, Nullability::NonNullable)), + 64, + Nullability::NonNullable, + ); + assert!(ExtDType::::try_new(EmptyMetadata, storage).is_err()); + } + + #[test] + fn roundtrip_metadata() -> VortexResult<()> { + let vtable = NormalizedVector; + let bytes = vtable.serialize_metadata(&EmptyMetadata)?; + assert!(bytes.is_empty()); + let deserialized = vtable.deserialize_metadata(&bytes)?; + assert_eq!(deserialized, EmptyMetadata); + Ok(()) + } +} diff --git a/vortex-tensor/src/types/vector/matcher.rs b/vortex-tensor/src/types/vector/matcher.rs index 7ac75f097db..9b61f769453 100644 --- a/vortex-tensor/src/types/vector/matcher.rs +++ b/vortex-tensor/src/types/vector/matcher.rs @@ -12,6 +12,11 @@ use vortex_error::vortex_panic; use crate::types::vector::Vector; +/// Matcher that accepts only the [`Vector`] extension type. +/// +/// Use [`AnyTensor`](crate::matcher::AnyTensor) instead when +/// [`NormalizedVector`](crate::normalized_vector::NormalizedVector) or `FixedShapeTensor` +/// should also match. pub struct AnyVector; /// Convenience metadata for vectors. diff --git a/vortex-tensor/src/utils.rs b/vortex-tensor/src/utils.rs index 9dc097e11e0..06944a1e02a 100644 --- a/vortex-tensor/src/utils.rs +++ b/vortex-tensor/src/utils.rs @@ -94,18 +94,43 @@ pub fn validate_tensor_float_input(input_dtype: &DType) -> VortexResult( lhs: &'a DType, rhs: &DType, ) -> VortexResult> { + let dtypes_match = lhs.eq_ignore_nullability(rhs) || vector_shapes_match(lhs, rhs); vortex_ensure!( - lhs.eq_ignore_nullability(rhs), + dtypes_match, "binary tensor expression expects inputs to have the same dtype, got {lhs} and {rhs}" ); validate_tensor_float_input(lhs) } +/// Returns `true` when `lhs` and `rhs` are both within the vector extension family (plain +/// `Vector` or `NormalizedVector`) and share the same float ptype and dimension. +fn vector_shapes_match(lhs: &DType, rhs: &DType) -> bool { + use crate::types::normalized_vector::AnyNormalizedVector; + use crate::types::vector::AnyVector; + + fn vector_family_match(dtype: &DType) -> Option { + let ext = dtype.as_extension_opt()?; + ext.metadata_opt::() + .or_else(|| ext.metadata_opt::()) + } + + matches!( + (vector_family_match(lhs), vector_family_match(rhs)), + (Some(l), Some(r)) + if l.element_ptype() == r.element_ptype() && l.dimensions() == r.dimensions() + ) +} + /// Cast a float [`PrimitiveArray`] to a `Buffer`. /// /// Several operations in this crate (SORF transform, TurboQuant quantization) work exclusively @@ -334,6 +359,7 @@ pub mod test_helpers { use crate::scalar_fns::l2_denorm::L2Denorm; use crate::types::fixed_shape::FixedShapeTensor; use crate::types::fixed_shape::FixedShapeTensorMetadata; + use crate::types::normalized_vector::NormalizedVector; use crate::types::vector::Vector; /// Builds a `FixedSizeList` storage array from flat `elements`. The row count is @@ -372,6 +398,16 @@ pub mod test_helpers { Vector::try_new_vector_array(flat_fsl(elements, dim)) } + /// Builds a [`NormalizedVector`] extension array from pre-normalized `elements` and a vector + /// dimension size. The caller must ensure each row is unit-norm or the zero vector. + pub fn normalized_vector_array( + dim: u32, + elements: &[T], + ctx: &mut ExecutionCtx, + ) -> VortexResult { + NormalizedVector::try_new(flat_fsl(elements, dim), ctx) + } + /// Builds a [`FixedShapeTensor`] extension array whose storage is a [`ConstantArray`], /// representing a single query tensor broadcast to `len` rows. pub fn constant_tensor_array>( @@ -399,17 +435,21 @@ pub mod test_helpers { ConstantArray::new(ext_scalar, len).into_array() } - /// Creates an [`L2Denorm`] scalar function array from pre-normalized tensor elements and + /// Creates an [`L2Denorm`] scalar function array from pre-normalized vector elements and /// matching norms. The caller must ensure every row of `normalized_elements` is unit-norm or /// zero. + /// + /// `dim` is the vector dimension (the inner `FixedSizeList` width). The number of rows is + /// inferred from `normalized_elements.len() / dim`. pub fn l2_denorm_array( - shape: &[usize], + dim: u32, normalized_elements: &[T], norms: &[T], ctx: &mut ExecutionCtx, ) -> VortexResult { let len = norms.len(); - let normalized = tensor_array(shape, normalized_elements)?; + let storage = flat_fsl(normalized_elements, dim); + let normalized = NormalizedVector::try_new(storage, ctx)?; let norms = PrimitiveArray::new(Buffer::copy_from(norms), Validity::NonNullable).into_array(); Ok(L2Denorm::try_new_array(normalized, norms, len, ctx)?.into_array()) diff --git a/vortex/benches/single_encoding_throughput.rs b/vortex/benches/single_encoding_throughput.rs index be253187956..10547ba011d 100644 --- a/vortex/benches/single_encoding_throughput.rs +++ b/vortex/benches/single_encoding_throughput.rs @@ -504,7 +504,7 @@ mod turboquant_benches { use vortex_array::VortexSessionExecute; use vortex_buffer::BufferMut; use vortex_tensor::encodings::turboquant::TurboQuantConfig; - use vortex_tensor::encodings::turboquant::turboquant_encode_unchecked; + use vortex_tensor::encodings::turboquant::turboquant_encode_normalized; use vortex_tensor::scalar_fns::l2_denorm::normalize_as_l2_denorm; use vortex_tensor::vector::Vector; @@ -573,10 +573,9 @@ mod turboquant_benches { .as_ref() .as_opt::() .expect("normalized benchmark input should be an Extension array"); - // SAFETY: Benchmark inputs are normalized once up front so the timed - // region measures only TurboQuant encoding. - unsafe { turboquant_encode_unchecked(normalized, &config, ctx) } - .unwrap() + // Benchmark inputs are normalized once up front so the timed region + // measures only TurboQuant encoding. + turboquant_encode_normalized(normalized, &config, ctx).unwrap() }); } } @@ -588,10 +587,9 @@ mod turboquant_benches { let normalized_ext = setup_normalized_vector_ext($dim); let config = turboquant_config($bits); let mut ctx = SESSION.create_execution_ctx(); - let compressed = unsafe { - turboquant_encode_unchecked(normalized_ext.as_view(), &config, &mut ctx) - } - .unwrap(); + let compressed = + turboquant_encode_normalized(normalized_ext.as_view(), &config, &mut ctx) + .unwrap(); with_byte_counter(bencher, (NUM_VECTORS * $dim * 4) as u64) .with_inputs(|| (&compressed, SESSION.create_execution_ctx())) .bench_refs(|(a, ctx)| { From 766a8a0dfd5bd2d9b4095a1e33ebf3be99742e2b Mon Sep 17 00:00:00 2001 From: Connor Tsui Date: Tue, 28 Apr 2026 15:37:51 -0400 Subject: [PATCH 2/6] fix bugs Signed-off-by: Connor Tsui --- vortex-tensor/src/encodings/l2_denorm.rs | 9 +- .../encodings/turboquant/tests/structural.rs | 6 +- vortex-tensor/src/matcher.rs | 11 +- .../src/scalar_fns/cosine_similarity.rs | 36 ++- vortex-tensor/src/scalar_fns/inner_product.rs | 45 ++-- .../src/scalar_fns/sorf_transform/tests.rs | 83 ++----- .../src/scalar_fns/sorf_transform/vtable.rs | 222 +++++++++++------- .../src/types/normalized_vector/matcher.rs | 12 +- vortex-tensor/src/types/vector/matcher.rs | 66 +++++- vortex-tensor/src/utils.rs | 11 +- 10 files changed, 290 insertions(+), 211 deletions(-) diff --git a/vortex-tensor/src/encodings/l2_denorm.rs b/vortex-tensor/src/encodings/l2_denorm.rs index 68b8c1b31d8..64b876d4afa 100644 --- a/vortex-tensor/src/encodings/l2_denorm.rs +++ b/vortex-tensor/src/encodings/l2_denorm.rs @@ -14,6 +14,7 @@ use vortex_compressor::scheme::Scheme; use vortex_compressor::stats::ArrayAndStats; use vortex_error::VortexResult; +use crate::normalized_vector::AnyNormalizedVector; use crate::scalar_fns::l2_denorm::normalize_as_l2_denorm; use crate::types::vector::AnyVector; @@ -30,10 +31,10 @@ impl Scheme for L2DenormScheme { return false; }; - // `AnyVector` is the strict matcher for plain `Vector` only, so a `NormalizedVector` - // input is naturally excluded here (it would already carry an authoritative unit-norm - // representation and does not need re-normalization). - ext.ext_dtype().is::() + // `AnyVector` matches any vector-shaped extension; we explicitly exclude `NormalizedVector` + // here because a normalized input already carries an authoritative unit-norm representation + // and does not need re-normalization. + ext.ext_dtype().is::() && !ext.ext_dtype().is::() } fn expected_compression_ratio( diff --git a/vortex-tensor/src/encodings/turboquant/tests/structural.rs b/vortex-tensor/src/encodings/turboquant/tests/structural.rs index cf6cc2c3fb7..f65137d5d75 100644 --- a/vortex-tensor/src/encodings/turboquant/tests/structural.rs +++ b/vortex-tensor/src/encodings/turboquant/tests/structural.rs @@ -12,8 +12,10 @@ use vortex_array::arrays::fixed_size_list::FixedSizeListArrayExt; use vortex_error::VortexResult; use super::*; +use crate::encodings::turboquant::centroids::compute_or_get_centroids; -/// Verify that the centroids stored in the DictArray match what `get_centroids()` computes. +/// Verify that the centroids stored in the DictArray match what `compute_or_get_centroids()` +/// computes. #[test] fn stored_centroids_match_computed() -> VortexResult<()> { let fsl = make_fsl(10, 128, 42); @@ -30,7 +32,7 @@ fn stored_centroids_match_computed() -> VortexResult<()> { let stored = centroids.as_slice::(); // padded_dim for dim=128 is 128. - let computed = crate::encodings::turboquant::centroids::compute_or_get_centroids(128, 3)?; + let computed = compute_or_get_centroids(128, 3)?; assert_eq!(stored.len(), computed.len()); for i in 0..stored.len() { diff --git a/vortex-tensor/src/matcher.rs b/vortex-tensor/src/matcher.rs index 65fff74b728..786f78036d9 100644 --- a/vortex-tensor/src/matcher.rs +++ b/vortex-tensor/src/matcher.rs @@ -70,14 +70,17 @@ impl Matcher for AnyTensor { return Some(TensorMatch::FixedShapeTensor(metadata)); } - if let Some(metadata) = ext_dtype.metadata_opt::() { - return Some(TensorMatch::Vector(metadata)); - } - + // Check `AnyNormalizedVector` first because `AnyVector` is inclusive: it would otherwise + // match `NormalizedVector` and we'd lose the normalized variant in the returned + // `TensorMatch`. if let Some(metadata) = ext_dtype.metadata_opt::() { return Some(TensorMatch::NormalizedVector(metadata)); } + if let Some(metadata) = ext_dtype.metadata_opt::() { + return Some(TensorMatch::Vector(metadata)); + } + None } } diff --git a/vortex-tensor/src/scalar_fns/cosine_similarity.rs b/vortex-tensor/src/scalar_fns/cosine_similarity.rs index 88203fe37f2..ae5ed3748dc 100644 --- a/vortex-tensor/src/scalar_fns/cosine_similarity.rs +++ b/vortex-tensor/src/scalar_fns/cosine_similarity.rs @@ -140,26 +140,28 @@ impl ScalarFnVTable for CosineSimilarity { rhs_ref = sfn.into_array(); } + // The combined validity always comes from the original operands. Compute it once up + // front so the unit-form helpers below can take it directly without re-deriving from + // an `L2Denorm` wrapper they no longer hold. + let validity = lhs_ref.validity()?.and(rhs_ref.validity()?)?; + // Classify each operand by its normal form. When both operands carry a known unit-norm // representation, cosine similarity collapses to the dot product of the unit vectors. let lhs_form = NormalForm::classify(&lhs_ref); let rhs_form = NormalForm::classify(&rhs_ref); match (lhs_form.normalized_array(), rhs_form.normalized_array()) { (Some(unit_lhs), Some(unit_rhs)) => { - return self.execute_both_unit(unit_lhs, unit_rhs, &lhs_ref, &rhs_ref, len); + return self.execute_both_unit(unit_lhs, unit_rhs, validity, len); } (Some(unit_lhs), None) => { - return self.execute_one_unit(unit_lhs, &rhs_ref, &lhs_ref, len, ctx); + return self.execute_one_unit(unit_lhs, &rhs_ref, validity, len, ctx); } (None, Some(unit_rhs)) => { - return self.execute_one_unit(unit_rhs, &lhs_ref, &rhs_ref, len, ctx); + return self.execute_one_unit(unit_rhs, &lhs_ref, validity, len, ctx); } (None, None) => {} } - // Compute combined validity. - let validity = lhs_ref.validity()?.and(rhs_ref.validity()?)?; - // Compute inner product and norms as columnar operations, and propagate the options. let norm_lhs_arr = L2Norm::try_new_array(lhs_ref.clone(), len)?; let norm_rhs_arr = L2Norm::try_new_array(rhs_ref.clone(), len)?; @@ -248,12 +250,9 @@ impl CosineSimilarity { &self, unit_lhs: &ArrayRef, unit_rhs: &ArrayRef, - lhs_ref: &ArrayRef, - rhs_ref: &ArrayRef, + validity: Validity, len: usize, ) -> VortexResult { - let validity = lhs_ref.validity()?.and(rhs_ref.validity()?)?; - let dot = InnerProduct::try_new_array(unit_lhs.clone(), unit_rhs.clone(), len)?.into_array(); @@ -266,23 +265,22 @@ impl CosineSimilarity { } /// Exactly one side carries a unit-norm representation: cosine similarity reduces to - /// `dot(unit, other) / ||other||`. The norms of the unit side are implicitly `1.0` (naked - /// `NormalizedVector`) or stored separately (the outer `L2Denorm` wrapper, which is not - /// needed here since cosine ignores magnitude). + /// `dot(unit, plain) / ||plain||`. The norms of the unit side are implicitly `1.0` (naked + /// `NormalizedVector`) or stored separately on the outer `L2Denorm` wrapper, which the + /// caller has already stripped — cosine ignores magnitude on the unit side, so the wrapper + /// is not needed here. fn execute_one_unit( &self, unit: &ArrayRef, - plain_ref: &ArrayRef, - unit_ref: &ArrayRef, + plain: &ArrayRef, + validity: Validity, len: usize, ctx: &mut ExecutionCtx, ) -> VortexResult { - let validity = unit_ref.validity()?.and(plain_ref.validity()?)?; - - let dot_arr = InnerProduct::try_new_array(unit.clone(), plain_ref.clone(), len)?; + let dot_arr = InnerProduct::try_new_array(unit.clone(), plain.clone(), len)?; let dot: PrimitiveArray = dot_arr.into_array().execute(ctx)?; - let norm_arr = L2Norm::try_new_array(plain_ref.clone(), len)?; + let norm_arr = L2Norm::try_new_array(plain.clone(), len)?; let plain_norm: PrimitiveArray = norm_arr.into_array().execute(ctx)?; // TODO(connor): Ideally we would have a `SafeDiv` binary numeric operation. diff --git a/vortex-tensor/src/scalar_fns/inner_product.rs b/vortex-tensor/src/scalar_fns/inner_product.rs index 05e52666c05..9cdae9903d6 100644 --- a/vortex-tensor/src/scalar_fns/inner_product.rs +++ b/vortex-tensor/src/scalar_fns/inner_product.rs @@ -253,9 +253,9 @@ impl ScalarFnArrayVTable for InnerProduct { impl InnerProduct { /// Inner product over operands that may carry a unit-norm representation: - /// `inner_product = scale_l * scale_r * dot(unit_l, unit_r)`, where `scale = 1` for naked - /// `Normalized` operands, `scale = stored_norms` for `Denormalized` operands, and the - /// `unit_*` operands are the input itself for `Plain` operands. + /// `inner_product = scale_l * scale_r * dot(unit_l, unit_r)`, where the `(unit, scale)` pair + /// for each operand is `(operand, None)` for `Plain`, `(NV, None)` for naked `Normalized`, + /// and `(NV, Some(stored_norms))` for `Denormalized`. See [`decompose_for_unit_form`]. fn execute_unit_form( &self, lhs_form: &NormalForm<'_>, @@ -267,24 +267,13 @@ impl InnerProduct { ) -> VortexResult { let validity = lhs_ref.validity()?.and(rhs_ref.validity()?)?; - // For each operand, take its unit-norm representation if it has one; fall back to the - // operand itself (the `Plain` case feeds the regular dot path with no scaling). - let unit_lhs = lhs_form - .normalized_array() - .cloned() - .unwrap_or_else(|| lhs_ref.clone()); - let unit_rhs = rhs_form - .normalized_array() - .cloned() - .unwrap_or_else(|| rhs_ref.clone()); + let (unit_lhs, lhs_scale) = decompose_for_unit_form(lhs_form, lhs_ref, ctx)?; + let (unit_rhs, rhs_scale) = decompose_for_unit_form(rhs_form, rhs_ref, ctx)?; let dot: PrimitiveArray = InnerProduct::try_new_array(unit_lhs, unit_rhs, len)? .into_array() .execute(ctx)?; - let lhs_scale = norms_for_scaling(lhs_form, ctx)?; - let rhs_scale = norms_for_scaling(rhs_form, ctx)?; - match_each_float_ptype!(dot.ptype(), |T| { let dots = dot.as_slice::(); let buffer: Buffer = match (lhs_scale.as_ref(), rhs_scale.as_ref()) { @@ -500,18 +489,26 @@ impl InnerProduct { } } -/// Materialize the per-row scaling factor for an operand classified by [`NormalForm`]. +/// Decompose an operand classified by [`NormalForm`] into the `(unit_operand, optional_scale)` +/// pair consumed by [`InnerProduct::execute_unit_form`]: /// -/// - `Plain`: no scaling needed (the operand itself enters the dot product). -/// - `Normalized`: implicit scaling of `1.0`, returned as `None` so the caller skips the multiply. -/// - `Denormalized`: returns the materialized stored norms. -fn norms_for_scaling( +/// - `Plain`: `(original, None)`. The unscaled operand IS its own "unit" for dot purposes; no +/// per-row multiply is needed. +/// - `Normalized`: `(NV, None)`. Implicit per-row scale of `1.0` — the unit child enters the dot +/// directly with no multiply. +/// - `Denormalized`: `(NV, Some(stored_norms))`. The dot is computed over the unit child and the +/// caller multiplies by the materialized stored norms afterward. +fn decompose_for_unit_form( form: &NormalForm<'_>, + original: &ArrayRef, ctx: &mut ExecutionCtx, -) -> VortexResult> { +) -> VortexResult<(ArrayRef, Option)> { match form { - NormalForm::Plain | NormalForm::Normalized { .. } => Ok(None), - NormalForm::Denormalized { norms, .. } => Ok(Some(norms.clone().execute(ctx)?)), + NormalForm::Plain => Ok((original.clone(), None)), + NormalForm::Normalized { array } => Ok(((*array).clone(), None)), + NormalForm::Denormalized { normalized, norms } => { + Ok((normalized.clone(), Some(norms.clone().execute(ctx)?))) + } } } diff --git a/vortex-tensor/src/scalar_fns/sorf_transform/tests.rs b/vortex-tensor/src/scalar_fns/sorf_transform/tests.rs index 0934657dc7c..6f906d57d69 100644 --- a/vortex-tensor/src/scalar_fns/sorf_transform/tests.rs +++ b/vortex-tensor/src/scalar_fns/sorf_transform/tests.rs @@ -7,7 +7,6 @@ use std::sync::Arc; -use prost::Message; use vortex_array::ArrayPlugin; use vortex_array::ArrayRef; use vortex_array::IntoArray; @@ -366,7 +365,7 @@ fn rejects_non_vector_extension_child_at_construction() { } #[test] -fn accepts_normalized_vector_child_and_returns_plain_vector() -> VortexResult<()> { +fn accepts_normalized_vector_child_and_mirrors_kind() -> VortexResult<()> { let options = default_options(128, 42); let mut values = vec![0.0f32; 128]; values[0] = 1.0; @@ -375,11 +374,34 @@ fn accepts_normalized_vector_child_and_returns_plain_vector() -> VortexResult<() let mut ctx = SESSION.create_execution_ctx(); let child = NormalizedVector::try_new(fsl.into_array(), &mut ctx)?; + // The output mirrors the child's wrapper kind: a `NormalizedVector` child produces a + // `NormalizedVector` parent. The orthogonal inverse rotation preserves L2 norm and the + // truncated coordinates were near-zero pre-rotation, so the output is approximately + // unit-norm (lossy contract documented on `NormalizedVector::new_unchecked`). let sorf = SorfTransform::try_new_array(&options, child, 1)?.into_array(); + assert!(sorf.dtype().as_extension().is::()); + + let result: ExtensionArray = sorf.execute(&mut ctx)?; + assert!(result.dtype().as_extension().is::()); + Ok(()) +} + +/// A plain [`Vector`] child should still produce a plain [`Vector`] parent. +#[test] +fn accepts_plain_vector_child_and_mirrors_kind() -> VortexResult<()> { + let options = default_options(128, 42); + let elements = PrimitiveArray::from_iter([0.0f32; 128]).into_array(); + let fsl = FixedSizeListArray::try_new(elements, 128, Validity::NonNullable, 1)?; + let child = wrap_as_vector(fsl, Validity::NonNullable)?; + + let sorf = SorfTransform::try_new_array(&options, child.into_array(), 1)?.into_array(); assert!(sorf.dtype().as_extension().is::()); + assert!(!sorf.dtype().as_extension().is::()); + let mut ctx = SESSION.create_execution_ctx(); let result: ExtensionArray = sorf.execute(&mut ctx)?; assert!(result.dtype().as_extension().is::()); + assert!(!result.dtype().as_extension().is::()); Ok(()) } @@ -491,18 +513,6 @@ fn trivial_padded_normalized_vector( NormalizedVector::try_new(fsl.into_array(), &mut ctx) } -#[derive(Clone, prost::Message)] -struct LegacySorfTransformMetadata { - #[prost(uint64, tag = "1")] - seed: u64, - #[prost(uint32, tag = "2")] - num_rounds: u32, - #[prost(uint32, tag = "3")] - dimension: u32, - #[prost(enumeration = "PType", tag = "4")] - element_ptype: i32, -} - #[rstest::rstest] // Non-power-of-two dimension to exercise `padded_dim = dim.next_power_of_two()`. #[case::power_of_two_dim(128, Validity::NonNullable)] @@ -590,48 +600,3 @@ fn serde_round_trip_preserves_normalized_vector_child_dtype() -> VortexResult<() ); Ok(()) } - -#[test] -fn serde_legacy_metadata_derives_plain_vector_child_dtype() -> VortexResult<()> { - let dimension = 128; - let num_rows = 4; - let options = default_options(dimension, 42); - let child = trivial_padded_vector( - dimension.next_power_of_two(), - num_rows, - Validity::NonNullable, - ); - let original = SorfTransform::try_new_array(&options, child.clone(), num_rows)?.into_array(); - - let legacy_metadata = LegacySorfTransformMetadata { - seed: options.seed, - num_rounds: u32::from(options.num_rounds), - dimension: options.dimensions, - element_ptype: options.element_ptype as i32, - } - .encode_to_vec(); - - let plugin = ScalarFnArrayPlugin::new(SorfTransform); - let children = vec![child]; - let recovered = plugin.deserialize( - original.dtype(), - original.len(), - &legacy_metadata, - &[], - &children, - &SESSION, - )?; - - assert_eq!(recovered.dtype(), original.dtype()); - assert_eq!(recovered.len(), original.len()); - assert_eq!(recovered.encoding_id(), original.encoding_id()); - let recovered_scalar_fn = recovered.as_::(); - assert!( - recovered_scalar_fn - .child_at(0) - .dtype() - .as_extension() - .is::() - ); - Ok(()) -} diff --git a/vortex-tensor/src/scalar_fns/sorf_transform/vtable.rs b/vortex-tensor/src/scalar_fns/sorf_transform/vtable.rs index ebe633ed7f6..c404f38e27c 100644 --- a/vortex-tensor/src/scalar_fns/sorf_transform/vtable.rs +++ b/vortex-tensor/src/scalar_fns/sorf_transform/vtable.rs @@ -16,10 +16,8 @@ use vortex_array::IntoArray; use vortex_array::arrays::ExtensionArray; use vortex_array::arrays::FixedSizeListArray; use vortex_array::arrays::PrimitiveArray; -use vortex_array::arrays::ScalarFn as ScalarFnArrayEncoding; use vortex_array::arrays::extension::ExtensionArrayExt; use vortex_array::arrays::fixed_size_list::FixedSizeListArrayExt; -use vortex_array::arrays::scalar_fn::ScalarFnArrayExt; use vortex_array::arrays::scalar_fn::ScalarFnArrayView; use vortex_array::arrays::scalar_fn::plugin::ScalarFnArrayParts; use vortex_array::arrays::scalar_fn::plugin::ScalarFnArrayVTable; @@ -28,7 +26,6 @@ use vortex_array::dtype::NativePType; use vortex_array::dtype::Nullability; use vortex_array::dtype::PType; use vortex_array::dtype::extension::ExtDType; -use vortex_array::dtype::proto::dtype as pb; use vortex_array::expr::Expression; use vortex_array::extension::EmptyMetadata; use vortex_array::match_each_float_ptype; @@ -51,6 +48,9 @@ use super::SorfTransform; use super::rotation::SorfMatrix; use super::validate_sorf_options; use crate::matcher::AnyTensor; +use crate::types::normalized_vector::NormalizedVector; +use crate::types::normalized_vector::inner_vector_array; +use crate::types::vector::AnyVector; use crate::types::vector::Vector; impl ScalarFnVTable for SorfTransform { @@ -117,18 +117,29 @@ impl ScalarFnVTable for SorfTransform { ); let output_elem_dtype = DType::Primitive(options.element_ptype, Nullability::NonNullable); - let storage_dtype = DType::FixedSizeList( + let fsl_dtype = DType::FixedSizeList( Arc::new(output_elem_dtype), options.dimensions, child_dtype.nullability(), ); - // The inverse SORF transform does not preserve unit norm on the output, even when the - // child is a [`NormalizedVector`]. Surface the output as a plain [`Vector`]. - let _ = vector_metadata; - let ext_dtype = ExtDType::::try_new(EmptyMetadata, storage_dtype)?.erased(); - - Ok(DType::Extension(ext_dtype)) + // The output mirrors the child's wrapper kind: if the child was a `NormalizedVector` the + // output is also surfaced as a `NormalizedVector` (the orthogonal inverse rotation + // preserves L2 norm and the truncation drops coordinates that were zero pre-rotation, so + // the output is approximately unit-norm under the same lossy contract that + // `NormalizedVector::new_unchecked` documents). + if vector_metadata.is_normalized() { + let inner_vector = ExtDType::::try_new(EmptyMetadata, fsl_dtype)?.erased(); + let outer = ExtDType::::try_new( + EmptyMetadata, + DType::Extension(inner_vector), + )? + .erased(); + Ok(DType::Extension(outer)) + } else { + let ext_dtype = ExtDType::::try_new(EmptyMetadata, fsl_dtype)?.erased(); + Ok(DType::Extension(ext_dtype)) + } } fn execute( @@ -140,50 +151,59 @@ impl ScalarFnVTable for SorfTransform { let dim = options.dimensions as usize; let num_rows = args.row_count(); - if num_rows == 0 { - let child_dtype = args.get(0)?.dtype().clone(); - let validity = Validity::from(child_dtype.nullability()); + let child_arg = args.get(0)?; + let is_normalized_child = child_arg + .dtype() + .as_extension_opt() + .is_some_and(|ext| ext.is::()); + + let fsl_array: ArrayRef = if num_rows == 0 { + let validity = Validity::from(child_arg.dtype().nullability()); - return match_each_float_ptype!(options.element_ptype, |T| { + match_each_float_ptype!(options.element_ptype, |T| { let elements = PrimitiveArray::empty::(Nullability::NonNullable); - let fsl = FixedSizeListArray::try_new( - elements.into_array(), - options.dimensions, - validity, - 0, - )?; - Vector::try_new_vector_array(fsl.into_array()) - }); - } + FixedSizeListArray::try_new(elements.into_array(), options.dimensions, validity, 0) + })? + .into_array() + } else { + // Execute the child to get either a `Vector` extension or a `NormalizedVector` + // wrapping a `Vector` over an FSL of f32 coordinates. The `return_dtype` check + // guarantees the shape is `Vector` at the FSL level, so drill past + // any `NormalizedVector` wrapper before unpacking. + let child_ref = inner_vector_array(&child_arg, ctx)?; + let child_ext: ExtensionArray = child_ref.execute(ctx)?; + let child_validity = child_ext.as_ref().validity()?; + let child_fsl: FixedSizeListArray = child_ext.storage_array().clone().execute(ctx)?; + let padded_dim = + usize::try_from(child_fsl.list_size()).vortex_expect("list_size fits usize"); + + let elements_prim: PrimitiveArray = child_fsl.elements().clone().execute(ctx)?; + let f32_elements = elements_prim.into_buffer::(); + + // Reconstruct the orthogonal transform matrix from the seed. + let rotation = SorfMatrix::try_new(options.seed, dim, options.num_rounds as usize)?; + + // Inverse transform each row, truncate to original dimension, cast to target type. + match_each_float_ptype!(options.element_ptype, |T| { + inverse_rotate_typed::( + &f32_elements, + &rotation, + dim, + padded_dim, + num_rows, + child_validity, + ) + })? + }; - // Execute the child to get either a `Vector` extension or a `NormalizedVector` - // wrapping a `Vector` over an FSL of f32 coordinates. The `return_dtype` check guarantees - // the shape is `Vector` at the FSL level, so drill past any - // `NormalizedVector` wrapper before unpacking. - let child_ref = crate::types::normalized_vector::inner_vector_array(&args.get(0)?, ctx)?; - let child_ext: ExtensionArray = child_ref.execute(ctx)?; - let child_validity = child_ext.as_ref().validity()?; - let child_fsl: FixedSizeListArray = child_ext.storage_array().clone().execute(ctx)?; - let padded_dim = - usize::try_from(child_fsl.list_size()).vortex_expect("list_size fits usize"); - - let elements_prim: PrimitiveArray = child_fsl.elements().clone().execute(ctx)?; - let f32_elements = elements_prim.into_buffer::(); - - // Reconstruct the orthogonal transform matrix from the seed. - let rotation = SorfMatrix::try_new(options.seed, dim, options.num_rounds as usize)?; - - // Inverse transform each row, truncate to original dimension, cast to target type. - match_each_float_ptype!(options.element_ptype, |T| { - inverse_rotate_typed::( - &f32_elements, - &rotation, - dim, - padded_dim, - num_rows, - child_validity, - ) - }) + // SAFETY: When `is_normalized_child` is `true`, the input child was a + // `NormalizedVector` (its dtype was checked above), so every valid row was unit-norm or + // zero by type. Inverse SORF is orthogonal (norm-preserving), and the truncated tail + // coordinates were zero pre-rotation up to quantization noise — so each row of + // `fsl_array` is approximately unit-norm under the same lossy contract that + // [`NormalizedVector::new_unchecked`] documents. When `is_normalized_child` is `false` + // [`wrap_output`] takes the trivially-safe `Vector` branch. + unsafe { wrap_output(fsl_array, is_normalized_child) } } fn validity( @@ -205,10 +225,10 @@ impl ScalarFnVTable for SorfTransform { /// Metadata for a serialized [`SorfTransform`] array. /// -/// Stores the full [`SorfOptions`] inline along with the child [`DType`]. The child dtype records -/// whether the input was a plain [`Vector`] or [`NormalizedVector`](crate::normalized_vector::NormalizedVector). -/// Older metadata omitted this field; deserialization derives the legacy plain-`Vector` child dtype -/// from the parent dtype in that case. +/// Stores the full [`SorfOptions`] inline. The child dtype is fully derivable from the parent +/// dtype: the parent's outer wrapper (plain `Vector` or `NormalizedVector`) mirrors the child's +/// wrapper kind, the inner FSL nullability is propagated through `return_dtype`, and +/// `padded_dim`/`f32` are determined by [`SorfOptions`]. #[derive(Clone, prost::Message)] pub(super) struct SorfTransformMetadata { #[prost(uint64, tag = "1")] @@ -220,8 +240,6 @@ pub(super) struct SorfTransformMetadata { dimension: u32, #[prost(enumeration = "PType", tag = "4")] element_ptype: i32, - #[prost(message, optional, tag = "5")] - child_dtype: Option, } impl ScalarFnArrayVTable for SorfTransform { @@ -230,12 +248,7 @@ impl ScalarFnArrayVTable for SorfTransform { view: &ScalarFnArrayView, _session: &VortexSession, ) -> VortexResult>> { - let scalar_fn_array = view.as_::(); - let child_dtype = Some(scalar_fn_array.child_at(0).dtype().try_into()?); - let metadata = SorfTransformMetadata { - child_dtype, - ..SorfTransformMetadata::from(view.options) - }; + let metadata = SorfTransformMetadata::from(view.options); Ok(Some(metadata.encode_to_vec())) } @@ -247,31 +260,59 @@ impl ScalarFnArrayVTable for SorfTransform { children: &dyn ArrayChildren, session: &VortexSession, ) -> VortexResult> { + let _ = session; let metadata = SorfTransformMetadata::decode(metadata) .map_err(|e| vortex_err!("Failed to decode SorfTransformMetadata: {e}"))?; let options = metadata.to_options()?; - // `return_dtype` sets the output FSL's nullability to the child's nullability (see - // `return_dtype` above), so we read the child nullability back from the parent dtype. - let child_nullability = dtype + // The parent dtype must be a vector-shaped extension produced by `return_dtype`: either + // a plain `Vector` (when the child was a plain `Vector`) or a `NormalizedVector` (when + // the child was a `NormalizedVector`). `AnyVector` matches both, and its `try_match` + // panics on a structurally malformed `NormalizedVector`, so a successful match also + // guarantees the inner drill below is well-formed. + let parent_ext = dtype .as_extension_opt() + .filter(|ext| ext.is::()) .ok_or_else(|| { - vortex_err!("SorfTransform parent dtype must be a Vector extension, got {dtype}") - })? - .storage_dtype() - .nullability(); + vortex_err!( + "SorfTransform parent dtype must be a `Vector` or `NormalizedVector` \ + extension, got {dtype}", + ) + })?; + let is_normalized = parent_ext.is::(); + + // The child's FSL nullability matches the parent's inner FSL nullability (set by + // `return_dtype` from the original child's outer nullability). Drill into the parent + // wrapper to recover it; `AnyVector` already validated the structural shape. + let parent_fsl_dtype = if is_normalized { + let DType::Extension(inner) = parent_ext.storage_dtype() else { + unreachable!( + "`AnyVector` matcher guarantees a `NormalizedVector` parent wraps a \ + `Vector` extension" + ) + }; + inner.storage_dtype() + } else { + parent_ext.storage_dtype() + }; + let fsl_nullability = parent_fsl_dtype.nullability(); + let padded_dim = options.dimensions.next_power_of_two(); - let child_storage = DType::FixedSizeList( + let child_fsl = DType::FixedSizeList( Arc::new(DType::Primitive(PType::F32, Nullability::NonNullable)), padded_dim, - child_nullability, + fsl_nullability, ); - let child_dtype = match metadata.child_dtype.as_ref() { - Some(dtype) => DType::from_proto(dtype, session)?, - None => { - let child_ext = ExtDType::::try_new(EmptyMetadata, child_storage)?.erased(); - DType::Extension(child_ext) - } + let inner_vector = ExtDType::::try_new(EmptyMetadata, child_fsl)?.erased(); + let child_dtype = if is_normalized { + let nv = ExtDType::::try_new( + EmptyMetadata, + DType::Extension(inner_vector), + )? + .erased(); + DType::Extension(nv) + } else { + DType::Extension(inner_vector) }; let child = children.get(0, &child_dtype, len)?; @@ -291,7 +332,9 @@ fn float_from_f32(v: f32) -> T { } /// Apply the inverse SORF transform on f32 data, truncate to the original dimension, cast each -/// element to `T`, and build a plain [`Vector`](crate::vector::Vector) extension array. +/// element to `T`, and return the resulting `FixedSizeList` storage array. The caller is +/// responsible for wrapping the FSL in the appropriate vector-family extension via +/// [`wrap_output`]. fn inverse_rotate_typed( f32_elements: &[f32], rotation: &SorfMatrix, @@ -317,7 +360,25 @@ fn inverse_rotate_typed( let elements = PrimitiveArray::new::(output.freeze(), Validity::NonNullable); let fsl = FixedSizeListArray::try_new(elements.into_array(), dim_u32, validity, num_rows)?; - Vector::try_new_vector_array(fsl.into_array()) + Ok(fsl.into_array()) +} + +/// Wraps `fsl` as either a [`Vector`] or [`NormalizedVector`] extension array, mirroring the kind +/// of the upstream `SorfTransform` child. +/// +/// # Safety +/// +/// When `is_normalized` is `true`, every valid row of `fsl` must be approximately unit-norm or +/// zero in the lossy sense documented by [`NormalizedVector::new_unchecked`]. +/// +/// When `is_normalized` is `false` the function takes the safe `Vector` branch. +unsafe fn wrap_output(fsl: ArrayRef, is_normalized: bool) -> VortexResult { + if is_normalized { + // SAFETY: Forwarded from the function-level safety contract above. + unsafe { NormalizedVector::new_unchecked(fsl) } + } else { + Vector::try_new_vector_array(fsl) + } } impl From<&SorfOptions> for SorfTransformMetadata { @@ -327,7 +388,6 @@ impl From<&SorfOptions> for SorfTransformMetadata { num_rounds: u32::from(options.num_rounds), dimension: options.dimensions, element_ptype: options.element_ptype as i32, - child_dtype: None, } } } diff --git a/vortex-tensor/src/types/normalized_vector/matcher.rs b/vortex-tensor/src/types/normalized_vector/matcher.rs index 483fa45c12b..ec5dc2b729c 100644 --- a/vortex-tensor/src/types/normalized_vector/matcher.rs +++ b/vortex-tensor/src/types/normalized_vector/matcher.rs @@ -111,14 +111,18 @@ mod tests { } #[test] - fn any_vector_does_not_match_normalized_vector() -> VortexResult<()> { + fn any_vector_matches_normalized_vector() -> VortexResult<()> { let ext_dtype = ExtDType::::try_new(EmptyMetadata, nv_storage(PType::F32, 128)?)? .erased(); - // `AnyVector` is strict: it only matches plain `Vector`. Use `AnyTensor` to accept - // both `Vector` and `NormalizedVector`. - assert!(ext_dtype.metadata_opt::().is_none()); + // `AnyVector` is the inclusive matcher: it matches both `Vector` and `NormalizedVector`. + // Callers that need to distinguish the two should pair it with an + // [`AnyNormalizedVector`] check, or use [`AnyTensor`](crate::matcher::AnyTensor) to also + // accept `FixedShapeTensor`. + let metadata = ext_dtype.metadata::(); + assert_eq!(metadata.element_ptype(), PType::F32); + assert_eq!(metadata.dimensions(), 128); Ok(()) } } diff --git a/vortex-tensor/src/types/vector/matcher.rs b/vortex-tensor/src/types/vector/matcher.rs index 9b61f769453..a646e8009fe 100644 --- a/vortex-tensor/src/types/vector/matcher.rs +++ b/vortex-tensor/src/types/vector/matcher.rs @@ -10,13 +10,16 @@ use vortex_error::VortexResult; use vortex_error::vortex_ensure; use vortex_error::vortex_panic; +use crate::types::normalized_vector::NormalizedVector; use crate::types::vector::Vector; -/// Matcher that accepts only the [`Vector`] extension type. +/// Matcher that accepts any vector-shaped extension type — both plain +/// [`Vector`] and [`NormalizedVector`](crate::normalized_vector::NormalizedVector). /// -/// Use [`AnyTensor`](crate::matcher::AnyTensor) instead when -/// [`NormalizedVector`](crate::normalized_vector::NormalizedVector) or `FixedShapeTensor` -/// should also match. +/// To match a plain [`Vector`] only (excluding [`NormalizedVector`]), pair this matcher with a +/// negated `is::()` check; to match a `NormalizedVector` only, use +/// [`AnyNormalizedVector`](crate::normalized_vector::AnyNormalizedVector) directly. Use +/// [`AnyTensor`](crate::matcher::AnyTensor) when `FixedShapeTensor` should also match. pub struct AnyVector; /// Convenience metadata for vectors. @@ -44,11 +47,32 @@ impl Matcher for AnyVector { type Match<'a> = VectorMatcherMetadata; fn try_match<'a>(ext_dtype: &'a ExtDTypeRef) -> Option> { - if !ext_dtype.is::() { + // Walk to the inner `FixedSizeList` for whichever vector-shaped wrapper this is. Plain + // `Vector` stores the FSL directly; `NormalizedVector` wraps a `Vector` extension which + // in turn stores the FSL. + let fsl_dtype = if ext_dtype.is::() { + let DType::Extension(inner) = ext_dtype.storage_dtype() else { + vortex_panic!( + "`NormalizedVector` storage must be `DType::Extension(Vector)`, got {}", + ext_dtype.storage_dtype(), + ) + }; + + if !inner.is::() { + vortex_panic!( + "`NormalizedVector` inner extension must be `Vector`, got {}", + inner.id(), + ) + } + + inner.storage_dtype() + } else if ext_dtype.is::() { + ext_dtype.storage_dtype() + } else { return None; - } + }; - let DType::FixedSizeList(element_dtype, list_size, _) = ext_dtype.storage_dtype() else { + let DType::FixedSizeList(element_dtype, list_size, _) = fsl_dtype else { vortex_panic!("`Vector` type somehow did not have a `FixedSizeList` storage type") }; @@ -117,6 +141,18 @@ mod tests { ) } + fn normalized_vector_storage_dtype( + element_ptype: PType, + dimensions: u32, + ) -> VortexResult { + let inner = ExtDType::::try_new( + EmptyMetadata, + vector_storage_dtype(element_ptype, dimensions), + )? + .erased(); + Ok(DType::Extension(inner)) + } + #[test] fn matches_vector_dtype_metadata() -> VortexResult<()> { let ext_dtype = @@ -129,6 +165,22 @@ mod tests { Ok(()) } + #[test] + fn matches_normalized_vector_dtype_metadata() -> VortexResult<()> { + let ext_dtype = ExtDType::::try_new( + EmptyMetadata, + normalized_vector_storage_dtype(PType::F32, 256)?, + )? + .erased(); + + // `AnyVector` is the inclusive matcher: it matches `NormalizedVector` too and surfaces + // the inner `Vector`'s element ptype and dimensionality. + let metadata = ext_dtype.metadata::(); + assert_eq!(metadata.element_ptype(), PType::F32); + assert_eq!(metadata.dimensions(), 256); + Ok(()) + } + #[test] fn does_not_match_fixed_shape_tensor() -> VortexResult<()> { let ext_dtype = ExtDType::::try_new( diff --git a/vortex-tensor/src/utils.rs b/vortex-tensor/src/utils.rs index 06944a1e02a..a50e8fa168f 100644 --- a/vortex-tensor/src/utils.rs +++ b/vortex-tensor/src/utils.rs @@ -32,6 +32,8 @@ use vortex_session::VortexSession; use crate::matcher::AnyTensor; use crate::matcher::TensorMatch; use crate::scalar_fns::l2_denorm::L2Denorm; +use crate::types::vector::VectorMatcherMetadata; +use crate::vector::AnyVector; /// Safety factor for unit-norm tolerance. Applied as a constant multiplier on the probabilistic /// `√d · ε` bound so that legitimate round-off noise clears the check with headroom. @@ -115,13 +117,8 @@ pub fn validate_binary_tensor_float_inputs<'a>( /// Returns `true` when `lhs` and `rhs` are both within the vector extension family (plain /// `Vector` or `NormalizedVector`) and share the same float ptype and dimension. fn vector_shapes_match(lhs: &DType, rhs: &DType) -> bool { - use crate::types::normalized_vector::AnyNormalizedVector; - use crate::types::vector::AnyVector; - - fn vector_family_match(dtype: &DType) -> Option { - let ext = dtype.as_extension_opt()?; - ext.metadata_opt::() - .or_else(|| ext.metadata_opt::()) + fn vector_family_match(dtype: &DType) -> Option { + dtype.as_extension_opt()?.metadata_opt::() } matches!( From 0d3cb51f80ba8ad26dc97079ac83e31b8c4eed58 Mon Sep 17 00:00:00 2001 From: Connor Tsui Date: Tue, 28 Apr 2026 16:56:49 -0400 Subject: [PATCH 3/6] clean up Signed-off-by: Connor Tsui --- vortex-tensor/public-api.lock | 8 +- .../src/encodings/turboquant/compress.rs | 87 ++++++------- vortex-tensor/src/encodings/turboquant/mod.rs | 29 ----- .../src/encodings/turboquant/scheme.rs | 28 ++++- .../src/scalar_fns/cosine_similarity.rs | 22 ++-- vortex-tensor/src/scalar_fns/inner_product.rs | 71 +++++++---- vortex-tensor/src/scalar_fns/l2_denorm.rs | 5 +- .../src/scalar_fns/sorf_transform/mod.rs | 8 +- .../src/scalar_fns/sorf_transform/rotation.rs | 8 +- .../src/scalar_fns/sorf_transform/vtable.rs | 115 +++++++----------- .../src/types/normalized_vector/matcher.rs | 2 +- vortex-tensor/src/types/vector/matcher.rs | 27 +++- 12 files changed, 209 insertions(+), 201 deletions(-) diff --git a/vortex-tensor/public-api.lock b/vortex-tensor/public-api.lock index 5a30355caac..d07b61afa73 100644 --- a/vortex-tensor/public-api.lock +++ b/vortex-tensor/public-api.lock @@ -78,8 +78,6 @@ pub const vortex_tensor::encodings::turboquant::MAX_CENTROIDS: usize pub const vortex_tensor::encodings::turboquant::MIN_DIMENSION: u32 -pub fn vortex_tensor::encodings::turboquant::tq_validate_vector_dtype(dtype: &vortex_array::dtype::DType) -> vortex_error::VortexResult - pub fn vortex_tensor::encodings::turboquant::turboquant_encode(input: vortex_array::array::erased::ArrayRef, config: &vortex_tensor::encodings::turboquant::TurboQuantConfig, ctx: &mut vortex_array::executor::ExecutionCtx) -> vortex_error::VortexResult pub fn vortex_tensor::encodings::turboquant::turboquant_encode_normalized(ext: vortex_array::array::view::ArrayView<'_, vortex_array::arrays::extension::vtable::Extension>, config: &vortex_tensor::encodings::turboquant::TurboQuantConfig, ctx: &mut vortex_array::executor::ExecutionCtx) -> vortex_error::VortexResult @@ -552,7 +550,7 @@ pub fn vortex_tensor::scalar_fns::sorf_transform::SorfTransform::clone(&self) -> impl vortex_array::arrays::scalar_fn::plugin::ScalarFnArrayVTable for vortex_tensor::scalar_fns::sorf_transform::SorfTransform -pub fn vortex_tensor::scalar_fns::sorf_transform::SorfTransform::deserialize(&self, dtype: &vortex_array::dtype::DType, len: usize, metadata: &[u8], children: &dyn vortex_array::serde::ArrayChildren, session: &vortex_session::VortexSession) -> vortex_error::VortexResult> +pub fn vortex_tensor::scalar_fns::sorf_transform::SorfTransform::deserialize(&self, dtype: &vortex_array::dtype::DType, len: usize, metadata: &[u8], children: &dyn vortex_array::serde::ArrayChildren, _session: &vortex_session::VortexSession) -> vortex_error::VortexResult> pub fn vortex_tensor::scalar_fns::sorf_transform::SorfTransform::serialize(&self, view: &vortex_array::arrays::scalar_fn::vtable::ScalarFnArrayView<'_, Self>, _session: &vortex_session::VortexSession) -> vortex_error::VortexResult>> @@ -640,7 +638,9 @@ pub fn vortex_tensor::vector::VectorMatcherMetadata::dimensions(&self) -> u32 pub fn vortex_tensor::vector::VectorMatcherMetadata::element_ptype(&self) -> vortex_array::dtype::ptype::PType -pub fn vortex_tensor::vector::VectorMatcherMetadata::try_new(element_ptype: vortex_array::dtype::ptype::PType, dimensions: u32) -> vortex_error::VortexResult +pub fn vortex_tensor::vector::VectorMatcherMetadata::is_normalized(self) -> bool + +pub fn vortex_tensor::vector::VectorMatcherMetadata::try_new(element_ptype: vortex_array::dtype::ptype::PType, dimensions: u32, is_normalized: bool) -> vortex_error::VortexResult impl core::clone::Clone for vortex_tensor::vector::VectorMatcherMetadata diff --git a/vortex-tensor/src/encodings/turboquant/compress.rs b/vortex-tensor/src/encodings/turboquant/compress.rs index 0a1b4dded13..610d7964752 100644 --- a/vortex-tensor/src/encodings/turboquant/compress.rs +++ b/vortex-tensor/src/encodings/turboquant/compress.rs @@ -3,10 +3,12 @@ //! TurboQuant encoding (quantization) logic. //! -//! The input to [`turboquant_encode`] must be a non-nullable [`Vector`](crate::vector::Vector) -//! extension array whose rows are already L2-normalized (unit norm). Normalization is handled -//! externally by [`normalize_as_l2_denorm`](crate::scalar_fns::l2_denorm::normalize_as_l2_denorm), -//! which the [`TurboQuantScheme`] calls before invoking this function. +//! The input to [`turboquant_encode`] must be a non-nullable [`Vector`] extension array whose rows +//! are already L2-normalized (unit norm). Normalization is handled externally by +//! [`normalize_as_l2_denorm`], which the [`TurboQuantScheme`] calls before invoking this function. +//! +//! If you already have a [`NormalizedVector`] array, then use the [`turboquant_encode_normalized`] +//! function instead. //! //! [`TurboQuantScheme`]: crate::encodings::turboquant::TurboQuantScheme @@ -23,6 +25,7 @@ use vortex_array::arrays::extension::ExtensionArrayExt; use vortex_array::arrays::fixed_size_list::FixedSizeListArrayExt; use vortex_array::arrays::scalar_fn::ScalarFnArrayExt; use vortex_array::dtype::Nullability; +use vortex_array::match_each_float_ptype; use vortex_array::validity::Validity; use vortex_buffer::Buffer; use vortex_buffer::BufferMut; @@ -42,6 +45,8 @@ use crate::scalar_fns::sorf_transform::SorfMatrix; use crate::scalar_fns::sorf_transform::SorfOptions; use crate::scalar_fns::sorf_transform::SorfTransform; use crate::types::normalized_vector::NormalizedVector; +#[expect(unused, reason = "docs")] +use crate::types::vector::Vector; use crate::utils::cast_to_f32; /// Configuration for TurboQuant encoding. @@ -55,6 +60,7 @@ pub struct TurboQuantConfig { pub num_rounds: u8, } +// TODO(connor): We should be able to modify this more easily from the `TurboQuantScheme`! impl Default for TurboQuantConfig { fn default() -> Self { Self { @@ -65,10 +71,10 @@ impl Default for TurboQuantConfig { } } -/// Apply the full TurboQuant compression pipeline to a [`Vector`](crate::vector::Vector) -/// extension array: normalize the rows via [`normalize_as_l2_denorm`], quantize the normalized -/// child via [`turboquant_encode_normalized`], and reattach the stored norms as the outer -/// [`L2Denorm`] wrapper. +/// Apply the full TurboQuant compression pipeline to a [`Vector`] extension array: normalize the +/// rows via [`normalize_as_l2_denorm`], quantize the normalized child via +/// [`turboquant_encode_normalized`], and reattach the stored norms as the outer [`L2Denorm`] +/// wrapper. /// /// The returned array has the canonical TurboQuant shape: /// @@ -91,8 +97,7 @@ pub fn turboquant_encode( // We must normalize the array before we can encode it with TurboQuant. let l2_denorm = normalize_as_l2_denorm(input, ctx)?; - // This is guaranteed to be a `NormalizedVector` extension type. - let normalized = l2_denorm.child_at(0).clone(); + let normalized = l2_denorm.child_at(0).clone(); // Guaranteed to be a `NormalizedVector`.. let norms = l2_denorm.child_at(1).clone(); let num_rows = l2_denorm.len(); @@ -107,13 +112,9 @@ pub fn turboquant_encode( Ok(unsafe { L2Denorm::new_array_unchecked(tq, norms, num_rows) }?.into_array()) } -/// Encode a non-nullable [`NormalizedVector`](crate::normalized_vector::NormalizedVector) -/// extension array into -/// a `ScalarFnArray(SorfTransform, [FSL(Dict(codes, centroids))])`, without validating the -/// unit-norm precondition. -/// -/// Passing non-unit-norm vectors will not cause memory unsafety, but will produce silently -/// incorrect quantization results. +/// Encode a non-nullable [`NormalizedVector`] extension array into a +/// `ScalarFnArray(SorfTransform, [FSL(Dict(codes, centroids))])`, without validating the unit-norm +/// precondition. pub fn turboquant_encode_normalized( ext: ArrayView, config: &TurboQuantConfig, @@ -125,7 +126,7 @@ pub fn turboquant_encode_normalized( let element_ptype = vector_metadata.element_ptype(); let dimensions = vector_metadata.dimensions(); - // `NormalizedVector` storage is `Extension(Vector(FSL))`; drill past the inner `Vector` to + // `NormalizedVector` storage is `Extension(Vector(FSL))`, so drill past the inner `Vector` to // reach the underlying `FixedSizeList`. let inner_vector: ExtensionArray = ext.storage_array().clone().execute(ctx)?; let fsl: FixedSizeListArray = inner_vector.storage_array().clone().execute(ctx)?; @@ -141,6 +142,28 @@ pub fn turboquant_encode_normalized( ); let num_rows = fsl.len(); + + // No data to quantize: short-circuit by returning an empty `NormalizedVector` directly at + // the final output shape `(dimensions, element_ptype)`. The non-empty path only goes + // through `SorfTransform` because the inverse rotation reshapes + // `(padded_dim, f32) → (dimensions, element_ptype)`; with zero rows there is no rotation + // to apply and we can construct an FSL with the destination dtype straight away. + if num_rows == 0 { + return match_each_float_ptype!(element_ptype, |T| { + let elements = PrimitiveArray::empty::(Nullability::NonNullable); + let empty_fsl = FixedSizeListArray::try_new( + elements.into_array(), + dimensions, + Validity::NonNullable, + 0, + )?; + + // SAFETY: An empty FSL contains no rows, so the unit-norm-or-zero invariant holds + // vacuously. + unsafe { NormalizedVector::new_unchecked(empty_fsl.into_array()) } + }); + } + let sorf_options = SorfOptions { seed: config.seed, num_rounds: config.num_rounds, @@ -148,28 +171,6 @@ pub fn turboquant_encode_normalized( element_ptype, }; - if fsl.is_empty() { - let padded_dim = dimensions.next_power_of_two(); - let empty_codes = PrimitiveArray::empty::(Nullability::NonNullable); - let empty_centroids = PrimitiveArray::empty::(Nullability::NonNullable); - let empty_dict = - DictArray::try_new(empty_codes.into_array(), empty_centroids.into_array())?; - let empty_fsl = FixedSizeListArray::try_new( - empty_dict.into_array(), - padded_dim, - Validity::NonNullable, - 0, - )?; - // SAFETY: An empty FSL contains no rows, so the unit-norm-or-zero invariant holds - // vacuously. - let empty_padded_vector = - unsafe { NormalizedVector::new_unchecked(empty_fsl.into_array()) }?; - - return Ok( - SorfTransform::try_new_array(&sorf_options, empty_padded_vector, 0)?.into_array(), - ); - } - let quantized_fsl = turboquant_quantize_fsl(&fsl, config.bit_width, &sorf_options, ctx)?; // NB: The quantized rows are approximately unit-norm by construction; downstream callers @@ -182,7 +183,7 @@ pub fn turboquant_encode_normalized( Ok(SorfTransform::try_new_array(&sorf_options, padded_vector, num_rows)?.into_array()) } -/// Rotate and quantize already-normalized rows into a dict-encoded `FixedSizeList`. +/// Rotate and quantize already-normalized vector rows into a dict-encoded `FixedSizeList`. /// /// The input `fsl` must contain non-nullable, unit-norm vectors of float values (already /// L2-normalized). Null vectors are not supported and must be zeroed out before reaching this @@ -208,11 +209,11 @@ fn turboquant_quantize_fsl( sorf_options: &SorfOptions, ctx: &mut ExecutionCtx, ) -> VortexResult { + vortex_ensure!(!fsl.dtype().is_nullable()); + let dimensions = fsl.list_size() as usize; let num_rows = fsl.len(); - vortex_ensure!(!fsl.dtype().is_nullable()); - let rotation = SorfMatrix::try_new( sorf_options.seed, dimensions, diff --git a/vortex-tensor/src/encodings/turboquant/mod.rs b/vortex-tensor/src/encodings/turboquant/mod.rs index bc404a8843e..93f70e13e9a 100644 --- a/vortex-tensor/src/encodings/turboquant/mod.rs +++ b/vortex-tensor/src/encodings/turboquant/mod.rs @@ -149,34 +149,5 @@ pub const MAX_BIT_WIDTH: u8 = 8; /// Maximum supported number of centroids in the scalar quantizer codebook. pub const MAX_CENTROIDS: usize = 1usize << (MAX_BIT_WIDTH as usize); -use vortex_array::dtype::DType; -use vortex_error::VortexResult; -use vortex_error::vortex_ensure; -use vortex_error::vortex_err; - -use crate::types::vector::AnyVector; -use crate::types::vector::VectorMatcherMetadata; - -/// Validates that `dtype` is a [`Vector`](crate::vector::Vector) extension type with -/// dimension >= [`MIN_DIMENSION`]. -/// -/// Returns the validated vector metadata on success. -pub fn tq_validate_vector_dtype(dtype: &DType) -> VortexResult { - let vector_metadata = dtype - .as_extension_opt() - .and_then(|ext| ext.metadata_opt::()) - .ok_or_else(|| { - vortex_err!("TurboQuant dtype must be a Vector extension type, got {dtype}") - })?; - - let dimensions = vector_metadata.dimensions(); - vortex_ensure!( - dimensions >= MIN_DIMENSION, - "TurboQuant requires dimension >= {MIN_DIMENSION}, got {dimensions}", - ); - - Ok(vector_metadata) -} - #[cfg(test)] mod tests; diff --git a/vortex-tensor/src/encodings/turboquant/scheme.rs b/vortex-tensor/src/encodings/turboquant/scheme.rs index d4362096bd2..b40bd52cb17 100644 --- a/vortex-tensor/src/encodings/turboquant/scheme.rs +++ b/vortex-tensor/src/encodings/turboquant/scheme.rs @@ -22,6 +22,7 @@ use vortex_array::ArrayRef; use vortex_array::Canonical; use vortex_array::ExecutionCtx; +use vortex_array::dtype::DType; use vortex_compressor::CascadingCompressor; use vortex_compressor::ctx::CompressorContext; use vortex_compressor::estimate::CompressionEstimate; @@ -30,11 +31,15 @@ use vortex_compressor::scheme::Scheme; use vortex_compressor::stats::ArrayAndStats; use vortex_error::VortexExpect; use vortex_error::VortexResult; +use vortex_error::vortex_ensure; +use vortex_error::vortex_err; use crate::encodings::turboquant::MAX_CENTROIDS; +use crate::encodings::turboquant::MIN_DIMENSION; use crate::encodings::turboquant::TurboQuantConfig; -use crate::encodings::turboquant::tq_validate_vector_dtype; use crate::encodings::turboquant::turboquant_encode; +use crate::vector::AnyVector; +use crate::vector::VectorMatcherMetadata; /// TurboQuant compression scheme for [`Vector`] extension types. /// @@ -133,6 +138,27 @@ fn estimate_compression_ratio(element_bit_width: u8, dimensions: u32, num_vector uncompressed_size_bits as f64 / compressed_size_bits as f64 } +/// Validates that `dtype` is a [`Vector`](crate::vector::Vector) extension type with +/// dimension >= [`MIN_DIMENSION`]. +/// +/// Returns the validated vector metadata on success. +pub fn tq_validate_vector_dtype(dtype: &DType) -> VortexResult { + let vector_metadata = dtype + .as_extension_opt() + .and_then(|ext| ext.metadata_opt::()) + .ok_or_else(|| { + vortex_err!("TurboQuant dtype must be a Vector extension type, got {dtype}") + })?; + + let dimensions = vector_metadata.dimensions(); + vortex_ensure!( + dimensions >= MIN_DIMENSION, + "TurboQuant requires dimension >= {MIN_DIMENSION}, got {dimensions}", + ); + + Ok(vector_metadata) +} + #[cfg(test)] mod tests { use rstest::rstest; diff --git a/vortex-tensor/src/scalar_fns/cosine_similarity.rs b/vortex-tensor/src/scalar_fns/cosine_similarity.rs index ae5ed3748dc..c429ad7bb01 100644 --- a/vortex-tensor/src/scalar_fns/cosine_similarity.rs +++ b/vortex-tensor/src/scalar_fns/cosine_similarity.rs @@ -35,7 +35,7 @@ use vortex_session::VortexSession; use crate::scalar_fns::inner_product::InnerProduct; use crate::scalar_fns::l2_denorm::NormalForm; -use crate::scalar_fns::l2_denorm::try_build_constant_l2_denorm; +use crate::scalar_fns::l2_denorm::try_build_constant_l2_denorm_from_constant; use crate::scalar_fns::l2_norm::L2Norm; use crate::utils::BinaryTensorOpMetadata; use crate::utils::validate_binary_tensor_float_inputs; @@ -133,26 +133,29 @@ impl ScalarFnVTable for CosineSimilarity { // If either side is a constant tensor-like extension array, eagerly normalize the single // stored row and re-wrap it as an `L2Denorm` whose children are both `ConstantArray`s. // The L2Denorm fast path below then picks it up. - if let Some(sfn) = try_build_constant_l2_denorm(&lhs_ref, len, ctx)? { + if let Some(sfn) = try_build_constant_l2_denorm_from_constant(&lhs_ref, len, ctx)? { lhs_ref = sfn.into_array(); } - if let Some(sfn) = try_build_constant_l2_denorm(&rhs_ref, len, ctx)? { + if let Some(sfn) = try_build_constant_l2_denorm_from_constant(&rhs_ref, len, ctx)? { rhs_ref = sfn.into_array(); } - // The combined validity always comes from the original operands. Compute it once up - // front so the unit-form helpers below can take it directly without re-deriving from - // an `L2Denorm` wrapper they no longer hold. + // The combined validity always comes from the original operands. Compute it once up front + // so the unit-form helpers below can take it directly without re-deriving from an + // `L2Denorm` wrapper they no longer hold. let validity = lhs_ref.validity()?.and(rhs_ref.validity()?)?; - // Classify each operand by its normal form. When both operands carry a known unit-norm - // representation, cosine similarity collapses to the dot product of the unit vectors. + // Classify each operand by its normal form. let lhs_form = NormalForm::classify(&lhs_ref); let rhs_form = NormalForm::classify(&rhs_ref); match (lhs_form.normalized_array(), rhs_form.normalized_array()) { (Some(unit_lhs), Some(unit_rhs)) => { + // When both operands carry a known unit-norm representation, cosine similarity + // collapses to the dot product of the unit vectors. return self.execute_both_unit(unit_lhs, unit_rhs, validity, len); } + // When one operand carries a unit-norm representation, then we can skip one of the + // division steps. (Some(unit_lhs), None) => { return self.execute_one_unit(unit_lhs, &rhs_ref, validity, len, ctx); } @@ -174,7 +177,7 @@ impl ScalarFnVTable for CosineSimilarity { let norm_r: PrimitiveArray = norm_rhs_arr.into_array().execute(ctx)?; // TODO(connor): Ideally we would have a `SafeDiv` binary numeric operation. - // TODO(connor): This can be written in a more SIMD-friendly manner. + // TODO(connor): This can probably be written in a more SIMD-friendly manner. match_each_float_ptype!(dot.ptype(), |T| { let dots = dot.as_slice::(); let norms_l = norm_l.as_slice::(); @@ -236,6 +239,7 @@ impl ScalarFnArrayVTable for CosineSimilarity { ) -> VortexResult> { let reconstructed = BinaryTensorOpMetadata::decode_children(metadata, len, children, session)?; + Ok(ScalarFnArrayParts { options: EmptyOptions, children: reconstructed, diff --git a/vortex-tensor/src/scalar_fns/inner_product.rs b/vortex-tensor/src/scalar_fns/inner_product.rs index 9cdae9903d6..f1bcfe3a2e9 100644 --- a/vortex-tensor/src/scalar_fns/inner_product.rs +++ b/vortex-tensor/src/scalar_fns/inner_product.rs @@ -49,6 +49,7 @@ use crate::matcher::AnyTensor; use crate::scalar_fns::l2_denorm::NormalForm; use crate::scalar_fns::sorf_transform::SorfMatrix; use crate::scalar_fns::sorf_transform::SorfTransform; +use crate::types::normalized_vector::inner_vector_array; use crate::types::vector::Vector; use crate::utils::BinaryTensorOpMetadata; use crate::utils::extract_constant_flat_row; @@ -123,7 +124,8 @@ impl ScalarFnVTable for InnerProduct { let lhs = &arg_dtypes[0]; let rhs = &arg_dtypes[1]; - // TODO(connor): relax the float-only gate once integer tensors are supported. + // TODO(connor): Relax the float-only gate once integer tensors are supported, since inner + // product is defined for integer tensors. let tensor_match = validate_binary_tensor_float_inputs(lhs, rhs)?; let ptype = tensor_match.element_ptype(); let nullability = Nullability::from(lhs.is_nullable() || rhs.is_nullable()); @@ -140,16 +142,16 @@ impl ScalarFnVTable for InnerProduct { let rhs_ref = args.get(1)?; let len = args.row_count(); - // Take the unit-norm fast path only when at least one operand wraps stored norms (the - // `Denormalized` form). For naked `NormalizedVector` operands the fall-through dot - // product already computes the right thing (and short-circuiting here would recurse - // back into `InnerProduct`). + // Take the factored fast path only when at least one operand wraps stored norms (the + // `Denormalized` form). Routing through this lets us extract the stored norms instead of + // canonicalizing the `L2Denorm` ScalarFnArray, which would materialize `unit · norms` + // row-by-row before the dotan avoidable `O(N·D)` pass. let lhs_form = NormalForm::classify(&lhs_ref); let rhs_form = NormalForm::classify(&rhs_ref); if matches!(lhs_form, NormalForm::Denormalized { .. }) || matches!(rhs_form, NormalForm::Denormalized { .. }) { - return self.execute_unit_form(&lhs_form, &rhs_form, &lhs_ref, &rhs_ref, len, ctx); + return self.execute_factored_dot(&lhs_form, &rhs_form, &lhs_ref, &rhs_ref, len, ctx); } // Reduction case 1: `InnerProduct(SorfTransform(x), const)` rewrites to @@ -169,10 +171,10 @@ impl ScalarFnVTable for InnerProduct { // Compute combined validity. let validity = lhs_ref.validity()?.and(rhs_ref.validity()?)?; - // Drill past any `NormalizedVector` wrapper so we always work with the underlying - // `Vector` extension array. - let lhs_inner = crate::types::normalized_vector::inner_vector_array(&lhs_ref, ctx)?; - let rhs_inner = crate::types::normalized_vector::inner_vector_array(&rhs_ref, ctx)?; + // Drill past any `NormalizedVector` wrapper so we always work with the underlying `Vector` + // extension array. + let lhs_inner = inner_vector_array(&lhs_ref, ctx)?; + let rhs_inner = inner_vector_array(&rhs_ref, ctx)?; // Canonicalize so we can perform the math directly. let lhs: ExtensionArray = lhs_inner.execute(ctx)?; @@ -244,6 +246,7 @@ impl ScalarFnArrayVTable for InnerProduct { ) -> VortexResult> { let reconstructed = BinaryTensorOpMetadata::decode_children(metadata, len, children, session)?; + Ok(ScalarFnArrayParts { options: EmptyOptions, children: reconstructed, @@ -252,11 +255,15 @@ impl ScalarFnArrayVTable for InnerProduct { } impl InnerProduct { - /// Inner product over operands that may carry a unit-norm representation: - /// `inner_product = scale_l * scale_r * dot(unit_l, unit_r)`, where the `(unit, scale)` pair - /// for each operand is `(operand, None)` for `Plain`, `(NV, None)` for naked `Normalized`, - /// and `(NV, Some(stored_norms))` for `Denormalized`. See [`decompose_for_unit_form`]. - fn execute_unit_form( + /// Compute `` after factoring each operand into a `(vector, optional_scale)` pair + /// via [`factor_operand`]. The math is ` = scale_l · scale_r · `, where + /// a `None` scale acts as `1.0` (skipping the per-row multiply). + /// + /// This is **not** restricted to unit-norm operands. `Plain` factors as `(operand, None)` with + /// `scale = 1`, and the formula still holds: `` = + /// `scale_r · `. The win over the standard path is avoiding canonicalizing the + /// `L2Denorm` ScalarFnArray (which would materialize `unit · norms` per row before the dot). + fn execute_factored_dot( &self, lhs_form: &NormalForm<'_>, rhs_form: &NormalForm<'_>, @@ -267,13 +274,16 @@ impl InnerProduct { ) -> VortexResult { let validity = lhs_ref.validity()?.and(rhs_ref.validity()?)?; - let (unit_lhs, lhs_scale) = decompose_for_unit_form(lhs_form, lhs_ref, ctx)?; - let (unit_rhs, rhs_scale) = decompose_for_unit_form(rhs_form, rhs_ref, ctx)?; + let (vec_lhs, lhs_scale) = factor_operand(lhs_form, lhs_ref, ctx)?; + let (vec_rhs, rhs_scale) = factor_operand(rhs_form, rhs_ref, ctx)?; - let dot: PrimitiveArray = InnerProduct::try_new_array(unit_lhs, unit_rhs, len)? + // NB: The call into `dot(vec_l, vec_r)` here dispatches back through `InnerProduct`, which + // lets the SORF and Dict reductions fire on TurboQuant's `SorfTransform` child. + let dot: PrimitiveArray = InnerProduct::try_new_array(vec_lhs, vec_rhs, len)? .into_array() .execute(ctx)?; + // TODO(connor): This should use the binary `Mul` expressions. match_each_float_ptype!(dot.ptype(), |T| { let dots = dot.as_slice::(); let buffer: Buffer = match (lhs_scale.as_ref(), rhs_scale.as_ref()) { @@ -489,16 +499,23 @@ impl InnerProduct { } } -/// Decompose an operand classified by [`NormalForm`] into the `(unit_operand, optional_scale)` -/// pair consumed by [`InnerProduct::execute_unit_form`]: +/// Factor an operand classified by [`NormalForm`] into the `(vector, optional_scale)` pair consumed +/// by [`InnerProduct::execute_factored_dot`]. The factorization satisfies +/// `original = scale · vector` (with `scale = 1` when the returned scale is `None`), so the inner +/// product distributes as ` = scale_l · scale_r · `. +/// +/// The "vector" component is **not** required to be unit-norm: for `Plain` operands the entire +/// operand is returned as the "vector" with an implicit scale of `1`. The point of the +/// factorization is to surface the stored norms of `Denormalized` operands so they can be applied +/// after the dot, not to assert anything about the geometry of the vector component. /// -/// - `Plain`: `(original, None)`. The unscaled operand IS its own "unit" for dot purposes; no -/// per-row multiply is needed. -/// - `Normalized`: `(NV, None)`. Implicit per-row scale of `1.0` — the unit child enters the dot -/// directly with no multiply. -/// - `Denormalized`: `(NV, Some(stored_norms))`. The dot is computed over the unit child and the -/// caller multiplies by the materialized stored norms afterward. -fn decompose_for_unit_form( +/// - `Plain`: `(original, None)`. Implicit `scale = 1`; the operand passes through to the dot +/// unchanged. +/// - `Normalized`: `(NV, None)`. Implicit `scale = 1`; the unit-norm child passes through to +/// the dot unchanged. +/// - `Denormalized`: `(NV, Some(stored_norms))`. The dot is computed over the unit-norm child +/// and the caller multiplies row-wise by the materialized stored norms afterward. +fn factor_operand( form: &NormalForm<'_>, original: &ArrayRef, ctx: &mut ExecutionCtx, diff --git a/vortex-tensor/src/scalar_fns/l2_denorm.rs b/vortex-tensor/src/scalar_fns/l2_denorm.rs index 0230e36d8f2..f5530cbb524 100644 --- a/vortex-tensor/src/scalar_fns/l2_denorm.rs +++ b/vortex-tensor/src/scalar_fns/l2_denorm.rs @@ -533,6 +533,7 @@ pub fn normalize_as_l2_denorm( // by [`L2Denorm::try_new_array`] that a zero-norm row is paired with an all-zero normalized // row, because [`L2Norm`]'s `NormalizedVector` short-circuit emits 0.0 exactly when the row // is all zero. + // This also has the added benefit of correcting any lossy-encoded `NormalizedVector` arrays. if tensor_metadata.is_normalized() { let norms_sfn = L2Norm::try_new_array(input.clone(), row_count)?; let norms_array: ArrayRef = norms_sfn.into_array().execute(ctx)?; @@ -545,7 +546,7 @@ pub fn normalize_as_l2_denorm( // Constant fast path: if the input is a constant-backed extension, normalize the single // stored row once and return an `L2Denorm` whose children are both `ConstantArray`s. - if let Some(wrapped) = try_build_constant_l2_denorm(&input, row_count, ctx)? { + if let Some(wrapped) = try_build_constant_l2_denorm_from_constant(&input, row_count, ctx)? { return Ok(wrapped); } @@ -609,7 +610,7 @@ pub fn normalize_as_l2_denorm( /// /// This is helpful in some of the reduction steps for cosine similarity execution into inner /// product execution. -pub(crate) fn try_build_constant_l2_denorm( +pub(crate) fn try_build_constant_l2_denorm_from_constant( input: &ArrayRef, len: usize, ctx: &mut ExecutionCtx, diff --git a/vortex-tensor/src/scalar_fns/sorf_transform/mod.rs b/vortex-tensor/src/scalar_fns/sorf_transform/mod.rs index a6862465eba..3d042bff604 100644 --- a/vortex-tensor/src/scalar_fns/sorf_transform/mod.rs +++ b/vortex-tensor/src/scalar_fns/sorf_transform/mod.rs @@ -78,11 +78,14 @@ pub struct SorfTransform; pub struct SorfOptions { /// Seed used to generate the structured sign diagonals via Vortex's frozen SplitMix64 stream. pub seed: u64, + /// Number of sign-diagonal + WHT rounds in the structured orthogonal transform. pub num_rounds: u8, + /// Original vector dimension (before power-of-2 padding). The output /// [`Vector`](crate::vector::Vector) has this dimension. pub dimensions: u32, + /// Element type of the output [`Vector`](crate::vector::Vector). The child input must always /// be `f32`, but the output can be any float type (`F16`, `F32`, `F64`); the final /// `f32 -> element_ptype` cast happens while building the output. @@ -90,7 +93,8 @@ pub struct SorfOptions { } impl SorfTransform { - /// Creates a new [`TypedScalarFnInstance`] wrapping the SORF inverse transform with the given options. + /// Creates a new [`TypedScalarFnInstance`] wrapping the SORF inverse transform with the given + /// options. pub fn new(options: &SorfOptions) -> TypedScalarFnInstance { TypedScalarFnInstance::new(SorfTransform, options.clone()) } @@ -121,7 +125,7 @@ impl SorfTransform { } /// Checks that the SORF configuration is valid. -pub(crate) fn validate_sorf_options(options: &SorfOptions) -> VortexResult<()> { +pub(super) fn validate_sorf_options(options: &SorfOptions) -> VortexResult<()> { vortex_ensure!( options.num_rounds >= 1, "SorfTransform num_rounds must be >= 1, got {}", diff --git a/vortex-tensor/src/scalar_fns/sorf_transform/rotation.rs b/vortex-tensor/src/scalar_fns/sorf_transform/rotation.rs index ff8aebd0f11..ea6e776ecdb 100644 --- a/vortex-tensor/src/scalar_fns/sorf_transform/rotation.rs +++ b/vortex-tensor/src/scalar_fns/sorf_transform/rotation.rs @@ -61,17 +61,19 @@ pub struct SorfMatrix { /// Indexed as `round * padded_dim + i`. `0x00000000` = multiply by +1 (no-op), `0x80000000` = /// multiply by -1 (flip sign bit). sign_masks: Vec, + /// The number of sign-diagonal + WHT rounds. num_rounds: usize, /// The padded dimension (next power of 2 >= dimension). padded_dim: usize, - /// Normalization factor: `padded_dim^(-num_rounds/2)`, applied once at the end. + /// Normalization factor: `padded_dim^(-num_rounds/2)`, applied once at the end. This is stored + /// for convenience. norm_factor: f32, } impl SorfMatrix { - /// Create a new structured Walsh-Hadamard-based orthogonal transform from a deterministic - /// seed. + // TODO(connor): Should this just only allow power-of-2 dimensions? Require the caller to do it? + /// Create a new structured Walsh-Hadamard-based orthogonal transform from a deterministic seed. /// /// The seed is expanded using Vortex's frozen local SplitMix64 stream. Signs are generated in /// round-major, block-major order, with each `u64` contributing 64 sign bits in diff --git a/vortex-tensor/src/scalar_fns/sorf_transform/vtable.rs b/vortex-tensor/src/scalar_fns/sorf_transform/vtable.rs index c404f38e27c..c7f59b040a2 100644 --- a/vortex-tensor/src/scalar_fns/sorf_transform/vtable.rs +++ b/vortex-tensor/src/scalar_fns/sorf_transform/vtable.rs @@ -47,7 +47,6 @@ use super::SorfOptions; use super::SorfTransform; use super::rotation::SorfMatrix; use super::validate_sorf_options; -use crate::matcher::AnyTensor; use crate::types::normalized_vector::NormalizedVector; use crate::types::normalized_vector::inner_vector_array; use crate::types::vector::AnyVector; @@ -88,25 +87,24 @@ impl ScalarFnVTable for SorfTransform { let child_dtype = &arg_dtypes[0]; let vector_metadata = child_dtype .as_extension_opt() - .and_then(|ext| ext.metadata_opt::()) + .and_then(|ext| ext.metadata_opt::()) .ok_or_else(|| { vortex_err!( - "SorfTransform child must be a Vector or NormalizedVector extension, got \ - {child_dtype}" + "SorfTransform child must be a Vector or NormalizedVector extension, \ + got {child_dtype}" ) })?; let expected_padded = options.dimensions.next_power_of_two(); vortex_ensure_eq!( - vector_metadata.list_size(), + vector_metadata.dimensions(), expected_padded, "SorfTransform child Vector must have dimension {expected_padded} (next power of two \ for dimension {})", options.dimensions, ); - // For now, the child Vector storage must be f32. TurboQuant stores its centroids as f32, - // and the SORF transform itself operates in f32, so any other input type would require an + // For now, the child Vector storage must be f32, so any other input type would require an // implicit cast that we do not yet support. The output element type is independently // specified via `options.element_ptype` and is built below. vortex_ensure_eq!( @@ -123,23 +121,17 @@ impl ScalarFnVTable for SorfTransform { child_dtype.nullability(), ); - // The output mirrors the child's wrapper kind: if the child was a `NormalizedVector` the - // output is also surfaced as a `NormalizedVector` (the orthogonal inverse rotation - // preserves L2 norm and the truncation drops coordinates that were zero pre-rotation, so - // the output is approximately unit-norm under the same lossy contract that - // `NormalizedVector::new_unchecked` documents). - if vector_metadata.is_normalized() { + // The output mirrors the child's wrapper kind, so if the child was a `NormalizedVector`, + // the output is also a `NormalizedVector`. + let inner = if vector_metadata.is_normalized() { let inner_vector = ExtDType::::try_new(EmptyMetadata, fsl_dtype)?.erased(); - let outer = ExtDType::::try_new( - EmptyMetadata, - DType::Extension(inner_vector), - )? - .erased(); - Ok(DType::Extension(outer)) + ExtDType::::try_new(EmptyMetadata, DType::Extension(inner_vector))? + .erased() } else { - let ext_dtype = ExtDType::::try_new(EmptyMetadata, fsl_dtype)?.erased(); - Ok(DType::Extension(ext_dtype)) - } + ExtDType::::try_new(EmptyMetadata, fsl_dtype)?.erased() + }; + + Ok(DType::Extension(inner)) } fn execute( @@ -152,34 +144,32 @@ impl ScalarFnVTable for SorfTransform { let num_rows = args.row_count(); let child_arg = args.get(0)?; - let is_normalized_child = child_arg + let child_is_normalized = child_arg .dtype() .as_extension_opt() .is_some_and(|ext| ext.is::()); - let fsl_array: ArrayRef = if num_rows == 0 { + let fsl_array = if num_rows == 0 { let validity = Validity::from(child_arg.dtype().nullability()); + let elements = match_each_float_ptype!(options.element_ptype, |T| { + PrimitiveArray::empty::(Nullability::NonNullable) + }) + .into_array(); - match_each_float_ptype!(options.element_ptype, |T| { - let elements = PrimitiveArray::empty::(Nullability::NonNullable); - FixedSizeListArray::try_new(elements.into_array(), options.dimensions, validity, 0) - })? - .into_array() + FixedSizeListArray::try_new(elements, options.dimensions, validity, 0)?.into_array() } else { - // Execute the child to get either a `Vector` extension or a `NormalizedVector` - // wrapping a `Vector` over an FSL of f32 coordinates. The `return_dtype` check - // guarantees the shape is `Vector` at the FSL level, so drill past - // any `NormalizedVector` wrapper before unpacking. + // Execute the child, since we cannot apply rotations over compressed data. let child_ref = inner_vector_array(&child_arg, ctx)?; let child_ext: ExtensionArray = child_ref.execute(ctx)?; let child_validity = child_ext.as_ref().validity()?; let child_fsl: FixedSizeListArray = child_ext.storage_array().clone().execute(ctx)?; - let padded_dim = - usize::try_from(child_fsl.list_size()).vortex_expect("list_size fits usize"); let elements_prim: PrimitiveArray = child_fsl.elements().clone().execute(ctx)?; let f32_elements = elements_prim.into_buffer::(); + let padded_dim = + usize::try_from(child_fsl.list_size()).vortex_expect("list_size fits usize"); + // Reconstruct the orthogonal transform matrix from the seed. let rotation = SorfMatrix::try_new(options.seed, dim, options.num_rounds as usize)?; @@ -196,14 +186,9 @@ impl ScalarFnVTable for SorfTransform { })? }; - // SAFETY: When `is_normalized_child` is `true`, the input child was a - // `NormalizedVector` (its dtype was checked above), so every valid row was unit-norm or - // zero by type. Inverse SORF is orthogonal (norm-preserving), and the truncated tail - // coordinates were zero pre-rotation up to quantization noise — so each row of - // `fsl_array` is approximately unit-norm under the same lossy contract that - // [`NormalizedVector::new_unchecked`] documents. When `is_normalized_child` is `false` - // [`wrap_output`] takes the trivially-safe `Vector` branch. - unsafe { wrap_output(fsl_array, is_normalized_child) } + // SAFETY: We used the matcher to check if the child was normalized, so this must be + // correct. + unsafe { wrap_vector_storage(fsl_array, child_is_normalized) } } fn validity( @@ -258,18 +243,12 @@ impl ScalarFnArrayVTable for SorfTransform { len: usize, metadata: &[u8], children: &dyn ArrayChildren, - session: &VortexSession, + _session: &VortexSession, ) -> VortexResult> { - let _ = session; let metadata = SorfTransformMetadata::decode(metadata) .map_err(|e| vortex_err!("Failed to decode SorfTransformMetadata: {e}"))?; let options = metadata.to_options()?; - // The parent dtype must be a vector-shaped extension produced by `return_dtype`: either - // a plain `Vector` (when the child was a plain `Vector`) or a `NormalizedVector` (when - // the child was a `NormalizedVector`). `AnyVector` matches both, and its `try_match` - // panics on a structurally malformed `NormalizedVector`, so a successful match also - // guarantees the inner drill below is well-formed. let parent_ext = dtype .as_extension_opt() .filter(|ext| ext.is::()) @@ -279,40 +258,28 @@ impl ScalarFnArrayVTable for SorfTransform { extension, got {dtype}", ) })?; - let is_normalized = parent_ext.is::(); - - // The child's FSL nullability matches the parent's inner FSL nullability (set by - // `return_dtype` from the original child's outer nullability). Drill into the parent - // wrapper to recover it; `AnyVector` already validated the structural shape. - let parent_fsl_dtype = if is_normalized { - let DType::Extension(inner) = parent_ext.storage_dtype() else { - unreachable!( - "`AnyVector` matcher guarantees a `NormalizedVector` parent wraps a \ - `Vector` extension" - ) - }; - inner.storage_dtype() - } else { - parent_ext.storage_dtype() - }; - let fsl_nullability = parent_fsl_dtype.nullability(); + + // The nullability of the parent extension type is the same as the storage type. + let fsl_nullability = parent_ext.nullability(); let padded_dim = options.dimensions.next_power_of_two(); - let child_fsl = DType::FixedSizeList( + let child_fsl_dtype = DType::FixedSizeList( Arc::new(DType::Primitive(PType::F32, Nullability::NonNullable)), padded_dim, fsl_nullability, ); - let inner_vector = ExtDType::::try_new(EmptyMetadata, child_fsl)?.erased(); - let child_dtype = if is_normalized { + let inner_vector_dtype = + ExtDType::::try_new(EmptyMetadata, child_fsl_dtype)?.erased(); + + let child_dtype = if parent_ext.is::() { let nv = ExtDType::::try_new( EmptyMetadata, - DType::Extension(inner_vector), + DType::Extension(inner_vector_dtype), )? .erased(); DType::Extension(nv) } else { - DType::Extension(inner_vector) + DType::Extension(inner_vector_dtype) }; let child = children.get(0, &child_dtype, len)?; @@ -348,7 +315,7 @@ fn inverse_rotate_typed( let mut unrotated = vec![0.0f32; padded_dim]; for row in 0..num_rows { - let row_data = &f32_elements[row * padded_dim..(row + 1) * padded_dim]; + let row_data = &f32_elements[row * padded_dim..][..padded_dim]; rotation.inverse_rotate(row_data, &mut unrotated); @@ -372,7 +339,7 @@ fn inverse_rotate_typed( /// zero in the lossy sense documented by [`NormalizedVector::new_unchecked`]. /// /// When `is_normalized` is `false` the function takes the safe `Vector` branch. -unsafe fn wrap_output(fsl: ArrayRef, is_normalized: bool) -> VortexResult { +unsafe fn wrap_vector_storage(fsl: ArrayRef, is_normalized: bool) -> VortexResult { if is_normalized { // SAFETY: Forwarded from the function-level safety contract above. unsafe { NormalizedVector::new_unchecked(fsl) } diff --git a/vortex-tensor/src/types/normalized_vector/matcher.rs b/vortex-tensor/src/types/normalized_vector/matcher.rs index ec5dc2b729c..b1974e13b52 100644 --- a/vortex-tensor/src/types/normalized_vector/matcher.rs +++ b/vortex-tensor/src/types/normalized_vector/matcher.rs @@ -52,7 +52,7 @@ impl Matcher for AnyNormalizedVector { "element dtype must be non-nullable" ); - let metadata = VectorMatcherMetadata::try_new(element_dtype.as_ptype(), *list_size) + let metadata = VectorMatcherMetadata::try_new(element_dtype.as_ptype(), *list_size, true) .vortex_expect("`NormalizedVector` inner Vector did not have float elements"); Some(metadata) diff --git a/vortex-tensor/src/types/vector/matcher.rs b/vortex-tensor/src/types/vector/matcher.rs index a646e8009fe..e6c878cb721 100644 --- a/vortex-tensor/src/types/vector/matcher.rs +++ b/vortex-tensor/src/types/vector/matcher.rs @@ -41,6 +41,9 @@ pub struct VectorMatcherMetadata { /// The number of dimensions of the vector. This is always fixed. dimensions: u32, + + ///`true` when the dtype is a [`NormalizedVector`]. + is_normalized: bool, } impl Matcher for AnyVector { @@ -50,7 +53,7 @@ impl Matcher for AnyVector { // Walk to the inner `FixedSizeList` for whichever vector-shaped wrapper this is. Plain // `Vector` stores the FSL directly; `NormalizedVector` wraps a `Vector` extension which // in turn stores the FSL. - let fsl_dtype = if ext_dtype.is::() { + let (fsl_dtype, is_normalized) = if ext_dtype.is::() { let DType::Extension(inner) = ext_dtype.storage_dtype() else { vortex_panic!( "`NormalizedVector` storage must be `DType::Extension(Vector)`, got {}", @@ -65,9 +68,9 @@ impl Matcher for AnyVector { ) } - inner.storage_dtype() + (inner.storage_dtype(), true) } else if ext_dtype.is::() { - ext_dtype.storage_dtype() + (ext_dtype.storage_dtype(), false) } else { return None; }; @@ -85,8 +88,9 @@ impl Matcher for AnyVector { ); let element_ptype = element_dtype.as_ptype(); - let vector_metadata = VectorMatcherMetadata::try_new(element_ptype, dimensions) - .vortex_expect("`Vector` type somehow did not have float elements"); + let vector_metadata = + VectorMatcherMetadata::try_new(element_ptype, dimensions, is_normalized) + .vortex_expect("`Vector` type somehow did not have float elements"); Some(vector_metadata) } @@ -98,12 +102,17 @@ impl VectorMatcherMetadata { /// # Errors /// /// Returns an error if the element type is not a float. - pub fn try_new(element_ptype: PType, dimensions: u32) -> VortexResult { + pub fn try_new( + element_ptype: PType, + dimensions: u32, + is_normalized: bool, + ) -> VortexResult { vortex_ensure!(element_ptype.is_float()); Ok(Self { element_ptype, dimensions, + is_normalized, }) } @@ -116,6 +125,12 @@ impl VectorMatcherMetadata { pub fn dimensions(&self) -> u32 { self.dimensions } + + /// Returns `true` when the dtype is a + /// [`NormalizedVector`](crate::normalized_vector::NormalizedVector). + pub fn is_normalized(self) -> bool { + self.is_normalized + } } #[cfg(test)] From 2b4588c5ef456aa5f06fae802e855cf1f2b3dc0c Mon Sep 17 00:00:00 2001 From: Connor Tsui Date: Tue, 28 Apr 2026 17:32:21 -0400 Subject: [PATCH 4/6] move utils around Signed-off-by: Connor Tsui --- .../src/scalar_fns/cosine_similarity.rs | 5 +-- vortex-tensor/src/scalar_fns/inner_product.rs | 2 +- vortex-tensor/src/scalar_fns/l2_denorm.rs | 5 ++- vortex-tensor/src/scalar_fns/l2_norm.rs | 5 ++- .../src/scalar_fns/sorf_transform/vtable.rs | 2 +- .../src/types/normalized_vector/mod.rs | 44 ------------------- vortex-tensor/src/types/vector/matcher.rs | 4 +- vortex-tensor/src/utils.rs | 44 +++++++++++++++++++ 8 files changed, 55 insertions(+), 56 deletions(-) diff --git a/vortex-tensor/src/scalar_fns/cosine_similarity.rs b/vortex-tensor/src/scalar_fns/cosine_similarity.rs index c429ad7bb01..a9369ad5ca0 100644 --- a/vortex-tensor/src/scalar_fns/cosine_similarity.rs +++ b/vortex-tensor/src/scalar_fns/cosine_similarity.rs @@ -269,10 +269,7 @@ impl CosineSimilarity { } /// Exactly one side carries a unit-norm representation: cosine similarity reduces to - /// `dot(unit, plain) / ||plain||`. The norms of the unit side are implicitly `1.0` (naked - /// `NormalizedVector`) or stored separately on the outer `L2Denorm` wrapper, which the - /// caller has already stripped — cosine ignores magnitude on the unit side, so the wrapper - /// is not needed here. + /// `dot(unit, plain) / ||plain||`. fn execute_one_unit( &self, unit: &ArrayRef, diff --git a/vortex-tensor/src/scalar_fns/inner_product.rs b/vortex-tensor/src/scalar_fns/inner_product.rs index f1bcfe3a2e9..ddd3d907c20 100644 --- a/vortex-tensor/src/scalar_fns/inner_product.rs +++ b/vortex-tensor/src/scalar_fns/inner_product.rs @@ -49,11 +49,11 @@ use crate::matcher::AnyTensor; use crate::scalar_fns::l2_denorm::NormalForm; use crate::scalar_fns::sorf_transform::SorfMatrix; use crate::scalar_fns::sorf_transform::SorfTransform; -use crate::types::normalized_vector::inner_vector_array; use crate::types::vector::Vector; use crate::utils::BinaryTensorOpMetadata; use crate::utils::extract_constant_flat_row; use crate::utils::extract_flat_elements; +use crate::utils::inner_vector_array; use crate::utils::validate_binary_tensor_float_inputs; /// Inner product (dot product) between two columns. diff --git a/vortex-tensor/src/scalar_fns/l2_denorm.rs b/vortex-tensor/src/scalar_fns/l2_denorm.rs index f5530cbb524..76302f94e84 100644 --- a/vortex-tensor/src/scalar_fns/l2_denorm.rs +++ b/vortex-tensor/src/scalar_fns/l2_denorm.rs @@ -62,14 +62,14 @@ use vortex_session::VortexSession; use crate::matcher::AnyTensor; use crate::scalar_fns::l2_norm::L2Norm; use crate::types::normalized_vector::NormalizedVector; -use crate::types::normalized_vector::inner_vector_array; -use crate::types::normalized_vector::vector_fsl_storage_dtype; use crate::types::vector::AnyVector; use crate::types::vector::Vector; use crate::utils::extract_constant_flat_row; use crate::utils::extract_flat_elements; use crate::utils::extract_l2_denorm_children; +use crate::utils::inner_vector_array; use crate::utils::unit_norm_tolerance; +use crate::utils::vector_fsl_storage_dtype; /// Re-applies authoritative L2 norms to a normalized vector column. /// @@ -599,6 +599,7 @@ pub fn normalize_as_l2_denorm( unsafe { L2Denorm::new_array_unchecked(normalized, norms_array, row_count) } } +// TODO(connor): This does not handle `NormalizedVector` correctly!!! /// Attempts to build an [`L2Denorm`] whose two children are both [`ConstantArray`]s by eagerly /// normalizing `input`'s single stored row. /// diff --git a/vortex-tensor/src/scalar_fns/l2_norm.rs b/vortex-tensor/src/scalar_fns/l2_norm.rs index 67b196eba4e..bcabe823466 100644 --- a/vortex-tensor/src/scalar_fns/l2_norm.rs +++ b/vortex-tensor/src/scalar_fns/l2_norm.rs @@ -48,6 +48,7 @@ use vortex_session::VortexSession; use crate::matcher::AnyTensor; use crate::scalar_fns::l2_denorm::NormalForm; use crate::utils::extract_flat_elements; +use crate::utils::inner_vector_array; use crate::utils::validate_tensor_float_input; /// L2 norm (Euclidean norm) of a tensor or vector column. @@ -186,7 +187,7 @@ impl ScalarFnVTable for L2Norm { // Drill past any `NormalizedVector` wrapper so we always work with the underlying // `Vector` extension array. - let input_ref = crate::types::normalized_vector::inner_vector_array(&input_ref, ctx)?; + let input_ref = inner_vector_array(&input_ref, ctx)?; let input: ExtensionArray = input_ref.execute(ctx)?; let validity = input.as_ref().validity()?; @@ -287,7 +288,7 @@ fn execute_normalized_vector_norms( ) -> VortexResult { // `NormalizedVector` storage is `Extension(Vector(FSL))`; drill to the inner `Vector` to // reach the underlying FSL. - let vector_ref = crate::types::normalized_vector::inner_vector_array(input_ref, ctx)?; + let vector_ref = inner_vector_array(input_ref, ctx)?; let input: ExtensionArray = vector_ref.execute(ctx)?; let validity = input.as_ref().validity()?; let flat = extract_flat_elements(input.storage_array(), tensor_flat_size, ctx)?; diff --git a/vortex-tensor/src/scalar_fns/sorf_transform/vtable.rs b/vortex-tensor/src/scalar_fns/sorf_transform/vtable.rs index c7f59b040a2..72592543125 100644 --- a/vortex-tensor/src/scalar_fns/sorf_transform/vtable.rs +++ b/vortex-tensor/src/scalar_fns/sorf_transform/vtable.rs @@ -48,9 +48,9 @@ use super::SorfTransform; use super::rotation::SorfMatrix; use super::validate_sorf_options; use crate::types::normalized_vector::NormalizedVector; -use crate::types::normalized_vector::inner_vector_array; use crate::types::vector::AnyVector; use crate::types::vector::Vector; +use crate::utils::inner_vector_array; impl ScalarFnVTable for SorfTransform { type Options = SorfOptions; diff --git a/vortex-tensor/src/types/normalized_vector/mod.rs b/vortex-tensor/src/types/normalized_vector/mod.rs index 605f6670c27..aa44f670743 100644 --- a/vortex-tensor/src/types/normalized_vector/mod.rs +++ b/vortex-tensor/src/types/normalized_vector/mod.rs @@ -142,50 +142,6 @@ pub(crate) fn validate_unit_norm_rows( Ok(()) } -/// Returns the underlying `FixedSizeList` storage dtype for a vector-shaped extension dtype. -/// -/// For a plain [`Vector`], this is the direct storage dtype. For a [`NormalizedVector`] -/// it drills through one extra extension layer. -pub(crate) fn vector_fsl_storage_dtype( - ext: &vortex_array::dtype::extension::ExtDTypeRef, -) -> Option { - use vortex_array::dtype::DType; - if ext.is::() { - Some(ext.storage_dtype().clone()) - } else if ext.is::() { - let DType::Extension(inner) = ext.storage_dtype() else { - return None; - }; - if !inner.is::() { - return None; - } - Some(inner.storage_dtype().clone()) - } else { - None - } -} - -/// Returns the underlying `Vector` extension array inside a vector-shaped extension array. -/// -/// For a [`NormalizedVector`] array, this executes the outer extension and returns its -/// `Vector` storage child. For a plain [`Vector`] array, it returns the array itself (after -/// canonicalizing to an `ExtensionArray`). -pub(crate) fn inner_vector_array( - array: &ArrayRef, - ctx: &mut ExecutionCtx, -) -> VortexResult { - let is_normalized = array - .dtype() - .as_extension_opt() - .is_some_and(|ext| ext.is::()); - if is_normalized { - let ext: ExtensionArray = array.clone().execute(ctx)?; - Ok(ext.storage_array().clone()) - } else { - Ok(array.clone()) - } -} - mod matcher; mod vtable; diff --git a/vortex-tensor/src/types/vector/matcher.rs b/vortex-tensor/src/types/vector/matcher.rs index e6c878cb721..e7a4b5c39e9 100644 --- a/vortex-tensor/src/types/vector/matcher.rs +++ b/vortex-tensor/src/types/vector/matcher.rs @@ -13,8 +13,8 @@ use vortex_error::vortex_panic; use crate::types::normalized_vector::NormalizedVector; use crate::types::vector::Vector; -/// Matcher that accepts any vector-shaped extension type — both plain -/// [`Vector`] and [`NormalizedVector`](crate::normalized_vector::NormalizedVector). +/// Matcher that accepts any vector-shaped extension type (both plain [`Vector`] and +/// [`NormalizedVector`]). /// /// To match a plain [`Vector`] only (excluding [`NormalizedVector`]), pair this matcher with a /// negated `is::()` check; to match a `NormalizedVector` only, use diff --git a/vortex-tensor/src/utils.rs b/vortex-tensor/src/utils.rs index a50e8fa168f..2358da592a8 100644 --- a/vortex-tensor/src/utils.rs +++ b/vortex-tensor/src/utils.rs @@ -8,9 +8,11 @@ use vortex_array::ExecutionCtx; use vortex_array::IntoArray; use vortex_array::arrays::Constant; use vortex_array::arrays::ConstantArray; +use vortex_array::arrays::ExtensionArray; use vortex_array::arrays::FixedSizeListArray; use vortex_array::arrays::PrimitiveArray; use vortex_array::arrays::ScalarFn; +use vortex_array::arrays::extension::ExtensionArrayExt; use vortex_array::arrays::fixed_size_list::FixedSizeListArrayExt; use vortex_array::arrays::primitive::PrimitiveArrayExt; use vortex_array::arrays::scalar_fn::ExactScalarFn; @@ -19,6 +21,7 @@ use vortex_array::arrays::scalar_fn::ScalarFnArrayView; use vortex_array::dtype::DType; use vortex_array::dtype::NativePType; use vortex_array::dtype::PType; +use vortex_array::dtype::extension::ExtDTypeRef; use vortex_array::dtype::proto::dtype as pb; use vortex_array::scalar_fn::ScalarFnVTable; use vortex_buffer::Buffer; @@ -31,9 +34,11 @@ use vortex_session::VortexSession; use crate::matcher::AnyTensor; use crate::matcher::TensorMatch; +use crate::normalized_vector::NormalizedVector; use crate::scalar_fns::l2_denorm::L2Denorm; use crate::types::vector::VectorMatcherMetadata; use crate::vector::AnyVector; +use crate::vector::Vector; /// Safety factor for unit-norm tolerance. Applied as a constant multiplier on the probabilistic /// `√d · ε` bound so that legitimate round-off noise clears the check with headroom. @@ -128,6 +133,45 @@ fn vector_shapes_match(lhs: &DType, rhs: &DType) -> bool { ) } +/// Returns the underlying `FixedSizeList` storage dtype for a vector-shaped extension dtype. +/// +/// For a plain [`Vector`], this is the direct storage dtype. For a [`NormalizedVector`] +/// it drills through one extra extension layer. +pub fn vector_fsl_storage_dtype(ext: &ExtDTypeRef) -> Option { + use vortex_array::dtype::DType; + if ext.is::() { + Some(ext.storage_dtype().clone()) + } else if ext.is::() { + let DType::Extension(inner) = ext.storage_dtype() else { + return None; + }; + if !inner.is::() { + return None; + } + Some(inner.storage_dtype().clone()) + } else { + None + } +} + +/// Returns the underlying `Vector` extension array inside a vector-shaped extension array. +/// +/// For a [`NormalizedVector`] array, this executes the outer extension and returns its +/// `Vector` storage child. For a plain [`Vector`] array, it returns the array itself (after +/// canonicalizing to an `ExtensionArray`). +pub fn inner_vector_array(array: &ArrayRef, ctx: &mut ExecutionCtx) -> VortexResult { + let is_normalized = array + .dtype() + .as_extension_opt() + .is_some_and(|ext| ext.is::()); + if is_normalized { + let ext: ExtensionArray = array.clone().execute(ctx)?; + Ok(ext.storage_array().clone()) + } else { + Ok(array.clone()) + } +} + /// Cast a float [`PrimitiveArray`] to a `Buffer`. /// /// Several operations in this crate (SORF transform, TurboQuant quantization) work exclusively From d98e8894319374e5c2973dc611845ce781511a5b Mon Sep 17 00:00:00 2001 From: Connor Tsui Date: Wed, 29 Apr 2026 13:21:15 -0400 Subject: [PATCH 5/6] fix sorf not normalized bug Signed-off-by: Connor Tsui --- .../src/encodings/turboquant/compress.rs | 30 +- vortex-tensor/src/encodings/turboquant/mod.rs | 6 +- .../src/encodings/turboquant/scheme.rs | 7 +- .../src/encodings/turboquant/tests/mod.rs | 9 +- .../encodings/turboquant/tests/structural.rs | 39 +++ .../src/scalar_fns/cosine_similarity.rs | 2 +- vortex-tensor/src/scalar_fns/inner_product.rs | 34 ++- vortex-tensor/src/scalar_fns/l2_denorm.rs | 279 ++++++++---------- .../src/scalar_fns/sorf_transform/mod.rs | 7 +- .../src/scalar_fns/sorf_transform/tests.rs | 17 +- .../src/scalar_fns/sorf_transform/vtable.rs | 93 ++---- .../src/types/normalized_vector/mod.rs | 36 +++ vortex-tensor/src/utils.rs | 8 +- 13 files changed, 302 insertions(+), 265 deletions(-) diff --git a/vortex-tensor/src/encodings/turboquant/compress.rs b/vortex-tensor/src/encodings/turboquant/compress.rs index 610d7964752..7e4f37d59a2 100644 --- a/vortex-tensor/src/encodings/turboquant/compress.rs +++ b/vortex-tensor/src/encodings/turboquant/compress.rs @@ -80,14 +80,14 @@ impl Default for TurboQuantConfig { /// /// ```text /// ScalarFnArray(L2Denorm, [ -/// ScalarFnArray(SorfTransform, [FSL(Dict(codes, centroids))]), +/// NormalizedVector(ScalarFnArray(SorfTransform, [NormalizedVector(Vector(FSL(Dict)))])), /// norms, /// ]) /// ``` /// /// # Errors /// -/// Returns an error if `input` is not a tensor-like extension array, if normalization fails, or if +/// Returns an error if `input` is not a vector-family extension array, if normalization fails, or if /// [`turboquant_encode_normalized`] rejects the input shape. pub fn turboquant_encode( input: ArrayRef, @@ -97,7 +97,15 @@ pub fn turboquant_encode( // We must normalize the array before we can encode it with TurboQuant. let l2_denorm = normalize_as_l2_denorm(input, ctx)?; - let normalized = l2_denorm.child_at(0).clone(); // Guaranteed to be a `NormalizedVector`.. + let normalized = l2_denorm.child_at(0).clone(); + vortex_ensure!( + normalized + .dtype() + .as_extension_opt() + .is_some_and(|ext| ext.is::()), + "TurboQuant requires a Vector or NormalizedVector input, got normalized child {}", + normalized.dtype(), + ); let norms = l2_denorm.child_at(1).clone(); let num_rows = l2_denorm.len(); @@ -112,9 +120,9 @@ pub fn turboquant_encode( Ok(unsafe { L2Denorm::new_array_unchecked(tq, norms, num_rows) }?.into_array()) } -/// Encode a non-nullable [`NormalizedVector`] extension array into a -/// `ScalarFnArray(SorfTransform, [FSL(Dict(codes, centroids))])`, without validating the unit-norm -/// precondition. +/// Encode a non-nullable [`NormalizedVector`] extension array into a lossy +/// `NormalizedVector(ScalarFnArray(SorfTransform, [NormalizedVector(Vector(FSL(Dict)))]))`, +/// without validating the decoded unit-norm precondition. pub fn turboquant_encode_normalized( ext: ArrayView, config: &TurboQuantConfig, @@ -173,14 +181,14 @@ pub fn turboquant_encode_normalized( let quantized_fsl = turboquant_quantize_fsl(&fsl, config.bit_width, &sorf_options, ctx)?; - // NB: The quantized rows are approximately unit-norm by construction; downstream callers - // (notably the enclosing `L2Denorm` wrapper) treat the stored-norm + NormalizedVector claim as - // authoritative rather than decode-verified. - // SAFETY: TurboQuant is a lossy approximation of the already-unit-norm input. let padded_vector = unsafe { NormalizedVector::new_unchecked(quantized_fsl) }?; - Ok(SorfTransform::try_new_array(&sorf_options, padded_vector, num_rows)?.into_array()) + let sorf = SorfTransform::try_new_array(&sorf_options, padded_vector, num_rows)?.into_array(); + // SAFETY: Inverse SORF followed by truncation can lose energy, and quantization is already + // lossy, so this is a semantic assertion made by TurboQuant rather than an exact validation. + // Downstream vector operators treat the compressed unit-vector claim as authoritative. + unsafe { NormalizedVector::wrap_vector_unchecked(sorf) } } /// Rotate and quantize already-normalized vector rows into a dict-encoded `FixedSizeList`. diff --git a/vortex-tensor/src/encodings/turboquant/mod.rs b/vortex-tensor/src/encodings/turboquant/mod.rs index 93f70e13e9a..71eb873810a 100644 --- a/vortex-tensor/src/encodings/turboquant/mod.rs +++ b/vortex-tensor/src/encodings/turboquant/mod.rs @@ -29,13 +29,15 @@ //! //! ```text //! ScalarFnArray(L2Denorm, [ -//! ScalarFnArray(SorfTransform, [FSL(Dict(codes, centroids))]), +//! NormalizedVector(ScalarFnArray(SorfTransform, [NormalizedVector(Vector(FSL(Dict)))])), //! norms //! ]) //! ``` //! //! When executed, the tree automatically decompresses: Dict dequantizes codes → SorfTransform -//! inverse-rotates → L2Denorm re-applies norms → original vectors (approximately). +//! inverse-rotates → L2Denorm re-applies norms → original vectors (approximately). The +//! `NormalizedVector` wrappers mark the unit-vector contract that the lossy encoding treats as +//! authoritative. //! //! [`L2Denorm`]: crate::scalar_fns::l2_denorm::L2Denorm //! [`SorfTransform`]: crate::scalar_fns::sorf_transform::SorfTransform diff --git a/vortex-tensor/src/encodings/turboquant/scheme.rs b/vortex-tensor/src/encodings/turboquant/scheme.rs index b40bd52cb17..bc8dcbfcf7f 100644 --- a/vortex-tensor/src/encodings/turboquant/scheme.rs +++ b/vortex-tensor/src/encodings/turboquant/scheme.rs @@ -7,10 +7,9 @@ //! //! ```text //! ScalarFnArray(L2Denorm, [ -//! ScalarFnArray( -//! SorfTransform, -//! FSL(Dict(codes, centroids)) -//! ), +//! NormalizedVector(ScalarFnArray(SorfTransform, [ +//! NormalizedVector(Vector(FSL(Dict(codes, centroids)))) +//! ])), //! norms //! ]) //! ``` diff --git a/vortex-tensor/src/encodings/turboquant/tests/mod.rs b/vortex-tensor/src/encodings/turboquant/tests/mod.rs index 4a21c40f0ba..c9e7b8410fc 100644 --- a/vortex-tensor/src/encodings/turboquant/tests/mod.rs +++ b/vortex-tensor/src/encodings/turboquant/tests/mod.rs @@ -16,6 +16,7 @@ use vortex_array::ArrayRef; use vortex_array::IntoArray; use vortex_array::VortexSessionExecute; use vortex_array::arrays::Dict; +use vortex_array::arrays::Extension; use vortex_array::arrays::ExtensionArray; use vortex_array::arrays::FixedSizeListArray; use vortex_array::arrays::PrimitiveArray; @@ -71,7 +72,7 @@ fn make_vector_ext(fsl: &FixedSizeListArray) -> ArrayRef { .vortex_expect("test FSL satisfies Vector storage constraints") } -/// Unwrap an L2Denorm ScalarFnArray into (sorf_child, norms_child). +/// Unwrap an L2Denorm ScalarFnArray into (normalized_sorf_child, norms_child). fn unwrap_l2denorm(encoded: &ArrayRef) -> (ArrayRef, ArrayRef) { let sfn = encoded .as_opt::() @@ -84,7 +85,11 @@ fn unwrap_codes_centroids_norms( encoded: &ArrayRef, ctx: &mut vortex_array::ExecutionCtx, ) -> VortexResult<(PrimitiveArray, PrimitiveArray, PrimitiveArray)> { - let (sorf_child, norms_child) = unwrap_l2denorm(encoded); + let (normalized_sorf_child, norms_child) = unwrap_l2denorm(encoded); + let normalized_sorf = normalized_sorf_child + .as_opt::() + .expect("expected NormalizedVector wrapping SorfTransform"); + let sorf_child = normalized_sorf.storage_array(); let padded_vector_child = sorf_child .as_opt::() .expect("expected SorfTransform ScalarFnArray") diff --git a/vortex-tensor/src/encodings/turboquant/tests/structural.rs b/vortex-tensor/src/encodings/turboquant/tests/structural.rs index f65137d5d75..4b153596ec4 100644 --- a/vortex-tensor/src/encodings/turboquant/tests/structural.rs +++ b/vortex-tensor/src/encodings/turboquant/tests/structural.rs @@ -4,15 +4,18 @@ //! Tests that verify the internal structure of the encoded tree. use vortex_array::VortexSessionExecute; +use vortex_array::arrays::Extension; use vortex_array::arrays::ExtensionArray; use vortex_array::arrays::FixedSizeListArray; use vortex_array::arrays::PrimitiveArray; +use vortex_array::arrays::ScalarFn; use vortex_array::arrays::extension::ExtensionArrayExt; use vortex_array::arrays::fixed_size_list::FixedSizeListArrayExt; use vortex_error::VortexResult; use super::*; use crate::encodings::turboquant::centroids::compute_or_get_centroids; +use crate::types::normalized_vector::NormalizedVector; /// Verify that the centroids stored in the DictArray match what `compute_or_get_centroids()` /// computes. @@ -110,6 +113,42 @@ fn encoded_dtype_is_vector_extension() -> VortexResult<()> { Ok(()) } +/// Verify the L2Denorm child keeps the normalized-vector marker even though SorfTransform itself +/// returns a plain Vector. +#[test] +fn encoded_l2_denorm_child_is_normalized_sorf_transform() -> VortexResult<()> { + let fsl = make_fsl(10, 128, 42); + let ext = make_vector_ext(&fsl); + let config = TurboQuantConfig { + bit_width: 3, + seed: 123, + num_rounds: 2, + }; + let mut ctx = SESSION.create_execution_ctx(); + let encoded = turboquant_encode(ext, &config, &mut ctx)?; + + let (normalized_child, _norms) = unwrap_l2denorm(&encoded); + assert!( + normalized_child + .dtype() + .as_extension() + .is::(), + "L2Denorm child should carry NormalizedVector dtype" + ); + + let normalized_ext = normalized_child + .as_opt::() + .expect("normalized child should be an Extension array"); + assert!( + normalized_ext + .storage_array() + .as_opt::() + .is_some(), + "NormalizedVector storage should be the SorfTransform ScalarFnArray" + ); + Ok(()) +} + /// Verify approximate cosine similarity in the quantized domain. #[test] fn cosine_similarity_quantized_accuracy() -> VortexResult<()> { diff --git a/vortex-tensor/src/scalar_fns/cosine_similarity.rs b/vortex-tensor/src/scalar_fns/cosine_similarity.rs index a9369ad5ca0..dd7fc58e455 100644 --- a/vortex-tensor/src/scalar_fns/cosine_similarity.rs +++ b/vortex-tensor/src/scalar_fns/cosine_similarity.rs @@ -130,7 +130,7 @@ impl ScalarFnVTable for CosineSimilarity { let mut rhs_ref = args.get(1)?; let len = args.row_count(); - // If either side is a constant tensor-like extension array, eagerly normalize the single + // If either side is a constant vector extension array, eagerly normalize the single // stored row and re-wrap it as an `L2Denorm` whose children are both `ConstantArray`s. // The L2Denorm fast path below then picks it up. if let Some(sfn) = try_build_constant_l2_denorm_from_constant(&lhs_ref, len, ctx)? { diff --git a/vortex-tensor/src/scalar_fns/inner_product.rs b/vortex-tensor/src/scalar_fns/inner_product.rs index ddd3d907c20..d0215b7e25e 100644 --- a/vortex-tensor/src/scalar_fns/inner_product.rs +++ b/vortex-tensor/src/scalar_fns/inner_product.rs @@ -145,7 +145,7 @@ impl ScalarFnVTable for InnerProduct { // Take the factored fast path only when at least one operand wraps stored norms (the // `Denormalized` form). Routing through this lets us extract the stored norms instead of // canonicalizing the `L2Denorm` ScalarFnArray, which would materialize `unit · norms` - // row-by-row before the dotan avoidable `O(N·D)` pass. + // row-by-row before the dot, an avoidable `O(N·D)` pass. let lhs_form = NormalForm::classify(&lhs_ref); let rhs_form = NormalForm::classify(&rhs_ref); if matches!(lhs_form, NormalForm::Denormalized { .. }) @@ -154,28 +154,29 @@ impl ScalarFnVTable for InnerProduct { return self.execute_factored_dot(&lhs_form, &rhs_form, &lhs_ref, &rhs_ref, len, ctx); } + // Peel any `NormalizedVector` wrapper before checking reduction cases. TurboQuant marks + // the decoded SORF output as normalized, but the optimization patterns still live on the + // inner vector-shaped storage. + let lhs_inner = inner_vector_array(&lhs_ref, ctx)?; + let rhs_inner = inner_vector_array(&rhs_ref, ctx)?; + // Reduction case 1: `InnerProduct(SorfTransform(x), const)` rewrites to // `InnerProduct(x, forward_rotate(zero_pad(const)))`. Re-executes recursively so // case 2 can fire on the rewritten tree. - if let Some(rewritten) = self.try_execute_sorf_constant(&lhs_ref, &rhs_ref, len, ctx)? { + if let Some(rewritten) = self.try_execute_sorf_constant(&lhs_inner, &rhs_inner, len, ctx)? { return Ok(rewritten); } // Reduction case 2: `InnerProduct(Vector[FSL(Dict(u8, f32))], const)` is computed by // gather-summing `q[j] * values[codes[j] as usize]` per row, reading the codebook // directly instead of decoding the column into dense vectors. - if let Some(result) = self.try_execute_dict_constant(&lhs_ref, &rhs_ref, len, ctx)? { + if let Some(result) = self.try_execute_dict_constant(&lhs_inner, &rhs_inner, len, ctx)? { return Ok(result); } // Compute combined validity. let validity = lhs_ref.validity()?.and(rhs_ref.validity()?)?; - // Drill past any `NormalizedVector` wrapper so we always work with the underlying `Vector` - // extension array. - let lhs_inner = inner_vector_array(&lhs_ref, ctx)?; - let rhs_inner = inner_vector_array(&rhs_ref, ctx)?; - // Canonicalize so we can perform the math directly. let lhs: ExtensionArray = lhs_inner.execute(ctx)?; let rhs: ExtensionArray = rhs_inner.execute(ctx)?; @@ -511,10 +512,10 @@ impl InnerProduct { /// /// - `Plain`: `(original, None)`. Implicit `scale = 1`; the operand passes through to the dot /// unchanged. -/// - `Normalized`: `(NV, None)`. Implicit `scale = 1`; the unit-norm child passes through to -/// the dot unchanged. -/// - `Denormalized`: `(NV, Some(stored_norms))`. The dot is computed over the unit-norm child -/// and the caller multiplies row-wise by the materialized stored norms afterward. +/// - `Normalized`: `(inner_vector, None)`. Implicit `scale = 1`; the unit-norm child passes +/// through to the dot unchanged, with the wrapper peeled so SORF/dict reductions can still fire. +/// - `Denormalized`: `(inner_vector, Some(stored_norms))`. The dot is computed over the unit +/// child and the caller multiplies row-wise by the materialized stored norms afterward. fn factor_operand( form: &NormalForm<'_>, original: &ArrayRef, @@ -522,10 +523,11 @@ fn factor_operand( ) -> VortexResult<(ArrayRef, Option)> { match form { NormalForm::Plain => Ok((original.clone(), None)), - NormalForm::Normalized { array } => Ok(((*array).clone(), None)), - NormalForm::Denormalized { normalized, norms } => { - Ok((normalized.clone(), Some(norms.clone().execute(ctx)?))) - } + NormalForm::Normalized { array } => Ok((inner_vector_array(array, ctx)?, None)), + NormalForm::Denormalized { normalized, norms } => Ok(( + inner_vector_array(normalized, ctx)?, + Some(norms.clone().execute(ctx)?), + )), } } diff --git a/vortex-tensor/src/scalar_fns/l2_denorm.rs b/vortex-tensor/src/scalar_fns/l2_denorm.rs index 76302f94e84..6c1eb022891 100644 --- a/vortex-tensor/src/scalar_fns/l2_denorm.rs +++ b/vortex-tensor/src/scalar_fns/l2_denorm.rs @@ -1,7 +1,7 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: Copyright the Vortex contributors -//! L2 denormalization expression for tensor-like types. +//! L2 denormalization expression for normalized vectors. use std::fmt::Formatter; @@ -59,8 +59,8 @@ use vortex_error::vortex_ensure_eq; use vortex_error::vortex_err; use vortex_session::VortexSession; -use crate::matcher::AnyTensor; use crate::scalar_fns::l2_norm::L2Norm; +use crate::types::normalized_vector::AnyNormalizedVector; use crate::types::normalized_vector::NormalizedVector; use crate::types::vector::AnyVector; use crate::types::vector::Vector; @@ -73,12 +73,11 @@ use crate::utils::vector_fsl_storage_dtype; /// Re-applies authoritative L2 norms to a normalized vector column. /// -/// Computes `normalized * norm` on each row over the flat backing buffer of the vector-shaped -/// child. +/// Computes `normalized * norm` on each row over the flat backing buffer of the vector child. /// -/// The first child must be vector-shaped and semantically suitable for L2 denormalization. Exact -/// callers should use [`try_new_array`](Self::try_new_array), which verifies that plain -/// [`Vector`] inputs are row-wise unit-norm (or zero). Lossy encodings may use +/// The first child must be a [`NormalizedVector`]. Exact callers should use +/// [`try_new_array`](Self::try_new_array), which verifies that stored norms are non-negative and +/// that a zero stored norm is paired with an all-zero normalized row. Lossy encodings may use /// [`new_array_unchecked`](Self::new_array_unchecked) when the decoded child is only an /// approximation but the stored norms are still authoritative. /// @@ -105,12 +104,8 @@ impl L2Denorm { /// Constructs a validated [`ScalarFnArray`] that lazily re-applies `norms` to `normalized`. /// /// In addition to the structural checks performed by [`ScalarFnArray::try_new`], this - /// constructor verifies that plain [`Vector`] children are row-wise unit-norm (or zero), that - /// stored norms are non-negative, and that any row with stored norm `0.0` has an all-zero - /// normalized row. - /// - /// Plain [`Vector`] children are promoted to [`NormalizedVector`] after validation so that - /// downstream execution paths can rely on the unit-norm marker. + /// constructor verifies that the first child is a [`NormalizedVector`], that stored norms are + /// non-negative, and that any row with stored norm `0.0` has an all-zero normalized row. /// /// # Errors /// @@ -125,38 +120,22 @@ impl L2Denorm { ) -> VortexResult { validate_norms_against_normalized(&normalized, &norms, ctx)?; - // Promote plain `Vector` children to `NormalizedVector`. The unit-norm invariant is - // verified by `validate_norms_against_normalized`, so the `wrap_vector_unchecked` wrap is - // safe by construction. - let normalized = if normalized - .dtype() - .as_extension_opt() - .is_some_and(|ext| ext.is::()) - { - normalized - } else { - // SAFETY: row-wise unit-norm (or zero) was just verified for the plain `Vector` input - // above. Wrap the `Vector` extension array as a `NormalizedVector` without unpacking - // to FSL storage. - unsafe { NormalizedVector::wrap_vector_unchecked(normalized) }? - }; - // SAFETY: The validation above established the exact L2Denorm invariants. unsafe { Self::new_array_unchecked(normalized, norms, len) } } - /// Constructs an [`L2Denorm`] array without validating the normalized-child invariant. + /// Constructs an [`L2Denorm`] array without validating row values against `norms`. /// - /// Structural validation still runs via [`ScalarFnArray::try_new`]. Use this when the - /// normalized child is a lossy approximation whose rows may not be exactly unit-norm or may not - /// preserve exact zero-ness. + /// Structural validation still runs via [`ScalarFnArray::try_new`], so the first child must be + /// a [`NormalizedVector`]. Use this when the normalized child is a lossy approximation whose + /// rows may not be exactly unit-norm or may not preserve exact zero-ness. /// /// # Safety /// - /// The caller must ensure the first child is semantically suitable for L2 denormalization. - /// For exact wrappers, every valid row must be unit-norm or zero and stored norms must be - /// non-negative. Lossy encodings may deliberately relax the exact row invariant while still - /// treating the stored norms as authoritative. + /// The caller must ensure the first child is semantically suitable for L2 denormalization and + /// is wrapped as a [`NormalizedVector`]. For exact wrappers, every valid row must be unit-norm + /// or zero and stored norms must be non-negative. Lossy encodings may deliberately relax the + /// exact row invariant while still treating the stored norms as authoritative. /// /// # Errors /// @@ -207,41 +186,7 @@ impl ScalarFnVTable for L2Denorm { let normalized = &arg_dtypes[0]; let norms = &arg_dtypes[1]; - let ext = normalized.as_extension_opt().ok_or_else(|| { - vortex_err!( - "L2Denorm normalized child must be a Vector or NormalizedVector, got \ - {normalized}", - ) - })?; - let normalized_metadata = ext.metadata_opt::().ok_or_else(|| { - vortex_err!( - "L2Denorm normalized child must be a Vector or NormalizedVector, got \ - {normalized}", - ) - })?; - let element_ptype = normalized_metadata.element_ptype(); - - let DType::Primitive(norms_ptype, _) = norms else { - vortex_bail!("L2Denorm norms must be a primitive float array, got {norms}"); - }; - vortex_ensure_eq!( - *norms_ptype, - element_ptype, - "L2Denorm norms dtype must match normalized element dtype ({element_ptype}), \ - got {norms_ptype}", - ); - - // The denormalized output has the same FSL storage shape as the normalized child but is - // no longer guaranteed unit-norm, so it surfaces as a plain `Vector` extension type. - let fsl_dtype = vector_fsl_storage_dtype(ext).ok_or_else(|| { - vortex_err!( - "L2Denorm normalized child must be a Vector or NormalizedVector, got \ - {normalized}", - ) - })?; - let plain_vector = - DType::Extension(ExtDType::::try_new(EmptyMetadata, fsl_dtype)?.erased()); - Ok(plain_vector.union_nullability(norms.nullability())) + l2_denorm_output_dtype(normalized, norms) } fn execute( @@ -252,19 +197,7 @@ impl ScalarFnVTable for L2Denorm { ) -> VortexResult { let normalized_ref = args.get(0)?; let norms_ref = args.get(1)?; - // Output is a plain `Vector` (not `NormalizedVector`) because denormalized values are no - // longer guaranteed unit-norm. Drill through any `NormalizedVector` wrapper to get the - // underlying FSL. - let fsl_dtype = vector_fsl_storage_dtype(normalized_ref.dtype().as_extension()) - .ok_or_else(|| { - vortex_err!( - "L2Denorm normalized child must be a Vector or NormalizedVector, got {}", - normalized_ref.dtype(), - ) - })?; - let output_dtype = - DType::Extension(ExtDType::::try_new(EmptyMetadata, fsl_dtype)?.erased()) - .union_nullability(norms_ref.dtype().nullability()); + let output_dtype = l2_denorm_output_dtype(normalized_ref.dtype(), norms_ref.dtype())?; let validity = normalized_ref.validity()?.and(norms_ref.validity()?)?; if let Some(const_norms) = norms_ref.as_opt::() { @@ -294,12 +227,12 @@ impl ScalarFnVTable for L2Denorm { let norms: PrimitiveArray = norms_ref.execute(ctx)?; let row_count = args.row_count(); - let tensor_match = normalized + let vector_metadata = normalized .dtype() .as_extension() - .metadata_opt::() + .metadata_opt::() .vortex_expect("we already validated this in `return_dtype`"); - let tensor_flat_size = tensor_match.list_size() as usize; + let tensor_flat_size = vector_metadata.dimensions() as usize; let flat = extract_flat_elements(normalized.storage_array(), tensor_flat_size, ctx)?; @@ -344,6 +277,40 @@ impl ScalarFnVTable for L2Denorm { } } +/// Returns the denormalized output dtype for a normalized vector child and matching norms. +fn l2_denorm_output_dtype(normalized: &DType, norms: &DType) -> VortexResult { + let normalized_ext = normalized.as_extension_opt().ok_or_else(|| { + vortex_err!("L2Denorm normalized child must be a NormalizedVector, got {normalized}") + })?; + let normalized_metadata = normalized_ext + .metadata_opt::() + .ok_or_else(|| { + vortex_err!("L2Denorm normalized child must be a NormalizedVector, got {normalized}") + })?; + let element_ptype = normalized_metadata.element_ptype(); + + let DType::Primitive(norms_ptype, _) = norms else { + vortex_bail!("L2Denorm norms must be a primitive float array, got {norms}"); + }; + vortex_ensure!( + norms_ptype.is_float(), + "L2Denorm norms must be a primitive float array, got {norms}", + ); + vortex_ensure_eq!( + *norms_ptype, + element_ptype, + "L2Denorm norms dtype must match normalized element dtype ({element_ptype}), \ + got {norms_ptype}", + ); + + let fsl_dtype = vector_fsl_storage_dtype(normalized_ext).ok_or_else(|| { + vortex_err!("L2Denorm normalized child must be a NormalizedVector, got {normalized}") + })?; + let output = DType::Extension(ExtDType::::try_new(EmptyMetadata, fsl_dtype)?.erased()); + + Ok(output.union_nullability(norms.nullability())) +} + /// Metadata for a serialized [`L2Denorm`] array: both children's full [`DType`]s. The parent's /// dtype is `normalized.union_nullability(norms.nullability())`, which loses both children's /// individual nullabilities, so we persist them directly. @@ -418,20 +385,20 @@ fn execute_l2_denorm_constant_norms( .vortex_expect("we know that this is a float, so it must fit in f64") - 1.0f64; - let tensor_match = normalized_ref + let normalized_metadata = normalized_ref .dtype() .as_extension_opt() - .and_then(|ext| ext.metadata_opt::()) + .and_then(|ext| ext.metadata_opt::()) .ok_or_else(|| { vortex_err!( - "L2Denorm normalized child must be a Vector or NormalizedVector, got {}", + "L2Denorm normalized child must be a NormalizedVector, got {}", normalized_ref.dtype(), ) })?; let tolerance = unit_norm_tolerance( norm_scalar.dtype().as_ptype(), - tensor_match.list_size() as usize, + normalized_metadata.dimensions() as usize, ); // Drill past any outer `NormalizedVector` wrapper so we always work with the inner plain @@ -477,19 +444,19 @@ fn execute_l2_denorm_constant_norms( Ok(ExtensionArray::new(output_dtype.as_extension().clone(), new_fsl.into_array()).into_array()) } -/// Builds an unexecuted [`L2Denorm`] expression by normalizing `input` and reattaching the exact -/// norms as the `norms` child. +/// Builds an unexecuted [`L2Denorm`] expression by normalizing a vector input and reattaching the +/// exact norms as the `norms` child. /// /// The returned array is a lazy `L2Denorm(normalized, norms)` scalar function array. /// /// # Normalized child /// -/// For plain [`Vector`] (and [`FixedShapeTensor`]) input, every non-null row with a positive L2 -/// norm is divided by its norm to produce a unit-norm vector. The normalized child is forced -/// **non-nullable** with [`Validity::NonNullable`] so optimized kernels over normalized vectors -/// only have to reason about unit-norm vs. zero rows, not nulls. Rows that are null in the -/// original input are **zeroed out** in the normalized output to avoid leaking undefined -/// physical storage values into downstream encodings (like TurboQuant). +/// For plain [`Vector`] input, every non-null row with a positive L2 norm is divided by its norm +/// to produce a unit-norm vector, and the normalized child is promoted to [`NormalizedVector`]. +/// The normalized child is forced **non-nullable** with [`Validity::NonNullable`] so optimized +/// kernels only have to reason about unit-norm vs. zero rows, not nulls. Rows that are null in the +/// original input are **zeroed out** in the normalized output to avoid leaking undefined physical +/// storage values into downstream encodings. /// /// For [`NormalizedVector`] input, the function takes a fast path that returns the input /// unchanged as the normalized child and asks [`L2Norm`] for the per-row norms. The fast path @@ -507,24 +474,23 @@ fn execute_l2_denorm_constant_norms( /// Because this helper computes exact `norms` and (on the slow path) divides by them, the /// returned `normalized` child satisfies the unit-norm invariant required by [`L2Denorm`]. /// -/// [`FixedShapeTensor`]: crate::fixed_shape::FixedShapeTensor /// [`NormalizedVector`]: crate::normalized_vector::NormalizedVector pub fn normalize_as_l2_denorm( input: ArrayRef, ctx: &mut ExecutionCtx, ) -> VortexResult { let row_count = input.len(); - let tensor_metadata = input - .dtype() + let input_dtype = input.dtype().clone(); + let vector_metadata = input_dtype .as_extension_opt() - .and_then(|ext| ext.metadata_opt::()) + .and_then(|ext| ext.metadata_opt::()) .ok_or_else(|| { vortex_err!( - "normalize_as_l2_denorm requires a tensor-like extension input, got {}", - input.dtype(), + "normalize_as_l2_denorm requires a Vector or NormalizedVector extension input, \ + got {input_dtype}", ) })?; - let tensor_flat_size = tensor_metadata.list_size() as usize; + let tensor_flat_size = vector_metadata.dimensions() as usize; // Fast path: input is already a `NormalizedVector`. The slow path below would compute exact // norms and divide every row by its norm, but for a `NormalizedVector` the divisor is always @@ -534,7 +500,7 @@ pub fn normalize_as_l2_denorm( // row, because [`L2Norm`]'s `NormalizedVector` short-circuit emits 0.0 exactly when the row // is all zero. // This also has the added benefit of correcting any lossy-encoded `NormalizedVector` arrays. - if tensor_metadata.is_normalized() { + if vector_metadata.is_normalized() { let norms_sfn = L2Norm::try_new_array(input.clone(), row_count)?; let norms_array: ArrayRef = norms_sfn.into_array().execute(ctx)?; @@ -599,12 +565,12 @@ pub fn normalize_as_l2_denorm( unsafe { L2Denorm::new_array_unchecked(normalized, norms_array, row_count) } } -// TODO(connor): This does not handle `NormalizedVector` correctly!!! /// Attempts to build an [`L2Denorm`] whose two children are both [`ConstantArray`]s by eagerly /// normalizing `input`'s single stored row. /// -/// Returns `Ok(None)` when `input` is not a tensor-like extension array whose storage is a -/// [`ConstantArray`] with a non-null fixed-size-list scalar. +/// Returns `Ok(None)` when `input` is not a plain vector extension array whose storage is a +/// [`ConstantArray`] with a non-null fixed-size-list scalar, or when it is already a +/// [`NormalizedVector`]. /// /// When `input` matches, the returned [`ScalarFnArray`] is equivalent to [`normalize_as_l2_denorm`] /// but runs in `O(list_size)` time instead of `O(row_count * list_size)`. @@ -627,9 +593,6 @@ pub(crate) fn try_build_constant_l2_denorm_from_constant( return Ok(None); } - // Only promote vector-family inputs: wrapping FST rows as `NormalizedVector` would be a - // family change, so `FixedShapeTensor` constants fall back to the generic fast path with - // per-row division. let Some(vector_metadata) = input .dtype() .as_extension_opt() @@ -637,6 +600,10 @@ pub(crate) fn try_build_constant_l2_denorm_from_constant( else { return Ok(None); }; + if vector_metadata.is_normalized() { + return Ok(None); + } + let list_size = vector_metadata.dimensions() as usize; let original_nullability = input.dtype().nullability(); let storage_fsl_nullability = storage.dtype().nullability(); @@ -674,8 +641,7 @@ pub(crate) fn try_build_constant_l2_denorm_from_constant( let normalized_storage = ConstantArray::new(normalized_fsl_scalar, len).into_array(); // SAFETY: The single stored row is either `v / ||v||` (unit norm within floating-point - // tolerance) or all zeros when `||v|| == 0`. This is the invariant required by - // `NormalizedVector::new_unchecked`. + // tolerance) or all zeros when `||v|| == 0`. let normalized = unsafe { NormalizedVector::new_unchecked(normalized_storage) }?; let norms_array = ConstantArray::new(norms_scalar, len).into_array(); @@ -686,7 +652,7 @@ pub(crate) fn try_build_constant_l2_denorm_from_constant( })) } -/// Rebuilds a tensor-like extension array from flat primitive elements. +/// Rebuilds a vector extension array from flat primitive elements. /// /// # Errors /// @@ -725,41 +691,32 @@ fn build_fsl_storage( ) -> VortexResult { let list_size = u32::try_from(tensor_flat_size).vortex_expect("tensor flat size must fit into `u32`"); - // SAFETY: Validity has no length (because tensor elements are always non-nullable). + // SAFETY: Validity has no length (because vector elements are always non-nullable). let elements = unsafe { PrimitiveArray::new_unchecked(elements, Validity::NonNullable) }; FixedSizeListArray::try_new(elements.into_array(), list_size, validity, row_count) } -// TODO(connor): Need better logic here to check against `NormalizedVector` vs `Vector`. /// Cross-check that `normalized` and `norms` agree on per-row zero-ness, and that stored norms /// are non-negative. Unit-norm enforcement on the rows lives on the /// [`NormalizedVector`](crate::normalized_vector::NormalizedVector) dtype itself. -/// -/// We match against [`AnyTensor`] for symmetry with the rest of the tensor pipeline, but -/// downstream construction in [`L2Denorm::return_dtype`] only succeeds for `Vector` and -/// `NormalizedVector` storage (see [`vector_fsl_storage_dtype`]). A `FixedShapeTensor` operand -/// will pass this validator and then be rejected later, which is why the user-visible error -/// message names only the two supported shapes. fn validate_norms_against_normalized( normalized: &ArrayRef, norms: &ArrayRef, ctx: &mut ExecutionCtx, ) -> VortexResult<()> { - let tensor_match = normalized + let vector_metadata = normalized .dtype() .as_extension_opt() - .and_then(|ext| ext.metadata_opt::()) + .and_then(|ext| ext.metadata_opt::()) .ok_or_else(|| { vortex_err!( - "L2Denorm normalized child must be a Vector or NormalizedVector, got {}", + "L2Denorm normalized child must be a NormalizedVector, got {}", normalized.dtype(), ) })?; let row_count = normalized.len(); - let element_ptype = tensor_match.element_ptype(); - let tolerance = unit_norm_tolerance(element_ptype, tensor_match.list_size() as usize); - let tensor_flat_size = tensor_match.list_size() as usize; - let skip_unit_norm_check = tensor_match.is_normalized(); + let element_ptype = vector_metadata.element_ptype(); + let tensor_flat_size = vector_metadata.dimensions() as usize; vortex_ensure_eq!( norms.len(), @@ -783,11 +740,11 @@ fn validate_norms_against_normalized( return Ok(()); } - // Drill past any outer `NormalizedVector` wrapper so we always iterate the FSL of the - // inner plain `Vector`. + // Drill past the outer `NormalizedVector` wrapper so we always iterate the FSL of the inner + // plain `Vector`. let vector_ref = inner_vector_array(normalized, ctx)?; let vector_ext: ExtensionArray = vector_ref.execute(ctx)?; - let normalized_validity = vector_ext.as_ref().validity()?; + let normalized_validity = normalized.validity()?; let flat = extract_flat_elements(vector_ext.storage_array(), tensor_flat_size, ctx)?; let norms_prim: PrimitiveArray = norms.clone().execute(ctx)?; @@ -807,22 +764,10 @@ fn validate_norms_against_normalized( "L2Denorm norms must be non-negative, but row {i} has {stored_norm_f64:.6}", ); - let (row_norm_sq, is_zero_row) = - flat.row::(i) - .iter() - .fold((0.0f64, true), |(sum_sq, all_zero), x| { - let value = ToPrimitive::to_f64(x).unwrap_or(f64::NAN); - (sum_sq + value * value, all_zero && value.abs() <= tolerance) - }); - - if !skip_unit_norm_check { - let row_norm = row_norm_sq.sqrt(); - vortex_ensure!( - row_norm == 0.0 || (row_norm - 1.0).abs() <= tolerance, - "L2Denorm normalized child row {i} has L2 norm {row_norm:.6}, \ - expected 1.0 or 0.0", - ); - } + let is_zero_row = flat.row::(i).iter().all(|x| { + let value = ToPrimitive::to_f64(x).unwrap_or(f64::NAN); + value == 0.0 + }); if stored_norm_f64 == 0.0 { vortex_ensure!( @@ -853,9 +798,11 @@ pub(crate) enum NormalForm<'a> { /// An already-normalized `NormalizedVector`, which has implicit norms of `1.0`. Normalized { array: &'a ArrayRef }, - /// Decomposed `L2Denorm(normalized: NormalizedVector, norms)`. + /// Decomposed `L2Denorm(normalized, norms)`. /// - /// Note that `normalized` is _always_ non-null, and the validity is stored in `norms`. + /// The normalized child is a `NormalizedVector` by structural contract. It is usually + /// non-null, with validity stored in `norms`, except when callers use + /// [`L2Denorm::new_array_unchecked`] directly. Denormalized { normalized: ArrayRef, norms: ArrayRef, @@ -928,6 +875,7 @@ mod tests { use crate::types::vector::Vector; use crate::utils::test_helpers::assert_close; use crate::utils::test_helpers::normalized_vector_array; + use crate::utils::test_helpers::tensor_array; use crate::utils::test_helpers::vector_array; /// Evaluates L2 denorm on a [`Vector`] (rewrapped as a [`NormalizedVector`]) and the matching @@ -1029,13 +977,13 @@ mod tests { } #[test] - fn l2_denorm_accepts_plain_unit_vector_lhs() -> VortexResult<()> { + fn l2_denorm_rejects_plain_unit_vector_lhs() -> VortexResult<()> { let lhs = vector_array(2, &[1.0, 0.0, 0.0, 1.0])?; let rhs = PrimitiveArray::from_iter([1.0f64, 1.0]).into_array(); let mut ctx = SESSION.create_execution_ctx(); let result = L2Denorm::try_new_array(lhs, rhs, 2, &mut ctx); - assert!(result.is_ok()); + assert!(result.is_err()); Ok(()) } @@ -1138,6 +1086,18 @@ mod tests { Ok(()) } + #[test] + fn l2_denorm_new_array_unchecked_rejects_plain_vector_lhs() -> VortexResult<()> { + let vector = vector_array(2, &[1.0, 0.0, 0.0, 1.0])?; + let norms = PrimitiveArray::from_iter([1.0f64, 1.0]).into_array(); + + // SAFETY: This deliberately checks that structural validation still rejects a plain + // `Vector` child. + let result = unsafe { L2Denorm::new_array_unchecked(vector, norms, 2) }; + assert!(result.is_err()); + Ok(()) + } + #[test] fn normalize_as_l2_denorm_roundtrips_vectors() -> VortexResult<()> { let input = vector_array(3, &[3.0, 4.0, 0.0, 0.0, 0.0, 0.0])?; @@ -1149,6 +1109,15 @@ mod tests { Ok(()) } + #[test] + fn normalize_as_l2_denorm_rejects_fixed_shape_tensor() -> VortexResult<()> { + let input = tensor_array(&[2, 2], &[3.0, 4.0, 0.0, 0.0])?; + let mut ctx = SESSION.create_execution_ctx(); + + assert!(normalize_as_l2_denorm(input, &mut ctx).is_err()); + Ok(()) + } + #[test] fn normalize_as_l2_denorm_supports_constant_vectors() -> VortexResult<()> { let input = Vector::constant_array(&[3.0, 4.0], 2)?; diff --git a/vortex-tensor/src/scalar_fns/sorf_transform/mod.rs b/vortex-tensor/src/scalar_fns/sorf_transform/mod.rs index 3d042bff604..7666e58cdac 100644 --- a/vortex-tensor/src/scalar_fns/sorf_transform/mod.rs +++ b/vortex-tensor/src/scalar_fns/sorf_transform/mod.rs @@ -34,9 +34,10 @@ //! //! The output [`Vector`]'s element type is whatever [`SorfOptions::element_ptype`] is set to. It //! does **not** have to match the child's `f32` storage: we apply an explicit `f32 -> T` cast -//! while materializing the output. This lets SorfTransform hand its result directly to a -//! downstream consumer (e.g. [`L2Denorm`](crate::scalar_fns::l2_denorm::L2Denorm)) whose -//! element-type expectation may differ from the `f32` the transform operated on internally. +//! while materializing the output. Callers that intentionally treat the decoded output as +//! normalized (for example TurboQuant) must wrap the result as a +//! [`NormalizedVector`](crate::normalized_vector::NormalizedVector) before handing it to consumers +//! such as [`L2Denorm`](crate::scalar_fns::l2_denorm::L2Denorm). //! //! [sorf-paper]: https://proceedings.neurips.cc/paper_files/paper/2016/file/53adaf494dc89ef7196d73636eb2451b-Paper.pdf //! [`Vector`]: crate::vector::Vector diff --git a/vortex-tensor/src/scalar_fns/sorf_transform/tests.rs b/vortex-tensor/src/scalar_fns/sorf_transform/tests.rs index 6f906d57d69..d65bf5f15c0 100644 --- a/vortex-tensor/src/scalar_fns/sorf_transform/tests.rs +++ b/vortex-tensor/src/scalar_fns/sorf_transform/tests.rs @@ -365,7 +365,7 @@ fn rejects_non_vector_extension_child_at_construction() { } #[test] -fn accepts_normalized_vector_child_and_mirrors_kind() -> VortexResult<()> { +fn accepts_normalized_vector_child_but_returns_plain_vector() -> VortexResult<()> { let options = default_options(128, 42); let mut values = vec![0.0f32; 128]; values[0] = 1.0; @@ -374,21 +374,20 @@ fn accepts_normalized_vector_child_and_mirrors_kind() -> VortexResult<()> { let mut ctx = SESSION.create_execution_ctx(); let child = NormalizedVector::try_new(fsl.into_array(), &mut ctx)?; - // The output mirrors the child's wrapper kind: a `NormalizedVector` child produces a - // `NormalizedVector` parent. The orthogonal inverse rotation preserves L2 norm and the - // truncated coordinates were near-zero pre-rotation, so the output is approximately - // unit-norm (lossy contract documented on `NormalizedVector::new_unchecked`). + // A `NormalizedVector` child is accepted, but the output is a plain `Vector`: inverse SORF is + // followed by truncation, which cannot generally preserve the unit-norm invariant. let sorf = SorfTransform::try_new_array(&options, child, 1)?.into_array(); - assert!(sorf.dtype().as_extension().is::()); + assert!(sorf.dtype().as_extension().is::()); + assert!(!sorf.dtype().as_extension().is::()); let result: ExtensionArray = sorf.execute(&mut ctx)?; - assert!(result.dtype().as_extension().is::()); + assert!(result.dtype().as_extension().is::()); + assert!(!result.dtype().as_extension().is::()); Ok(()) } -/// A plain [`Vector`] child should still produce a plain [`Vector`] parent. #[test] -fn accepts_plain_vector_child_and_mirrors_kind() -> VortexResult<()> { +fn accepts_plain_vector_child_and_returns_plain_vector() -> VortexResult<()> { let options = default_options(128, 42); let elements = PrimitiveArray::from_iter([0.0f32; 128]).into_array(); let fsl = FixedSizeListArray::try_new(elements, 128, Validity::NonNullable, 1)?; diff --git a/vortex-tensor/src/scalar_fns/sorf_transform/vtable.rs b/vortex-tensor/src/scalar_fns/sorf_transform/vtable.rs index 72592543125..9551a942c0a 100644 --- a/vortex-tensor/src/scalar_fns/sorf_transform/vtable.rs +++ b/vortex-tensor/src/scalar_fns/sorf_transform/vtable.rs @@ -16,8 +16,10 @@ use vortex_array::IntoArray; use vortex_array::arrays::ExtensionArray; use vortex_array::arrays::FixedSizeListArray; use vortex_array::arrays::PrimitiveArray; +use vortex_array::arrays::ScalarFn as ScalarFnArrayEncoding; use vortex_array::arrays::extension::ExtensionArrayExt; use vortex_array::arrays::fixed_size_list::FixedSizeListArrayExt; +use vortex_array::arrays::scalar_fn::ScalarFnArrayExt; use vortex_array::arrays::scalar_fn::ScalarFnArrayView; use vortex_array::arrays::scalar_fn::plugin::ScalarFnArrayParts; use vortex_array::arrays::scalar_fn::plugin::ScalarFnArrayVTable; @@ -26,6 +28,7 @@ use vortex_array::dtype::NativePType; use vortex_array::dtype::Nullability; use vortex_array::dtype::PType; use vortex_array::dtype::extension::ExtDType; +use vortex_array::dtype::proto::dtype as pb; use vortex_array::expr::Expression; use vortex_array::extension::EmptyMetadata; use vortex_array::match_each_float_ptype; @@ -47,7 +50,6 @@ use super::SorfOptions; use super::SorfTransform; use super::rotation::SorfMatrix; use super::validate_sorf_options; -use crate::types::normalized_vector::NormalizedVector; use crate::types::vector::AnyVector; use crate::types::vector::Vector; use crate::utils::inner_vector_array; @@ -121,17 +123,12 @@ impl ScalarFnVTable for SorfTransform { child_dtype.nullability(), ); - // The output mirrors the child's wrapper kind, so if the child was a `NormalizedVector`, - // the output is also a `NormalizedVector`. - let inner = if vector_metadata.is_normalized() { - let inner_vector = ExtDType::::try_new(EmptyMetadata, fsl_dtype)?.erased(); - ExtDType::::try_new(EmptyMetadata, DType::Extension(inner_vector))? - .erased() - } else { - ExtDType::::try_new(EmptyMetadata, fsl_dtype)?.erased() - }; + // The inverse SORF rotation is orthogonal over the padded dimension, but this scalar + // function then truncates back to the original dimension. Truncation can drop energy, so + // even a `NormalizedVector` child cannot generally produce a `NormalizedVector` parent. + let ext_dtype = ExtDType::::try_new(EmptyMetadata, fsl_dtype)?.erased(); - Ok(DType::Extension(inner)) + Ok(DType::Extension(ext_dtype)) } fn execute( @@ -144,10 +141,6 @@ impl ScalarFnVTable for SorfTransform { let num_rows = args.row_count(); let child_arg = args.get(0)?; - let child_is_normalized = child_arg - .dtype() - .as_extension_opt() - .is_some_and(|ext| ext.is::()); let fsl_array = if num_rows == 0 { let validity = Validity::from(child_arg.dtype().nullability()); @@ -186,9 +179,7 @@ impl ScalarFnVTable for SorfTransform { })? }; - // SAFETY: We used the matcher to check if the child was normalized, so this must be - // correct. - unsafe { wrap_vector_storage(fsl_array, child_is_normalized) } + Vector::try_new_vector_array(fsl_array) } fn validity( @@ -210,10 +201,9 @@ impl ScalarFnVTable for SorfTransform { /// Metadata for a serialized [`SorfTransform`] array. /// -/// Stores the full [`SorfOptions`] inline. The child dtype is fully derivable from the parent -/// dtype: the parent's outer wrapper (plain `Vector` or `NormalizedVector`) mirrors the child's -/// wrapper kind, the inner FSL nullability is propagated through `return_dtype`, and -/// `padded_dim`/`f32` are determined by [`SorfOptions`]. +/// Stores the full [`SorfOptions`] inline along with the child [`DType`]. Older metadata omitted +/// this field; deserialization derives the legacy plain-`Vector` child dtype from the parent dtype +/// in that case. #[derive(Clone, prost::Message)] pub(super) struct SorfTransformMetadata { #[prost(uint64, tag = "1")] @@ -225,6 +215,8 @@ pub(super) struct SorfTransformMetadata { dimension: u32, #[prost(enumeration = "PType", tag = "4")] element_ptype: i32, + #[prost(message, optional, tag = "5")] + child_dtype: Option, } impl ScalarFnArrayVTable for SorfTransform { @@ -233,7 +225,12 @@ impl ScalarFnArrayVTable for SorfTransform { view: &ScalarFnArrayView, _session: &VortexSession, ) -> VortexResult>> { - let metadata = SorfTransformMetadata::from(view.options); + let scalar_fn_array = view.as_::(); + let child_dtype = Some(scalar_fn_array.child_at(0).dtype().try_into()?); + let metadata = SorfTransformMetadata { + child_dtype, + ..SorfTransformMetadata::from(view.options) + }; Ok(Some(metadata.encode_to_vec())) } @@ -251,12 +248,9 @@ impl ScalarFnArrayVTable for SorfTransform { let parent_ext = dtype .as_extension_opt() - .filter(|ext| ext.is::()) + .filter(|ext| ext.is::()) .ok_or_else(|| { - vortex_err!( - "SorfTransform parent dtype must be a `Vector` or `NormalizedVector` \ - extension, got {dtype}", - ) + vortex_err!("SorfTransform parent dtype must be a `Vector` extension, got {dtype}",) })?; // The nullability of the parent extension type is the same as the storage type. @@ -268,18 +262,13 @@ impl ScalarFnArrayVTable for SorfTransform { padded_dim, fsl_nullability, ); - let inner_vector_dtype = - ExtDType::::try_new(EmptyMetadata, child_fsl_dtype)?.erased(); - - let child_dtype = if parent_ext.is::() { - let nv = ExtDType::::try_new( - EmptyMetadata, - DType::Extension(inner_vector_dtype), - )? - .erased(); - DType::Extension(nv) - } else { - DType::Extension(inner_vector_dtype) + let child_dtype = match metadata.child_dtype.as_ref() { + Some(dtype) => DType::from_proto(dtype, _session)?, + None => { + let child_ext = + ExtDType::::try_new(EmptyMetadata, child_fsl_dtype)?.erased(); + DType::Extension(child_ext) + } }; let child = children.get(0, &child_dtype, len)?; @@ -299,9 +288,8 @@ fn float_from_f32(v: f32) -> T { } /// Apply the inverse SORF transform on f32 data, truncate to the original dimension, cast each -/// element to `T`, and return the resulting `FixedSizeList` storage array. The caller is -/// responsible for wrapping the FSL in the appropriate vector-family extension via -/// [`wrap_output`]. +/// element to `T`, and return the resulting `FixedSizeList` storage array. The caller wraps the +/// FSL as a plain [`Vector`](crate::vector::Vector) extension array. fn inverse_rotate_typed( f32_elements: &[f32], rotation: &SorfMatrix, @@ -330,24 +318,6 @@ fn inverse_rotate_typed( Ok(fsl.into_array()) } -/// Wraps `fsl` as either a [`Vector`] or [`NormalizedVector`] extension array, mirroring the kind -/// of the upstream `SorfTransform` child. -/// -/// # Safety -/// -/// When `is_normalized` is `true`, every valid row of `fsl` must be approximately unit-norm or -/// zero in the lossy sense documented by [`NormalizedVector::new_unchecked`]. -/// -/// When `is_normalized` is `false` the function takes the safe `Vector` branch. -unsafe fn wrap_vector_storage(fsl: ArrayRef, is_normalized: bool) -> VortexResult { - if is_normalized { - // SAFETY: Forwarded from the function-level safety contract above. - unsafe { NormalizedVector::new_unchecked(fsl) } - } else { - Vector::try_new_vector_array(fsl) - } -} - impl From<&SorfOptions> for SorfTransformMetadata { fn from(options: &SorfOptions) -> Self { Self { @@ -355,6 +325,7 @@ impl From<&SorfOptions> for SorfTransformMetadata { num_rounds: u32::from(options.num_rounds), dimension: options.dimensions, element_ptype: options.element_ptype as i32, + child_dtype: None, } } } diff --git a/vortex-tensor/src/types/normalized_vector/mod.rs b/vortex-tensor/src/types/normalized_vector/mod.rs index aa44f670743..7242bab9868 100644 --- a/vortex-tensor/src/types/normalized_vector/mod.rs +++ b/vortex-tensor/src/types/normalized_vector/mod.rs @@ -142,6 +142,42 @@ pub(crate) fn validate_unit_norm_rows( Ok(()) } +#[cfg(test)] +mod tests { + use half::f16; + use vortex_array::IntoArray; + use vortex_array::VortexSessionExecute; + use vortex_array::arrays::FixedSizeListArray; + use vortex_array::arrays::PrimitiveArray; + use vortex_array::dtype::PType; + use vortex_array::validity::Validity; + use vortex_error::VortexResult; + + use super::NormalizedVector; + use crate::tests::SESSION; + use crate::utils::unit_norm_tolerance; + + #[test] + fn f16_unit_norm_tolerance_is_capped() { + assert!(unit_norm_tolerance(PType::F16, 768) <= 1e-3); + } + + #[test] + fn try_new_rejects_f16_row_outside_capped_tolerance() -> VortexResult<()> { + let dim = 768u32; + let dim_usize = usize::try_from(dim).expect("dim fits usize"); + let mut values = vec![f16::from_f32(0.0); dim_usize]; + values[0] = f16::from_f32(0.99); + + let elements = PrimitiveArray::from_iter(values).into_array(); + let fsl = FixedSizeListArray::try_new(elements, dim, Validity::NonNullable, 1)?; + let mut ctx = SESSION.create_execution_ctx(); + + assert!(NormalizedVector::try_new(fsl.into_array(), &mut ctx).is_err()); + Ok(()) + } +} + mod matcher; mod vtable; diff --git a/vortex-tensor/src/utils.rs b/vortex-tensor/src/utils.rs index 2358da592a8..bab89ce2b6c 100644 --- a/vortex-tensor/src/utils.rs +++ b/vortex-tensor/src/utils.rs @@ -44,6 +44,12 @@ use crate::vector::Vector; /// `√d · ε` bound so that legitimate round-off noise clears the check with headroom. pub(crate) const SAFETY_FACTOR: usize = 10; +/// Upper bound for unit-norm validation drift. +/// +/// This keeps low-precision element types (especially f16) from accepting vectors whose norms are +/// materially different from 1.0 at common embedding dimensions. +pub(crate) const MAX_UNIT_NORM_TOLERANCE: f64 = 1e-3; + /// Returns the acceptable unit-norm drift for the given element precision and dimension count. /// /// Uses the `c · √d · ε` bound where ε is machine epsilon and d is the vector dimension. Under @@ -64,7 +70,7 @@ pub fn unit_norm_tolerance(element_ptype: PType, dimensions: usize) -> f64 { let dimensions_root = (dimensions as f64).sqrt(); - SAFETY_FACTOR as f64 * machine_epsilon * dimensions_root + (SAFETY_FACTOR as f64 * machine_epsilon * dimensions_root).min(MAX_UNIT_NORM_TOLERANCE) } /// Extracts the `(normalized, norms)` children from an [`L2Denorm`] scalar function array. From 6e0e6aa7274a770bdf24f6fae2bb7fd918169dd3 Mon Sep 17 00:00:00 2001 From: Connor Tsui Date: Wed, 29 Apr 2026 14:03:01 -0400 Subject: [PATCH 6/6] fix validity and type bugs Signed-off-by: Connor Tsui --- .../src/encodings/turboquant/scheme.rs | 97 ++++++++++++++++++- vortex-tensor/src/scalar_fns/l2_denorm.rs | 22 +++-- 2 files changed, 107 insertions(+), 12 deletions(-) diff --git a/vortex-tensor/src/encodings/turboquant/scheme.rs b/vortex-tensor/src/encodings/turboquant/scheme.rs index bc8dcbfcf7f..af782c4abb2 100644 --- a/vortex-tensor/src/encodings/turboquant/scheme.rs +++ b/vortex-tensor/src/encodings/turboquant/scheme.rs @@ -3,7 +3,8 @@ //! TurboQuant compression scheme. //! -//! The scheme is a thin [`Scheme`] adapter over [`turboquant_encode`], which produces: +//! Plain [`Vector`](crate::vector::Vector) inputs are normalized and encoded via +//! [`turboquant_encode`], which produces: //! //! ```text //! ScalarFnArray(L2Denorm, [ @@ -14,13 +15,19 @@ //! ]) //! ``` //! +//! Non-nullable [`NormalizedVector`](crate::normalized_vector::NormalizedVector) inputs skip the +//! outer [`L2Denorm`](crate::scalar_fns::l2_denorm::L2Denorm) wrapper and are encoded directly via +//! [`turboquant_encode_normalized`]. +//! //! Decompression is automatic: executing the outer array walks the ScalarFn tree. //! //! [`turboquant_encode`]: crate::encodings::turboquant::turboquant_encode +//! [`turboquant_encode_normalized`]: crate::encodings::turboquant::turboquant_encode_normalized use vortex_array::ArrayRef; use vortex_array::Canonical; use vortex_array::ExecutionCtx; +use vortex_array::arrays::Extension; use vortex_array::dtype::DType; use vortex_compressor::CascadingCompressor; use vortex_compressor::ctx::CompressorContext; @@ -37,6 +44,7 @@ use crate::encodings::turboquant::MAX_CENTROIDS; use crate::encodings::turboquant::MIN_DIMENSION; use crate::encodings::turboquant::TurboQuantConfig; use crate::encodings::turboquant::turboquant_encode; +use crate::encodings::turboquant::turboquant_encode_normalized; use crate::vector::AnyVector; use crate::vector::VectorMatcherMetadata; @@ -105,7 +113,29 @@ impl Scheme for TurboQuantScheme { _compress_ctx: CompressorContext, exec_ctx: &mut ExecutionCtx, ) -> VortexResult { - turboquant_encode(data.array().clone(), &TurboQuantConfig::default(), exec_ctx) + // TODO(connor): If we ever add scheme vtables with metadata, we would need to pass in the + // config as a parameter here. + let config = TurboQuantConfig::default(); + turboquant_encode_for_scheme(data.array().clone(), &config, exec_ctx) + } +} + +fn turboquant_encode_for_scheme( + input: ArrayRef, + config: &TurboQuantConfig, + exec_ctx: &mut ExecutionCtx, +) -> VortexResult { + let vector_metadata = tq_validate_vector_dtype(input.dtype())?; + if vector_metadata.is_normalized() { + let ext = input.as_opt::().ok_or_else(|| { + vortex_err!( + "TurboQuant normalized input must be an Extension array, got {}", + input.encoding_id() + ) + })?; + turboquant_encode_normalized(ext, config, exec_ctx) + } else { + turboquant_encode(input, config, exec_ctx) } } @@ -137,8 +167,9 @@ fn estimate_compression_ratio(element_bit_width: u8, dimensions: u32, num_vector uncompressed_size_bits as f64 / compressed_size_bits as f64 } -/// Validates that `dtype` is a [`Vector`](crate::vector::Vector) extension type with -/// dimension >= [`MIN_DIMENSION`]. +/// Validates that `dtype` is a plain [`Vector`](crate::vector::Vector) or non-nullable +/// [`NormalizedVector`](crate::normalized_vector::NormalizedVector) extension type with dimension +/// >= [`MIN_DIMENSION`]. /// /// Returns the validated vector metadata on success. pub fn tq_validate_vector_dtype(dtype: &DType) -> VortexResult { @@ -154,6 +185,11 @@ pub fn tq_validate_vector_dtype(dtype: &DType) -> VortexResult= MIN_DIMENSION, "TurboQuant requires dimension >= {MIN_DIMENSION}, got {dimensions}", ); + vortex_ensure!( + !vector_metadata.is_normalized() || !dtype.is_nullable(), + "TurboQuant cannot encode nullable NormalizedVector inputs because normalized encode has \ + no norms child to carry validity", + ); Ok(vector_metadata) } @@ -161,8 +197,19 @@ pub fn tq_validate_vector_dtype(dtype: &DType) -> VortexResult VortexResult<()> { + let mut ctx = SESSION.create_execution_ctx(); + let mut values = vec![0.0f32; 2 * 128]; + values[0] = 1.0; + values[128 + 1] = 1.0; + let input = normalized_vector_array(128, &values, &mut ctx)?; + + let encoded = turboquant_encode_for_scheme(input, &TurboQuantConfig::default(), &mut ctx)?; + + assert!(encoded.dtype().as_extension().is::()); + assert!( + encoded.as_opt::().is_none(), + "NormalizedVector scheme path should not add an outer L2Denorm ScalarFnArray", + ); + Ok(()) + } + + #[test] + fn validate_rejects_nullable_normalized_vector() -> VortexResult<()> { + let dim = 128u32; + let mut values = BufferMut::::with_capacity(2 * dim as usize); + for row in 0..2 { + for col in 0..dim { + values.push(if col == row { 1.0 } else { 0.0 }); + } + } + let elements = PrimitiveArray::new::(values.freeze(), Validity::NonNullable); + let fsl = FixedSizeListArray::try_new( + elements.into_array(), + dim, + Validity::from_iter([true, false]), + 2, + )?; + let mut ctx = SESSION.create_execution_ctx(); + let normalized = NormalizedVector::try_new(fsl.into_array(), &mut ctx)?; + + assert_eq!(normalized.dtype().nullability(), Nullability::Nullable); + assert!(tq_validate_vector_dtype(normalized.dtype()).is_err()); + Ok(()) + } + /// Power-of-2 dimensions should have better ratios than their non-power-of-2 /// predecessors due to no padding waste. #[test] diff --git a/vortex-tensor/src/scalar_fns/l2_denorm.rs b/vortex-tensor/src/scalar_fns/l2_denorm.rs index 6c1eb022891..c3841cbb600 100644 --- a/vortex-tensor/src/scalar_fns/l2_denorm.rs +++ b/vortex-tensor/src/scalar_fns/l2_denorm.rs @@ -406,12 +406,21 @@ fn execute_l2_denorm_constant_norms( let vector_ref = inner_vector_array(&normalized_ref, ctx)?; if err.abs() < tolerance { - // The output dtype is the sibling plain `Vector`. Rewrap the vector storage so the - // executed array's dtype matches `output_dtype`. + // The output dtype is the sibling plain `Vector`. Rebuild the FSL wrapper with the + // combined validity so the executed array's storage nullability matches `output_dtype`. let normalized: ExtensionArray = vector_ref.execute(ctx)?; + + let storage_fsl: FixedSizeListArray = normalized.storage_array().clone().execute(ctx)?; + let new_fsl = FixedSizeListArray::try_new( + storage_fsl.elements().clone(), + storage_fsl.list_size(), + new_validity, + storage_fsl.len(), + )?; + return Ok(ExtensionArray::try_new( output_dtype.as_extension().clone(), - normalized.storage_array().clone(), + new_fsl.into_array(), )? .into_array()); } @@ -1280,13 +1289,10 @@ mod tests { } /// Regression: a non-nullable [`NormalizedVector`] child paired with a nullable-dtype - /// constant norms array (whose value happens to be non-null `1.0`) used to panic in the + /// constant norms array (whose value happens to be non-null `1.0`) used to fail in the /// constant-unit fast path because the extension's declared storage nullability no longer - /// matched the storage array's own nullability. The fix is on the [`ExtensionArray`] side, - /// where storage-dtype matching will ignore outer nullability. That relaxation is not yet on - /// this branch, so the test is ignored until the `ExtensionArray::try_new` change lands. + /// matched the storage array's own nullability. #[test] - #[ignore = "depends on ExtensionArray::try_new ignoring outer storage nullability"] fn l2_denorm_constant_unit_norms_nullable_scalar_nonnullable_normalized() -> VortexResult<()> { let mut ctx = SESSION.create_execution_ctx(); let normalized = normalized_vector_array(3, &[1.0, 0.0, 0.0, 0.0, 1.0, 0.0], &mut ctx)?;