diff --git a/datafusion/substrait/src/logical_plan/consumer/rel/project_rel.rs b/datafusion/substrait/src/logical_plan/consumer/rel/project_rel.rs index 0a4048650fa2b..5aea6c809b701 100644 --- a/datafusion/substrait/src/logical_plan/consumer/rel/project_rel.rs +++ b/datafusion/substrait/src/logical_plan/consumer/rel/project_rel.rs @@ -20,6 +20,7 @@ use crate::logical_plan::consumer::utils::NameTracker; use async_recursion::async_recursion; use datafusion::common::{Column, not_impl_err}; use datafusion::logical_expr::builder::project; +use datafusion::logical_expr::utils::find_window_exprs; use datafusion::logical_expr::{Expr, LogicalPlan, LogicalPlanBuilder}; use std::collections::HashSet; use std::sync::Arc; @@ -57,13 +58,9 @@ pub async fn from_project_rel( let e = consumer .consume_expression(expr, input.clone().schema()) .await?; - // if the expression is WindowFunction, wrap in a Window relation - if let Expr::WindowFunction(_) = &e { - // Adding the same expression here and in the project below - // works because the project's builder uses columnize_expr(..) - // to transform it into a column reference - window_exprs.insert(e.clone()); - } + // The project's builder uses columnize_expr(..) to transform + // nested window expressions into column references. + window_exprs.extend(find_window_exprs([&e])); explicit_exprs.push(name_tracker.get_uniquely_named_expr(e)?); } diff --git a/datafusion/substrait/tests/cases/logical_plans.rs b/datafusion/substrait/tests/cases/logical_plans.rs index d4ac01462c879..522381de6efdf 100644 --- a/datafusion/substrait/tests/cases/logical_plans.rs +++ b/datafusion/substrait/tests/cases/logical_plans.rs @@ -91,6 +91,31 @@ mod tests { Ok(()) } + #[tokio::test] + async fn nested_window_function_in_expression() -> Result<()> { + // The Substrait Project expression represents: + // SELECT 1 + count(*) OVER () FROM DATA + let proto_plan = read_json( + "tests/testdata/test_plans/nested_window_expression.substrait.json", + ); + let ctx = add_plan_schemas_to_ctx(SessionContext::new(), &proto_plan)?; + let plan = from_substrait_plan(&ctx.state(), &proto_plan).await?; + + assert_snapshot!( + plan, + @r" + Projection: Int64(1) + count(Int64(1)) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING AS EXPR$0 + WindowAggr: windowExpr=[[count(Int64(1)) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING]] + TableScan: DATA + " + ); + + // Trigger execution to ensure the nested window is physically plannable + DataFrame::new(ctx.state(), plan).show().await?; + + Ok(()) + } + #[tokio::test] async fn double_window_function() -> Result<()> { // Confirms a WindowExpr can be repeated in the same project. diff --git a/datafusion/substrait/tests/testdata/test_plans/nested_window_expression.substrait.json b/datafusion/substrait/tests/testdata/test_plans/nested_window_expression.substrait.json new file mode 100644 index 0000000000000..f4dc73a9ca672 --- /dev/null +++ b/datafusion/substrait/tests/testdata/test_plans/nested_window_expression.substrait.json @@ -0,0 +1,131 @@ +{ + "extensionUris": [ + { + "extensionUriAnchor": 1, + "uri": "/functions_arithmetic.yaml" + }, + { + "extensionUriAnchor": 2, + "uri": "/functions_aggregate_generic.yaml" + } + ], + "extensions": [ + { + "extensionFunction": { + "extensionUriReference": 1, + "functionAnchor": 0, + "name": "add:i64_i64" + } + }, + { + "extensionFunction": { + "extensionUriReference": 2, + "functionAnchor": 1, + "name": "count:any" + } + } + ], + "relations": [ + { + "root": { + "input": { + "project": { + "common": { + "emit": { + "outputMapping": [ + 1 + ] + } + }, + "input": { + "read": { + "common": { + "direct": {} + }, + "baseSchema": { + "names": [ + "A" + ], + "struct": { + "types": [ + { + "i64": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + } + ], + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + "namedTable": { + "names": [ + "DATA" + ] + } + } + }, + "expressions": [ + { + "scalarFunction": { + "functionReference": 0, + "args": [], + "outputType": { + "i64": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + "arguments": [ + { + "value": { + "literal": { + "i64": 1, + "nullable": false, + "typeVariationReference": 0 + } + } + }, + { + "value": { + "windowFunction": { + "functionReference": 1, + "partitions": [], + "sorts": [], + "upperBound": { + "unbounded": {} + }, + "lowerBound": { + "unbounded": {} + }, + "phase": "AGGREGATION_PHASE_INITIAL_TO_RESULT", + "outputType": { + "i64": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + "args": [], + "arguments": [], + "invocation": "AGGREGATION_INVOCATION_ALL", + "options": [], + "boundsType": "BOUNDS_TYPE_ROWS" + } + } + } + ], + "options": [] + } + } + ] + } + }, + "names": [ + "EXPR$0" + ] + } + } + ], + "expectedTypeUrls": [] +}