Skip to content

Commit

Permalink
[eval] add eval.ntt*
Browse files Browse the repository at this point in the history
Co-Authored-By: Lucas-Wye <[email protected]>
  • Loading branch information
2 people authored and sequencer committed Jan 26, 2025
1 parent 94cbbc8 commit 804da11
Show file tree
Hide file tree
Showing 13 changed files with 1,281 additions and 1 deletion.
44 changes: 44 additions & 0 deletions tests/eval/_ntt/default.nix
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
{ linkerScript
, makeBuilder
, t1main
}:

let
builder = makeBuilder { casePrefix = "eval"; };
build_ntt = caseName /* must be consistent with attr name */ : main_src: kernel_src:
builder {
caseName = caseName;

src = ./.;

passthru.featuresRequired = { };

buildPhase = ''
runHook preBuild
$CC -T${linkerScript} \
${main_src} ${kernel_src} \
${t1main} \
-o $pname.elf
runHook postBuild
'';

meta.description = "test case 'ntt'";
};

in {
ntt_64 = build_ntt "ntt_64" ./ntt.c ./ntt_64_main.c;
ntt_128 = build_ntt "ntt_128" ./ntt.c ./ntt_128_main.c;
ntt_256 = build_ntt "ntt_256" ./ntt.c ./ntt_256_main.c;
ntt_512 = build_ntt "ntt_512" ./ntt.c ./ntt_512_main.c;
ntt_1024 = build_ntt "ntt_1024" ./ntt.c ./ntt_1024_main.c;
ntt_4096 = build_ntt "ntt_4096" ./ntt.c ./ntt_4096_main.c;

ntt_mem_64 = build_ntt "ntt_mem_64" ./ntt_mem.c ./ntt_64_main.c;
ntt_mem_128 = build_ntt "ntt_mem_128" ./ntt_mem.c ./ntt_128_main.c;
ntt_mem_256 = build_ntt "ntt_mem_256" ./ntt_mem.c ./ntt_256_main.c;
ntt_mem_512 = build_ntt "ntt_mem_512" ./ntt_mem.c ./ntt_512_main.c;
ntt_mem_1024 = build_ntt "ntt_mem_1024" ./ntt_mem.c ./ntt_1024_main.c;
ntt_mem_4096 = build_ntt "ntt_mem_4096" ./ntt_mem.c ./ntt_4096_main.c;
}
24 changes: 24 additions & 0 deletions tests/eval/_ntt/gen_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
import random

