Skip to content

Commit 5727b8a

Browse files
committed
ntt/intt finished
1 parent 9e0d915 commit 5727b8a

3 files changed

Lines changed: 146 additions & 66 deletions

File tree

Makefile

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
.PHONY: check fix fmt clippy test
2+
3+
check: fmt clippy test
4+
5+
fix: fmt
6+
cargo clippy --fix --allow-dirty --allow-staged
7+
8+
fmt:
9+
cargo fmt
10+
11+
clippy:
12+
cargo clippy -- -D warnings
13+
14+
test:
15+
cargo test

src/ntt.rs

Lines changed: 83 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -49,13 +49,46 @@ pub fn find_primitive_2d_root_of_unity<const Q: u64>(d: u64) -> Zq<Q> {
4949
panic!("no multiplicative generator found for Q={Q}")
5050
}
5151

52+
/// Reverse order of the result.
53+
/// Since the result of our NTT would be (w^1, w^5, w^3, w^7) for d=4,
54+
/// but we expect it to be (w^1, w^3, w^5, w^7).
55+
/// Intuition:
56+
/// So it's actually dividing elements k s.t. \psi^{2k+1} is a root of x^d+1
57+
/// every layer we put w^{n/2} (even) to the left and -w^{n/2} to the right.
58+
/// So we just map the result from NTT back to (w^1, w^3, w^5, w^7) with bit-reverse permutation
59+
fn _bit_reverse_permutation<T>(v: &mut [T]) {
60+
let n = v.len();
61+
let log_n = n.trailing_zeros();
62+
for i in 0..n {
63+
let j = i.reverse_bits() >> (usize::BITS - log_n);
64+
if i < j {
65+
v.swap(i, j);
66+
}
67+
}
68+
}
69+
70+
/// NTT: split polynomial Z_q[X]/(X^d+1) into their remainders in irreducibles Z_q[X]/(X-\zeta^i).
71+
/// For negacyclic (X^d+1), to fully split X^d+1, we need {d} to be a power of two.
72+
/// Otherwise the last layer wouldn't be degree 1 poly, might be deg-2 or something else.
73+
/// Here we only deal with the ones can be split *completely* for simplicity and efficiency
74+
/// in the split fields.
75+
pub fn ntt<const Q: u64, const D: usize>(coeffs: Vec<Zq<Q>>) -> Vec<Zq<Q>> {
76+
assert!(
77+
D.is_power_of_two(),
78+
"d should be power of two to split completely: d={D}"
79+
);
80+
assert!((Q - 1).is_multiple_of(2 * D as u64));
81+
82+
let psi = find_primitive_2d_root_of_unity::<Q>(D as u64);
83+
84+
let mut result = _ntt::<Q, D>(coeffs, psi, D as u64);
85+
_bit_reverse_permutation(&mut result);
86+
result
87+
}
88+
5289
/// This is implemented according to this great article https://electricdusk.com/ntt.html
5390
/// zeta means current level is Z_q[X]/(X^d - \psi^{zeta_exp})
54-
pub fn _ntt<const Q: u64, const D: u64>(
55-
coeffs: Vec<Zq<Q>>,
56-
psi: Zq<Q>,
57-
zeta_exp: u64,
58-
) -> Vec<Zq<Q>> {
91+
fn _ntt<const Q: u64, const D: usize>(coeffs: Vec<Zq<Q>>, psi: Zq<Q>, zeta_exp: u64) -> Vec<Zq<Q>> {
5992
let d = coeffs.len();
6093
assert!((Q - 1).is_multiple_of(2 * d as u64));
6194

@@ -64,8 +97,17 @@ pub fn _ntt<const Q: u64, const D: u64>(
6497
if d == 1 {
6598
return vec![coeffs[0]];
6699
}
100+
101+
// Find the term \zeta^{d/2} for this split, which is used to replace X^{d/2} with `root`
102+
// to reduce the polynomial to a_l and a_r.
103+
// We pass `zeta_exp` instead of \zeta^{d/2} directly because doing square root of field is expensive.
104+
// This is required in later recursion.
105+
// - left: (\zeta^{d/2})^{1/2}
106+
// - right: -(\zeta^{d/2})^{1/2}
107+
// Instead, we track the current exponent of zeta and we can calculate the term.
108+
// Replace X^{d/2} with zeta^ X^{d/2}..X^d
67109
// psi_power = d/2 first.
68-
// E.g. d=256, root here is \psi^{128} since X^{256}+1 = (X^{128} - 1)(X^{128} + 1)
110+
// E.g. d=256, root here is \psi^{128} since X^{256}+1 = (X^{128} - \zeta^{128})(X^{128} + \zeta^{128})
69111
let root = psi.pow(zeta_exp / 2);
70112
// Here is the "butterfly" part
71113
// E.g. we're at a \in Z_q[X] / (X^256+1) and we're gonna split to
@@ -94,63 +136,32 @@ pub fn _ntt<const Q: u64, const D: u64>(
94136
// = X^{128} - \psi^{128+D}, where D=256 and \psi^D = -1.
95137
// TODO: we can actually derive the correct root with a precalculated table \psi...\psi^{511}
96138
let a_l_coeffs = _ntt::<Q, D>(a_l, psi, zeta_exp / 2);
97-
let a_r_coeffs = _ntt::<Q, D>(a_r, psi, zeta_exp / 2 + D);
139+
let a_r_coeffs = _ntt::<Q, D>(a_r, psi, zeta_exp / 2 + D as u64);
98140
a_l_coeffs.into_iter().chain(a_r_coeffs).collect()
99141
}
100142

101-
/// Reverse order of the result.
102-
/// Since the result of our NTT would be (w^1, w^5, w^3, w^7) for d=4,
103-
/// but we expect it to be (w^1, w^3, w^5, w^7).
104-
/// Intuition:
105-
/// So it's actually dividing elements k s.t. \psi^{2k+1} is a root of x^d+1
106-
/// every layer we put w^{n/2} (even) to the left and -w^{n/2} to the right.
107-
/// So we just map the result from NTT back to (w^1, w^3, w^5, w^7) with bit-reverse permutation
108-
fn _bit_reverse_permutation<T>(v: &mut [T]) {
109-
let n = v.len();
110-
let log_n = n.trailing_zeros();
111-
for i in 0..n {
112-
let j = i.reverse_bits() >> (usize::BITS - log_n);
113-
if i < j {
114-
v.swap(i, j);
115-
}
116-
}
117-
}
118-
119-
/// NTT: split polynomials X^d+1 into irreducibles. For negacyclic (X^d+1), to fully split the
120-
/// polynomial, we need {d} to be a power of two. Otherwise the last layer wouldn't be degree 1 poly, might be
121-
/// degree 2 or something else.
122-
/// Here we only deal with the ones can be split *completely* for simplicity and efficiency
123-
/// in the split fields.
124-
pub fn ntt<const Q: u64, const D: u64>(coeffs: Vec<Zq<Q>>, psi: Zq<Q>) -> Vec<Zq<Q>> {
143+
/// Inverse NTT: recover evaluations (remainders) in irreducible polynomials Z_q[X]/(X-\zeta^i) back
144+
/// to the single polynomial in Z_q[X]/(X^d+1).
145+
/// Assumption is the same as NTT:
146+
/// 1. 2d | q-1 so primitive 2d-th roots exist.
147+
/// 2. d should be a power of two so the polynomial can be fully split into deg-1.
148+
pub fn intt<const Q: u64, const D: usize>(mut evals: Vec<Zq<Q>>) -> Vec<Zq<Q>> {
125149
assert!(
126150
D.is_power_of_two(),
127151
"d should be power of two to split completely: d={D}"
128152
);
129-
assert!((Q - 1).is_multiple_of(2 * D));
130-
131-
let mut result = _ntt::<Q, D>(coeffs, psi, D);
132-
_bit_reverse_permutation(&mut result);
133-
result
134-
}
153+
assert!((Q - 1).is_multiple_of(2 * D as u64));
135154

136-
pub fn intt<const Q: u64, const D: u64>(mut evals: Vec<Zq<Q>>, psi: Zq<Q>) -> Vec<Zq<Q>> {
137-
assert!(
138-
D.is_power_of_two(),
139-
"d should be power of two to split completely: d={D}"
140-
);
141-
assert!((Q - 1).is_multiple_of(2 * D));
155+
let psi = find_primitive_2d_root_of_unity::<Q>(D as u64);
142156

143157
// since we need to run iNTT on the original order of the output from NTT
144158
_bit_reverse_permutation(&mut evals);
145159

146-
_intt::<Q, D>(evals, psi, D)
160+
_intt::<Q, D>(evals, psi, D as u64)
147161
}
148162

149-
pub fn _intt<const Q: u64, const D: u64>(
150-
evals: Vec<Zq<Q>>,
151-
psi: Zq<Q>,
152-
zeta_exp: u64,
153-
) -> Vec<Zq<Q>> {
163+
/// Inverse NTT: recover polynomials Z_q[X]/(X^d+1) from irreducible polynomials.
164+
fn _intt<const Q: u64, const D: usize>(evals: Vec<Zq<Q>>, psi: Zq<Q>, zeta_exp: u64) -> Vec<Zq<Q>> {
154165
// return coefficient form
155166
let d = evals.len();
156167
assert!((Q - 1).is_multiple_of(2 * d as u64));
@@ -162,12 +173,21 @@ pub fn _intt<const Q: u64, const D: u64>(
162173
}
163174
let (evals_l, evals_r) = evals.split_at(d / 2);
164175

165-
let a_l = _intt::<Q, D>(evals_l.to_vec(), psi, zeta_exp / 2);
166-
let a_r = _intt::<Q, D>(evals_r.to_vec(), psi, zeta_exp / 2 + D);
167-
168176
// Inverse butterfly: recover a[i] and a[i+d/2] from a_l[i] and a_r[i]
169-
let mut a: Vec<Zq<Q>> = vec![Zq::<Q>::zero(); d];
177+
// It's just the inverse of NTT butterfly. Observing the first term of a_l(x) and a_r(x)
178+
// - a_l0 = a_0 + \zeta^{d/2} a_{128}
179+
// - a_r0 = a_0 - \zeta^{d/2} a_{128}
180+
// Adding them we get a_0 = 2^{-1} * (a_l0 + a_r0)
181+
// Subtracting them we get a_{128} = 2^{-1} * (a_l0 - a_r0) * \zeta^{-128}
182+
// So we recover a_i and a_{i+d/2} from a_li and a_ri with 2^{-1} and \zeta^{-d/2}
183+
184+
// We use the same approach to calculate \zeta^{128}
170185
let root = psi.pow(zeta_exp / 2);
186+
// Recursively prepare a_l and a_r
187+
let a_l = _intt::<Q, D>(evals_l.to_vec(), psi, zeta_exp / 2);
188+
let a_r = _intt::<Q, D>(evals_r.to_vec(), psi, zeta_exp / 2 + D as u64);
189+
// Actual inverse butterfly as described above
190+
let mut a: Vec<Zq<Q>> = vec![Zq::<Q>::zero(); d];
171191
let two_inv = Zq::new(2).inv();
172192
for i in 0..(d / 2) {
173193
a[i] = two_inv * (a_l[i] + a_r[i]);
@@ -182,20 +202,20 @@ mod tests {
182202
use super::*;
183203

184204
const Q: u64 = 17;
185-
const D: u64 = 4;
205+
const D: usize = 4;
186206
type F = Zq<Q>;
187207

188208
fn setup() -> Zq<Q> {
189-
let psi = find_primitive_2d_root_of_unity::<Q>(D);
209+
let psi = find_primitive_2d_root_of_unity::<Q>(D as u64);
190210
println!("psi={:?}", psi);
191211
psi
192212
}
193213

194214
#[test]
195215
fn test_primitive_2d_root_of_unity() {
196216
let psi = setup();
197-
assert_eq!(psi.pow(2 * D), F::one()); // w^{2d} = 1
198-
assert_eq!(psi.pow(D), -F::one()); // w^d = -1
217+
assert_eq!(psi.pow(2 * D as u64), F::one()); // w^{2d} = 1
218+
assert_eq!(psi.pow(D as u64), -F::one()); // w^d = -1
199219
}
200220

201221
// Sage test vectors: q=17, d=4, negacyclic NTT (X^d+1)
@@ -218,36 +238,33 @@ mod tests {
218238
evals
219239
};
220240
let evals = get_evals();
221-
assert_eq!(ntt::<Q, D>(coeffs, psi), evals);
241+
assert_eq!(ntt::<Q, D>(coeffs), evals);
222242
}
223243

224244
#[test]
225245
fn test_intt_backward() {
226-
let psi = setup();
227246
let evals = vec![F::new(14), F::new(0), F::new(10), F::new(16)];
228247
let expected_coeffs = vec![F::new(10), F::new(4), F::new(8), F::new(0)];
229248

230-
assert_eq!(intt::<Q, D>(evals, psi), expected_coeffs);
249+
assert_eq!(intt::<Q, D>(evals), expected_coeffs);
231250
}
232251

233252
#[test]
234253
fn test_ntt_intt_roundtrip() {
235254
type F = Zq<Q>;
236-
let psi = setup();
237255
let coeffs = vec![F::new(16), F::new(3), F::new(0), F::new(14)];
238256
let coeffs_clone = coeffs.clone();
239-
assert_eq!(intt::<Q, D>(ntt::<Q, D>(coeffs, psi), psi), coeffs_clone);
257+
assert_eq!(intt::<Q, D>(ntt::<Q, D>(coeffs)), coeffs_clone);
240258
}
241259

242260
// ─── q=12289, d=1024 (Falcon params) ───
243261

244262
const Q2: u64 = 12289;
245-
const D2: u64 = 1024;
263+
const D2: usize = 1024;
246264
type F2 = Zq<Q2>;
247265

248266
#[test]
249267
fn test_ntt_falcon() {
250-
let psi = find_primitive_2d_root_of_unity::<Q2>(D2);
251268
let coeffs_raw: [u64; 1024] = [
252269
8633, 1504, 11298, 8147, 6951, 5539, 3291, 334, 7732, 376, 3099, 4879, 9978, 7512,
253270
3274, 6114, 4942, 8255, 8730, 758, 1334, 5361, 3507, 10969, 5079, 9882, 6516, 4586,
@@ -402,9 +419,9 @@ mod tests {
402419
let coeffs: Vec<F2> = coeffs_raw.iter().map(|&c| F2::new(c)).collect();
403420
let expected_evals: Vec<F2> = evals_raw.iter().map(|&e| F2::new(e)).collect();
404421

405-
let actual_evals = ntt::<Q2, D2>(coeffs.clone(), psi);
422+
let actual_evals = ntt::<Q2, D2>(coeffs.clone());
406423
assert_eq!(actual_evals, expected_evals);
407-
let coeffs_roundtrip = intt::<Q2, D2>(actual_evals, psi);
424+
let coeffs_roundtrip = intt::<Q2, D2>(actual_evals);
408425
assert_eq!(coeffs, coeffs_roundtrip);
409426
}
410427
}

src/poly.rs

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
use super::ntt;
12
use super::zq::Zq;
23
use std::ops::{Add, Mul, Neg, Sub};
34

@@ -129,6 +130,14 @@ impl<const Q: u64, const D: usize> Rq<Q, D> {
129130
&self.coeffs
130131
}
131132

133+
/// Convert to NTT (evaluation) form.
134+
pub fn ntt(self) -> RqNtt<Q, D> {
135+
let evals_vec = ntt::ntt::<Q, D>(self.coeffs.to_vec());
136+
RqNtt {
137+
evals: evals_vec.try_into().unwrap(),
138+
}
139+
}
140+
132141
/// Reduce a polynomial (with up to 2D-1 coefficients) mod X^D + 1.
133142
fn reduce(full: &[Zq<Q>]) -> [Zq<Q>; D] {
134143
assert!(full.len() < 2 * D);
@@ -218,6 +227,14 @@ impl<const Q: u64, const D: usize> RqNtt<Q, D> {
218227
pub fn evals(&self) -> &[Zq<Q>; D] {
219228
&self.evals
220229
}
230+
231+
/// Convert back to coefficient form.
232+
pub fn intt(self) -> Rq<Q, D> {
233+
let coeffs_vec = ntt::intt::<Q, D>(self.evals.to_vec());
234+
Rq {
235+
coeffs: coeffs_vec.try_into().unwrap(),
236+
}
237+
}
221238
}
222239

223240
impl<const Q: u64, const D: usize> Add for RqNtt<Q, D> {
@@ -581,4 +598,35 @@ mod tests {
581598
let c = ntt_from([4, 9, 2, 6]);
582599
assert_eq!(a.clone() * (b.clone() + c.clone()), a.clone() * b + a * c);
583600
}
601+
602+
// ─── Rq <-> RqNtt conversion tests ───
603+
604+
#[test]
605+
fn test_rq_ntt_roundtrip() {
606+
let a = rp([10, 4, 8, 0]);
607+
assert_eq!(a.clone().ntt().intt(), a);
608+
}
609+
610+
#[test]
611+
fn test_rq_ntt_roundtrip_ones() {
612+
let a = Ring::one();
613+
assert_eq!(a.clone().ntt().intt(), a);
614+
}
615+
616+
#[test]
617+
fn test_rq_ntt_mul_matches_schoolbook() {
618+
// NTT mul should give same result as schoolbook mul
619+
let a = rp([1, 0, 0, 1]); // 1 + x^3
620+
let b = rp([1, 0, 1, 0]); // 1 + x^2
621+
let schoolbook = a.clone() * b.clone();
622+
let ntt_result = (a.ntt() * b.ntt()).intt();
623+
assert_eq!(ntt_result, schoolbook);
624+
}
625+
626+
#[test]
627+
fn test_rq_ntt_mul_by_one() {
628+
let a = rp([3, 5, 7, 11]);
629+
let one = Ring::one();
630+
assert_eq!((a.clone().ntt() * one.ntt()).intt(), a);
631+
}
584632
}

0 commit comments

Comments
 (0)