Skip to content

Commit

Permalink
tmp
Browse files Browse the repository at this point in the history
  • Loading branch information
Lucas-Wye committed Feb 7, 2025
1 parent 9ecdf6f commit 6722f89
Show file tree
Hide file tree
Showing 3 changed files with 184 additions and 14 deletions.
34 changes: 20 additions & 14 deletions tests/eval/_ntt/default.nix
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
{ linkerScript
, makeBuilder
, python3
, t1main
}:

let
builder = makeBuilder { casePrefix = "eval"; };
build_ntt = caseName /* must be consistent with attr name */ : main_src: kernel_src:
build_ntt = caseName /* must be consistent with attr name */ : main_src: kernel_src: extra_flag:
builder {
caseName = caseName;

Expand All @@ -16,7 +17,11 @@ let
buildPhase = ''
runHook preBuild
${python3}/bin/python3 ./gen_data.py ${caseName}
$CC -T${linkerScript} \
-DCASE=${caseName} \
${extra_flag} \
${main_src} ${kernel_src} \
${t1main} \
-o $pname.elf
Expand All @@ -28,17 +33,18 @@ let
};

in {
ntt_64 = build_ntt "ntt_64" ./ntt.c ./ntt_64_main.c;
ntt_128 = build_ntt "ntt_128" ./ntt.c ./ntt_128_main.c;
ntt_256 = build_ntt "ntt_256" ./ntt.c ./ntt_256_main.c;
ntt_512 = build_ntt "ntt_512" ./ntt.c ./ntt_512_main.c;
ntt_1024 = build_ntt "ntt_1024" ./ntt.c ./ntt_1024_main.c;
ntt_4096 = build_ntt "ntt_4096" ./ntt.c ./ntt_4096_main.c;

ntt_mem_64 = build_ntt "ntt_mem_64" ./ntt_mem.c ./ntt_64_main.c;
ntt_mem_128 = build_ntt "ntt_mem_128" ./ntt_mem.c ./ntt_128_main.c;
ntt_mem_256 = build_ntt "ntt_mem_256" ./ntt_mem.c ./ntt_256_main.c;
ntt_mem_512 = build_ntt "ntt_mem_512" ./ntt_mem.c ./ntt_512_main.c;
ntt_mem_1024 = build_ntt "ntt_mem_1024" ./ntt_mem.c ./ntt_1024_main.c;
ntt_mem_4096 = build_ntt "ntt_mem_4096" ./ntt_mem.c ./ntt_4096_main.c;
ntt_test = build_ntt "ntt_64" ./ntt.c ./ntt_main.c "";
ntt_64 = build_ntt "ntt_64" ./ntt.c ./ntt_main.c "";
ntt_128 = build_ntt "ntt_128" ./ntt.c ./ntt_main.c "";
ntt_256 = build_ntt "ntt_256" ./ntt.c ./ntt_main.c "";
ntt_512 = build_ntt "ntt_512" ./ntt.c ./ntt_main.c "";
ntt_1024 = build_ntt "ntt_1024" ./ntt.c ./ntt_main.c "";
ntt_4096 = build_ntt "ntt_4096" ./ntt.c ./ntt_main.c "";

ntt_mem_64 = build_ntt "ntt_mem_64" ./ntt_mem.c ./ntt_main.c "-DUSE_SCALAR";
ntt_mem_128 = build_ntt "ntt_mem_128" ./ntt_mem.c ./ntt_main.c "-DUSE_SCALAR";
ntt_mem_256 = build_ntt "ntt_mem_256" ./ntt_mem.c ./ntt_main.c "-DUSE_SCALAR";
ntt_mem_512 = build_ntt "ntt_mem_512" ./ntt_mem.c ./ntt_main.c "-DUSE_SCALAR";
ntt_mem_1024 = build_ntt "ntt_mem_1024" ./ntt_mem.c ./ntt_main.c "-DUSE_SCALAR";
ntt_mem_4096 = build_ntt "ntt_mem_4096" ./ntt_mem.c ./ntt_main.c "-DUSE_SCALAR";
}
93 changes: 93 additions & 0 deletions tests/eval/_ntt/gen_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
import json
import random

def genRandomPoly(l, p):
n = 1 << l
a = [random.randrange(p) for _ in range(n)]
return a

