Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
258 changes: 232 additions & 26 deletions encodings/runend/src/compute/take.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,18 +9,22 @@ use vortex_array::ExecutionCtx;
use vortex_array::IntoArray;
use vortex_array::arrays::PrimitiveArray;
use vortex_array::arrays::dict::TakeExecute;
use vortex_array::dtype::UnsignedPType;
use vortex_array::match_each_integer_ptype;
use vortex_array::search_sorted::SearchResult;
use vortex_array::search_sorted::SearchSorted;
use vortex_array::search_sorted::SearchSortedSide;
use vortex_array::match_each_unsigned_integer_ptype;
use vortex_array::validity::Validity;
use vortex_buffer::Buffer;
use vortex_buffer::BufferMut;
use vortex_error::VortexResult;
use vortex_error::vortex_bail;
use vortex_mask::Mask;

use crate::RunEnd;
use crate::array::RunEndArrayExt;

const SORTED_LINEAR_RUNS_PER_INDEX_THRESHOLD: usize = 16;
const UNSORTED_LINEAR_RUNS_PER_INDEX_THRESHOLD: usize = 4;

impl TakeExecute for RunEnd {
#[expect(
clippy::cast_possible_truncation,
Expand All @@ -32,13 +36,20 @@ impl TakeExecute for RunEnd {
ctx: &mut ExecutionCtx,
) -> VortexResult<Option<ArrayRef>> {
let primitive_indices = indices.clone().execute::<PrimitiveArray>(ctx)?;
let indices_validity = primitive_indices.validity()?;
let indices_mask = indices_validity.execute_mask(primitive_indices.len(), ctx)?;

let checked_indices = match_each_integer_ptype!(primitive_indices.ptype(), |P| {
primitive_indices
.as_slice::<P>()
.iter()
.copied()
.map(|idx| {
.enumerate()
.map(|(idx_pos, idx)| {
if !indices_mask.value(idx_pos) {
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that looks very slow?

return Ok(0);
}

let usize_idx = idx as usize;
if usize_idx >= array.len() {
vortex_bail!(OutOfBounds: usize_idx, 0, array.len());
Expand All @@ -48,46 +59,218 @@ impl TakeExecute for RunEnd {
.collect::<VortexResult<Vec<_>>>()?
});

let indices_validity = primitive_indices.validity()?;
take_indices_unchecked(array, &checked_indices, &indices_validity, ctx).map(Some)
take_indices_unchecked_with_mask(
array,
&checked_indices,
&indices_validity,
&indices_mask,
ctx,
)
.map(Some)
}
}

/// Perform a take operation on a RunEndArray by binary searching for each of the indices.
/// Perform a take operation on a RunEndArray.
pub fn take_indices_unchecked<T: AsPrimitive<usize>>(
array: ArrayView<'_, RunEnd>,
indices: &[T],
validity: &Validity,
ctx: &mut ExecutionCtx,
) -> VortexResult<ArrayRef> {
let validity_mask = validity.execute_mask(indices.len(), ctx)?;
take_indices_unchecked_with_mask(array, indices, validity, &validity_mask, ctx)
}

fn take_indices_unchecked_with_mask<T: AsPrimitive<usize>>(
array: ArrayView<'_, RunEnd>,
indices: &[T],
validity: &Validity,
validity_mask: &Mask,
ctx: &mut ExecutionCtx,
) -> VortexResult<ArrayRef> {
let ends = array.ends().clone().execute::<PrimitiveArray>(ctx)?;
let ends_len = ends.len();

// TODO(joe): use the validity mask to skip search sorted.
let physical_indices = match_each_integer_ptype!(ends.ptype(), |I| {
let physical_indices = match_each_unsigned_integer_ptype!(ends.ptype(), |I| {
let end_slices = ends.as_slice::<I>();
let physical_indices_vec: Vec<u64> = indices
.iter()
.map(|idx| idx.as_() + array.offset())
.map(|idx| {
match <I as NumCast>::from(idx) {
Some(idx) => end_slices.search_sorted(&idx, SearchSortedSide::Right),
None => {
// The idx is too large for I, therefore it's out of bounds.
Ok(SearchResult::NotFound(ends_len))
}
}
})
.map(|result| result.map(|r| r.to_ends_index(ends_len) as u64))
.collect::<VortexResult<Vec<_>>>()?;
let buffer = Buffer::from(physical_indices_vec);
let physical_indices = physical_indices(end_slices, array.offset(), indices, validity_mask);

PrimitiveArray::new(buffer, validity.clone())
PrimitiveArray::new(physical_indices, validity.clone())
});

array.values().take(physical_indices.into_array())
}

fn physical_indices<I, T>(
ends: &[I],
offset: usize,
indices: &[T],
validity_mask: &Mask,
) -> Buffer<u64>
where
I: UnsignedPType,
T: AsPrimitive<usize>,
{
let (valid_count, valid_indices_sorted) = valid_indices_stats(indices, validity_mask);

if valid_count == 0 {
return Buffer::zeroed(indices.len());
}

if valid_indices_sorted
&& prefer_linear_scan(
ends.len(),
valid_count,
SORTED_LINEAR_RUNS_PER_INDEX_THRESHOLD,
)
{
return physical_indices_linear_sorted(ends, offset, indices, validity_mask);
}

if prefer_linear_scan(
ends.len(),
valid_count,
UNSORTED_LINEAR_RUNS_PER_INDEX_THRESHOLD,
) {
return physical_indices_linear_unsorted(ends, offset, indices, validity_mask, valid_count);
}

physical_indices_binary(ends, offset, indices, validity_mask)
}

fn valid_indices_stats<T: AsPrimitive<usize>>(
indices: &[T],
validity_mask: &Mask,
) -> (usize, bool) {
let mut valid_count = 0;
let mut previous_idx = None;
let mut sorted = true;

for (idx_pos, idx) in indices.iter().enumerate() {
if !validity_mask.value(idx_pos) {
continue;
}

let idx = idx.as_();
if previous_idx.is_some_and(|previous_idx| previous_idx > idx) {
sorted = false;
}
previous_idx = Some(idx);
valid_count += 1;
}

(valid_count, sorted)
}

fn prefer_linear_scan(
ends_len: usize,
valid_count: usize,
runs_per_index_threshold: usize,
) -> bool {
ends_len <= valid_count.saturating_mul(runs_per_index_threshold)
}

fn physical_indices_linear_sorted<I, T>(
ends: &[I],
offset: usize,
indices: &[T],
validity_mask: &Mask,
) -> Buffer<u64>
where
I: UnsignedPType,
T: AsPrimitive<usize>,
{
let mut physical_indices = BufferMut::zeroed(indices.len());
let mut run_idx = 0;

for (idx_pos, idx) in indices.iter().enumerate() {
if !validity_mask.value(idx_pos) {
continue;
}

let logical_idx = idx.as_() + offset;
advance_run(ends, &mut run_idx, logical_idx);
physical_indices[idx_pos] = run_idx as u64;
}

physical_indices.freeze()
}

fn physical_indices_linear_unsorted<I, T>(
ends: &[I],
offset: usize,
indices: &[T],
validity_mask: &Mask,
valid_count: usize,
) -> Buffer<u64>
where
I: UnsignedPType,
T: AsPrimitive<usize>,
{
let mut pairs = Vec::with_capacity(valid_count);
for (idx_pos, idx) in indices.iter().enumerate() {
if validity_mask.value(idx_pos) {
pairs.push((idx.as_(), idx_pos));
}
}
pairs.sort_unstable();

let mut physical_indices = BufferMut::zeroed(indices.len());
let mut run_idx = 0;

for (idx, idx_pos) in pairs {
let logical_idx = idx + offset;
advance_run(ends, &mut run_idx, logical_idx);
physical_indices[idx_pos] = run_idx as u64;
}

physical_indices.freeze()
}

fn physical_indices_binary<I, T>(
ends: &[I],
offset: usize,
indices: &[T],
validity_mask: &Mask,
) -> Buffer<u64>
where
I: UnsignedPType,
T: AsPrimitive<usize>,
{
let mut physical_indices = BufferMut::zeroed(indices.len());

for (idx_pos, idx) in indices.iter().enumerate() {
if !validity_mask.value(idx_pos) {
continue;
}

let logical_idx = idx.as_() + offset;
physical_indices[idx_pos] = physical_index_binary(ends, logical_idx) as u64;
}

physical_indices.freeze()
}

fn physical_index_binary<I: UnsignedPType>(ends: &[I], logical_idx: usize) -> usize {
let index = match <I as NumCast>::from(logical_idx) {
Some(logical_idx) => ends.partition_point(|end| *end <= logical_idx),
None => ends.len(),
};
index.min(ends.len() - 1)
}

fn advance_run<I: UnsignedPType>(ends: &[I], run_idx: &mut usize, logical_idx: usize) {
while *run_idx + 1 < ends.len() && run_end_le_logical_idx(ends[*run_idx], logical_idx) {
*run_idx += 1;
}
}

fn run_end_le_logical_idx<I: UnsignedPType>(run_end: I, logical_idx: usize) -> bool {
match <I as NumCast>::from(logical_idx) {
Some(logical_idx) => run_end <= logical_idx,
None => true,
}
}

#[cfg(test)]
mod tests {
use rstest::rstest;
Expand All @@ -96,9 +279,11 @@ mod tests {
use vortex_array::IntoArray;
use vortex_array::LEGACY_SESSION;
use vortex_array::VortexSessionExecute;
use vortex_array::arrays::BoolArray;
use vortex_array::arrays::PrimitiveArray;
use vortex_array::assert_arrays_eq;
use vortex_array::compute::conformance::take::test_take_conformance;
use vortex_array::validity::Validity;
use vortex_buffer::buffer;

use crate::RunEnd;
Expand Down Expand Up @@ -126,6 +311,15 @@ mod tests {
assert_arrays_eq!(taken, expected);
}

#[test]
fn ree_take_sorted_boundaries() {
let taken = ree_array()
.take(buffer![0, 2, 3, 6, 8, 11].into_array())
.unwrap();
let expected = PrimitiveArray::from_iter(vec![1i32, 1, 4, 2, 5, 5]).into_array();
assert_arrays_eq!(taken, expected);
}

#[test]
#[should_panic]
fn ree_take_out_of_bounds() {
Expand Down Expand Up @@ -155,6 +349,18 @@ mod tests {
assert_arrays_eq!(taken, expected.into_array());
}

#[test]
fn ree_take_null_index_skips_out_of_bounds_value() {
let indices = PrimitiveArray::new(
buffer![1u64, 12],
Validity::Array(BoolArray::from_iter([true, false]).into_array()),
);
let taken = ree_array().take(indices.into_array()).unwrap();

let expected = PrimitiveArray::from_option_iter([Some(1i32), None]);
assert_arrays_eq!(taken, expected.into_array());
}

#[rstest]
#[case(ree_array())]
#[case(RunEnd::encode(
Expand Down
Loading