diff --git a/datafusion/spark/src/function/string/is_valid_utf8.rs b/datafusion/spark/src/function/string/is_valid_utf8.rs new file mode 100644 index 0000000000000..04958a25317d2 --- /dev/null +++ b/datafusion/spark/src/function/string/is_valid_utf8.rs @@ -0,0 +1,118 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::datatypes::{DataType, Field, FieldRef}; +use datafusion::logical_expr::{ColumnarValue, Signature, Volatility}; +use datafusion_common::{Result, internal_err}; +use datafusion_expr::{ReturnFieldArgs, ScalarFunctionArgs, ScalarUDFImpl}; + +use arrow::array::{Array, ArrayRef, BooleanArray}; +use arrow::buffer::BooleanBuffer; +use datafusion_common::cast::{ + as_binary_array, as_binary_view_array, as_large_binary_array, +}; +use datafusion_common::utils::take_function_args; +use datafusion_functions::utils::make_scalar_function; + +use std::sync::Arc; + +/// Spark-compatible `is_valid_utf8` expression +/// +#[derive(Debug, PartialEq, Eq, Hash)] +pub struct SparkIsValidUtf8 { + signature: Signature, +} + +impl Default for SparkIsValidUtf8 { + fn default() -> Self { + Self::new() + } +} + +impl SparkIsValidUtf8 { + pub fn new() -> Self { + Self { + signature: Signature::uniform( + 1, + vec![ + DataType::Utf8, + DataType::LargeUtf8, + DataType::Utf8View, + DataType::Binary, + DataType::BinaryView, + DataType::LargeBinary, + ], + Volatility::Immutable, + ), + } + } +} + +impl ScalarUDFImpl for SparkIsValidUtf8 { + fn name(&self) -> &str { + "is_valid_utf8" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + internal_err!("return_field_from_args should be used instead") + } + + fn return_field_from_args(&self, _args: ReturnFieldArgs) -> Result { + Ok(Arc::new(Field::new(self.name(), DataType::Boolean, true))) + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + make_scalar_function(spark_is_valid_utf8_inner, vec![])(&args.args) + } +} + +fn spark_is_valid_utf8_inner(args: &[ArrayRef]) -> Result { + let [array] = take_function_args("is_valid_utf8", args)?; + match array.data_type() { + DataType::Utf8 | DataType::Utf8View | DataType::LargeUtf8 => { + Ok(Arc::new(BooleanArray::new( + BooleanBuffer::new_set(array.len()), + array.nulls().cloned(), + ))) + } + DataType::Binary => Ok(Arc::new( + as_binary_array(array)? + .iter() + .map(|x| x.map(|y| str::from_utf8(y).is_ok())) + .collect::(), + )), + DataType::LargeBinary => Ok(Arc::new( + as_large_binary_array(array)? + .iter() + .map(|x| x.map(|y| str::from_utf8(y).is_ok())) + .collect::(), + )), + DataType::BinaryView => Ok(Arc::new( + as_binary_view_array(array)? + .iter() + .map(|x| x.map(|y| str::from_utf8(y).is_ok())) + .collect::(), + )), + data_type => { + internal_err!("is_valid_utf8 does not support: {data_type}") + } + } +} diff --git a/datafusion/spark/src/function/string/mod.rs b/datafusion/spark/src/function/string/mod.rs index e0f6878fdea7b..64d603cb8bb67 100644 --- a/datafusion/spark/src/function/string/mod.rs +++ b/datafusion/spark/src/function/string/mod.rs @@ -22,6 +22,7 @@ pub mod concat; pub mod elt; pub mod format_string; pub mod ilike; +pub mod is_valid_utf8; pub mod length; pub mod like; pub mod luhn_check; @@ -49,6 +50,7 @@ make_udf_function!(substring::SparkSubstring, substring); make_udf_function!(base64::SparkUnBase64, unbase64); make_udf_function!(soundex::SparkSoundex, soundex); make_udf_function!(make_valid_utf8::SparkMakeValidUtf8, make_valid_utf8); +make_udf_function!(is_valid_utf8::SparkIsValidUtf8, is_valid_utf8); pub mod expr_fn { use datafusion_functions::export_functions; @@ -115,6 +117,11 @@ pub mod expr_fn { str )); export_functions!((soundex, "Returns Soundex code of the string.", str)); + export_functions!(( + is_valid_utf8, + "Returns true if str is a valid UTF-8 string, otherwise returns false", + str + )); export_functions!(( make_valid_utf8, "Returns the original string if str is a valid UTF-8 string, otherwise returns a new string whose invalid UTF8 byte sequences are replaced using the UNICODE replacement character U+FFFD.", @@ -139,5 +146,6 @@ pub fn functions() -> Vec> { unbase64(), soundex(), make_valid_utf8(), + is_valid_utf8(), ] } diff --git a/datafusion/sqllogictest/test_files/spark/string/is_valid_utf8.slt b/datafusion/sqllogictest/test_files/spark/string/is_valid_utf8.slt new file mode 100644 index 0000000000000..9b04595334ae1 --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/string/is_valid_utf8.slt @@ -0,0 +1,203 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +statement ok +CREATE TABLE test_is_valid_utf8(value STRING) AS VALUES + (arrow_cast('Hello, world!', 'Utf8')), + (arrow_cast('Spark', 'Utf8')), + (arrow_cast('DataFusion', 'Utf8')), + (arrow_cast('ASCII only 123 !@#', 'Utf8')), + (arrow_cast(NULL, 'Utf8')); + +query B +SELECT is_valid_utf8(value) FROM test_is_valid_utf8; +---- +true +true +true +true +NULL + +query B +SELECT is_valid_utf8(NULL::string); +---- +NULL + +query B +SELECT is_valid_utf8('Hello, world!'::string); +---- +true + +query B +SELECT is_valid_utf8('😀🎉✨'::string); +---- +true + +query B +SELECT is_valid_utf8(''::string); +---- +true + +query B +SELECT is_valid_utf8('ASCII only 123 !@#'::string); +---- +true + +query B +SELECT is_valid_utf8(arrow_cast(x'C2A9', 'Binary')); +---- +true + +query B +SELECT is_valid_utf8(arrow_cast(x'C2AE', 'Binary')); +---- +true + +query B +SELECT is_valid_utf8(arrow_cast(x'E282AC', 'Binary')); +---- +true + +query B +SELECT is_valid_utf8(arrow_cast(x'E284A2', 'Binary')); +---- +true + +query B +SELECT is_valid_utf8(arrow_cast(x'F09F9880', 'Binary')); +---- +true + +query B +SELECT is_valid_utf8(arrow_cast(x'F09F8E89', 'Binary')); +---- +true + +query B +SELECT is_valid_utf8(arrow_cast(x'80', 'Binary')); +---- +false + +query B +SELECT is_valid_utf8(arrow_cast(x'BF', 'Binary')); +---- +false + +query B +SELECT is_valid_utf8(arrow_cast(x'808080', 'Binary')); +---- +false + +query B +SELECT is_valid_utf8(arrow_cast(x'C2', 'Binary')); +---- +false + +query B +SELECT is_valid_utf8(arrow_cast(x'E2', 'Binary')); +---- +false + +query B +SELECT is_valid_utf8(arrow_cast(x'F0', 'Binary')); +---- +false + +query B +SELECT is_valid_utf8(arrow_cast(x'E282', 'Binary')); +---- +false + +query B +SELECT is_valid_utf8(arrow_cast(x'C081', 'Binary')); +---- +false + +query B +SELECT is_valid_utf8(arrow_cast(x'E08080', 'Binary')); +---- +false + +query B +SELECT is_valid_utf8(arrow_cast(x'F0808080', 'Binary')); +---- +false + +query B +SELECT is_valid_utf8(arrow_cast(x'FE', 'Binary')); +---- +false + +query B +SELECT is_valid_utf8(arrow_cast(x'FF', 'Binary')); +---- +false + +query B +SELECT is_valid_utf8(arrow_cast(x'61C262', 'Binary')); +---- +false + +query B +SELECT is_valid_utf8(arrow_cast(x'41BF42', 'Binary')); +---- +false + +query B +SELECT is_valid_utf8(arrow_cast(x'ED9FBF', 'Binary')); +---- +true + +query B +SELECT is_valid_utf8(arrow_cast(x'EDA080', 'Binary')); +---- +false + +query B +SELECT is_valid_utf8(arrow_cast(x'EDBFBF', 'Binary')); +---- +false + +query B +SELECT is_valid_utf8(arrow_cast(x'F48FBFBF', 'Binary')); +---- +true + +query B +SELECT is_valid_utf8(arrow_cast(x'F4908080', 'Binary')); +---- +false + +query B +SELECT is_valid_utf8(arrow_cast(x'6162C2A963', 'Binary')); +---- +true + +query B +SELECT is_valid_utf8(arrow_cast(x'6162806364', 'Binary')); +---- +false + +query B +SELECT is_valid_utf8(arrow_cast(x'610062', 'Binary')); +---- +true + +query B +SELECT is_valid_utf8(arrow_cast(x'', 'Binary')); +---- +true