Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
316 changes: 163 additions & 153 deletions zstd/src/fse/fse_encoder.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
use crate::bit_io::BitWriter;
use alloc::collections::BTreeSet;
use alloc::vec::Vec;

pub(crate) struct FSEEncoder<'output, V: AsMut<Vec<u8>>> {
Expand Down Expand Up @@ -225,11 +224,20 @@ impl FSETable {
/// return shape so callers can store `State` directly without
/// juggling lifetimes.
pub(crate) fn start_state(&self, symbol: u8) -> State {
let states = &self.states[symbol as usize];
let slot = states
.start_state_slot
let index = self.states[symbol as usize]
.start_state
.expect("symbol must be present in the FSE table");
states.states[slot]
// Callers consume only `index` (audited across the encoder + sequence
// emit paths). Donor `FSE_initCState2` likewise stores just the
// start state value; `num_bits` / `baseline` are properties of
// transitions, not of the initial state, so they have no
// meaningful values here and are zeroed.
State {
num_bits: 0,
baseline: 0,
last_index: 0,
index,
}
}

pub fn acc_log(&self) -> u8 {
Expand All @@ -242,11 +250,7 @@ impl FSETable {
}

pub(crate) fn max_num_bits_for_symbol(&self, symbol: u8) -> Option<u8> {
let states = &self.states[symbol as usize];
if states.probability == 0 {
return None;
}
states.states.iter().map(|state| state.num_bits).max()
self.states[symbol as usize].max_num_bits
}

/// Compute the exact serialized size (in bits) of the FSE table header,
Expand Down Expand Up @@ -359,18 +363,35 @@ impl FSETable {
}
}

#[derive(Debug, Clone)]
#[derive(Debug, Clone, Default)]
pub(super) struct SymbolStates {
/// Sorted by baseline to allow easy lookup using an index
pub(super) states: Vec<State>,
/// Donor-encoder start-state slot (`FSE_initCState2` result, in
/// `0..table_size`). `None` when `probability == 0`. Callers
/// consume only the `index` field of the [`State`] returned from
/// [`FSETable::start_state`] (see line-by-line audit on the
/// `encode` / `encode_interleaved` / compressed-block sequence
/// paths) — so the other [`State`] fields stay zeroed on the
/// returned value.
pub(super) start_state: Option<usize>,
/// Probability assigned to this symbol (`0` absent, `-1` less-than-one).
pub(super) probability: i32,
start_state_slot: Option<usize>,
/// Max `num_bits` emitted by [`FSETable::next_state`] across all input
/// states for this symbol. `None` when `probability == 0`. Computed via
/// donor arithmetic at build time (`(2*table_size-1 + delta_nb_bits) >> 16`)
/// so [`FSETable::max_num_bits_for_symbol`] is a single array load
/// instead of a `Vec<State>` scan.
pub(super) max_num_bits: Option<u8>,
}

// SymbolStates::get (the old linear-scan next-state lookup) was
// replaced by [`FSETable::next_state`]'s O(1) donor arithmetic in
// #164. The Vec<State> storage remains for `start_state`,
// `symbol_probability`, `max_num_bits_for_symbol`, and `write_table`.
// #164. The legacy per-symbol `Vec<State>` storage was dropped in
// #110: production no longer materializes any per-state vector; donor
// `FSE_buildCTable_wksp` only ever stores `nextStateTable` (here
// [`FSETable::state_table_flat`]) + `symbolTT` (here
// [`FSETable::symbol_tt`]). Everything else — start state,
// max-nb-bits, probability — is precomputed once per symbol via the
// donor arithmetic and held in [`SymbolStates`].

#[derive(Debug, Clone, Copy)]
pub(crate) struct State {
Expand Down Expand Up @@ -626,181 +647,170 @@ fn donor_normalize_m2(
}

pub(super) fn build_table_from_probabilities(probs: &[i32], acc_log: u8) -> FSETable {
let mut states = core::array::from_fn::<SymbolStates, 256, _>(|_| SymbolStates {
states: Vec::new(),
probability: 0,
start_state_slot: None,
});
let mut symbol_positions = core::array::from_fn::<Vec<usize>, 256, _>(|_| Vec::new());

// distribute -1 symbols
let mut negative_idx = (1 << acc_log) - 1;
for (symbol, _prob) in probs
.iter()
.copied()
.enumerate()
.filter(|prob| prob.1 == -1)
{
states[symbol].states.push(State {
num_bits: acc_log,
baseline: 0,
last_index: (1 << acc_log) - 1,
index: negative_idx,
});
symbol_positions[symbol].push(negative_idx);
states[symbol].probability = -1;
negative_idx -= 1;
let table_size: usize = 1 << acc_log;
let mut symbol_states: [SymbolStates; 256] = core::array::from_fn(|_| SymbolStates::default());

// Donor `FSE_buildCTable_wksp` (lib/compress/fse_compress.c) — build
// `nextStateTable` (== `state_table_flat`) once via cumul + spread +
// sorted-by-symbol sweep, without ever materializing per-symbol
// `Vec<State>`. Previous implementation paid an O(num_symbols ×
// table_size) enumeration with a BTreeSet dedup per symbol —
// ~18% self time on small-input encode profiles. Drop it; donor
// only ever needs symbolTT + nextStateTable, and we precompute
// `start_state` / `max_num_bits` for callers via the same
// arithmetic [`FSETable::next_state`] uses on the hot path.
//
// Phase 1 — distribute `-1` (low-probability) symbols at the top
// of the table; bump the high-threshold cursor down. Build the
// `cumul` prefix-sum table that maps each symbol to its first
// `nextStateTable` slot in sorted-by-symbol layout.
let mut table_symbol = alloc::vec![0u8; table_size];
let mut high_threshold = (table_size - 1) as isize;
// `cumul` / running prefix-sum holds slot counts up to `table_size`.
// Decoder accepts `accuracy_log` up to `ENTRY_MAX_ACCURACY_LOG = 16`
// and `fse_decoder::FSETable::to_encoder_table` round-trips through
// this builder; at `acc_log == 16` the prefix sum reaches 65 536
// which overflows `u16` (max 65 535). Keep `cumul` / `cursor` at
// `u32` so the cumulative count is representable for every valid
// `acc_log`. Slot indices written into `state_table_flat` stay in
// `0..table_size-1` (≤ u16::MAX) and remain `u16` — only the
// running cursor needs the wider type.
let mut cumul = [0u32; 257];
for (symbol, &prob) in probs.iter().enumerate() {
let bump: u32 = match prob {
-1 => {
table_symbol[high_threshold as usize] = symbol as u8;
high_threshold -= 1;
1
}
p if p > 0 => p as u32,
_ => 0,
};
cumul[symbol + 1] = cumul[symbol] + bump;
}

// distribute other symbols

// Setup all needed states per symbol with their respective index
let mut idx = 0;
for (symbol, prob) in probs.iter().copied().enumerate() {
// Phase 2 — spread positive-probability symbols across the
// remaining slots, donor `step`-walk with low-prob area skip.
let step = (table_size >> 1) + (table_size >> 3) + 3;
let table_mask = table_size - 1;
let mut position: usize = 0;
for (symbol, &prob) in probs.iter().enumerate() {
if prob <= 0 {
continue;
}
states[symbol].probability = prob;
let states = &mut states[symbol].states;
let positions = &mut symbol_positions[symbol];
for _ in 0..prob {
states.push(State {
num_bits: 0,
baseline: 0,
last_index: 0,
index: idx,
});
positions.push(idx);

idx = next_position(idx, 1 << acc_log);
while idx > negative_idx {
idx = next_position(idx, 1 << acc_log);
table_symbol[position] = symbol as u8;
position = (position + step) & table_mask;
while (position as isize) > high_threshold {
position = (position + step) & table_mask;
}
}
assert_eq!(states.len(), prob as usize);
}

// Materialize the C `stateTable`: symbols are grouped by symbol and, within
// each symbol, ordered by table position.
let mut state_table = Vec::with_capacity(1 << acc_log);
for positions in &mut symbol_positions {
positions.sort_unstable();
state_table.extend(positions.iter().copied());
}
debug_assert_eq!(
position, 0,
"FSE spread must cycle exactly once through tableSize positions"
);

// Donor `FSE_encodeSymbol` driver tables (`fse_compress.c`). Built
// alongside the legacy `Vec<State>` storage. The flat
// `state_table_flat` is the donor `nextStateTable` (u16 entries
// for direct memory parity); `symbol_tt[s]` holds the per-symbol
// `{delta_nb_bits, delta_find_state}`. Once populated these drive
// the O(1) `FSETable::next_state` arithmetic; the `Vec<State>` per
// symbol stays alive for `start_state` / `symbol_probability` /
// `max_num_bits_for_symbol` / `write_table` and the existing test
// suite, but is no longer touched on the encode hot path. See #164.
let mut state_table_flat: alloc::vec::Vec<u16> = alloc::vec::Vec::with_capacity(1 << acc_log);
for &slot in &state_table {
// `slot` originates from `idx` values bounded by
// `(1 << acc_log) - 1` (see the negative_idx / next_position
// distribution loops above), so `u16` is wide enough for
// every supported `acc_log` (max 12 → table_size 4096).
state_table_flat.push(slot as u16);
// Phase 3 — emit `state_table_flat` (donor `nextStateTable`)
// ordered by `(symbol, slot)`. Walk every table slot `u`, look up
// its owning symbol via `table_symbol[u]`, and write the raw slot
// `u` into that symbol's running cumul cursor. The Rust convention
// stores raw slots (`0..table_size`); donor stores `table_size + u`
// pre-shifted and recovers `u` on read by subtracting `table_size`.
// Both representations encode the same `(symbol → next_slot)`
// mapping; [`FSETable::next_state`] is written against the raw-slot
// convention so the pre-shift is intentionally skipped here.
let mut state_table_flat: alloc::vec::Vec<u16> = alloc::vec![0u16; table_size];
let mut cursor = cumul;
for (u, &symbol_at_slot) in table_symbol.iter().enumerate() {
let s = symbol_at_slot as usize;
// The Rust convention here keeps `state_table_flat[i]` as the
// raw slot (`0..table_size`); donor stores `table_size + u`
// and subtracts on read. `next_state` arithmetic ([`FSETable::next_state`])
// matches the Rust convention — store the slot directly.
state_table_flat[cursor[s] as usize] = u as u16;
cursor[s] += 1;
}
let state_table_flat: alloc::boxed::Box<[u16]> = state_table_flat.into_boxed_slice();
let mut symbol_tt = [SymbolTT::default(); 256];

// Build encoder transitions directly from C `FSE_encodeSymbol()` formulas.
let mut symbol_transform_total = 0usize;
for (symbol, probability) in probs.iter().copied().enumerate() {
if probability == 0 {
// Phase 4 — `symbolTT[]` (delta_nb_bits, delta_find_state) plus
// precomputed `start_state` and `max_num_bits` per symbol. All via
// donor 16.16 fixed-point arithmetic; no per-state enumeration.
let mut symbol_tt = [SymbolTT::default(); 256];
let mut total: usize = 0;
for (symbol, &prob) in probs.iter().enumerate() {
symbol_states[symbol].probability = prob;
if prob == 0 {
// Donor fills `symbolTT` for prob==0 too, so `FSE_getMaxNbBits`
// still works (returns `acc_log + 1` for absent symbols).
// We don't expose that path, but mirror the value for parity.
symbol_tt[symbol] = SymbolTT {
delta_nb_bits: ((acc_log as u32 + 1) << 16).saturating_sub(1u32 << acc_log),
delta_find_state: 0,
};
continue;
}
let probability_abs = probability.unsigned_abs() as usize;
// Donor 16.16 fixed-point arithmetic — performed in `u32` so a
// 16-bit `usize` target (AVR / MSP430 / no-atomic Cortex-M0)
// can't silently overflow the `<<16` shift. Donor
// `fse_compress.c` keeps the same width.
let (delta_nb_bits, delta_find_state): (u32, isize) = match probability {
let (delta_nb_bits, delta_find_state): (u32, isize) = match prob {
-1 | 1 => (
((acc_log as u32) << 16).saturating_sub(1u32 << acc_log),
symbol_transform_total as isize - 1,
total as isize - 1,
),
probability if probability > 1 => {
let probability = probability as u32;
let max_bits_out = (acc_log as u32) - (probability - 1).ilog2();
let min_state_plus = probability << max_bits_out;
p if p > 1 => {
let p_u32 = p as u32;
let max_bits_out = (acc_log as u32) - (p_u32 - 1).ilog2();
let min_state_plus = p_u32 << max_bits_out;
(
(max_bits_out << 16).saturating_sub(min_state_plus),
symbol_transform_total as isize - probability as isize,
total as isize - p_u32 as isize,
)
}
_ => unreachable!(),
_ => unreachable!("probability is one of {{-1, 1+}} after the prob==0 gate above"),
};
symbol_tt[symbol] = SymbolTT {
delta_nb_bits,
delta_find_state,
};
let state = &mut states[symbol];
total += prob.unsigned_abs() as usize;

// Donor `FSE_initCState2`: start_state =
// stateTable[(((nbBitsOut<<16) - deltaNbBits) >> nbBitsOut) + deltaFindState]
// where `nbBitsOut = (deltaNbBits + (1<<15)) >> 16`.
let init_nb_bits_out = (delta_nb_bits + (1 << 15)) >> 16;
let init_value = (init_nb_bits_out << 16).saturating_sub(delta_nb_bits);
let state_table_index = (init_value >> init_nb_bits_out) as isize + delta_find_state;
let start_index = state_table[state_table_index as usize];
symbol_transform_total += probability_abs;
state.states = Vec::with_capacity(probability_abs.max(1));
// Dedup via `BTreeSet<(baseline, last_index, index, num_bits)>` instead
// of the previous `state.states.iter().any(...)` linear scan. With
// `acc_log` up to 12 (table_size = 4096) the inner scan was O(N²) per
// symbol and dominated encoder FSE table construction (~6% exclusive
// on level22 profile). BTreeSet keeps insertion order on the
// accepted-entries Vec untouched — important because the downstream
// `start_state_slot = states.iter().position(|e| e.index == start_index)`
// lookup depends on which duplicate-keyed entry shows up first.
let mut seen: BTreeSet<(usize, usize, usize, u8)> = BTreeSet::new();
for current_index in 0..(1usize << acc_log) {
// Same 16.16 fixed-point arithmetic as `next_state` —
// keep in `u32` for 16-bit-target safety, then cast back
// to `usize` at indexing sites.
let current_value = (1u32 << acc_log) + (current_index as u32);
let num_bits = ((current_value + delta_nb_bits) >> 16) as usize;
let next_state_idx = (current_value >> num_bits) as isize + delta_find_state;
let next_index = state_table[next_state_idx as usize];
let mask = (1usize << num_bits) - 1;
let baseline = current_index & !mask;
let last_index = baseline + mask;
if !seen.insert((baseline, last_index, next_index, num_bits as u8)) {
continue;
}
state.states.push(State {
num_bits: num_bits as u8,
baseline,
last_index,
index: next_index,
});
}

// For encoding we use the states ordered by the indexes they target
state.states.sort_by_key(|l| l.baseline);
state.start_state_slot = state
.states
.iter()
.position(|entry| entry.index == start_index);
// Donor `FSE_initCState2` guarantees this index is in
// `0..table_size` by construction (`delta_find_state` is bounded
// by `total - probability`, and `(value >> nb_bits_out)` is
// bounded by `2 * probability - 1`). The `debug_assert` makes
// the invariant explicit so a future regression in the donor
// arithmetic surfaces in dev builds before the silent
// `as usize` wraparound.
debug_assert!(
state_table_index >= 0,
"FSE start_state index must be non-negative (got {state_table_index} for symbol {symbol})"
);
let start_index = state_table_flat[state_table_index as usize] as usize;
Comment thread
coderabbitai[bot] marked this conversation as resolved.

// Max nb_bits across all input states `0..table_size`. Donor
// `next_state` arithmetic: `nb_bits = (value + delta_nb_bits) >> 16`
// with `value = table_size + idx`, `idx ∈ 0..table_size`. The
// maximum is at `idx = table_size - 1`. Single op vs the prior
// `Vec<State>::iter().map(|s|s.num_bits).max()` linear scan.
let max_value = (2 * table_size as u32 - 1) + delta_nb_bits;
let max_num_bits = (max_value >> 16) as u8;

symbol_states[symbol].start_state = Some(start_index);
symbol_states[symbol].max_num_bits = Some(max_num_bits);
}

FSETable {
table_size: 1 << acc_log,
states,
table_size,
states: symbol_states,
state_table_flat,
symbol_tt,
}
}

/// Calculate the position of the next entry of the table given the current
/// position and size of the table.
fn next_position(mut p: usize, table_size: usize) -> usize {
p += (table_size >> 1) + (table_size >> 3) + 3;
p &= table_size - 1;
p
}

const ML_DIST: &[i32] = &[
1, 4, 3, 2, 2, 2, 2, 2, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, -1, -1, -1, -1, -1, -1, -1,
Expand Down
Loading
Loading