diff --git a/common/src/main/scala/org/apache/comet/CometConf.scala b/common/src/main/scala/org/apache/comet/CometConf.scala index 046ccf0b1c..3d7926662f 100644 --- a/common/src/main/scala/org/apache/comet/CometConf.scala +++ b/common/src/main/scala/org/apache/comet/CometConf.scala @@ -448,6 +448,16 @@ object CometConf extends ShimCometConf { .intConf .createWithDefault(1) + val COMET_SHUFFLE_BATCH_STASH_ENABLED: ConfigEntry[Boolean] = + conf(s"$COMET_EXEC_CONFIG_PREFIX.shuffle.batchStash.enabled") + .category(CATEGORY_SHUFFLE) + .doc( + "When enabled, batches passed between a native child plan and a native shuffle " + + "writer are transferred via an opaque handle instead of Arrow FFI, avoiding " + + "unnecessary serialization overhead.") + .booleanConf + .createWithDefault(true) + val COMET_COLUMNAR_SHUFFLE_ASYNC_ENABLED: ConfigEntry[Boolean] = conf("spark.comet.columnar.shuffle.async.enabled") .category(CATEGORY_SHUFFLE) diff --git a/native/Cargo.lock b/native/Cargo.lock index 480f7ad06d..b5c7f2b0c7 100644 --- a/native/Cargo.lock +++ b/native/Cargo.lock @@ -2070,9 +2070,7 @@ dependencies = [ "num", "rand 0.10.0", "regex", - "serde", "serde_json", - "thiserror 2.0.18", "tokio", "twox-hash", ] @@ -5885,7 +5883,7 @@ version = "1.0.149" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "83fc039473c5595ace860d8c4fafa220ff474b3fc6bfdb4293327f1a37e94d86" dependencies = [ - "indexmap 2.13.0", + "indexmap 2.13.1", "itoa", "memchr", "serde", diff --git a/native/core/src/execution/batch_stash.rs b/native/core/src/execution/batch_stash.rs new file mode 100644 index 0000000000..c655069c65 --- /dev/null +++ b/native/core/src/execution/batch_stash.rs @@ -0,0 +1,124 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Global registry for passing RecordBatch values between native execution contexts +//! via opaque u64 handles, without Arrow FFI serialization. + +use arrow::record_batch::RecordBatch; +use once_cell::sync::Lazy; +use std::collections::HashMap; +use std::sync::atomic::{AtomicU64, Ordering}; +use std::sync::Mutex; + +/// Counter for generating unique handles. +static NEXT_HANDLE: AtomicU64 = AtomicU64::new(1); + +/// Global stash mapping handles to RecordBatch values. +/// Entries are removed by `take()` when the downstream ScanExec consumes them, +/// so there is no leak under normal operation. The stash lives for the process +/// lifetime but is effectively empty between query executions. +static STASH: Lazy>> = Lazy::new(|| Mutex::new(HashMap::new())); + +/// Store a RecordBatch in the global stash and return a unique handle. +pub(crate) fn stash(batch: RecordBatch) -> u64 { + let handle = NEXT_HANDLE.fetch_add(1, Ordering::Relaxed); + STASH + .lock() + .unwrap_or_else(|e| e.into_inner()) + .insert(handle, batch); + handle +} + +/// Remove and return the RecordBatch associated with the given handle. +/// +/// Returns `None` if the handle does not exist in the stash. +pub(crate) fn take(handle: u64) -> Option { + STASH + .lock() + .unwrap_or_else(|e| e.into_inner()) + .remove(&handle) +} + +#[cfg(test)] +mod tests { + use super::*; + use arrow::array::Int32Array; + use arrow::datatypes::{DataType, Field, Schema}; + use std::sync::Arc; + + fn make_batch(values: Vec) -> RecordBatch { + let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)])); + let array = Arc::new(Int32Array::from(values)); + RecordBatch::try_new(schema, vec![array]).unwrap() + } + + #[test] + fn test_stash_and_take() { + let batch = make_batch(vec![1, 2, 3]); + let num_rows = batch.num_rows(); + + let handle = stash(batch); + let retrieved = take(handle).expect("expected batch to be present"); + + assert_eq!(retrieved.num_rows(), num_rows); + let col = retrieved + .column(0) + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(col.values(), &[1, 2, 3]); + } + + #[test] + fn test_take_removes_entry() { + let batch = make_batch(vec![10, 20]); + let handle = stash(batch); + + // First take returns the batch. + assert!(take(handle).is_some()); + // Second take finds nothing. + assert!(take(handle).is_none()); + } + + #[test] + fn test_take_unknown_handle() { + // Handle 0 is never issued (counter starts at 1). + assert!(take(0).is_none()); + // A large handle that was never issued. + assert!(take(u64::MAX).is_none()); + } + + #[test] + fn test_handles_are_unique() { + let batch1 = make_batch(vec![1]); + let batch2 = make_batch(vec![2]); + let batch3 = make_batch(vec![3]); + + let h1 = stash(batch1); + let h2 = stash(batch2); + let h3 = stash(batch3); + + assert_ne!(h1, h2); + assert_ne!(h2, h3); + assert_ne!(h1, h3); + + // Clean up. + take(h1); + take(h2); + take(h3); + } +} diff --git a/native/core/src/execution/jni_api.rs b/native/core/src/execution/jni_api.rs index 93f75bae96..f24e4c7f33 100644 --- a/native/core/src/execution/jni_api.rs +++ b/native/core/src/execution/jni_api.rs @@ -617,6 +617,201 @@ fn prepare_output( Ok(num_rows as jlong) } +/// Stash the output RecordBatch in the BatchStash and return the handle. +fn stash_output(output_batch: RecordBatch) -> CometResult { + let handle = crate::execution::batch_stash::stash(output_batch); + Ok(handle as jlong) +} + +/// How to handle output batches from the execution plan. +enum OutputMode<'a> { + /// Export via Arrow FFI to the provided addresses. + Ffi { + array_addrs: JLongArray<'a>, + schema_addrs: JLongArray<'a>, + validate: bool, + }, + /// Stash in BatchStash and return handle. + Stash, +} + +impl OutputMode<'_> { + fn handle_batch(&self, env: &mut Env, batch: RecordBatch) -> CometResult { + match self { + OutputMode::Ffi { + array_addrs, + schema_addrs, + validate, + } => { + // Safety: JLongArray is a raw JNI reference that remains valid for the + // duration of the JNI call. We reborrow it here since prepare_output + // only reads from it. + let array_addrs = unsafe { JLongArray::from_raw(env, array_addrs.as_raw()) }; + let schema_addrs = unsafe { JLongArray::from_raw(env, schema_addrs.as_raw()) }; + prepare_output(env, array_addrs, schema_addrs, batch, *validate) + } + OutputMode::Stash => stash_output(batch), + } + } +} + +/// Shared execution logic for `executePlan` and `executePlanBatchHandle`. +fn execute_plan_impl( + env: &mut Env, + stage_id: jint, + partition: jint, + exec_context: &mut ExecutionContext, + output_mode: &OutputMode, +) -> CometResult { + let tracing_enabled = exec_context.tracing_enabled; + let owned_label; + let tracing_label = if tracing_enabled { + owned_label = exec_context.tracing_event_name.clone(); + owned_label.as_str() + } else { + "" + }; + + let result = with_trace(tracing_label, tracing_enabled, || { + let exec_context_id = exec_context.id; + + // Initialize the execution stream. + // Because we don't know if input arrays are dictionary-encoded when we create + // query plan, we need to defer stream initialization to first time execution. + if exec_context.root_op.is_none() { + let start = Instant::now(); + let planner = PhysicalPlanner::new(Arc::clone(&exec_context.session_ctx), partition) + .with_exec_id(exec_context_id); + let (scans, shuffle_scans, root_op) = planner.create_plan( + &exec_context.spark_plan, + &mut exec_context.input_sources.clone(), + exec_context.partition_count, + )?; + let physical_plan_time = start.elapsed(); + + exec_context.plan_creation_time += physical_plan_time; + exec_context.scans = scans; + exec_context.shuffle_scans = shuffle_scans; + + if exec_context.explain_native { + let formatted_plan_str = + DisplayableExecutionPlan::new(root_op.native_plan.as_ref()).indent(true); + info!("Comet native query plan:\n{formatted_plan_str:}"); + } + + let task_ctx = exec_context.session_ctx.task_ctx(); + // Each Comet native execution corresponds to a single Spark partition, + // so we should always execute partition 0. + let stream = root_op.native_plan.execute(0, task_ctx)?; + + if exec_context.scans.is_empty() && exec_context.shuffle_scans.is_empty() { + // No JVM data sources -- spawn onto tokio so the executor + // thread parks in blocking_recv instead of busy-polling. + let (tx, rx) = mpsc::channel(2); + let mut stream = stream; + get_runtime().spawn(async move { + let result = std::panic::AssertUnwindSafe(async { + while let Some(batch) = stream.next().await { + if tx.send(batch).await.is_err() { + break; + } + } + }) + .catch_unwind() + .await; + + if let Err(panic) = result { + let msg = match panic.downcast_ref::<&str>() { + Some(s) => s.to_string(), + None => match panic.downcast_ref::() { + Some(s) => s.clone(), + None => "unknown panic".to_string(), + }, + }; + let _ = tx + .send(Err(DataFusionError::Execution(format!( + "native panic: {msg}" + )))) + .await; + } + }); + exec_context.batch_receiver = Some(rx); + } else { + exec_context.stream = Some(stream); + } + exec_context.root_op = Some(root_op); + } else { + pull_input_batches(exec_context)?; + } + + if let Some(rx) = &mut exec_context.batch_receiver { + match rx.blocking_recv() { + Some(Ok(batch)) => { + update_metrics(env, exec_context)?; + return output_mode.handle_batch(env, batch); + } + Some(Err(e)) => { + return Err(e.into()); + } + None => { + log_plan_metrics(exec_context, stage_id, partition); + return Ok(-1); + } + } + } + + // ScanExec path: busy-poll to interleave JVM batch pulls with stream polling + get_runtime().block_on(async { + loop { + let next_item = exec_context.stream.as_mut().unwrap().next(); + let poll_output = poll!(next_item); + + exec_context.poll_count_since_metrics_check += 1; + if exec_context.poll_count_since_metrics_check >= 100 { + exec_context.poll_count_since_metrics_check = 0; + if let Some(interval) = exec_context.metrics_update_interval { + let now = Instant::now(); + if now - exec_context.metrics_last_update_time >= interval { + update_metrics(env, exec_context)?; + exec_context.metrics_last_update_time = now; + } + } + if exec_context.tracing_enabled { + log_memory_usage( + &exec_context.tracing_memory_metric_name, + total_reserved_for_thread(exec_context.rust_thread_id) as u64, + ); + } + } + + match poll_output { + Poll::Ready(Some(output)) => { + return output_mode.handle_batch(env, output?); + } + Poll::Ready(None) => { + log_plan_metrics(exec_context, stage_id, partition); + return Ok(-1); + } + Poll::Pending => { + tokio::task::block_in_place(|| pull_input_batches(exec_context))?; + } + } + } + }) + }); + + if exec_context.tracing_enabled { + #[cfg(feature = "jemalloc")] + log_jemalloc_usage(); + log_memory_usage( + &exec_context.tracing_memory_metric_name, + total_reserved_for_thread(exec_context.rust_thread_id) as u64, + ); + } + + result +} + /// Pull the next input from JVM. Note that we cannot pull input batches in /// `ScanStream.poll_next` when the execution stream is polled for output. /// Because the input source could be another native execution stream, which @@ -650,182 +845,31 @@ pub unsafe extern "system" fn Java_org_apache_comet_Native_executePlan( schema_addrs: JLongArray, ) -> jlong { try_unwrap_or_throw(&e, |env| { - // Retrieve the query let exec_context = get_execution_context(exec_context); - - let tracing_enabled = exec_context.tracing_enabled; - // Clone the label only when tracing is enabled. The clone is needed - // because the closure below mutably borrows exec_context. - let owned_label; - let tracing_label = if tracing_enabled { - owned_label = exec_context.tracing_event_name.clone(); - owned_label.as_str() - } else { - "" + let output_mode = OutputMode::Ffi { + array_addrs, + schema_addrs, + validate: exec_context.debug_native, }; + execute_plan_impl(env, stage_id, partition, exec_context, &output_mode) + }) +} - let result = with_trace(tracing_label, tracing_enabled, || { - let exec_context_id = exec_context.id; - - // Initialize the execution stream. - // Because we don't know if input arrays are dictionary-encoded when we create - // query plan, we need to defer stream initialization to first time execution. - if exec_context.root_op.is_none() { - let start = Instant::now(); - let planner = - PhysicalPlanner::new(Arc::clone(&exec_context.session_ctx), partition) - .with_exec_id(exec_context_id); - let (scans, shuffle_scans, root_op) = planner.create_plan( - &exec_context.spark_plan, - &mut exec_context.input_sources.clone(), - exec_context.partition_count, - )?; - let physical_plan_time = start.elapsed(); - - exec_context.plan_creation_time += physical_plan_time; - exec_context.scans = scans; - exec_context.shuffle_scans = shuffle_scans; - - if exec_context.explain_native { - let formatted_plan_str = - DisplayableExecutionPlan::new(root_op.native_plan.as_ref()).indent(true); - info!("Comet native query plan:\n{formatted_plan_str:}"); - } - - let task_ctx = exec_context.session_ctx.task_ctx(); - // Each Comet native execution corresponds to a single Spark partition, - // so we should always execute partition 0. - let stream = root_op.native_plan.execute(0, task_ctx)?; - - if exec_context.scans.is_empty() && exec_context.shuffle_scans.is_empty() { - // No JVM data sources — spawn onto tokio so the executor - // thread parks in blocking_recv instead of busy-polling. - // - // Channel capacity of 2 allows the producer to work one batch - // ahead while the consumer processes the current one via JNI, - // without buffering excessive memory. Increasing this would - // trade memory for latency hiding if JNI/FFI overhead dominates; - // decreasing to 1 would serialize production and consumption. - let (tx, rx) = mpsc::channel(2); - let mut stream = stream; - get_runtime().spawn(async move { - let result = std::panic::AssertUnwindSafe(async { - while let Some(batch) = stream.next().await { - if tx.send(batch).await.is_err() { - break; - } - } - }) - .catch_unwind() - .await; - - if let Err(panic) = result { - let msg = match panic.downcast_ref::<&str>() { - Some(s) => s.to_string(), - None => match panic.downcast_ref::() { - Some(s) => s.clone(), - None => "unknown panic".to_string(), - }, - }; - let _ = tx - .send(Err(DataFusionError::Execution(format!( - "native panic: {msg}" - )))) - .await; - } - }); - exec_context.batch_receiver = Some(rx); - } else { - exec_context.stream = Some(stream); - } - exec_context.root_op = Some(root_op); - } else { - // Pull input batches - pull_input_batches(exec_context)?; - } - - if let Some(rx) = &mut exec_context.batch_receiver { - match rx.blocking_recv() { - Some(Ok(batch)) => { - update_metrics(env, exec_context)?; - return prepare_output( - env, - array_addrs, - schema_addrs, - batch, - exec_context.debug_native, - ); - } - Some(Err(e)) => { - return Err(e.into()); - } - None => { - log_plan_metrics(exec_context, stage_id, partition); - return Ok(-1); - } - } - } - - // ScanExec path: busy-poll to interleave JVM batch pulls with stream polling - get_runtime().block_on(async { - loop { - let next_item = exec_context.stream.as_mut().unwrap().next(); - let poll_output = poll!(next_item); - - // Only check time/tracing every 100 polls to reduce overhead - exec_context.poll_count_since_metrics_check += 1; - if exec_context.poll_count_since_metrics_check >= 100 { - exec_context.poll_count_since_metrics_check = 0; - if let Some(interval) = exec_context.metrics_update_interval { - let now = Instant::now(); - if now - exec_context.metrics_last_update_time >= interval { - update_metrics(env, exec_context)?; - exec_context.metrics_last_update_time = now; - } - } - if exec_context.tracing_enabled { - log_memory_usage( - &exec_context.tracing_memory_metric_name, - total_reserved_for_thread(exec_context.rust_thread_id) as u64, - ); - } - } - - match poll_output { - Poll::Ready(Some(output)) => { - return prepare_output( - env, - array_addrs, - schema_addrs, - output?, - exec_context.debug_native, - ); - } - Poll::Ready(None) => { - log_plan_metrics(exec_context, stage_id, partition); - return Ok(-1); - } - Poll::Pending => { - // JNI call to pull batches from JVM into ScanExec operators. - // block_in_place lets tokio move other tasks off this worker - // while we wait for JVM data. - tokio::task::block_in_place(|| pull_input_batches(exec_context))?; - } - } - } - }) - }); - - if exec_context.tracing_enabled { - #[cfg(feature = "jemalloc")] - log_jemalloc_usage(); - log_memory_usage( - &exec_context.tracing_memory_metric_name, - total_reserved_for_thread(exec_context.rust_thread_id) as u64, - ); - } - - result +/// Like executePlan but stashes the output RecordBatch and returns a handle instead of +/// exporting via Arrow FFI. Used when output feeds directly into another native plan. +/// # Safety +/// This function is inherently unsafe since it deals with raw pointers passed from JNI. +#[no_mangle] +pub unsafe extern "system" fn Java_org_apache_comet_Native_executePlanBatchHandle( + e: EnvUnowned, + _class: JClass, + stage_id: jint, + partition: jint, + exec_context: jlong, +) -> jlong { + try_unwrap_or_throw(&e, |env| { + let exec_context = get_execution_context(exec_context); + execute_plan_impl(env, stage_id, partition, exec_context, &OutputMode::Stash) }) } diff --git a/native/core/src/execution/mod.rs b/native/core/src/execution/mod.rs index f556fce41c..b25fe277ff 100644 --- a/native/core/src/execution/mod.rs +++ b/native/core/src/execution/mod.rs @@ -16,6 +16,7 @@ // under the License. //! PoC of vectorization execution through JNI to Rust. +pub(crate) mod batch_stash; pub mod columnar_to_row; pub mod expressions; pub mod jni_api; diff --git a/native/core/src/execution/operators/scan.rs b/native/core/src/execution/operators/scan.rs index 90bb741b5e..87db054a85 100644 --- a/native/core/src/execution/operators/scan.rs +++ b/native/core/src/execution/operators/scan.rs @@ -77,6 +77,9 @@ pub struct ScanExec { baseline_metrics: BaselineMetrics, /// Whether native code can assume ownership of batches that it receives arrow_ffi_safe: bool, + /// When true, input comes from a CometHandleBatchIterator and batches are + /// retrieved from the BatchStash instead of via Arrow FFI import. + pub handle_mode: bool, } impl ScanExec { @@ -113,6 +116,7 @@ impl ScanExec { baseline_metrics, schema, arrow_ffi_safe, + handle_mode: false, }) } @@ -141,12 +145,19 @@ impl ScanExec { let mut current_batch = self.batch.try_lock().unwrap(); if current_batch.is_none() { - let next_batch = ScanExec::get_next( - self.exec_context_id, - self.input_source.as_ref().unwrap().as_obj(), - self.data_types.len(), - self.arrow_ffi_safe, - )?; + let next_batch = if self.handle_mode { + ScanExec::get_next_handle( + self.exec_context_id, + self.input_source.as_ref().unwrap().as_obj(), + )? + } else { + ScanExec::get_next( + self.exec_context_id, + self.input_source.as_ref().unwrap().as_obj(), + self.data_types.len(), + self.arrow_ffi_safe, + )? + }; *current_batch = Some(next_batch); } @@ -257,6 +268,37 @@ impl ScanExec { }) } + /// Pull next input batch from a CometHandleBatchIterator via batch stash handle. + fn get_next_handle(exec_context_id: i64, iter: &JObject) -> Result { + if exec_context_id == TEST_EXEC_CONTEXT_ID { + return Ok(InputBatch::EOF); + } + + if iter.is_null() { + return Err(CometError::from(ExecutionError::GeneralError(format!( + "Null handle batch iterator object. Plan id: {exec_context_id}" + )))); + } + + JVMClasses::with_env(|env| { + let handle: i64 = unsafe { + jni_call!(env, + comet_handle_batch_iterator(iter).next_handle() -> i64)? + }; + + if handle == -1 { + return Ok(InputBatch::EOF); + } + + match crate::execution::batch_stash::take(handle as u64) { + Some(batch) => Ok(InputBatch::Complete(batch)), + None => Err(CometError::from(ExecutionError::GeneralError(format!( + "Batch stash handle {handle} not found" + )))), + } + }) + } + /// Allocates Arrow FFI structures and calls JNI to get the next batch data. /// Returns the number of rows and the allocated array/schema addresses. fn allocate_and_fetch_batch( @@ -517,8 +559,7 @@ impl Stream for ScanStream<'_> { let mut timer = self.baseline_metrics.elapsed_compute().timer(); let mut scan_batch = self.scan.batch.try_lock().unwrap(); - let input_batch = &*scan_batch; - let input_batch = if let Some(batch) = input_batch { + let input_batch = if let Some(batch) = scan_batch.take() { batch } else { timer.stop(); @@ -527,15 +568,28 @@ impl Stream for ScanStream<'_> { let result = match input_batch { InputBatch::EOF => Poll::Ready(None), - InputBatch::Batch(columns, num_rows) => { - self.baseline_metrics.record_output(*num_rows); - let maybe_batch = self.build_record_batch(columns, *num_rows); + InputBatch::Batch(ref columns, num_rows) => { + self.baseline_metrics.record_output(num_rows); + let maybe_batch = self.build_record_batch(columns, num_rows); Poll::Ready(Some(maybe_batch)) } + InputBatch::Complete(batch) => { + self.baseline_metrics.record_output(batch.num_rows()); + let columns = batch.columns(); + let num_rows = batch.num_rows(); + if columns.len() == self.schema.fields().len() { + // Column counts match. Use build_record_batch to handle any + // type differences (e.g., timestamp timezone casting). + let maybe_batch = self.build_record_batch(columns, num_rows); + Poll::Ready(Some(maybe_batch)) + } else { + // Column count mismatch (e.g., empty schema scan). + // Return the stashed batch as-is since it's already valid. + Poll::Ready(Some(Ok(batch))) + } + } }; - *scan_batch = None; - timer.stop(); result @@ -558,6 +612,11 @@ pub enum InputBatch { /// It is possible to have a zero-column batch with a non-zero number of rows, /// i.e. reading empty schema from scan. Batch(Vec, usize), + + /// A complete RecordBatch retrieved from the BatchStash. May still + /// go through `build_record_batch` for schema reconciliation (e.g., + /// timestamp timezone casting) when column counts match. + Complete(RecordBatch), } impl InputBatch { diff --git a/native/core/src/execution/operators/shuffle_scan.rs b/native/core/src/execution/operators/shuffle_scan.rs index 92c4dc8780..53c839ec97 100644 --- a/native/core/src/execution/operators/shuffle_scan.rs +++ b/native/core/src/execution/operators/shuffle_scan.rs @@ -311,8 +311,7 @@ impl Stream for ShuffleScanStream { let mut timer = self.baseline_metrics.elapsed_compute().timer(); let mut scan_batch = self.shuffle_scan.batch.try_lock().unwrap(); - let input_batch = &*scan_batch; - let input_batch = if let Some(batch) = input_batch { + let input_batch = if let Some(batch) = scan_batch.take() { batch } else { timer.stop(); @@ -321,10 +320,10 @@ impl Stream for ShuffleScanStream { let result = match input_batch { InputBatch::EOF => Poll::Ready(None), - InputBatch::Batch(columns, num_rows) => { - self.baseline_metrics.record_output(*num_rows); + InputBatch::Batch(ref columns, num_rows) => { + self.baseline_metrics.record_output(num_rows); let options = - arrow::array::RecordBatchOptions::new().with_row_count(Some(*num_rows)); + arrow::array::RecordBatchOptions::new().with_row_count(Some(num_rows)); let maybe_batch = arrow::array::RecordBatch::try_new_with_options( self.shuffle_scan.schema(), columns.clone(), @@ -333,10 +332,12 @@ impl Stream for ShuffleScanStream { .map_err(|e| arrow_datafusion_err!(e)); Poll::Ready(Some(maybe_batch)) } + InputBatch::Complete(batch) => { + self.baseline_metrics.record_output(batch.num_rows()); + Poll::Ready(Some(Ok(batch))) + } }; - *scan_batch = None; - timer.stop(); result diff --git a/native/core/src/execution/planner.rs b/native/core/src/execution/planner.rs index ac35925ace..4117a69fd3 100644 --- a/native/core/src/execution/planner.rs +++ b/native/core/src/execution/planner.rs @@ -1295,8 +1295,10 @@ impl PhysicalPlanner { Some(inputs.remove(0)) }; + let use_batch_stash = scan.batch_stash_handle; + // The `ScanExec` operator will take actual arrays from Spark during execution - let scan = ScanExec::new( + let mut scan = ScanExec::new( self.exec_context_id, input_source, &scan.source, @@ -1304,6 +1306,10 @@ impl PhysicalPlanner { scan.arrow_ffi_safe, )?; + if use_batch_stash { + scan.handle_mode = true; + } + Ok(( vec![scan.clone()], vec![], @@ -3772,6 +3778,7 @@ mod tests { }], source: "".to_string(), arrow_ffi_safe: false, + batch_stash_handle: false, })), }; @@ -3838,6 +3845,7 @@ mod tests { }], source: "".to_string(), arrow_ffi_safe: false, + batch_stash_handle: false, })), }; @@ -4047,6 +4055,7 @@ mod tests { fields: vec![create_proto_datatype()], source: "".to_string(), arrow_ffi_safe: false, + batch_stash_handle: false, })), } } @@ -4090,6 +4099,7 @@ mod tests { ], source: "".to_string(), arrow_ffi_safe: false, + batch_stash_handle: false, })), }; @@ -4213,6 +4223,7 @@ mod tests { ], source: "".to_string(), arrow_ffi_safe: false, + batch_stash_handle: false, })), }; @@ -4696,6 +4707,7 @@ mod tests { ], source: "".to_string(), arrow_ffi_safe: false, + batch_stash_handle: false, })), }; diff --git a/native/jni-bridge/src/handle_batch_iterator.rs b/native/jni-bridge/src/handle_batch_iterator.rs new file mode 100644 index 0000000000..b37637ce0b --- /dev/null +++ b/native/jni-bridge/src/handle_batch_iterator.rs @@ -0,0 +1,51 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use jni::signature::Primitive; +use jni::{ + errors::Result as JniResult, + objects::{JClass, JMethodID}, + signature::ReturnType, + strings::JNIString, + Env, +}; + +/// A struct that holds JNI methods for the JVM `CometHandleBatchIterator` class. +#[allow(dead_code)] // we need to keep references to Java items to prevent GC +pub struct CometHandleBatchIterator<'a> { + pub class: JClass<'a>, + pub method_next_handle: JMethodID, + pub method_next_handle_ret: ReturnType, +} + +impl<'a> CometHandleBatchIterator<'a> { + pub const JVM_CLASS: &'static str = "org/apache/comet/CometHandleBatchIterator"; + + pub fn new(env: &mut Env<'a>) -> JniResult> { + let class = env.find_class(JNIString::new(Self::JVM_CLASS))?; + + Ok(CometHandleBatchIterator { + class, + method_next_handle: env.get_method_id( + JNIString::new(Self::JVM_CLASS), + jni::jni_str!("nextHandle"), + jni::jni_sig!("()J"), + )?, + method_next_handle_ret: ReturnType::Primitive(Primitive::Long), + }) + } +} diff --git a/native/jni-bridge/src/lib.rs b/native/jni-bridge/src/lib.rs index 5b0c0a4a56..82b85d4be4 100644 --- a/native/jni-bridge/src/lib.rs +++ b/native/jni-bridge/src/lib.rs @@ -181,11 +181,13 @@ pub use comet_exec::*; mod batch_iterator; mod comet_metric_node; mod comet_task_memory_manager; +mod handle_batch_iterator; mod shuffle_block_iterator; use batch_iterator::CometBatchIterator; pub use comet_metric_node::*; pub use comet_task_memory_manager::*; +use handle_batch_iterator::CometHandleBatchIterator; use shuffle_block_iterator::CometShuffleBlockIterator; /// The JVM classes that are used in the JNI calls. @@ -214,6 +216,8 @@ pub struct JVMClasses<'a> { pub comet_batch_iterator: CometBatchIterator<'a>, /// The CometShuffleBlockIterator class. Used for iterating over shuffle blocks. pub comet_shuffle_block_iterator: CometShuffleBlockIterator<'a>, + /// The CometHandleBatchIterator class. Used for passing batch handles between native contexts. + pub comet_handle_batch_iterator: CometHandleBatchIterator<'a>, /// The CometTaskMemoryManager used for interacting with JVM side to /// acquire & release native memory. pub comet_task_memory_manager: CometTaskMemoryManager<'a>, @@ -286,6 +290,7 @@ impl JVMClasses<'_> { comet_exec: CometExec::new(env).unwrap(), comet_batch_iterator: CometBatchIterator::new(env).unwrap(), comet_shuffle_block_iterator: CometShuffleBlockIterator::new(env).unwrap(), + comet_handle_batch_iterator: CometHandleBatchIterator::new(env).unwrap(), comet_task_memory_manager: CometTaskMemoryManager::new(env).unwrap(), } }); diff --git a/native/proto/src/proto/operator.proto b/native/proto/src/proto/operator.proto index fb438b26a4..df5d5b308a 100644 --- a/native/proto/src/proto/operator.proto +++ b/native/proto/src/proto/operator.proto @@ -84,6 +84,9 @@ message Scan { string source = 2; // Whether native code can assume ownership of batches that it receives bool arrow_ffi_safe = 3; + // When true, the input is a CometHandleBatchIterator and batches should be + // retrieved from the BatchStash instead of via Arrow FFI import. + bool batch_stash_handle = 4; } message ShuffleScan { diff --git a/spark/src/main/java/org/apache/comet/CometHandleBatchIterator.java b/spark/src/main/java/org/apache/comet/CometHandleBatchIterator.java new file mode 100644 index 0000000000..98cef1d5e8 --- /dev/null +++ b/spark/src/main/java/org/apache/comet/CometHandleBatchIterator.java @@ -0,0 +1,44 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.comet; + +/** + * Iterator that passes opaque native batch handles between two native execution contexts through + * the JVM. Used when a native child plan feeds directly into a native ShuffleWriter, avoiding Arrow + * FFI export/import overhead. + * + *

Called from native ScanExec via JNI. The source CometExecIterator must be in stash mode. + */ +public class CometHandleBatchIterator { + private final CometExecIterator source; + + public CometHandleBatchIterator(CometExecIterator source) { + this.source = source; + } + + /** + * Get the next batch handle from the source iterator. + * + * @return a native batch handle (positive long), or -1 if no more batches. + */ + public long nextHandle() { + return source.nextHandle(); + } +} diff --git a/spark/src/main/scala/org/apache/comet/CometExecIterator.scala b/spark/src/main/scala/org/apache/comet/CometExecIterator.scala index f0c6373149..1916436600 100644 --- a/spark/src/main/scala/org/apache/comet/CometExecIterator.scala +++ b/spark/src/main/scala/org/apache/comet/CometExecIterator.scala @@ -68,7 +68,8 @@ class CometExecIterator( partitionIndex: Int, broadcastedHadoopConfForEncryption: Option[Broadcast[SerializableConfiguration]] = None, encryptedFilePaths: Seq[String] = Seq.empty, - shuffleBlockIterators: Map[Int, CometShuffleBlockIterator] = Map.empty) + shuffleBlockIterators: Map[Int, CometShuffleBlockIterator] = Map.empty, + handleInputs: Array[Object] = Array.empty) extends Iterator[ColumnarBatch] with Logging { @@ -79,14 +80,22 @@ class CometExecIterator( private val taskAttemptId = TaskContext.get().taskAttemptId private val taskCPUs = TaskContext.get().cpus() private val cometTaskMemoryManager = new CometTaskMemoryManager(id, taskAttemptId) + // When true, executePlan stashes output batches natively and returns handles + // instead of exporting via Arrow FFI. Used when output feeds a native ShuffleWriter. + private var stashMode: Boolean = false + private var pendingHandle: Long = -1L // Build a mixed array of iterators: CometShuffleBlockIterator for shuffle // scan indices, CometBatchIterator for regular scan indices. - private val inputIterators: Array[Object] = inputs.zipWithIndex.map { - case (_, idx) if shuffleBlockIterators.contains(idx) => - shuffleBlockIterators(idx).asInstanceOf[Object] - case (iterator, _) => - new CometBatchIterator(iterator, nativeUtil).asInstanceOf[Object] - }.toArray + private val inputIterators: Array[Object] = if (handleInputs.nonEmpty) { + handleInputs + } else { + inputs.zipWithIndex.map { + case (_, idx) if shuffleBlockIterators.contains(idx) => + shuffleBlockIterators(idx).asInstanceOf[Object] + case (iterator, _) => + new CometBatchIterator(iterator, nativeUtil).asInstanceOf[Object] + }.toArray + } private val plan = { val conf = SparkEnv.get.conf @@ -189,31 +198,49 @@ class CometExecIterator( override def hasNext: Boolean = { if (closed) return false - if (nextBatch.isDefined) { - return true - } + if (stashMode) { + if (pendingHandle >= 0) return true + val ctx = TaskContext.get() + pendingHandle = nativeLib.executePlanBatchHandle(ctx.stageId(), partitionIndex, plan) + if (pendingHandle == -1L) { + close() + false + } else { + true + } + } else { + if (nextBatch.isDefined) { + return true + } - // Close previous batch if any. - // This is to guarantee safety at the native side before we overwrite the buffer memory - // shared across batches in the native side. - if (prevBatch != null) { - prevBatch.close() - prevBatch = null - } + // Close previous batch if any. + // This is to guarantee safety at the native side before we overwrite the buffer memory + // shared across batches in the native side. + if (prevBatch != null) { + prevBatch.close() + prevBatch = null + } - nextBatch = getNextBatch + nextBatch = getNextBatch - logTrace(s"Task $taskAttemptId memory pool usage is ${cometTaskMemoryManager.getUsed} bytes") + logTrace( + s"Task $taskAttemptId memory pool usage is ${cometTaskMemoryManager.getUsed} bytes") - if (nextBatch.isEmpty) { - close() - false - } else { - true + if (nextBatch.isEmpty) { + close() + false + } else { + true + } } } override def next(): ColumnarBatch = { + if (stashMode) { + throw new UnsupportedOperationException( + "next() should not be called in stash mode. Use nextHandle() instead.") + } + if (currentBatch != null) { // Eagerly release Arrow Arrays in the previous batch currentBatch.close() @@ -230,6 +257,32 @@ class CometExecIterator( currentBatch } + /** Enable stash mode. Must be called before iteration begins. */ + def enableStashMode(): Unit = { + stashMode = true + } + + /** + * In stash mode, advance the native plan and return the batch handle. Returns a positive + * handle, or -1 for EOF. + */ + def nextHandle(): Long = { + if (closed) return -1L + + if (pendingHandle >= 0) { + val h = pendingHandle + pendingHandle = -1L + return h + } + + val ctx = TaskContext.get() + val handle = nativeLib.executePlanBatchHandle(ctx.stageId(), partitionIndex, plan) + if (handle == -1L) { + close() + } + handle + } + def close(): Unit = synchronized { if (!closed) { if (currentBatch != null) { diff --git a/spark/src/main/scala/org/apache/comet/Native.scala b/spark/src/main/scala/org/apache/comet/Native.scala index c003bcd138..ff709d2b23 100644 --- a/spark/src/main/scala/org/apache/comet/Native.scala +++ b/spark/src/main/scala/org/apache/comet/Native.scala @@ -95,6 +95,23 @@ class Native extends NativeBase { arrayAddrs: Array[Long], schemaAddrs: Array[Long]): Long + /** + * Execute one step of the native query plan, stashing the output RecordBatch in the native + * BatchStash instead of exporting via Arrow FFI. Returns the stash handle (positive long) or -1 + * for EOF. Used when the output feeds directly into another native plan (e.g., native + * ShuffleWriter) to avoid unnecessary FFI round-trips. + * + * @param stage + * the stage ID, for informational purposes + * @param partition + * the partition ID, for informational purposes + * @param plan + * the address to native query plan. + * @return + * a batch stash handle (positive), or -1 for EOF. + */ + @native def executePlanBatchHandle(stage: Int, partition: Int, plan: Long): Long + /** * Release and drop the native query plan object and context object. * diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometNativeShuffleWriter.scala b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometNativeShuffleWriter.scala index f27d021ac4..cef08b36b2 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometNativeShuffleWriter.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometNativeShuffleWriter.scala @@ -36,7 +36,7 @@ import org.apache.spark.sql.comet.{CometExec, CometMetricNode} import org.apache.spark.sql.execution.metric.SQLMetric import org.apache.spark.sql.vectorized.ColumnarBatch -import org.apache.comet.CometConf +import org.apache.comet.{CometConf, CometExecIterator, CometHandleBatchIterator} import org.apache.comet.serde.{OperatorOuterClass, PartitioningOuterClass, QueryPlanSerde} import org.apache.comet.serde.OperatorOuterClass.{CompressionCodec, Operator} import org.apache.comet.serde.QueryPlanSerde.serializeDataType @@ -72,8 +72,20 @@ class CometNativeShuffleWriter[K, V]( val tempDataFilePath = Paths.get(tempDataFilename) val tempIndexFilePath = Paths.get(tempIndexFilename) + // Detect if input comes from a native plan (CometExecIterator) + val nativeIter: Option[CometExecIterator] = + if (CometConf.COMET_SHUFFLE_BATCH_STASH_ENABLED.get()) { + inputs match { + case swi: CometShuffleWriterInputIterator => swi.nativeIterator + case _ => None + } + } else { + None + } + val useHandleMode = nativeIter.isDefined + // Call native shuffle write - val nativePlan = getNativePlan(tempDataFilename, tempIndexFilename) + val nativePlan = getNativePlan(tempDataFilename, tempIndexFilename, useHandleMode) val detailedMetrics = Seq( "elapsed_compute", @@ -93,18 +105,31 @@ class CometNativeShuffleWriter[K, V]( metrics.filterKeys(detailedMetrics.contains) val nativeMetrics = CometMetricNode(nativeSQLMetrics) - // Getting rid of the fake partitionId - val newInputs = inputs.asInstanceOf[Iterator[_ <: Product2[Any, Any]]].map(_._2) - - val cometIter = CometExec.getCometIterator( - Seq(newInputs.asInstanceOf[Iterator[ColumnarBatch]]), - outputAttributes.length, - nativePlan, - nativeMetrics, - numParts, - context.partitionId(), - broadcastedHadoopConfForEncryption = None, - encryptedFilePaths = Seq.empty) + val cometIter = nativeIter match { + case Some(childIter) => + // Stash mode: child plan stashes batches, shuffle writer retrieves via handles + childIter.enableStashMode() + val handleIter = new CometHandleBatchIterator(childIter) + CometExec.getCometIteratorWithHandleInputs( + Array(handleIter.asInstanceOf[Object]), + outputAttributes.length, + nativePlan, + nativeMetrics, + numParts, + context.partitionId()) + case None => + // Normal FFI mode: wrap input in CometBatchIterator as before + val newInputs = inputs.asInstanceOf[Iterator[_ <: Product2[Any, Any]]].map(_._2) + CometExec.getCometIterator( + Seq(newInputs.asInstanceOf[Iterator[ColumnarBatch]]), + outputAttributes.length, + nativePlan, + nativeMetrics, + numParts, + context.partitionId(), + broadcastedHadoopConfForEncryption = None, + encryptedFilePaths = Seq.empty) + } while (cometIter.hasNext) { cometIter.next() @@ -162,8 +187,14 @@ class CometNativeShuffleWriter[K, V]( case _ => false } - private def getNativePlan(dataFile: String, indexFile: String): Operator = { - val scanBuilder = OperatorOuterClass.Scan.newBuilder().setSource("ShuffleWriterInput") + private def getNativePlan( + dataFile: String, + indexFile: String, + useHandleMode: Boolean = false): Operator = { + val scanBuilder = OperatorOuterClass.Scan + .newBuilder() + .setSource("ShuffleWriterInput") + .setBatchStashHandle(useHandleMode) val opBuilder = OperatorOuterClass.Operator.newBuilder() val scanTypes = outputAttributes.flatten { attr => diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleExchangeExec.scala b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleExchangeExec.scala index df2dca0331..40b231e090 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleExchangeExec.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleExchangeExec.scala @@ -50,6 +50,7 @@ import com.google.common.base.Objects import org.apache.comet.CometConf import org.apache.comet.CometConf.{COMET_EXEC_SHUFFLE_ENABLED, COMET_SHUFFLE_MODE} +import org.apache.comet.CometExecIterator import org.apache.comet.CometSparkSessionExtensions.{isCometShuffleManagerEnabled, withInfo} import org.apache.comet.serde.{Compatible, OperatorOuterClass, QueryPlanSerde, SupportLevel, Unsupported} import org.apache.comet.serde.operator.CometSink @@ -640,10 +641,16 @@ object CometShuffleExchangeExec None) } + val wrappedRDD = rdd.mapPartitions { iter => + val nativeIter = iter match { + case cei: CometExecIterator => Some(cei) + case _ => None + } + new CometShuffleWriterInputIterator(iter, nativeIter) + } + val dependency = new CometShuffleDependency[Int, ColumnarBatch, ColumnarBatch]( - rdd.map( - (0, _) - ), // adding fake partitionId that is always 0 because ShuffleDependency requires it + wrappedRDD, serializer = serializer, shuffleWriterProcessor = ShuffleExchangeExec.createShuffleWriteProcessor(metrics), shuffleType = CometNativeShuffle, @@ -858,3 +865,16 @@ object CometShuffleExchangeExec dependency } } + +/** + * An iterator wrapper that preserves access to the underlying CometExecIterator when present. + * Used by CometNativeShuffleWriter to detect native child plans and enable the batch stash + * optimization. + */ +private[shuffle] class CometShuffleWriterInputIterator( + underlying: Iterator[ColumnarBatch], + val nativeIterator: Option[CometExecIterator]) + extends Iterator[Product2[Int, ColumnarBatch]] { + override def hasNext: Boolean = underlying.hasNext + override def next(): Product2[Int, ColumnarBatch] = (0, underlying.next()) +} diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala b/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala index 21cbdab974..dec2e7fbd1 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala @@ -370,6 +370,29 @@ object CometExec { encryptedFilePaths) } + /** + * Create a CometExecIterator with pre-built input iterators (e.g., CometHandleBatchIterator). + * Bypasses the normal CometBatchIterator wrapping. + */ + def getCometIteratorWithHandleInputs( + handleInputs: Array[Object], + numOutputCols: Int, + nativePlan: Operator, + nativeMetrics: CometMetricNode, + numParts: Int, + partitionIdx: Int): CometExecIterator = { + val bytes = serializeNativePlan(nativePlan) + new CometExecIterator( + newIterId, + Seq.empty, + numOutputCols, + bytes, + nativeMetrics, + numParts, + partitionIdx, + handleInputs = handleInputs) + } + /** * Executes this Comet operator and serialized output ColumnarBatch into bytes. */