From 6dd3950fbff364b999a1c0e7603c8bd5a195d99b Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Mon, 13 Apr 2026 14:03:24 -0600 Subject: [PATCH 1/2] fix: share unified memory pools across native execution contexts within a task (#3924) --- docs/source/user-guide/latest/tuning.md | 7 +- native/core/src/execution/jni_api.rs | 67 +++++++++++++++++++ .../core/src/execution/memory_pools/config.rs | 5 +- native/core/src/execution/memory_pools/mod.rs | 42 ++++++++---- 4 files changed, 105 insertions(+), 16 deletions(-) diff --git a/docs/source/user-guide/latest/tuning.md b/docs/source/user-guide/latest/tuning.md index 5939e89ef3..ff9acee1f4 100644 --- a/docs/source/user-guide/latest/tuning.md +++ b/docs/source/user-guide/latest/tuning.md @@ -61,7 +61,12 @@ The valid pool types are: - `fair_unified` (default when `spark.memory.offHeap.enabled=true` is set) - `greedy_unified` -The `fair_unified` pool types prevents operators from using more than an even fraction of the available memory +Both pool types are shared across all native execution contexts within the same Spark task. When +Comet executes a shuffle, it runs two native execution contexts concurrently (e.g. one for +pre-shuffle operators and one for the shuffle writer). The shared pool ensures that the combined +memory usage stays within the per-task limit. + +The `fair_unified` pool prevents operators from using more than an even fraction of the available memory (i.e. `pool_size / num_reservations`). This pool works best when you know beforehand the query has multiple operators that will likely all need to spill. Sometimes it will cause spills even when there is sufficient memory in order to leave enough memory for other operators. diff --git a/native/core/src/execution/jni_api.rs b/native/core/src/execution/jni_api.rs index 361deae182..77e9578bbf 100644 --- a/native/core/src/execution/jni_api.rs +++ b/native/core/src/execution/jni_api.rs @@ -26,6 +26,9 @@ use crate::{ }, jvm_bridge::{jni_new_global_ref, JVMClasses}, }; +use parking_lot::Mutex; +use std::collections::HashSet; + use arrow::array::{Array, RecordBatch, UInt32Array}; use arrow::compute::{take, TakeOptions}; use arrow::datatypes::DataType as ArrowDataType; @@ -102,6 +105,70 @@ use tikv_jemalloc_ctl::{epoch, stats}; static TOKIO_RUNTIME: OnceLock = OnceLock::new(); +#[cfg(feature = "jemalloc")] +fn log_jemalloc_usage() { + let e = epoch::mib().unwrap(); + let allocated = stats::allocated::mib().unwrap(); + e.advance().unwrap(); + log_memory_usage("jemalloc_allocated", allocated.read().unwrap() as u64); +} + +/// Registry of active memory pools per Rust thread ID. +/// Used to sum memory reservations across all contexts on the same thread for tracing. +type ThreadPoolMap = HashMap>>; + +static THREAD_MEMORY_POOLS: OnceLock> = OnceLock::new(); + +fn get_thread_memory_pools() -> &'static Mutex { + THREAD_MEMORY_POOLS.get_or_init(|| Mutex::new(HashMap::new())) +} + +fn register_memory_pool(thread_id: u64, context_id: i64, pool: Arc) { + get_thread_memory_pools() + .lock() + .entry(thread_id) + .or_default() + .insert(context_id, pool); +} + +/// Unregister a context's pool and return the remaining total reserved for the thread. +fn unregister_and_total(thread_id: u64, context_id: i64) -> usize { + let mut map = get_thread_memory_pools().lock(); + if let Some(pools) = map.get_mut(&thread_id) { + pools.remove(&context_id); + if pools.is_empty() { + map.remove(&thread_id); + return 0; + } + let mut seen = HashSet::new(); + return pools + .values() + .filter_map(|p| { + let ptr = Arc::as_ptr(p) as *const (); + seen.insert(ptr).then(|| p.reserved()) + }) + .sum::(); + } + 0 +} + +fn total_reserved_for_thread(thread_id: u64) -> usize { + let map = get_thread_memory_pools().lock(); + map.get(&thread_id) + .map(|pools| { + // Deduplicate pools that share the same underlying allocation + // (e.g. task-shared pools registered by multiple execution contexts) + let mut seen = HashSet::new(); + pools + .values() + .filter_map(|p| { + let ptr = Arc::as_ptr(p) as *const (); + seen.insert(ptr).then(|| p.reserved()) + }) + .sum::() + }) + .unwrap_or(0) +} fn parse_usize_env_var(name: &str) -> Option { std::env::var_os(name).and_then(|n| n.to_str().and_then(|s| s.parse::().ok())) } diff --git a/native/core/src/execution/memory_pools/config.rs b/native/core/src/execution/memory_pools/config.rs index d30126a99a..83d6c14a36 100644 --- a/native/core/src/execution/memory_pools/config.rs +++ b/native/core/src/execution/memory_pools/config.rs @@ -34,7 +34,10 @@ impl MemoryPoolType { pub(crate) fn is_task_shared(&self) -> bool { matches!( self, - MemoryPoolType::GreedyTaskShared | MemoryPoolType::FairSpillTaskShared + MemoryPoolType::GreedyTaskShared + | MemoryPoolType::FairSpillTaskShared + | MemoryPoolType::FairUnified + | MemoryPoolType::GreedyUnified ) } } diff --git a/native/core/src/execution/memory_pools/mod.rs b/native/core/src/execution/memory_pools/mod.rs index d8b3473353..34f0587537 100644 --- a/native/core/src/execution/memory_pools/mod.rs +++ b/native/core/src/execution/memory_pools/mod.rs @@ -42,22 +42,36 @@ pub(crate) fn create_memory_pool( const NUM_TRACKED_CONSUMERS: usize = 10; match memory_pool_config.pool_type { MemoryPoolType::GreedyUnified => { - // Set Comet memory pool for native - let memory_pool = - CometUnifiedMemoryPool::new(comet_task_memory_manager, task_attempt_id); - Arc::new(TrackConsumersPool::new( - memory_pool, - NonZeroUsize::new(NUM_TRACKED_CONSUMERS).unwrap(), - )) + let mut memory_pool_map = TASK_SHARED_MEMORY_POOLS.lock().unwrap(); + let per_task_memory_pool = + memory_pool_map.entry(task_attempt_id).or_insert_with(|| { + let pool: Arc = Arc::new(TrackConsumersPool::new( + CometUnifiedMemoryPool::new( + Arc::clone(&comet_task_memory_manager), + task_attempt_id, + ), + NonZeroUsize::new(NUM_TRACKED_CONSUMERS).unwrap(), + )); + PerTaskMemoryPool::new(pool) + }); + per_task_memory_pool.num_plans += 1; + Arc::clone(&per_task_memory_pool.memory_pool) } MemoryPoolType::FairUnified => { - // Set Comet fair memory pool for native - let memory_pool = - CometFairMemoryPool::new(comet_task_memory_manager, memory_pool_config.pool_size); - Arc::new(TrackConsumersPool::new( - memory_pool, - NonZeroUsize::new(NUM_TRACKED_CONSUMERS).unwrap(), - )) + let mut memory_pool_map = TASK_SHARED_MEMORY_POOLS.lock().unwrap(); + let per_task_memory_pool = + memory_pool_map.entry(task_attempt_id).or_insert_with(|| { + let pool: Arc = Arc::new(TrackConsumersPool::new( + CometFairMemoryPool::new( + Arc::clone(&comet_task_memory_manager), + memory_pool_config.pool_size, + ), + NonZeroUsize::new(NUM_TRACKED_CONSUMERS).unwrap(), + )); + PerTaskMemoryPool::new(pool) + }); + per_task_memory_pool.num_plans += 1; + Arc::clone(&per_task_memory_pool.memory_pool) } MemoryPoolType::Greedy => Arc::new(TrackConsumersPool::new( GreedyMemoryPool::new(memory_pool_config.pool_size), From a8529ed8c3561fc7dfe7867c72a0ab99ca0155d6 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Mon, 13 Apr 2026 16:27:12 -0600 Subject: [PATCH 2/2] fix: remove uncalled helper functions from backport Remove tracing helper functions (register_memory_pool, unregister_and_total, total_reserved_for_thread, log_jemalloc_usage) whose call sites do not exist on branch-0.14, along with their now-unused imports. --- native/core/src/execution/jni_api.rs | 67 ---------------------------- 1 file changed, 67 deletions(-) diff --git a/native/core/src/execution/jni_api.rs b/native/core/src/execution/jni_api.rs index 77e9578bbf..361deae182 100644 --- a/native/core/src/execution/jni_api.rs +++ b/native/core/src/execution/jni_api.rs @@ -26,9 +26,6 @@ use crate::{ }, jvm_bridge::{jni_new_global_ref, JVMClasses}, }; -use parking_lot::Mutex; -use std::collections::HashSet; - use arrow::array::{Array, RecordBatch, UInt32Array}; use arrow::compute::{take, TakeOptions}; use arrow::datatypes::DataType as ArrowDataType; @@ -105,70 +102,6 @@ use tikv_jemalloc_ctl::{epoch, stats}; static TOKIO_RUNTIME: OnceLock = OnceLock::new(); -#[cfg(feature = "jemalloc")] -fn log_jemalloc_usage() { - let e = epoch::mib().unwrap(); - let allocated = stats::allocated::mib().unwrap(); - e.advance().unwrap(); - log_memory_usage("jemalloc_allocated", allocated.read().unwrap() as u64); -} - -/// Registry of active memory pools per Rust thread ID. -/// Used to sum memory reservations across all contexts on the same thread for tracing. -type ThreadPoolMap = HashMap>>; - -static THREAD_MEMORY_POOLS: OnceLock> = OnceLock::new(); - -fn get_thread_memory_pools() -> &'static Mutex { - THREAD_MEMORY_POOLS.get_or_init(|| Mutex::new(HashMap::new())) -} - -fn register_memory_pool(thread_id: u64, context_id: i64, pool: Arc) { - get_thread_memory_pools() - .lock() - .entry(thread_id) - .or_default() - .insert(context_id, pool); -} - -/// Unregister a context's pool and return the remaining total reserved for the thread. -fn unregister_and_total(thread_id: u64, context_id: i64) -> usize { - let mut map = get_thread_memory_pools().lock(); - if let Some(pools) = map.get_mut(&thread_id) { - pools.remove(&context_id); - if pools.is_empty() { - map.remove(&thread_id); - return 0; - } - let mut seen = HashSet::new(); - return pools - .values() - .filter_map(|p| { - let ptr = Arc::as_ptr(p) as *const (); - seen.insert(ptr).then(|| p.reserved()) - }) - .sum::(); - } - 0 -} - -fn total_reserved_for_thread(thread_id: u64) -> usize { - let map = get_thread_memory_pools().lock(); - map.get(&thread_id) - .map(|pools| { - // Deduplicate pools that share the same underlying allocation - // (e.g. task-shared pools registered by multiple execution contexts) - let mut seen = HashSet::new(); - pools - .values() - .filter_map(|p| { - let ptr = Arc::as_ptr(p) as *const (); - seen.insert(ptr).then(|| p.reserved()) - }) - .sum::() - }) - .unwrap_or(0) -} fn parse_usize_env_var(name: &str) -> Option { std::env::var_os(name).and_then(|n| n.to_str().and_then(|s| s.parse::().ok())) }