diff --git a/native/core/Cargo.toml b/native/core/Cargo.toml index 0d3b084ba3..b32eec5686 100644 --- a/native/core/Cargo.toml +++ b/native/core/Cargo.toml @@ -101,6 +101,11 @@ hdfs = ["datafusion-comet-objectstore-hdfs"] hdfs-opendal = ["opendal", "object_store_opendal", "hdfs-sys"] jemalloc = ["tikv-jemallocator", "tikv-jemalloc-ctl"] +# Allocator-level OOM circuit breaker. When enabled, the global allocator is +# wrapped to track real allocated bytes and panic an over-budget query-worker +# thread (caught at the task boundary). Off by default; zero overhead when off. +oom-guard = [] + # exclude optional packages from cargo machete verifications [package.metadata.cargo-machete] ignored = ["hdfs-sys", "paste"] diff --git a/native/core/src/execution/jni_api.rs b/native/core/src/execution/jni_api.rs index 0dcd78ba0f..94c676cd92 100644 --- a/native/core/src/execution/jni_api.rs +++ b/native/core/src/execution/jni_api.rs @@ -112,9 +112,13 @@ use crate::execution::spark_config::{ SparkConfig, COMET_DEBUG_ENABLED, COMET_DEBUG_MEMORY, COMET_EXPLAIN_NATIVE_ENABLED, COMET_MAX_TEMP_DIRECTORY_SIZE, COMET_TRACING_ENABLED, SPARK_EXECUTOR_CORES, }; +#[cfg(feature = "oom-guard")] +use crate::execution::spark_config::{COMET_MEMORY_GUARD_ENABLED, COMET_MEMORY_GUARD_SIZE}; use crate::parquet::encryption_support::{CometEncryptionFactory, ENCRYPTION_FACTORY_ID}; use datafusion_comet_proto::spark_operator::operator::OpStruct; use log::info; +#[cfg(feature = "oom-guard")] +use log::warn; use std::sync::OnceLock; #[cfg(feature = "jemalloc")] use tikv_jemalloc_ctl::{epoch, stats}; @@ -192,6 +196,8 @@ fn parse_usize_env_var(name: &str) -> Option { fn build_runtime(default_worker_threads: Option) -> Runtime { let mut builder = tokio::runtime::Builder::new_multi_thread(); + #[cfg(feature = "oom-guard")] + builder.on_thread_start(|| crate::execution::memory_pools::oom_guard::stamp_current_thread()); if let Some(n) = parse_usize_env_var("COMET_WORKER_THREADS") { info!("Comet tokio runtime: using COMET_WORKER_THREADS={n}"); builder.worker_threads(n); @@ -369,6 +375,24 @@ pub unsafe extern "system" fn Java_org_apache_comet_Native_createPlan( spark_config.get_u64(COMET_MAX_TEMP_DIRECTORY_SIZE, 100 * 1024 * 1024 * 1024); let logging_memory_pool = spark_config.get_bool(COMET_DEBUG_MEMORY); + #[cfg(feature = "oom-guard")] + { + if spark_config.get_bool(COMET_MEMORY_GUARD_ENABLED) { + // Default to the executor off-heap memory limit (`memory_limit`); + // allow an explicit override. + let default_limit = memory_limit.max(0) as u64; + let limit = spark_config.get_u64(COMET_MEMORY_GUARD_SIZE, default_limit); + if limit == 0 { + warn!( + "spark.comet.exec.memoryGuard.enabled is true but the effective limit \ + is 0 (memory_limit={memory_limit}); the guard will not trip. Set \ + spark.comet.exec.memoryGuard.size explicitly." + ); + } + crate::execution::memory_pools::oom_guard::arm(limit as usize); + } + } + with_trace("createPlan", tracing_enabled, || { // Init JVM classes JVMClasses::init(env); @@ -715,6 +739,8 @@ pub unsafe extern "system" fn Java_org_apache_comet_Native_executePlan( schema_addrs: JLongArray, ) -> jlong { try_unwrap_or_throw(&e, |env| { + #[cfg(feature = "oom-guard")] + crate::execution::memory_pools::oom_guard::stamp_current_thread(); // Retrieve the query let exec_context = get_execution_context(exec_context); @@ -786,6 +812,17 @@ pub unsafe extern "system" fn Java_org_apache_comet_Native_executePlan( .await; if let Err(panic) = result { + #[cfg(feature = "oom-guard")] + if let Some(e) = + crate::execution::memory_pools::oom_guard::map_panic_to_error( + panic.as_ref(), + ) + { + // Runs on the tokio worker thread that panicked, so this clears + // that worker's UNWINDING flag (not the blocked JNI caller thread's). + let _ = tx.send(Err(e)).await; + return; + } let msg = match panic.downcast_ref::<&str>() { Some(s) => s.to_string(), None => match panic.downcast_ref::() { @@ -810,76 +847,120 @@ pub unsafe extern "system" fn Java_org_apache_comet_Native_executePlan( pull_input_batches(exec_context)?; } - if let Some(rx) = &mut exec_context.batch_receiver { - match rx.blocking_recv() { - Some(Ok(batch)) => { - update_metrics(env, exec_context)?; - return prepare_output( - env, - array_addrs, - schema_addrs, - batch, - exec_context.debug_native, - ); - } - Some(Err(e)) => { - return Err(e.into()); - } - None => { - log_plan_metrics(exec_context, stage_id, partition); - return Ok(-1); + if exec_context.batch_receiver.is_some() { + let recv_result = std::panic::catch_unwind(std::panic::AssertUnwindSafe( + || -> CometResult { + // Scope the rx borrow to just the blocking_recv call so that + // exec_context is free for update_metrics / prepare_output below. + let recv = exec_context + .batch_receiver + .as_mut() + .unwrap() + .blocking_recv(); + match recv { + Some(Ok(batch)) => { + update_metrics(env, exec_context)?; + prepare_output( + env, + array_addrs, + schema_addrs, + batch, + exec_context.debug_native, + ) + } + Some(Err(e)) => Err(e.into()), + None => { + log_plan_metrics(exec_context, stage_id, partition); + Ok(-1) + } + } + }, + )); + + match recv_result { + Ok(r) => return r, + Err(_panic) => { + #[cfg(feature = "oom-guard")] + if let Some(e) = + crate::execution::memory_pools::oom_guard::map_panic_to_error( + _panic.as_ref(), + ) + { + // Drop the receiver so any re-entry re-initializes. + exec_context.batch_receiver = None; + return Err(e.into()); + } + std::panic::resume_unwind(_panic); } } } // ScanExec path: busy-poll to interleave JVM batch pulls with stream polling - get_runtime().block_on(async { - loop { - let next_item = exec_context.stream.as_mut().unwrap().next(); - let poll_output = poll!(next_item); - - // Only check time/tracing every 100 polls to reduce overhead - exec_context.poll_count_since_metrics_check += 1; - if exec_context.poll_count_since_metrics_check >= 100 { - exec_context.poll_count_since_metrics_check = 0; - if let Some(interval) = exec_context.metrics_update_interval { - let now = Instant::now(); - if now - exec_context.metrics_last_update_time >= interval { - update_metrics(env, exec_context)?; - exec_context.metrics_last_update_time = now; + let poll_result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| { + get_runtime().block_on(async { + loop { + let next_item = exec_context.stream.as_mut().unwrap().next(); + let poll_output = poll!(next_item); + + // Only check time/tracing every 100 polls to reduce overhead + exec_context.poll_count_since_metrics_check += 1; + if exec_context.poll_count_since_metrics_check >= 100 { + exec_context.poll_count_since_metrics_check = 0; + if let Some(interval) = exec_context.metrics_update_interval { + let now = Instant::now(); + if now - exec_context.metrics_last_update_time >= interval { + update_metrics(env, exec_context)?; + exec_context.metrics_last_update_time = now; + } + } + if exec_context.tracing_enabled { + log_memory_usage( + &exec_context.tracing_memory_metric_name, + total_reserved_for_thread(exec_context.rust_thread_id) as u64, + ); } } - if exec_context.tracing_enabled { - log_memory_usage( - &exec_context.tracing_memory_metric_name, - total_reserved_for_thread(exec_context.rust_thread_id) as u64, - ); - } - } - match poll_output { - Poll::Ready(Some(output)) => { - return prepare_output( - env, - array_addrs, - schema_addrs, - output?, - exec_context.debug_native, - ); - } - Poll::Ready(None) => { - log_plan_metrics(exec_context, stage_id, partition); - return Ok(-1); - } - Poll::Pending => { - // JNI call to pull batches from JVM into ScanExec operators. - // block_in_place lets tokio move other tasks off this worker - // while we wait for JVM data. - tokio::task::block_in_place(|| pull_input_batches(exec_context))?; + match poll_output { + Poll::Ready(Some(output)) => { + return prepare_output( + env, + array_addrs, + schema_addrs, + output?, + exec_context.debug_native, + ); + } + Poll::Ready(None) => { + log_plan_metrics(exec_context, stage_id, partition); + return Ok(-1); + } + Poll::Pending => { + // JNI call to pull batches from JVM into ScanExec operators. + // block_in_place lets tokio move other tasks off this worker + // while we wait for JVM data. + tokio::task::block_in_place(|| pull_input_batches(exec_context))?; + } } } + }) + })); + + match poll_result { + Ok(r) => r, + Err(_panic) => { + #[cfg(feature = "oom-guard")] + if let Some(e) = crate::execution::memory_pools::oom_guard::map_panic_to_error( + _panic.as_ref(), + ) { + // The block_on future was dropped mid-poll; null the stream so any + // inadvertent re-entry re-initializes rather than polling a half-consumed one. + exec_context.stream = None; + return Err(e.into()); + } + std::panic::resume_unwind(_panic); } - }) + } }); if exec_context.tracing_enabled { diff --git a/native/core/src/execution/memory_pools/mod.rs b/native/core/src/execution/memory_pools/mod.rs index 389e348990..47709baad0 100644 --- a/native/core/src/execution/memory_pools/mod.rs +++ b/native/core/src/execution/memory_pools/mod.rs @@ -18,6 +18,8 @@ mod config; mod fair_pool; pub mod logging_pool; +#[cfg(feature = "oom-guard")] +pub mod oom_guard; mod task_shared; mod unified_pool; diff --git a/native/core/src/execution/memory_pools/oom_guard.rs b/native/core/src/execution/memory_pools/oom_guard.rs new file mode 100644 index 0000000000..8a7a4d21e3 --- /dev/null +++ b/native/core/src/execution/memory_pools/oom_guard.rs @@ -0,0 +1,377 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use datafusion::common::DataFusionError; +use std::alloc::{GlobalAlloc, Layout}; +use std::cell::Cell; +use std::sync::atomic::{AtomicBool, AtomicIsize, AtomicUsize, Ordering}; + +/// Per-thread drift is flushed into the shared balance once it crosses this. +const SETTLE_THRESHOLD: isize = 64 * 1024; + +/// Process-wide outstanding bytes (signed so transient under-settle is fine). +static BALANCE: AtomicIsize = AtomicIsize::new(0); +/// Enforcement limit in bytes; 0 means unset. +static LIMIT: AtomicUsize = AtomicUsize::new(0); +/// Master enforcement gate (single relaxed load on the hot path). +static ARMED: AtomicBool = AtomicBool::new(false); + +thread_local! { + /// Un-flushed per-thread delta. + static LOCAL_DRIFT: Cell = const { Cell::new(0) }; + /// Is this a query-worker thread eligible for enforcement? + static STAMPED: Cell = const { Cell::new(false) }; + /// Set while a guard panic is unwinding this thread, to avoid double-faults. + static UNWINDING: Cell = const { Cell::new(false) }; +} + +/// Payload of the panic raised when an armed, stamped thread exceeds the limit. +#[derive(Debug)] +pub struct OomGuardPanic { + pub balance: usize, + pub limit: usize, +} + +/// Arm the guard with a byte limit. Idempotent. +pub fn arm(limit_bytes: usize) { + LIMIT.store(limit_bytes, Ordering::Relaxed); + ARMED.store(true, Ordering::Relaxed); +} + +/// Disarm the guard (enforcement off; tracking continues cheaply). +#[allow(dead_code)] // used only by tests +pub fn disarm() { + ARMED.store(false, Ordering::Relaxed); +} + +/// Mark the current thread as a query-worker thread eligible for enforcement. +pub fn stamp_current_thread() { + STAMPED.with(|s| s.set(true)); +} + +/// Reset the per-thread unwinding guard after a guard panic has been caught on +/// this thread. Safe to call when not unwinding. The JNI caller thread is +/// reused across tasks, so this must run after catching an OomGuardPanic. +pub fn clear_unwinding() { + UNWINDING.with(|u| u.set(false)); +} + +/// If `panic` is an `OomGuardPanic`, clear this thread's unwinding guard and +/// return the mapped retriable error. Returns `None` for any other panic. +/// Centralizes the downcast + unwinding-reset + error mapping for all catch sites. +pub fn map_panic_to_error(panic: &(dyn std::any::Any + Send)) -> Option { + let g = panic.downcast_ref::()?; + clear_unwinding(); + Some(DataFusionError::ResourcesExhausted(format!( + "Comet OomGuard: native allocation pushed usage to {} bytes, over the limit of {} \ + bytes; failing this task", + g.balance, g.limit + ))) +} + +/// Current process-wide balance in bytes (never reported negative). +#[allow(dead_code)] // used only by tests +pub fn current_balance() -> usize { + BALANCE.load(Ordering::Relaxed).max(0) as usize +} + +/// Record an allocation of `size` bytes; may trip the breaker. +#[inline] +fn record_alloc(size: usize) { + track(size as isize); +} + +/// Record a deallocation of `size` bytes; never trips (credit only). +#[inline] +fn record_dealloc(size: usize) { + track(-(size as isize)); +} + +/// Core tracking + enforcement. Flushes drift; on a debit flush that crosses the +/// limit on an armed, stamped, non-unwinding thread, panics with `OomGuardPanic`. +#[inline] +fn track(delta: isize) { + let new_balance = LOCAL_DRIFT.with(|d| { + let mut drift = d.get(); + let flushed = settle(&mut drift, delta, &BALANCE); + d.set(drift); + flushed + }); + + if delta <= 0 { + return; // credits never enforce + } + let Some(balance) = new_balance else { return }; + if !ARMED.load(Ordering::Relaxed) { + return; + } + if !STAMPED.with(|s| s.get()) { + return; + } + if UNWINDING.with(|u| u.get()) { + return; + } + let limit = LIMIT.load(Ordering::Relaxed); + if should_trip(balance, limit) { + // At most one thread may fire the guard panic per arm cycle. CAS the + // master gate true->false; threads that lose the race bail before + // panic_any. The relaxed load above (line ~121) is not a serialization + // point: several threads can all read ARMED=true and reach here in the + // same tight window. If each then dispatches a panic, Rust's unwind ABI + // can abort the process with "failed to initiate panic" instead of + // unwinding cleanly (observed on the 5-concurrent repro: ~4 threads + // firing within ~10 ms -> exit 133). The guard re-arms on the next + // createPlan. + if ARMED + .compare_exchange(true, false, Ordering::Relaxed, Ordering::Relaxed) + .is_err() + { + return; + } + // panic_any boxes the payload, which re-enters this allocator and calls + // track() again. ARMED is now false so the re-entrant call short-circuits + // at the ARMED check above; setting UNWINDING adds defense in depth in + // case a concurrent createPlan re-arms mid-unwind. + UNWINDING.with(|u| u.set(true)); + std::panic::panic_any(OomGuardPanic { + balance: balance.max(0) as usize, + limit, + }); + } +} + +/// Pure helper: given the current shared balance and a limit, decide whether an +/// armed+stamped thread should trip the breaker. `limit == 0` means "unset". +fn should_trip(balance: isize, limit: usize) -> bool { + limit != 0 && balance > limit.try_into().unwrap_or(isize::MAX) +} + +/// Pure helper: add `delta` to `local_drift`; if it reaches or exceeds `SETTLE_THRESHOLD` +/// in magnitude, flush it into `shared` and return the new shared balance. +/// Otherwise return `None` (nothing flushed). +fn settle(local_drift: &mut isize, delta: isize, shared: &AtomicIsize) -> Option { + *local_drift = local_drift.wrapping_add(delta); + if local_drift.unsigned_abs() >= SETTLE_THRESHOLD as usize { + let flushed = *local_drift; + *local_drift = 0; + let prev = shared.fetch_add(flushed, Ordering::Relaxed); + Some(prev.wrapping_add(flushed)) + } else { + None + } +} + +/// Wraps an inner global allocator, tracking layout bytes for the OomGuard. +pub struct AccountingAllocator { + inner: A, +} + +impl AccountingAllocator { + pub const fn new(inner: A) -> Self { + Self { inner } + } +} + +unsafe impl GlobalAlloc for AccountingAllocator { + unsafe fn alloc(&self, layout: Layout) -> *mut u8 { + let ptr = self.inner.alloc(layout); + if !ptr.is_null() { + record_alloc(layout.size()); + } + ptr + } + + unsafe fn dealloc(&self, ptr: *mut u8, layout: Layout) { + self.inner.dealloc(ptr, layout); + record_dealloc(layout.size()); + } + + unsafe fn alloc_zeroed(&self, layout: Layout) -> *mut u8 { + let ptr = self.inner.alloc_zeroed(layout); + if !ptr.is_null() { + record_alloc(layout.size()); + } + ptr + } + + unsafe fn realloc(&self, ptr: *mut u8, layout: Layout, new_size: usize) -> *mut u8 { + // Account for and enforce the size delta BEFORE delegating to the inner + // realloc. If this trips the breaker it panics here, while `ptr` is still + // valid, so the unwind frees it correctly. Panicking *after* inner.realloc + // would be unsound: realloc may have already freed/moved the old block, and + // the caller (which never received the new pointer) would free the dangling + // old pointer on unwind and segfault. Only growth can trip; over-counting on + // a (rare) realloc failure errs on the conservative side for an OOM guard. + // + // Casts and subtraction are safe in practice: a single allocation cannot + // exceed isize::MAX on any real platform, so no wrapping or overflow occurs. + let old = layout.size() as isize; + let new = new_size as isize; + track(new - old); + self.inner.realloc(ptr, layout, new_size) + } +} + +#[cfg(test)] +fn reset_for_test() { + BALANCE.store(0, Ordering::Relaxed); + LIMIT.store(0, Ordering::Relaxed); + ARMED.store(false, Ordering::Relaxed); + LOCAL_DRIFT.with(|d| d.set(0)); + STAMPED.with(|s| s.set(false)); + UNWINDING.with(|u| u.set(false)); +} + +#[cfg(test)] +fn clear_unwinding_for_test() { + UNWINDING.with(|u| u.set(false)); +} + +#[cfg(test)] +mod tests { + use super::*; + use std::sync::Mutex; + + // Serializes tests that mutate the process-global guard state. + static GUARD: Mutex<()> = Mutex::new(()); + + #[test] + fn test_should_trip() { + assert!(!should_trip(100, 0)); // unset limit never trips + assert!(!should_trip(100, 200)); // under limit + assert!(!should_trip(200, 200)); // at limit (strictly greater required) + assert!(should_trip(201, 200)); // over limit + } + + #[test] + fn test_settle_accumulates_then_flushes() { + let shared = AtomicIsize::new(0); + let mut drift = 0isize; + // small allocs below threshold do not flush + assert_eq!(settle(&mut drift, 1024, &shared), None); + assert_eq!(shared.load(Ordering::Relaxed), 0); + // crossing the threshold flushes the accumulated drift + let new_balance = settle(&mut drift, SETTLE_THRESHOLD, &shared); + assert_eq!(new_balance, Some(1024 + SETTLE_THRESHOLD)); + assert_eq!(shared.load(Ordering::Relaxed), 1024 + SETTLE_THRESHOLD); + assert_eq!(drift, 0); // drift reset after flush + } + + #[test] + fn test_settle_flushes_negative_drift() { + let shared = AtomicIsize::new(1_000_000); + let mut drift = 0isize; + assert_eq!( + settle(&mut drift, -SETTLE_THRESHOLD, &shared), + Some(1_000_000 - SETTLE_THRESHOLD) + ); + assert_eq!(drift, 0); + } + + #[test] + fn test_settle_flushes_at_exact_threshold() { + let shared = AtomicIsize::new(0); + let mut drift = 0isize; + assert_eq!( + settle(&mut drift, SETTLE_THRESHOLD, &shared), + Some(SETTLE_THRESHOLD) + ); + assert_eq!(drift, 0); + } + + #[test] + fn test_disarmed_never_trips() { + let _g = GUARD.lock().unwrap_or_else(|e| e.into_inner()); + reset_for_test(); + stamp_current_thread(); + // not armed -> record_alloc must never panic regardless of size + record_alloc(usize::MAX / 2); + record_alloc(usize::MAX / 2); + } + + #[test] + fn test_unstamped_thread_never_trips() { + let _g = GUARD.lock().unwrap_or_else(|e| e.into_inner()); + reset_for_test(); + // arm with a tiny limit relative to current balance, but DO NOT stamp + let limit = current_balance() + 1; + arm(limit); + record_alloc(SETTLE_THRESHOLD as usize * 4); // big enough to flush + disarm(); + } + + #[test] + fn test_stamped_over_budget_trips() { + let _g = GUARD.lock().unwrap_or_else(|e| e.into_inner()); + reset_for_test(); + stamp_current_thread(); + let limit = current_balance() + SETTLE_THRESHOLD as usize; // headroom + arm(limit); + let result = std::panic::catch_unwind(|| { + // exceed the headroom in one flush + record_alloc(SETTLE_THRESHOLD as usize * 4); + }); + disarm(); + clear_unwinding_for_test(); + assert!(result.is_err(), "expected OomGuardPanic"); + let panic = result.unwrap_err(); + assert!( + panic.downcast_ref::().is_some(), + "panic payload should be OomGuardPanic" + ); + } + + // Drives a real heap allocation through the installed AccountingAllocator (only + // wrapped under the `oom-guard` feature) and confirms the guard trips. + #[test] + #[cfg(feature = "oom-guard")] + fn test_real_allocation_trips_guard() { + let _g = GUARD.lock().unwrap_or_else(|e| e.into_inner()); + reset_for_test(); + stamp_current_thread(); + // 8 MiB headroom over the current (noisy) baseline. + let headroom = 8 * 1024 * 1024; + arm(current_balance() + headroom); + + let result = std::panic::catch_unwind(|| { + // Allocate well past the headroom in 1 MiB chunks so a flush crosses the limit. + let mut held: Vec> = Vec::new(); + for _ in 0..64 { + held.push(vec![0u8; 1024 * 1024]); + } + // Touch the data so the allocation cannot be optimized away. + held.iter().map(|v| v.len()).sum::() + }); + + // Disarm BEFORE clearing UNWINDING so no post-catch allocation on this still-armed, + // still-stamped thread can re-trip outside the catch. + disarm(); + clear_unwinding_for_test(); + + assert!( + result.is_err(), + "large allocation on a stamped, armed thread should trip the guard" + ); + assert!( + result + .unwrap_err() + .downcast_ref::() + .is_some(), + "panic payload should be OomGuardPanic" + ); + } +} diff --git a/native/core/src/execution/mod.rs b/native/core/src/execution/mod.rs index ec247f72b7..185b191f49 100644 --- a/native/core/src/execution/mod.rs +++ b/native/core/src/execution/mod.rs @@ -28,7 +28,7 @@ pub use datafusion_comet_shuffle as shuffle; pub(crate) mod sort; pub(crate) mod spark_plan; pub use datafusion_comet_spark_expr::timezone; -mod memory_pools; +pub(crate) mod memory_pools; pub(crate) mod spark_config; pub(crate) mod tracing; pub(crate) mod utils; diff --git a/native/core/src/execution/spark_config.rs b/native/core/src/execution/spark_config.rs index 277c0eb43b..13e60b00a7 100644 --- a/native/core/src/execution/spark_config.rs +++ b/native/core/src/execution/spark_config.rs @@ -23,6 +23,10 @@ pub(crate) const COMET_EXPLAIN_NATIVE_ENABLED: &str = "spark.comet.explain.nativ pub(crate) const COMET_MAX_TEMP_DIRECTORY_SIZE: &str = "spark.comet.maxTempDirectorySize"; pub(crate) const COMET_DEBUG_MEMORY: &str = "spark.comet.debug.memory"; pub(crate) const SPARK_EXECUTOR_CORES: &str = "spark.executor.cores"; +#[cfg(feature = "oom-guard")] +pub(crate) const COMET_MEMORY_GUARD_ENABLED: &str = "spark.comet.exec.memoryGuard.enabled"; +#[cfg(feature = "oom-guard")] +pub(crate) const COMET_MEMORY_GUARD_SIZE: &str = "spark.comet.exec.memoryGuard.size"; pub(crate) trait SparkConfig { fn get_bool(&self, name: &str) -> bool; diff --git a/native/core/src/lib.rs b/native/core/src/lib.rs index 19a2d774a0..5b87f26ba3 100644 --- a/native/core/src/lib.rs +++ b/native/core/src/lib.rs @@ -75,18 +75,49 @@ pub mod debug; #[cfg(all( not(target_env = "msvc"), feature = "jemalloc", - not(feature = "mimalloc") + not(feature = "mimalloc"), + not(feature = "oom-guard") ))] #[global_allocator] static GLOBAL: Jemalloc = Jemalloc; #[cfg(all( feature = "mimalloc", - not(all(not(target_env = "msvc"), feature = "jemalloc")) + not(all(not(target_env = "msvc"), feature = "jemalloc")), + not(feature = "oom-guard") ))] #[global_allocator] static GLOBAL: MiMalloc = MiMalloc; +#[cfg(all( + not(target_env = "msvc"), + feature = "jemalloc", + not(feature = "mimalloc"), + feature = "oom-guard" +))] +#[global_allocator] +static GLOBAL: crate::execution::memory_pools::oom_guard::AccountingAllocator = + crate::execution::memory_pools::oom_guard::AccountingAllocator::new(Jemalloc); + +#[cfg(all( + feature = "mimalloc", + not(all(not(target_env = "msvc"), feature = "jemalloc")), + feature = "oom-guard" +))] +#[global_allocator] +static GLOBAL: crate::execution::memory_pools::oom_guard::AccountingAllocator = + crate::execution::memory_pools::oom_guard::AccountingAllocator::new(MiMalloc); + +// oom-guard enabled with system allocator (no mimalloc, and no jemalloc or on MSVC). +#[cfg(all( + feature = "oom-guard", + not(feature = "mimalloc"), + any(target_env = "msvc", not(feature = "jemalloc")) +))] +#[global_allocator] +static GLOBAL: crate::execution::memory_pools::oom_guard::AccountingAllocator = + crate::execution::memory_pools::oom_guard::AccountingAllocator::new(std::alloc::System); + #[no_mangle] pub extern "system" fn Java_org_apache_comet_NativeBase_init( e: EnvUnowned, diff --git a/spark/src/main/scala/org/apache/comet/CometConf.scala b/spark/src/main/scala/org/apache/comet/CometConf.scala index 82700a939e..08338866f8 100644 --- a/spark/src/main/scala/org/apache/comet/CometConf.scala +++ b/spark/src/main/scala/org/apache/comet/CometConf.scala @@ -828,6 +828,26 @@ object CometConf extends ShimCometConf { .bytesConf(ByteUnit.BYTE) .createWithDefault(100L * 1024 * 1024 * 1024) // 100 GB + val COMET_EXEC_MEMORY_GUARD_ENABLED: ConfigEntry[Boolean] = + conf(s"$COMET_EXEC_CONFIG_PREFIX.memoryGuard.enabled") + .category(CATEGORY_EXEC) + .doc( + "Experimental. When enabled, Comet tracks real native memory allocations and aborts " + + "an over-budget task with a retriable error instead of risking an executor-wide OOM " + + "kill. Requires a Comet build with the 'oom-guard' native feature; has no effect " + + "on builds without it.") + .booleanConf + .createWithDefault(false) + + val COMET_EXEC_MEMORY_GUARD_SIZE: OptionalConfigEntry[Long] = + conf(s"$COMET_EXEC_CONFIG_PREFIX.memoryGuard.size") + .category(CATEGORY_EXEC) + .doc( + "Experimental. Memory budget for the Comet native OOM guard (accepts sizes like '4g'). " + + "Defaults to the executor off-heap memory size (spark.memory.offHeap.size) when unset.") + .bytesConf(ByteUnit.BYTE) + .createOptional + val COMET_RESPECT_DATAFUSION_CONFIGS: ConfigEntry[Boolean] = conf(s"$COMET_EXEC_CONFIG_PREFIX.respectDataFusionConfigs") .category(CATEGORY_TESTING)