Skip to content

Commit 99e1884

Browse files
committed
module-lattice, ml-kem: introduce FixedWidthInt for compressed values
Per #26, the codomain of FIPS 203's Compress_d is the integer ring Z_{2^d}, not the prime field Z_q. Reusing Elem<F> for both lets the Barrett-reduced Mul on Elem be applied to compressed values where it is meaningless. Adds FixedWidthInt<F, D>, FixedWidthPolynomial<F, D>, and FixedWidthVector<F, K, D> in module-lattice, plus a PrimeField: Field marker trait that gates Mul on Elem/Polynomial/Vector/NttPolynomial. ml-kem's mutating Compress trait is replaced with consuming Compress<D>/Decompress<D> traits that move between the prime-field types and the new fixed-width types, threading the typed boundary through K-PKE.Encrypt and K-PKE.Decrypt. Closes #26.
1 parent 584f581 commit 99e1884

5 files changed

Lines changed: 224 additions & 69 deletions

File tree

ml-kem/src/compress.rs

Lines changed: 54 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
use crate::algebra::{BaseField, Elem, Int, Polynomial, Vector};
2-
use array::ArraySize;
3-
use module_lattice::EncodingSize;
4-
use module_lattice::{Field, Truncate};
2+
use module_lattice::{
3+
ArraySize, EncodingSize, Field, FixedWidthInt, FixedWidthPolynomial, FixedWidthVector,
4+
Truncate,
5+
};
56

67
// A convenience trait to allow us to associate some constants with a typenum
78
pub(crate) trait CompressionFactor: EncodingSize {
@@ -22,68 +23,75 @@ where
2223
const DIV_MUL: u64 = (1 << T::DIV_SHIFT) / BaseField::QLL;
2324
}
2425

