From 04cfd4d270736256294fb7c4c156f50039376d6e Mon Sep 17 00:00:00 2001 From: Anoop Narang Date: Wed, 18 Mar 2026 12:26:16 +0530 Subject: [PATCH] fix(udf): remove sqrt from l2_distance to match USearch L2sq metric l2_distance UDF was computing actual L2 (with sqrt) while USearch and the rewritten execution paths all use L2sq (no sqrt). This caused the same query to return different numeric distance values depending on whether the optimizer rewrote it. Remove sqrt to match USearch's MetricKind::L2sq and DuckDB VSS's array_distance behavior. All paths now return consistent L2sq values. --- README.md | 4 ++-- src/lib.rs | 2 +- src/udf.rs | 3 ++- tests/execution.rs | 59 +++++++++++++++++++++++++++++++++++++++++++++- 4 files changed, 63 insertions(+), 5 deletions(-) diff --git a/README.md b/README.md index e7bb74f..8c086aa 100644 --- a/README.md +++ b/README.md @@ -272,11 +272,11 @@ All three distance functions are **lower-is-closer**: | SQL function | Index metric | Kernel | |---|---|---| -| `l2_distance(a, b)` | `L2sq` | `sqrt(sum((a_i - b_i)^2))` (UDF) / `sum((a_i - b_i)^2)` (index) | +| `l2_distance(a, b)` | `L2sq` | `sum((a_i - b_i)^2)` | | `cosine_distance(a, b)` | `Cos` | `1 - dot(a,b) / (norm(a) * norm(b))` | | `negative_dot_product(a, b)` | `IP` | `-(a . b)` | -Note: `l2_distance` UDF returns actual L2 (with sqrt) for human-readable distances; USearch uses L2sq internally (no sqrt). The sort order is identical. +`l2_distance` returns squared L2 (no sqrt), matching USearch's `MetricKind::L2sq`. This ensures numeric consistency between the UDF, the rewritten index path, and the brute-force path. ### Running tests diff --git a/src/lib.rs b/src/lib.rs index 538752d..5aee261 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -93,7 +93,7 @@ use datafusion::prelude::SessionContext; /// Register all extension components with a DataFusion [`SessionContext`]. /// /// Registers: -/// - `l2_distance(col, query)` — Euclidean distance (L2) +/// - `l2_distance(col, query)` — squared Euclidean distance (L2sq) /// - `cosine_distance(col, query)` — cosine distance /// - `negative_dot_product(col, query)` — negated inner product /// - `vector_usearch(table, query, k)` — explicit ANN table function diff --git a/src/udf.rs b/src/udf.rs index 226c89c..233fcc1 100644 --- a/src/udf.rs +++ b/src/udf.rs @@ -22,12 +22,13 @@ use datafusion::scalar::ScalarValue; type Kernel = fn(&[f32], &[f32]) -> f32; +// Returns L2sq (no sqrt) — matches USearch MetricKind::L2sq and keeps numeric +// values consistent between the UDF path and the optimizer-rewritten index path. fn l2_kernel(a: &[f32], b: &[f32]) -> f32 { a.iter() .zip(b.iter()) .map(|(x, y)| (x - y) * (x - y)) .sum::() - .sqrt() } fn cosine_kernel(a: &[f32], b: &[f32]) -> f32 { diff --git a/tests/execution.rs b/tests/execution.rs index 8d6c5e8..542b5fe 100644 --- a/tests/execution.rs +++ b/tests/execution.rs @@ -19,7 +19,7 @@ use std::sync::Arc; use arrow_array::builder::{FixedSizeListBuilder, Float32Builder}; -use arrow_array::{FixedSizeListArray, RecordBatch, StringArray, UInt64Array}; +use arrow_array::{FixedSizeListArray, Float32Array, RecordBatch, StringArray, UInt64Array}; use arrow_schema::{DataType, Field, Schema}; use datafusion::execution::session_state::SessionStateBuilder; use datafusion::prelude::SessionContext; @@ -453,3 +453,60 @@ async fn exec_parquet_native_where_no_matches() { let ids = collect_ids(&ctx, &sql).await; assert!(ids.is_empty(), "no rows should match; got {ids:?}"); } + +// ═══════════════════════════════════════════════════════════════════════════════ +// Numeric regression — l2_distance must return L2sq (no sqrt) +// ═══════════════════════════════════════════════════════════════════════════════ + +/// l2_distance must return squared L2, not actual L2. +/// Row 1 = [1,0,0,0], query = [1,0,0,0] → L2sq = 0.0 +/// Row 2 = [0,1,0,0], query = [1,0,0,0] → L2sq = 2.0 (L2 would be ~1.414) +#[tokio::test] +async fn exec_l2_distance_returns_l2sq() { + let ctx = make_exec_ctx("items::vector").await; + let sql = + format!("SELECT id, l2_distance(vector, {Q}) AS dist FROM items ORDER BY dist ASC LIMIT 4"); + let df = ctx.sql(&sql).await.expect("sql"); + let batches = df.collect().await.expect("collect"); + + let mut dists: Vec<(u64, f32)> = vec![]; + for batch in &batches { + let id_idx = batch.schema().index_of("id").unwrap(); + let dist_idx = batch.schema().index_of("dist").unwrap(); + let ids = batch + .column(id_idx) + .as_any() + .downcast_ref::() + .unwrap(); + let ds = batch + .column(dist_idx) + .as_any() + .downcast_ref::() + .unwrap(); + for i in 0..batch.num_rows() { + dists.push((ids.value(i), ds.value(i))); + } + } + + // Row 1: exact match → 0.0 + let row1 = dists + .iter() + .find(|(id, _)| *id == 1) + .expect("row 1 missing"); + assert!( + (row1.1 - 0.0).abs() < 1e-6, + "row 1 distance must be 0.0 (L2sq); got {}", + row1.1 + ); + + // Row 2: [0,1,0,0] vs [1,0,0,0] → L2sq = 2.0, NOT sqrt(2) ≈ 1.414 + let row2 = dists + .iter() + .find(|(id, _)| *id == 2) + .expect("row 2 missing"); + assert!( + (row2.1 - 2.0).abs() < 1e-6, + "row 2 distance must be 2.0 (L2sq), not {:.4} (would be ~1.414 if L2)", + row2.1 + ); +}