1- use core:: panic;
2-
3- use super :: zq:: Zq ;
4- /*
5- src/
6- poly.rs ← Poly, Rq, RqNtt
7- ntt.rs ← ntt(), intt() methods
8-
9- Rq <-> RqNTT with methods
10-
11- impl Rq<Q, D> {
12- fn ntt(self) -> RqNtt<Q, D> { ... }
13- }
14- impl RqNtt<Q, D> {
15- fn intt(self) -> Rq<Q, D> { ... }
16- }
17- */
1+ use super :: { zq:: Zq , poly:: Poly } ;
182
193pub fn prime_factors ( mut n : u64 ) -> Vec < u64 > {
204 let mut factors = Vec :: new ( ) ;
@@ -34,11 +18,13 @@ pub fn prime_factors(mut n: u64) -> Vec<u64> {
3418 factors
3519}
3620
21+
22+ /// R_q = Z_q[X]/(X^d+1). X^d + 1 = 0 -> X^d = -1 \mod q
23+ /// -> X^{2d} = 1. Assume q is prime, Z_q^* is a cyclic group with order q-1
24+ /// i.e. \forall g \in Z_q^*, g^{q-1} = 1. Since g^{(q-1)/(2d)}^{2d} = 1,
25+ /// for g to exist 2d must divide (q-1).
3726pub fn find_primitive_2d_root_of_unity < const Q : u64 > ( d : u64 ) -> Zq < Q > {
38- // R_q = Z_q[X]/(X^d+1). X^d + 1 = 0 -> X^d = -1 \mod q
39- // -> X^{2d} = 1. Assume q is prime, Z_q^* is a cyclic group with order q-1
40- // i.e. \forall g \in Z_q^*, g^{q-1} = 1. Since g^{(q-1)/(2d)}^{2d} = 1,
41- // for g to exist 2d must divide (q-1).
27+
4228 let order = Q - 1 ;
4329 assert_eq ! (
4430 order % ( 2 * d) ,
@@ -54,7 +40,6 @@ pub fn find_primitive_2d_root_of_unity<const Q: u64>(d: u64) -> Zq<Q> {
5440 let factors = prime_factors ( order) ;
5541 for i in 2 ..order {
5642 let g = Zq :: < Q > :: new ( i) ;
57- // o(g) =
5843 let is_generator = factors. iter ( ) . all ( |& p| g. pow ( order / p) != Zq :: < Q > :: one ( ) ) ;
5944 if is_generator {
6045 // if g is a generator -> o(g) = q-1 -> o(g^{(q-1)/2d}) = 2d
@@ -64,17 +49,108 @@ pub fn find_primitive_2d_root_of_unity<const Q: u64>(d: u64) -> Zq<Q> {
6449 panic ! ( "no multiplicative generator found for Q={Q}" )
6550}
6651
52+
53+ /// NTT: here we assume coeffs can be split *completely* for simplicity and efficiency
54+ /// in the split fields.
55+ /// This is implemented according to this great article https://electricdusk.com/ntt.html
56+ /// This requires {d} to be a power of two.
57+ pub fn ntt < const Q : u64 > ( coeffs : Vec < Zq < Q > > , psi : Zq < Q > , psi_power : u64 ) -> Vec < Zq < Q > > {
58+ let d = coeffs. len ( ) ;
59+ assert ! ( d. is_power_of_two( ) , "d should be power of two to split completely: d={d}" ) ;
60+ assert ! ( ( Q -1 ) . is_multiple_of( 2 * d as u64 ) ) ;
61+
62+ // Terminal condition: when d = 1, it's the last split. Just returns
63+ // the constant term.
64+ if d == 1 {
65+ return vec ! [ coeffs[ 0 ] ] ;
66+ }
67+
68+ // E.g. d=256, root here is \psi^{128} since X^{256}+1 = (X^{128} - 1)(X^{128} + 1)
69+ let root= psi. pow ( psi_power) ;
70+ // Here is the "butterfly" part
71+ // E.g. we're at a \in Z_q[X] / (X^256+1) and we're gonna split to
72+ // a_l \in Z_q[X]/(X^128 - \psi^128), a_r \in Z_q[X]/(X^128 + \psi^128).
73+ // We just let replace all X^128=\psi^128 in a to become a_l,
74+ // X^128=-\psi^128 in a to become a_r.
75+ // Then,
76+ // a_l[0] = a[0] + psi^{128} * a[128]
77+ // a_r[0] = a[0] - psi^{128} * a[128]
78+ // Since `a[0]` and `psi^{128} * a[128]` are reused for a_l and a_r, just different
79+ // operator before the latter term.
80+ // We can draw it as a butterfly.
81+
82+ let mut a_l: Vec < Zq < Q > > = Vec :: new ( ) ;
83+ let mut a_r: Vec < Zq < Q > > = Vec :: new ( ) ;
84+
85+ for i in 0 ..( d/2 ) {
86+ a_l. push ( coeffs[ i] + root * coeffs[ i + d/2 ] ) ;
87+ a_r. push ( coeffs[ i] - root * coeffs[ i + d/2 ] ) ;
88+ }
89+
90+ // Split the left/right poly all the way down and get the results.
91+ let a_l_coeffs = ntt ( a_l, psi, psi_power / 2 ) ;
92+ let a_r_coeffs = ntt ( a_r, psi, psi_power / 2 + ( d/2 ) as u64 ) ;
93+ a_l_coeffs. into_iter ( ) . chain ( a_r_coeffs) . collect ( )
94+ }
95+
96+ pub fn intt < const Q : u64 > ( evals : Vec < Zq < Q > > ) -> Vec < Zq < Q > > {
97+ todo ! ( )
98+ }
99+
67100#[ cfg( test) ]
68101mod tests {
102+
69103 use super :: * ;
70104
71105 const Q : u64 = 17 ;
72106 const D : u64 = 4 ;
107+ type F = Zq < Q > ;
108+
109+ fn setup ( ) -> Zq < Q > {
110+ let psi = find_primitive_2d_root_of_unity :: < Q > ( D ) ;
111+ println ! ( "psi={:?}" , psi) ;
112+ psi
113+ }
73114
74115 #[ test]
75116 fn test_primitive_2d_root_of_unity ( ) {
76- let omega = find_primitive_2d_root_of_unity :: < Q > ( D ) ;
77- assert_eq ! ( omega. pow( 2 * D ) , Zq :: <Q >:: one( ) ) ; // w^{2d} = 1
78- assert_eq ! ( omega. pow( D ) , -Zq :: <Q >:: one( ) ) ; // w^d = -1
117+ let psi = setup ( ) ;
118+ assert_eq ! ( psi. pow( 2 * D ) , F :: one( ) ) ; // w^{2d} = 1
119+ assert_eq ! ( psi. pow( D ) , -F :: one( ) ) ; // w^d = -1
120+ }
121+
122+ // Sage test vectors: q=17, d=4, negacyclic NTT (X^d+1)
123+ // coeffs [16, 3, 0, 14] <-> evals [15, 0, 0, 15]
124+ #[ test]
125+ fn test_ntt_forward ( ) {
126+ let psi = setup ( ) ;
127+ let d = 4 ;
128+ let coeffs = vec ! [ F :: new( 16 ) , F :: new( 3 ) , F :: new( 0 ) , F :: new( 14 ) ] ;
129+ let expected_evals = vec ! [ F :: new( 15 ) , F :: new( 0 ) , F :: new( 0 ) , F :: new( 15 ) ] ;
130+
131+ let odd_powers: Vec < _ > = ( 0 ..d) . map ( |k| psi. pow ( 2 * k as u64 + 1 ) ) . collect ( ) ;
132+ println ! ( "roots: {:?}" , odd_powers) ;
133+
134+ let a = Poly :: new ( coeffs. clone ( ) ) ;
135+ let evals: Vec < _ > = odd_powers. iter ( ) . map ( |w| a. clone ( ) . eval ( w. value ( ) ) ) . collect ( ) ;
136+ println ! ( "evals: {:?}" , evals) ;
137+ assert_eq ! ( ntt:: <Q >( coeffs, psi, d/2 ) , expected_evals) ;
138+ }
139+
140+
141+ #[ test]
142+ fn test_intt_backward ( ) {
143+ let evals = vec ! [ F :: new( 15 ) , F :: new( 0 ) , F :: new( 0 ) , F :: new( 15 ) ] ;
144+ let expected_coeffs = vec ! [ F :: new( 16 ) , F :: new( 3 ) , F :: new( 0 ) , F :: new( 14 ) ] ;
145+ assert_eq ! ( intt:: <Q >( evals) , expected_coeffs) ;
146+ }
147+
148+ #[ test]
149+ fn test_ntt_intt_roundtrip ( ) {
150+ type F = Zq < Q > ;
151+ let psi = setup ( ) ;
152+ let d = 4 ;
153+ let coeffs = vec ! [ F :: new( 16 ) , F :: new( 3 ) , F :: new( 0 ) , F :: new( 14 ) ] ;
154+ // assert_eq!(intt::<Q>(ntt::<Q>(coeffs, psi, d)), coeffs);
79155 }
80156}
0 commit comments