25-
// Traits for objects that allow compression / decompression
26-
pub(crate) trait Compress {
27-
fn compress<D: CompressionFactor>(&mut self) -> &Self;
28-
fn decompress<D: CompressionFactor>(&mut self) -> &Self;
26+
/// Compress a prime-field representation into its `Z_{2^D}` fixed-width form.
27+
pub(crate) trait Compress<D: CompressionFactor> {
28+
type Output;
29+
fn compress(self) -> Self::Output;
2930
}
3031

31-
impl Compress for Elem {
32+
/// Decompress a `Z_{2^D}` fixed-width representation back into the prime field.
33+
pub(crate) trait Decompress<D: CompressionFactor> {
34+
type Output;
35+
fn decompress(self) -> Self::Output;
36+
}
37+
38+
impl<D: CompressionFactor> Compress<D> for Elem {
39+
type Output = FixedWidthInt<BaseField, D>;
40+
3241
// Equation 4.5: Compress_d(x) = round((2^d / q) x)
3342
//
3443
// Here and in decompression, we leverage the following facts:
3544
//
3645
// round(a / b) = floor((a + b/2) / b)
3746
// a / q ~= (a * x) >> s where x >> s ~= 1/q
38-
fn compress<D: CompressionFactor>(&mut self) -> &Self {
47+
fn compress(self) -> FixedWidthInt<BaseField, D> {
3948
const Q_HALF: u64 = (BaseField::QLL + 1) >> 1;
4049
let x = u64::from(self.0);
4150
let y = (((x << D::USIZE) + Q_HALF) * D::DIV_MUL) >> D::DIV_SHIFT;
42-
self.0 = u16::truncate(y) & D::MASK;
43-
self
51+
FixedWidthInt::new(u16::truncate(y) & D::MASK)
4452
}
53+
}
54+
55+
impl<D: CompressionFactor> Decompress<D> for FixedWidthInt<BaseField, D> {
56+
type Output = Elem;
4557

4658
// Equation 4.6: Decompress_d(x) = round((q / 2^d) x)
47-
fn decompress<D: CompressionFactor>(&mut self) -> &Self {
48-
let x = u32::from(self.0);
59+
fn decompress(self) -> Elem {
60+
let x = u32::from(self.value());
4961
let y = ((x * BaseField::QL) + D::POW2_HALF) >> D::USIZE;
50-
self.0 = Truncate::truncate(y);
51-
self
62+
Elem::new(Truncate::truncate(y))
5263
}
5364
}
54-
impl Compress for Polynomial {
55-
fn compress<D: CompressionFactor>(&mut self) -> &Self {
56-
for x in &mut self.0 {
57-
x.compress::<D>();
58-
}
5965

60-
self
66+
impl<D: CompressionFactor> Compress<D> for Polynomial {
67+
type Output = FixedWidthPolynomial<BaseField, D>;
68+
69+
fn compress(self) -> FixedWidthPolynomial<BaseField, D> {
70+
FixedWidthPolynomial::new(self.0.into_iter().map(Compress::<D>::compress).collect())
6171
}
72+
}
6273

63-
fn decompress<D: CompressionFactor>(&mut self) -> &Self {
64-
for x in &mut self.0 {
65-
x.decompress::<D>();
66-
}
74+
impl<D: CompressionFactor> Decompress<D> for FixedWidthPolynomial<BaseField, D> {
75+
type Output = Polynomial;
6776

68-
self
77+
fn decompress(self) -> Polynomial {
78+
Polynomial::new(self.0.into_iter().map(Decompress::<D>::decompress).collect())
6979
}
7080
}
7181

72-
impl<K: ArraySize> Compress for Vector<K> {
73-
fn compress<D: CompressionFactor>(&mut self) -> &Self {
74-
for x in &mut self.0 {
75-
x.compress::<D>();
76-
}
82+
impl<K: ArraySize, D: CompressionFactor> Compress<D> for Vector<K> {
83+
type Output = FixedWidthVector<BaseField, K, D>;
7784

78-
self
85+
fn compress(self) -> FixedWidthVector<BaseField, K, D> {
86+
FixedWidthVector::new(self.0.into_iter().map(Compress::<D>::compress).collect())
7987
}
88+
}
8089

81-
fn decompress<D: CompressionFactor>(&mut self) -> &Self {
82-
for x in &mut self.0 {
83-
x.decompress::<D>();
84-
}
90+
impl<K: ArraySize, D: CompressionFactor> Decompress<D> for FixedWidthVector<BaseField, K, D> {
91+
type Output = Vector<K>;
8592

86-
self
93+
fn decompress(self) -> Vector<K> {
94+
Vector::new(self.0.into_iter().map(Decompress::<D>::decompress).collect())
8795
}
8896
}
8997

@@ -111,11 +119,10 @@ pub(crate) mod tests {
111119
let error_threshold = i32::from(Ratio::new(BaseField::Q, 1 << D::USIZE).to_integer());
112120

113121
for x in 0..BaseField::Q {
114-
let mut y = Elem::new(x);
115-
y.compress::<D>();
116-
y.decompress::<D>();
122+
let compressed = Compress::<D>::compress(Elem::new(x));
123+
let decompressed = Decompress::<D>::decompress(compressed);
117124

118-
let mut error = i32::from(y.0) - i32::from(x) + QI32;
125+
let mut error = i32::from(decompressed.0) - i32::from(x) + QI32;
119126
if error > (QI32 - 1) / 2 {
120127
error -= QI32;
121128
}
@@ -131,19 +138,17 @@ pub(crate) mod tests {
131138

132139
fn decompression_compression_equality<D: CompressionFactor>() {
133140
for x in 0..(1 << D::USIZE) {
134-
let mut y = Elem::new(x);
135-
y.decompress::<D>();
136-
y.compress::<D>();
141+
let decompressed = Decompress::<D>::decompress(FixedWidthInt::<BaseField, D>::new(x));
142+
let recompressed = Compress::<D>::compress(decompressed);
137143

138-
assert_eq!(y.0, x, "failed for x: {}, D: {}", x, D::USIZE);
144+
assert_eq!(recompressed.value(), x, "failed for x: {}, D: {}", x, D::USIZE);
139145
}
140146
}
141147

142148
fn decompress_KAT<D: CompressionFactor>() {
143149
for y in 0..(1 << D::USIZE) {
144150
let x_expected = rational_decompress::<D>(y);
145-
let mut x_actual = Elem::new(y);
146-
x_actual.decompress::<D>();
151+
let x_actual = Decompress::<D>::decompress(FixedWidthInt::<BaseField, D>::new(y));
147152

148153
assert_eq!(x_expected, x_actual.0);
149154
}
@@ -152,10 +157,9 @@ pub(crate) mod tests {
152157
fn compress_KAT<D: CompressionFactor>() {
153158
for x in 0..BaseField::Q {
154159
let y_expected = rational_compress::<D>(x);
155-
let mut y_actual = Elem::new(x);
156-
y_actual.compress::<D>();
160+
let y_actual = Compress::<D>::compress(Elem::new(x));
157161

158-
assert_eq!(y_expected, y_actual.0, "for x: {}, D: {}", x, D::USIZE);
162+
assert_eq!(y_expected, y_actual.value(), "for x: {}, D: {}", x, D::USIZE);
159163
}
160164
}
161165

ml-kem/src/pke.rs

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,13 @@ use crate::algebra::{
33
Ntt, NttInverse, NttMatrix, NttVector, Polynomial, Vector, matrix_sample_ntt, sample_poly_cbd,
44
sample_poly_vec_cbd,
55
};
6-
use crate::compress::Compress;
6+
use crate::compress::{Compress, Decompress};
77
use crate::crypto::{G, PRF};
88
use crate::param::{EncodedDecryptionKey, EncodedEncryptionKey, PkeParams};
99
use array::typenum::{U1, Unsigned};
1010
use kem::{Ciphertext, InvalidKey};
1111
use module_lattice::{
12-
Encode,
12+
Encode, FixedWidthPolynomial, FixedWidthVector,
1313
ctutils::{Choice, CtEq},
1414
};
1515

@@ -90,16 +90,16 @@ where
9090
pub(crate) fn decrypt(&self, ciphertext: &Ciphertext<P>) -> B32 {
9191
let (c1, c2) = P::split_ct(ciphertext);
9292

93-
let mut u: Vector<P::K> = Encode::<P::Du>::decode(c1);
94-
u.decompress::<P::Du>();
93+
let u_compressed: FixedWidthVector<_, P::K, P::Du> = Encode::<P::Du>::decode(c1);
94+
let u: Vector<P::K> = u_compressed.decompress();
9595

96-
let mut v: Polynomial = Encode::<P::Dv>::decode(c2);
97-
v.decompress::<P::Dv>();
96+
let v_compressed: FixedWidthPolynomial<_, P::Dv> = Encode::<P::Dv>::decode(c2);
97+
let v: Polynomial = v_compressed.decompress();
9898

9999
let u_hat = u.ntt();
100100
let sTu = (&self.s_hat * &u_hat).ntt_inverse();
101-
let mut w = &v - &sTu;
102-
Encode::<U1>::encode(w.compress::<U1>())
101+
let w = &v - &sTu;
102+
Encode::<U1>::encode(&Compress::<U1>::compress(w))
103103
}
104104

105105
/// Represent this decryption key as a byte array `(s_hat)`
@@ -141,16 +141,16 @@ where
141141
let A_hat_t: NttMatrix<P::K> = matrix_sample_ntt(&self.rho, true);
142142
let r_hat: NttVector<P::K> = r.ntt();
143143
let ATr: Vector<P::K> = (&A_hat_t * &r_hat).ntt_inverse();
144-
let mut u = ATr + e1;
144+
let u = ATr + e1;
145145

146-
let mut mu: Polynomial = Encode::<U1>::decode(message);
147-
mu.decompress::<U1>();
146+
let mu_compressed: FixedWidthPolynomial<_, U1> = Encode::<U1>::decode(message);
147+
let mu: Polynomial = mu_compressed.decompress();
148148

149149
let tTr: Polynomial = (&self.t_hat * &r_hat).ntt_inverse();
150-
let mut v = &(&tTr + &e2) + &mu;
150+
let v = &(&tTr + &e2) + &mu;
151151

152-
let c1 = Encode::<P::Du>::encode(u.compress::<P::Du>());
153-
let c2 = Encode::<P::Dv>::encode(v.compress::<P::Dv>());
152+
let c1 = Encode::<P::Du>::encode(&Compress::<P::Du>::compress(u));
153+
let c2 = Encode::<P::Dv>::encode(&Compress::<P::Dv>::compress(v));
154154
P::concat_ct(c1, c2)
155155
}
156156

module-lattice/src/algebra.rs

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,15 @@ pub trait Field: Copy + Default + Debug + PartialEq {
3737
fn barrett_reduce(x: Self::Long) -> Self::Int;
3838
}
3939

40+
/// Marker trait for a [`Field`] whose modulus is prime.
41+
///
42+
/// Multiplication on [`Elem<F>`] is gated on `F: PrimeField` because the
43+
/// reduction-based arithmetic in this crate (Barrett reduction, NTT) is
44+
/// only valid for prime-order fields. A non-prime-order representation
45+
/// such as Z_{2^d} can still impl [`Field`] for storage purposes (see
46+
/// [`FixedWidthInt`]) without claiming the multiplicative group structure.
47+
pub trait PrimeField: Field {}
48+
4049
/// The `define_field` macro creates a zero-sized struct and an implementation of the [`Field`]
4150
/// trait for that struct. The caller must specify:
4251
///
@@ -89,6 +98,8 @@ macro_rules! define_field {
8998
Self::small_reduce($crate::Truncate::truncate(remainder))
9099
}
91100
}
101+
102+
impl $crate::PrimeField for $field {}
92103
};
93104
}
94105

@@ -157,7 +168,7 @@ impl<F: Field> Sub<Elem<F>> for Elem<F> {
157168
}
158169
}
159170

160-
impl<F: Field> Mul<Elem<F>> for Elem<F> {
171+
impl<F: PrimeField> Mul<Elem<F>> for Elem<F> {
161172
type Output = Elem<F>;
162173

163174
fn mul(self, rhs: Elem<F>) -> Elem<F> {
@@ -220,7 +231,7 @@ impl<F: Field> Sub<&Polynomial<F>> for &Polynomial<F> {
220231
}
221232
}
222233

223-
impl<F: Field> Mul<&Polynomial<F>> for Elem<F> {
234+
impl<F: PrimeField> Mul<&Polynomial<F>> for Elem<F> {
224235
type Output = Polynomial<F>;
225236

226237
fn mul(self, rhs: &Polynomial<F>) -> Polynomial<F> {
@@ -306,7 +317,7 @@ impl<F: Field, K: ArraySize> Sub<&Vector<F, K>> for &Vector<F, K> {
306317
}
307318
}
308319

309-
impl<F: Field, K: ArraySize> Mul<&Vector<F, K>> for Elem<F> {
320+
impl<F: PrimeField, K: ArraySize> Mul<&Vector<F, K>> for Elem<F> {
310321
type Output = Vector<F, K>;
311322

312323
fn mul(self, rhs: &Vector<F, K>) -> Vector<F, K> {
@@ -382,7 +393,7 @@ impl<F: Field> Sub<&NttPolynomial<F>> for &NttPolynomial<F> {
382393
}
383394
}
384395

385-
impl<F: Field> Mul<&NttPolynomial<F>> for Elem<F> {
396+
impl<F: PrimeField> Mul<&NttPolynomial<F>> for Elem<F> {
386397
type Output = NttPolynomial<F>;
387398

388399
fn mul(self, rhs: &NttPolynomial<F>) -> NttPolynomial<F> {

0 commit comments

Comments
 (0)