Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
104 changes: 54 additions & 50 deletions ml-kem/src/compress.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
use crate::algebra::{BaseField, Elem, Int, Polynomial, Vector};
use array::ArraySize;
use module_lattice::EncodingSize;
use module_lattice::{Field, Truncate};
use module_lattice::{
ArraySize, EncodingSize, Field, FixedWidthInt, FixedWidthPolynomial, FixedWidthVector,
Truncate,
};

// A convenience trait to allow us to associate some constants with a typenum
pub(crate) trait CompressionFactor: EncodingSize {
Expand All @@ -22,68 +23,75 @@ where
const DIV_MUL: u64 = (1 << T::DIV_SHIFT) / BaseField::QLL;
}

// Traits for objects that allow compression / decompression
pub(crate) trait Compress {
fn compress<D: CompressionFactor>(&mut self) -> &Self;
fn decompress<D: CompressionFactor>(&mut self) -> &Self;
/// Compress a prime-field representation into its `Z_{2^D}` fixed-width form.
pub(crate) trait Compress<D: CompressionFactor> {
type Output;
fn compress(self) -> Self::Output;
}

impl Compress for Elem {
/// Decompress a `Z_{2^D}` fixed-width representation back into the prime field.
pub(crate) trait Decompress<D: CompressionFactor> {
type Output;
fn decompress(self) -> Self::Output;
}

impl<D: CompressionFactor> Compress<D> for Elem {
type Output = FixedWidthInt<BaseField, D>;

// Equation 4.5: Compress_d(x) = round((2^d / q) x)
//
// Here and in decompression, we leverage the following facts:
//
// round(a / b) = floor((a + b/2) / b)
// a / q ~= (a * x) >> s where x >> s ~= 1/q
fn compress<D: CompressionFactor>(&mut self) -> &Self {
fn compress(self) -> FixedWidthInt<BaseField, D> {
const Q_HALF: u64 = (BaseField::QLL + 1) >> 1;
let x = u64::from(self.0);
let y = (((x << D::USIZE) + Q_HALF) * D::DIV_MUL) >> D::DIV_SHIFT;
self.0 = u16::truncate(y) & D::MASK;
self
FixedWidthInt::new(u16::truncate(y) & D::MASK)
}
}

impl<D: CompressionFactor> Decompress<D> for FixedWidthInt<BaseField, D> {
type Output = Elem;

// Equation 4.6: Decompress_d(x) = round((q / 2^d) x)
fn decompress<D: CompressionFactor>(&mut self) -> &Self {
let x = u32::from(self.0);
fn decompress(self) -> Elem {
let x = u32::from(self.value());
let y = ((x * BaseField::QL) + D::POW2_HALF) >> D::USIZE;
self.0 = Truncate::truncate(y);
self
Elem::new(Truncate::truncate(y))
}
}
impl Compress for Polynomial {
fn compress<D: CompressionFactor>(&mut self) -> &Self {
for x in &mut self.0 {
x.compress::<D>();
}

self
impl<D: CompressionFactor> Compress<D> for Polynomial {
type Output = FixedWidthPolynomial<BaseField, D>;

fn compress(self) -> FixedWidthPolynomial<BaseField, D> {
FixedWidthPolynomial::new(self.0.into_iter().map(Compress::<D>::compress).collect())
}
}

fn decompress<D: CompressionFactor>(&mut self) -> &Self {
for x in &mut self.0 {
x.decompress::<D>();
}
impl<D: CompressionFactor> Decompress<D> for FixedWidthPolynomial<BaseField, D> {
type Output = Polynomial;

self
fn decompress(self) -> Polynomial {
Polynomial::new(self.0.into_iter().map(Decompress::<D>::decompress).collect())
}
}

impl<K: ArraySize> Compress for Vector<K> {
fn compress<D: CompressionFactor>(&mut self) -> &Self {
for x in &mut self.0 {
x.compress::<D>();
}
impl<K: ArraySize, D: CompressionFactor> Compress<D> for Vector<K> {
type Output = FixedWidthVector<BaseField, K, D>;

self
fn compress(self) -> FixedWidthVector<BaseField, K, D> {
FixedWidthVector::new(self.0.into_iter().map(Compress::<D>::compress).collect())
}
}

fn decompress<D: CompressionFactor>(&mut self) -> &Self {
for x in &mut self.0 {
x.decompress::<D>();
}
impl<K: ArraySize, D: CompressionFactor> Decompress<D> for FixedWidthVector<BaseField, K, D> {
type Output = Vector<K>;

self
fn decompress(self) -> Vector<K> {
Vector::new(self.0.into_iter().map(Decompress::<D>::decompress).collect())
}
}

Expand Down Expand Up @@ -111,11 +119,10 @@ pub(crate) mod tests {
let error_threshold = i32::from(Ratio::new(BaseField::Q, 1 << D::USIZE).to_integer());

for x in 0..BaseField::Q {
let mut y = Elem::new(x);
y.compress::<D>();
y.decompress::<D>();
let compressed = Compress::<D>::compress(Elem::new(x));
let decompressed = Decompress::<D>::decompress(compressed);

let mut error = i32::from(y.0) - i32::from(x) + QI32;
let mut error = i32::from(decompressed.0) - i32::from(x) + QI32;
if error > (QI32 - 1) / 2 {
error -= QI32;
}
Expand All @@ -131,19 +138,17 @@ pub(crate) mod tests {

fn decompression_compression_equality<D: CompressionFactor>() {
for x in 0..(1 << D::USIZE) {
let mut y = Elem::new(x);
y.decompress::<D>();
y.compress::<D>();
let decompressed = Decompress::<D>::decompress(FixedWidthInt::<BaseField, D>::new(x));
let recompressed = Compress::<D>::compress(decompressed);

assert_eq!(y.0, x, "failed for x: {}, D: {}", x, D::USIZE);
assert_eq!(recompressed.value(), x, "failed for x: {}, D: {}", x, D::USIZE);
}
}

fn decompress_KAT<D: CompressionFactor>() {
for y in 0..(1 << D::USIZE) {
let x_expected = rational_decompress::<D>(y);
let mut x_actual = Elem::new(y);
x_actual.decompress::<D>();
let x_actual = Decompress::<D>::decompress(FixedWidthInt::<BaseField, D>::new(y));

assert_eq!(x_expected, x_actual.0);
}
Expand All @@ -152,10 +157,9 @@ pub(crate) mod tests {
fn compress_KAT<D: CompressionFactor>() {
for x in 0..BaseField::Q {
let y_expected = rational_compress::<D>(x);
let mut y_actual = Elem::new(x);
y_actual.compress::<D>();
let y_actual = Compress::<D>::compress(Elem::new(x));

assert_eq!(y_expected, y_actual.0, "for x: {}, D: {}", x, D::USIZE);
assert_eq!(y_expected, y_actual.value(), "for x: {}, D: {}", x, D::USIZE);
}
}

Expand Down
28 changes: 14 additions & 14 deletions ml-kem/src/pke.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,13 @@ use crate::algebra::{
Ntt, NttInverse, NttMatrix, NttVector, Polynomial, Vector, matrix_sample_ntt, sample_poly_cbd,
sample_poly_vec_cbd,
};
use crate::compress::Compress;
use crate::compress::{Compress, Decompress};
use crate::crypto::{G, PRF};
use crate::param::{EncodedDecryptionKey, EncodedEncryptionKey, PkeParams};
use array::typenum::{U1, Unsigned};
use kem::{Ciphertext, InvalidKey};
use module_lattice::{
Encode,
Encode, FixedWidthPolynomial, FixedWidthVector,
ctutils::{Choice, CtEq},
};

Expand Down Expand Up @@ -90,16 +90,16 @@ where
pub(crate) fn decrypt(&self, ciphertext: &Ciphertext<P>) -> B32 {
let (c1, c2) = P::split_ct(ciphertext);

let mut u: Vector<P::K> = Encode::<P::Du>::decode(c1);
u.decompress::<P::Du>();
let u_compressed: FixedWidthVector<_, P::K, P::Du> = Encode::<P::Du>::decode(c1);
let u: Vector<P::K> = u_compressed.decompress();

let mut v: Polynomial = Encode::<P::Dv>::decode(c2);
v.decompress::<P::Dv>();
let v_compressed: FixedWidthPolynomial<_, P::Dv> = Encode::<P::Dv>::decode(c2);
let v: Polynomial = v_compressed.decompress();

let u_hat = u.ntt();
let sTu = (&self.s_hat * &u_hat).ntt_inverse();
let mut w = &v - &sTu;
Encode::<U1>::encode(w.compress::<U1>())
let w = &v - &sTu;
Encode::<U1>::encode(&Compress::<U1>::compress(w))
}

/// Represent this decryption key as a byte array `(s_hat)`
Expand Down Expand Up @@ -141,16 +141,16 @@ where
let A_hat_t: NttMatrix<P::K> = matrix_sample_ntt(&self.rho, true);
let r_hat: NttVector<P::K> = r.ntt();
let ATr: Vector<P::K> = (&A_hat_t * &r_hat).ntt_inverse();
let mut u = ATr + e1;
let u = ATr + e1;

let mut mu: Polynomial = Encode::<U1>::decode(message);
mu.decompress::<U1>();
let mu_compressed: FixedWidthPolynomial<_, U1> = Encode::<U1>::decode(message);
let mu: Polynomial = mu_compressed.decompress();

let tTr: Polynomial = (&self.t_hat * &r_hat).ntt_inverse();
let mut v = &(&tTr + &e2) + &mu;
let v = &(&tTr + &e2) + &mu;

let c1 = Encode::<P::Du>::encode(u.compress::<P::Du>());
let c2 = Encode::<P::Dv>::encode(v.compress::<P::Dv>());
let c1 = Encode::<P::Du>::encode(&Compress::<P::Du>::compress(u));
let c2 = Encode::<P::Dv>::encode(&Compress::<P::Dv>::compress(v));
P::concat_ct(c1, c2)
}

Expand Down
19 changes: 15 additions & 4 deletions module-lattice/src/algebra.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,15 @@ pub trait Field: Copy + Default + Debug + PartialEq {
fn barrett_reduce(x: Self::Long) -> Self::Int;
}

/// Marker trait for a [`Field`] whose modulus is prime.
///
/// Multiplication on [`Elem<F>`] is gated on `F: PrimeField` because the
/// reduction-based arithmetic in this crate (Barrett reduction, NTT) is
/// only valid for prime-order fields. A non-prime-order representation
/// such as Z_{2^d} can still impl [`Field`] for storage purposes (see
/// [`FixedWidthInt`]) without claiming the multiplicative group structure.
pub trait PrimeField: Field {}

/// The `define_field` macro creates a zero-sized struct and an implementation of the [`Field`]
/// trait for that struct. The caller must specify:
///
Expand Down Expand Up @@ -89,6 +98,8 @@ macro_rules! define_field {
Self::small_reduce($crate::Truncate::truncate(remainder))
}
}

impl $crate::PrimeField for $field {}
};
}

Expand Down Expand Up @@ -157,7 +168,7 @@ impl<F: Field> Sub<Elem<F>> for Elem<F> {
}
}

impl<F: Field> Mul<Elem<F>> for Elem<F> {
impl<F: PrimeField> Mul<Elem<F>> for Elem<F> {
type Output = Elem<F>;

fn mul(self, rhs: Elem<F>) -> Elem<F> {
Expand Down Expand Up @@ -220,7 +231,7 @@ impl<F: Field> Sub<&Polynomial<F>> for &Polynomial<F> {
}
}

impl<F: Field> Mul<&Polynomial<F>> for Elem<F> {
impl<F: PrimeField> Mul<&Polynomial<F>> for Elem<F> {
type Output = Polynomial<F>;

fn mul(self, rhs: &Polynomial<F>) -> Polynomial<F> {
Expand Down Expand Up @@ -306,7 +317,7 @@ impl<F: Field, K: ArraySize> Sub<&Vector<F, K>> for &Vector<F, K> {
}
}

impl<F: Field, K: ArraySize> Mul<&Vector<F, K>> for Elem<F> {
impl<F: PrimeField, K: ArraySize> Mul<&Vector<F, K>> for Elem<F> {
type Output = Vector<F, K>;

fn mul(self, rhs: &Vector<F, K>) -> Vector<F, K> {
Expand Down Expand Up @@ -382,7 +393,7 @@ impl<F: Field> Sub<&NttPolynomial<F>> for &NttPolynomial<F> {
}
}

impl<F: Field> Mul<&NttPolynomial<F>> for Elem<F> {
impl<F: PrimeField> Mul<&NttPolynomial<F>> for Elem<F> {
type Output = NttPolynomial<F>;

fn mul(self, rhs: &NttPolynomial<F>) -> NttPolynomial<F> {
Expand Down
Loading
Loading