def genGoldPoly(l, p, g, poly):
n = 1 << l
poly_out = []
for i in range(n):
tmp = 0
for j in range(n):
tmp += poly[j] * pow(g, i * j, p)
tmp = tmp % p
poly_out.append(tmp)
return poly_out

def genScalarTW(l, p, g):
w = g

twiddle_list = []
for _ in range(l):
twiddle_list.append(w)
w = (w * w) % p

return twiddle_list

def genVectorTW(l, p, g):
n = 1 << l
m = 2
layerIndex = 0

outTW = []
while m <= n:
# print(f"// layer #{layerIndex}")
layerIndex+=1
wPower = 0

for j in range(m//2):
k = 0
while k < n:
currentW = pow(g, wPower, p)
k += m
outTW.append(currentW)
# print(currentW, end =", ")
wPower += n // m
m *= 2
# print("\n")
return outTW

def main(l, p, g):
poly_in = genRandomPoly(l, p)
poly_out = genGoldPoly(l, p, g, poly_in)
scalar_tw = genScalarTW(l, p, g)
vector_tw = genVectorTW(l, p, g)
n = 1 << l
data = {
"l": l,
"n": n,
"p": p,
"input": poly_in,
"output": poly_out,
"vector_tw": vector_tw,
"scalar_tw": scalar_tw,
}
# json_name = "ntt_" + str(n) + ".json"
# with open(json_name, "w") as out:
# json.dump(data, out)
header_file = "ntt_" + str(n) + ".h"

# with open(json_name, "r") as json_in:
# data = json.load(json_in)
header_str = "#define macroL " + str(data["l"]) + "\n"
header_str += "#define macroN " + str(data["n"]) + "\n"
header_str += "#define macroP " + str(data["p"]) + "\n"
header_str += "#define macroIn " + ','.join(str(e) for e in data["input"]) + "\n"
header_str += "#define macroOut " + ','.join(str(e) for e in data["output"]) + "\n"
header_str += "#define macroScalarTW " + ','.join(str(e) for e in data["scalar_tw"]) + "\n"
header_str += "#define macroVectorTW " + ','.join(str(e) for e in data["vector_tw"]) + "\n"
with open(header_file, "w") as header_out:
header_out.write(header_str)

if __name__ == '__main__':
p = 12289
main(6, p, 7311)
main(7, p, 12149)
main(8, p, 8340)
main(9, p, 3400)
main(10, p, 10302)
main(12, p, 1331)

71 changes: 71 additions & 0 deletions tests/eval/_ntt/ntt_main.c
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
#include <stdio.h>
#include <assert.h>

// #define USE_SCALAR
#define DEBUG
// #define WITHMAIN

#if(CASE==ntt_64)
#include "ntt_64.h"
#elif(CASE==ntt_128)
#include "ntt_128.h"
#elif(CASE==ntt_256)
#include "ntt_256.h"
#elif(CASE==ntt_512)
#include "ntt_512.h"
#elif(CASE==ntt_1024)
#include "ntt_1024.h"
#elif(CASE==ntt_4096)
#include "ntt_4096.h"
#else
#error "undefined ntt case"
#endif


void ntt(const int *array, int l, const int *twiddle, int p, int *dst);

void test() {
const int l = macroL;
const int n = macroN;
static const int arr[macroN] = {
macroIn
};
#ifdef USE_SCALAR
static const int twiddle[] = {
macroScalarTW
};
#else
static const int twiddle[] = {
macroVectorTW
};
#endif
const int p = macroP;
int dst[macroN];
ntt(arr, l, twiddle, p, dst);

#ifdef DEBUG
const int gold[macroN] = {
macroOut
};
printf("n = %d\n", n);
for (int i = 0; i < n; i++) {
// dst[i] = dst[i] % p;
if(dst[i] < 0) dst[i] += p;
if(dst[i] != gold[i]) {
printf("(%d %d, i)", dst[i], gold[i], i);
if ((i + 1) % 8 == 0) {
printf("\n");
} else {
printf(" ");
}
}
}
#endif
}

#ifdef WITHMAIN
int main(void) {
test();
}
#endif

0 comments on commit 6722f89

Please sign in to comment.