def main():
vlen = 4096
l = 12
n = 1 << l
# assert n <= vlen // 4
p = 12289 # p is prime and n | p - 1
g = 11 # primitive root of p
assert (p - 1) % n == 0
w = (g ** ((p - 1) // n)) % p # now w^n == 1 mod p by Fermat's little theorem
print(w)

twindle_list = []
for _ in range(l):
twindle_list.append(w)
w = (w * w) % p
print(twindle_list)

a = [random.randrange(p) for _ in range(n)]
print(a)

if __name__ == '__main__':
main()
53 changes: 53 additions & 0 deletions tests/eval/_ntt/gen_vector_ntt_tw.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
def gen_tw_for_vector_ntt(l, w_one, prime_p):
n = pow(2, l)
w_power_list = []
m = 2
while m <= n:
w_power = 0
w = 1
w_power_dict = {}
for j in range(m // 2):
k = 0
while k < n:
i_u = k + j
i_t = k + j + m //2
k += m
w_power_dict[i_u] = (i_t, w_power)
w_power += n//m
m = 2 * m
w_power_list.append(w_power_dict)

# print(w_power_list)
perm_each = { }
for i in range(n//2):
perm_each[i] = i
perm_each[i+n//2] = i + n//2
# print("(coe 0, 1), w_power, (permu 0, 1)\n")
print(f"\nfor ntt {n}")
layer_index = 0
for w_power_dict in w_power_list:
print(f"// layer #{layer_index}")
layer_index += 1

# sort_keys = sorted(w_power_dict.keys())
sort_keys = w_power_dict.keys()
index = 0
for w_key in sort_keys:
# print(f"({w_key}, {w_power_dict[w_key][0]}), {w_power_dict[w_key][1]}, ", end = "")
# print(f"({perm_each[w_key]}, {perm_each[w_power_dict[w_key][0]]})")
current_w = pow(w_one, w_power_dict[w_key][1], prime_p)
print(current_w, end = ", ")
perm_each[w_key] = index
perm_each[w_power_dict[w_key][0]] = index + n//2
index += 1

print("\n")

if __name__ == '__main__':
gen_tw_for_vector_ntt(6, 7311, 12289)
gen_tw_for_vector_ntt(7, 12149, 12289)
gen_tw_for_vector_ntt(8, 8340, 12289)
gen_tw_for_vector_ntt(9, 3400, 12289)
gen_tw_for_vector_ntt(10, 10302, 12289)
gen_tw_for_vector_ntt(12, 1331, 12289)

135 changes: 135 additions & 0 deletions tests/eval/_ntt/ntt.c
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
#include <assert.h>
#include <stdio.h>

// #define USERN 32
// #define DEBUG

// array is of length n=2^l, p is a prime number
// roots is of length l, where g = roots[0] satisfies that
// g^(2^l) == 1 mod p and g^(2^(l-1)) == -1 mod p
// roots[i] = g^(2^i) (hence roots[l - 1] = -1)
//
// 32bit * n <= VLEN * 8 => n <= VLEN / 4
void ntt(const int *array, int l, const int *twiddle, int p, int *dst) {
// prepare an array of permutation indices
assert(l <= 16);

int n = 1 << l;

// registers:
// v8-15: array
// v16-24: loaded elements (until vrgather)
// v4-7: permutation index (until vrgather)
// v16-24: coefficients
int vlenb;
asm("csrr %0, vlenb" : "=r"(vlenb));
int elements_in_vreg = vlenb * 2;
assert(elements_in_vreg >= n);

asm("vsetvli zero, %0, e16, m4, tu, mu\n"
"vid.v v4\n"
:
: "r"(n));

// prepare the bit-reversal permutation list
for (int k = 0; 2 * k < l; k++) {
asm("vand.vx v8, v4, %0\n"
"vsub.vv v4, v4, v8\n"
"vsll.vx v8, v8, %1\n" // get the k-th digit and shift left

"vand.vx v12, v4, %2\n"
"vsub.vv v4, v4, v12\n"
"vsrl.vx v12, v12, %1\n" // get the (l-k-1)-th digit and shift right

"vor.vv v4, v4, v8\n"
"vor.vv v4, v4, v12\n"

:
: "r"(1 << k), "r"(l - 1 - 2 * k), "r"(1 << (l - k - 1)));
}

// perform bit-reversal for input coefficients
asm("vsetvli zero, %0, e32, m8, tu, mu\n"
"vle32.v v16, 0(%1)\n"
"vrgatherei16.vv v8, v16, v4\n"
"vse32.v v8, 0(%2)\n"

:
: "r"(n), "r"(array), "r"(dst));

// generate permutation list (0, 2, 4, ..., 1, 3, 5, ...)
asm("vsetvli zero, %0, e16, m4, tu, mu\n"
"vid.v v4\n"
"vsrl.vx v0, v4, %1\n" // (0, 0, 0, 0, ..., 1, 1, 1, 1, ...)
"vand.vx v4, v4, %2\n" // (0, 1, 2, 3, ..., 0, 1, 2, 3, ...)
"vsll.vi v4, v4, 1\n"
"vadd.vv v4, v4, v0\n"

:
: "r"(n), "r"(l-1), "r"((n / 2 - 1)), "r"(n / 2));

#ifdef DEBUG
int tmp1[USERN];// c
int tmp2[USERN];// c
int tmp3[USERN];// c
#endif

for (int k = 0; k < l; k++) {
asm(
// "n" mode
"vsetvli zero, %0, e32, m8, tu, mu\n"
// load coefficients
"vle32.v v16, 0(%4)\n"
// perform permutation for coefficient
"vrgatherei16.vv v8, v16, v4\n"
// save coefficients
"vse32.v v8, 0(%4)\n"

// "n/2" mode
"vsetvli zero, %1, e32, m4, tu, mu\n"
// load twiddle factors
"vle32.v v16, 0(%2)\n"
// load half coefficients
"vle32.v v8, 0(%4)\n"
"vle32.v v12, 0(%5)\n"

#ifdef DEBUG
"vse32.v v8, 0(%6)\n"// c
"vse32.v v12, 0(%7)\n"// c
"vse32.v v16, 0(%8)\n"// c
#endif

// butterfly operation
"vmul.vv v12, v12, v16\n"
"vrem.vx v12, v12, %3\n"
"vadd.vv v16, v8, v12\n" // TODO: will it overflow?
"vsub.vv v20, v8, v12\n"
// save half coefficients
"vse32.v v16, 0(%4)\n"
"vse32.v v20, 0(%5)\n"
:
: /* %0 */ "r"(n),
/* %1 */ "r"(n / 2),
/* %2 */ "r"(twiddle + k * (n / 2)),
/* %3 */ "r"(p),
"r"(dst),
"r"(dst + (n / 2))
#ifdef DEBUG
, "r"(tmp1), "r"(tmp2), "r"(tmp3)
#endif
);
#ifdef DEBUG
for(int k = 0; k < USERN; k++) {
printf("(%x, %x, %x)\n", tmp1[k], tmp2[k], tmp3[k]);
}
#endif
}
// deal with modular
asm("vsetvli zero, %0, e32, m8, tu, mu\n"
"vle32.v v16, 0(%1)\n"
"vrem.vx v8, v16, %2\n"
"vse32.v v8, 0(%1)\n"

:
: "r"(n), "r"(dst), "r"(p));
}
Loading

0 comments on commit 804da11

Please sign in to comment.