From b40666d9a96bee13c1834c9efe0450a19c2553e8 Mon Sep 17 00:00:00 2001 From: Anoop Narang Date: Wed, 8 Apr 2026 17:57:50 +0530 Subject: [PATCH 1/3] fix(rule): preserve hidden distance for sort Keep _distance in an inner projection when ORDER BY uses a vector\ndistance expression that is not part of the final select list.\n\nThis fixes split-provider execution for queries like SELECT id ORDER\nBY l2_distance(vector, ARRAY[...]) LIMIT k while preserving the final\noutput schema. Add an execution test for the direct ORDER BY shape to\ncover the production case. --- src/rule.rs | 88 ++++++++++++++++++++++++++++++++++++---------- tests/execution.rs | 11 ++++++ 2 files changed, 81 insertions(+), 18 deletions(-) diff --git a/src/rule.rs b/src/rule.rs index e0b8a7c..e8eca08 100644 --- a/src/rule.rs +++ b/src/rule.rs @@ -18,9 +18,10 @@ // // Replacement: // -// Sort(fetch=k) ← kept (sort order) -// Projection([col(a), col(b), col("_distance").alias("dist")]) -// USearchNode ← executes ANN +// Projection([final output cols]) +// Sort(fetch=k) +// Projection([final output cols + optional hidden _distance]) +// USearchNode use std::collections::HashMap; use std::sync::Arc; @@ -151,25 +152,33 @@ impl USearchRule { node: Arc::new(node) as Arc, })); - // Build Projection over USearchNode matching the original output schema. + // Build the final user-visible projection over USearchNode output. let dist_alias_str = dist_alias.as_deref().unwrap_or("_distance"); - let new_proj_exprs = if proj_exprs_slice.is_empty() { + let final_proj_exprs = if proj_exprs_slice.is_empty() { passthrough_projection(&vsn_df_schema, &table_ref) } else { remap_projections(proj_exprs_slice, dist_alias_str, &table_ref) }; - let new_proj = Projection::try_new(new_proj_exprs, node_plan).ok()?; - - // Keep the Sort node so DataFusion handles ordering by _distance / dist. - // USearch returns results in arbitrary (internal) order when the underlying - // data is fetched from the TableProvider. - Some(LogicalPlan::Sort( - datafusion::logical_expr::logical_plan::Sort { - expr: sort.expr.clone(), - input: Arc::new(LogicalPlan::Projection(new_proj)), - fetch: sort.fetch, - }, - )) + let remapped_sort_exprs = remap_sort_exprs(&sort.expr, dist_alias.as_deref()); + let needs_hidden_distance = remapped_sort_exprs.iter().any(|e| { + matches!(&e.expr, Expr::Column(c) if c.relation.is_none() && c.name == "_distance") + }) && !projection_exposes_name(&final_proj_exprs, "_distance"); + + let mut sort_input_exprs = final_proj_exprs.clone(); + if needs_hidden_distance { + sort_input_exprs.push(col("_distance")); + } + + let sort_input = Projection::try_new(sort_input_exprs, node_plan).ok()?; + let sorted = LogicalPlan::Sort(datafusion::logical_expr::logical_plan::Sort { + expr: remapped_sort_exprs, + input: Arc::new(LogicalPlan::Projection(sort_input)), + fetch: sort.fetch, + }); + + let outer_proj_exprs = build_outer_projection(&final_proj_exprs); + let outer_proj = Projection::try_new(outer_proj_exprs, Arc::new(sorted)).ok()?; + Some(LogicalPlan::Projection(outer_proj)) } } @@ -283,7 +292,11 @@ fn dist_type_matches_metric(dist_type: &DistanceType, metric: MetricKind) -> boo } fn is_distance_expr(expr: &Expr) -> bool { - matches!(expr, Expr::ScalarFunction(sf) if is_dist_udf_name(sf.func.name())) + let inner = match expr { + Expr::Alias(a) => a.expr.as_ref(), + other => other, + }; + matches!(inner, Expr::ScalarFunction(sf) if is_dist_udf_name(sf.func.name())) } fn try_extract_distance(expr: &Expr) -> Option<(String, String, Vec)> { @@ -322,6 +335,45 @@ fn remap_projections( .collect() } +fn remap_sort_exprs( + sort_exprs: &[datafusion::logical_expr::SortExpr], + dist_alias_name: Option<&str>, +) -> Vec { + sort_exprs + .iter() + .map(|sort_expr| { + let remapped_expr = match &sort_expr.expr { + Expr::Column(c) if Some(c.name.as_str()) == dist_alias_name => col(c.name.as_str()), + expr if is_distance_expr(expr) => col("_distance"), + other => other.clone(), + }; + datafusion::logical_expr::SortExpr { + expr: remapped_expr, + asc: sort_expr.asc, + nulls_first: sort_expr.nulls_first, + } + }) + .collect() +} + +fn projection_exposes_name(exprs: &[Expr], name: &str) -> bool { + exprs.iter().any(|expr| match expr { + Expr::Alias(a) => a.name == name, + Expr::Column(c) => c.name == name, + _ => false, + }) +} + +fn build_outer_projection(exprs: &[Expr]) -> Vec { + exprs.iter() + .filter_map(|expr| match expr { + Expr::Alias(a) => Some(col(a.name.as_str())), + Expr::Column(c) => Some(Expr::Column(c.clone())), + _ => None, + }) + .collect() +} + /// Build a passthrough Projection for SELECT * queries (no original Projection node). /// Projects only the original table columns (not `_distance`) so the output schema /// matches the original Sort schema. The Sort re-evaluates the distance UDF expression diff --git a/tests/execution.rs b/tests/execution.rs index 5fa7c33..38724b9 100644 --- a/tests/execution.rs +++ b/tests/execution.rs @@ -530,6 +530,17 @@ async fn exec_split_provider_select_specific_columns() { assert_eq!(ids.len(), 2, "expected 2 results; got {ids:?}"); } +/// SELECT specific columns without projecting the distance expression. +/// This is the production shape behind `vector_distance(...)`. +#[tokio::test] +async fn exec_split_provider_order_by_udf_direct() { + let ctx = make_split_provider_ctx("items::vector").await; + let sql = format!("SELECT id FROM items ORDER BY l2_distance(vector, {Q}) ASC LIMIT 2"); + let ids = collect_ids(&ctx, &sql).await; + assert_eq!(ids[0], 1, "closest must be row 1\nids: {ids:?}"); + assert_eq!(ids.len(), 2, "expected 2 results; got {ids:?}"); +} + /// SELECT * with distance UDF — should fall back to UDF brute-force /// (since vector column is not in lookup provider schema). #[tokio::test] From da9fb6db5e0f46bd5c59491043bf62c71e643e12 Mon Sep 17 00:00:00 2001 From: Anoop Narang Date: Wed, 8 Apr 2026 18:01:43 +0530 Subject: [PATCH 2/3] style(rule): format hidden distance rewrite --- src/rule.rs | 9 +++++---- tests/execution.rs | 3 ++- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/src/rule.rs b/src/rule.rs index e8eca08..eaa54dd 100644 --- a/src/rule.rs +++ b/src/rule.rs @@ -160,9 +160,9 @@ impl USearchRule { remap_projections(proj_exprs_slice, dist_alias_str, &table_ref) }; let remapped_sort_exprs = remap_sort_exprs(&sort.expr, dist_alias.as_deref()); - let needs_hidden_distance = remapped_sort_exprs.iter().any(|e| { - matches!(&e.expr, Expr::Column(c) if c.relation.is_none() && c.name == "_distance") - }) && !projection_exposes_name(&final_proj_exprs, "_distance"); + let needs_hidden_distance = remapped_sort_exprs.iter().any( + |e| matches!(&e.expr, Expr::Column(c) if c.relation.is_none() && c.name == "_distance"), + ) && !projection_exposes_name(&final_proj_exprs, "_distance"); let mut sort_input_exprs = final_proj_exprs.clone(); if needs_hidden_distance { @@ -365,7 +365,8 @@ fn projection_exposes_name(exprs: &[Expr], name: &str) -> bool { } fn build_outer_projection(exprs: &[Expr]) -> Vec { - exprs.iter() + exprs + .iter() .filter_map(|expr| match expr { Expr::Alias(a) => Some(col(a.name.as_str())), Expr::Column(c) => Some(Expr::Column(c.clone())), diff --git a/tests/execution.rs b/tests/execution.rs index 38724b9..92bd933 100644 --- a/tests/execution.rs +++ b/tests/execution.rs @@ -531,7 +531,8 @@ async fn exec_split_provider_select_specific_columns() { } /// SELECT specific columns without projecting the distance expression. -/// This is the production shape behind `vector_distance(...)`. +/// This matches the split-provider direct ORDER BY shape used by callers that +/// rewrite higher-level search helpers into the low-level distance UDF. #[tokio::test] async fn exec_split_provider_order_by_udf_direct() { let ctx = make_split_provider_ctx("items::vector").await; From 8e53743e73c54fc4f7606463c27e1de18d00e4c0 Mon Sep 17 00:00:00 2001 From: Anoop Narang Date: Wed, 8 Apr 2026 18:16:18 +0530 Subject: [PATCH 3/3] test(rule): cover computed sort projections --- src/rule.rs | 8 ++--- tests/execution.rs | 82 +++++++++++++++++++++++++++++++++++++++++++++- 2 files changed, 85 insertions(+), 5 deletions(-) diff --git a/src/rule.rs b/src/rule.rs index eaa54dd..7896516 100644 --- a/src/rule.rs +++ b/src/rule.rs @@ -367,10 +367,10 @@ fn projection_exposes_name(exprs: &[Expr], name: &str) -> bool { fn build_outer_projection(exprs: &[Expr]) -> Vec { exprs .iter() - .filter_map(|expr| match expr { - Expr::Alias(a) => Some(col(a.name.as_str())), - Expr::Column(c) => Some(Expr::Column(c.clone())), - _ => None, + .map(|expr| match expr { + Expr::Alias(a) => col(a.name.as_str()), + Expr::Column(c) => Expr::Column(c.clone()), + other => col(other.schema_name().to_string()), }) .collect() } diff --git a/tests/execution.rs b/tests/execution.rs index 92bd933..8fc47fb 100644 --- a/tests/execution.rs +++ b/tests/execution.rs @@ -19,7 +19,9 @@ use std::sync::Arc; use arrow_array::builder::{FixedSizeListBuilder, Float32Builder}; -use arrow_array::{FixedSizeListArray, Float32Array, RecordBatch, StringArray, UInt64Array}; +use arrow_array::{ + FixedSizeListArray, Float32Array, Int64Array, RecordBatch, StringArray, UInt64Array, +}; use arrow_schema::{DataType, Field, Schema}; use datafusion::execution::session_state::SessionStateBuilder; use datafusion::prelude::SessionContext; @@ -152,6 +154,60 @@ async fn collect_ids(ctx: &SessionContext, sql: &str) -> Vec { ids } +/// Collect a named integer column from a query result. +async fn collect_i64_column(ctx: &SessionContext, sql: &str, column_name: &str) -> Vec { + let df = ctx + .sql(sql) + .await + .unwrap_or_else(|e| panic!("sql() failed: {e}\nSQL: {sql}")); + let batches = df + .collect() + .await + .unwrap_or_else(|e| panic!("collect() failed: {e}\nSQL: {sql}")); + + let mut values: Vec = vec![]; + for batch in &batches { + let col_idx = batch + .schema() + .index_of(column_name) + .unwrap_or_else(|e| panic!("no '{column_name}' column in result: {e}\nSQL: {sql}")); + let column = batch.column(col_idx); + if let Some(arr) = column.as_any().downcast_ref::() { + values.extend(arr.values().iter().map(|v| *v as i64)); + } else if let Some(arr) = column.as_any().downcast_ref::() { + values.extend(arr.values()); + } else { + panic!("column '{column_name}' not Int64/UInt64\nSQL: {sql}"); + } + } + values +} + +/// Collect the first integer column from a query result. +async fn collect_first_i64_column(ctx: &SessionContext, sql: &str) -> Vec { + let df = ctx + .sql(sql) + .await + .unwrap_or_else(|e| panic!("sql() failed: {e}\nSQL: {sql}")); + let batches = df + .collect() + .await + .unwrap_or_else(|e| panic!("collect() failed: {e}\nSQL: {sql}")); + + let mut values: Vec = vec![]; + for batch in &batches { + let column = batch.column(0); + if let Some(arr) = column.as_any().downcast_ref::() { + values.extend(arr.values().iter().map(|v| *v as i64)); + } else if let Some(arr) = column.as_any().downcast_ref::() { + values.extend(arr.values()); + } else { + panic!("first result column not Int64/UInt64\nSQL: {sql}"); + } + } + values +} + const Q: &str = "ARRAY[1.0::float, 0.0::float, 0.0::float, 0.0::float]"; // ═══════════════════════════════════════════════════════════════════════════════ @@ -542,6 +598,30 @@ async fn exec_split_provider_order_by_udf_direct() { assert_eq!(ids.len(), 2, "expected 2 results; got {ids:?}"); } +/// Direct ORDER BY UDF with an aliased computed projection must preserve the +/// computed output through the rewrite. +#[tokio::test] +async fn exec_split_provider_order_by_udf_with_computed_alias() { + let ctx = make_split_provider_ctx("items::vector").await; + let sql = format!( + "SELECT CAST(id + 1 AS BIGINT) AS id_plus FROM items ORDER BY l2_distance(vector, {Q}) ASC LIMIT 2" + ); + let values = collect_i64_column(&ctx, &sql, "id_plus").await; + assert_eq!(values, vec![2, 3], "unexpected computed values: {values:?}"); +} + +/// Direct ORDER BY UDF with an unaliased computed projection relies on the +/// outer projection rebuilding by schema name rather than by raw expression. +#[tokio::test] +async fn exec_split_provider_order_by_udf_with_computed_expr() { + let ctx = make_split_provider_ctx("items::vector").await; + let sql = format!( + "SELECT CAST(id + 1 AS BIGINT) FROM items ORDER BY l2_distance(vector, {Q}) ASC LIMIT 2" + ); + let values = collect_first_i64_column(&ctx, &sql).await; + assert_eq!(values, vec![2, 3], "unexpected computed values: {values:?}"); +} + /// SELECT * with distance UDF — should fall back to UDF brute-force /// (since vector column is not in lookup provider schema). #[tokio::test]