Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
198 changes: 183 additions & 15 deletions zstd/src/huff0/huff0_encoder.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,20 @@
use alloc::vec::Vec;
use core::cmp::Ordering;

/// Cache primitive for `HuffmanTable::cached_encoded_weight_description`,
/// std-only. `std::sync::OnceLock` is `Sync` (atomic-init), so wrapping
/// it inside `pub struct HuffmanTable` keeps the type's auto-traits
/// intact for downstream consumers that share encoder tables across
/// threads. **The cache is entirely absent in no_std builds**: the
/// `cached_encoded_weight_description` field is `#[cfg(feature = "std")]`,
Comment thread
polaz marked this conversation as resolved.
/// so `HuffmanTable` retains `Sync` unconditionally regardless of which
/// feature set the consumer builds with. no_std embedded targets that
/// might run `HuffmanTable` across threads (e.g. via `Arc`) lose the
/// per-table FSE-encode cache as a trade-off — they get the
/// recompute-every-time path that existed before the cache landed.
#[cfg(feature = "std")]
type CachedDescription = std::sync::OnceLock<Option<Vec<u8>>>;

use crate::{
bit_io::BitWriter,
fse::fse_encoder::{self, FSEEncoder},
Expand Down Expand Up @@ -109,13 +123,53 @@ impl<V: AsMut<Vec<u8>>> HuffmanEncoder<'_, '_, V> {
}

fn write_table(&mut self) {
let weights = self.weights();
let weights = &weights[..weights.len() - 1]; // don't encode last weight
if let Some(fse_description) = Self::encode_weight_description(weights) {
self.writer.write_bits(fse_description.len() as u8, 8);
self.writer.append_bytes(&fse_description);
} else {
Self::write_raw_weight_description(self.writer, weights);
#[cfg(feature = "std")]
{
// Cached path: cache hit → emit FSE bytes directly OR the
// cached `None` sentinel → emit raw (one `weights()` recompute,
// unavoidable since the cache stores only the FSE encoding,
// not the raw nibbles).
if let Some(cached) = self.table.cached_encoded_weight_description.get() {
if let Some(fse_description) = cached.as_deref() {
self.writer.write_bits(fse_description.len() as u8, 8);
self.writer.append_bytes(fse_description);
return;
}
let weights = self.weights();
let weights = &weights[..weights.len() - 1];
Self::write_raw_weight_description(self.writer, weights);
return;
}
// Cold path: compute `weights` once and share it between the
// cache initializer (which uses it to FSE-encode) and the raw
// fallback (which uses it directly to write nibbles). Without
// this, the raw fallback would call back into `weights()` and
// recompute the slice — a measurable hotspot for small /
// low-cardinality tables (#170 review thread).
let weights = self.weights();
let weights = &weights[..weights.len() - 1];
if let Some(fse_description) = self
.table
.cached_encoded_weight_description_with_weights(weights)
{
self.writer.write_bits(fse_description.len() as u8, 8);
self.writer.append_bytes(fse_description);
} else {
Self::write_raw_weight_description(self.writer, weights);
}
}
#[cfg(not(feature = "std"))]
{
// no_std: no cache field, no shared state — single `weights()`
// compute, branch on FSE-vs-raw based on direct encoder call.
let weights = self.weights();
let weights = &weights[..weights.len() - 1];
if let Some(fse_description) = Self::encode_weight_description(weights) {
self.writer.write_bits(fse_description.len() as u8, 8);
self.writer.append_bytes(&fse_description);
} else {
Self::write_raw_weight_description(self.writer, weights);
}
}
Comment thread
polaz marked this conversation as resolved.
}

Expand Down Expand Up @@ -210,6 +264,15 @@ impl<V: AsMut<Vec<u8>>> HuffmanEncoder<'_, '_, V> {
pub struct HuffmanTable {
/// Index is the symbol, values are the bitstring in the lower bits of the u32 and the amount of bits in the u8
codes: Vec<(u32, u8)>,
/// Lazy cache of the FSE-encoded weight description. Avoids re-running
/// `encode_weight_description` across `try_table_description_size` and
/// `write_table` for the same table instance. **std-only** —
/// `core::cell::OnceCell` is `!Sync` and would break the `Sync`
/// auto-trait for `pub HuffmanTable` in no_std builds; no_std users
/// keep the original recompute-every-time semantics. See the
/// `CachedDescription` type-alias doc above for full rationale.
#[cfg(feature = "std")]
cached_encoded_weight_description: CachedDescription,
}

impl HuffmanTable {
Expand Down Expand Up @@ -312,17 +375,37 @@ impl HuffmanTable {
}

/// Returns exact writable table-description size when representable.
/// std build path: consults the lazy cache to avoid re-encoding the
/// weight stream when both planner and emitter call this for the
/// same table. no_std build path: recomputes via the direct encoder
/// every call (cache field absent — preserves `Sync`).
pub(crate) fn try_table_description_size(&self) -> Option<usize> {
let weights = self.weights();
let weights = &weights[..weights.len() - 1];
if let Some(fse_description) = HuffmanEncoder::<Vec<u8>>::encode_weight_description(weights)
#[cfg(feature = "std")]
{
return Some(fse_description.len() + 1);
if let Some(fse_description) = self.cached_encoded_weight_description() {
return Some(fse_description.len() + 1);
}
let raw_weights_len = self.codes.len().saturating_sub(1);
if raw_weights_len <= 128 {
Some(raw_weights_len.div_ceil(2) + 1)
} else {
None
}
}
if weights.len() <= 128 {
Some(weights.len().div_ceil(2) + 1)
} else {
None
#[cfg(not(feature = "std"))]
{
let weights = self.weights();
let weights = &weights[..weights.len() - 1];
if let Some(fse_description) =
HuffmanEncoder::<Vec<u8>>::encode_weight_description(weights)
{
return Some(fse_description.len() + 1);
}
if weights.len() <= 128 {
Some(weights.len().div_ceil(2) + 1)
} else {
None
}
}
}

Expand All @@ -340,6 +423,23 @@ impl HuffmanTable {
.collect::<Vec<u8>>()
}

#[cfg(feature = "std")]
fn cached_encoded_weight_description(&self) -> Option<&[u8]> {
if let Some(cached) = self.cached_encoded_weight_description.get() {
return cached.as_deref();
}
let weights = self.weights();
let weights = &weights[..weights.len() - 1];
self.cached_encoded_weight_description_with_weights(weights)
}

#[cfg(feature = "std")]
fn cached_encoded_weight_description_with_weights(&self, weights: &[u8]) -> Option<&[u8]> {
self.cached_encoded_weight_description
.get_or_init(|| HuffmanEncoder::<Vec<u8>>::encode_weight_description(weights))
.as_deref()
}

/// Estimates encoded payload size in bytes directly from per-symbol counts.
pub(crate) fn estimate_compressed_size_from_counts(&self, counts: &[usize]) -> usize {
let bits = self
Expand All @@ -364,6 +464,8 @@ impl HuffmanTable {
let table_log = highest_bit_set(weight_sum) - 1;
let mut table = HuffmanTable {
codes: alloc::vec![(0, 0); weights.len()],
#[cfg(feature = "std")]
cached_encoded_weight_description: CachedDescription::new(),
};
let mut nb_per_rank = [0u16; 13];
for &weight in weights {
Expand Down Expand Up @@ -1236,6 +1338,72 @@ fn large_alphabet_weight_description_uses_fse_when_raw_is_unrepresentable() {
));
}

#[cfg(feature = "std")]
#[test]
fn cached_encoded_weight_description_is_reused_for_write_table() {
let mut data = Vec::new();
for symbol in 0u8..=255 {
data.extend(core::iter::repeat_n(symbol, usize::from(symbol) + 1));
}
let table = HuffmanTable::build_from_data(&data);
let desc_size = table
.writeable_table_description_size()
.expect("table description must be writable");
let cached = table
.cached_encoded_weight_description
.get()
.and_then(Option::as_ref)
.expect("large alphabet fixture must cache FSE description")
.clone();
assert_eq!(desc_size, cached.len() + 1);

let mut encoded = Vec::new();
{
let mut writer = BitWriter::from(&mut encoded);
let mut encoder = HuffmanEncoder::new(&table, &mut writer);
encoder.write_table();
writer.flush();
}
assert_eq!(encoded[0] as usize, cached.len());
assert_eq!(&encoded[1..], cached.as_slice());
}

#[cfg(feature = "std")]
#[test]
fn write_table_raw_path_initializes_none_cache() {
let table = HuffmanTable::build_from_weights(&[1, 1]);
assert!(table.cached_encoded_weight_description.get().is_none());

let mut expected = Vec::new();
let weights = {
let mut out = Vec::new();
let mut writer = BitWriter::from(&mut out);
let encoder = HuffmanEncoder::new(&table, &mut writer);
encoder.weights()
};
{
let mut writer = BitWriter::from(&mut expected);
HuffmanEncoder::<Vec<u8>>::write_raw_weight_description(
&mut writer,
&weights[..weights.len() - 1],
);
writer.flush();
}

let mut encoded = Vec::new();
{
let mut writer = BitWriter::from(&mut encoded);
let mut encoder = HuffmanEncoder::new(&table, &mut writer);
encoder.write_table();
writer.flush();
}
assert_eq!(encoded, expected);
assert!(matches!(
table.cached_encoded_weight_description.get(),
Some(None)
));
}

#[test]
fn encoded_weight_description_is_accepted_by_donor_huf_reader() {
use zstd::zstd_safe::zstd_sys;
Expand Down
Loading