Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 71 additions & 18 deletions src/rule.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,10 @@
//
// Replacement:
//
// Sort(fetch=k) ← kept (sort order)
// Projection([col(a), col(b), col("_distance").alias("dist")])
// USearchNode ← executes ANN
// Projection([final output cols])
// Sort(fetch=k)
// Projection([final output cols + optional hidden _distance])
// USearchNode

use std::collections::HashMap;
use std::sync::Arc;
Expand Down Expand Up @@ -151,25 +152,33 @@ impl USearchRule {
node: Arc::new(node) as Arc<dyn UserDefinedLogicalNode>,
}));

// Build Projection over USearchNode matching the original output schema.
// Build the final user-visible projection over USearchNode output.
let dist_alias_str = dist_alias.as_deref().unwrap_or("_distance");
let new_proj_exprs = if proj_exprs_slice.is_empty() {
let final_proj_exprs = if proj_exprs_slice.is_empty() {
passthrough_projection(&vsn_df_schema, &table_ref)
} else {
remap_projections(proj_exprs_slice, dist_alias_str, &table_ref)
};
let new_proj = Projection::try_new(new_proj_exprs, node_plan).ok()?;

// Keep the Sort node so DataFusion handles ordering by _distance / dist.
// USearch returns results in arbitrary (internal) order when the underlying
// data is fetched from the TableProvider.
Some(LogicalPlan::Sort(
datafusion::logical_expr::logical_plan::Sort {
expr: sort.expr.clone(),
input: Arc::new(LogicalPlan::Projection(new_proj)),
fetch: sort.fetch,
},
))
let remapped_sort_exprs = remap_sort_exprs(&sort.expr, dist_alias.as_deref());
let needs_hidden_distance = remapped_sort_exprs.iter().any(
|e| matches!(&e.expr, Expr::Column(c) if c.relation.is_none() && c.name == "_distance"),
) && !projection_exposes_name(&final_proj_exprs, "_distance");

let mut sort_input_exprs = final_proj_exprs.clone();
if needs_hidden_distance {
sort_input_exprs.push(col("_distance"));
}

let sort_input = Projection::try_new(sort_input_exprs, node_plan).ok()?;
let sorted = LogicalPlan::Sort(datafusion::logical_expr::logical_plan::Sort {
expr: remapped_sort_exprs,
input: Arc::new(LogicalPlan::Projection(sort_input)),
fetch: sort.fetch,
});

let outer_proj_exprs = build_outer_projection(&final_proj_exprs);
let outer_proj = Projection::try_new(outer_proj_exprs, Arc::new(sorted)).ok()?;
Some(LogicalPlan::Projection(outer_proj))
}
}

Expand Down Expand Up @@ -283,7 +292,11 @@ fn dist_type_matches_metric(dist_type: &DistanceType, metric: MetricKind) -> boo
}

fn is_distance_expr(expr: &Expr) -> bool {
matches!(expr, Expr::ScalarFunction(sf) if is_dist_udf_name(sf.func.name()))
let inner = match expr {
Expr::Alias(a) => a.expr.as_ref(),
other => other,
};
matches!(inner, Expr::ScalarFunction(sf) if is_dist_udf_name(sf.func.name()))
}

