diff --git a/encodings/runend/src/compute/take.rs b/encodings/runend/src/compute/take.rs index 7100faf9eac..ecc98071693 100644 --- a/encodings/runend/src/compute/take.rs +++ b/encodings/runend/src/compute/take.rs @@ -9,18 +9,22 @@ use vortex_array::ExecutionCtx; use vortex_array::IntoArray; use vortex_array::arrays::PrimitiveArray; use vortex_array::arrays::dict::TakeExecute; +use vortex_array::dtype::UnsignedPType; use vortex_array::match_each_integer_ptype; -use vortex_array::search_sorted::SearchResult; -use vortex_array::search_sorted::SearchSorted; -use vortex_array::search_sorted::SearchSortedSide; +use vortex_array::match_each_unsigned_integer_ptype; use vortex_array::validity::Validity; use vortex_buffer::Buffer; +use vortex_buffer::BufferMut; use vortex_error::VortexResult; use vortex_error::vortex_bail; +use vortex_mask::Mask; use crate::RunEnd; use crate::array::RunEndArrayExt; +const SORTED_LINEAR_RUNS_PER_INDEX_THRESHOLD: usize = 16; +const UNSORTED_LINEAR_RUNS_PER_INDEX_THRESHOLD: usize = 4; + impl TakeExecute for RunEnd { #[expect( clippy::cast_possible_truncation, @@ -32,13 +36,20 @@ impl TakeExecute for RunEnd { ctx: &mut ExecutionCtx, ) -> VortexResult> { let primitive_indices = indices.clone().execute::(ctx)?; + let indices_validity = primitive_indices.validity()?; + let indices_mask = indices_validity.execute_mask(primitive_indices.len(), ctx)?; let checked_indices = match_each_integer_ptype!(primitive_indices.ptype(), |P| { primitive_indices .as_slice::

() .iter() .copied() - .map(|idx| { + .enumerate() + .map(|(idx_pos, idx)| { + if !indices_mask.value(idx_pos) { + return Ok(0); + } + let usize_idx = idx as usize; if usize_idx >= array.len() { vortex_bail!(OutOfBounds: usize_idx, 0, array.len()); @@ -48,46 +59,218 @@ impl TakeExecute for RunEnd { .collect::>>()? }); - let indices_validity = primitive_indices.validity()?; - take_indices_unchecked(array, &checked_indices, &indices_validity, ctx).map(Some) + take_indices_unchecked_with_mask( + array, + &checked_indices, + &indices_validity, + &indices_mask, + ctx, + ) + .map(Some) } } -/// Perform a take operation on a RunEndArray by binary searching for each of the indices. +/// Perform a take operation on a RunEndArray. pub fn take_indices_unchecked>( array: ArrayView<'_, RunEnd>, indices: &[T], validity: &Validity, ctx: &mut ExecutionCtx, +) -> VortexResult { + let validity_mask = validity.execute_mask(indices.len(), ctx)?; + take_indices_unchecked_with_mask(array, indices, validity, &validity_mask, ctx) +} + +fn take_indices_unchecked_with_mask>( + array: ArrayView<'_, RunEnd>, + indices: &[T], + validity: &Validity, + validity_mask: &Mask, + ctx: &mut ExecutionCtx, ) -> VortexResult { let ends = array.ends().clone().execute::(ctx)?; - let ends_len = ends.len(); - // TODO(joe): use the validity mask to skip search sorted. - let physical_indices = match_each_integer_ptype!(ends.ptype(), |I| { + let physical_indices = match_each_unsigned_integer_ptype!(ends.ptype(), |I| { let end_slices = ends.as_slice::(); - let physical_indices_vec: Vec = indices - .iter() - .map(|idx| idx.as_() + array.offset()) - .map(|idx| { - match ::from(idx) { - Some(idx) => end_slices.search_sorted(&idx, SearchSortedSide::Right), - None => { - // The idx is too large for I, therefore it's out of bounds. - Ok(SearchResult::NotFound(ends_len)) - } - } - }) - .map(|result| result.map(|r| r.to_ends_index(ends_len) as u64)) - .collect::>>()?; - let buffer = Buffer::from(physical_indices_vec); + let physical_indices = physical_indices(end_slices, array.offset(), indices, validity_mask); - PrimitiveArray::new(buffer, validity.clone()) + PrimitiveArray::new(physical_indices, validity.clone()) }); array.values().take(physical_indices.into_array()) } +fn physical_indices( + ends: &[I], + offset: usize, + indices: &[T], + validity_mask: &Mask, +) -> Buffer +where + I: UnsignedPType, + T: AsPrimitive, +{ + let (valid_count, valid_indices_sorted) = valid_indices_stats(indices, validity_mask); + + if valid_count == 0 { + return Buffer::zeroed(indices.len()); + } + + if valid_indices_sorted + && prefer_linear_scan( + ends.len(), + valid_count, + SORTED_LINEAR_RUNS_PER_INDEX_THRESHOLD, + ) + { + return physical_indices_linear_sorted(ends, offset, indices, validity_mask); + } + + if prefer_linear_scan( + ends.len(), + valid_count, + UNSORTED_LINEAR_RUNS_PER_INDEX_THRESHOLD, + ) { + return physical_indices_linear_unsorted(ends, offset, indices, validity_mask, valid_count); + } + + physical_indices_binary(ends, offset, indices, validity_mask) +} + +fn valid_indices_stats>( + indices: &[T], + validity_mask: &Mask, +) -> (usize, bool) { + let mut valid_count = 0; + let mut previous_idx = None; + let mut sorted = true; + + for (idx_pos, idx) in indices.iter().enumerate() { + if !validity_mask.value(idx_pos) { + continue; + } + + let idx = idx.as_(); + if previous_idx.is_some_and(|previous_idx| previous_idx > idx) { + sorted = false; + } + previous_idx = Some(idx); + valid_count += 1; + } + + (valid_count, sorted) +} + +fn prefer_linear_scan( + ends_len: usize, + valid_count: usize, + runs_per_index_threshold: usize, +) -> bool { + ends_len <= valid_count.saturating_mul(runs_per_index_threshold) +} + +fn physical_indices_linear_sorted( + ends: &[I], + offset: usize, + indices: &[T], + validity_mask: &Mask, +) -> Buffer +where + I: UnsignedPType, + T: AsPrimitive, +{ + let mut physical_indices = BufferMut::zeroed(indices.len()); + let mut run_idx = 0; + + for (idx_pos, idx) in indices.iter().enumerate() { + if !validity_mask.value(idx_pos) { + continue; + } + + let logical_idx = idx.as_() + offset; + advance_run(ends, &mut run_idx, logical_idx); + physical_indices[idx_pos] = run_idx as u64; + } + + physical_indices.freeze() +} + +fn physical_indices_linear_unsorted( + ends: &[I], + offset: usize, + indices: &[T], + validity_mask: &Mask, + valid_count: usize, +) -> Buffer +where + I: UnsignedPType, + T: AsPrimitive, +{ + let mut pairs = Vec::with_capacity(valid_count); + for (idx_pos, idx) in indices.iter().enumerate() { + if validity_mask.value(idx_pos) { + pairs.push((idx.as_(), idx_pos)); + } + } + pairs.sort_unstable(); + + let mut physical_indices = BufferMut::zeroed(indices.len()); + let mut run_idx = 0; + + for (idx, idx_pos) in pairs { + let logical_idx = idx + offset; + advance_run(ends, &mut run_idx, logical_idx); + physical_indices[idx_pos] = run_idx as u64; + } + + physical_indices.freeze() +} + +fn physical_indices_binary( + ends: &[I], + offset: usize, + indices: &[T], + validity_mask: &Mask, +) -> Buffer +where + I: UnsignedPType, + T: AsPrimitive, +{ + let mut physical_indices = BufferMut::zeroed(indices.len()); + + for (idx_pos, idx) in indices.iter().enumerate() { + if !validity_mask.value(idx_pos) { + continue; + } + + let logical_idx = idx.as_() + offset; + physical_indices[idx_pos] = physical_index_binary(ends, logical_idx) as u64; + } + + physical_indices.freeze() +} + +fn physical_index_binary(ends: &[I], logical_idx: usize) -> usize { + let index = match ::from(logical_idx) { + Some(logical_idx) => ends.partition_point(|end| *end <= logical_idx), + None => ends.len(), + }; + index.min(ends.len() - 1) +} + +fn advance_run(ends: &[I], run_idx: &mut usize, logical_idx: usize) { + while *run_idx + 1 < ends.len() && run_end_le_logical_idx(ends[*run_idx], logical_idx) { + *run_idx += 1; + } +} + +fn run_end_le_logical_idx(run_end: I, logical_idx: usize) -> bool { + match ::from(logical_idx) { + Some(logical_idx) => run_end <= logical_idx, + None => true, + } +} + #[cfg(test)] mod tests { use rstest::rstest; @@ -96,9 +279,11 @@ mod tests { use vortex_array::IntoArray; use vortex_array::LEGACY_SESSION; use vortex_array::VortexSessionExecute; + use vortex_array::arrays::BoolArray; use vortex_array::arrays::PrimitiveArray; use vortex_array::assert_arrays_eq; use vortex_array::compute::conformance::take::test_take_conformance; + use vortex_array::validity::Validity; use vortex_buffer::buffer; use crate::RunEnd; @@ -126,6 +311,15 @@ mod tests { assert_arrays_eq!(taken, expected); } + #[test] + fn ree_take_sorted_boundaries() { + let taken = ree_array() + .take(buffer![0, 2, 3, 6, 8, 11].into_array()) + .unwrap(); + let expected = PrimitiveArray::from_iter(vec![1i32, 1, 4, 2, 5, 5]).into_array(); + assert_arrays_eq!(taken, expected); + } + #[test] #[should_panic] fn ree_take_out_of_bounds() { @@ -155,6 +349,18 @@ mod tests { assert_arrays_eq!(taken, expected.into_array()); } + #[test] + fn ree_take_null_index_skips_out_of_bounds_value() { + let indices = PrimitiveArray::new( + buffer![1u64, 12], + Validity::Array(BoolArray::from_iter([true, false]).into_array()), + ); + let taken = ree_array().take(indices.into_array()).unwrap(); + + let expected = PrimitiveArray::from_option_iter([Some(1i32), None]); + assert_arrays_eq!(taken, expected.into_array()); + } + #[rstest] #[case(ree_array())] #[case(RunEnd::encode(