diff --git a/datafusion/spark/src/function/string/mod.rs b/datafusion/spark/src/function/string/mod.rs index 64d603cb8bb67..9c90ded5f7e1b 100644 --- a/datafusion/spark/src/function/string/mod.rs +++ b/datafusion/spark/src/function/string/mod.rs @@ -27,6 +27,7 @@ pub mod length; pub mod like; pub mod luhn_check; pub mod make_valid_utf8; +pub mod quote; pub mod soundex; pub mod space; pub mod substring; @@ -51,6 +52,7 @@ make_udf_function!(base64::SparkUnBase64, unbase64); make_udf_function!(soundex::SparkSoundex, soundex); make_udf_function!(make_valid_utf8::SparkMakeValidUtf8, make_valid_utf8); make_udf_function!(is_valid_utf8::SparkIsValidUtf8, is_valid_utf8); +make_udf_function!(quote::SparkQuote, quote); pub mod expr_fn { use datafusion_functions::export_functions; @@ -127,6 +129,11 @@ pub mod expr_fn { "Returns the original string if str is a valid UTF-8 string, otherwise returns a new string whose invalid UTF8 byte sequences are replaced using the UNICODE replacement character U+FFFD.", str )); + export_functions!(( + quote, + "Returns str enclosed by single quotes and each instance of single quote in it is preceded by a backslash", + str + )); } pub fn functions() -> Vec> { @@ -147,5 +154,6 @@ pub fn functions() -> Vec> { soundex(), make_valid_utf8(), is_valid_utf8(), + quote(), ] } diff --git a/datafusion/spark/src/function/string/quote.rs b/datafusion/spark/src/function/string/quote.rs new file mode 100644 index 0000000000000..39ad8bf841764 --- /dev/null +++ b/datafusion/spark/src/function/string/quote.rs @@ -0,0 +1,121 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::{ArrayRef, OffsetSizeTrait, StringArray}; +use arrow::datatypes::DataType; +use datafusion::logical_expr::{Coercion, ColumnarValue, Signature, TypeSignatureClass}; +use datafusion_common::cast::{as_generic_string_array, as_string_view_array}; +use datafusion_common::types::{NativeType, logical_string}; +use datafusion_common::utils::take_function_args; +use datafusion_common::{Result, exec_err}; +use datafusion_expr::{ScalarFunctionArgs, ScalarUDFImpl, Volatility}; +use datafusion_functions::utils::make_scalar_function; + +use std::sync::Arc; + +/// Spark-compatible `quote` expression +/// +#[derive(Debug, PartialEq, Eq, Hash)] +pub struct SparkQuote { + signature: Signature, +} + +impl Default for SparkQuote { + fn default() -> Self { + Self::new() + } +} + +impl SparkQuote { + pub fn new() -> Self { + let str_coercion = Coercion::new_implicit( + TypeSignatureClass::Native(logical_string()), + vec![TypeSignatureClass::Any], + NativeType::String, + ); + Self { + signature: Signature::coercible(vec![str_coercion], Volatility::Immutable), + } + } +} + +impl ScalarUDFImpl for SparkQuote { + fn name(&self) -> &str { + "quote" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + match &arg_types[0] { + DataType::LargeUtf8 => Ok(DataType::LargeUtf8), + _ => Ok(DataType::Utf8), + } + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + make_scalar_function(spark_quote_inner, vec![])(&args.args) + } +} + +fn spark_quote_inner(arg: &[ArrayRef]) -> Result { + let [array] = take_function_args("quote", arg)?; + match &array.data_type() { + DataType::Utf8 => quote_array::(array), + DataType::LargeUtf8 => quote_array::(array), + DataType::Utf8View => quote_view(array), + other => { + exec_err!("unsupported data type {other:?} for function `quote`") + } + } +} + +fn quote_array(array: &ArrayRef) -> Result { + let str_array = as_generic_string_array::(array)?; + let result = str_array + .iter() + .map(|s| s.map(compute_quote)) + .collect::(); + Ok(Arc::new(result)) +} + +fn quote_view(str_view: &ArrayRef) -> Result { + let str_array = as_string_view_array(str_view)?; + let result = str_array + .iter() + .map(|opt_str| opt_str.map(compute_quote)) + .collect::(); + Ok(Arc::new(result) as ArrayRef) +} + +const QUOTE_CHAR: char = '\''; +const ESCAPE_CHAR: char = '\\'; + +fn compute_quote(s: &str) -> String { + let mut quoted = String::with_capacity(s.len() + 2); + quoted.push(QUOTE_CHAR); + for c in s.chars() { + if c == QUOTE_CHAR { + quoted.push(ESCAPE_CHAR); + } + quoted.push(c); + } + quoted.push(QUOTE_CHAR); + quoted +} diff --git a/datafusion/sqllogictest/test_files/spark/string/quote.slt b/datafusion/sqllogictest/test_files/spark/string/quote.slt new file mode 100644 index 0000000000000..856c8d9e5c516 --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/string/quote.slt @@ -0,0 +1,147 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +query T +SELECT quote(arrow_cast(127, 'Int8')); +---- +'127' + +query T +SELECT quote(arrow_cast(-128, 'Int8')); +---- +'-128' + +query T +SELECT quote(arrow_cast(32767, 'Int16')); +---- +'32767' + +query T +SELECT quote(arrow_cast(-32768, 'Int16')); +---- +'-32768' + +query T +SELECT quote(arrow_cast(2147483647, 'Int32')); +---- +'2147483647' + +query T +SELECT quote(arrow_cast(-2147483648, 'Int32')); +---- +'-2147483648' + +query T +SELECT quote(arrow_cast(9223372036854775807, 'Int64')); +---- +'9223372036854775807' + +query T +SELECT quote(arrow_cast(-9223372036854775808, 'Int64')); +---- +'-9223372036854775808' + +query T +SELECT quote(arrow_cast(3.14, 'Float32')); +---- +'3.14' + +query T +SELECT quote(arrow_cast(2.718281828459045, 'Float64')); +---- +'2.718281828459045' + +query T +SELECT quote(arrow_cast(0, 'UInt8')); +---- +'0' + +query T +SELECT quote(arrow_cast(255, 'UInt8')); +---- +'255' + +query T +SELECT quote(arrow_cast(65535, 'UInt16')); +---- +'65535' + +query T +SELECT quote(arrow_cast(4294967295, 'UInt32')); +---- +'4294967295' + +query T +SELECT quote(arrow_cast(18446744073709551615, 'UInt64')); +---- +'18446744073709551615' + +query T +SELECT quote('special chars: !@#$%^&*()'); +---- +'special chars: !@#$%^&*()' + +query T +SELECT quote('tab\tseparated'); +---- +'tab\tseparated' + +query T +SELECT quote('carriage\rreturn'); +---- +'carriage\rreturn' + +query T +SELECT quote('backslash\\test'); +---- +'backslash\\test' + +query T +SELECT quote('quote\"inside\"'); +---- +'quote\"inside\"' + +query T +SELECT quote('mixed\nescape\tchars\r\n'); +---- +'mixed\nescape\tchars\r\n' + +query T +SELECT quote('unicode: 你好, 世界'); +---- +'unicode: 你好, 世界' + +query T +SELECT quote('emoji: 😀🎉❤️🚀'); +---- +'emoji: 😀🎉❤️🚀' + +query T +SELECT quote(arrow_cast('2024-01-15', 'Date32')); +---- +'2024-01-15' + +query T +SELECT quote(arrow_cast('2024-01-15T12:30:45', 'Timestamp(µs)')); +---- +'2024-01-15T12:30:45' + +query T +SELECT quote('special\n\t\r'); +---- +'special\n\t\r' +