fn try_extract_distance(expr: &Expr) -> Option<(String, String, Vec<f64>)> {
Expand Down Expand Up @@ -322,6 +335,46 @@ fn remap_projections(
.collect()
}

fn remap_sort_exprs(
sort_exprs: &[datafusion::logical_expr::SortExpr],
dist_alias_name: Option<&str>,
) -> Vec<datafusion::logical_expr::SortExpr> {
sort_exprs
.iter()
.map(|sort_expr| {
let remapped_expr = match &sort_expr.expr {
Expr::Column(c) if Some(c.name.as_str()) == dist_alias_name => col(c.name.as_str()),
expr if is_distance_expr(expr) => col("_distance"),
other => other.clone(),
};
datafusion::logical_expr::SortExpr {
expr: remapped_expr,
asc: sort_expr.asc,
nulls_first: sort_expr.nulls_first,
}
})
.collect()
}

fn projection_exposes_name(exprs: &[Expr], name: &str) -> bool {
exprs.iter().any(|expr| match expr {
Expr::Alias(a) => a.name == name,
Expr::Column(c) => c.name == name,
_ => false,
})
}

fn build_outer_projection(exprs: &[Expr]) -> Vec<Expr> {
exprs
.iter()
.map(|expr| match expr {
Expr::Alias(a) => col(a.name.as_str()),
Expr::Column(c) => Expr::Column(c.clone()),
other => col(other.schema_name().to_string()),
})
Comment thread
anoop-narang marked this conversation as resolved.
.collect()
}
Comment thread
anoop-narang marked this conversation as resolved.

/// Build a passthrough Projection for SELECT * queries (no original Projection node).
/// Projects only the original table columns (not `_distance`) so the output schema
/// matches the original Sort schema. The Sort re-evaluates the distance UDF expression
Expand Down
94 changes: 93 additions & 1 deletion tests/execution.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,9 @@
use std::sync::Arc;

use arrow_array::builder::{FixedSizeListBuilder, Float32Builder};
use arrow_array::{FixedSizeListArray, Float32Array, RecordBatch, StringArray, UInt64Array};
use arrow_array::{
FixedSizeListArray, Float32Array, Int64Array, RecordBatch, StringArray, UInt64Array,
};
use arrow_schema::{DataType, Field, Schema};
use datafusion::execution::session_state::SessionStateBuilder;
use datafusion::prelude::SessionContext;
Expand Down Expand Up @@ -152,6 +154,60 @@ async fn collect_ids(ctx: &SessionContext, sql: &str) -> Vec<u64> {
ids
}

/// Collect a named integer column from a query result.
async fn collect_i64_column(ctx: &SessionContext, sql: &str, column_name: &str) -> Vec<i64> {
let df = ctx
.sql(sql)
.await
.unwrap_or_else(|e| panic!("sql() failed: {e}\nSQL: {sql}"));
let batches = df
.collect()
.await
.unwrap_or_else(|e| panic!("collect() failed: {e}\nSQL: {sql}"));

let mut values: Vec<i64> = vec![];
for batch in &batches {
let col_idx = batch
.schema()
.index_of(column_name)
.unwrap_or_else(|e| panic!("no '{column_name}' column in result: {e}\nSQL: {sql}"));
let column = batch.column(col_idx);
if let Some(arr) = column.as_any().downcast_ref::<UInt64Array>() {
values.extend(arr.values().iter().map(|v| *v as i64));
} else if let Some(arr) = column.as_any().downcast_ref::<Int64Array>() {
values.extend(arr.values());
} else {
panic!("column '{column_name}' not Int64/UInt64\nSQL: {sql}");
}
}
values
}

/// Collect the first integer column from a query result.
async fn collect_first_i64_column(ctx: &SessionContext, sql: &str) -> Vec<i64> {
let df = ctx
.sql(sql)
.await
.unwrap_or_else(|e| panic!("sql() failed: {e}\nSQL: {sql}"));
let batches = df
.collect()
.await
.unwrap_or_else(|e| panic!("collect() failed: {e}\nSQL: {sql}"));

let mut values: Vec<i64> = vec![];
for batch in &batches {
let column = batch.column(0);
if let Some(arr) = column.as_any().downcast_ref::<UInt64Array>() {
values.extend(arr.values().iter().map(|v| *v as i64));
} else if let Some(arr) = column.as_any().downcast_ref::<Int64Array>() {
values.extend(arr.values());
} else {
panic!("first result column not Int64/UInt64\nSQL: {sql}");
}
}
values
}
Comment thread
anoop-narang marked this conversation as resolved.

const Q: &str = "ARRAY[1.0::float, 0.0::float, 0.0::float, 0.0::float]";

// ═══════════════════════════════════════════════════════════════════════════════
Expand Down Expand Up @@ -530,6 +586,42 @@ async fn exec_split_provider_select_specific_columns() {
assert_eq!(ids.len(), 2, "expected 2 results; got {ids:?}");
}

/// SELECT specific columns without projecting the distance expression.
/// This matches the split-provider direct ORDER BY shape used by callers that
/// rewrite higher-level search helpers into the low-level distance UDF.
#[tokio::test]
async fn exec_split_provider_order_by_udf_direct() {
let ctx = make_split_provider_ctx("items::vector").await;
let sql = format!("SELECT id FROM items ORDER BY l2_distance(vector, {Q}) ASC LIMIT 2");
let ids = collect_ids(&ctx, &sql).await;
assert_eq!(ids[0], 1, "closest must be row 1\nids: {ids:?}");
assert_eq!(ids.len(), 2, "expected 2 results; got {ids:?}");
}

/// Direct ORDER BY UDF with an aliased computed projection must preserve the
/// computed output through the rewrite.
#[tokio::test]
async fn exec_split_provider_order_by_udf_with_computed_alias() {
let ctx = make_split_provider_ctx("items::vector").await;
let sql = format!(
"SELECT CAST(id + 1 AS BIGINT) AS id_plus FROM items ORDER BY l2_distance(vector, {Q}) ASC LIMIT 2"
);
let values = collect_i64_column(&ctx, &sql, "id_plus").await;
assert_eq!(values, vec![2, 3], "unexpected computed values: {values:?}");
}

/// Direct ORDER BY UDF with an unaliased computed projection relies on the
/// outer projection rebuilding by schema name rather than by raw expression.
#[tokio::test]
async fn exec_split_provider_order_by_udf_with_computed_expr() {
let ctx = make_split_provider_ctx("items::vector").await;
let sql = format!(
"SELECT CAST(id + 1 AS BIGINT) FROM items ORDER BY l2_distance(vector, {Q}) ASC LIMIT 2"
);
let values = collect_first_i64_column(&ctx, &sql).await;
assert_eq!(values, vec![2, 3], "unexpected computed values: {values:?}");
}

/// SELECT * with distance UDF — should fall back to UDF brute-force
/// (since vector column is not in lookup provider schema).
#[tokio::test]
Expand Down
Loading