From 3e2f4990d254e2dbee9930271c4440169693413e Mon Sep 17 00:00:00 2001 From: yusinnmao Date: Mon, 27 Apr 2026 14:42:36 +0800 Subject: [PATCH 01/15] feat: Support Spark levenshtein expression in native execution Implements the Levenshtein edit distance function as a native Comet scalar UDF, enabling Spark's `levenshtein(str1, str2)` to run via DataFusion instead of falling back to JVM. Implementation: - Rust kernel: O(min(m,n)) space DP algorithm with Unicode char-level distance computation and proper NULL propagation - Scala serde: Register via CometScalarFunction("levenshtein") in QueryPlanSerde stringExpressions map - Tests: Basic, NULL handling, and Unicode test cases Closes #3084 --- native/spark-expr/src/comet_scalar_funcs.rs | 4 + .../src/string_funcs/levenshtein.rs | 164 ++++++++++++++++++ native/spark-expr/src/string_funcs/mod.rs | 2 + .../apache/comet/serde/QueryPlanSerde.scala | 1 + .../org/apache/comet/serde/strings.scala | 28 ++- .../comet/CometStringExpressionSuite.scala | 33 ++++ 6 files changed, 230 insertions(+), 2 deletions(-) create mode 100644 native/spark-expr/src/string_funcs/levenshtein.rs diff --git a/native/spark-expr/src/comet_scalar_funcs.rs b/native/spark-expr/src/comet_scalar_funcs.rs index 784e6c6829..2b8ba93f6e 100644 --- a/native/spark-expr/src/comet_scalar_funcs.rs +++ b/native/spark-expr/src/comet_scalar_funcs.rs @@ -200,6 +200,10 @@ pub fn create_comet_physical_fun_with_eval_mode( "to_time" => { make_comet_scalar_udf!("to_time", spark_to_time, without data_type, fail_on_error) } + "levenshtein" => { + let func = Arc::new(crate::string_funcs::spark_levenshtein); + make_comet_scalar_udf!("levenshtein", func, without data_type) + } _ => registry.udf(fun_name).map_err(|e| { DataFusionError::Execution(format!( "Function {fun_name} not found in the registry: {e}", diff --git a/native/spark-expr/src/string_funcs/levenshtein.rs b/native/spark-expr/src/string_funcs/levenshtein.rs new file mode 100644 index 0000000000..e0ddf7bce7 --- /dev/null +++ b/native/spark-expr/src/string_funcs/levenshtein.rs @@ -0,0 +1,164 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Levenshtein distance expression implementation. +//! +//! Computes the Levenshtein edit distance between two strings, +//! matching Apache Spark's `levenshtein(str1, str2)` semantics. + +use arrow::array::{as_string_array, Array, ArrayRef, Int32Array}; +use datafusion::common::{DataFusionError, Result}; +use datafusion::physical_plan::ColumnarValue; +use std::sync::Arc; + +/// Computes the Levenshtein edit distance between two UTF-8 strings. +/// +/// This uses the standard dynamic programming algorithm with O(min(m,n)) space. +fn levenshtein_distance(s: &str, t: &str) -> i32 { + let s_chars: Vec = s.chars().collect(); + let t_chars: Vec = t.chars().collect(); + let m = s_chars.len(); + let n = t_chars.len(); + + // Optimization: if one string is empty, distance is the length of the other + if m == 0 { + return n as i32; + } + if n == 0 { + return m as i32; + } + + // Use the shorter string for the "column" to minimize space usage + let (s_chars, t_chars, m, n) = if m > n { + (t_chars, s_chars, n, m) + } else { + (s_chars, t_chars, m, n) + }; + + // Previous and current row of distances + let mut prev = vec![0i32; m + 1]; + let mut curr = vec![0i32; m + 1]; + + // Initialize base case: distance from empty string + for i in 0..=m { + prev[i] = i as i32; + } + + for j in 1..=n { + curr[0] = j as i32; + for i in 1..=m { + let cost = if s_chars[i - 1] == t_chars[j - 1] { + 0 + } else { + 1 + }; + curr[i] = (prev[i] + 1) // deletion + .min(curr[i - 1] + 1) // insertion + .min(prev[i - 1] + cost); // substitution + } + std::mem::swap(&mut prev, &mut curr); + } + + prev[m] +} + +/// Spark-compatible levenshtein scalar function. +/// +/// Accepts two string arguments and returns an Int32 array of edit distances. +/// NULL inputs produce NULL outputs. +pub fn spark_levenshtein(args: &[ColumnarValue]) -> Result { + if args.len() != 2 { + return Err(DataFusionError::Internal(format!( + "levenshtein requires exactly 2 arguments, got {}", + args.len() + ))); + } + + // Expand scalars to arrays for uniform processing + let len = args + .iter() + .find_map(|arg| match arg { + ColumnarValue::Array(a) => Some(a.len()), + _ => None, + }) + .unwrap_or(1); + + let left = args[0].clone().into_array(len)?; + let right = args[1].clone().into_array(len)?; + + let left_arr = as_string_array(&left); + let right_arr = as_string_array(&right); + + let result: Int32Array = left_arr + .iter() + .zip(right_arr.iter()) + .map(|(l, r)| match (l, r) { + (Some(l), Some(r)) => Some(levenshtein_distance(l, r)), + _ => None, // NULL propagation + }) + .collect(); + + Ok(ColumnarValue::Array(Arc::new(result) as ArrayRef)) +} + +#[cfg(test)] +mod tests { + use super::*; + use arrow::array::StringArray; + + #[test] + fn test_levenshtein_basic() { + assert_eq!(levenshtein_distance("", ""), 0); + assert_eq!(levenshtein_distance("abc", ""), 3); + assert_eq!(levenshtein_distance("", "abc"), 3); + assert_eq!(levenshtein_distance("abc", "abc"), 0); + assert_eq!(levenshtein_distance("kitten", "sitting"), 3); + assert_eq!(levenshtein_distance("frog", "fog"), 1); + } + + #[test] + fn test_levenshtein_unicode() { + // Spark counts character-level (not byte-level) edit distance + assert_eq!(levenshtein_distance("你好", "你坏"), 1); + assert_eq!(levenshtein_distance("abc", "äbc"), 1); + } + + #[test] + fn test_spark_levenshtein_nulls() { + let left = ColumnarValue::Array(Arc::new(StringArray::from(vec![ + Some("abc"), + None, + Some("hello"), + ]))); + let right = ColumnarValue::Array(Arc::new(StringArray::from(vec![ + Some("adc"), + Some("test"), + None, + ]))); + + let result = spark_levenshtein(&[left, right]).unwrap(); + match result { + ColumnarValue::Array(arr) => { + let int_arr = arr.as_any().downcast_ref::().unwrap(); + assert_eq!(int_arr.value(0), 1); // abc -> adc = 1 + assert!(int_arr.is_null(1)); // NULL -> test = NULL + assert!(int_arr.is_null(2)); // hello -> NULL = NULL + } + _ => panic!("Expected array result"), + } + } +} diff --git a/native/spark-expr/src/string_funcs/mod.rs b/native/spark-expr/src/string_funcs/mod.rs index bb785bdb44..ed871ef0da 100644 --- a/native/spark-expr/src/string_funcs/mod.rs +++ b/native/spark-expr/src/string_funcs/mod.rs @@ -17,10 +17,12 @@ mod contains; mod get_json_object; +mod levenshtein; mod split; mod substring; pub use contains::SparkContains; pub use get_json_object::spark_get_json_object; +pub use levenshtein::spark_levenshtein; pub use split::spark_split; pub use substring::SubstringExpr; diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index 0bdc02a790..b1b4efef1d 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -183,6 +183,7 @@ object QueryPlanSerde extends Logging with CometExprShim with CometTypeShim { classOf[GetJsonObject] -> CometGetJsonObject, classOf[InitCap] -> CometInitCap, classOf[Length] -> CometLength, + classOf[Levenshtein] -> CometLevenshtein, classOf[Like] -> CometLike, classOf[Lower] -> CometLower, classOf[OctetLength] -> CometScalarFunction("octet_length"), diff --git a/spark/src/main/scala/org/apache/comet/serde/strings.scala b/spark/src/main/scala/org/apache/comet/serde/strings.scala index f2f10d5f1c..d4ef76d4ce 100644 --- a/spark/src/main/scala/org/apache/comet/serde/strings.scala +++ b/spark/src/main/scala/org/apache/comet/serde/strings.scala @@ -21,8 +21,8 @@ package org.apache.comet.serde import java.util.Locale -import org.apache.spark.sql.catalyst.expressions.{Attribute, Cast, Concat, ConcatWs, Expression, GetJsonObject, If, InitCap, IsNull, Left, Length, Like, Literal, Lower, RegExpReplace, Right, RLike, StringLPad, StringRepeat, StringRPad, StringSplit, Substring, SubstringIndex, Upper} -import org.apache.spark.sql.types.{BinaryType, DataTypes, LongType, StringType} +import org.apache.spark.sql.catalyst.expressions.{Attribute, Cast, Concat, ConcatWs, Expression, GetJsonObject, If, InitCap, IsNull, Left, Length, Levenshtein, Like, Literal, Lower, RegExpReplace, Right, RLike, StringLPad, StringRepeat, StringRPad, StringSplit, Substring, SubstringIndex, Upper} +import org.apache.spark.sql.types.{BinaryType, DataTypes, IntegerType, LongType, StringType} import org.apache.spark.unsafe.types.UTF8String import org.apache.comet.CometConf @@ -83,6 +83,30 @@ object CometLength extends CometScalarFunction[Length]("length") { } } +object CometLevenshtein extends CometExpressionSerde[Levenshtein] { + + override def getUnsupportedReasons(): Seq[String] = Seq( + "Non-default collation (non-UTF8_BINARY) is not supported") + + override def getSupportLevel(expr: Levenshtein): SupportLevel = { + expr.children.headOption match { + case Some(child) if QueryPlanSerde.isStringCollationType(child.dataType) => + Unsupported(Some("Levenshtein with non-default collation is not supported")) + case _ => Compatible() + } + } + + override def convert( + expr: Levenshtein, + inputs: Seq[Attribute], + binding: Boolean): Option[Expr] = { + val childExprs = expr.children.map(exprToProtoInternal(_, inputs, binding)) + val optExpr = + scalarFunctionExprToProtoWithReturnType("levenshtein", IntegerType, false, childExprs: _*) + optExprWithFallbackReason(optExpr, expr, expr.children: _*) + } +} + object CometInitCap extends CometScalarFunction[InitCap]("initcap") { override def getSupportLevel(expr: InitCap): SupportLevel = Compatible() diff --git a/spark/src/test/scala/org/apache/comet/CometStringExpressionSuite.scala b/spark/src/test/scala/org/apache/comet/CometStringExpressionSuite.scala index 89d5dfd4bc..90e58be77e 100644 --- a/spark/src/test/scala/org/apache/comet/CometStringExpressionSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometStringExpressionSuite.scala @@ -707,4 +707,37 @@ class CometStringExpressionSuite extends CometTestBase { // scalastyle:on } + test("levenshtein") { + val data = Seq( + ("kitten", "sitting"), + ("frog", "fog"), + ("abc", "abc"), + ("", "hello"), + ("hello", "")) + + withParquetTable(data, "tbl") { + checkSparkAnswerAndOperator("SELECT levenshtein(_1, _2) FROM tbl") + } + } + + test("levenshtein with nulls") { + val table = "levenshtein_null_test" + withTable(table) { + sql(s"CREATE TABLE $table(s1 STRING, s2 STRING) USING parquet") + sql(s"INSERT INTO $table VALUES ('abc', 'adc'), (NULL, 'test'), ('hello', NULL), (NULL, NULL)") + checkSparkAnswerAndOperator(s"SELECT levenshtein(s1, s2) FROM $table") + } + } + + test("levenshtein with unicode") { + val data = Seq( + ("\u4f60\u597d", "\u4f60\u574f"), + ("caf\u00e9", "cafe"), + ("\ud83d\ude00", "\ud83d\ude01")) + + withParquetTable(data, "tbl") { + checkSparkAnswerAndOperator("SELECT levenshtein(_1, _2) FROM tbl") + } + } + } From 33ef36259741292e2b663842b7d725fdf18a4456 Mon Sep 17 00:00:00 2001 From: yusinnmao Date: Mon, 27 Apr 2026 16:25:40 +0800 Subject: [PATCH 02/15] style: fix Spotless formatting violations in test code - Reformat Seq literal to match scalafmt maxColumn=98 constraint - Break long sql() call into multi-line format - Reformat unicode test data as multi-line Seq --- .../apache/comet/CometStringExpressionSuite.scala | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/spark/src/test/scala/org/apache/comet/CometStringExpressionSuite.scala b/spark/src/test/scala/org/apache/comet/CometStringExpressionSuite.scala index 90e58be77e..0a6e17194a 100644 --- a/spark/src/test/scala/org/apache/comet/CometStringExpressionSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometStringExpressionSuite.scala @@ -708,12 +708,8 @@ class CometStringExpressionSuite extends CometTestBase { } test("levenshtein") { - val data = Seq( - ("kitten", "sitting"), - ("frog", "fog"), - ("abc", "abc"), - ("", "hello"), - ("hello", "")) + val data = + Seq(("kitten", "sitting"), ("frog", "fog"), ("abc", "abc"), ("", "hello"), ("hello", "")) withParquetTable(data, "tbl") { checkSparkAnswerAndOperator("SELECT levenshtein(_1, _2) FROM tbl") @@ -724,7 +720,9 @@ class CometStringExpressionSuite extends CometTestBase { val table = "levenshtein_null_test" withTable(table) { sql(s"CREATE TABLE $table(s1 STRING, s2 STRING) USING parquet") - sql(s"INSERT INTO $table VALUES ('abc', 'adc'), (NULL, 'test'), ('hello', NULL), (NULL, NULL)") + sql( + s"INSERT INTO $table VALUES " + + s"('abc', 'adc'), (NULL, 'test'), ('hello', NULL), (NULL, NULL)") checkSparkAnswerAndOperator(s"SELECT levenshtein(s1, s2) FROM $table") } } From feb3d8fef6f0430459ef75d8d83cf13f6c7d6275 Mon Sep 17 00:00:00 2001 From: yusinnmao Date: Mon, 27 Apr 2026 16:35:45 +0800 Subject: [PATCH 03/15] fix: address CI lint failures - Fix Clippy needless_range_loop: use iter_mut().enumerate() instead of indexing loop for initializing prev array - Fix Scalafix unused import: remove Levenshtein from strings.scala imports (it's used via wildcard import in QueryPlanSerde.scala) --- native/spark-expr/src/string_funcs/levenshtein.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/native/spark-expr/src/string_funcs/levenshtein.rs b/native/spark-expr/src/string_funcs/levenshtein.rs index e0ddf7bce7..8194e67605 100644 --- a/native/spark-expr/src/string_funcs/levenshtein.rs +++ b/native/spark-expr/src/string_funcs/levenshtein.rs @@ -54,8 +54,8 @@ fn levenshtein_distance(s: &str, t: &str) -> i32 { let mut curr = vec![0i32; m + 1]; // Initialize base case: distance from empty string - for i in 0..=m { - prev[i] = i as i32; + for (i, val) in prev.iter_mut().enumerate() { + *val = i as i32; } for j in 1..=n { From ffe0b4a6542a25e39e7411cd06026ad4f02ff7af Mon Sep 17 00:00:00 2001 From: yusinnmao Date: Mon, 27 Apr 2026 17:39:37 +0800 Subject: [PATCH 04/15] test: add SLT tests for levenshtein expression Add SQL logic test file covering: - Column arguments with NULL propagation - Column + literal combinations - Literal + literal edge cases - Identical string comparison - Unicode character-level distance --- .../expressions/string/levenshtein.sql | 46 +++++++++++++++++++ 1 file changed, 46 insertions(+) create mode 100644 spark/src/test/resources/sql-tests/expressions/string/levenshtein.sql diff --git a/spark/src/test/resources/sql-tests/expressions/string/levenshtein.sql b/spark/src/test/resources/sql-tests/expressions/string/levenshtein.sql new file mode 100644 index 0000000000..9f11c415b2 --- /dev/null +++ b/spark/src/test/resources/sql-tests/expressions/string/levenshtein.sql @@ -0,0 +1,46 @@ +-- Licensed to the Apache Software Foundation (ASF) under one +-- or more contributor license agreements. See the NOTICE file +-- distributed with this work for additional information +-- regarding copyright ownership. The ASF licenses this file +-- to you under the Apache License, Version 2.0 (the +-- "License"); you may not use this file except in compliance +-- with the License. You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, +-- software distributed under the License is distributed on an +-- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +-- KIND, either express or implied. See the License for the +-- specific language governing permissions and limitations +-- under the License. + +statement +CREATE TABLE test_levenshtein(s1 string, s2 string) USING parquet + +statement +INSERT INTO test_levenshtein VALUES ('kitten', 'sitting'), ('frog', 'fog'), ('abc', 'abc'), ('', 'hello'), ('hello', ''), ('', ''), (NULL, 'test'), ('hello', NULL), (NULL, NULL) + +-- column arguments +query +SELECT levenshtein(s1, s2) FROM test_levenshtein + +-- column + literal +query +SELECT levenshtein(s1, 'abc') FROM test_levenshtein + +-- literal + column +query +SELECT levenshtein('kitten', s2) FROM test_levenshtein + +-- literal + literal +query +SELECT levenshtein('kitten', 'sitting'), levenshtein('frog', 'fog'), levenshtein('', ''), levenshtein(NULL, 'a') + +-- identical strings +query +SELECT levenshtein(s1, s1) FROM test_levenshtein + +-- unicode characters +query +SELECT levenshtein('café', 'cafe'), levenshtein('你好', '你坏') From 746674a44cbcfdbcfbc61dc297aeab621f9fceda Mon Sep 17 00:00:00 2001 From: yusinnmao Date: Tue, 28 Apr 2026 15:08:19 +0800 Subject: [PATCH 05/15] Add levenshtein to benchmark and update expressions.md - Add levenshtein benchmark entry in CometStringExpressionBenchmark - Update expressions.md to list Levenshtein as supported Co-Authored-By: Claude Opus 4.6 (1M context) --- docs/source/user-guide/latest/expressions.md | 1 + .../spark/sql/benchmark/CometStringExpressionBenchmark.scala | 2 ++ 2 files changed, 3 insertions(+) diff --git a/docs/source/user-guide/latest/expressions.md b/docs/source/user-guide/latest/expressions.md index 0823e79fed..46afc69b0d 100644 --- a/docs/source/user-guide/latest/expressions.md +++ b/docs/source/user-guide/latest/expressions.md @@ -68,6 +68,7 @@ of expressions that be disabled. | InitCap | | Left | | Length | +| Levenshtein | | Like | | Lower | | OctetLength | diff --git a/spark/src/test/scala/org/apache/spark/sql/benchmark/CometStringExpressionBenchmark.scala b/spark/src/test/scala/org/apache/spark/sql/benchmark/CometStringExpressionBenchmark.scala index d7be505161..b267dac691 100644 --- a/spark/src/test/scala/org/apache/spark/sql/benchmark/CometStringExpressionBenchmark.scala +++ b/spark/src/test/scala/org/apache/spark/sql/benchmark/CometStringExpressionBenchmark.scala @@ -57,6 +57,8 @@ object CometStringExpressionBenchmark extends CometBenchmarkBase { StringExprConfig("initCap", "select initCap(c1) from parquetV1Table"), StringExprConfig("instr", "select instr(c1, '123') from parquetV1Table"), StringExprConfig("length", "select length(c1) from parquetV1Table"), + StringExprConfig("levenshtein", "select levenshtein(c1, 'test') from parquetV1Table"), + StringExprConfig("levenshtein", "select levenshtein(c1, 'test') from parquetV1Table"), StringExprConfig("like", "select c1 like '%123%' from parquetV1Table"), StringExprConfig("lower", "select lower(c1) from parquetV1Table"), StringExprConfig("lpad", "select lpad(c1, 150, 'x') from parquetV1Table"), From b4196769bf85d01873838292f61159eadac195b5 Mon Sep 17 00:00:00 2001 From: yusinnmao Date: Tue, 28 Apr 2026 15:20:01 +0800 Subject: [PATCH 06/15] Support three-argument levenshtein with threshold and add collation check - Implement threshold semantics: levenshtein(str1, str2, threshold) returns -1 when distance exceeds threshold, matching Spark behavior - Add CometLevenshtein serde with getSupportLevel that falls back for non-default collations (Spark 4 StringTypeWithCollation) - Add Rust tests for threshold and NULL threshold cases - Add SLT tests for threshold variants - Add Scala integration test for threshold - Add levenshtein_threshold to benchmark Co-Authored-By: Claude Opus 4.6 (1M context) --- .../src/string_funcs/levenshtein.rs | 78 +++++++++++++++++-- .../expressions/string/levenshtein.sql | 16 ++++ .../comet/CometStringExpressionSuite.scala | 17 ++++ .../CometStringExpressionBenchmark.scala | 3 + 4 files changed, 108 insertions(+), 6 deletions(-) diff --git a/native/spark-expr/src/string_funcs/levenshtein.rs b/native/spark-expr/src/string_funcs/levenshtein.rs index 8194e67605..9d2c48c7b7 100644 --- a/native/spark-expr/src/string_funcs/levenshtein.rs +++ b/native/spark-expr/src/string_funcs/levenshtein.rs @@ -21,7 +21,7 @@ //! matching Apache Spark's `levenshtein(str1, str2)` semantics. use arrow::array::{as_string_array, Array, ArrayRef, Int32Array}; -use datafusion::common::{DataFusionError, Result}; +use datafusion::common::{DataFusionError, Result, ScalarValue}; use datafusion::physical_plan::ColumnarValue; use std::sync::Arc; @@ -78,19 +78,40 @@ fn levenshtein_distance(s: &str, t: &str) -> i32 { /// Spark-compatible levenshtein scalar function. /// -/// Accepts two string arguments and returns an Int32 array of edit distances. -/// NULL inputs produce NULL outputs. +/// Accepts two or three arguments: +/// - `levenshtein(str1, str2)` → edit distance +/// - `levenshtein(str1, str2, threshold)` → edit distance if <= threshold, else -1 +/// +/// NULL inputs produce NULL outputs. NULL threshold produces NULL output. pub fn spark_levenshtein(args: &[ColumnarValue]) -> Result { - if args.len() != 2 { + if args.len() < 2 || args.len() > 3 { return Err(DataFusionError::Internal(format!( - "levenshtein requires exactly 2 arguments, got {}", + "levenshtein requires 2 or 3 arguments, got {}", args.len() ))); } + // Extract optional threshold (3rd argument must be a scalar Int32) + let threshold: Option = if args.len() == 3 { + match &args[2] { + ColumnarValue::Scalar(ScalarValue::Int32(t)) => match t { + Some(val) => Some(*val), + None => return Ok(ColumnarValue::Scalar(ScalarValue::Int32(None))), + }, + _ => { + return Err(DataFusionError::Internal( + "levenshtein threshold must be an Int32 scalar".to_string(), + )); + } + } + } else { + None + }; + // Expand scalars to arrays for uniform processing let len = args .iter() + .take(2) .find_map(|arg| match arg { ColumnarValue::Array(a) => Some(a.len()), _ => None, @@ -107,7 +128,13 @@ pub fn spark_levenshtein(args: &[ColumnarValue]) -> Result { .iter() .zip(right_arr.iter()) .map(|(l, r)| match (l, r) { - (Some(l), Some(r)) => Some(levenshtein_distance(l, r)), + (Some(l), Some(r)) => { + let dist = levenshtein_distance(l, r); + match threshold { + Some(t) if dist > t => Some(-1), + _ => Some(dist), + } + } _ => None, // NULL propagation }) .collect(); @@ -161,4 +188,43 @@ mod tests { _ => panic!("Expected array result"), } } + + #[test] + fn test_spark_levenshtein_with_threshold() { + let left = ColumnarValue::Array(Arc::new(StringArray::from(vec![ + Some("kitten"), + Some("abc"), + Some("frog"), + ]))); + let right = ColumnarValue::Array(Arc::new(StringArray::from(vec![ + Some("sitting"), + Some("adc"), + Some("fog"), + ]))); + let threshold = ColumnarValue::Scalar(ScalarValue::Int32(Some(2))); + + let result = spark_levenshtein(&[left, right, threshold]).unwrap(); + match result { + ColumnarValue::Array(arr) => { + let int_arr = arr.as_any().downcast_ref::().unwrap(); + assert_eq!(int_arr.value(0), -1); // kitten->sitting=3 > 2, return -1 + assert_eq!(int_arr.value(1), 1); // abc->adc=1 <= 2, return 1 + assert_eq!(int_arr.value(2), 1); // frog->fog=1 <= 2, return 1 + } + _ => panic!("Expected array result"), + } + } + + #[test] + fn test_spark_levenshtein_null_threshold() { + let left = ColumnarValue::Array(Arc::new(StringArray::from(vec![Some("abc")]))); + let right = ColumnarValue::Array(Arc::new(StringArray::from(vec![Some("adc")]))); + let threshold = ColumnarValue::Scalar(ScalarValue::Int32(None)); + + let result = spark_levenshtein(&[left, right, threshold]).unwrap(); + match result { + ColumnarValue::Scalar(ScalarValue::Int32(None)) => {} // NULL threshold -> NULL + _ => panic!("Expected NULL scalar result for NULL threshold"), + } + } } diff --git a/spark/src/test/resources/sql-tests/expressions/string/levenshtein.sql b/spark/src/test/resources/sql-tests/expressions/string/levenshtein.sql index 9f11c415b2..fc5ff5c8c5 100644 --- a/spark/src/test/resources/sql-tests/expressions/string/levenshtein.sql +++ b/spark/src/test/resources/sql-tests/expressions/string/levenshtein.sql @@ -44,3 +44,19 @@ SELECT levenshtein(s1, s1) FROM test_levenshtein -- unicode characters query SELECT levenshtein('café', 'cafe'), levenshtein('你好', '你坏') + +-- three argument version with threshold +query +SELECT levenshtein('kitten', 'sitting', 2), levenshtein('kitten', 'sitting', 3), levenshtein('kitten', 'sitting', 4) + +-- threshold with column arguments +query +SELECT levenshtein(s1, s2, 2) FROM test_levenshtein + +-- threshold edge cases +query +SELECT levenshtein('abc', 'abc', 0), levenshtein('abc', 'adc', 0), levenshtein('', '', 0) + +-- threshold with NULL +query +SELECT levenshtein('abc', 'adc', NULL), levenshtein(NULL, 'test', 2) diff --git a/spark/src/test/scala/org/apache/comet/CometStringExpressionSuite.scala b/spark/src/test/scala/org/apache/comet/CometStringExpressionSuite.scala index 0a6e17194a..514c53d123 100644 --- a/spark/src/test/scala/org/apache/comet/CometStringExpressionSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometStringExpressionSuite.scala @@ -738,4 +738,21 @@ class CometStringExpressionSuite extends CometTestBase { } } + test("levenshtein with threshold") { + val data = Seq( + ("kitten", "sitting"), + ("frog", "fog"), + ("abc", "abc"), + ("hello", "world")) + + withParquetTable(data, "tbl") { + checkSparkAnswerAndOperator( + "SELECT levenshtein(_1, _2, 2) FROM tbl") + checkSparkAnswerAndOperator( + "SELECT levenshtein(_1, _2, 0) FROM tbl") + checkSparkAnswerAndOperator( + "SELECT levenshtein(_1, _2, 10) FROM tbl") + } + } + } diff --git a/spark/src/test/scala/org/apache/spark/sql/benchmark/CometStringExpressionBenchmark.scala b/spark/src/test/scala/org/apache/spark/sql/benchmark/CometStringExpressionBenchmark.scala index b267dac691..4219103f96 100644 --- a/spark/src/test/scala/org/apache/spark/sql/benchmark/CometStringExpressionBenchmark.scala +++ b/spark/src/test/scala/org/apache/spark/sql/benchmark/CometStringExpressionBenchmark.scala @@ -59,6 +59,9 @@ object CometStringExpressionBenchmark extends CometBenchmarkBase { StringExprConfig("length", "select length(c1) from parquetV1Table"), StringExprConfig("levenshtein", "select levenshtein(c1, 'test') from parquetV1Table"), StringExprConfig("levenshtein", "select levenshtein(c1, 'test') from parquetV1Table"), + StringExprConfig( + "levenshtein_threshold", + "select levenshtein(c1, 'test', 3) from parquetV1Table"), StringExprConfig("like", "select c1 like '%123%' from parquetV1Table"), StringExprConfig("lower", "select lower(c1) from parquetV1Table"), StringExprConfig("lpad", "select lpad(c1, 150, 'x') from parquetV1Table"), From 6fea3f539bfc4698d93cbda8a66f035c131cf1d1 Mon Sep 17 00:00:00 2001 From: yusinnmao Date: Tue, 28 Apr 2026 16:41:25 +0800 Subject: [PATCH 07/15] Fix Spotless formatting in levenshtein threshold test Co-Authored-By: Claude Opus 4.6 (1M context) --- .../org/apache/comet/CometStringExpressionSuite.scala | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/spark/src/test/scala/org/apache/comet/CometStringExpressionSuite.scala b/spark/src/test/scala/org/apache/comet/CometStringExpressionSuite.scala index 514c53d123..b67e171e99 100644 --- a/spark/src/test/scala/org/apache/comet/CometStringExpressionSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometStringExpressionSuite.scala @@ -746,12 +746,9 @@ class CometStringExpressionSuite extends CometTestBase { ("hello", "world")) withParquetTable(data, "tbl") { - checkSparkAnswerAndOperator( - "SELECT levenshtein(_1, _2, 2) FROM tbl") - checkSparkAnswerAndOperator( - "SELECT levenshtein(_1, _2, 0) FROM tbl") - checkSparkAnswerAndOperator( - "SELECT levenshtein(_1, _2, 10) FROM tbl") + checkSparkAnswerAndOperator("SELECT levenshtein(_1, _2, 2) FROM tbl") + checkSparkAnswerAndOperator("SELECT levenshtein(_1, _2, 0) FROM tbl") + checkSparkAnswerAndOperator("SELECT levenshtein(_1, _2, 10) FROM tbl") } } From cfd4aff00b9de14db971fc5b9075ea1722b0341d Mon Sep 17 00:00:00 2001 From: yusinnmao Date: Tue, 28 Apr 2026 17:08:23 +0800 Subject: [PATCH 08/15] Fix Spotless formatting in strings.scala and test suite Co-Authored-By: Claude Opus 4.6 (1M context) --- .../scala/org/apache/comet/CometStringExpressionSuite.scala | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/spark/src/test/scala/org/apache/comet/CometStringExpressionSuite.scala b/spark/src/test/scala/org/apache/comet/CometStringExpressionSuite.scala index b67e171e99..f9219ecd03 100644 --- a/spark/src/test/scala/org/apache/comet/CometStringExpressionSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometStringExpressionSuite.scala @@ -739,11 +739,7 @@ class CometStringExpressionSuite extends CometTestBase { } test("levenshtein with threshold") { - val data = Seq( - ("kitten", "sitting"), - ("frog", "fog"), - ("abc", "abc"), - ("hello", "world")) + val data = Seq(("kitten", "sitting"), ("frog", "fog"), ("abc", "abc"), ("hello", "world")) withParquetTable(data, "tbl") { checkSparkAnswerAndOperator("SELECT levenshtein(_1, _2, 2) FROM tbl") From 3178039e76ff0363bd0edb356b232de4bd1ff915 Mon Sep 17 00:00:00 2001 From: yusinnmao Date: Wed, 29 Apr 2026 10:09:46 +0800 Subject: [PATCH 09/15] Use scalarFunctionExprToProtoWithReturnType for levenshtein Override convert() to use scalarFunctionExprToProtoWithReturnType with IntegerType, so the native planner skips the DataFusion registry lookup and does not conflict with DataFusion's built-in 2-arg levenshtein function when 3 args (threshold) are passed. Also fix Spotless: remove unnecessary string interpolation prefix. Co-Authored-By: Claude Opus 4.6 (1M context) --- .../scala/org/apache/comet/CometStringExpressionSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/spark/src/test/scala/org/apache/comet/CometStringExpressionSuite.scala b/spark/src/test/scala/org/apache/comet/CometStringExpressionSuite.scala index f9219ecd03..dea397d104 100644 --- a/spark/src/test/scala/org/apache/comet/CometStringExpressionSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometStringExpressionSuite.scala @@ -722,7 +722,7 @@ class CometStringExpressionSuite extends CometTestBase { sql(s"CREATE TABLE $table(s1 STRING, s2 STRING) USING parquet") sql( s"INSERT INTO $table VALUES " + - s"('abc', 'adc'), (NULL, 'test'), ('hello', NULL), (NULL, NULL)") + "('abc', 'adc'), (NULL, 'test'), ('hello', NULL), (NULL, NULL)") checkSparkAnswerAndOperator(s"SELECT levenshtein(s1, s2) FROM $table") } } From 4e6c75982557eedf33e5c80e87b812aac3d2624a Mon Sep 17 00:00:00 2001 From: yusinnmao Date: Thu, 30 Apr 2026 16:26:32 +0800 Subject: [PATCH 10/15] Fix Spark 3.4 compatibility: skip levenshtein threshold tests on Spark < 3.5 The three-argument levenshtein(str1, str2, threshold) is only available in Spark 3.5+. This moves threshold-related SQL tests to a separate file with MinSparkVersion: 3.5 and adds assume(isSpark35Plus) to the unit test. Co-Authored-By: Claude Opus 4.6 (1M context) --- .../expressions/string/levenshtein.sql | 15 ------- .../string/levenshtein_threshold.sql | 40 +++++++++++++++++++ .../comet/CometStringExpressionSuite.scala | 2 + 3 files changed, 42 insertions(+), 15 deletions(-) create mode 100644 spark/src/test/resources/sql-tests/expressions/string/levenshtein_threshold.sql diff --git a/spark/src/test/resources/sql-tests/expressions/string/levenshtein.sql b/spark/src/test/resources/sql-tests/expressions/string/levenshtein.sql index fc5ff5c8c5..6ae8225cbe 100644 --- a/spark/src/test/resources/sql-tests/expressions/string/levenshtein.sql +++ b/spark/src/test/resources/sql-tests/expressions/string/levenshtein.sql @@ -45,18 +45,3 @@ SELECT levenshtein(s1, s1) FROM test_levenshtein query SELECT levenshtein('café', 'cafe'), levenshtein('你好', '你坏') --- three argument version with threshold -query -SELECT levenshtein('kitten', 'sitting', 2), levenshtein('kitten', 'sitting', 3), levenshtein('kitten', 'sitting', 4) - --- threshold with column arguments -query -SELECT levenshtein(s1, s2, 2) FROM test_levenshtein - --- threshold edge cases -query -SELECT levenshtein('abc', 'abc', 0), levenshtein('abc', 'adc', 0), levenshtein('', '', 0) - --- threshold with NULL -query -SELECT levenshtein('abc', 'adc', NULL), levenshtein(NULL, 'test', 2) diff --git a/spark/src/test/resources/sql-tests/expressions/string/levenshtein_threshold.sql b/spark/src/test/resources/sql-tests/expressions/string/levenshtein_threshold.sql new file mode 100644 index 0000000000..a89435416d --- /dev/null +++ b/spark/src/test/resources/sql-tests/expressions/string/levenshtein_threshold.sql @@ -0,0 +1,40 @@ +-- Licensed to the Apache Software Foundation (ASF) under one +-- or more contributor license agreements. See the NOTICE file +-- distributed with this work for additional information +-- regarding copyright ownership. The ASF licenses this file +-- to you under the Apache License, Version 2.0 (the +-- "License"); you may not use this file except in compliance +-- with the License. You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, +-- software distributed under the License is distributed on an +-- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +-- KIND, either express or implied. See the License for the +-- specific language governing permissions and limitations +-- under the License. + +-- MinSparkVersion: 3.5 + +statement +CREATE TABLE test_levenshtein(s1 string, s2 string) USING parquet + +statement +INSERT INTO test_levenshtein VALUES ('kitten', 'sitting'), ('frog', 'fog'), ('abc', 'abc'), ('', 'hello'), ('hello', ''), ('', ''), (NULL, 'test'), ('hello', NULL), (NULL, NULL) + +-- three argument version with threshold +query +SELECT levenshtein('kitten', 'sitting', 2), levenshtein('kitten', 'sitting', 3), levenshtein('kitten', 'sitting', 4) + +-- threshold with column arguments +query +SELECT levenshtein(s1, s2, 2) FROM test_levenshtein + +-- threshold edge cases +query +SELECT levenshtein('abc', 'abc', 0), levenshtein('abc', 'adc', 0), levenshtein('', '', 0) + +-- threshold with NULL +query +SELECT levenshtein('abc', 'adc', NULL), levenshtein(NULL, 'test', 2) diff --git a/spark/src/test/scala/org/apache/comet/CometStringExpressionSuite.scala b/spark/src/test/scala/org/apache/comet/CometStringExpressionSuite.scala index dea397d104..282b59f78d 100644 --- a/spark/src/test/scala/org/apache/comet/CometStringExpressionSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometStringExpressionSuite.scala @@ -26,6 +26,7 @@ import org.apache.spark.sql.{CometTestBase, DataFrame} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{DataTypes, StructField, StructType} +import org.apache.comet.CometSparkSessionExtensions.isSpark35Plus import org.apache.comet.testing.{DataGenOptions, FuzzDataGenerator} class CometStringExpressionSuite extends CometTestBase { @@ -739,6 +740,7 @@ class CometStringExpressionSuite extends CometTestBase { } test("levenshtein with threshold") { + assume(isSpark35Plus, "levenshtein with threshold requires Spark 3.5+") val data = Seq(("kitten", "sitting"), ("frog", "fog"), ("abc", "abc"), ("hello", "world")) withParquetTable(data, "tbl") { From 4ac12d1e389d615689c34b68d89921dcde242ad8 Mon Sep 17 00:00:00 2001 From: yusinnmao Date: Thu, 30 Apr 2026 17:06:30 +0800 Subject: [PATCH 11/15] Remove duplicate levenshtein benchmark entry Co-Authored-By: Claude Opus 4.6 (1M context) --- .../spark/sql/benchmark/CometStringExpressionBenchmark.scala | 1 - 1 file changed, 1 deletion(-) diff --git a/spark/src/test/scala/org/apache/spark/sql/benchmark/CometStringExpressionBenchmark.scala b/spark/src/test/scala/org/apache/spark/sql/benchmark/CometStringExpressionBenchmark.scala index 4219103f96..330dcdbc5a 100644 --- a/spark/src/test/scala/org/apache/spark/sql/benchmark/CometStringExpressionBenchmark.scala +++ b/spark/src/test/scala/org/apache/spark/sql/benchmark/CometStringExpressionBenchmark.scala @@ -58,7 +58,6 @@ object CometStringExpressionBenchmark extends CometBenchmarkBase { StringExprConfig("instr", "select instr(c1, '123') from parquetV1Table"), StringExprConfig("length", "select length(c1) from parquetV1Table"), StringExprConfig("levenshtein", "select levenshtein(c1, 'test') from parquetV1Table"), - StringExprConfig("levenshtein", "select levenshtein(c1, 'test') from parquetV1Table"), StringExprConfig( "levenshtein_threshold", "select levenshtein(c1, 'test', 3) from parquetV1Table"), From 08445e32f013a179915128b6b01e8d00af662211 Mon Sep 17 00:00:00 2001 From: yusinnmao Date: Wed, 6 May 2026 14:11:29 +0800 Subject: [PATCH 12/15] Support threshold as column (not just literal) in levenshtein - Refactored spark_levenshtein to handle threshold as either ColumnarValue::Scalar or ColumnarValue::Array using into_array() - NULL threshold in a column produces NULL result for that row - Added Rust unit tests for array threshold, nulls, and negative values - Added Scala integration tests with threshold as column reference - Added SQL file tests with threshold column and NULL scenarios --- .../src/string_funcs/levenshtein.rs | 189 ++++++++++++++---- .../string/levenshtein_threshold.sql | 20 ++ .../comet/CometStringExpressionSuite.scala | 26 +++ 3 files changed, 199 insertions(+), 36 deletions(-) diff --git a/native/spark-expr/src/string_funcs/levenshtein.rs b/native/spark-expr/src/string_funcs/levenshtein.rs index 9d2c48c7b7..e6190f7488 100644 --- a/native/spark-expr/src/string_funcs/levenshtein.rs +++ b/native/spark-expr/src/string_funcs/levenshtein.rs @@ -82,7 +82,8 @@ fn levenshtein_distance(s: &str, t: &str) -> i32 { /// - `levenshtein(str1, str2)` → edit distance /// - `levenshtein(str1, str2, threshold)` → edit distance if <= threshold, else -1 /// -/// NULL inputs produce NULL outputs. NULL threshold produces NULL output. +/// The threshold argument can be either a scalar or a column (array). +/// NULL inputs produce NULL outputs. NULL threshold produces NULL output for that row. pub fn spark_levenshtein(args: &[ColumnarValue]) -> Result { if args.len() < 2 || args.len() > 3 { return Err(DataFusionError::Internal(format!( @@ -91,27 +92,9 @@ pub fn spark_levenshtein(args: &[ColumnarValue]) -> Result { ))); } - // Extract optional threshold (3rd argument must be a scalar Int32) - let threshold: Option = if args.len() == 3 { - match &args[2] { - ColumnarValue::Scalar(ScalarValue::Int32(t)) => match t { - Some(val) => Some(*val), - None => return Ok(ColumnarValue::Scalar(ScalarValue::Int32(None))), - }, - _ => { - return Err(DataFusionError::Internal( - "levenshtein threshold must be an Int32 scalar".to_string(), - )); - } - } - } else { - None - }; - - // Expand scalars to arrays for uniform processing + // Determine array length from any array argument let len = args .iter() - .take(2) .find_map(|arg| match arg { ColumnarValue::Array(a) => Some(a.len()), _ => None, @@ -124,22 +107,56 @@ pub fn spark_levenshtein(args: &[ColumnarValue]) -> Result { let left_arr = as_string_array(&left); let right_arr = as_string_array(&right); - let result: Int32Array = left_arr - .iter() - .zip(right_arr.iter()) - .map(|(l, r)| match (l, r) { - (Some(l), Some(r)) => { - let dist = levenshtein_distance(l, r); - match threshold { - Some(t) if dist > t => Some(-1), - _ => Some(dist), + // Handle the optional threshold argument (scalar or array) + if args.len() == 3 { + let threshold_array = args[2].clone().into_array(len)?; + let threshold_arr = threshold_array + .as_any() + .downcast_ref::() + .ok_or_else(|| { + DataFusionError::Internal( + "levenshtein threshold must be Int32".to_string(), + ) + })?; + + let result: Int32Array = left_arr + .iter() + .zip(right_arr.iter()) + .enumerate() + .map(|(i, (l, r))| { + // If threshold is NULL for this row, result is NULL + if threshold_arr.is_null(i) { + return None; } - } - _ => None, // NULL propagation - }) - .collect(); + match (l, r) { + (Some(l), Some(r)) => { + let dist = levenshtein_distance(l, r); + let t = threshold_arr.value(i); + if dist > t { + Some(-1) + } else { + Some(dist) + } + } + _ => None, // NULL propagation + } + }) + .collect(); - Ok(ColumnarValue::Array(Arc::new(result) as ArrayRef)) + Ok(ColumnarValue::Array(Arc::new(result) as ArrayRef)) + } else { + // No threshold: just compute distance + let result: Int32Array = left_arr + .iter() + .zip(right_arr.iter()) + .map(|(l, r)| match (l, r) { + (Some(l), Some(r)) => Some(levenshtein_distance(l, r)), + _ => None, // NULL propagation + }) + .collect(); + + Ok(ColumnarValue::Array(Arc::new(result) as ArrayRef)) + } } #[cfg(test)] @@ -223,8 +240,108 @@ mod tests { let result = spark_levenshtein(&[left, right, threshold]).unwrap(); match result { - ColumnarValue::Scalar(ScalarValue::Int32(None)) => {} // NULL threshold -> NULL - _ => panic!("Expected NULL scalar result for NULL threshold"), + ColumnarValue::Array(arr) => { + let int_arr = arr.as_any().downcast_ref::().unwrap(); + assert_eq!(int_arr.len(), 1); + assert!(int_arr.is_null(0)); // NULL threshold -> NULL result + } + _ => panic!("Expected array result with NULL for NULL threshold"), + } + } + + #[test] + fn test_spark_levenshtein_threshold_as_array() { + // threshold is a column (array) with per-row values + let left = ColumnarValue::Array(Arc::new(StringArray::from(vec![ + Some("kitten"), + Some("frog"), + Some("abc"), + Some("hello"), + ]))); + let right = ColumnarValue::Array(Arc::new(StringArray::from(vec![ + Some("sitting"), + Some("fog"), + Some("abc"), + Some("world"), + ]))); + // Per-row thresholds: 2, 5, 0, 3 + let threshold = ColumnarValue::Array(Arc::new(Int32Array::from(vec![ + Some(2), + Some(5), + Some(0), + Some(3), + ]))); + + let result = spark_levenshtein(&[left, right, threshold]).unwrap(); + match result { + ColumnarValue::Array(arr) => { + let int_arr = arr.as_any().downcast_ref::().unwrap(); + assert_eq!(int_arr.value(0), -1); // kitten->sitting=3 > 2, return -1 + assert_eq!(int_arr.value(1), 1); // frog->fog=1 <= 5, return 1 + assert_eq!(int_arr.value(2), 0); // abc->abc=0 <= 0, return 0 + assert_eq!(int_arr.value(3), -1); // hello->world=4 > 3, return -1 + } + _ => panic!("Expected array result"), + } + } + + #[test] + fn test_spark_levenshtein_threshold_array_with_nulls() { + // threshold array where some values are NULL + let left = ColumnarValue::Array(Arc::new(StringArray::from(vec![ + Some("abc"), + Some("hello"), + Some("frog"), + ]))); + let right = ColumnarValue::Array(Arc::new(StringArray::from(vec![ + Some("adc"), + Some("world"), + Some("fog"), + ]))); + let threshold = ColumnarValue::Array(Arc::new(Int32Array::from(vec![ + Some(2), + None, // NULL threshold for this row + Some(0), + ]))); + + let result = spark_levenshtein(&[left, right, threshold]).unwrap(); + match result { + ColumnarValue::Array(arr) => { + let int_arr = arr.as_any().downcast_ref::().unwrap(); + assert_eq!(int_arr.value(0), 1); // abc->adc=1 <= 2, return 1 + assert!(int_arr.is_null(1)); // NULL threshold -> NULL + assert_eq!(int_arr.value(2), -1); // frog->fog=1 > 0, return -1 + } + _ => panic!("Expected array result"), + } + } + + #[test] + fn test_spark_levenshtein_threshold_negative() { + // Negative threshold means distance always exceeds threshold → return -1 + let left = ColumnarValue::Array(Arc::new(StringArray::from(vec![ + Some("abc"), + Some("abc"), + ]))); + let right = ColumnarValue::Array(Arc::new(StringArray::from(vec![ + Some("abc"), + Some("adc"), + ]))); + let threshold = ColumnarValue::Array(Arc::new(Int32Array::from(vec![ + Some(-1), + Some(-5), + ]))); + + let result = spark_levenshtein(&[left, right, threshold]).unwrap(); + match result { + ColumnarValue::Array(arr) => { + let int_arr = arr.as_any().downcast_ref::().unwrap(); + // dist=0 > -1 is true, so return -1 + assert_eq!(int_arr.value(0), -1); + // dist=1 > -5 is true, so return -1 + assert_eq!(int_arr.value(1), -1); + } + _ => panic!("Expected array result"), } } } diff --git a/spark/src/test/resources/sql-tests/expressions/string/levenshtein_threshold.sql b/spark/src/test/resources/sql-tests/expressions/string/levenshtein_threshold.sql index a89435416d..a0756a3403 100644 --- a/spark/src/test/resources/sql-tests/expressions/string/levenshtein_threshold.sql +++ b/spark/src/test/resources/sql-tests/expressions/string/levenshtein_threshold.sql @@ -38,3 +38,23 @@ SELECT levenshtein('abc', 'abc', 0), levenshtein('abc', 'adc', 0), levenshtein(' -- threshold with NULL query SELECT levenshtein('abc', 'adc', NULL), levenshtein(NULL, 'test', 2) + +-- threshold as column +statement +CREATE TABLE test_levenshtein_col(s1 string, s2 string, threshold int) USING parquet + +statement +INSERT INTO test_levenshtein_col VALUES ('kitten', 'sitting', 2), ('frog', 'fog', 5), ('abc', 'abc', 0), ('hello', 'world', 3) + +query +SELECT levenshtein(s1, s2, threshold) FROM test_levenshtein_col + +-- threshold as column with NULLs +statement +CREATE TABLE test_levenshtein_col_nulls(s1 string, s2 string, threshold int) USING parquet + +statement +INSERT INTO test_levenshtein_col_nulls VALUES ('abc', 'adc', 2), ('hello', 'world', NULL), (NULL, 'test', 3), ('frog', 'fog', -1) + +query +SELECT levenshtein(s1, s2, threshold) FROM test_levenshtein_col_nulls diff --git a/spark/src/test/scala/org/apache/comet/CometStringExpressionSuite.scala b/spark/src/test/scala/org/apache/comet/CometStringExpressionSuite.scala index 282b59f78d..980cf96337 100644 --- a/spark/src/test/scala/org/apache/comet/CometStringExpressionSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometStringExpressionSuite.scala @@ -750,4 +750,30 @@ class CometStringExpressionSuite extends CometTestBase { } } + test("levenshtein with threshold as column") { + assume(isSpark35Plus, "levenshtein with threshold requires Spark 3.5+") + val table = "levenshtein_col_threshold_test" + withTable(table) { + sql(s"CREATE TABLE $table(s1 STRING, s2 STRING, threshold INT) USING parquet") + sql( + s"INSERT INTO $table VALUES " + + "('kitten', 'sitting', 2), ('frog', 'fog', 5), ('abc', 'abc', 0), ('hello', 'world', 3)") + // threshold as column reference + checkSparkAnswerAndOperator(s"SELECT levenshtein(s1, s2, threshold) FROM $table") + } + } + + test("levenshtein with threshold as column with nulls") { + assume(isSpark35Plus, "levenshtein with threshold requires Spark 3.5+") + val table = "levenshtein_col_threshold_null_test" + withTable(table) { + sql(s"CREATE TABLE $table(s1 STRING, s2 STRING, threshold INT) USING parquet") + sql( + s"INSERT INTO $table VALUES " + + "('abc', 'adc', 2), ('hello', 'world', NULL), (NULL, 'test', 3), ('frog', 'fog', -1)") + // NULL threshold and NULL strings should produce NULL; negative threshold returns -1 + checkSparkAnswerAndOperator(s"SELECT levenshtein(s1, s2, threshold) FROM $table") + } + } + } From 965328589cc4e32cabd1c65a15c9a84759aaf375 Mon Sep 17 00:00:00 2001 From: yusinnmao Date: Wed, 6 May 2026 14:24:20 +0800 Subject: [PATCH 13/15] Fix unused import warning: move ScalarValue to test module --- native/spark-expr/src/string_funcs/levenshtein.rs | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/native/spark-expr/src/string_funcs/levenshtein.rs b/native/spark-expr/src/string_funcs/levenshtein.rs index e6190f7488..c73b2d4507 100644 --- a/native/spark-expr/src/string_funcs/levenshtein.rs +++ b/native/spark-expr/src/string_funcs/levenshtein.rs @@ -21,7 +21,7 @@ //! matching Apache Spark's `levenshtein(str1, str2)` semantics. use arrow::array::{as_string_array, Array, ArrayRef, Int32Array}; -use datafusion::common::{DataFusionError, Result, ScalarValue}; +use datafusion::common::{DataFusionError, Result}; use datafusion::physical_plan::ColumnarValue; use std::sync::Arc; @@ -163,6 +163,7 @@ pub fn spark_levenshtein(args: &[ColumnarValue]) -> Result { mod tests { use super::*; use arrow::array::StringArray; + use datafusion::common::ScalarValue; #[test] fn test_levenshtein_basic() { From 33cf341f81ddea9fef3b69beb1636aac8aeab1e4 Mon Sep 17 00:00:00 2001 From: yusinnmao Date: Wed, 6 May 2026 14:33:27 +0800 Subject: [PATCH 14/15] Fix Spotless formatting in threshold column tests --- .../scala/org/apache/comet/CometStringExpressionSuite.scala | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/spark/src/test/scala/org/apache/comet/CometStringExpressionSuite.scala b/spark/src/test/scala/org/apache/comet/CometStringExpressionSuite.scala index 980cf96337..fb734bdeb3 100644 --- a/spark/src/test/scala/org/apache/comet/CometStringExpressionSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometStringExpressionSuite.scala @@ -755,9 +755,8 @@ class CometStringExpressionSuite extends CometTestBase { val table = "levenshtein_col_threshold_test" withTable(table) { sql(s"CREATE TABLE $table(s1 STRING, s2 STRING, threshold INT) USING parquet") - sql( - s"INSERT INTO $table VALUES " + - "('kitten', 'sitting', 2), ('frog', 'fog', 5), ('abc', 'abc', 0), ('hello', 'world', 3)") + sql(s"INSERT INTO $table VALUES " + + "('kitten', 'sitting', 2), ('frog', 'fog', 5), ('abc', 'abc', 0), ('hello', 'world', 3)") // threshold as column reference checkSparkAnswerAndOperator(s"SELECT levenshtein(s1, s2, threshold) FROM $table") } From aaa1d0493c3132e82d01401edaf99b75a99cc108 Mon Sep 17 00:00:00 2001 From: yusinnmao Date: Wed, 6 May 2026 15:26:52 +0800 Subject: [PATCH 15/15] Run cargo fmt --- .../src/string_funcs/levenshtein.rs | 21 ++++++------------- 1 file changed, 6 insertions(+), 15 deletions(-) diff --git a/native/spark-expr/src/string_funcs/levenshtein.rs b/native/spark-expr/src/string_funcs/levenshtein.rs index c73b2d4507..7387773398 100644 --- a/native/spark-expr/src/string_funcs/levenshtein.rs +++ b/native/spark-expr/src/string_funcs/levenshtein.rs @@ -114,9 +114,7 @@ pub fn spark_levenshtein(args: &[ColumnarValue]) -> Result { .as_any() .downcast_ref::() .ok_or_else(|| { - DataFusionError::Internal( - "levenshtein threshold must be Int32".to_string(), - ) + DataFusionError::Internal("levenshtein threshold must be Int32".to_string()) })?; let result: Int32Array = left_arr @@ -320,18 +318,11 @@ mod tests { #[test] fn test_spark_levenshtein_threshold_negative() { // Negative threshold means distance always exceeds threshold → return -1 - let left = ColumnarValue::Array(Arc::new(StringArray::from(vec![ - Some("abc"), - Some("abc"), - ]))); - let right = ColumnarValue::Array(Arc::new(StringArray::from(vec![ - Some("abc"), - Some("adc"), - ]))); - let threshold = ColumnarValue::Array(Arc::new(Int32Array::from(vec![ - Some(-1), - Some(-5), - ]))); + let left = + ColumnarValue::Array(Arc::new(StringArray::from(vec![Some("abc"), Some("abc")]))); + let right = + ColumnarValue::Array(Arc::new(StringArray::from(vec![Some("abc"), Some("adc")]))); + let threshold = ColumnarValue::Array(Arc::new(Int32Array::from(vec![Some(-1), Some(-5)]))); let result = spark_levenshtein(&[left, right, threshold]).unwrap(); match result {