diff --git a/tests/eval/_ntt/default.nix b/tests/eval/_ntt/default.nix index 77a1a7a0b..1a525732e 100644 --- a/tests/eval/_ntt/default.nix +++ b/tests/eval/_ntt/default.nix @@ -1,11 +1,12 @@ { linkerScript , makeBuilder +, python3 , t1main }: let builder = makeBuilder { casePrefix = "eval"; }; - build_ntt = caseName /* must be consistent with attr name */ : main_src: kernel_src: + build_ntt = caseName /* must be consistent with attr name */ : main_src: kernel_src: extra_flag: builder { caseName = caseName; @@ -16,7 +17,11 @@ let buildPhase = '' runHook preBuild + ${python3}/bin/python3 ./gen_data.py ${caseName} + $CC -T${linkerScript} \ + -DCASE=${caseName} \ + ${extra_flag} \ ${main_src} ${kernel_src} \ ${t1main} \ -o $pname.elf @@ -28,17 +33,18 @@ let }; in { - ntt_64 = build_ntt "ntt_64" ./ntt.c ./ntt_64_main.c; - ntt_128 = build_ntt "ntt_128" ./ntt.c ./ntt_128_main.c; - ntt_256 = build_ntt "ntt_256" ./ntt.c ./ntt_256_main.c; - ntt_512 = build_ntt "ntt_512" ./ntt.c ./ntt_512_main.c; - ntt_1024 = build_ntt "ntt_1024" ./ntt.c ./ntt_1024_main.c; - ntt_4096 = build_ntt "ntt_4096" ./ntt.c ./ntt_4096_main.c; - - ntt_mem_64 = build_ntt "ntt_mem_64" ./ntt_mem.c ./ntt_64_main.c; - ntt_mem_128 = build_ntt "ntt_mem_128" ./ntt_mem.c ./ntt_128_main.c; - ntt_mem_256 = build_ntt "ntt_mem_256" ./ntt_mem.c ./ntt_256_main.c; - ntt_mem_512 = build_ntt "ntt_mem_512" ./ntt_mem.c ./ntt_512_main.c; - ntt_mem_1024 = build_ntt "ntt_mem_1024" ./ntt_mem.c ./ntt_1024_main.c; - ntt_mem_4096 = build_ntt "ntt_mem_4096" ./ntt_mem.c ./ntt_4096_main.c; + ntt_test = build_ntt "ntt_64" ./ntt.c ./ntt_main.c ""; + ntt_64 = build_ntt "ntt_64" ./ntt.c ./ntt_main.c ""; + ntt_128 = build_ntt "ntt_128" ./ntt.c ./ntt_main.c ""; + ntt_256 = build_ntt "ntt_256" ./ntt.c ./ntt_main.c ""; + ntt_512 = build_ntt "ntt_512" ./ntt.c ./ntt_main.c ""; + ntt_1024 = build_ntt "ntt_1024" ./ntt.c ./ntt_main.c ""; + ntt_4096 = build_ntt "ntt_4096" ./ntt.c ./ntt_main.c ""; + + ntt_mem_64 = build_ntt "ntt_mem_64" ./ntt_mem.c ./ntt_main.c "-DUSE_SCALAR"; + ntt_mem_128 = build_ntt "ntt_mem_128" ./ntt_mem.c ./ntt_main.c "-DUSE_SCALAR"; + ntt_mem_256 = build_ntt "ntt_mem_256" ./ntt_mem.c ./ntt_main.c "-DUSE_SCALAR"; + ntt_mem_512 = build_ntt "ntt_mem_512" ./ntt_mem.c ./ntt_main.c "-DUSE_SCALAR"; + ntt_mem_1024 = build_ntt "ntt_mem_1024" ./ntt_mem.c ./ntt_main.c "-DUSE_SCALAR"; + ntt_mem_4096 = build_ntt "ntt_mem_4096" ./ntt_mem.c ./ntt_main.c "-DUSE_SCALAR"; } diff --git a/tests/eval/_ntt/gen_data.py b/tests/eval/_ntt/gen_data.py new file mode 100644 index 000000000..220ad87b8 --- /dev/null +++ b/tests/eval/_ntt/gen_data.py @@ -0,0 +1,93 @@ +import json +import random + +def genRandomPoly(l, p): + n = 1 << l + a = [random.randrange(p) for _ in range(n)] + return a + +def genGoldPoly(l, p, g, poly): + n = 1 << l + poly_out = [] + for i in range(n): + tmp = 0 + for j in range(n): + tmp += poly[j] * pow(g, i * j, p) + tmp = tmp % p + poly_out.append(tmp) + return poly_out + +def genScalarTW(l, p, g): + w = g + + twiddle_list = [] + for _ in range(l): + twiddle_list.append(w) + w = (w * w) % p + + return twiddle_list + +def genVectorTW(l, p, g): + n = 1 << l + m = 2 + layerIndex = 0 + + outTW = [] + while m <= n: + # print(f"// layer #{layerIndex}") + layerIndex+=1 + wPower = 0 + + for j in range(m//2): + k = 0 + while k < n: + currentW = pow(g, wPower, p) + k += m + outTW.append(currentW) + # print(currentW, end =", ") + wPower += n // m + m *= 2 + # print("\n") + return outTW + +def main(l, p, g): + poly_in = genRandomPoly(l, p) + poly_out = genGoldPoly(l, p, g, poly_in) + scalar_tw = genScalarTW(l, p, g) + vector_tw = genVectorTW(l, p, g) + n = 1 << l + data = { + "l": l, + "n": n, + "p": p, + "input": poly_in, + "output": poly_out, + "vector_tw": vector_tw, + "scalar_tw": scalar_tw, + } + # json_name = "ntt_" + str(n) + ".json" + # with open(json_name, "w") as out: + # json.dump(data, out) + header_file = "ntt_" + str(n) + ".h" + + # with open(json_name, "r") as json_in: + # data = json.load(json_in) + header_str = "#define macroL " + str(data["l"]) + "\n" + header_str += "#define macroN " + str(data["n"]) + "\n" + header_str += "#define macroP " + str(data["p"]) + "\n" + header_str += "#define macroIn " + ','.join(str(e) for e in data["input"]) + "\n" + header_str += "#define macroOut " + ','.join(str(e) for e in data["output"]) + "\n" + header_str += "#define macroScalarTW " + ','.join(str(e) for e in data["scalar_tw"]) + "\n" + header_str += "#define macroVectorTW " + ','.join(str(e) for e in data["vector_tw"]) + "\n" + with open(header_file, "w") as header_out: + header_out.write(header_str) + +if __name__ == '__main__': + p = 12289 + main(6, p, 7311) + main(7, p, 12149) + main(8, p, 8340) + main(9, p, 3400) + main(10, p, 10302) + main(12, p, 1331) + diff --git a/tests/eval/_ntt/ntt_main.c b/tests/eval/_ntt/ntt_main.c new file mode 100644 index 000000000..45f37f1e9 --- /dev/null +++ b/tests/eval/_ntt/ntt_main.c @@ -0,0 +1,71 @@ +#include +#include + +// #define USE_SCALAR +#define DEBUG +// #define WITHMAIN + +#if(CASE==ntt_64) +#include "ntt_64.h" +#elif(CASE==ntt_128) +#include "ntt_128.h" +#elif(CASE==ntt_256) +#include "ntt_256.h" +#elif(CASE==ntt_512) +#include "ntt_512.h" +#elif(CASE==ntt_1024) +#include "ntt_1024.h" +#elif(CASE==ntt_4096) +#include "ntt_4096.h" +#else +#error "undefined ntt case" +#endif + + +void ntt(const int *array, int l, const int *twiddle, int p, int *dst); + +void test() { + const int l = macroL; + const int n = macroN; + static const int arr[macroN] = { + macroIn + }; +#ifdef USE_SCALAR + static const int twiddle[] = { + macroScalarTW + }; +#else + static const int twiddle[] = { + macroVectorTW + }; +#endif + const int p = macroP; + int dst[macroN]; + ntt(arr, l, twiddle, p, dst); + +#ifdef DEBUG + const int gold[macroN] = { + macroOut + }; + printf("n = %d\n", n); + for (int i = 0; i < n; i++) { + // dst[i] = dst[i] % p; + if(dst[i] < 0) dst[i] += p; + if(dst[i] != gold[i]) { + printf("(%d %d, i)", dst[i], gold[i], i); + if ((i + 1) % 8 == 0) { + printf("\n"); + } else { + printf(" "); + } + } + } +#endif +} + +#ifdef WITHMAIN +int main(void) { + test(); +} +#endif +