diff --git a/ml-kem/src/compress.rs b/ml-kem/src/compress.rs index dc6cbd9..756b7d6 100644 --- a/ml-kem/src/compress.rs +++ b/ml-kem/src/compress.rs @@ -1,7 +1,8 @@ use crate::algebra::{BaseField, Elem, Int, Polynomial, Vector}; -use array::ArraySize; -use module_lattice::EncodingSize; -use module_lattice::{Field, Truncate}; +use module_lattice::{ + ArraySize, EncodingSize, Field, FixedWidthInt, FixedWidthPolynomial, FixedWidthVector, + Truncate, +}; // A convenience trait to allow us to associate some constants with a typenum pub(crate) trait CompressionFactor: EncodingSize { @@ -22,68 +23,75 @@ where const DIV_MUL: u64 = (1 << T::DIV_SHIFT) / BaseField::QLL; } -// Traits for objects that allow compression / decompression -pub(crate) trait Compress { - fn compress(&mut self) -> &Self; - fn decompress(&mut self) -> &Self; +/// Compress a prime-field representation into its `Z_{2^D}` fixed-width form. +pub(crate) trait Compress { + type Output; + fn compress(self) -> Self::Output; } -impl Compress for Elem { +/// Decompress a `Z_{2^D}` fixed-width representation back into the prime field. +pub(crate) trait Decompress { + type Output; + fn decompress(self) -> Self::Output; +} + +impl Compress for Elem { + type Output = FixedWidthInt; + // Equation 4.5: Compress_d(x) = round((2^d / q) x) // // Here and in decompression, we leverage the following facts: // // round(a / b) = floor((a + b/2) / b) // a / q ~= (a * x) >> s where x >> s ~= 1/q - fn compress(&mut self) -> &Self { + fn compress(self) -> FixedWidthInt { const Q_HALF: u64 = (BaseField::QLL + 1) >> 1; let x = u64::from(self.0); let y = (((x << D::USIZE) + Q_HALF) * D::DIV_MUL) >> D::DIV_SHIFT; - self.0 = u16::truncate(y) & D::MASK; - self + FixedWidthInt::new(u16::truncate(y) & D::MASK) } +} + +impl Decompress for FixedWidthInt { + type Output = Elem; // Equation 4.6: Decompress_d(x) = round((q / 2^d) x) - fn decompress(&mut self) -> &Self { - let x = u32::from(self.0); + fn decompress(self) -> Elem { + let x = u32::from(self.value()); let y = ((x * BaseField::QL) + D::POW2_HALF) >> D::USIZE; - self.0 = Truncate::truncate(y); - self + Elem::new(Truncate::truncate(y)) } } -impl Compress for Polynomial { - fn compress(&mut self) -> &Self { - for x in &mut self.0 { - x.compress::(); - } - self +impl Compress for Polynomial { + type Output = FixedWidthPolynomial; + + fn compress(self) -> FixedWidthPolynomial { + FixedWidthPolynomial::new(self.0.into_iter().map(Compress::::compress).collect()) } +} - fn decompress(&mut self) -> &Self { - for x in &mut self.0 { - x.decompress::(); - } +impl Decompress for FixedWidthPolynomial { + type Output = Polynomial; - self + fn decompress(self) -> Polynomial { + Polynomial::new(self.0.into_iter().map(Decompress::::decompress).collect()) } } -impl Compress for Vector { - fn compress(&mut self) -> &Self { - for x in &mut self.0 { - x.compress::(); - } +impl Compress for Vector { + type Output = FixedWidthVector; - self + fn compress(self) -> FixedWidthVector { + FixedWidthVector::new(self.0.into_iter().map(Compress::::compress).collect()) } +} - fn decompress(&mut self) -> &Self { - for x in &mut self.0 { - x.decompress::(); - } +impl Decompress for FixedWidthVector { + type Output = Vector; - self + fn decompress(self) -> Vector { + Vector::new(self.0.into_iter().map(Decompress::::decompress).collect()) } } @@ -111,11 +119,10 @@ pub(crate) mod tests { let error_threshold = i32::from(Ratio::new(BaseField::Q, 1 << D::USIZE).to_integer()); for x in 0..BaseField::Q { - let mut y = Elem::new(x); - y.compress::(); - y.decompress::(); + let compressed = Compress::::compress(Elem::new(x)); + let decompressed = Decompress::::decompress(compressed); - let mut error = i32::from(y.0) - i32::from(x) + QI32; + let mut error = i32::from(decompressed.0) - i32::from(x) + QI32; if error > (QI32 - 1) / 2 { error -= QI32; } @@ -131,19 +138,17 @@ pub(crate) mod tests { fn decompression_compression_equality() { for x in 0..(1 << D::USIZE) { - let mut y = Elem::new(x); - y.decompress::(); - y.compress::(); + let decompressed = Decompress::::decompress(FixedWidthInt::::new(x)); + let recompressed = Compress::::compress(decompressed); - assert_eq!(y.0, x, "failed for x: {}, D: {}", x, D::USIZE); + assert_eq!(recompressed.value(), x, "failed for x: {}, D: {}", x, D::USIZE); } } fn decompress_KAT() { for y in 0..(1 << D::USIZE) { let x_expected = rational_decompress::(y); - let mut x_actual = Elem::new(y); - x_actual.decompress::(); + let x_actual = Decompress::::decompress(FixedWidthInt::::new(y)); assert_eq!(x_expected, x_actual.0); } @@ -152,10 +157,9 @@ pub(crate) mod tests { fn compress_KAT() { for x in 0..BaseField::Q { let y_expected = rational_compress::(x); - let mut y_actual = Elem::new(x); - y_actual.compress::(); + let y_actual = Compress::::compress(Elem::new(x)); - assert_eq!(y_expected, y_actual.0, "for x: {}, D: {}", x, D::USIZE); + assert_eq!(y_expected, y_actual.value(), "for x: {}, D: {}", x, D::USIZE); } } diff --git a/ml-kem/src/pke.rs b/ml-kem/src/pke.rs index d936078..3e6f0cf 100644 --- a/ml-kem/src/pke.rs +++ b/ml-kem/src/pke.rs @@ -3,13 +3,13 @@ use crate::algebra::{ Ntt, NttInverse, NttMatrix, NttVector, Polynomial, Vector, matrix_sample_ntt, sample_poly_cbd, sample_poly_vec_cbd, }; -use crate::compress::Compress; +use crate::compress::{Compress, Decompress}; use crate::crypto::{G, PRF}; use crate::param::{EncodedDecryptionKey, EncodedEncryptionKey, PkeParams}; use array::typenum::{U1, Unsigned}; use kem::{Ciphertext, InvalidKey}; use module_lattice::{ - Encode, + Encode, FixedWidthPolynomial, FixedWidthVector, ctutils::{Choice, CtEq}, }; @@ -90,16 +90,16 @@ where pub(crate) fn decrypt(&self, ciphertext: &Ciphertext

) -> B32 { let (c1, c2) = P::split_ct(ciphertext); - let mut u: Vector = Encode::::decode(c1); - u.decompress::(); + let u_compressed: FixedWidthVector<_, P::K, P::Du> = Encode::::decode(c1); + let u: Vector = u_compressed.decompress(); - let mut v: Polynomial = Encode::::decode(c2); - v.decompress::(); + let v_compressed: FixedWidthPolynomial<_, P::Dv> = Encode::::decode(c2); + let v: Polynomial = v_compressed.decompress(); let u_hat = u.ntt(); let sTu = (&self.s_hat * &u_hat).ntt_inverse(); - let mut w = &v - &sTu; - Encode::::encode(w.compress::()) + let w = &v - &sTu; + Encode::::encode(&Compress::::compress(w)) } /// Represent this decryption key as a byte array `(s_hat)` @@ -141,16 +141,16 @@ where let A_hat_t: NttMatrix = matrix_sample_ntt(&self.rho, true); let r_hat: NttVector = r.ntt(); let ATr: Vector = (&A_hat_t * &r_hat).ntt_inverse(); - let mut u = ATr + e1; + let u = ATr + e1; - let mut mu: Polynomial = Encode::::decode(message); - mu.decompress::(); + let mu_compressed: FixedWidthPolynomial<_, U1> = Encode::::decode(message); + let mu: Polynomial = mu_compressed.decompress(); let tTr: Polynomial = (&self.t_hat * &r_hat).ntt_inverse(); - let mut v = &(&tTr + &e2) + μ + let v = &(&tTr + &e2) + μ - let c1 = Encode::::encode(u.compress::()); - let c2 = Encode::::encode(v.compress::()); + let c1 = Encode::::encode(&Compress::::compress(u)); + let c2 = Encode::::encode(&Compress::::compress(v)); P::concat_ct(c1, c2) } diff --git a/module-lattice/src/algebra.rs b/module-lattice/src/algebra.rs index 113eb72..dd7e057 100644 --- a/module-lattice/src/algebra.rs +++ b/module-lattice/src/algebra.rs @@ -37,6 +37,15 @@ pub trait Field: Copy + Default + Debug + PartialEq { fn barrett_reduce(x: Self::Long) -> Self::Int; } +/// Marker trait for a [`Field`] whose modulus is prime. +/// +/// Multiplication on [`Elem`] is gated on `F: PrimeField` because the +/// reduction-based arithmetic in this crate (Barrett reduction, NTT) is +/// only valid for prime-order fields. A non-prime-order representation +/// such as Z_{2^d} can still impl [`Field`] for storage purposes (see +/// [`FixedWidthInt`]) without claiming the multiplicative group structure. +pub trait PrimeField: Field {} + /// The `define_field` macro creates a zero-sized struct and an implementation of the [`Field`] /// trait for that struct. The caller must specify: /// @@ -89,6 +98,8 @@ macro_rules! define_field { Self::small_reduce($crate::Truncate::truncate(remainder)) } } + + impl $crate::PrimeField for $field {} }; } @@ -157,7 +168,7 @@ impl Sub> for Elem { } } -impl Mul> for Elem { +impl Mul> for Elem { type Output = Elem; fn mul(self, rhs: Elem) -> Elem { @@ -220,7 +231,7 @@ impl Sub<&Polynomial> for &Polynomial { } } -impl Mul<&Polynomial> for Elem { +impl Mul<&Polynomial> for Elem { type Output = Polynomial; fn mul(self, rhs: &Polynomial) -> Polynomial { @@ -306,7 +317,7 @@ impl Sub<&Vector> for &Vector { } } -impl Mul<&Vector> for Elem { +impl Mul<&Vector> for Elem { type Output = Vector; fn mul(self, rhs: &Vector) -> Vector { @@ -382,7 +393,7 @@ impl Sub<&NttPolynomial> for &NttPolynomial { } } -impl Mul<&NttPolynomial> for Elem { +impl Mul<&NttPolynomial> for Elem { type Output = NttPolynomial; fn mul(self, rhs: &NttPolynomial) -> NttPolynomial { diff --git a/module-lattice/src/fixed_width.rs b/module-lattice/src/fixed_width.rs new file mode 100644 index 0000000..971139f --- /dev/null +++ b/module-lattice/src/fixed_width.rs @@ -0,0 +1,135 @@ +use crate::algebra::{Elem, Field, Polynomial, Vector}; +use crate::encoding::{ArraySize, Encode, EncodingSize, VectorEncodingSize}; +use array::{Array, typenum::U256}; +use core::marker::PhantomData; + +/// A value of width `D` bits, stored in `F::Int` for compatibility with the +/// rest of the lattice algebra plumbing. +/// +/// Despite carrying an `F: Field` parameter, a [`FixedWidthInt`] is *not* a +/// member of `F`; it is an element of `Z_{2^D}`. The type exists so that +/// compressed values (i.e., the codomain of `Compress_d` in FIPS 203) can be +/// distinguished from field elements at the type level. +/// +/// Multiplication is intentionally not provided: `Z_{2^D}` is not a prime +/// field and the Barrett-reduced [`Mul`] on [`Elem`] would be wrong here. +#[derive(Copy, Clone, Debug, Default, Eq, PartialEq)] +pub struct FixedWidthInt { + val: F::Int, + _phantom: PhantomData, +} + +impl FixedWidthInt { + /// Create a new fixed-width value. The caller is responsible for + /// ensuring `val < 2^D`; the type does not enforce this. + pub const fn new(val: F::Int) -> Self { + Self { + val, + _phantom: PhantomData, + } + } + + /// Access the underlying integer. + pub fn value(&self) -> F::Int { + self.val + } +} + +impl From> for FixedWidthInt { + fn from(elem: Elem) -> Self { + Self::new(elem.0) + } +} + +impl From> for Elem { + fn from(fwi: FixedWidthInt) -> Self { + Elem(fwi.val) + } +} + +/// A polynomial whose coefficients are [`FixedWidthInt`] values. +#[derive(Clone, Copy, Debug, Default, Eq, PartialEq)] +pub struct FixedWidthPolynomial( + pub Array, U256>, +); + +impl FixedWidthPolynomial { + /// Create a new polynomial. + pub const fn new(coeffs: Array, U256>) -> Self { + Self(coeffs) + } +} + +impl From> for FixedWidthPolynomial { + fn from(poly: Polynomial) -> Self { + Self(poly.0.iter().map(|&elem| elem.into()).collect()) + } +} + +impl From> for Polynomial { + fn from(poly: FixedWidthPolynomial) -> Self { + Polynomial::new(poly.0.iter().map(|&fwi| fwi.into()).collect()) + } +} + +impl Encode for FixedWidthPolynomial +where + Polynomial: Encode, +{ + type EncodedSize = as Encode>::EncodedSize; + + fn encode(&self) -> Array { + Encode::::encode(&Polynomial::::from(*self)) + } + + fn decode(enc: &Array) -> Self { + as Encode>::decode(enc).into() + } +} + +/// A vector of [`FixedWidthPolynomial`]. +#[derive(Clone, Debug, Default, Eq, PartialEq)] +pub struct FixedWidthVector( + pub Array, K>, +); + +impl FixedWidthVector { + /// Create a new vector. + pub const fn new(polys: Array, K>) -> Self { + Self(polys) + } +} + +impl From> + for FixedWidthVector +{ + fn from(vec: Vector) -> Self { + Self(vec.0.into_iter().map(FixedWidthPolynomial::from).collect()) + } +} + +impl From> + for Vector +{ + fn from(vec: FixedWidthVector) -> Self { + Vector::new(vec.0.into_iter().map(Polynomial::from).collect()) + } +} + +impl Encode for FixedWidthVector +where + F: Field, + K: ArraySize, + D: VectorEncodingSize, + Vector: Encode, +{ + type EncodedSize = as Encode>::EncodedSize; + + fn encode(&self) -> Array { + Encode::::encode(&Vector::::from(self.clone())) + } + + fn decode(enc: &Array) -> Self { + as Encode>::decode(enc).into() + } +} diff --git a/module-lattice/src/lib.rs b/module-lattice/src/lib.rs index 9bdbc35..b4a43db 100644 --- a/module-lattice/src/lib.rs +++ b/module-lattice/src/lib.rs @@ -13,16 +13,21 @@ mod algebra; /// Packing of polynomials into coefficients with a specified number of bits. mod encoding; +/// Fixed-width integer values in `Z_{2^d}`, used to represent the codomain of +/// the FIPS 203 `Compress_d` operation distinctly from prime-field elements. +mod fixed_width; + /// Integer truncation support. mod truncate; pub use algebra::{ - Elem, Field, MultiplyNtt, NttMatrix, NttPolynomial, NttVector, Polynomial, Vector, + Elem, Field, MultiplyNtt, NttMatrix, NttPolynomial, NttVector, Polynomial, PrimeField, Vector, }; pub use encoding::{ ArraySize, DecodedValue, Encode, EncodedPolynomial, EncodedPolynomialSize, EncodedVector, EncodedVectorSize, EncodingSize, VectorEncodingSize, byte_decode, byte_encode, }; +pub use fixed_width::{FixedWidthInt, FixedWidthPolynomial, FixedWidthVector}; pub use truncate::Truncate; #[cfg(feature = "ctutils")]