Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions vortex-array/src/aggregate_fn/accumulator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ impl<V: AggregateFnVTable> DynAccumulator for Accumulator<V> {
// `Combined::try_accumulate` always returns true, so a later kernel check would be
// unreachable.
{
let kernels_r = kernels.read();
let kernels_r = kernels.load();
let batch_id = batch.encoding_id();
let kernel = kernels_r
.get(&(batch_id, Some(self.aggregate_fn.id())))
Expand Down Expand Up @@ -187,7 +187,7 @@ impl<V: AggregateFnVTable> DynAccumulator for Accumulator<V> {
break;
}

let kernels_r = kernels.read();
let kernels_r = kernels.load();
let batch_id = batch.encoding_id();
let kernel = kernels_r
.get(&(batch_id, Some(self.aggregate_fn.id())))
Expand Down
225 changes: 223 additions & 2 deletions vortex-array/src/aggregate_fn/accumulator_grouped.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,16 @@
// SPDX-FileCopyrightText: Copyright the Vortex contributors

use arrow_buffer::ArrowNativeType;
use num_traits::AsPrimitive;
use num_traits::ToPrimitive;
use vortex_buffer::Buffer;
use vortex_error::VortexExpect;
use vortex_error::VortexResult;
use vortex_error::vortex_bail;
use vortex_error::vortex_ensure;
use vortex_error::vortex_err;
use vortex_error::vortex_panic;
use vortex_mask::AllOr;
use vortex_mask::Mask;

use crate::AnyCanonical;
Expand All @@ -26,14 +29,18 @@ use crate::aggregate_fn::session::AggregateFnSessionExt;
use crate::arrays::ChunkedArray;
use crate::arrays::FixedSizeListArray;
use crate::arrays::ListViewArray;
use crate::arrays::Primitive;
use crate::arrays::PrimitiveArray;
use crate::arrays::fixed_size_list::FixedSizeListArrayExt;
use crate::arrays::listview::ListViewArrayExt;
use crate::builders::builder_with_capacity;
use crate::builtins::ArrayBuiltins;
use crate::dtype::DType;
use crate::dtype::IntegerPType;
use crate::dtype::NativePType;
use crate::executor::max_iterations;
use crate::match_each_integer_ptype;
use crate::match_each_native_ptype;

/// Reference-counted type-erased grouped accumulator.
pub type GroupedAccumulatorRef = Box<dyn DynGroupedAccumulator>;
Expand Down Expand Up @@ -170,7 +177,7 @@ impl<V: AggregateFnVTable> GroupedAccumulator<V> {
break;
}

let kernels_r = kernels.read();
let kernels_r = kernels.load();
if let Some(result) = kernels_r
.get(&(elements.encoding_id(), Some(self.aggregate_fn.id())))
.or_else(|| kernels_r.get(&(elements.encoding_id(), None)))
Expand Down Expand Up @@ -224,6 +231,32 @@ impl<V: AggregateFnVTable> GroupedAccumulator<V> {
validity: &Mask,
ctx: &mut ExecutionCtx,
) -> VortexResult<()> {
// Fast path: summing a canonical, all-valid primitive column. The generic loop below
// creates a slice array + does a kernel-dispatching scalar `accumulate` + scalar boxing per
// group, which is catastrophic for many small groups. A direct typed per-group slice sum is
// orders of magnitude faster. Falls through for null elements / other aggregates.
if self.aggregate_fn.id().as_str() == "vortex.sum"
&& let Some(result) = try_sum_primitive_groups(
elements,
offsets,
sizes,
validity,
&self.partial_dtype,
ctx,
)?
{
return self.push_result(result);
}

// Fast path: `count` over any element type is just the number of valid elements per group
// (= group size when the elements are all-valid). Avoids the per-group scalar accumulator.
if self.aggregate_fn.id().as_str() == "vortex.count"
&& let Some(result) =
try_count_groups(elements, offsets, sizes, validity, &self.partial_dtype, ctx)?
{
return self.push_result(result);
}

let mut accumulator = Accumulator::try_new(
self.vtable.clone(),
self.options.clone(),
Expand Down Expand Up @@ -262,7 +295,7 @@ impl<V: AggregateFnVTable> GroupedAccumulator<V> {
break;
}

let kernels_r = kernels.read();
let kernels_r = kernels.load();
if let Some(result) = kernels_r
.get(&(elements.encoding_id(), Some(self.aggregate_fn.id())))
.or_else(|| kernels_r.get(&(elements.encoding_id(), None)))
Expand Down Expand Up @@ -332,3 +365,191 @@ impl<V: AggregateFnVTable> GroupedAccumulator<V> {
Ok(())
}
}

/// Fast vectorized per-group `Sum` over a canonical, all-valid primitive elements array.
///
/// Returns `Ok(None)` (caller falls back to the generic per-group path) when the elements are not a
/// canonical primitive, contain nulls, or the result dtype would not match the `Sum` partial dtype.
fn try_sum_primitive_groups<O: IntegerPType>(
elements: &ArrayRef,
offsets: &[O],
sizes: &[O],
group_validity: &Mask,
partial_dtype: &DType,
ctx: &mut ExecutionCtx,
) -> VortexResult<Option<ArrayRef>> {
let Some(prim) = elements.as_opt::<Primitive>() else {
return Ok(None);
};
// Materialize the element validity once. The common case (a nullable column with no actual
// nulls) is `AllOr::All` and takes the tight slice-sum loop; mixed validity falls to a masked
// loop. (`AllOr::None` -> every valid group sums to zero, matching `Sum`.)
let elem_mask = prim.validity()?.execute_mask(prim.len(), ctx)?;
let all_valid = matches!(elem_mask.slices(), AllOr::All);

let result = match_each_native_ptype!(prim.ptype(),
unsigned: |T| { sum_groups_unsigned::<T, O>(prim.as_slice::<T>(), offsets, sizes, group_validity, &elem_mask, all_valid) },
signed: |T| { sum_groups_signed::<T, O>(prim.as_slice::<T>(), offsets, sizes, group_validity, &elem_mask, all_valid) },
floating: |T| { sum_groups_float::<T, O>(prim.as_slice::<T>(), offsets, sizes, group_validity, &elem_mask, all_valid) }
);

// Defensive: if our widening doesn't exactly match the aggregate's partial dtype, fall back to
// the generic path rather than emit a mistyped array downstream.
if result.dtype() != partial_dtype {
return Ok(None);
}
Ok(Some(result))
}

/// Fast vectorized per-group `count` (number of non-null elements per group). For all-valid
/// elements this is just the group size. Returns `Ok(None)` (caller falls back) when any group is
/// null (the non-nullable count partial can't represent it) or the result dtype mismatches.
fn try_count_groups<O: IntegerPType>(
elements: &ArrayRef,
offsets: &[O],
sizes: &[O],
group_validity: &Mask,
partial_dtype: &DType,
ctx: &mut ExecutionCtx,
) -> VortexResult<Option<ArrayRef>> {
// The count partial dtype is non-nullable, so it cannot represent a null group.
if !matches!(group_validity.slices(), AllOr::All) {
return Ok(None);
}
let elem_mask = elements.validity()?.execute_mask(elements.len(), ctx)?;
let all_valid = matches!(elem_mask.slices(), AllOr::All);

let counts = offsets.iter().zip(sizes.iter()).map(|(o, sz)| {
let o = o.to_usize().vortex_expect("offset usize");
let sz = sz.to_usize().vortex_expect("size usize");
if all_valid {
sz as u64
} else {
(o..o + sz).filter(|&j| elem_mask.value(j)).count() as u64
}
});
let result = PrimitiveArray::from_iter(counts).into_array();

if result.dtype() != partial_dtype {
return Ok(None);
}
Ok(Some(result))
}

fn sum_groups_unsigned<T, O>(
values: &[T],
offsets: &[O],
sizes: &[O],
group_validity: &Mask,
elem_mask: &Mask,
all_valid: bool,
) -> ArrayRef
where
T: NativePType + AsPrimitive<u64>,
O: IntegerPType,
{
let iter = offsets
.iter()
.zip(sizes.iter())
.enumerate()
.map(|(i, (o, sz))| {
if !group_validity.value(i) {
return None;
}
let o = o.to_usize().vortex_expect("offset usize");
let sz = sz.to_usize().vortex_expect("size usize");
let mut acc: u64 = 0;
if all_valid {
for &v in &values[o..o + sz] {
acc = acc.checked_add(v.as_())?; // overflow -> null, matching Sum saturation
}
} else {
for j in 0..sz {
if elem_mask.value(o + j) {
acc = acc.checked_add(values[o + j].as_())?;
}
}
}
Some(acc)
});
PrimitiveArray::from_option_iter(iter).into_array()
}

fn sum_groups_signed<T, O>(
values: &[T],
offsets: &[O],
sizes: &[O],
group_validity: &Mask,
elem_mask: &Mask,
all_valid: bool,
) -> ArrayRef
where
T: NativePType + AsPrimitive<i64>,
O: IntegerPType,
{
let iter = offsets
.iter()
.zip(sizes.iter())
.enumerate()
.map(|(i, (o, sz))| {
if !group_validity.value(i) {
return None;
}
let o = o.to_usize().vortex_expect("offset usize");
let sz = sz.to_usize().vortex_expect("size usize");
let mut acc: i64 = 0;
if all_valid {
for &v in &values[o..o + sz] {
acc = acc.checked_add(v.as_())?; // overflow -> null, matching Sum saturation
}
} else {
for j in 0..sz {
if elem_mask.value(o + j) {
acc = acc.checked_add(values[o + j].as_())?;
}
}
}
Some(acc)
});
PrimitiveArray::from_option_iter(iter).into_array()
}

fn sum_groups_float<T, O>(
values: &[T],
offsets: &[O],
sizes: &[O],
group_validity: &Mask,
elem_mask: &Mask,
all_valid: bool,
) -> ArrayRef
where
T: NativePType + ToPrimitive,
O: IntegerPType,
{
let iter = offsets
.iter()
.zip(sizes.iter())
.enumerate()
.map(|(i, (o, sz))| {
if !group_validity.value(i) {
return None;
}
let o = o.to_usize().vortex_expect("offset usize");
let sz = sz.to_usize().vortex_expect("size usize");
let mut acc: f64 = 0.0;
// NaN propagates, matching Sum's float semantics.
if all_valid {
for &v in &values[o..o + sz] {
acc += v.to_f64().vortex_expect("float to f64");
}
} else {
for j in 0..sz {
if elem_mask.value(o + j) {
acc += values[o + j].to_f64().vortex_expect("float to f64");
}
}
}
Some(acc)
});
PrimitiveArray::from_option_iter(iter).into_array()
}
22 changes: 14 additions & 8 deletions vortex-array/src/aggregate_fn/session.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
use std::any::Any;
use std::sync::Arc;

use parking_lot::RwLock;
use arc_swap::ArcSwap;
use vortex_session::Ref;
use vortex_session::SessionExt;
use vortex_session::SessionVar;
Expand Down Expand Up @@ -51,8 +51,11 @@ pub type AggregateFnRegistry = Registry<AggregateFnPluginRef>;
pub struct AggregateFnSession {
registry: AggregateFnRegistry,

pub(super) kernels: RwLock<HashMap<KernelKey, &'static dyn DynAggregateKernel>>,
pub(super) grouped_kernels: RwLock<HashMap<KernelKey, &'static dyn DynGroupedAggregateKernel>>,
// `ArcSwap` rather than `RwLock`: kernels are registered once at session construction and read
// on every accumulate call. Under parallel reduce, a `RwLock` read-lock here was the dominant
// cost (`lock_shared_slow`); `ArcSwap` makes the per-accumulate lookup a lock-free atomic load.
pub(super) kernels: ArcSwap<HashMap<KernelKey, &'static dyn DynAggregateKernel>>,
pub(super) grouped_kernels: ArcSwap<HashMap<KernelKey, &'static dyn DynGroupedAggregateKernel>>,
}

impl SessionVar for AggregateFnSession {
Expand All @@ -71,8 +74,8 @@ impl Default for AggregateFnSession {
fn default() -> Self {
let this = Self {
registry: AggregateFnRegistry::default(),
kernels: RwLock::new(HashMap::default()),
grouped_kernels: RwLock::new(HashMap::default()),
kernels: ArcSwap::from_pointee(HashMap::default()),
grouped_kernels: ArcSwap::from_pointee(HashMap::default()),
};

// Register the built-in aggregate functions
Expand Down Expand Up @@ -125,9 +128,12 @@ impl AggregateFnSession {
agg_fn_id: Option<impl Into<AggregateFnId>>,
kernel: &'static dyn DynAggregateKernel,
) {
self.kernels
.write()
.insert((array_id.into(), agg_fn_id.map(|id| id.into())), kernel);
let key = (array_id.into(), agg_fn_id.map(|id| id.into()));
self.kernels.rcu(|current| {
let mut new = (**current).clone();
new.insert(key.clone(), kernel);
new
});
}
}

Expand Down
Loading