@@ -49,13 +49,46 @@ pub fn find_primitive_2d_root_of_unity<const Q: u64>(d: u64) -> Zq<Q> {
4949 panic ! ( "no multiplicative generator found for Q={Q}" )
5050}
5151
52+ /// Reverse order of the result.
53+ /// Since the result of our NTT would be (w^1, w^5, w^3, w^7) for d=4,
54+ /// but we expect it to be (w^1, w^3, w^5, w^7).
55+ /// Intuition:
56+ /// So it's actually dividing elements k s.t. \psi^{2k+1} is a root of x^d+1
57+ /// every layer we put w^{n/2} (even) to the left and -w^{n/2} to the right.
58+ /// So we just map the result from NTT back to (w^1, w^3, w^5, w^7) with bit-reverse permutation
59+ fn _bit_reverse_permutation < T > ( v : & mut [ T ] ) {
60+ let n = v. len ( ) ;
61+ let log_n = n. trailing_zeros ( ) ;
62+ for i in 0 ..n {
63+ let j = i. reverse_bits ( ) >> ( usize:: BITS - log_n) ;
64+ if i < j {
65+ v. swap ( i, j) ;
66+ }
67+ }
68+ }
69+
70+ /// NTT: split polynomial Z_q[X]/(X^d+1) into their remainders in irreducibles Z_q[X]/(X-\zeta^i).
71+ /// For negacyclic (X^d+1), to fully split X^d+1, we need {d} to be a power of two.
72+ /// Otherwise the last layer wouldn't be degree 1 poly, might be deg-2 or something else.
73+ /// Here we only deal with the ones can be split *completely* for simplicity and efficiency
74+ /// in the split fields.
75+ pub fn ntt < const Q : u64 , const D : usize > ( coeffs : Vec < Zq < Q > > ) -> Vec < Zq < Q > > {
76+ assert ! (
77+ D . is_power_of_two( ) ,
78+ "d should be power of two to split completely: d={D}"
79+ ) ;
80+ assert ! ( ( Q - 1 ) . is_multiple_of( 2 * D as u64 ) ) ;
81+
82+ let psi = find_primitive_2d_root_of_unity :: < Q > ( D as u64 ) ;
83+
84+ let mut result = _ntt :: < Q , D > ( coeffs, psi, D as u64 ) ;
85+ _bit_reverse_permutation ( & mut result) ;
86+ result
87+ }
88+
5289/// This is implemented according to this great article https://electricdusk.com/ntt.html
5390/// zeta means current level is Z_q[X]/(X^d - \psi^{zeta_exp})
54- pub fn _ntt < const Q : u64 , const D : u64 > (
55- coeffs : Vec < Zq < Q > > ,
56- psi : Zq < Q > ,
57- zeta_exp : u64 ,
58- ) -> Vec < Zq < Q > > {
91+ fn _ntt < const Q : u64 , const D : usize > ( coeffs : Vec < Zq < Q > > , psi : Zq < Q > , zeta_exp : u64 ) -> Vec < Zq < Q > > {
5992 let d = coeffs. len ( ) ;
6093 assert ! ( ( Q - 1 ) . is_multiple_of( 2 * d as u64 ) ) ;
6194
@@ -64,8 +97,17 @@ pub fn _ntt<const Q: u64, const D: u64>(
6497 if d == 1 {
6598 return vec ! [ coeffs[ 0 ] ] ;
6699 }
100+
101+ // Find the term \zeta^{d/2} for this split, which is used to replace X^{d/2} with `root`
102+ // to reduce the polynomial to a_l and a_r.
103+ // We pass `zeta_exp` instead of \zeta^{d/2} directly because doing square root of field is expensive.
104+ // This is required in later recursion.
105+ // - left: (\zeta^{d/2})^{1/2}
106+ // - right: -(\zeta^{d/2})^{1/2}
107+ // Instead, we track the current exponent of zeta and we can calculate the term.
108+ // Replace X^{d/2} with zeta^ X^{d/2}..X^d
67109 // psi_power = d/2 first.
68- // E.g. d=256, root here is \psi^{128} since X^{256}+1 = (X^{128} - 1 )(X^{128} + 1 )
110+ // E.g. d=256, root here is \psi^{128} since X^{256}+1 = (X^{128} - \zeta^{128} )(X^{128} + \zeta^{128} )
69111 let root = psi. pow ( zeta_exp / 2 ) ;
70112 // Here is the "butterfly" part
71113 // E.g. we're at a \in Z_q[X] / (X^256+1) and we're gonna split to
@@ -94,63 +136,32 @@ pub fn _ntt<const Q: u64, const D: u64>(
94136 // = X^{128} - \psi^{128+D}, where D=256 and \psi^D = -1.
95137 // TODO: we can actually derive the correct root with a precalculated table \psi...\psi^{511}
96138 let a_l_coeffs = _ntt :: < Q , D > ( a_l, psi, zeta_exp / 2 ) ;
97- let a_r_coeffs = _ntt :: < Q , D > ( a_r, psi, zeta_exp / 2 + D ) ;
139+ let a_r_coeffs = _ntt :: < Q , D > ( a_r, psi, zeta_exp / 2 + D as u64 ) ;
98140 a_l_coeffs. into_iter ( ) . chain ( a_r_coeffs) . collect ( )
99141}
100142
101- /// Reverse order of the result.
102- /// Since the result of our NTT would be (w^1, w^5, w^3, w^7) for d=4,
103- /// but we expect it to be (w^1, w^3, w^5, w^7).
104- /// Intuition:
105- /// So it's actually dividing elements k s.t. \psi^{2k+1} is a root of x^d+1
106- /// every layer we put w^{n/2} (even) to the left and -w^{n/2} to the right.
107- /// So we just map the result from NTT back to (w^1, w^3, w^5, w^7) with bit-reverse permutation
108- fn _bit_reverse_permutation < T > ( v : & mut [ T ] ) {
109- let n = v. len ( ) ;
110- let log_n = n. trailing_zeros ( ) ;
111- for i in 0 ..n {
112- let j = i. reverse_bits ( ) >> ( usize:: BITS - log_n) ;
113- if i < j {
114- v. swap ( i, j) ;
115- }
116- }
117- }
118-
119- /// NTT: split polynomials X^d+1 into irreducibles. For negacyclic (X^d+1), to fully split the
120- /// polynomial, we need {d} to be a power of two. Otherwise the last layer wouldn't be degree 1 poly, might be
121- /// degree 2 or something else.
122- /// Here we only deal with the ones can be split *completely* for simplicity and efficiency
123- /// in the split fields.
124- pub fn ntt < const Q : u64 , const D : u64 > ( coeffs : Vec < Zq < Q > > , psi : Zq < Q > ) -> Vec < Zq < Q > > {
143+ /// Inverse NTT: recover evaluations (remainders) in irreducible polynomials Z_q[X]/(X-\zeta^i) back
144+ /// to the single polynomial in Z_q[X]/(X^d+1).
145+ /// Assumption is the same as NTT:
146+ /// 1. 2d | q-1 so primitive 2d-th roots exist.
147+ /// 2. d should be a power of two so the polynomial can be fully split into deg-1.
148+ pub fn intt < const Q : u64 , const D : usize > ( mut evals : Vec < Zq < Q > > ) -> Vec < Zq < Q > > {
125149 assert ! (
126150 D . is_power_of_two( ) ,
127151 "d should be power of two to split completely: d={D}"
128152 ) ;
129- assert ! ( ( Q - 1 ) . is_multiple_of( 2 * D ) ) ;
130-
131- let mut result = _ntt :: < Q , D > ( coeffs, psi, D ) ;
132- _bit_reverse_permutation ( & mut result) ;
133- result
134- }
153+ assert ! ( ( Q - 1 ) . is_multiple_of( 2 * D as u64 ) ) ;
135154
136- pub fn intt < const Q : u64 , const D : u64 > ( mut evals : Vec < Zq < Q > > , psi : Zq < Q > ) -> Vec < Zq < Q > > {
137- assert ! (
138- D . is_power_of_two( ) ,
139- "d should be power of two to split completely: d={D}"
140- ) ;
141- assert ! ( ( Q - 1 ) . is_multiple_of( 2 * D ) ) ;
155+ let psi = find_primitive_2d_root_of_unity :: < Q > ( D as u64 ) ;
142156
143157 // since we need to run iNTT on the original order of the output from NTT
144158 _bit_reverse_permutation ( & mut evals) ;
145159
146- _intt :: < Q , D > ( evals, psi, D )
160+ _intt :: < Q , D > ( evals, psi, D as u64 )
147161}
148162
149- pub fn _intt < const Q : u64 , const D : u64 > (
150- evals : Vec < Zq < Q > > ,
151- psi : Zq < Q > ,
152- zeta_exp : u64 ,
153- ) -> Vec < Zq < Q > > {
163+ /// Inverse NTT: recover polynomials Z_q[X]/(X^d+1) from irreducible polynomials.
164+ fn _intt < const Q : u64 , const D : usize > ( evals : Vec < Zq < Q > > , psi : Zq < Q > , zeta_exp : u64 ) -> Vec < Zq < Q > > {
154165 // return coefficient form
155166 let d = evals. len ( ) ;
156167 assert ! ( ( Q - 1 ) . is_multiple_of( 2 * d as u64 ) ) ;
@@ -162,12 +173,21 @@ pub fn _intt<const Q: u64, const D: u64>(
162173 }
163174 let ( evals_l, evals_r) = evals. split_at ( d / 2 ) ;
164175
165- let a_l = _intt :: < Q , D > ( evals_l. to_vec ( ) , psi, zeta_exp / 2 ) ;
166- let a_r = _intt :: < Q , D > ( evals_r. to_vec ( ) , psi, zeta_exp / 2 + D ) ;
167-
168176 // Inverse butterfly: recover a[i] and a[i+d/2] from a_l[i] and a_r[i]
169- let mut a: Vec < Zq < Q > > = vec ! [ Zq :: <Q >:: zero( ) ; d] ;
177+ // It's just the inverse of NTT butterfly. Observing the first term of a_l(x) and a_r(x)
178+ // - a_l0 = a_0 + \zeta^{d/2} a_{128}
179+ // - a_r0 = a_0 - \zeta^{d/2} a_{128}
180+ // Adding them we get a_0 = 2^{-1} * (a_l0 + a_r0)
181+ // Subtracting them we get a_{128} = 2^{-1} * (a_l0 - a_r0) * \zeta^{-128}
182+ // So we recover a_i and a_{i+d/2} from a_li and a_ri with 2^{-1} and \zeta^{-d/2}
183+
184+ // We use the same approach to calculate \zeta^{128}
170185 let root = psi. pow ( zeta_exp / 2 ) ;
186+ // Recursively prepare a_l and a_r
187+ let a_l = _intt :: < Q , D > ( evals_l. to_vec ( ) , psi, zeta_exp / 2 ) ;
188+ let a_r = _intt :: < Q , D > ( evals_r. to_vec ( ) , psi, zeta_exp / 2 + D as u64 ) ;
189+ // Actual inverse butterfly as described above
190+ let mut a: Vec < Zq < Q > > = vec ! [ Zq :: <Q >:: zero( ) ; d] ;
171191 let two_inv = Zq :: new ( 2 ) . inv ( ) ;
172192 for i in 0 ..( d / 2 ) {
173193 a[ i] = two_inv * ( a_l[ i] + a_r[ i] ) ;
@@ -182,20 +202,20 @@ mod tests {
182202 use super :: * ;
183203
184204 const Q : u64 = 17 ;
185- const D : u64 = 4 ;
205+ const D : usize = 4 ;
186206 type F = Zq < Q > ;
187207
188208 fn setup ( ) -> Zq < Q > {
189- let psi = find_primitive_2d_root_of_unity :: < Q > ( D ) ;
209+ let psi = find_primitive_2d_root_of_unity :: < Q > ( D as u64 ) ;
190210 println ! ( "psi={:?}" , psi) ;
191211 psi
192212 }
193213
194214 #[ test]
195215 fn test_primitive_2d_root_of_unity ( ) {
196216 let psi = setup ( ) ;
197- assert_eq ! ( psi. pow( 2 * D ) , F :: one( ) ) ; // w^{2d} = 1
198- assert_eq ! ( psi. pow( D ) , -F :: one( ) ) ; // w^d = -1
217+ assert_eq ! ( psi. pow( 2 * D as u64 ) , F :: one( ) ) ; // w^{2d} = 1
218+ assert_eq ! ( psi. pow( D as u64 ) , -F :: one( ) ) ; // w^d = -1
199219 }
200220
201221 // Sage test vectors: q=17, d=4, negacyclic NTT (X^d+1)
@@ -218,36 +238,33 @@ mod tests {
218238 evals
219239 } ;
220240 let evals = get_evals ( ) ;
221- assert_eq ! ( ntt:: <Q , D >( coeffs, psi ) , evals) ;
241+ assert_eq ! ( ntt:: <Q , D >( coeffs) , evals) ;
222242 }
223243
224244 #[ test]
225245 fn test_intt_backward ( ) {
226- let psi = setup ( ) ;
227246 let evals = vec ! [ F :: new( 14 ) , F :: new( 0 ) , F :: new( 10 ) , F :: new( 16 ) ] ;
228247 let expected_coeffs = vec ! [ F :: new( 10 ) , F :: new( 4 ) , F :: new( 8 ) , F :: new( 0 ) ] ;
229248
230- assert_eq ! ( intt:: <Q , D >( evals, psi ) , expected_coeffs) ;
249+ assert_eq ! ( intt:: <Q , D >( evals) , expected_coeffs) ;
231250 }
232251
233252 #[ test]
234253 fn test_ntt_intt_roundtrip ( ) {
235254 type F = Zq < Q > ;
236- let psi = setup ( ) ;
237255 let coeffs = vec ! [ F :: new( 16 ) , F :: new( 3 ) , F :: new( 0 ) , F :: new( 14 ) ] ;
238256 let coeffs_clone = coeffs. clone ( ) ;
239- assert_eq ! ( intt:: <Q , D >( ntt:: <Q , D >( coeffs, psi ) , psi ) , coeffs_clone) ;
257+ assert_eq ! ( intt:: <Q , D >( ntt:: <Q , D >( coeffs) ) , coeffs_clone) ;
240258 }
241259
242260 // ─── q=12289, d=1024 (Falcon params) ───
243261
244262 const Q2 : u64 = 12289 ;
245- const D2 : u64 = 1024 ;
263+ const D2 : usize = 1024 ;
246264 type F2 = Zq < Q2 > ;
247265
248266 #[ test]
249267 fn test_ntt_falcon ( ) {
250- let psi = find_primitive_2d_root_of_unity :: < Q2 > ( D2 ) ;
251268 let coeffs_raw: [ u64 ; 1024 ] = [
252269 8633 , 1504 , 11298 , 8147 , 6951 , 5539 , 3291 , 334 , 7732 , 376 , 3099 , 4879 , 9978 , 7512 ,
253270 3274 , 6114 , 4942 , 8255 , 8730 , 758 , 1334 , 5361 , 3507 , 10969 , 5079 , 9882 , 6516 , 4586 ,
@@ -402,9 +419,9 @@ mod tests {
402419 let coeffs: Vec < F2 > = coeffs_raw. iter ( ) . map ( |& c| F2 :: new ( c) ) . collect ( ) ;
403420 let expected_evals: Vec < F2 > = evals_raw. iter ( ) . map ( |& e| F2 :: new ( e) ) . collect ( ) ;
404421
405- let actual_evals = ntt :: < Q2 , D2 > ( coeffs. clone ( ) , psi ) ;
422+ let actual_evals = ntt :: < Q2 , D2 > ( coeffs. clone ( ) ) ;
406423 assert_eq ! ( actual_evals, expected_evals) ;
407- let coeffs_roundtrip = intt :: < Q2 , D2 > ( actual_evals, psi ) ;
424+ let coeffs_roundtrip = intt :: < Q2 , D2 > ( actual_evals) ;
408425 assert_eq ! ( coeffs, coeffs_roundtrip) ;
409426 }
410427}
0 commit comments