From 68e8de629cfbca27527529da5dfed2c5ad8522be Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20Heres?= Date: Fri, 29 May 2026 11:01:09 +0200 Subject: [PATCH 1/2] Add row-number late materialization for TopK --- Cargo.lock | 2 + datafusion/common/src/config.rs | 6 + .../core/src/optimizer_rule_reference.md | 9 +- .../tests/parquet/external_access_plan.rs | 40 +- .../late_materialization.rs | 175 ++++ .../core/tests/physical_optimizer/mod.rs | 1 + .../datasource-parquet/src/opener/mod.rs | 129 ++- datafusion/datasource/src/mod.rs | 27 + datafusion/physical-optimizer/Cargo.toml | 2 + .../src/late_materialization.rs | 851 ++++++++++++++++++ datafusion/physical-optimizer/src/lib.rs | 1 + .../physical-optimizer/src/optimizer.rs | 5 + docs/source/user-guide/configs.md | 1 + 13 files changed, 1238 insertions(+), 11 deletions(-) create mode 100644 datafusion/core/tests/physical_optimizer/late_materialization.rs create mode 100644 datafusion/physical-optimizer/src/late_materialization.rs diff --git a/Cargo.lock b/Cargo.lock index 4d5b15075ecef..66728b16c2376 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2393,6 +2393,7 @@ version = "53.1.0" dependencies = [ "arrow", "datafusion-common", + "datafusion-datasource", "datafusion-execution", "datafusion-expr", "datafusion-expr-common", @@ -2402,6 +2403,7 @@ dependencies = [ "datafusion-physical-expr-common", "datafusion-physical-plan", "datafusion-pruning", + "futures", "insta", "itertools 0.14.0", "recursive", diff --git a/datafusion/common/src/config.rs b/datafusion/common/src/config.rs index e6d1ebbbbe746..20c9f564cc8eb 100644 --- a/datafusion/common/src/config.rs +++ b/datafusion/common/src/config.rs @@ -1124,6 +1124,12 @@ config_namespace! { /// into the file scan phase. pub enable_topk_dynamic_filter_pushdown: bool, default = true + /// When set to true, the physical optimizer will try to evaluate + /// simple ORDER BY ... LIMIT queries over a narrow key-only input first, + /// carrying hidden row numbers and materializing the full rows only + /// after the TopK has been computed. + pub enable_row_number_topk_late_materialization: bool, default = true + /// When set to true, the optimizer will attempt to push down Join dynamic filters /// into the file scan phase. pub enable_join_dynamic_filter_pushdown: bool, default = true diff --git a/datafusion/core/src/optimizer_rule_reference.md b/datafusion/core/src/optimizer_rule_reference.md index 1f9f37f530557..0ae9c6a042b41 100644 --- a/datafusion/core/src/optimizer_rule_reference.md +++ b/datafusion/core/src/optimizer_rule_reference.md @@ -88,7 +88,8 @@ in multiple phases. | 16 | `LimitPushdown` | - | Moves physical limits into child operators or fetch-enabled variants to cut data early. | | 17 | `TopKRepartition` | - | Pushes TopK below hash repartition when the partition key is a prefix of the sort key. | | 18 | `ProjectionPushdown` | late pass | Runs projection pushdown again after limit and TopK rewrites expose new pruning opportunities. | -| 19 | `PushdownSort` | - | Pushes sort requirements into data sources that can already return sorted output. | -| 20 | `EnsureCooperative` | - | Wraps non-cooperative plan parts so long-running tasks yield fairly. | -| 21 | `FilterPushdown(Post)` | post-optimization phase | Pushes dynamic filters at the end of optimization, after plan references stop moving. | -| 22 | `SanityCheckPlan` | - | Validates that the final physical plan meets ordering, distribution, and infinite-input safety requirements. | +| 19 | `LateMaterialization` | - | Rewrites simple TopK plans to sort a narrow key-only input before materializing full rows by row number. | +| 20 | `PushdownSort` | - | Pushes sort requirements into data sources that can already return sorted output. | +| 21 | `EnsureCooperative` | - | Wraps non-cooperative plan parts so long-running tasks yield fairly. | +| 22 | `FilterPushdown(Post)` | post-optimization phase | Pushes dynamic filters at the end of optimization, after plan references stop moving. | +| 23 | `SanityCheckPlan` | - | Validates that the final physical plan meets ordering, distribution, and infinite-input safety requirements. | diff --git a/datafusion/core/tests/parquet/external_access_plan.rs b/datafusion/core/tests/parquet/external_access_plan.rs index 31be6fd979fd6..24df223716380 100644 --- a/datafusion/core/tests/parquet/external_access_plan.rs +++ b/datafusion/core/tests/parquet/external_access_plan.rs @@ -36,8 +36,8 @@ use datafusion_expr::{Expr, col, lit}; use datafusion_physical_plan::ExecutionPlan; use datafusion_physical_plan::metrics::{MetricValue, MetricsSet}; -use datafusion_datasource::file_scan_config::FileScanConfigBuilder; use datafusion_datasource::source::DataSourceExec; +use datafusion_datasource::{FileRowsSelection, file_scan_config::FileScanConfigBuilder}; use parquet::arrow::ArrowWriter; use parquet::arrow::arrow_reader::{RowSelection, RowSelector}; use parquet::file::properties::WriterProperties; @@ -128,6 +128,35 @@ async fn selection_scan() { } } +#[tokio::test] +async fn file_rows_selection() { + TestFull { + access_plan: None, + file_rows_selection: Some(FileRowsSelection::new(vec![1, 5, 6, 9])), + expected_rows: 4, + predicate: None, + } + .run() + .await + .unwrap(); +} + +#[tokio::test] +async fn file_rows_selection_intersects_access_plan() { + TestFull { + access_plan: Some(ParquetAccessPlan::new(vec![ + RowGroupAccess::Scan, + RowGroupAccess::Skip, + ])), + file_rows_selection: Some(FileRowsSelection::new(vec![1, 5, 6, 9])), + expected_rows: 1, + predicate: None, + } + .run() + .await + .unwrap(); +} + #[tokio::test] async fn skip_scan() { let plans = vec![ @@ -170,6 +199,7 @@ async fn plan_and_filter() { // initial let parquet_metrics = TestFull { access_plan, + file_rows_selection: None, expected_rows: 0, predicate: Some(predicate), } @@ -227,6 +257,7 @@ async fn bad_row_groups() { RowGroupAccess::Skip, RowGroupAccess::Scan, ])), + file_rows_selection: None, expected_rows: 0, predicate: None, } @@ -249,6 +280,7 @@ async fn bad_selection() { ])), RowGroupAccess::Skip, ])), + file_rows_selection: None, // expects that we hit an error, this should not be run expected_rows: 10000, predicate: None, @@ -300,6 +332,7 @@ impl Test { } = self; TestFull { access_plan, + file_rows_selection: None, expected_rows, predicate: None, } @@ -317,6 +350,7 @@ impl Test { /// 4. Returns the statistics from running the plan struct TestFull { access_plan: Option, + file_rows_selection: Option, expected_rows: usize, predicate: Option, } @@ -327,6 +361,7 @@ impl TestFull { let Self { access_plan, + file_rows_selection, expected_rows, predicate, } = self; @@ -351,6 +386,9 @@ impl TestFull { if let Some(access_plan) = access_plan { partitioned_file = partitioned_file.with_extension(access_plan); } + if let Some(file_rows_selection) = file_rows_selection { + partitioned_file = partitioned_file.with_extension(file_rows_selection); + } // Create a DataSourceExec to read the file let object_store_url = ObjectStoreUrl::local_filesystem(); diff --git a/datafusion/core/tests/physical_optimizer/late_materialization.rs b/datafusion/core/tests/physical_optimizer/late_materialization.rs new file mode 100644 index 0000000000000..52de3bbdd757f --- /dev/null +++ b/datafusion/core/tests/physical_optimizer/late_materialization.rs @@ -0,0 +1,175 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::sync::Arc; + +use crate::physical_optimizer::test_utils::{parquet_exec, schema}; + +use arrow::array::{Int64Array, StringArray}; +use arrow::compute::SortOptions; +use arrow::datatypes::{DataType, Field, Schema}; +use arrow::record_batch::RecordBatch; +use datafusion::assert_batches_eq; +use datafusion::prelude::{ParquetReadOptions, SessionConfig, SessionContext}; +use datafusion_common::config::ConfigOptions; +use datafusion_common::{Result, assert_contains}; +use datafusion_physical_expr::expressions::col; +use datafusion_physical_expr_common::sort_expr::{LexOrdering, PhysicalSortExpr}; +use datafusion_physical_optimizer::PhysicalOptimizerRule; +use datafusion_physical_optimizer::late_materialization::LateMaterialization; +use datafusion_physical_plan::sorts::sort::SortExec; +use datafusion_physical_plan::{ExecutionPlan, displayable}; +use parquet::arrow::ArrowWriter; +use parquet::file::properties::WriterProperties; +use tempfile::NamedTempFile; + +#[test] +fn rewrites_simple_parquet_topk_by_default() -> Result<()> { + let plan = topk_plan()?; + let optimized = LateMaterialization::new().optimize(plan, &ConfigOptions::new())?; + let display = displayable(optimized.as_ref()).indent(false).to_string(); + + assert_contains!(&display, "LateTopKMaterializationExec: fetch=10"); + assert_contains!(&display, "projection=[c]"); + assert_contains!(&display, "projection=[a, b, c, d, e]"); + Ok(()) +} + +#[test] +fn skips_rewrite_when_disabled() -> Result<()> { + let plan = topk_plan()?; + let mut config = ConfigOptions::new(); + config.optimizer.enable_row_number_topk_late_materialization = false; + + let optimized = LateMaterialization::new().optimize(plan, &config)?; + let display = displayable(optimized.as_ref()).indent(false).to_string(); + + assert_contains!(&display, "SortExec: TopK(fetch=10)"); + assert!(!display.contains("LateTopKMaterializationExec")); + Ok(()) +} + +#[tokio::test] +async fn sql_parquet_select_star_order_by_limit_uses_late_materialization() -> Result<()> +{ + let test_file = make_topk_parquet()?; + let ctx = topk_context(&test_file, true).await?; + + let dataframe = ctx.sql("SELECT * FROM t ORDER BY c LIMIT 10").await?; + let plan = dataframe.create_physical_plan().await?; + let display = displayable(plan.as_ref()).indent(false).to_string(); + + assert_contains!(&display, "LateTopKMaterializationExec: fetch=10"); + assert_contains!(&display, "projection=[c]"); + + let batches = dataframe.collect().await?; + assert_batches_eq!( + [ + "+----+---+---------+", + "| id | c | payload |", + "+----+---+---------+", + "| 19 | 0 | p19 |", + "| 18 | 1 | p18 |", + "| 17 | 2 | p17 |", + "| 16 | 3 | p16 |", + "| 15 | 4 | p15 |", + "| 14 | 5 | p14 |", + "| 13 | 6 | p13 |", + "| 12 | 7 | p12 |", + "| 11 | 8 | p11 |", + "| 10 | 9 | p10 |", + "+----+---+---------+", + ], + &batches + ); + + Ok(()) +} + +#[tokio::test] +async fn sql_parquet_select_star_order_by_limit_respects_disabled_config() -> Result<()> { + let test_file = make_topk_parquet()?; + let ctx = topk_context(&test_file, false).await?; + + let dataframe = ctx.sql("SELECT * FROM t ORDER BY c LIMIT 10").await?; + let plan = dataframe.create_physical_plan().await?; + let display = displayable(plan.as_ref()).indent(false).to_string(); + + assert_contains!(&display, "SortExec: TopK(fetch=10)"); + assert!(!display.contains("LateTopKMaterializationExec")); + + Ok(()) +} + +fn topk_plan() -> Result> { + let schema = schema(); + let scan = parquet_exec(Arc::clone(&schema)); + let sort_exprs = LexOrdering::new(vec![PhysicalSortExpr { + expr: col("c", &schema)?, + options: SortOptions::default(), + }]) + .unwrap(); + + Ok(Arc::new( + SortExec::new(sort_exprs, scan).with_fetch(Some(10)), + )) +} + +async fn topk_context( + test_file: &NamedTempFile, + late_materialization_enabled: bool, +) -> Result { + let config = SessionConfig::new().with_collect_statistics(true).set_bool( + "datafusion.optimizer.enable_row_number_topk_late_materialization", + late_materialization_enabled, + ); + let ctx = SessionContext::new_with_config(config); + ctx.register_parquet( + "t", + test_file.path().to_str().unwrap(), + ParquetReadOptions::default(), + ) + .await?; + Ok(ctx) +} + +fn make_topk_parquet() -> Result { + let mut test_file = tempfile::Builder::new().suffix(".parquet").tempfile()?; + let schema = Arc::new(Schema::new(vec![ + Field::new("id", DataType::Int64, false), + Field::new("c", DataType::Int64, false), + Field::new("payload", DataType::Utf8, false), + ])); + let ids = (0..20).collect::>(); + let sort_values = (0..20).rev().collect::>(); + let payloads = (0..20).map(|value| format!("p{value}")).collect::>(); + let batch = RecordBatch::try_new( + Arc::clone(&schema), + vec![ + Arc::new(Int64Array::from(ids)), + Arc::new(Int64Array::from(sort_values)), + Arc::new(StringArray::from(payloads)), + ], + )?; + let props = WriterProperties::builder() + .set_max_row_group_row_count(Some(5)) + .build(); + let mut writer = ArrowWriter::try_new(&mut test_file, schema, Some(props))?; + writer.write(&batch)?; + writer.close()?; + Ok(test_file) +} diff --git a/datafusion/core/tests/physical_optimizer/mod.rs b/datafusion/core/tests/physical_optimizer/mod.rs index b7ba661d2343a..52b30891fd8b1 100644 --- a/datafusion/core/tests/physical_optimizer/mod.rs +++ b/datafusion/core/tests/physical_optimizer/mod.rs @@ -26,6 +26,7 @@ mod enforce_sorting; mod enforce_sorting_monotonicity; mod filter_pushdown; mod join_selection; +mod late_materialization; #[expect(clippy::needless_pass_by_value)] mod limit_pushdown; mod limited_distinct_aggregation; diff --git a/datafusion/datasource-parquet/src/opener/mod.rs b/datafusion/datasource-parquet/src/opener/mod.rs index 95e0516e8bc27..01d0e74222515 100644 --- a/datafusion/datasource-parquet/src/opener/mod.rs +++ b/datafusion/datasource-parquet/src/opener/mod.rs @@ -49,7 +49,7 @@ use arrow::datatypes::{SchemaRef, TimeUnit}; use datafusion_common::encryption::FileDecryptionProperties; use datafusion_common::stats::Precision; use datafusion_common::{ColumnStatistics, Result, ScalarValue, Statistics, exec_err}; -use datafusion_datasource::{PartitionedFile, TableSchema}; +use datafusion_datasource::{FileRowsSelection, PartitionedFile, TableSchema}; use datafusion_physical_expr::simplifier::PhysicalExprSimplifier; use datafusion_physical_expr_adapter::PhysicalExprAdapterFactory; use datafusion_physical_expr_common::physical_expr::{ @@ -70,11 +70,12 @@ use log::debug; use parquet::arrow::ParquetRecordBatchStreamBuilder; use parquet::arrow::arrow_reader::metrics::ArrowReaderMetrics; use parquet::arrow::arrow_reader::{ArrowReaderMetadata, ArrowReaderOptions}; +use parquet::arrow::arrow_reader::{RowSelection, RowSelector}; use parquet::arrow::async_reader::AsyncFileReader; use parquet::arrow::parquet_column; use parquet::basic::Type; use parquet::bloom_filter::Sbbf; -use parquet::file::metadata::{PageIndexPolicy, ParquetMetaDataReader}; +use parquet::file::metadata::{PageIndexPolicy, ParquetMetaDataReader, RowGroupMetaData}; /// Stateless Parquet morselizer implementation. /// @@ -896,7 +897,7 @@ impl FiltersPreparedParquetOpen { let mut row_groups = RowGroupAccessPlanFilter::new(create_initial_plan( &prepared.file_name, &prepared.extensions, - rg_metadata.len(), + rg_metadata, )?); // If there is a range restricting what parts of the file to read @@ -1343,8 +1344,9 @@ fn constant_value_from_stats( fn create_initial_plan( file_name: &str, extensions: &datafusion_datasource::FileExtensions, - row_group_count: usize, + row_group_meta_data: &[RowGroupMetaData], ) -> Result { + let row_group_count = row_group_meta_data.len(); if let Some(access_plan) = extensions.get::() { let plan_len = access_plan.len(); if plan_len != row_group_count { @@ -1352,11 +1354,126 @@ fn create_initial_plan( "Invalid ParquetAccessPlan for {file_name}. Specified {plan_len} row groups, but file has {row_group_count}" ); } - return Ok(access_plan.clone()); + let mut access_plan = access_plan.clone(); + if let Some(file_rows_selection) = extensions.get::() { + apply_file_rows_selection( + &mut access_plan, + file_rows_selection, + row_group_meta_data, + )?; + } + return Ok(access_plan); } // default to scanning all row groups - Ok(ParquetAccessPlan::new_all(row_group_count)) + let mut access_plan = ParquetAccessPlan::new_all(row_group_count); + if let Some(file_rows_selection) = extensions.get::() { + apply_file_rows_selection( + &mut access_plan, + file_rows_selection, + row_group_meta_data, + )?; + } + Ok(access_plan) +} + +fn apply_file_rows_selection( + access_plan: &mut ParquetAccessPlan, + file_rows_selection: &FileRowsSelection, + row_group_meta_data: &[RowGroupMetaData], +) -> Result<()> { + let mut selected_rows = file_rows_selection.row_indices().iter().copied().peekable(); + let mut row_group_start = 0_u64; + let mut last_selected_row = None; + + for (row_group_index, row_group_metadata) in row_group_meta_data.iter().enumerate() { + let row_group_rows = + u64::try_from(row_group_metadata.num_rows()).map_err(|_| { + datafusion_common::DataFusionError::Internal( + "Parquet row group row count overflowed u64".to_string(), + ) + })?; + let row_group_end = row_group_start + row_group_rows; + let mut row_group_offsets = Vec::new(); + + while let Some(row_index) = selected_rows.peek().copied() { + if row_index < row_group_start { + return exec_err!( + "Invalid FileRowsSelection. Row indices must be sorted in ascending order" + ); + } + if row_index >= row_group_end { + break; + } + if last_selected_row.is_some_and(|last| row_index <= last) { + return exec_err!( + "Invalid FileRowsSelection. Row indices must be sorted in ascending order and unique" + ); + } + last_selected_row = Some(row_index); + row_group_offsets.push( + usize::try_from(row_index - row_group_start).map_err(|_| { + datafusion_common::DataFusionError::Internal( + "selected Parquet row offset overflowed usize".to_string(), + ) + })?, + ); + selected_rows.next(); + } + + if row_group_offsets.is_empty() { + access_plan.skip(row_group_index); + } else { + let row_selection = row_selection_from_offsets( + &row_group_offsets, + usize::try_from(row_group_rows).map_err(|_| { + datafusion_common::DataFusionError::Internal( + "Parquet row group row count overflowed usize".to_string(), + ) + })?, + ); + access_plan.scan_selection(row_group_index, row_selection); + } + + row_group_start = row_group_end; + } + + if let Some(row_index) = selected_rows.next() { + return exec_err!( + "Invalid FileRowsSelection. Row index {row_index} is outside file row count {row_group_start}" + ); + } + + Ok(()) +} + +fn row_selection_from_offsets(offsets: &[usize], row_group_rows: usize) -> RowSelection { + let mut selectors = Vec::new(); + let mut position = 0_usize; + let mut offset_index = 0_usize; + + while offset_index < offsets.len() { + let range_start = offsets[offset_index]; + if range_start > position { + selectors.push(RowSelector::skip(range_start - position)); + } + + let mut range_end = range_start + 1; + offset_index += 1; + while offset_index < offsets.len() && offsets[offset_index] == range_end { + range_end += 1; + offset_index += 1; + } + + selectors.push(RowSelector::select(range_end - range_start)); + position = range_end; + } + + if position < row_group_rows { + selectors.push(RowSelector::skip(row_group_rows - position)); + } + + RowSelection::from(selectors) } /// Build a page pruning predicate from an optional predicate expression. diff --git a/datafusion/datasource/src/mod.rs b/datafusion/datasource/src/mod.rs index 84daf608b5182..9e26d7ac4c450 100644 --- a/datafusion/datasource/src/mod.rs +++ b/datafusion/datasource/src/mod.rs @@ -165,6 +165,33 @@ pub struct PartitionedFile { pub table_reference: Option, } +/// Row indices to scan from a single [`PartitionedFile`]. +/// +/// File sources that support row-level pruning can use this extension to avoid +/// decoding rows that are known to be unnecessary. Indices are zero-based +/// within the file and must be sorted in ascending order. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct FileRowsSelection { + row_indices: Vec, +} + +impl FileRowsSelection { + /// Create a new file row selection. + pub fn new(row_indices: Vec) -> Self { + Self { row_indices } + } + + /// Return the selected row indices. + pub fn row_indices(&self) -> &[u64] { + &self.row_indices + } + + /// Return true if no rows are selected. + pub fn is_empty(&self) -> bool { + self.row_indices.is_empty() + } +} + impl PartitionedFile { /// Create a simple file without metadata or partition pub fn new(path: impl Into, size: u64) -> Self { diff --git a/datafusion/physical-optimizer/Cargo.toml b/datafusion/physical-optimizer/Cargo.toml index 38c8a7c37211f..fe0ebc9cbc799 100644 --- a/datafusion/physical-optimizer/Cargo.toml +++ b/datafusion/physical-optimizer/Cargo.toml @@ -43,6 +43,7 @@ recursive_protection = ["dep:recursive"] [dependencies] arrow = { workspace = true } datafusion-common = { workspace = true } +datafusion-datasource = { workspace = true } datafusion-execution = { workspace = true } datafusion-expr = { workspace = true } datafusion-expr-common = { workspace = true, default-features = true } @@ -50,6 +51,7 @@ datafusion-physical-expr = { workspace = true } datafusion-physical-expr-common = { workspace = true } datafusion-physical-plan = { workspace = true } datafusion-pruning = { workspace = true } +futures = { workspace = true } itertools = { workspace = true } recursive = { workspace = true, optional = true } diff --git a/datafusion/physical-optimizer/src/late_materialization.rs b/datafusion/physical-optimizer/src/late_materialization.rs new file mode 100644 index 0000000000000..e0e8f75e7f445 --- /dev/null +++ b/datafusion/physical-optimizer/src/late_materialization.rs @@ -0,0 +1,851 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Row-number based late materialization for simple TopK plans. + +use std::collections::{BTreeMap, HashMap, HashSet}; +use std::fmt::{Debug, Formatter}; +use std::sync::Arc; + +use arrow::array::{ArrayRef, UInt32Array, UInt64Array}; +use arrow::compute::{concat_batches, take}; +use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; +use arrow::record_batch::RecordBatch; +use datafusion_common::config::ConfigOptions; +use datafusion_common::stats::Precision; +use datafusion_common::tree_node::{ + Transformed, TransformedResult, TreeNode, TreeNodeRecursion, +}; +use datafusion_common::{DataFusionError, Result, internal_err}; +use datafusion_datasource::file_groups::FileGroup; +use datafusion_datasource::file_scan_config::{FileScanConfig, FileScanConfigBuilder}; +use datafusion_datasource::source::DataSourceExec; +use datafusion_datasource::{FileRowsSelection, PartitionedFile}; +use datafusion_execution::{SendableRecordBatchStream, TaskContext}; +use datafusion_physical_expr::expressions::Column; +use datafusion_physical_expr::{EquivalenceProperties, Partitioning}; +use datafusion_physical_expr_common::physical_expr::PhysicalExpr; +use datafusion_physical_expr_common::sort_expr::{LexOrdering, PhysicalSortExpr}; +use datafusion_physical_plan::coalesce_partitions::CoalescePartitionsExec; +use datafusion_physical_plan::execution_plan::{ + Boundedness, EmissionType, collect, reset_plan_states, +}; +use datafusion_physical_plan::metrics::{ExecutionPlanMetricsSet, MetricsSet}; +use datafusion_physical_plan::projection::ProjectionExec; +use datafusion_physical_plan::sorts::sort::SortExec; +use datafusion_physical_plan::sorts::sort_preserving_merge::SortPreservingMergeExec; +use datafusion_physical_plan::stream::RecordBatchStreamAdapter; +use datafusion_physical_plan::{ + DisplayAs, DisplayFormatType, ExecutionPlan, ExecutionPlanProperties, PlanProperties, + Statistics, +}; +use futures::{StreamExt, TryStreamExt, stream}; + +use crate::PhysicalOptimizerRule; + +const PARTITION_COLUMN: &str = "__datafusion_late_materialization_partition"; +const ROW_NUMBER_COLUMN: &str = "__datafusion_late_materialization_row_number"; + +/// Rewrites simple TopK plans to sort a narrow key-only stream and materialize +/// the full rows after the winning row numbers are known. +#[derive(Default, Debug)] +pub struct LateMaterialization {} + +impl LateMaterialization { + #[expect(missing_docs)] + pub fn new() -> Self { + Self {} + } +} + +impl PhysicalOptimizerRule for LateMaterialization { + fn optimize( + &self, + plan: Arc, + config: &ConfigOptions, + ) -> Result> { + if !config.optimizer.enable_row_number_topk_late_materialization { + return Ok(plan); + } + + plan.transform_down(|plan| { + if let Some(spm) = plan.downcast_ref::() + && let Some(fetch) = spm.fetch() + && let Some(sort_child) = spm.input().downcast_ref::() + && sort_child.preserve_partitioning() + && let Some(exec) = LateTopKMaterializationExec::try_new( + sort_child.input(), + spm.expr().clone(), + fetch, + )? + { + return Ok(Transformed::yes(Arc::new(exec) as Arc)); + } + + let Some(sort) = plan.downcast_ref::() else { + return Ok(Transformed::no(plan)); + }; + let Some(fetch) = sort.fetch() else { + return Ok(Transformed::no(plan)); + }; + if sort.preserve_partitioning() { + return Ok(Transformed::no(plan)); + } + + if let Some(exec) = LateTopKMaterializationExec::try_new( + sort.input(), + sort.expr().clone(), + fetch, + )? { + Ok(Transformed::yes(Arc::new(exec) as Arc)) + } else { + Ok(Transformed::no(plan)) + } + }) + .data() + } + + fn name(&self) -> &str { + "LateMaterialization" + } + + fn schema_check(&self) -> bool { + true + } +} + +#[derive(Debug, Clone)] +struct LateTopKMaterializationExec { + key_input: Arc, + full_input: Arc, + sort_exprs: LexOrdering, + key_sort_exprs: LexOrdering, + fetch: usize, + key_width: usize, + cache: Arc, + metrics: ExecutionPlanMetricsSet, +} + +impl LateTopKMaterializationExec { + fn try_new( + input: &Arc, + sort_exprs: LexOrdering, + fetch: usize, + ) -> Result> { + if fetch == 0 { + return Ok(None); + } + + let Some(full_input) = input.with_preserve_order(true) else { + return Ok(None); + }; + + let key_columns = simple_key_columns(&sort_exprs)?; + if key_columns.is_empty() { + return Ok(None); + } + let input_schema = full_input.schema(); + if key_columns.len() >= input_schema.fields().len() { + return Ok(None); + } + + let projection_exprs = key_columns + .iter() + .map(|column| { + ( + Arc::new(Column::new(column.name(), column.index())) + as Arc, + column.name().to_string(), + ) + }) + .collect::>(); + let projection = + ProjectionExec::try_new(projection_exprs, Arc::clone(&full_input))?; + + let Some(key_input) = full_input.try_swapping_with_projection(&projection)? + else { + return Ok(None); + }; + let Some(key_input) = key_input.with_preserve_order(true) else { + return Ok(None); + }; + + let key_sort_exprs = remap_sort_exprs(&sort_exprs, &key_columns)?; + let cache = Self::compute_properties(&full_input, &sort_exprs); + + Ok(Some(Self { + key_input, + full_input, + sort_exprs, + key_sort_exprs, + fetch, + key_width: key_columns.len(), + cache: Arc::new(cache), + metrics: ExecutionPlanMetricsSet::new(), + })) + } + + fn compute_properties( + full_input: &Arc, + sort_exprs: &LexOrdering, + ) -> PlanProperties { + let schema = full_input.schema(); + let eq_properties = + EquivalenceProperties::new_with_orderings(schema, vec![sort_exprs.to_vec()]); + PlanProperties::new( + eq_properties, + Partitioning::UnknownPartitioning(1), + EmissionType::Final, + Boundedness::Bounded, + ) + } + + fn key_topk_plan(&self) -> Arc { + let key_input = Arc::new(RowNumberExec::new(Arc::clone(&self.key_input))) + as Arc; + let key_input = if key_input.output_partitioning().partition_count() > 1 { + Arc::new(CoalescePartitionsExec::new(key_input)) as Arc + } else { + key_input + }; + Arc::new( + SortExec::new(self.key_sort_exprs.clone(), key_input) + .with_fetch(Some(self.fetch)), + ) + } +} + +impl DisplayAs for LateTopKMaterializationExec { + fn fmt_as(&self, t: DisplayFormatType, f: &mut Formatter) -> std::fmt::Result { + match t { + DisplayFormatType::Default | DisplayFormatType::Verbose => { + write!( + f, + "LateTopKMaterializationExec: fetch={}, expr=[{}]", + self.fetch, self.sort_exprs + ) + } + DisplayFormatType::TreeRender => { + writeln!(f, "fetch={}", self.fetch)?; + write!(f, "expr=[{}]", self.sort_exprs) + } + } + } +} + +impl ExecutionPlan for LateTopKMaterializationExec { + fn name(&self) -> &str { + "LateTopKMaterializationExec" + } + + fn properties(&self) -> &Arc { + &self.cache + } + + fn children(&self) -> Vec<&Arc> { + vec![&self.key_input, &self.full_input] + } + + fn with_new_children( + self: Arc, + children: Vec>, + ) -> Result> { + if children.len() != 2 { + return internal_err!( + "LateTopKMaterializationExec requires exactly two children" + ); + } + let mut new_exec = Arc::unwrap_or_clone(self); + new_exec.key_input = Arc::clone(&children[0]); + new_exec.full_input = Arc::clone(&children[1]); + new_exec.cache = Arc::new(Self::compute_properties( + &new_exec.full_input, + &new_exec.sort_exprs, + )); + Ok(Arc::new(new_exec)) + } + + fn execute( + &self, + partition: usize, + context: Arc, + ) -> Result { + if partition != 0 { + return internal_err!( + "LateTopKMaterializationExec invalid partition {partition}" + ); + } + + let key_topk_plan = reset_plan_states(self.key_topk_plan())?; + let full_input = reset_plan_states(Arc::clone(&self.full_input))?; + let output_schema = self.schema(); + let key_width = self.key_width; + let context = Arc::clone(&context); + + let batches = stream::once(async move { + let selected_rows = + collect_selected_rows(key_topk_plan, Arc::clone(&context), key_width) + .await?; + if let Some(batches) = materialize_with_pushed_down_file_rows( + &full_input, + Arc::clone(&context), + Arc::clone(&output_schema), + &selected_rows, + ) + .await? + { + return Ok(batches); + } + materialize_selected_rows(full_input, context, output_schema, &selected_rows) + .await + }) + .map_ok(|batches| stream::iter(batches.into_iter().map(Ok))) + .try_flatten(); + + Ok(Box::pin(RecordBatchStreamAdapter::new( + self.schema(), + batches, + ))) + } + + fn metrics(&self) -> Option { + Some(self.metrics.clone_inner()) + } + + fn partition_statistics(&self, partition: Option) -> Result> { + if partition.is_some() { + return Ok(Arc::new(Statistics::new_unknown(&self.schema()))); + } + Ok(Arc::new( + self.full_input + .partition_statistics(None)? + .as_ref() + .clone() + .with_fetch(Some(self.fetch), 0, 1)?, + )) + } + + fn apply_expressions( + &self, + f: &mut dyn FnMut(&dyn PhysicalExpr) -> Result, + ) -> Result { + let mut tnr = TreeNodeRecursion::Continue; + for sort_expr in &self.sort_exprs { + tnr = tnr.visit_sibling(|| f(sort_expr.expr.as_ref()))?; + } + Ok(tnr) + } +} + +#[derive(Debug, Clone)] +struct RowNumberExec { + input: Arc, + cache: Arc, +} + +impl RowNumberExec { + fn new(input: Arc) -> Self { + let input_schema = input.schema(); + let mut fields = input_schema.fields().to_vec(); + fields.push(Arc::new(Field::new( + PARTITION_COLUMN, + DataType::UInt64, + false, + ))); + fields.push(Arc::new(Field::new( + ROW_NUMBER_COLUMN, + DataType::UInt64, + false, + ))); + let schema = Arc::new(Schema::new_with_metadata( + fields, + input_schema.metadata().clone(), + )); + let cache = PlanProperties::new( + EquivalenceProperties::new(Arc::clone(&schema)), + input.output_partitioning().clone(), + input.pipeline_behavior(), + input.boundedness(), + ); + Self { + input, + cache: Arc::new(cache), + } + } +} + +impl DisplayAs for RowNumberExec { + fn fmt_as(&self, t: DisplayFormatType, f: &mut Formatter) -> std::fmt::Result { + match t { + DisplayFormatType::Default | DisplayFormatType::Verbose => { + write!(f, "RowNumberExec") + } + DisplayFormatType::TreeRender => write!(f, "RowNumberExec"), + } + } +} + +impl ExecutionPlan for RowNumberExec { + fn name(&self) -> &str { + "RowNumberExec" + } + + fn properties(&self) -> &Arc { + &self.cache + } + + fn children(&self) -> Vec<&Arc> { + vec![&self.input] + } + + fn with_new_children( + self: Arc, + children: Vec>, + ) -> Result> { + if children.len() != 1 { + return internal_err!("RowNumberExec requires exactly one child"); + } + Ok(Arc::new(Self::new(Arc::clone(&children[0])))) + } + + fn execute( + &self, + partition: usize, + context: Arc, + ) -> Result { + let input = self.input.execute(partition, context)?; + let schema = self.schema(); + let mut next_row_number = 0_u64; + let stream = input.map(move |batch| { + let batch = batch?; + let row_count = batch.num_rows(); + let end_row_number = next_row_number + row_count as u64; + let partition_values = UInt64Array::from_value(partition as u64, row_count); + let row_numbers = + UInt64Array::from_iter_values(next_row_number..end_row_number); + next_row_number = end_row_number; + + let mut columns = batch.columns().to_vec(); + columns.push(Arc::new(partition_values) as ArrayRef); + columns.push(Arc::new(row_numbers) as ArrayRef); + Ok(RecordBatch::try_new(Arc::clone(&schema), columns)?) + }); + + Ok(Box::pin(RecordBatchStreamAdapter::new( + self.schema(), + stream, + ))) + } + + fn partition_statistics(&self, partition: Option) -> Result> { + if partition.is_some() { + return Ok(Arc::new(Statistics::new_unknown(&self.schema()))); + } + Ok(Arc::new(Statistics::new_unknown(&self.schema()))) + } + + fn apply_expressions( + &self, + _f: &mut dyn FnMut(&dyn PhysicalExpr) -> Result, + ) -> Result { + Ok(TreeNodeRecursion::Continue) + } +} + +#[derive(Debug)] +struct SelectedRows { + by_partition: HashMap>, + row_count: usize, +} + +async fn collect_selected_rows( + key_topk_plan: Arc, + context: Arc, + key_width: usize, +) -> Result { + let batches = collect(key_topk_plan, context).await?; + let partition_index = key_width; + let row_number_index = key_width + 1; + + let mut by_partition: HashMap> = HashMap::new(); + let mut rank = 0_usize; + for batch in batches { + let partition_array = batch + .column(partition_index) + .as_any() + .downcast_ref::() + .ok_or_else(|| { + DataFusionError::Internal( + "late materialization partition column had wrong type".to_string(), + ) + })?; + let row_number_array = batch + .column(row_number_index) + .as_any() + .downcast_ref::() + .ok_or_else(|| { + DataFusionError::Internal( + "late materialization row number column had wrong type".to_string(), + ) + })?; + + for row in 0..batch.num_rows() { + let partition = + usize::try_from(partition_array.value(row)).map_err(|_| { + DataFusionError::Internal( + "late materialization partition value overflowed usize" + .to_string(), + ) + })?; + let row_number = row_number_array.value(row); + by_partition + .entry(partition) + .or_default() + .insert(row_number, rank); + rank += 1; + } + } + + Ok(SelectedRows { + by_partition, + row_count: rank, + }) +} + +async fn materialize_selected_rows( + full_input: Arc, + context: Arc, + schema: SchemaRef, + selected_rows: &SelectedRows, +) -> Result> { + if selected_rows.row_count == 0 { + return Ok(vec![]); + } + + let partition_count = full_input.output_partitioning().partition_count(); + let mut selected_batches = Vec::new(); + let mut selected_ranks = Vec::with_capacity(selected_rows.row_count); + + for (partition, rows) in &selected_rows.by_partition { + if *partition >= partition_count { + return Err(DataFusionError::Internal(format!( + "late materialization selected partition {partition}, but full input has {partition_count} partitions" + ))); + } + let mut stream = full_input.execute(*partition, Arc::clone(&context))?; + let mut next_row_number = 0_u64; + let mut found_in_partition = 0_usize; + let rows_to_find = rows.len(); + let max_row_number = rows.keys().next_back().copied(); + + while let Some(batch) = stream.next().await { + let batch = batch?; + let batch_start = next_row_number; + let batch_end = batch_start + batch.num_rows() as u64; + + let mut indices = Vec::new(); + let mut ranks = Vec::new(); + for (row_number, rank) in rows.range(batch_start..batch_end) { + indices.push(u32::try_from(*row_number - batch_start).map_err(|_| { + DataFusionError::Internal( + "late materialization batch row index overflowed u32".to_string(), + ) + })?); + ranks.push(*rank); + } + + if !indices.is_empty() { + let indices = UInt32Array::from(indices); + let columns = batch + .columns() + .iter() + .map(|column| take(column.as_ref(), &indices, None)) + .collect::, _>>()?; + selected_batches + .push(RecordBatch::try_new(Arc::clone(&schema), columns)?); + found_in_partition += ranks.len(); + selected_ranks.extend(ranks); + } + + next_row_number = batch_end; + if found_in_partition >= rows_to_find + || max_row_number.is_some_and(|max| next_row_number > max) + { + break; + } + } + } + + if selected_ranks.len() != selected_rows.row_count { + return Err(DataFusionError::Internal(format!( + "late materialization found {} of {} selected rows", + selected_ranks.len(), + selected_rows.row_count + ))); + } + + reorder_selected_batches( + schema, + &selected_batches, + selected_ranks, + selected_rows.row_count, + ) +} + +async fn materialize_with_pushed_down_file_rows( + full_input: &Arc, + context: Arc, + schema: SchemaRef, + selected_rows: &SelectedRows, +) -> Result>> { + if selected_rows.row_count == 0 { + return Ok(Some(vec![])); + } + + let Some(selected_input) = selected_file_scan(full_input, selected_rows)? else { + return Ok(None); + }; + + let partition_count = selected_input.output_partitioning().partition_count(); + let mut selected_batches = Vec::new(); + let mut selected_ranks = Vec::with_capacity(selected_rows.row_count); + + for partition in 0..partition_count { + let Some(rows) = selected_rows.by_partition.get(&partition) else { + continue; + }; + + let mut stream = selected_input.execute(partition, Arc::clone(&context))?; + let mut partition_row_count = 0_usize; + while let Some(batch) = stream.next().await { + let batch = batch?; + partition_row_count += batch.num_rows(); + selected_batches.push(batch); + } + + if partition_row_count != rows.len() { + return Err(DataFusionError::Internal(format!( + "late materialization pushed-down file row selection returned {partition_row_count} rows for partition {partition}, expected {}", + rows.len() + ))); + } + selected_ranks.extend(rows.values().copied()); + } + + reorder_selected_batches( + schema, + &selected_batches, + selected_ranks, + selected_rows.row_count, + ) + .map(Some) +} + +fn reorder_selected_batches( + schema: SchemaRef, + selected_batches: &[RecordBatch], + selected_ranks: Vec, + row_count: usize, +) -> Result> { + if row_count == 0 { + return Ok(vec![]); + } + + let concatenated = concat_batches(&schema, selected_batches)?; + let mut rank_to_row = selected_ranks + .into_iter() + .enumerate() + .map(|(row_index, rank)| (rank, row_index)) + .collect::>(); + rank_to_row.sort_by_key(|(rank, _)| *rank); + let take_indices = rank_to_row + .into_iter() + .map(|(_, row_index)| { + u32::try_from(row_index).map_err(|_| { + DataFusionError::Internal( + "late materialization output row index overflowed u32".to_string(), + ) + }) + }) + .collect::>>()?; + let take_indices = UInt32Array::from(take_indices); + let columns = concatenated + .columns() + .iter() + .map(|column| take(column.as_ref(), &take_indices, None)) + .collect::, _>>()?; + Ok(vec![RecordBatch::try_new(schema, columns)?]) +} + +fn selected_file_scan( + full_input: &Arc, + selected_rows: &SelectedRows, +) -> Result>> { + let Some(data_source_exec) = full_input.downcast_ref::() else { + return Ok(None); + }; + let Some(file_scan_config) = data_source_exec + .data_source() + .downcast_ref::() + else { + return Ok(None); + }; + + if file_scan_config.file_source().file_type() != "parquet" + || file_scan_config.file_source().filter().is_some() + || file_scan_config.limit.is_some() + { + return Ok(None); + } + + let mut selected_file_groups = Vec::with_capacity(file_scan_config.file_groups.len()); + for (partition, file_group) in file_scan_config.file_groups.iter().enumerate() { + let selected_files = match selected_rows.by_partition.get(&partition) { + Some(rows) => { + let Some(selected_files) = + selected_files_for_partition(file_group, rows)? + else { + return Ok(None); + }; + selected_files + } + None => Vec::new(), + }; + selected_file_groups.push(FileGroup::new(selected_files)); + } + + let selected_config = FileScanConfigBuilder::from(file_scan_config.clone()) + .with_file_groups(selected_file_groups) + .build(); + Ok(Some(Arc::new(DataSourceExec::new(Arc::new( + selected_config, + ))))) +} + +fn selected_files_for_partition( + file_group: &FileGroup, + rows: &BTreeMap, +) -> Result>> { + if rows.is_empty() { + return Ok(Some(vec![])); + } + + let mut selected_files = Vec::new(); + let mut file_start = 0_u64; + let max_row_number = *rows.keys().next_back().expect("rows is not empty"); + + for file in file_group.iter() { + if file_start > max_row_number { + break; + } + + if file.range.is_some() || !file.extensions.is_empty() { + return Ok(None); + } + + let Some(file_row_count) = exact_file_row_count(file)? else { + return Ok(None); + }; + let file_end = file_start.checked_add(file_row_count).ok_or_else(|| { + DataFusionError::Internal( + "late materialization file row count overflowed u64".to_string(), + ) + })?; + + let selected_file_rows = rows + .range(file_start..file_end) + .map(|(row_number, _)| row_number - file_start) + .collect::>(); + if !selected_file_rows.is_empty() { + selected_files.push( + file.clone() + .with_extension(FileRowsSelection::new(selected_file_rows)), + ); + } + + file_start = file_end; + } + + if max_row_number >= file_start { + return Err(DataFusionError::Internal(format!( + "late materialization selected row {max_row_number}, but partition has {file_start} known file rows" + ))); + } + + Ok(Some(selected_files)) +} + +fn exact_file_row_count(file: &PartitionedFile) -> Result> { + let Some(statistics) = &file.statistics else { + return Ok(None); + }; + let Precision::Exact(row_count) = &statistics.num_rows else { + return Ok(None); + }; + u64::try_from(*row_count).map(Some).map_err(|_| { + DataFusionError::Internal( + "late materialization exact file row count overflowed u64".to_string(), + ) + }) +} + +fn simple_key_columns(sort_exprs: &LexOrdering) -> Result> { + let mut seen = HashSet::new(); + let mut columns = Vec::new(); + for sort_expr in sort_exprs { + let Some(column) = sort_expr.expr.downcast_ref::() else { + return Ok(vec![]); + }; + if seen.insert(column.index()) { + columns.push(column.clone()); + } + } + Ok(columns) +} + +fn remap_sort_exprs( + sort_exprs: &LexOrdering, + key_columns: &[Column], +) -> Result { + let index_map = key_columns + .iter() + .enumerate() + .map(|(key_index, column)| { + (column.index(), (key_index, column.name().to_string())) + }) + .collect::>(); + + let sort_exprs = sort_exprs + .iter() + .map(|sort_expr| { + let column = sort_expr + .expr + .downcast_ref::() + .expect("validated by simple_key_columns"); + let (key_index, name) = index_map.get(&column.index()).ok_or_else(|| { + DataFusionError::Internal(format!( + "sort column {column} was missing from late materialization key projection" + )) + })?; + Ok(PhysicalSortExpr { + expr: Arc::new(Column::new(name, *key_index)), + options: sort_expr.options, + }) + }) + .collect::>>()?; + Ok(LexOrdering::new(sort_exprs).expect("sort exprs are not empty")) +} diff --git a/datafusion/physical-optimizer/src/lib.rs b/datafusion/physical-optimizer/src/lib.rs index 5fac8948b7f04..1702ce44a53e6 100644 --- a/datafusion/physical-optimizer/src/lib.rs +++ b/datafusion/physical-optimizer/src/lib.rs @@ -32,6 +32,7 @@ pub mod enforce_sorting; pub mod ensure_coop; pub mod filter_pushdown; pub mod join_selection; +pub mod late_materialization; pub mod limit_pushdown; pub mod limit_pushdown_past_window; pub mod limited_distinct_aggregation; diff --git a/datafusion/physical-optimizer/src/optimizer.rs b/datafusion/physical-optimizer/src/optimizer.rs index 05df642f8446b..22db76517db01 100644 --- a/datafusion/physical-optimizer/src/optimizer.rs +++ b/datafusion/physical-optimizer/src/optimizer.rs @@ -27,6 +27,7 @@ use crate::enforce_sorting::EnforceSorting; use crate::ensure_coop::EnsureCooperative; use crate::filter_pushdown::FilterPushdown; use crate::join_selection::JoinSelection; +use crate::late_materialization::LateMaterialization; use crate::limit_pushdown::LimitPushdown; use crate::limited_distinct_aggregation::LimitedDistinctAggregation; use crate::output_requirements::OutputRequirements; @@ -223,6 +224,10 @@ impl PhysicalOptimizer { // are not present, the load of executors such as join or union will be // reduced by narrowing their input tables. Arc::new(ProjectionPushdown::new()), + // Late materialization for simple TopK scans. Run before sort + // pushdown so hidden row numbers continue to reference the scan's + // original file order. + Arc::new(LateMaterialization::new()), // PushdownSort: Detect sorts that can be pushed down to data sources. Arc::new(PushdownSort::new()), Arc::new(EnsureCooperative::new()), diff --git a/docs/source/user-guide/configs.md b/docs/source/user-guide/configs.md index 576137bda29d1..e68cdfc63116d 100644 --- a/docs/source/user-guide/configs.md +++ b/docs/source/user-guide/configs.md @@ -144,6 +144,7 @@ The following configuration settings are available: | datafusion.optimizer.enable_window_topn | false | When set to true, the optimizer will replace Filter(rn<=K) → Window(ROW_NUMBER) → Sort patterns with a PartitionedTopKExec that maintains per-partition heaps, avoiding a full sort of the input. When the window partition key has low cardinality, enabling this optimization can improve performance. However, for high cardinality keys, it may cause regressions in both memory usage and runtime. | | datafusion.optimizer.enable_topk_repartition | true | When set to true, the optimizer will push TopK (Sort with fetch) below hash repartition when the partition key is a prefix of the sort key, reducing data volume before the shuffle. | | datafusion.optimizer.enable_topk_dynamic_filter_pushdown | true | When set to true, the optimizer will attempt to push down TopK dynamic filters into the file scan phase. | +| datafusion.optimizer.enable_row_number_topk_late_materialization | true | When set to true, the physical optimizer will try to evaluate simple ORDER BY ... LIMIT queries over a narrow key-only input first, carrying hidden row numbers and materializing the full rows only after the TopK has been computed. | | datafusion.optimizer.enable_join_dynamic_filter_pushdown | true | When set to true, the optimizer will attempt to push down Join dynamic filters into the file scan phase. | | datafusion.optimizer.enable_aggregate_dynamic_filter_pushdown | true | When set to true, the optimizer will attempt to push down Aggregate dynamic filters into the file scan phase. | | datafusion.optimizer.enable_dynamic_filter_pushdown | true | When set to true attempts to push down dynamic filters generated by operators (TopK, Join & Aggregate) into the file scan phase. For example, for a query such as `SELECT * FROM t ORDER BY timestamp DESC LIMIT 10`, the optimizer will attempt to push down the current top 10 timestamps that the TopK operator references into the file scans. This means that if we already have 10 timestamps in the year 2025 any files that only have timestamps in the year 2024 can be skipped / pruned at various stages in the scan. The config will suppress `enable_join_dynamic_filter_pushdown`, `enable_topk_dynamic_filter_pushdown` & `enable_aggregate_dynamic_filter_pushdown` So if you disable `enable_topk_dynamic_filter_pushdown`, then enable `enable_dynamic_filter_pushdown`, the `enable_topk_dynamic_filter_pushdown` will be overridden. | From c5da79763b3c6f0ba786e43358bbc49af734370f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20Heres?= Date: Fri, 29 May 2026 13:37:34 +0200 Subject: [PATCH 2/2] Support filtered TopK late materialization --- .../late_materialization.rs | 208 ++++- datafusion/datasource-parquet/src/source.rs | 64 +- datafusion/datasource/src/file.rs | 27 + .../src/late_materialization.rs | 744 ++++++++++++++++-- 4 files changed, 994 insertions(+), 49 deletions(-) diff --git a/datafusion/core/tests/physical_optimizer/late_materialization.rs b/datafusion/core/tests/physical_optimizer/late_materialization.rs index 52de3bbdd757f..6138ccc47fc28 100644 --- a/datafusion/core/tests/physical_optimizer/late_materialization.rs +++ b/datafusion/core/tests/physical_optimizer/late_materialization.rs @@ -19,7 +19,7 @@ use std::sync::Arc; use crate::physical_optimizer::test_utils::{parquet_exec, schema}; -use arrow::array::{Int64Array, StringArray}; +use arrow::array::{Int64Array, StringArray, UInt16Array}; use arrow::compute::SortOptions; use arrow::datatypes::{DataType, Field, Schema}; use arrow::record_batch::RecordBatch; @@ -31,6 +31,7 @@ use datafusion_physical_expr::expressions::col; use datafusion_physical_expr_common::sort_expr::{LexOrdering, PhysicalSortExpr}; use datafusion_physical_optimizer::PhysicalOptimizerRule; use datafusion_physical_optimizer::late_materialization::LateMaterialization; +use datafusion_physical_plan::filter::FilterExec; use datafusion_physical_plan::sorts::sort::SortExec; use datafusion_physical_plan::{ExecutionPlan, displayable}; use parquet::arrow::ArrowWriter; @@ -63,6 +64,18 @@ fn skips_rewrite_when_disabled() -> Result<()> { Ok(()) } +#[test] +fn rewrites_filtered_topk_when_output_is_wide() -> Result<()> { + let plan = filtered_topk_plan()?; + let optimized = LateMaterialization::new().optimize(plan, &ConfigOptions::new())?; + let display = displayable(optimized.as_ref()).indent(false).to_string(); + + assert_contains!(&display, "LateTopKMaterializationExec: fetch=10"); + assert_contains!(&display, "FilterExec: e@1"); + assert_contains!(&display, "RowNumberExec"); + Ok(()) +} + #[tokio::test] async fn sql_parquet_select_star_order_by_limit_uses_late_materialization() -> Result<()> { @@ -100,6 +113,118 @@ async fn sql_parquet_select_star_order_by_limit_uses_late_materialization() -> R Ok(()) } +#[tokio::test] +async fn sql_parquet_filtered_select_star_order_by_limit_uses_late_materialization() +-> Result<()> { + let test_file = make_topk_parquet()?; + let ctx = topk_context(&test_file, true).await?; + + let dataframe = ctx + .sql("SELECT * FROM t WHERE payload <> 'p19' ORDER BY c LIMIT 10") + .await?; + let plan = dataframe.create_physical_plan().await?; + let display = displayable(plan.as_ref()).indent(false).to_string(); + + assert_contains!(&display, "LateTopKMaterializationExec: fetch=10"); + assert_contains!(&display, "FilterExec:"); + + let batches = dataframe.collect().await?; + assert_filtered_topk_batches(&batches); + + Ok(()) +} + +#[tokio::test] +async fn sql_parquet_filtered_byte_range_select_star_order_by_limit_uses_late_materialization() +-> Result<()> { + let test_file = make_topk_parquet()?; + let config = topk_config(true) + .with_target_partitions(2) + .with_repartition_file_min_size(1); + let ctx = topk_context_with_config(&test_file, config).await?; + + let dataframe = ctx + .sql("SELECT * FROM t WHERE payload <> 'p19' ORDER BY c LIMIT 10") + .await?; + let plan = dataframe.create_physical_plan().await?; + let display = displayable(plan.as_ref()).indent(false).to_string(); + + assert_contains!(&display, "LateTopKMaterializationExec: fetch=10"); + assert_contains!(&display, "PartitionColumnExec"); + assert!(!display.contains("RowNumberExec")); + + let batches = dataframe.collect().await?; + assert_filtered_topk_batches(&batches); + + Ok(()) +} + +#[tokio::test] +async fn sql_parquet_filtered_projection_rebinds_to_full_schema() -> Result<()> { + let test_file = make_topk_parquet_with_date()?; + let ctx = topk_context(&test_file, true).await?; + + let dataframe = ctx + .sql("SELECT payload, id, c FROM t WHERE payload <> 'p19' ORDER BY c LIMIT 10") + .await?; + let plan = dataframe.create_physical_plan().await?; + let display = displayable(plan.as_ref()).indent(false).to_string(); + + assert_contains!(&display, "LateTopKMaterializationExec: fetch=10"); + + let batches = dataframe.collect().await?; + assert_batches_eq!( + [ + "+---------+----+----+", + "| payload | id | c |", + "+---------+----+----+", + "| p18 | 18 | 1 |", + "| p17 | 17 | 2 |", + "| p16 | 16 | 3 |", + "| p15 | 15 | 4 |", + "| p14 | 14 | 5 |", + "| p13 | 13 | 6 |", + "| p12 | 12 | 7 |", + "| p11 | 11 | 8 |", + "| p10 | 10 | 9 |", + "| p9 | 9 | 10 |", + "+---------+----+----+", + ], + &batches + ); + + Ok(()) +} + +#[tokio::test] +async fn sql_parquet_filtered_view_cast_order_by_limit_uses_late_materialization() +-> Result<()> { + let test_file = make_topk_parquet_with_date()?; + let ctx = topk_context(&test_file, true).await?; + ctx.sql( + "CREATE VIEW v AS \ + SELECT * EXCEPT(event_date), \ + CAST(CAST(event_date AS INTEGER) AS DATE) AS event_date FROM t", + ) + .await? + .collect() + .await?; + + let dataframe = ctx + .sql("SELECT * FROM v WHERE payload <> 'p19' ORDER BY c LIMIT 10") + .await?; + let plan = dataframe.create_physical_plan().await?; + let display = displayable(plan.as_ref()).indent(false).to_string(); + + assert_contains!(&display, "LateTopKMaterializationExec: fetch=10"); + + let batches = dataframe.collect().await?; + let row_count = batches.iter().map(|batch| batch.num_rows()).sum::(); + assert_eq!(row_count, 10); + + Ok(()) +} + #[tokio::test] async fn sql_parquet_select_star_order_by_limit_respects_disabled_config() -> Result<()> { let test_file = make_topk_parquet()?; @@ -129,14 +254,39 @@ fn topk_plan() -> Result> { )) } +fn filtered_topk_plan() -> Result> { + let schema = schema(); + let scan = parquet_exec(Arc::clone(&schema)); + let filter = Arc::new(FilterExec::try_new(col("e", &schema)?, scan)?); + let sort_exprs = LexOrdering::new(vec![PhysicalSortExpr { + expr: col("c", &schema)?, + options: SortOptions::default(), + }]) + .unwrap(); + + Ok(Arc::new( + SortExec::new(sort_exprs, filter).with_fetch(Some(10)), + )) +} + async fn topk_context( test_file: &NamedTempFile, late_materialization_enabled: bool, ) -> Result { - let config = SessionConfig::new().with_collect_statistics(true).set_bool( + topk_context_with_config(test_file, topk_config(late_materialization_enabled)).await +} + +fn topk_config(late_materialization_enabled: bool) -> SessionConfig { + SessionConfig::new().with_collect_statistics(true).set_bool( "datafusion.optimizer.enable_row_number_topk_late_materialization", late_materialization_enabled, - ); + ) +} + +async fn topk_context_with_config( + test_file: &NamedTempFile, + config: SessionConfig, +) -> Result { let ctx = SessionContext::new_with_config(config); ctx.register_parquet( "t", @@ -147,6 +297,28 @@ async fn topk_context( Ok(ctx) } +fn assert_filtered_topk_batches(batches: &[RecordBatch]) { + assert_batches_eq!( + [ + "+----+----+---------+", + "| id | c | payload |", + "+----+----+---------+", + "| 18 | 1 | p18 |", + "| 17 | 2 | p17 |", + "| 16 | 3 | p16 |", + "| 15 | 4 | p15 |", + "| 14 | 5 | p14 |", + "| 13 | 6 | p13 |", + "| 12 | 7 | p12 |", + "| 11 | 8 | p11 |", + "| 10 | 9 | p10 |", + "| 9 | 10 | p9 |", + "+----+----+---------+", + ], + batches + ); +} + fn make_topk_parquet() -> Result { let mut test_file = tempfile::Builder::new().suffix(".parquet").tempfile()?; let schema = Arc::new(Schema::new(vec![ @@ -173,3 +345,33 @@ fn make_topk_parquet() -> Result { writer.close()?; Ok(test_file) } + +fn make_topk_parquet_with_date() -> Result { + let mut test_file = tempfile::Builder::new().suffix(".parquet").tempfile()?; + let schema = Arc::new(Schema::new(vec![ + Field::new("id", DataType::Int64, false), + Field::new("c", DataType::Int64, false), + Field::new("event_date", DataType::UInt16, false), + Field::new("payload", DataType::Utf8, false), + ])); + let ids = (0..20).collect::>(); + let sort_values = (0..20).rev().collect::>(); + let dates = (0..20).map(|value| value as u16).collect::>(); + let payloads = (0..20).map(|value| format!("p{value}")).collect::>(); + let batch = RecordBatch::try_new( + Arc::clone(&schema), + vec![ + Arc::new(Int64Array::from(ids)), + Arc::new(Int64Array::from(sort_values)), + Arc::new(UInt16Array::from(dates)), + Arc::new(StringArray::from(payloads)), + ], + )?; + let props = WriterProperties::builder() + .set_max_row_group_row_count(Some(5)) + .build(); + let mut writer = ArrowWriter::try_new(&mut test_file, schema, Some(props))?; + writer.write(&batch)?; + writer.close()?; + Ok(test_file) +} diff --git a/datafusion/datasource-parquet/src/source.rs b/datafusion/datasource-parquet/src/source.rs index 8228cd273eae6..481f88fa265db 100644 --- a/datafusion/datasource-parquet/src/source.rs +++ b/datafusion/datasource-parquet/src/source.rs @@ -34,7 +34,7 @@ use datafusion_datasource::file_stream::FileOpener; use datafusion_datasource::morsel::Morselizer; use arrow::array::timezone::Tz; -use arrow::datatypes::TimeUnit; +use arrow::datatypes::{DataType, Field, FieldRef, TimeUnit}; use datafusion_common::DataFusionError; use datafusion_common::config::TableParquetOptions; use datafusion_datasource::TableSchema; @@ -60,6 +60,7 @@ use datafusion_execution::parquet_encryption::EncryptionFactory; use datafusion_physical_expr_common::sort_expr::{LexOrdering, PhysicalSortExpr}; use itertools::Itertools; use object_store::ObjectStore; +use parquet::arrow::RowNumber; #[cfg(feature = "parquet_encryption")] use parquet::encryption::decrypt::FileDecryptionProperties; @@ -658,6 +659,67 @@ impl FileSource for ParquetSource { self.predicate.clone() } + fn without_filter(&self) -> Option> { + let mut source = self.clone(); + source.predicate = None; + source = source.with_pushdown_filters(false); + Some(Arc::new(source)) + } + + fn without_filter_and_projection(&self) -> Option> { + let mut source = self.clone(); + source.predicate = None; + source = source.with_pushdown_filters(false); + let full_schema = source.table_schema.table_schema(); + let indices = (0..full_schema.fields().len()).collect::>(); + source.projection = ProjectionExprs::from_indices(&indices, full_schema); + Some(Arc::new(source)) + } + + fn with_row_number_column( + &self, + column_name: &str, + ) -> datafusion_common::Result, usize)>> { + if self + .table_schema + .table_schema() + .fields() + .iter() + .any(|field| field.name() == column_name) + { + return Ok(None); + } + + let mut source = self.clone(); + source.predicate = None; + source = source.with_pushdown_filters(false); + + let row_number_field: FieldRef = Arc::new( + Field::new(column_name, DataType::Int64, false) + .with_extension_type(RowNumber), + ); + let mut virtual_columns = source + .table_schema + .virtual_columns() + .iter() + .cloned() + .collect::>(); + virtual_columns.push(row_number_field); + source.table_schema = + TableSchema::builder(Arc::clone(source.table_schema.file_schema())) + .with_table_partition_cols( + source.table_schema.table_partition_cols().clone(), + ) + .with_virtual_columns(virtual_columns) + .build(); + + let full_schema = source.table_schema.table_schema(); + let row_number_index = full_schema.fields().len() - 1; + let indices = (0..full_schema.fields().len()).collect::>(); + source.projection = ProjectionExprs::from_indices(&indices, full_schema); + Ok(Some((Arc::new(source), row_number_index))) + } + fn with_batch_size(&self, batch_size: usize) -> Arc { let mut conf = self.clone(); conf.batch_size = Some(batch_size); diff --git a/datafusion/datasource/src/file.rs b/datafusion/datasource/src/file.rs index 07460b23694b7..fb422a63cd86d 100644 --- a/datafusion/datasource/src/file.rs +++ b/datafusion/datasource/src/file.rs @@ -113,6 +113,33 @@ pub trait FileSource: Any + Send + Sync { None } + /// Return a copy of this [`FileSource`] without scan-time filters. + /// + /// Sources that expose [`Self::filter`] should override this if callers can + /// safely rebuild the same scan without predicate pruning or row filtering. + fn without_filter(&self) -> Option> { + None + } + + /// Return a copy of this [`FileSource`] without scan-time filters and with + /// its projection reset to the full table schema. + fn without_filter_and_projection(&self) -> Option> { + None + } + + /// Return a copy of this [`FileSource`] with an absolute file row-number + /// column appended to the table schema and included in its projection. + /// + /// The returned `usize` is the row-number column index in the returned + /// source's unprojected table schema. The row numbers must identify rows + /// within each physical file, not within an execution partition. + fn with_row_number_column( + &self, + _column_name: &str, + ) -> Result, usize)>> { + Ok(None) + } + /// Return the projection that will be applied to the output stream on top /// of [`Self::table_schema`]. /// diff --git a/datafusion/physical-optimizer/src/late_materialization.rs b/datafusion/physical-optimizer/src/late_materialization.rs index 3342bd1d0c01c..1c7a785960c8d 100644 --- a/datafusion/physical-optimizer/src/late_materialization.rs +++ b/datafusion/physical-optimizer/src/late_materialization.rs @@ -21,8 +21,8 @@ use std::collections::{BTreeMap, HashMap, HashSet}; use std::fmt::{Debug, Formatter}; use std::sync::Arc; -use arrow::array::{ArrayRef, UInt32Array, UInt64Array}; -use arrow::compute::{concat_batches, take}; +use arrow::array::{ArrayRef, Int64Array, UInt32Array, UInt64Array}; +use arrow::compute::{cast, concat_batches, take}; use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; use arrow::record_batch::RecordBatch; use datafusion_common::config::ConfigOptions; @@ -35,6 +35,7 @@ use datafusion_datasource::source::DataSourceExec; use datafusion_datasource::{FileRowsSelection, PartitionedFile}; use datafusion_execution::{SendableRecordBatchStream, TaskContext}; use datafusion_physical_expr::expressions::Column; +use datafusion_physical_expr::utils::{collect_columns, reassign_expr_columns}; use datafusion_physical_expr::{EquivalenceProperties, Partitioning}; use datafusion_physical_expr_common::physical_expr::PhysicalExpr; use datafusion_physical_expr_common::sort_expr::{LexOrdering, PhysicalSortExpr}; @@ -42,8 +43,12 @@ use datafusion_physical_plan::coalesce_partitions::CoalescePartitionsExec; use datafusion_physical_plan::execution_plan::{ Boundedness, EmissionType, collect, reset_plan_states, }; +use datafusion_physical_plan::filter::FilterExec; use datafusion_physical_plan::metrics::{ExecutionPlanMetricsSet, MetricsSet}; -use datafusion_physical_plan::projection::ProjectionExec; +use datafusion_physical_plan::projection::{ + ProjectionExec, ProjectionExpr, ProjectionExprs, +}; +use datafusion_physical_plan::repartition::RepartitionExec; use datafusion_physical_plan::sorts::sort::SortExec; use datafusion_physical_plan::sorts::sort_preserving_merge::SortPreservingMergeExec; use datafusion_physical_plan::stream::RecordBatchStreamAdapter; @@ -130,6 +135,8 @@ impl PhysicalOptimizerRule for LateMaterialization { struct LateTopKMaterializationExec { key_input: Arc, full_input: Arc, + output_projection: Option>, + row_number_mode: RowNumberMode, sort_exprs: LexOrdering, key_sort_exprs: LexOrdering, fetch: usize, @@ -138,6 +145,21 @@ struct LateTopKMaterializationExec { metrics: ExecutionPlanMetricsSet, } +#[derive(Debug, Clone, Copy)] +enum RowNumberMode { + /// Row numbers are contiguous within the file-group execution partition. + Partition, + /// Row numbers are absolute within the physical file. + File, +} + +type PlanPair = (Arc, Arc); +type NumberedKeyInput = ( + Arc, + Arc, + RowNumberMode, +); + impl LateTopKMaterializationExec { fn try_new( input: &Arc, @@ -148,9 +170,16 @@ impl LateTopKMaterializationExec { return Ok(None); } + if let Some(exec) = Self::try_new_filtered(input, sort_exprs.clone(), fetch)? { + return Ok(Some(exec)); + } + let Some(full_input) = input.with_preserve_order(true) else { return Ok(None); }; + if !supports_pushed_down_file_rows(&full_input) { + return Ok(None); + } let key_columns = simple_key_columns(&sort_exprs)?; if key_columns.is_empty() { @@ -181,13 +210,16 @@ impl LateTopKMaterializationExec { let Some(key_input) = key_input.with_preserve_order(true) else { return Ok(None); }; + let key_input = Arc::new(RowNumberExec::new(key_input)) as Arc; let key_sort_exprs = remap_sort_exprs(&sort_exprs, &key_columns)?; - let cache = Self::compute_properties(&full_input, &sort_exprs); + let cache = Self::compute_properties(&full_input, None, &sort_exprs)?; Ok(Some(Self { key_input, full_input, + output_projection: None, + row_number_mode: RowNumberMode::Partition, sort_exprs, key_sort_exprs, fetch, @@ -197,24 +229,146 @@ impl LateTopKMaterializationExec { })) } + fn try_new_filtered( + input: &Arc, + sort_exprs: LexOrdering, + fetch: usize, + ) -> Result> { + let (mut output_projection, filter_input, mut filter_sort_exprs) = + if let Some(projection) = input.downcast_ref::() { + let Some(sort_exprs) = + unproject_ordering(sort_exprs.clone(), projection.expr())? + else { + return Ok(None); + }; + ( + Some(projection.expr().to_vec()), + projection.input(), + sort_exprs, + ) + } else { + (None, input, sort_exprs.clone()) + }; + + let Some(filter) = filter_input.downcast_ref::() else { + return Ok(None); + }; + if filter.fetch().is_some() { + return Ok(None); + } + if let Some(projection) = filter.projection() { + let filter_projection = + projection_exprs_from_indices(projection, &filter.input().schema()); + let Some(sort_exprs) = + unproject_ordering(filter_sort_exprs.clone(), &filter_projection)? + else { + return Ok(None); + }; + filter_sort_exprs = sort_exprs; + output_projection.get_or_insert(filter_projection); + } + + let key_columns = simple_key_columns(&filter_sort_exprs)?; + if key_columns.is_empty() { + return Ok(None); + } + + let mut required_columns = key_columns.clone(); + required_columns.extend(collect_columns(filter.predicate())); + required_columns.sort_by_key(|column| column.index()); + required_columns.dedup_by_key(|column| column.index()); + + // The rewrite only helps when the first pass is meaningfully narrower + // than the final materialized rows. This includes ClickBench Q23 and + // excludes narrow projections such as Q24. + if input.schema().fields().len() <= required_columns.len() { + return Ok(None); + } + + let Some((numbered_input, full_input, row_number_mode)) = + numbered_key_input(filter.input(), &required_columns)? + else { + return Ok(None); + }; + let full_schema = full_input.schema(); + let output_projection = output_projection + .map(|projection| reassign_projection_exprs(projection, full_schema.as_ref())) + .transpose()?; + + let remapped_predicate = reassign_expr_columns( + Arc::clone(filter.predicate()), + &numbered_input.schema(), + )?; + let filtered = Arc::new(FilterExec::try_new(remapped_predicate, numbered_input)?) + as Arc; + + let key_width = key_columns.len(); + let numbered_width = required_columns.len(); + let mut key_projection_exprs = key_columns + .iter() + .map(|column| { + let index = required_columns + .iter() + .position(|required| required.index() == column.index()) + .expect("sort key column is required"); + ( + Arc::new(Column::new(column.name(), index)) as Arc, + column.name().to_string(), + ) + }) + .collect::>(); + key_projection_exprs.push(( + Arc::new(Column::new(PARTITION_COLUMN, numbered_width)) + as Arc, + PARTITION_COLUMN.to_string(), + )); + key_projection_exprs.push(( + Arc::new(Column::new(ROW_NUMBER_COLUMN, numbered_width + 1)) + as Arc, + ROW_NUMBER_COLUMN.to_string(), + )); + let key_input = Arc::new(ProjectionExec::try_new(key_projection_exprs, filtered)?) + as Arc; + + let key_sort_exprs = remap_sort_exprs(&filter_sort_exprs, &key_columns)?; + let cache = Self::compute_properties( + &full_input, + output_projection.as_deref(), + &sort_exprs, + )?; + + Ok(Some(Self { + key_input, + full_input, + output_projection, + row_number_mode, + sort_exprs, + key_sort_exprs, + fetch, + key_width, + cache: Arc::new(cache), + metrics: ExecutionPlanMetricsSet::new(), + })) + } + fn compute_properties( full_input: &Arc, + output_projection: Option<&[ProjectionExpr]>, sort_exprs: &LexOrdering, - ) -> PlanProperties { - let schema = full_input.schema(); + ) -> Result { + let schema = output_schema(full_input, output_projection)?; let eq_properties = EquivalenceProperties::new_with_orderings(schema, vec![sort_exprs.to_vec()]); - PlanProperties::new( + Ok(PlanProperties::new( eq_properties, Partitioning::UnknownPartitioning(1), EmissionType::Final, Boundedness::Bounded, - ) + )) } fn key_topk_plan(&self) -> Arc { - let key_input = Arc::new(RowNumberExec::new(Arc::clone(&self.key_input))) - as Arc; + let key_input = Arc::clone(&self.key_input); let key_input = if key_input.output_partitioning().partition_count() > 1 { Arc::new(CoalescePartitionsExec::new(key_input)) as Arc } else { @@ -272,8 +426,9 @@ impl ExecutionPlan for LateTopKMaterializationExec { new_exec.full_input = Arc::clone(&children[1]); new_exec.cache = Arc::new(Self::compute_properties( &new_exec.full_input, + new_exec.output_projection.as_deref(), &new_exec.sort_exprs, - )); + )?); Ok(Arc::new(new_exec)) } @@ -292,6 +447,8 @@ impl ExecutionPlan for LateTopKMaterializationExec { let full_input = reset_plan_states(Arc::clone(&self.full_input))?; let output_schema = self.schema(); let key_width = self.key_width; + let output_projection = self.output_projection.clone(); + let row_number_mode = self.row_number_mode; let context = Arc::clone(&context); let batches = stream::once(async move { @@ -302,14 +459,23 @@ impl ExecutionPlan for LateTopKMaterializationExec { &full_input, Arc::clone(&context), Arc::clone(&output_schema), + output_projection.as_deref(), + row_number_mode, &selected_rows, ) .await? { return Ok(batches); } - materialize_selected_rows(full_input, context, output_schema, &selected_rows) - .await + materialize_selected_rows( + full_input, + context, + output_schema, + output_projection.as_deref(), + row_number_mode, + &selected_rows, + ) + .await }) .map_ok(|batches| stream::iter(batches.into_iter().map(Ok))) .try_flatten(); @@ -329,11 +495,7 @@ impl ExecutionPlan for LateTopKMaterializationExec { return Ok(Arc::new(Statistics::new_unknown(&self.schema()))); } Ok(Arc::new( - self.full_input - .partition_statistics(None)? - .as_ref() - .clone() - .with_fetch(Some(self.fetch), 0, 1)?, + Statistics::new_unknown(&self.schema()).with_fetch(Some(self.fetch), 0, 1)?, )) } } @@ -446,6 +608,374 @@ impl ExecutionPlan for RowNumberExec { } } +#[derive(Debug, Clone)] +struct PartitionColumnExec { + input: Arc, + cache: Arc, +} + +impl PartitionColumnExec { + fn new(input: Arc) -> Self { + let input_schema = input.schema(); + let mut fields = input_schema.fields().to_vec(); + fields.push(Arc::new(Field::new( + PARTITION_COLUMN, + DataType::UInt64, + false, + ))); + let schema = Arc::new(Schema::new_with_metadata( + fields, + input_schema.metadata().clone(), + )); + let cache = PlanProperties::new( + EquivalenceProperties::new(Arc::clone(&schema)), + input.output_partitioning().clone(), + input.pipeline_behavior(), + input.boundedness(), + ); + Self { + input, + cache: Arc::new(cache), + } + } +} + +impl DisplayAs for PartitionColumnExec { + fn fmt_as(&self, t: DisplayFormatType, f: &mut Formatter) -> std::fmt::Result { + match t { + DisplayFormatType::Default | DisplayFormatType::Verbose => { + write!(f, "PartitionColumnExec") + } + DisplayFormatType::TreeRender => write!(f, "PartitionColumnExec"), + } + } +} + +impl ExecutionPlan for PartitionColumnExec { + fn name(&self) -> &str { + "PartitionColumnExec" + } + + fn properties(&self) -> &Arc { + &self.cache + } + + fn children(&self) -> Vec<&Arc> { + vec![&self.input] + } + + fn with_new_children( + self: Arc, + children: Vec>, + ) -> Result> { + if children.len() != 1 { + return internal_err!("PartitionColumnExec requires exactly one child"); + } + Ok(Arc::new(Self::new(Arc::clone(&children[0])))) + } + + fn execute( + &self, + partition: usize, + context: Arc, + ) -> Result { + let input = self.input.execute(partition, context)?; + let schema = self.schema(); + let stream = input.map(move |batch| { + let batch = batch?; + let row_count = batch.num_rows(); + let partition_values = UInt64Array::from_value(partition as u64, row_count); + + let mut columns = batch.columns().to_vec(); + columns.push(Arc::new(partition_values) as ArrayRef); + Ok(RecordBatch::try_new(Arc::clone(&schema), columns)?) + }); + + Ok(Box::pin(RecordBatchStreamAdapter::new( + self.schema(), + stream, + ))) + } + + fn partition_statistics(&self, partition: Option) -> Result> { + if partition.is_some() { + return Ok(Arc::new(Statistics::new_unknown(&self.schema()))); + } + Ok(Arc::new(Statistics::new_unknown(&self.schema()))) + } +} + +fn output_schema( + full_input: &Arc, + output_projection: Option<&[ProjectionExpr]>, +) -> Result { + let Some(output_projection) = output_projection else { + return Ok(full_input.schema()); + }; + Ok( + ProjectionExec::try_new(output_projection.to_vec(), Arc::clone(full_input))? + .schema(), + ) +} + +fn reassign_projection_exprs( + projection: Vec, + schema: &Schema, +) -> Result> { + projection + .into_iter() + .map(|expr| { + Ok(ProjectionExpr { + expr: reassign_expr_columns(expr.expr, schema)?, + alias: expr.alias, + }) + }) + .collect() +} + +fn unproject_ordering( + ordering: LexOrdering, + projection: &[ProjectionExpr], +) -> Result> { + let projection = ProjectionExprs::new(projection.to_vec()); + let mut unprojected_exprs = Vec::with_capacity(ordering.len()); + for mut sort_expr in ordering { + let Ok(expr) = projection.unproject_expr(&sort_expr.expr) else { + return Ok(None); + }; + sort_expr.expr = expr; + unprojected_exprs.push(sort_expr); + } + Ok(LexOrdering::new(unprojected_exprs)) +} + +fn projection_exprs_from_indices( + indices: &[usize], + schema: &SchemaRef, +) -> Vec { + indices + .iter() + .map(|index| { + let field = schema.field(*index); + ProjectionExpr { + expr: Arc::new(Column::new(field.name(), *index)), + alias: field.name().to_string(), + } + }) + .collect() +} + +fn column_projection_expr( + column: &Column, + schema: &SchemaRef, +) -> Result<(Arc, String)> { + let index = schema.index_of(column.name())?; + Ok(( + Arc::new(Column::new(column.name(), index)) as Arc, + column.name().to_string(), + )) +} + +fn numbered_key_input( + input: &Arc, + required_columns: &[Column], +) -> Result> { + if let Some(file_scan_config) = file_scan_config(input) + && file_scan_config_has_ranges(file_scan_config) + { + let Some((key_input, full_input)) = + absolute_row_number_key_input(input, required_columns)? + else { + return Ok(None); + }; + return Ok(Some((key_input, full_input, RowNumberMode::File))); + } + + if let Some(source) = raw_file_scan(input)? { + let source_schema = source.schema(); + let projection_exprs = required_columns + .iter() + .map(|column| column_projection_expr(column, &source_schema)) + .collect::>>()?; + let projection = ProjectionExec::try_new(projection_exprs, Arc::clone(&source))?; + let Some(key_source) = source.try_swapping_with_projection(&projection)? else { + return Ok(None); + }; + let key_source = + Arc::new(RowNumberExec::new(key_source)) as Arc; + return Ok(Some((key_source, source, RowNumberMode::Partition))); + } + + if let Some(repartition) = input.downcast_ref::() { + let Partitioning::RoundRobinBatch(partition_count) = repartition.partitioning() + else { + return Ok(None); + }; + let Some((numbered_child, full_input, row_number_mode)) = + numbered_key_input(repartition.input(), required_columns)? + else { + return Ok(None); + }; + + let mut repartitioned = RepartitionExec::try_new( + numbered_child, + Partitioning::RoundRobinBatch(*partition_count), + )?; + if repartition.preserve_order() { + repartitioned = repartitioned.with_preserve_order(); + } + return Ok(Some((Arc::new(repartitioned), full_input, row_number_mode))); + } + + Ok(None) +} + +fn absolute_row_number_key_input( + input: &Arc, + required_columns: &[Column], +) -> Result> { + let Some(file_scan_config) = file_scan_config(input) else { + return Ok(None); + }; + if !file_scan_config_supports_file_row_numbers(file_scan_config) { + return Ok(None); + } + + let Some(full_input) = raw_file_scan(input)? else { + return Ok(None); + }; + let Some((row_number_source, row_number_index)) = + raw_file_scan_with_row_number(input)? + else { + return Ok(None); + }; + + let row_number_source_schema = row_number_source.schema(); + let mut projection_exprs = required_columns + .iter() + .map(|column| column_projection_expr(column, &row_number_source_schema)) + .collect::>>()?; + projection_exprs.push(( + Arc::new(Column::new(ROW_NUMBER_COLUMN, row_number_index)) + as Arc, + ROW_NUMBER_COLUMN.to_string(), + )); + let projection = + ProjectionExec::try_new(projection_exprs, Arc::clone(&row_number_source))?; + let Some(key_source) = row_number_source.try_swapping_with_projection(&projection)? + else { + return Ok(None); + }; + + let key_source = + Arc::new(PartitionColumnExec::new(key_source)) as Arc; + let numbered_width = required_columns.len(); + let mut reorder_exprs = required_columns + .iter() + .enumerate() + .map(|(index, column)| { + ( + Arc::new(Column::new(column.name(), index)) as Arc, + column.name().to_string(), + ) + }) + .collect::>(); + reorder_exprs.push(( + Arc::new(Column::new(PARTITION_COLUMN, numbered_width + 1)) + as Arc, + PARTITION_COLUMN.to_string(), + )); + reorder_exprs.push(( + Arc::new(Column::new(ROW_NUMBER_COLUMN, numbered_width)) as Arc, + ROW_NUMBER_COLUMN.to_string(), + )); + let key_input = Arc::new(ProjectionExec::try_new(reorder_exprs, key_source)?) as _; + + Ok(Some((key_input, full_input))) +} + +fn raw_file_scan( + input: &Arc, +) -> Result>> { + let Some(file_scan_config) = file_scan_config(input) else { + return Ok(None); + }; + if file_scan_config.file_source().file_type() != "parquet" + || file_scan_config.limit.is_some() + { + return Ok(None); + } + + let file_source = match file_scan_config + .file_source() + .without_filter_and_projection() + { + Some(source) => source, + None if file_scan_config.file_source().filter().is_none() => { + Arc::clone(file_scan_config.file_source()) + } + None => { + return Ok(None); + } + }; + let raw_config = FileScanConfigBuilder::from(file_scan_config.clone()) + .with_source(file_source) + .with_preserve_order(true) + .build(); + Ok(Some(Arc::new(DataSourceExec::new(Arc::new(raw_config))))) +} + +fn raw_file_scan_with_row_number( + input: &Arc, +) -> Result, usize)>> { + let Some(file_scan_config) = file_scan_config(input) else { + return Ok(None); + }; + if file_scan_config.file_source().file_type() != "parquet" + || file_scan_config.limit.is_some() + { + return Ok(None); + } + + let Some((file_source, row_number_index)) = file_scan_config + .file_source() + .with_row_number_column(ROW_NUMBER_COLUMN)? + else { + return Ok(None); + }; + let statistics = Statistics::new_unknown(file_source.table_schema().table_schema()); + let raw_config = FileScanConfigBuilder::from(file_scan_config.clone()) + .with_source(file_source) + .with_statistics(statistics) + .with_preserve_order(true) + .build(); + Ok(Some(( + Arc::new(DataSourceExec::new(Arc::new(raw_config))), + row_number_index, + ))) +} + +fn file_scan_config(input: &Arc) -> Option<&FileScanConfig> { + let data_source_exec = input.downcast_ref::()?; + data_source_exec + .data_source() + .downcast_ref::() +} + +fn file_scan_config_has_ranges(file_scan_config: &FileScanConfig) -> bool { + file_scan_config + .file_groups + .iter() + .flat_map(FileGroup::iter) + .any(|file| file.range.is_some()) +} + +fn file_scan_config_supports_file_row_numbers(file_scan_config: &FileScanConfig) -> bool { + file_scan_config.file_groups.iter().all(|file_group| { + file_group.len() <= 1 && file_group.iter().all(|file| file.extensions.is_empty()) + }) +} + #[derive(Debug)] struct SelectedRows { by_partition: HashMap>, @@ -473,15 +1003,15 @@ async fn collect_selected_rows( "late materialization partition column had wrong type".to_string(), ) })?; - let row_number_array = batch - .column(row_number_index) - .as_any() - .downcast_ref::() - .ok_or_else(|| { - DataFusionError::Internal( - "late materialization row number column had wrong type".to_string(), - ) - })?; + let row_number_array = batch.column(row_number_index); + let row_number_u64_array = + row_number_array.as_any().downcast_ref::(); + let row_number_i64_array = row_number_array.as_any().downcast_ref::(); + if row_number_u64_array.is_none() && row_number_i64_array.is_none() { + return Err(DataFusionError::Internal( + "late materialization row number column had wrong type".to_string(), + )); + } for row in 0..batch.num_rows() { let partition = @@ -491,7 +1021,15 @@ async fn collect_selected_rows( .to_string(), ) })?; - let row_number = row_number_array.value(row); + let row_number = match (row_number_u64_array, row_number_i64_array) { + (Some(array), _) => array.value(row), + (_, Some(array)) => u64::try_from(array.value(row)).map_err(|_| { + DataFusionError::Internal( + "late materialization row number was negative".to_string(), + ) + })?, + (None, None) => unreachable!("validated above"), + }; by_partition .entry(partition) .or_default() @@ -510,11 +1048,18 @@ async fn materialize_selected_rows( full_input: Arc, context: Arc, schema: SchemaRef, + output_projection: Option<&[ProjectionExpr]>, + row_number_mode: RowNumberMode, selected_rows: &SelectedRows, ) -> Result> { if selected_rows.row_count == 0 { return Ok(vec![]); } + if matches!(row_number_mode, RowNumberMode::File) { + return internal_err!( + "late materialization requires pushed-down file row selection for file row numbers" + ); + } let partition_count = full_input.output_partitioning().partition_count(); let mut selected_batches = Vec::new(); @@ -555,8 +1100,7 @@ async fn materialize_selected_rows( .iter() .map(|column| take(column.as_ref(), &indices, None)) .collect::, _>>()?; - selected_batches - .push(RecordBatch::try_new(Arc::clone(&schema), columns)?); + selected_batches.push(RecordBatch::try_new(batch.schema(), columns)?); found_in_partition += ranks.len(); selected_ranks.extend(ranks); } @@ -578,25 +1122,34 @@ async fn materialize_selected_rows( ))); } - reorder_selected_batches( - schema, + let materialized_schema = selected_batches + .first() + .map(|batch| batch.schema()) + .unwrap_or_else(|| full_input.schema()); + let batches = reorder_selected_batches( + materialized_schema, &selected_batches, selected_ranks, selected_rows.row_count, - ) + )?; + project_batches(&schema, batches, output_projection) } async fn materialize_with_pushed_down_file_rows( full_input: &Arc, context: Arc, schema: SchemaRef, + output_projection: Option<&[ProjectionExpr]>, + row_number_mode: RowNumberMode, selected_rows: &SelectedRows, ) -> Result>> { if selected_rows.row_count == 0 { return Ok(Some(vec![])); } - let Some(selected_input) = selected_file_scan(full_input, selected_rows)? else { + let Some(selected_input) = + selected_file_scan(full_input, selected_rows, row_number_mode)? + else { return Ok(None); }; @@ -626,13 +1179,18 @@ async fn materialize_with_pushed_down_file_rows( selected_ranks.extend(rows.values().copied()); } - reorder_selected_batches( - schema, + let materialized_schema = selected_batches + .first() + .map(|batch| batch.schema()) + .unwrap_or_else(|| selected_input.schema()); + let batches = reorder_selected_batches( + materialized_schema, &selected_batches, selected_ranks, selected_rows.row_count, ) - .map(Some) + .and_then(|batches| project_batches(&schema, batches, output_projection))?; + Ok(Some(batches)) } fn reorder_selected_batches( @@ -671,10 +1229,73 @@ fn reorder_selected_batches( Ok(vec![RecordBatch::try_new(schema, columns)?]) } +fn project_batches( + schema: &SchemaRef, + batches: Vec, + output_projection: Option<&[ProjectionExpr]>, +) -> Result> { + let Some(output_projection) = output_projection else { + return batches + .into_iter() + .map(|batch| align_batch_to_schema(Arc::clone(schema), batch)) + .collect(); + }; + + batches + .into_iter() + .map(|batch| { + let columns = output_projection + .iter() + .map(|expr| expr.expr.evaluate(&batch)?.into_array(batch.num_rows())) + .collect::>>()?; + make_batch_with_schema(Arc::clone(schema), columns) + }) + .collect() +} + +fn align_batch_to_schema(schema: SchemaRef, batch: RecordBatch) -> Result { + if batch.schema().as_ref() == schema.as_ref() { + return Ok(batch); + } + make_batch_with_schema(schema, batch.columns().to_vec()) +} + +fn make_batch_with_schema( + schema: SchemaRef, + columns: Vec, +) -> Result { + let fields = schema.fields(); + if fields.len() != columns.len() { + return internal_err!( + "late materialization projected {} columns for {} output fields", + columns.len(), + fields.len() + ); + } + + let columns = columns + .into_iter() + .zip(fields.iter()) + .map(|(column, field)| { + if column.data_type() == field.data_type() { + Ok(column) + } else { + cast(column.as_ref(), field.data_type()).map_err(DataFusionError::from) + } + }) + .collect::>>()?; + Ok(RecordBatch::try_new(schema, columns)?) +} + fn selected_file_scan( full_input: &Arc, selected_rows: &SelectedRows, + row_number_mode: RowNumberMode, ) -> Result>> { + if !supports_pushed_down_file_rows(full_input) { + return Ok(None); + } + let Some(data_source_exec) = full_input.downcast_ref::() else { return Ok(None); }; @@ -685,19 +1306,12 @@ fn selected_file_scan( return Ok(None); }; - if file_scan_config.file_source().file_type() != "parquet" - || file_scan_config.file_source().filter().is_some() - || file_scan_config.limit.is_some() - { - return Ok(None); - } - let mut selected_file_groups = Vec::with_capacity(file_scan_config.file_groups.len()); for (partition, file_group) in file_scan_config.file_groups.iter().enumerate() { let selected_files = match selected_rows.by_partition.get(&partition) { Some(rows) => { let Some(selected_files) = - selected_files_for_partition(file_group, rows)? + selected_files_for_partition(file_group, rows, row_number_mode)? else { return Ok(None); }; @@ -716,14 +1330,35 @@ fn selected_file_scan( ))))) } +fn supports_pushed_down_file_rows(input: &Arc) -> bool { + let Some(data_source_exec) = input.downcast_ref::() else { + return false; + }; + let Some(file_scan_config) = data_source_exec + .data_source() + .downcast_ref::() + else { + return false; + }; + + file_scan_config.file_source().file_type() == "parquet" + && file_scan_config.file_source().filter().is_none() + && file_scan_config.limit.is_none() +} + fn selected_files_for_partition( file_group: &FileGroup, rows: &BTreeMap, + row_number_mode: RowNumberMode, ) -> Result>> { if rows.is_empty() { return Ok(Some(vec![])); } + if matches!(row_number_mode, RowNumberMode::File) { + return selected_files_for_file_row_numbers(file_group, rows); + } + let mut selected_files = Vec::new(); let mut file_start = 0_u64; let max_row_number = *rows.keys().next_back().expect("rows is not empty"); @@ -769,6 +1404,25 @@ fn selected_files_for_partition( Ok(Some(selected_files)) } +fn selected_files_for_file_row_numbers( + file_group: &FileGroup, + rows: &BTreeMap, +) -> Result>> { + if file_group.len() != 1 { + return Ok(None); + } + + let file = &file_group.files()[0]; + if !file.extensions.is_empty() { + return Ok(None); + } + + let selected_file_rows = rows.keys().copied().collect::>(); + Ok(Some(vec![file.clone().with_extension( + FileRowsSelection::new(selected_file_rows), + )])) +} + fn exact_file_row_count(file: &PartitionedFile) -> Result> { let Some(statistics) = &file.statistics else { return Ok(None);