diff --git a/zstd/src/huff0/huff0_encoder.rs b/zstd/src/huff0/huff0_encoder.rs index 1070ccd4..69610de2 100644 --- a/zstd/src/huff0/huff0_encoder.rs +++ b/zstd/src/huff0/huff0_encoder.rs @@ -1,6 +1,20 @@ use alloc::vec::Vec; use core::cmp::Ordering; +/// Cache primitive for `HuffmanTable::cached_encoded_weight_description`, +/// std-only. `std::sync::OnceLock` is `Sync` (atomic-init), so wrapping +/// it inside `pub struct HuffmanTable` keeps the type's auto-traits +/// intact for downstream consumers that share encoder tables across +/// threads. **The cache is entirely absent in no_std builds**: the +/// `cached_encoded_weight_description` field is `#[cfg(feature = "std")]`, +/// so `HuffmanTable` retains `Sync` unconditionally regardless of which +/// feature set the consumer builds with. no_std embedded targets that +/// might run `HuffmanTable` across threads (e.g. via `Arc`) lose the +/// per-table FSE-encode cache as a trade-off — they get the +/// recompute-every-time path that existed before the cache landed. +#[cfg(feature = "std")] +type CachedDescription = std::sync::OnceLock>>; + use crate::{ bit_io::BitWriter, fse::fse_encoder::{self, FSEEncoder}, @@ -109,13 +123,53 @@ impl>> HuffmanEncoder<'_, '_, V> { } fn write_table(&mut self) { - let weights = self.weights(); - let weights = &weights[..weights.len() - 1]; // don't encode last weight - if let Some(fse_description) = Self::encode_weight_description(weights) { - self.writer.write_bits(fse_description.len() as u8, 8); - self.writer.append_bytes(&fse_description); - } else { - Self::write_raw_weight_description(self.writer, weights); + #[cfg(feature = "std")] + { + // Cached path: cache hit → emit FSE bytes directly OR the + // cached `None` sentinel → emit raw (one `weights()` recompute, + // unavoidable since the cache stores only the FSE encoding, + // not the raw nibbles). + if let Some(cached) = self.table.cached_encoded_weight_description.get() { + if let Some(fse_description) = cached.as_deref() { + self.writer.write_bits(fse_description.len() as u8, 8); + self.writer.append_bytes(fse_description); + return; + } + let weights = self.weights(); + let weights = &weights[..weights.len() - 1]; + Self::write_raw_weight_description(self.writer, weights); + return; + } + // Cold path: compute `weights` once and share it between the + // cache initializer (which uses it to FSE-encode) and the raw + // fallback (which uses it directly to write nibbles). Without + // this, the raw fallback would call back into `weights()` and + // recompute the slice — a measurable hotspot for small / + // low-cardinality tables (#170 review thread). + let weights = self.weights(); + let weights = &weights[..weights.len() - 1]; + if let Some(fse_description) = self + .table + .cached_encoded_weight_description_with_weights(weights) + { + self.writer.write_bits(fse_description.len() as u8, 8); + self.writer.append_bytes(fse_description); + } else { + Self::write_raw_weight_description(self.writer, weights); + } + } + #[cfg(not(feature = "std"))] + { + // no_std: no cache field, no shared state — single `weights()` + // compute, branch on FSE-vs-raw based on direct encoder call. + let weights = self.weights(); + let weights = &weights[..weights.len() - 1]; + if let Some(fse_description) = Self::encode_weight_description(weights) { + self.writer.write_bits(fse_description.len() as u8, 8); + self.writer.append_bytes(&fse_description); + } else { + Self::write_raw_weight_description(self.writer, weights); + } } } @@ -210,6 +264,15 @@ impl>> HuffmanEncoder<'_, '_, V> { pub struct HuffmanTable { /// Index is the symbol, values are the bitstring in the lower bits of the u32 and the amount of bits in the u8 codes: Vec<(u32, u8)>, + /// Lazy cache of the FSE-encoded weight description. Avoids re-running + /// `encode_weight_description` across `try_table_description_size` and + /// `write_table` for the same table instance. **std-only** — + /// `core::cell::OnceCell` is `!Sync` and would break the `Sync` + /// auto-trait for `pub HuffmanTable` in no_std builds; no_std users + /// keep the original recompute-every-time semantics. See the + /// `CachedDescription` type-alias doc above for full rationale. + #[cfg(feature = "std")] + cached_encoded_weight_description: CachedDescription, } impl HuffmanTable { @@ -312,17 +375,37 @@ impl HuffmanTable { } /// Returns exact writable table-description size when representable. + /// std build path: consults the lazy cache to avoid re-encoding the + /// weight stream when both planner and emitter call this for the + /// same table. no_std build path: recomputes via the direct encoder + /// every call (cache field absent — preserves `Sync`). pub(crate) fn try_table_description_size(&self) -> Option { - let weights = self.weights(); - let weights = &weights[..weights.len() - 1]; - if let Some(fse_description) = HuffmanEncoder::>::encode_weight_description(weights) + #[cfg(feature = "std")] { - return Some(fse_description.len() + 1); + if let Some(fse_description) = self.cached_encoded_weight_description() { + return Some(fse_description.len() + 1); + } + let raw_weights_len = self.codes.len().saturating_sub(1); + if raw_weights_len <= 128 { + Some(raw_weights_len.div_ceil(2) + 1) + } else { + None + } } - if weights.len() <= 128 { - Some(weights.len().div_ceil(2) + 1) - } else { - None + #[cfg(not(feature = "std"))] + { + let weights = self.weights(); + let weights = &weights[..weights.len() - 1]; + if let Some(fse_description) = + HuffmanEncoder::>::encode_weight_description(weights) + { + return Some(fse_description.len() + 1); + } + if weights.len() <= 128 { + Some(weights.len().div_ceil(2) + 1) + } else { + None + } } } @@ -340,6 +423,23 @@ impl HuffmanTable { .collect::>() } + #[cfg(feature = "std")] + fn cached_encoded_weight_description(&self) -> Option<&[u8]> { + if let Some(cached) = self.cached_encoded_weight_description.get() { + return cached.as_deref(); + } + let weights = self.weights(); + let weights = &weights[..weights.len() - 1]; + self.cached_encoded_weight_description_with_weights(weights) + } + + #[cfg(feature = "std")] + fn cached_encoded_weight_description_with_weights(&self, weights: &[u8]) -> Option<&[u8]> { + self.cached_encoded_weight_description + .get_or_init(|| HuffmanEncoder::>::encode_weight_description(weights)) + .as_deref() + } + /// Estimates encoded payload size in bytes directly from per-symbol counts. pub(crate) fn estimate_compressed_size_from_counts(&self, counts: &[usize]) -> usize { let bits = self @@ -364,6 +464,8 @@ impl HuffmanTable { let table_log = highest_bit_set(weight_sum) - 1; let mut table = HuffmanTable { codes: alloc::vec![(0, 0); weights.len()], + #[cfg(feature = "std")] + cached_encoded_weight_description: CachedDescription::new(), }; let mut nb_per_rank = [0u16; 13]; for &weight in weights { @@ -1236,6 +1338,72 @@ fn large_alphabet_weight_description_uses_fse_when_raw_is_unrepresentable() { )); } +#[cfg(feature = "std")] +#[test] +fn cached_encoded_weight_description_is_reused_for_write_table() { + let mut data = Vec::new(); + for symbol in 0u8..=255 { + data.extend(core::iter::repeat_n(symbol, usize::from(symbol) + 1)); + } + let table = HuffmanTable::build_from_data(&data); + let desc_size = table + .writeable_table_description_size() + .expect("table description must be writable"); + let cached = table + .cached_encoded_weight_description + .get() + .and_then(Option::as_ref) + .expect("large alphabet fixture must cache FSE description") + .clone(); + assert_eq!(desc_size, cached.len() + 1); + + let mut encoded = Vec::new(); + { + let mut writer = BitWriter::from(&mut encoded); + let mut encoder = HuffmanEncoder::new(&table, &mut writer); + encoder.write_table(); + writer.flush(); + } + assert_eq!(encoded[0] as usize, cached.len()); + assert_eq!(&encoded[1..], cached.as_slice()); +} + +#[cfg(feature = "std")] +#[test] +fn write_table_raw_path_initializes_none_cache() { + let table = HuffmanTable::build_from_weights(&[1, 1]); + assert!(table.cached_encoded_weight_description.get().is_none()); + + let mut expected = Vec::new(); + let weights = { + let mut out = Vec::new(); + let mut writer = BitWriter::from(&mut out); + let encoder = HuffmanEncoder::new(&table, &mut writer); + encoder.weights() + }; + { + let mut writer = BitWriter::from(&mut expected); + HuffmanEncoder::>::write_raw_weight_description( + &mut writer, + &weights[..weights.len() - 1], + ); + writer.flush(); + } + + let mut encoded = Vec::new(); + { + let mut writer = BitWriter::from(&mut encoded); + let mut encoder = HuffmanEncoder::new(&table, &mut writer); + encoder.write_table(); + writer.flush(); + } + assert_eq!(encoded, expected); + assert!(matches!( + table.cached_encoded_weight_description.get(), + Some(None) + )); +} + #[test] fn encoded_weight_description_is_accepted_by_donor_huf_reader() { use zstd::zstd_safe::zstd_sys;