Skip to content

Commit 2707520

Browse files
committed
negacyclic ntt works
1 parent 20e633e commit 2707520

2 files changed

Lines changed: 145 additions & 25 deletions

File tree

src/ntt.rs

Lines changed: 101 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,4 @@
1-
use core::panic;
2-
3-
use super::zq::Zq;
4-
/*
5-
src/
6-
poly.rs ← Poly, Rq, RqNtt
7-
ntt.rs ← ntt(), intt() methods
8-
9-
Rq <-> RqNTT with methods
10-
11-
impl Rq<Q, D> {
12-
fn ntt(self) -> RqNtt<Q, D> { ... }
13-
}
14-
impl RqNtt<Q, D> {
15-
fn intt(self) -> Rq<Q, D> { ... }
16-
}
17-
*/
1+
use super::{zq::Zq, poly::Poly};
182

193
pub fn prime_factors(mut n: u64) -> Vec<u64> {
204
let mut factors = Vec::new();
@@ -34,11 +18,13 @@ pub fn prime_factors(mut n: u64) -> Vec<u64> {
3418
factors
3519
}
3620

21+
22+
/// R_q = Z_q[X]/(X^d+1). X^d + 1 = 0 -> X^d = -1 \mod q
23+
/// -> X^{2d} = 1. Assume q is prime, Z_q^* is a cyclic group with order q-1
24+
/// i.e. \forall g \in Z_q^*, g^{q-1} = 1. Since g^{(q-1)/(2d)}^{2d} = 1,
25+
/// for g to exist 2d must divide (q-1).
3726
pub fn find_primitive_2d_root_of_unity<const Q: u64>(d: u64) -> Zq<Q> {
38-
// R_q = Z_q[X]/(X^d+1). X^d + 1 = 0 -> X^d = -1 \mod q
39-
// -> X^{2d} = 1. Assume q is prime, Z_q^* is a cyclic group with order q-1
40-
// i.e. \forall g \in Z_q^*, g^{q-1} = 1. Since g^{(q-1)/(2d)}^{2d} = 1,
41-
// for g to exist 2d must divide (q-1).
27+
4228
let order = Q - 1;
4329
assert_eq!(
4430
order % (2 * d),
@@ -54,7 +40,6 @@ pub fn find_primitive_2d_root_of_unity<const Q: u64>(d: u64) -> Zq<Q> {
5440
let factors = prime_factors(order);
5541
for i in 2..order {
5642
let g = Zq::<Q>::new(i);
57-
// o(g) =
5843
let is_generator = factors.iter().all(|&p| g.pow(order / p) != Zq::<Q>::one());
5944
if is_generator {
6045
// if g is a generator -> o(g) = q-1 -> o(g^{(q-1)/2d}) = 2d
@@ -64,17 +49,108 @@ pub fn find_primitive_2d_root_of_unity<const Q: u64>(d: u64) -> Zq<Q> {
6449
panic!("no multiplicative generator found for Q={Q}")
6550
}
6651

52+
53+
/// NTT: here we assume coeffs can be split *completely* for simplicity and efficiency
54+
/// in the split fields.
55+
/// This is implemented according to this great article https://electricdusk.com/ntt.html
56+
/// This requires {d} to be a power of two.
57+
pub fn ntt<const Q: u64>(coeffs: Vec<Zq<Q>>, psi: Zq<Q>, psi_power: u64) -> Vec<Zq<Q>> {
58+
let d = coeffs.len();
59+
assert!(d.is_power_of_two(), "d should be power of two to split completely: d={d}");
60+
assert!((Q-1).is_multiple_of(2*d as u64));
61+
62+
// Terminal condition: when d = 1, it's the last split. Just returns
63+
// the constant term.
64+
if d == 1 {
65+
return vec![coeffs[0]];
66+
}
67+
68+
// E.g. d=256, root here is \psi^{128} since X^{256}+1 = (X^{128} - 1)(X^{128} + 1)
69+
let root= psi.pow(psi_power);
70+
// Here is the "butterfly" part
71+
// E.g. we're at a \in Z_q[X] / (X^256+1) and we're gonna split to
72+
// a_l \in Z_q[X]/(X^128 - \psi^128), a_r \in Z_q[X]/(X^128 + \psi^128).
73+
// We just let replace all X^128=\psi^128 in a to become a_l,
74+
// X^128=-\psi^128 in a to become a_r.
75+
// Then,
76+
// a_l[0] = a[0] + psi^{128} * a[128]
77+
// a_r[0] = a[0] - psi^{128} * a[128]
78+
// Since `a[0]` and `psi^{128} * a[128]` are reused for a_l and a_r, just different
79+
// operator before the latter term.
80+
// We can draw it as a butterfly.
81+
82+
let mut a_l: Vec<Zq<Q>> = Vec::new();
83+
let mut a_r: Vec<Zq<Q>> = Vec::new();
84+
85+
for i in 0..(d/2) {
86+
a_l.push(coeffs[i] + root * coeffs[i + d/2]);
87+
a_r.push(coeffs[i] - root * coeffs[i + d/2]);
88+
}
89+
90+
// Split the left/right poly all the way down and get the results.
91+
let a_l_coeffs = ntt(a_l, psi, psi_power / 2);
92+
let a_r_coeffs = ntt(a_r, psi, psi_power / 2 + (d/2) as u64);
93+
a_l_coeffs.into_iter().chain(a_r_coeffs).collect()
94+
}
95+
96+
pub fn intt<const Q: u64>(evals: Vec<Zq<Q>>) -> Vec<Zq<Q>> {
97+
todo!()
98+
}
99+
67100
#[cfg(test)]
68101
mod tests {
102+
69103
use super::*;
70104

71105
const Q: u64 = 17;
72106
const D: u64 = 4;
107+
type F = Zq<Q>;
108+
109+
fn setup() -> Zq<Q> {
110+
let psi = find_primitive_2d_root_of_unity::<Q>(D);
111+
println!("psi={:?}", psi);
112+
psi
113+
}
73114

74115
#[test]
75116
fn test_primitive_2d_root_of_unity() {
76-
let omega = find_primitive_2d_root_of_unity::<Q>(D);
77-
assert_eq!(omega.pow(2 * D), Zq::<Q>::one()); // w^{2d} = 1
78-
assert_eq!(omega.pow(D), -Zq::<Q>::one()); // w^d = -1
117+
let psi = setup();
118+
assert_eq!(psi.pow(2 * D), F::one()); // w^{2d} = 1
119+
assert_eq!(psi.pow(D), -F::one()); // w^d = -1
120+
}
121+
122+
// Sage test vectors: q=17, d=4, negacyclic NTT (X^d+1)
123+
// coeffs [16, 3, 0, 14] <-> evals [15, 0, 0, 15]
124+
#[test]
125+
fn test_ntt_forward() {
126+
let psi = setup();
127+
let d = 4;
128+
let coeffs = vec![F::new(16), F::new(3), F::new(0), F::new(14)];
129+
let expected_evals = vec![F::new(15), F::new(0), F::new(0), F::new(15)];
130+
131+
let odd_powers: Vec<_> = (0..d).map(|k| psi.pow(2*k as u64 + 1)).collect();
132+
println!("roots: {:?}", odd_powers);
133+
134+
let a = Poly::new(coeffs.clone());
135+
let evals: Vec<_> = odd_powers.iter().map(|w| a.clone().eval(w.value())).collect();
136+
println!("evals: {:?}", evals);
137+
assert_eq!(ntt::<Q>(coeffs, psi, d/2), expected_evals);
138+
}
139+
140+
141+
#[test]
142+
fn test_intt_backward() {
143+
let evals = vec![F::new(15), F::new(0), F::new(0), F::new(15)];
144+
let expected_coeffs = vec![F::new(16), F::new(3), F::new(0), F::new(14)];
145+
assert_eq!(intt::<Q>(evals), expected_coeffs);
146+
}
147+
148+
#[test]
149+
fn test_ntt_intt_roundtrip() {
150+
type F = Zq<Q>;
151+
let psi = setup();
152+
let d = 4;
153+
let coeffs = vec![F::new(16), F::new(3), F::new(0), F::new(14)];
154+
// assert_eq!(intt::<Q>(ntt::<Q>(coeffs, psi, d)), coeffs);
79155
}
80156
}

src/poly.rs

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,14 @@ impl<const Q: u64> Poly<Q> {
2424
coeffs: vec![Zq::one()],
2525
}
2626
}
27+
28+
pub fn eval(&self, x: u64) -> Zq<Q> {
29+
let mut s = Zq::<Q>::zero();
30+
for (i, &c) in self.coeffs.iter().enumerate() {
31+
s = s + c * Zq::<Q>::new(x).pow(i as u64);
32+
}
33+
s
34+
}
2735
}
2836

2937
impl<const Q: u64> Add for Poly<Q> {
@@ -353,6 +361,42 @@ mod tests {
353361
assert_eq!(p(&[1, 2, 3]) * p(&[2]), p(&[2, 4, 6]));
354362
}
355363

364+
// ─── Poly::eval tests ───
365+
366+
#[test]
367+
fn test_eval_constant() {
368+
// f = 5, f(x) = 5 for all x
369+
assert_eq!(p(&[5]).eval(0), F::new(5));
370+
assert_eq!(p(&[5]).eval(3), F::new(5));
371+
}
372+
373+
#[test]
374+
fn test_eval_linear() {
375+
// f = 3 + 5x
376+
// f(0) = 3, f(1) = 8, f(2) = 13
377+
assert_eq!(p(&[3, 5]).eval(0), F::new(3));
378+
assert_eq!(p(&[3, 5]).eval(1), F::new(8));
379+
assert_eq!(p(&[3, 5]).eval(2), F::new(13));
380+
}
381+
382+
#[test]
383+
fn test_eval_quadratic() {
384+
// f = 3 + 5x + 2x^2 (mod 17)
385+
// f(0) = 3
386+
// f(1) = 3 + 5 + 2 = 10
387+
// f(2) = 3 + 10 + 8 = 21 mod 17 = 4
388+
// f(16) = f(-1 mod 17) = 3 - 5 + 2 = 0
389+
assert_eq!(p(&[3, 5, 2]).eval(0), F::new(3));
390+
assert_eq!(p(&[3, 5, 2]).eval(1), F::new(10));
391+
assert_eq!(p(&[3, 5, 2]).eval(2), F::new(4));
392+
assert_eq!(p(&[3, 5, 2]).eval(16), F::new(0));
393+
}
394+
395+
#[test]
396+
fn test_eval_zero_poly() {
397+
assert_eq!(R::zero().eval(5), F::zero());
398+
}
399+
356400
// ─── Rq tests: R_q = Z_q[X]/(X^4 + 1), q=17, d=4 ───
357401

358402
const D: usize = 4;

0 commit comments

Comments
 (0)