Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 64 additions & 4 deletions benchmarks/src/tpcds/run.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ use std::sync::Arc;

use crate::util::{BenchmarkRun, CommonOpt, QueryResult, print_memory_stats};

use arrow::datatypes::Schema;
use arrow::record_batch::RecordBatch;
use arrow::util::pretty::{self, pretty_format_batches};
use datafusion::datasource::file_format::parquet::ParquetFormat;
Expand All @@ -34,7 +35,7 @@ use datafusion::physical_plan::{collect, displayable};
use datafusion::prelude::*;
use datafusion_common::instant::Instant;
use datafusion_common::utils::get_available_parallelism;
use datafusion_common::{DEFAULT_PARQUET_EXTENSION, plan_err};
use datafusion_common::{Constraint, Constraints, DEFAULT_PARQUET_EXTENSION, plan_err};

use clap::Args;
use log::info;
Expand Down Expand Up @@ -71,6 +72,61 @@ pub const TPCDS_TABLES: &[&str] = &[
"web_site",
];

static TPCDS_PRIMARY_KEYS: &[(&str, &[&str])] = &[
("call_center", &["cc_call_center_sk"]),
("catalog_page", &["cp_catalog_page_sk"]),
("catalog_returns", &["cr_item_sk", "cr_order_number"]),
("catalog_sales", &["cs_item_sk", "cs_order_number"]),
("customer", &["c_customer_sk"]),
("customer_address", &["ca_address_sk"]),
("customer_demographics", &["cd_demo_sk"]),
("date_dim", &["d_date_sk"]),
("household_demographics", &["hd_demo_sk"]),
("income_band", &["ib_income_band_sk"]),
(
"inventory",
&["inv_date_sk", "inv_item_sk", "inv_warehouse_sk"],
),
("item", &["i_item_sk"]),
("promotion", &["p_promo_sk"]),
("reason", &["r_reason_sk"]),
("ship_mode", &["sm_ship_mode_sk"]),
("store", &["s_store_sk"]),
("store_returns", &["sr_item_sk", "sr_ticket_number"]),
("store_sales", &["ss_item_sk", "ss_ticket_number"]),
("time_dim", &["t_time_sk"]),
("warehouse", &["w_warehouse_sk"]),
("web_page", &["wp_web_page_sk"]),
("web_returns", &["wr_item_sk", "wr_order_number"]),
("web_sales", &["ws_item_sk", "ws_order_number"]),
("web_site", &["web_site_sk"]),
];

/// Get the constraints for a TPC-DS table. Only primary keys are returned;
/// TPC-DS also defines foreign keys, but those are currently unsupported.
fn table_constraints(table: &str, schema: &Schema) -> Constraints {
let columns = TPCDS_PRIMARY_KEYS
.iter()
.find(|(name, _)| *name == table)
.map(|(_, columns)| *columns)
.unwrap_or_else(|| unimplemented!("unknown TPC-DS table: {table}"));

Constraints::new_unverified(vec![primary_key(schema, columns)])
}

fn primary_key(schema: &Schema, column_names: &[&str]) -> Constraint {
let indices = column_names
.iter()
.map(|column_name| {
schema.index_of(column_name).unwrap_or_else(|_| {
panic!("primary key column '{column_name}' not found in schema")
})
})
.collect();

Constraint::PrimaryKey(indices)
}

