From bf6f1f738729dfb4ba484e46ee24ce945cc20a55 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?L=C3=ADa=20Adriana?= Date: Fri, 29 May 2026 09:47:56 +0200 Subject: [PATCH] refactor: wrap HigherOrderUDFImpl in a concrete HigherOrderUDF struct (#22593) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Part of https://github.com/apache/datafusion/issues/21172 `HigherOrderUDF` was the only UDF kind defined as a trait that callers used directly via `Arc`. The other UDFs: `ScalarUDF`, `AggregateUDF`, `WindowUDF` — are concrete structs that wrap their respective `*Impl trait`, which makes inherent methods like `with_aliases` ergonomic to call on the function object. With the trait-only setup, adding aliases to an existing higher-order function required an extension trait import or a free helper function. This PR brings higher order functions in line with the other UDFs so the same `with_aliases` pattern works. - Rename the `HigherOrderUDF` trait to `HigherOrderUDFImpl`, matching `ScalarUDFImpl`/`AggregateUDFImpl`. Add a concrete `HigherOrderUDF` struct wrapping `Arc`, with the same shape as `ScalarUDF`: new_from_impl, new_from_shared_impl, inner, with_aliases, From, and delegate methods for every trait method. `with_aliases` is backed by a private `AliasedHigherOrderUDFImpl` decorator (same pattern as `AliasedScalarUDFImpl`). - Update `Expr::HigherOrderFunction`, `FunctionRegistry`, the `create_higher_order! `singleton macro, and all consumer files ( across several crates) to use `Arc` instead of `Arc`. Existing impls (`ArrayFilter`, `ArrayTransform`, `ArrayAnyMatch`) now implement `HigherOrderUDFImpl`; their public constructors continue to return `Arc` so external call sites need no changes. Callers can now write: `array_filter_higher_order_function().with_aliases(["filter"]) ` exactly like the existing scalar pattern: `make_array_udf().as_ref().clone().with_aliases(["array_construct"]) ` Covered by existing tests Yes, any code referring to `Arc` needs to become `Arc`, and any code that wrote `impl HigherOrderUDF for MyHOF` needs to write i`mpl HigherOrderUDFImpl for MyType`. Constructing a HigherOrderUDF from an impl is HigherOrderUDF::new_from_impl(my_impl) (or my_impl.into()). --- .../examples/sql_ops/frontend.rs | 2 +- .../core/src/bin/print_functions_docs.rs | 2 +- .../src/datasource/listing_table_factory.rs | 2 +- datafusion/core/src/execution/context/mod.rs | 10 +- .../core/src/execution/session_state.rs | 26 +- .../src/execution/session_state_defaults.rs | 2 +- datafusion/core/tests/optimizer/mod.rs | 2 +- .../datasource-arrow/src/file_format.rs | 2 +- datafusion/datasource/src/url.rs | 2 +- datafusion/execution/src/task.rs | 12 +- datafusion/expr/src/expr.rs | 6 +- datafusion/expr/src/higher_order_function.rs | 364 ++++++++++++++++-- datafusion/expr/src/lib.rs | 4 +- datafusion/expr/src/planner.rs | 2 +- datafusion/expr/src/registry.rs | 16 +- .../expr/src/type_coercion/functions.rs | 48 +-- datafusion/expr/src/udf_eq.rs | 4 +- datafusion/ffi/src/session/mod.rs | 4 +- .../functions-nested/src/array_any_match.rs | 6 +- .../functions-nested/src/array_filter.rs | 6 +- .../functions-nested/src/array_transform.rs | 6 +- .../functions-nested/src/lambda_utils.rs | 2 +- datafusion/functions-nested/src/lib.rs | 4 +- .../functions-nested/src/macros_lambda.rs | 6 +- .../optimizer/tests/optimizer_integration.rs | 2 +- .../src/higher_order_function.rs | 32 +- datafusion/proto/src/logical_plan/mod.rs | 4 +- datafusion/session/src/session.rs | 2 +- datafusion/spark/src/lib.rs | 2 +- datafusion/sql/examples/sql.rs | 2 +- datafusion/sql/src/expr/function.rs | 2 +- datafusion/sql/src/expr/mod.rs | 2 +- datafusion/sql/src/unparser/expr.rs | 15 +- datafusion/sql/tests/common/mod.rs | 6 +- datafusion/sql/tests/sql_integration.rs | 10 +- 35 files changed, 457 insertions(+), 162 deletions(-) diff --git a/datafusion-examples/examples/sql_ops/frontend.rs b/datafusion-examples/examples/sql_ops/frontend.rs index b34c720a78198..27eb97ee7ab25 100644 --- a/datafusion-examples/examples/sql_ops/frontend.rs +++ b/datafusion-examples/examples/sql_ops/frontend.rs @@ -154,7 +154,7 @@ impl ContextProvider for MyContextProvider { None } - fn get_higher_order_meta(&self, _name: &str) -> Option> { + fn get_higher_order_meta(&self, _name: &str) -> Option> { None } diff --git a/datafusion/core/src/bin/print_functions_docs.rs b/datafusion/core/src/bin/print_functions_docs.rs index c34865a32d532..86f433ac8e12c 100644 --- a/datafusion/core/src/bin/print_functions_docs.rs +++ b/datafusion/core/src/bin/print_functions_docs.rs @@ -287,7 +287,7 @@ impl DocProvider for WindowUDF { } } -impl DocProvider for Arc { +impl DocProvider for Arc { fn get_name(&self) -> String { self.name().to_string() } diff --git a/datafusion/core/src/datasource/listing_table_factory.rs b/datafusion/core/src/datasource/listing_table_factory.rs index a1eb7ffb64b7d..349d941cc2bda 100644 --- a/datafusion/core/src/datasource/listing_table_factory.rs +++ b/datafusion/core/src/datasource/listing_table_factory.rs @@ -587,7 +587,7 @@ mod tests { } fn higher_order_functions( &self, - ) -> &HashMap> { + ) -> &HashMap> { unimplemented!() } fn aggregate_functions( diff --git a/datafusion/core/src/execution/context/mod.rs b/datafusion/core/src/execution/context/mod.rs index 67dbe6b7402ed..f414efc64671a 100644 --- a/datafusion/core/src/execution/context/mod.rs +++ b/datafusion/core/src/execution/context/mod.rs @@ -1626,7 +1626,7 @@ impl SessionContext { /// - `SELECT "my_HIGHER_ORDER_FUNC"(x)` will look for a function named `"my_HIGHER_ORDER_FUNC"` /// /// Any functions registered with the function name or its aliases will be overwritten with this new function - pub fn register_higher_order_function(&self, f: Arc) { + pub fn register_higher_order_function(&self, f: Arc) { let mut state = self.state.write(); state.register_higher_order_function(f).ok(); } @@ -2063,7 +2063,7 @@ impl FunctionRegistry for SessionContext { self.state.read().udf(name) } - fn higher_order_function(&self, name: &str) -> Result> { + fn higher_order_function(&self, name: &str) -> Result> { self.state.read().higher_order_function(name) } @@ -2081,8 +2081,8 @@ impl FunctionRegistry for SessionContext { fn register_higher_order_function( &mut self, - function: Arc, - ) -> Result>> { + function: Arc, + ) -> Result>> { self.state.write().register_higher_order_function(function) } @@ -2221,7 +2221,7 @@ pub enum RegisterFunction { /// Window user defined function Window(Arc), /// Higher-order user defined function - HigherOrder(Arc), + HigherOrder(Arc), /// Table user defined function Table(String, Arc), } diff --git a/datafusion/core/src/execution/session_state.rs b/datafusion/core/src/execution/session_state.rs index de5e6b97c1af9..d21eeda93a127 100644 --- a/datafusion/core/src/execution/session_state.rs +++ b/datafusion/core/src/execution/session_state.rs @@ -162,7 +162,7 @@ pub struct SessionState { /// Scalar functions that are registered with the context scalar_functions: HashMap>, /// Higher order functions that are registered with the context - higher_order_functions: HashMap>, + higher_order_functions: HashMap>, /// Aggregate functions registered in the context aggregate_functions: HashMap>, /// Window functions registered in the context @@ -286,7 +286,7 @@ impl Session for SessionState { &self.scalar_functions } - fn higher_order_functions(&self) -> &HashMap> { + fn higher_order_functions(&self) -> &HashMap> { &self.higher_order_functions } @@ -934,7 +934,7 @@ impl SessionState { } /// Return reference to higher_order_functions - pub fn higher_order_functions(&self) -> &HashMap> { + pub fn higher_order_functions(&self) -> &HashMap> { &self.higher_order_functions } @@ -1034,7 +1034,7 @@ pub struct SessionStateBuilder { catalog_list: Option>, table_functions: Option>>, scalar_functions: Option>>, - higher_order_functions: Option>>, + higher_order_functions: Option>>, aggregate_functions: Option>>, window_functions: Option>>, extension_types: Option, @@ -1371,7 +1371,7 @@ impl SessionStateBuilder { /// Set the map of [`HigherOrderUDF`]s pub fn with_higher_order_functions( mut self, - higher_order_functions: Vec>, + higher_order_functions: Vec>, ) -> Self { self.higher_order_functions = Some(higher_order_functions); self @@ -1791,9 +1791,7 @@ impl SessionStateBuilder { } /// Returns the current scalar_functions value - pub fn higher_order_functions( - &mut self, - ) -> &mut Option>> { + pub fn higher_order_functions(&mut self) -> &mut Option>> { &mut self.higher_order_functions } @@ -2016,7 +2014,7 @@ impl ContextProvider for SessionContextProvider<'_> { self.state.scalar_functions().get(name).cloned() } - fn get_higher_order_meta(&self, name: &str) -> Option> { + fn get_higher_order_meta(&self, name: &str) -> Option> { self.state.higher_order_functions().get(name).cloned() } @@ -2106,7 +2104,7 @@ impl FunctionRegistry for SessionState { fn higher_order_function( &self, name: &str, - ) -> datafusion_common::Result> { + ) -> datafusion_common::Result> { self.higher_order_functions .get(name) .cloned() @@ -2142,8 +2140,8 @@ impl FunctionRegistry for SessionState { fn register_higher_order_function( &mut self, - function: Arc, - ) -> datafusion_common::Result>> { + function: Arc, + ) -> datafusion_common::Result>> { function.aliases().iter().for_each(|alias| { self.higher_order_functions .insert(alias.clone(), Arc::clone(&function)); @@ -2191,7 +2189,7 @@ impl FunctionRegistry for SessionState { fn deregister_higher_order_function( &mut self, name: &str, - ) -> datafusion_common::Result>> { + ) -> datafusion_common::Result>> { let function = self.higher_order_functions.remove(name); if let Some(function) = &function { for alias in function.aliases() { @@ -2677,7 +2675,7 @@ mod tests { self.state.scalar_functions().get(name).cloned() } - fn get_higher_order_meta(&self, name: &str) -> Option> { + fn get_higher_order_meta(&self, name: &str) -> Option> { self.state.higher_order_functions().get(name).cloned() } diff --git a/datafusion/core/src/execution/session_state_defaults.rs b/datafusion/core/src/execution/session_state_defaults.rs index 5e85c1bbc5e9e..584879cb197b5 100644 --- a/datafusion/core/src/execution/session_state_defaults.rs +++ b/datafusion/core/src/execution/session_state_defaults.rs @@ -114,7 +114,7 @@ impl SessionStateDefaults { } /// returns the list of default [`HigherOrderUDF`]s - pub fn default_higher_order_functions() -> Vec> { + pub fn default_higher_order_functions() -> Vec> { #[cfg(feature = "nested_expressions")] return functions_nested::all_default_higher_order_functions(); diff --git a/datafusion/core/tests/optimizer/mod.rs b/datafusion/core/tests/optimizer/mod.rs index c8208ef3efa90..0bfe1fac68795 100644 --- a/datafusion/core/tests/optimizer/mod.rs +++ b/datafusion/core/tests/optimizer/mod.rs @@ -216,7 +216,7 @@ impl ContextProvider for MyContextProvider { self.udfs.get(name).cloned() } - fn get_higher_order_meta(&self, _name: &str) -> Option> { + fn get_higher_order_meta(&self, _name: &str) -> Option> { None } diff --git a/datafusion/datasource-arrow/src/file_format.rs b/datafusion/datasource-arrow/src/file_format.rs index 9297486ad66e7..ee3c676a1bc0a 100644 --- a/datafusion/datasource-arrow/src/file_format.rs +++ b/datafusion/datasource-arrow/src/file_format.rs @@ -594,7 +594,7 @@ mod tests { unimplemented!() } - fn higher_order_functions(&self) -> &HashMap> { + fn higher_order_functions(&self) -> &HashMap> { unimplemented!() } diff --git a/datafusion/datasource/src/url.rs b/datafusion/datasource/src/url.rs index 14f9b2af0021d..4bf99fc325e2c 100644 --- a/datafusion/datasource/src/url.rs +++ b/datafusion/datasource/src/url.rs @@ -1210,7 +1210,7 @@ mod tests { unimplemented!() } - fn higher_order_functions(&self) -> &HashMap> { + fn higher_order_functions(&self) -> &HashMap> { unimplemented!() } diff --git a/datafusion/execution/src/task.rs b/datafusion/execution/src/task.rs index 0de0c937f2211..18825e1d8d19d 100644 --- a/datafusion/execution/src/task.rs +++ b/datafusion/execution/src/task.rs @@ -59,7 +59,7 @@ pub struct TaskContext { /// Scalar functions associated with this task context scalar_functions: HashMap>, /// Higher order functions associated with this task context - higher_order_functions: HashMap>, + higher_order_functions: HashMap>, /// Aggregate functions associated with this task context aggregate_functions: HashMap>, /// Window functions associated with this task context @@ -98,7 +98,7 @@ impl TaskContext { session_id: String, session_config: SessionConfig, scalar_functions: HashMap>, - higher_order_functions: HashMap>, + higher_order_functions: HashMap>, aggregate_functions: HashMap>, window_functions: HashMap>, runtime: Arc, @@ -144,7 +144,7 @@ impl TaskContext { &self.scalar_functions } - pub fn higher_order_functions(&self) -> &HashMap> { + pub fn higher_order_functions(&self) -> &HashMap> { &self.higher_order_functions } @@ -182,7 +182,7 @@ impl FunctionRegistry for TaskContext { }) } - fn higher_order_function(&self, name: &str) -> Result> { + fn higher_order_function(&self, name: &str) -> Result> { let result = self.higher_order_functions.get(name); result.cloned().ok_or_else(|| { @@ -236,8 +236,8 @@ impl FunctionRegistry for TaskContext { fn register_higher_order_function( &mut self, - function: Arc, - ) -> Result>> { + function: Arc, + ) -> Result>> { function.aliases().iter().for_each(|alias| { self.higher_order_functions .insert(alias.clone(), Arc::clone(&function)); diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs index d6276b944c334..4220bbd097dfc 100644 --- a/datafusion/expr/src/expr.rs +++ b/datafusion/expr/src/expr.rs @@ -437,14 +437,14 @@ pub enum Expr { #[derive(Clone, Eq, PartialOrd, Debug)] pub struct HigherOrderFunction { /// The function - pub func: Arc, + pub func: Arc, /// List of expressions to feed to the functions as arguments pub args: Vec, } impl HigherOrderFunction { /// Create a new `HigherOrderFunction` from a [`HigherOrderUDF`] - pub fn new(func: Arc, args: Vec) -> Self { + pub fn new(func: Arc, args: Vec) -> Self { Self { func, args } } @@ -452,7 +452,7 @@ impl HigherOrderFunction { self.func.name() } - /// Invokes the inner function [`HigherOrderUDF::lambda_parameters`] + /// Invokes the inner function [`crate::HigherOrderUDFImpl::lambda_parameters`] /// using the arguments of this invocation. This expression lambda /// variables must be already resolved either by coming from the /// default sql planner or by calling [Expr::resolve_lambda_variables] diff --git a/datafusion/expr/src/higher_order_function.rs b/datafusion/expr/src/higher_order_function.rs index 00522ad97b9e2..413714f498164 100644 --- a/datafusion/expr/src/higher_order_function.rs +++ b/datafusion/expr/src/higher_order_function.rs @@ -22,6 +22,7 @@ use crate::expr::{ schema_name_from_exprs_comma_separated_without_space, }; use crate::type_coercion::functions::value_fields_with_higher_order_udf; +use crate::udf_eq::UdfEq; use crate::{ColumnarValue, Documentation, Expr, ExprSchemable}; use arrow::array::{ArrayRef, RecordBatch}; use arrow::datatypes::{DataType, FieldRef, Schema}; @@ -67,14 +68,14 @@ pub enum HigherOrderTypeSignature { /// function. /// /// If this signature is specified, - /// DataFusion will call [`HigherOrderUDF::coerce_value_types`] to prepare argument types. + /// DataFusion will call [`HigherOrderUDFImpl::coerce_value_types`] to prepare argument types. UserDefined, /// One or more lambdas or arguments with arbitrary types VariadicAny, /// The specified number of lambdas or arguments with arbitrary types. Any(usize), /// Exactly the specified arguments in the given order, with arbitrary types. - /// DataFusion will call [`HigherOrderUDF::coerce_value_types`] to prepare the value + /// DataFusion will call [`HigherOrderUDFImpl::coerce_value_types`] to prepare the value /// argument types. Exact(Vec>), } @@ -91,7 +92,7 @@ pub struct HigherOrderSignature { pub type_signature: HigherOrderTypeSignature, /// The volatility of the function. See [Volatility] for more information. pub volatility: Volatility, - /// The max number of times to call [HigherOrderUDF::lambda_parameters] before raising an error. + /// The max number of times to call [HigherOrderUDFImpl::lambda_parameters] before raising an error. /// Used to guard against implementations that causes an infinite loop by endlessly returning /// [LambdaParametersProgress::Partial]. Defaults to 256 pub lambda_parameters_max_iterations: usize, @@ -137,7 +138,7 @@ impl HigherOrderSignature { } /// Exactly the specified arguments in the given order, with arbitrary types. - /// DataFusion will call [`HigherOrderUDF::coerce_value_types`] to prepare the value + /// DataFusion will call [`HigherOrderUDFImpl::coerce_value_types`] to prepare the value /// argument types. /// /// # Example @@ -158,13 +159,13 @@ impl HigherOrderSignature { } } -impl PartialEq for dyn HigherOrderUDF { +impl PartialEq for dyn HigherOrderUDFImpl { fn eq(&self, other: &Self) -> bool { self.dyn_eq(other as _) } } -impl PartialOrd for dyn HigherOrderUDF { +impl PartialOrd for dyn HigherOrderUDFImpl { fn partial_cmp(&self, other: &Self) -> Option { let mut cmp = self.name().cmp(other.name()); if cmp == Ordering::Equal { @@ -193,15 +194,15 @@ impl PartialOrd for dyn HigherOrderUDF { } } -impl Eq for dyn HigherOrderUDF {} +impl Eq for dyn HigherOrderUDFImpl {} -impl Hash for dyn HigherOrderUDF { +impl Hash for dyn HigherOrderUDFImpl { fn hash(&self, state: &mut H) { self.dyn_hash(state) } } -/// Arguments passed to [`HigherOrderUDF::invoke_with_args`] when invoking a +/// Arguments passed to [`HigherOrderUDFImpl::invoke_with_args`] when invoking a /// higher order function. #[derive(Debug, Clone)] pub struct HigherOrderFunctionArgs { @@ -210,7 +211,7 @@ pub struct HigherOrderFunctionArgs { /// Field associated with each arg, if it exists /// For lambdas, it will be the field of the result of /// the lambda if evaluated with the parameters - /// returned from [`HigherOrderUDF::lambda_parameters`] + /// returned from [`HigherOrderUDFImpl::lambda_parameters`] pub arg_fields: Vec>, /// The number of rows in record batch being evaluated pub number_rows: usize, @@ -284,7 +285,7 @@ impl LambdaArgument { /// Evaluate this lambda /// `args` should evaluate to the value of each parameter - /// of the correspondent lambda returned in [HigherOrderUDF::lambda_parameters]. + /// of the correspondent lambda returned in [HigherOrderUDFImpl::lambda_parameters]. /// /// `spread_captures` is responsible for transforming the captured column arrays /// so they align with the evaluation batch. Captures are snapshotted from the @@ -390,13 +391,13 @@ fn merge_captures_with_variables( /// such as the type of the arguments, any scalar arguments and if the /// arguments can (ever) be null /// -/// See [`HigherOrderUDF::return_field_from_args`] for more information +/// See [`HigherOrderUDFImpl::return_field_from_args`] for more information #[derive(Clone, Debug)] pub struct HigherOrderReturnFieldArgs<'a> { /// The data types of the arguments to the function /// /// If argument `i` to the function is a lambda, it will be the field of the result of the - /// lambda if evaluated with the parameters returned from [`HigherOrderUDF::lambda_parameters`] + /// lambda if evaluated with the parameters returned from [`HigherOrderUDFImpl::lambda_parameters`] /// /// For example, with `array_transform([1], v -> v == 5)` /// this field will be @@ -426,19 +427,19 @@ pub enum ValueOrLambda { } /// Represents a step during the resolution of the parameters of all lambdas of a given -/// higher-order function via [HigherOrderUDF::lambda_parameters]. It's valid that the +/// higher-order function via [HigherOrderUDFImpl::lambda_parameters]. It's valid that the /// fields of a given lambda changes between steps, and is up to the implementation to /// provide during the function evaluation the parameters that matches the fields returned -/// at the [LambdaParametersProgress::Complete] step. See [HigherOrderUDF::lambda_parameters] +/// at the [LambdaParametersProgress::Complete] step. See [HigherOrderUDFImpl::lambda_parameters] /// docs for more details pub enum LambdaParametersProgress { /// The parameters of some lambdas are unknown due to a dependency on another lambda output field /// or are placeholders due to a dependency on it's own output field. It's perfectly valid to /// contain only `Some`'s and not a single `None`, representing lambdas that depends only on itself - /// and not on others. [HigherOrderUDF::lambda_parameters] will be called again with the output + /// and not on others. [HigherOrderUDFImpl::lambda_parameters] will be called again with the output /// field of all lambdas with known parameters. Partial(Vec>>), - /// There are no unmet dependencies and all parameters are known, [HigherOrderUDF::lambda_parameters] + /// There are no unmet dependencies and all parameters are known, [HigherOrderUDFImpl::lambda_parameters] /// will not be called again Complete(Vec>), } @@ -448,10 +449,13 @@ pub enum LambdaParametersProgress { /// This trait exposes the full API for implementing user defined functions and /// can be used to implement any function. /// +/// New higher order functions typically implement this trait and are then +/// wrapped in a [`HigherOrderUDF`] for registration with DataFusion. +/// /// See [`array_transform.rs`] for a commented complete implementation /// /// [`array_transform.rs`]: https://github.com/apache/datafusion/blob/main/datafusion/functions-nested/src/array_transform.rs -pub trait HigherOrderUDF: Debug + DynEq + DynHash + Send + Sync + Any { +pub trait HigherOrderUDFImpl: Debug + DynEq + DynHash + Send + Sync + Any { /// Returns this function's name fn name(&self) -> &str; @@ -546,11 +550,11 @@ pub trait HigherOrderUDF: Debug + DynEq + DynHash + Send + Sync + Any { /// /// For functions which lambda parameters depends on the output of other lambdas, or on their own lambda, /// this can return [LambdaParametersProgress::Partial] until all dependencies are met. Note that for - /// lambda with cyclic dependencies, you likely want to use [HigherOrderUDF::coerce_values_for_lambdas] too. + /// lambda with cyclic dependencies, you likely want to use [HigherOrderUDFImpl::coerce_values_for_lambdas] too. /// Take as an example a flexible array_reduce with the signature `(arr: [V], initial_value: I, (ACC, V) -> ACC, (ACC) -> O) -> O`. /// It has a cyclic dependency in the merge lambda, and a dependency of the finish lambda in the merge lambda, /// and only requires the initial value to be *coercible* to the output of the merge lambda, which is defined by - /// it's [HigherOrderUDF::coerce_values_for_lambdas] implementation. The expression + /// it's [HigherOrderUDFImpl::coerce_values_for_lambdas] implementation. The expression /// /// `array_reduce([1.2, 2.1], 0, (acc, v) -> acc + v + 1.5, v -> v > 5.1)` /// @@ -658,7 +662,7 @@ pub trait HigherOrderUDF: Debug + DynEq + DynHash + Send + Sync + Any { ) -> Result; /// Coerce value arguments of a function call to types that the function can evaluate also taking into - /// account the *output type of it's lambdas*. This differs from [HigherOrderUDF::coerce_value_types] + /// account the *output type of it's lambdas*. This differs from [HigherOrderUDFImpl::coerce_value_types] /// that only has access to the type of it's value arguments because it's called before the output type /// of lambdas are known. /// @@ -744,7 +748,7 @@ pub trait HigherOrderUDF: Debug + DynEq + DynHash + Send + Sync + Any { /// Setting this to true prevents certain optimizations such as common /// subexpression elimination /// - /// When overriding this function to return `true`, [HigherOrderUDF::conditional_arguments] can also be + /// When overriding this function to return `true`, [HigherOrderUDFImpl::conditional_arguments] can also be /// overridden to report more accurately which arguments are eagerly evaluated and which ones /// lazily. fn short_circuits(&self) -> bool { @@ -768,7 +772,7 @@ pub trait HigherOrderUDF: Debug + DynEq + DynHash + Send + Sync + Any { /// Implementations must ensure that the two returned `Vec`s are disjunct, /// and that each argument from `args` is present in one the two `Vec`s. /// - /// When overriding this function, [HigherOrderUDF::short_circuits] must + /// When overriding this function, [HigherOrderUDFImpl::short_circuits] must /// be overridden to return `true`. fn conditional_arguments<'a>( &self, @@ -783,7 +787,7 @@ pub trait HigherOrderUDF: Debug + DynEq + DynHash + Send + Sync + Any { /// Coerce value arguments of a function call to types that the function can evaluate. /// Note that if you need to coerce values based on the output type of lambdas, you - /// must use [HigherOrderUDF::coerce_values_for_lambdas], as this function is used before + /// must use [HigherOrderUDFImpl::coerce_values_for_lambdas], as this function is used before /// the output type of lambdas are known /// /// See the [type coercion module](crate::type_coercion) @@ -806,7 +810,7 @@ pub trait HigherOrderUDF: Debug + DynEq + DynHash + Send + Sync + Any { ) } - /// Returns the documentation for this HigherOrderUDF. + /// Returns the documentation for this function. /// /// Documentation can be accessed programmatically as well as generating /// publicly facing documentation. @@ -815,6 +819,296 @@ pub trait HigherOrderUDF: Debug + DynEq + DynHash + Send + Sync + Any { } } +/// Logical representation of a Higher Order User Defined Function. +/// +/// A higher order function takes one or more lambda arguments in addition to +/// regular value arguments. This struct contains the information DataFusion +/// needs to plan and invoke functions you supply such as name, type signature, +/// return type, and actual implementation. +#[derive(Debug, Clone)] +pub struct HigherOrderUDF { + inner: Arc, +} + +impl PartialEq for HigherOrderUDF { + fn eq(&self, other: &Self) -> bool { + self.inner.as_ref().dyn_eq(other.inner.as_ref()) + } +} + +impl PartialOrd for HigherOrderUDF { + fn partial_cmp(&self, other: &Self) -> Option { + let mut cmp = self.name().cmp(other.name()); + if cmp == Ordering::Equal { + cmp = self.signature().partial_cmp(other.signature())?; + } + if cmp == Ordering::Equal { + cmp = self.aliases().partial_cmp(other.aliases())?; + } + // Contract for PartialOrd and PartialEq consistency requires that + // a == b if and only if partial_cmp(a, b) == Some(Equal). + if cmp == Ordering::Equal && self != other { + // Functions may have other properties besides name and signature + // that differentiate two instances (e.g. type, or arbitrary parameters). + // We cannot return Some(Equal) in such case. + return None; + } + debug_assert!( + cmp == Ordering::Equal || self != other, + "Detected incorrect implementation of PartialEq when comparing functions: '{}' and '{}'. \ + The functions compare as equal, but they are not equal based on general properties that \ + the PartialOrd implementation observes,", + self.name(), + other.name() + ); + Some(cmp) + } +} + +impl Eq for HigherOrderUDF {} + +impl Hash for HigherOrderUDF { + fn hash(&self, state: &mut H) { + self.inner.dyn_hash(state) + } +} + +impl HigherOrderUDF { + /// Create a new `HigherOrderUDF` from a [`HigherOrderUDFImpl`] trait object. + /// + /// Note this is the same as using the `From` impl (`HigherOrderUDF::from`). + pub fn new_from_impl(fun: F) -> HigherOrderUDF + where + F: HigherOrderUDFImpl + 'static, + { + Self::new_from_shared_impl(Arc::new(fun)) + } + + /// Create a new `HigherOrderUDF` from a shared [`HigherOrderUDFImpl`] trait object. + pub fn new_from_shared_impl(fun: Arc) -> HigherOrderUDF { + Self { inner: fun } + } + + /// Return the underlying [`HigherOrderUDFImpl`] trait object for this function. + pub fn inner(&self) -> &Arc { + &self.inner + } + + /// Adds additional names that can be used to invoke this function, in + /// addition to `name`. + /// + /// If you implement [`HigherOrderUDFImpl`] directly you should return aliases + /// directly. + pub fn with_aliases(self, aliases: impl IntoIterator) -> Self { + Self::new_from_impl(AliasedHigherOrderUDFImpl::new( + Arc::clone(&self.inner), + aliases, + )) + } + + /// Returns this function's name. + /// + /// See [`HigherOrderUDFImpl::name`] for more details. + pub fn name(&self) -> &str { + self.inner.name() + } + + /// Returns the aliases for this function. + /// + /// See [`HigherOrderUDF::with_aliases`] for more details. + pub fn aliases(&self) -> &[String] { + self.inner.aliases() + } + + /// Returns this function's schema_name. + /// + /// See [`HigherOrderUDFImpl::schema_name`] for more details. + pub fn schema_name(&self, args: &[Expr]) -> Result { + self.inner.schema_name(args) + } + + /// Returns this function's [`HigherOrderSignature`]. + pub fn signature(&self) -> &HigherOrderSignature { + self.inner.signature() + } + + /// Returns the parameters of all lambdas of this function for the current step. + /// + /// See [`HigherOrderUDFImpl::lambda_parameters`] for more details. + pub fn lambda_parameters( + &self, + step: usize, + fields: &[ValueOrLambda>], + ) -> Result { + self.inner.lambda_parameters(step, fields) + } + + /// Coerce value arguments based on lambda output types. + /// + /// See [`HigherOrderUDFImpl::coerce_values_for_lambdas`] for more details. + pub fn coerce_values_for_lambdas( + &self, + fields: &[ValueOrLambda], + ) -> Result>> { + self.inner.coerce_values_for_lambdas(fields) + } + + /// Returns the return field of the function given its arguments. + /// + /// See [`HigherOrderUDFImpl::return_field_from_args`] for more details. + pub fn return_field_from_args( + &self, + args: HigherOrderReturnFieldArgs, + ) -> Result { + self.inner.return_field_from_args(args) + } + + /// Whether List or LargeList arguments should have non-empty null sublists + /// cleaned before invoking this function. + pub fn clear_null_values(&self) -> bool { + self.inner.clear_null_values() + } + + /// Invoke the function returning the appropriate result. + /// + /// See [`HigherOrderUDFImpl::invoke_with_args`] for more details. + pub fn invoke_with_args( + &self, + args: HigherOrderFunctionArgs, + ) -> Result { + self.inner.invoke_with_args(args) + } + + /// Returns true if some of this function's subexpressions may not be evaluated. + /// + /// See [`HigherOrderUDFImpl::short_circuits`] for more details. + pub fn short_circuits(&self) -> bool { + self.inner.short_circuits() + } + + /// Returns which arguments are evaluated eagerly vs lazily. + /// + /// See [`HigherOrderUDFImpl::conditional_arguments`] for more details. + pub fn conditional_arguments<'a>( + &self, + args: &'a [Expr], + ) -> Option<(Vec<&'a Expr>, Vec<&'a Expr>)> { + self.inner.conditional_arguments(args) + } + + /// Coerce value arguments of a function call to types that the function can evaluate. + /// + /// See [`HigherOrderUDFImpl::coerce_value_types`] for more details. + pub fn coerce_value_types(&self, arg_types: &[DataType]) -> Result> { + self.inner.coerce_value_types(arg_types) + } + + /// Returns the documentation for this function, if any. + pub fn documentation(&self) -> Option<&Documentation> { + self.inner.documentation() + } +} + +impl From for HigherOrderUDF +where + F: HigherOrderUDFImpl + 'static, +{ + fn from(fun: F) -> Self { + Self::new_from_impl(fun) + } +} + +/// `HigherOrderUDFImpl` that adds aliases to the underlying function. It is +/// better to implement [`HigherOrderUDFImpl`], which supports aliases, directly +/// if possible. +#[derive(Debug, PartialEq, Eq, Hash)] +struct AliasedHigherOrderUDFImpl { + inner: UdfEq>, + aliases: Vec, +} + +impl AliasedHigherOrderUDFImpl { + fn new( + inner: Arc, + new_aliases: impl IntoIterator, + ) -> Self { + let mut aliases = inner.aliases().to_vec(); + aliases.extend(new_aliases.into_iter().map(|s| s.to_string())); + Self { + inner: inner.into(), + aliases, + } + } +} + +#[warn(clippy::missing_trait_methods)] // Delegates, so it should implement every single trait method +impl HigherOrderUDFImpl for AliasedHigherOrderUDFImpl { + fn name(&self) -> &str { + self.inner.name() + } + + fn aliases(&self) -> &[String] { + &self.aliases + } + + fn schema_name(&self, args: &[Expr]) -> Result { + self.inner.schema_name(args) + } + + fn signature(&self) -> &HigherOrderSignature { + self.inner.signature() + } + + fn lambda_parameters( + &self, + step: usize, + fields: &[ValueOrLambda>], + ) -> Result { + self.inner.lambda_parameters(step, fields) + } + + fn coerce_values_for_lambdas( + &self, + fields: &[ValueOrLambda], + ) -> Result>> { + self.inner.coerce_values_for_lambdas(fields) + } + + fn return_field_from_args( + &self, + args: HigherOrderReturnFieldArgs, + ) -> Result { + self.inner.return_field_from_args(args) + } + + fn clear_null_values(&self) -> bool { + self.inner.clear_null_values() + } + + fn invoke_with_args(&self, args: HigherOrderFunctionArgs) -> Result { + self.inner.invoke_with_args(args) + } + + fn short_circuits(&self) -> bool { + self.inner.short_circuits() + } + + fn conditional_arguments<'a>( + &self, + args: &'a [Expr], + ) -> Option<(Vec<&'a Expr>, Vec<&'a Expr>)> { + self.inner.conditional_arguments(args) + } + + fn coerce_value_types(&self, arg_types: &[DataType]) -> Result> { + self.inner.coerce_value_types(arg_types) + } + + fn documentation(&self) -> Option<&Documentation> { + self.inner.documentation() + } +} + pub(crate) fn resolve_lambda_variables( expr: Expr, schema: &DFSchema, @@ -854,7 +1148,7 @@ pub(crate) fn resolve_lambda_variables( } fn resolve_higher_order_function( - func: Arc, + func: Arc, args: Vec, schema: &DFSchema, // a map of lambda variable name => a never empty stack of fields [ [..shadowed], in_scope ] @@ -1083,8 +1377,8 @@ mod tests { use datafusion_expr_common::signature::Volatility; use crate::{ - Expr, HigherOrderSignature, HigherOrderUDF, LambdaParametersProgress, - ValueOrLambda, col, + Expr, HigherOrderSignature, HigherOrderUDF, HigherOrderUDFImpl, + LambdaParametersProgress, ValueOrLambda, col, expr::{HigherOrderFunction, LambdaVariable}, lambda, lambda_var, lit, }; @@ -1095,7 +1389,7 @@ mod tests { field: &'static str, signature: HigherOrderSignature, } - impl HigherOrderUDF for TestHigherOrderUDF { + impl HigherOrderUDFImpl for TestHigherOrderUDF { fn name(&self) -> &str { self.name } @@ -1158,12 +1452,12 @@ mod tests { assert_eq!(b.partial_cmp(&o), Some(Ordering::Less)); } - fn test_func(name: &'static str, parameter: &'static str) -> Arc { - Arc::new(TestHigherOrderUDF { + fn test_func(name: &'static str, parameter: &'static str) -> Arc { + Arc::new(HigherOrderUDF::new_from_impl(TestHigherOrderUDF { name, field: parameter, signature: HigherOrderSignature::variadic_any(Volatility::Immutable), - }) + })) } fn hash(value: &T) -> u64 { @@ -1177,7 +1471,7 @@ mod tests { signature: HigherOrderSignature, } - impl HigherOrderUDF for MockArrayReduce { + impl HigherOrderUDFImpl for MockArrayReduce { fn name(&self) -> &str { "array_reduce" } @@ -1274,9 +1568,9 @@ mod tests { )])) .unwrap(); - let func = Arc::new(MockArrayReduce { + let func = Arc::new(HigherOrderUDF::new_from_impl(MockArrayReduce { signature: HigherOrderSignature::variadic_any(Volatility::Immutable), - }) as _; + })); /* array_reduce( diff --git a/datafusion/expr/src/lib.rs b/datafusion/expr/src/lib.rs index da7f20783bd06..b52a784df931a 100644 --- a/datafusion/expr/src/lib.rs +++ b/datafusion/expr/src/lib.rs @@ -117,8 +117,8 @@ pub use function::{ }; pub use higher_order_function::{ HigherOrderFunctionArgs, HigherOrderReturnFieldArgs, HigherOrderSignature, - HigherOrderTypeSignature, HigherOrderUDF, LambdaArgument, LambdaParametersProgress, - ValueOrLambda, + HigherOrderTypeSignature, HigherOrderUDF, HigherOrderUDFImpl, LambdaArgument, + LambdaParametersProgress, ValueOrLambda, }; pub use literal::{ Literal, TimestampLiteral, lit, lit_timestamp_nano, lit_with_metadata, diff --git a/datafusion/expr/src/planner.rs b/datafusion/expr/src/planner.rs index d69f4ac5fe23f..00f197357295d 100644 --- a/datafusion/expr/src/planner.rs +++ b/datafusion/expr/src/planner.rs @@ -104,7 +104,7 @@ pub trait ContextProvider { fn get_function_meta(&self, name: &str) -> Option>; /// Return the higher order function with a given name, if any - fn get_higher_order_meta(&self, name: &str) -> Option>; + fn get_higher_order_meta(&self, name: &str) -> Option>; /// Return the aggregate function with a given name, if any fn get_aggregate_meta(&self, name: &str) -> Option>; diff --git a/datafusion/expr/src/registry.rs b/datafusion/expr/src/registry.rs index f03cc5936c6ed..4b9744d9573b6 100644 --- a/datafusion/expr/src/registry.rs +++ b/datafusion/expr/src/registry.rs @@ -56,7 +56,7 @@ pub trait FunctionRegistry { /// Returns a reference to the user defined higher order function named /// `name`. - fn higher_order_function(&self, name: &str) -> Result>; + fn higher_order_function(&self, name: &str) -> Result>; /// Returns a reference to the user defined aggregate function (udaf) named /// `name`. @@ -81,8 +81,8 @@ pub trait FunctionRegistry { /// for example if the registry is read only. fn register_higher_order_function( &mut self, - _function: Arc, - ) -> Result>> { + _function: Arc, + ) -> Result>> { not_impl_err!("Registering HigherOrderUDF") } /// Registers a new [`AggregateUDF`], returning any previously registered @@ -122,7 +122,7 @@ pub trait FunctionRegistry { fn deregister_higher_order_function( &mut self, _name: &str, - ) -> Result>> { + ) -> Result>> { not_impl_err!("Deregistering HigherOrderUDF") } @@ -198,7 +198,7 @@ pub struct MemoryFunctionRegistry { /// Window Functions udwfs: HashMap>, /// Higher Order Functions - higher_order_functions: HashMap>, + higher_order_functions: HashMap>, } impl MemoryFunctionRegistry { @@ -219,7 +219,7 @@ impl FunctionRegistry for MemoryFunctionRegistry { .ok_or_else(|| plan_datafusion_err!("Function {name} not found")) } - fn higher_order_function(&self, name: &str) -> Result> { + fn higher_order_function(&self, name: &str) -> Result> { self.higher_order_functions .get(name) .cloned() @@ -245,8 +245,8 @@ impl FunctionRegistry for MemoryFunctionRegistry { } fn register_higher_order_function( &mut self, - function: Arc, - ) -> Result>> { + function: Arc, + ) -> Result>> { Ok(self .higher_order_functions .insert(function.name().into(), function)) diff --git a/datafusion/expr/src/type_coercion/functions.rs b/datafusion/expr/src/type_coercion/functions.rs index 1f625e33d31ef..c3802590bcacc 100644 --- a/datafusion/expr/src/type_coercion/functions.rs +++ b/datafusion/expr/src/type_coercion/functions.rs @@ -158,7 +158,7 @@ pub fn fields_with_udf( /// argument must be coerced to match `signature`. /// For lambda arguments, returns a clone of the associated data /// -/// Note this does not invokes [HigherOrderUDF::coerce_values_for_lambdas]. +/// Note this does not invokes [crate::HigherOrderUDFImpl::coerce_values_for_lambdas]. /// If that's required, use [value_fields_with_higher_order_udf_and_lambdas] /// instead /// @@ -166,7 +166,7 @@ pub fn fields_with_udf( /// [`type_coercion`](crate::type_coercion) module. pub fn value_fields_with_higher_order_udf( current_fields: &[ValueOrLambda], - func: &dyn HigherOrderUDF, + func: &HigherOrderUDF, ) -> Result>> { match func.signature().type_signature { HigherOrderTypeSignature::UserDefined => { @@ -306,7 +306,7 @@ pub fn value_fields_with_higher_order_udf( } /// Performs type coercion for higher order function arguments, -/// including those defined by [HigherOrderUDF::coerce_values_for_lambdas], +/// including those defined by [crate::HigherOrderUDFImpl::coerce_values_for_lambdas], /// if it returns `Some(...)` instead of the default `None`. Note that /// compared to [value_fields_with_higher_order_udf], this function requires /// the [ValueOrLambda::Lambda] variant to contain the output field of the lambda. @@ -319,7 +319,7 @@ pub fn value_fields_with_higher_order_udf( /// [`type_coercion`](crate::type_coercion) module. pub fn value_fields_with_higher_order_udf_and_lambdas( current_fields: &[ValueOrLambda], - func: &dyn HigherOrderUDF, + func: &HigherOrderUDF, ) -> Result>> { let mut new_fields = value_fields_with_higher_order_udf(current_fields, func)?; @@ -1169,7 +1169,7 @@ fn coerced_from<'a>( mod tests { use crate::{ HigherOrderFunctionArgs, HigherOrderReturnFieldArgs, HigherOrderSignature, - Volatility, + HigherOrderUDFImpl, Volatility, }; use super::*; @@ -1901,7 +1901,7 @@ mod tests { coerced_value_types: Vec, } - impl HigherOrderUDF for MockHigherOrderUDF { + impl HigherOrderUDFImpl for MockHigherOrderUDF { fn name(&self) -> &str { "mock_higher_order_function" } @@ -1962,10 +1962,10 @@ mod tests { #[test] fn test_higher_order_function_user_defined_type_coercion() { - let fun = MockHigherOrderUDF { + let fun = HigherOrderUDF::new_from_impl(MockHigherOrderUDF { signature: HigherOrderSignature::user_defined(Volatility::Immutable), coerced_value_types: vec![DataType::new_large_list(DataType::Int32, false)], - }; + }); let new_fields = value_fields_with_higher_order_udf( &[ @@ -1996,10 +1996,10 @@ mod tests { #[test] fn test_higher_order_function_coerce_values_for_lambdas() { - let fun = MockHigherOrderUDF { + let fun = HigherOrderUDF::new_from_impl(MockHigherOrderUDF { signature: HigherOrderSignature::variadic_any(Volatility::Immutable), coerced_value_types: vec![], - }; + }); let new_fields = value_fields_with_higher_order_udf_and_lambdas( &[ @@ -2032,10 +2032,10 @@ mod tests { #[test] fn test_higher_order_function_user_defined_type_coercion_bad_args() { - let fun = MockHigherOrderUDF { + let fun = HigherOrderUDF::new_from_impl(MockHigherOrderUDF { signature: HigherOrderSignature::user_defined(Volatility::Immutable), coerced_value_types: vec![DataType::Int32], - }; + }); let err = value_fields_with_higher_order_udf::<()>(&[], &fun).unwrap_err(); @@ -2047,10 +2047,10 @@ mod tests { #[test] fn test_higher_order_function_faulty_user_defined_type_coercion() { - let fun = MockHigherOrderUDF { + let fun = HigherOrderUDF::new_from_impl(MockHigherOrderUDF { signature: HigherOrderSignature::user_defined(Volatility::Immutable), coerced_value_types: vec![DataType::Int32, DataType::Int32], - }; + }); let err = value_fields_with_higher_order_udf::<()>( &[ValueOrLambda::Value(Arc::new(Field::new( @@ -2070,10 +2070,10 @@ mod tests { #[test] fn test_higher_order_function_any_signature() { - let fun = MockHigherOrderUDF { + let fun = HigherOrderUDF::new_from_impl(MockHigherOrderUDF { signature: HigherOrderSignature::any(1, Volatility::Immutable), coerced_value_types: vec![], - }; + }); let new_fields = value_fields_with_higher_order_udf(&[ValueOrLambda::Lambda(())], &fun) @@ -2085,10 +2085,10 @@ mod tests { #[test] fn test_higher_order_function_any_signature_bad_args() { - let fun = MockHigherOrderUDF { + let fun = HigherOrderUDF::new_from_impl(MockHigherOrderUDF { signature: HigherOrderSignature::any(1, Volatility::Immutable), coerced_value_types: vec![], - }; + }); let err = value_fields_with_higher_order_udf::<()>(&[], &fun).unwrap_err(); @@ -2100,13 +2100,13 @@ mod tests { #[test] fn test_higher_order_function_exact_signature() { - let fun = MockHigherOrderUDF { + let fun = HigherOrderUDF::new_from_impl(MockHigherOrderUDF { signature: HigherOrderSignature::exact( vec![ValueOrLambda::Value(()), ValueOrLambda::Lambda(())], Volatility::Immutable, ), coerced_value_types: vec![DataType::new_large_list(DataType::Int32, false)], - }; + }); let new_fields = value_fields_with_higher_order_udf( &[ @@ -2137,13 +2137,13 @@ mod tests { #[test] fn test_higher_order_function_exact_signature_wrong_value_count() { - let fun = MockHigherOrderUDF { + let fun = HigherOrderUDF::new_from_impl(MockHigherOrderUDF { signature: HigherOrderSignature::exact( vec![ValueOrLambda::Value(()), ValueOrLambda::Lambda(())], Volatility::Immutable, ), coerced_value_types: vec![], - }; + }); let err = value_fields_with_higher_order_udf::<()>( &[ValueOrLambda::Lambda(()), ValueOrLambda::Lambda(())], @@ -2159,13 +2159,13 @@ mod tests { #[test] fn test_higher_order_function_exact_signature_wrong_lambda_count() { - let fun = MockHigherOrderUDF { + let fun = HigherOrderUDF::new_from_impl(MockHigherOrderUDF { signature: HigherOrderSignature::exact( vec![ValueOrLambda::Value(()), ValueOrLambda::Lambda(())], Volatility::Immutable, ), coerced_value_types: vec![], - }; + }); let err = value_fields_with_higher_order_udf::<()>( &[ diff --git a/datafusion/expr/src/udf_eq.rs b/datafusion/expr/src/udf_eq.rs index 5fb0266aef5dd..8766b483137f4 100644 --- a/datafusion/expr/src/udf_eq.rs +++ b/datafusion/expr/src/udf_eq.rs @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -use crate::{AggregateUDFImpl, HigherOrderUDF, ScalarUDFImpl, WindowUDFImpl}; +use crate::{AggregateUDFImpl, HigherOrderUDFImpl, ScalarUDFImpl, WindowUDFImpl}; use std::any::Any; use std::fmt::Debug; use std::hash::{DefaultHasher, Hash, Hasher}; @@ -94,7 +94,7 @@ impl UdfPointer for Arc { } } -impl UdfPointer for Arc { +impl UdfPointer for Arc { fn equals(&self, other: &Self::Target) -> bool { self.as_ref().dyn_eq(other) } diff --git a/datafusion/ffi/src/session/mod.rs b/datafusion/ffi/src/session/mod.rs index dfc9d1c7dfebd..6ddb879feb217 100644 --- a/datafusion/ffi/src/session/mod.rs +++ b/datafusion/ffi/src/session/mod.rs @@ -378,7 +378,7 @@ pub struct ForeignSession { session: FFI_SessionRef, config: SessionConfig, scalar_functions: HashMap>, - higher_order_functions: HashMap>, + higher_order_functions: HashMap>, aggregate_functions: HashMap>, window_functions: HashMap>, extension_types: ExtensionTypeRegistryRef, @@ -590,7 +590,7 @@ impl Session for ForeignSession { &self.scalar_functions } - fn higher_order_functions(&self) -> &HashMap> { + fn higher_order_functions(&self) -> &HashMap> { &self.higher_order_functions } diff --git a/datafusion/functions-nested/src/array_any_match.rs b/datafusion/functions-nested/src/array_any_match.rs index 3ce43a23c2124..a14cc8c260ecb 100644 --- a/datafusion/functions-nested/src/array_any_match.rs +++ b/datafusion/functions-nested/src/array_any_match.rs @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -//! [`HigherOrderUDF`] definitions for array_any_match function. +//! [`datafusion_expr::HigherOrderUDF`] definitions for array_any_match function. use arrow::{ array::{Array, AsArray, BooleanArray, BooleanBuilder, new_null_array}, @@ -31,7 +31,7 @@ use datafusion_common::{ }; use datafusion_expr::{ ColumnarValue, Documentation, HigherOrderFunctionArgs, HigherOrderReturnFieldArgs, - HigherOrderSignature, HigherOrderUDF, LambdaParametersProgress, ValueOrLambda, + HigherOrderSignature, HigherOrderUDFImpl, LambdaParametersProgress, ValueOrLambda, Volatility, }; use datafusion_macros::user_doc; @@ -106,7 +106,7 @@ fn any_match_for_range( if any_null { None } else { Some(false) } } -impl HigherOrderUDF for ArrayAnyMatch { +impl HigherOrderUDFImpl for ArrayAnyMatch { fn name(&self) -> &str { "array_any_match" } diff --git a/datafusion/functions-nested/src/array_filter.rs b/datafusion/functions-nested/src/array_filter.rs index f8b7fc35404a8..a1fa8268a31a9 100644 --- a/datafusion/functions-nested/src/array_filter.rs +++ b/datafusion/functions-nested/src/array_filter.rs @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -//! [`HigherOrderUDF`] definitions for array_filter function. +//! [`datafusion_expr::HigherOrderUDF`] definitions for array_filter function. use arrow::{ array::{ @@ -32,7 +32,7 @@ use datafusion_common::{ }; use datafusion_expr::{ ColumnarValue, Documentation, HigherOrderFunctionArgs, HigherOrderReturnFieldArgs, - HigherOrderSignature, HigherOrderUDF, LambdaParametersProgress, ValueOrLambda, + HigherOrderSignature, HigherOrderUDFImpl, LambdaParametersProgress, ValueOrLambda, Volatility, }; use datafusion_macros::user_doc; @@ -96,7 +96,7 @@ impl ArrayFilter { } } -impl HigherOrderUDF for ArrayFilter { +impl HigherOrderUDFImpl for ArrayFilter { fn name(&self) -> &str { "array_filter" } diff --git a/datafusion/functions-nested/src/array_transform.rs b/datafusion/functions-nested/src/array_transform.rs index a0415749f45e2..1c1c5077344e1 100644 --- a/datafusion/functions-nested/src/array_transform.rs +++ b/datafusion/functions-nested/src/array_transform.rs @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -//! [`HigherOrderUDF`] definitions for array_transform function. +//! [`datafusion_expr::HigherOrderUDF`] definitions for array_transform function. use arrow::{ array::{Array, ArrayRef, AsArray, LargeListArray, ListArray}, @@ -28,7 +28,7 @@ use datafusion_common::{ }; use datafusion_expr::{ ColumnarValue, Documentation, HigherOrderFunctionArgs, HigherOrderReturnFieldArgs, - HigherOrderSignature, HigherOrderUDF, LambdaParametersProgress, ValueOrLambda, + HigherOrderSignature, HigherOrderUDFImpl, LambdaParametersProgress, ValueOrLambda, Volatility, }; use datafusion_macros::user_doc; @@ -89,7 +89,7 @@ impl ArrayTransform { } } -impl HigherOrderUDF for ArrayTransform { +impl HigherOrderUDFImpl for ArrayTransform { fn name(&self) -> &str { "array_transform" } diff --git a/datafusion/functions-nested/src/lambda_utils.rs b/datafusion/functions-nested/src/lambda_utils.rs index cb8682d4bd18b..0f208ce5d26b2 100644 --- a/datafusion/functions-nested/src/lambda_utils.rs +++ b/datafusion/functions-nested/src/lambda_utils.rs @@ -153,7 +153,7 @@ pub(crate) mod test_utils { } pub(crate) fn eval_hof_on_i32_list( - func: Arc, + func: Arc, list: impl Array + Clone + 'static, lambda_body: Expr, ) -> Result { diff --git a/datafusion/functions-nested/src/lib.rs b/datafusion/functions-nested/src/lib.rs index 1e6dc68cb23ae..45918735e3d7b 100644 --- a/datafusion/functions-nested/src/lib.rs +++ b/datafusion/functions-nested/src/lib.rs @@ -201,7 +201,7 @@ pub fn all_default_nested_functions() -> Vec> { ] } -pub fn all_default_higher_order_functions() -> Vec> { +pub fn all_default_higher_order_functions() -> Vec> { vec![ array_any_match::array_any_match_higher_order_function(), array_filter::array_filter_higher_order_function(), @@ -220,7 +220,7 @@ pub fn register_all(registry: &mut dyn FunctionRegistry) -> Result<()> { Ok(()) as Result<()> })?; - let functions: Vec> = all_default_higher_order_functions(); + let functions: Vec> = all_default_higher_order_functions(); functions.into_iter().try_for_each(|function| { let existing_function = registry.register_higher_order_function(function)?; if let Some(existing_function) = existing_function { diff --git a/datafusion/functions-nested/src/macros_lambda.rs b/datafusion/functions-nested/src/macros_lambda.rs index 8c15d8aed13b6..c8fe670844b2d 100644 --- a/datafusion/functions-nested/src/macros_lambda.rs +++ b/datafusion/functions-nested/src/macros_lambda.rs @@ -95,11 +95,11 @@ macro_rules! create_higher_order { ($UDF:ident, $HIGHER_ORDER_UDF_FN:ident, $CTOR:path) => { #[doc = concat!("HigherOrderFunction that returns a [`HigherOrderUDF`](datafusion_expr::HigherOrderUDF) for ")] #[doc = stringify!($UDF)] - pub fn $HIGHER_ORDER_UDF_FN() -> std::sync::Arc { + pub fn $HIGHER_ORDER_UDF_FN() -> std::sync::Arc { // Singleton instance of [`$UDF`], ensures the UDF is only created once - static INSTANCE: std::sync::LazyLock> = + static INSTANCE: std::sync::LazyLock> = std::sync::LazyLock::new(|| { - std::sync::Arc::new($CTOR()) + std::sync::Arc::new(datafusion_expr::HigherOrderUDF::new_from_impl($CTOR())) }); std::sync::Arc::clone(&INSTANCE) } diff --git a/datafusion/optimizer/tests/optimizer_integration.rs b/datafusion/optimizer/tests/optimizer_integration.rs index e61e6467930e6..e1a394ad975b3 100644 --- a/datafusion/optimizer/tests/optimizer_integration.rs +++ b/datafusion/optimizer/tests/optimizer_integration.rs @@ -882,7 +882,7 @@ impl ContextProvider for MyContextProvider { fn get_higher_order_meta( &self, _name: &str, - ) -> Option> { + ) -> Option> { None } diff --git a/datafusion/physical-expr/src/higher_order_function.rs b/datafusion/physical-expr/src/higher_order_function.rs index 801e69ea8fb69..7390eb33a0922 100644 --- a/datafusion/physical-expr/src/higher_order_function.rs +++ b/datafusion/physical-expr/src/higher_order_function.rs @@ -69,7 +69,7 @@ enum ArgSlot { /// Physical expression of a higher order function pub struct HigherOrderFunctionExpr { /// A shared instance of the higher-order function - fun: Arc, + fun: Arc, /// The name of the higher-order function name: String, /// List of expressions to feed to the function as arguments @@ -125,7 +125,7 @@ impl HigherOrderFunctionExpr { /// Note that lambda arguments must be present directly in args as [LambdaExpr], /// and not as a wrapped child of any arg pub fn try_new_with_schema( - fun: Arc, + fun: Arc, args: Vec>, schema: &Schema, config_options: Arc, @@ -172,7 +172,7 @@ impl HigherOrderFunctionExpr { } /// Get the higher order function implementation - pub fn fun(&self) -> &dyn HigherOrderUDF { + pub fn fun(&self) -> &HigherOrderUDF { self.fun.as_ref() } @@ -200,7 +200,7 @@ impl HigherOrderFunctionExpr { } /// Resolve every lambda's parameter list. Returns an empty `Vec` when - /// there are no lambdas, avoiding the [`HigherOrderUDF::lambda_parameters`] + /// there are no lambdas, avoiding the [`datafusion_expr::HigherOrderUDFImpl::lambda_parameters`] /// virtual call entirely. fn resolve_lambda_parameters( &self, @@ -519,7 +519,7 @@ mod tests { use datafusion_common::Result; use datafusion_common::assert_contains; use datafusion_expr::{ - HigherOrderFunctionArgs, HigherOrderSignature, HigherOrderUDF, + HigherOrderFunctionArgs, HigherOrderSignature, HigherOrderUDF, HigherOrderUDFImpl, }; use datafusion_expr_common::columnar_value::ColumnarValue; use datafusion_physical_expr_common::physical_expr::PhysicalExpr; @@ -531,7 +531,7 @@ mod tests { signature: HigherOrderSignature, } - impl HigherOrderUDF for MockHigherOrderUDF { + impl HigherOrderUDFImpl for MockHigherOrderUDF { fn name(&self) -> &str { "mock_function" } @@ -578,14 +578,14 @@ mod tests { #[test] fn test_higher_order_function_volatile_node() { // Create a volatile UDF - let volatile_udf = Arc::new(MockHigherOrderUDF { + let volatile_udf = Arc::new(HigherOrderUDF::new_from_impl(MockHigherOrderUDF { signature: HigherOrderSignature::variadic_any(Volatility::Volatile), - }); + })); // Create a non-volatile UDF - let stable_udf = Arc::new(MockHigherOrderUDF { + let stable_udf = Arc::new(HigherOrderUDF::new_from_impl(MockHigherOrderUDF { signature: HigherOrderSignature::variadic_any(Volatility::Stable), - }); + })); let schema = Schema::new(vec![Field::new("a", DataType::Float32, false)]); let args = vec![Arc::new(Column::new("a", 0)) as Arc]; @@ -620,9 +620,9 @@ mod tests { #[test] fn test_higher_order_function_wrapped_lambda() { - let fun = Arc::new(MockHigherOrderUDF { + let fun = Arc::new(HigherOrderUDF::new_from_impl(MockHigherOrderUDF { signature: HigherOrderSignature::variadic_any(Volatility::Stable), - }); + })); let expected = ScalarValue::Int32(Some(42)); @@ -657,9 +657,9 @@ mod tests { #[test] fn test_higher_order_function_badly_wrapped_lambda() { - let fun = Arc::new(MockHigherOrderUDF { + let fun = Arc::new(HigherOrderUDF::new_from_impl(MockHigherOrderUDF { signature: HigherOrderSignature::variadic_any(Volatility::Stable), - }); + })); let hof = HigherOrderFunctionExpr::try_new_with_schema( fun, @@ -694,9 +694,9 @@ mod tests { #[test] fn test_higher_order_function_unexpected_lambda() { - let fun = Arc::new(MockHigherOrderUDF { + let fun = Arc::new(HigherOrderUDF::new_from_impl(MockHigherOrderUDF { signature: HigherOrderSignature::variadic_any(Volatility::Stable), - }); + })); let hof = HigherOrderFunctionExpr::try_new_with_schema( fun, diff --git a/datafusion/proto/src/logical_plan/mod.rs b/datafusion/proto/src/logical_plan/mod.rs index 8228e8e6f2ff0..7f3843d802691 100644 --- a/datafusion/proto/src/logical_plan/mod.rs +++ b/datafusion/proto/src/logical_plan/mod.rs @@ -162,7 +162,7 @@ pub trait LogicalExtensionCodec: Debug + Send + Sync + std::any::Any { &self, name: &str, _buf: &[u8], - ) -> Result> { + ) -> Result> { not_impl_err!( "LogicalExtensionCodec is not provided for higher order function {name}" ) @@ -170,7 +170,7 @@ pub trait LogicalExtensionCodec: Debug + Send + Sync + std::any::Any { fn try_encode_higher_order_function( &self, - _node: &dyn HigherOrderUDF, + _node: &HigherOrderUDF, _buf: &mut Vec, ) -> Result<()> { Ok(()) diff --git a/datafusion/session/src/session.rs b/datafusion/session/src/session.rs index 82dda6655f8e2..15ad543cf0ffb 100644 --- a/datafusion/session/src/session.rs +++ b/datafusion/session/src/session.rs @@ -114,7 +114,7 @@ pub trait Session: Send + Sync { fn scalar_functions(&self) -> &HashMap>; /// Return reference to higher_order_functions - fn higher_order_functions(&self) -> &HashMap>; + fn higher_order_functions(&self) -> &HashMap>; /// Return reference to aggregate_functions fn aggregate_functions(&self) -> &HashMap>; diff --git a/datafusion/spark/src/lib.rs b/datafusion/spark/src/lib.rs index 2eee94c52ef78..6cd4678da7560 100644 --- a/datafusion/spark/src/lib.rs +++ b/datafusion/spark/src/lib.rs @@ -59,7 +59,7 @@ //! # fn udafs(&self) -> HashSet { unimplemented!() } //! # fn udwfs(&self) -> HashSet { unimplemented!() } //! # fn udf(&self, _name: &str) -> Result> { unimplemented!() } -//! # fn higher_order_function(&self, name: &str) -> Result> { unimplemented!() } +//! # fn higher_order_function(&self, name: &str) -> Result> { unimplemented!() } //! # fn udaf(&self, name: &str) -> Result> {unimplemented!() } //! # fn udwf(&self, name: &str) -> Result> { unimplemented!() } //! # fn expr_planners(&self) -> Vec> { unimplemented!() } diff --git a/datafusion/sql/examples/sql.rs b/datafusion/sql/examples/sql.rs index dc49b4460fec5..883439fbf1e09 100644 --- a/datafusion/sql/examples/sql.rs +++ b/datafusion/sql/examples/sql.rs @@ -138,7 +138,7 @@ impl ContextProvider for MyContextProvider { None } - fn get_higher_order_meta(&self, _name: &str) -> Option> { + fn get_higher_order_meta(&self, _name: &str) -> Option> { None } diff --git a/datafusion/sql/src/expr/function.rs b/datafusion/sql/src/expr/function.rs index 67abb8b822063..701485eee733c 100644 --- a/datafusion/sql/src/expr/function.rs +++ b/datafusion/sql/src/expr/function.rs @@ -370,7 +370,7 @@ impl SqlToRel<'_, S> { if let Some(fm) = self.context_provider.get_higher_order_meta(&name) { // plan non-lambda arguments first so we can get theirs datatype and call - // HigherOrderUDF::lambda_parameters to then plan the lambda arguments with + // HigherOrderUDFImpl::lambda_parameters to then plan the lambda arguments with // resolved lambda variables enum ExprOrLambda { Expr(Expr), diff --git a/datafusion/sql/src/expr/mod.rs b/datafusion/sql/src/expr/mod.rs index ba7811acd8f3c..440ced09f227a 100644 --- a/datafusion/sql/src/expr/mod.rs +++ b/datafusion/sql/src/expr/mod.rs @@ -1415,7 +1415,7 @@ mod tests { None } - fn get_higher_order_meta(&self, _name: &str) -> Option> { + fn get_higher_order_meta(&self, _name: &str) -> Option> { None } diff --git a/datafusion/sql/src/unparser/expr.rs b/datafusion/sql/src/unparser/expr.rs index d7b1c6a3bb6de..d83c6b6e13bb7 100644 --- a/datafusion/sql/src/unparser/expr.rs +++ b/datafusion/sql/src/unparser/expr.rs @@ -1908,11 +1908,12 @@ mod tests { use datafusion_common::{Spans, TableReference}; use datafusion_expr::expr::WildcardOptions; use datafusion_expr::{ - ColumnarValue, HigherOrderUDF, LambdaParametersProgress, ScalarFunctionArgs, - ScalarUDF, ScalarUDFImpl, Signature, ValueOrLambda, Volatility, WindowFrame, - WindowFunctionDefinition, case, cast, col, cube, exists, grouping_set, - interval_datetime_lit, interval_year_month_lit, lambda, lambda_var, lit, not, - not_exists, out_ref_col, placeholder, rollup, table_scan, try_cast, when, + ColumnarValue, HigherOrderUDF, HigherOrderUDFImpl, LambdaParametersProgress, + ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl, Signature, ValueOrLambda, + Volatility, WindowFrame, WindowFunctionDefinition, case, cast, col, cube, exists, + grouping_set, interval_datetime_lit, interval_year_month_lit, lambda, lambda_var, + lit, not, not_exists, out_ref_col, placeholder, rollup, table_scan, try_cast, + when, }; use datafusion_expr::{ExprFunctionExt, interval_month_day_nano_lit}; use datafusion_functions::datetime::from_unixtime::FromUnixtimeFunc; @@ -1969,7 +1970,7 @@ mod tests { #[derive(Debug, Hash, Eq, PartialEq)] struct DummyHigherOrderUDF; - impl HigherOrderUDF for DummyHigherOrderUDF { + impl HigherOrderUDFImpl for DummyHigherOrderUDF { fn name(&self) -> &str { "dummy_higher_order_function" } @@ -2087,7 +2088,7 @@ mod tests { ), ( Expr::HigherOrderFunction(HigherOrderFunction::new( - Arc::new(DummyHigherOrderUDF), + Arc::new(HigherOrderUDF::new_from_impl(DummyHigherOrderUDF)), vec![col("a"), lambda(["v"], -lambda_var("v"))], )), r#"dummy_higher_order_function(a, (v) -> -v)"#, diff --git a/datafusion/sql/tests/common/mod.rs b/datafusion/sql/tests/common/mod.rs index 71e864d2a733d..e7c819bbf64a6 100644 --- a/datafusion/sql/tests/common/mod.rs +++ b/datafusion/sql/tests/common/mod.rs @@ -56,7 +56,7 @@ impl Display for MockCsvType { #[derive(Default)] pub(crate) struct MockSessionState { scalar_functions: HashMap>, - higher_order_functions: HashMap>, + higher_order_functions: HashMap>, aggregate_functions: HashMap>, expr_planners: Vec>, type_planner: Option>, @@ -101,7 +101,7 @@ impl MockSessionState { pub fn with_higher_order_function( mut self, - higher_order_function: Arc, + higher_order_function: Arc, ) -> Self { self.higher_order_functions.insert( higher_order_function.name().to_string(), @@ -291,7 +291,7 @@ impl ContextProvider for MockContextProvider { self.state.scalar_functions.get(name).cloned() } - fn get_higher_order_meta(&self, name: &str) -> Option> { + fn get_higher_order_meta(&self, name: &str) -> Option> { self.state.higher_order_functions.get(name).cloned() } diff --git a/datafusion/sql/tests/sql_integration.rs b/datafusion/sql/tests/sql_integration.rs index 64763e33d93f7..6885711531c5f 100644 --- a/datafusion/sql/tests/sql_integration.rs +++ b/datafusion/sql/tests/sql_integration.rs @@ -29,7 +29,7 @@ use common::MockContextProvider; use datafusion_common::{DFSchema, DataFusionError, Result, assert_contains}; use datafusion_expr::{ ColumnarValue, CreateIndex, DdlStatement, Expr, HigherOrderFunctionArgs, - HigherOrderReturnFieldArgs, HigherOrderSignature, HigherOrderUDF, + HigherOrderReturnFieldArgs, HigherOrderSignature, HigherOrderUDF, HigherOrderUDFImpl, LambdaParametersProgress, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl, Signature, ValueOrLambda, Volatility, col, expr::{HigherOrderFunction, LambdaVariable, ScalarFunction}, @@ -3510,7 +3510,9 @@ fn logical_plan_with_options(sql: &str, options: ParserOptions) -> Result Result { let state = MockSessionState::default() .with_aggregate_function(sum_udaf()) - .with_higher_order_function(Arc::new(MockArrayReduce::new())) + .with_higher_order_function(Arc::new(HigherOrderUDF::new_from_impl( + MockArrayReduce::new(), + ))) .with_scalar_function(make_array_udf()) .with_expr_planner(Arc::new(CustomExprPlanner {})); // plan array literal let context = MockContextProvider { state }; @@ -5358,7 +5360,7 @@ fn test_progressive_lambda_parameters() { assert_eq!( expr, Expr::HigherOrderFunction(HigherOrderFunction::new( - Arc::new(MockArrayReduce::new()), + Arc::new(HigherOrderUDF::new_from_impl(MockArrayReduce::new())), vec![ Expr::ScalarFunction(ScalarFunction::new_udf( make_array_udf(), @@ -5402,7 +5404,7 @@ impl MockArrayReduce { } } -impl HigherOrderUDF for MockArrayReduce { +impl HigherOrderUDFImpl for MockArrayReduce { fn name(&self) -> &str { "array_reduce" }