diff --git a/rust/spark-rs/src/sort.rs b/rust/spark-rs/src/sort.rs index 52130c99..8ee1423a 100644 --- a/rust/spark-rs/src/sort.rs +++ b/rust/spark-rs/src/sort.rs @@ -65,7 +65,10 @@ pub fn sort_internal(buffers: &mut SortBuffers, num_splats: usize) -> Result, /// bucket counts / offsets (length == RADIX_BASE) pub buckets16hi: Vec, - /// scratch space for indices - pub scratch: Vec, + /// scratch space for (key, index) + pub scratch: Vec, } impl Sort32Buffers { @@ -102,8 +105,18 @@ impl Sort32Buffers { } } +fn prefix_sum_exclusive(buckets: &mut [u32]) -> u32 { + let mut sum = 0u32; + for b in buckets.iter_mut() { + let tmp = *b; + *b = sum; + sum = sum.wrapping_add(tmp); + } + sum +} + /// Two‑pass radix sort (base 2¹⁶) of 32‑bit float bit‑patterns, -/// descending order (largest keys first). Mirrors the JS `sort32Splats`. +/// descending order (largest keys first). pub fn sort32_internal( buffers: &mut Sort32Buffers, max_splats: usize, @@ -115,52 +128,118 @@ pub fn sort32_internal( let Sort32Buffers { readback, ordering, buckets16lo, buckets16hi, scratch } = buffers; let keys = &readback[..num_splats]; - // tally low and high buckets + // tally low and high buckets (branchless) buckets16lo.fill(0); buckets16hi.fill(0); - for &key in keys.iter() { - if key < DEPTH_INFINITY_F32 { - let inv = !key; - buckets16lo[(inv & 0xFFFF) as usize] += 1; - buckets16hi[(inv >> 16) as usize] += 1; - } + + macro_rules! tick { + ($key:expr) => {{ + let valid = ($key < DEPTH_INFINITY_F32) as u32; + let inv = !$key; + let lo = inv & RADIX_MASK; + let hi = inv >> RADIX_BITS; + + // by mask above: lo < 65536 == buckets16lo.len() == RADIX_BASE + unsafe { *buckets16lo.get_unchecked_mut(lo as usize) += valid; } + // by shift above: hi < 65536 == buckets16hi.len() == RADIX_BASE + unsafe { *buckets16hi.get_unchecked_mut(hi as usize) += valid; } + }}; } - // ——— Pass #1: bucket by inv(low 16 bits) ——— - // exclusive prefix‑sum → starting offsets - let mut total: u32 = 0; - for slot in buckets16lo.iter_mut() { - let cnt = *slot; - *slot = total; - total = total.wrapping_add(cnt); + let mut chunks = keys.chunks_exact(8); + + for chunk in chunks.by_ref() { + tick!(chunk[0]); + tick!(chunk[1]); + tick!(chunk[2]); + tick!(chunk[3]); + tick!(chunk[4]); + tick!(chunk[5]); + tick!(chunk[6]); + tick!(chunk[7]); } - let active_splats = total; + + for &k in chunks.remainder() { + tick!(k); + } + + // exclusive prefix‑sum → starting offsets + let active_splats = prefix_sum_exclusive(buckets16lo); + prefix_sum_exclusive(buckets16hi); + + // ——— Pass #1: bucket by inv(low 16 bits) ——— // scatter into scratch by low bits of inv - for (i, &key) in keys.iter().enumerate() { - if key < DEPTH_INFINITY_F32 { - let inv = !key; - let lo = (inv & 0xFFFF) as usize; - scratch[buckets16lo[lo] as usize] = i as u32; - buckets16lo[lo] += 1; - } + macro_rules! place { + ($key:expr, $idx:expr) => {{ + if $key < DEPTH_INFINITY_F32 { + let inv = !$key; + let lo = (inv & RADIX_MASK) as usize; + // by mask above: lo < 65536 == buckets16lo.len() == RADIX_BASE + let pos = unsafe { *buckets16lo.get_unchecked(lo) } as usize; + let inv_idx = ((inv as u64) << 32) | ($idx as u64); + + // by design we have pos < active_splats <= max_splats <= scratch.len() + unsafe { *scratch.get_unchecked_mut(pos) = inv_idx; } + // by mask above: lo < 65536 == buckets16lo.len() == RADIX_BASE + unsafe { *buckets16lo.get_unchecked_mut(lo) += 1; } + } + }}; } - // ——— Pass #2: bucket by inv(high 16 bits) ——— - // exclusive prefix‑sum again - let mut sum: u32 = 0; - for slot in buckets16hi.iter_mut() { - let cnt = *slot; - *slot = sum; - sum = sum.wrapping_add(cnt); + let mut chunks = keys.chunks_exact(8); + let mut i = 0; + + for chunk in chunks.by_ref() { + place!(chunk[0], i); + place!(chunk[1], i + 1); + place!(chunk[2], i + 2); + place!(chunk[3], i + 3); + place!(chunk[4], i + 4); + place!(chunk[5], i + 5); + place!(chunk[6], i + 6); + place!(chunk[7], i + 7); + + i += 8; } + + for &k in chunks.remainder() { + place!(k, i); + i += 1; + } + + // ——— Pass #2: bucket by inv(high 16 bits) ——— + // scatter into final ordering by high bits of inv - for &idx in scratch.iter().take(active_splats as usize) { - let key = keys[idx as usize]; - let inv = !key; - let hi = (inv >> 16) as usize; - ordering[buckets16hi[hi] as usize] = idx; - buckets16hi[hi] += 1; + macro_rules! place2 { + ($inv_idx:expr) => {{ + let idx = $inv_idx as u32; + let hi = (($inv_idx >> 48) & RADIX_MASK as u64) as usize; + // by mask above: hi < 65536 == buckets16hi.len() == RADIX_BASE + let pos = unsafe { *buckets16hi.get_unchecked(hi) } as usize; + + // by design we have pos < active_splats <= max_splats <= ordering.len() + unsafe { *ordering.get_unchecked_mut(pos) = idx; } + // by mask above: hi < 65536 == buckets16hi.len() == RADIX_BASE + unsafe { *buckets16hi.get_unchecked_mut(hi) += 1; } + }}; + } + + let mut chunks = scratch[..active_splats as usize].chunks_exact(8); + + for chunk in chunks.by_ref() { + place2!(chunk[0]); + place2!(chunk[1]); + place2!(chunk[2]); + place2!(chunk[3]); + place2!(chunk[4]); + place2!(chunk[5]); + place2!(chunk[6]); + place2!(chunk[7]); + } + + for &inv_idx in chunks.remainder() { + place2!(inv_idx); } // sanity‑check: last bucket should have consumed all entries @@ -173,4 +252,4 @@ pub fn sort32_internal( } Ok(active_splats) -} \ No newline at end of file +}