/// Get the SQL statements from the specified query file
pub fn get_query_sql(base_query_path: &str, query: usize) -> Result<Vec<String>> {
if query > 0 && query < 100 {
Expand Down Expand Up @@ -327,7 +383,9 @@ impl RunOpt {
.with_file_extension(DEFAULT_PARQUET_EXTENSION)
.with_target_partitions(target_partitions)
.with_collect_stat(state.config().collect_statistics());

let schema = options.infer_schema(&state, &table_path).await?;
let constraints = table_constraints(table, schema.as_ref());

if self.common.debug {
println!(
Expand All @@ -347,9 +405,11 @@ impl RunOpt {
.with_listing_options(options)
.with_schema(schema);

Ok(Arc::new(ListingTable::try_new(config)?.with_cache(
ctx.runtime_env().cache_manager.get_file_statistic_cache(),
)))
let provider = ListingTable::try_new(config)?
.with_constraints(constraints)
.with_cache(ctx.runtime_env().cache_manager.get_file_statistic_cache());

Ok(Arc::new(provider))
}

fn iterations(&self) -> usize {
Expand Down
38 changes: 37 additions & 1 deletion benchmarks/src/tpch/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
use arrow::datatypes::SchemaBuilder;
use datafusion::{
arrow::datatypes::{DataType, Field, Schema},
common::plan_err,
common::{Constraint, Constraints, plan_err},
error::Result,
};
use std::fs;
Expand Down Expand Up @@ -138,6 +138,42 @@ pub fn get_tpch_table_schema(table: &str) -> Schema {
}
}

static TPCH_PRIMARY_KEYS: &[(&str, &[&str])] = &[
("region", &["r_regionkey"]),
("nation", &["n_nationkey"]),
("part", &["p_partkey"]),
("supplier", &["s_suppkey"]),
("partsupp", &["ps_partkey", "ps_suppkey"]),
("customer", &["c_custkey"]),
("orders", &["o_orderkey"]),
("lineitem", &["l_orderkey", "l_linenumber"]),
];

/// Get the constraints for a TPC-H table. Only primary keys are returned; TPC-H
/// also defines foreign keys, but those are currently unsupported.
fn table_constraints(table: &str, schema: &Schema) -> Constraints {
let columns = TPCH_PRIMARY_KEYS
.iter()
.find(|(name, _)| *name == table)
.map(|(_, columns)| *columns)
.unwrap_or_else(|| unimplemented!("unknown TPC-H table: {table}"));

Constraints::new_unverified(vec![primary_key(schema, columns)])
}

fn primary_key(schema: &Schema, column_names: &[&str]) -> Constraint {
let indices = column_names
.iter()
.map(|column_name| {
schema.index_of(column_name).unwrap_or_else(|_| {
panic!("primary key column '{column_name}' not found in schema")
})
})
.collect();

Constraint::PrimaryKey(indices)
}

/// Get the SQL statements from the specified query file
pub fn get_query_sql(query: usize) -> Result<Vec<String>> {
get_query_sql_for_scale_factor(query, 1.0)
Expand Down
13 changes: 9 additions & 4 deletions benchmarks/src/tpch/run.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ use std::sync::Arc;

use super::{
TPCH_QUERY_END_ID, TPCH_QUERY_START_ID, TPCH_TABLES, get_query_sql_for_scale_factor,
get_tbl_tpch_table_schema, get_tpch_table_schema,
get_tbl_tpch_table_schema, get_tpch_table_schema, table_constraints,
};
use crate::util::{BenchmarkRun, CommonOpt, QueryResult, print_memory_stats};

Expand Down Expand Up @@ -324,12 +324,15 @@ impl RunOpt {
.with_file_extension(extension)
.with_target_partitions(target_partitions)
.with_collect_stat(state.config().collect_statistics());

let schema = match table_format {
"parquet" => options.infer_schema(&state, &table_path).await?,
"tbl" => Arc::new(get_tbl_tpch_table_schema(table)),
"csv" => Arc::new(get_tpch_table_schema(table)),
_ => unreachable!(),
};
let constraints = table_constraints(table, schema.as_ref());

let options = if self.sorted {
let key_column_name = schema.fields()[0].name();
options
Expand All @@ -342,9 +345,11 @@ impl RunOpt {
.with_listing_options(options)
.with_schema(schema);

Ok(Arc::new(ListingTable::try_new(config)?.with_cache(
ctx.runtime_env().cache_manager.get_file_statistic_cache(),
)))
let provider = ListingTable::try_new(config)?
.with_constraints(constraints)
.with_cache(ctx.runtime_env().cache_manager.get_file_statistic_cache());

Ok(Arc::new(provider))
}

fn iterations(&self) -> usize {
Expand Down
80 changes: 19 additions & 61 deletions datafusion/catalog/src/memory/table.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@ use arrow::datatypes::{DataType, Field, Schema, SchemaRef};
use arrow::record_batch::RecordBatch;
use datafusion_common::error::Result;
use datafusion_common::{Constraints, DFSchema, SchemaExt, not_impl_err, plan_err};
use datafusion_common_runtime::JoinSet;
use datafusion_datasource::memory::{MemSink, MemorySourceConfig};
use datafusion_datasource::sink::DataSinkExec;
use datafusion_datasource::source::DataSourceExec;
Expand All @@ -44,13 +43,12 @@ use datafusion_physical_expr::{
use datafusion_physical_plan::repartition::RepartitionExec;
use datafusion_physical_plan::stream::RecordBatchStreamAdapter;
use datafusion_physical_plan::{
DisplayAs, DisplayFormatType, ExecutionPlan, ExecutionPlanProperties, Partitioning,
PlanProperties, common,
DisplayAs, DisplayFormatType, ExecutionPlan, Partitioning, PlanProperties,
collect_partitioned,
};
use datafusion_session::Session;

use async_trait::async_trait;
use futures::StreamExt;
use log::debug;
use parking_lot::Mutex;
use tokio::sync::RwLock;
Expand Down Expand Up @@ -145,68 +143,28 @@ impl MemTable {
state: &dyn Session,
) -> Result<Self> {
let schema = t.schema();
let constraints = t.constraints();
let exec = t.scan(state, None, &[], None).await?;
let partition_count = exec.output_partitioning().partition_count();

let mut join_set = JoinSet::new();

for part_idx in 0..partition_count {
let task = state.task_ctx();
let exec = Arc::clone(&exec);
join_set.spawn(async move {
let stream = exec.execute(part_idx, task)?;
common::collect(stream).await
});
}

let mut data: Vec<Vec<RecordBatch>> =
Vec::with_capacity(exec.output_partitioning().partition_count());

while let Some(result) = join_set.join_next().await {
match result {
Ok(res) => data.push(res?),
Err(e) => {
if e.is_panic() {
std::panic::resume_unwind(e.into_panic());
} else {
unreachable!();
}
}
}
}
let constraints = t.constraints().cloned().unwrap_or_default();

let mut exec = DataSourceExec::new(Arc::new(MemorySourceConfig::try_new(
&data,
Arc::clone(&schema),
None,
)?));
if let Some(cons) = constraints {
exec = exec.with_constraints(cons.clone());
}

if let Some(num_partitions) = output_partitions {
let exec = t.scan(state, None, &[], None).await?;
let data = collect_partitioned(exec, state.task_ctx()).await?;

// Optionally repartition the collected batches.
let data = if let Some(num_partitions) = output_partitions {
let source = DataSourceExec::new(Arc::new(MemorySourceConfig::try_new(
&data,
Arc::clone(&schema),
None,
)?));
let exec = RepartitionExec::try_new(
Arc::new(exec),
Arc::new(source),
Partitioning::RoundRobinBatch(num_partitions),
)?;
collect_partitioned(Arc::new(exec), state.task_ctx()).await?
} else {
data
};

// execute and collect results
let mut output_partitions = vec![];
for i in 0..exec.properties().output_partitioning().partition_count() {
// execute this *output* partition and collect all batches
let task_ctx = state.task_ctx();
let mut stream = exec.execute(i, task_ctx)?;
let mut batches = vec![];
while let Some(result) = stream.next().await {
batches.push(result?);
}
output_partitions.push(batches);
}

return MemTable::try_new(Arc::clone(&schema), output_partitions);
}
MemTable::try_new(Arc::clone(&schema), data)
MemTable::try_new(schema, data).map(|table| table.with_constraints(constraints))
}
}

Expand Down
18 changes: 13 additions & 5 deletions datafusion/core/benches/sql_planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -133,15 +133,23 @@ fn create_context() -> SessionContext {

/// Register the table definitions as a MemTable with the context and return the
/// context
#[expect(clippy::needless_pass_by_value)]
fn register_defs(ctx: SessionContext, defs: Vec<TableDef>) -> SessionContext {
defs.iter().for_each(|TableDef { name, schema }| {
for TableDef {
name,
schema,
constraints,
} in defs
{
ctx.register_table(
name,
Arc::new(MemTable::try_new(Arc::new(schema.clone()), vec![vec![]]).unwrap()),
&name,
Arc::new(
MemTable::try_new(Arc::new(schema), vec![vec![]])
.unwrap()
.with_constraints(constraints),
),
)
.unwrap();
});
}
ctx
}

Expand Down
Loading
Loading