From 292d9076be200ef4bc240eabed52fb625d9d0e63 Mon Sep 17 00:00:00 2001 From: Robert Kruszewski Date: Thu, 9 Apr 2026 16:07:41 +0100 Subject: [PATCH 1/3] TakeExecute for FilterArray Signed-off-by: Robert Kruszewski --- vortex-array/public-api.lock | 12 + vortex-array/src/arrays/filter/mod.rs | 1 + vortex-array/src/arrays/filter/take.rs | 131 +++++ vortex-array/src/arrays/filter/vtable.rs | 10 + vortex-buffer/benches/vortex_bitbuffer.rs | 22 + vortex-buffer/public-api.lock | 4 + vortex-buffer/src/bit/buf.rs | 33 ++ vortex-buffer/src/bit/count_ones.rs | 6 +- vortex-buffer/src/bit/mod.rs | 1 + vortex-buffer/src/bit/select.rs | 686 ++++++++++++++++++++++ vortex-mask/Cargo.toml | 4 + vortex-mask/benches/rank.rs | 94 +++ vortex-mask/public-api.lock | 2 + vortex-mask/src/intersect_by_rank.rs | 18 +- vortex-mask/src/lib.rs | 35 +- 15 files changed, 1043 insertions(+), 16 deletions(-) create mode 100644 vortex-array/src/arrays/filter/take.rs create mode 100644 vortex-buffer/src/bit/select.rs create mode 100644 vortex-mask/benches/rank.rs diff --git a/vortex-array/public-api.lock b/vortex-array/public-api.lock index 25736c6df84..8dae2ab23c8 100644 --- a/vortex-array/public-api.lock +++ b/vortex-array/public-api.lock @@ -2228,6 +2228,10 @@ impl vortex_array::arrays::dict::TakeExecute for vortex_array::arrays::Extension pub fn vortex_array::arrays::Extension::take(array: vortex_array::ArrayView<'_, vortex_array::arrays::Extension>, indices: &vortex_array::ArrayRef, _ctx: &mut vortex_array::ExecutionCtx) -> vortex_error::VortexResult> +impl vortex_array::arrays::dict::TakeExecute for vortex_array::arrays::Filter + +pub fn vortex_array::arrays::Filter::take(array: vortex_array::ArrayView<'_, vortex_array::arrays::Filter>, indices: &vortex_array::ArrayRef, ctx: &mut vortex_array::ExecutionCtx) -> vortex_error::VortexResult> + impl vortex_array::arrays::dict::TakeExecute for vortex_array::arrays::FixedSizeList pub fn vortex_array::arrays::FixedSizeList::take(array: vortex_array::ArrayView<'_, vortex_array::arrays::FixedSizeList>, indices: &vortex_array::ArrayRef, ctx: &mut vortex_array::ExecutionCtx) -> vortex_error::VortexResult> @@ -2470,6 +2474,10 @@ impl vortex_array::ValidityVTable for vortex_array pub fn vortex_array::arrays::Filter::validity(array: vortex_array::ArrayView<'_, vortex_array::arrays::Filter>) -> vortex_error::VortexResult +impl vortex_array::arrays::dict::TakeExecute for vortex_array::arrays::Filter + +pub fn vortex_array::arrays::Filter::take(array: vortex_array::ArrayView<'_, vortex_array::arrays::Filter>, indices: &vortex_array::ArrayRef, ctx: &mut vortex_array::ExecutionCtx) -> vortex_error::VortexResult> + pub struct vortex_array::arrays::filter::FilterData impl vortex_array::arrays::filter::FilterData @@ -5706,6 +5714,10 @@ impl vortex_array::ValidityVTable for vortex_array pub fn vortex_array::arrays::Filter::validity(array: vortex_array::ArrayView<'_, vortex_array::arrays::Filter>) -> vortex_error::VortexResult +impl vortex_array::arrays::dict::TakeExecute for vortex_array::arrays::Filter + +pub fn vortex_array::arrays::Filter::take(array: vortex_array::ArrayView<'_, vortex_array::arrays::Filter>, indices: &vortex_array::ArrayRef, ctx: &mut vortex_array::ExecutionCtx) -> vortex_error::VortexResult> + pub struct vortex_array::arrays::FixedSizeList impl core::clone::Clone for vortex_array::arrays::FixedSizeList diff --git a/vortex-array/src/arrays/filter/mod.rs b/vortex-array/src/arrays/filter/mod.rs index 39859bc5b25..15a987e8f20 100644 --- a/vortex-array/src/arrays/filter/mod.rs +++ b/vortex-array/src/arrays/filter/mod.rs @@ -16,6 +16,7 @@ pub use kernel::FilterReduce; pub use kernel::FilterReduceAdaptor; mod rules; +mod take; mod vtable; pub use vtable::Filter; diff --git a/vortex-array/src/arrays/filter/take.rs b/vortex-array/src/arrays/filter/take.rs new file mode 100644 index 00000000000..258788761d2 --- /dev/null +++ b/vortex-array/src/arrays/filter/take.rs @@ -0,0 +1,131 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +use num_traits::ToPrimitive; +use vortex_buffer::BufferMut; +use vortex_error::VortexResult; +use vortex_error::vortex_bail; + +use super::Filter; +use crate::ArrayRef; +use crate::IntoArray; +use crate::array::ArrayView; +use crate::arrays::PrimitiveArray; +use crate::arrays::dict::TakeExecute; +use crate::arrays::dict::TakeExecuteAdaptor; +use crate::arrays::filter::FilterArrayExt; +use crate::builtins::ArrayBuiltins; +use crate::dtype::DType; +use crate::executor::ExecutionCtx; +use crate::kernel::ParentKernelSet; +use crate::match_each_integer_ptype; +use crate::validity::Validity; + +pub(super) const PARENT_KERNELS: ParentKernelSet = + ParentKernelSet::new(&[ParentKernelSet::lift(&TakeExecuteAdaptor(Filter))]); + +fn take_impl(array: ArrayView<'_, Filter>, indices: &PrimitiveArray) -> VortexResult { + let indices_validity = indices.validity_mask()?; + let mut translated = BufferMut::::with_capacity(indices.len()); + translated.push_n(0u64, indices.len()); + + match_each_integer_ptype!(indices.ptype(), |P| { + // Collect valid (output_idx, rank) pairs — validates bounds up front. + let mut valid_indices = Vec::new(); + let mut valid_ranks = Vec::new(); + for (idx, rank) in indices.as_slice::

().iter().enumerate() { + if !indices_validity.value(idx) { + continue; + } + let Some(rank) = rank.to_usize() else { + vortex_bail!(OutOfBounds: 0, 0, array.len()); + }; + if rank >= array.len() { + vortex_bail!(OutOfBounds: rank, 0, array.len()); + } + valid_indices.push(idx); + valid_ranks.push(rank); + } + + // Batch rank: single-pass over the bitmap instead of per-element scan. + let positions = array.filter_mask().rank_batch(&valid_ranks); + for (&idx, pos) in valid_indices.iter().zip(positions.iter()) { + translated[idx] = u64::try_from(*pos)?; + } + + Ok::<(), vortex_error::VortexError>(()) + })?; + + let translated_indices = PrimitiveArray::new( + translated.freeze(), + Validity::from_mask(indices_validity, indices.dtype().nullability()), + ) + .into_array(); + + array.child().take(translated_indices) +} + +impl TakeExecute for Filter { + fn take( + array: ArrayView<'_, Filter>, + indices: &ArrayRef, + ctx: &mut ExecutionCtx, + ) -> VortexResult> { + let DType::Primitive(ptype, nullability) = indices.dtype() else { + vortex_bail!("Invalid indices dtype: {}", indices.dtype()) + }; + + let unsigned_indices = if ptype.is_unsigned_int() { + indices.clone().execute::(ctx)? + } else { + indices + .clone() + .cast(DType::Primitive(ptype.to_unsigned(), *nullability))? + .execute::(ctx)? + }; + + take_impl(array, &unsigned_indices).map(Some) + } +} + +#[cfg(test)] +mod tests { + use vortex_buffer::buffer; + use vortex_error::VortexResult; + use vortex_mask::Mask; + use vortex_session::VortexSession; + + use crate::IntoArray; + use crate::arrays::Dict; + use crate::arrays::DictArray; + use crate::arrays::FilterArray; + use crate::arrays::PrimitiveArray; + use crate::assert_arrays_eq; + use crate::executor::ExecutionCtx; + + #[test] + fn test_take_execute_kernel_maps_indices_through_filter() -> VortexResult<()> { + let filter = FilterArray::new( + buffer![10i32, 20, 30, 40, 50].into_array(), + Mask::from_iter([true, false, true, true, false]), + ) + .into_array(); + let parent = DictArray::try_new( + PrimitiveArray::from_option_iter([Some(2u64), None, Some(0)]).into_array(), + filter.clone(), + )? + .into_array(); + let mut ctx = ExecutionCtx::new(VortexSession::empty()); + + let result = filter + .execute_parent(&parent, 1, &mut ctx)? + .expect("filter child should execute its take parent"); + + assert!(result.as_opt::().is_some()); + assert_arrays_eq!( + result.to_canonical()?.into_array(), + PrimitiveArray::from_option_iter([Some(40i32), None, Some(10)]).into_array() + ); + Ok(()) + } +} diff --git a/vortex-array/src/arrays/filter/vtable.rs b/vortex-array/src/arrays/filter/vtable.rs index 05f49b5c941..52e9bcb7438 100644 --- a/vortex-array/src/arrays/filter/vtable.rs +++ b/vortex-array/src/arrays/filter/vtable.rs @@ -34,6 +34,7 @@ use crate::arrays::filter::execute::execute_filter; use crate::arrays::filter::execute::execute_filter_fast_paths; use crate::arrays::filter::rules::PARENT_RULES; use crate::arrays::filter::rules::RULES; +use crate::arrays::filter::take::PARENT_KERNELS; use crate::buffer::BufferHandle; use crate::dtype::DType; use crate::executor::ExecutionCtx; @@ -170,6 +171,15 @@ impl VTable for Filter { PARENT_RULES.evaluate(array, parent, child_idx) } + fn execute_parent( + array: ArrayView<'_, Self>, + parent: &ArrayRef, + child_idx: usize, + ctx: &mut ExecutionCtx, + ) -> VortexResult> { + PARENT_KERNELS.execute(array, parent, child_idx, ctx) + } + fn reduce(array: ArrayView<'_, Self>) -> VortexResult> { RULES.evaluate(array) } diff --git a/vortex-buffer/benches/vortex_bitbuffer.rs b/vortex-buffer/benches/vortex_bitbuffer.rs index 67ce88da889..f2b04474592 100644 --- a/vortex-buffer/benches/vortex_bitbuffer.rs +++ b/vortex-buffer/benches/vortex_bitbuffer.rs @@ -288,3 +288,25 @@ fn set_indices_arrow_buffer(bencher: Bencher, length: usize) { } }); } + +// ── select benchmarks ─────────────────────────────────────────────────── + +#[divan::bench(args = INPUT_SIZE)] +fn select_mid_vortex_buffer(bencher: Bencher, length: usize) { + let buffer = BitBuffer::from_iter((0..length).map(true_count_pattern)); + let mid = buffer.true_count() / 2; + bencher + .with_inputs(|| (&buffer, mid)) + .bench_refs(|(buffer, mid)| buffer.select(*mid)); +} + +#[divan::bench(args = INPUT_SIZE)] +fn select_all_vortex_buffer(bencher: Bencher, length: usize) { + let buffer = BitBuffer::from_iter((0..length).map(true_count_pattern)); + let tc = buffer.true_count(); + bencher.with_inputs(|| &buffer).bench_refs(|buffer| { + for nth in 0..tc { + divan::black_box(buffer.select(nth)); + } + }); +} diff --git a/vortex-buffer/public-api.lock b/vortex-buffer/public-api.lock index c89264da253..f72d82db4fc 100644 --- a/vortex-buffer/public-api.lock +++ b/vortex-buffer/public-api.lock @@ -280,6 +280,10 @@ pub fn vortex_buffer::BitBuffer::new_with_offset(buffer: vortex_buffer::ByteBuff pub fn vortex_buffer::BitBuffer::offset(&self) -> usize +pub fn vortex_buffer::BitBuffer::select(&self, nth: usize) -> usize + +pub fn vortex_buffer::BitBuffer::select_sorted_batch(&self, sorted_ranks: &[usize]) -> alloc::vec::Vec + pub fn vortex_buffer::BitBuffer::set_indices(&self) -> arrow_buffer::util::bit_iterator::BitIndexIterator<'_> pub fn vortex_buffer::BitBuffer::set_slices(&self) -> arrow_buffer::util::bit_iterator::BitSliceIterator<'_> diff --git a/vortex-buffer/src/bit/buf.rs b/vortex-buffer/src/bit/buf.rs index 61cc5e30044..1c341f1fbaa 100644 --- a/vortex-buffer/src/bit/buf.rs +++ b/vortex-buffer/src/bit/buf.rs @@ -25,6 +25,8 @@ use crate::bit::count_ones::count_ones; use crate::bit::get_bit_unchecked; use crate::bit::ops::bitwise_binary_op; use crate::bit::ops::bitwise_unary_op; +use crate::bit::select::bit_select; +use crate::bit::select::bit_select_sorted_batch; use crate::buffer; /// An immutable bitset stored as a packed byte buffer. @@ -319,6 +321,37 @@ impl BitBuffer { count_ones(self.buffer.as_slice(), self.offset, self.len) } + /// Returns the position of the `nth` set bit (0-indexed). + /// + /// This is the "select" operation on a bitmap: given a rank `nth`, find + /// which logical bit position holds that rank. + /// + /// # Panics + /// + /// Panics (debug) or produces undefined results (release) if `nth` is + /// greater than or equal to [`true_count`](Self::true_count). + pub fn select(&self, nth: usize) -> usize { + bit_select(self.buffer.as_slice(), self.offset, self.len, nth) + } + + /// Select positions for multiple ranks in a single pass over the bitmap. + /// + /// `sorted_ranks` must be sorted in non-decreasing order, with each value + /// less than [`true_count`](Self::true_count). This is O(L/64 + N) where + /// L = bitmap length and N = number of ranks, compared to O(N × L/64) for + /// individual [`select`](Self::select) calls. + pub fn select_sorted_batch(&self, sorted_ranks: &[usize]) -> Vec { + let mut out = vec![0; sorted_ranks.len()]; + bit_select_sorted_batch( + self.buffer.as_slice(), + self.offset, + self.len, + sorted_ranks, + &mut out, + ); + out + } + /// Get the number of unset bits in the buffer. pub fn false_count(&self) -> usize { self.len - self.true_count() diff --git a/vortex-buffer/src/bit/count_ones.rs b/vortex-buffer/src/bit/count_ones.rs index 6d70d47cfa7..df5844a2914 100644 --- a/vortex-buffer/src/bit/count_ones.rs +++ b/vortex-buffer/src/bit/count_ones.rs @@ -22,7 +22,11 @@ pub fn count_ones(bytes: &[u8], offset: usize, len: usize) -> usize { } #[inline] -fn align_offset_len(bytes: &[u8], offset: usize, len: usize) -> (Option, &[u8], Option) { +pub(super) fn align_offset_len( + bytes: &[u8], + offset: usize, + len: usize, +) -> (Option, &[u8], Option) { let start_byte = offset / 8; let start_bit = offset % 8; let end_bit = offset + len; diff --git a/vortex-buffer/src/bit/mod.rs b/vortex-buffer/src/bit/mod.rs index 034be84a18c..37930d788b7 100644 --- a/vortex-buffer/src/bit/mod.rs +++ b/vortex-buffer/src/bit/mod.rs @@ -13,6 +13,7 @@ mod buf_mut; mod count_ones; mod macros; mod ops; +mod select; pub use arrow_buffer::bit_chunk_iterator::BitChunkIterator; pub use arrow_buffer::bit_chunk_iterator::BitChunks; diff --git a/vortex-buffer/src/bit/select.rs b/vortex-buffer/src/bit/select.rs new file mode 100644 index 00000000000..c070ae80e86 --- /dev/null +++ b/vortex-buffer/src/bit/select.rs @@ -0,0 +1,686 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +use super::count_ones::align_offset_len; + +/// Returns the position of the `nth` set bit (0-indexed) within the logical range +/// `[offset, offset + len)` of the given byte slice. +/// +/// The returned position is relative to the logical start (i.e., 0-indexed from `offset`). +/// +/// Uses architecture-specific optimizations: +/// - **aarch64**: NEON `vcnt`-based popcount for the word-level scan. +/// - **x86_64 + BMI2**: `pdep` + `tzcnt` for the final in-word select. +/// - **Scalar fallback**: 4× unrolled word scan with `count_ones`, byte-level narrowing. +#[inline] +pub fn bit_select(bytes: &[u8], offset: usize, len: usize, nth: usize) -> usize { + let (head, middle, tail) = align_offset_len(bytes, offset, len); + let mut remaining = nth; + let mut pos = 0usize; + + // ── partial first byte ────────────────────────────────────────────── + if let Some(head) = head { + let count = head.count_ones() as usize; + if remaining < count { + return select_in_byte(head, remaining); + } + remaining -= count; + let start_bit = offset % 8; + pos = (8 - start_bit).min(len); + } + + // ── aligned middle bytes ──────────────────────────────────────────── + if !middle.is_empty() { + let (words, tail_bytes) = middle.as_chunks::<8>(); + + let (rem, new_pos, word_idx) = scan_words(words, remaining, pos); + remaining = rem; + pos = new_pos; + + if word_idx < words.len() { + let word = u64::from_le_bytes(words[word_idx]); + return pos + select_in_word(word, remaining); + } + + // Remaining aligned bytes that don't fill a full u64. + for &byte in tail_bytes { + let count = byte.count_ones() as usize; + if remaining < count { + return pos + select_in_byte(byte, remaining); + } + remaining -= count; + pos += 8; + } + } + + // ── partial last byte ─────────────────────────────────────────────── + if let Some(tail) = tail { + debug_assert!( + remaining < tail.count_ones() as usize, + "bit_select: nth={nth} out of bounds" + ); + return pos + select_in_byte(tail, remaining); + } + + unreachable!("bit_select: nth={nth} exceeds set bit count") +} + +// ── Batch select (sorted ranks, single pass) ─────────────────────────── + +/// Select positions for multiple ranks in a single pass over the bitmap. +/// +/// `sorted_ranks` must be sorted in non-decreasing order; each value must be +/// less than the total number of set bits in `[offset, offset+len)`. +/// Results are written to `out[0..sorted_ranks.len()]`. +pub fn bit_select_sorted_batch( + bytes: &[u8], + offset: usize, + len: usize, + sorted_ranks: &[usize], + out: &mut [usize], +) { + debug_assert!(out.len() >= sorted_ranks.len()); + if sorted_ranks.is_empty() { + return; + } + + let (head, middle, tail) = align_offset_len(bytes, offset, len); + let mut cumul = 0usize; + let mut pos = 0usize; + let mut ri = 0usize; // index into sorted_ranks / out + + // ── head byte ─────────────────────────────────────────────────── + if let Some(head) = head { + let count = head.count_ones() as usize; + while ri < sorted_ranks.len() && sorted_ranks[ri] < cumul + count { + out[ri] = select_in_byte(head, sorted_ranks[ri] - cumul); + ri += 1; + } + cumul += count; + let start_bit = offset % 8; + pos = (8 - start_bit).min(len); + if ri >= sorted_ranks.len() { + return; + } + } + + // ── middle bytes ──────────────────────────────────────────────── + if !middle.is_empty() { + let (words, tail_bytes) = middle.as_chunks::<8>(); + + scan_words_batch_impl(words, sorted_ranks, out, &mut cumul, &mut pos, &mut ri); + if ri >= sorted_ranks.len() { + return; + } + + for &byte in tail_bytes { + let count = byte.count_ones() as usize; + while ri < sorted_ranks.len() && sorted_ranks[ri] < cumul + count { + out[ri] = pos + select_in_byte(byte, sorted_ranks[ri] - cumul); + ri += 1; + } + cumul += count; + pos += 8; + if ri >= sorted_ranks.len() { + return; + } + } + } + + // ── tail byte ─────────────────────────────────────────────────── + if let Some(tail) = tail { + let count = tail.count_ones() as usize; + while ri < sorted_ranks.len() && sorted_ranks[ri] < cumul + count { + out[ri] = pos + select_in_byte(tail, sorted_ranks[ri] - cumul); + ri += 1; + } + } +} + +// ── aarch64 NEON batch scan ───────────────────────────────────────────── + +#[cfg(target_arch = "aarch64")] +#[allow(clippy::cast_possible_truncation)] +fn scan_words_batch_impl( + words: &[[u8; 8]], + sorted_ranks: &[usize], + out: &mut [usize], + cumul: &mut usize, + pos: &mut usize, + ri: &mut usize, +) { + use std::arch::aarch64::vcntq_u8; + use std::arch::aarch64::vgetq_lane_u64; + use std::arch::aarch64::vld1q_u8; + use std::arch::aarch64::vpaddlq_u8; + use std::arch::aarch64::vpaddlq_u16; + use std::arch::aarch64::vpaddlq_u32; + + let mut wi = 0; + + // 4-word NEON blocks — skip entire blocks when no targets inside. + while wi + 4 <= words.len() && *ri < sorted_ranks.len() { + let ptr = words[wi].as_ptr(); + // SAFETY: wi + 4 <= words.len() guarantees 32 contiguous bytes. + let (c0, c1, c2, c3) = unsafe { + let pop_lo = vcntq_u8(vld1q_u8(ptr)); + let pop_hi = vcntq_u8(vld1q_u8(ptr.add(16))); + let sums_lo = vpaddlq_u32(vpaddlq_u16(vpaddlq_u8(pop_lo))); + let sums_hi = vpaddlq_u32(vpaddlq_u16(vpaddlq_u8(pop_hi))); + ( + vgetq_lane_u64::<0>(sums_lo) as usize, + vgetq_lane_u64::<1>(sums_lo) as usize, + vgetq_lane_u64::<0>(sums_hi) as usize, + vgetq_lane_u64::<1>(sums_hi) as usize, + ) + }; + let total = c0 + c1 + c2 + c3; + + if sorted_ranks[*ri] >= *cumul + total { + *cumul += total; + *pos += 256; + wi += 4; + continue; + } + + // At least one target in this block — emit per word. + let counts = [c0, c1, c2, c3]; + for (j, &count) in counts.iter().enumerate() { + while *ri < sorted_ranks.len() && sorted_ranks[*ri] < *cumul + count { + let word = u64::from_le_bytes(words[wi + j]); + out[*ri] = *pos + select_in_word(word, sorted_ranks[*ri] - *cumul); + *ri += 1; + } + *cumul += count; + *pos += 64; + } + wi += 4; + } + + // Remaining words, scalar. + while wi < words.len() && *ri < sorted_ranks.len() { + let word = u64::from_le_bytes(words[wi]); + let count = word.count_ones() as usize; + while *ri < sorted_ranks.len() && sorted_ranks[*ri] < *cumul + count { + out[*ri] = *pos + select_in_word(word, sorted_ranks[*ri] - *cumul); + *ri += 1; + } + *cumul += count; + *pos += 64; + wi += 1; + } +} + +// ── Scalar batch scan ─────────────────────────────────────────────────── + +#[cfg(not(target_arch = "aarch64"))] +fn scan_words_batch_impl( + words: &[[u8; 8]], + sorted_ranks: &[usize], + out: &mut [usize], + cumul: &mut usize, + pos: &mut usize, + ri: &mut usize, +) { + let mut wi = 0; + + while wi + 4 <= words.len() && *ri < sorted_ranks.len() { + let c0 = u64::from_le_bytes(words[wi]).count_ones() as usize; + let c1 = u64::from_le_bytes(words[wi + 1]).count_ones() as usize; + let c2 = u64::from_le_bytes(words[wi + 2]).count_ones() as usize; + let c3 = u64::from_le_bytes(words[wi + 3]).count_ones() as usize; + let total = c0 + c1 + c2 + c3; + + if sorted_ranks[*ri] >= *cumul + total { + *cumul += total; + *pos += 256; + wi += 4; + continue; + } + + let counts = [c0, c1, c2, c3]; + for (j, &count) in counts.iter().enumerate() { + while *ri < sorted_ranks.len() && sorted_ranks[*ri] < *cumul + count { + let word = u64::from_le_bytes(words[wi + j]); + out[*ri] = *pos + select_in_word(word, sorted_ranks[*ri] - *cumul); + *ri += 1; + } + *cumul += count; + *pos += 64; + } + wi += 4; + } + + while wi < words.len() && *ri < sorted_ranks.len() { + let word = u64::from_le_bytes(words[wi]); + let count = word.count_ones() as usize; + while *ri < sorted_ranks.len() && sorted_ranks[*ri] < *cumul + count { + out[*ri] = *pos + select_in_word(word, sorted_ranks[*ri] - *cumul); + *ri += 1; + } + *cumul += count; + *pos += 64; + wi += 1; + } +} + +// ── Word-level scan ───────────────────────────────────────────────────── + +/// Scan `words` accumulating popcounts. Returns `(remaining, position, word_index)`. +/// +/// If `word_index < words.len()`, the target bit is inside that word and `remaining` +/// is the rank *within* that word. Otherwise all words were consumed. +#[inline] +fn scan_words(words: &[[u8; 8]], remaining: usize, pos: usize) -> (usize, usize, usize) { + scan_words_impl(words, remaining, pos) +} + +// ── aarch64 NEON scan ─────────────────────────────────────────────────── + +#[cfg(target_arch = "aarch64")] +#[allow(clippy::cast_possible_truncation)] // u64 → usize is lossless on aarch64 (64-bit) +#[inline] +fn scan_words_impl( + words: &[[u8; 8]], + mut remaining: usize, + mut pos: usize, +) -> (usize, usize, usize) { + use std::arch::aarch64::vcntq_u8; + use std::arch::aarch64::vgetq_lane_u64; + use std::arch::aarch64::vld1q_u8; + use std::arch::aarch64::vpaddlq_u8; + use std::arch::aarch64::vpaddlq_u16; + use std::arch::aarch64::vpaddlq_u32; + + let mut idx = 0; + + // Process 4 u64 words at a time using two 128-bit NEON registers. + while idx + 4 <= words.len() { + let ptr = words[idx].as_ptr(); + // SAFETY: idx + 4 <= words.len() guarantees 32 contiguous bytes from ptr. + // NEON vld1q_u8 supports unaligned access. + let (count_0, count_1, count_2, count_3) = unsafe { + let pop_lo = vcntq_u8(vld1q_u8(ptr)); + let pop_hi = vcntq_u8(vld1q_u8(ptr.add(16))); + let sums_lo = vpaddlq_u32(vpaddlq_u16(vpaddlq_u8(pop_lo))); + let sums_hi = vpaddlq_u32(vpaddlq_u16(vpaddlq_u8(pop_hi))); + ( + vgetq_lane_u64::<0>(sums_lo) as usize, + vgetq_lane_u64::<1>(sums_lo) as usize, + vgetq_lane_u64::<0>(sums_hi) as usize, + vgetq_lane_u64::<1>(sums_hi) as usize, + ) + }; + + let total = count_0 + count_1 + count_2 + count_3; + if remaining >= total { + remaining -= total; + pos += 256; + idx += 4; + continue; + } + + // Narrow down to the exact word. + if remaining < count_0 { + return (remaining, pos, idx); + } + remaining -= count_0; + pos += 64; + if remaining < count_1 { + return (remaining, pos, idx + 1); + } + remaining -= count_1; + pos += 64; + if remaining < count_2 { + return (remaining, pos, idx + 2); + } + remaining -= count_2; + pos += 64; + return (remaining, pos, idx + 3); + } + + // Process pairs. + while idx + 2 <= words.len() { + let ptr = words[idx].as_ptr(); + // SAFETY: idx + 2 <= words.len() guarantees 16 contiguous bytes. + let (count_0, count_1) = unsafe { + let pop = vcntq_u8(vld1q_u8(ptr)); + let sums = vpaddlq_u32(vpaddlq_u16(vpaddlq_u8(pop))); + ( + vgetq_lane_u64::<0>(sums) as usize, + vgetq_lane_u64::<1>(sums) as usize, + ) + }; + let total = count_0 + count_1; + if remaining < total { + if remaining < count_0 { + return (remaining, pos, idx); + } + return (remaining - count_0, pos + 64, idx + 1); + } + remaining -= total; + pos += 128; + idx += 2; + } + + // Single trailing word. + if idx < words.len() { + let word = u64::from_le_bytes(words[idx]); + let count = word.count_ones() as usize; + if remaining < count { + return (remaining, pos, idx); + } + remaining -= count; + pos += 64; + idx += 1; + } + + (remaining, pos, idx) +} + +// ── Scalar scan (x86_64 / generic) ───────────────────────────────────── + +#[cfg(not(target_arch = "aarch64"))] +#[inline] +fn scan_words_impl( + words: &[[u8; 8]], + mut remaining: usize, + mut pos: usize, +) -> (usize, usize, usize) { + let mut idx = 0; + + // 4× unrolled: the four independent `count_ones` calls pipeline well. + while idx + 4 <= words.len() { + let count_0 = u64::from_le_bytes(words[idx]).count_ones() as usize; + let count_1 = u64::from_le_bytes(words[idx + 1]).count_ones() as usize; + let count_2 = u64::from_le_bytes(words[idx + 2]).count_ones() as usize; + let count_3 = u64::from_le_bytes(words[idx + 3]).count_ones() as usize; + let total = count_0 + count_1 + count_2 + count_3; + + if remaining >= total { + remaining -= total; + pos += 256; + idx += 4; + continue; + } + + if remaining < count_0 { + return (remaining, pos, idx); + } + remaining -= count_0; + pos += 64; + if remaining < count_1 { + return (remaining, pos, idx + 1); + } + remaining -= count_1; + pos += 64; + if remaining < count_2 { + return (remaining, pos, idx + 2); + } + remaining -= count_2; + pos += 64; + return (remaining, pos, idx + 3); + } + + while idx < words.len() { + let word = u64::from_le_bytes(words[idx]); + let count = word.count_ones() as usize; + if remaining < count { + return (remaining, pos, idx); + } + remaining -= count; + pos += 64; + idx += 1; + } + + (remaining, pos, idx) +} + +// ── In-word select ────────────────────────────────────────────────────── + +/// Position of the `nth` set bit inside a u64 (0-indexed, little-endian bit order). +#[inline] +fn select_in_word(word: u64, nth: usize) -> usize { + #[cfg(target_arch = "x86_64")] + { + if is_x86_feature_detected!("bmi2") { + // SAFETY: runtime detection guarantees the required target feature. + return unsafe { select_in_word_bmi2(word, nth) }; + } + } + select_in_word_scalar(word, nth) +} + +/// BMI2: deposit a single bit at the nth set-bit position, then count trailing zeros. +#[cfg(target_arch = "x86_64")] +#[target_feature(enable = "bmi2")] +unsafe fn select_in_word_bmi2(word: u64, nth: usize) -> usize { + use std::arch::x86_64::_pdep_u64; + use std::arch::x86_64::_tzcnt_u64; + unsafe { _tzcnt_u64(_pdep_u64(1u64 << nth, word)) as usize } +} + +/// Scalar: narrow to the correct byte, then clear `nth` lowest set bits and trailing-zeros. +#[inline] +fn select_in_word_scalar(word: u64, mut nth: usize) -> usize { + let bytes = word.to_le_bytes(); + let mut bit_offset = 0usize; + for &byte in &bytes { + let count = byte.count_ones() as usize; + if nth < count { + return bit_offset + select_in_byte(byte, nth); + } + nth -= count; + bit_offset += 8; + } + unreachable!("select_in_word: nth exceeds popcount") +} + +// ── In-byte select ────────────────────────────────────────────────────── + +/// Position of the `nth` set bit inside a byte (0-indexed, LSB-first). +/// +/// Clears the lowest `nth` set bits, then uses `trailing_zeros`. +#[inline] +fn select_in_byte(byte: u8, nth: usize) -> usize { + debug_assert!(nth < byte.count_ones() as usize); + let mut bits = u32::from(byte); + for _ in 0..nth { + bits &= bits - 1; // clear lowest set bit + } + bits.trailing_zeros() as usize +} + +#[cfg(test)] +mod tests { + #![allow(clippy::cast_possible_truncation)] + + use rstest::rstest; + + use super::*; + + #[test] + fn test_select_all_set() { + // Every bit is set — select(n) == n. + let buf = [0xFFu8; 16]; // 128 bits, all set + for nth in 0..128 { + assert_eq!(bit_select(&buf, 0, 128, nth), nth, "nth={nth}"); + } + } + + #[test] + fn test_select_every_other() { + // 0b01010101 repeated: bits 0,2,4,6 of each byte are set. + let buf = [0x55u8; 16]; // 128 bits, 64 set + for nth in 0..64 { + assert_eq!(bit_select(&buf, 0, 128, nth), nth * 2, "nth={nth}"); + } + } + + #[test] + fn test_select_single_bit() { + // Only bit 42 is set. + let mut buf = [0u8; 16]; + buf[42 / 8] |= 1 << (42 % 8); + assert_eq!(bit_select(&buf, 0, 128, 0), 42); + } + + #[rstest] + #[case(0, 128)] + #[case(3, 100)] + #[case(7, 50)] + #[case(1, 7)] + #[case(5, 5)] + #[case(0, 1)] + #[case(0, 64)] + #[case(1, 64)] + #[case(0, 65)] + #[case(3, 256)] + fn test_select_agrees_with_naive(#[case] offset: usize, #[case] len: usize) { + let total_bits = offset + len; + let total_bytes = total_bits.div_ceil(8); + // Deterministic pattern with moderate density. + let buf: Vec = (0..total_bytes) + .map(|i| ((i.wrapping_mul(0x9E) ^ 0xA5) & 0xFF) as u8) + .collect(); + + // Collect set-bit positions naively. + let expected: Vec = (0..len) + .filter(|&i| { + let phys = offset + i; + (buf[phys / 8] >> (phys % 8)) & 1 == 1 + }) + .collect(); + + for (nth, &expected_pos) in expected.iter().enumerate() { + assert_eq!( + bit_select(&buf, offset, len, nth), + expected_pos, + "offset={offset} len={len} nth={nth}" + ); + } + } + + #[test] + fn test_select_large_buffer() { + // ~64 KB buffer, ~50% density. + let len = 65_536 * 8; + let buf: Vec = (0u32..65_536) + .map(|i| ((i.wrapping_mul(0x37) ^ 0xBC) & 0xFF) as u8) + .collect(); + + let true_count = buf.iter().map(|b| b.count_ones() as usize).sum::(); + + // Spot-check a few positions. + let first = bit_select(&buf, 0, len, 0); + let last = bit_select(&buf, 0, len, true_count - 1); + assert!(first < len); + assert!(last < len); + assert!(first <= last); + + // Verify the found positions are actually set. + assert_ne!(buf[first / 8] & (1 << (first % 8)), 0); + assert_ne!(buf[last / 8] & (1 << (last % 8)), 0); + } + + #[test] + fn test_select_in_word_basic() { + // 0b1010_1010 = 0xAA — bits 1,3,5,7 are set. + let word = 0x00000000_000000AAu64; + assert_eq!(select_in_word(word, 0), 1); + assert_eq!(select_in_word(word, 1), 3); + assert_eq!(select_in_word(word, 2), 5); + assert_eq!(select_in_word(word, 3), 7); + } + + #[test] + fn test_select_in_word_all_set() { + let word = u64::MAX; + for nth in 0..64 { + assert_eq!(select_in_word(word, nth), nth, "nth={nth}"); + } + } + + #[test] + fn test_select_in_byte_basic() { + assert_eq!(select_in_byte(0b1010_1010, 0), 1); + assert_eq!(select_in_byte(0b1010_1010, 1), 3); + assert_eq!(select_in_byte(0b1010_1010, 2), 5); + assert_eq!(select_in_byte(0b1010_1010, 3), 7); + assert_eq!(select_in_byte(0b0000_0001, 0), 0); + assert_eq!(select_in_byte(0b1000_0000, 0), 7); + assert_eq!(select_in_byte(0xFF, 7), 7); + } + + // ── batch select tests ────────────────────────────────────────── + + #[test] + fn test_batch_select_all_set() { + let buf = [0xFFu8; 16]; // 128 bits, all set + let ranks: Vec = (0..128).collect(); + let mut out = vec![0usize; 128]; + bit_select_sorted_batch(&buf, 0, 128, &ranks, &mut out); + for (nth, &pos) in out.iter().enumerate() { + assert_eq!(pos, nth, "nth={nth}"); + } + } + + #[test] + fn test_batch_select_every_other() { + let buf = [0x55u8; 16]; // 128 bits, 64 set + let ranks: Vec = (0..64).collect(); + let mut out = vec![0usize; 64]; + bit_select_sorted_batch(&buf, 0, 128, &ranks, &mut out); + for (nth, &pos) in out.iter().enumerate() { + assert_eq!(pos, nth * 2, "nth={nth}"); + } + } + + #[test] + fn test_batch_select_sparse_ranks() { + let buf = [0xFFu8; 16]; // 128 set + let ranks = [0, 10, 50, 100, 127]; + let mut out = [0usize; 5]; + bit_select_sorted_batch(&buf, 0, 128, &ranks, &mut out); + assert_eq!(out, [0, 10, 50, 100, 127]); + } + + #[test] + fn test_batch_select_empty() { + let buf = [0xFFu8; 4]; + let mut out = []; + bit_select_sorted_batch(&buf, 0, 32, &[], &mut out); + } + + #[rstest] + #[case(0, 128)] + #[case(3, 100)] + #[case(7, 50)] + #[case(0, 65)] + #[case(3, 256)] + fn test_batch_select_agrees_with_individual(#[case] offset: usize, #[case] len: usize) { + let total_bytes = (offset + len).div_ceil(8); + let buf: Vec = (0..total_bytes) + .map(|i| ((i.wrapping_mul(0x9E) ^ 0xA5) & 0xFF) as u8) + .collect(); + + // Get individual results. + let true_count = (0..len) + .filter(|&i| { + let phys = offset + i; + (buf[phys / 8] >> (phys % 8)) & 1 == 1 + }) + .count(); + + let all_ranks: Vec = (0..true_count).collect(); + let individual: Vec = all_ranks + .iter() + .map(|&r| bit_select(&buf, offset, len, r)) + .collect(); + + let mut batch = vec![0usize; true_count]; + bit_select_sorted_batch(&buf, offset, len, &all_ranks, &mut batch); + + assert_eq!(batch, individual, "offset={offset} len={len}"); + } +} diff --git a/vortex-mask/Cargo.toml b/vortex-mask/Cargo.toml index 037ffdb0d9d..a2af3d27cb0 100644 --- a/vortex-mask/Cargo.toml +++ b/vortex-mask/Cargo.toml @@ -35,5 +35,9 @@ rstest = { workspace = true } name = "intersect_by_rank" harness = false +[[bench]] +name = "rank" +harness = false + [lints] workspace = true diff --git a/vortex-mask/benches/rank.rs b/vortex-mask/benches/rank.rs new file mode 100644 index 00000000000..7dac0a991f9 --- /dev/null +++ b/vortex-mask/benches/rank.rs @@ -0,0 +1,94 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +//! Benchmarks for `Mask::rank`. + +#![allow(clippy::unwrap_used, clippy::cast_possible_truncation)] + +use divan::Bencher; +use vortex_buffer::BitBuffer; +use vortex_mask::Mask; + +fn main() { + divan::main(); +} + +const BENCH_SIZES: &[usize] = &[1_024, 16_384, 65_536, 262_144]; + +fn create_mask(len: usize, density: f64) -> Mask { + Mask::from_buffer(BitBuffer::from_iter((0..len).map(|i| { + let threshold = (density * 1000.0) as usize; + (i * 7 + 13) % 1000 < threshold + }))) +} + +/// Single rank lookup at the midpoint. +#[divan::bench(args = BENCH_SIZES)] +fn rank_single_mid(bencher: Bencher, len: usize) { + let mask = create_mask(len, 0.5); + let mid = mask.true_count() / 2; + bencher + .with_inputs(|| (&mask, mid)) + .bench_refs(|(mask, mid)| mask.rank(*mid)); +} + +/// Rank every set bit sequentially (worst-case total scan). +#[divan::bench(args = BENCH_SIZES)] +fn rank_all_sequential(bencher: Bencher, len: usize) { + let mask = create_mask(len, 0.5); + let tc = mask.true_count(); + bencher.with_inputs(|| &mask).bench_refs(|mask| { + for nth in 0..tc { + divan::black_box(mask.rank(nth)); + } + }); +} + +/// Single rank on a sparse mask (1% density). +#[divan::bench(args = BENCH_SIZES)] +fn rank_single_sparse(bencher: Bencher, len: usize) { + let mask = create_mask(len, 0.01); + let mid = mask.true_count() / 2; + bencher + .with_inputs(|| (&mask, mid)) + .bench_refs(|(mask, mid)| mask.rank(*mid)); +} + +/// Single rank on a dense mask (90% density). +#[divan::bench(args = BENCH_SIZES)] +fn rank_single_dense(bencher: Bencher, len: usize) { + let mask = create_mask(len, 0.9); + let mid = mask.true_count() / 2; + bencher + .with_inputs(|| (&mask, mid)) + .bench_refs(|(mask, mid)| mask.rank(*mid)); +} + +// ── Batch rank benchmarks ─────────────────────────────────────────── + +/// rank_batch: all ranks at once (single-pass). +#[divan::bench(args = BENCH_SIZES)] +fn rank_batch_all(bencher: Bencher, len: usize) { + let mask = create_mask(len, 0.5); + let tc = mask.true_count(); + let ranks: Vec = (0..tc).collect(); + bencher + .with_inputs(|| (&mask, &ranks)) + .bench_refs(|(mask, ranks)| mask.rank_batch(ranks)); +} + +/// rank_batch with scrambled (unsorted) input. +#[divan::bench(args = BENCH_SIZES)] +fn rank_batch_scrambled(bencher: Bencher, len: usize) { + let mask = create_mask(len, 0.5); + let tc = mask.true_count(); + // Scramble: reverse + interleave odd/even. + let mut ranks: Vec = (0..tc).collect(); + let mid = tc / 2; + for i in 0..mid { + ranks.swap(i, tc - 1 - i); + } + bencher + .with_inputs(|| (&mask, &ranks)) + .bench_refs(|(mask, ranks)| mask.rank_batch(ranks)); +} diff --git a/vortex-mask/public-api.lock b/vortex-mask/public-api.lock index 941c422c4f8..8f5cd22d9d0 100644 --- a/vortex-mask/public-api.lock +++ b/vortex-mask/public-api.lock @@ -82,6 +82,8 @@ pub fn vortex_mask::Mask::new_true(length: usize) -> Self pub fn vortex_mask::Mask::rank(&self, n: usize) -> usize +pub fn vortex_mask::Mask::rank_batch(&self, ranks: &[usize]) -> alloc::vec::Vec + pub fn vortex_mask::Mask::slice(&self, range: impl core::ops::range::RangeBounds) -> Self pub fn vortex_mask::Mask::slices(&self) -> vortex_mask::AllOr<&[(usize, usize)]> diff --git a/vortex-mask/src/intersect_by_rank.rs b/vortex-mask/src/intersect_by_rank.rs index efce6f17dbd..0739349e2e3 100644 --- a/vortex-mask/src/intersect_by_rank.rs +++ b/vortex-mask/src/intersect_by_rank.rs @@ -29,23 +29,15 @@ impl Mask { pub fn intersect_by_rank(&self, mask: &Mask) -> Mask { assert_eq!(self.true_count(), mask.len()); - match (self.indices(), mask.indices()) { + match (self.bit_buffer(), mask.indices()) { (AllOr::All, _) => mask.clone(), (_, AllOr::All) => self.clone(), (AllOr::None, _) | (_, AllOr::None) => Self::new_false(self.len()), - (AllOr::Some(self_indices), AllOr::Some(mask_indices)) => { - Self::from_indices( - self.len(), - mask_indices - .iter() - .map(|idx| - // This is verified as safe because we know that the indices are less than the - // mask.len() and we known mask.len() <= self.len(), - // implied by `self.true_count() == mask.len()`. - unsafe{*self_indices.get_unchecked(*idx)}) - .collect(), - ) + (AllOr::Some(self_buffer), AllOr::Some(mask_indices)) => { + // mask_indices are already sorted — single-pass batch select + // avoids materializing the full self.indices() vector. + Self::from_indices(self.len(), self_buffer.select_sorted_batch(mask_indices)) } } } diff --git a/vortex-mask/src/lib.rs b/vortex-mask/src/lib.rs index da2a9f6a6d9..793a3b16710 100644 --- a/vortex-mask/src/lib.rs +++ b/vortex-mask/src/lib.rs @@ -472,8 +472,39 @@ impl Mask { match &self { Self::AllTrue(_) => n, Self::AllFalse(_) => unreachable!("no true values in all-false mask"), - // TODO(joe): optimize this function - Self::Values(values) => values.indices()[n], + Self::Values(values) => values.buffer.select(n), + } + } + + /// Translate multiple positions through the mask in batch. + /// + /// For each `ranks[i]`, computes the position of the `ranks[i]`-th set bit, + /// equivalent to calling [`rank`](Self::rank) for each element. + /// + /// This is O(N log N + L/64) vs O(N × L/64) for individual calls, where + /// N = `ranks.len()` and L = mask length. + pub fn rank_batch(&self, ranks: &[usize]) -> Vec { + if ranks.is_empty() { + return vec![]; + } + match &self { + Self::AllTrue(_) => ranks.to_vec(), + Self::AllFalse(_) => unreachable!("no true values in all-false mask"), + Self::Values(values) => { + // Sort an index permutation by rank value. + let mut perm: Vec = (0..ranks.len()).collect(); + perm.sort_unstable_by_key(|&i| ranks[i]); + + let sorted_ranks: Vec = perm.iter().map(|&i| ranks[i]).collect(); + let sorted_results = values.buffer.select_sorted_batch(&sorted_ranks); + + // Scatter back to original order. + let mut results = vec![0usize; ranks.len()]; + for (perm_idx, &orig_idx) in perm.iter().enumerate() { + results[orig_idx] = sorted_results[perm_idx]; + } + results + } } } From 44d12f007620d108b32f7c816efc6d077bd5d75c Mon Sep 17 00:00:00 2001 From: Robert Kruszewski Date: Fri, 10 Apr 2026 23:41:27 +0100 Subject: [PATCH 2/3] lint Signed-off-by: Robert Kruszewski --- vortex-buffer/src/bit/select.rs | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/vortex-buffer/src/bit/select.rs b/vortex-buffer/src/bit/select.rs index c070ae80e86..40eeb2a2b24 100644 --- a/vortex-buffer/src/bit/select.rs +++ b/vortex-buffer/src/bit/select.rs @@ -457,7 +457,9 @@ fn select_in_word(word: u64, nth: usize) -> usize { unsafe fn select_in_word_bmi2(word: u64, nth: usize) -> usize { use std::arch::x86_64::_pdep_u64; use std::arch::x86_64::_tzcnt_u64; - unsafe { _tzcnt_u64(_pdep_u64(1u64 << nth, word)) as usize } + use vortex_error::VortexExpect; + + usize::try_from(unsafe { _tzcnt_u64(_pdep_u64(1u64 << nth, word)) }).vortex_expect("safe to convert tzcnt result to usize") } /// Scalar: narrow to the correct byte, then clear `nth` lowest set bits and trailing-zeros. From 0ba188e7bca0ed45cbde2ab7f65a3bcdd0104196 Mon Sep 17 00:00:00 2001 From: Robert Kruszewski Date: Fri, 10 Apr 2026 23:54:16 +0100 Subject: [PATCH 3/3] format Signed-off-by: Robert Kruszewski --- vortex-buffer/src/bit/select.rs | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/vortex-buffer/src/bit/select.rs b/vortex-buffer/src/bit/select.rs index 40eeb2a2b24..1ac6336971d 100644 --- a/vortex-buffer/src/bit/select.rs +++ b/vortex-buffer/src/bit/select.rs @@ -457,9 +457,11 @@ fn select_in_word(word: u64, nth: usize) -> usize { unsafe fn select_in_word_bmi2(word: u64, nth: usize) -> usize { use std::arch::x86_64::_pdep_u64; use std::arch::x86_64::_tzcnt_u64; + use vortex_error::VortexExpect; - usize::try_from(unsafe { _tzcnt_u64(_pdep_u64(1u64 << nth, word)) }).vortex_expect("safe to convert tzcnt result to usize") + usize::try_from(unsafe { _tzcnt_u64(_pdep_u64(1u64 << nth, word)) }) + .vortex_expect("safe to convert tzcnt result to usize") } /// Scalar: narrow to the correct byte, then clear `nth` lowest set bits and trailing-zeros.