diff --git a/ecc/ecc.go b/ecc/ecc.go index 4ce9bd2d10..70925ca1b1 100644 --- a/ecc/ecc.go +++ b/ecc/ecc.go @@ -36,11 +36,12 @@ const ( BW6_633 STARK_CURVE SECP256K1 + GRUMPKIN ) // Implemented return the list of curves fully implemented in gnark-crypto func Implemented() []ID { - return []ID{BN254, BLS12_377, BLS12_381, BW6_761, BLS24_315, BW6_633, BLS24_317, STARK_CURVE, SECP256K1} + return []ID{BN254, BLS12_377, BLS12_381, BW6_761, BLS24_315, BW6_633, BLS24_317, STARK_CURVE, SECP256K1, GRUMPKIN} } func IDFromString(s string) (ID, error) { @@ -91,6 +92,8 @@ func (id ID) config() *config.Curve { return &config.STARK_CURVE case SECP256K1: return &config.SECP256K1 + case GRUMPKIN: + return &config.GRUMPKIN default: panic("unimplemented ecc ID") } diff --git a/ecc/ecc.md b/ecc/ecc.md index ef795dd006..b0be48033a 100644 --- a/ecc/ecc.md +++ b/ecc/ecc.md @@ -2,6 +2,7 @@ * BLS12-381 (Zcash) * BN254 (Ethereum) +* GRUMPKIN (2-cycle with BN254) * BLS12-377 (ZEXE) * BW6-761 (2-chain with BLS12-377) * BLS24-315 diff --git a/ecc/grumpkin/ecdsa/doc.go b/ecc/grumpkin/ecdsa/doc.go new file mode 100644 index 0000000000..94971c427d --- /dev/null +++ b/ecc/grumpkin/ecdsa/doc.go @@ -0,0 +1,17 @@ +// Copyright 2020-2025 Consensys Software Inc. +// Licensed under the Apache License, Version 2.0. See the LICENSE file for details. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +// Package ecdsa provides ECDSA signature scheme on the grumpkin curve. +// +// The implementation is adapted from https://pkg.go.dev/crypto/ecdsa. +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. +// +// Documentation: +// - Wikipedia: https://en.wikipedia.org/wiki/Elliptic_Curve_Digital_Signature_Algorithm +// - FIPS 186-4: https://nvlpubs.nist.gov/nistpubs/FIPS/NIST.FIPS.186-4.pdf +// - SEC 1, v-2: https://www.secg.org/sec1-v2.pdf +package ecdsa diff --git a/ecc/grumpkin/ecdsa/ecdsa.go b/ecc/grumpkin/ecdsa/ecdsa.go new file mode 100644 index 0000000000..3911c0ad9c --- /dev/null +++ b/ecc/grumpkin/ecdsa/ecdsa.go @@ -0,0 +1,294 @@ +// Copyright 2020-2025 Consensys Software Inc. +// Licensed under the Apache License, Version 2.0. See the LICENSE file for details. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package ecdsa + +import ( + "crypto/aes" + "crypto/cipher" + "crypto/rand" + "crypto/sha512" + "crypto/subtle" + "hash" + "io" + "math/big" + + "github.com/consensys/gnark-crypto/ecc/grumpkin" + "github.com/consensys/gnark-crypto/ecc/grumpkin/fp" + "github.com/consensys/gnark-crypto/ecc/grumpkin/fr" + "github.com/consensys/gnark-crypto/signature" +) + +const ( + sizeFr = fr.Bytes + sizeFrBits = fr.Bits + sizeFp = fp.Bytes + sizePublicKey = sizeFp + sizePrivateKey = sizeFr + sizePublicKey + sizeSignature = 2 * sizeFr +) + +var order = fr.Modulus() + +// PublicKey represents an ECDSA public key +type PublicKey struct { + A grumpkin.G1Affine +} + +// PrivateKey represents an ECDSA private key +type PrivateKey struct { + PublicKey PublicKey + scalar [sizeFr]byte // secret scalar, in big Endian +} + +// Signature represents an ECDSA signature +type Signature struct { + R, S [sizeFr]byte +} + +var one = new(big.Int).SetInt64(1) + +// randFieldElement returns a random element of the order of the given +// curve using the procedure given in FIPS 186-4, Appendix B.5.1. +func randFieldElement(rand io.Reader) (k *big.Int, err error) { + b := make([]byte, fr.Bits/8+8) + _, err = io.ReadFull(rand, b) + if err != nil { + return + } + + k = new(big.Int).SetBytes(b) + n := new(big.Int).Sub(order, one) + k.Mod(k, n) + k.Add(k, one) + return +} + +// GenerateKey generates a public and private key pair. +func GenerateKey(rand io.Reader) (*PrivateKey, error) { + + k, err := randFieldElement(rand) + if err != nil { + return nil, err + + } + _, g := grumpkin.Generators() + + privateKey := new(PrivateKey) + k.FillBytes(privateKey.scalar[:sizeFr]) + privateKey.PublicKey.A.ScalarMultiplication(&g, k) + return privateKey, nil +} + +// HashToInt converts a hash value to an integer. Per FIPS 186-4, Section 6.4, +// we use the left-most bits of the hash to match the bit-length of the order of +// the curve. This also performs Step 5 of SEC 1, Version 2.0, Section 4.1.3. +func HashToInt(hash []byte) *big.Int { + if len(hash) > sizeFr { + hash = hash[:sizeFr] + } + ret := new(big.Int).SetBytes(hash) + excess := ret.BitLen() - sizeFrBits + if excess > 0 { + ret.Rsh(ret, uint(excess)) + } + return ret +} + +type zr struct{} + +// Read replaces the contents of dst with zeros. It is safe for concurrent use. +func (zr) Read(dst []byte) (n int, err error) { + for i := range dst { + dst[i] = 0 + } + return len(dst), nil +} + +var zeroReader = zr{} + +const ( + aesIV = "gnark-crypto IV." // must be 16 chars (equal block size) +) + +func nonce(privateKey *PrivateKey, hash []byte) (csprng *cipher.StreamReader, err error) { + // This implementation derives the nonce from an AES-CTR CSPRNG keyed by: + // + // SHA2-512(privateKey.scalar ∥ entropy ∥ hash)[:32] + // + // The CSPRNG key is indifferentiable from a random oracle as shown in + // [Coron], the AES-CTR stream is indifferentiable from a random oracle + // under standard cryptographic assumptions (see [Larsson] for examples). + // + // [Coron]: https://cs.nyu.edu/~dodis/ps/merkle.pdf + // [Larsson]: https://web.archive.org/web/20040719170906/https://www.nada.kth.se/kurser/kth/2D1441/semteo03/lecturenotes/assump.pdf + + // Get 256 bits of entropy from rand. + entropy := make([]byte, 32) + _, err = io.ReadFull(rand.Reader, entropy) + if err != nil { + return + + } + + // Initialize an SHA-512 hash context; digest... + md := sha512.New() + md.Write(privateKey.scalar[:sizeFr]) // the private key, + md.Write(entropy) // the entropy, + md.Write(hash) // and the input hash; + key := md.Sum(nil)[:32] // and compute ChopMD-256(SHA-512), + // which is an indifferentiable MAC. + + // Create an AES-CTR instance to use as a CSPRNG. + block, _ := aes.NewCipher(key) + + // Create a CSPRNG that xors a stream of zeros with + // the output of the AES-CTR instance. + csprng = &cipher.StreamReader{ + R: zeroReader, + S: cipher.NewCTR(block, []byte(aesIV)), + } + + return csprng, err +} + +// Equal compares 2 public keys +func (pub *PublicKey) Equal(x signature.PublicKey) bool { + xx, ok := x.(*PublicKey) + if !ok { + return false + } + bpk := pub.Bytes() + bxx := xx.Bytes() + return subtle.ConstantTimeCompare(bpk, bxx) == 1 +} + +// Public returns the public key associated to the private key. +func (privKey *PrivateKey) Public() signature.PublicKey { + var pub PublicKey + pub.A.Set(&privKey.PublicKey.A) + return &pub +} + +// Sign performs the ECDSA signature +// +// k ← 𝔽r (random) +// P = k ⋅ g1Gen +// r = x_P (mod order) +// s = k⁻¹ . (m + sk ⋅ r) +// signature = {r, s} +// +// SEC 1, Version 2.0, Section 4.1.3 +func (privKey *PrivateKey) Sign(message []byte, hFunc hash.Hash) ([]byte, error) { + scalar, r, s, kInv := new(big.Int), new(big.Int), new(big.Int), new(big.Int) + scalar.SetBytes(privKey.scalar[:sizeFr]) + for { + for { + csprng, err := nonce(privKey, message) + if err != nil { + return nil, err + } + k, err := randFieldElement(csprng) + if err != nil { + return nil, err + } + + var P grumpkin.G1Affine + P.ScalarMultiplicationBase(k) + kInv.ModInverse(k, order) + + P.X.BigInt(r) + + r.Mod(r, order) + if r.Sign() != 0 { + break + } + } + s.Mul(r, scalar) + + var m *big.Int + if hFunc != nil { + // compute the hash of the message as an integer + dataToHash := make([]byte, len(message)) + copy(dataToHash[:], message[:]) + hFunc.Reset() + _, err := hFunc.Write(dataToHash[:]) + if err != nil { + return nil, err + } + hramBin := hFunc.Sum(nil) + m = HashToInt(hramBin) + } else { + m = HashToInt(message) + } + + s.Add(m, s). + Mul(kInv, s). + Mod(s, order) // order != 0 + if s.Sign() != 0 { + break + } + } + + var sig Signature + r.FillBytes(sig.R[:sizeFr]) + s.FillBytes(sig.S[:sizeFr]) + + return sig.Bytes(), nil +} + +// Verify validates the ECDSA signature +// +// R ?= (s⁻¹ ⋅ m ⋅ Base + s⁻¹ ⋅ R ⋅ publiKey)_x +// +// SEC 1, Version 2.0, Section 4.1.4 +func (publicKey *PublicKey) Verify(sigBin, message []byte, hFunc hash.Hash) (bool, error) { + + // Deserialize the signature + var sig Signature + if _, err := sig.SetBytes(sigBin); err != nil { + return false, err + } + + r, s := new(big.Int), new(big.Int) + r.SetBytes(sig.R[:sizeFr]) + s.SetBytes(sig.S[:sizeFr]) + + sInv := new(big.Int).ModInverse(s, order) + + var m *big.Int + if hFunc != nil { + // compute the hash of the message as an integer + dataToHash := make([]byte, len(message)) + copy(dataToHash[:], message[:]) + hFunc.Reset() + _, err := hFunc.Write(dataToHash[:]) + if err != nil { + return false, err + } + hramBin := hFunc.Sum(nil) + m = HashToInt(hramBin) + } else { + m = HashToInt(message) + } + + u1 := new(big.Int).Mul(m, sInv) + u1.Mod(u1, order) + u2 := new(big.Int).Mul(r, sInv) + u2.Mod(u2, order) + var U grumpkin.G1Jac + U.JointScalarMultiplicationBase(&publicKey.A, u1, u2) + + var z big.Int + U.Z.Square(&U.Z). + Inverse(&U.Z). + Mul(&U.Z, &U.X). + BigInt(&z) + + z.Mod(&z, order) + + return z.Cmp(r) == 0, nil + +} diff --git a/ecc/grumpkin/ecdsa/ecdsa_test.go b/ecc/grumpkin/ecdsa/ecdsa_test.go new file mode 100644 index 0000000000..c7660b4786 --- /dev/null +++ b/ecc/grumpkin/ecdsa/ecdsa_test.go @@ -0,0 +1,155 @@ +// Copyright 2020-2025 Consensys Software Inc. +// Licensed under the Apache License, Version 2.0. See the LICENSE file for details. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package ecdsa + +import ( + "crypto/rand" + "crypto/sha256" + "github.com/consensys/gnark-crypto/ecc/grumpkin/fr" + "math/big" + "testing" + + "github.com/leanovate/gopter" + "github.com/leanovate/gopter/prop" +) + +func TestECDSA(t *testing.T) { + + t.Parallel() + parameters := gopter.DefaultTestParameters() + properties := gopter.NewProperties(parameters) + + properties.Property("[GRUMPKIN] test the signing and verification", prop.ForAll( + func() bool { + + privKey, _ := GenerateKey(rand.Reader) + publicKey := privKey.PublicKey + + msg := []byte("testing ECDSA") + hFunc := sha256.New() + sig, _ := privKey.Sign(msg, hFunc) + flag, _ := publicKey.Verify(sig, msg, hFunc) + + return flag + }, + )) + + properties.Property("[GRUMPKIN] test the signing and verification (pre-hashed)", prop.ForAll( + func() bool { + + privKey, _ := GenerateKey(rand.Reader) + publicKey := privKey.PublicKey + + msg := []byte("testing ECDSA") + sig, _ := privKey.Sign(msg, nil) + flag, _ := publicKey.Verify(sig, msg, nil) + + return flag + }, + )) + + properties.TestingRun(t, gopter.ConsoleReporter(false)) +} + +func TestNonMalleability(t *testing.T) { + + // buffer too big + t.Run("buffer_overflow", func(t *testing.T) { + bsig := make([]byte, 2*sizeFr+1) + var sig Signature + _, err := sig.SetBytes(bsig) + if err != errWrongSize { + t.Fatal("should raise wrong size error") + } + }) + + // R overflows p_mod + t.Run("R_overflow", func(t *testing.T) { + bsig := make([]byte, 2*sizeFr) + r := big.NewInt(1) + frMod := fr.Modulus() + r.Add(r, frMod) + buf := r.Bytes() + copy(bsig, buf[:]) + + var sig Signature + _, err := sig.SetBytes(bsig) + if err != errRBiggerThanRMod { + t.Fatal("should raise error r >= r_mod") + } + }) + + // S overflows p_mod + t.Run("S_overflow", func(t *testing.T) { + bsig := make([]byte, 2*sizeFr) + r := big.NewInt(1) + frMod := fr.Modulus() + r.Add(r, frMod) + buf := r.Bytes() + copy(bsig[sizeFr:], buf[:]) + big.NewInt(1).FillBytes(bsig[:sizeFr]) + + var sig Signature + _, err := sig.SetBytes(bsig) + if err != errSBiggerThanRMod { + t.Fatal("should raise error s >= r_mod") + } + }) + +} + +func TestNoZeros(t *testing.T) { + t.Run("R=0", func(t *testing.T) { + // R is 0 + var sig Signature + big.NewInt(0).FillBytes(sig.R[:]) + big.NewInt(1).FillBytes(sig.S[:]) + bts := sig.Bytes() + var newSig Signature + _, err := newSig.SetBytes(bts) + if err != errZero { + t.Fatal("expected error for zero R") + } + }) + t.Run("S=0", func(t *testing.T) { + // S is 0 + var sig Signature + big.NewInt(1).FillBytes(sig.R[:]) + big.NewInt(0).FillBytes(sig.S[:]) + bts := sig.Bytes() + var newSig Signature + _, err := newSig.SetBytes(bts) + if err != errZero { + t.Fatal("expected error for zero S") + } + }) +} + +// ------------------------------------------------------------ +// benches + +func BenchmarkSignECDSA(b *testing.B) { + + privKey, _ := GenerateKey(rand.Reader) + + msg := []byte("benchmarking ECDSA sign()") + b.ResetTimer() + for i := 0; i < b.N; i++ { + privKey.Sign(msg, nil) + } +} + +func BenchmarkVerifyECDSA(b *testing.B) { + + privKey, _ := GenerateKey(rand.Reader) + msg := []byte("benchmarking ECDSA sign()") + sig, _ := privKey.Sign(msg, nil) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + privKey.PublicKey.Verify(sig, msg, nil) + } +} diff --git a/ecc/grumpkin/ecdsa/marshal.go b/ecc/grumpkin/ecdsa/marshal.go new file mode 100644 index 0000000000..fabbc8a287 --- /dev/null +++ b/ecc/grumpkin/ecdsa/marshal.go @@ -0,0 +1,125 @@ +// Copyright 2020-2025 Consensys Software Inc. +// Licensed under the Apache License, Version 2.0. See the LICENSE file for details. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package ecdsa + +import ( + "crypto/subtle" + "errors" + "github.com/consensys/gnark-crypto/ecc/grumpkin/fr" + "io" + "math/big" +) + +var errWrongSize = errors.New("wrong size buffer") +var errRBiggerThanRMod = errors.New("r >= r_mod") +var errSBiggerThanRMod = errors.New("s >= r_mod") +var errZero = errors.New("zero value") + +// Bytes returns the binary representation of the public key +// follows https://tools.ietf.org/html/rfc8032#section-3.1 +// and returns a compressed representation of the point (x,y) +// +// x, y are the coordinates of the point +// on the curve as big endian integers. +// compressed representation store x with a parity bit to recompute y +func (pk *PublicKey) Bytes() []byte { + var res [sizePublicKey]byte + pkBin := pk.A.Bytes() + subtle.ConstantTimeCopy(1, res[:sizePublicKey], pkBin[:]) + return res[:] +} + +// SetBytes sets p from binary representation in buf. +// buf represents a public key as x||y where x, y are +// interpreted as big endian binary numbers corresponding +// to the coordinates of a point on the curve. +// It returns the number of bytes read from the buffer. +func (pk *PublicKey) SetBytes(buf []byte) (int, error) { + n := 0 + if len(buf) < sizePublicKey { + return n, io.ErrShortBuffer + } + if _, err := pk.A.SetBytes(buf[:sizePublicKey]); err != nil { + return 0, err + } + n += sizeFp + return n, nil +} + +// Bytes returns the binary representation of pk, +// as byte array publicKey||scalar +// where publicKey is as publicKey.Bytes(), and +// scalar is in big endian, of size sizeFr. +func (privKey *PrivateKey) Bytes() []byte { + var res [sizePrivateKey]byte + pubkBin := privKey.PublicKey.A.Bytes() + subtle.ConstantTimeCopy(1, res[:sizePublicKey], pubkBin[:]) + subtle.ConstantTimeCopy(1, res[sizePublicKey:sizePrivateKey], privKey.scalar[:]) + return res[:] +} + +// SetBytes sets pk from buf, where buf is interpreted +// as publicKey||scalar +// where publicKey is as publicKey.Bytes(), and +// scalar is in big endian, of size sizeFr. +// It returns the number byte read. +func (privKey *PrivateKey) SetBytes(buf []byte) (int, error) { + n := 0 + if len(buf) < sizePrivateKey { + return n, io.ErrShortBuffer + } + if _, err := privKey.PublicKey.A.SetBytes(buf[:sizePublicKey]); err != nil { + return 0, err + } + n += sizePublicKey + subtle.ConstantTimeCopy(1, privKey.scalar[:], buf[sizePublicKey:sizePrivateKey]) + n += sizeFr + return n, nil +} + +// Bytes returns the binary representation of sig +// as a byte array of size 2*sizeFr r||s +func (sig *Signature) Bytes() []byte { + var res [sizeSignature]byte + subtle.ConstantTimeCopy(1, res[:sizeFr], sig.R[:]) + subtle.ConstantTimeCopy(1, res[sizeFr:], sig.S[:]) + return res[:] +} + +// SetBytes sets sig from a buffer in binary. +// buf is read interpreted as r||s +// It returns the number of bytes read from buf. +func (sig *Signature) SetBytes(buf []byte) (int, error) { + n := 0 + if len(buf) != sizeSignature { + return n, errWrongSize + } + + // S, R < R_mod (to avoid malleability) + frMod := fr.Modulus() + zero := big.NewInt(0) + bufBigInt := new(big.Int) + bufBigInt.SetBytes(buf[:sizeFr]) + if bufBigInt.Cmp(zero) == 0 { + return 0, errZero + } + if bufBigInt.Cmp(frMod) != -1 { + return 0, errRBiggerThanRMod + } + bufBigInt.SetBytes(buf[sizeFr : 2*sizeFr]) + if bufBigInt.Cmp(zero) == 0 { + return 0, errZero + } + if bufBigInt.Cmp(frMod) != -1 { + return 0, errSBiggerThanRMod + } + + subtle.ConstantTimeCopy(1, sig.R[:], buf[:sizeFr]) + n += sizeFr + subtle.ConstantTimeCopy(1, sig.S[:], buf[sizeFr:2*sizeFr]) + n += sizeFr + return n, nil +} diff --git a/ecc/grumpkin/ecdsa/marshal_test.go b/ecc/grumpkin/ecdsa/marshal_test.go new file mode 100644 index 0000000000..bbf82f6b87 --- /dev/null +++ b/ecc/grumpkin/ecdsa/marshal_test.go @@ -0,0 +1,53 @@ +// Copyright 2020-2025 Consensys Software Inc. +// Licensed under the Apache License, Version 2.0. See the LICENSE file for details. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package ecdsa + +import ( + "crypto/rand" + "crypto/subtle" + "testing" + + "github.com/leanovate/gopter" + "github.com/leanovate/gopter/prop" +) + +const ( + nbFuzzShort = 10 + nbFuzz = 100 +) + +func TestSerialization(t *testing.T) { + t.Parallel() + parameters := gopter.DefaultTestParameters() + if testing.Short() { + parameters.MinSuccessfulTests = nbFuzzShort + } else { + parameters.MinSuccessfulTests = nbFuzz + } + + properties := gopter.NewProperties(parameters) + + properties.Property("[GRUMPKIN] ECDSA serialization: SetBytes(Bytes()) should stay the same", prop.ForAll( + func() bool { + privKey, _ := GenerateKey(rand.Reader) + + var end PrivateKey + buf := privKey.Bytes() + n, err := end.SetBytes(buf[:]) + if err != nil { + return false + } + if n != sizePrivateKey { + return false + } + + return end.PublicKey.Equal(&privKey.PublicKey) && subtle.ConstantTimeCompare(end.scalar[:], privKey.scalar[:]) == 1 + + }, + )) + + properties.TestingRun(t, gopter.ConsoleReporter(false)) +} diff --git a/ecc/grumpkin/fp/arith.go b/ecc/grumpkin/fp/arith.go new file mode 100644 index 0000000000..5c9905de80 --- /dev/null +++ b/ecc/grumpkin/fp/arith.go @@ -0,0 +1,49 @@ +// Copyright 2020-2025 Consensys Software Inc. +// Licensed under the Apache License, Version 2.0. See the LICENSE file for details. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package fp + +import ( + "math/bits" +) + +// madd0 hi = a*b + c (discards lo bits) +func madd0(a, b, c uint64) (hi uint64) { + var carry, lo uint64 + hi, lo = bits.Mul64(a, b) + _, carry = bits.Add64(lo, c, 0) + hi, _ = bits.Add64(hi, 0, carry) + return +} + +// madd1 hi, lo = a*b + c +func madd1(a, b, c uint64) (hi uint64, lo uint64) { + var carry uint64 + hi, lo = bits.Mul64(a, b) + lo, carry = bits.Add64(lo, c, 0) + hi, _ = bits.Add64(hi, 0, carry) + return +} + +// madd2 hi, lo = a*b + c + d +func madd2(a, b, c, d uint64) (hi uint64, lo uint64) { + var carry uint64 + hi, lo = bits.Mul64(a, b) + c, carry = bits.Add64(c, d, 0) + hi, _ = bits.Add64(hi, 0, carry) + lo, carry = bits.Add64(lo, c, 0) + hi, _ = bits.Add64(hi, 0, carry) + return +} + +func madd3(a, b, c, d, e uint64) (hi uint64, lo uint64) { + var carry uint64 + hi, lo = bits.Mul64(a, b) + c, carry = bits.Add64(c, d, 0) + hi, _ = bits.Add64(hi, 0, carry) + lo, carry = bits.Add64(lo, c, 0) + hi, _ = bits.Add64(hi, e, carry) + return +} diff --git a/ecc/grumpkin/fp/asm_adx.go b/ecc/grumpkin/fp/asm_adx.go new file mode 100644 index 0000000000..8d85a11345 --- /dev/null +++ b/ecc/grumpkin/fp/asm_adx.go @@ -0,0 +1,15 @@ +//go:build !noadx + +// Copyright 2020-2025 Consensys Software Inc. +// Licensed under the Apache License, Version 2.0. See the LICENSE file for details. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package fp + +import "golang.org/x/sys/cpu" + +var ( + supportAdx = cpu.X86.HasADX && cpu.X86.HasBMI2 + _ = supportAdx +) diff --git a/ecc/grumpkin/fp/asm_avx.go b/ecc/grumpkin/fp/asm_avx.go new file mode 100644 index 0000000000..45e1ab3f0d --- /dev/null +++ b/ecc/grumpkin/fp/asm_avx.go @@ -0,0 +1,15 @@ +//go:build !noavx + +// Copyright 2020-2025 Consensys Software Inc. +// Licensed under the Apache License, Version 2.0. See the LICENSE file for details. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package fp + +import "golang.org/x/sys/cpu" + +var ( + supportAvx512 = supportAdx && cpu.X86.HasAVX512 && cpu.X86.HasAVX512DQ + _ = supportAvx512 +) diff --git a/ecc/grumpkin/fp/asm_noadx.go b/ecc/grumpkin/fp/asm_noadx.go new file mode 100644 index 0000000000..75ca96d775 --- /dev/null +++ b/ecc/grumpkin/fp/asm_noadx.go @@ -0,0 +1,16 @@ +//go:build noadx + +// Copyright 2020-2025 Consensys Software Inc. +// Licensed under the Apache License, Version 2.0. See the LICENSE file for details. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package fp + +// note: this is needed for test purposes, as dynamically changing supportAdx doesn't flag +// certain errors (like fatal error: missing stackmap) +// this ensures we test all asm path. +var ( + supportAdx = false + _ = supportAdx +) diff --git a/ecc/grumpkin/fp/asm_noavx.go b/ecc/grumpkin/fp/asm_noavx.go new file mode 100644 index 0000000000..01f2011925 --- /dev/null +++ b/ecc/grumpkin/fp/asm_noavx.go @@ -0,0 +1,10 @@ +//go:build noavx + +// Copyright 2020-2025 Consensys Software Inc. +// Licensed under the Apache License, Version 2.0. See the LICENSE file for details. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package fp + +const supportAvx512 = false diff --git a/ecc/grumpkin/fp/doc.go b/ecc/grumpkin/fp/doc.go new file mode 100644 index 0000000000..894847e095 --- /dev/null +++ b/ecc/grumpkin/fp/doc.go @@ -0,0 +1,46 @@ +// Copyright 2020-2025 Consensys Software Inc. +// Licensed under the Apache License, Version 2.0. See the LICENSE file for details. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +// Package fp contains field arithmetic operations for modulus = 0x30644e...000001. +// +// The API is similar to math/big (big.Int), but the operations are significantly faster (up to 20x). +// +// Additionally fp.Vector offers an API to manipulate []Element using AVX512 instructions if available. +// +// The modulus is hardcoded in all the operations. +// +// Field elements are represented as an array, and assumed to be in Montgomery form in all methods: +// +// type Element [4]uint64 +// +// # Usage +// +// Example API signature: +// +// // Mul z = x * y (mod q) +// func (z *Element) Mul(x, y *Element) *Element +// +// and can be used like so: +// +// var a, b Element +// a.SetUint64(2) +// b.SetString("984896738") +// a.Mul(a, b) +// a.Sub(a, a) +// .Add(a, b) +// .Inv(a) +// b.Exp(b, new(big.Int).SetUint64(42)) +// +// Modulus q = +// +// q[base10] = 21888242871839275222246405745257275088548364400416034343698204186575808495617 +// q[base16] = 0x30644e72e131a029b85045b68181585d2833e84879b9709143e1f593f0000001 +// +// # Warning +// +// There is no security guarantees such as constant time implementation or side-channel attack resistance. +// This code is provided as-is. Partially audited, see https://github.com/Consensys/gnark/tree/master/audits +// for more details. +package fp diff --git a/ecc/grumpkin/fp/element.go b/ecc/grumpkin/fp/element.go new file mode 100644 index 0000000000..257d62236e --- /dev/null +++ b/ecc/grumpkin/fp/element.go @@ -0,0 +1,1590 @@ +// Copyright 2020-2025 Consensys Software Inc. +// Licensed under the Apache License, Version 2.0. See the LICENSE file for details. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package fp + +import ( + "crypto/rand" + "encoding/binary" + "errors" + "io" + "math/big" + "math/bits" + "reflect" + "strconv" + "strings" + + "github.com/bits-and-blooms/bitset" + "github.com/consensys/gnark-crypto/field/hash" + "github.com/consensys/gnark-crypto/field/pool" +) + +// Element represents a field element stored on 4 words (uint64) +// +// Element are assumed to be in Montgomery form in all methods. +// +// Modulus q = +// +// q[base10] = 21888242871839275222246405745257275088548364400416034343698204186575808495617 +// q[base16] = 0x30644e72e131a029b85045b68181585d2833e84879b9709143e1f593f0000001 +// +// # Warning +// +// This code has not been audited and is provided as-is. In particular, there is no security guarantees such as constant time implementation or side-channel attack resistance. +type Element [4]uint64 + +const ( + Limbs = 4 // number of 64 bits words needed to represent a Element + Bits = 254 // number of bits needed to represent a Element + Bytes = 32 // number of bytes needed to represent a Element +) + +// Field modulus q +const ( + q0 = 4891460686036598785 + q1 = 2896914383306846353 + q2 = 13281191951274694749 + q3 = 3486998266802970665 +) + +var qElement = Element{ + q0, + q1, + q2, + q3, +} + +var _modulus big.Int // q stored as big.Int + +// Modulus returns q as a big.Int +// +// q[base10] = 21888242871839275222246405745257275088548364400416034343698204186575808495617 +// q[base16] = 0x30644e72e131a029b85045b68181585d2833e84879b9709143e1f593f0000001 +func Modulus() *big.Int { + return new(big.Int).Set(&_modulus) +} + +// q + r'.r = 1, i.e., qInvNeg = - q⁻¹ mod r +// used for Montgomery reduction +const qInvNeg = 14042775128853446655 + +// mu = 2^288 / q needed for partial Barrett reduction +const mu uint64 = 22721021478 + +func init() { + _modulus.SetString("30644e72e131a029b85045b68181585d2833e84879b9709143e1f593f0000001", 16) +} + +// NewElement returns a new Element from a uint64 value +// +// it is equivalent to +// +// var v Element +// v.SetUint64(...) +func NewElement(v uint64) Element { + z := Element{v} + z.Mul(&z, &rSquare) + return z +} + +// SetUint64 sets z to v and returns z +func (z *Element) SetUint64(v uint64) *Element { + // sets z LSB to v (non-Montgomery form) and convert z to Montgomery form + *z = Element{v} + return z.Mul(z, &rSquare) // z.toMont() +} + +// SetInt64 sets z to v and returns z +func (z *Element) SetInt64(v int64) *Element { + + // absolute value of v + m := v >> 63 + z.SetUint64(uint64((v ^ m) - m)) + + if m != 0 { + // v is negative + z.Neg(z) + } + + return z +} + +// Set z = x and returns z +func (z *Element) Set(x *Element) *Element { + z[0] = x[0] + z[1] = x[1] + z[2] = x[2] + z[3] = x[3] + return z +} + +// SetInterface converts provided interface into Element +// returns an error if provided type is not supported +// supported types: +// +// Element +// *Element +// uint64 +// int +// string (see SetString for valid formats) +// *big.Int +// big.Int +// []byte +func (z *Element) SetInterface(i1 interface{}) (*Element, error) { + if i1 == nil { + return nil, errors.New("can't set fp.Element with ") + } + + switch c1 := i1.(type) { + case Element: + return z.Set(&c1), nil + case *Element: + if c1 == nil { + return nil, errors.New("can't set fp.Element with ") + } + return z.Set(c1), nil + case uint8: + return z.SetUint64(uint64(c1)), nil + case uint16: + return z.SetUint64(uint64(c1)), nil + case uint32: + return z.SetUint64(uint64(c1)), nil + case uint: + return z.SetUint64(uint64(c1)), nil + case uint64: + return z.SetUint64(c1), nil + case int8: + return z.SetInt64(int64(c1)), nil + case int16: + return z.SetInt64(int64(c1)), nil + case int32: + return z.SetInt64(int64(c1)), nil + case int64: + return z.SetInt64(c1), nil + case int: + return z.SetInt64(int64(c1)), nil + case string: + return z.SetString(c1) + case *big.Int: + if c1 == nil { + return nil, errors.New("can't set fp.Element with ") + } + return z.SetBigInt(c1), nil + case big.Int: + return z.SetBigInt(&c1), nil + case []byte: + return z.SetBytes(c1), nil + default: + return nil, errors.New("can't set fp.Element from type " + reflect.TypeOf(i1).String()) + } +} + +// SetZero z = 0 +func (z *Element) SetZero() *Element { + z[0] = 0 + z[1] = 0 + z[2] = 0 + z[3] = 0 + return z +} + +// SetOne z = 1 (in Montgomery form) +func (z *Element) SetOne() *Element { + z[0] = 12436184717236109307 + z[1] = 3962172157175319849 + z[2] = 7381016538464732718 + z[3] = 1011752739694698287 + return z +} + +// Div z = x*y⁻¹ (mod q) +func (z *Element) Div(x, y *Element) *Element { + var yInv Element + yInv.Inverse(y) + z.Mul(x, &yInv) + return z +} + +// Equal returns z == x; constant-time +func (z *Element) Equal(x *Element) bool { + return z.NotEqual(x) == 0 +} + +// NotEqual returns 0 if and only if z == x; constant-time +func (z *Element) NotEqual(x *Element) uint64 { + return (z[3] ^ x[3]) | (z[2] ^ x[2]) | (z[1] ^ x[1]) | (z[0] ^ x[0]) +} + +// IsZero returns z == 0 +func (z *Element) IsZero() bool { + return (z[3] | z[2] | z[1] | z[0]) == 0 +} + +// IsOne returns z == 1 +func (z *Element) IsOne() bool { + return ((z[3] ^ 1011752739694698287) | (z[2] ^ 7381016538464732718) | (z[1] ^ 3962172157175319849) | (z[0] ^ 12436184717236109307)) == 0 +} + +// IsUint64 reports whether z can be represented as an uint64. +func (z *Element) IsUint64() bool { + zz := *z + zz.fromMont() + return zz.FitsOnOneWord() +} + +// Uint64 returns the uint64 representation of x. If x cannot be represented in a uint64, the result is undefined. +func (z *Element) Uint64() uint64 { + return z.Bits()[0] +} + +// FitsOnOneWord reports whether z words (except the least significant word) are 0 +// +// It is the responsibility of the caller to convert from Montgomery to Regular form if needed. +func (z *Element) FitsOnOneWord() bool { + return (z[3] | z[2] | z[1]) == 0 +} + +// Cmp compares (lexicographic order) z and x and returns: +// +// -1 if z < x +// 0 if z == x +// +1 if z > x +func (z *Element) Cmp(x *Element) int { + _z := z.Bits() + _x := x.Bits() + if _z[3] > _x[3] { + return 1 + } else if _z[3] < _x[3] { + return -1 + } + if _z[2] > _x[2] { + return 1 + } else if _z[2] < _x[2] { + return -1 + } + if _z[1] > _x[1] { + return 1 + } else if _z[1] < _x[1] { + return -1 + } + if _z[0] > _x[0] { + return 1 + } else if _z[0] < _x[0] { + return -1 + } + return 0 +} + +// LexicographicallyLargest returns true if this element is strictly lexicographically +// larger than its negation, false otherwise +func (z *Element) LexicographicallyLargest() bool { + // adapted from github.com/zkcrypto/bls12_381 + // we check if the element is larger than (q-1) / 2 + // if z - (((q -1) / 2) + 1) have no underflow, then z > (q-1) / 2 + + _z := z.Bits() + + var b uint64 + _, b = bits.Sub64(_z[0], 11669102379873075201, 0) + _, b = bits.Sub64(_z[1], 10671829228508198984, b) + _, b = bits.Sub64(_z[2], 15863968012492123182, b) + _, b = bits.Sub64(_z[3], 1743499133401485332, b) + + return b == 0 +} + +// SetRandom sets z to a uniform random value in [0, q). +// +// This might error only if reading from crypto/rand.Reader errors, +// in which case, value of z is undefined. +func (z *Element) SetRandom() (*Element, error) { + // this code is generated for all modulus + // and derived from go/src/crypto/rand/util.go + + // l is number of limbs * 8; the number of bytes needed to reconstruct 4 uint64 + const l = 32 + + // bitLen is the maximum bit length needed to encode a value < q. + const bitLen = 254 + + // k is the maximum byte length needed to encode a value < q. + const k = (bitLen + 7) / 8 + + // b is the number of bits in the most significant byte of q-1. + b := uint(bitLen % 8) + if b == 0 { + b = 8 + } + + var bytes [l]byte + + for { + // note that bytes[k:l] is always 0 + if _, err := io.ReadFull(rand.Reader, bytes[:k]); err != nil { + return nil, err + } + + // Clear unused bits in in the most significant byte to increase probability + // that the candidate is < q. + bytes[k-1] &= uint8(int(1<> 1 + z[0] = z[0]>>1 | z[1]<<63 + z[1] = z[1]>>1 | z[2]<<63 + z[2] = z[2]>>1 | z[3]<<63 + z[3] >>= 1 + +} + +// fromMont converts z in place (i.e. mutates) from Montgomery to regular representation +// sets and returns z = z * 1 +func (z *Element) fromMont() *Element { + fromMont(z) + return z +} + +// Add z = x + y (mod q) +func (z *Element) Add(x, y *Element) *Element { + + var carry uint64 + z[0], carry = bits.Add64(x[0], y[0], 0) + z[1], carry = bits.Add64(x[1], y[1], carry) + z[2], carry = bits.Add64(x[2], y[2], carry) + z[3], _ = bits.Add64(x[3], y[3], carry) + + // if z ⩾ q → z -= q + if !z.smallerThanModulus() { + var b uint64 + z[0], b = bits.Sub64(z[0], q0, 0) + z[1], b = bits.Sub64(z[1], q1, b) + z[2], b = bits.Sub64(z[2], q2, b) + z[3], _ = bits.Sub64(z[3], q3, b) + } + return z +} + +// Double z = x + x (mod q), aka Lsh 1 +func (z *Element) Double(x *Element) *Element { + + var carry uint64 + z[0], carry = bits.Add64(x[0], x[0], 0) + z[1], carry = bits.Add64(x[1], x[1], carry) + z[2], carry = bits.Add64(x[2], x[2], carry) + z[3], _ = bits.Add64(x[3], x[3], carry) + + // if z ⩾ q → z -= q + if !z.smallerThanModulus() { + var b uint64 + z[0], b = bits.Sub64(z[0], q0, 0) + z[1], b = bits.Sub64(z[1], q1, b) + z[2], b = bits.Sub64(z[2], q2, b) + z[3], _ = bits.Sub64(z[3], q3, b) + } + return z +} + +// Sub z = x - y (mod q) +func (z *Element) Sub(x, y *Element) *Element { + var b uint64 + z[0], b = bits.Sub64(x[0], y[0], 0) + z[1], b = bits.Sub64(x[1], y[1], b) + z[2], b = bits.Sub64(x[2], y[2], b) + z[3], b = bits.Sub64(x[3], y[3], b) + if b != 0 { + var c uint64 + z[0], c = bits.Add64(z[0], q0, 0) + z[1], c = bits.Add64(z[1], q1, c) + z[2], c = bits.Add64(z[2], q2, c) + z[3], _ = bits.Add64(z[3], q3, c) + } + return z +} + +// Neg z = q - x +func (z *Element) Neg(x *Element) *Element { + if x.IsZero() { + z.SetZero() + return z + } + var borrow uint64 + z[0], borrow = bits.Sub64(q0, x[0], 0) + z[1], borrow = bits.Sub64(q1, x[1], borrow) + z[2], borrow = bits.Sub64(q2, x[2], borrow) + z[3], _ = bits.Sub64(q3, x[3], borrow) + return z +} + +// Select is a constant-time conditional move. +// If c=0, z = x0. Else z = x1 +func (z *Element) Select(c int, x0 *Element, x1 *Element) *Element { + cC := uint64((int64(c) | -int64(c)) >> 63) // "canonicized" into: 0 if c=0, -1 otherwise + z[0] = x0[0] ^ cC&(x0[0]^x1[0]) + z[1] = x0[1] ^ cC&(x0[1]^x1[1]) + z[2] = x0[2] ^ cC&(x0[2]^x1[2]) + z[3] = x0[3] ^ cC&(x0[3]^x1[3]) + return z +} + +// _mulGeneric is unoptimized textbook CIOS +// it is a fallback solution on x86 when ADX instruction set is not available +// and is used for testing purposes. +func _mulGeneric(z, x, y *Element) { + + // Algorithm 2 of "Faster Montgomery Multiplication and Multi-Scalar-Multiplication for SNARKS" + // by Y. El Housni and G. Botrel https://doi.org/10.46586/tches.v2023.i3.504-521 + + var t [5]uint64 + var D uint64 + var m, C uint64 + // ----------------------------------- + // First loop + + C, t[0] = bits.Mul64(y[0], x[0]) + C, t[1] = madd1(y[0], x[1], C) + C, t[2] = madd1(y[0], x[2], C) + C, t[3] = madd1(y[0], x[3], C) + + t[4], D = bits.Add64(t[4], C, 0) + + // m = t[0]n'[0] mod W + m = t[0] * qInvNeg + + // ----------------------------------- + // Second loop + C = madd0(m, q0, t[0]) + C, t[0] = madd2(m, q1, t[1], C) + C, t[1] = madd2(m, q2, t[2], C) + C, t[2] = madd2(m, q3, t[3], C) + + t[3], C = bits.Add64(t[4], C, 0) + t[4], _ = bits.Add64(0, D, C) + // ----------------------------------- + // First loop + + C, t[0] = madd1(y[1], x[0], t[0]) + C, t[1] = madd2(y[1], x[1], t[1], C) + C, t[2] = madd2(y[1], x[2], t[2], C) + C, t[3] = madd2(y[1], x[3], t[3], C) + + t[4], D = bits.Add64(t[4], C, 0) + + // m = t[0]n'[0] mod W + m = t[0] * qInvNeg + + // ----------------------------------- + // Second loop + C = madd0(m, q0, t[0]) + C, t[0] = madd2(m, q1, t[1], C) + C, t[1] = madd2(m, q2, t[2], C) + C, t[2] = madd2(m, q3, t[3], C) + + t[3], C = bits.Add64(t[4], C, 0) + t[4], _ = bits.Add64(0, D, C) + // ----------------------------------- + // First loop + + C, t[0] = madd1(y[2], x[0], t[0]) + C, t[1] = madd2(y[2], x[1], t[1], C) + C, t[2] = madd2(y[2], x[2], t[2], C) + C, t[3] = madd2(y[2], x[3], t[3], C) + + t[4], D = bits.Add64(t[4], C, 0) + + // m = t[0]n'[0] mod W + m = t[0] * qInvNeg + + // ----------------------------------- + // Second loop + C = madd0(m, q0, t[0]) + C, t[0] = madd2(m, q1, t[1], C) + C, t[1] = madd2(m, q2, t[2], C) + C, t[2] = madd2(m, q3, t[3], C) + + t[3], C = bits.Add64(t[4], C, 0) + t[4], _ = bits.Add64(0, D, C) + // ----------------------------------- + // First loop + + C, t[0] = madd1(y[3], x[0], t[0]) + C, t[1] = madd2(y[3], x[1], t[1], C) + C, t[2] = madd2(y[3], x[2], t[2], C) + C, t[3] = madd2(y[3], x[3], t[3], C) + + t[4], D = bits.Add64(t[4], C, 0) + + // m = t[0]n'[0] mod W + m = t[0] * qInvNeg + + // ----------------------------------- + // Second loop + C = madd0(m, q0, t[0]) + C, t[0] = madd2(m, q1, t[1], C) + C, t[1] = madd2(m, q2, t[2], C) + C, t[2] = madd2(m, q3, t[3], C) + + t[3], C = bits.Add64(t[4], C, 0) + t[4], _ = bits.Add64(0, D, C) + + if t[4] != 0 { + // we need to reduce, we have a result on 5 words + var b uint64 + z[0], b = bits.Sub64(t[0], q0, 0) + z[1], b = bits.Sub64(t[1], q1, b) + z[2], b = bits.Sub64(t[2], q2, b) + z[3], _ = bits.Sub64(t[3], q3, b) + return + } + + // copy t into z + z[0] = t[0] + z[1] = t[1] + z[2] = t[2] + z[3] = t[3] + + // if z ⩾ q → z -= q + if !z.smallerThanModulus() { + var b uint64 + z[0], b = bits.Sub64(z[0], q0, 0) + z[1], b = bits.Sub64(z[1], q1, b) + z[2], b = bits.Sub64(z[2], q2, b) + z[3], _ = bits.Sub64(z[3], q3, b) + } +} + +func _fromMontGeneric(z *Element) { + // the following lines implement z = z * 1 + // with a modified CIOS montgomery multiplication + // see Mul for algorithm documentation + { + // m = z[0]n'[0] mod W + m := z[0] * qInvNeg + C := madd0(m, q0, z[0]) + C, z[0] = madd2(m, q1, z[1], C) + C, z[1] = madd2(m, q2, z[2], C) + C, z[2] = madd2(m, q3, z[3], C) + z[3] = C + } + { + // m = z[0]n'[0] mod W + m := z[0] * qInvNeg + C := madd0(m, q0, z[0]) + C, z[0] = madd2(m, q1, z[1], C) + C, z[1] = madd2(m, q2, z[2], C) + C, z[2] = madd2(m, q3, z[3], C) + z[3] = C + } + { + // m = z[0]n'[0] mod W + m := z[0] * qInvNeg + C := madd0(m, q0, z[0]) + C, z[0] = madd2(m, q1, z[1], C) + C, z[1] = madd2(m, q2, z[2], C) + C, z[2] = madd2(m, q3, z[3], C) + z[3] = C + } + { + // m = z[0]n'[0] mod W + m := z[0] * qInvNeg + C := madd0(m, q0, z[0]) + C, z[0] = madd2(m, q1, z[1], C) + C, z[1] = madd2(m, q2, z[2], C) + C, z[2] = madd2(m, q3, z[3], C) + z[3] = C + } + + // if z ⩾ q → z -= q + if !z.smallerThanModulus() { + var b uint64 + z[0], b = bits.Sub64(z[0], q0, 0) + z[1], b = bits.Sub64(z[1], q1, b) + z[2], b = bits.Sub64(z[2], q2, b) + z[3], _ = bits.Sub64(z[3], q3, b) + } +} + +func _reduceGeneric(z *Element) { + + // if z ⩾ q → z -= q + if !z.smallerThanModulus() { + var b uint64 + z[0], b = bits.Sub64(z[0], q0, 0) + z[1], b = bits.Sub64(z[1], q1, b) + z[2], b = bits.Sub64(z[2], q2, b) + z[3], _ = bits.Sub64(z[3], q3, b) + } +} + +// BatchInvert returns a new slice with every element inverted. +// Uses Montgomery batch inversion trick +func BatchInvert(a []Element) []Element { + res := make([]Element, len(a)) + if len(a) == 0 { + return res + } + + zeroes := bitset.New(uint(len(a))) + accumulator := One() + + for i := 0; i < len(a); i++ { + if a[i].IsZero() { + zeroes.Set(uint(i)) + continue + } + res[i] = accumulator + accumulator.Mul(&accumulator, &a[i]) + } + + accumulator.Inverse(&accumulator) + + for i := len(a) - 1; i >= 0; i-- { + if zeroes.Test(uint(i)) { + continue + } + res[i].Mul(&res[i], &accumulator) + accumulator.Mul(&accumulator, &a[i]) + } + + return res +} + +func _butterflyGeneric(a, b *Element) { + t := *a + a.Add(a, b) + b.Sub(&t, b) +} + +// BitLen returns the minimum number of bits needed to represent z +// returns 0 if z == 0 +func (z *Element) BitLen() int { + if z[3] != 0 { + return 192 + bits.Len64(z[3]) + } + if z[2] != 0 { + return 128 + bits.Len64(z[2]) + } + if z[1] != 0 { + return 64 + bits.Len64(z[1]) + } + return bits.Len64(z[0]) +} + +// Hash msg to count prime field elements. +// https://tools.ietf.org/html/draft-irtf-cfrg-hash-to-curve-06#section-5.2 +func Hash(msg, dst []byte, count int) ([]Element, error) { + // 128 bits of security + // L = ceil((ceil(log2(p)) + k) / 8), where k is the security parameter = 128 + const Bytes = 1 + (Bits-1)/8 + const L = 16 + Bytes + + lenInBytes := count * L + pseudoRandomBytes, err := hash.ExpandMsgXmd(msg, dst, lenInBytes) + if err != nil { + return nil, err + } + + // get temporary big int from the pool + vv := pool.BigInt.Get() + + res := make([]Element, count) + for i := 0; i < count; i++ { + vv.SetBytes(pseudoRandomBytes[i*L : (i+1)*L]) + res[i].SetBigInt(vv) + } + + // release object into pool + pool.BigInt.Put(vv) + + return res, nil +} + +// Exp z = xᵏ (mod q) +func (z *Element) Exp(x Element, k *big.Int) *Element { + if k.IsUint64() && k.Uint64() == 0 { + return z.SetOne() + } + + e := k + if k.Sign() == -1 { + // negative k, we invert + // if k < 0: xᵏ (mod q) == (x⁻¹)ᵏ (mod q) + x.Inverse(&x) + + // we negate k in a temp big.Int since + // Int.Bit(_) of k and -k is different + e = pool.BigInt.Get() + defer pool.BigInt.Put(e) + e.Neg(k) + } + + z.Set(&x) + + for i := e.BitLen() - 2; i >= 0; i-- { + z.Square(z) + if e.Bit(i) == 1 { + z.Mul(z, &x) + } + } + + return z +} + +// rSquare where r is the Montgommery constant +// see section 2.3.2 of Tolga Acar's thesis +// https://www.microsoft.com/en-us/research/wp-content/uploads/1998/06/97Acar.pdf +var rSquare = Element{ + 1997599621687373223, + 6052339484930628067, + 10108755138030829701, + 150537098327114917, +} + +// toMont converts z to Montgomery form +// sets and returns z = z * r² +func (z *Element) toMont() *Element { + return z.Mul(z, &rSquare) +} + +// String returns the decimal representation of z as generated by +// z.Text(10). +func (z *Element) String() string { + return z.Text(10) +} + +// toBigInt returns z as a big.Int in Montgomery form +func (z *Element) toBigInt(res *big.Int) *big.Int { + var b [Bytes]byte + binary.BigEndian.PutUint64(b[24:32], z[0]) + binary.BigEndian.PutUint64(b[16:24], z[1]) + binary.BigEndian.PutUint64(b[8:16], z[2]) + binary.BigEndian.PutUint64(b[0:8], z[3]) + + return res.SetBytes(b[:]) +} + +// Text returns the string representation of z in the given base. +// Base must be between 2 and 36, inclusive. The result uses the +// lower-case letters 'a' to 'z' for digit values 10 to 35. +// No prefix (such as "0x") is added to the string. If z is a nil +// pointer it returns "". +// If base == 10 and -z fits in a uint16 prefix "-" is added to the string. +func (z *Element) Text(base int) string { + if base < 2 || base > 36 { + panic("invalid base") + } + if z == nil { + return "" + } + + const maxUint16 = 65535 + if base == 10 { + var zzNeg Element + zzNeg.Neg(z) + zzNeg.fromMont() + if zzNeg.FitsOnOneWord() && zzNeg[0] <= maxUint16 && zzNeg[0] != 0 { + return "-" + strconv.FormatUint(zzNeg[0], base) + } + } + zz := *z + zz.fromMont() + if zz.FitsOnOneWord() { + return strconv.FormatUint(zz[0], base) + } + vv := pool.BigInt.Get() + r := zz.toBigInt(vv).Text(base) + pool.BigInt.Put(vv) + return r +} + +// BigInt sets and return z as a *big.Int +func (z *Element) BigInt(res *big.Int) *big.Int { + _z := *z + _z.fromMont() + return _z.toBigInt(res) +} + +// ToBigIntRegular returns z as a big.Int in regular form +// +// Deprecated: use BigInt(*big.Int) instead +func (z Element) ToBigIntRegular(res *big.Int) *big.Int { + z.fromMont() + return z.toBigInt(res) +} + +// Bits provides access to z by returning its value as a little-endian [4]uint64 array. +// Bits is intended to support implementation of missing low-level Element +// functionality outside this package; it should be avoided otherwise. +func (z *Element) Bits() [4]uint64 { + _z := *z + fromMont(&_z) + return _z +} + +// Bytes returns the value of z as a big-endian byte array +func (z *Element) Bytes() (res [Bytes]byte) { + BigEndian.PutElement(&res, *z) + return +} + +// Marshal returns the value of z as a big-endian byte slice +func (z *Element) Marshal() []byte { + b := z.Bytes() + return b[:] +} + +// Unmarshal is an alias for SetBytes, it sets z to the value of e. +func (z *Element) Unmarshal(e []byte) { + z.SetBytes(e) +} + +// SetBytes interprets e as the bytes of a big-endian unsigned integer, +// sets z to that value, and returns z. +func (z *Element) SetBytes(e []byte) *Element { + if len(e) == Bytes { + // fast path + v, err := BigEndian.Element((*[Bytes]byte)(e)) + if err == nil { + *z = v + return z + } + } + + // slow path. + // get a big int from our pool + vv := pool.BigInt.Get() + vv.SetBytes(e) + + // set big int + z.SetBigInt(vv) + + // put temporary object back in pool + pool.BigInt.Put(vv) + + return z +} + +// SetBytesCanonical interprets e as the bytes of a big-endian 32-byte integer. +// If e is not a 32-byte slice or encodes a value higher than q, +// SetBytesCanonical returns an error. +func (z *Element) SetBytesCanonical(e []byte) error { + if len(e) != Bytes { + return errors.New("invalid fp.Element encoding") + } + v, err := BigEndian.Element((*[Bytes]byte)(e)) + if err != nil { + return err + } + *z = v + return nil +} + +// SetBigInt sets z to v and returns z +func (z *Element) SetBigInt(v *big.Int) *Element { + z.SetZero() + + var zero big.Int + + // fast path + c := v.Cmp(&_modulus) + if c == 0 { + // v == 0 + return z + } else if c != 1 && v.Cmp(&zero) != -1 { + // 0 <= v < q + return z.setBigInt(v) + } + + // get temporary big int from the pool + vv := pool.BigInt.Get() + + // copy input + modular reduction + vv.Mod(v, &_modulus) + + // set big int byte value + z.setBigInt(vv) + + // release object into pool + pool.BigInt.Put(vv) + return z +} + +// setBigInt assumes 0 ⩽ v < q +func (z *Element) setBigInt(v *big.Int) *Element { + vBits := v.Bits() + + if bits.UintSize == 64 { + for i := 0; i < len(vBits); i++ { + z[i] = uint64(vBits[i]) + } + } else { + for i := 0; i < len(vBits); i++ { + if i%2 == 0 { + z[i/2] = uint64(vBits[i]) + } else { + z[i/2] |= uint64(vBits[i]) << 32 + } + } + } + + return z.toMont() +} + +// SetString creates a big.Int with number and calls SetBigInt on z +// +// The number prefix determines the actual base: A prefix of +// ”0b” or ”0B” selects base 2, ”0”, ”0o” or ”0O” selects base 8, +// and ”0x” or ”0X” selects base 16. Otherwise, the selected base is 10 +// and no prefix is accepted. +// +// For base 16, lower and upper case letters are considered the same: +// The letters 'a' to 'f' and 'A' to 'F' represent digit values 10 to 15. +// +// An underscore character ”_” may appear between a base +// prefix and an adjacent digit, and between successive digits; such +// underscores do not change the value of the number. +// Incorrect placement of underscores is reported as a panic if there +// are no other errors. +// +// If the number is invalid this method leaves z unchanged and returns nil, error. +func (z *Element) SetString(number string) (*Element, error) { + // get temporary big int from the pool + vv := pool.BigInt.Get() + + if _, ok := vv.SetString(number, 0); !ok { + return nil, errors.New("Element.SetString failed -> can't parse number into a big.Int " + number) + } + + z.SetBigInt(vv) + + // release object into pool + pool.BigInt.Put(vv) + + return z, nil +} + +// MarshalJSON returns json encoding of z (z.Text(10)) +// If z == nil, returns null +func (z *Element) MarshalJSON() ([]byte, error) { + if z == nil { + return []byte("null"), nil + } + const maxSafeBound = 15 // we encode it as number if it's small + s := z.Text(10) + if len(s) <= maxSafeBound { + return []byte(s), nil + } + var sbb strings.Builder + sbb.WriteByte('"') + sbb.WriteString(s) + sbb.WriteByte('"') + return []byte(sbb.String()), nil +} + +// UnmarshalJSON accepts numbers and strings as input +// See Element.SetString for valid prefixes (0x, 0b, ...) +func (z *Element) UnmarshalJSON(data []byte) error { + s := string(data) + if len(s) > Bits*3 { + return errors.New("value too large (max = Element.Bits * 3)") + } + + // we accept numbers and strings, remove leading and trailing quotes if any + if len(s) > 0 && s[0] == '"' { + s = s[1:] + } + if len(s) > 0 && s[len(s)-1] == '"' { + s = s[:len(s)-1] + } + + // get temporary big int from the pool + vv := pool.BigInt.Get() + + if _, ok := vv.SetString(s, 0); !ok { + return errors.New("can't parse into a big.Int: " + s) + } + + z.SetBigInt(vv) + + // release object into pool + pool.BigInt.Put(vv) + return nil +} + +// A ByteOrder specifies how to convert byte slices into a Element +type ByteOrder interface { + Element(*[Bytes]byte) (Element, error) + PutElement(*[Bytes]byte, Element) + String() string +} + +var errInvalidEncoding = errors.New("invalid fp.Element encoding") + +// BigEndian is the big-endian implementation of ByteOrder and AppendByteOrder. +var BigEndian bigEndian + +type bigEndian struct{} + +// Element interpret b is a big-endian 32-byte slice. +// If b encodes a value higher than q, Element returns error. +func (bigEndian) Element(b *[Bytes]byte) (Element, error) { + var z Element + z[0] = binary.BigEndian.Uint64((*b)[24:32]) + z[1] = binary.BigEndian.Uint64((*b)[16:24]) + z[2] = binary.BigEndian.Uint64((*b)[8:16]) + z[3] = binary.BigEndian.Uint64((*b)[0:8]) + + if !z.smallerThanModulus() { + return Element{}, errInvalidEncoding + } + + z.toMont() + return z, nil +} + +func (bigEndian) PutElement(b *[Bytes]byte, e Element) { + e.fromMont() + binary.BigEndian.PutUint64((*b)[24:32], e[0]) + binary.BigEndian.PutUint64((*b)[16:24], e[1]) + binary.BigEndian.PutUint64((*b)[8:16], e[2]) + binary.BigEndian.PutUint64((*b)[0:8], e[3]) +} + +func (bigEndian) String() string { return "BigEndian" } + +// LittleEndian is the little-endian implementation of ByteOrder and AppendByteOrder. +var LittleEndian littleEndian + +type littleEndian struct{} + +func (littleEndian) Element(b *[Bytes]byte) (Element, error) { + var z Element + z[0] = binary.LittleEndian.Uint64((*b)[0:8]) + z[1] = binary.LittleEndian.Uint64((*b)[8:16]) + z[2] = binary.LittleEndian.Uint64((*b)[16:24]) + z[3] = binary.LittleEndian.Uint64((*b)[24:32]) + + if !z.smallerThanModulus() { + return Element{}, errInvalidEncoding + } + + z.toMont() + return z, nil +} + +func (littleEndian) PutElement(b *[Bytes]byte, e Element) { + e.fromMont() + binary.LittleEndian.PutUint64((*b)[0:8], e[0]) + binary.LittleEndian.PutUint64((*b)[8:16], e[1]) + binary.LittleEndian.PutUint64((*b)[16:24], e[2]) + binary.LittleEndian.PutUint64((*b)[24:32], e[3]) +} + +func (littleEndian) String() string { return "LittleEndian" } + +// Legendre returns the Legendre symbol of z (either +1, -1, or 0.) +func (z *Element) Legendre() int { + var l Element + // z^((q-1)/2) + l.expByLegendreExp(*z) + + if l.IsZero() { + return 0 + } + + // if l == 1 + if l.IsOne() { + return 1 + } + return -1 +} + +// Sqrt z = √x (mod q) +// if the square root doesn't exist (x is not a square mod q) +// Sqrt leaves z unchanged and returns nil +func (z *Element) Sqrt(x *Element) *Element { + // q ≡ 1 (mod 4) + // see modSqrtTonelliShanks in math/big/int.go + // using https://www.maa.org/sites/default/files/pdf/upload_library/22/Polya/07468342.di020786.02p0470a.pdf + + var y, b, t, w Element + // w = x^((s-1)/2)) + w.expBySqrtExp(*x) + + // y = x^((s+1)/2)) = w * x + y.Mul(x, &w) + + // b = xˢ = w * w * x = y * x + b.Mul(&w, &y) + + // g = nonResidue ^ s + var g = Element{ + 7164790868263648668, + 11685701338293206998, + 6216421865291908056, + 1756667274303109607, + } + r := uint64(28) + + // compute legendre symbol + // t = x^((q-1)/2) = r-1 squaring of xˢ + t = b + for i := uint64(0); i < r-1; i++ { + t.Square(&t) + } + if t.IsZero() { + return z.SetZero() + } + if !t.IsOne() { + // t != 1, we don't have a square root + return nil + } + for { + var m uint64 + t = b + + // for t != 1 + for !t.IsOne() { + t.Square(&t) + m++ + } + + if m == 0 { + return z.Set(&y) + } + // t = g^(2^(r-m-1)) (mod q) + ge := int(r - m - 1) + t = g + for ge > 0 { + t.Square(&t) + ge-- + } + + g.Square(&t) + y.Mul(&y, &t) + b.Mul(&b, &g) + r = m + } +} + +const ( + k = 32 // word size / 2 + signBitSelector = uint64(1) << 63 + approxLowBitsN = k - 1 + approxHighBitsN = k + 1 +) + +const ( + inversionCorrectionFactorWord0 = 13488105295233737379 + inversionCorrectionFactorWord1 = 17373395488625725466 + inversionCorrectionFactorWord2 = 6831692495576925776 + inversionCorrectionFactorWord3 = 3282329835997625403 + invIterationsN = 18 +) + +// Inverse z = x⁻¹ (mod q) +// +// if x == 0, sets and returns z = x +func (z *Element) Inverse(x *Element) *Element { + // Implements "Optimized Binary GCD for Modular Inversion" + // https://github.com/pornin/bingcd/blob/main/doc/bingcd.pdf + + a := *x + b := Element{ + q0, + q1, + q2, + q3, + } // b := q + + u := Element{1} + + // Update factors: we get [u; v] ← [f₀ g₀; f₁ g₁] [u; v] + // cᵢ = fᵢ + 2³¹ - 1 + 2³² * (gᵢ + 2³¹ - 1) + var c0, c1 int64 + + // Saved update factors to reduce the number of field multiplications + var pf0, pf1, pg0, pg1 int64 + + var i uint + + var v, s Element + + // Since u,v are updated every other iteration, we must make sure we terminate after evenly many iterations + // This also lets us get away with half as many updates to u,v + // To make this constant-time-ish, replace the condition with i < invIterationsN + for i = 0; i&1 == 1 || !a.IsZero(); i++ { + n := max(a.BitLen(), b.BitLen()) + aApprox, bApprox := approximate(&a, n), approximate(&b, n) + + // f₀, g₀, f₁, g₁ = 1, 0, 0, 1 + c0, c1 = updateFactorIdentityMatrixRow0, updateFactorIdentityMatrixRow1 + + for j := 0; j < approxLowBitsN; j++ { + + // -2ʲ < f₀, f₁ ≤ 2ʲ + // |f₀| + |f₁| < 2ʲ⁺¹ + + if aApprox&1 == 0 { + aApprox /= 2 + } else { + s, borrow := bits.Sub64(aApprox, bApprox, 0) + if borrow == 1 { + s = bApprox - aApprox + bApprox = aApprox + c0, c1 = c1, c0 + // invariants unchanged + } + + aApprox = s / 2 + c0 = c0 - c1 + + // Now |f₀| < 2ʲ⁺¹ ≤ 2ʲ⁺¹ (only the weaker inequality is needed, strictly speaking) + // Started with f₀ > -2ʲ and f₁ ≤ 2ʲ, so f₀ - f₁ > -2ʲ⁺¹ + // Invariants unchanged for f₁ + } + + c1 *= 2 + // -2ʲ⁺¹ < f₁ ≤ 2ʲ⁺¹ + // So now |f₀| + |f₁| < 2ʲ⁺² + } + + s = a + + var g0 int64 + // from this point on c0 aliases for f0 + c0, g0 = updateFactorsDecompose(c0) + aHi := a.linearCombNonModular(&s, c0, &b, g0) + if aHi&signBitSelector != 0 { + // if aHi < 0 + c0, g0 = -c0, -g0 + aHi = negL(&a, aHi) + } + // right-shift a by k-1 bits + a[0] = (a[0] >> approxLowBitsN) | ((a[1]) << approxHighBitsN) + a[1] = (a[1] >> approxLowBitsN) | ((a[2]) << approxHighBitsN) + a[2] = (a[2] >> approxLowBitsN) | ((a[3]) << approxHighBitsN) + a[3] = (a[3] >> approxLowBitsN) | (aHi << approxHighBitsN) + + var f1 int64 + // from this point on c1 aliases for g0 + f1, c1 = updateFactorsDecompose(c1) + bHi := b.linearCombNonModular(&s, f1, &b, c1) + if bHi&signBitSelector != 0 { + // if bHi < 0 + f1, c1 = -f1, -c1 + bHi = negL(&b, bHi) + } + // right-shift b by k-1 bits + b[0] = (b[0] >> approxLowBitsN) | ((b[1]) << approxHighBitsN) + b[1] = (b[1] >> approxLowBitsN) | ((b[2]) << approxHighBitsN) + b[2] = (b[2] >> approxLowBitsN) | ((b[3]) << approxHighBitsN) + b[3] = (b[3] >> approxLowBitsN) | (bHi << approxHighBitsN) + + if i&1 == 1 { + // Combine current update factors with previously stored ones + // [F₀, G₀; F₁, G₁] ← [f₀, g₀; f₁, g₁] [pf₀, pg₀; pf₁, pg₁], with capital letters denoting new combined values + // We get |F₀| = | f₀pf₀ + g₀pf₁ | ≤ |f₀pf₀| + |g₀pf₁| = |f₀| |pf₀| + |g₀| |pf₁| ≤ 2ᵏ⁻¹|pf₀| + 2ᵏ⁻¹|pf₁| + // = 2ᵏ⁻¹ (|pf₀| + |pf₁|) < 2ᵏ⁻¹ 2ᵏ = 2²ᵏ⁻¹ + // So |F₀| < 2²ᵏ⁻¹ meaning it fits in a 2k-bit signed register + + // c₀ aliases f₀, c₁ aliases g₁ + c0, g0, f1, c1 = c0*pf0+g0*pf1, + c0*pg0+g0*pg1, + f1*pf0+c1*pf1, + f1*pg0+c1*pg1 + + s = u + + // 0 ≤ u, v < 2²⁵⁵ + // |F₀|, |G₀| < 2⁶³ + u.linearComb(&u, c0, &v, g0) + // |F₁|, |G₁| < 2⁶³ + v.linearComb(&s, f1, &v, c1) + + } else { + // Save update factors + pf0, pg0, pf1, pg1 = c0, g0, f1, c1 + } + } + + // For every iteration that we miss, v is not being multiplied by 2ᵏ⁻² + const pSq uint64 = 1 << (2 * (k - 1)) + a = Element{pSq} + // If the function is constant-time ish, this loop will not run (no need to take it out explicitly) + for ; i < invIterationsN; i += 2 { + // could optimize further with mul by word routine or by pre-computing a table since with k=26, + // we would multiply by pSq up to 13times; + // on x86, the assembly routine outperforms generic code for mul by word + // on arm64, we may loose up to ~5% for 6 limbs + v.Mul(&v, &a) + } + + u.Set(x) // for correctness check + + z.Mul(&v, &Element{ + inversionCorrectionFactorWord0, + inversionCorrectionFactorWord1, + inversionCorrectionFactorWord2, + inversionCorrectionFactorWord3, + }) + + // correctness check + v.Mul(&u, z) + if !v.IsOne() && !u.IsZero() { + return z.inverseExp(u) + } + + return z +} + +// inverseExp computes z = x⁻¹ (mod q) = x**(q-2) (mod q) +func (z *Element) inverseExp(x Element) *Element { + // e == q-2 + e := Modulus() + e.Sub(e, big.NewInt(2)) + + z.Set(&x) + + for i := e.BitLen() - 2; i >= 0; i-- { + z.Square(z) + if e.Bit(i) == 1 { + z.Mul(z, &x) + } + } + + return z +} + +// approximate a big number x into a single 64 bit word using its uppermost and lowermost bits +// if x fits in a word as is, no approximation necessary +func approximate(x *Element, nBits int) uint64 { + + if nBits <= 64 { + return x[0] + } + + const mask = (uint64(1) << (k - 1)) - 1 // k-1 ones + lo := mask & x[0] + + hiWordIndex := (nBits - 1) / 64 + + hiWordBitsAvailable := nBits - hiWordIndex*64 + hiWordBitsUsed := min(hiWordBitsAvailable, approxHighBitsN) + + mask_ := uint64(^((1 << (hiWordBitsAvailable - hiWordBitsUsed)) - 1)) + hi := (x[hiWordIndex] & mask_) << (64 - hiWordBitsAvailable) + + mask_ = ^(1<<(approxLowBitsN+hiWordBitsUsed) - 1) + mid := (mask_ & x[hiWordIndex-1]) >> hiWordBitsUsed + + return lo | mid | hi +} + +// linearComb z = xC * x + yC * y; +// 0 ≤ x, y < 2²⁵⁴ +// |xC|, |yC| < 2⁶³ +func (z *Element) linearComb(x *Element, xC int64, y *Element, yC int64) { + // | (hi, z) | < 2 * 2⁶³ * 2²⁵⁴ = 2³¹⁸ + // therefore | hi | < 2⁶² ≤ 2⁶³ + hi := z.linearCombNonModular(x, xC, y, yC) + z.montReduceSigned(z, hi) +} + +// montReduceSigned z = (xHi * r + x) * r⁻¹ using the SOS algorithm +// Requires |xHi| < 2⁶³. Most significant bit of xHi is the sign bit. +func (z *Element) montReduceSigned(x *Element, xHi uint64) { + const signBitRemover = ^signBitSelector + mustNeg := xHi&signBitSelector != 0 + // the SOS implementation requires that most significant bit is 0 + // Let X be xHi*r + x + // If X is negative we would have initially stored it as 2⁶⁴ r + X (à la 2's complement) + xHi &= signBitRemover + // with this a negative X is now represented as 2⁶³ r + X + + var t [2*Limbs - 1]uint64 + var C uint64 + + m := x[0] * qInvNeg + + C = madd0(m, q0, x[0]) + C, t[1] = madd2(m, q1, x[1], C) + C, t[2] = madd2(m, q2, x[2], C) + C, t[3] = madd2(m, q3, x[3], C) + + // m * qElement[3] ≤ (2⁶⁴ - 1) * (2⁶³ - 1) = 2¹²⁷ - 2⁶⁴ - 2⁶³ + 1 + // x[3] + C ≤ 2*(2⁶⁴ - 1) = 2⁶⁵ - 2 + // On LHS, (C, t[3]) ≤ 2¹²⁷ - 2⁶⁴ - 2⁶³ + 1 + 2⁶⁵ - 2 = 2¹²⁷ + 2⁶³ - 1 + // So on LHS, C ≤ 2⁶³ + t[4] = xHi + C + // xHi + C < 2⁶³ + 2⁶³ = 2⁶⁴ + + // + { + const i = 1 + m = t[i] * qInvNeg + + C = madd0(m, q0, t[i+0]) + C, t[i+1] = madd2(m, q1, t[i+1], C) + C, t[i+2] = madd2(m, q2, t[i+2], C) + C, t[i+3] = madd2(m, q3, t[i+3], C) + + t[i+Limbs] += C + } + { + const i = 2 + m = t[i] * qInvNeg + + C = madd0(m, q0, t[i+0]) + C, t[i+1] = madd2(m, q1, t[i+1], C) + C, t[i+2] = madd2(m, q2, t[i+2], C) + C, t[i+3] = madd2(m, q3, t[i+3], C) + + t[i+Limbs] += C + } + { + const i = 3 + m := t[i] * qInvNeg + + C = madd0(m, q0, t[i+0]) + C, z[0] = madd2(m, q1, t[i+1], C) + C, z[1] = madd2(m, q2, t[i+2], C) + z[3], z[2] = madd2(m, q3, t[i+3], C) + } + + // if z ⩾ q → z -= q + if !z.smallerThanModulus() { + var b uint64 + z[0], b = bits.Sub64(z[0], q0, 0) + z[1], b = bits.Sub64(z[1], q1, b) + z[2], b = bits.Sub64(z[2], q2, b) + z[3], _ = bits.Sub64(z[3], q3, b) + } + // + + if mustNeg { + // We have computed ( 2⁶³ r + X ) r⁻¹ = 2⁶³ + X r⁻¹ instead + var b uint64 + z[0], b = bits.Sub64(z[0], signBitSelector, 0) + z[1], b = bits.Sub64(z[1], 0, b) + z[2], b = bits.Sub64(z[2], 0, b) + z[3], b = bits.Sub64(z[3], 0, b) + + // Occurs iff x == 0 && xHi < 0, i.e. X = rX' for -2⁶³ ≤ X' < 0 + + if b != 0 { + // z[3] = -1 + // negative: add q + const neg1 = 0xFFFFFFFFFFFFFFFF + + var carry uint64 + + z[0], carry = bits.Add64(z[0], q0, 0) + z[1], carry = bits.Add64(z[1], q1, carry) + z[2], carry = bits.Add64(z[2], q2, carry) + z[3], _ = bits.Add64(neg1, q3, carry) + } + } +} + +const ( + updateFactorsConversionBias int64 = 0x7fffffff7fffffff // (2³¹ - 1)(2³² + 1) + updateFactorIdentityMatrixRow0 = 1 + updateFactorIdentityMatrixRow1 = 1 << 32 +) + +func updateFactorsDecompose(c int64) (int64, int64) { + c += updateFactorsConversionBias + const low32BitsFilter int64 = 0xFFFFFFFF + f := c&low32BitsFilter - 0x7FFFFFFF + g := c>>32&low32BitsFilter - 0x7FFFFFFF + return f, g +} + +// negL negates in place [x | xHi] and return the new most significant word xHi +func negL(x *Element, xHi uint64) uint64 { + var b uint64 + + x[0], b = bits.Sub64(0, x[0], 0) + x[1], b = bits.Sub64(0, x[1], b) + x[2], b = bits.Sub64(0, x[2], b) + x[3], b = bits.Sub64(0, x[3], b) + xHi, _ = bits.Sub64(0, xHi, b) + + return xHi +} + +// mulWNonModular multiplies by one word in non-montgomery, without reducing +func (z *Element) mulWNonModular(x *Element, y int64) uint64 { + + // w := abs(y) + m := y >> 63 + w := uint64((y ^ m) - m) + + var c uint64 + c, z[0] = bits.Mul64(x[0], w) + c, z[1] = madd1(x[1], w, c) + c, z[2] = madd1(x[2], w, c) + c, z[3] = madd1(x[3], w, c) + + if y < 0 { + c = negL(z, c) + } + + return c +} + +// linearCombNonModular computes a linear combination without modular reduction +func (z *Element) linearCombNonModular(x *Element, xC int64, y *Element, yC int64) uint64 { + var yTimes Element + + yHi := yTimes.mulWNonModular(y, yC) + xHi := z.mulWNonModular(x, xC) + + var carry uint64 + z[0], carry = bits.Add64(z[0], yTimes[0], 0) + z[1], carry = bits.Add64(z[1], yTimes[1], carry) + z[2], carry = bits.Add64(z[2], yTimes[2], carry) + z[3], carry = bits.Add64(z[3], yTimes[3], carry) + + yHi, _ = bits.Add64(xHi, yHi, carry) + + return yHi +} diff --git a/ecc/grumpkin/fp/element_amd64.go b/ecc/grumpkin/fp/element_amd64.go new file mode 100644 index 0000000000..bb861d49b3 --- /dev/null +++ b/ecc/grumpkin/fp/element_amd64.go @@ -0,0 +1,59 @@ +//go:build !purego + +// Copyright 2020-2025 Consensys Software Inc. +// Licensed under the Apache License, Version 2.0. See the LICENSE file for details. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package fp + +import ( + _ "github.com/consensys/gnark-crypto/field/asm/element_4w" +) + +//go:noescape +func MulBy3(x *Element) + +//go:noescape +func MulBy5(x *Element) + +//go:noescape +func MulBy13(x *Element) + +//go:noescape +func mul(res, x, y *Element) + +//go:noescape +func fromMont(res *Element) + +//go:noescape +func reduce(res *Element) + +// Butterfly sets +// +// a = a + b (mod q) +// b = a - b (mod q) +// +//go:noescape +func Butterfly(a, b *Element) + +// Mul z = x * y (mod q) +// +// x and y must be less than q +func (z *Element) Mul(x, y *Element) *Element { + + // Algorithm 2 of "Faster Montgomery Multiplication and Multi-Scalar-Multiplication for SNARKS" + // by Y. El Housni and G. Botrel https://doi.org/10.46586/tches.v2023.i3.504-521 + + mul(z, x, y) + return z +} + +// Square z = x * x (mod q) +// +// x must be less than q +func (z *Element) Square(x *Element) *Element { + // see Mul for doc. + mul(z, x, x) + return z +} diff --git a/ecc/grumpkin/fp/element_amd64.s b/ecc/grumpkin/fp/element_amd64.s new file mode 100644 index 0000000000..b45615aa36 --- /dev/null +++ b/ecc/grumpkin/fp/element_amd64.s @@ -0,0 +1,10 @@ +//go:build !purego + +// Copyright 2020-2025 Consensys Software Inc. +// Licensed under the Apache License, Version 2.0. See the LICENSE file for details. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +// We include the hash to force the Go compiler to recompile: 14652627197992229521 +#include "../../../field/asm/element_4w/element_4w_amd64.s" + diff --git a/ecc/grumpkin/fp/element_arm64.go b/ecc/grumpkin/fp/element_arm64.go new file mode 100644 index 0000000000..6a41dba9de --- /dev/null +++ b/ecc/grumpkin/fp/element_arm64.go @@ -0,0 +1,70 @@ +//go:build !purego + +// Copyright 2020-2025 Consensys Software Inc. +// Licensed under the Apache License, Version 2.0. See the LICENSE file for details. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package fp + +import ( + _ "github.com/consensys/gnark-crypto/field/asm/element_4w" +) + +// Butterfly sets +// +// a = a + b (mod q) +// b = a - b (mod q) +// +//go:noescape +func Butterfly(a, b *Element) + +//go:noescape +func mul(res, x, y *Element) + +// Mul z = x * y (mod q) +// +// x and y must be less than q +func (z *Element) Mul(x, y *Element) *Element { + mul(z, x, y) + return z +} + +// Square z = x * x (mod q) +// +// x must be less than q +func (z *Element) Square(x *Element) *Element { + // see Mul for doc. + mul(z, x, x) + return z +} + +// MulBy3 x *= 3 (mod q) +func MulBy3(x *Element) { + _x := *x + x.Double(x).Add(x, &_x) +} + +// MulBy5 x *= 5 (mod q) +func MulBy5(x *Element) { + _x := *x + x.Double(x).Double(x).Add(x, &_x) +} + +// MulBy13 x *= 13 (mod q) +func MulBy13(x *Element) { + var y = Element{ + 17868810749992763324, + 5924006745939515753, + 769406925088786241, + 2691790815622165739, + } + x.Mul(x, &y) +} + +func fromMont(z *Element) { + _fromMontGeneric(z) +} + +//go:noescape +func reduce(res *Element) diff --git a/ecc/grumpkin/fp/element_arm64.s b/ecc/grumpkin/fp/element_arm64.s new file mode 100644 index 0000000000..c8df07e345 --- /dev/null +++ b/ecc/grumpkin/fp/element_arm64.s @@ -0,0 +1,10 @@ +//go:build !purego + +// Copyright 2020-2025 Consensys Software Inc. +// Licensed under the Apache License, Version 2.0. See the LICENSE file for details. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +// We include the hash to force the Go compiler to recompile: 1501560133179981797 +#include "../../../field/asm/element_4w/element_4w_arm64.s" + diff --git a/ecc/grumpkin/fp/element_exp.go b/ecc/grumpkin/fp/element_exp.go new file mode 100644 index 0000000000..0d26383385 --- /dev/null +++ b/ecc/grumpkin/fp/element_exp.go @@ -0,0 +1,808 @@ +// Copyright 2020-2025 Consensys Software Inc. +// Licensed under the Apache License, Version 2.0. See the LICENSE file for details. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package fp + +// expBySqrtExp is equivalent to z.Exp(x, 183227397098d014dc2822db40c0ac2e9419f4243cdcb848a1f0fac9f) +// +// uses github.com/mmcloughlin/addchain v0.4.0 to generate a shorter addition chain +func (z *Element) expBySqrtExp(x Element) *Element { + // addition chain: + // + // _10 = 2*1 + // _11 = 1 + _10 + // _101 = _10 + _11 + // _111 = _10 + _101 + // _1001 = _10 + _111 + // _1011 = _10 + _1001 + // _1101 = _10 + _1011 + // _1111 = _10 + _1101 + // _11000 = _1001 + _1111 + // _11111 = _111 + _11000 + // i26 = ((_11000 << 4 + _11) << 3 + 1) << 7 + // i36 = ((_1001 + i26) << 2 + _11) << 5 + _111 + // i53 = (2*(i36 << 6 + _1011) + 1) << 8 + // i64 = (2*(_1001 + i53) + 1) << 7 + _1101 + // i84 = ((i64 << 10 + _101) << 6 + _1101) << 2 + // i100 = ((_11 + i84) << 7 + _101) << 6 + 1 + // i117 = ((i100 << 7 + _1011) << 5 + _1101) << 3 + // i137 = ((_101 + i117) << 8 + _11) << 9 + _101 + // i153 = ((i137 << 3 + _11) << 8 + _1011) << 3 + // i168 = ((_101 + i153) << 5 + _101) << 7 + _11 + // i187 = ((i168 << 7 + _11111) << 2 + 1) << 8 + // i204 = ((_1001 + i187) << 8 + _1111) << 6 + _1101 + // i215 = 2*((i204 << 2 + _11) << 6 + _1011) + // i232 = ((1 + i215) << 8 + _1001) << 6 + _101 + // i257 = ((i232 << 9 + _11111) << 9 + _11111) << 5 + // return ((_1011 + i257) << 3 + 1) << 7 + _11111 + // + // Operations: 221 squares 49 multiplies + + // Allocate Temporaries. + var ( + t0 = new(Element) + t1 = new(Element) + t2 = new(Element) + t3 = new(Element) + t4 = new(Element) + t5 = new(Element) + t6 = new(Element) + t7 = new(Element) + ) + + // var t0,t1,t2,t3,t4,t5,t6,t7 Element + // Step 1: z = x^0x2 + z.Square(&x) + + // Step 2: t3 = x^0x3 + t3.Mul(&x, z) + + // Step 3: t1 = x^0x5 + t1.Mul(z, t3) + + // Step 4: t6 = x^0x7 + t6.Mul(z, t1) + + // Step 5: t2 = x^0x9 + t2.Mul(z, t6) + + // Step 6: t0 = x^0xb + t0.Mul(z, t2) + + // Step 7: t4 = x^0xd + t4.Mul(z, t0) + + // Step 8: t5 = x^0xf + t5.Mul(z, t4) + + // Step 9: t7 = x^0x18 + t7.Mul(t2, t5) + + // Step 10: z = x^0x1f + z.Mul(t6, t7) + + // Step 14: t7 = x^0x180 + for s := 0; s < 4; s++ { + t7.Square(t7) + } + + // Step 15: t7 = x^0x183 + t7.Mul(t3, t7) + + // Step 18: t7 = x^0xc18 + for s := 0; s < 3; s++ { + t7.Square(t7) + } + + // Step 19: t7 = x^0xc19 + t7.Mul(&x, t7) + + // Step 26: t7 = x^0x60c80 + for s := 0; s < 7; s++ { + t7.Square(t7) + } + + // Step 27: t7 = x^0x60c89 + t7.Mul(t2, t7) + + // Step 29: t7 = x^0x183224 + for s := 0; s < 2; s++ { + t7.Square(t7) + } + + // Step 30: t7 = x^0x183227 + t7.Mul(t3, t7) + + // Step 35: t7 = x^0x30644e0 + for s := 0; s < 5; s++ { + t7.Square(t7) + } + + // Step 36: t6 = x^0x30644e7 + t6.Mul(t6, t7) + + // Step 42: t6 = x^0xc19139c0 + for s := 0; s < 6; s++ { + t6.Square(t6) + } + + // Step 43: t6 = x^0xc19139cb + t6.Mul(t0, t6) + + // Step 44: t6 = x^0x183227396 + t6.Square(t6) + + // Step 45: t6 = x^0x183227397 + t6.Mul(&x, t6) + + // Step 53: t6 = x^0x18322739700 + for s := 0; s < 8; s++ { + t6.Square(t6) + } + + // Step 54: t6 = x^0x18322739709 + t6.Mul(t2, t6) + + // Step 55: t6 = x^0x30644e72e12 + t6.Square(t6) + + // Step 56: t6 = x^0x30644e72e13 + t6.Mul(&x, t6) + + // Step 63: t6 = x^0x1832273970980 + for s := 0; s < 7; s++ { + t6.Square(t6) + } + + // Step 64: t6 = x^0x183227397098d + t6.Mul(t4, t6) + + // Step 74: t6 = x^0x60c89ce5c263400 + for s := 0; s < 10; s++ { + t6.Square(t6) + } + + // Step 75: t6 = x^0x60c89ce5c263405 + t6.Mul(t1, t6) + + // Step 81: t6 = x^0x183227397098d0140 + for s := 0; s < 6; s++ { + t6.Square(t6) + } + + // Step 82: t6 = x^0x183227397098d014d + t6.Mul(t4, t6) + + // Step 84: t6 = x^0x60c89ce5c26340534 + for s := 0; s < 2; s++ { + t6.Square(t6) + } + + // Step 85: t6 = x^0x60c89ce5c26340537 + t6.Mul(t3, t6) + + // Step 92: t6 = x^0x30644e72e131a029b80 + for s := 0; s < 7; s++ { + t6.Square(t6) + } + + // Step 93: t6 = x^0x30644e72e131a029b85 + t6.Mul(t1, t6) + + // Step 99: t6 = x^0xc19139cb84c680a6e140 + for s := 0; s < 6; s++ { + t6.Square(t6) + } + + // Step 100: t6 = x^0xc19139cb84c680a6e141 + t6.Mul(&x, t6) + + // Step 107: t6 = x^0x60c89ce5c263405370a080 + for s := 0; s < 7; s++ { + t6.Square(t6) + } + + // Step 108: t6 = x^0x60c89ce5c263405370a08b + t6.Mul(t0, t6) + + // Step 113: t6 = x^0xc19139cb84c680a6e141160 + for s := 0; s < 5; s++ { + t6.Square(t6) + } + + // Step 114: t6 = x^0xc19139cb84c680a6e14116d + t6.Mul(t4, t6) + + // Step 117: t6 = x^0x60c89ce5c263405370a08b68 + for s := 0; s < 3; s++ { + t6.Square(t6) + } + + // Step 118: t6 = x^0x60c89ce5c263405370a08b6d + t6.Mul(t1, t6) + + // Step 126: t6 = x^0x60c89ce5c263405370a08b6d00 + for s := 0; s < 8; s++ { + t6.Square(t6) + } + + // Step 127: t6 = x^0x60c89ce5c263405370a08b6d03 + t6.Mul(t3, t6) + + // Step 136: t6 = x^0xc19139cb84c680a6e14116da0600 + for s := 0; s < 9; s++ { + t6.Square(t6) + } + + // Step 137: t6 = x^0xc19139cb84c680a6e14116da0605 + t6.Mul(t1, t6) + + // Step 140: t6 = x^0x60c89ce5c263405370a08b6d03028 + for s := 0; s < 3; s++ { + t6.Square(t6) + } + + // Step 141: t6 = x^0x60c89ce5c263405370a08b6d0302b + t6.Mul(t3, t6) + + // Step 149: t6 = x^0x60c89ce5c263405370a08b6d0302b00 + for s := 0; s < 8; s++ { + t6.Square(t6) + } + + // Step 150: t6 = x^0x60c89ce5c263405370a08b6d0302b0b + t6.Mul(t0, t6) + + // Step 153: t6 = x^0x30644e72e131a029b85045b681815858 + for s := 0; s < 3; s++ { + t6.Square(t6) + } + + // Step 154: t6 = x^0x30644e72e131a029b85045b68181585d + t6.Mul(t1, t6) + + // Step 159: t6 = x^0x60c89ce5c263405370a08b6d0302b0ba0 + for s := 0; s < 5; s++ { + t6.Square(t6) + } + + // Step 160: t6 = x^0x60c89ce5c263405370a08b6d0302b0ba5 + t6.Mul(t1, t6) + + // Step 167: t6 = x^0x30644e72e131a029b85045b68181585d280 + for s := 0; s < 7; s++ { + t6.Square(t6) + } + + // Step 168: t6 = x^0x30644e72e131a029b85045b68181585d283 + t6.Mul(t3, t6) + + // Step 175: t6 = x^0x183227397098d014dc2822db40c0ac2e94180 + for s := 0; s < 7; s++ { + t6.Square(t6) + } + + // Step 176: t6 = x^0x183227397098d014dc2822db40c0ac2e9419f + t6.Mul(z, t6) + + // Step 178: t6 = x^0x60c89ce5c263405370a08b6d0302b0ba5067c + for s := 0; s < 2; s++ { + t6.Square(t6) + } + + // Step 179: t6 = x^0x60c89ce5c263405370a08b6d0302b0ba5067d + t6.Mul(&x, t6) + + // Step 187: t6 = x^0x60c89ce5c263405370a08b6d0302b0ba5067d00 + for s := 0; s < 8; s++ { + t6.Square(t6) + } + + // Step 188: t6 = x^0x60c89ce5c263405370a08b6d0302b0ba5067d09 + t6.Mul(t2, t6) + + // Step 196: t6 = x^0x60c89ce5c263405370a08b6d0302b0ba5067d0900 + for s := 0; s < 8; s++ { + t6.Square(t6) + } + + // Step 197: t5 = x^0x60c89ce5c263405370a08b6d0302b0ba5067d090f + t5.Mul(t5, t6) + + // Step 203: t5 = x^0x183227397098d014dc2822db40c0ac2e9419f4243c0 + for s := 0; s < 6; s++ { + t5.Square(t5) + } + + // Step 204: t4 = x^0x183227397098d014dc2822db40c0ac2e9419f4243cd + t4.Mul(t4, t5) + + // Step 206: t4 = x^0x60c89ce5c263405370a08b6d0302b0ba5067d090f34 + for s := 0; s < 2; s++ { + t4.Square(t4) + } + + // Step 207: t3 = x^0x60c89ce5c263405370a08b6d0302b0ba5067d090f37 + t3.Mul(t3, t4) + + // Step 213: t3 = x^0x183227397098d014dc2822db40c0ac2e9419f4243cdc0 + for s := 0; s < 6; s++ { + t3.Square(t3) + } + + // Step 214: t3 = x^0x183227397098d014dc2822db40c0ac2e9419f4243cdcb + t3.Mul(t0, t3) + + // Step 215: t3 = x^0x30644e72e131a029b85045b68181585d2833e84879b96 + t3.Square(t3) + + // Step 216: t3 = x^0x30644e72e131a029b85045b68181585d2833e84879b97 + t3.Mul(&x, t3) + + // Step 224: t3 = x^0x30644e72e131a029b85045b68181585d2833e84879b9700 + for s := 0; s < 8; s++ { + t3.Square(t3) + } + + // Step 225: t2 = x^0x30644e72e131a029b85045b68181585d2833e84879b9709 + t2.Mul(t2, t3) + + // Step 231: t2 = x^0xc19139cb84c680a6e14116da06056174a0cfa121e6e5c240 + for s := 0; s < 6; s++ { + t2.Square(t2) + } + + // Step 232: t1 = x^0xc19139cb84c680a6e14116da06056174a0cfa121e6e5c245 + t1.Mul(t1, t2) + + // Step 241: t1 = x^0x183227397098d014dc2822db40c0ac2e9419f4243cdcb848a00 + for s := 0; s < 9; s++ { + t1.Square(t1) + } + + // Step 242: t1 = x^0x183227397098d014dc2822db40c0ac2e9419f4243cdcb848a1f + t1.Mul(z, t1) + + // Step 251: t1 = x^0x30644e72e131a029b85045b68181585d2833e84879b9709143e00 + for s := 0; s < 9; s++ { + t1.Square(t1) + } + + // Step 252: t1 = x^0x30644e72e131a029b85045b68181585d2833e84879b9709143e1f + t1.Mul(z, t1) + + // Step 257: t1 = x^0x60c89ce5c263405370a08b6d0302b0ba5067d090f372e12287c3e0 + for s := 0; s < 5; s++ { + t1.Square(t1) + } + + // Step 258: t0 = x^0x60c89ce5c263405370a08b6d0302b0ba5067d090f372e12287c3eb + t0.Mul(t0, t1) + + // Step 261: t0 = x^0x30644e72e131a029b85045b68181585d2833e84879b9709143e1f58 + for s := 0; s < 3; s++ { + t0.Square(t0) + } + + // Step 262: t0 = x^0x30644e72e131a029b85045b68181585d2833e84879b9709143e1f59 + t0.Mul(&x, t0) + + // Step 269: t0 = x^0x183227397098d014dc2822db40c0ac2e9419f4243cdcb848a1f0fac80 + for s := 0; s < 7; s++ { + t0.Square(t0) + } + + // Step 270: z = x^0x183227397098d014dc2822db40c0ac2e9419f4243cdcb848a1f0fac9f + z.Mul(z, t0) + + return z +} + +// expByLegendreExp is equivalent to z.Exp(x, 183227397098d014dc2822db40c0ac2e9419f4243cdcb848a1f0fac9f8000000) +// +// uses github.com/mmcloughlin/addchain v0.4.0 to generate a shorter addition chain +func (z *Element) expByLegendreExp(x Element) *Element { + // addition chain: + // + // _10 = 2*1 + // _11 = 1 + _10 + // _101 = _10 + _11 + // _111 = _10 + _101 + // _1001 = _10 + _111 + // _1011 = _10 + _1001 + // _1101 = _10 + _1011 + // _1111 = _10 + _1101 + // _11000 = _1001 + _1111 + // _11111 = _111 + _11000 + // i26 = ((_11000 << 4 + _11) << 3 + 1) << 7 + // i36 = ((_1001 + i26) << 2 + _11) << 5 + _111 + // i53 = (2*(i36 << 6 + _1011) + 1) << 8 + // i64 = (2*(_1001 + i53) + 1) << 7 + _1101 + // i84 = ((i64 << 10 + _101) << 6 + _1101) << 2 + // i100 = ((_11 + i84) << 7 + _101) << 6 + 1 + // i117 = ((i100 << 7 + _1011) << 5 + _1101) << 3 + // i137 = ((_101 + i117) << 8 + _11) << 9 + _101 + // i153 = ((i137 << 3 + _11) << 8 + _1011) << 3 + // i168 = ((_101 + i153) << 5 + _101) << 7 + _11 + // i187 = ((i168 << 7 + _11111) << 2 + 1) << 8 + // i204 = ((_1001 + i187) << 8 + _1111) << 6 + _1101 + // i215 = 2*((i204 << 2 + _11) << 6 + _1011) + // i232 = ((1 + i215) << 8 + _1001) << 6 + _101 + // i257 = ((i232 << 9 + _11111) << 9 + _11111) << 5 + // i270 = ((_1011 + i257) << 3 + 1) << 7 + _11111 + // return (2*i270 + 1) << 27 + // + // Operations: 249 squares 50 multiplies + + // Allocate Temporaries. + var ( + t0 = new(Element) + t1 = new(Element) + t2 = new(Element) + t3 = new(Element) + t4 = new(Element) + t5 = new(Element) + t6 = new(Element) + t7 = new(Element) + ) + + // var t0,t1,t2,t3,t4,t5,t6,t7 Element + // Step 1: z = x^0x2 + z.Square(&x) + + // Step 2: t3 = x^0x3 + t3.Mul(&x, z) + + // Step 3: t1 = x^0x5 + t1.Mul(z, t3) + + // Step 4: t6 = x^0x7 + t6.Mul(z, t1) + + // Step 5: t2 = x^0x9 + t2.Mul(z, t6) + + // Step 6: t0 = x^0xb + t0.Mul(z, t2) + + // Step 7: t4 = x^0xd + t4.Mul(z, t0) + + // Step 8: t5 = x^0xf + t5.Mul(z, t4) + + // Step 9: t7 = x^0x18 + t7.Mul(t2, t5) + + // Step 10: z = x^0x1f + z.Mul(t6, t7) + + // Step 14: t7 = x^0x180 + for s := 0; s < 4; s++ { + t7.Square(t7) + } + + // Step 15: t7 = x^0x183 + t7.Mul(t3, t7) + + // Step 18: t7 = x^0xc18 + for s := 0; s < 3; s++ { + t7.Square(t7) + } + + // Step 19: t7 = x^0xc19 + t7.Mul(&x, t7) + + // Step 26: t7 = x^0x60c80 + for s := 0; s < 7; s++ { + t7.Square(t7) + } + + // Step 27: t7 = x^0x60c89 + t7.Mul(t2, t7) + + // Step 29: t7 = x^0x183224 + for s := 0; s < 2; s++ { + t7.Square(t7) + } + + // Step 30: t7 = x^0x183227 + t7.Mul(t3, t7) + + // Step 35: t7 = x^0x30644e0 + for s := 0; s < 5; s++ { + t7.Square(t7) + } + + // Step 36: t6 = x^0x30644e7 + t6.Mul(t6, t7) + + // Step 42: t6 = x^0xc19139c0 + for s := 0; s < 6; s++ { + t6.Square(t6) + } + + // Step 43: t6 = x^0xc19139cb + t6.Mul(t0, t6) + + // Step 44: t6 = x^0x183227396 + t6.Square(t6) + + // Step 45: t6 = x^0x183227397 + t6.Mul(&x, t6) + + // Step 53: t6 = x^0x18322739700 + for s := 0; s < 8; s++ { + t6.Square(t6) + } + + // Step 54: t6 = x^0x18322739709 + t6.Mul(t2, t6) + + // Step 55: t6 = x^0x30644e72e12 + t6.Square(t6) + + // Step 56: t6 = x^0x30644e72e13 + t6.Mul(&x, t6) + + // Step 63: t6 = x^0x1832273970980 + for s := 0; s < 7; s++ { + t6.Square(t6) + } + + // Step 64: t6 = x^0x183227397098d + t6.Mul(t4, t6) + + // Step 74: t6 = x^0x60c89ce5c263400 + for s := 0; s < 10; s++ { + t6.Square(t6) + } + + // Step 75: t6 = x^0x60c89ce5c263405 + t6.Mul(t1, t6) + + // Step 81: t6 = x^0x183227397098d0140 + for s := 0; s < 6; s++ { + t6.Square(t6) + } + + // Step 82: t6 = x^0x183227397098d014d + t6.Mul(t4, t6) + + // Step 84: t6 = x^0x60c89ce5c26340534 + for s := 0; s < 2; s++ { + t6.Square(t6) + } + + // Step 85: t6 = x^0x60c89ce5c26340537 + t6.Mul(t3, t6) + + // Step 92: t6 = x^0x30644e72e131a029b80 + for s := 0; s < 7; s++ { + t6.Square(t6) + } + + // Step 93: t6 = x^0x30644e72e131a029b85 + t6.Mul(t1, t6) + + // Step 99: t6 = x^0xc19139cb84c680a6e140 + for s := 0; s < 6; s++ { + t6.Square(t6) + } + + // Step 100: t6 = x^0xc19139cb84c680a6e141 + t6.Mul(&x, t6) + + // Step 107: t6 = x^0x60c89ce5c263405370a080 + for s := 0; s < 7; s++ { + t6.Square(t6) + } + + // Step 108: t6 = x^0x60c89ce5c263405370a08b + t6.Mul(t0, t6) + + // Step 113: t6 = x^0xc19139cb84c680a6e141160 + for s := 0; s < 5; s++ { + t6.Square(t6) + } + + // Step 114: t6 = x^0xc19139cb84c680a6e14116d + t6.Mul(t4, t6) + + // Step 117: t6 = x^0x60c89ce5c263405370a08b68 + for s := 0; s < 3; s++ { + t6.Square(t6) + } + + // Step 118: t6 = x^0x60c89ce5c263405370a08b6d + t6.Mul(t1, t6) + + // Step 126: t6 = x^0x60c89ce5c263405370a08b6d00 + for s := 0; s < 8; s++ { + t6.Square(t6) + } + + // Step 127: t6 = x^0x60c89ce5c263405370a08b6d03 + t6.Mul(t3, t6) + + // Step 136: t6 = x^0xc19139cb84c680a6e14116da0600 + for s := 0; s < 9; s++ { + t6.Square(t6) + } + + // Step 137: t6 = x^0xc19139cb84c680a6e14116da0605 + t6.Mul(t1, t6) + + // Step 140: t6 = x^0x60c89ce5c263405370a08b6d03028 + for s := 0; s < 3; s++ { + t6.Square(t6) + } + + // Step 141: t6 = x^0x60c89ce5c263405370a08b6d0302b + t6.Mul(t3, t6) + + // Step 149: t6 = x^0x60c89ce5c263405370a08b6d0302b00 + for s := 0; s < 8; s++ { + t6.Square(t6) + } + + // Step 150: t6 = x^0x60c89ce5c263405370a08b6d0302b0b + t6.Mul(t0, t6) + + // Step 153: t6 = x^0x30644e72e131a029b85045b681815858 + for s := 0; s < 3; s++ { + t6.Square(t6) + } + + // Step 154: t6 = x^0x30644e72e131a029b85045b68181585d + t6.Mul(t1, t6) + + // Step 159: t6 = x^0x60c89ce5c263405370a08b6d0302b0ba0 + for s := 0; s < 5; s++ { + t6.Square(t6) + } + + // Step 160: t6 = x^0x60c89ce5c263405370a08b6d0302b0ba5 + t6.Mul(t1, t6) + + // Step 167: t6 = x^0x30644e72e131a029b85045b68181585d280 + for s := 0; s < 7; s++ { + t6.Square(t6) + } + + // Step 168: t6 = x^0x30644e72e131a029b85045b68181585d283 + t6.Mul(t3, t6) + + // Step 175: t6 = x^0x183227397098d014dc2822db40c0ac2e94180 + for s := 0; s < 7; s++ { + t6.Square(t6) + } + + // Step 176: t6 = x^0x183227397098d014dc2822db40c0ac2e9419f + t6.Mul(z, t6) + + // Step 178: t6 = x^0x60c89ce5c263405370a08b6d0302b0ba5067c + for s := 0; s < 2; s++ { + t6.Square(t6) + } + + // Step 179: t6 = x^0x60c89ce5c263405370a08b6d0302b0ba5067d + t6.Mul(&x, t6) + + // Step 187: t6 = x^0x60c89ce5c263405370a08b6d0302b0ba5067d00 + for s := 0; s < 8; s++ { + t6.Square(t6) + } + + // Step 188: t6 = x^0x60c89ce5c263405370a08b6d0302b0ba5067d09 + t6.Mul(t2, t6) + + // Step 196: t6 = x^0x60c89ce5c263405370a08b6d0302b0ba5067d0900 + for s := 0; s < 8; s++ { + t6.Square(t6) + } + + // Step 197: t5 = x^0x60c89ce5c263405370a08b6d0302b0ba5067d090f + t5.Mul(t5, t6) + + // Step 203: t5 = x^0x183227397098d014dc2822db40c0ac2e9419f4243c0 + for s := 0; s < 6; s++ { + t5.Square(t5) + } + + // Step 204: t4 = x^0x183227397098d014dc2822db40c0ac2e9419f4243cd + t4.Mul(t4, t5) + + // Step 206: t4 = x^0x60c89ce5c263405370a08b6d0302b0ba5067d090f34 + for s := 0; s < 2; s++ { + t4.Square(t4) + } + + // Step 207: t3 = x^0x60c89ce5c263405370a08b6d0302b0ba5067d090f37 + t3.Mul(t3, t4) + + // Step 213: t3 = x^0x183227397098d014dc2822db40c0ac2e9419f4243cdc0 + for s := 0; s < 6; s++ { + t3.Square(t3) + } + + // Step 214: t3 = x^0x183227397098d014dc2822db40c0ac2e9419f4243cdcb + t3.Mul(t0, t3) + + // Step 215: t3 = x^0x30644e72e131a029b85045b68181585d2833e84879b96 + t3.Square(t3) + + // Step 216: t3 = x^0x30644e72e131a029b85045b68181585d2833e84879b97 + t3.Mul(&x, t3) + + // Step 224: t3 = x^0x30644e72e131a029b85045b68181585d2833e84879b9700 + for s := 0; s < 8; s++ { + t3.Square(t3) + } + + // Step 225: t2 = x^0x30644e72e131a029b85045b68181585d2833e84879b9709 + t2.Mul(t2, t3) + + // Step 231: t2 = x^0xc19139cb84c680a6e14116da06056174a0cfa121e6e5c240 + for s := 0; s < 6; s++ { + t2.Square(t2) + } + + // Step 232: t1 = x^0xc19139cb84c680a6e14116da06056174a0cfa121e6e5c245 + t1.Mul(t1, t2) + + // Step 241: t1 = x^0x183227397098d014dc2822db40c0ac2e9419f4243cdcb848a00 + for s := 0; s < 9; s++ { + t1.Square(t1) + } + + // Step 242: t1 = x^0x183227397098d014dc2822db40c0ac2e9419f4243cdcb848a1f + t1.Mul(z, t1) + + // Step 251: t1 = x^0x30644e72e131a029b85045b68181585d2833e84879b9709143e00 + for s := 0; s < 9; s++ { + t1.Square(t1) + } + + // Step 252: t1 = x^0x30644e72e131a029b85045b68181585d2833e84879b9709143e1f + t1.Mul(z, t1) + + // Step 257: t1 = x^0x60c89ce5c263405370a08b6d0302b0ba5067d090f372e12287c3e0 + for s := 0; s < 5; s++ { + t1.Square(t1) + } + + // Step 258: t0 = x^0x60c89ce5c263405370a08b6d0302b0ba5067d090f372e12287c3eb + t0.Mul(t0, t1) + + // Step 261: t0 = x^0x30644e72e131a029b85045b68181585d2833e84879b9709143e1f58 + for s := 0; s < 3; s++ { + t0.Square(t0) + } + + // Step 262: t0 = x^0x30644e72e131a029b85045b68181585d2833e84879b9709143e1f59 + t0.Mul(&x, t0) + + // Step 269: t0 = x^0x183227397098d014dc2822db40c0ac2e9419f4243cdcb848a1f0fac80 + for s := 0; s < 7; s++ { + t0.Square(t0) + } + + // Step 270: z = x^0x183227397098d014dc2822db40c0ac2e9419f4243cdcb848a1f0fac9f + z.Mul(z, t0) + + // Step 271: z = x^0x30644e72e131a029b85045b68181585d2833e84879b9709143e1f593e + z.Square(z) + + // Step 272: z = x^0x30644e72e131a029b85045b68181585d2833e84879b9709143e1f593f + z.Mul(&x, z) + + // Step 299: z = x^0x183227397098d014dc2822db40c0ac2e9419f4243cdcb848a1f0fac9f8000000 + for s := 0; s < 27; s++ { + z.Square(z) + } + + return z +} diff --git a/ecc/grumpkin/fp/element_purego.go b/ecc/grumpkin/fp/element_purego.go new file mode 100644 index 0000000000..247946286b --- /dev/null +++ b/ecc/grumpkin/fp/element_purego.go @@ -0,0 +1,391 @@ +//go:build purego || (!amd64 && !arm64) + +// Copyright 2020-2025 Consensys Software Inc. +// Licensed under the Apache License, Version 2.0. See the LICENSE file for details. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package fp + +import "math/bits" + +// MulBy3 x *= 3 (mod q) +func MulBy3(x *Element) { + _x := *x + x.Double(x).Add(x, &_x) +} + +// MulBy5 x *= 5 (mod q) +func MulBy5(x *Element) { + _x := *x + x.Double(x).Double(x).Add(x, &_x) +} + +// MulBy13 x *= 13 (mod q) +func MulBy13(x *Element) { + var y = Element{ + 17868810749992763324, + 5924006745939515753, + 769406925088786241, + 2691790815622165739, + } + x.Mul(x, &y) +} + +func fromMont(z *Element) { + _fromMontGeneric(z) +} + +func reduce(z *Element) { + _reduceGeneric(z) +} + +// Mul z = x * y (mod q) +// +// x and y must be less than q +func (z *Element) Mul(x, y *Element) *Element { + + // Algorithm 2 of "Faster Montgomery Multiplication and Multi-Scalar-Multiplication for SNARKS" + // by Y. El Housni and G. Botrel https://doi.org/10.46586/tches.v2023.i3.504-521 + + var t0, t1, t2, t3 uint64 + var u0, u1, u2, u3 uint64 + { + var c0, c1, c2 uint64 + v := x[0] + u0, t0 = bits.Mul64(v, y[0]) + u1, t1 = bits.Mul64(v, y[1]) + u2, t2 = bits.Mul64(v, y[2]) + u3, t3 = bits.Mul64(v, y[3]) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + c2, _ = bits.Add64(u3, 0, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + + t2, c0 = bits.Add64(0, c1, c0) + u3, _ = bits.Add64(u3, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + c2, _ = bits.Add64(c2, 0, c0) + t2, c0 = bits.Add64(t3, t2, 0) + t3, _ = bits.Add64(u3, c2, c0) + + } + { + var c0, c1, c2 uint64 + v := x[1] + u0, c1 = bits.Mul64(v, y[0]) + t0, c0 = bits.Add64(c1, t0, 0) + u1, c1 = bits.Mul64(v, y[1]) + t1, c0 = bits.Add64(c1, t1, c0) + u2, c1 = bits.Mul64(v, y[2]) + t2, c0 = bits.Add64(c1, t2, c0) + u3, c1 = bits.Mul64(v, y[3]) + t3, c0 = bits.Add64(c1, t3, c0) + + c2, _ = bits.Add64(0, 0, c0) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + c2, _ = bits.Add64(u3, c2, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + + t2, c0 = bits.Add64(0, c1, c0) + u3, _ = bits.Add64(u3, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + c2, _ = bits.Add64(c2, 0, c0) + t2, c0 = bits.Add64(t3, t2, 0) + t3, _ = bits.Add64(u3, c2, c0) + + } + { + var c0, c1, c2 uint64 + v := x[2] + u0, c1 = bits.Mul64(v, y[0]) + t0, c0 = bits.Add64(c1, t0, 0) + u1, c1 = bits.Mul64(v, y[1]) + t1, c0 = bits.Add64(c1, t1, c0) + u2, c1 = bits.Mul64(v, y[2]) + t2, c0 = bits.Add64(c1, t2, c0) + u3, c1 = bits.Mul64(v, y[3]) + t3, c0 = bits.Add64(c1, t3, c0) + + c2, _ = bits.Add64(0, 0, c0) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + c2, _ = bits.Add64(u3, c2, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + + t2, c0 = bits.Add64(0, c1, c0) + u3, _ = bits.Add64(u3, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + c2, _ = bits.Add64(c2, 0, c0) + t2, c0 = bits.Add64(t3, t2, 0) + t3, _ = bits.Add64(u3, c2, c0) + + } + { + var c0, c1, c2 uint64 + v := x[3] + u0, c1 = bits.Mul64(v, y[0]) + t0, c0 = bits.Add64(c1, t0, 0) + u1, c1 = bits.Mul64(v, y[1]) + t1, c0 = bits.Add64(c1, t1, c0) + u2, c1 = bits.Mul64(v, y[2]) + t2, c0 = bits.Add64(c1, t2, c0) + u3, c1 = bits.Mul64(v, y[3]) + t3, c0 = bits.Add64(c1, t3, c0) + + c2, _ = bits.Add64(0, 0, c0) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + c2, _ = bits.Add64(u3, c2, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + + t2, c0 = bits.Add64(0, c1, c0) + u3, _ = bits.Add64(u3, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + c2, _ = bits.Add64(c2, 0, c0) + t2, c0 = bits.Add64(t3, t2, 0) + t3, _ = bits.Add64(u3, c2, c0) + + } + z[0] = t0 + z[1] = t1 + z[2] = t2 + z[3] = t3 + + // if z ⩾ q → z -= q + if !z.smallerThanModulus() { + var b uint64 + z[0], b = bits.Sub64(z[0], q0, 0) + z[1], b = bits.Sub64(z[1], q1, b) + z[2], b = bits.Sub64(z[2], q2, b) + z[3], _ = bits.Sub64(z[3], q3, b) + } + return z +} + +// Square z = x * x (mod q) +// +// x must be less than q +func (z *Element) Square(x *Element) *Element { + // see Mul for algorithm documentation + + var t0, t1, t2, t3 uint64 + var u0, u1, u2, u3 uint64 + { + var c0, c1, c2 uint64 + v := x[0] + u0, t0 = bits.Mul64(v, x[0]) + u1, t1 = bits.Mul64(v, x[1]) + u2, t2 = bits.Mul64(v, x[2]) + u3, t3 = bits.Mul64(v, x[3]) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + c2, _ = bits.Add64(u3, 0, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + + t2, c0 = bits.Add64(0, c1, c0) + u3, _ = bits.Add64(u3, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + c2, _ = bits.Add64(c2, 0, c0) + t2, c0 = bits.Add64(t3, t2, 0) + t3, _ = bits.Add64(u3, c2, c0) + + } + { + var c0, c1, c2 uint64 + v := x[1] + u0, c1 = bits.Mul64(v, x[0]) + t0, c0 = bits.Add64(c1, t0, 0) + u1, c1 = bits.Mul64(v, x[1]) + t1, c0 = bits.Add64(c1, t1, c0) + u2, c1 = bits.Mul64(v, x[2]) + t2, c0 = bits.Add64(c1, t2, c0) + u3, c1 = bits.Mul64(v, x[3]) + t3, c0 = bits.Add64(c1, t3, c0) + + c2, _ = bits.Add64(0, 0, c0) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + c2, _ = bits.Add64(u3, c2, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + + t2, c0 = bits.Add64(0, c1, c0) + u3, _ = bits.Add64(u3, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + c2, _ = bits.Add64(c2, 0, c0) + t2, c0 = bits.Add64(t3, t2, 0) + t3, _ = bits.Add64(u3, c2, c0) + + } + { + var c0, c1, c2 uint64 + v := x[2] + u0, c1 = bits.Mul64(v, x[0]) + t0, c0 = bits.Add64(c1, t0, 0) + u1, c1 = bits.Mul64(v, x[1]) + t1, c0 = bits.Add64(c1, t1, c0) + u2, c1 = bits.Mul64(v, x[2]) + t2, c0 = bits.Add64(c1, t2, c0) + u3, c1 = bits.Mul64(v, x[3]) + t3, c0 = bits.Add64(c1, t3, c0) + + c2, _ = bits.Add64(0, 0, c0) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + c2, _ = bits.Add64(u3, c2, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + + t2, c0 = bits.Add64(0, c1, c0) + u3, _ = bits.Add64(u3, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + c2, _ = bits.Add64(c2, 0, c0) + t2, c0 = bits.Add64(t3, t2, 0) + t3, _ = bits.Add64(u3, c2, c0) + + } + { + var c0, c1, c2 uint64 + v := x[3] + u0, c1 = bits.Mul64(v, x[0]) + t0, c0 = bits.Add64(c1, t0, 0) + u1, c1 = bits.Mul64(v, x[1]) + t1, c0 = bits.Add64(c1, t1, c0) + u2, c1 = bits.Mul64(v, x[2]) + t2, c0 = bits.Add64(c1, t2, c0) + u3, c1 = bits.Mul64(v, x[3]) + t3, c0 = bits.Add64(c1, t3, c0) + + c2, _ = bits.Add64(0, 0, c0) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + c2, _ = bits.Add64(u3, c2, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + + t2, c0 = bits.Add64(0, c1, c0) + u3, _ = bits.Add64(u3, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + c2, _ = bits.Add64(c2, 0, c0) + t2, c0 = bits.Add64(t3, t2, 0) + t3, _ = bits.Add64(u3, c2, c0) + + } + z[0] = t0 + z[1] = t1 + z[2] = t2 + z[3] = t3 + + // if z ⩾ q → z -= q + if !z.smallerThanModulus() { + var b uint64 + z[0], b = bits.Sub64(z[0], q0, 0) + z[1], b = bits.Sub64(z[1], q1, b) + z[2], b = bits.Sub64(z[2], q2, b) + z[3], _ = bits.Sub64(z[3], q3, b) + } + return z +} + +// Butterfly sets +// +// a = a + b (mod q) +// b = a - b (mod q) +func Butterfly(a, b *Element) { + _butterflyGeneric(a, b) +} diff --git a/ecc/grumpkin/fp/element_test.go b/ecc/grumpkin/fp/element_test.go new file mode 100644 index 0000000000..b80a36d933 --- /dev/null +++ b/ecc/grumpkin/fp/element_test.go @@ -0,0 +1,2885 @@ +// Copyright 2020-2025 Consensys Software Inc. +// Licensed under the Apache License, Version 2.0. See the LICENSE file for details. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package fp + +import ( + "crypto/rand" + "encoding/json" + "fmt" + "math/big" + "math/bits" + + mrand "math/rand" + + "testing" + + "github.com/leanovate/gopter" + ggen "github.com/leanovate/gopter/gen" + "github.com/leanovate/gopter/prop" + + "github.com/stretchr/testify/require" +) + +// ------------------------------------------------------------------------------------------------- +// benchmarks +// most benchmarks are rudimentary and should sample a large number of random inputs +// or be run multiple times to ensure it didn't measure the fastest path of the function + +var benchResElement Element + +func BenchmarkElementSelect(b *testing.B) { + var x, y Element + x.SetRandom() + y.SetRandom() + + b.ResetTimer() + for i := 0; i < b.N; i++ { + benchResElement.Select(i%3, &x, &y) + } +} + +func BenchmarkElementSetRandom(b *testing.B) { + var x Element + x.SetRandom() + + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, _ = x.SetRandom() + } +} + +func BenchmarkElementSetBytes(b *testing.B) { + var x Element + x.SetRandom() + bb := x.Bytes() + b.ResetTimer() + + for i := 0; i < b.N; i++ { + benchResElement.SetBytes(bb[:]) + } + +} + +func BenchmarkElementMulByConstants(b *testing.B) { + b.Run("mulBy3", func(b *testing.B) { + benchResElement.SetRandom() + b.ResetTimer() + for i := 0; i < b.N; i++ { + MulBy3(&benchResElement) + } + }) + b.Run("mulBy5", func(b *testing.B) { + benchResElement.SetRandom() + b.ResetTimer() + for i := 0; i < b.N; i++ { + MulBy5(&benchResElement) + } + }) + b.Run("mulBy13", func(b *testing.B) { + benchResElement.SetRandom() + b.ResetTimer() + for i := 0; i < b.N; i++ { + MulBy13(&benchResElement) + } + }) +} + +func BenchmarkElementInverse(b *testing.B) { + var x Element + x.SetRandom() + benchResElement.SetRandom() + b.ResetTimer() + + for i := 0; i < b.N; i++ { + benchResElement.Inverse(&x) + } + +} + +func BenchmarkElementButterfly(b *testing.B) { + var x Element + x.SetRandom() + benchResElement.SetRandom() + b.ResetTimer() + for i := 0; i < b.N; i++ { + Butterfly(&x, &benchResElement) + } +} + +func BenchmarkElementExp(b *testing.B) { + var x Element + x.SetRandom() + benchResElement.SetRandom() + b1, _ := rand.Int(rand.Reader, Modulus()) + b.ResetTimer() + for i := 0; i < b.N; i++ { + benchResElement.Exp(x, b1) + } +} + +func BenchmarkElementDouble(b *testing.B) { + benchResElement.SetRandom() + b.ResetTimer() + for i := 0; i < b.N; i++ { + benchResElement.Double(&benchResElement) + } +} + +func BenchmarkElementAdd(b *testing.B) { + var x Element + x.SetRandom() + benchResElement.SetRandom() + b.ResetTimer() + for i := 0; i < b.N; i++ { + benchResElement.Add(&x, &benchResElement) + } +} + +func BenchmarkElementSub(b *testing.B) { + var x Element + x.SetRandom() + benchResElement.SetRandom() + b.ResetTimer() + for i := 0; i < b.N; i++ { + benchResElement.Sub(&x, &benchResElement) + } +} + +func BenchmarkElementNeg(b *testing.B) { + benchResElement.SetRandom() + b.ResetTimer() + for i := 0; i < b.N; i++ { + benchResElement.Neg(&benchResElement) + } +} + +func BenchmarkElementDiv(b *testing.B) { + var x Element + x.SetRandom() + benchResElement.SetRandom() + b.ResetTimer() + for i := 0; i < b.N; i++ { + benchResElement.Div(&x, &benchResElement) + } +} + +func BenchmarkElementFromMont(b *testing.B) { + benchResElement.SetRandom() + b.ResetTimer() + for i := 0; i < b.N; i++ { + benchResElement.fromMont() + } +} + +func BenchmarkElementSquare(b *testing.B) { + benchResElement.SetRandom() + b.ResetTimer() + for i := 0; i < b.N; i++ { + benchResElement.Square(&benchResElement) + } +} + +func BenchmarkElementSqrt(b *testing.B) { + var a Element + a.SetUint64(4) + a.Neg(&a) + b.ResetTimer() + for i := 0; i < b.N; i++ { + benchResElement.Sqrt(&a) + } +} + +func BenchmarkElementMul(b *testing.B) { + x := Element{ + 1997599621687373223, + 6052339484930628067, + 10108755138030829701, + 150537098327114917, + } + benchResElement.SetOne() + b.ResetTimer() + for i := 0; i < b.N; i++ { + benchResElement.Mul(&benchResElement, &x) + } +} + +func BenchmarkElementCmp(b *testing.B) { + x := Element{ + 1997599621687373223, + 6052339484930628067, + 10108755138030829701, + 150537098327114917, + } + benchResElement = x + benchResElement[0] = 0 + b.ResetTimer() + for i := 0; i < b.N; i++ { + benchResElement.Cmp(&x) + } +} + +func TestElementCmp(t *testing.T) { + var x, y Element + + if x.Cmp(&y) != 0 { + t.Fatal("x == y") + } + + one := One() + y.Sub(&y, &one) + + if x.Cmp(&y) != -1 { + t.Fatal("x < y") + } + if y.Cmp(&x) != 1 { + t.Fatal("x < y") + } + + x = y + if x.Cmp(&y) != 0 { + t.Fatal("x == y") + } + + x.Sub(&x, &one) + if x.Cmp(&y) != -1 { + t.Fatal("x < y") + } + if y.Cmp(&x) != 1 { + t.Fatal("x < y") + } +} +func TestElementIsRandom(t *testing.T) { + for i := 0; i < 50; i++ { + var x, y Element + x.SetRandom() + y.SetRandom() + if x.Equal(&y) { + t.Fatal("2 random numbers are unlikely to be equal") + } + } +} + +func TestElementIsUint64(t *testing.T) { + t.Parallel() + parameters := gopter.DefaultTestParameters() + if testing.Short() { + parameters.MinSuccessfulTests = nbFuzzShort + } else { + parameters.MinSuccessfulTests = nbFuzz + } + + properties := gopter.NewProperties(parameters) + + properties.Property("reduce should output a result smaller than modulus", prop.ForAll( + func(v uint64) bool { + var e Element + e.SetUint64(v) + + if !e.IsUint64() { + return false + } + + return e.Uint64() == v + }, + ggen.UInt64(), + )) + + properties.TestingRun(t, gopter.ConsoleReporter(false)) +} + +func TestElementNegZero(t *testing.T) { + var a, b Element + b.SetZero() + for a.IsZero() { + a.SetRandom() + } + a.Neg(&b) + if !a.IsZero() { + t.Fatal("neg(0) != 0") + } +} + +// ------------------------------------------------------------------------------------------------- +// Gopter tests +// most of them are generated with a template + +const ( + nbFuzzShort = 200 + nbFuzz = 1000 +) + +// special values to be used in tests +var staticTestValues []Element + +func init() { + staticTestValues = append(staticTestValues, Element{}) // zero + staticTestValues = append(staticTestValues, One()) // one + staticTestValues = append(staticTestValues, rSquare) // r² + var e, one Element + one.SetOne() + e.Sub(&qElement, &one) + staticTestValues = append(staticTestValues, e) // q - 1 + e.Double(&one) + staticTestValues = append(staticTestValues, e) // 2 + + { + a := qElement + a[0]-- + staticTestValues = append(staticTestValues, a) + } + staticTestValues = append(staticTestValues, Element{0}) + staticTestValues = append(staticTestValues, Element{0, 0}) + staticTestValues = append(staticTestValues, Element{1}) + staticTestValues = append(staticTestValues, Element{0, 1}) + staticTestValues = append(staticTestValues, Element{2}) + staticTestValues = append(staticTestValues, Element{0, 2}) + + { + a := qElement + a[3]-- + staticTestValues = append(staticTestValues, a) + } + { + a := qElement + a[3]-- + a[0]++ + staticTestValues = append(staticTestValues, a) + } + + { + a := qElement + a[3] = 0 + staticTestValues = append(staticTestValues, a) + } + +} + +func TestElementReduce(t *testing.T) { + testValues := make([]Element, len(staticTestValues)) + copy(testValues, staticTestValues) + + for i := range testValues { + s := testValues[i] + expected := s + reduce(&s) + _reduceGeneric(&expected) + if !s.Equal(&expected) { + t.Fatal("reduce failed: asm and generic impl don't match") + } + } + + t.Parallel() + parameters := gopter.DefaultTestParameters() + if testing.Short() { + parameters.MinSuccessfulTests = nbFuzzShort + } else { + parameters.MinSuccessfulTests = nbFuzz + } + + properties := gopter.NewProperties(parameters) + + genA := genFull() + + properties.Property("reduce should output a result smaller than modulus", prop.ForAll( + func(a Element) bool { + b := a + reduce(&a) + _reduceGeneric(&b) + return a.smallerThanModulus() && a.Equal(&b) + }, + genA, + )) + + properties.TestingRun(t, gopter.ConsoleReporter(false)) + +} + +func TestElementEqual(t *testing.T) { + t.Parallel() + parameters := gopter.DefaultTestParameters() + if testing.Short() { + parameters.MinSuccessfulTests = nbFuzzShort + } else { + parameters.MinSuccessfulTests = nbFuzz + } + + properties := gopter.NewProperties(parameters) + + genA := gen() + genB := gen() + + properties.Property("x.Equal(&y) iff x == y; likely false for random pairs", prop.ForAll( + func(a testPairElement, b testPairElement) bool { + return a.element.Equal(&b.element) == (a.element == b.element) + }, + genA, + genB, + )) + + properties.Property("x.Equal(&y) if x == y", prop.ForAll( + func(a testPairElement) bool { + b := a.element + return a.element.Equal(&b) + }, + genA, + )) + + properties.TestingRun(t, gopter.ConsoleReporter(false)) +} + +func TestElementBytes(t *testing.T) { + t.Parallel() + parameters := gopter.DefaultTestParameters() + if testing.Short() { + parameters.MinSuccessfulTests = nbFuzzShort + } else { + parameters.MinSuccessfulTests = nbFuzz + } + + properties := gopter.NewProperties(parameters) + + genA := gen() + + properties.Property("SetBytes(Bytes()) should stay constant", prop.ForAll( + func(a testPairElement) bool { + var b Element + bytes := a.element.Bytes() + b.SetBytes(bytes[:]) + return a.element.Equal(&b) + }, + genA, + )) + + properties.TestingRun(t, gopter.ConsoleReporter(false)) +} + +func TestElementInverseExp(t *testing.T) { + // inverse must be equal to exp^-2 + exp := Modulus() + exp.Sub(exp, new(big.Int).SetUint64(2)) + + invMatchExp := func(a testPairElement) bool { + var b Element + b.Set(&a.element) + a.element.Inverse(&a.element) + b.Exp(b, exp) + + return a.element.Equal(&b) + } + + t.Parallel() + parameters := gopter.DefaultTestParameters() + if testing.Short() { + parameters.MinSuccessfulTests = nbFuzzShort + } else { + parameters.MinSuccessfulTests = nbFuzz + } + properties := gopter.NewProperties(parameters) + genA := gen() + properties.Property("inv == exp^-2", prop.ForAll(invMatchExp, genA)) + properties.TestingRun(t, gopter.ConsoleReporter(false)) + + parameters.MinSuccessfulTests = 1 + properties = gopter.NewProperties(parameters) + properties.Property("inv(0) == 0", prop.ForAll(invMatchExp, ggen.OneConstOf(testPairElement{}))) + properties.TestingRun(t, gopter.ConsoleReporter(false)) + +} + +func mulByConstant(z *Element, c uint8) { + var y Element + y.SetUint64(uint64(c)) + z.Mul(z, &y) +} + +func TestElementMulByConstants(t *testing.T) { + + t.Parallel() + parameters := gopter.DefaultTestParameters() + if testing.Short() { + parameters.MinSuccessfulTests = nbFuzzShort + } else { + parameters.MinSuccessfulTests = nbFuzz + } + + properties := gopter.NewProperties(parameters) + + genA := gen() + + implemented := []uint8{0, 1, 2, 3, 5, 13} + properties.Property("mulByConstant", prop.ForAll( + func(a testPairElement) bool { + for _, c := range implemented { + var constant Element + constant.SetUint64(uint64(c)) + + b := a.element + b.Mul(&b, &constant) + + aa := a.element + mulByConstant(&aa, c) + + if !aa.Equal(&b) { + return false + } + } + + return true + }, + genA, + )) + + properties.Property("MulBy3(x) == Mul(x, 3)", prop.ForAll( + func(a testPairElement) bool { + var constant Element + constant.SetUint64(3) + + b := a.element + b.Mul(&b, &constant) + + MulBy3(&a.element) + + return a.element.Equal(&b) + }, + genA, + )) + + properties.Property("MulBy5(x) == Mul(x, 5)", prop.ForAll( + func(a testPairElement) bool { + var constant Element + constant.SetUint64(5) + + b := a.element + b.Mul(&b, &constant) + + MulBy5(&a.element) + + return a.element.Equal(&b) + }, + genA, + )) + + properties.Property("MulBy13(x) == Mul(x, 13)", prop.ForAll( + func(a testPairElement) bool { + var constant Element + constant.SetUint64(13) + + b := a.element + b.Mul(&b, &constant) + + MulBy13(&a.element) + + return a.element.Equal(&b) + }, + genA, + )) + + properties.TestingRun(t, gopter.ConsoleReporter(false)) + +} + +func TestElementLegendre(t *testing.T) { + t.Parallel() + parameters := gopter.DefaultTestParameters() + if testing.Short() { + parameters.MinSuccessfulTests = nbFuzzShort + } else { + parameters.MinSuccessfulTests = nbFuzz + } + + properties := gopter.NewProperties(parameters) + + genA := gen() + + properties.Property("legendre should output same result than big.Int.Jacobi", prop.ForAll( + func(a testPairElement) bool { + return a.element.Legendre() == big.Jacobi(&a.bigint, Modulus()) + }, + genA, + )) + + properties.TestingRun(t, gopter.ConsoleReporter(false)) + +} + +func TestElementBitLen(t *testing.T) { + t.Parallel() + parameters := gopter.DefaultTestParameters() + if testing.Short() { + parameters.MinSuccessfulTests = nbFuzzShort + } else { + parameters.MinSuccessfulTests = nbFuzz + } + + properties := gopter.NewProperties(parameters) + + genA := gen() + + properties.Property("BitLen should output same result than big.Int.BitLen", prop.ForAll( + func(a testPairElement) bool { + return a.element.fromMont().BitLen() == a.bigint.BitLen() + }, + genA, + )) + + properties.TestingRun(t, gopter.ConsoleReporter(false)) +} + +func TestElementButterflies(t *testing.T) { + + t.Parallel() + parameters := gopter.DefaultTestParameters() + if testing.Short() { + parameters.MinSuccessfulTests = nbFuzzShort + } else { + parameters.MinSuccessfulTests = nbFuzz + } + + properties := gopter.NewProperties(parameters) + + genA := gen() + + properties.Property("butterfly0 == a -b; a +b", prop.ForAll( + func(a, b testPairElement) bool { + a0, b0 := a.element, b.element + + _butterflyGeneric(&a.element, &b.element) + Butterfly(&a0, &b0) + + return a.element.Equal(&a0) && b.element.Equal(&b0) + }, + genA, + genA, + )) + + properties.TestingRun(t, gopter.ConsoleReporter(false)) + +} + +func TestElementLexicographicallyLargest(t *testing.T) { + t.Parallel() + parameters := gopter.DefaultTestParameters() + if testing.Short() { + parameters.MinSuccessfulTests = nbFuzzShort + } else { + parameters.MinSuccessfulTests = nbFuzz + } + + properties := gopter.NewProperties(parameters) + + genA := gen() + + properties.Property("element.Cmp should match LexicographicallyLargest output", prop.ForAll( + func(a testPairElement) bool { + var negA Element + negA.Neg(&a.element) + + cmpResult := a.element.Cmp(&negA) + lResult := a.element.LexicographicallyLargest() + + if lResult && cmpResult == 1 { + return true + } + if !lResult && cmpResult != 1 { + return true + } + return false + }, + genA, + )) + + properties.TestingRun(t, gopter.ConsoleReporter(false)) + +} + +func TestElementAdd(t *testing.T) { + t.Parallel() + parameters := gopter.DefaultTestParameters() + if testing.Short() { + parameters.MinSuccessfulTests = nbFuzzShort + } else { + parameters.MinSuccessfulTests = nbFuzz + } + + properties := gopter.NewProperties(parameters) + + genA := gen() + genB := gen() + + properties.Property("Add: having the receiver as operand should output the same result", prop.ForAll( + func(a, b testPairElement) bool { + var c, d Element + d.Set(&a.element) + + c.Add(&a.element, &b.element) + a.element.Add(&a.element, &b.element) + b.element.Add(&d, &b.element) + + return a.element.Equal(&b.element) && a.element.Equal(&c) && b.element.Equal(&c) + }, + genA, + genB, + )) + + properties.Property("Add: operation result must match big.Int result", prop.ForAll( + func(a, b testPairElement) bool { + { + var c Element + + c.Add(&a.element, &b.element) + + var d, e big.Int + d.Add(&a.bigint, &b.bigint).Mod(&d, Modulus()) + + if c.BigInt(&e).Cmp(&d) != 0 { + return false + } + } + + // fixed elements + // a is random + // r takes special values + testValues := make([]Element, len(staticTestValues)) + copy(testValues, staticTestValues) + + for i := range testValues { + r := testValues[i] + var d, e, rb big.Int + r.BigInt(&rb) + + var c Element + c.Add(&a.element, &r) + d.Add(&a.bigint, &rb).Mod(&d, Modulus()) + + if c.BigInt(&e).Cmp(&d) != 0 { + return false + } + } + return true + }, + genA, + genB, + )) + + properties.Property("Add: operation result must be smaller than modulus", prop.ForAll( + func(a, b testPairElement) bool { + var c Element + + c.Add(&a.element, &b.element) + + return c.smallerThanModulus() + }, + genA, + genB, + )) + + specialValueTest := func() { + // test special values against special values + testValues := make([]Element, len(staticTestValues)) + copy(testValues, staticTestValues) + + for i := range testValues { + a := testValues[i] + var aBig big.Int + a.BigInt(&aBig) + for j := range testValues { + b := testValues[j] + var bBig, d, e big.Int + b.BigInt(&bBig) + + var c Element + c.Add(&a, &b) + d.Add(&aBig, &bBig).Mod(&d, Modulus()) + + if c.BigInt(&e).Cmp(&d) != 0 { + t.Fatal("Add failed special test values") + } + } + } + } + + properties.TestingRun(t, gopter.ConsoleReporter(false)) + specialValueTest() + +} + +func TestElementSub(t *testing.T) { + t.Parallel() + parameters := gopter.DefaultTestParameters() + if testing.Short() { + parameters.MinSuccessfulTests = nbFuzzShort + } else { + parameters.MinSuccessfulTests = nbFuzz + } + + properties := gopter.NewProperties(parameters) + + genA := gen() + genB := gen() + + properties.Property("Sub: having the receiver as operand should output the same result", prop.ForAll( + func(a, b testPairElement) bool { + var c, d Element + d.Set(&a.element) + + c.Sub(&a.element, &b.element) + a.element.Sub(&a.element, &b.element) + b.element.Sub(&d, &b.element) + + return a.element.Equal(&b.element) && a.element.Equal(&c) && b.element.Equal(&c) + }, + genA, + genB, + )) + + properties.Property("Sub: operation result must match big.Int result", prop.ForAll( + func(a, b testPairElement) bool { + { + var c Element + + c.Sub(&a.element, &b.element) + + var d, e big.Int + d.Sub(&a.bigint, &b.bigint).Mod(&d, Modulus()) + + if c.BigInt(&e).Cmp(&d) != 0 { + return false + } + } + + // fixed elements + // a is random + // r takes special values + testValues := make([]Element, len(staticTestValues)) + copy(testValues, staticTestValues) + + for i := range testValues { + r := testValues[i] + var d, e, rb big.Int + r.BigInt(&rb) + + var c Element + c.Sub(&a.element, &r) + d.Sub(&a.bigint, &rb).Mod(&d, Modulus()) + + if c.BigInt(&e).Cmp(&d) != 0 { + return false + } + } + return true + }, + genA, + genB, + )) + + properties.Property("Sub: operation result must be smaller than modulus", prop.ForAll( + func(a, b testPairElement) bool { + var c Element + + c.Sub(&a.element, &b.element) + + return c.smallerThanModulus() + }, + genA, + genB, + )) + + specialValueTest := func() { + // test special values against special values + testValues := make([]Element, len(staticTestValues)) + copy(testValues, staticTestValues) + + for i := range testValues { + a := testValues[i] + var aBig big.Int + a.BigInt(&aBig) + for j := range testValues { + b := testValues[j] + var bBig, d, e big.Int + b.BigInt(&bBig) + + var c Element + c.Sub(&a, &b) + d.Sub(&aBig, &bBig).Mod(&d, Modulus()) + + if c.BigInt(&e).Cmp(&d) != 0 { + t.Fatal("Sub failed special test values") + } + } + } + } + + properties.TestingRun(t, gopter.ConsoleReporter(false)) + specialValueTest() + +} + +func TestElementMul(t *testing.T) { + t.Parallel() + parameters := gopter.DefaultTestParameters() + if testing.Short() { + parameters.MinSuccessfulTests = nbFuzzShort + } else { + parameters.MinSuccessfulTests = nbFuzz + } + + properties := gopter.NewProperties(parameters) + + genA := gen() + genB := gen() + + properties.Property("Mul: having the receiver as operand should output the same result", prop.ForAll( + func(a, b testPairElement) bool { + var c, d Element + d.Set(&a.element) + + c.Mul(&a.element, &b.element) + a.element.Mul(&a.element, &b.element) + b.element.Mul(&d, &b.element) + + return a.element.Equal(&b.element) && a.element.Equal(&c) && b.element.Equal(&c) + }, + genA, + genB, + )) + + properties.Property("Mul: operation result must match big.Int result", prop.ForAll( + func(a, b testPairElement) bool { + { + var c Element + + c.Mul(&a.element, &b.element) + + var d, e big.Int + d.Mul(&a.bigint, &b.bigint).Mod(&d, Modulus()) + + if c.BigInt(&e).Cmp(&d) != 0 { + return false + } + } + + // fixed elements + // a is random + // r takes special values + testValues := make([]Element, len(staticTestValues)) + copy(testValues, staticTestValues) + + for i := range testValues { + r := testValues[i] + var d, e, rb big.Int + r.BigInt(&rb) + + var c Element + c.Mul(&a.element, &r) + d.Mul(&a.bigint, &rb).Mod(&d, Modulus()) + + // checking generic impl against asm path + var cGeneric Element + _mulGeneric(&cGeneric, &a.element, &r) + if !cGeneric.Equal(&c) { + // need to give context to failing error. + return false + } + + if c.BigInt(&e).Cmp(&d) != 0 { + return false + } + } + return true + }, + genA, + genB, + )) + + properties.Property("Mul: operation result must be smaller than modulus", prop.ForAll( + func(a, b testPairElement) bool { + var c Element + + c.Mul(&a.element, &b.element) + + return c.smallerThanModulus() + }, + genA, + genB, + )) + + properties.Property("Mul: assembly implementation must be consistent with generic one", prop.ForAll( + func(a, b testPairElement) bool { + var c, d Element + c.Mul(&a.element, &b.element) + _mulGeneric(&d, &a.element, &b.element) + return c.Equal(&d) + }, + genA, + genB, + )) + + specialValueTest := func() { + // test special values against special values + testValues := make([]Element, len(staticTestValues)) + copy(testValues, staticTestValues) + + for i := range testValues { + a := testValues[i] + var aBig big.Int + a.BigInt(&aBig) + for j := range testValues { + b := testValues[j] + var bBig, d, e big.Int + b.BigInt(&bBig) + + var c Element + c.Mul(&a, &b) + d.Mul(&aBig, &bBig).Mod(&d, Modulus()) + + // checking asm against generic impl + var cGeneric Element + _mulGeneric(&cGeneric, &a, &b) + if !cGeneric.Equal(&c) { + t.Fatal("Mul failed special test values: asm and generic impl don't match") + } + + if c.BigInt(&e).Cmp(&d) != 0 { + t.Fatal("Mul failed special test values") + } + } + } + } + + properties.TestingRun(t, gopter.ConsoleReporter(false)) + specialValueTest() + +} + +func TestElementDiv(t *testing.T) { + t.Parallel() + parameters := gopter.DefaultTestParameters() + if testing.Short() { + parameters.MinSuccessfulTests = nbFuzzShort + } else { + parameters.MinSuccessfulTests = nbFuzz + } + + properties := gopter.NewProperties(parameters) + + genA := gen() + genB := gen() + + properties.Property("Div: having the receiver as operand should output the same result", prop.ForAll( + func(a, b testPairElement) bool { + var c, d Element + d.Set(&a.element) + + c.Div(&a.element, &b.element) + a.element.Div(&a.element, &b.element) + b.element.Div(&d, &b.element) + + return a.element.Equal(&b.element) && a.element.Equal(&c) && b.element.Equal(&c) + }, + genA, + genB, + )) + + properties.Property("Div: operation result must match big.Int result", prop.ForAll( + func(a, b testPairElement) bool { + { + var c Element + + c.Div(&a.element, &b.element) + + var d, e big.Int + d.ModInverse(&b.bigint, Modulus()) + d.Mul(&d, &a.bigint).Mod(&d, Modulus()) + + if c.BigInt(&e).Cmp(&d) != 0 { + return false + } + } + + // fixed elements + // a is random + // r takes special values + testValues := make([]Element, len(staticTestValues)) + copy(testValues, staticTestValues) + + for i := range testValues { + r := testValues[i] + var d, e, rb big.Int + r.BigInt(&rb) + + var c Element + c.Div(&a.element, &r) + d.ModInverse(&rb, Modulus()) + d.Mul(&d, &a.bigint).Mod(&d, Modulus()) + + if c.BigInt(&e).Cmp(&d) != 0 { + return false + } + } + return true + }, + genA, + genB, + )) + + properties.Property("Div: operation result must be smaller than modulus", prop.ForAll( + func(a, b testPairElement) bool { + var c Element + + c.Div(&a.element, &b.element) + + return c.smallerThanModulus() + }, + genA, + genB, + )) + + specialValueTest := func() { + // test special values against special values + testValues := make([]Element, len(staticTestValues)) + copy(testValues, staticTestValues) + + for i := range testValues { + a := testValues[i] + var aBig big.Int + a.BigInt(&aBig) + for j := range testValues { + b := testValues[j] + var bBig, d, e big.Int + b.BigInt(&bBig) + + var c Element + c.Div(&a, &b) + d.ModInverse(&bBig, Modulus()) + d.Mul(&d, &aBig).Mod(&d, Modulus()) + + if c.BigInt(&e).Cmp(&d) != 0 { + t.Fatal("Div failed special test values") + } + } + } + } + + properties.TestingRun(t, gopter.ConsoleReporter(false)) + specialValueTest() + +} + +func TestElementExp(t *testing.T) { + t.Parallel() + parameters := gopter.DefaultTestParameters() + if testing.Short() { + parameters.MinSuccessfulTests = nbFuzzShort + } else { + parameters.MinSuccessfulTests = nbFuzz + } + + properties := gopter.NewProperties(parameters) + + genA := gen() + genB := gen() + + properties.Property("Exp: having the receiver as operand should output the same result", prop.ForAll( + func(a, b testPairElement) bool { + var c, d Element + d.Set(&a.element) + + c.Exp(a.element, &b.bigint) + a.element.Exp(a.element, &b.bigint) + b.element.Exp(d, &b.bigint) + + return a.element.Equal(&b.element) && a.element.Equal(&c) && b.element.Equal(&c) + }, + genA, + genB, + )) + + properties.Property("Exp: operation result must match big.Int result", prop.ForAll( + func(a, b testPairElement) bool { + { + var c Element + + c.Exp(a.element, &b.bigint) + + var d, e big.Int + d.Exp(&a.bigint, &b.bigint, Modulus()) + + if c.BigInt(&e).Cmp(&d) != 0 { + return false + } + } + + // fixed elements + // a is random + // r takes special values + testValues := make([]Element, len(staticTestValues)) + copy(testValues, staticTestValues) + + for i := range testValues { + r := testValues[i] + var d, e, rb big.Int + r.BigInt(&rb) + + var c Element + c.Exp(a.element, &rb) + d.Exp(&a.bigint, &rb, Modulus()) + + if c.BigInt(&e).Cmp(&d) != 0 { + return false + } + } + return true + }, + genA, + genB, + )) + + properties.Property("Exp: operation result must be smaller than modulus", prop.ForAll( + func(a, b testPairElement) bool { + var c Element + + c.Exp(a.element, &b.bigint) + + return c.smallerThanModulus() + }, + genA, + genB, + )) + + specialValueTest := func() { + // test special values against special values + testValues := make([]Element, len(staticTestValues)) + copy(testValues, staticTestValues) + + for i := range testValues { + a := testValues[i] + var aBig big.Int + a.BigInt(&aBig) + for j := range testValues { + b := testValues[j] + var bBig, d, e big.Int + b.BigInt(&bBig) + + var c Element + c.Exp(a, &bBig) + d.Exp(&aBig, &bBig, Modulus()) + + if c.BigInt(&e).Cmp(&d) != 0 { + t.Fatal("Exp failed special test values") + } + } + } + } + + properties.TestingRun(t, gopter.ConsoleReporter(false)) + specialValueTest() + +} + +func TestElementSquare(t *testing.T) { + t.Parallel() + parameters := gopter.DefaultTestParameters() + if testing.Short() { + parameters.MinSuccessfulTests = nbFuzzShort + } else { + parameters.MinSuccessfulTests = nbFuzz + } + + properties := gopter.NewProperties(parameters) + + genA := gen() + + properties.Property("Square: having the receiver as operand should output the same result", prop.ForAll( + func(a testPairElement) bool { + + var b Element + + b.Square(&a.element) + a.element.Square(&a.element) + return a.element.Equal(&b) + }, + genA, + )) + + properties.Property("Square: operation result must match big.Int result", prop.ForAll( + func(a testPairElement) bool { + var c Element + c.Square(&a.element) + + var d, e big.Int + d.Mul(&a.bigint, &a.bigint).Mod(&d, Modulus()) + + return c.BigInt(&e).Cmp(&d) == 0 + }, + genA, + )) + + properties.Property("Square: operation result must be smaller than modulus", prop.ForAll( + func(a testPairElement) bool { + var c Element + c.Square(&a.element) + return c.smallerThanModulus() + }, + genA, + )) + + specialValueTest := func() { + // test special values + testValues := make([]Element, len(staticTestValues)) + copy(testValues, staticTestValues) + + for i := range testValues { + a := testValues[i] + var aBig big.Int + a.BigInt(&aBig) + var c Element + c.Square(&a) + + var d, e big.Int + d.Mul(&aBig, &aBig).Mod(&d, Modulus()) + + if c.BigInt(&e).Cmp(&d) != 0 { + t.Fatal("Square failed special test values") + } + } + } + + properties.TestingRun(t, gopter.ConsoleReporter(false)) + specialValueTest() + +} + +func TestElementInverse(t *testing.T) { + t.Parallel() + parameters := gopter.DefaultTestParameters() + if testing.Short() { + parameters.MinSuccessfulTests = nbFuzzShort + } else { + parameters.MinSuccessfulTests = nbFuzz + } + + properties := gopter.NewProperties(parameters) + + genA := gen() + + properties.Property("Inverse: having the receiver as operand should output the same result", prop.ForAll( + func(a testPairElement) bool { + + var b Element + + b.Inverse(&a.element) + a.element.Inverse(&a.element) + return a.element.Equal(&b) + }, + genA, + )) + + properties.Property("Inverse: operation result must match big.Int result", prop.ForAll( + func(a testPairElement) bool { + var c Element + c.Inverse(&a.element) + + var d, e big.Int + d.ModInverse(&a.bigint, Modulus()) + + return c.BigInt(&e).Cmp(&d) == 0 + }, + genA, + )) + + properties.Property("Inverse: operation result must be smaller than modulus", prop.ForAll( + func(a testPairElement) bool { + var c Element + c.Inverse(&a.element) + return c.smallerThanModulus() + }, + genA, + )) + + specialValueTest := func() { + // test special values + testValues := make([]Element, len(staticTestValues)) + copy(testValues, staticTestValues) + + for i := range testValues { + a := testValues[i] + var aBig big.Int + a.BigInt(&aBig) + var c Element + c.Inverse(&a) + + var d, e big.Int + d.ModInverse(&aBig, Modulus()) + + if c.BigInt(&e).Cmp(&d) != 0 { + t.Fatal("Inverse failed special test values") + } + } + } + + properties.TestingRun(t, gopter.ConsoleReporter(false)) + specialValueTest() + +} + +func TestElementSqrt(t *testing.T) { + t.Parallel() + parameters := gopter.DefaultTestParameters() + if testing.Short() { + parameters.MinSuccessfulTests = nbFuzzShort + } else { + parameters.MinSuccessfulTests = nbFuzz + } + + properties := gopter.NewProperties(parameters) + + genA := gen() + + properties.Property("Sqrt: having the receiver as operand should output the same result", prop.ForAll( + func(a testPairElement) bool { + + b := a.element + + b.Sqrt(&a.element) + a.element.Sqrt(&a.element) + return a.element.Equal(&b) + }, + genA, + )) + + properties.Property("Sqrt: operation result must match big.Int result", prop.ForAll( + func(a testPairElement) bool { + var c Element + c.Sqrt(&a.element) + + var d, e big.Int + d.ModSqrt(&a.bigint, Modulus()) + + return c.BigInt(&e).Cmp(&d) == 0 + }, + genA, + )) + + properties.Property("Sqrt: operation result must be smaller than modulus", prop.ForAll( + func(a testPairElement) bool { + var c Element + c.Sqrt(&a.element) + return c.smallerThanModulus() + }, + genA, + )) + + specialValueTest := func() { + // test special values + testValues := make([]Element, len(staticTestValues)) + copy(testValues, staticTestValues) + + for i := range testValues { + a := testValues[i] + var aBig big.Int + a.BigInt(&aBig) + var c Element + c.Sqrt(&a) + + var d, e big.Int + d.ModSqrt(&aBig, Modulus()) + + if c.BigInt(&e).Cmp(&d) != 0 { + t.Fatal("Sqrt failed special test values") + } + } + } + + properties.TestingRun(t, gopter.ConsoleReporter(false)) + specialValueTest() + +} + +func TestElementDouble(t *testing.T) { + t.Parallel() + parameters := gopter.DefaultTestParameters() + if testing.Short() { + parameters.MinSuccessfulTests = nbFuzzShort + } else { + parameters.MinSuccessfulTests = nbFuzz + } + + properties := gopter.NewProperties(parameters) + + genA := gen() + + properties.Property("Double: having the receiver as operand should output the same result", prop.ForAll( + func(a testPairElement) bool { + + var b Element + + b.Double(&a.element) + a.element.Double(&a.element) + return a.element.Equal(&b) + }, + genA, + )) + + properties.Property("Double: operation result must match big.Int result", prop.ForAll( + func(a testPairElement) bool { + var c Element + c.Double(&a.element) + + var d, e big.Int + d.Lsh(&a.bigint, 1).Mod(&d, Modulus()) + + return c.BigInt(&e).Cmp(&d) == 0 + }, + genA, + )) + + properties.Property("Double: operation result must be smaller than modulus", prop.ForAll( + func(a testPairElement) bool { + var c Element + c.Double(&a.element) + return c.smallerThanModulus() + }, + genA, + )) + + specialValueTest := func() { + // test special values + testValues := make([]Element, len(staticTestValues)) + copy(testValues, staticTestValues) + + for i := range testValues { + a := testValues[i] + var aBig big.Int + a.BigInt(&aBig) + var c Element + c.Double(&a) + + var d, e big.Int + d.Lsh(&aBig, 1).Mod(&d, Modulus()) + + if c.BigInt(&e).Cmp(&d) != 0 { + t.Fatal("Double failed special test values") + } + } + } + + properties.TestingRun(t, gopter.ConsoleReporter(false)) + specialValueTest() + +} + +func TestElementNeg(t *testing.T) { + t.Parallel() + parameters := gopter.DefaultTestParameters() + if testing.Short() { + parameters.MinSuccessfulTests = nbFuzzShort + } else { + parameters.MinSuccessfulTests = nbFuzz + } + + properties := gopter.NewProperties(parameters) + + genA := gen() + + properties.Property("Neg: having the receiver as operand should output the same result", prop.ForAll( + func(a testPairElement) bool { + + var b Element + + b.Neg(&a.element) + a.element.Neg(&a.element) + return a.element.Equal(&b) + }, + genA, + )) + + properties.Property("Neg: operation result must match big.Int result", prop.ForAll( + func(a testPairElement) bool { + var c Element + c.Neg(&a.element) + + var d, e big.Int + d.Neg(&a.bigint).Mod(&d, Modulus()) + + return c.BigInt(&e).Cmp(&d) == 0 + }, + genA, + )) + + properties.Property("Neg: operation result must be smaller than modulus", prop.ForAll( + func(a testPairElement) bool { + var c Element + c.Neg(&a.element) + return c.smallerThanModulus() + }, + genA, + )) + + specialValueTest := func() { + // test special values + testValues := make([]Element, len(staticTestValues)) + copy(testValues, staticTestValues) + + for i := range testValues { + a := testValues[i] + var aBig big.Int + a.BigInt(&aBig) + var c Element + c.Neg(&a) + + var d, e big.Int + d.Neg(&aBig).Mod(&d, Modulus()) + + if c.BigInt(&e).Cmp(&d) != 0 { + t.Fatal("Neg failed special test values") + } + } + } + + properties.TestingRun(t, gopter.ConsoleReporter(false)) + specialValueTest() + +} + +func TestElementFixedExp(t *testing.T) { + + t.Parallel() + parameters := gopter.DefaultTestParameters() + if testing.Short() { + parameters.MinSuccessfulTests = nbFuzzShort + } else { + parameters.MinSuccessfulTests = nbFuzz + } + + properties := gopter.NewProperties(parameters) + + var ( + _bLegendreExponentElement *big.Int + _bSqrtExponentElement *big.Int + ) + + _bLegendreExponentElement, _ = new(big.Int).SetString("183227397098d014dc2822db40c0ac2e9419f4243cdcb848a1f0fac9f8000000", 16) + const sqrtExponentElement = "183227397098d014dc2822db40c0ac2e9419f4243cdcb848a1f0fac9f" + _bSqrtExponentElement, _ = new(big.Int).SetString(sqrtExponentElement, 16) + + genA := gen() + + properties.Property(fmt.Sprintf("expBySqrtExp must match Exp(%s)", sqrtExponentElement), prop.ForAll( + func(a testPairElement) bool { + c := a.element + d := a.element + c.expBySqrtExp(c) + d.Exp(d, _bSqrtExponentElement) + return c.Equal(&d) + }, + genA, + )) + + properties.Property("expByLegendreExp must match Exp(183227397098d014dc2822db40c0ac2e9419f4243cdcb848a1f0fac9f8000000)", prop.ForAll( + func(a testPairElement) bool { + c := a.element + d := a.element + c.expByLegendreExp(c) + d.Exp(d, _bLegendreExponentElement) + return c.Equal(&d) + }, + genA, + )) + + properties.TestingRun(t, gopter.ConsoleReporter(false)) +} + +func TestElementHalve(t *testing.T) { + + t.Parallel() + parameters := gopter.DefaultTestParameters() + if testing.Short() { + parameters.MinSuccessfulTests = nbFuzzShort + } else { + parameters.MinSuccessfulTests = nbFuzz + } + + properties := gopter.NewProperties(parameters) + + genA := gen() + var twoInv Element + twoInv.SetUint64(2) + twoInv.Inverse(&twoInv) + + properties.Property("z.Halve must match z / 2", prop.ForAll( + func(a testPairElement) bool { + c := a.element + d := a.element + c.Halve() + d.Mul(&d, &twoInv) + return c.Equal(&d) + }, + genA, + )) + + properties.TestingRun(t, gopter.ConsoleReporter(false)) +} + +func combineSelectionArguments(c int64, z int8) int { + if z%3 == 0 { + return 0 + } + return int(c) +} + +func TestElementSelect(t *testing.T) { + t.Parallel() + parameters := gopter.DefaultTestParameters() + if testing.Short() { + parameters.MinSuccessfulTests = nbFuzzShort + } else { + parameters.MinSuccessfulTests = nbFuzz + } + + properties := gopter.NewProperties(parameters) + + genA := genFull() + genB := genFull() + genC := ggen.Int64() //the condition + genZ := ggen.Int8() //to make zeros artificially more likely + + properties.Property("Select: must select correctly", prop.ForAll( + func(a, b Element, cond int64, z int8) bool { + condC := combineSelectionArguments(cond, z) + + var c Element + c.Select(condC, &a, &b) + + if condC == 0 { + return c.Equal(&a) + } + return c.Equal(&b) + }, + genA, + genB, + genC, + genZ, + )) + + properties.Property("Select: having the receiver as operand should output the same result", prop.ForAll( + func(a, b Element, cond int64, z int8) bool { + condC := combineSelectionArguments(cond, z) + + var c, d Element + d.Set(&a) + c.Select(condC, &a, &b) + a.Select(condC, &a, &b) + b.Select(condC, &d, &b) + return a.Equal(&b) && a.Equal(&c) && b.Equal(&c) + }, + genA, + genB, + genC, + genZ, + )) + + properties.TestingRun(t, gopter.ConsoleReporter(false)) +} + +func TestElementSetInt64(t *testing.T) { + + t.Parallel() + parameters := gopter.DefaultTestParameters() + if testing.Short() { + parameters.MinSuccessfulTests = nbFuzzShort + } else { + parameters.MinSuccessfulTests = nbFuzz + } + + properties := gopter.NewProperties(parameters) + + genA := gen() + + properties.Property("z.SetInt64 must match z.SetString", prop.ForAll( + func(a testPairElement, v int64) bool { + c := a.element + d := a.element + + c.SetInt64(v) + d.SetString(fmt.Sprintf("%v", v)) + + return c.Equal(&d) + }, + genA, ggen.Int64(), + )) + + properties.TestingRun(t, gopter.ConsoleReporter(false)) +} + +func TestElementSetInterface(t *testing.T) { + + t.Parallel() + parameters := gopter.DefaultTestParameters() + if testing.Short() { + parameters.MinSuccessfulTests = nbFuzzShort + } else { + parameters.MinSuccessfulTests = nbFuzz + } + + properties := gopter.NewProperties(parameters) + + genA := gen() + genInt := ggen.Int + genInt8 := ggen.Int8 + genInt16 := ggen.Int16 + genInt32 := ggen.Int32 + genInt64 := ggen.Int64 + + genUint := ggen.UInt + genUint8 := ggen.UInt8 + genUint16 := ggen.UInt16 + genUint32 := ggen.UInt32 + genUint64 := ggen.UInt64 + + properties.Property("z.SetInterface must match z.SetString with int8", prop.ForAll( + func(a testPairElement, v int8) bool { + c := a.element + d := a.element + + c.SetInterface(v) + d.SetString(fmt.Sprintf("%v", v)) + + return c.Equal(&d) + }, + genA, genInt8(), + )) + + properties.Property("z.SetInterface must match z.SetString with int16", prop.ForAll( + func(a testPairElement, v int16) bool { + c := a.element + d := a.element + + c.SetInterface(v) + d.SetString(fmt.Sprintf("%v", v)) + + return c.Equal(&d) + }, + genA, genInt16(), + )) + + properties.Property("z.SetInterface must match z.SetString with int32", prop.ForAll( + func(a testPairElement, v int32) bool { + c := a.element + d := a.element + + c.SetInterface(v) + d.SetString(fmt.Sprintf("%v", v)) + + return c.Equal(&d) + }, + genA, genInt32(), + )) + + properties.Property("z.SetInterface must match z.SetString with int64", prop.ForAll( + func(a testPairElement, v int64) bool { + c := a.element + d := a.element + + c.SetInterface(v) + d.SetString(fmt.Sprintf("%v", v)) + + return c.Equal(&d) + }, + genA, genInt64(), + )) + + properties.Property("z.SetInterface must match z.SetString with int", prop.ForAll( + func(a testPairElement, v int) bool { + c := a.element + d := a.element + + c.SetInterface(v) + d.SetString(fmt.Sprintf("%v", v)) + + return c.Equal(&d) + }, + genA, genInt(), + )) + + properties.Property("z.SetInterface must match z.SetString with uint8", prop.ForAll( + func(a testPairElement, v uint8) bool { + c := a.element + d := a.element + + c.SetInterface(v) + d.SetString(fmt.Sprintf("%v", v)) + + return c.Equal(&d) + }, + genA, genUint8(), + )) + + properties.Property("z.SetInterface must match z.SetString with uint16", prop.ForAll( + func(a testPairElement, v uint16) bool { + c := a.element + d := a.element + + c.SetInterface(v) + d.SetString(fmt.Sprintf("%v", v)) + + return c.Equal(&d) + }, + genA, genUint16(), + )) + + properties.Property("z.SetInterface must match z.SetString with uint32", prop.ForAll( + func(a testPairElement, v uint32) bool { + c := a.element + d := a.element + + c.SetInterface(v) + d.SetString(fmt.Sprintf("%v", v)) + + return c.Equal(&d) + }, + genA, genUint32(), + )) + + properties.Property("z.SetInterface must match z.SetString with uint64", prop.ForAll( + func(a testPairElement, v uint64) bool { + c := a.element + d := a.element + + c.SetInterface(v) + d.SetString(fmt.Sprintf("%v", v)) + + return c.Equal(&d) + }, + genA, genUint64(), + )) + + properties.Property("z.SetInterface must match z.SetString with uint", prop.ForAll( + func(a testPairElement, v uint) bool { + c := a.element + d := a.element + + c.SetInterface(v) + d.SetString(fmt.Sprintf("%v", v)) + + return c.Equal(&d) + }, + genA, genUint(), + )) + + properties.TestingRun(t, gopter.ConsoleReporter(false)) + + { + assert := require.New(t) + var e Element + r, err := e.SetInterface(nil) + assert.Nil(r) + assert.Error(err) + + var ptE *Element + var ptB *big.Int + + r, err = e.SetInterface(ptE) + assert.Nil(r) + assert.Error(err) + ptE = new(Element).SetOne() + r, err = e.SetInterface(ptE) + assert.NoError(err) + assert.True(r.IsOne()) + + r, err = e.SetInterface(ptB) + assert.Nil(r) + assert.Error(err) + + } +} + +func TestElementNegativeExp(t *testing.T) { + t.Parallel() + + parameters := gopter.DefaultTestParameters() + if testing.Short() { + parameters.MinSuccessfulTests = nbFuzzShort + } else { + parameters.MinSuccessfulTests = nbFuzz + } + + properties := gopter.NewProperties(parameters) + + genA := gen() + + properties.Property("x⁻ᵏ == 1/xᵏ", prop.ForAll( + func(a, b testPairElement) bool { + + var nb, d, e big.Int + nb.Neg(&b.bigint) + + var c Element + c.Exp(a.element, &nb) + + d.Exp(&a.bigint, &nb, Modulus()) + + return c.BigInt(&e).Cmp(&d) == 0 + }, + genA, genA, + )) + + properties.TestingRun(t, gopter.ConsoleReporter(false)) +} + +func TestElementNewElement(t *testing.T) { + assert := require.New(t) + + t.Parallel() + + e := NewElement(1) + assert.True(e.IsOne()) + + e = NewElement(0) + assert.True(e.IsZero()) +} + +func TestElementBatchInvert(t *testing.T) { + assert := require.New(t) + + t.Parallel() + + // ensure batchInvert([x]) == invert(x) + for i := int64(-1); i <= 2; i++ { + var e, eInv Element + e.SetInt64(i) + eInv.Inverse(&e) + + a := []Element{e} + aInv := BatchInvert(a) + + assert.True(aInv[0].Equal(&eInv), "batchInvert != invert") + + } + + // test x * x⁻¹ == 1 + tData := [][]int64{ + {-1, 1, 2, 3}, + {0, -1, 1, 2, 3, 0}, + {0, -1, 1, 0, 2, 3, 0}, + {-1, 1, 0, 2, 3}, + {0, 0, 1}, + {1, 0, 0}, + {0, 0, 0}, + } + + for _, t := range tData { + a := make([]Element, len(t)) + for i := 0; i < len(a); i++ { + a[i].SetInt64(t[i]) + } + + aInv := BatchInvert(a) + + assert.True(len(aInv) == len(a)) + + for i := 0; i < len(a); i++ { + if a[i].IsZero() { + assert.True(aInv[i].IsZero(), "0⁻¹ != 0") + } else { + assert.True(a[i].Mul(&a[i], &aInv[i]).IsOne(), "x * x⁻¹ != 1") + } + } + } + + parameters := gopter.DefaultTestParameters() + if testing.Short() { + parameters.MinSuccessfulTests = nbFuzzShort + } else { + parameters.MinSuccessfulTests = nbFuzz + } + + properties := gopter.NewProperties(parameters) + + genA := gen() + + properties.Property("batchInvert --> x * x⁻¹ == 1", prop.ForAll( + func(tp testPairElement, r uint8) bool { + + a := make([]Element, r) + if r != 0 { + a[0] = tp.element + + } + one := One() + for i := 1; i < len(a); i++ { + a[i].Add(&a[i-1], &one) + } + + aInv := BatchInvert(a) + + assert.True(len(aInv) == len(a)) + + for i := 0; i < len(a); i++ { + if a[i].IsZero() { + if !aInv[i].IsZero() { + return false + } + } else { + if !a[i].Mul(&a[i], &aInv[i]).IsOne() { + return false + } + } + } + return true + }, + genA, ggen.UInt8(), + )) + + properties.TestingRun(t, gopter.ConsoleReporter(false)) +} + +func TestElementFromMont(t *testing.T) { + + t.Parallel() + parameters := gopter.DefaultTestParameters() + if testing.Short() { + parameters.MinSuccessfulTests = nbFuzzShort + } else { + parameters.MinSuccessfulTests = nbFuzz + } + + properties := gopter.NewProperties(parameters) + + genA := gen() + + properties.Property("Assembly implementation must be consistent with generic one", prop.ForAll( + func(a testPairElement) bool { + c := a.element + d := a.element + c.fromMont() + _fromMontGeneric(&d) + return c.Equal(&d) + }, + genA, + )) + + properties.Property("x.fromMont().toMont() == x", prop.ForAll( + func(a testPairElement) bool { + c := a.element + c.fromMont().toMont() + return c.Equal(&a.element) + }, + genA, + )) + + properties.TestingRun(t, gopter.ConsoleReporter(false)) +} + +func TestElementJSON(t *testing.T) { + assert := require.New(t) + + type S struct { + A Element + B [3]Element + C *Element + D *Element + } + + // encode to JSON + var s S + s.A.SetString("-1") + s.B[2].SetUint64(42) + s.D = new(Element).SetUint64(8000) + + encoded, err := json.Marshal(&s) + assert.NoError(err) + // we may need to adjust "42" and "8000" values for some moduli; see Text() method for more details. + formatValue := func(v int64) string { + var a big.Int + a.SetInt64(v) + a.Mod(&a, Modulus()) + const maxUint16 = 65535 + var aNeg big.Int + aNeg.Neg(&a).Mod(&aNeg, Modulus()) + if aNeg.Uint64() != 0 && aNeg.Uint64() <= maxUint16 { + return "-" + aNeg.Text(10) + } + return a.Text(10) + } + expected := fmt.Sprintf("{\"A\":%s,\"B\":[0,0,%s],\"C\":null,\"D\":%s}", formatValue(-1), formatValue(42), formatValue(8000)) + assert.Equal(expected, string(encoded)) + + // decode valid + var decoded S + err = json.Unmarshal([]byte(expected), &decoded) + assert.NoError(err) + + assert.Equal(s, decoded, "element -> json -> element round trip failed") + + // decode hex and string values + withHexValues := "{\"A\":\"-1\",\"B\":[0,\"0x00000\",\"0x2A\"],\"C\":null,\"D\":\"8000\"}" + + var decodedS S + err = json.Unmarshal([]byte(withHexValues), &decodedS) + assert.NoError(err) + + assert.Equal(s, decodedS, " json with strings -> element failed") + +} + +type testPairElement struct { + element Element + bigint big.Int +} + +func gen() gopter.Gen { + return func(genParams *gopter.GenParameters) *gopter.GenResult { + var g testPairElement + + g.element = Element{ + genParams.NextUint64(), + genParams.NextUint64(), + genParams.NextUint64(), + genParams.NextUint64(), + } + if qElement[3] != ^uint64(0) { + g.element[3] %= (qElement[3] + 1) + } + + for !g.element.smallerThanModulus() { + g.element = Element{ + genParams.NextUint64(), + genParams.NextUint64(), + genParams.NextUint64(), + genParams.NextUint64(), + } + if qElement[3] != ^uint64(0) { + g.element[3] %= (qElement[3] + 1) + } + } + + g.element.BigInt(&g.bigint) + genResult := gopter.NewGenResult(g, gopter.NoShrinker) + return genResult + } +} + +func genRandomFq(genParams *gopter.GenParameters) Element { + var g Element + + g = Element{ + genParams.NextUint64(), + genParams.NextUint64(), + genParams.NextUint64(), + genParams.NextUint64(), + } + + if qElement[3] != ^uint64(0) { + g[3] %= (qElement[3] + 1) + } + + for !g.smallerThanModulus() { + g = Element{ + genParams.NextUint64(), + genParams.NextUint64(), + genParams.NextUint64(), + genParams.NextUint64(), + } + if qElement[3] != ^uint64(0) { + g[3] %= (qElement[3] + 1) + } + } + + return g +} + +func genFull() gopter.Gen { + return func(genParams *gopter.GenParameters) *gopter.GenResult { + a := genRandomFq(genParams) + + var carry uint64 + a[0], carry = bits.Add64(a[0], qElement[0], carry) + a[1], carry = bits.Add64(a[1], qElement[1], carry) + a[2], carry = bits.Add64(a[2], qElement[2], carry) + a[3], _ = bits.Add64(a[3], qElement[3], carry) + + genResult := gopter.NewGenResult(a, gopter.NoShrinker) + return genResult + } +} + +func genElement() gopter.Gen { + return func(genParams *gopter.GenParameters) *gopter.GenResult { + a := genRandomFq(genParams) + genResult := gopter.NewGenResult(a, gopter.NoShrinker) + return genResult + } +} + +func (z *Element) matchVeryBigInt(aHi uint64, aInt *big.Int) error { + var modulus big.Int + var aIntMod big.Int + modulus.SetInt64(1) + modulus.Lsh(&modulus, (Limbs+1)*64) + aIntMod.Mod(aInt, &modulus) + + slice := append(z[:], aHi) + + return bigIntMatchUint64Slice(&aIntMod, slice) +} + +// TODO: Phase out in favor of property based testing +func (z *Element) assertMatchVeryBigInt(t *testing.T, aHi uint64, aInt *big.Int) { + + if err := z.matchVeryBigInt(aHi, aInt); err != nil { + t.Error(err) + } +} + +// bigIntMatchUint64Slice is a test helper to match big.Int words against a uint64 slice +func bigIntMatchUint64Slice(aInt *big.Int, a []uint64) error { + + words := aInt.Bits() + + const steps = 64 / bits.UintSize + const filter uint64 = 0xFFFFFFFFFFFFFFFF >> (64 - bits.UintSize) + for i := 0; i < len(a)*steps; i++ { + + var wI big.Word + + if i < len(words) { + wI = words[i] + } + + aI := a[i/steps] >> ((i * bits.UintSize) % 64) + aI &= filter + + if uint64(wI) != aI { + return fmt.Errorf("bignum mismatch: disagreement on word %d: %x ≠ %x; %d ≠ %d", i, uint64(wI), aI, uint64(wI), aI) + } + } + + return nil +} + +func TestElementInversionApproximation(t *testing.T) { + var x Element + for i := 0; i < 1000; i++ { + x.SetRandom() + + // Normally small elements are unlikely. Here we give them a higher chance + xZeros := mrand.Int() % Limbs //#nosec G404 weak rng is fine here + for j := 1; j < xZeros; j++ { + x[Limbs-j] = 0 + } + + a := approximate(&x, x.BitLen()) + aRef := approximateRef(&x) + + if a != aRef { + t.Error("Approximation mismatch") + } + } +} + +func TestElementInversionCorrectionFactorFormula(t *testing.T) { + const kLimbs = k * Limbs + const power = kLimbs*6 + invIterationsN*(kLimbs-k+1) + factorInt := big.NewInt(1) + factorInt.Lsh(factorInt, power) + factorInt.Mod(factorInt, Modulus()) + + var refFactorInt big.Int + inversionCorrectionFactor := Element{ + inversionCorrectionFactorWord0, + inversionCorrectionFactorWord1, + inversionCorrectionFactorWord2, + inversionCorrectionFactorWord3, + } + inversionCorrectionFactor.toBigInt(&refFactorInt) + + if refFactorInt.Cmp(factorInt) != 0 { + t.Error("mismatch") + } +} + +func TestElementLinearComb(t *testing.T) { + var x Element + var y Element + + for i := 0; i < 1000; i++ { + x.SetRandom() + y.SetRandom() + testLinearComb(t, &x, mrand.Int63(), &y, mrand.Int63()) //#nosec G404 weak rng is fine here + } +} + +// Probably unnecessary post-dev. In case the output of inv is wrong, this checks whether it's only off by a constant factor. +func TestElementInversionCorrectionFactor(t *testing.T) { + + // (1/x)/inv(x) = (1/1)/inv(1) ⇔ inv(1) = x inv(x) + + var one Element + var oneInv Element + one.SetOne() + oneInv.Inverse(&one) + + for i := 0; i < 100; i++ { + var x Element + var xInv Element + x.SetRandom() + xInv.Inverse(&x) + + x.Mul(&x, &xInv) + if !x.Equal(&oneInv) { + t.Error("Correction factor is inconsistent") + } + } + + if !oneInv.Equal(&one) { + var i big.Int + oneInv.BigInt(&i) // no montgomery + i.ModInverse(&i, Modulus()) + var fac Element + fac.setBigInt(&i) // back to montgomery + + var facTimesFac Element + facTimesFac.Mul(&fac, &Element{ + inversionCorrectionFactorWord0, + inversionCorrectionFactorWord1, + inversionCorrectionFactorWord2, + inversionCorrectionFactorWord3, + }) + + t.Error("Correction factor is consistently off by", fac, "Should be", facTimesFac) + } +} + +func TestElementBigNumNeg(t *testing.T) { + var a Element + aHi := negL(&a, 0) + if !a.IsZero() || aHi != 0 { + t.Error("-0 != 0") + } +} + +func TestElementBigNumWMul(t *testing.T) { + var x Element + + for i := 0; i < 1000; i++ { + x.SetRandom() + w := mrand.Int63() //#nosec G404 weak rng is fine here + testBigNumWMul(t, &x, w) + } +} + +func TestElementVeryBigIntConversion(t *testing.T) { + xHi := mrand.Uint64() //#nosec G404 weak rng is fine here + var x Element + x.SetRandom() + var xInt big.Int + x.toVeryBigIntSigned(&xInt, xHi) + x.assertMatchVeryBigInt(t, xHi, &xInt) +} + +type veryBigInt struct { + asInt big.Int + low Element + hi uint64 +} + +// genVeryBigIntSigned if sign == 0, no sign is forced +func genVeryBigIntSigned(sign int) gopter.Gen { + return func(genParams *gopter.GenParameters) *gopter.GenResult { + var g veryBigInt + + g.low = Element{ + genParams.NextUint64(), + genParams.NextUint64(), + genParams.NextUint64(), + genParams.NextUint64(), + } + + g.hi = genParams.NextUint64() + + if sign < 0 { + g.hi |= signBitSelector + } else if sign > 0 { + g.hi &= ^signBitSelector + } + + g.low.toVeryBigIntSigned(&g.asInt, g.hi) + + genResult := gopter.NewGenResult(g, gopter.NoShrinker) + return genResult + } +} + +func TestElementMontReduce(t *testing.T) { + + parameters := gopter.DefaultTestParameters() + if testing.Short() { + parameters.MinSuccessfulTests = nbFuzzShort + } else { + parameters.MinSuccessfulTests = nbFuzz + } + + properties := gopter.NewProperties(parameters) + + gen := genVeryBigIntSigned(0) + + properties.Property("Montgomery reduction is correct", prop.ForAll( + func(g veryBigInt) bool { + var res Element + var resInt big.Int + + montReduce(&resInt, &g.asInt) + res.montReduceSigned(&g.low, g.hi) + + return res.matchVeryBigInt(0, &resInt) == nil + }, + gen, + )) + + properties.TestingRun(t, gopter.ConsoleReporter(false)) +} + +func TestElementMontReduceMultipleOfR(t *testing.T) { + + parameters := gopter.DefaultTestParameters() + if testing.Short() { + parameters.MinSuccessfulTests = nbFuzzShort + } else { + parameters.MinSuccessfulTests = nbFuzz + } + + properties := gopter.NewProperties(parameters) + + gen := ggen.UInt64() + + properties.Property("Montgomery reduction is correct", prop.ForAll( + func(hi uint64) bool { + var zero, res Element + var asInt, resInt big.Int + + zero.toVeryBigIntSigned(&asInt, hi) + + montReduce(&resInt, &asInt) + res.montReduceSigned(&zero, hi) + + return res.matchVeryBigInt(0, &resInt) == nil + }, + gen, + )) + + properties.TestingRun(t, gopter.ConsoleReporter(false)) +} + +func TestElement0Inverse(t *testing.T) { + var x Element + x.Inverse(&x) + if !x.IsZero() { + t.Fail() + } +} + +// TODO: Tests like this (update factor related) are common to all fields. Move them to somewhere non-autogen +func TestUpdateFactorSubtraction(t *testing.T) { + for i := 0; i < 1000; i++ { + + f0, g0 := randomizeUpdateFactors() + f1, g1 := randomizeUpdateFactors() + + for f0-f1 > 1<<31 || f0-f1 <= -1<<31 { + f1 /= 2 + } + + for g0-g1 > 1<<31 || g0-g1 <= -1<<31 { + g1 /= 2 + } + + c0 := updateFactorsCompose(f0, g0) + c1 := updateFactorsCompose(f1, g1) + + cRes := c0 - c1 + fRes, gRes := updateFactorsDecompose(cRes) + + if fRes != f0-f1 || gRes != g0-g1 { + t.Error(i) + } + } +} + +func TestUpdateFactorsDouble(t *testing.T) { + for i := 0; i < 1000; i++ { + f, g := randomizeUpdateFactors() + + if f > 1<<30 || f < (-1<<31+1)/2 { + f /= 2 + if g <= 1<<29 && g >= (-1<<31+1)/4 { + g *= 2 //g was kept small on f's account. Now that we're halving f, we can double g + } + } + + if g > 1<<30 || g < (-1<<31+1)/2 { + g /= 2 + + if f <= 1<<29 && f >= (-1<<31+1)/4 { + f *= 2 //f was kept small on g's account. Now that we're halving g, we can double f + } + } + + c := updateFactorsCompose(f, g) + cD := c * 2 + fD, gD := updateFactorsDecompose(cD) + + if fD != 2*f || gD != 2*g { + t.Error(i) + } + } +} + +func TestUpdateFactorsNeg(t *testing.T) { + var fMistake bool + for i := 0; i < 1000; i++ { + f, g := randomizeUpdateFactors() + + if f == 0x80000000 || g == 0x80000000 { + // Update factors this large can only have been obtained after 31 iterations and will therefore never be negated + // We don't have capacity to store -2³¹ + // Repeat this iteration + i-- + continue + } + + c := updateFactorsCompose(f, g) + nc := -c + nf, ng := updateFactorsDecompose(nc) + fMistake = fMistake || nf != -f + if nf != -f || ng != -g { + t.Errorf("Mismatch iteration #%d:\n%d, %d ->\n %d -> %d ->\n %d, %d\n Inputs in hex: %X, %X", + i, f, g, c, nc, nf, ng, f, g) + } + } + if fMistake { + t.Error("Mistake with f detected") + } else { + t.Log("All good with f") + } +} + +func TestUpdateFactorsNeg0(t *testing.T) { + c := updateFactorsCompose(0, 0) + t.Logf("c(0,0) = %X", c) + cn := -c + + if c != cn { + t.Error("Negation of zero update factors should yield the same result.") + } +} + +func TestUpdateFactorDecomposition(t *testing.T) { + var negSeen bool + + for i := 0; i < 1000; i++ { + + f, g := randomizeUpdateFactors() + + if f <= -(1<<31) || f > 1<<31 { + t.Fatal("f out of range") + } + + negSeen = negSeen || f < 0 + + c := updateFactorsCompose(f, g) + + fBack, gBack := updateFactorsDecompose(c) + + if f != fBack || g != gBack { + t.Errorf("(%d, %d) -> %d -> (%d, %d)\n", f, g, c, fBack, gBack) + } + } + + if !negSeen { + t.Fatal("No negative f factors") + } +} + +func TestUpdateFactorInitialValues(t *testing.T) { + + f0, g0 := updateFactorsDecompose(updateFactorIdentityMatrixRow0) + f1, g1 := updateFactorsDecompose(updateFactorIdentityMatrixRow1) + + if f0 != 1 || g0 != 0 || f1 != 0 || g1 != 1 { + t.Error("Update factor initial value constants are incorrect") + } +} + +func TestUpdateFactorsRandomization(t *testing.T) { + var maxLen int + + //t.Log("|f| + |g| is not to exceed", 1 << 31) + for i := 0; i < 1000; i++ { + f, g := randomizeUpdateFactors() + lf, lg := abs64T32(f), abs64T32(g) + absSum := lf + lg + if absSum >= 1<<31 { + + if absSum == 1<<31 { + maxLen++ + } else { + t.Error(i, "Sum of absolute values too large, f =", f, ",g =", g, ",|f| + |g| =", absSum) + } + } + } + + if maxLen == 0 { + t.Error("max len not observed") + } else { + t.Log(maxLen, "maxLens observed") + } +} + +func randomizeUpdateFactor(absLimit uint32) int64 { + const maxSizeLikelihood = 10 + maxSize := mrand.Intn(maxSizeLikelihood) //#nosec G404 weak rng is fine here + + absLimit64 := int64(absLimit) + var f int64 + switch maxSize { + case 0: + f = absLimit64 + case 1: + f = -absLimit64 + default: + f = int64(mrand.Uint64()%(2*uint64(absLimit64)+1)) - absLimit64 //#nosec G404 weak rng is fine here + } + + if f > 1<<31 { + return 1 << 31 + } else if f < -1<<31+1 { + return -1<<31 + 1 + } + + return f +} + +func abs64T32(f int64) uint32 { + if f >= 1<<32 || f < -1<<32 { + panic("f out of range") + } + + if f < 0 { + return uint32(-f) + } + return uint32(f) +} + +func randomizeUpdateFactors() (int64, int64) { + var f [2]int64 + b := mrand.Int() % 2 //#nosec G404 weak rng is fine here + + f[b] = randomizeUpdateFactor(1 << 31) + + //As per the paper, |f| + |g| \le 2³¹. + f[1-b] = randomizeUpdateFactor(1<<31 - abs64T32(f[b])) + + //Patching another edge case + if f[0]+f[1] == -1<<31 { + b = mrand.Int() % 2 //#nosec G404 weak rng is fine here + f[b]++ + } + + return f[0], f[1] +} + +func testLinearComb(t *testing.T, x *Element, xC int64, y *Element, yC int64) { + + var p1 big.Int + x.toBigInt(&p1) + p1.Mul(&p1, big.NewInt(xC)) + + var p2 big.Int + y.toBigInt(&p2) + p2.Mul(&p2, big.NewInt(yC)) + + p1.Add(&p1, &p2) + p1.Mod(&p1, Modulus()) + montReduce(&p1, &p1) + + var z Element + z.linearComb(x, xC, y, yC) + z.assertMatchVeryBigInt(t, 0, &p1) +} + +func testBigNumWMul(t *testing.T, a *Element, c int64) { + var aHi uint64 + var aTimes Element + aHi = aTimes.mulWNonModular(a, c) + + assertMulProduct(t, a, c, &aTimes, aHi) +} + +func updateFactorsCompose(f int64, g int64) int64 { + return f + g<<32 +} + +var rInv big.Int + +func montReduce(res *big.Int, x *big.Int) { + if rInv.BitLen() == 0 { // initialization + rInv.SetUint64(1) + rInv.Lsh(&rInv, Limbs*64) + rInv.ModInverse(&rInv, Modulus()) + } + res.Mul(x, &rInv) + res.Mod(res, Modulus()) +} + +func (z *Element) toVeryBigIntUnsigned(i *big.Int, xHi uint64) { + z.toBigInt(i) + var upperWord big.Int + upperWord.SetUint64(xHi) + upperWord.Lsh(&upperWord, Limbs*64) + i.Add(&upperWord, i) +} + +func (z *Element) toVeryBigIntSigned(i *big.Int, xHi uint64) { + z.toVeryBigIntUnsigned(i, xHi) + if signBitSelector&xHi != 0 { + twosCompModulus := big.NewInt(1) + twosCompModulus.Lsh(twosCompModulus, (Limbs+1)*64) + i.Sub(i, twosCompModulus) + } +} + +func assertMulProduct(t *testing.T, x *Element, c int64, result *Element, resultHi uint64) big.Int { + var xInt big.Int + x.toBigInt(&xInt) + + xInt.Mul(&xInt, big.NewInt(c)) + + result.assertMatchVeryBigInt(t, resultHi, &xInt) + return xInt +} + +func approximateRef(x *Element) uint64 { + + var asInt big.Int + x.toBigInt(&asInt) + n := x.BitLen() + + if n <= 64 { + return asInt.Uint64() + } + + modulus := big.NewInt(1 << 31) + var lo big.Int + lo.Mod(&asInt, modulus) + + modulus.Lsh(modulus, uint(n-64)) + var hi big.Int + hi.Div(&asInt, modulus) + hi.Lsh(&hi, 31) + + hi.Add(&hi, &lo) + return hi.Uint64() +} diff --git a/ecc/grumpkin/fp/hash_to_field/doc.go b/ecc/grumpkin/fp/hash_to_field/doc.go new file mode 100644 index 0000000000..d202519123 --- /dev/null +++ b/ecc/grumpkin/fp/hash_to_field/doc.go @@ -0,0 +1,21 @@ +// Copyright 2020-2025 Consensys Software Inc. +// Licensed under the Apache License, Version 2.0. See the LICENSE file for details. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +// Package htf provides hasher based on RFC 9380 Section 5. +// +// The [RFC 9380] defines a method for hashing bytes to elliptic curves. Section +// 5 of the RFC describes a method for uniformly hashing bytes into a field +// using a domain separation. The hashing is implemented in [fp], but this +// package provides a wrapper for the method which implements [hash.Hash] for +// using the method recursively. +// +// [RFC 9380]: https://datatracker.ietf.org/doc/html/rfc9380 +package hash_to_field + +import ( + _ "hash" + + _ "github.com/consensys/gnark-crypto/ecc/grumpkin/fp" +) diff --git a/ecc/grumpkin/fp/hash_to_field/hash_to_field.go b/ecc/grumpkin/fp/hash_to_field/hash_to_field.go new file mode 100644 index 0000000000..64d14791d9 --- /dev/null +++ b/ecc/grumpkin/fp/hash_to_field/hash_to_field.go @@ -0,0 +1,55 @@ +// Copyright 2020-2025 Consensys Software Inc. +// Licensed under the Apache License, Version 2.0. See the LICENSE file for details. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package hash_to_field + +import ( + "fmt" + "hash" + + "github.com/consensys/gnark-crypto/ecc/grumpkin/fp" +) + +type wrappedHashToField struct { + domain []byte + toHash []byte +} + +// New returns a new hasher instance which uses [fp.Hash] to hash all the +// written bytes to a field element, returning the byte representation of the +// field element. The domain separator is passed as-is to hashing method. +func New(domainSeparator []byte) hash.Hash { + return &wrappedHashToField{ + domain: append([]byte{}, domainSeparator...), // copy in case the argument is modified + } +} + +func (w *wrappedHashToField) Write(p []byte) (n int, err error) { + w.toHash = append(w.toHash, p...) + return len(p), nil +} + +func (w *wrappedHashToField) Sum(b []byte) []byte { + res, err := fp.Hash(w.toHash, w.domain, 1) + if err != nil { + // we want to follow the interface, cannot return error and have to panic + // but by default the method shouldn't return an error internally + panic(fmt.Sprintf("native field to hash: %v", err)) + } + bts := res[0].Bytes() + return append(b, bts[:]...) +} + +func (w *wrappedHashToField) Reset() { + w.toHash = nil +} + +func (w *wrappedHashToField) Size() int { + return fp.Bytes +} + +func (w *wrappedHashToField) BlockSize() int { + return fp.Bytes +} diff --git a/ecc/grumpkin/fp/hash_to_field/hash_to_field_test.go b/ecc/grumpkin/fp/hash_to_field/hash_to_field_test.go new file mode 100644 index 0000000000..0f6263d06b --- /dev/null +++ b/ecc/grumpkin/fp/hash_to_field/hash_to_field_test.go @@ -0,0 +1,30 @@ +// Copyright 2020-2025 Consensys Software Inc. +// Licensed under the Apache License, Version 2.0. See the LICENSE file for details. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package hash_to_field + +import ( + "testing" + + "github.com/consensys/gnark-crypto/ecc/grumpkin/fp" +) + +func TestHashInterface(t *testing.T) { + msg := []byte("test") + sep := []byte("separator") + res, err := fp.Hash(msg, sep, 1) + if err != nil { + t.Fatal("hash to field", err) + } + + htfFn := New(sep) + htfFn.Write(msg) + bts := htfFn.Sum(nil) + var res2 fp.Element + res2.SetBytes(bts[:fp.Bytes]) + if !res[0].Equal(&res2) { + t.Error("not equal") + } +} diff --git a/ecc/grumpkin/fp/vector.go b/ecc/grumpkin/fp/vector.go new file mode 100644 index 0000000000..5dfc17ebbc --- /dev/null +++ b/ecc/grumpkin/fp/vector.go @@ -0,0 +1,295 @@ +// Copyright 2020-2025 Consensys Software Inc. +// Licensed under the Apache License, Version 2.0. See the LICENSE file for details. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package fp + +import ( + "bytes" + "encoding/binary" + "fmt" + "io" + "runtime" + "strings" + "sync" + "sync/atomic" + "unsafe" +) + +// Vector represents a slice of Element. +// +// It implements the following interfaces: +// - Stringer +// - io.WriterTo +// - io.ReaderFrom +// - encoding.BinaryMarshaler +// - encoding.BinaryUnmarshaler +// - sort.Interface +type Vector []Element + +// MarshalBinary implements encoding.BinaryMarshaler +func (vector *Vector) MarshalBinary() (data []byte, err error) { + var buf bytes.Buffer + + if _, err = vector.WriteTo(&buf); err != nil { + return + } + return buf.Bytes(), nil +} + +// UnmarshalBinary implements encoding.BinaryUnmarshaler +func (vector *Vector) UnmarshalBinary(data []byte) error { + r := bytes.NewReader(data) + _, err := vector.ReadFrom(r) + return err +} + +// WriteTo implements io.WriterTo and writes a vector of big endian encoded Element. +// Length of the vector is encoded as a uint32 on the first 4 bytes. +func (vector *Vector) WriteTo(w io.Writer) (int64, error) { + // encode slice length + if err := binary.Write(w, binary.BigEndian, uint32(len(*vector))); err != nil { + return 0, err + } + + n := int64(4) + + var buf [Bytes]byte + for i := 0; i < len(*vector); i++ { + BigEndian.PutElement(&buf, (*vector)[i]) + m, err := w.Write(buf[:]) + n += int64(m) + if err != nil { + return n, err + } + } + return n, nil +} + +// AsyncReadFrom reads a vector of big endian encoded Element. +// Length of the vector must be encoded as a uint32 on the first 4 bytes. +// It consumes the needed bytes from the reader and returns the number of bytes read and an error if any. +// It also returns a channel that will be closed when the validation is done. +// The validation consist of checking that the elements are smaller than the modulus, and +// converting them to montgomery form. +func (vector *Vector) AsyncReadFrom(r io.Reader) (int64, error, chan error) { + chErr := make(chan error, 1) + var buf [Bytes]byte + if read, err := io.ReadFull(r, buf[:4]); err != nil { + close(chErr) + return int64(read), err, chErr + } + sliceLen := binary.BigEndian.Uint32(buf[:4]) + + n := int64(4) + (*vector) = make(Vector, sliceLen) + if sliceLen == 0 { + close(chErr) + return n, nil, chErr + } + + bSlice := unsafe.Slice((*byte)(unsafe.Pointer(&(*vector)[0])), sliceLen*Bytes) + read, err := io.ReadFull(r, bSlice) + n += int64(read) + if err != nil { + close(chErr) + return n, err, chErr + } + + go func() { + var cptErrors uint64 + // process the elements in parallel + execute(int(sliceLen), func(start, end int) { + + var z Element + for i := start; i < end; i++ { + // we have to set vector[i] + bstart := i * Bytes + bend := bstart + Bytes + b := bSlice[bstart:bend] + z[0] = binary.BigEndian.Uint64(b[24:32]) + z[1] = binary.BigEndian.Uint64(b[16:24]) + z[2] = binary.BigEndian.Uint64(b[8:16]) + z[3] = binary.BigEndian.Uint64(b[0:8]) + + if !z.smallerThanModulus() { + atomic.AddUint64(&cptErrors, 1) + return + } + z.toMont() + (*vector)[i] = z + } + }) + + if cptErrors > 0 { + chErr <- fmt.Errorf("async read: %d elements failed validation", cptErrors) + } + close(chErr) + }() + return n, nil, chErr +} + +// ReadFrom implements io.ReaderFrom and reads a vector of big endian encoded Element. +// Length of the vector must be encoded as a uint32 on the first 4 bytes. +func (vector *Vector) ReadFrom(r io.Reader) (int64, error) { + + var buf [Bytes]byte + if read, err := io.ReadFull(r, buf[:4]); err != nil { + return int64(read), err + } + sliceLen := binary.BigEndian.Uint32(buf[:4]) + + n := int64(4) + (*vector) = make(Vector, sliceLen) + + for i := 0; i < int(sliceLen); i++ { + read, err := io.ReadFull(r, buf[:]) + n += int64(read) + if err != nil { + return n, err + } + (*vector)[i], err = BigEndian.Element(&buf) + if err != nil { + return n, err + } + } + + return n, nil +} + +// String implements fmt.Stringer interface +func (vector Vector) String() string { + var sbb strings.Builder + sbb.WriteByte('[') + for i := 0; i < len(vector); i++ { + sbb.WriteString(vector[i].String()) + if i != len(vector)-1 { + sbb.WriteByte(',') + } + } + sbb.WriteByte(']') + return sbb.String() +} + +// Len is the number of elements in the collection. +func (vector Vector) Len() int { + return len(vector) +} + +// Less reports whether the element with +// index i should sort before the element with index j. +func (vector Vector) Less(i, j int) bool { + return vector[i].Cmp(&vector[j]) == -1 +} + +// Swap swaps the elements with indexes i and j. +func (vector Vector) Swap(i, j int) { + vector[i], vector[j] = vector[j], vector[i] +} + +func addVecGeneric(res, a, b Vector) { + if len(a) != len(b) || len(a) != len(res) { + panic("vector.Add: vectors don't have the same length") + } + for i := 0; i < len(a); i++ { + res[i].Add(&a[i], &b[i]) + } +} + +func subVecGeneric(res, a, b Vector) { + if len(a) != len(b) || len(a) != len(res) { + panic("vector.Sub: vectors don't have the same length") + } + for i := 0; i < len(a); i++ { + res[i].Sub(&a[i], &b[i]) + } +} + +func scalarMulVecGeneric(res, a Vector, b *Element) { + if len(a) != len(res) { + panic("vector.ScalarMul: vectors don't have the same length") + } + for i := 0; i < len(a); i++ { + res[i].Mul(&a[i], b) + } +} + +func sumVecGeneric(res *Element, a Vector) { + for i := 0; i < len(a); i++ { + res.Add(res, &a[i]) + } +} + +func innerProductVecGeneric(res *Element, a, b Vector) { + if len(a) != len(b) { + panic("vector.InnerProduct: vectors don't have the same length") + } + var tmp Element + for i := 0; i < len(a); i++ { + tmp.Mul(&a[i], &b[i]) + res.Add(res, &tmp) + } +} + +func mulVecGeneric(res, a, b Vector) { + if len(a) != len(b) || len(a) != len(res) { + panic("vector.Mul: vectors don't have the same length") + } + for i := 0; i < len(a); i++ { + res[i].Mul(&a[i], &b[i]) + } +} + +// TODO @gbotrel make a public package out of that. +// execute executes the work function in parallel. +// this is copy paste from internal/parallel/parallel.go +// as we don't want to generate code importing internal/ +func execute(nbIterations int, work func(int, int), maxCpus ...int) { + + nbTasks := runtime.NumCPU() + if len(maxCpus) == 1 { + nbTasks = maxCpus[0] + if nbTasks < 1 { + nbTasks = 1 + } else if nbTasks > 512 { + nbTasks = 512 + } + } + + if nbTasks == 1 { + // no go routines + work(0, nbIterations) + return + } + + nbIterationsPerCpus := nbIterations / nbTasks + + // more CPUs than tasks: a CPU will work on exactly one iteration + if nbIterationsPerCpus < 1 { + nbIterationsPerCpus = 1 + nbTasks = nbIterations + } + + var wg sync.WaitGroup + + extraTasks := nbIterations - (nbTasks * nbIterationsPerCpus) + extraTasksOffset := 0 + + for i := 0; i < nbTasks; i++ { + wg.Add(1) + _start := i*nbIterationsPerCpus + extraTasksOffset + _end := _start + nbIterationsPerCpus + if extraTasks > 0 { + _end++ + extraTasks-- + extraTasksOffset++ + } + go func() { + work(_start, _end) + wg.Done() + }() + } + + wg.Wait() +} diff --git a/ecc/grumpkin/fp/vector_amd64.go b/ecc/grumpkin/fp/vector_amd64.go new file mode 100644 index 0000000000..4faa5cac55 --- /dev/null +++ b/ecc/grumpkin/fp/vector_amd64.go @@ -0,0 +1,153 @@ +//go:build !purego + +// Copyright 2020-2025 Consensys Software Inc. +// Licensed under the Apache License, Version 2.0. See the LICENSE file for details. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package fp + +import ( + _ "github.com/consensys/gnark-crypto/field/asm/element_4w" +) + +// Add adds two vectors element-wise and stores the result in self. +// It panics if the vectors don't have the same length. +func (vector *Vector) Add(a, b Vector) { + if len(a) != len(b) || len(a) != len(*vector) { + panic("vector.Add: vectors don't have the same length") + } + n := uint64(len(a)) + addVec(&(*vector)[0], &a[0], &b[0], n) +} + +//go:noescape +func addVec(res, a, b *Element, n uint64) + +// Sub subtracts two vectors element-wise and stores the result in self. +// It panics if the vectors don't have the same length. +func (vector *Vector) Sub(a, b Vector) { + if len(a) != len(b) || len(a) != len(*vector) { + panic("vector.Sub: vectors don't have the same length") + } + subVec(&(*vector)[0], &a[0], &b[0], uint64(len(a))) +} + +//go:noescape +func subVec(res, a, b *Element, n uint64) + +// ScalarMul multiplies a vector by a scalar element-wise and stores the result in self. +// It panics if the vectors don't have the same length. +func (vector *Vector) ScalarMul(a Vector, b *Element) { + if len(a) != len(*vector) { + panic("vector.ScalarMul: vectors don't have the same length") + } + const maxN = (1 << 32) - 1 + if !supportAvx512 || uint64(len(a)) >= maxN { + // call scalarMulVecGeneric + scalarMulVecGeneric(*vector, a, b) + return + } + n := uint64(len(a)) + if n == 0 { + return + } + // the code for scalarMul is identical to mulVec; and it expects at least + // 2 elements in the vector to fill the Z registers + var bb [2]Element + bb[0] = *b + bb[1] = *b + const blockSize = 16 + scalarMulVec(&(*vector)[0], &a[0], &bb[0], n/blockSize, qInvNeg) + if n%blockSize != 0 { + // call scalarMulVecGeneric on the rest + start := n - n%blockSize + scalarMulVecGeneric((*vector)[start:], a[start:], b) + } +} + +//go:noescape +func scalarMulVec(res, a, b *Element, n uint64, qInvNeg uint64) + +// Sum computes the sum of all elements in the vector. +func (vector *Vector) Sum() (res Element) { + n := uint64(len(*vector)) + if n == 0 { + return + } + const minN = 16 * 7 // AVX512 slower than generic for small n + const maxN = (1 << 32) - 1 + if !supportAvx512 || n <= minN || n >= maxN { + // call sumVecGeneric + sumVecGeneric(&res, *vector) + return + } + sumVec(&res, &(*vector)[0], uint64(len(*vector))) + return +} + +//go:noescape +func sumVec(res *Element, a *Element, n uint64) + +// InnerProduct computes the inner product of two vectors. +// It panics if the vectors don't have the same length. +func (vector *Vector) InnerProduct(other Vector) (res Element) { + n := uint64(len(*vector)) + if n == 0 { + return + } + if n != uint64(len(other)) { + panic("vector.InnerProduct: vectors don't have the same length") + } + const maxN = (1 << 32) - 1 + if !supportAvx512 || n >= maxN { + // call innerProductVecGeneric + // note; we could split the vector into smaller chunks and call innerProductVec + innerProductVecGeneric(&res, *vector, other) + return + } + innerProdVec(&res[0], &(*vector)[0], &other[0], uint64(len(*vector))) + + return +} + +//go:noescape +func innerProdVec(res *uint64, a, b *Element, n uint64) + +// Mul multiplies two vectors element-wise and stores the result in self. +// It panics if the vectors don't have the same length. +func (vector *Vector) Mul(a, b Vector) { + if len(a) != len(b) || len(a) != len(*vector) { + panic("vector.Mul: vectors don't have the same length") + } + n := uint64(len(a)) + if n == 0 { + return + } + const maxN = (1 << 32) - 1 + if !supportAvx512 || n >= maxN { + // call mulVecGeneric + mulVecGeneric(*vector, a, b) + return + } + + const blockSize = 16 + mulVec(&(*vector)[0], &a[0], &b[0], n/blockSize, qInvNeg) + if n%blockSize != 0 { + // call mulVecGeneric on the rest + start := n - n%blockSize + mulVecGeneric((*vector)[start:], a[start:], b[start:]) + } + +} + +// Patterns use for transposing the vectors in mulVec +var ( + pattern1 = [8]uint64{0, 8, 1, 9, 2, 10, 3, 11} + pattern2 = [8]uint64{12, 4, 13, 5, 14, 6, 15, 7} + pattern3 = [8]uint64{0, 1, 8, 9, 2, 3, 10, 11} + pattern4 = [8]uint64{12, 13, 4, 5, 14, 15, 6, 7} +) + +//go:noescape +func mulVec(res, a, b *Element, n uint64, qInvNeg uint64) diff --git a/ecc/grumpkin/fp/vector_purego.go b/ecc/grumpkin/fp/vector_purego.go new file mode 100644 index 0000000000..98343052a2 --- /dev/null +++ b/ecc/grumpkin/fp/vector_purego.go @@ -0,0 +1,45 @@ +//go:build purego || !amd64 + +// Copyright 2020-2025 Consensys Software Inc. +// Licensed under the Apache License, Version 2.0. See the LICENSE file for details. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package fp + +// Add adds two vectors element-wise and stores the result in self. +// It panics if the vectors don't have the same length. +func (vector *Vector) Add(a, b Vector) { + addVecGeneric(*vector, a, b) +} + +// Sub subtracts two vectors element-wise and stores the result in self. +// It panics if the vectors don't have the same length. +func (vector *Vector) Sub(a, b Vector) { + subVecGeneric(*vector, a, b) +} + +// ScalarMul multiplies a vector by a scalar element-wise and stores the result in self. +// It panics if the vectors don't have the same length. +func (vector *Vector) ScalarMul(a Vector, b *Element) { + scalarMulVecGeneric(*vector, a, b) +} + +// Sum computes the sum of all elements in the vector. +func (vector *Vector) Sum() (res Element) { + sumVecGeneric(&res, *vector) + return +} + +// InnerProduct computes the inner product of two vectors. +// It panics if the vectors don't have the same length. +func (vector *Vector) InnerProduct(other Vector) (res Element) { + innerProductVecGeneric(&res, *vector, other) + return +} + +// Mul multiplies two vectors element-wise and stores the result in self. +// It panics if the vectors don't have the same length. +func (vector *Vector) Mul(a, b Vector) { + mulVecGeneric(*vector, a, b) +} diff --git a/ecc/grumpkin/fp/vector_test.go b/ecc/grumpkin/fp/vector_test.go new file mode 100644 index 0000000000..5f78812dc6 --- /dev/null +++ b/ecc/grumpkin/fp/vector_test.go @@ -0,0 +1,360 @@ +// Copyright 2020-2025 Consensys Software Inc. +// Licensed under the Apache License, Version 2.0. See the LICENSE file for details. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package fp + +import ( + "bytes" + "fmt" + "github.com/stretchr/testify/require" + "os" + "reflect" + "sort" + "testing" + + "github.com/leanovate/gopter" + "github.com/leanovate/gopter/prop" +) + +func TestVectorSort(t *testing.T) { + assert := require.New(t) + + v := make(Vector, 3) + v[0].SetUint64(2) + v[1].SetUint64(3) + v[2].SetUint64(1) + + sort.Sort(v) + + assert.Equal("[1,2,3]", v.String()) +} + +func TestVectorRoundTrip(t *testing.T) { + assert := require.New(t) + + v1 := make(Vector, 3) + v1[0].SetUint64(2) + v1[1].SetUint64(3) + v1[2].SetUint64(1) + + b, err := v1.MarshalBinary() + assert.NoError(err) + + var v2, v3 Vector + + err = v2.UnmarshalBinary(b) + assert.NoError(err) + + err = v3.unmarshalBinaryAsync(b) + assert.NoError(err) + + assert.True(reflect.DeepEqual(v1, v2)) + assert.True(reflect.DeepEqual(v3, v2)) +} + +func TestVectorEmptyRoundTrip(t *testing.T) { + assert := require.New(t) + + v1 := make(Vector, 0) + + b, err := v1.MarshalBinary() + assert.NoError(err) + + var v2, v3 Vector + + err = v2.UnmarshalBinary(b) + assert.NoError(err) + + err = v3.unmarshalBinaryAsync(b) + assert.NoError(err) + + assert.True(reflect.DeepEqual(v1, v2)) + assert.True(reflect.DeepEqual(v3, v2)) +} + +func (vector *Vector) unmarshalBinaryAsync(data []byte) error { + r := bytes.NewReader(data) + _, err, chErr := vector.AsyncReadFrom(r) + if err != nil { + return err + } + return <-chErr +} + +func TestVectorOps(t *testing.T) { + parameters := gopter.DefaultTestParameters() + if testing.Short() { + parameters.MinSuccessfulTests = 2 + } else { + parameters.MinSuccessfulTests = 10 + } + properties := gopter.NewProperties(parameters) + + addVector := func(a, b Vector) bool { + c := make(Vector, len(a)) + c.Add(a, b) + + for i := 0; i < len(a); i++ { + var tmp Element + tmp.Add(&a[i], &b[i]) + if !tmp.Equal(&c[i]) { + return false + } + } + return true + } + + subVector := func(a, b Vector) bool { + c := make(Vector, len(a)) + c.Sub(a, b) + + for i := 0; i < len(a); i++ { + var tmp Element + tmp.Sub(&a[i], &b[i]) + if !tmp.Equal(&c[i]) { + return false + } + } + return true + } + + scalarMulVector := func(a Vector, b Element) bool { + c := make(Vector, len(a)) + c.ScalarMul(a, &b) + + for i := 0; i < len(a); i++ { + var tmp Element + tmp.Mul(&a[i], &b) + if !tmp.Equal(&c[i]) { + return false + } + } + return true + } + + sumVector := func(a Vector) bool { + var sum Element + computed := a.Sum() + for i := 0; i < len(a); i++ { + sum.Add(&sum, &a[i]) + } + + return sum.Equal(&computed) + } + + innerProductVector := func(a, b Vector) bool { + computed := a.InnerProduct(b) + var innerProduct Element + for i := 0; i < len(a); i++ { + var tmp Element + tmp.Mul(&a[i], &b[i]) + innerProduct.Add(&innerProduct, &tmp) + } + + return innerProduct.Equal(&computed) + } + + mulVector := func(a, b Vector) bool { + c := make(Vector, len(a)) + a[0].SetUint64(0x24) + b[0].SetUint64(0x42) + c.Mul(a, b) + + for i := 0; i < len(a); i++ { + var tmp Element + tmp.Mul(&a[i], &b[i]) + if !tmp.Equal(&c[i]) { + return false + } + } + return true + } + + sizes := []int{1, 2, 3, 4, 8, 9, 15, 16, 509, 510, 511, 512, 513, 514} + type genPair struct { + g1, g2 gopter.Gen + label string + } + + for _, size := range sizes { + generators := []genPair{ + {genZeroVector(size), genZeroVector(size), "zero vectors"}, + {genMaxVector(size), genMaxVector(size), "max vectors"}, + {genVector(size), genVector(size), "random vectors"}, + {genVector(size), genZeroVector(size), "random and zero vectors"}, + } + for _, gp := range generators { + properties.Property(fmt.Sprintf("vector addition %d - %s", size, gp.label), prop.ForAll( + addVector, + gp.g1, + gp.g2, + )) + + properties.Property(fmt.Sprintf("vector subtraction %d - %s", size, gp.label), prop.ForAll( + subVector, + gp.g1, + gp.g2, + )) + + properties.Property(fmt.Sprintf("vector scalar multiplication %d - %s", size, gp.label), prop.ForAll( + scalarMulVector, + gp.g1, + genElement(), + )) + + properties.Property(fmt.Sprintf("vector sum %d - %s", size, gp.label), prop.ForAll( + sumVector, + gp.g1, + )) + + properties.Property(fmt.Sprintf("vector inner product %d - %s", size, gp.label), prop.ForAll( + innerProductVector, + gp.g1, + gp.g2, + )) + + properties.Property(fmt.Sprintf("vector multiplication %d - %s", size, gp.label), prop.ForAll( + mulVector, + gp.g1, + gp.g2, + )) + } + } + + properties.TestingRun(t, gopter.NewFormatedReporter(false, 260, os.Stdout)) +} + +func BenchmarkVectorOps(b *testing.B) { + // note; to benchmark against "no asm" version, use the following + // build tag: -tags purego + const N = 1 << 24 + a1 := make(Vector, N) + b1 := make(Vector, N) + c1 := make(Vector, N) + var mixer Element + mixer.SetRandom() + for i := 1; i < N; i++ { + a1[i-1].SetUint64(uint64(i)). + Mul(&a1[i-1], &mixer) + b1[i-1].SetUint64(^uint64(i)). + Mul(&b1[i-1], &mixer) + } + + for n := 1 << 4; n <= N; n <<= 1 { + b.Run(fmt.Sprintf("add %d", n), func(b *testing.B) { + _a := a1[:n] + _b := b1[:n] + _c := c1[:n] + b.ResetTimer() + for i := 0; i < b.N; i++ { + _c.Add(_a, _b) + } + }) + + b.Run(fmt.Sprintf("sub %d", n), func(b *testing.B) { + _a := a1[:n] + _b := b1[:n] + _c := c1[:n] + b.ResetTimer() + for i := 0; i < b.N; i++ { + _c.Sub(_a, _b) + } + }) + + b.Run(fmt.Sprintf("scalarMul %d", n), func(b *testing.B) { + _a := a1[:n] + _c := c1[:n] + b.ResetTimer() + for i := 0; i < b.N; i++ { + _c.ScalarMul(_a, &mixer) + } + }) + + b.Run(fmt.Sprintf("sum %d", n), func(b *testing.B) { + _a := a1[:n] + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = _a.Sum() + } + }) + + b.Run(fmt.Sprintf("innerProduct %d", n), func(b *testing.B) { + _a := a1[:n] + _b := b1[:n] + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = _a.InnerProduct(_b) + } + }) + + b.Run(fmt.Sprintf("mul %d", n), func(b *testing.B) { + _a := a1[:n] + _b := b1[:n] + _c := c1[:n] + b.ResetTimer() + for i := 0; i < b.N; i++ { + _c.Mul(_a, _b) + } + }) + } +} + +func genZeroVector(size int) gopter.Gen { + return func(genParams *gopter.GenParameters) *gopter.GenResult { + g := make(Vector, size) + genResult := gopter.NewGenResult(g, gopter.NoShrinker) + return genResult + } +} + +func genMaxVector(size int) gopter.Gen { + return func(genParams *gopter.GenParameters) *gopter.GenResult { + g := make(Vector, size) + + qMinusOne := qElement + qMinusOne[0]-- + + for i := 0; i < size; i++ { + g[i] = qMinusOne + } + genResult := gopter.NewGenResult(g, gopter.NoShrinker) + return genResult + } +} + +func genVector(size int) gopter.Gen { + return func(genParams *gopter.GenParameters) *gopter.GenResult { + g := make(Vector, size) + mixer := Element{ + genParams.NextUint64(), + genParams.NextUint64(), + genParams.NextUint64(), + genParams.NextUint64(), + } + if qElement[3] != ^uint64(0) { + mixer[3] %= (qElement[3] + 1) + } + + for !mixer.smallerThanModulus() { + mixer = Element{ + genParams.NextUint64(), + genParams.NextUint64(), + genParams.NextUint64(), + genParams.NextUint64(), + } + if qElement[3] != ^uint64(0) { + mixer[3] %= (qElement[3] + 1) + } + } + + for i := 1; i <= size; i++ { + g[i-1].SetUint64(uint64(i)). + Mul(&g[i-1], &mixer) + } + + genResult := gopter.NewGenResult(g, gopter.NoShrinker) + return genResult + } +} diff --git a/ecc/grumpkin/fr/arith.go b/ecc/grumpkin/fr/arith.go new file mode 100644 index 0000000000..8dbae56e32 --- /dev/null +++ b/ecc/grumpkin/fr/arith.go @@ -0,0 +1,49 @@ +// Copyright 2020-2025 Consensys Software Inc. +// Licensed under the Apache License, Version 2.0. See the LICENSE file for details. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package fr + +import ( + "math/bits" +) + +// madd0 hi = a*b + c (discards lo bits) +func madd0(a, b, c uint64) (hi uint64) { + var carry, lo uint64 + hi, lo = bits.Mul64(a, b) + _, carry = bits.Add64(lo, c, 0) + hi, _ = bits.Add64(hi, 0, carry) + return +} + +// madd1 hi, lo = a*b + c +func madd1(a, b, c uint64) (hi uint64, lo uint64) { + var carry uint64 + hi, lo = bits.Mul64(a, b) + lo, carry = bits.Add64(lo, c, 0) + hi, _ = bits.Add64(hi, 0, carry) + return +} + +// madd2 hi, lo = a*b + c + d +func madd2(a, b, c, d uint64) (hi uint64, lo uint64) { + var carry uint64 + hi, lo = bits.Mul64(a, b) + c, carry = bits.Add64(c, d, 0) + hi, _ = bits.Add64(hi, 0, carry) + lo, carry = bits.Add64(lo, c, 0) + hi, _ = bits.Add64(hi, 0, carry) + return +} + +func madd3(a, b, c, d, e uint64) (hi uint64, lo uint64) { + var carry uint64 + hi, lo = bits.Mul64(a, b) + c, carry = bits.Add64(c, d, 0) + hi, _ = bits.Add64(hi, 0, carry) + lo, carry = bits.Add64(lo, c, 0) + hi, _ = bits.Add64(hi, e, carry) + return +} diff --git a/ecc/grumpkin/fr/asm_adx.go b/ecc/grumpkin/fr/asm_adx.go new file mode 100644 index 0000000000..52421cc159 --- /dev/null +++ b/ecc/grumpkin/fr/asm_adx.go @@ -0,0 +1,15 @@ +//go:build !noadx + +// Copyright 2020-2025 Consensys Software Inc. +// Licensed under the Apache License, Version 2.0. See the LICENSE file for details. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package fr + +import "golang.org/x/sys/cpu" + +var ( + supportAdx = cpu.X86.HasADX && cpu.X86.HasBMI2 + _ = supportAdx +) diff --git a/ecc/grumpkin/fr/asm_avx.go b/ecc/grumpkin/fr/asm_avx.go new file mode 100644 index 0000000000..f89a44926d --- /dev/null +++ b/ecc/grumpkin/fr/asm_avx.go @@ -0,0 +1,15 @@ +//go:build !noavx + +// Copyright 2020-2025 Consensys Software Inc. +// Licensed under the Apache License, Version 2.0. See the LICENSE file for details. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package fr + +import "golang.org/x/sys/cpu" + +var ( + supportAvx512 = supportAdx && cpu.X86.HasAVX512 && cpu.X86.HasAVX512DQ + _ = supportAvx512 +) diff --git a/ecc/grumpkin/fr/asm_noadx.go b/ecc/grumpkin/fr/asm_noadx.go new file mode 100644 index 0000000000..9c9c9dab5d --- /dev/null +++ b/ecc/grumpkin/fr/asm_noadx.go @@ -0,0 +1,16 @@ +//go:build noadx + +// Copyright 2020-2025 Consensys Software Inc. +// Licensed under the Apache License, Version 2.0. See the LICENSE file for details. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package fr + +// note: this is needed for test purposes, as dynamically changing supportAdx doesn't flag +// certain errors (like fatal error: missing stackmap) +// this ensures we test all asm path. +var ( + supportAdx = false + _ = supportAdx +) diff --git a/ecc/grumpkin/fr/asm_noavx.go b/ecc/grumpkin/fr/asm_noavx.go new file mode 100644 index 0000000000..ae86f75e7f --- /dev/null +++ b/ecc/grumpkin/fr/asm_noavx.go @@ -0,0 +1,10 @@ +//go:build noavx + +// Copyright 2020-2025 Consensys Software Inc. +// Licensed under the Apache License, Version 2.0. See the LICENSE file for details. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package fr + +const supportAvx512 = false diff --git a/ecc/grumpkin/fr/doc.go b/ecc/grumpkin/fr/doc.go new file mode 100644 index 0000000000..54bd9204ac --- /dev/null +++ b/ecc/grumpkin/fr/doc.go @@ -0,0 +1,46 @@ +// Copyright 2020-2025 Consensys Software Inc. +// Licensed under the Apache License, Version 2.0. See the LICENSE file for details. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +// Package fr contains field arithmetic operations for modulus = 0x30644e...7cfd47. +// +// The API is similar to math/big (big.Int), but the operations are significantly faster (up to 20x). +// +// Additionally fr.Vector offers an API to manipulate []Element using AVX512 instructions if available. +// +// The modulus is hardcoded in all the operations. +// +// Field elements are represented as an array, and assumed to be in Montgomery form in all methods: +// +// type Element [4]uint64 +// +// # Usage +// +// Example API signature: +// +// // Mul z = x * y (mod q) +// func (z *Element) Mul(x, y *Element) *Element +// +// and can be used like so: +// +// var a, b Element +// a.SetUint64(2) +// b.SetString("984896738") +// a.Mul(a, b) +// a.Sub(a, a) +// .Add(a, b) +// .Inv(a) +// b.Exp(b, new(big.Int).SetUint64(42)) +// +// Modulus q = +// +// q[base10] = 21888242871839275222246405745257275088696311157297823662689037894645226208583 +// q[base16] = 0x30644e72e131a029b85045b68181585d97816a916871ca8d3c208c16d87cfd47 +// +// # Warning +// +// There is no security guarantees such as constant time implementation or side-channel attack resistance. +// This code is provided as-is. Partially audited, see https://github.com/Consensys/gnark/tree/master/audits +// for more details. +package fr diff --git a/ecc/grumpkin/fr/element.go b/ecc/grumpkin/fr/element.go new file mode 100644 index 0000000000..7f2e5d9562 --- /dev/null +++ b/ecc/grumpkin/fr/element.go @@ -0,0 +1,1538 @@ +// Copyright 2020-2025 Consensys Software Inc. +// Licensed under the Apache License, Version 2.0. See the LICENSE file for details. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package fr + +import ( + "crypto/rand" + "encoding/binary" + "errors" + "io" + "math/big" + "math/bits" + "reflect" + "strconv" + "strings" + + "github.com/bits-and-blooms/bitset" + "github.com/consensys/gnark-crypto/field/hash" + "github.com/consensys/gnark-crypto/field/pool" +) + +// Element represents a field element stored on 4 words (uint64) +// +// Element are assumed to be in Montgomery form in all methods. +// +// Modulus q = +// +// q[base10] = 21888242871839275222246405745257275088696311157297823662689037894645226208583 +// q[base16] = 0x30644e72e131a029b85045b68181585d97816a916871ca8d3c208c16d87cfd47 +// +// # Warning +// +// This code has not been audited and is provided as-is. In particular, there is no security guarantees such as constant time implementation or side-channel attack resistance. +type Element [4]uint64 + +const ( + Limbs = 4 // number of 64 bits words needed to represent a Element + Bits = 254 // number of bits needed to represent a Element + Bytes = 32 // number of bytes needed to represent a Element +) + +// Field modulus q +const ( + q0 = 4332616871279656263 + q1 = 10917124144477883021 + q2 = 13281191951274694749 + q3 = 3486998266802970665 +) + +var qElement = Element{ + q0, + q1, + q2, + q3, +} + +var _modulus big.Int // q stored as big.Int + +// Modulus returns q as a big.Int +// +// q[base10] = 21888242871839275222246405745257275088696311157297823662689037894645226208583 +// q[base16] = 0x30644e72e131a029b85045b68181585d97816a916871ca8d3c208c16d87cfd47 +func Modulus() *big.Int { + return new(big.Int).Set(&_modulus) +} + +// q + r'.r = 1, i.e., qInvNeg = - q⁻¹ mod r +// used for Montgomery reduction +const qInvNeg = 9786893198990664585 + +// mu = 2^288 / q needed for partial Barrett reduction +const mu uint64 = 22721021478 + +func init() { + _modulus.SetString("30644e72e131a029b85045b68181585d97816a916871ca8d3c208c16d87cfd47", 16) +} + +// NewElement returns a new Element from a uint64 value +// +// it is equivalent to +// +// var v Element +// v.SetUint64(...) +func NewElement(v uint64) Element { + z := Element{v} + z.Mul(&z, &rSquare) + return z +} + +// SetUint64 sets z to v and returns z +func (z *Element) SetUint64(v uint64) *Element { + // sets z LSB to v (non-Montgomery form) and convert z to Montgomery form + *z = Element{v} + return z.Mul(z, &rSquare) // z.toMont() +} + +// SetInt64 sets z to v and returns z +func (z *Element) SetInt64(v int64) *Element { + + // absolute value of v + m := v >> 63 + z.SetUint64(uint64((v ^ m) - m)) + + if m != 0 { + // v is negative + z.Neg(z) + } + + return z +} + +// Set z = x and returns z +func (z *Element) Set(x *Element) *Element { + z[0] = x[0] + z[1] = x[1] + z[2] = x[2] + z[3] = x[3] + return z +} + +// SetInterface converts provided interface into Element +// returns an error if provided type is not supported +// supported types: +// +// Element +// *Element +// uint64 +// int +// string (see SetString for valid formats) +// *big.Int +// big.Int +// []byte +func (z *Element) SetInterface(i1 interface{}) (*Element, error) { + if i1 == nil { + return nil, errors.New("can't set fr.Element with ") + } + + switch c1 := i1.(type) { + case Element: + return z.Set(&c1), nil + case *Element: + if c1 == nil { + return nil, errors.New("can't set fr.Element with ") + } + return z.Set(c1), nil + case uint8: + return z.SetUint64(uint64(c1)), nil + case uint16: + return z.SetUint64(uint64(c1)), nil + case uint32: + return z.SetUint64(uint64(c1)), nil + case uint: + return z.SetUint64(uint64(c1)), nil + case uint64: + return z.SetUint64(c1), nil + case int8: + return z.SetInt64(int64(c1)), nil + case int16: + return z.SetInt64(int64(c1)), nil + case int32: + return z.SetInt64(int64(c1)), nil + case int64: + return z.SetInt64(c1), nil + case int: + return z.SetInt64(int64(c1)), nil + case string: + return z.SetString(c1) + case *big.Int: + if c1 == nil { + return nil, errors.New("can't set fr.Element with ") + } + return z.SetBigInt(c1), nil + case big.Int: + return z.SetBigInt(&c1), nil + case []byte: + return z.SetBytes(c1), nil + default: + return nil, errors.New("can't set fr.Element from type " + reflect.TypeOf(i1).String()) + } +} + +// SetZero z = 0 +func (z *Element) SetZero() *Element { + z[0] = 0 + z[1] = 0 + z[2] = 0 + z[3] = 0 + return z +} + +// SetOne z = 1 (in Montgomery form) +func (z *Element) SetOne() *Element { + z[0] = 15230403791020821917 + z[1] = 754611498739239741 + z[2] = 7381016538464732716 + z[3] = 1011752739694698287 + return z +} + +// Div z = x*y⁻¹ (mod q) +func (z *Element) Div(x, y *Element) *Element { + var yInv Element + yInv.Inverse(y) + z.Mul(x, &yInv) + return z +} + +// Equal returns z == x; constant-time +func (z *Element) Equal(x *Element) bool { + return z.NotEqual(x) == 0 +} + +// NotEqual returns 0 if and only if z == x; constant-time +func (z *Element) NotEqual(x *Element) uint64 { + return (z[3] ^ x[3]) | (z[2] ^ x[2]) | (z[1] ^ x[1]) | (z[0] ^ x[0]) +} + +// IsZero returns z == 0 +func (z *Element) IsZero() bool { + return (z[3] | z[2] | z[1] | z[0]) == 0 +} + +// IsOne returns z == 1 +func (z *Element) IsOne() bool { + return ((z[3] ^ 1011752739694698287) | (z[2] ^ 7381016538464732716) | (z[1] ^ 754611498739239741) | (z[0] ^ 15230403791020821917)) == 0 +} + +// IsUint64 reports whether z can be represented as an uint64. +func (z *Element) IsUint64() bool { + zz := *z + zz.fromMont() + return zz.FitsOnOneWord() +} + +// Uint64 returns the uint64 representation of x. If x cannot be represented in a uint64, the result is undefined. +func (z *Element) Uint64() uint64 { + return z.Bits()[0] +} + +// FitsOnOneWord reports whether z words (except the least significant word) are 0 +// +// It is the responsibility of the caller to convert from Montgomery to Regular form if needed. +func (z *Element) FitsOnOneWord() bool { + return (z[3] | z[2] | z[1]) == 0 +} + +// Cmp compares (lexicographic order) z and x and returns: +// +// -1 if z < x +// 0 if z == x +// +1 if z > x +func (z *Element) Cmp(x *Element) int { + _z := z.Bits() + _x := x.Bits() + if _z[3] > _x[3] { + return 1 + } else if _z[3] < _x[3] { + return -1 + } + if _z[2] > _x[2] { + return 1 + } else if _z[2] < _x[2] { + return -1 + } + if _z[1] > _x[1] { + return 1 + } else if _z[1] < _x[1] { + return -1 + } + if _z[0] > _x[0] { + return 1 + } else if _z[0] < _x[0] { + return -1 + } + return 0 +} + +// LexicographicallyLargest returns true if this element is strictly lexicographically +// larger than its negation, false otherwise +func (z *Element) LexicographicallyLargest() bool { + // adapted from github.com/zkcrypto/bls12_381 + // we check if the element is larger than (q-1) / 2 + // if z - (((q -1) / 2) + 1) have no underflow, then z > (q-1) / 2 + + _z := z.Bits() + + var b uint64 + _, b = bits.Sub64(_z[0], 11389680472494603940, 0) + _, b = bits.Sub64(_z[1], 14681934109093717318, b) + _, b = bits.Sub64(_z[2], 15863968012492123182, b) + _, b = bits.Sub64(_z[3], 1743499133401485332, b) + + return b == 0 +} + +// SetRandom sets z to a uniform random value in [0, q). +// +// This might error only if reading from crypto/rand.Reader errors, +// in which case, value of z is undefined. +func (z *Element) SetRandom() (*Element, error) { + // this code is generated for all modulus + // and derived from go/src/crypto/rand/util.go + + // l is number of limbs * 8; the number of bytes needed to reconstruct 4 uint64 + const l = 32 + + // bitLen is the maximum bit length needed to encode a value < q. + const bitLen = 254 + + // k is the maximum byte length needed to encode a value < q. + const k = (bitLen + 7) / 8 + + // b is the number of bits in the most significant byte of q-1. + b := uint(bitLen % 8) + if b == 0 { + b = 8 + } + + var bytes [l]byte + + for { + // note that bytes[k:l] is always 0 + if _, err := io.ReadFull(rand.Reader, bytes[:k]); err != nil { + return nil, err + } + + // Clear unused bits in in the most significant byte to increase probability + // that the candidate is < q. + bytes[k-1] &= uint8(int(1<> 1 + z[0] = z[0]>>1 | z[1]<<63 + z[1] = z[1]>>1 | z[2]<<63 + z[2] = z[2]>>1 | z[3]<<63 + z[3] >>= 1 + +} + +// fromMont converts z in place (i.e. mutates) from Montgomery to regular representation +// sets and returns z = z * 1 +func (z *Element) fromMont() *Element { + fromMont(z) + return z +} + +// Add z = x + y (mod q) +func (z *Element) Add(x, y *Element) *Element { + + var carry uint64 + z[0], carry = bits.Add64(x[0], y[0], 0) + z[1], carry = bits.Add64(x[1], y[1], carry) + z[2], carry = bits.Add64(x[2], y[2], carry) + z[3], _ = bits.Add64(x[3], y[3], carry) + + // if z ⩾ q → z -= q + if !z.smallerThanModulus() { + var b uint64 + z[0], b = bits.Sub64(z[0], q0, 0) + z[1], b = bits.Sub64(z[1], q1, b) + z[2], b = bits.Sub64(z[2], q2, b) + z[3], _ = bits.Sub64(z[3], q3, b) + } + return z +} + +// Double z = x + x (mod q), aka Lsh 1 +func (z *Element) Double(x *Element) *Element { + + var carry uint64 + z[0], carry = bits.Add64(x[0], x[0], 0) + z[1], carry = bits.Add64(x[1], x[1], carry) + z[2], carry = bits.Add64(x[2], x[2], carry) + z[3], _ = bits.Add64(x[3], x[3], carry) + + // if z ⩾ q → z -= q + if !z.smallerThanModulus() { + var b uint64 + z[0], b = bits.Sub64(z[0], q0, 0) + z[1], b = bits.Sub64(z[1], q1, b) + z[2], b = bits.Sub64(z[2], q2, b) + z[3], _ = bits.Sub64(z[3], q3, b) + } + return z +} + +// Sub z = x - y (mod q) +func (z *Element) Sub(x, y *Element) *Element { + var b uint64 + z[0], b = bits.Sub64(x[0], y[0], 0) + z[1], b = bits.Sub64(x[1], y[1], b) + z[2], b = bits.Sub64(x[2], y[2], b) + z[3], b = bits.Sub64(x[3], y[3], b) + if b != 0 { + var c uint64 + z[0], c = bits.Add64(z[0], q0, 0) + z[1], c = bits.Add64(z[1], q1, c) + z[2], c = bits.Add64(z[2], q2, c) + z[3], _ = bits.Add64(z[3], q3, c) + } + return z +} + +// Neg z = q - x +func (z *Element) Neg(x *Element) *Element { + if x.IsZero() { + z.SetZero() + return z + } + var borrow uint64 + z[0], borrow = bits.Sub64(q0, x[0], 0) + z[1], borrow = bits.Sub64(q1, x[1], borrow) + z[2], borrow = bits.Sub64(q2, x[2], borrow) + z[3], _ = bits.Sub64(q3, x[3], borrow) + return z +} + +// Select is a constant-time conditional move. +// If c=0, z = x0. Else z = x1 +func (z *Element) Select(c int, x0 *Element, x1 *Element) *Element { + cC := uint64((int64(c) | -int64(c)) >> 63) // "canonicized" into: 0 if c=0, -1 otherwise + z[0] = x0[0] ^ cC&(x0[0]^x1[0]) + z[1] = x0[1] ^ cC&(x0[1]^x1[1]) + z[2] = x0[2] ^ cC&(x0[2]^x1[2]) + z[3] = x0[3] ^ cC&(x0[3]^x1[3]) + return z +} + +// _mulGeneric is unoptimized textbook CIOS +// it is a fallback solution on x86 when ADX instruction set is not available +// and is used for testing purposes. +func _mulGeneric(z, x, y *Element) { + + // Algorithm 2 of "Faster Montgomery Multiplication and Multi-Scalar-Multiplication for SNARKS" + // by Y. El Housni and G. Botrel https://doi.org/10.46586/tches.v2023.i3.504-521 + + var t [5]uint64 + var D uint64 + var m, C uint64 + // ----------------------------------- + // First loop + + C, t[0] = bits.Mul64(y[0], x[0]) + C, t[1] = madd1(y[0], x[1], C) + C, t[2] = madd1(y[0], x[2], C) + C, t[3] = madd1(y[0], x[3], C) + + t[4], D = bits.Add64(t[4], C, 0) + + // m = t[0]n'[0] mod W + m = t[0] * qInvNeg + + // ----------------------------------- + // Second loop + C = madd0(m, q0, t[0]) + C, t[0] = madd2(m, q1, t[1], C) + C, t[1] = madd2(m, q2, t[2], C) + C, t[2] = madd2(m, q3, t[3], C) + + t[3], C = bits.Add64(t[4], C, 0) + t[4], _ = bits.Add64(0, D, C) + // ----------------------------------- + // First loop + + C, t[0] = madd1(y[1], x[0], t[0]) + C, t[1] = madd2(y[1], x[1], t[1], C) + C, t[2] = madd2(y[1], x[2], t[2], C) + C, t[3] = madd2(y[1], x[3], t[3], C) + + t[4], D = bits.Add64(t[4], C, 0) + + // m = t[0]n'[0] mod W + m = t[0] * qInvNeg + + // ----------------------------------- + // Second loop + C = madd0(m, q0, t[0]) + C, t[0] = madd2(m, q1, t[1], C) + C, t[1] = madd2(m, q2, t[2], C) + C, t[2] = madd2(m, q3, t[3], C) + + t[3], C = bits.Add64(t[4], C, 0) + t[4], _ = bits.Add64(0, D, C) + // ----------------------------------- + // First loop + + C, t[0] = madd1(y[2], x[0], t[0]) + C, t[1] = madd2(y[2], x[1], t[1], C) + C, t[2] = madd2(y[2], x[2], t[2], C) + C, t[3] = madd2(y[2], x[3], t[3], C) + + t[4], D = bits.Add64(t[4], C, 0) + + // m = t[0]n'[0] mod W + m = t[0] * qInvNeg + + // ----------------------------------- + // Second loop + C = madd0(m, q0, t[0]) + C, t[0] = madd2(m, q1, t[1], C) + C, t[1] = madd2(m, q2, t[2], C) + C, t[2] = madd2(m, q3, t[3], C) + + t[3], C = bits.Add64(t[4], C, 0) + t[4], _ = bits.Add64(0, D, C) + // ----------------------------------- + // First loop + + C, t[0] = madd1(y[3], x[0], t[0]) + C, t[1] = madd2(y[3], x[1], t[1], C) + C, t[2] = madd2(y[3], x[2], t[2], C) + C, t[3] = madd2(y[3], x[3], t[3], C) + + t[4], D = bits.Add64(t[4], C, 0) + + // m = t[0]n'[0] mod W + m = t[0] * qInvNeg + + // ----------------------------------- + // Second loop + C = madd0(m, q0, t[0]) + C, t[0] = madd2(m, q1, t[1], C) + C, t[1] = madd2(m, q2, t[2], C) + C, t[2] = madd2(m, q3, t[3], C) + + t[3], C = bits.Add64(t[4], C, 0) + t[4], _ = bits.Add64(0, D, C) + + if t[4] != 0 { + // we need to reduce, we have a result on 5 words + var b uint64 + z[0], b = bits.Sub64(t[0], q0, 0) + z[1], b = bits.Sub64(t[1], q1, b) + z[2], b = bits.Sub64(t[2], q2, b) + z[3], _ = bits.Sub64(t[3], q3, b) + return + } + + // copy t into z + z[0] = t[0] + z[1] = t[1] + z[2] = t[2] + z[3] = t[3] + + // if z ⩾ q → z -= q + if !z.smallerThanModulus() { + var b uint64 + z[0], b = bits.Sub64(z[0], q0, 0) + z[1], b = bits.Sub64(z[1], q1, b) + z[2], b = bits.Sub64(z[2], q2, b) + z[3], _ = bits.Sub64(z[3], q3, b) + } +} + +func _fromMontGeneric(z *Element) { + // the following lines implement z = z * 1 + // with a modified CIOS montgomery multiplication + // see Mul for algorithm documentation + { + // m = z[0]n'[0] mod W + m := z[0] * qInvNeg + C := madd0(m, q0, z[0]) + C, z[0] = madd2(m, q1, z[1], C) + C, z[1] = madd2(m, q2, z[2], C) + C, z[2] = madd2(m, q3, z[3], C) + z[3] = C + } + { + // m = z[0]n'[0] mod W + m := z[0] * qInvNeg + C := madd0(m, q0, z[0]) + C, z[0] = madd2(m, q1, z[1], C) + C, z[1] = madd2(m, q2, z[2], C) + C, z[2] = madd2(m, q3, z[3], C) + z[3] = C + } + { + // m = z[0]n'[0] mod W + m := z[0] * qInvNeg + C := madd0(m, q0, z[0]) + C, z[0] = madd2(m, q1, z[1], C) + C, z[1] = madd2(m, q2, z[2], C) + C, z[2] = madd2(m, q3, z[3], C) + z[3] = C + } + { + // m = z[0]n'[0] mod W + m := z[0] * qInvNeg + C := madd0(m, q0, z[0]) + C, z[0] = madd2(m, q1, z[1], C) + C, z[1] = madd2(m, q2, z[2], C) + C, z[2] = madd2(m, q3, z[3], C) + z[3] = C + } + + // if z ⩾ q → z -= q + if !z.smallerThanModulus() { + var b uint64 + z[0], b = bits.Sub64(z[0], q0, 0) + z[1], b = bits.Sub64(z[1], q1, b) + z[2], b = bits.Sub64(z[2], q2, b) + z[3], _ = bits.Sub64(z[3], q3, b) + } +} + +func _reduceGeneric(z *Element) { + + // if z ⩾ q → z -= q + if !z.smallerThanModulus() { + var b uint64 + z[0], b = bits.Sub64(z[0], q0, 0) + z[1], b = bits.Sub64(z[1], q1, b) + z[2], b = bits.Sub64(z[2], q2, b) + z[3], _ = bits.Sub64(z[3], q3, b) + } +} + +// BatchInvert returns a new slice with every element inverted. +// Uses Montgomery batch inversion trick +func BatchInvert(a []Element) []Element { + res := make([]Element, len(a)) + if len(a) == 0 { + return res + } + + zeroes := bitset.New(uint(len(a))) + accumulator := One() + + for i := 0; i < len(a); i++ { + if a[i].IsZero() { + zeroes.Set(uint(i)) + continue + } + res[i] = accumulator + accumulator.Mul(&accumulator, &a[i]) + } + + accumulator.Inverse(&accumulator) + + for i := len(a) - 1; i >= 0; i-- { + if zeroes.Test(uint(i)) { + continue + } + res[i].Mul(&res[i], &accumulator) + accumulator.Mul(&accumulator, &a[i]) + } + + return res +} + +func _butterflyGeneric(a, b *Element) { + t := *a + a.Add(a, b) + b.Sub(&t, b) +} + +// BitLen returns the minimum number of bits needed to represent z +// returns 0 if z == 0 +func (z *Element) BitLen() int { + if z[3] != 0 { + return 192 + bits.Len64(z[3]) + } + if z[2] != 0 { + return 128 + bits.Len64(z[2]) + } + if z[1] != 0 { + return 64 + bits.Len64(z[1]) + } + return bits.Len64(z[0]) +} + +// Hash msg to count prime field elements. +// https://tools.ietf.org/html/draft-irtf-cfrg-hash-to-curve-06#section-5.2 +func Hash(msg, dst []byte, count int) ([]Element, error) { + // 128 bits of security + // L = ceil((ceil(log2(p)) + k) / 8), where k is the security parameter = 128 + const Bytes = 1 + (Bits-1)/8 + const L = 16 + Bytes + + lenInBytes := count * L + pseudoRandomBytes, err := hash.ExpandMsgXmd(msg, dst, lenInBytes) + if err != nil { + return nil, err + } + + // get temporary big int from the pool + vv := pool.BigInt.Get() + + res := make([]Element, count) + for i := 0; i < count; i++ { + vv.SetBytes(pseudoRandomBytes[i*L : (i+1)*L]) + res[i].SetBigInt(vv) + } + + // release object into pool + pool.BigInt.Put(vv) + + return res, nil +} + +// Exp z = xᵏ (mod q) +func (z *Element) Exp(x Element, k *big.Int) *Element { + if k.IsUint64() && k.Uint64() == 0 { + return z.SetOne() + } + + e := k + if k.Sign() == -1 { + // negative k, we invert + // if k < 0: xᵏ (mod q) == (x⁻¹)ᵏ (mod q) + x.Inverse(&x) + + // we negate k in a temp big.Int since + // Int.Bit(_) of k and -k is different + e = pool.BigInt.Get() + defer pool.BigInt.Put(e) + e.Neg(k) + } + + z.Set(&x) + + for i := e.BitLen() - 2; i >= 0; i-- { + z.Square(z) + if e.Bit(i) == 1 { + z.Mul(z, &x) + } + } + + return z +} + +// rSquare where r is the Montgommery constant +// see section 2.3.2 of Tolga Acar's thesis +// https://www.microsoft.com/en-us/research/wp-content/uploads/1998/06/97Acar.pdf +var rSquare = Element{ + 17522657719365597833, + 13107472804851548667, + 5164255478447964150, + 493319470278259999, +} + +// toMont converts z to Montgomery form +// sets and returns z = z * r² +func (z *Element) toMont() *Element { + return z.Mul(z, &rSquare) +} + +// String returns the decimal representation of z as generated by +// z.Text(10). +func (z *Element) String() string { + return z.Text(10) +} + +// toBigInt returns z as a big.Int in Montgomery form +func (z *Element) toBigInt(res *big.Int) *big.Int { + var b [Bytes]byte + binary.BigEndian.PutUint64(b[24:32], z[0]) + binary.BigEndian.PutUint64(b[16:24], z[1]) + binary.BigEndian.PutUint64(b[8:16], z[2]) + binary.BigEndian.PutUint64(b[0:8], z[3]) + + return res.SetBytes(b[:]) +} + +// Text returns the string representation of z in the given base. +// Base must be between 2 and 36, inclusive. The result uses the +// lower-case letters 'a' to 'z' for digit values 10 to 35. +// No prefix (such as "0x") is added to the string. If z is a nil +// pointer it returns "". +// If base == 10 and -z fits in a uint16 prefix "-" is added to the string. +func (z *Element) Text(base int) string { + if base < 2 || base > 36 { + panic("invalid base") + } + if z == nil { + return "" + } + + const maxUint16 = 65535 + if base == 10 { + var zzNeg Element + zzNeg.Neg(z) + zzNeg.fromMont() + if zzNeg.FitsOnOneWord() && zzNeg[0] <= maxUint16 && zzNeg[0] != 0 { + return "-" + strconv.FormatUint(zzNeg[0], base) + } + } + zz := *z + zz.fromMont() + if zz.FitsOnOneWord() { + return strconv.FormatUint(zz[0], base) + } + vv := pool.BigInt.Get() + r := zz.toBigInt(vv).Text(base) + pool.BigInt.Put(vv) + return r +} + +// BigInt sets and return z as a *big.Int +func (z *Element) BigInt(res *big.Int) *big.Int { + _z := *z + _z.fromMont() + return _z.toBigInt(res) +} + +// ToBigIntRegular returns z as a big.Int in regular form +// +// Deprecated: use BigInt(*big.Int) instead +func (z Element) ToBigIntRegular(res *big.Int) *big.Int { + z.fromMont() + return z.toBigInt(res) +} + +// Bits provides access to z by returning its value as a little-endian [4]uint64 array. +// Bits is intended to support implementation of missing low-level Element +// functionality outside this package; it should be avoided otherwise. +func (z *Element) Bits() [4]uint64 { + _z := *z + fromMont(&_z) + return _z +} + +// Bytes returns the value of z as a big-endian byte array +func (z *Element) Bytes() (res [Bytes]byte) { + BigEndian.PutElement(&res, *z) + return +} + +// Marshal returns the value of z as a big-endian byte slice +func (z *Element) Marshal() []byte { + b := z.Bytes() + return b[:] +} + +// Unmarshal is an alias for SetBytes, it sets z to the value of e. +func (z *Element) Unmarshal(e []byte) { + z.SetBytes(e) +} + +// SetBytes interprets e as the bytes of a big-endian unsigned integer, +// sets z to that value, and returns z. +func (z *Element) SetBytes(e []byte) *Element { + if len(e) == Bytes { + // fast path + v, err := BigEndian.Element((*[Bytes]byte)(e)) + if err == nil { + *z = v + return z + } + } + + // slow path. + // get a big int from our pool + vv := pool.BigInt.Get() + vv.SetBytes(e) + + // set big int + z.SetBigInt(vv) + + // put temporary object back in pool + pool.BigInt.Put(vv) + + return z +} + +// SetBytesCanonical interprets e as the bytes of a big-endian 32-byte integer. +// If e is not a 32-byte slice or encodes a value higher than q, +// SetBytesCanonical returns an error. +func (z *Element) SetBytesCanonical(e []byte) error { + if len(e) != Bytes { + return errors.New("invalid fr.Element encoding") + } + v, err := BigEndian.Element((*[Bytes]byte)(e)) + if err != nil { + return err + } + *z = v + return nil +} + +// SetBigInt sets z to v and returns z +func (z *Element) SetBigInt(v *big.Int) *Element { + z.SetZero() + + var zero big.Int + + // fast path + c := v.Cmp(&_modulus) + if c == 0 { + // v == 0 + return z + } else if c != 1 && v.Cmp(&zero) != -1 { + // 0 <= v < q + return z.setBigInt(v) + } + + // get temporary big int from the pool + vv := pool.BigInt.Get() + + // copy input + modular reduction + vv.Mod(v, &_modulus) + + // set big int byte value + z.setBigInt(vv) + + // release object into pool + pool.BigInt.Put(vv) + return z +} + +// setBigInt assumes 0 ⩽ v < q +func (z *Element) setBigInt(v *big.Int) *Element { + vBits := v.Bits() + + if bits.UintSize == 64 { + for i := 0; i < len(vBits); i++ { + z[i] = uint64(vBits[i]) + } + } else { + for i := 0; i < len(vBits); i++ { + if i%2 == 0 { + z[i/2] = uint64(vBits[i]) + } else { + z[i/2] |= uint64(vBits[i]) << 32 + } + } + } + + return z.toMont() +} + +// SetString creates a big.Int with number and calls SetBigInt on z +// +// The number prefix determines the actual base: A prefix of +// ”0b” or ”0B” selects base 2, ”0”, ”0o” or ”0O” selects base 8, +// and ”0x” or ”0X” selects base 16. Otherwise, the selected base is 10 +// and no prefix is accepted. +// +// For base 16, lower and upper case letters are considered the same: +// The letters 'a' to 'f' and 'A' to 'F' represent digit values 10 to 15. +// +// An underscore character ”_” may appear between a base +// prefix and an adjacent digit, and between successive digits; such +// underscores do not change the value of the number. +// Incorrect placement of underscores is reported as a panic if there +// are no other errors. +// +// If the number is invalid this method leaves z unchanged and returns nil, error. +func (z *Element) SetString(number string) (*Element, error) { + // get temporary big int from the pool + vv := pool.BigInt.Get() + + if _, ok := vv.SetString(number, 0); !ok { + return nil, errors.New("Element.SetString failed -> can't parse number into a big.Int " + number) + } + + z.SetBigInt(vv) + + // release object into pool + pool.BigInt.Put(vv) + + return z, nil +} + +// MarshalJSON returns json encoding of z (z.Text(10)) +// If z == nil, returns null +func (z *Element) MarshalJSON() ([]byte, error) { + if z == nil { + return []byte("null"), nil + } + const maxSafeBound = 15 // we encode it as number if it's small + s := z.Text(10) + if len(s) <= maxSafeBound { + return []byte(s), nil + } + var sbb strings.Builder + sbb.WriteByte('"') + sbb.WriteString(s) + sbb.WriteByte('"') + return []byte(sbb.String()), nil +} + +// UnmarshalJSON accepts numbers and strings as input +// See Element.SetString for valid prefixes (0x, 0b, ...) +func (z *Element) UnmarshalJSON(data []byte) error { + s := string(data) + if len(s) > Bits*3 { + return errors.New("value too large (max = Element.Bits * 3)") + } + + // we accept numbers and strings, remove leading and trailing quotes if any + if len(s) > 0 && s[0] == '"' { + s = s[1:] + } + if len(s) > 0 && s[len(s)-1] == '"' { + s = s[:len(s)-1] + } + + // get temporary big int from the pool + vv := pool.BigInt.Get() + + if _, ok := vv.SetString(s, 0); !ok { + return errors.New("can't parse into a big.Int: " + s) + } + + z.SetBigInt(vv) + + // release object into pool + pool.BigInt.Put(vv) + return nil +} + +// A ByteOrder specifies how to convert byte slices into a Element +type ByteOrder interface { + Element(*[Bytes]byte) (Element, error) + PutElement(*[Bytes]byte, Element) + String() string +} + +var errInvalidEncoding = errors.New("invalid fr.Element encoding") + +// BigEndian is the big-endian implementation of ByteOrder and AppendByteOrder. +var BigEndian bigEndian + +type bigEndian struct{} + +// Element interpret b is a big-endian 32-byte slice. +// If b encodes a value higher than q, Element returns error. +func (bigEndian) Element(b *[Bytes]byte) (Element, error) { + var z Element + z[0] = binary.BigEndian.Uint64((*b)[24:32]) + z[1] = binary.BigEndian.Uint64((*b)[16:24]) + z[2] = binary.BigEndian.Uint64((*b)[8:16]) + z[3] = binary.BigEndian.Uint64((*b)[0:8]) + + if !z.smallerThanModulus() { + return Element{}, errInvalidEncoding + } + + z.toMont() + return z, nil +} + +func (bigEndian) PutElement(b *[Bytes]byte, e Element) { + e.fromMont() + binary.BigEndian.PutUint64((*b)[24:32], e[0]) + binary.BigEndian.PutUint64((*b)[16:24], e[1]) + binary.BigEndian.PutUint64((*b)[8:16], e[2]) + binary.BigEndian.PutUint64((*b)[0:8], e[3]) +} + +func (bigEndian) String() string { return "BigEndian" } + +// LittleEndian is the little-endian implementation of ByteOrder and AppendByteOrder. +var LittleEndian littleEndian + +type littleEndian struct{} + +func (littleEndian) Element(b *[Bytes]byte) (Element, error) { + var z Element + z[0] = binary.LittleEndian.Uint64((*b)[0:8]) + z[1] = binary.LittleEndian.Uint64((*b)[8:16]) + z[2] = binary.LittleEndian.Uint64((*b)[16:24]) + z[3] = binary.LittleEndian.Uint64((*b)[24:32]) + + if !z.smallerThanModulus() { + return Element{}, errInvalidEncoding + } + + z.toMont() + return z, nil +} + +func (littleEndian) PutElement(b *[Bytes]byte, e Element) { + e.fromMont() + binary.LittleEndian.PutUint64((*b)[0:8], e[0]) + binary.LittleEndian.PutUint64((*b)[8:16], e[1]) + binary.LittleEndian.PutUint64((*b)[16:24], e[2]) + binary.LittleEndian.PutUint64((*b)[24:32], e[3]) +} + +func (littleEndian) String() string { return "LittleEndian" } + +// Legendre returns the Legendre symbol of z (either +1, -1, or 0.) +func (z *Element) Legendre() int { + var l Element + // z^((q-1)/2) + l.expByLegendreExp(*z) + + if l.IsZero() { + return 0 + } + + // if l == 1 + if l.IsOne() { + return 1 + } + return -1 +} + +// Sqrt z = √x (mod q) +// if the square root doesn't exist (x is not a square mod q) +// Sqrt leaves z unchanged and returns nil +func (z *Element) Sqrt(x *Element) *Element { + // q ≡ 3 (mod 4) + // using z ≡ ± x^((p+1)/4) (mod q) + var y, square Element + y.expBySqrtExp(*x) + // as we didn't compute the legendre symbol, ensure we found y such that y * y = x + square.Square(&y) + if square.Equal(x) { + return z.Set(&y) + } + return nil +} + +const ( + k = 32 // word size / 2 + signBitSelector = uint64(1) << 63 + approxLowBitsN = k - 1 + approxHighBitsN = k + 1 +) + +const ( + inversionCorrectionFactorWord0 = 11111708840330028223 + inversionCorrectionFactorWord1 = 3098618286181893933 + inversionCorrectionFactorWord2 = 756602578711705709 + inversionCorrectionFactorWord3 = 1041752015607019851 + invIterationsN = 18 +) + +// Inverse z = x⁻¹ (mod q) +// +// if x == 0, sets and returns z = x +func (z *Element) Inverse(x *Element) *Element { + // Implements "Optimized Binary GCD for Modular Inversion" + // https://github.com/pornin/bingcd/blob/main/doc/bingcd.pdf + + a := *x + b := Element{ + q0, + q1, + q2, + q3, + } // b := q + + u := Element{1} + + // Update factors: we get [u; v] ← [f₀ g₀; f₁ g₁] [u; v] + // cᵢ = fᵢ + 2³¹ - 1 + 2³² * (gᵢ + 2³¹ - 1) + var c0, c1 int64 + + // Saved update factors to reduce the number of field multiplications + var pf0, pf1, pg0, pg1 int64 + + var i uint + + var v, s Element + + // Since u,v are updated every other iteration, we must make sure we terminate after evenly many iterations + // This also lets us get away with half as many updates to u,v + // To make this constant-time-ish, replace the condition with i < invIterationsN + for i = 0; i&1 == 1 || !a.IsZero(); i++ { + n := max(a.BitLen(), b.BitLen()) + aApprox, bApprox := approximate(&a, n), approximate(&b, n) + + // f₀, g₀, f₁, g₁ = 1, 0, 0, 1 + c0, c1 = updateFactorIdentityMatrixRow0, updateFactorIdentityMatrixRow1 + + for j := 0; j < approxLowBitsN; j++ { + + // -2ʲ < f₀, f₁ ≤ 2ʲ + // |f₀| + |f₁| < 2ʲ⁺¹ + + if aApprox&1 == 0 { + aApprox /= 2 + } else { + s, borrow := bits.Sub64(aApprox, bApprox, 0) + if borrow == 1 { + s = bApprox - aApprox + bApprox = aApprox + c0, c1 = c1, c0 + // invariants unchanged + } + + aApprox = s / 2 + c0 = c0 - c1 + + // Now |f₀| < 2ʲ⁺¹ ≤ 2ʲ⁺¹ (only the weaker inequality is needed, strictly speaking) + // Started with f₀ > -2ʲ and f₁ ≤ 2ʲ, so f₀ - f₁ > -2ʲ⁺¹ + // Invariants unchanged for f₁ + } + + c1 *= 2 + // -2ʲ⁺¹ < f₁ ≤ 2ʲ⁺¹ + // So now |f₀| + |f₁| < 2ʲ⁺² + } + + s = a + + var g0 int64 + // from this point on c0 aliases for f0 + c0, g0 = updateFactorsDecompose(c0) + aHi := a.linearCombNonModular(&s, c0, &b, g0) + if aHi&signBitSelector != 0 { + // if aHi < 0 + c0, g0 = -c0, -g0 + aHi = negL(&a, aHi) + } + // right-shift a by k-1 bits + a[0] = (a[0] >> approxLowBitsN) | ((a[1]) << approxHighBitsN) + a[1] = (a[1] >> approxLowBitsN) | ((a[2]) << approxHighBitsN) + a[2] = (a[2] >> approxLowBitsN) | ((a[3]) << approxHighBitsN) + a[3] = (a[3] >> approxLowBitsN) | (aHi << approxHighBitsN) + + var f1 int64 + // from this point on c1 aliases for g0 + f1, c1 = updateFactorsDecompose(c1) + bHi := b.linearCombNonModular(&s, f1, &b, c1) + if bHi&signBitSelector != 0 { + // if bHi < 0 + f1, c1 = -f1, -c1 + bHi = negL(&b, bHi) + } + // right-shift b by k-1 bits + b[0] = (b[0] >> approxLowBitsN) | ((b[1]) << approxHighBitsN) + b[1] = (b[1] >> approxLowBitsN) | ((b[2]) << approxHighBitsN) + b[2] = (b[2] >> approxLowBitsN) | ((b[3]) << approxHighBitsN) + b[3] = (b[3] >> approxLowBitsN) | (bHi << approxHighBitsN) + + if i&1 == 1 { + // Combine current update factors with previously stored ones + // [F₀, G₀; F₁, G₁] ← [f₀, g₀; f₁, g₁] [pf₀, pg₀; pf₁, pg₁], with capital letters denoting new combined values + // We get |F₀| = | f₀pf₀ + g₀pf₁ | ≤ |f₀pf₀| + |g₀pf₁| = |f₀| |pf₀| + |g₀| |pf₁| ≤ 2ᵏ⁻¹|pf₀| + 2ᵏ⁻¹|pf₁| + // = 2ᵏ⁻¹ (|pf₀| + |pf₁|) < 2ᵏ⁻¹ 2ᵏ = 2²ᵏ⁻¹ + // So |F₀| < 2²ᵏ⁻¹ meaning it fits in a 2k-bit signed register + + // c₀ aliases f₀, c₁ aliases g₁ + c0, g0, f1, c1 = c0*pf0+g0*pf1, + c0*pg0+g0*pg1, + f1*pf0+c1*pf1, + f1*pg0+c1*pg1 + + s = u + + // 0 ≤ u, v < 2²⁵⁵ + // |F₀|, |G₀| < 2⁶³ + u.linearComb(&u, c0, &v, g0) + // |F₁|, |G₁| < 2⁶³ + v.linearComb(&s, f1, &v, c1) + + } else { + // Save update factors + pf0, pg0, pf1, pg1 = c0, g0, f1, c1 + } + } + + // For every iteration that we miss, v is not being multiplied by 2ᵏ⁻² + const pSq uint64 = 1 << (2 * (k - 1)) + a = Element{pSq} + // If the function is constant-time ish, this loop will not run (no need to take it out explicitly) + for ; i < invIterationsN; i += 2 { + // could optimize further with mul by word routine or by pre-computing a table since with k=26, + // we would multiply by pSq up to 13times; + // on x86, the assembly routine outperforms generic code for mul by word + // on arm64, we may loose up to ~5% for 6 limbs + v.Mul(&v, &a) + } + + u.Set(x) // for correctness check + + z.Mul(&v, &Element{ + inversionCorrectionFactorWord0, + inversionCorrectionFactorWord1, + inversionCorrectionFactorWord2, + inversionCorrectionFactorWord3, + }) + + // correctness check + v.Mul(&u, z) + if !v.IsOne() && !u.IsZero() { + return z.inverseExp(u) + } + + return z +} + +// inverseExp computes z = x⁻¹ (mod q) = x**(q-2) (mod q) +func (z *Element) inverseExp(x Element) *Element { + // e == q-2 + e := Modulus() + e.Sub(e, big.NewInt(2)) + + z.Set(&x) + + for i := e.BitLen() - 2; i >= 0; i-- { + z.Square(z) + if e.Bit(i) == 1 { + z.Mul(z, &x) + } + } + + return z +} + +// approximate a big number x into a single 64 bit word using its uppermost and lowermost bits +// if x fits in a word as is, no approximation necessary +func approximate(x *Element, nBits int) uint64 { + + if nBits <= 64 { + return x[0] + } + + const mask = (uint64(1) << (k - 1)) - 1 // k-1 ones + lo := mask & x[0] + + hiWordIndex := (nBits - 1) / 64 + + hiWordBitsAvailable := nBits - hiWordIndex*64 + hiWordBitsUsed := min(hiWordBitsAvailable, approxHighBitsN) + + mask_ := uint64(^((1 << (hiWordBitsAvailable - hiWordBitsUsed)) - 1)) + hi := (x[hiWordIndex] & mask_) << (64 - hiWordBitsAvailable) + + mask_ = ^(1<<(approxLowBitsN+hiWordBitsUsed) - 1) + mid := (mask_ & x[hiWordIndex-1]) >> hiWordBitsUsed + + return lo | mid | hi +} + +// linearComb z = xC * x + yC * y; +// 0 ≤ x, y < 2²⁵⁴ +// |xC|, |yC| < 2⁶³ +func (z *Element) linearComb(x *Element, xC int64, y *Element, yC int64) { + // | (hi, z) | < 2 * 2⁶³ * 2²⁵⁴ = 2³¹⁸ + // therefore | hi | < 2⁶² ≤ 2⁶³ + hi := z.linearCombNonModular(x, xC, y, yC) + z.montReduceSigned(z, hi) +} + +// montReduceSigned z = (xHi * r + x) * r⁻¹ using the SOS algorithm +// Requires |xHi| < 2⁶³. Most significant bit of xHi is the sign bit. +func (z *Element) montReduceSigned(x *Element, xHi uint64) { + const signBitRemover = ^signBitSelector + mustNeg := xHi&signBitSelector != 0 + // the SOS implementation requires that most significant bit is 0 + // Let X be xHi*r + x + // If X is negative we would have initially stored it as 2⁶⁴ r + X (à la 2's complement) + xHi &= signBitRemover + // with this a negative X is now represented as 2⁶³ r + X + + var t [2*Limbs - 1]uint64 + var C uint64 + + m := x[0] * qInvNeg + + C = madd0(m, q0, x[0]) + C, t[1] = madd2(m, q1, x[1], C) + C, t[2] = madd2(m, q2, x[2], C) + C, t[3] = madd2(m, q3, x[3], C) + + // m * qElement[3] ≤ (2⁶⁴ - 1) * (2⁶³ - 1) = 2¹²⁷ - 2⁶⁴ - 2⁶³ + 1 + // x[3] + C ≤ 2*(2⁶⁴ - 1) = 2⁶⁵ - 2 + // On LHS, (C, t[3]) ≤ 2¹²⁷ - 2⁶⁴ - 2⁶³ + 1 + 2⁶⁵ - 2 = 2¹²⁷ + 2⁶³ - 1 + // So on LHS, C ≤ 2⁶³ + t[4] = xHi + C + // xHi + C < 2⁶³ + 2⁶³ = 2⁶⁴ + + // + { + const i = 1 + m = t[i] * qInvNeg + + C = madd0(m, q0, t[i+0]) + C, t[i+1] = madd2(m, q1, t[i+1], C) + C, t[i+2] = madd2(m, q2, t[i+2], C) + C, t[i+3] = madd2(m, q3, t[i+3], C) + + t[i+Limbs] += C + } + { + const i = 2 + m = t[i] * qInvNeg + + C = madd0(m, q0, t[i+0]) + C, t[i+1] = madd2(m, q1, t[i+1], C) + C, t[i+2] = madd2(m, q2, t[i+2], C) + C, t[i+3] = madd2(m, q3, t[i+3], C) + + t[i+Limbs] += C + } + { + const i = 3 + m := t[i] * qInvNeg + + C = madd0(m, q0, t[i+0]) + C, z[0] = madd2(m, q1, t[i+1], C) + C, z[1] = madd2(m, q2, t[i+2], C) + z[3], z[2] = madd2(m, q3, t[i+3], C) + } + + // if z ⩾ q → z -= q + if !z.smallerThanModulus() { + var b uint64 + z[0], b = bits.Sub64(z[0], q0, 0) + z[1], b = bits.Sub64(z[1], q1, b) + z[2], b = bits.Sub64(z[2], q2, b) + z[3], _ = bits.Sub64(z[3], q3, b) + } + // + + if mustNeg { + // We have computed ( 2⁶³ r + X ) r⁻¹ = 2⁶³ + X r⁻¹ instead + var b uint64 + z[0], b = bits.Sub64(z[0], signBitSelector, 0) + z[1], b = bits.Sub64(z[1], 0, b) + z[2], b = bits.Sub64(z[2], 0, b) + z[3], b = bits.Sub64(z[3], 0, b) + + // Occurs iff x == 0 && xHi < 0, i.e. X = rX' for -2⁶³ ≤ X' < 0 + + if b != 0 { + // z[3] = -1 + // negative: add q + const neg1 = 0xFFFFFFFFFFFFFFFF + + var carry uint64 + + z[0], carry = bits.Add64(z[0], q0, 0) + z[1], carry = bits.Add64(z[1], q1, carry) + z[2], carry = bits.Add64(z[2], q2, carry) + z[3], _ = bits.Add64(neg1, q3, carry) + } + } +} + +const ( + updateFactorsConversionBias int64 = 0x7fffffff7fffffff // (2³¹ - 1)(2³² + 1) + updateFactorIdentityMatrixRow0 = 1 + updateFactorIdentityMatrixRow1 = 1 << 32 +) + +func updateFactorsDecompose(c int64) (int64, int64) { + c += updateFactorsConversionBias + const low32BitsFilter int64 = 0xFFFFFFFF + f := c&low32BitsFilter - 0x7FFFFFFF + g := c>>32&low32BitsFilter - 0x7FFFFFFF + return f, g +} + +// negL negates in place [x | xHi] and return the new most significant word xHi +func negL(x *Element, xHi uint64) uint64 { + var b uint64 + + x[0], b = bits.Sub64(0, x[0], 0) + x[1], b = bits.Sub64(0, x[1], b) + x[2], b = bits.Sub64(0, x[2], b) + x[3], b = bits.Sub64(0, x[3], b) + xHi, _ = bits.Sub64(0, xHi, b) + + return xHi +} + +// mulWNonModular multiplies by one word in non-montgomery, without reducing +func (z *Element) mulWNonModular(x *Element, y int64) uint64 { + + // w := abs(y) + m := y >> 63 + w := uint64((y ^ m) - m) + + var c uint64 + c, z[0] = bits.Mul64(x[0], w) + c, z[1] = madd1(x[1], w, c) + c, z[2] = madd1(x[2], w, c) + c, z[3] = madd1(x[3], w, c) + + if y < 0 { + c = negL(z, c) + } + + return c +} + +// linearCombNonModular computes a linear combination without modular reduction +func (z *Element) linearCombNonModular(x *Element, xC int64, y *Element, yC int64) uint64 { + var yTimes Element + + yHi := yTimes.mulWNonModular(y, yC) + xHi := z.mulWNonModular(x, xC) + + var carry uint64 + z[0], carry = bits.Add64(z[0], yTimes[0], 0) + z[1], carry = bits.Add64(z[1], yTimes[1], carry) + z[2], carry = bits.Add64(z[2], yTimes[2], carry) + z[3], carry = bits.Add64(z[3], yTimes[3], carry) + + yHi, _ = bits.Add64(xHi, yHi, carry) + + return yHi +} diff --git a/ecc/grumpkin/fr/element_amd64.go b/ecc/grumpkin/fr/element_amd64.go new file mode 100644 index 0000000000..1916ed6fa8 --- /dev/null +++ b/ecc/grumpkin/fr/element_amd64.go @@ -0,0 +1,59 @@ +//go:build !purego + +// Copyright 2020-2025 Consensys Software Inc. +// Licensed under the Apache License, Version 2.0. See the LICENSE file for details. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package fr + +import ( + _ "github.com/consensys/gnark-crypto/field/asm/element_4w" +) + +//go:noescape +func MulBy3(x *Element) + +//go:noescape +func MulBy5(x *Element) + +//go:noescape +func MulBy13(x *Element) + +//go:noescape +func mul(res, x, y *Element) + +//go:noescape +func fromMont(res *Element) + +//go:noescape +func reduce(res *Element) + +// Butterfly sets +// +// a = a + b (mod q) +// b = a - b (mod q) +// +//go:noescape +func Butterfly(a, b *Element) + +// Mul z = x * y (mod q) +// +// x and y must be less than q +func (z *Element) Mul(x, y *Element) *Element { + + // Algorithm 2 of "Faster Montgomery Multiplication and Multi-Scalar-Multiplication for SNARKS" + // by Y. El Housni and G. Botrel https://doi.org/10.46586/tches.v2023.i3.504-521 + + mul(z, x, y) + return z +} + +// Square z = x * x (mod q) +// +// x must be less than q +func (z *Element) Square(x *Element) *Element { + // see Mul for doc. + mul(z, x, x) + return z +} diff --git a/ecc/grumpkin/fr/element_amd64.s b/ecc/grumpkin/fr/element_amd64.s new file mode 100644 index 0000000000..b45615aa36 --- /dev/null +++ b/ecc/grumpkin/fr/element_amd64.s @@ -0,0 +1,10 @@ +//go:build !purego + +// Copyright 2020-2025 Consensys Software Inc. +// Licensed under the Apache License, Version 2.0. See the LICENSE file for details. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +// We include the hash to force the Go compiler to recompile: 14652627197992229521 +#include "../../../field/asm/element_4w/element_4w_amd64.s" + diff --git a/ecc/grumpkin/fr/element_arm64.go b/ecc/grumpkin/fr/element_arm64.go new file mode 100644 index 0000000000..03421559cc --- /dev/null +++ b/ecc/grumpkin/fr/element_arm64.go @@ -0,0 +1,70 @@ +//go:build !purego + +// Copyright 2020-2025 Consensys Software Inc. +// Licensed under the Apache License, Version 2.0. See the LICENSE file for details. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package fr + +import ( + _ "github.com/consensys/gnark-crypto/field/asm/element_4w" +) + +// Butterfly sets +// +// a = a + b (mod q) +// b = a - b (mod q) +// +//go:noescape +func Butterfly(a, b *Element) + +//go:noescape +func mul(res, x, y *Element) + +// Mul z = x * y (mod q) +// +// x and y must be less than q +func (z *Element) Mul(x, y *Element) *Element { + mul(z, x, y) + return z +} + +// Square z = x * x (mod q) +// +// x must be less than q +func (z *Element) Square(x *Element) *Element { + // see Mul for doc. + mul(z, x, x) + return z +} + +// MulBy3 x *= 3 (mod q) +func MulBy3(x *Element) { + _x := *x + x.Double(x).Add(x, &_x) +} + +// MulBy5 x *= 5 (mod q) +func MulBy5(x *Element) { + _x := *x + x.Double(x).Double(x).Add(x, &_x) +} + +// MulBy13 x *= 13 (mod q) +func MulBy13(x *Element) { + var y = Element{ + 529957932336199972, + 13952065197595570812, + 769406925088786211, + 2691790815622165739, + } + x.Mul(x, &y) +} + +func fromMont(z *Element) { + _fromMontGeneric(z) +} + +//go:noescape +func reduce(res *Element) diff --git a/ecc/grumpkin/fr/element_arm64.s b/ecc/grumpkin/fr/element_arm64.s new file mode 100644 index 0000000000..c8df07e345 --- /dev/null +++ b/ecc/grumpkin/fr/element_arm64.s @@ -0,0 +1,10 @@ +//go:build !purego + +// Copyright 2020-2025 Consensys Software Inc. +// Licensed under the Apache License, Version 2.0. See the LICENSE file for details. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +// We include the hash to force the Go compiler to recompile: 1501560133179981797 +#include "../../../field/asm/element_4w/element_4w_arm64.s" + diff --git a/ecc/grumpkin/fr/element_exp.go b/ecc/grumpkin/fr/element_exp.go new file mode 100644 index 0000000000..c9716e3a7f --- /dev/null +++ b/ecc/grumpkin/fr/element_exp.go @@ -0,0 +1,791 @@ +// Copyright 2020-2025 Consensys Software Inc. +// Licensed under the Apache License, Version 2.0. See the LICENSE file for details. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package fr + +// expBySqrtExp is equivalent to z.Exp(x, c19139cb84c680a6e14116da060561765e05aa45a1c72a34f082305b61f3f52) +// +// uses github.com/mmcloughlin/addchain v0.4.0 to generate a shorter addition chain +func (z *Element) expBySqrtExp(x Element) *Element { + // addition chain: + // + // _10 = 2*1 + // _11 = 1 + _10 + // _101 = _10 + _11 + // _110 = 1 + _101 + // _111 = 1 + _110 + // _1011 = _101 + _110 + // _1100 = 1 + _1011 + // _1101 = 1 + _1100 + // _1111 = _10 + _1101 + // _10001 = _10 + _1111 + // _10011 = _10 + _10001 + // _10111 = _110 + _10001 + // _11001 = _10 + _10111 + // _11011 = _10 + _11001 + // _11111 = _110 + _11001 + // _100011 = _1100 + _10111 + // _100111 = _1100 + _11011 + // _101001 = _10 + _100111 + // _101011 = _10 + _101001 + // _101101 = _10 + _101011 + // _111001 = _1100 + _101101 + // _1100000 = _100111 + _111001 + // i46 = ((_1100000 << 5 + _11001) << 9 + _100111) << 8 + // i62 = ((_111001 + i46) << 4 + _111) << 9 + _10011 + // i89 = ((i62 << 7 + _1101) << 13 + _101001) << 5 + // i109 = ((_10111 + i89) << 7 + _101) << 10 + _10001 + // i130 = ((i109 << 6 + _11011) << 5 + _1101) << 8 + // i154 = ((_11 + i130) << 12 + _101011) << 9 + _10111 + // i179 = ((i154 << 6 + _11001) << 5 + _1111) << 12 + // i198 = ((_101101 + i179) << 7 + _101001) << 9 + _101101 + // i220 = ((i198 << 7 + _111) << 9 + _111001) << 4 + // i236 = ((_101 + i220) << 7 + _1101) << 6 + _1111 + // i265 = ((i236 << 5 + 1) << 11 + _100011) << 11 + // i281 = ((_101101 + i265) << 4 + _1011) << 9 + _11111 + // i299 = (i281 << 8 + _110 + _111001) << 7 + _101001 + // return 2*i299 + // + // Operations: 246 squares 54 multiplies + + // Allocate Temporaries. + var ( + t0 = new(Element) + t1 = new(Element) + t2 = new(Element) + t3 = new(Element) + t4 = new(Element) + t5 = new(Element) + t6 = new(Element) + t7 = new(Element) + t8 = new(Element) + t9 = new(Element) + t10 = new(Element) + t11 = new(Element) + t12 = new(Element) + t13 = new(Element) + t14 = new(Element) + t15 = new(Element) + t16 = new(Element) + t17 = new(Element) + t18 = new(Element) + ) + + // var t0,t1,t2,t3,t4,t5,t6,t7,t8,t9,t10,t11,t12,t13,t14,t15,t16,t17,t18 Element + // Step 1: t4 = x^0x2 + t4.Square(&x) + + // Step 2: t13 = x^0x3 + t13.Mul(&x, t4) + + // Step 3: t8 = x^0x5 + t8.Mul(t4, t13) + + // Step 4: t1 = x^0x6 + t1.Mul(&x, t8) + + // Step 5: t9 = x^0x7 + t9.Mul(&x, t1) + + // Step 6: t3 = x^0xb + t3.Mul(t8, t1) + + // Step 7: t0 = x^0xc + t0.Mul(&x, t3) + + // Step 8: t7 = x^0xd + t7.Mul(&x, t0) + + // Step 9: t6 = x^0xf + t6.Mul(t4, t7) + + // Step 10: t15 = x^0x11 + t15.Mul(t4, t6) + + // Step 11: t16 = x^0x13 + t16.Mul(t4, t15) + + // Step 12: t11 = x^0x17 + t11.Mul(t1, t15) + + // Step 13: t10 = x^0x19 + t10.Mul(t4, t11) + + // Step 14: t14 = x^0x1b + t14.Mul(t4, t10) + + // Step 15: t2 = x^0x1f + t2.Mul(t1, t10) + + // Step 16: t5 = x^0x23 + t5.Mul(t0, t11) + + // Step 17: t17 = x^0x27 + t17.Mul(t0, t14) + + // Step 18: z = x^0x29 + z.Mul(t4, t17) + + // Step 19: t12 = x^0x2b + t12.Mul(t4, z) + + // Step 20: t4 = x^0x2d + t4.Mul(t4, t12) + + // Step 21: t0 = x^0x39 + t0.Mul(t0, t4) + + // Step 22: t18 = x^0x60 + t18.Mul(t17, t0) + + // Step 27: t18 = x^0xc00 + for s := 0; s < 5; s++ { + t18.Square(t18) + } + + // Step 28: t18 = x^0xc19 + t18.Mul(t10, t18) + + // Step 37: t18 = x^0x183200 + for s := 0; s < 9; s++ { + t18.Square(t18) + } + + // Step 38: t17 = x^0x183227 + t17.Mul(t17, t18) + + // Step 46: t17 = x^0x18322700 + for s := 0; s < 8; s++ { + t17.Square(t17) + } + + // Step 47: t17 = x^0x18322739 + t17.Mul(t0, t17) + + // Step 51: t17 = x^0x183227390 + for s := 0; s < 4; s++ { + t17.Square(t17) + } + + // Step 52: t17 = x^0x183227397 + t17.Mul(t9, t17) + + // Step 61: t17 = x^0x30644e72e00 + for s := 0; s < 9; s++ { + t17.Square(t17) + } + + // Step 62: t16 = x^0x30644e72e13 + t16.Mul(t16, t17) + + // Step 69: t16 = x^0x1832273970980 + for s := 0; s < 7; s++ { + t16.Square(t16) + } + + // Step 70: t16 = x^0x183227397098d + t16.Mul(t7, t16) + + // Step 83: t16 = x^0x30644e72e131a000 + for s := 0; s < 13; s++ { + t16.Square(t16) + } + + // Step 84: t16 = x^0x30644e72e131a029 + t16.Mul(z, t16) + + // Step 89: t16 = x^0x60c89ce5c26340520 + for s := 0; s < 5; s++ { + t16.Square(t16) + } + + // Step 90: t16 = x^0x60c89ce5c26340537 + t16.Mul(t11, t16) + + // Step 97: t16 = x^0x30644e72e131a029b80 + for s := 0; s < 7; s++ { + t16.Square(t16) + } + + // Step 98: t16 = x^0x30644e72e131a029b85 + t16.Mul(t8, t16) + + // Step 108: t16 = x^0xc19139cb84c680a6e1400 + for s := 0; s < 10; s++ { + t16.Square(t16) + } + + // Step 109: t15 = x^0xc19139cb84c680a6e1411 + t15.Mul(t15, t16) + + // Step 115: t15 = x^0x30644e72e131a029b850440 + for s := 0; s < 6; s++ { + t15.Square(t15) + } + + // Step 116: t14 = x^0x30644e72e131a029b85045b + t14.Mul(t14, t15) + + // Step 121: t14 = x^0x60c89ce5c263405370a08b60 + for s := 0; s < 5; s++ { + t14.Square(t14) + } + + // Step 122: t14 = x^0x60c89ce5c263405370a08b6d + t14.Mul(t7, t14) + + // Step 130: t14 = x^0x60c89ce5c263405370a08b6d00 + for s := 0; s < 8; s++ { + t14.Square(t14) + } + + // Step 131: t13 = x^0x60c89ce5c263405370a08b6d03 + t13.Mul(t13, t14) + + // Step 143: t13 = x^0x60c89ce5c263405370a08b6d03000 + for s := 0; s < 12; s++ { + t13.Square(t13) + } + + // Step 144: t12 = x^0x60c89ce5c263405370a08b6d0302b + t12.Mul(t12, t13) + + // Step 153: t12 = x^0xc19139cb84c680a6e14116da0605600 + for s := 0; s < 9; s++ { + t12.Square(t12) + } + + // Step 154: t11 = x^0xc19139cb84c680a6e14116da0605617 + t11.Mul(t11, t12) + + // Step 160: t11 = x^0x30644e72e131a029b85045b68181585c0 + for s := 0; s < 6; s++ { + t11.Square(t11) + } + + // Step 161: t10 = x^0x30644e72e131a029b85045b68181585d9 + t10.Mul(t10, t11) + + // Step 166: t10 = x^0x60c89ce5c263405370a08b6d0302b0bb20 + for s := 0; s < 5; s++ { + t10.Square(t10) + } + + // Step 167: t10 = x^0x60c89ce5c263405370a08b6d0302b0bb2f + t10.Mul(t6, t10) + + // Step 179: t10 = x^0x60c89ce5c263405370a08b6d0302b0bb2f000 + for s := 0; s < 12; s++ { + t10.Square(t10) + } + + // Step 180: t10 = x^0x60c89ce5c263405370a08b6d0302b0bb2f02d + t10.Mul(t4, t10) + + // Step 187: t10 = x^0x30644e72e131a029b85045b68181585d9781680 + for s := 0; s < 7; s++ { + t10.Square(t10) + } + + // Step 188: t10 = x^0x30644e72e131a029b85045b68181585d97816a9 + t10.Mul(z, t10) + + // Step 197: t10 = x^0x60c89ce5c263405370a08b6d0302b0bb2f02d5200 + for s := 0; s < 9; s++ { + t10.Square(t10) + } + + // Step 198: t10 = x^0x60c89ce5c263405370a08b6d0302b0bb2f02d522d + t10.Mul(t4, t10) + + // Step 205: t10 = x^0x30644e72e131a029b85045b68181585d97816a91680 + for s := 0; s < 7; s++ { + t10.Square(t10) + } + + // Step 206: t9 = x^0x30644e72e131a029b85045b68181585d97816a91687 + t9.Mul(t9, t10) + + // Step 215: t9 = x^0x60c89ce5c263405370a08b6d0302b0bb2f02d522d0e00 + for s := 0; s < 9; s++ { + t9.Square(t9) + } + + // Step 216: t9 = x^0x60c89ce5c263405370a08b6d0302b0bb2f02d522d0e39 + t9.Mul(t0, t9) + + // Step 220: t9 = x^0x60c89ce5c263405370a08b6d0302b0bb2f02d522d0e390 + for s := 0; s < 4; s++ { + t9.Square(t9) + } + + // Step 221: t8 = x^0x60c89ce5c263405370a08b6d0302b0bb2f02d522d0e395 + t8.Mul(t8, t9) + + // Step 228: t8 = x^0x30644e72e131a029b85045b68181585d97816a916871ca80 + for s := 0; s < 7; s++ { + t8.Square(t8) + } + + // Step 229: t7 = x^0x30644e72e131a029b85045b68181585d97816a916871ca8d + t7.Mul(t7, t8) + + // Step 235: t7 = x^0xc19139cb84c680a6e14116da060561765e05aa45a1c72a340 + for s := 0; s < 6; s++ { + t7.Square(t7) + } + + // Step 236: t6 = x^0xc19139cb84c680a6e14116da060561765e05aa45a1c72a34f + t6.Mul(t6, t7) + + // Step 241: t6 = x^0x183227397098d014dc2822db40c0ac2ecbc0b548b438e5469e0 + for s := 0; s < 5; s++ { + t6.Square(t6) + } + + // Step 242: t6 = x^0x183227397098d014dc2822db40c0ac2ecbc0b548b438e5469e1 + t6.Mul(&x, t6) + + // Step 253: t6 = x^0xc19139cb84c680a6e14116da060561765e05aa45a1c72a34f0800 + for s := 0; s < 11; s++ { + t6.Square(t6) + } + + // Step 254: t5 = x^0xc19139cb84c680a6e14116da060561765e05aa45a1c72a34f0823 + t5.Mul(t5, t6) + + // Step 265: t5 = x^0x60c89ce5c263405370a08b6d0302b0bb2f02d522d0e3951a78411800 + for s := 0; s < 11; s++ { + t5.Square(t5) + } + + // Step 266: t4 = x^0x60c89ce5c263405370a08b6d0302b0bb2f02d522d0e3951a7841182d + t4.Mul(t4, t5) + + // Step 270: t4 = x^0x60c89ce5c263405370a08b6d0302b0bb2f02d522d0e3951a7841182d0 + for s := 0; s < 4; s++ { + t4.Square(t4) + } + + // Step 271: t3 = x^0x60c89ce5c263405370a08b6d0302b0bb2f02d522d0e3951a7841182db + t3.Mul(t3, t4) + + // Step 280: t3 = x^0xc19139cb84c680a6e14116da060561765e05aa45a1c72a34f082305b600 + for s := 0; s < 9; s++ { + t3.Square(t3) + } + + // Step 281: t2 = x^0xc19139cb84c680a6e14116da060561765e05aa45a1c72a34f082305b61f + t2.Mul(t2, t3) + + // Step 289: t2 = x^0xc19139cb84c680a6e14116da060561765e05aa45a1c72a34f082305b61f00 + for s := 0; s < 8; s++ { + t2.Square(t2) + } + + // Step 290: t1 = x^0xc19139cb84c680a6e14116da060561765e05aa45a1c72a34f082305b61f06 + t1.Mul(t1, t2) + + // Step 291: t0 = x^0xc19139cb84c680a6e14116da060561765e05aa45a1c72a34f082305b61f3f + t0.Mul(t0, t1) + + // Step 298: t0 = x^0x60c89ce5c263405370a08b6d0302b0bb2f02d522d0e3951a7841182db0f9f80 + for s := 0; s < 7; s++ { + t0.Square(t0) + } + + // Step 299: z = x^0x60c89ce5c263405370a08b6d0302b0bb2f02d522d0e3951a7841182db0f9fa9 + z.Mul(z, t0) + + // Step 300: z = x^0xc19139cb84c680a6e14116da060561765e05aa45a1c72a34f082305b61f3f52 + z.Square(z) + + return z +} + +// expByLegendreExp is equivalent to z.Exp(x, 183227397098d014dc2822db40c0ac2ecbc0b548b438e5469e10460b6c3e7ea3) +// +// uses github.com/mmcloughlin/addchain v0.4.0 to generate a shorter addition chain +func (z *Element) expByLegendreExp(x Element) *Element { + // addition chain: + // + // _10 = 2*1 + // _11 = 1 + _10 + // _101 = _10 + _11 + // _110 = 1 + _101 + // _1000 = _10 + _110 + // _1101 = _101 + _1000 + // _10010 = _101 + _1101 + // _10011 = 1 + _10010 + // _10100 = 1 + _10011 + // _10111 = _11 + _10100 + // _11100 = _101 + _10111 + // _100000 = _1101 + _10011 + // _100011 = _11 + _100000 + // _101011 = _1000 + _100011 + // _101111 = _10011 + _11100 + // _1000001 = _10010 + _101111 + // _1010011 = _10010 + _1000001 + // _1011011 = _1000 + _1010011 + // _1100001 = _110 + _1011011 + // _1110101 = _10100 + _1100001 + // _10010001 = _11100 + _1110101 + // _10010101 = _100000 + _1110101 + // _10110101 = _100000 + _10010101 + // _10111011 = _110 + _10110101 + // _11000001 = _110 + _10111011 + // _11000011 = _10 + _11000001 + // _11010011 = _10010 + _11000001 + // _11100001 = _100000 + _11000001 + // _11100011 = _10 + _11100001 + // _11100111 = _110 + _11100001 + // i57 = ((_11000001 << 8 + _10010001) << 10 + _11100111) << 7 + // i76 = ((_10111 + i57) << 9 + _10011) << 7 + _1101 + // i109 = ((i76 << 14 + _1010011) << 9 + _11100001) << 8 + // i127 = ((_1000001 + i109) << 10 + _1011011) << 5 + _1101 + // i161 = ((i127 << 8 + _11) << 12 + _101011) << 12 + // i186 = ((_10111011 + i161) << 8 + _101111) << 14 + _10110101 + // i214 = ((i186 << 9 + _10010001) << 5 + _1101) << 12 + // i236 = ((_11100011 + i214) << 8 + _10010101) << 11 + _11010011 + // i268 = ((i236 << 7 + _1100001) << 11 + _100011) << 12 + // i288 = ((_1011011 + i268) << 9 + _11000011) << 8 + _11100111 + // return (i288 << 7 + _1110101) << 5 + _11 + // + // Operations: 246 squares 56 multiplies + + // Allocate Temporaries. + var ( + t0 = new(Element) + t1 = new(Element) + t2 = new(Element) + t3 = new(Element) + t4 = new(Element) + t5 = new(Element) + t6 = new(Element) + t7 = new(Element) + t8 = new(Element) + t9 = new(Element) + t10 = new(Element) + t11 = new(Element) + t12 = new(Element) + t13 = new(Element) + t14 = new(Element) + t15 = new(Element) + t16 = new(Element) + t17 = new(Element) + t18 = new(Element) + t19 = new(Element) + t20 = new(Element) + ) + + // var t0,t1,t2,t3,t4,t5,t6,t7,t8,t9,t10,t11,t12,t13,t14,t15,t16,t17,t18,t19,t20 Element + // Step 1: t8 = x^0x2 + t8.Square(&x) + + // Step 2: z = x^0x3 + z.Mul(&x, t8) + + // Step 3: t2 = x^0x5 + t2.Mul(t8, z) + + // Step 4: t1 = x^0x6 + t1.Mul(&x, t2) + + // Step 5: t3 = x^0x8 + t3.Mul(t8, t1) + + // Step 6: t9 = x^0xd + t9.Mul(t2, t3) + + // Step 7: t6 = x^0x12 + t6.Mul(t2, t9) + + // Step 8: t18 = x^0x13 + t18.Mul(&x, t6) + + // Step 9: t0 = x^0x14 + t0.Mul(&x, t18) + + // Step 10: t19 = x^0x17 + t19.Mul(z, t0) + + // Step 11: t2 = x^0x1c + t2.Mul(t2, t19) + + // Step 12: t16 = x^0x20 + t16.Mul(t9, t18) + + // Step 13: t4 = x^0x23 + t4.Mul(z, t16) + + // Step 14: t14 = x^0x2b + t14.Mul(t3, t4) + + // Step 15: t12 = x^0x2f + t12.Mul(t18, t2) + + // Step 16: t15 = x^0x41 + t15.Mul(t6, t12) + + // Step 17: t17 = x^0x53 + t17.Mul(t6, t15) + + // Step 18: t3 = x^0x5b + t3.Mul(t3, t17) + + // Step 19: t5 = x^0x61 + t5.Mul(t1, t3) + + // Step 20: t0 = x^0x75 + t0.Mul(t0, t5) + + // Step 21: t10 = x^0x91 + t10.Mul(t2, t0) + + // Step 22: t7 = x^0x95 + t7.Mul(t16, t0) + + // Step 23: t11 = x^0xb5 + t11.Mul(t16, t7) + + // Step 24: t13 = x^0xbb + t13.Mul(t1, t11) + + // Step 25: t20 = x^0xc1 + t20.Mul(t1, t13) + + // Step 26: t2 = x^0xc3 + t2.Mul(t8, t20) + + // Step 27: t6 = x^0xd3 + t6.Mul(t6, t20) + + // Step 28: t16 = x^0xe1 + t16.Mul(t16, t20) + + // Step 29: t8 = x^0xe3 + t8.Mul(t8, t16) + + // Step 30: t1 = x^0xe7 + t1.Mul(t1, t16) + + // Step 38: t20 = x^0xc100 + for s := 0; s < 8; s++ { + t20.Square(t20) + } + + // Step 39: t20 = x^0xc191 + t20.Mul(t10, t20) + + // Step 49: t20 = x^0x3064400 + for s := 0; s < 10; s++ { + t20.Square(t20) + } + + // Step 50: t20 = x^0x30644e7 + t20.Mul(t1, t20) + + // Step 57: t20 = x^0x183227380 + for s := 0; s < 7; s++ { + t20.Square(t20) + } + + // Step 58: t19 = x^0x183227397 + t19.Mul(t19, t20) + + // Step 67: t19 = x^0x30644e72e00 + for s := 0; s < 9; s++ { + t19.Square(t19) + } + + // Step 68: t18 = x^0x30644e72e13 + t18.Mul(t18, t19) + + // Step 75: t18 = x^0x1832273970980 + for s := 0; s < 7; s++ { + t18.Square(t18) + } + + // Step 76: t18 = x^0x183227397098d + t18.Mul(t9, t18) + + // Step 90: t18 = x^0x60c89ce5c2634000 + for s := 0; s < 14; s++ { + t18.Square(t18) + } + + // Step 91: t17 = x^0x60c89ce5c2634053 + t17.Mul(t17, t18) + + // Step 100: t17 = x^0xc19139cb84c680a600 + for s := 0; s < 9; s++ { + t17.Square(t17) + } + + // Step 101: t16 = x^0xc19139cb84c680a6e1 + t16.Mul(t16, t17) + + // Step 109: t16 = x^0xc19139cb84c680a6e100 + for s := 0; s < 8; s++ { + t16.Square(t16) + } + + // Step 110: t15 = x^0xc19139cb84c680a6e141 + t15.Mul(t15, t16) + + // Step 120: t15 = x^0x30644e72e131a029b850400 + for s := 0; s < 10; s++ { + t15.Square(t15) + } + + // Step 121: t15 = x^0x30644e72e131a029b85045b + t15.Mul(t3, t15) + + // Step 126: t15 = x^0x60c89ce5c263405370a08b60 + for s := 0; s < 5; s++ { + t15.Square(t15) + } + + // Step 127: t15 = x^0x60c89ce5c263405370a08b6d + t15.Mul(t9, t15) + + // Step 135: t15 = x^0x60c89ce5c263405370a08b6d00 + for s := 0; s < 8; s++ { + t15.Square(t15) + } + + // Step 136: t15 = x^0x60c89ce5c263405370a08b6d03 + t15.Mul(z, t15) + + // Step 148: t15 = x^0x60c89ce5c263405370a08b6d03000 + for s := 0; s < 12; s++ { + t15.Square(t15) + } + + // Step 149: t14 = x^0x60c89ce5c263405370a08b6d0302b + t14.Mul(t14, t15) + + // Step 161: t14 = x^0x60c89ce5c263405370a08b6d0302b000 + for s := 0; s < 12; s++ { + t14.Square(t14) + } + + // Step 162: t13 = x^0x60c89ce5c263405370a08b6d0302b0bb + t13.Mul(t13, t14) + + // Step 170: t13 = x^0x60c89ce5c263405370a08b6d0302b0bb00 + for s := 0; s < 8; s++ { + t13.Square(t13) + } + + // Step 171: t12 = x^0x60c89ce5c263405370a08b6d0302b0bb2f + t12.Mul(t12, t13) + + // Step 185: t12 = x^0x183227397098d014dc2822db40c0ac2ecbc000 + for s := 0; s < 14; s++ { + t12.Square(t12) + } + + // Step 186: t11 = x^0x183227397098d014dc2822db40c0ac2ecbc0b5 + t11.Mul(t11, t12) + + // Step 195: t11 = x^0x30644e72e131a029b85045b68181585d97816a00 + for s := 0; s < 9; s++ { + t11.Square(t11) + } + + // Step 196: t10 = x^0x30644e72e131a029b85045b68181585d97816a91 + t10.Mul(t10, t11) + + // Step 201: t10 = x^0x60c89ce5c263405370a08b6d0302b0bb2f02d5220 + for s := 0; s < 5; s++ { + t10.Square(t10) + } + + // Step 202: t9 = x^0x60c89ce5c263405370a08b6d0302b0bb2f02d522d + t9.Mul(t9, t10) + + // Step 214: t9 = x^0x60c89ce5c263405370a08b6d0302b0bb2f02d522d000 + for s := 0; s < 12; s++ { + t9.Square(t9) + } + + // Step 215: t8 = x^0x60c89ce5c263405370a08b6d0302b0bb2f02d522d0e3 + t8.Mul(t8, t9) + + // Step 223: t8 = x^0x60c89ce5c263405370a08b6d0302b0bb2f02d522d0e300 + for s := 0; s < 8; s++ { + t8.Square(t8) + } + + // Step 224: t7 = x^0x60c89ce5c263405370a08b6d0302b0bb2f02d522d0e395 + t7.Mul(t7, t8) + + // Step 235: t7 = x^0x30644e72e131a029b85045b68181585d97816a916871ca800 + for s := 0; s < 11; s++ { + t7.Square(t7) + } + + // Step 236: t6 = x^0x30644e72e131a029b85045b68181585d97816a916871ca8d3 + t6.Mul(t6, t7) + + // Step 243: t6 = x^0x183227397098d014dc2822db40c0ac2ecbc0b548b438e546980 + for s := 0; s < 7; s++ { + t6.Square(t6) + } + + // Step 244: t5 = x^0x183227397098d014dc2822db40c0ac2ecbc0b548b438e5469e1 + t5.Mul(t5, t6) + + // Step 255: t5 = x^0xc19139cb84c680a6e14116da060561765e05aa45a1c72a34f0800 + for s := 0; s < 11; s++ { + t5.Square(t5) + } + + // Step 256: t4 = x^0xc19139cb84c680a6e14116da060561765e05aa45a1c72a34f0823 + t4.Mul(t4, t5) + + // Step 268: t4 = x^0xc19139cb84c680a6e14116da060561765e05aa45a1c72a34f0823000 + for s := 0; s < 12; s++ { + t4.Square(t4) + } + + // Step 269: t3 = x^0xc19139cb84c680a6e14116da060561765e05aa45a1c72a34f082305b + t3.Mul(t3, t4) + + // Step 278: t3 = x^0x183227397098d014dc2822db40c0ac2ecbc0b548b438e5469e10460b600 + for s := 0; s < 9; s++ { + t3.Square(t3) + } + + // Step 279: t2 = x^0x183227397098d014dc2822db40c0ac2ecbc0b548b438e5469e10460b6c3 + t2.Mul(t2, t3) + + // Step 287: t2 = x^0x183227397098d014dc2822db40c0ac2ecbc0b548b438e5469e10460b6c300 + for s := 0; s < 8; s++ { + t2.Square(t2) + } + + // Step 288: t1 = x^0x183227397098d014dc2822db40c0ac2ecbc0b548b438e5469e10460b6c3e7 + t1.Mul(t1, t2) + + // Step 295: t1 = x^0xc19139cb84c680a6e14116da060561765e05aa45a1c72a34f082305b61f380 + for s := 0; s < 7; s++ { + t1.Square(t1) + } + + // Step 296: t0 = x^0xc19139cb84c680a6e14116da060561765e05aa45a1c72a34f082305b61f3f5 + t0.Mul(t0, t1) + + // Step 301: t0 = x^0x183227397098d014dc2822db40c0ac2ecbc0b548b438e5469e10460b6c3e7ea0 + for s := 0; s < 5; s++ { + t0.Square(t0) + } + + // Step 302: z = x^0x183227397098d014dc2822db40c0ac2ecbc0b548b438e5469e10460b6c3e7ea3 + z.Mul(z, t0) + + return z +} diff --git a/ecc/grumpkin/fr/element_purego.go b/ecc/grumpkin/fr/element_purego.go new file mode 100644 index 0000000000..e4b9c87ac4 --- /dev/null +++ b/ecc/grumpkin/fr/element_purego.go @@ -0,0 +1,391 @@ +//go:build purego || (!amd64 && !arm64) + +// Copyright 2020-2025 Consensys Software Inc. +// Licensed under the Apache License, Version 2.0. See the LICENSE file for details. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package fr + +import "math/bits" + +// MulBy3 x *= 3 (mod q) +func MulBy3(x *Element) { + _x := *x + x.Double(x).Add(x, &_x) +} + +// MulBy5 x *= 5 (mod q) +func MulBy5(x *Element) { + _x := *x + x.Double(x).Double(x).Add(x, &_x) +} + +// MulBy13 x *= 13 (mod q) +func MulBy13(x *Element) { + var y = Element{ + 529957932336199972, + 13952065197595570812, + 769406925088786211, + 2691790815622165739, + } + x.Mul(x, &y) +} + +func fromMont(z *Element) { + _fromMontGeneric(z) +} + +func reduce(z *Element) { + _reduceGeneric(z) +} + +// Mul z = x * y (mod q) +// +// x and y must be less than q +func (z *Element) Mul(x, y *Element) *Element { + + // Algorithm 2 of "Faster Montgomery Multiplication and Multi-Scalar-Multiplication for SNARKS" + // by Y. El Housni and G. Botrel https://doi.org/10.46586/tches.v2023.i3.504-521 + + var t0, t1, t2, t3 uint64 + var u0, u1, u2, u3 uint64 + { + var c0, c1, c2 uint64 + v := x[0] + u0, t0 = bits.Mul64(v, y[0]) + u1, t1 = bits.Mul64(v, y[1]) + u2, t2 = bits.Mul64(v, y[2]) + u3, t3 = bits.Mul64(v, y[3]) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + c2, _ = bits.Add64(u3, 0, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + + t2, c0 = bits.Add64(0, c1, c0) + u3, _ = bits.Add64(u3, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + c2, _ = bits.Add64(c2, 0, c0) + t2, c0 = bits.Add64(t3, t2, 0) + t3, _ = bits.Add64(u3, c2, c0) + + } + { + var c0, c1, c2 uint64 + v := x[1] + u0, c1 = bits.Mul64(v, y[0]) + t0, c0 = bits.Add64(c1, t0, 0) + u1, c1 = bits.Mul64(v, y[1]) + t1, c0 = bits.Add64(c1, t1, c0) + u2, c1 = bits.Mul64(v, y[2]) + t2, c0 = bits.Add64(c1, t2, c0) + u3, c1 = bits.Mul64(v, y[3]) + t3, c0 = bits.Add64(c1, t3, c0) + + c2, _ = bits.Add64(0, 0, c0) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + c2, _ = bits.Add64(u3, c2, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + + t2, c0 = bits.Add64(0, c1, c0) + u3, _ = bits.Add64(u3, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + c2, _ = bits.Add64(c2, 0, c0) + t2, c0 = bits.Add64(t3, t2, 0) + t3, _ = bits.Add64(u3, c2, c0) + + } + { + var c0, c1, c2 uint64 + v := x[2] + u0, c1 = bits.Mul64(v, y[0]) + t0, c0 = bits.Add64(c1, t0, 0) + u1, c1 = bits.Mul64(v, y[1]) + t1, c0 = bits.Add64(c1, t1, c0) + u2, c1 = bits.Mul64(v, y[2]) + t2, c0 = bits.Add64(c1, t2, c0) + u3, c1 = bits.Mul64(v, y[3]) + t3, c0 = bits.Add64(c1, t3, c0) + + c2, _ = bits.Add64(0, 0, c0) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + c2, _ = bits.Add64(u3, c2, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + + t2, c0 = bits.Add64(0, c1, c0) + u3, _ = bits.Add64(u3, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + c2, _ = bits.Add64(c2, 0, c0) + t2, c0 = bits.Add64(t3, t2, 0) + t3, _ = bits.Add64(u3, c2, c0) + + } + { + var c0, c1, c2 uint64 + v := x[3] + u0, c1 = bits.Mul64(v, y[0]) + t0, c0 = bits.Add64(c1, t0, 0) + u1, c1 = bits.Mul64(v, y[1]) + t1, c0 = bits.Add64(c1, t1, c0) + u2, c1 = bits.Mul64(v, y[2]) + t2, c0 = bits.Add64(c1, t2, c0) + u3, c1 = bits.Mul64(v, y[3]) + t3, c0 = bits.Add64(c1, t3, c0) + + c2, _ = bits.Add64(0, 0, c0) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + c2, _ = bits.Add64(u3, c2, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + + t2, c0 = bits.Add64(0, c1, c0) + u3, _ = bits.Add64(u3, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + c2, _ = bits.Add64(c2, 0, c0) + t2, c0 = bits.Add64(t3, t2, 0) + t3, _ = bits.Add64(u3, c2, c0) + + } + z[0] = t0 + z[1] = t1 + z[2] = t2 + z[3] = t3 + + // if z ⩾ q → z -= q + if !z.smallerThanModulus() { + var b uint64 + z[0], b = bits.Sub64(z[0], q0, 0) + z[1], b = bits.Sub64(z[1], q1, b) + z[2], b = bits.Sub64(z[2], q2, b) + z[3], _ = bits.Sub64(z[3], q3, b) + } + return z +} + +// Square z = x * x (mod q) +// +// x must be less than q +func (z *Element) Square(x *Element) *Element { + // see Mul for algorithm documentation + + var t0, t1, t2, t3 uint64 + var u0, u1, u2, u3 uint64 + { + var c0, c1, c2 uint64 + v := x[0] + u0, t0 = bits.Mul64(v, x[0]) + u1, t1 = bits.Mul64(v, x[1]) + u2, t2 = bits.Mul64(v, x[2]) + u3, t3 = bits.Mul64(v, x[3]) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + c2, _ = bits.Add64(u3, 0, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + + t2, c0 = bits.Add64(0, c1, c0) + u3, _ = bits.Add64(u3, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + c2, _ = bits.Add64(c2, 0, c0) + t2, c0 = bits.Add64(t3, t2, 0) + t3, _ = bits.Add64(u3, c2, c0) + + } + { + var c0, c1, c2 uint64 + v := x[1] + u0, c1 = bits.Mul64(v, x[0]) + t0, c0 = bits.Add64(c1, t0, 0) + u1, c1 = bits.Mul64(v, x[1]) + t1, c0 = bits.Add64(c1, t1, c0) + u2, c1 = bits.Mul64(v, x[2]) + t2, c0 = bits.Add64(c1, t2, c0) + u3, c1 = bits.Mul64(v, x[3]) + t3, c0 = bits.Add64(c1, t3, c0) + + c2, _ = bits.Add64(0, 0, c0) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + c2, _ = bits.Add64(u3, c2, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + + t2, c0 = bits.Add64(0, c1, c0) + u3, _ = bits.Add64(u3, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + c2, _ = bits.Add64(c2, 0, c0) + t2, c0 = bits.Add64(t3, t2, 0) + t3, _ = bits.Add64(u3, c2, c0) + + } + { + var c0, c1, c2 uint64 + v := x[2] + u0, c1 = bits.Mul64(v, x[0]) + t0, c0 = bits.Add64(c1, t0, 0) + u1, c1 = bits.Mul64(v, x[1]) + t1, c0 = bits.Add64(c1, t1, c0) + u2, c1 = bits.Mul64(v, x[2]) + t2, c0 = bits.Add64(c1, t2, c0) + u3, c1 = bits.Mul64(v, x[3]) + t3, c0 = bits.Add64(c1, t3, c0) + + c2, _ = bits.Add64(0, 0, c0) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + c2, _ = bits.Add64(u3, c2, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + + t2, c0 = bits.Add64(0, c1, c0) + u3, _ = bits.Add64(u3, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + c2, _ = bits.Add64(c2, 0, c0) + t2, c0 = bits.Add64(t3, t2, 0) + t3, _ = bits.Add64(u3, c2, c0) + + } + { + var c0, c1, c2 uint64 + v := x[3] + u0, c1 = bits.Mul64(v, x[0]) + t0, c0 = bits.Add64(c1, t0, 0) + u1, c1 = bits.Mul64(v, x[1]) + t1, c0 = bits.Add64(c1, t1, c0) + u2, c1 = bits.Mul64(v, x[2]) + t2, c0 = bits.Add64(c1, t2, c0) + u3, c1 = bits.Mul64(v, x[3]) + t3, c0 = bits.Add64(c1, t3, c0) + + c2, _ = bits.Add64(0, 0, c0) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + c2, _ = bits.Add64(u3, c2, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + + t2, c0 = bits.Add64(0, c1, c0) + u3, _ = bits.Add64(u3, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + c2, _ = bits.Add64(c2, 0, c0) + t2, c0 = bits.Add64(t3, t2, 0) + t3, _ = bits.Add64(u3, c2, c0) + + } + z[0] = t0 + z[1] = t1 + z[2] = t2 + z[3] = t3 + + // if z ⩾ q → z -= q + if !z.smallerThanModulus() { + var b uint64 + z[0], b = bits.Sub64(z[0], q0, 0) + z[1], b = bits.Sub64(z[1], q1, b) + z[2], b = bits.Sub64(z[2], q2, b) + z[3], _ = bits.Sub64(z[3], q3, b) + } + return z +} + +// Butterfly sets +// +// a = a + b (mod q) +// b = a - b (mod q) +func Butterfly(a, b *Element) { + _butterflyGeneric(a, b) +} diff --git a/ecc/grumpkin/fr/element_test.go b/ecc/grumpkin/fr/element_test.go new file mode 100644 index 0000000000..9e823a5319 --- /dev/null +++ b/ecc/grumpkin/fr/element_test.go @@ -0,0 +1,2885 @@ +// Copyright 2020-2025 Consensys Software Inc. +// Licensed under the Apache License, Version 2.0. See the LICENSE file for details. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package fr + +import ( + "crypto/rand" + "encoding/json" + "fmt" + "math/big" + "math/bits" + + mrand "math/rand" + + "testing" + + "github.com/leanovate/gopter" + ggen "github.com/leanovate/gopter/gen" + "github.com/leanovate/gopter/prop" + + "github.com/stretchr/testify/require" +) + +// ------------------------------------------------------------------------------------------------- +// benchmarks +// most benchmarks are rudimentary and should sample a large number of random inputs +// or be run multiple times to ensure it didn't measure the fastest path of the function + +var benchResElement Element + +func BenchmarkElementSelect(b *testing.B) { + var x, y Element + x.SetRandom() + y.SetRandom() + + b.ResetTimer() + for i := 0; i < b.N; i++ { + benchResElement.Select(i%3, &x, &y) + } +} + +func BenchmarkElementSetRandom(b *testing.B) { + var x Element + x.SetRandom() + + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, _ = x.SetRandom() + } +} + +func BenchmarkElementSetBytes(b *testing.B) { + var x Element + x.SetRandom() + bb := x.Bytes() + b.ResetTimer() + + for i := 0; i < b.N; i++ { + benchResElement.SetBytes(bb[:]) + } + +} + +func BenchmarkElementMulByConstants(b *testing.B) { + b.Run("mulBy3", func(b *testing.B) { + benchResElement.SetRandom() + b.ResetTimer() + for i := 0; i < b.N; i++ { + MulBy3(&benchResElement) + } + }) + b.Run("mulBy5", func(b *testing.B) { + benchResElement.SetRandom() + b.ResetTimer() + for i := 0; i < b.N; i++ { + MulBy5(&benchResElement) + } + }) + b.Run("mulBy13", func(b *testing.B) { + benchResElement.SetRandom() + b.ResetTimer() + for i := 0; i < b.N; i++ { + MulBy13(&benchResElement) + } + }) +} + +func BenchmarkElementInverse(b *testing.B) { + var x Element + x.SetRandom() + benchResElement.SetRandom() + b.ResetTimer() + + for i := 0; i < b.N; i++ { + benchResElement.Inverse(&x) + } + +} + +func BenchmarkElementButterfly(b *testing.B) { + var x Element + x.SetRandom() + benchResElement.SetRandom() + b.ResetTimer() + for i := 0; i < b.N; i++ { + Butterfly(&x, &benchResElement) + } +} + +func BenchmarkElementExp(b *testing.B) { + var x Element + x.SetRandom() + benchResElement.SetRandom() + b1, _ := rand.Int(rand.Reader, Modulus()) + b.ResetTimer() + for i := 0; i < b.N; i++ { + benchResElement.Exp(x, b1) + } +} + +func BenchmarkElementDouble(b *testing.B) { + benchResElement.SetRandom() + b.ResetTimer() + for i := 0; i < b.N; i++ { + benchResElement.Double(&benchResElement) + } +} + +func BenchmarkElementAdd(b *testing.B) { + var x Element + x.SetRandom() + benchResElement.SetRandom() + b.ResetTimer() + for i := 0; i < b.N; i++ { + benchResElement.Add(&x, &benchResElement) + } +} + +func BenchmarkElementSub(b *testing.B) { + var x Element + x.SetRandom() + benchResElement.SetRandom() + b.ResetTimer() + for i := 0; i < b.N; i++ { + benchResElement.Sub(&x, &benchResElement) + } +} + +func BenchmarkElementNeg(b *testing.B) { + benchResElement.SetRandom() + b.ResetTimer() + for i := 0; i < b.N; i++ { + benchResElement.Neg(&benchResElement) + } +} + +func BenchmarkElementDiv(b *testing.B) { + var x Element + x.SetRandom() + benchResElement.SetRandom() + b.ResetTimer() + for i := 0; i < b.N; i++ { + benchResElement.Div(&x, &benchResElement) + } +} + +func BenchmarkElementFromMont(b *testing.B) { + benchResElement.SetRandom() + b.ResetTimer() + for i := 0; i < b.N; i++ { + benchResElement.fromMont() + } +} + +func BenchmarkElementSquare(b *testing.B) { + benchResElement.SetRandom() + b.ResetTimer() + for i := 0; i < b.N; i++ { + benchResElement.Square(&benchResElement) + } +} + +func BenchmarkElementSqrt(b *testing.B) { + var a Element + a.SetUint64(4) + a.Neg(&a) + b.ResetTimer() + for i := 0; i < b.N; i++ { + benchResElement.Sqrt(&a) + } +} + +func BenchmarkElementMul(b *testing.B) { + x := Element{ + 17522657719365597833, + 13107472804851548667, + 5164255478447964150, + 493319470278259999, + } + benchResElement.SetOne() + b.ResetTimer() + for i := 0; i < b.N; i++ { + benchResElement.Mul(&benchResElement, &x) + } +} + +func BenchmarkElementCmp(b *testing.B) { + x := Element{ + 17522657719365597833, + 13107472804851548667, + 5164255478447964150, + 493319470278259999, + } + benchResElement = x + benchResElement[0] = 0 + b.ResetTimer() + for i := 0; i < b.N; i++ { + benchResElement.Cmp(&x) + } +} + +func TestElementCmp(t *testing.T) { + var x, y Element + + if x.Cmp(&y) != 0 { + t.Fatal("x == y") + } + + one := One() + y.Sub(&y, &one) + + if x.Cmp(&y) != -1 { + t.Fatal("x < y") + } + if y.Cmp(&x) != 1 { + t.Fatal("x < y") + } + + x = y + if x.Cmp(&y) != 0 { + t.Fatal("x == y") + } + + x.Sub(&x, &one) + if x.Cmp(&y) != -1 { + t.Fatal("x < y") + } + if y.Cmp(&x) != 1 { + t.Fatal("x < y") + } +} +func TestElementIsRandom(t *testing.T) { + for i := 0; i < 50; i++ { + var x, y Element + x.SetRandom() + y.SetRandom() + if x.Equal(&y) { + t.Fatal("2 random numbers are unlikely to be equal") + } + } +} + +func TestElementIsUint64(t *testing.T) { + t.Parallel() + parameters := gopter.DefaultTestParameters() + if testing.Short() { + parameters.MinSuccessfulTests = nbFuzzShort + } else { + parameters.MinSuccessfulTests = nbFuzz + } + + properties := gopter.NewProperties(parameters) + + properties.Property("reduce should output a result smaller than modulus", prop.ForAll( + func(v uint64) bool { + var e Element + e.SetUint64(v) + + if !e.IsUint64() { + return false + } + + return e.Uint64() == v + }, + ggen.UInt64(), + )) + + properties.TestingRun(t, gopter.ConsoleReporter(false)) +} + +func TestElementNegZero(t *testing.T) { + var a, b Element + b.SetZero() + for a.IsZero() { + a.SetRandom() + } + a.Neg(&b) + if !a.IsZero() { + t.Fatal("neg(0) != 0") + } +} + +// ------------------------------------------------------------------------------------------------- +// Gopter tests +// most of them are generated with a template + +const ( + nbFuzzShort = 200 + nbFuzz = 1000 +) + +// special values to be used in tests +var staticTestValues []Element + +func init() { + staticTestValues = append(staticTestValues, Element{}) // zero + staticTestValues = append(staticTestValues, One()) // one + staticTestValues = append(staticTestValues, rSquare) // r² + var e, one Element + one.SetOne() + e.Sub(&qElement, &one) + staticTestValues = append(staticTestValues, e) // q - 1 + e.Double(&one) + staticTestValues = append(staticTestValues, e) // 2 + + { + a := qElement + a[0]-- + staticTestValues = append(staticTestValues, a) + } + staticTestValues = append(staticTestValues, Element{0}) + staticTestValues = append(staticTestValues, Element{0, 0}) + staticTestValues = append(staticTestValues, Element{1}) + staticTestValues = append(staticTestValues, Element{0, 1}) + staticTestValues = append(staticTestValues, Element{2}) + staticTestValues = append(staticTestValues, Element{0, 2}) + + { + a := qElement + a[3]-- + staticTestValues = append(staticTestValues, a) + } + { + a := qElement + a[3]-- + a[0]++ + staticTestValues = append(staticTestValues, a) + } + + { + a := qElement + a[3] = 0 + staticTestValues = append(staticTestValues, a) + } + +} + +func TestElementReduce(t *testing.T) { + testValues := make([]Element, len(staticTestValues)) + copy(testValues, staticTestValues) + + for i := range testValues { + s := testValues[i] + expected := s + reduce(&s) + _reduceGeneric(&expected) + if !s.Equal(&expected) { + t.Fatal("reduce failed: asm and generic impl don't match") + } + } + + t.Parallel() + parameters := gopter.DefaultTestParameters() + if testing.Short() { + parameters.MinSuccessfulTests = nbFuzzShort + } else { + parameters.MinSuccessfulTests = nbFuzz + } + + properties := gopter.NewProperties(parameters) + + genA := genFull() + + properties.Property("reduce should output a result smaller than modulus", prop.ForAll( + func(a Element) bool { + b := a + reduce(&a) + _reduceGeneric(&b) + return a.smallerThanModulus() && a.Equal(&b) + }, + genA, + )) + + properties.TestingRun(t, gopter.ConsoleReporter(false)) + +} + +func TestElementEqual(t *testing.T) { + t.Parallel() + parameters := gopter.DefaultTestParameters() + if testing.Short() { + parameters.MinSuccessfulTests = nbFuzzShort + } else { + parameters.MinSuccessfulTests = nbFuzz + } + + properties := gopter.NewProperties(parameters) + + genA := gen() + genB := gen() + + properties.Property("x.Equal(&y) iff x == y; likely false for random pairs", prop.ForAll( + func(a testPairElement, b testPairElement) bool { + return a.element.Equal(&b.element) == (a.element == b.element) + }, + genA, + genB, + )) + + properties.Property("x.Equal(&y) if x == y", prop.ForAll( + func(a testPairElement) bool { + b := a.element + return a.element.Equal(&b) + }, + genA, + )) + + properties.TestingRun(t, gopter.ConsoleReporter(false)) +} + +func TestElementBytes(t *testing.T) { + t.Parallel() + parameters := gopter.DefaultTestParameters() + if testing.Short() { + parameters.MinSuccessfulTests = nbFuzzShort + } else { + parameters.MinSuccessfulTests = nbFuzz + } + + properties := gopter.NewProperties(parameters) + + genA := gen() + + properties.Property("SetBytes(Bytes()) should stay constant", prop.ForAll( + func(a testPairElement) bool { + var b Element + bytes := a.element.Bytes() + b.SetBytes(bytes[:]) + return a.element.Equal(&b) + }, + genA, + )) + + properties.TestingRun(t, gopter.ConsoleReporter(false)) +} + +func TestElementInverseExp(t *testing.T) { + // inverse must be equal to exp^-2 + exp := Modulus() + exp.Sub(exp, new(big.Int).SetUint64(2)) + + invMatchExp := func(a testPairElement) bool { + var b Element + b.Set(&a.element) + a.element.Inverse(&a.element) + b.Exp(b, exp) + + return a.element.Equal(&b) + } + + t.Parallel() + parameters := gopter.DefaultTestParameters() + if testing.Short() { + parameters.MinSuccessfulTests = nbFuzzShort + } else { + parameters.MinSuccessfulTests = nbFuzz + } + properties := gopter.NewProperties(parameters) + genA := gen() + properties.Property("inv == exp^-2", prop.ForAll(invMatchExp, genA)) + properties.TestingRun(t, gopter.ConsoleReporter(false)) + + parameters.MinSuccessfulTests = 1 + properties = gopter.NewProperties(parameters) + properties.Property("inv(0) == 0", prop.ForAll(invMatchExp, ggen.OneConstOf(testPairElement{}))) + properties.TestingRun(t, gopter.ConsoleReporter(false)) + +} + +func mulByConstant(z *Element, c uint8) { + var y Element + y.SetUint64(uint64(c)) + z.Mul(z, &y) +} + +func TestElementMulByConstants(t *testing.T) { + + t.Parallel() + parameters := gopter.DefaultTestParameters() + if testing.Short() { + parameters.MinSuccessfulTests = nbFuzzShort + } else { + parameters.MinSuccessfulTests = nbFuzz + } + + properties := gopter.NewProperties(parameters) + + genA := gen() + + implemented := []uint8{0, 1, 2, 3, 5, 13} + properties.Property("mulByConstant", prop.ForAll( + func(a testPairElement) bool { + for _, c := range implemented { + var constant Element + constant.SetUint64(uint64(c)) + + b := a.element + b.Mul(&b, &constant) + + aa := a.element + mulByConstant(&aa, c) + + if !aa.Equal(&b) { + return false + } + } + + return true + }, + genA, + )) + + properties.Property("MulBy3(x) == Mul(x, 3)", prop.ForAll( + func(a testPairElement) bool { + var constant Element + constant.SetUint64(3) + + b := a.element + b.Mul(&b, &constant) + + MulBy3(&a.element) + + return a.element.Equal(&b) + }, + genA, + )) + + properties.Property("MulBy5(x) == Mul(x, 5)", prop.ForAll( + func(a testPairElement) bool { + var constant Element + constant.SetUint64(5) + + b := a.element + b.Mul(&b, &constant) + + MulBy5(&a.element) + + return a.element.Equal(&b) + }, + genA, + )) + + properties.Property("MulBy13(x) == Mul(x, 13)", prop.ForAll( + func(a testPairElement) bool { + var constant Element + constant.SetUint64(13) + + b := a.element + b.Mul(&b, &constant) + + MulBy13(&a.element) + + return a.element.Equal(&b) + }, + genA, + )) + + properties.TestingRun(t, gopter.ConsoleReporter(false)) + +} + +func TestElementLegendre(t *testing.T) { + t.Parallel() + parameters := gopter.DefaultTestParameters() + if testing.Short() { + parameters.MinSuccessfulTests = nbFuzzShort + } else { + parameters.MinSuccessfulTests = nbFuzz + } + + properties := gopter.NewProperties(parameters) + + genA := gen() + + properties.Property("legendre should output same result than big.Int.Jacobi", prop.ForAll( + func(a testPairElement) bool { + return a.element.Legendre() == big.Jacobi(&a.bigint, Modulus()) + }, + genA, + )) + + properties.TestingRun(t, gopter.ConsoleReporter(false)) + +} + +func TestElementBitLen(t *testing.T) { + t.Parallel() + parameters := gopter.DefaultTestParameters() + if testing.Short() { + parameters.MinSuccessfulTests = nbFuzzShort + } else { + parameters.MinSuccessfulTests = nbFuzz + } + + properties := gopter.NewProperties(parameters) + + genA := gen() + + properties.Property("BitLen should output same result than big.Int.BitLen", prop.ForAll( + func(a testPairElement) bool { + return a.element.fromMont().BitLen() == a.bigint.BitLen() + }, + genA, + )) + + properties.TestingRun(t, gopter.ConsoleReporter(false)) +} + +func TestElementButterflies(t *testing.T) { + + t.Parallel() + parameters := gopter.DefaultTestParameters() + if testing.Short() { + parameters.MinSuccessfulTests = nbFuzzShort + } else { + parameters.MinSuccessfulTests = nbFuzz + } + + properties := gopter.NewProperties(parameters) + + genA := gen() + + properties.Property("butterfly0 == a -b; a +b", prop.ForAll( + func(a, b testPairElement) bool { + a0, b0 := a.element, b.element + + _butterflyGeneric(&a.element, &b.element) + Butterfly(&a0, &b0) + + return a.element.Equal(&a0) && b.element.Equal(&b0) + }, + genA, + genA, + )) + + properties.TestingRun(t, gopter.ConsoleReporter(false)) + +} + +func TestElementLexicographicallyLargest(t *testing.T) { + t.Parallel() + parameters := gopter.DefaultTestParameters() + if testing.Short() { + parameters.MinSuccessfulTests = nbFuzzShort + } else { + parameters.MinSuccessfulTests = nbFuzz + } + + properties := gopter.NewProperties(parameters) + + genA := gen() + + properties.Property("element.Cmp should match LexicographicallyLargest output", prop.ForAll( + func(a testPairElement) bool { + var negA Element + negA.Neg(&a.element) + + cmpResult := a.element.Cmp(&negA) + lResult := a.element.LexicographicallyLargest() + + if lResult && cmpResult == 1 { + return true + } + if !lResult && cmpResult != 1 { + return true + } + return false + }, + genA, + )) + + properties.TestingRun(t, gopter.ConsoleReporter(false)) + +} + +func TestElementAdd(t *testing.T) { + t.Parallel() + parameters := gopter.DefaultTestParameters() + if testing.Short() { + parameters.MinSuccessfulTests = nbFuzzShort + } else { + parameters.MinSuccessfulTests = nbFuzz + } + + properties := gopter.NewProperties(parameters) + + genA := gen() + genB := gen() + + properties.Property("Add: having the receiver as operand should output the same result", prop.ForAll( + func(a, b testPairElement) bool { + var c, d Element + d.Set(&a.element) + + c.Add(&a.element, &b.element) + a.element.Add(&a.element, &b.element) + b.element.Add(&d, &b.element) + + return a.element.Equal(&b.element) && a.element.Equal(&c) && b.element.Equal(&c) + }, + genA, + genB, + )) + + properties.Property("Add: operation result must match big.Int result", prop.ForAll( + func(a, b testPairElement) bool { + { + var c Element + + c.Add(&a.element, &b.element) + + var d, e big.Int + d.Add(&a.bigint, &b.bigint).Mod(&d, Modulus()) + + if c.BigInt(&e).Cmp(&d) != 0 { + return false + } + } + + // fixed elements + // a is random + // r takes special values + testValues := make([]Element, len(staticTestValues)) + copy(testValues, staticTestValues) + + for i := range testValues { + r := testValues[i] + var d, e, rb big.Int + r.BigInt(&rb) + + var c Element + c.Add(&a.element, &r) + d.Add(&a.bigint, &rb).Mod(&d, Modulus()) + + if c.BigInt(&e).Cmp(&d) != 0 { + return false + } + } + return true + }, + genA, + genB, + )) + + properties.Property("Add: operation result must be smaller than modulus", prop.ForAll( + func(a, b testPairElement) bool { + var c Element + + c.Add(&a.element, &b.element) + + return c.smallerThanModulus() + }, + genA, + genB, + )) + + specialValueTest := func() { + // test special values against special values + testValues := make([]Element, len(staticTestValues)) + copy(testValues, staticTestValues) + + for i := range testValues { + a := testValues[i] + var aBig big.Int + a.BigInt(&aBig) + for j := range testValues { + b := testValues[j] + var bBig, d, e big.Int + b.BigInt(&bBig) + + var c Element + c.Add(&a, &b) + d.Add(&aBig, &bBig).Mod(&d, Modulus()) + + if c.BigInt(&e).Cmp(&d) != 0 { + t.Fatal("Add failed special test values") + } + } + } + } + + properties.TestingRun(t, gopter.ConsoleReporter(false)) + specialValueTest() + +} + +func TestElementSub(t *testing.T) { + t.Parallel() + parameters := gopter.DefaultTestParameters() + if testing.Short() { + parameters.MinSuccessfulTests = nbFuzzShort + } else { + parameters.MinSuccessfulTests = nbFuzz + } + + properties := gopter.NewProperties(parameters) + + genA := gen() + genB := gen() + + properties.Property("Sub: having the receiver as operand should output the same result", prop.ForAll( + func(a, b testPairElement) bool { + var c, d Element + d.Set(&a.element) + + c.Sub(&a.element, &b.element) + a.element.Sub(&a.element, &b.element) + b.element.Sub(&d, &b.element) + + return a.element.Equal(&b.element) && a.element.Equal(&c) && b.element.Equal(&c) + }, + genA, + genB, + )) + + properties.Property("Sub: operation result must match big.Int result", prop.ForAll( + func(a, b testPairElement) bool { + { + var c Element + + c.Sub(&a.element, &b.element) + + var d, e big.Int + d.Sub(&a.bigint, &b.bigint).Mod(&d, Modulus()) + + if c.BigInt(&e).Cmp(&d) != 0 { + return false + } + } + + // fixed elements + // a is random + // r takes special values + testValues := make([]Element, len(staticTestValues)) + copy(testValues, staticTestValues) + + for i := range testValues { + r := testValues[i] + var d, e, rb big.Int + r.BigInt(&rb) + + var c Element + c.Sub(&a.element, &r) + d.Sub(&a.bigint, &rb).Mod(&d, Modulus()) + + if c.BigInt(&e).Cmp(&d) != 0 { + return false + } + } + return true + }, + genA, + genB, + )) + + properties.Property("Sub: operation result must be smaller than modulus", prop.ForAll( + func(a, b testPairElement) bool { + var c Element + + c.Sub(&a.element, &b.element) + + return c.smallerThanModulus() + }, + genA, + genB, + )) + + specialValueTest := func() { + // test special values against special values + testValues := make([]Element, len(staticTestValues)) + copy(testValues, staticTestValues) + + for i := range testValues { + a := testValues[i] + var aBig big.Int + a.BigInt(&aBig) + for j := range testValues { + b := testValues[j] + var bBig, d, e big.Int + b.BigInt(&bBig) + + var c Element + c.Sub(&a, &b) + d.Sub(&aBig, &bBig).Mod(&d, Modulus()) + + if c.BigInt(&e).Cmp(&d) != 0 { + t.Fatal("Sub failed special test values") + } + } + } + } + + properties.TestingRun(t, gopter.ConsoleReporter(false)) + specialValueTest() + +} + +func TestElementMul(t *testing.T) { + t.Parallel() + parameters := gopter.DefaultTestParameters() + if testing.Short() { + parameters.MinSuccessfulTests = nbFuzzShort + } else { + parameters.MinSuccessfulTests = nbFuzz + } + + properties := gopter.NewProperties(parameters) + + genA := gen() + genB := gen() + + properties.Property("Mul: having the receiver as operand should output the same result", prop.ForAll( + func(a, b testPairElement) bool { + var c, d Element + d.Set(&a.element) + + c.Mul(&a.element, &b.element) + a.element.Mul(&a.element, &b.element) + b.element.Mul(&d, &b.element) + + return a.element.Equal(&b.element) && a.element.Equal(&c) && b.element.Equal(&c) + }, + genA, + genB, + )) + + properties.Property("Mul: operation result must match big.Int result", prop.ForAll( + func(a, b testPairElement) bool { + { + var c Element + + c.Mul(&a.element, &b.element) + + var d, e big.Int + d.Mul(&a.bigint, &b.bigint).Mod(&d, Modulus()) + + if c.BigInt(&e).Cmp(&d) != 0 { + return false + } + } + + // fixed elements + // a is random + // r takes special values + testValues := make([]Element, len(staticTestValues)) + copy(testValues, staticTestValues) + + for i := range testValues { + r := testValues[i] + var d, e, rb big.Int + r.BigInt(&rb) + + var c Element + c.Mul(&a.element, &r) + d.Mul(&a.bigint, &rb).Mod(&d, Modulus()) + + // checking generic impl against asm path + var cGeneric Element + _mulGeneric(&cGeneric, &a.element, &r) + if !cGeneric.Equal(&c) { + // need to give context to failing error. + return false + } + + if c.BigInt(&e).Cmp(&d) != 0 { + return false + } + } + return true + }, + genA, + genB, + )) + + properties.Property("Mul: operation result must be smaller than modulus", prop.ForAll( + func(a, b testPairElement) bool { + var c Element + + c.Mul(&a.element, &b.element) + + return c.smallerThanModulus() + }, + genA, + genB, + )) + + properties.Property("Mul: assembly implementation must be consistent with generic one", prop.ForAll( + func(a, b testPairElement) bool { + var c, d Element + c.Mul(&a.element, &b.element) + _mulGeneric(&d, &a.element, &b.element) + return c.Equal(&d) + }, + genA, + genB, + )) + + specialValueTest := func() { + // test special values against special values + testValues := make([]Element, len(staticTestValues)) + copy(testValues, staticTestValues) + + for i := range testValues { + a := testValues[i] + var aBig big.Int + a.BigInt(&aBig) + for j := range testValues { + b := testValues[j] + var bBig, d, e big.Int + b.BigInt(&bBig) + + var c Element + c.Mul(&a, &b) + d.Mul(&aBig, &bBig).Mod(&d, Modulus()) + + // checking asm against generic impl + var cGeneric Element + _mulGeneric(&cGeneric, &a, &b) + if !cGeneric.Equal(&c) { + t.Fatal("Mul failed special test values: asm and generic impl don't match") + } + + if c.BigInt(&e).Cmp(&d) != 0 { + t.Fatal("Mul failed special test values") + } + } + } + } + + properties.TestingRun(t, gopter.ConsoleReporter(false)) + specialValueTest() + +} + +func TestElementDiv(t *testing.T) { + t.Parallel() + parameters := gopter.DefaultTestParameters() + if testing.Short() { + parameters.MinSuccessfulTests = nbFuzzShort + } else { + parameters.MinSuccessfulTests = nbFuzz + } + + properties := gopter.NewProperties(parameters) + + genA := gen() + genB := gen() + + properties.Property("Div: having the receiver as operand should output the same result", prop.ForAll( + func(a, b testPairElement) bool { + var c, d Element + d.Set(&a.element) + + c.Div(&a.element, &b.element) + a.element.Div(&a.element, &b.element) + b.element.Div(&d, &b.element) + + return a.element.Equal(&b.element) && a.element.Equal(&c) && b.element.Equal(&c) + }, + genA, + genB, + )) + + properties.Property("Div: operation result must match big.Int result", prop.ForAll( + func(a, b testPairElement) bool { + { + var c Element + + c.Div(&a.element, &b.element) + + var d, e big.Int + d.ModInverse(&b.bigint, Modulus()) + d.Mul(&d, &a.bigint).Mod(&d, Modulus()) + + if c.BigInt(&e).Cmp(&d) != 0 { + return false + } + } + + // fixed elements + // a is random + // r takes special values + testValues := make([]Element, len(staticTestValues)) + copy(testValues, staticTestValues) + + for i := range testValues { + r := testValues[i] + var d, e, rb big.Int + r.BigInt(&rb) + + var c Element + c.Div(&a.element, &r) + d.ModInverse(&rb, Modulus()) + d.Mul(&d, &a.bigint).Mod(&d, Modulus()) + + if c.BigInt(&e).Cmp(&d) != 0 { + return false + } + } + return true + }, + genA, + genB, + )) + + properties.Property("Div: operation result must be smaller than modulus", prop.ForAll( + func(a, b testPairElement) bool { + var c Element + + c.Div(&a.element, &b.element) + + return c.smallerThanModulus() + }, + genA, + genB, + )) + + specialValueTest := func() { + // test special values against special values + testValues := make([]Element, len(staticTestValues)) + copy(testValues, staticTestValues) + + for i := range testValues { + a := testValues[i] + var aBig big.Int + a.BigInt(&aBig) + for j := range testValues { + b := testValues[j] + var bBig, d, e big.Int + b.BigInt(&bBig) + + var c Element + c.Div(&a, &b) + d.ModInverse(&bBig, Modulus()) + d.Mul(&d, &aBig).Mod(&d, Modulus()) + + if c.BigInt(&e).Cmp(&d) != 0 { + t.Fatal("Div failed special test values") + } + } + } + } + + properties.TestingRun(t, gopter.ConsoleReporter(false)) + specialValueTest() + +} + +func TestElementExp(t *testing.T) { + t.Parallel() + parameters := gopter.DefaultTestParameters() + if testing.Short() { + parameters.MinSuccessfulTests = nbFuzzShort + } else { + parameters.MinSuccessfulTests = nbFuzz + } + + properties := gopter.NewProperties(parameters) + + genA := gen() + genB := gen() + + properties.Property("Exp: having the receiver as operand should output the same result", prop.ForAll( + func(a, b testPairElement) bool { + var c, d Element + d.Set(&a.element) + + c.Exp(a.element, &b.bigint) + a.element.Exp(a.element, &b.bigint) + b.element.Exp(d, &b.bigint) + + return a.element.Equal(&b.element) && a.element.Equal(&c) && b.element.Equal(&c) + }, + genA, + genB, + )) + + properties.Property("Exp: operation result must match big.Int result", prop.ForAll( + func(a, b testPairElement) bool { + { + var c Element + + c.Exp(a.element, &b.bigint) + + var d, e big.Int + d.Exp(&a.bigint, &b.bigint, Modulus()) + + if c.BigInt(&e).Cmp(&d) != 0 { + return false + } + } + + // fixed elements + // a is random + // r takes special values + testValues := make([]Element, len(staticTestValues)) + copy(testValues, staticTestValues) + + for i := range testValues { + r := testValues[i] + var d, e, rb big.Int + r.BigInt(&rb) + + var c Element + c.Exp(a.element, &rb) + d.Exp(&a.bigint, &rb, Modulus()) + + if c.BigInt(&e).Cmp(&d) != 0 { + return false + } + } + return true + }, + genA, + genB, + )) + + properties.Property("Exp: operation result must be smaller than modulus", prop.ForAll( + func(a, b testPairElement) bool { + var c Element + + c.Exp(a.element, &b.bigint) + + return c.smallerThanModulus() + }, + genA, + genB, + )) + + specialValueTest := func() { + // test special values against special values + testValues := make([]Element, len(staticTestValues)) + copy(testValues, staticTestValues) + + for i := range testValues { + a := testValues[i] + var aBig big.Int + a.BigInt(&aBig) + for j := range testValues { + b := testValues[j] + var bBig, d, e big.Int + b.BigInt(&bBig) + + var c Element + c.Exp(a, &bBig) + d.Exp(&aBig, &bBig, Modulus()) + + if c.BigInt(&e).Cmp(&d) != 0 { + t.Fatal("Exp failed special test values") + } + } + } + } + + properties.TestingRun(t, gopter.ConsoleReporter(false)) + specialValueTest() + +} + +func TestElementSquare(t *testing.T) { + t.Parallel() + parameters := gopter.DefaultTestParameters() + if testing.Short() { + parameters.MinSuccessfulTests = nbFuzzShort + } else { + parameters.MinSuccessfulTests = nbFuzz + } + + properties := gopter.NewProperties(parameters) + + genA := gen() + + properties.Property("Square: having the receiver as operand should output the same result", prop.ForAll( + func(a testPairElement) bool { + + var b Element + + b.Square(&a.element) + a.element.Square(&a.element) + return a.element.Equal(&b) + }, + genA, + )) + + properties.Property("Square: operation result must match big.Int result", prop.ForAll( + func(a testPairElement) bool { + var c Element + c.Square(&a.element) + + var d, e big.Int + d.Mul(&a.bigint, &a.bigint).Mod(&d, Modulus()) + + return c.BigInt(&e).Cmp(&d) == 0 + }, + genA, + )) + + properties.Property("Square: operation result must be smaller than modulus", prop.ForAll( + func(a testPairElement) bool { + var c Element + c.Square(&a.element) + return c.smallerThanModulus() + }, + genA, + )) + + specialValueTest := func() { + // test special values + testValues := make([]Element, len(staticTestValues)) + copy(testValues, staticTestValues) + + for i := range testValues { + a := testValues[i] + var aBig big.Int + a.BigInt(&aBig) + var c Element + c.Square(&a) + + var d, e big.Int + d.Mul(&aBig, &aBig).Mod(&d, Modulus()) + + if c.BigInt(&e).Cmp(&d) != 0 { + t.Fatal("Square failed special test values") + } + } + } + + properties.TestingRun(t, gopter.ConsoleReporter(false)) + specialValueTest() + +} + +func TestElementInverse(t *testing.T) { + t.Parallel() + parameters := gopter.DefaultTestParameters() + if testing.Short() { + parameters.MinSuccessfulTests = nbFuzzShort + } else { + parameters.MinSuccessfulTests = nbFuzz + } + + properties := gopter.NewProperties(parameters) + + genA := gen() + + properties.Property("Inverse: having the receiver as operand should output the same result", prop.ForAll( + func(a testPairElement) bool { + + var b Element + + b.Inverse(&a.element) + a.element.Inverse(&a.element) + return a.element.Equal(&b) + }, + genA, + )) + + properties.Property("Inverse: operation result must match big.Int result", prop.ForAll( + func(a testPairElement) bool { + var c Element + c.Inverse(&a.element) + + var d, e big.Int + d.ModInverse(&a.bigint, Modulus()) + + return c.BigInt(&e).Cmp(&d) == 0 + }, + genA, + )) + + properties.Property("Inverse: operation result must be smaller than modulus", prop.ForAll( + func(a testPairElement) bool { + var c Element + c.Inverse(&a.element) + return c.smallerThanModulus() + }, + genA, + )) + + specialValueTest := func() { + // test special values + testValues := make([]Element, len(staticTestValues)) + copy(testValues, staticTestValues) + + for i := range testValues { + a := testValues[i] + var aBig big.Int + a.BigInt(&aBig) + var c Element + c.Inverse(&a) + + var d, e big.Int + d.ModInverse(&aBig, Modulus()) + + if c.BigInt(&e).Cmp(&d) != 0 { + t.Fatal("Inverse failed special test values") + } + } + } + + properties.TestingRun(t, gopter.ConsoleReporter(false)) + specialValueTest() + +} + +func TestElementSqrt(t *testing.T) { + t.Parallel() + parameters := gopter.DefaultTestParameters() + if testing.Short() { + parameters.MinSuccessfulTests = nbFuzzShort + } else { + parameters.MinSuccessfulTests = nbFuzz + } + + properties := gopter.NewProperties(parameters) + + genA := gen() + + properties.Property("Sqrt: having the receiver as operand should output the same result", prop.ForAll( + func(a testPairElement) bool { + + b := a.element + + b.Sqrt(&a.element) + a.element.Sqrt(&a.element) + return a.element.Equal(&b) + }, + genA, + )) + + properties.Property("Sqrt: operation result must match big.Int result", prop.ForAll( + func(a testPairElement) bool { + var c Element + c.Sqrt(&a.element) + + var d, e big.Int + d.ModSqrt(&a.bigint, Modulus()) + + return c.BigInt(&e).Cmp(&d) == 0 + }, + genA, + )) + + properties.Property("Sqrt: operation result must be smaller than modulus", prop.ForAll( + func(a testPairElement) bool { + var c Element + c.Sqrt(&a.element) + return c.smallerThanModulus() + }, + genA, + )) + + specialValueTest := func() { + // test special values + testValues := make([]Element, len(staticTestValues)) + copy(testValues, staticTestValues) + + for i := range testValues { + a := testValues[i] + var aBig big.Int + a.BigInt(&aBig) + var c Element + c.Sqrt(&a) + + var d, e big.Int + d.ModSqrt(&aBig, Modulus()) + + if c.BigInt(&e).Cmp(&d) != 0 { + t.Fatal("Sqrt failed special test values") + } + } + } + + properties.TestingRun(t, gopter.ConsoleReporter(false)) + specialValueTest() + +} + +func TestElementDouble(t *testing.T) { + t.Parallel() + parameters := gopter.DefaultTestParameters() + if testing.Short() { + parameters.MinSuccessfulTests = nbFuzzShort + } else { + parameters.MinSuccessfulTests = nbFuzz + } + + properties := gopter.NewProperties(parameters) + + genA := gen() + + properties.Property("Double: having the receiver as operand should output the same result", prop.ForAll( + func(a testPairElement) bool { + + var b Element + + b.Double(&a.element) + a.element.Double(&a.element) + return a.element.Equal(&b) + }, + genA, + )) + + properties.Property("Double: operation result must match big.Int result", prop.ForAll( + func(a testPairElement) bool { + var c Element + c.Double(&a.element) + + var d, e big.Int + d.Lsh(&a.bigint, 1).Mod(&d, Modulus()) + + return c.BigInt(&e).Cmp(&d) == 0 + }, + genA, + )) + + properties.Property("Double: operation result must be smaller than modulus", prop.ForAll( + func(a testPairElement) bool { + var c Element + c.Double(&a.element) + return c.smallerThanModulus() + }, + genA, + )) + + specialValueTest := func() { + // test special values + testValues := make([]Element, len(staticTestValues)) + copy(testValues, staticTestValues) + + for i := range testValues { + a := testValues[i] + var aBig big.Int + a.BigInt(&aBig) + var c Element + c.Double(&a) + + var d, e big.Int + d.Lsh(&aBig, 1).Mod(&d, Modulus()) + + if c.BigInt(&e).Cmp(&d) != 0 { + t.Fatal("Double failed special test values") + } + } + } + + properties.TestingRun(t, gopter.ConsoleReporter(false)) + specialValueTest() + +} + +func TestElementNeg(t *testing.T) { + t.Parallel() + parameters := gopter.DefaultTestParameters() + if testing.Short() { + parameters.MinSuccessfulTests = nbFuzzShort + } else { + parameters.MinSuccessfulTests = nbFuzz + } + + properties := gopter.NewProperties(parameters) + + genA := gen() + + properties.Property("Neg: having the receiver as operand should output the same result", prop.ForAll( + func(a testPairElement) bool { + + var b Element + + b.Neg(&a.element) + a.element.Neg(&a.element) + return a.element.Equal(&b) + }, + genA, + )) + + properties.Property("Neg: operation result must match big.Int result", prop.ForAll( + func(a testPairElement) bool { + var c Element + c.Neg(&a.element) + + var d, e big.Int + d.Neg(&a.bigint).Mod(&d, Modulus()) + + return c.BigInt(&e).Cmp(&d) == 0 + }, + genA, + )) + + properties.Property("Neg: operation result must be smaller than modulus", prop.ForAll( + func(a testPairElement) bool { + var c Element + c.Neg(&a.element) + return c.smallerThanModulus() + }, + genA, + )) + + specialValueTest := func() { + // test special values + testValues := make([]Element, len(staticTestValues)) + copy(testValues, staticTestValues) + + for i := range testValues { + a := testValues[i] + var aBig big.Int + a.BigInt(&aBig) + var c Element + c.Neg(&a) + + var d, e big.Int + d.Neg(&aBig).Mod(&d, Modulus()) + + if c.BigInt(&e).Cmp(&d) != 0 { + t.Fatal("Neg failed special test values") + } + } + } + + properties.TestingRun(t, gopter.ConsoleReporter(false)) + specialValueTest() + +} + +func TestElementFixedExp(t *testing.T) { + + t.Parallel() + parameters := gopter.DefaultTestParameters() + if testing.Short() { + parameters.MinSuccessfulTests = nbFuzzShort + } else { + parameters.MinSuccessfulTests = nbFuzz + } + + properties := gopter.NewProperties(parameters) + + var ( + _bLegendreExponentElement *big.Int + _bSqrtExponentElement *big.Int + ) + + _bLegendreExponentElement, _ = new(big.Int).SetString("183227397098d014dc2822db40c0ac2ecbc0b548b438e5469e10460b6c3e7ea3", 16) + const sqrtExponentElement = "c19139cb84c680a6e14116da060561765e05aa45a1c72a34f082305b61f3f52" + _bSqrtExponentElement, _ = new(big.Int).SetString(sqrtExponentElement, 16) + + genA := gen() + + properties.Property(fmt.Sprintf("expBySqrtExp must match Exp(%s)", sqrtExponentElement), prop.ForAll( + func(a testPairElement) bool { + c := a.element + d := a.element + c.expBySqrtExp(c) + d.Exp(d, _bSqrtExponentElement) + return c.Equal(&d) + }, + genA, + )) + + properties.Property("expByLegendreExp must match Exp(183227397098d014dc2822db40c0ac2ecbc0b548b438e5469e10460b6c3e7ea3)", prop.ForAll( + func(a testPairElement) bool { + c := a.element + d := a.element + c.expByLegendreExp(c) + d.Exp(d, _bLegendreExponentElement) + return c.Equal(&d) + }, + genA, + )) + + properties.TestingRun(t, gopter.ConsoleReporter(false)) +} + +func TestElementHalve(t *testing.T) { + + t.Parallel() + parameters := gopter.DefaultTestParameters() + if testing.Short() { + parameters.MinSuccessfulTests = nbFuzzShort + } else { + parameters.MinSuccessfulTests = nbFuzz + } + + properties := gopter.NewProperties(parameters) + + genA := gen() + var twoInv Element + twoInv.SetUint64(2) + twoInv.Inverse(&twoInv) + + properties.Property("z.Halve must match z / 2", prop.ForAll( + func(a testPairElement) bool { + c := a.element + d := a.element + c.Halve() + d.Mul(&d, &twoInv) + return c.Equal(&d) + }, + genA, + )) + + properties.TestingRun(t, gopter.ConsoleReporter(false)) +} + +func combineSelectionArguments(c int64, z int8) int { + if z%3 == 0 { + return 0 + } + return int(c) +} + +func TestElementSelect(t *testing.T) { + t.Parallel() + parameters := gopter.DefaultTestParameters() + if testing.Short() { + parameters.MinSuccessfulTests = nbFuzzShort + } else { + parameters.MinSuccessfulTests = nbFuzz + } + + properties := gopter.NewProperties(parameters) + + genA := genFull() + genB := genFull() + genC := ggen.Int64() //the condition + genZ := ggen.Int8() //to make zeros artificially more likely + + properties.Property("Select: must select correctly", prop.ForAll( + func(a, b Element, cond int64, z int8) bool { + condC := combineSelectionArguments(cond, z) + + var c Element + c.Select(condC, &a, &b) + + if condC == 0 { + return c.Equal(&a) + } + return c.Equal(&b) + }, + genA, + genB, + genC, + genZ, + )) + + properties.Property("Select: having the receiver as operand should output the same result", prop.ForAll( + func(a, b Element, cond int64, z int8) bool { + condC := combineSelectionArguments(cond, z) + + var c, d Element + d.Set(&a) + c.Select(condC, &a, &b) + a.Select(condC, &a, &b) + b.Select(condC, &d, &b) + return a.Equal(&b) && a.Equal(&c) && b.Equal(&c) + }, + genA, + genB, + genC, + genZ, + )) + + properties.TestingRun(t, gopter.ConsoleReporter(false)) +} + +func TestElementSetInt64(t *testing.T) { + + t.Parallel() + parameters := gopter.DefaultTestParameters() + if testing.Short() { + parameters.MinSuccessfulTests = nbFuzzShort + } else { + parameters.MinSuccessfulTests = nbFuzz + } + + properties := gopter.NewProperties(parameters) + + genA := gen() + + properties.Property("z.SetInt64 must match z.SetString", prop.ForAll( + func(a testPairElement, v int64) bool { + c := a.element + d := a.element + + c.SetInt64(v) + d.SetString(fmt.Sprintf("%v", v)) + + return c.Equal(&d) + }, + genA, ggen.Int64(), + )) + + properties.TestingRun(t, gopter.ConsoleReporter(false)) +} + +func TestElementSetInterface(t *testing.T) { + + t.Parallel() + parameters := gopter.DefaultTestParameters() + if testing.Short() { + parameters.MinSuccessfulTests = nbFuzzShort + } else { + parameters.MinSuccessfulTests = nbFuzz + } + + properties := gopter.NewProperties(parameters) + + genA := gen() + genInt := ggen.Int + genInt8 := ggen.Int8 + genInt16 := ggen.Int16 + genInt32 := ggen.Int32 + genInt64 := ggen.Int64 + + genUint := ggen.UInt + genUint8 := ggen.UInt8 + genUint16 := ggen.UInt16 + genUint32 := ggen.UInt32 + genUint64 := ggen.UInt64 + + properties.Property("z.SetInterface must match z.SetString with int8", prop.ForAll( + func(a testPairElement, v int8) bool { + c := a.element + d := a.element + + c.SetInterface(v) + d.SetString(fmt.Sprintf("%v", v)) + + return c.Equal(&d) + }, + genA, genInt8(), + )) + + properties.Property("z.SetInterface must match z.SetString with int16", prop.ForAll( + func(a testPairElement, v int16) bool { + c := a.element + d := a.element + + c.SetInterface(v) + d.SetString(fmt.Sprintf("%v", v)) + + return c.Equal(&d) + }, + genA, genInt16(), + )) + + properties.Property("z.SetInterface must match z.SetString with int32", prop.ForAll( + func(a testPairElement, v int32) bool { + c := a.element + d := a.element + + c.SetInterface(v) + d.SetString(fmt.Sprintf("%v", v)) + + return c.Equal(&d) + }, + genA, genInt32(), + )) + + properties.Property("z.SetInterface must match z.SetString with int64", prop.ForAll( + func(a testPairElement, v int64) bool { + c := a.element + d := a.element + + c.SetInterface(v) + d.SetString(fmt.Sprintf("%v", v)) + + return c.Equal(&d) + }, + genA, genInt64(), + )) + + properties.Property("z.SetInterface must match z.SetString with int", prop.ForAll( + func(a testPairElement, v int) bool { + c := a.element + d := a.element + + c.SetInterface(v) + d.SetString(fmt.Sprintf("%v", v)) + + return c.Equal(&d) + }, + genA, genInt(), + )) + + properties.Property("z.SetInterface must match z.SetString with uint8", prop.ForAll( + func(a testPairElement, v uint8) bool { + c := a.element + d := a.element + + c.SetInterface(v) + d.SetString(fmt.Sprintf("%v", v)) + + return c.Equal(&d) + }, + genA, genUint8(), + )) + + properties.Property("z.SetInterface must match z.SetString with uint16", prop.ForAll( + func(a testPairElement, v uint16) bool { + c := a.element + d := a.element + + c.SetInterface(v) + d.SetString(fmt.Sprintf("%v", v)) + + return c.Equal(&d) + }, + genA, genUint16(), + )) + + properties.Property("z.SetInterface must match z.SetString with uint32", prop.ForAll( + func(a testPairElement, v uint32) bool { + c := a.element + d := a.element + + c.SetInterface(v) + d.SetString(fmt.Sprintf("%v", v)) + + return c.Equal(&d) + }, + genA, genUint32(), + )) + + properties.Property("z.SetInterface must match z.SetString with uint64", prop.ForAll( + func(a testPairElement, v uint64) bool { + c := a.element + d := a.element + + c.SetInterface(v) + d.SetString(fmt.Sprintf("%v", v)) + + return c.Equal(&d) + }, + genA, genUint64(), + )) + + properties.Property("z.SetInterface must match z.SetString with uint", prop.ForAll( + func(a testPairElement, v uint) bool { + c := a.element + d := a.element + + c.SetInterface(v) + d.SetString(fmt.Sprintf("%v", v)) + + return c.Equal(&d) + }, + genA, genUint(), + )) + + properties.TestingRun(t, gopter.ConsoleReporter(false)) + + { + assert := require.New(t) + var e Element + r, err := e.SetInterface(nil) + assert.Nil(r) + assert.Error(err) + + var ptE *Element + var ptB *big.Int + + r, err = e.SetInterface(ptE) + assert.Nil(r) + assert.Error(err) + ptE = new(Element).SetOne() + r, err = e.SetInterface(ptE) + assert.NoError(err) + assert.True(r.IsOne()) + + r, err = e.SetInterface(ptB) + assert.Nil(r) + assert.Error(err) + + } +} + +func TestElementNegativeExp(t *testing.T) { + t.Parallel() + + parameters := gopter.DefaultTestParameters() + if testing.Short() { + parameters.MinSuccessfulTests = nbFuzzShort + } else { + parameters.MinSuccessfulTests = nbFuzz + } + + properties := gopter.NewProperties(parameters) + + genA := gen() + + properties.Property("x⁻ᵏ == 1/xᵏ", prop.ForAll( + func(a, b testPairElement) bool { + + var nb, d, e big.Int + nb.Neg(&b.bigint) + + var c Element + c.Exp(a.element, &nb) + + d.Exp(&a.bigint, &nb, Modulus()) + + return c.BigInt(&e).Cmp(&d) == 0 + }, + genA, genA, + )) + + properties.TestingRun(t, gopter.ConsoleReporter(false)) +} + +func TestElementNewElement(t *testing.T) { + assert := require.New(t) + + t.Parallel() + + e := NewElement(1) + assert.True(e.IsOne()) + + e = NewElement(0) + assert.True(e.IsZero()) +} + +func TestElementBatchInvert(t *testing.T) { + assert := require.New(t) + + t.Parallel() + + // ensure batchInvert([x]) == invert(x) + for i := int64(-1); i <= 2; i++ { + var e, eInv Element + e.SetInt64(i) + eInv.Inverse(&e) + + a := []Element{e} + aInv := BatchInvert(a) + + assert.True(aInv[0].Equal(&eInv), "batchInvert != invert") + + } + + // test x * x⁻¹ == 1 + tData := [][]int64{ + {-1, 1, 2, 3}, + {0, -1, 1, 2, 3, 0}, + {0, -1, 1, 0, 2, 3, 0}, + {-1, 1, 0, 2, 3}, + {0, 0, 1}, + {1, 0, 0}, + {0, 0, 0}, + } + + for _, t := range tData { + a := make([]Element, len(t)) + for i := 0; i < len(a); i++ { + a[i].SetInt64(t[i]) + } + + aInv := BatchInvert(a) + + assert.True(len(aInv) == len(a)) + + for i := 0; i < len(a); i++ { + if a[i].IsZero() { + assert.True(aInv[i].IsZero(), "0⁻¹ != 0") + } else { + assert.True(a[i].Mul(&a[i], &aInv[i]).IsOne(), "x * x⁻¹ != 1") + } + } + } + + parameters := gopter.DefaultTestParameters() + if testing.Short() { + parameters.MinSuccessfulTests = nbFuzzShort + } else { + parameters.MinSuccessfulTests = nbFuzz + } + + properties := gopter.NewProperties(parameters) + + genA := gen() + + properties.Property("batchInvert --> x * x⁻¹ == 1", prop.ForAll( + func(tp testPairElement, r uint8) bool { + + a := make([]Element, r) + if r != 0 { + a[0] = tp.element + + } + one := One() + for i := 1; i < len(a); i++ { + a[i].Add(&a[i-1], &one) + } + + aInv := BatchInvert(a) + + assert.True(len(aInv) == len(a)) + + for i := 0; i < len(a); i++ { + if a[i].IsZero() { + if !aInv[i].IsZero() { + return false + } + } else { + if !a[i].Mul(&a[i], &aInv[i]).IsOne() { + return false + } + } + } + return true + }, + genA, ggen.UInt8(), + )) + + properties.TestingRun(t, gopter.ConsoleReporter(false)) +} + +func TestElementFromMont(t *testing.T) { + + t.Parallel() + parameters := gopter.DefaultTestParameters() + if testing.Short() { + parameters.MinSuccessfulTests = nbFuzzShort + } else { + parameters.MinSuccessfulTests = nbFuzz + } + + properties := gopter.NewProperties(parameters) + + genA := gen() + + properties.Property("Assembly implementation must be consistent with generic one", prop.ForAll( + func(a testPairElement) bool { + c := a.element + d := a.element + c.fromMont() + _fromMontGeneric(&d) + return c.Equal(&d) + }, + genA, + )) + + properties.Property("x.fromMont().toMont() == x", prop.ForAll( + func(a testPairElement) bool { + c := a.element + c.fromMont().toMont() + return c.Equal(&a.element) + }, + genA, + )) + + properties.TestingRun(t, gopter.ConsoleReporter(false)) +} + +func TestElementJSON(t *testing.T) { + assert := require.New(t) + + type S struct { + A Element + B [3]Element + C *Element + D *Element + } + + // encode to JSON + var s S + s.A.SetString("-1") + s.B[2].SetUint64(42) + s.D = new(Element).SetUint64(8000) + + encoded, err := json.Marshal(&s) + assert.NoError(err) + // we may need to adjust "42" and "8000" values for some moduli; see Text() method for more details. + formatValue := func(v int64) string { + var a big.Int + a.SetInt64(v) + a.Mod(&a, Modulus()) + const maxUint16 = 65535 + var aNeg big.Int + aNeg.Neg(&a).Mod(&aNeg, Modulus()) + if aNeg.Uint64() != 0 && aNeg.Uint64() <= maxUint16 { + return "-" + aNeg.Text(10) + } + return a.Text(10) + } + expected := fmt.Sprintf("{\"A\":%s,\"B\":[0,0,%s],\"C\":null,\"D\":%s}", formatValue(-1), formatValue(42), formatValue(8000)) + assert.Equal(expected, string(encoded)) + + // decode valid + var decoded S + err = json.Unmarshal([]byte(expected), &decoded) + assert.NoError(err) + + assert.Equal(s, decoded, "element -> json -> element round trip failed") + + // decode hex and string values + withHexValues := "{\"A\":\"-1\",\"B\":[0,\"0x00000\",\"0x2A\"],\"C\":null,\"D\":\"8000\"}" + + var decodedS S + err = json.Unmarshal([]byte(withHexValues), &decodedS) + assert.NoError(err) + + assert.Equal(s, decodedS, " json with strings -> element failed") + +} + +type testPairElement struct { + element Element + bigint big.Int +} + +func gen() gopter.Gen { + return func(genParams *gopter.GenParameters) *gopter.GenResult { + var g testPairElement + + g.element = Element{ + genParams.NextUint64(), + genParams.NextUint64(), + genParams.NextUint64(), + genParams.NextUint64(), + } + if qElement[3] != ^uint64(0) { + g.element[3] %= (qElement[3] + 1) + } + + for !g.element.smallerThanModulus() { + g.element = Element{ + genParams.NextUint64(), + genParams.NextUint64(), + genParams.NextUint64(), + genParams.NextUint64(), + } + if qElement[3] != ^uint64(0) { + g.element[3] %= (qElement[3] + 1) + } + } + + g.element.BigInt(&g.bigint) + genResult := gopter.NewGenResult(g, gopter.NoShrinker) + return genResult + } +} + +func genRandomFq(genParams *gopter.GenParameters) Element { + var g Element + + g = Element{ + genParams.NextUint64(), + genParams.NextUint64(), + genParams.NextUint64(), + genParams.NextUint64(), + } + + if qElement[3] != ^uint64(0) { + g[3] %= (qElement[3] + 1) + } + + for !g.smallerThanModulus() { + g = Element{ + genParams.NextUint64(), + genParams.NextUint64(), + genParams.NextUint64(), + genParams.NextUint64(), + } + if qElement[3] != ^uint64(0) { + g[3] %= (qElement[3] + 1) + } + } + + return g +} + +func genFull() gopter.Gen { + return func(genParams *gopter.GenParameters) *gopter.GenResult { + a := genRandomFq(genParams) + + var carry uint64 + a[0], carry = bits.Add64(a[0], qElement[0], carry) + a[1], carry = bits.Add64(a[1], qElement[1], carry) + a[2], carry = bits.Add64(a[2], qElement[2], carry) + a[3], _ = bits.Add64(a[3], qElement[3], carry) + + genResult := gopter.NewGenResult(a, gopter.NoShrinker) + return genResult + } +} + +func genElement() gopter.Gen { + return func(genParams *gopter.GenParameters) *gopter.GenResult { + a := genRandomFq(genParams) + genResult := gopter.NewGenResult(a, gopter.NoShrinker) + return genResult + } +} + +func (z *Element) matchVeryBigInt(aHi uint64, aInt *big.Int) error { + var modulus big.Int + var aIntMod big.Int + modulus.SetInt64(1) + modulus.Lsh(&modulus, (Limbs+1)*64) + aIntMod.Mod(aInt, &modulus) + + slice := append(z[:], aHi) + + return bigIntMatchUint64Slice(&aIntMod, slice) +} + +// TODO: Phase out in favor of property based testing +func (z *Element) assertMatchVeryBigInt(t *testing.T, aHi uint64, aInt *big.Int) { + + if err := z.matchVeryBigInt(aHi, aInt); err != nil { + t.Error(err) + } +} + +// bigIntMatchUint64Slice is a test helper to match big.Int words against a uint64 slice +func bigIntMatchUint64Slice(aInt *big.Int, a []uint64) error { + + words := aInt.Bits() + + const steps = 64 / bits.UintSize + const filter uint64 = 0xFFFFFFFFFFFFFFFF >> (64 - bits.UintSize) + for i := 0; i < len(a)*steps; i++ { + + var wI big.Word + + if i < len(words) { + wI = words[i] + } + + aI := a[i/steps] >> ((i * bits.UintSize) % 64) + aI &= filter + + if uint64(wI) != aI { + return fmt.Errorf("bignum mismatch: disagreement on word %d: %x ≠ %x; %d ≠ %d", i, uint64(wI), aI, uint64(wI), aI) + } + } + + return nil +} + +func TestElementInversionApproximation(t *testing.T) { + var x Element + for i := 0; i < 1000; i++ { + x.SetRandom() + + // Normally small elements are unlikely. Here we give them a higher chance + xZeros := mrand.Int() % Limbs //#nosec G404 weak rng is fine here + for j := 1; j < xZeros; j++ { + x[Limbs-j] = 0 + } + + a := approximate(&x, x.BitLen()) + aRef := approximateRef(&x) + + if a != aRef { + t.Error("Approximation mismatch") + } + } +} + +func TestElementInversionCorrectionFactorFormula(t *testing.T) { + const kLimbs = k * Limbs + const power = kLimbs*6 + invIterationsN*(kLimbs-k+1) + factorInt := big.NewInt(1) + factorInt.Lsh(factorInt, power) + factorInt.Mod(factorInt, Modulus()) + + var refFactorInt big.Int + inversionCorrectionFactor := Element{ + inversionCorrectionFactorWord0, + inversionCorrectionFactorWord1, + inversionCorrectionFactorWord2, + inversionCorrectionFactorWord3, + } + inversionCorrectionFactor.toBigInt(&refFactorInt) + + if refFactorInt.Cmp(factorInt) != 0 { + t.Error("mismatch") + } +} + +func TestElementLinearComb(t *testing.T) { + var x Element + var y Element + + for i := 0; i < 1000; i++ { + x.SetRandom() + y.SetRandom() + testLinearComb(t, &x, mrand.Int63(), &y, mrand.Int63()) //#nosec G404 weak rng is fine here + } +} + +// Probably unnecessary post-dev. In case the output of inv is wrong, this checks whether it's only off by a constant factor. +func TestElementInversionCorrectionFactor(t *testing.T) { + + // (1/x)/inv(x) = (1/1)/inv(1) ⇔ inv(1) = x inv(x) + + var one Element + var oneInv Element + one.SetOne() + oneInv.Inverse(&one) + + for i := 0; i < 100; i++ { + var x Element + var xInv Element + x.SetRandom() + xInv.Inverse(&x) + + x.Mul(&x, &xInv) + if !x.Equal(&oneInv) { + t.Error("Correction factor is inconsistent") + } + } + + if !oneInv.Equal(&one) { + var i big.Int + oneInv.BigInt(&i) // no montgomery + i.ModInverse(&i, Modulus()) + var fac Element + fac.setBigInt(&i) // back to montgomery + + var facTimesFac Element + facTimesFac.Mul(&fac, &Element{ + inversionCorrectionFactorWord0, + inversionCorrectionFactorWord1, + inversionCorrectionFactorWord2, + inversionCorrectionFactorWord3, + }) + + t.Error("Correction factor is consistently off by", fac, "Should be", facTimesFac) + } +} + +func TestElementBigNumNeg(t *testing.T) { + var a Element + aHi := negL(&a, 0) + if !a.IsZero() || aHi != 0 { + t.Error("-0 != 0") + } +} + +func TestElementBigNumWMul(t *testing.T) { + var x Element + + for i := 0; i < 1000; i++ { + x.SetRandom() + w := mrand.Int63() //#nosec G404 weak rng is fine here + testBigNumWMul(t, &x, w) + } +} + +func TestElementVeryBigIntConversion(t *testing.T) { + xHi := mrand.Uint64() //#nosec G404 weak rng is fine here + var x Element + x.SetRandom() + var xInt big.Int + x.toVeryBigIntSigned(&xInt, xHi) + x.assertMatchVeryBigInt(t, xHi, &xInt) +} + +type veryBigInt struct { + asInt big.Int + low Element + hi uint64 +} + +// genVeryBigIntSigned if sign == 0, no sign is forced +func genVeryBigIntSigned(sign int) gopter.Gen { + return func(genParams *gopter.GenParameters) *gopter.GenResult { + var g veryBigInt + + g.low = Element{ + genParams.NextUint64(), + genParams.NextUint64(), + genParams.NextUint64(), + genParams.NextUint64(), + } + + g.hi = genParams.NextUint64() + + if sign < 0 { + g.hi |= signBitSelector + } else if sign > 0 { + g.hi &= ^signBitSelector + } + + g.low.toVeryBigIntSigned(&g.asInt, g.hi) + + genResult := gopter.NewGenResult(g, gopter.NoShrinker) + return genResult + } +} + +func TestElementMontReduce(t *testing.T) { + + parameters := gopter.DefaultTestParameters() + if testing.Short() { + parameters.MinSuccessfulTests = nbFuzzShort + } else { + parameters.MinSuccessfulTests = nbFuzz + } + + properties := gopter.NewProperties(parameters) + + gen := genVeryBigIntSigned(0) + + properties.Property("Montgomery reduction is correct", prop.ForAll( + func(g veryBigInt) bool { + var res Element + var resInt big.Int + + montReduce(&resInt, &g.asInt) + res.montReduceSigned(&g.low, g.hi) + + return res.matchVeryBigInt(0, &resInt) == nil + }, + gen, + )) + + properties.TestingRun(t, gopter.ConsoleReporter(false)) +} + +func TestElementMontReduceMultipleOfR(t *testing.T) { + + parameters := gopter.DefaultTestParameters() + if testing.Short() { + parameters.MinSuccessfulTests = nbFuzzShort + } else { + parameters.MinSuccessfulTests = nbFuzz + } + + properties := gopter.NewProperties(parameters) + + gen := ggen.UInt64() + + properties.Property("Montgomery reduction is correct", prop.ForAll( + func(hi uint64) bool { + var zero, res Element + var asInt, resInt big.Int + + zero.toVeryBigIntSigned(&asInt, hi) + + montReduce(&resInt, &asInt) + res.montReduceSigned(&zero, hi) + + return res.matchVeryBigInt(0, &resInt) == nil + }, + gen, + )) + + properties.TestingRun(t, gopter.ConsoleReporter(false)) +} + +func TestElement0Inverse(t *testing.T) { + var x Element + x.Inverse(&x) + if !x.IsZero() { + t.Fail() + } +} + +// TODO: Tests like this (update factor related) are common to all fields. Move them to somewhere non-autogen +func TestUpdateFactorSubtraction(t *testing.T) { + for i := 0; i < 1000; i++ { + + f0, g0 := randomizeUpdateFactors() + f1, g1 := randomizeUpdateFactors() + + for f0-f1 > 1<<31 || f0-f1 <= -1<<31 { + f1 /= 2 + } + + for g0-g1 > 1<<31 || g0-g1 <= -1<<31 { + g1 /= 2 + } + + c0 := updateFactorsCompose(f0, g0) + c1 := updateFactorsCompose(f1, g1) + + cRes := c0 - c1 + fRes, gRes := updateFactorsDecompose(cRes) + + if fRes != f0-f1 || gRes != g0-g1 { + t.Error(i) + } + } +} + +func TestUpdateFactorsDouble(t *testing.T) { + for i := 0; i < 1000; i++ { + f, g := randomizeUpdateFactors() + + if f > 1<<30 || f < (-1<<31+1)/2 { + f /= 2 + if g <= 1<<29 && g >= (-1<<31+1)/4 { + g *= 2 //g was kept small on f's account. Now that we're halving f, we can double g + } + } + + if g > 1<<30 || g < (-1<<31+1)/2 { + g /= 2 + + if f <= 1<<29 && f >= (-1<<31+1)/4 { + f *= 2 //f was kept small on g's account. Now that we're halving g, we can double f + } + } + + c := updateFactorsCompose(f, g) + cD := c * 2 + fD, gD := updateFactorsDecompose(cD) + + if fD != 2*f || gD != 2*g { + t.Error(i) + } + } +} + +func TestUpdateFactorsNeg(t *testing.T) { + var fMistake bool + for i := 0; i < 1000; i++ { + f, g := randomizeUpdateFactors() + + if f == 0x80000000 || g == 0x80000000 { + // Update factors this large can only have been obtained after 31 iterations and will therefore never be negated + // We don't have capacity to store -2³¹ + // Repeat this iteration + i-- + continue + } + + c := updateFactorsCompose(f, g) + nc := -c + nf, ng := updateFactorsDecompose(nc) + fMistake = fMistake || nf != -f + if nf != -f || ng != -g { + t.Errorf("Mismatch iteration #%d:\n%d, %d ->\n %d -> %d ->\n %d, %d\n Inputs in hex: %X, %X", + i, f, g, c, nc, nf, ng, f, g) + } + } + if fMistake { + t.Error("Mistake with f detected") + } else { + t.Log("All good with f") + } +} + +func TestUpdateFactorsNeg0(t *testing.T) { + c := updateFactorsCompose(0, 0) + t.Logf("c(0,0) = %X", c) + cn := -c + + if c != cn { + t.Error("Negation of zero update factors should yield the same result.") + } +} + +func TestUpdateFactorDecomposition(t *testing.T) { + var negSeen bool + + for i := 0; i < 1000; i++ { + + f, g := randomizeUpdateFactors() + + if f <= -(1<<31) || f > 1<<31 { + t.Fatal("f out of range") + } + + negSeen = negSeen || f < 0 + + c := updateFactorsCompose(f, g) + + fBack, gBack := updateFactorsDecompose(c) + + if f != fBack || g != gBack { + t.Errorf("(%d, %d) -> %d -> (%d, %d)\n", f, g, c, fBack, gBack) + } + } + + if !negSeen { + t.Fatal("No negative f factors") + } +} + +func TestUpdateFactorInitialValues(t *testing.T) { + + f0, g0 := updateFactorsDecompose(updateFactorIdentityMatrixRow0) + f1, g1 := updateFactorsDecompose(updateFactorIdentityMatrixRow1) + + if f0 != 1 || g0 != 0 || f1 != 0 || g1 != 1 { + t.Error("Update factor initial value constants are incorrect") + } +} + +func TestUpdateFactorsRandomization(t *testing.T) { + var maxLen int + + //t.Log("|f| + |g| is not to exceed", 1 << 31) + for i := 0; i < 1000; i++ { + f, g := randomizeUpdateFactors() + lf, lg := abs64T32(f), abs64T32(g) + absSum := lf + lg + if absSum >= 1<<31 { + + if absSum == 1<<31 { + maxLen++ + } else { + t.Error(i, "Sum of absolute values too large, f =", f, ",g =", g, ",|f| + |g| =", absSum) + } + } + } + + if maxLen == 0 { + t.Error("max len not observed") + } else { + t.Log(maxLen, "maxLens observed") + } +} + +func randomizeUpdateFactor(absLimit uint32) int64 { + const maxSizeLikelihood = 10 + maxSize := mrand.Intn(maxSizeLikelihood) //#nosec G404 weak rng is fine here + + absLimit64 := int64(absLimit) + var f int64 + switch maxSize { + case 0: + f = absLimit64 + case 1: + f = -absLimit64 + default: + f = int64(mrand.Uint64()%(2*uint64(absLimit64)+1)) - absLimit64 //#nosec G404 weak rng is fine here + } + + if f > 1<<31 { + return 1 << 31 + } else if f < -1<<31+1 { + return -1<<31 + 1 + } + + return f +} + +func abs64T32(f int64) uint32 { + if f >= 1<<32 || f < -1<<32 { + panic("f out of range") + } + + if f < 0 { + return uint32(-f) + } + return uint32(f) +} + +func randomizeUpdateFactors() (int64, int64) { + var f [2]int64 + b := mrand.Int() % 2 //#nosec G404 weak rng is fine here + + f[b] = randomizeUpdateFactor(1 << 31) + + //As per the paper, |f| + |g| \le 2³¹. + f[1-b] = randomizeUpdateFactor(1<<31 - abs64T32(f[b])) + + //Patching another edge case + if f[0]+f[1] == -1<<31 { + b = mrand.Int() % 2 //#nosec G404 weak rng is fine here + f[b]++ + } + + return f[0], f[1] +} + +func testLinearComb(t *testing.T, x *Element, xC int64, y *Element, yC int64) { + + var p1 big.Int + x.toBigInt(&p1) + p1.Mul(&p1, big.NewInt(xC)) + + var p2 big.Int + y.toBigInt(&p2) + p2.Mul(&p2, big.NewInt(yC)) + + p1.Add(&p1, &p2) + p1.Mod(&p1, Modulus()) + montReduce(&p1, &p1) + + var z Element + z.linearComb(x, xC, y, yC) + z.assertMatchVeryBigInt(t, 0, &p1) +} + +func testBigNumWMul(t *testing.T, a *Element, c int64) { + var aHi uint64 + var aTimes Element + aHi = aTimes.mulWNonModular(a, c) + + assertMulProduct(t, a, c, &aTimes, aHi) +} + +func updateFactorsCompose(f int64, g int64) int64 { + return f + g<<32 +} + +var rInv big.Int + +func montReduce(res *big.Int, x *big.Int) { + if rInv.BitLen() == 0 { // initialization + rInv.SetUint64(1) + rInv.Lsh(&rInv, Limbs*64) + rInv.ModInverse(&rInv, Modulus()) + } + res.Mul(x, &rInv) + res.Mod(res, Modulus()) +} + +func (z *Element) toVeryBigIntUnsigned(i *big.Int, xHi uint64) { + z.toBigInt(i) + var upperWord big.Int + upperWord.SetUint64(xHi) + upperWord.Lsh(&upperWord, Limbs*64) + i.Add(&upperWord, i) +} + +func (z *Element) toVeryBigIntSigned(i *big.Int, xHi uint64) { + z.toVeryBigIntUnsigned(i, xHi) + if signBitSelector&xHi != 0 { + twosCompModulus := big.NewInt(1) + twosCompModulus.Lsh(twosCompModulus, (Limbs+1)*64) + i.Sub(i, twosCompModulus) + } +} + +func assertMulProduct(t *testing.T, x *Element, c int64, result *Element, resultHi uint64) big.Int { + var xInt big.Int + x.toBigInt(&xInt) + + xInt.Mul(&xInt, big.NewInt(c)) + + result.assertMatchVeryBigInt(t, resultHi, &xInt) + return xInt +} + +func approximateRef(x *Element) uint64 { + + var asInt big.Int + x.toBigInt(&asInt) + n := x.BitLen() + + if n <= 64 { + return asInt.Uint64() + } + + modulus := big.NewInt(1 << 31) + var lo big.Int + lo.Mod(&asInt, modulus) + + modulus.Lsh(modulus, uint(n-64)) + var hi big.Int + hi.Div(&asInt, modulus) + hi.Lsh(&hi, 31) + + hi.Add(&hi, &lo) + return hi.Uint64() +} diff --git a/ecc/grumpkin/fr/gkr/gkr.go b/ecc/grumpkin/fr/gkr/gkr.go new file mode 100644 index 0000000000..f4d1880a1a --- /dev/null +++ b/ecc/grumpkin/fr/gkr/gkr.go @@ -0,0 +1,934 @@ +// Copyright 2020-2025 Consensys Software Inc. +// Licensed under the Apache License, Version 2.0. See the LICENSE file for details. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package gkr + +import ( + "errors" + "fmt" + "github.com/consensys/gnark-crypto/ecc/grumpkin/fr" + "github.com/consensys/gnark-crypto/ecc/grumpkin/fr/polynomial" + "github.com/consensys/gnark-crypto/ecc/grumpkin/fr/sumcheck" + fiatshamir "github.com/consensys/gnark-crypto/fiat-shamir" + "github.com/consensys/gnark-crypto/internal/parallel" + "github.com/consensys/gnark-crypto/utils" + "math/big" + "strconv" + "sync" +) + +// The goal is to prove/verify evaluations of many instances of the same circuit + +// Gate must be a low-degree polynomial +type Gate interface { + Evaluate(...fr.Element) fr.Element + Degree() int +} + +type Wire struct { + Gate Gate + Inputs []*Wire // if there are no Inputs, the wire is assumed an input wire + nbUniqueOutputs int // number of other wires using it as input, not counting duplicates (i.e. providing two inputs to the same gate counts as one) +} + +type Circuit []Wire + +func (w Wire) IsInput() bool { + return len(w.Inputs) == 0 +} + +func (w Wire) IsOutput() bool { + return w.nbUniqueOutputs == 0 +} + +func (w Wire) NbClaims() int { + if w.IsOutput() { + return 1 + } + return w.nbUniqueOutputs +} + +func (w Wire) noProof() bool { + return w.IsInput() && w.NbClaims() == 1 +} + +func (c Circuit) maxGateDegree() int { + res := 1 + for i := range c { + if !c[i].IsInput() { + res = max(res, c[i].Gate.Degree()) + } + } + return res +} + +// WireAssignment is assignment of values to the same wire across many instances of the circuit +type WireAssignment map[*Wire]polynomial.MultiLin + +type Proof []sumcheck.Proof // for each layer, for each wire, a sumcheck (for each variable, a polynomial) + +type eqTimesGateEvalSumcheckLazyClaims struct { + wire *Wire + evaluationPoints [][]fr.Element + claimedEvaluations []fr.Element + manager *claimsManager // WARNING: Circular references +} + +func (e *eqTimesGateEvalSumcheckLazyClaims) ClaimsNum() int { + return len(e.evaluationPoints) +} + +func (e *eqTimesGateEvalSumcheckLazyClaims) VarsNum() int { + return len(e.evaluationPoints[0]) +} + +func (e *eqTimesGateEvalSumcheckLazyClaims) CombinedSum(a fr.Element) fr.Element { + evalsAsPoly := polynomial.Polynomial(e.claimedEvaluations) + return evalsAsPoly.Eval(&a) +} + +func (e *eqTimesGateEvalSumcheckLazyClaims) Degree(int) int { + return 1 + e.wire.Gate.Degree() +} + +func (e *eqTimesGateEvalSumcheckLazyClaims) VerifyFinalEval(r []fr.Element, combinationCoeff fr.Element, purportedValue fr.Element, proof interface{}) error { + inputEvaluationsNoRedundancy := proof.([]fr.Element) + + // the eq terms + numClaims := len(e.evaluationPoints) + evaluation := polynomial.EvalEq(e.evaluationPoints[numClaims-1], r) + for i := numClaims - 2; i >= 0; i-- { + evaluation.Mul(&evaluation, &combinationCoeff) + eq := polynomial.EvalEq(e.evaluationPoints[i], r) + evaluation.Add(&evaluation, &eq) + } + + // the g(...) term + var gateEvaluation fr.Element + if e.wire.IsInput() { + gateEvaluation = e.manager.assignment[e.wire].Evaluate(r, e.manager.memPool) + } else { + inputEvaluations := make([]fr.Element, len(e.wire.Inputs)) + indexesInProof := make(map[*Wire]int, len(inputEvaluationsNoRedundancy)) + + proofI := 0 + for inI, in := range e.wire.Inputs { + indexInProof, found := indexesInProof[in] + if !found { + indexInProof = proofI + indexesInProof[in] = indexInProof + + // defer verification, store new claim + e.manager.add(in, r, inputEvaluationsNoRedundancy[indexInProof]) + proofI++ + } + inputEvaluations[inI] = inputEvaluationsNoRedundancy[indexInProof] + } + if proofI != len(inputEvaluationsNoRedundancy) { + return fmt.Errorf("%d input wire evaluations given, %d expected", len(inputEvaluationsNoRedundancy), proofI) + } + gateEvaluation = e.wire.Gate.Evaluate(inputEvaluations...) + } + + evaluation.Mul(&evaluation, &gateEvaluation) + + if evaluation.Equal(&purportedValue) { + return nil + } + return errors.New("incompatible evaluations") +} + +type eqTimesGateEvalSumcheckClaims struct { + wire *Wire + evaluationPoints [][]fr.Element // x in the paper + claimedEvaluations []fr.Element // y in the paper + manager *claimsManager + + inputPreprocessors []polynomial.MultiLin // P_u in the paper, so that we don't need to pass along all the circuit's evaluations + + eq polynomial.MultiLin // ∑_i τ_i eq(x_i, -) +} + +func (c *eqTimesGateEvalSumcheckClaims) Combine(combinationCoeff fr.Element) polynomial.Polynomial { + varsNum := c.VarsNum() + eqLength := 1 << varsNum + claimsNum := c.ClaimsNum() + // initialize the eq tables + c.eq = c.manager.memPool.Make(eqLength) + + c.eq[0].SetOne() + c.eq.Eq(c.evaluationPoints[0]) + + newEq := polynomial.MultiLin(c.manager.memPool.Make(eqLength)) + aI := combinationCoeff + + for k := 1; k < claimsNum; k++ { //TODO: parallelizable? + // define eq_k = aᵏ eq(x_k1, ..., x_kn, *, ..., *) where x_ki are the evaluation points + newEq[0].Set(&aI) + + c.eqAcc(c.eq, newEq, c.evaluationPoints[k]) + + // newEq.Eq(c.evaluationPoints[k]) + // eqAsPoly := polynomial.Polynomial(c.eq) //just semantics + // eqAsPoly.Add(eqAsPoly, polynomial.Polynomial(newEq)) + + if k+1 < claimsNum { + aI.Mul(&aI, &combinationCoeff) + } + } + + c.manager.memPool.Dump(newEq) + + // from this point on the claim is a rather simple one: g = E(h) × R_v (P_u0(h), ...) where E and the P_u are multilinear and R_v is of low-degree + + return c.computeGJ() +} + +// eqAcc sets m to an eq table at q and then adds it to e +func (c *eqTimesGateEvalSumcheckClaims) eqAcc(e, m polynomial.MultiLin, q []fr.Element) { + n := len(q) + + //At the end of each iteration, m(h₁, ..., hₙ) = Eq(q₁, ..., qᵢ₊₁, h₁, ..., hᵢ₊₁) + for i := range q { // In the comments we use a 1-based index so q[i] = qᵢ₊₁ + // go through all assignments of (b₁, ..., bᵢ) ∈ {0,1}ⁱ + const threshold = 1 << 6 + k := 1 << i + if k < threshold { + for j := 0; j < k; j++ { + j0 := j << (n - i) // bᵢ₊₁ = 0 + j1 := j0 + 1<<(n-1-i) // bᵢ₊₁ = 1 + + m[j1].Mul(&q[i], &m[j0]) // Eq(q₁, ..., qᵢ₊₁, b₁, ..., bᵢ, 1) = Eq(q₁, ..., qᵢ, b₁, ..., bᵢ) Eq(qᵢ₊₁, 1) = Eq(q₁, ..., qᵢ, b₁, ..., bᵢ) qᵢ₊₁ + m[j0].Sub(&m[j0], &m[j1]) // Eq(q₁, ..., qᵢ₊₁, b₁, ..., bᵢ, 0) = Eq(q₁, ..., qᵢ, b₁, ..., bᵢ) Eq(qᵢ₊₁, 0) = Eq(q₁, ..., qᵢ, b₁, ..., bᵢ) (1-qᵢ₊₁) + } + } else { + c.manager.workers.Submit(k, func(start, end int) { + for j := start; j < end; j++ { + j0 := j << (n - i) // bᵢ₊₁ = 0 + j1 := j0 + 1<<(n-1-i) // bᵢ₊₁ = 1 + + m[j1].Mul(&q[i], &m[j0]) // Eq(q₁, ..., qᵢ₊₁, b₁, ..., bᵢ, 1) = Eq(q₁, ..., qᵢ, b₁, ..., bᵢ) Eq(qᵢ₊₁, 1) = Eq(q₁, ..., qᵢ, b₁, ..., bᵢ) qᵢ₊₁ + m[j0].Sub(&m[j0], &m[j1]) // Eq(q₁, ..., qᵢ₊₁, b₁, ..., bᵢ, 0) = Eq(q₁, ..., qᵢ, b₁, ..., bᵢ) Eq(qᵢ₊₁, 0) = Eq(q₁, ..., qᵢ, b₁, ..., bᵢ) (1-qᵢ₊₁) + } + }, 1024).Wait() + } + + } + c.manager.workers.Submit(len(e), func(start, end int) { + for i := start; i < end; i++ { + e[i].Add(&e[i], &m[i]) + } + }, 512).Wait() + + // e.Add(e, polynomial.Polynomial(m)) +} + +// computeGJ: gⱼ = ∑_{0≤i<2ⁿ⁻ʲ} g(r₁, r₂, ..., rⱼ₋₁, Xⱼ, i...) = ∑_{0≤i<2ⁿ⁻ʲ} E(r₁, ..., X_j, i...) R_v( P_u0(r₁, ..., X_j, i...), ... ) where E = ∑ eq_k +// the polynomial is represented by the evaluations g_j(1), g_j(2), ..., g_j(deg(g_j)). +// The value g_j(0) is inferred from the equation g_j(0) + g_j(1) = gⱼ₋₁(rⱼ₋₁). By convention, g₀ is a constant polynomial equal to the claimed sum. +func (c *eqTimesGateEvalSumcheckClaims) computeGJ() polynomial.Polynomial { + + degGJ := 1 + c.wire.Gate.Degree() // guaranteed to be no smaller than the actual deg(g_j) + nbGateIn := len(c.inputPreprocessors) + + // Let f ∈ { E(r₁, ..., X_j, d...) } ∪ {P_ul(r₁, ..., X_j, d...) }. It is linear in X_j, so f(m) = m×(f(1) - f(0)) + f(0), and f(0), f(1) are easily computed from the bookkeeping tables + s := make([]polynomial.MultiLin, nbGateIn+1) + s[0] = c.eq + copy(s[1:], c.inputPreprocessors) + + // Perf-TODO: Collate once at claim "combination" time and not again. then, even folding can be done in one operation every time "next" is called + nbInner := len(s) // wrt output, which has high nbOuter and low nbInner + nbOuter := len(s[0]) / 2 + + gJ := make([]fr.Element, degGJ) + var mu sync.Mutex + computeAll := func(start, end int) { + var step fr.Element + + res := make([]fr.Element, degGJ) + operands := make([]fr.Element, degGJ*nbInner) + + for i := start; i < end; i++ { + + block := nbOuter + i + for j := 0; j < nbInner; j++ { + step.Set(&s[j][i]) + operands[j].Set(&s[j][block]) + step.Sub(&operands[j], &step) + for d := 1; d < degGJ; d++ { + operands[d*nbInner+j].Add(&operands[(d-1)*nbInner+j], &step) + } + } + + _s := 0 + _e := nbInner + for d := 0; d < degGJ; d++ { + summand := c.wire.Gate.Evaluate(operands[_s+1 : _e]...) + summand.Mul(&summand, &operands[_s]) + res[d].Add(&res[d], &summand) + _s, _e = _e, _e+nbInner + } + } + mu.Lock() + for i := 0; i < len(gJ); i++ { + gJ[i].Add(&gJ[i], &res[i]) + } + mu.Unlock() + } + + const minBlockSize = 64 + + if nbOuter < minBlockSize { + // no parallelization + computeAll(0, nbOuter) + } else { + c.manager.workers.Submit(nbOuter, computeAll, minBlockSize).Wait() + } + + // Perf-TODO: Separate functions Gate.TotalDegree and Gate.Degree(i) so that we get to use possibly smaller values for degGJ. Won't help with MiMC though + + return gJ +} + +// Next first folds the "preprocessing" and "eq" polynomials then compute the new g_j +func (c *eqTimesGateEvalSumcheckClaims) Next(element fr.Element) polynomial.Polynomial { + const minBlockSize = 512 + n := len(c.eq) / 2 + if n < minBlockSize { + // no parallelization + for i := 0; i < len(c.inputPreprocessors); i++ { + c.inputPreprocessors[i].Fold(element) + } + c.eq.Fold(element) + } else { + wgs := make([]*sync.WaitGroup, len(c.inputPreprocessors)) + for i := 0; i < len(c.inputPreprocessors); i++ { + wgs[i] = c.manager.workers.Submit(n, c.inputPreprocessors[i].FoldParallel(element), minBlockSize) + } + c.manager.workers.Submit(n, c.eq.FoldParallel(element), minBlockSize).Wait() + for _, wg := range wgs { + wg.Wait() + } + } + + return c.computeGJ() +} + +func (c *eqTimesGateEvalSumcheckClaims) VarsNum() int { + return len(c.evaluationPoints[0]) +} + +func (c *eqTimesGateEvalSumcheckClaims) ClaimsNum() int { + return len(c.claimedEvaluations) +} + +func (c *eqTimesGateEvalSumcheckClaims) ProveFinalEval(r []fr.Element) interface{} { + + //defer the proof, return list of claims + evaluations := make([]fr.Element, 0, len(c.wire.Inputs)) + noMoreClaimsAllowed := make(map[*Wire]struct{}, len(c.inputPreprocessors)) + noMoreClaimsAllowed[c.wire] = struct{}{} + + for inI, in := range c.wire.Inputs { + puI := c.inputPreprocessors[inI] + if _, found := noMoreClaimsAllowed[in]; !found { + noMoreClaimsAllowed[in] = struct{}{} + puI.Fold(r[len(r)-1]) + c.manager.add(in, r, puI[0]) + evaluations = append(evaluations, puI[0]) + } + c.manager.memPool.Dump(puI) + } + + c.manager.memPool.Dump(c.claimedEvaluations, c.eq) + + return evaluations +} + +type claimsManager struct { + claimsMap map[*Wire]*eqTimesGateEvalSumcheckLazyClaims + assignment WireAssignment + memPool *polynomial.Pool + workers *utils.WorkerPool +} + +func newClaimsManager(c Circuit, assignment WireAssignment, o settings) (claims claimsManager) { + claims.assignment = assignment + claims.claimsMap = make(map[*Wire]*eqTimesGateEvalSumcheckLazyClaims, len(c)) + claims.memPool = o.pool + claims.workers = o.workers + + for i := range c { + wire := &c[i] + + claims.claimsMap[wire] = &eqTimesGateEvalSumcheckLazyClaims{ + wire: wire, + evaluationPoints: make([][]fr.Element, 0, wire.NbClaims()), + claimedEvaluations: claims.memPool.Make(wire.NbClaims()), + manager: &claims, + } + } + return +} + +func (m *claimsManager) add(wire *Wire, evaluationPoint []fr.Element, evaluation fr.Element) { + claim := m.claimsMap[wire] + i := len(claim.evaluationPoints) + claim.claimedEvaluations[i] = evaluation + claim.evaluationPoints = append(claim.evaluationPoints, evaluationPoint) +} + +func (m *claimsManager) getLazyClaim(wire *Wire) *eqTimesGateEvalSumcheckLazyClaims { + return m.claimsMap[wire] +} + +func (m *claimsManager) getClaim(wire *Wire) *eqTimesGateEvalSumcheckClaims { + lazy := m.claimsMap[wire] + res := &eqTimesGateEvalSumcheckClaims{ + wire: wire, + evaluationPoints: lazy.evaluationPoints, + claimedEvaluations: lazy.claimedEvaluations, + manager: m, + } + + if wire.IsInput() { + res.inputPreprocessors = []polynomial.MultiLin{m.memPool.Clone(m.assignment[wire])} + } else { + res.inputPreprocessors = make([]polynomial.MultiLin, len(wire.Inputs)) + + for inputI, inputW := range wire.Inputs { + res.inputPreprocessors[inputI] = m.memPool.Clone(m.assignment[inputW]) //will be edited later, so must be deep copied + } + } + return res +} + +func (m *claimsManager) deleteClaim(wire *Wire) { + delete(m.claimsMap, wire) +} + +type settings struct { + pool *polynomial.Pool + sorted []*Wire + transcript *fiatshamir.Transcript + transcriptPrefix string + nbVars int + workers *utils.WorkerPool +} + +type Option func(*settings) + +func WithPool(pool *polynomial.Pool) Option { + return func(options *settings) { + options.pool = pool + } +} + +func WithSortedCircuit(sorted []*Wire) Option { + return func(options *settings) { + options.sorted = sorted + } +} + +func WithWorkers(workers *utils.WorkerPool) Option { + return func(options *settings) { + options.workers = workers + } +} + +// MemoryRequirements returns an increasing vector of memory allocation sizes required for proving a GKR statement +func (c Circuit) MemoryRequirements(nbInstances int) []int { + res := []int{256, nbInstances, nbInstances * (c.maxGateDegree() + 1)} + + if res[0] > res[1] { // make sure it's sorted + res[0], res[1] = res[1], res[0] + if res[1] > res[2] { + res[1], res[2] = res[2], res[1] + } + } + + return res +} + +func setup(c Circuit, assignment WireAssignment, transcriptSettings fiatshamir.Settings, options ...Option) (settings, error) { + var o settings + var err error + for _, option := range options { + option(&o) + } + + o.nbVars = assignment.NumVars() + nbInstances := assignment.NumInstances() + if 1< 1 { //combine the claims + size++ + } + size += logNbInstances // full run of sumcheck on logNbInstances variables + } + + nums := make([]string, max(len(sorted), logNbInstances)) + for i := range nums { + nums[i] = strconv.Itoa(i) + } + + challenges := make([]string, size) + + // output wire claims + firstChallengePrefix := prefix + "fC." + for j := 0; j < logNbInstances; j++ { + challenges[j] = firstChallengePrefix + nums[j] + } + j := logNbInstances + for i := len(sorted) - 1; i >= 0; i-- { + if sorted[i].noProof() { + continue + } + wirePrefix := prefix + "w" + nums[i] + "." + + if sorted[i].NbClaims() > 1 { + challenges[j] = wirePrefix + "comb" + j++ + } + + partialSumPrefix := wirePrefix + "pSP." + for k := 0; k < logNbInstances; k++ { + challenges[j] = partialSumPrefix + nums[k] + j++ + } + } + return challenges +} + +func getFirstChallengeNames(logNbInstances int, prefix string) []string { + res := make([]string, logNbInstances) + firstChallengePrefix := prefix + "fC." + for i := 0; i < logNbInstances; i++ { + res[i] = firstChallengePrefix + strconv.Itoa(i) + } + return res +} + +func getChallenges(transcript *fiatshamir.Transcript, names []string) ([]fr.Element, error) { + res := make([]fr.Element, len(names)) + for i, name := range names { + if bytes, err := transcript.ComputeChallenge(name); err == nil { + res[i].SetBytes(bytes) + } else { + return nil, err + } + } + return res, nil +} + +// Prove consistency of the claimed assignment +func Prove(c Circuit, assignment WireAssignment, transcriptSettings fiatshamir.Settings, options ...Option) (Proof, error) { + o, err := setup(c, assignment, transcriptSettings, options...) + if err != nil { + return nil, err + } + defer o.workers.Stop() + + claims := newClaimsManager(c, assignment, o) + + proof := make(Proof, len(c)) + // firstChallenge called rho in the paper + var firstChallenge []fr.Element + firstChallenge, err = getChallenges(o.transcript, getFirstChallengeNames(o.nbVars, o.transcriptPrefix)) + if err != nil { + return nil, err + } + + wirePrefix := o.transcriptPrefix + "w" + var baseChallenge [][]byte + for i := len(c) - 1; i >= 0; i-- { + + wire := o.sorted[i] + + if wire.IsOutput() { + claims.add(wire, firstChallenge, assignment[wire].Evaluate(firstChallenge, claims.memPool)) + } + + claim := claims.getClaim(wire) + if wire.noProof() { // input wires with one claim only + proof[i] = sumcheck.Proof{ + PartialSumPolys: []polynomial.Polynomial{}, + FinalEvalProof: []fr.Element{}, + } + } else { + if proof[i], err = sumcheck.Prove( + claim, fiatshamir.WithTranscript(o.transcript, wirePrefix+strconv.Itoa(i)+".", baseChallenge...), + ); err != nil { + return proof, err + } + + finalEvalProof := proof[i].FinalEvalProof.([]fr.Element) + baseChallenge = make([][]byte, len(finalEvalProof)) + for j := range finalEvalProof { + bytes := finalEvalProof[j].Bytes() + baseChallenge[j] = bytes[:] + } + } + // the verifier checks a single claim about input wires itself + claims.deleteClaim(wire) + } + + return proof, nil +} + +// Verify the consistency of the claimed output with the claimed input +// Unlike in Prove, the assignment argument need not be complete +func Verify(c Circuit, assignment WireAssignment, proof Proof, transcriptSettings fiatshamir.Settings, options ...Option) error { + o, err := setup(c, assignment, transcriptSettings, options...) + if err != nil { + return err + } + defer o.workers.Stop() + + claims := newClaimsManager(c, assignment, o) + + var firstChallenge []fr.Element + firstChallenge, err = getChallenges(o.transcript, getFirstChallengeNames(o.nbVars, o.transcriptPrefix)) + if err != nil { + return err + } + + wirePrefix := o.transcriptPrefix + "w" + var baseChallenge [][]byte + for i := len(c) - 1; i >= 0; i-- { + wire := o.sorted[i] + + if wire.IsOutput() { + claims.add(wire, firstChallenge, assignment[wire].Evaluate(firstChallenge, claims.memPool)) + } + + proofW := proof[i] + finalEvalProof := proofW.FinalEvalProof.([]fr.Element) + claim := claims.getLazyClaim(wire) + if wire.noProof() { // input wires with one claim only + // make sure the proof is empty + if len(finalEvalProof) != 0 || len(proofW.PartialSumPolys) != 0 { + return errors.New("no proof allowed for input wire with a single claim") + } + + if wire.NbClaims() == 1 { // input wire + // simply evaluate and see if it matches + evaluation := assignment[wire].Evaluate(claim.evaluationPoints[0], claims.memPool) + if !claim.claimedEvaluations[0].Equal(&evaluation) { + return errors.New("incorrect input wire claim") + } + } + } else if err = sumcheck.Verify( + claim, proof[i], fiatshamir.WithTranscript(o.transcript, wirePrefix+strconv.Itoa(i)+".", baseChallenge...), + ); err == nil { + baseChallenge = make([][]byte, len(finalEvalProof)) + for j := range finalEvalProof { + bytes := finalEvalProof[j].Bytes() + baseChallenge[j] = bytes[:] + } + } else { + return fmt.Errorf("sumcheck proof rejected: %v", err) //TODO: Any polynomials to dump? + } + claims.deleteClaim(wire) + } + return nil +} + +// outputsList also sets the nbUniqueOutputs fields. It also sets the wire metadata. +func outputsList(c Circuit, indexes map[*Wire]int) [][]int { + res := make([][]int, len(c)) + for i := range c { + res[i] = make([]int, 0) + c[i].nbUniqueOutputs = 0 + if c[i].IsInput() { + c[i].Gate = IdentityGate{} + } + } + ins := make(map[int]struct{}, len(c)) + for i := range c { + for k := range ins { // clear map + delete(ins, k) + } + for _, in := range c[i].Inputs { + inI := indexes[in] + res[inI] = append(res[inI], i) + if _, ok := ins[inI]; !ok { + in.nbUniqueOutputs++ + ins[inI] = struct{}{} + } + } + } + return res +} + +type topSortData struct { + outputs [][]int + status []int // status > 0 indicates number of inputs left to be ready. status = 0 means ready. status = -1 means done + index map[*Wire]int + leastReady int +} + +func (d *topSortData) markDone(i int) { + + d.status[i] = -1 + + for _, outI := range d.outputs[i] { + d.status[outI]-- + if d.status[outI] == 0 && outI < d.leastReady { + d.leastReady = outI + } + } + + for d.leastReady < len(d.status) && d.status[d.leastReady] != 0 { + d.leastReady++ + } +} + +func indexMap(c Circuit) map[*Wire]int { + res := make(map[*Wire]int, len(c)) + for i := range c { + res[&c[i]] = i + } + return res +} + +func statusList(c Circuit) []int { + res := make([]int, len(c)) + for i := range c { + res[i] = len(c[i].Inputs) + } + return res +} + +// topologicalSort sorts the wires in order of dependence. Such that for any wire, any one it depends on +// occurs before it. It tries to stick to the input order as much as possible. An already sorted list will remain unchanged. +// It also sets the nbOutput flags, and a dummy IdentityGate for input wires. +// Worst-case inefficient O(n^2), but that probably won't matter since the circuits are small. +// Furthermore, it is efficient with already-close-to-sorted lists, which are the expected input +func topologicalSort(c Circuit) []*Wire { + var data topSortData + data.index = indexMap(c) + data.outputs = outputsList(c, data.index) + data.status = statusList(c) + sorted := make([]*Wire, len(c)) + + for data.leastReady = 0; data.status[data.leastReady] != 0; data.leastReady++ { + } + + for i := range c { + sorted[i] = &c[data.leastReady] + data.markDone(data.leastReady) + } + + return sorted +} + +// Complete the circuit evaluation from input values +func (a WireAssignment) Complete(c Circuit) WireAssignment { + + sortedWires := topologicalSort(c) + nbInstances := a.NumInstances() + maxNbIns := 0 + + for _, w := range sortedWires { + maxNbIns = max(maxNbIns, len(w.Inputs)) + if a[w] == nil { + a[w] = make([]fr.Element, nbInstances) + } + } + + parallel.Execute(nbInstances, func(start, end int) { + ins := make([]fr.Element, maxNbIns) + for i := start; i < end; i++ { + for _, w := range sortedWires { + if !w.IsInput() { + for inI, in := range w.Inputs { + ins[inI] = a[in][i] + } + a[w][i] = w.Gate.Evaluate(ins[:len(w.Inputs)]...) + } + } + } + }) + + return a +} + +func (a WireAssignment) NumInstances() int { + for _, aW := range a { + return len(aW) + } + panic("empty assignment") +} + +func (a WireAssignment) NumVars() int { + for _, aW := range a { + return aW.NumVars() + } + panic("empty assignment") +} + +// SerializeToBigInts flattens a proof object into the given slice of big.Ints +// useful in gnark hints. TODO: Change propagation: Once this is merged, it will duplicate some code in std/gkr/bn254Prover.go. Remove that in favor of this +func (p Proof) SerializeToBigInts(outs []*big.Int) { + offset := 0 + for i := range p { + for _, poly := range p[i].PartialSumPolys { + frToBigInts(outs[offset:], poly) + offset += len(poly) + } + if p[i].FinalEvalProof != nil { + finalEvalProof := p[i].FinalEvalProof.([]fr.Element) + frToBigInts(outs[offset:], finalEvalProof) + offset += len(finalEvalProof) + } + } +} + +func frToBigInts(dst []*big.Int, src []fr.Element) { + for i := range src { + src[i].BigInt(dst[i]) + } +} + +// Gates defined by name +var Gates = map[string]Gate{ + "identity": IdentityGate{}, + "add": AddGate{}, + "sub": SubGate{}, + "neg": NegGate{}, + "mul": MulGate(2), +} + +type IdentityGate struct{} +type AddGate struct{} +type MulGate int +type SubGate struct{} +type NegGate struct{} + +func (IdentityGate) Evaluate(input ...fr.Element) fr.Element { + return input[0] +} + +func (IdentityGate) Degree() int { + return 1 +} + +func (g AddGate) Evaluate(x ...fr.Element) (res fr.Element) { + switch len(x) { + case 0: + // set zero + case 1: + res.Set(&x[0]) + default: + res.Add(&x[0], &x[1]) + for i := 2; i < len(x); i++ { + res.Add(&res, &x[i]) + } + } + return +} + +func (g AddGate) Degree() int { + return 1 +} + +func (g MulGate) Evaluate(x ...fr.Element) (res fr.Element) { + if len(x) != int(g) { + panic("wrong input count") + } + switch len(x) { + case 0: + res.SetOne() + case 1: + res.Set(&x[0]) + default: + res.Mul(&x[0], &x[1]) + for i := 2; i < len(x); i++ { + res.Mul(&res, &x[i]) + } + } + return +} + +func (g MulGate) Degree() int { + return int(g) +} + +func (g SubGate) Evaluate(element ...fr.Element) (diff fr.Element) { + if len(element) > 2 { + panic("not implemented") //TODO + } + diff.Sub(&element[0], &element[1]) + return +} + +func (g SubGate) Degree() int { + return 1 +} + +func (g NegGate) Evaluate(element ...fr.Element) (neg fr.Element) { + if len(element) != 1 { + panic("univariate gate") + } + neg.Neg(&element[0]) + return +} + +func (g NegGate) Degree() int { + return 1 +} diff --git a/ecc/grumpkin/fr/gkr/gkr_test.go b/ecc/grumpkin/fr/gkr/gkr_test.go new file mode 100644 index 0000000000..2416572754 --- /dev/null +++ b/ecc/grumpkin/fr/gkr/gkr_test.go @@ -0,0 +1,739 @@ +// Copyright 2020-2025 Consensys Software Inc. +// Licensed under the Apache License, Version 2.0. See the LICENSE file for details. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package gkr + +import ( + "encoding/json" + "fmt" + "github.com/consensys/gnark-crypto/ecc/grumpkin/fr" + "github.com/consensys/gnark-crypto/ecc/grumpkin/fr/mimc" + "github.com/consensys/gnark-crypto/ecc/grumpkin/fr/polynomial" + "github.com/consensys/gnark-crypto/ecc/grumpkin/fr/sumcheck" + "github.com/consensys/gnark-crypto/ecc/grumpkin/fr/test_vector_utils" + fiatshamir "github.com/consensys/gnark-crypto/fiat-shamir" + "github.com/consensys/gnark-crypto/utils" + "github.com/stretchr/testify/assert" + "hash" + "os" + "path/filepath" + "reflect" + "strconv" + "testing" + "time" +) + +func TestNoGateTwoInstances(t *testing.T) { + // Testing a single instance is not possible because the sumcheck implementation doesn't cover the trivial 0-variate case + testNoGate(t, []fr.Element{four, three}) +} + +func TestNoGate(t *testing.T) { + testManyInstances(t, 1, testNoGate) +} + +func TestSingleAddGateTwoInstances(t *testing.T) { + testSingleAddGate(t, []fr.Element{four, three}, []fr.Element{two, three}) +} + +func TestSingleAddGate(t *testing.T) { + testManyInstances(t, 2, testSingleAddGate) +} + +func TestSingleMulGateTwoInstances(t *testing.T) { + testSingleMulGate(t, []fr.Element{four, three}, []fr.Element{two, three}) +} + +func TestSingleMulGate(t *testing.T) { + testManyInstances(t, 2, testSingleMulGate) +} + +func TestSingleInputTwoIdentityGatesTwoInstances(t *testing.T) { + + testSingleInputTwoIdentityGates(t, []fr.Element{two, three}) +} + +func TestSingleInputTwoIdentityGates(t *testing.T) { + + testManyInstances(t, 2, testSingleInputTwoIdentityGates) +} + +func TestSingleInputTwoIdentityGatesComposedTwoInstances(t *testing.T) { + testSingleInputTwoIdentityGatesComposed(t, []fr.Element{two, one}) +} + +func TestSingleInputTwoIdentityGatesComposed(t *testing.T) { + testManyInstances(t, 1, testSingleInputTwoIdentityGatesComposed) +} + +func TestSingleMimcCipherGateTwoInstances(t *testing.T) { + testSingleMimcCipherGate(t, []fr.Element{one, one}, []fr.Element{one, two}) +} + +func TestSingleMimcCipherGate(t *testing.T) { + testManyInstances(t, 2, testSingleMimcCipherGate) +} + +func TestATimesBSquaredTwoInstances(t *testing.T) { + testATimesBSquared(t, 2, []fr.Element{one, one}, []fr.Element{one, two}) +} + +func TestShallowMimcTwoInstances(t *testing.T) { + testMimc(t, 2, []fr.Element{one, one}, []fr.Element{one, two}) +} +func TestMimcTwoInstances(t *testing.T) { + testMimc(t, 93, []fr.Element{one, one}, []fr.Element{one, two}) +} + +func TestMimc(t *testing.T) { + testManyInstances(t, 2, generateTestMimc(93)) +} + +func generateTestMimc(numRounds int) func(*testing.T, ...[]fr.Element) { + return func(t *testing.T, inputAssignments ...[]fr.Element) { + testMimc(t, numRounds, inputAssignments...) + } +} + +func TestSumcheckFromSingleInputTwoIdentityGatesGateTwoInstances(t *testing.T) { + circuit := Circuit{Wire{ + Gate: IdentityGate{}, + Inputs: []*Wire{}, + nbUniqueOutputs: 2, + }} + + wire := &circuit[0] + + assignment := WireAssignment{&circuit[0]: []fr.Element{two, three}} + var o settings + pool := polynomial.NewPool(256, 1<<11) + workers := utils.NewWorkerPool() + o.pool = &pool + o.workers = workers + + claimsManagerGen := func() *claimsManager { + manager := newClaimsManager(circuit, assignment, o) + manager.add(wire, []fr.Element{three}, five) + manager.add(wire, []fr.Element{four}, six) + return &manager + } + + transcriptGen := test_vector_utils.NewMessageCounterGenerator(4, 1) + + proof, err := sumcheck.Prove(claimsManagerGen().getClaim(wire), fiatshamir.WithHash(transcriptGen(), nil)) + assert.NoError(t, err) + err = sumcheck.Verify(claimsManagerGen().getLazyClaim(wire), proof, fiatshamir.WithHash(transcriptGen(), nil)) + assert.NoError(t, err) +} + +var one, two, three, four, five, six fr.Element + +func init() { + one.SetOne() + two.Double(&one) + three.Add(&two, &one) + four.Double(&two) + five.Add(&three, &two) + six.Double(&three) +} + +var testManyInstancesLogMaxInstances = -1 + +func getLogMaxInstances(t *testing.T) int { + if testManyInstancesLogMaxInstances == -1 { + + s := os.Getenv("GKR_LOG_INSTANCES") + if s == "" { + testManyInstancesLogMaxInstances = 5 + } else { + var err error + testManyInstancesLogMaxInstances, err = strconv.Atoi(s) + if err != nil { + t.Error(err) + } + } + + } + return testManyInstancesLogMaxInstances +} + +func testManyInstances(t *testing.T, numInput int, test func(*testing.T, ...[]fr.Element)) { + fullAssignments := make([][]fr.Element, numInput) + maxSize := 1 << getLogMaxInstances(t) + + t.Log("Entered test orchestrator, assigning and randomizing inputs") + + for i := range fullAssignments { + fullAssignments[i] = make([]fr.Element, maxSize) + setRandom(fullAssignments[i]) + } + + inputAssignments := make([][]fr.Element, numInput) + for numEvals := maxSize; numEvals <= maxSize; numEvals *= 2 { + for i, fullAssignment := range fullAssignments { + inputAssignments[i] = fullAssignment[:numEvals] + } + + t.Log("Selected inputs for test") + test(t, inputAssignments...) + } +} + +func testNoGate(t *testing.T, inputAssignments ...[]fr.Element) { + c := Circuit{ + { + Inputs: []*Wire{}, + Gate: nil, + }, + } + + assignment := WireAssignment{&c[0]: inputAssignments[0]} + + proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) + assert.NoError(t, err) + + // Even though a hash is called here, the proof is empty + + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) + assert.NoError(t, err, "proof rejected") +} + +func testSingleAddGate(t *testing.T, inputAssignments ...[]fr.Element) { + c := make(Circuit, 3) + c[2] = Wire{ + Gate: Gates["add"], + Inputs: []*Wire{&c[0], &c[1]}, + } + + assignment := WireAssignment{&c[0]: inputAssignments[0], &c[1]: inputAssignments[1]}.Complete(c) + + proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) + assert.NoError(t, err) + + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) + assert.NoError(t, err, "proof rejected") + + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NotNil(t, err, "bad proof accepted") +} + +func testSingleMulGate(t *testing.T, inputAssignments ...[]fr.Element) { + + c := make(Circuit, 3) + c[2] = Wire{ + Gate: Gates["mul"], + Inputs: []*Wire{&c[0], &c[1]}, + } + + assignment := WireAssignment{&c[0]: inputAssignments[0], &c[1]: inputAssignments[1]}.Complete(c) + + proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) + assert.NoError(t, err) + + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) + assert.NoError(t, err, "proof rejected") + + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NotNil(t, err, "bad proof accepted") +} + +func testSingleInputTwoIdentityGates(t *testing.T, inputAssignments ...[]fr.Element) { + c := make(Circuit, 3) + + c[1] = Wire{ + Gate: IdentityGate{}, + Inputs: []*Wire{&c[0]}, + } + + c[2] = Wire{ + Gate: IdentityGate{}, + Inputs: []*Wire{&c[0]}, + } + + assignment := WireAssignment{&c[0]: inputAssignments[0]}.Complete(c) + + proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NoError(t, err) + + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NoError(t, err, "proof rejected") + + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) + assert.NotNil(t, err, "bad proof accepted") +} + +func testSingleMimcCipherGate(t *testing.T, inputAssignments ...[]fr.Element) { + c := make(Circuit, 3) + + c[2] = Wire{ + Gate: mimcCipherGate{}, + Inputs: []*Wire{&c[0], &c[1]}, + } + + t.Log("Evaluating all circuit wires") + assignment := WireAssignment{&c[0]: inputAssignments[0], &c[1]: inputAssignments[1]}.Complete(c) + t.Log("Circuit evaluation complete") + proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NoError(t, err) + t.Log("Proof complete") + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NoError(t, err, "proof rejected") + + t.Log("Successful verification complete") + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) + assert.NotNil(t, err, "bad proof accepted") + t.Log("Unsuccessful verification complete") +} + +func testSingleInputTwoIdentityGatesComposed(t *testing.T, inputAssignments ...[]fr.Element) { + c := make(Circuit, 3) + + c[1] = Wire{ + Gate: IdentityGate{}, + Inputs: []*Wire{&c[0]}, + } + c[2] = Wire{ + Gate: IdentityGate{}, + Inputs: []*Wire{&c[1]}, + } + + assignment := WireAssignment{&c[0]: inputAssignments[0]}.Complete(c) + + proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NoError(t, err) + + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NoError(t, err, "proof rejected") + + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) + assert.NotNil(t, err, "bad proof accepted") +} + +func mimcCircuit(numRounds int) Circuit { + c := make(Circuit, numRounds+2) + + for i := 2; i < len(c); i++ { + c[i] = Wire{ + Gate: mimcCipherGate{}, + Inputs: []*Wire{&c[i-1], &c[0]}, + } + } + return c +} + +func testMimc(t *testing.T, numRounds int, inputAssignments ...[]fr.Element) { + //TODO: Implement mimc correctly. Currently, the computation is mimc(a,b) = cipher( cipher( ... cipher(a, b), b) ..., b) + // @AlexandreBelling: Please explain the extra layers in https://github.com/Consensys/gkr-mimc/blob/81eada039ab4ed403b7726b535adb63026e8011f/examples/mimc.go#L10 + + c := mimcCircuit(numRounds) + + t.Log("Evaluating all circuit wires") + assignment := WireAssignment{&c[0]: inputAssignments[0], &c[1]: inputAssignments[1]}.Complete(c) + t.Log("Circuit evaluation complete") + + proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NoError(t, err) + + t.Log("Proof finished") + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NoError(t, err, "proof rejected") + + t.Log("Successful verification finished") + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) + assert.NotNil(t, err, "bad proof accepted") + t.Log("Unsuccessful verification finished") +} + +func testATimesBSquared(t *testing.T, numRounds int, inputAssignments ...[]fr.Element) { + // This imitates the MiMC circuit + + c := make(Circuit, numRounds+2) + + for i := 2; i < len(c); i++ { + c[i] = Wire{ + Gate: Gates["mul"], + Inputs: []*Wire{&c[i-1], &c[0]}, + } + } + + assignment := WireAssignment{&c[0]: inputAssignments[0], &c[1]: inputAssignments[1]}.Complete(c) + + proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NoError(t, err) + + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NoError(t, err, "proof rejected") + + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) + assert.NotNil(t, err, "bad proof accepted") +} + +func setRandom(slice []fr.Element) { + for i := range slice { + slice[i].SetRandom() + } +} + +func generateTestProver(path string) func(t *testing.T) { + return func(t *testing.T) { + testCase, err := newTestCase(path) + assert.NoError(t, err) + proof, err := Prove(testCase.Circuit, testCase.FullAssignment, fiatshamir.WithHash(testCase.Hash)) + assert.NoError(t, err) + assert.NoError(t, proofEquals(testCase.Proof, proof)) + } +} + +func generateTestVerifier(path string) func(t *testing.T) { + return func(t *testing.T) { + testCase, err := newTestCase(path) + assert.NoError(t, err) + err = Verify(testCase.Circuit, testCase.InOutAssignment, testCase.Proof, fiatshamir.WithHash(testCase.Hash)) + assert.NoError(t, err, "proof rejected") + testCase, err = newTestCase(path) + assert.NoError(t, err) + err = Verify(testCase.Circuit, testCase.InOutAssignment, testCase.Proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(2, 0))) + assert.NotNil(t, err, "bad proof accepted") + } +} + +func TestGkrVectors(t *testing.T) { + + testDirPath := "../../../../internal/generator/gkr/test_vectors" + dirEntries, err := os.ReadDir(testDirPath) + assert.NoError(t, err) + for _, dirEntry := range dirEntries { + if !dirEntry.IsDir() { + + if filepath.Ext(dirEntry.Name()) == ".json" { + path := filepath.Join(testDirPath, dirEntry.Name()) + noExt := dirEntry.Name()[:len(dirEntry.Name())-len(".json")] + + t.Run(noExt+"_prover", generateTestProver(path)) + t.Run(noExt+"_verifier", generateTestVerifier(path)) + + } + } + } +} + +func proofEquals(expected Proof, seen Proof) error { + if len(expected) != len(seen) { + return fmt.Errorf("length mismatch %d ≠ %d", len(expected), len(seen)) + } + for i, x := range expected { + xSeen := seen[i] + + if xSeen.FinalEvalProof == nil { + if seenFinalEval := x.FinalEvalProof.([]fr.Element); len(seenFinalEval) != 0 { + return fmt.Errorf("length mismatch %d ≠ %d", 0, len(seenFinalEval)) + } + } else { + if err := test_vector_utils.SliceEquals(x.FinalEvalProof.([]fr.Element), xSeen.FinalEvalProof.([]fr.Element)); err != nil { + return fmt.Errorf("final evaluation proof mismatch") + } + } + if err := test_vector_utils.PolynomialSliceEquals(x.PartialSumPolys, xSeen.PartialSumPolys); err != nil { + return err + } + } + return nil +} + +func benchmarkGkrMiMC(b *testing.B, nbInstances, mimcDepth int) { + fmt.Println("creating circuit structure") + c := mimcCircuit(mimcDepth) + + in0 := make([]fr.Element, nbInstances) + in1 := make([]fr.Element, nbInstances) + setRandom(in0) + setRandom(in1) + + fmt.Println("evaluating circuit") + start := time.Now().UnixMicro() + assignment := WireAssignment{&c[0]: in0, &c[1]: in1}.Complete(c) + solved := time.Now().UnixMicro() - start + fmt.Println("solved in", solved, "μs") + + //b.ResetTimer() + fmt.Println("constructing proof") + start = time.Now().UnixMicro() + _, err := Prove(c, assignment, fiatshamir.WithHash(mimc.NewMiMC())) + proved := time.Now().UnixMicro() - start + fmt.Println("proved in", proved, "μs") + assert.NoError(b, err) +} + +func BenchmarkGkrMimc19(b *testing.B) { + benchmarkGkrMiMC(b, 1<<19, 91) +} + +func BenchmarkGkrMimc17(b *testing.B) { + benchmarkGkrMiMC(b, 1<<17, 91) +} + +func TestTopSortTrivial(t *testing.T) { + c := make(Circuit, 2) + c[0].Inputs = []*Wire{&c[1]} + sorted := topologicalSort(c) + assert.Equal(t, []*Wire{&c[1], &c[0]}, sorted) +} + +func TestTopSortDeep(t *testing.T) { + c := make(Circuit, 4) + c[0].Inputs = []*Wire{&c[2]} + c[1].Inputs = []*Wire{&c[3]} + c[2].Inputs = []*Wire{} + c[3].Inputs = []*Wire{&c[0]} + sorted := topologicalSort(c) + assert.Equal(t, []*Wire{&c[2], &c[0], &c[3], &c[1]}, sorted) +} + +func TestTopSortWide(t *testing.T) { + c := make(Circuit, 10) + c[0].Inputs = []*Wire{&c[3], &c[8]} + c[1].Inputs = []*Wire{&c[6]} + c[2].Inputs = []*Wire{&c[4]} + c[3].Inputs = []*Wire{} + c[4].Inputs = []*Wire{} + c[5].Inputs = []*Wire{&c[9]} + c[6].Inputs = []*Wire{&c[9]} + c[7].Inputs = []*Wire{&c[9], &c[5], &c[2]} + c[8].Inputs = []*Wire{&c[4], &c[3]} + c[9].Inputs = []*Wire{} + + sorted := topologicalSort(c) + sortedExpected := []*Wire{&c[3], &c[4], &c[2], &c[8], &c[0], &c[9], &c[5], &c[6], &c[1], &c[7]} + + assert.Equal(t, sortedExpected, sorted) +} + +type WireInfo struct { + Gate string `json:"gate"` + Inputs []int `json:"inputs"` +} + +type CircuitInfo []WireInfo + +var circuitCache = make(map[string]Circuit) + +func getCircuit(path string) (Circuit, error) { + path, err := filepath.Abs(path) + if err != nil { + return nil, err + } + if circuit, ok := circuitCache[path]; ok { + return circuit, nil + } + var bytes []byte + if bytes, err = os.ReadFile(path); err == nil { + var circuitInfo CircuitInfo + if err = json.Unmarshal(bytes, &circuitInfo); err == nil { + circuit := circuitInfo.toCircuit() + circuitCache[path] = circuit + return circuit, nil + } else { + return nil, err + } + } else { + return nil, err + } +} + +func (c CircuitInfo) toCircuit() (circuit Circuit) { + circuit = make(Circuit, len(c)) + for i := range c { + circuit[i].Gate = Gates[c[i].Gate] + circuit[i].Inputs = make([]*Wire, len(c[i].Inputs)) + for k, inputCoord := range c[i].Inputs { + input := &circuit[inputCoord] + circuit[i].Inputs[k] = input + } + } + return +} + +func init() { + Gates["mimc"] = mimcCipherGate{} //TODO: Add ark + Gates["select-input-3"] = _select(2) +} + +type mimcCipherGate struct { + ark fr.Element +} + +func (m mimcCipherGate) Evaluate(input ...fr.Element) (res fr.Element) { + var sum fr.Element + + sum. + Add(&input[0], &input[1]). + Add(&sum, &m.ark) + + res.Square(&sum) // sum^2 + res.Mul(&res, &sum) // sum^3 + res.Square(&res) //sum^6 + res.Mul(&res, &sum) //sum^7 + + return +} + +func (m mimcCipherGate) Degree() int { + return 7 +} + +type PrintableProof []PrintableSumcheckProof + +type PrintableSumcheckProof struct { + FinalEvalProof interface{} `json:"finalEvalProof"` + PartialSumPolys [][]interface{} `json:"partialSumPolys"` +} + +func unmarshalProof(printable PrintableProof) (Proof, error) { + proof := make(Proof, len(printable)) + for i := range printable { + finalEvalProof := []fr.Element(nil) + + if printable[i].FinalEvalProof != nil { + finalEvalSlice := reflect.ValueOf(printable[i].FinalEvalProof) + finalEvalProof = make([]fr.Element, finalEvalSlice.Len()) + for k := range finalEvalProof { + if _, err := test_vector_utils.SetElement(&finalEvalProof[k], finalEvalSlice.Index(k).Interface()); err != nil { + return nil, err + } + } + } + + proof[i] = sumcheck.Proof{ + PartialSumPolys: make([]polynomial.Polynomial, len(printable[i].PartialSumPolys)), + FinalEvalProof: finalEvalProof, + } + for k := range printable[i].PartialSumPolys { + var err error + if proof[i].PartialSumPolys[k], err = test_vector_utils.SliceToElementSlice(printable[i].PartialSumPolys[k]); err != nil { + return nil, err + } + } + } + return proof, nil +} + +type TestCase struct { + Circuit Circuit + Hash hash.Hash + Proof Proof + FullAssignment WireAssignment + InOutAssignment WireAssignment +} + +type TestCaseInfo struct { + Hash test_vector_utils.HashDescription `json:"hash"` + Circuit string `json:"circuit"` + Input [][]interface{} `json:"input"` + Output [][]interface{} `json:"output"` + Proof PrintableProof `json:"proof"` +} + +var testCases = make(map[string]*TestCase) + +func newTestCase(path string) (*TestCase, error) { + path, err := filepath.Abs(path) + if err != nil { + return nil, err + } + dir := filepath.Dir(path) + + tCase, ok := testCases[path] + if !ok { + var bytes []byte + if bytes, err = os.ReadFile(path); err == nil { + var info TestCaseInfo + err = json.Unmarshal(bytes, &info) + if err != nil { + return nil, err + } + + var circuit Circuit + if circuit, err = getCircuit(filepath.Join(dir, info.Circuit)); err != nil { + return nil, err + } + var _hash hash.Hash + if _hash, err = test_vector_utils.HashFromDescription(info.Hash); err != nil { + return nil, err + } + var proof Proof + if proof, err = unmarshalProof(info.Proof); err != nil { + return nil, err + } + + fullAssignment := make(WireAssignment) + inOutAssignment := make(WireAssignment) + + sorted := topologicalSort(circuit) + + inI, outI := 0, 0 + for _, w := range sorted { + var assignmentRaw []interface{} + if w.IsInput() { + if inI == len(info.Input) { + return nil, fmt.Errorf("fewer input in vector than in circuit") + } + assignmentRaw = info.Input[inI] + inI++ + } else if w.IsOutput() { + if outI == len(info.Output) { + return nil, fmt.Errorf("fewer output in vector than in circuit") + } + assignmentRaw = info.Output[outI] + outI++ + } + if assignmentRaw != nil { + var wireAssignment []fr.Element + if wireAssignment, err = test_vector_utils.SliceToElementSlice(assignmentRaw); err != nil { + return nil, err + } + + fullAssignment[w] = wireAssignment + inOutAssignment[w] = wireAssignment + } + } + + fullAssignment.Complete(circuit) + + for _, w := range sorted { + if w.IsOutput() { + + if err = test_vector_utils.SliceEquals(inOutAssignment[w], fullAssignment[w]); err != nil { + return nil, fmt.Errorf("assignment mismatch: %v", err) + } + + } + } + + tCase = &TestCase{ + FullAssignment: fullAssignment, + InOutAssignment: inOutAssignment, + Proof: proof, + Hash: _hash, + Circuit: circuit, + } + + testCases[path] = tCase + } else { + return nil, err + } + } + + return tCase, nil +} + +type _select int + +func (g _select) Evaluate(in ...fr.Element) fr.Element { + return in[g] +} + +func (g _select) Degree() int { + return 1 +} diff --git a/ecc/grumpkin/fr/hash_to_field/doc.go b/ecc/grumpkin/fr/hash_to_field/doc.go new file mode 100644 index 0000000000..40ebcffa79 --- /dev/null +++ b/ecc/grumpkin/fr/hash_to_field/doc.go @@ -0,0 +1,21 @@ +// Copyright 2020-2025 Consensys Software Inc. +// Licensed under the Apache License, Version 2.0. See the LICENSE file for details. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +// Package htf provides hasher based on RFC 9380 Section 5. +// +// The [RFC 9380] defines a method for hashing bytes to elliptic curves. Section +// 5 of the RFC describes a method for uniformly hashing bytes into a field +// using a domain separation. The hashing is implemented in [fp], but this +// package provides a wrapper for the method which implements [hash.Hash] for +// using the method recursively. +// +// [RFC 9380]: https://datatracker.ietf.org/doc/html/rfc9380 +package hash_to_field + +import ( + _ "hash" + + _ "github.com/consensys/gnark-crypto/ecc/grumpkin/fr" +) diff --git a/ecc/grumpkin/fr/hash_to_field/hash_to_field.go b/ecc/grumpkin/fr/hash_to_field/hash_to_field.go new file mode 100644 index 0000000000..5787a3a0e5 --- /dev/null +++ b/ecc/grumpkin/fr/hash_to_field/hash_to_field.go @@ -0,0 +1,55 @@ +// Copyright 2020-2025 Consensys Software Inc. +// Licensed under the Apache License, Version 2.0. See the LICENSE file for details. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package hash_to_field + +import ( + "fmt" + "hash" + + "github.com/consensys/gnark-crypto/ecc/grumpkin/fr" +) + +type wrappedHashToField struct { + domain []byte + toHash []byte +} + +// New returns a new hasher instance which uses [fr.Hash] to hash all the +// written bytes to a field element, returning the byte representation of the +// field element. The domain separator is passed as-is to hashing method. +func New(domainSeparator []byte) hash.Hash { + return &wrappedHashToField{ + domain: append([]byte{}, domainSeparator...), // copy in case the argument is modified + } +} + +func (w *wrappedHashToField) Write(p []byte) (n int, err error) { + w.toHash = append(w.toHash, p...) + return len(p), nil +} + +func (w *wrappedHashToField) Sum(b []byte) []byte { + res, err := fr.Hash(w.toHash, w.domain, 1) + if err != nil { + // we want to follow the interface, cannot return error and have to panic + // but by default the method shouldn't return an error internally + panic(fmt.Sprintf("native field to hash: %v", err)) + } + bts := res[0].Bytes() + return append(b, bts[:]...) +} + +func (w *wrappedHashToField) Reset() { + w.toHash = nil +} + +func (w *wrappedHashToField) Size() int { + return fr.Bytes +} + +func (w *wrappedHashToField) BlockSize() int { + return fr.Bytes +} diff --git a/ecc/grumpkin/fr/hash_to_field/hash_to_field_test.go b/ecc/grumpkin/fr/hash_to_field/hash_to_field_test.go new file mode 100644 index 0000000000..8c94b5b3c6 --- /dev/null +++ b/ecc/grumpkin/fr/hash_to_field/hash_to_field_test.go @@ -0,0 +1,30 @@ +// Copyright 2020-2025 Consensys Software Inc. +// Licensed under the Apache License, Version 2.0. See the LICENSE file for details. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package hash_to_field + +import ( + "testing" + + "github.com/consensys/gnark-crypto/ecc/grumpkin/fr" +) + +func TestHashInterface(t *testing.T) { + msg := []byte("test") + sep := []byte("separator") + res, err := fr.Hash(msg, sep, 1) + if err != nil { + t.Fatal("hash to field", err) + } + + htfFn := New(sep) + htfFn.Write(msg) + bts := htfFn.Sum(nil) + var res2 fr.Element + res2.SetBytes(bts[:fr.Bytes]) + if !res[0].Equal(&res2) { + t.Error("not equal") + } +} diff --git a/ecc/grumpkin/fr/mimc/doc.go b/ecc/grumpkin/fr/mimc/doc.go new file mode 100644 index 0000000000..946e0dafaa --- /dev/null +++ b/ecc/grumpkin/fr/mimc/doc.go @@ -0,0 +1,49 @@ +// Copyright 2020-2025 Consensys Software Inc. +// Licensed under the Apache License, Version 2.0. See the LICENSE file for details. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +// Package mimc provides MiMC hash function using Miyaguchi–Preneel construction. +// +// # Length extension attack +// +// The MiMC hash function is vulnerable to a length extension attack. For +// example when we have a hash +// +// h = MiMC(k || m) +// +// and we want to hash a new message +// +// m' = m || m2, +// +// we can compute +// +// h' = MiMC(k || m || m2) +// +// without knowing k by computing +// +// h' = MiMC(h || m2). +// +// This is because the MiMC hash function is a simple iterated cipher, and the +// hash value is the state of the cipher after encrypting the message. +// +// There are several ways to mitigate this attack: +// - use a random key for each hash +// - use a domain separation tag for different use cases: +// h = MiMC(k || tag || m) +// - use the secret input as last input: +// h = MiMC(m || k) +// +// In general, inside a circuit the length-extension attack is not a concern as +// due to the circuit definition the attacker can not append messages to +// existing hash. But the user has to consider the cases when using a secret key +// and MiMC in different contexts. +// +// # Hash input format +// +// The MiMC hash function is defined over a field. The input to the hash +// function is a byte slice. The byte slice is interpreted as a sequence of +// field elements. Due to this interpretation, the input byte slice length must +// be multiple of the field modulus size. And every sequence of byte slice for a +// single field element must be strictly less than the field modulus. +package mimc diff --git a/ecc/grumpkin/fr/mimc/mimc.go b/ecc/grumpkin/fr/mimc/mimc.go new file mode 100644 index 0000000000..a1fe21bbad --- /dev/null +++ b/ecc/grumpkin/fr/mimc/mimc.go @@ -0,0 +1,225 @@ +// Copyright 2020-2025 Consensys Software Inc. +// Licensed under the Apache License, Version 2.0. See the LICENSE file for details. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package mimc + +import ( + "errors" + stdhash "hash" + "math/big" + "sync" + + "github.com/consensys/gnark-crypto/ecc/grumpkin/fr" + "github.com/consensys/gnark-crypto/hash" + + "golang.org/x/crypto/sha3" +) + +func init() { + hash.RegisterHash(hash.MIMC_GRUMPKIN, func() stdhash.Hash { + return NewMiMC() + }) +} + +const ( + mimcNbRounds = 110 + seed = "seed" // seed to derive the constants + BlockSize = fr.Bytes // BlockSize size that mimc consumes +) + +// Params constants for the mimc hash function +var ( + mimcConstants [mimcNbRounds]fr.Element + once sync.Once +) + +// digest represents the partial evaluation of the checksum +// along with the params of the mimc function +type digest struct { + h fr.Element + data []fr.Element // data to hash + byteOrder fr.ByteOrder +} + +// GetConstants exposed to be used in gnark +func GetConstants() []big.Int { + once.Do(initConstants) // init constants + res := make([]big.Int, mimcNbRounds) + for i := 0; i < mimcNbRounds; i++ { + mimcConstants[i].BigInt(&res[i]) + } + return res +} + +// NewMiMC returns a MiMC implementation, pure Go reference implementation. +func NewMiMC(opts ...Option) hash.StateStorer { + d := new(digest) + d.Reset() + cfg := mimcOptions(opts...) + d.byteOrder = cfg.byteOrder + return d +} + +// Reset resets the Hash to its initial state. +func (d *digest) Reset() { + d.data = d.data[:0] + d.h = fr.Element{0, 0, 0, 0} +} + +// Sum appends the current hash to b and returns the resulting slice. +// It does not change the underlying hash state. +func (d *digest) Sum(b []byte) []byte { + buffer := d.checksum() + d.data = nil // flush the data already hashed + hash := buffer.Bytes() + b = append(b, hash[:]...) + return b +} + +// BlockSize returns the hash's underlying block size. +// The Write method must be able to accept any amount +// of data, but it may operate more efficiently if all writes +// are a multiple of the block size. +func (d *digest) Size() int { + return BlockSize +} + +// BlockSize returns the number of bytes Sum will return. +func (d *digest) BlockSize() int { + return BlockSize +} + +// Write (via the embedded io.Writer interface) adds more data to the running hash. +// +// Each []byte block of size BlockSize represents a big endian fr.Element. +// +// If len(p) is not a multiple of BlockSize and any of the []byte in p represent an integer +// larger than fr.Modulus, this function returns an error. +// +// To hash arbitrary data ([]byte not representing canonical field elements) use fr.Hash first +func (d *digest) Write(p []byte) (int, error) { + // we usually expect multiple of block size. But sometimes we hash short + // values (FS transcript). Instead of forcing to hash to field, we left-pad the + // input here. + if len(p) > 0 && len(p) < BlockSize { + pp := make([]byte, BlockSize) + copy(pp[len(pp)-len(p):], p) + p = pp + } + + var start int + for start = 0; start < len(p); start += BlockSize { + if elem, err := d.byteOrder.Element((*[BlockSize]byte)(p[start : start+BlockSize])); err == nil { + d.data = append(d.data, elem) + } else { + return 0, err + } + } + + if start != len(p) { + return 0, errors.New("invalid input length: must represent a list of field elements, expects a []byte of len m*BlockSize") + } + return len(p), nil +} + +// Hash hash using Miyaguchi-Preneel: +// https://en.wikipedia.org/wiki/One-way_compression_function +// The XOR operation is replaced by field addition, data is in Montgomery form +func (d *digest) checksum() fr.Element { + // Write guarantees len(data) % BlockSize == 0 + + // TODO @ThomasPiellard shouldn't Sum() returns an error if there is no data? + // TODO: @Tabaie, @Thomas Piellard Now sure what to make of this + /*if len(d.data) == 0 { + d.data = make([]byte, BlockSize) + }*/ + + for i := range d.data { + r := d.encrypt(d.data[i]) + d.h.Add(&r, &d.h).Add(&d.h, &d.data[i]) + } + + return d.h +} + +// plain execution of a mimc run +// m: message +// k: encryption key +func (d *digest) encrypt(m fr.Element) fr.Element { + once.Do(initConstants) // init constants + + var tmp fr.Element + for i := 0; i < mimcNbRounds; i++ { + // m = (m+k+c)^5 + tmp.Add(&m, &d.h).Add(&tmp, &mimcConstants[i]) + m.Square(&tmp). + Square(&m). + Mul(&m, &tmp) + } + m.Add(&m, &d.h) + return m +} + +// Sum computes the mimc hash of msg from seed +func Sum(msg []byte) ([]byte, error) { + var d digest + if _, err := d.Write(msg); err != nil { + return nil, err + } + h := d.checksum() + bytes := h.Bytes() + return bytes[:], nil +} + +func initConstants() { + bseed := ([]byte)(seed) + + hash := sha3.NewLegacyKeccak256() + _, _ = hash.Write(bseed) + rnd := hash.Sum(nil) // pre hash before use + hash.Reset() + _, _ = hash.Write(rnd) + + for i := 0; i < mimcNbRounds; i++ { + rnd = hash.Sum(nil) + mimcConstants[i].SetBytes(rnd) + hash.Reset() + _, _ = hash.Write(rnd) + } +} + +// WriteString writes a string that doesn't necessarily consist of field elements +func (d *digest) WriteString(rawBytes []byte) error { + if elems, err := fr.Hash(rawBytes, []byte("string:"), 1); err != nil { + return err + } else { + d.data = append(d.data, elems[0]) + } + return nil +} + +// SetState manually sets the state of the hasher to an user-provided value. In +// the context of MiMC, the method expects a byte slice of 32 elements. +func (d *digest) SetState(newState []byte) error { + + if len(newState) != 32 { + return errors.New("the mimc state expects a state of 32 bytes") + } + + if err := d.h.SetBytesCanonical(newState); err != nil { + return errors.New("the provided newState does not represent a valid state") + } + + d.data = nil + + return nil +} + +// State returns the internal state of the hasher +func (d *digest) State() []byte { + _ = d.Sum(nil) // this flushes the hasher + b := d.h.Bytes() + return b[:] +} diff --git a/ecc/grumpkin/fr/mimc/mimc_test.go b/ecc/grumpkin/fr/mimc/mimc_test.go new file mode 100644 index 0000000000..a044b6adf9 --- /dev/null +++ b/ecc/grumpkin/fr/mimc/mimc_test.go @@ -0,0 +1,120 @@ +// Copyright 2020-2025 Consensys Software Inc. +// Licensed under the Apache License, Version 2.0. See the LICENSE file for details. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package mimc_test + +import ( + "bytes" + "testing" + + "github.com/consensys/gnark-crypto/ecc/grumpkin/fr" + "github.com/consensys/gnark-crypto/ecc/grumpkin/fr/mimc" + fiatshamir "github.com/consensys/gnark-crypto/fiat-shamir" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestMiMCFiatShamir(t *testing.T) { + fs := fiatshamir.NewTranscript(mimc.NewMiMC(), "c0") + zero := make([]byte, mimc.BlockSize) + err := fs.Bind("c0", zero) + assert.NoError(t, err) + _, err = fs.ComputeChallenge("c0") + assert.NoError(t, err) +} + +func TestByteOrder(t *testing.T) { + assert := require.New(t) + + var buf [fr.Bytes]byte + // if the 31 first bytes are FF, it's a valid FF in little endian, but not in big endian + for i := 0; i < fr.Bytes-1; i++ { + buf[i] = 0xFF + } + _, err := fr.BigEndian.Element(&buf) + assert.Error(err) + _, err = fr.LittleEndian.Element(&buf) + assert.NoError(err) + + { + // hashing buf with big endian should fail + mimcHash := mimc.NewMiMC(mimc.WithByteOrder(fr.BigEndian)) + _, err := mimcHash.Write(buf[:]) + assert.Error(err) + } + + { + // hashing buf with little endian should succeed + mimcHash := mimc.NewMiMC(mimc.WithByteOrder(fr.LittleEndian)) + _, err := mimcHash.Write(buf[:]) + assert.NoError(err) + } + + buf = [fr.Bytes]byte{} + // if the 31 bytes are FF, it's a valid FF in big endian, but not in little endian + for i := 1; i < fr.Bytes; i++ { + buf[i] = 0xFF + } + _, err = fr.BigEndian.Element(&buf) + assert.NoError(err) + _, err = fr.LittleEndian.Element(&buf) + assert.Error(err) + + { + // hashing buf with big endian should succeed + mimcHash := mimc.NewMiMC(mimc.WithByteOrder(fr.BigEndian)) + _, err := mimcHash.Write(buf[:]) + assert.NoError(err) + } + + { + // hashing buf with little endian should fail + mimcHash := mimc.NewMiMC(mimc.WithByteOrder(fr.LittleEndian)) + _, err := mimcHash.Write(buf[:]) + assert.Error(err) + } +} + +func TestSetState(t *testing.T) { + // we use for hashing and retrieving the state + h1 := mimc.NewMiMC() + // only hashing + h2 := mimc.NewMiMC() + // we use for restoring from state + h3 := mimc.NewMiMC() + + randInputs := make([]fr.Element, 10) + for i := range randInputs { + randInputs[i].SetRandom() + } + + storedStates := make([][]byte, len(randInputs)) + + for i := range randInputs { + storedStates[i] = h1.State() + + h1.Write(randInputs[i].Marshal()) + h2.Write(randInputs[i].Marshal()) + } + dgst1 := h1.Sum(nil) + dgst2 := h2.Sum(nil) + if !bytes.Equal(dgst1, dgst2) { + t.Fatal("hashes do not match") + } + + for i := range storedStates { + if err := h3.SetState(storedStates[i]); err != nil { + t.Fatal(err) + } + for j := i; j < len(randInputs); j++ { + h3.Write(randInputs[j].Marshal()) + } + dgst3 := h3.Sum(nil) + if !bytes.Equal(dgst1, dgst3) { + t.Fatal("hashes do not match") + } + } +} diff --git a/ecc/grumpkin/fr/mimc/options.go b/ecc/grumpkin/fr/mimc/options.go new file mode 100644 index 0000000000..c7bc26de97 --- /dev/null +++ b/ecc/grumpkin/fr/mimc/options.go @@ -0,0 +1,39 @@ +// Copyright 2020-2025 Consensys Software Inc. +// Licensed under the Apache License, Version 2.0. See the LICENSE file for details. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package mimc + +import ( + "github.com/consensys/gnark-crypto/ecc/grumpkin/fr" +) + +// Option defines option for altering the behavior of the MiMC hasher. +// See the descriptions of functions returning instances of this type for +// particular options. +type Option func(*mimcConfig) + +type mimcConfig struct { + byteOrder fr.ByteOrder +} + +// default options +func mimcOptions(opts ...Option) mimcConfig { + // apply options + opt := mimcConfig{ + byteOrder: fr.BigEndian, + } + for _, option := range opts { + option(&opt) + } + return opt +} + +// WithByteOrder sets the byte order used to decode the input +// in the Write method. Default is BigEndian. +func WithByteOrder(byteOrder fr.ByteOrder) Option { + return func(opt *mimcConfig) { + opt.byteOrder = byteOrder + } +} diff --git a/ecc/grumpkin/fr/polynomial/doc.go b/ecc/grumpkin/fr/polynomial/doc.go new file mode 100644 index 0000000000..ead3b5cba5 --- /dev/null +++ b/ecc/grumpkin/fr/polynomial/doc.go @@ -0,0 +1,7 @@ +// Copyright 2020-2025 Consensys Software Inc. +// Licensed under the Apache License, Version 2.0. See the LICENSE file for details. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +// Package polynomial provides polynomial methods and commitment schemes. +package polynomial diff --git a/ecc/grumpkin/fr/polynomial/multilin.go b/ecc/grumpkin/fr/polynomial/multilin.go new file mode 100644 index 0000000000..c57983ce4c --- /dev/null +++ b/ecc/grumpkin/fr/polynomial/multilin.go @@ -0,0 +1,280 @@ +// Copyright 2020-2025 Consensys Software Inc. +// Licensed under the Apache License, Version 2.0. See the LICENSE file for details. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package polynomial + +import ( + "github.com/consensys/gnark-crypto/ecc/grumpkin/fr" + "github.com/consensys/gnark-crypto/utils" + "math/bits" +) + +// MultiLin tracks the values of a (dense i.e. not sparse) multilinear polynomial +// The variables are X₁ through Xₙ where n = log(len(.)) +// .[∑ᵢ 2ⁱ⁻¹ bₙ₋ᵢ] = the polynomial evaluated at (b₁, b₂, ..., bₙ) +// It is understood that any hypercube evaluation can be extrapolated to a multilinear polynomial +type MultiLin []fr.Element + +// Fold is partial evaluation function k[X₁, X₂, ..., Xₙ] → k[X₂, ..., Xₙ] by setting X₁=r +func (m *MultiLin) Fold(r fr.Element) { + mid := len(*m) / 2 + + bottom, top := (*m)[:mid], (*m)[mid:] + + var t fr.Element // no need to update the top part + + // updating bookkeeping table + // knowing that the polynomial f ∈ (k[X₂, ..., Xₙ])[X₁] is linear, we would get f(r) = f(0) + r(f(1) - f(0)) + // the following loop computes the evaluations of f(r) accordingly: + // f(r, b₂, ..., bₙ) = f(0, b₂, ..., bₙ) + r(f(1, b₂, ..., bₙ) - f(0, b₂, ..., bₙ)) + for i := 0; i < mid; i++ { + // table[i] ← table[i] + r (table[i + mid] - table[i]) + t.Sub(&top[i], &bottom[i]) + t.Mul(&t, &r) + bottom[i].Add(&bottom[i], &t) + } + + *m = (*m)[:mid] +} + +func (m *MultiLin) FoldParallel(r fr.Element) utils.Task { + mid := len(*m) / 2 + bottom, top := (*m)[:mid], (*m)[mid:] + + *m = bottom + + return func(start, end int) { + var t fr.Element // no need to update the top part + for i := start; i < end; i++ { + // table[i] ← table[i] + r (table[i + mid] - table[i]) + t.Sub(&top[i], &bottom[i]) + t.Mul(&t, &r) + bottom[i].Add(&bottom[i], &t) + } + } +} + +func (m MultiLin) Sum() fr.Element { + s := m[0] + for i := 1; i < len(m); i++ { + s.Add(&s, &m[i]) + } + return s +} + +func _clone(m MultiLin, p *Pool) MultiLin { + if p == nil { + return m.Clone() + } else { + return p.Clone(m) + } +} + +func _dump(m MultiLin, p *Pool) { + if p != nil { + p.Dump(m) + } +} + +// Evaluate extrapolate the value of the multilinear polynomial corresponding to m +// on the given coordinates +func (m MultiLin) Evaluate(coordinates []fr.Element, p *Pool) fr.Element { + // Folding is a mutating operation + bkCopy := _clone(m, p) + + // Evaluate step by step through repeated folding (i.e. evaluation at the first remaining variable) + for _, r := range coordinates { + bkCopy.Fold(r) + } + + result := bkCopy[0] + + _dump(bkCopy, p) + return result +} + +// Clone creates a deep copy of a bookkeeping table. +// Both multilinear interpolation and sumcheck require folding an underlying +// array, but folding changes the array. To do both one requires a deep copy +// of the bookkeeping table. +func (m MultiLin) Clone() MultiLin { + res := make(MultiLin, len(m)) + copy(res, m) + return res +} + +// Add two bookKeepingTables +func (m *MultiLin) Add(left, right MultiLin) { + size := len(left) + // Check that left and right have the same size + if len(right) != size || len(*m) != size { + panic("left, right and destination must have the right size") + } + + // Add elementwise + for i := 0; i < size; i++ { + (*m)[i].Add(&left[i], &right[i]) + } +} + +// EvalEq computes Eq(q₁, ... , qₙ, h₁, ... , hₙ) = Π₁ⁿ Eq(qᵢ, hᵢ) +// where Eq(x,y) = xy + (1-x)(1-y) = 1 - x - y + xy + xy interpolates +// +// _________________ +// | | | +// | 0 | 1 | +// |_______|_______| +// y | | | +// | 1 | 0 | +// |_______|_______| +// +// x +// +// In other words the polynomial evaluated here is the multilinear extrapolation of +// one that evaluates to q' == h' for vectors q', h' of binary values +func EvalEq(q, h []fr.Element) fr.Element { + var res, nxt, one, sum fr.Element + one.SetOne() + for i := 0; i < len(q); i++ { + nxt.Mul(&q[i], &h[i]) // nxt <- qᵢ * hᵢ + nxt.Double(&nxt) // nxt <- 2 * qᵢ * hᵢ + nxt.Add(&nxt, &one) // nxt <- 1 + 2 * qᵢ * hᵢ + sum.Add(&q[i], &h[i]) // sum <- qᵢ + hᵢ TODO: Why not subtract one by one from nxt? More parallel? + + if i == 0 { + res.Sub(&nxt, &sum) // nxt <- 1 + 2 * qᵢ * hᵢ - qᵢ - hᵢ + } else { + nxt.Sub(&nxt, &sum) // nxt <- 1 + 2 * qᵢ * hᵢ - qᵢ - hᵢ + res.Mul(&res, &nxt) // res <- res * nxt + } + } + return res +} + +// Eq sets m to the representation of the polynomial Eq(q₁, ..., qₙ, *, ..., *) × m[0] +func (m *MultiLin) Eq(q []fr.Element) { + n := len(q) + + if len(*m) != 1<= 0; i-- { + res.Mul(&res, v) + res.Add(&res, &(*p)[i]) + } + + return res +} + +// Clone returns a copy of the polynomial +func (p *Polynomial) Clone() Polynomial { + _p := make(Polynomial, len(*p)) + copy(_p, *p) + return _p +} + +// Set to another polynomial +func (p *Polynomial) Set(p1 Polynomial) { + if len(*p) != len(p1) { + *p = p1.Clone() + return + } + + for i := 0; i < len(p1); i++ { + (*p)[i].Set(&p1[i]) + } +} + +// AddConstantInPlace adds a constant to the polynomial, modifying p +func (p *Polynomial) AddConstantInPlace(c *fr.Element) { + for i := 0; i < len(*p); i++ { + (*p)[i].Add(&(*p)[i], c) + } +} + +// SubConstantInPlace subs a constant to the polynomial, modifying p +func (p *Polynomial) SubConstantInPlace(c *fr.Element) { + for i := 0; i < len(*p); i++ { + (*p)[i].Sub(&(*p)[i], c) + } +} + +// ScaleInPlace multiplies p by v, modifying p +func (p *Polynomial) ScaleInPlace(c *fr.Element) { + for i := 0; i < len(*p); i++ { + (*p)[i].Mul(&(*p)[i], c) + } +} + +// Scale multiplies p0 by v, storing the result in p +func (p *Polynomial) Scale(c *fr.Element, p0 Polynomial) { + if len(*p) != len(p0) { + *p = make(Polynomial, len(p0)) + } + for i := 0; i < len(p0); i++ { + (*p)[i].Mul(c, &p0[i]) + } +} + +// Add adds p1 to p2 +// This function allocates a new slice unless p == p1 or p == p2 +func (p *Polynomial) Add(p1, p2 Polynomial) *Polynomial { + + bigger := p1 + smaller := p2 + if len(bigger) < len(smaller) { + bigger, smaller = smaller, bigger + } + + if len(*p) == len(bigger) && (&(*p)[0] == &bigger[0]) { + for i := 0; i < len(smaller); i++ { + (*p)[i].Add(&(*p)[i], &smaller[i]) + } + return p + } + + if len(*p) == len(smaller) && (&(*p)[0] == &smaller[0]) { + for i := 0; i < len(smaller); i++ { + (*p)[i].Add(&(*p)[i], &bigger[i]) + } + *p = append(*p, bigger[len(smaller):]...) + return p + } + + res := make(Polynomial, len(bigger)) + copy(res, bigger) + for i := 0; i < len(smaller); i++ { + res[i].Add(&res[i], &smaller[i]) + } + *p = res + return p +} + +// Sub subtracts p2 from p1 +// TODO make interface more consistent with Add +func (p *Polynomial) Sub(p1, p2 Polynomial) *Polynomial { + if len(p1) != len(p2) || len(p2) != len(*p) { + return nil + } + for i := 0; i < len(*p); i++ { + (*p)[i].Sub(&p1[i], &p2[i]) + } + return p +} + +// Equal checks equality between two polynomials +func (p *Polynomial) Equal(p1 Polynomial) bool { + if (*p == nil) != (p1 == nil) { + return false + } + + if len(*p) != len(p1) { + return false + } + + for i := range p1 { + if !(*p)[i].Equal(&p1[i]) { + return false + } + } + + return true +} + +func (p Polynomial) SetZero() { + for i := 0; i < len(p); i++ { + p[i].SetZero() + } +} + +func (p Polynomial) Text(base int) string { + + var builder strings.Builder + + first := true + for d := len(p) - 1; d >= 0; d-- { + if p[d].IsZero() { + continue + } + + pD := p[d] + pDText := pD.Text(base) + + initialLen := builder.Len() + + if pDText[0] == '-' { + pDText = pDText[1:] + if first { + builder.WriteString("-") + } else { + builder.WriteString(" - ") + } + } else if !first { + builder.WriteString(" + ") + } + + first = false + + if !pD.IsOne() || d == 0 { + builder.WriteString(pDText) + } + + if builder.Len()-initialLen > 10 { + builder.WriteString("×") + } + + if d != 0 { + builder.WriteString("X") + } + if d > 1 { + builder.WriteString( + utils.ToSuperscript(strconv.Itoa(d)), + ) + } + + } + + if first { + return "0" + } + + return builder.String() +} diff --git a/ecc/grumpkin/fr/polynomial/polynomial_test.go b/ecc/grumpkin/fr/polynomial/polynomial_test.go new file mode 100644 index 0000000000..6e7ada78a6 --- /dev/null +++ b/ecc/grumpkin/fr/polynomial/polynomial_test.go @@ -0,0 +1,207 @@ +// Copyright 2020-2025 Consensys Software Inc. +// Licensed under the Apache License, Version 2.0. See the LICENSE file for details. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package polynomial + +import ( + "github.com/consensys/gnark-crypto/ecc/grumpkin/fr" + "github.com/stretchr/testify/assert" + "math/big" + "testing" +) + +func TestPolynomialEval(t *testing.T) { + + // build polynomial + f := make(Polynomial, 20) + for i := 0; i < 20; i++ { + f[i].SetOne() + } + + // random value + var point fr.Element + point.SetRandom() + + // compute manually f(val) + var expectedEval, one, den fr.Element + var expo big.Int + one.SetOne() + expo.SetUint64(20) + expectedEval.Exp(point, &expo). + Sub(&expectedEval, &one) + den.Sub(&point, &one) + expectedEval.Div(&expectedEval, &den) + + // compute purported evaluation + purportedEval := f.Eval(&point) + + // check + if !purportedEval.Equal(&expectedEval) { + t.Fatal("polynomial evaluation failed") + } +} + +func TestPolynomialAddConstantInPlace(t *testing.T) { + + // build polynomial + f := make(Polynomial, 20) + for i := 0; i < 20; i++ { + f[i].SetOne() + } + + // constant to add + var c fr.Element + c.SetRandom() + + // add constant + f.AddConstantInPlace(&c) + + // check + var expectedCoeffs, one fr.Element + one.SetOne() + expectedCoeffs.Add(&one, &c) + for i := 0; i < 20; i++ { + if !f[i].Equal(&expectedCoeffs) { + t.Fatal("AddConstantInPlace failed") + } + } +} + +func TestPolynomialSubConstantInPlace(t *testing.T) { + + // build polynomial + f := make(Polynomial, 20) + for i := 0; i < 20; i++ { + f[i].SetOne() + } + + // constant to sub + var c fr.Element + c.SetRandom() + + // sub constant + f.SubConstantInPlace(&c) + + // check + var expectedCoeffs, one fr.Element + one.SetOne() + expectedCoeffs.Sub(&one, &c) + for i := 0; i < 20; i++ { + if !f[i].Equal(&expectedCoeffs) { + t.Fatal("SubConstantInPlace failed") + } + } +} + +func TestPolynomialScaleInPlace(t *testing.T) { + + // build polynomial + f := make(Polynomial, 20) + for i := 0; i < 20; i++ { + f[i].SetOne() + } + + // constant to scale by + var c fr.Element + c.SetRandom() + + // scale by constant + f.ScaleInPlace(&c) + + // check + for i := 0; i < 20; i++ { + if !f[i].Equal(&c) { + t.Fatal("ScaleInPlace failed") + } + } + +} + +func TestPolynomialAdd(t *testing.T) { + + // build unbalanced polynomials + f1 := make(Polynomial, 20) + f1Backup := make(Polynomial, 20) + for i := 0; i < 20; i++ { + f1[i].SetOne() + f1Backup[i].SetOne() + } + f2 := make(Polynomial, 10) + f2Backup := make(Polynomial, 10) + for i := 0; i < 10; i++ { + f2[i].SetOne() + f2Backup[i].SetOne() + } + + // expected result + var one, two fr.Element + one.SetOne() + two.Double(&one) + expectedSum := make(Polynomial, 20) + for i := 0; i < 10; i++ { + expectedSum[i].Set(&two) + } + for i := 10; i < 20; i++ { + expectedSum[i].Set(&one) + } + + // caller is empty + var g Polynomial + g.Add(f1, f2) + if !g.Equal(expectedSum) { + t.Fatal("add polynomials fails") + } + if !f1.Equal(f1Backup) { + t.Fatal("side effect, f1 should not have been modified") + } + if !f2.Equal(f2Backup) { + t.Fatal("side effect, f2 should not have been modified") + } + + // all operands are distinct + _f1 := f1.Clone() + _f1.Add(f1, f2) + if !_f1.Equal(expectedSum) { + t.Fatal("add polynomials fails") + } + if !f1.Equal(f1Backup) { + t.Fatal("side effect, f1 should not have been modified") + } + if !f2.Equal(f2Backup) { + t.Fatal("side effect, f2 should not have been modified") + } + + // first operand = caller + _f1 = f1.Clone() + _f2 := f2.Clone() + _f1.Add(_f1, _f2) + if !_f1.Equal(expectedSum) { + t.Fatal("add polynomials fails") + } + if !_f2.Equal(f2Backup) { + t.Fatal("side effect, _f2 should not have been modified") + } + + // second operand = caller + _f1 = f1.Clone() + _f2 = f2.Clone() + _f1.Add(_f2, _f1) + if !_f1.Equal(expectedSum) { + t.Fatal("add polynomials fails") + } + if !_f2.Equal(f2Backup) { + t.Fatal("side effect, _f2 should not have been modified") + } +} + +func TestPolynomialText(t *testing.T) { + var one, negTwo fr.Element + one.SetOne() + negTwo.SetInt64(-2) + + p := Polynomial{one, negTwo, one} + + assert.Equal(t, "X² - 2X + 1", p.Text(10)) +} diff --git a/ecc/grumpkin/fr/polynomial/pool.go b/ecc/grumpkin/fr/polynomial/pool.go new file mode 100644 index 0000000000..4cf6e6ec0b --- /dev/null +++ b/ecc/grumpkin/fr/polynomial/pool.go @@ -0,0 +1,190 @@ +// Copyright 2020-2025 Consensys Software Inc. +// Licensed under the Apache License, Version 2.0. See the LICENSE file for details. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package polynomial + +import ( + "encoding/json" + "fmt" + "github.com/consensys/gnark-crypto/ecc/grumpkin/fr" + "runtime" + "sort" + "sync" + "unsafe" +) + +// Memory management for polynomials +// WARNING: This is not thread safe TODO: Make sure that is not a problem +// TODO: There is a lot of "unsafe" memory management here and needs to be vetted thoroughly + +type sizedPool struct { + maxN int + pool sync.Pool + stats poolStats +} + +type inUseData struct { + allocatedFor []uintptr + pool *sizedPool +} + +type Pool struct { + //lock sync.Mutex + inUse sync.Map + subPools []sizedPool +} + +func (p *sizedPool) get(n int) *fr.Element { + p.stats.make(n) + return p.pool.Get().(*fr.Element) +} + +func (p *sizedPool) put(ptr *fr.Element) { + p.stats.dump() + p.pool.Put(ptr) +} + +func NewPool(maxN ...int) (pool Pool) { + + sort.Ints(maxN) + pool = Pool{ + subPools: make([]sizedPool, len(maxN)), + } + + for i := range pool.subPools { + subPool := &pool.subPools[i] + subPool.maxN = maxN[i] + subPool.pool = sync.Pool{ + New: func() interface{} { + subPool.stats.Allocated++ + return getDataPointer(make([]fr.Element, 0, subPool.maxN)) + }, + } + } + return +} + +func (p *Pool) findCorrespondingPool(n int) *sizedPool { + poolI := 0 + for poolI < len(p.subPools) && n > p.subPools[poolI].maxN { + poolI++ + } + return &p.subPools[poolI] // out of bounds error here would mean that n is too large +} + +func (p *Pool) Make(n int) []fr.Element { + pool := p.findCorrespondingPool(n) + ptr := pool.get(n) + p.addInUse(ptr, pool) + return unsafe.Slice(ptr, n) +} + +// Dump dumps a set of polynomials into the pool +func (p *Pool) Dump(slices ...[]fr.Element) { + for _, slice := range slices { + ptr := getDataPointer(slice) + if metadata, ok := p.inUse.Load(ptr); ok { + p.inUse.Delete(ptr) + metadata.(inUseData).pool.put(ptr) + } else { + panic("attempting to dump a slice not created by the pool") + } + } +} + +func (p *Pool) addInUse(ptr *fr.Element, pool *sizedPool) { + pcs := make([]uintptr, 2) + n := runtime.Callers(3, pcs) + + if prevPcs, ok := p.inUse.Load(ptr); ok { // TODO: remove if unnecessary for security + panic(fmt.Errorf("re-allocated non-dumped slice, previously allocated at %v", runtime.CallersFrames(prevPcs.(inUseData).allocatedFor))) + } + p.inUse.Store(ptr, inUseData{ + allocatedFor: pcs[:n], + pool: pool, + }) +} + +func printFrame(frame runtime.Frame) { + fmt.Printf("\t%s line %d, function %s\n", frame.File, frame.Line, frame.Function) +} + +func (p *Pool) printInUse() { + fmt.Println("slices never dumped allocated at:") + p.inUse.Range(func(_, pcs any) bool { + fmt.Println("-------------------------") + + var frame runtime.Frame + frames := runtime.CallersFrames(pcs.(inUseData).allocatedFor) + more := true + for more { + frame, more = frames.Next() + printFrame(frame) + } + return true + }) +} + +type poolStats struct { + Used int + Allocated int + ReuseRate float64 + InUse int + GreatestNUsed int + SmallestNUsed int +} + +type poolsStats struct { + SubPools []poolStats + InUse int +} + +func (s *poolStats) make(n int) { + s.Used++ + s.InUse++ + if n > s.GreatestNUsed { + s.GreatestNUsed = n + } + if s.SmallestNUsed == 0 || s.SmallestNUsed > n { + s.SmallestNUsed = n + } +} + +func (s *poolStats) dump() { + s.InUse-- +} + +func (s *poolStats) finalize() { + s.ReuseRate = float64(s.Used) / float64(s.Allocated) +} + +func getDataPointer(slice []fr.Element) *fr.Element { + return (*fr.Element)(unsafe.SliceData(slice)) +} + +func (p *Pool) PrintPoolStats() { + InUse := 0 + subStats := make([]poolStats, len(p.subPools)) + for i := range p.subPools { + subPool := &p.subPools[i] + subPool.stats.finalize() + subStats[i] = subPool.stats + InUse += subPool.stats.InUse + } + + stats := poolsStats{ + SubPools: subStats, + InUse: InUse, + } + serialized, _ := json.MarshalIndent(stats, "", " ") + fmt.Println(string(serialized)) + p.printInUse() +} + +func (p *Pool) Clone(slice []fr.Element) []fr.Element { + res := p.Make(len(slice)) + copy(res, slice) + return res +} diff --git a/ecc/grumpkin/fr/poseidon2/doc.go b/ecc/grumpkin/fr/poseidon2/doc.go new file mode 100644 index 0000000000..c2f3d4688e --- /dev/null +++ b/ecc/grumpkin/fr/poseidon2/doc.go @@ -0,0 +1,17 @@ +// Copyright 2020-2025 Consensys Software Inc. +// Licensed under the Apache License, Version 2.0. See the LICENSE file for details. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +// Package poseidon2 implements the Poseidon2 permutation +// +// Poseidon2 permutation is a cryptographic permutation for algebraic hashes. +// See the [original paper] by Grassi, Khovratovich and Schofnegger for the full details. +// +// This implementation is based on the [reference implementation] from +// HorizenLabs. See the [specifications] for parameter choices. +// +// [reference implementation]: https://github.com/HorizenLabs/poseidon2/blob/main/plain_implementations/src/poseidon2/poseidon2.rs +// [specifications]: https://github.com/argumentcomputer/neptune/blob/main/spec/poseidon_spec.pdf +// [original paper]: https://eprint.iacr.org/2023/323.pdf +package poseidon2 diff --git a/ecc/grumpkin/fr/poseidon2/poseidon2.go b/ecc/grumpkin/fr/poseidon2/poseidon2.go new file mode 100644 index 0000000000..61d082c7b6 --- /dev/null +++ b/ecc/grumpkin/fr/poseidon2/poseidon2.go @@ -0,0 +1,320 @@ +// Copyright 2020-2025 Consensys Software Inc. +// Licensed under the Apache License, Version 2.0. See the LICENSE file for details. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package poseidon2 + +import ( + "errors" + "fmt" + + "golang.org/x/crypto/sha3" + + "github.com/consensys/gnark-crypto/ecc/grumpkin/fr" +) + +var ( + ErrInvalidSizebuffer = errors.New("the size of the input should match the size of the hash buffer") +) + +const ( + // d is the degree of the sBox + d = 5 +) + +// DegreeSBox returns the degree of the sBox function used in the Poseidon2 +// permutation. +func DegreeSBox() int { + return d +} + +// Parameters describing the Poseidon2 implementation. Use [NewParameters] or +// [NewParametersWithSeed] to initialize a new set of parameters to +// deterministically precompute the round keys. +type Parameters struct { + // len(preimage)+len(digest)=len(preimage)+ceil(log(2*/r)) + Width int + + // number of full rounds (even number) + NbFullRounds int + + // number of partial rounds + NbPartialRounds int + + // derived round keys from the parameter seed and curve ID + RoundKeys [][]fr.Element +} + +// NewParameters returns a new set of parameters for the Poseidon2 permutation. +// After creating the parameters, the round keys are initialized deterministically +// from the seed which is a digest of the parameters and curve ID. +func NewParameters(width, nbFullRounds, nbPartialRounds int) *Parameters { + p := Parameters{Width: width, NbFullRounds: nbFullRounds, NbPartialRounds: nbPartialRounds} + seed := p.String() + p.initRC(seed) + return &p +} + +// NewParametersWithSeed returns a new set of parameters for the Poseidon2 permutation. +// After creating the parameters, the round keys are initialized deterministically +// from the given seed. +func NewParametersWithSeed(width, nbFullRounds, nbPartialRounds int, seed string) *Parameters { + p := Parameters{Width: width, NbFullRounds: nbFullRounds, NbPartialRounds: nbPartialRounds} + p.initRC(seed) + return &p +} + +// String returns a string representation of the parameters. It is unique for +// specific parameters and curve. +func (p *Parameters) String() string { + return fmt.Sprintf("Poseidon2-GRUMPKIN[t=%d,rF=%d,rP=%d,d=%d]", p.Width, p.NbFullRounds, p.NbPartialRounds, d) +} + +// initRC initiate round keys. Only one entry is non zero for the internal +// rounds, cf https://eprint.iacr.org/2023/323.pdf page 9 +func (p *Parameters) initRC(seed string) { + + bseed := ([]byte)(seed) + hash := sha3.NewLegacyKeccak256() + _, _ = hash.Write(bseed) + rnd := hash.Sum(nil) // pre hash before use + hash.Reset() + _, _ = hash.Write(rnd) + + roundKeys := make([][]fr.Element, p.NbFullRounds+p.NbPartialRounds) + for i := 0; i < p.NbFullRounds/2; i++ { + roundKeys[i] = make([]fr.Element, p.Width) + for j := 0; j < p.Width; j++ { + rnd = hash.Sum(nil) + roundKeys[i][j].SetBytes(rnd) + hash.Reset() + _, _ = hash.Write(rnd) + } + } + for i := p.NbFullRounds / 2; i < p.NbPartialRounds+p.NbFullRounds/2; i++ { + roundKeys[i] = make([]fr.Element, 1) + rnd = hash.Sum(nil) + roundKeys[i][0].SetBytes(rnd) + hash.Reset() + _, _ = hash.Write(rnd) + } + for i := p.NbPartialRounds + p.NbFullRounds/2; i < p.NbPartialRounds+p.NbFullRounds; i++ { + roundKeys[i] = make([]fr.Element, p.Width) + for j := 0; j < p.Width; j++ { + rnd = hash.Sum(nil) + roundKeys[i][j].SetBytes(rnd) + hash.Reset() + _, _ = hash.Write(rnd) + } + } + p.RoundKeys = roundKeys +} + +// Permutation stores the buffer of the Poseidon2 permutation and provides +// Poseidon2 permutation methods on the buffer +type Permutation struct { + // parameters describing the instance + params *Parameters +} + +// NewPermutation returns a new Poseidon2 permutation instance. +func NewPermutation(t, rf, rp int) *Permutation { + if t < 2 || t > 3 { + panic("only t=2,3 is supported") + } + params := NewParameters(t, rf, rp) + res := &Permutation{params: params} + return res +} + +// NewPermutationWithSeed returns a new Poseidon2 permutation instance with a +// given seed. +func NewPermutationWithSeed(t, rf, rp int, seed string) *Permutation { + if t < 2 || t > 3 { + panic("only t=2,3 is supported") + } + params := NewParametersWithSeed(t, rf, rp, seed) + res := &Permutation{params: params} + return res +} + +// sBox applies the sBox on buffer[index] +func (h *Permutation) sBox(index int, input []fr.Element) { + var tmp fr.Element + tmp.Set(&input[index]) + + // sbox degree is 5 + input[index].Square(&input[index]). + Square(&input[index]). + Mul(&input[index], &tmp) + +} + +// matMulM4 computes +// s <- M4*s +// where M4= +// (5 7 1 3) +// (4 6 1 1) +// (1 3 5 7) +// (1 1 4 6) +// on chunks of 4 elemts on each part of the buffer +// see https://eprint.iacr.org/2023/323.pdf appendix B for the addition chain +func (h *Permutation) matMulM4InPlace(s []fr.Element) { + c := len(s) / 4 + for i := 0; i < c; i++ { + var t0, t1, t2, t3, t4, t5, t6, t7 fr.Element + t0.Add(&s[4*i], &s[4*i+1]) // s0+s1 + t1.Add(&s[4*i+2], &s[4*i+3]) // s2+s3 + t2.Double(&s[4*i+1]).Add(&t2, &t1) // 2s1+t1 + t3.Double(&s[4*i+3]).Add(&t3, &t0) // 2s3+t0 + t4.Double(&t1).Double(&t4).Add(&t4, &t3) // 4t1+t3 + t5.Double(&t0).Double(&t5).Add(&t5, &t2) // 4t0+t2 + t6.Add(&t3, &t5) // t3+t4 + t7.Add(&t2, &t4) // t2+t4 + s[4*i].Set(&t6) + s[4*i+1].Set(&t5) + s[4*i+2].Set(&t7) + s[4*i+3].Set(&t4) + } +} + +// when T=2,3 the buffer is multiplied by circ(2,1) and circ(2,1,1) +// see https://eprint.iacr.org/2023/323.pdf page 15, case T=2,3 +// +// when T=0[4], the buffer is multiplied by circ(2M4,M4,..,M4) +// see https://eprint.iacr.org/2023/323.pdf +func (h *Permutation) matMulExternalInPlace(input []fr.Element) { + + if h.params.Width == 2 { + var tmp fr.Element + tmp.Add(&input[0], &input[1]) + input[0].Add(&tmp, &input[0]) + input[1].Add(&tmp, &input[1]) + } else if h.params.Width == 3 { + var tmp fr.Element + tmp.Add(&input[0], &input[1]). + Add(&tmp, &input[2]) + input[0].Add(&tmp, &input[0]) + input[1].Add(&tmp, &input[1]) + input[2].Add(&tmp, &input[2]) + } else if h.params.Width == 4 { + h.matMulM4InPlace(input) + } else { + // at this stage t is supposed to be a multiple of 4 + // the MDS matrix is circ(2M4,M4,..,M4) + h.matMulM4InPlace(input) + tmp := make([]fr.Element, 4) + for i := 0; i < h.params.Width/4; i++ { + tmp[0].Add(&tmp[0], &input[4*i]) + tmp[1].Add(&tmp[1], &input[4*i+1]) + tmp[2].Add(&tmp[2], &input[4*i+2]) + tmp[3].Add(&tmp[3], &input[4*i+3]) + } + for i := 0; i < h.params.Width/4; i++ { + input[4*i].Add(&input[4*i], &tmp[0]) + input[4*i+1].Add(&input[4*i], &tmp[1]) + input[4*i+2].Add(&input[4*i], &tmp[2]) + input[4*i+3].Add(&input[4*i], &tmp[3]) + } + } +} + +// when T=2,3 the matrix are respectibely [[2,1][1,3]] and [[2,1,1][1,2,1][1,1,3]] +// otherwise the matrix is filled with ones except on the diagonal, +func (h *Permutation) matMulInternalInPlace(input []fr.Element) { + switch h.params.Width { + case 2: + var sum fr.Element + sum.Add(&input[0], &input[1]) + input[0].Add(&input[0], &sum) + input[1].Double(&input[1]).Add(&input[1], &sum) + case 3: + var sum fr.Element + sum.Add(&input[0], &input[1]).Add(&sum, &input[2]) + input[0].Add(&input[0], &sum) + input[1].Add(&input[1], &sum) + input[2].Double(&input[2]).Add(&input[2], &sum) + default: + // var sum fr.Element + // sum.Set(&input[0]) + // for i := 1; i < h.params.t; i++ { + // sum.Add(&sum, &input[i]) + // } + // for i := 0; i < h.params.t; i++ { + // input[i].Mul(&input[i], &h.params.diagInternalMatrices[i]). + // Add(&input[i], &sum) + // } + panic("only T=2,3 is supported") + } +} + +// addRoundKeyInPlace adds the round-th key to the buffer +func (h *Permutation) addRoundKeyInPlace(round int, input []fr.Element) { + for i := 0; i < len(h.params.RoundKeys[round]); i++ { + input[i].Add(&input[i], &h.params.RoundKeys[round][i]) + } +} + +func (h *Permutation) BlockSize() int { + return fr.Bytes +} + +// Permutation applies the permutation on input, and stores the result in input. +func (h *Permutation) Permutation(input []fr.Element) error { + if len(input) != h.params.Width { + return ErrInvalidSizebuffer + } + + // external matrix multiplication, cf https://eprint.iacr.org/2023/323.pdf page 14 (part 6) + h.matMulExternalInPlace(input) + + rf := h.params.NbFullRounds / 2 + for i := 0; i < rf; i++ { + // one round = matMulExternal(sBox_Full(addRoundKey)) + h.addRoundKeyInPlace(i, input) + for j := 0; j < h.params.Width; j++ { + h.sBox(j, input) + } + h.matMulExternalInPlace(input) + } + + for i := rf; i < rf+h.params.NbPartialRounds; i++ { + // one round = matMulInternal(sBox_sparse(addRoundKey)) + h.addRoundKeyInPlace(i, input) + h.sBox(0, input) + h.matMulInternalInPlace(input) + } + for i := rf + h.params.NbPartialRounds; i < h.params.NbFullRounds+h.params.NbPartialRounds; i++ { + // one round = matMulExternal(sBox_Full(addRoundKey)) + h.addRoundKeyInPlace(i, input) + for j := 0; j < h.params.Width; j++ { + h.sBox(j, input) + } + h.matMulExternalInPlace(input) + } + + return nil +} + +// Compress applies the permutation on left and right and returns the right lane +// of the result. Panics if the permutation instance is not initialized with a +// width of 2. +func (h *Permutation) Compress(left []byte, right []byte) ([]byte, error) { + if h.params.Width != 2 { + return nil, errors.New("need a 2-1 function") + } + var x [2]fr.Element + + if err := x[0].SetBytesCanonical(left); err != nil { + return nil, err + } + if err := x[1].SetBytesCanonical(right); err != nil { + return nil, err + } + if err := h.Permutation(x[:]); err != nil { + return nil, err + } + res := x[1].Bytes() + return res[:], nil +} diff --git a/ecc/grumpkin/fr/poseidon2/poseidon2_test.go b/ecc/grumpkin/fr/poseidon2/poseidon2_test.go new file mode 100644 index 0000000000..a6b5d587aa --- /dev/null +++ b/ecc/grumpkin/fr/poseidon2/poseidon2_test.go @@ -0,0 +1,68 @@ +// Copyright 2020-2025 Consensys Software Inc. +// Licensed under the Apache License, Version 2.0. See the LICENSE file for details. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package poseidon2 + +import ( + "testing" + + "github.com/consensys/gnark-crypto/ecc/grumpkin/fr" +) + +func TestExternalMatrix(t *testing.T) { + t.Skip("skipping test - it is initialized for width=4 for which we don't have the diagonal matrix") + + var expected [4][4]fr.Element + expected[0][0].SetUint64(5) + expected[0][1].SetUint64(4) + expected[0][2].SetUint64(1) + expected[0][3].SetUint64(1) + + expected[1][0].SetUint64(7) + expected[1][1].SetUint64(6) + expected[1][2].SetUint64(3) + expected[1][3].SetUint64(1) + + expected[2][0].SetUint64(1) + expected[2][1].SetUint64(1) + expected[2][2].SetUint64(5) + expected[2][3].SetUint64(4) + + expected[3][0].SetUint64(3) + expected[3][1].SetUint64(1) + expected[3][2].SetUint64(7) + expected[3][3].SetUint64(6) + + h := NewPermutation(4, 8, 56) + var tmp [4]fr.Element + for i := 0; i < 4; i++ { + for j := 0; j < 4; j++ { + tmp[j].SetUint64(0) + if i == j { + tmp[j].SetOne() + } + } + // h.Write(tmp[:]) + h.matMulExternalInPlace(tmp[:]) + for j := 0; j < 4; j++ { + if !tmp[j].Equal(&expected[i][j]) { + t.Fatal("error matMul4") + } + } + } + +} + +func BenchmarkPoseidon2(b *testing.B) { + h := NewPermutation(3, 8, 56) + var tmp [3]fr.Element + tmp[0].SetRandom() + tmp[1].SetRandom() + tmp[2].SetRandom() + b.ResetTimer() + for i := 0; i < b.N; i++ { + h.Permutation(tmp[:]) + } +} diff --git a/ecc/grumpkin/fr/sumcheck/sumcheck.go b/ecc/grumpkin/fr/sumcheck/sumcheck.go new file mode 100644 index 0000000000..e901d8479f --- /dev/null +++ b/ecc/grumpkin/fr/sumcheck/sumcheck.go @@ -0,0 +1,170 @@ +// Copyright 2020-2025 Consensys Software Inc. +// Licensed under the Apache License, Version 2.0. See the LICENSE file for details. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package sumcheck + +import ( + "errors" + "github.com/consensys/gnark-crypto/ecc/grumpkin/fr" + "github.com/consensys/gnark-crypto/ecc/grumpkin/fr/polynomial" + fiatshamir "github.com/consensys/gnark-crypto/fiat-shamir" + "strconv" +) + +// This does not make use of parallelism and represents polynomials as lists of coefficients +// It is currently geared towards arithmetic hashes. Once we have a more unified hash function interface, this can be generified. + +// Claims to a multi-sumcheck statement. i.e. one of the form ∑_{0≤i<2ⁿ} fⱼ(i) = cⱼ for 1 ≤ j ≤ m. +// Later evolving into a claim of the form gⱼ = ∑_{0≤i<2ⁿ⁻ʲ} g(r₁, r₂, ..., rⱼ₋₁, Xⱼ, i...) +type Claims interface { + Combine(a fr.Element) polynomial.Polynomial // Combine into the 0ᵗʰ sumcheck subclaim. Create g := ∑_{1≤j≤m} aʲ⁻¹fⱼ for which now we seek to prove ∑_{0≤i<2ⁿ} g(i) = c := ∑_{1≤j≤m} aʲ⁻¹cⱼ. Return g₁. + Next(fr.Element) polynomial.Polynomial // Return the evaluations gⱼ(k) for 1 ≤ k < degⱼ(g). Update the claim to gⱼ₊₁ for the input value as rⱼ + VarsNum() int //number of variables + ClaimsNum() int //number of claims + ProveFinalEval(r []fr.Element) interface{} //in case it is difficult for the verifier to compute g(r₁, ..., rₙ) on its own, the prover can provide the value and a proof +} + +// LazyClaims is the Claims data structure on the verifier side. It is "lazy" in that it has to compute fewer things. +type LazyClaims interface { + ClaimsNum() int // ClaimsNum = m + VarsNum() int // VarsNum = n + CombinedSum(a fr.Element) fr.Element // CombinedSum returns c = ∑_{1≤j≤m} aʲ⁻¹cⱼ + Degree(i int) int //Degree of the total claim in the i'th variable + VerifyFinalEval(r []fr.Element, combinationCoeff fr.Element, purportedValue fr.Element, proof interface{}) error +} + +// Proof of a multi-sumcheck statement. +type Proof struct { + PartialSumPolys []polynomial.Polynomial `json:"partialSumPolys"` + FinalEvalProof interface{} `json:"finalEvalProof"` //in case it is difficult for the verifier to compute g(r₁, ..., rₙ) on its own, the prover can provide the value and a proof +} + +func setupTranscript(claimsNum int, varsNum int, settings *fiatshamir.Settings) (challengeNames []string, err error) { + numChallenges := varsNum + if claimsNum >= 2 { + numChallenges++ + } + challengeNames = make([]string, numChallenges) + if claimsNum >= 2 { + challengeNames[0] = settings.Prefix + "comb" + } + prefix := settings.Prefix + "pSP." + for i := 0; i < varsNum; i++ { + challengeNames[i+numChallenges-varsNum] = prefix + strconv.Itoa(i) + } + if settings.Transcript == nil { + transcript := fiatshamir.NewTranscript(settings.Hash, challengeNames...) + settings.Transcript = transcript + } + + for i := range settings.BaseChallenges { + if err = settings.Transcript.Bind(challengeNames[0], settings.BaseChallenges[i]); err != nil { + return + } + } + return +} + +func next(transcript *fiatshamir.Transcript, bindings []fr.Element, remainingChallengeNames *[]string) (fr.Element, error) { + challengeName := (*remainingChallengeNames)[0] + for i := range bindings { + bytes := bindings[i].Bytes() + if err := transcript.Bind(challengeName, bytes[:]); err != nil { + return fr.Element{}, err + } + } + var res fr.Element + bytes, err := transcript.ComputeChallenge(challengeName) + res.SetBytes(bytes) + + *remainingChallengeNames = (*remainingChallengeNames)[1:] + + return res, err +} + +// Prove create a non-interactive sumcheck proof +func Prove(claims Claims, transcriptSettings fiatshamir.Settings) (Proof, error) { + + var proof Proof + remainingChallengeNames, err := setupTranscript(claims.ClaimsNum(), claims.VarsNum(), &transcriptSettings) + transcript := transcriptSettings.Transcript + if err != nil { + return proof, err + } + + var combinationCoeff fr.Element + if claims.ClaimsNum() >= 2 { + if combinationCoeff, err = next(transcript, []fr.Element{}, &remainingChallengeNames); err != nil { + return proof, err + } + } + + varsNum := claims.VarsNum() + proof.PartialSumPolys = make([]polynomial.Polynomial, varsNum) + proof.PartialSumPolys[0] = claims.Combine(combinationCoeff) + challenges := make([]fr.Element, varsNum) + + for j := 0; j+1 < varsNum; j++ { + if challenges[j], err = next(transcript, proof.PartialSumPolys[j], &remainingChallengeNames); err != nil { + return proof, err + } + proof.PartialSumPolys[j+1] = claims.Next(challenges[j]) + } + + if challenges[varsNum-1], err = next(transcript, proof.PartialSumPolys[varsNum-1], &remainingChallengeNames); err != nil { + return proof, err + } + + proof.FinalEvalProof = claims.ProveFinalEval(challenges) + + return proof, nil +} + +func Verify(claims LazyClaims, proof Proof, transcriptSettings fiatshamir.Settings) error { + remainingChallengeNames, err := setupTranscript(claims.ClaimsNum(), claims.VarsNum(), &transcriptSettings) + transcript := transcriptSettings.Transcript + if err != nil { + return err + } + + var combinationCoeff fr.Element + + if claims.ClaimsNum() >= 2 { + if combinationCoeff, err = next(transcript, []fr.Element{}, &remainingChallengeNames); err != nil { + return err + } + } + + r := make([]fr.Element, claims.VarsNum()) + + // Just so that there is enough room for gJ to be reused + maxDegree := claims.Degree(0) + for j := 1; j < claims.VarsNum(); j++ { + if d := claims.Degree(j); d > maxDegree { + maxDegree = d + } + } + gJ := make(polynomial.Polynomial, maxDegree+1) //At the end of iteration j, gJ = ∑_{i < 2ⁿ⁻ʲ⁻¹} g(X₁, ..., Xⱼ₊₁, i...) NOTE: n is shorthand for claims.VarsNum() + gJR := claims.CombinedSum(combinationCoeff) // At the beginning of iteration j, gJR = ∑_{i < 2ⁿ⁻ʲ} g(r₁, ..., rⱼ, i...) + + for j := 0; j < claims.VarsNum(); j++ { + if len(proof.PartialSumPolys[j]) != claims.Degree(j) { + return errors.New("malformed proof") + } + copy(gJ[1:], proof.PartialSumPolys[j]) + gJ[0].Sub(&gJR, &proof.PartialSumPolys[j][0]) // Requirement that gⱼ(0) + gⱼ(1) = gⱼ₋₁(r) + // gJ is ready + + //Prepare for the next iteration + if r[j], err = next(transcript, proof.PartialSumPolys[j], &remainingChallengeNames); err != nil { + return err + } + // This is an extremely inefficient way of interpolating. TODO: Interpolate without symbolically computing a polynomial + gJCoeffs := polynomial.InterpolateOnRange(gJ[:(claims.Degree(j) + 1)]) + gJR = gJCoeffs.Eval(&r[j]) + } + + return claims.VerifyFinalEval(r, combinationCoeff, gJR, proof.FinalEvalProof) +} diff --git a/ecc/grumpkin/fr/sumcheck/sumcheck_test.go b/ecc/grumpkin/fr/sumcheck/sumcheck_test.go new file mode 100644 index 0000000000..e7cc105811 --- /dev/null +++ b/ecc/grumpkin/fr/sumcheck/sumcheck_test.go @@ -0,0 +1,150 @@ +// Copyright 2020-2025 Consensys Software Inc. +// Licensed under the Apache License, Version 2.0. See the LICENSE file for details. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package sumcheck + +import ( + "fmt" + "github.com/consensys/gnark-crypto/ecc/grumpkin/fr" + "github.com/consensys/gnark-crypto/ecc/grumpkin/fr/polynomial" + "github.com/consensys/gnark-crypto/ecc/grumpkin/fr/test_vector_utils" + fiatshamir "github.com/consensys/gnark-crypto/fiat-shamir" + "github.com/stretchr/testify/assert" + "hash" + "math/bits" + "strings" + "testing" +) + +type singleMultilinClaim struct { + g polynomial.MultiLin +} + +func (c singleMultilinClaim) ProveFinalEval(r []fr.Element) interface{} { + return nil // verifier can compute the final eval itself +} + +func (c singleMultilinClaim) VarsNum() int { + return bits.TrailingZeros(uint(len(c.g))) +} + +func (c singleMultilinClaim) ClaimsNum() int { + return 1 +} + +func sumForX1One(g polynomial.MultiLin) polynomial.Polynomial { + sum := g[len(g)/2] + for i := len(g)/2 + 1; i < len(g); i++ { + sum.Add(&sum, &g[i]) + } + return []fr.Element{sum} +} + +func (c singleMultilinClaim) Combine(fr.Element) polynomial.Polynomial { + return sumForX1One(c.g) +} + +func (c *singleMultilinClaim) Next(r fr.Element) polynomial.Polynomial { + c.g.Fold(r) + return sumForX1One(c.g) +} + +type singleMultilinLazyClaim struct { + g polynomial.MultiLin + claimedSum fr.Element +} + +func (c singleMultilinLazyClaim) VerifyFinalEval(r []fr.Element, combinationCoeff fr.Element, purportedValue fr.Element, proof interface{}) error { + val := c.g.Evaluate(r, nil) + if val.Equal(&purportedValue) { + return nil + } + return fmt.Errorf("mismatch") +} + +func (c singleMultilinLazyClaim) CombinedSum(combinationCoeffs fr.Element) fr.Element { + return c.claimedSum +} + +func (c singleMultilinLazyClaim) Degree(i int) int { + return 1 +} + +func (c singleMultilinLazyClaim) ClaimsNum() int { + return 1 +} + +func (c singleMultilinLazyClaim) VarsNum() int { + return bits.TrailingZeros(uint(len(c.g))) +} + +func testSumcheckSingleClaimMultilin(polyInt []uint64, hashGenerator func() hash.Hash) error { + poly := make(polynomial.MultiLin, len(polyInt)) + for i, n := range polyInt { + poly[i].SetUint64(n) + } + + claim := singleMultilinClaim{g: poly.Clone()} + + proof, err := Prove(&claim, fiatshamir.WithHash(hashGenerator())) + if err != nil { + return err + } + + var sb strings.Builder + for _, p := range proof.PartialSumPolys { + + sb.WriteString("\t{") + for i := 0; i < len(p); i++ { + sb.WriteString(p[i].String()) + if i+1 < len(p) { + sb.WriteString(", ") + } + } + sb.WriteString("}\n") + } + + lazyClaim := singleMultilinLazyClaim{g: poly, claimedSum: poly.Sum()} + if err = Verify(lazyClaim, proof, fiatshamir.WithHash(hashGenerator())); err != nil { + return err + } + + proof.PartialSumPolys[0][0].Add(&proof.PartialSumPolys[0][0], test_vector_utils.ToElement(1)) + lazyClaim = singleMultilinLazyClaim{g: poly, claimedSum: poly.Sum()} + if Verify(lazyClaim, proof, fiatshamir.WithHash(hashGenerator())) == nil { + return fmt.Errorf("bad proof accepted") + } + return nil +} + +func TestSumcheckDeterministicHashSingleClaimMultilin(t *testing.T) { + //printMsws(36) + + polys := [][]uint64{ + {1, 2, 3, 4}, // 1 + 2X₁ + X₂ + {1, 2, 3, 4, 5, 6, 7, 8}, // 1 + 4X₁ + 2X₂ + X₃ + {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}, // 1 + 8X₁ + 4X₂ + 2X₃ + X₄ + } + + const MaxStep = 4 + const MaxStart = 4 + hashGens := make([]func() hash.Hash, 0, MaxStart*MaxStep) + + for step := 0; step < MaxStep; step++ { + for startState := 0; startState < MaxStart; startState++ { + if step == 0 && startState == 1 { // unlucky case where a bad proof would be accepted + continue + } + hashGens = append(hashGens, test_vector_utils.NewMessageCounterGenerator(startState, step)) + } + } + + for _, poly := range polys { + for _, hashGen := range hashGens { + assert.NoError(t, testSumcheckSingleClaimMultilin(poly, hashGen), + "failed with poly %v and hashGen %v", poly, hashGen()) + } + } +} diff --git a/ecc/grumpkin/fr/test_vector_utils/test_vector_utils.go b/ecc/grumpkin/fr/test_vector_utils/test_vector_utils.go new file mode 100644 index 0000000000..df83ecc9b9 --- /dev/null +++ b/ecc/grumpkin/fr/test_vector_utils/test_vector_utils.go @@ -0,0 +1,216 @@ +// Copyright 2020-2025 Consensys Software Inc. +// Licensed under the Apache License, Version 2.0. See the LICENSE file for details. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package test_vector_utils + +import ( + "fmt" + "github.com/consensys/gnark-crypto/ecc/grumpkin/fr" + "github.com/consensys/gnark-crypto/ecc/grumpkin/fr/polynomial" + "hash" + "reflect" + "strings" +) + +func ToElement(i int64) *fr.Element { + var res fr.Element + res.SetInt64(i) + return &res +} + +type HashDescription map[string]interface{} + +func HashFromDescription(d HashDescription) (hash.Hash, error) { + if _type, ok := d["type"]; ok { + switch _type { + case "const": + startState := int64(d["val"].(float64)) + return &MessageCounter{startState: startState, step: 0, state: startState}, nil + default: + return nil, fmt.Errorf("unknown fake hash type \"%s\"", _type) + } + } + return nil, fmt.Errorf("hash description missing type") +} + +type MessageCounter struct { + startState int64 + state int64 + step int64 +} + +func (m *MessageCounter) Write(p []byte) (n int, err error) { + inputBlockSize := (len(p)-1)/fr.Bytes + 1 + m.state += int64(inputBlockSize) * m.step + return len(p), nil +} + +func (m *MessageCounter) Sum(b []byte) []byte { + inputBlockSize := (len(b)-1)/fr.Bytes + 1 + resI := m.state + int64(inputBlockSize)*m.step + var res fr.Element + res.SetInt64(int64(resI)) + resBytes := res.Bytes() + return resBytes[:] +} + +func (m *MessageCounter) Reset() { + m.state = m.startState +} + +func (m *MessageCounter) Size() int { + return fr.Bytes +} + +func (m *MessageCounter) BlockSize() int { + return fr.Bytes +} + +func NewMessageCounter(startState, step int) hash.Hash { + transcript := &MessageCounter{startState: int64(startState), state: int64(startState), step: int64(step)} + return transcript +} + +func NewMessageCounterGenerator(startState, step int) func() hash.Hash { + return func() hash.Hash { + return NewMessageCounter(startState, step) + } +} + +type ListHash []fr.Element + +func (h *ListHash) Write(p []byte) (n int, err error) { + return len(p), nil +} + +func (h *ListHash) Sum(b []byte) []byte { + res := (*h)[0].Bytes() + *h = (*h)[1:] + return res[:] +} + +func (h *ListHash) Reset() { +} + +func (h *ListHash) Size() int { + return fr.Bytes +} + +func (h *ListHash) BlockSize() int { + return fr.Bytes +} +func SetElement(z *fr.Element, value interface{}) (*fr.Element, error) { + + // TODO: Put this in element.SetString? + switch v := value.(type) { + case string: + + if sep := strings.Split(v, "/"); len(sep) == 2 { + var denom fr.Element + if _, err := z.SetString(sep[0]); err != nil { + return nil, err + } + if _, err := denom.SetString(sep[1]); err != nil { + return nil, err + } + denom.Inverse(&denom) + z.Mul(z, &denom) + return z, nil + } + + case float64: + asInt := int64(v) + if float64(asInt) != v { + return nil, fmt.Errorf("cannot currently parse float") + } + z.SetInt64(asInt) + return z, nil + } + + return z.SetInterface(value) +} + +func SliceToElementSlice[T any](slice []T) ([]fr.Element, error) { + elementSlice := make([]fr.Element, len(slice)) + for i, v := range slice { + if _, err := SetElement(&elementSlice[i], v); err != nil { + return nil, err + } + } + return elementSlice, nil +} + +func SliceEquals(a []fr.Element, b []fr.Element) error { + if len(a) != len(b) { + return fmt.Errorf("length mismatch %d≠%d", len(a), len(b)) + } + for i := range a { + if !a[i].Equal(&b[i]) { + return fmt.Errorf("at index %d: %s ≠ %s", i, a[i].String(), b[i].String()) + } + } + return nil +} + +func SliceSliceEquals(a [][]fr.Element, b [][]fr.Element) error { + if len(a) != len(b) { + return fmt.Errorf("length mismatch %d≠%d", len(a), len(b)) + } + for i := range a { + if err := SliceEquals(a[i], b[i]); err != nil { + return fmt.Errorf("at index %d: %w", i, err) + } + } + return nil +} + +func PolynomialSliceEquals(a []polynomial.Polynomial, b []polynomial.Polynomial) error { + if len(a) != len(b) { + return fmt.Errorf("length mismatch %d≠%d", len(a), len(b)) + } + for i := range a { + if err := SliceEquals(a[i], b[i]); err != nil { + return fmt.Errorf("at index %d: %w", i, err) + } + } + return nil +} + +func ElementToInterface(x *fr.Element) interface{} { + if i := x.BigInt(nil); i != nil { + return i + } + return x.Text(10) +} + +func ElementSliceToInterfaceSlice(x interface{}) []interface{} { + if x == nil { + return nil + } + + X := reflect.ValueOf(x) + + res := make([]interface{}, X.Len()) + for i := range res { + xI := X.Index(i).Interface().(fr.Element) + res[i] = ElementToInterface(&xI) + } + return res +} + +func ElementSliceSliceToInterfaceSliceSlice(x interface{}) [][]interface{} { + if x == nil { + return nil + } + + X := reflect.ValueOf(x) + + res := make([][]interface{}, X.Len()) + for i := range res { + res[i] = ElementSliceToInterfaceSlice(X.Index(i).Interface()) + } + + return res +} diff --git a/ecc/grumpkin/fr/vector.go b/ecc/grumpkin/fr/vector.go new file mode 100644 index 0000000000..7792a6f73b --- /dev/null +++ b/ecc/grumpkin/fr/vector.go @@ -0,0 +1,295 @@ +// Copyright 2020-2025 Consensys Software Inc. +// Licensed under the Apache License, Version 2.0. See the LICENSE file for details. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package fr + +import ( + "bytes" + "encoding/binary" + "fmt" + "io" + "runtime" + "strings" + "sync" + "sync/atomic" + "unsafe" +) + +// Vector represents a slice of Element. +// +// It implements the following interfaces: +// - Stringer +// - io.WriterTo +// - io.ReaderFrom +// - encoding.BinaryMarshaler +// - encoding.BinaryUnmarshaler +// - sort.Interface +type Vector []Element + +// MarshalBinary implements encoding.BinaryMarshaler +func (vector *Vector) MarshalBinary() (data []byte, err error) { + var buf bytes.Buffer + + if _, err = vector.WriteTo(&buf); err != nil { + return + } + return buf.Bytes(), nil +} + +// UnmarshalBinary implements encoding.BinaryUnmarshaler +func (vector *Vector) UnmarshalBinary(data []byte) error { + r := bytes.NewReader(data) + _, err := vector.ReadFrom(r) + return err +} + +// WriteTo implements io.WriterTo and writes a vector of big endian encoded Element. +// Length of the vector is encoded as a uint32 on the first 4 bytes. +func (vector *Vector) WriteTo(w io.Writer) (int64, error) { + // encode slice length + if err := binary.Write(w, binary.BigEndian, uint32(len(*vector))); err != nil { + return 0, err + } + + n := int64(4) + + var buf [Bytes]byte + for i := 0; i < len(*vector); i++ { + BigEndian.PutElement(&buf, (*vector)[i]) + m, err := w.Write(buf[:]) + n += int64(m) + if err != nil { + return n, err + } + } + return n, nil +} + +// AsyncReadFrom reads a vector of big endian encoded Element. +// Length of the vector must be encoded as a uint32 on the first 4 bytes. +// It consumes the needed bytes from the reader and returns the number of bytes read and an error if any. +// It also returns a channel that will be closed when the validation is done. +// The validation consist of checking that the elements are smaller than the modulus, and +// converting them to montgomery form. +func (vector *Vector) AsyncReadFrom(r io.Reader) (int64, error, chan error) { + chErr := make(chan error, 1) + var buf [Bytes]byte + if read, err := io.ReadFull(r, buf[:4]); err != nil { + close(chErr) + return int64(read), err, chErr + } + sliceLen := binary.BigEndian.Uint32(buf[:4]) + + n := int64(4) + (*vector) = make(Vector, sliceLen) + if sliceLen == 0 { + close(chErr) + return n, nil, chErr + } + + bSlice := unsafe.Slice((*byte)(unsafe.Pointer(&(*vector)[0])), sliceLen*Bytes) + read, err := io.ReadFull(r, bSlice) + n += int64(read) + if err != nil { + close(chErr) + return n, err, chErr + } + + go func() { + var cptErrors uint64 + // process the elements in parallel + execute(int(sliceLen), func(start, end int) { + + var z Element + for i := start; i < end; i++ { + // we have to set vector[i] + bstart := i * Bytes + bend := bstart + Bytes + b := bSlice[bstart:bend] + z[0] = binary.BigEndian.Uint64(b[24:32]) + z[1] = binary.BigEndian.Uint64(b[16:24]) + z[2] = binary.BigEndian.Uint64(b[8:16]) + z[3] = binary.BigEndian.Uint64(b[0:8]) + + if !z.smallerThanModulus() { + atomic.AddUint64(&cptErrors, 1) + return + } + z.toMont() + (*vector)[i] = z + } + }) + + if cptErrors > 0 { + chErr <- fmt.Errorf("async read: %d elements failed validation", cptErrors) + } + close(chErr) + }() + return n, nil, chErr +} + +// ReadFrom implements io.ReaderFrom and reads a vector of big endian encoded Element. +// Length of the vector must be encoded as a uint32 on the first 4 bytes. +func (vector *Vector) ReadFrom(r io.Reader) (int64, error) { + + var buf [Bytes]byte + if read, err := io.ReadFull(r, buf[:4]); err != nil { + return int64(read), err + } + sliceLen := binary.BigEndian.Uint32(buf[:4]) + + n := int64(4) + (*vector) = make(Vector, sliceLen) + + for i := 0; i < int(sliceLen); i++ { + read, err := io.ReadFull(r, buf[:]) + n += int64(read) + if err != nil { + return n, err + } + (*vector)[i], err = BigEndian.Element(&buf) + if err != nil { + return n, err + } + } + + return n, nil +} + +// String implements fmt.Stringer interface +func (vector Vector) String() string { + var sbb strings.Builder + sbb.WriteByte('[') + for i := 0; i < len(vector); i++ { + sbb.WriteString(vector[i].String()) + if i != len(vector)-1 { + sbb.WriteByte(',') + } + } + sbb.WriteByte(']') + return sbb.String() +} + +// Len is the number of elements in the collection. +func (vector Vector) Len() int { + return len(vector) +} + +// Less reports whether the element with +// index i should sort before the element with index j. +func (vector Vector) Less(i, j int) bool { + return vector[i].Cmp(&vector[j]) == -1 +} + +// Swap swaps the elements with indexes i and j. +func (vector Vector) Swap(i, j int) { + vector[i], vector[j] = vector[j], vector[i] +} + +func addVecGeneric(res, a, b Vector) { + if len(a) != len(b) || len(a) != len(res) { + panic("vector.Add: vectors don't have the same length") + } + for i := 0; i < len(a); i++ { + res[i].Add(&a[i], &b[i]) + } +} + +func subVecGeneric(res, a, b Vector) { + if len(a) != len(b) || len(a) != len(res) { + panic("vector.Sub: vectors don't have the same length") + } + for i := 0; i < len(a); i++ { + res[i].Sub(&a[i], &b[i]) + } +} + +func scalarMulVecGeneric(res, a Vector, b *Element) { + if len(a) != len(res) { + panic("vector.ScalarMul: vectors don't have the same length") + } + for i := 0; i < len(a); i++ { + res[i].Mul(&a[i], b) + } +} + +func sumVecGeneric(res *Element, a Vector) { + for i := 0; i < len(a); i++ { + res.Add(res, &a[i]) + } +} + +func innerProductVecGeneric(res *Element, a, b Vector) { + if len(a) != len(b) { + panic("vector.InnerProduct: vectors don't have the same length") + } + var tmp Element + for i := 0; i < len(a); i++ { + tmp.Mul(&a[i], &b[i]) + res.Add(res, &tmp) + } +} + +func mulVecGeneric(res, a, b Vector) { + if len(a) != len(b) || len(a) != len(res) { + panic("vector.Mul: vectors don't have the same length") + } + for i := 0; i < len(a); i++ { + res[i].Mul(&a[i], &b[i]) + } +} + +// TODO @gbotrel make a public package out of that. +// execute executes the work function in parallel. +// this is copy paste from internal/parallel/parallel.go +// as we don't want to generate code importing internal/ +func execute(nbIterations int, work func(int, int), maxCpus ...int) { + + nbTasks := runtime.NumCPU() + if len(maxCpus) == 1 { + nbTasks = maxCpus[0] + if nbTasks < 1 { + nbTasks = 1 + } else if nbTasks > 512 { + nbTasks = 512 + } + } + + if nbTasks == 1 { + // no go routines + work(0, nbIterations) + return + } + + nbIterationsPerCpus := nbIterations / nbTasks + + // more CPUs than tasks: a CPU will work on exactly one iteration + if nbIterationsPerCpus < 1 { + nbIterationsPerCpus = 1 + nbTasks = nbIterations + } + + var wg sync.WaitGroup + + extraTasks := nbIterations - (nbTasks * nbIterationsPerCpus) + extraTasksOffset := 0 + + for i := 0; i < nbTasks; i++ { + wg.Add(1) + _start := i*nbIterationsPerCpus + extraTasksOffset + _end := _start + nbIterationsPerCpus + if extraTasks > 0 { + _end++ + extraTasks-- + extraTasksOffset++ + } + go func() { + work(_start, _end) + wg.Done() + }() + } + + wg.Wait() +} diff --git a/ecc/grumpkin/fr/vector_amd64.go b/ecc/grumpkin/fr/vector_amd64.go new file mode 100644 index 0000000000..85074e2f07 --- /dev/null +++ b/ecc/grumpkin/fr/vector_amd64.go @@ -0,0 +1,153 @@ +//go:build !purego + +// Copyright 2020-2025 Consensys Software Inc. +// Licensed under the Apache License, Version 2.0. See the LICENSE file for details. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package fr + +import ( + _ "github.com/consensys/gnark-crypto/field/asm/element_4w" +) + +// Add adds two vectors element-wise and stores the result in self. +// It panics if the vectors don't have the same length. +func (vector *Vector) Add(a, b Vector) { + if len(a) != len(b) || len(a) != len(*vector) { + panic("vector.Add: vectors don't have the same length") + } + n := uint64(len(a)) + addVec(&(*vector)[0], &a[0], &b[0], n) +} + +//go:noescape +func addVec(res, a, b *Element, n uint64) + +// Sub subtracts two vectors element-wise and stores the result in self. +// It panics if the vectors don't have the same length. +func (vector *Vector) Sub(a, b Vector) { + if len(a) != len(b) || len(a) != len(*vector) { + panic("vector.Sub: vectors don't have the same length") + } + subVec(&(*vector)[0], &a[0], &b[0], uint64(len(a))) +} + +//go:noescape +func subVec(res, a, b *Element, n uint64) + +// ScalarMul multiplies a vector by a scalar element-wise and stores the result in self. +// It panics if the vectors don't have the same length. +func (vector *Vector) ScalarMul(a Vector, b *Element) { + if len(a) != len(*vector) { + panic("vector.ScalarMul: vectors don't have the same length") + } + const maxN = (1 << 32) - 1 + if !supportAvx512 || uint64(len(a)) >= maxN { + // call scalarMulVecGeneric + scalarMulVecGeneric(*vector, a, b) + return + } + n := uint64(len(a)) + if n == 0 { + return + } + // the code for scalarMul is identical to mulVec; and it expects at least + // 2 elements in the vector to fill the Z registers + var bb [2]Element + bb[0] = *b + bb[1] = *b + const blockSize = 16 + scalarMulVec(&(*vector)[0], &a[0], &bb[0], n/blockSize, qInvNeg) + if n%blockSize != 0 { + // call scalarMulVecGeneric on the rest + start := n - n%blockSize + scalarMulVecGeneric((*vector)[start:], a[start:], b) + } +} + +//go:noescape +func scalarMulVec(res, a, b *Element, n uint64, qInvNeg uint64) + +// Sum computes the sum of all elements in the vector. +func (vector *Vector) Sum() (res Element) { + n := uint64(len(*vector)) + if n == 0 { + return + } + const minN = 16 * 7 // AVX512 slower than generic for small n + const maxN = (1 << 32) - 1 + if !supportAvx512 || n <= minN || n >= maxN { + // call sumVecGeneric + sumVecGeneric(&res, *vector) + return + } + sumVec(&res, &(*vector)[0], uint64(len(*vector))) + return +} + +//go:noescape +func sumVec(res *Element, a *Element, n uint64) + +// InnerProduct computes the inner product of two vectors. +// It panics if the vectors don't have the same length. +func (vector *Vector) InnerProduct(other Vector) (res Element) { + n := uint64(len(*vector)) + if n == 0 { + return + } + if n != uint64(len(other)) { + panic("vector.InnerProduct: vectors don't have the same length") + } + const maxN = (1 << 32) - 1 + if !supportAvx512 || n >= maxN { + // call innerProductVecGeneric + // note; we could split the vector into smaller chunks and call innerProductVec + innerProductVecGeneric(&res, *vector, other) + return + } + innerProdVec(&res[0], &(*vector)[0], &other[0], uint64(len(*vector))) + + return +} + +//go:noescape +func innerProdVec(res *uint64, a, b *Element, n uint64) + +// Mul multiplies two vectors element-wise and stores the result in self. +// It panics if the vectors don't have the same length. +func (vector *Vector) Mul(a, b Vector) { + if len(a) != len(b) || len(a) != len(*vector) { + panic("vector.Mul: vectors don't have the same length") + } + n := uint64(len(a)) + if n == 0 { + return + } + const maxN = (1 << 32) - 1 + if !supportAvx512 || n >= maxN { + // call mulVecGeneric + mulVecGeneric(*vector, a, b) + return + } + + const blockSize = 16 + mulVec(&(*vector)[0], &a[0], &b[0], n/blockSize, qInvNeg) + if n%blockSize != 0 { + // call mulVecGeneric on the rest + start := n - n%blockSize + mulVecGeneric((*vector)[start:], a[start:], b[start:]) + } + +} + +// Patterns use for transposing the vectors in mulVec +var ( + pattern1 = [8]uint64{0, 8, 1, 9, 2, 10, 3, 11} + pattern2 = [8]uint64{12, 4, 13, 5, 14, 6, 15, 7} + pattern3 = [8]uint64{0, 1, 8, 9, 2, 3, 10, 11} + pattern4 = [8]uint64{12, 13, 4, 5, 14, 15, 6, 7} +) + +//go:noescape +func mulVec(res, a, b *Element, n uint64, qInvNeg uint64) diff --git a/ecc/grumpkin/fr/vector_purego.go b/ecc/grumpkin/fr/vector_purego.go new file mode 100644 index 0000000000..6608e394c7 --- /dev/null +++ b/ecc/grumpkin/fr/vector_purego.go @@ -0,0 +1,45 @@ +//go:build purego || !amd64 + +// Copyright 2020-2025 Consensys Software Inc. +// Licensed under the Apache License, Version 2.0. See the LICENSE file for details. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package fr + +// Add adds two vectors element-wise and stores the result in self. +// It panics if the vectors don't have the same length. +func (vector *Vector) Add(a, b Vector) { + addVecGeneric(*vector, a, b) +} + +// Sub subtracts two vectors element-wise and stores the result in self. +// It panics if the vectors don't have the same length. +func (vector *Vector) Sub(a, b Vector) { + subVecGeneric(*vector, a, b) +} + +// ScalarMul multiplies a vector by a scalar element-wise and stores the result in self. +// It panics if the vectors don't have the same length. +func (vector *Vector) ScalarMul(a Vector, b *Element) { + scalarMulVecGeneric(*vector, a, b) +} + +// Sum computes the sum of all elements in the vector. +func (vector *Vector) Sum() (res Element) { + sumVecGeneric(&res, *vector) + return +} + +// InnerProduct computes the inner product of two vectors. +// It panics if the vectors don't have the same length. +func (vector *Vector) InnerProduct(other Vector) (res Element) { + innerProductVecGeneric(&res, *vector, other) + return +} + +// Mul multiplies two vectors element-wise and stores the result in self. +// It panics if the vectors don't have the same length. +func (vector *Vector) Mul(a, b Vector) { + mulVecGeneric(*vector, a, b) +} diff --git a/ecc/grumpkin/fr/vector_test.go b/ecc/grumpkin/fr/vector_test.go new file mode 100644 index 0000000000..88c4ef318f --- /dev/null +++ b/ecc/grumpkin/fr/vector_test.go @@ -0,0 +1,360 @@ +// Copyright 2020-2025 Consensys Software Inc. +// Licensed under the Apache License, Version 2.0. See the LICENSE file for details. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package fr + +import ( + "bytes" + "fmt" + "github.com/stretchr/testify/require" + "os" + "reflect" + "sort" + "testing" + + "github.com/leanovate/gopter" + "github.com/leanovate/gopter/prop" +) + +func TestVectorSort(t *testing.T) { + assert := require.New(t) + + v := make(Vector, 3) + v[0].SetUint64(2) + v[1].SetUint64(3) + v[2].SetUint64(1) + + sort.Sort(v) + + assert.Equal("[1,2,3]", v.String()) +} + +func TestVectorRoundTrip(t *testing.T) { + assert := require.New(t) + + v1 := make(Vector, 3) + v1[0].SetUint64(2) + v1[1].SetUint64(3) + v1[2].SetUint64(1) + + b, err := v1.MarshalBinary() + assert.NoError(err) + + var v2, v3 Vector + + err = v2.UnmarshalBinary(b) + assert.NoError(err) + + err = v3.unmarshalBinaryAsync(b) + assert.NoError(err) + + assert.True(reflect.DeepEqual(v1, v2)) + assert.True(reflect.DeepEqual(v3, v2)) +} + +func TestVectorEmptyRoundTrip(t *testing.T) { + assert := require.New(t) + + v1 := make(Vector, 0) + + b, err := v1.MarshalBinary() + assert.NoError(err) + + var v2, v3 Vector + + err = v2.UnmarshalBinary(b) + assert.NoError(err) + + err = v3.unmarshalBinaryAsync(b) + assert.NoError(err) + + assert.True(reflect.DeepEqual(v1, v2)) + assert.True(reflect.DeepEqual(v3, v2)) +} + +func (vector *Vector) unmarshalBinaryAsync(data []byte) error { + r := bytes.NewReader(data) + _, err, chErr := vector.AsyncReadFrom(r) + if err != nil { + return err + } + return <-chErr +} + +func TestVectorOps(t *testing.T) { + parameters := gopter.DefaultTestParameters() + if testing.Short() { + parameters.MinSuccessfulTests = 2 + } else { + parameters.MinSuccessfulTests = 10 + } + properties := gopter.NewProperties(parameters) + + addVector := func(a, b Vector) bool { + c := make(Vector, len(a)) + c.Add(a, b) + + for i := 0; i < len(a); i++ { + var tmp Element + tmp.Add(&a[i], &b[i]) + if !tmp.Equal(&c[i]) { + return false + } + } + return true + } + + subVector := func(a, b Vector) bool { + c := make(Vector, len(a)) + c.Sub(a, b) + + for i := 0; i < len(a); i++ { + var tmp Element + tmp.Sub(&a[i], &b[i]) + if !tmp.Equal(&c[i]) { + return false + } + } + return true + } + + scalarMulVector := func(a Vector, b Element) bool { + c := make(Vector, len(a)) + c.ScalarMul(a, &b) + + for i := 0; i < len(a); i++ { + var tmp Element + tmp.Mul(&a[i], &b) + if !tmp.Equal(&c[i]) { + return false + } + } + return true + } + + sumVector := func(a Vector) bool { + var sum Element + computed := a.Sum() + for i := 0; i < len(a); i++ { + sum.Add(&sum, &a[i]) + } + + return sum.Equal(&computed) + } + + innerProductVector := func(a, b Vector) bool { + computed := a.InnerProduct(b) + var innerProduct Element + for i := 0; i < len(a); i++ { + var tmp Element + tmp.Mul(&a[i], &b[i]) + innerProduct.Add(&innerProduct, &tmp) + } + + return innerProduct.Equal(&computed) + } + + mulVector := func(a, b Vector) bool { + c := make(Vector, len(a)) + a[0].SetUint64(0x24) + b[0].SetUint64(0x42) + c.Mul(a, b) + + for i := 0; i < len(a); i++ { + var tmp Element + tmp.Mul(&a[i], &b[i]) + if !tmp.Equal(&c[i]) { + return false + } + } + return true + } + + sizes := []int{1, 2, 3, 4, 8, 9, 15, 16, 509, 510, 511, 512, 513, 514} + type genPair struct { + g1, g2 gopter.Gen + label string + } + + for _, size := range sizes { + generators := []genPair{ + {genZeroVector(size), genZeroVector(size), "zero vectors"}, + {genMaxVector(size), genMaxVector(size), "max vectors"}, + {genVector(size), genVector(size), "random vectors"}, + {genVector(size), genZeroVector(size), "random and zero vectors"}, + } + for _, gp := range generators { + properties.Property(fmt.Sprintf("vector addition %d - %s", size, gp.label), prop.ForAll( + addVector, + gp.g1, + gp.g2, + )) + + properties.Property(fmt.Sprintf("vector subtraction %d - %s", size, gp.label), prop.ForAll( + subVector, + gp.g1, + gp.g2, + )) + + properties.Property(fmt.Sprintf("vector scalar multiplication %d - %s", size, gp.label), prop.ForAll( + scalarMulVector, + gp.g1, + genElement(), + )) + + properties.Property(fmt.Sprintf("vector sum %d - %s", size, gp.label), prop.ForAll( + sumVector, + gp.g1, + )) + + properties.Property(fmt.Sprintf("vector inner product %d - %s", size, gp.label), prop.ForAll( + innerProductVector, + gp.g1, + gp.g2, + )) + + properties.Property(fmt.Sprintf("vector multiplication %d - %s", size, gp.label), prop.ForAll( + mulVector, + gp.g1, + gp.g2, + )) + } + } + + properties.TestingRun(t, gopter.NewFormatedReporter(false, 260, os.Stdout)) +} + +func BenchmarkVectorOps(b *testing.B) { + // note; to benchmark against "no asm" version, use the following + // build tag: -tags purego + const N = 1 << 24 + a1 := make(Vector, N) + b1 := make(Vector, N) + c1 := make(Vector, N) + var mixer Element + mixer.SetRandom() + for i := 1; i < N; i++ { + a1[i-1].SetUint64(uint64(i)). + Mul(&a1[i-1], &mixer) + b1[i-1].SetUint64(^uint64(i)). + Mul(&b1[i-1], &mixer) + } + + for n := 1 << 4; n <= N; n <<= 1 { + b.Run(fmt.Sprintf("add %d", n), func(b *testing.B) { + _a := a1[:n] + _b := b1[:n] + _c := c1[:n] + b.ResetTimer() + for i := 0; i < b.N; i++ { + _c.Add(_a, _b) + } + }) + + b.Run(fmt.Sprintf("sub %d", n), func(b *testing.B) { + _a := a1[:n] + _b := b1[:n] + _c := c1[:n] + b.ResetTimer() + for i := 0; i < b.N; i++ { + _c.Sub(_a, _b) + } + }) + + b.Run(fmt.Sprintf("scalarMul %d", n), func(b *testing.B) { + _a := a1[:n] + _c := c1[:n] + b.ResetTimer() + for i := 0; i < b.N; i++ { + _c.ScalarMul(_a, &mixer) + } + }) + + b.Run(fmt.Sprintf("sum %d", n), func(b *testing.B) { + _a := a1[:n] + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = _a.Sum() + } + }) + + b.Run(fmt.Sprintf("innerProduct %d", n), func(b *testing.B) { + _a := a1[:n] + _b := b1[:n] + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = _a.InnerProduct(_b) + } + }) + + b.Run(fmt.Sprintf("mul %d", n), func(b *testing.B) { + _a := a1[:n] + _b := b1[:n] + _c := c1[:n] + b.ResetTimer() + for i := 0; i < b.N; i++ { + _c.Mul(_a, _b) + } + }) + } +} + +func genZeroVector(size int) gopter.Gen { + return func(genParams *gopter.GenParameters) *gopter.GenResult { + g := make(Vector, size) + genResult := gopter.NewGenResult(g, gopter.NoShrinker) + return genResult + } +} + +func genMaxVector(size int) gopter.Gen { + return func(genParams *gopter.GenParameters) *gopter.GenResult { + g := make(Vector, size) + + qMinusOne := qElement + qMinusOne[0]-- + + for i := 0; i < size; i++ { + g[i] = qMinusOne + } + genResult := gopter.NewGenResult(g, gopter.NoShrinker) + return genResult + } +} + +func genVector(size int) gopter.Gen { + return func(genParams *gopter.GenParameters) *gopter.GenResult { + g := make(Vector, size) + mixer := Element{ + genParams.NextUint64(), + genParams.NextUint64(), + genParams.NextUint64(), + genParams.NextUint64(), + } + if qElement[3] != ^uint64(0) { + mixer[3] %= (qElement[3] + 1) + } + + for !mixer.smallerThanModulus() { + mixer = Element{ + genParams.NextUint64(), + genParams.NextUint64(), + genParams.NextUint64(), + genParams.NextUint64(), + } + if qElement[3] != ^uint64(0) { + mixer[3] %= (qElement[3] + 1) + } + } + + for i := 1; i <= size; i++ { + g[i-1].SetUint64(uint64(i)). + Mul(&g[i-1], &mixer) + } + + genResult := gopter.NewGenResult(g, gopter.NoShrinker) + return genResult + } +} diff --git a/ecc/grumpkin/g1.go b/ecc/grumpkin/g1.go new file mode 100644 index 0000000000..bfb26fd5fd --- /dev/null +++ b/ecc/grumpkin/g1.go @@ -0,0 +1,1181 @@ +// Copyright 2020-2025 Consensys Software Inc. +// Licensed under the Apache License, Version 2.0. See the LICENSE file for details. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package grumpkin + +import ( + "github.com/consensys/gnark-crypto/ecc" + "github.com/consensys/gnark-crypto/ecc/grumpkin/fp" + "github.com/consensys/gnark-crypto/ecc/grumpkin/fr" + "github.com/consensys/gnark-crypto/internal/parallel" + "math/big" + "runtime" +) + +// G1Affine is a point in affine coordinates (x,y) +type G1Affine struct { + X, Y fp.Element +} + +// G1Jac is a point in Jacobian coordinates (x=X/Z², y=Y/Z³) +type G1Jac struct { + X, Y, Z fp.Element +} + +// g1JacExtended is a point in extended Jacobian coordinates (x=X/ZZ, y=Y/ZZZ, ZZ³=ZZZ²) +type g1JacExtended struct { + X, Y, ZZ, ZZZ fp.Element +} + +// ------------------------------------------------------------------------------------------------- +// Affine coordinates + +// Set sets p to a in affine coordinates. +func (p *G1Affine) Set(a *G1Affine) *G1Affine { + p.X, p.Y = a.X, a.Y + return p +} + +// SetInfinity sets p to the infinity point, which is encoded as (0,0). +// N.B.: (0,0) is never on the curve for j=0 curves (Y²=X³+B). +func (p *G1Affine) SetInfinity() *G1Affine { + p.X.SetZero() + p.Y.SetZero() + return p +} + +// ScalarMultiplication computes and returns p = [s]a +// where p and a are affine points. +func (p *G1Affine) ScalarMultiplication(a *G1Affine, s *big.Int) *G1Affine { + var _p G1Jac + _p.FromAffine(a) + _p.mulGLV(&_p, s) + p.FromJacobian(&_p) + return p +} + +// ScalarMultiplicationBase computes and returns p = [s]g +// where g is the affine point generating the prime subgroup. +func (p *G1Affine) ScalarMultiplicationBase(s *big.Int) *G1Affine { + var _p G1Jac + _p.mulGLV(&g1Gen, s) + p.FromJacobian(&_p) + return p +} + +// Add adds two points in affine coordinates. +// It uses the Jacobian addition with a.Z=b.Z=1 and converts the result to affine coordinates. +// +// https://www.hyperelliptic.org/EFD/g1p/auto-shortw-jacobian-0.html#addition-mmadd-2007-bl +func (p *G1Affine) Add(a, b *G1Affine) *G1Affine { + var q G1Jac + // a is infinity, return b + if a.IsInfinity() { + p.Set(b) + return p + } + // b is infinity, return a + if b.IsInfinity() { + p.Set(a) + return p + } + if a.X.Equal(&b.X) { + // if b == a, we double instead + if a.Y.Equal(&b.Y) { + q.DoubleMixed(a) + return p.FromJacobian(&q) + } else { + // if b == -a, we return 0 + return p.SetInfinity() + } + } + var H, HH, I, J, r, V fp.Element + H.Sub(&b.X, &a.X) + HH.Square(&H) + I.Double(&HH).Double(&I) + J.Mul(&H, &I) + r.Sub(&b.Y, &a.Y) + r.Double(&r) + V.Mul(&a.X, &I) + q.X.Square(&r). + Sub(&q.X, &J). + Sub(&q.X, &V). + Sub(&q.X, &V) + q.Y.Sub(&V, &q.X). + Mul(&q.Y, &r) + J.Mul(&a.Y, &J).Double(&J) + q.Y.Sub(&q.Y, &J) + q.Z.Double(&H) + + return p.FromJacobian(&q) +} + +// Double doubles a point in affine coordinates. +// It converts the point to Jacobian coordinates, doubles it using Jacobian +// addition with a.Z=1, and converts it back to affine coordinates. +// +// http://www.hyperelliptic.org/EFD/g1p/auto-shortw-jacobian-0.html#doubling-mdbl-2007-bl +func (p *G1Affine) Double(a *G1Affine) *G1Affine { + var q G1Jac + q.FromAffine(a) + q.DoubleMixed(a) + p.FromJacobian(&q) + return p +} + +// Sub subtracts two points in affine coordinates. +// It uses a similar approach to Add, but negates the second point before adding. +func (p *G1Affine) Sub(a, b *G1Affine) *G1Affine { + var bneg G1Affine + bneg.Neg(b) + p.Add(a, &bneg) + return p +} + +// Equal tests if two points in affine coordinates are equal. +func (p *G1Affine) Equal(a *G1Affine) bool { + return p.X.Equal(&a.X) && p.Y.Equal(&a.Y) +} + +// Neg sets p to the affine negative point -a = (a.X, -a.Y). +func (p *G1Affine) Neg(a *G1Affine) *G1Affine { + p.X = a.X + p.Y.Neg(&a.Y) + return p +} + +// FromJacobian converts a point p1 from Jacobian to affine coordinates. +func (p *G1Affine) FromJacobian(p1 *G1Jac) *G1Affine { + + var a, b fp.Element + + if p1.Z.IsZero() { + p.X.SetZero() + p.Y.SetZero() + return p + } + + a.Inverse(&p1.Z) + b.Square(&a) + p.X.Mul(&p1.X, &b) + p.Y.Mul(&p1.Y, &b).Mul(&p.Y, &a) + + return p +} + +// String returns the string representation E(x,y) of the affine point p or "O" if it is infinity. +func (p *G1Affine) String() string { + if p.IsInfinity() { + return "O" + } + return "E([" + p.X.String() + "," + p.Y.String() + "])" +} + +// IsInfinity checks if the affine point p is infinity, which is encoded as (0,0). +// N.B.: (0,0) is never on the curve for j=0 curves (Y²=X³+B). +func (p *G1Affine) IsInfinity() bool { + return p.X.IsZero() && p.Y.IsZero() +} + +// IsOnCurve returns true if the affine point p in on the curve. +func (p *G1Affine) IsOnCurve() bool { + var point G1Jac + point.FromAffine(p) + return point.IsOnCurve() // call this function to handle infinity point +} + +// IsInSubGroup returns true if the affine point p is in the correct subgroup, false otherwise. +func (p *G1Affine) IsInSubGroup() bool { + var _p G1Jac + _p.FromAffine(p) + return _p.IsInSubGroup() +} + +// ------------------------------------------------------------------------------------------------- +// Jacobian coordinates + +// Set sets p to a in Jacobian coordinates. +func (p *G1Jac) Set(q *G1Jac) *G1Jac { + p.X, p.Y, p.Z = q.X, q.Y, q.Z + return p +} + +// Equal tests if two points in Jacobian coordinates are equal. +func (p *G1Jac) Equal(q *G1Jac) bool { + // If one point is infinity, the other must also be infinity. + if p.Z.IsZero() { + return q.Z.IsZero() + } + // If the other point is infinity, return false since we can't + // the following checks would be incorrect. + if q.Z.IsZero() { + return false + } + + var pZSquare, aZSquare fp.Element + pZSquare.Square(&p.Z) + aZSquare.Square(&q.Z) + + var lhs, rhs fp.Element + lhs.Mul(&p.X, &aZSquare) + rhs.Mul(&q.X, &pZSquare) + if !lhs.Equal(&rhs) { + return false + } + lhs.Mul(&p.Y, &aZSquare).Mul(&lhs, &q.Z) + rhs.Mul(&q.Y, &pZSquare).Mul(&rhs, &p.Z) + + return lhs.Equal(&rhs) +} + +// Neg sets p to the Jacobian negative point -q = (q.X, -q.Y, q.Z). +func (p *G1Jac) Neg(q *G1Jac) *G1Jac { + *p = *q + p.Y.Neg(&q.Y) + return p +} + +// AddAssign sets p to p+a in Jacobian coordinates. +// +// https://hyperelliptic.org/EFD/g1p/auto-shortw-jacobian-3.html#addition-add-2007-bl +func (p *G1Jac) AddAssign(q *G1Jac) *G1Jac { + + // p is infinity, return q + if p.Z.IsZero() { + p.Set(q) + return p + } + + // q is infinity, return p + if q.Z.IsZero() { + return p + } + + var Z1Z1, Z2Z2, U1, U2, S1, S2, H, I, J, r, V fp.Element + Z1Z1.Square(&q.Z) + Z2Z2.Square(&p.Z) + U1.Mul(&q.X, &Z2Z2) + U2.Mul(&p.X, &Z1Z1) + S1.Mul(&q.Y, &p.Z). + Mul(&S1, &Z2Z2) + S2.Mul(&p.Y, &q.Z). + Mul(&S2, &Z1Z1) + + // if p == q, we double instead + if U1.Equal(&U2) && S1.Equal(&S2) { + return p.DoubleAssign() + } + + H.Sub(&U2, &U1) + I.Double(&H). + Square(&I) + J.Mul(&H, &I) + r.Sub(&S2, &S1).Double(&r) + V.Mul(&U1, &I) + p.X.Square(&r). + Sub(&p.X, &J). + Sub(&p.X, &V). + Sub(&p.X, &V) + p.Y.Sub(&V, &p.X). + Mul(&p.Y, &r) + S1.Mul(&S1, &J).Double(&S1) + p.Y.Sub(&p.Y, &S1) + p.Z.Add(&p.Z, &q.Z) + p.Z.Square(&p.Z). + Sub(&p.Z, &Z1Z1). + Sub(&p.Z, &Z2Z2). + Mul(&p.Z, &H) + + return p +} + +// SubAssign sets p to p-a in Jacobian coordinates. +// It uses a similar approach to AddAssign, but negates the point a before adding. +func (p *G1Jac) SubAssign(q *G1Jac) *G1Jac { + var tmp G1Jac + tmp.Set(q) + tmp.Y.Neg(&tmp.Y) + p.AddAssign(&tmp) + return p +} + +// Double sets p to [2]q in Jacobian coordinates. +// +// https://hyperelliptic.org/EFD/g1p/auto-shortw-jacobian-3.html#doubling-dbl-2007-bl +func (p *G1Jac) DoubleMixed(a *G1Affine) *G1Jac { + var XX, YY, YYYY, S, M, T fp.Element + XX.Square(&a.X) + YY.Square(&a.Y) + YYYY.Square(&YY) + S.Add(&a.X, &YY). + Square(&S). + Sub(&S, &XX). + Sub(&S, &YYYY). + Double(&S) + M.Double(&XX). + Add(&M, &XX) // -> + A, but A=0 here + T.Square(&M). + Sub(&T, &S). + Sub(&T, &S) + p.X.Set(&T) + p.Y.Sub(&S, &T). + Mul(&p.Y, &M) + YYYY.Double(&YYYY). + Double(&YYYY). + Double(&YYYY) + p.Y.Sub(&p.Y, &YYYY) + p.Z.Double(&a.Y) + + return p +} + +// AddMixed sets p to p+a in Jacobian coordinates, where a.Z = 1. +// +// http://www.hyperelliptic.org/EFD/g1p/auto-shortw-jacobian-0.html#addition-madd-2007-bl +func (p *G1Jac) AddMixed(a *G1Affine) *G1Jac { + + //if a is infinity return p + if a.IsInfinity() { + return p + } + // p is infinity, return a + if p.Z.IsZero() { + p.X = a.X + p.Y = a.Y + p.Z.SetOne() + return p + } + + var Z1Z1, U2, S2, H, HH, I, J, r, V fp.Element + Z1Z1.Square(&p.Z) + U2.Mul(&a.X, &Z1Z1) + S2.Mul(&a.Y, &p.Z). + Mul(&S2, &Z1Z1) + + // if p == a, we double instead + if U2.Equal(&p.X) && S2.Equal(&p.Y) { + return p.DoubleMixed(a) + } + + H.Sub(&U2, &p.X) + HH.Square(&H) + I.Double(&HH).Double(&I) + J.Mul(&H, &I) + r.Sub(&S2, &p.Y).Double(&r) + V.Mul(&p.X, &I) + p.X.Square(&r). + Sub(&p.X, &J). + Sub(&p.X, &V). + Sub(&p.X, &V) + J.Mul(&J, &p.Y).Double(&J) + p.Y.Sub(&V, &p.X). + Mul(&p.Y, &r) + p.Y.Sub(&p.Y, &J) + p.Z.Add(&p.Z, &H) + p.Z.Square(&p.Z). + Sub(&p.Z, &Z1Z1). + Sub(&p.Z, &HH) + + return p +} + +// Double sets p to [2]q in Jacobian coordinates. +// +// https://hyperelliptic.org/EFD/g1p/auto-shortw-jacobian-3.html#doubling-dbl-2007-bl +func (p *G1Jac) Double(q *G1Jac) *G1Jac { + p.Set(q) + p.DoubleAssign() + return p +} + +// DoubleAssign doubles p in Jacobian coordinates. +// +// https://hyperelliptic.org/EFD/g1p/auto-shortw-jacobian-3.html#doubling-dbl-2007-bl +func (p *G1Jac) DoubleAssign() *G1Jac { + + var XX, YY, YYYY, ZZ, S, M, T fp.Element + + XX.Square(&p.X) + YY.Square(&p.Y) + YYYY.Square(&YY) + ZZ.Square(&p.Z) + S.Add(&p.X, &YY) + S.Square(&S). + Sub(&S, &XX). + Sub(&S, &YYYY). + Double(&S) + M.Double(&XX).Add(&M, &XX) + p.Z.Add(&p.Z, &p.Y). + Square(&p.Z). + Sub(&p.Z, &YY). + Sub(&p.Z, &ZZ) + T.Square(&M) + p.X = T + T.Double(&S) + p.X.Sub(&p.X, &T) + p.Y.Sub(&S, &p.X). + Mul(&p.Y, &M) + YYYY.Double(&YYYY).Double(&YYYY).Double(&YYYY) + p.Y.Sub(&p.Y, &YYYY) + + return p +} + +// ScalarMultiplication computes and returns p = [s]a +// where p and a are Jacobian points. +// using the GLV technique. +// see https://www.iacr.org/archive/crypto2001/21390189.pdf +func (p *G1Jac) ScalarMultiplication(q *G1Jac, s *big.Int) *G1Jac { + return p.mulGLV(q, s) +} + +// ScalarMultiplicationBase computes and returns p = [s]g +// where g is the prime subgroup generator. +func (p *G1Jac) ScalarMultiplicationBase(s *big.Int) *G1Jac { + return p.mulGLV(&g1Gen, s) + +} + +// String converts p to affine coordinates and returns its string representation E(x,y) or "O" if it is infinity. +func (p *G1Jac) String() string { + _p := G1Affine{} + _p.FromJacobian(p) + return _p.String() +} + +// FromAffine converts a point a from affine to Jacobian coordinates. +func (p *G1Jac) FromAffine(a *G1Affine) *G1Jac { + if a.IsInfinity() { + p.Z.SetZero() + p.X.SetOne() + p.Y.SetOne() + return p + } + p.Z.SetOne() + p.X.Set(&a.X) + p.Y.Set(&a.Y) + return p +} + +// IsOnCurve returns true if the Jacobian point p in on the curve. +func (p *G1Jac) IsOnCurve() bool { + var left, right, tmp, ZZ fp.Element + left.Square(&p.Y) + right.Square(&p.X).Mul(&right, &p.X) + ZZ.Square(&p.Z) + tmp.Square(&ZZ).Mul(&tmp, &ZZ) + tmp.Mul(&tmp, &bCurveCoeff) + right.Add(&right, &tmp) + return left.Equal(&right) +} + +// IsInSubGroup returns true if p is on the r-torsion, false otherwise. +// the curve is of prime order i.e. E(𝔽p) is the full group +// so we just check that the point is on the curve. +func (p *G1Jac) IsInSubGroup() bool { + + return p.IsOnCurve() + +} + +// mulWindowed computes the 2-bits windowed double-and-add scalar +// multiplication p=[s]q in Jacobian coordinates. +func (p *G1Jac) mulWindowed(q *G1Jac, s *big.Int) *G1Jac { + + var res G1Jac + var ops [3]G1Jac + + ops[0].Set(q) + if s.Sign() == -1 { + ops[0].Neg(&ops[0]) + } + res.Set(&g1Infinity) + ops[1].Double(&ops[0]) + ops[2].Set(&ops[0]).AddAssign(&ops[1]) + + b := s.Bytes() + for i := range b { + w := b[i] + mask := byte(0xc0) + for j := 0; j < 4; j++ { + res.DoubleAssign().DoubleAssign() + c := (w & mask) >> (6 - 2*j) + if c != 0 { + res.AddAssign(&ops[c-1]) + } + mask = mask >> 2 + } + } + p.Set(&res) + + return p + +} + +// phi sets p to ϕ(a) where ϕ: (x,y) → (w x,y), +// where w is a third root of unity. +func (p *G1Jac) phi(q *G1Jac) *G1Jac { + p.Set(q) + p.X.Mul(&p.X, &thirdRootOneG1) + return p +} + +// mulGLV computes the scalar multiplication using a windowed-GLV method +// +// see https://www.iacr.org/archive/crypto2001/21390189.pdf +func (p *G1Jac) mulGLV(q *G1Jac, s *big.Int) *G1Jac { + + var table [15]G1Jac + var res G1Jac + var k1, k2 fr.Element + + res.Set(&g1Infinity) + + // table[b3b2b1b0-1] = b3b2 ⋅ ϕ(q) + b1b0*q + table[0].Set(q) + table[3].phi(q) + + // split the scalar, modifies ±q, ϕ(q) accordingly + k := ecc.SplitScalar(s, &glvBasis) + + if k[0].Sign() == -1 { + k[0].Neg(&k[0]) + table[0].Neg(&table[0]) + } + if k[1].Sign() == -1 { + k[1].Neg(&k[1]) + table[3].Neg(&table[3]) + } + + // precompute table (2 bits sliding window) + // table[b3b2b1b0-1] = b3b2 ⋅ ϕ(q) + b1b0 ⋅ q if b3b2b1b0 != 0 + table[1].Double(&table[0]) + table[2].Set(&table[1]).AddAssign(&table[0]) + table[4].Set(&table[3]).AddAssign(&table[0]) + table[5].Set(&table[3]).AddAssign(&table[1]) + table[6].Set(&table[3]).AddAssign(&table[2]) + table[7].Double(&table[3]) + table[8].Set(&table[7]).AddAssign(&table[0]) + table[9].Set(&table[7]).AddAssign(&table[1]) + table[10].Set(&table[7]).AddAssign(&table[2]) + table[11].Set(&table[7]).AddAssign(&table[3]) + table[12].Set(&table[11]).AddAssign(&table[0]) + table[13].Set(&table[11]).AddAssign(&table[1]) + table[14].Set(&table[11]).AddAssign(&table[2]) + + // bounds on the lattice base vectors guarantee that k1, k2 are len(r)/2 or len(r)/2+1 bits long max + // this is because we use a probabilistic scalar decomposition that replaces a division by a right-shift + k1 = k1.SetBigInt(&k[0]).Bits() + k2 = k2.SetBigInt(&k[1]).Bits() + + // we don't target constant-timeness so we check first if we increase the bounds or not + maxBit := k1.BitLen() + if k2.BitLen() > maxBit { + maxBit = k2.BitLen() + } + hiWordIndex := (maxBit - 1) / 64 + + // loop starts from len(k1)/2 or len(k1)/2+1 due to the bounds + for i := hiWordIndex; i >= 0; i-- { + mask := uint64(3) << 62 + for j := 0; j < 32; j++ { + res.Double(&res).Double(&res) + b1 := (k1[i] & mask) >> (62 - 2*j) + b2 := (k2[i] & mask) >> (62 - 2*j) + if b1|b2 != 0 { + s := (b2<<2 | b1) + res.AddAssign(&table[s-1]) + } + mask = mask >> 2 + } + } + + p.Set(&res) + return p +} + +// JointScalarMultiplication computes [s1]a1+[s2]a2 using Strauss-Shamir technique +// where a1 and a2 are affine points. +func (p *G1Jac) JointScalarMultiplication(a1, a2 *G1Affine, s1, s2 *big.Int) *G1Jac { + + var res, p1, p2 G1Jac + res.Set(&g1Infinity) + p1.FromAffine(a1) + p2.FromAffine(a2) + + var table [15]G1Jac + + var k1, k2 big.Int + if s1.Sign() == -1 { + k1.Neg(s1) + table[0].Neg(&p1) + } else { + k1.Set(s1) + table[0].Set(&p1) + } + if s2.Sign() == -1 { + k2.Neg(s2) + table[3].Neg(&p2) + } else { + k2.Set(s2) + table[3].Set(&p2) + } + + // precompute table (2 bits sliding window) + table[1].Double(&table[0]) + table[2].Set(&table[1]).AddAssign(&table[0]) + table[4].Set(&table[3]).AddAssign(&table[0]) + table[5].Set(&table[3]).AddAssign(&table[1]) + table[6].Set(&table[3]).AddAssign(&table[2]) + table[7].Double(&table[3]) + table[8].Set(&table[7]).AddAssign(&table[0]) + table[9].Set(&table[7]).AddAssign(&table[1]) + table[10].Set(&table[7]).AddAssign(&table[2]) + table[11].Set(&table[7]).AddAssign(&table[3]) + table[12].Set(&table[11]).AddAssign(&table[0]) + table[13].Set(&table[11]).AddAssign(&table[1]) + table[14].Set(&table[11]).AddAssign(&table[2]) + + var s [2]fr.Element + s[0] = s[0].SetBigInt(&k1).Bits() + s[1] = s[1].SetBigInt(&k2).Bits() + + maxBit := k1.BitLen() + if k2.BitLen() > maxBit { + maxBit = k2.BitLen() + } + hiWordIndex := (maxBit - 1) / 64 + + for i := hiWordIndex; i >= 0; i-- { + mask := uint64(3) << 62 + for j := 0; j < 32; j++ { + res.Double(&res).Double(&res) + b1 := (s[0][i] & mask) >> (62 - 2*j) + b2 := (s[1][i] & mask) >> (62 - 2*j) + if b1|b2 != 0 { + s := (b2<<2 | b1) + res.AddAssign(&table[s-1]) + } + mask = mask >> 2 + } + } + + p.Set(&res) + return p + +} + +// JointScalarMultiplicationBase computes [s1]g+[s2]a using Straus-Shamir technique +// where g is the prime subgroup generator. +func (p *G1Jac) JointScalarMultiplicationBase(a *G1Affine, s1, s2 *big.Int) *G1Jac { + return p.JointScalarMultiplication(&g1GenAff, a, s1, s2) + +} + +// ------------------------------------------------------------------------------------------------- +// extended Jacobian coordinates + +// Set sets p to a in extended Jacobian coordinates. +func (p *g1JacExtended) Set(q *g1JacExtended) *g1JacExtended { + p.X, p.Y, p.ZZ, p.ZZZ = q.X, q.Y, q.ZZ, q.ZZZ + return p +} + +// SetInfinity sets p to the infinity point (1,1,0,0). +func (p *g1JacExtended) SetInfinity() *g1JacExtended { + p.X.SetOne() + p.Y.SetOne() + p.ZZ = fp.Element{} + p.ZZZ = fp.Element{} + return p +} + +// IsInfinity checks if the p is infinity, i.e. p.ZZ=0. +func (p *g1JacExtended) IsInfinity() bool { + return p.ZZ.IsZero() +} + +// fromJacExtended converts an extended Jacobian point to an affine point. +func (p *G1Affine) fromJacExtended(q *g1JacExtended) *G1Affine { + if q.ZZ.IsZero() { + p.X = fp.Element{} + p.Y = fp.Element{} + return p + } + p.X.Inverse(&q.ZZ).Mul(&p.X, &q.X) + p.Y.Inverse(&q.ZZZ).Mul(&p.Y, &q.Y) + return p +} + +// fromJacExtended converts an extended Jacobian point to a Jacobian point. +func (p *G1Jac) fromJacExtended(q *g1JacExtended) *G1Jac { + if q.ZZ.IsZero() { + p.Set(&g1Infinity) + return p + } + p.X.Mul(&q.ZZ, &q.X).Mul(&p.X, &q.ZZ) + p.Y.Mul(&q.ZZZ, &q.Y).Mul(&p.Y, &q.ZZZ) + p.Z.Set(&q.ZZZ) + return p +} + +// unsafeFromJacExtended converts an extended Jacobian point, distinct from Infinity, to a Jacobian point. +func (p *G1Jac) unsafeFromJacExtended(q *g1JacExtended) *G1Jac { + p.X.Square(&q.ZZ).Mul(&p.X, &q.X) + p.Y.Square(&q.ZZZ).Mul(&p.Y, &q.Y) + p.Z = q.ZZZ + return p +} + +// add sets p to p+q in extended Jacobian coordinates. +// +// https://www.hyperelliptic.org/EFD/g1p/auto-shortw-xyzz.html#addition-add-2008-s +func (p *g1JacExtended) add(q *g1JacExtended) *g1JacExtended { + //if q is infinity return p + if q.ZZ.IsZero() { + return p + } + // p is infinity, return q + if p.ZZ.IsZero() { + p.Set(q) + return p + } + + var A, B, U1, U2, S1, S2 fp.Element + + // p2: q, p1: p + U2.Mul(&q.X, &p.ZZ) + U1.Mul(&p.X, &q.ZZ) + A.Sub(&U2, &U1) + S2.Mul(&q.Y, &p.ZZZ) + S1.Mul(&p.Y, &q.ZZZ) + B.Sub(&S2, &S1) + + if A.IsZero() { + if B.IsZero() { + return p.double(q) + + } + p.ZZ = fp.Element{} + p.ZZZ = fp.Element{} + return p + } + + var P, R, PP, PPP, Q, V fp.Element + P.Sub(&U2, &U1) + R.Sub(&S2, &S1) + PP.Square(&P) + PPP.Mul(&P, &PP) + Q.Mul(&U1, &PP) + V.Mul(&S1, &PPP) + + p.X.Square(&R). + Sub(&p.X, &PPP). + Sub(&p.X, &Q). + Sub(&p.X, &Q) + p.Y.Sub(&Q, &p.X). + Mul(&p.Y, &R). + Sub(&p.Y, &V) + p.ZZ.Mul(&p.ZZ, &q.ZZ). + Mul(&p.ZZ, &PP) + p.ZZZ.Mul(&p.ZZZ, &q.ZZZ). + Mul(&p.ZZZ, &PPP) + + return p +} + +// double sets p to [2]q in Jacobian extended coordinates. +// +// http://www.hyperelliptic.org/EFD/g1p/auto-shortw-xyzz.html#doubling-dbl-2008-s-1 +// N.B.: since we consider any point on Z=0 as the point at infinity +// this doubling formula works for infinity points as well. +func (p *g1JacExtended) double(q *g1JacExtended) *g1JacExtended { + var U, V, W, S, XX, M fp.Element + + U.Double(&q.Y) + V.Square(&U) + W.Mul(&U, &V) + S.Mul(&q.X, &V) + XX.Square(&q.X) + M.Double(&XX). + Add(&M, &XX) // -> + A, but A=0 here + U.Mul(&W, &q.Y) + + p.X.Square(&M). + Sub(&p.X, &S). + Sub(&p.X, &S) + p.Y.Sub(&S, &p.X). + Mul(&p.Y, &M). + Sub(&p.Y, &U) + p.ZZ.Mul(&V, &q.ZZ) + p.ZZZ.Mul(&W, &q.ZZZ) + + return p +} + +// addMixed sets p to p+q in extended Jacobian coordinates, where a.ZZ=1. +// +// http://www.hyperelliptic.org/EFD/g1p/auto-shortw-xyzz.html#addition-madd-2008-s +func (p *g1JacExtended) addMixed(a *G1Affine) *g1JacExtended { + + //if a is infinity return p + if a.IsInfinity() { + return p + } + // p is infinity, return a + if p.ZZ.IsZero() { + p.X = a.X + p.Y = a.Y + p.ZZ.SetOne() + p.ZZZ.SetOne() + return p + } + + var P, R fp.Element + + // p2: a, p1: p + P.Mul(&a.X, &p.ZZ) + P.Sub(&P, &p.X) + + R.Mul(&a.Y, &p.ZZZ) + R.Sub(&R, &p.Y) + + if P.IsZero() { + if R.IsZero() { + return p.doubleMixed(a) + + } + p.ZZ = fp.Element{} + p.ZZZ = fp.Element{} + return p + } + + var PP, PPP, Q, Q2, RR, X3, Y3 fp.Element + + PP.Square(&P) + PPP.Mul(&P, &PP) + Q.Mul(&p.X, &PP) + RR.Square(&R) + X3.Sub(&RR, &PPP) + Q2.Double(&Q) + p.X.Sub(&X3, &Q2) + Y3.Sub(&Q, &p.X).Mul(&Y3, &R) + R.Mul(&p.Y, &PPP) + p.Y.Sub(&Y3, &R) + p.ZZ.Mul(&p.ZZ, &PP) + p.ZZZ.Mul(&p.ZZZ, &PPP) + + return p + +} + +// subMixed works the same as addMixed, but negates a.Y. +// +// http://www.hyperelliptic.org/EFD/g1p/auto-shortw-xyzz.html#addition-madd-2008-s +func (p *g1JacExtended) subMixed(a *G1Affine) *g1JacExtended { + + //if a is infinity return p + if a.IsInfinity() { + return p + } + // p is infinity, return a + if p.ZZ.IsZero() { + p.X = a.X + p.Y.Neg(&a.Y) + p.ZZ.SetOne() + p.ZZZ.SetOne() + return p + } + + var P, R fp.Element + + // p2: a, p1: p + P.Mul(&a.X, &p.ZZ) + P.Sub(&P, &p.X) + + R.Mul(&a.Y, &p.ZZZ) + R.Neg(&R) + R.Sub(&R, &p.Y) + + if P.IsZero() { + if R.IsZero() { + return p.doubleNegMixed(a) + + } + p.ZZ = fp.Element{} + p.ZZZ = fp.Element{} + return p + } + + var PP, PPP, Q, Q2, RR, X3, Y3 fp.Element + + PP.Square(&P) + PPP.Mul(&P, &PP) + Q.Mul(&p.X, &PP) + RR.Square(&R) + X3.Sub(&RR, &PPP) + Q2.Double(&Q) + p.X.Sub(&X3, &Q2) + Y3.Sub(&Q, &p.X).Mul(&Y3, &R) + R.Mul(&p.Y, &PPP) + p.Y.Sub(&Y3, &R) + p.ZZ.Mul(&p.ZZ, &PP) + p.ZZZ.Mul(&p.ZZZ, &PPP) + + return p + +} + +// doubleNegMixed works the same as double, but negates q.Y. +func (p *g1JacExtended) doubleNegMixed(a *G1Affine) *g1JacExtended { + + var U, V, W, S, XX, M, S2, L fp.Element + + U.Double(&a.Y) + U.Neg(&U) + V.Square(&U) + W.Mul(&U, &V) + S.Mul(&a.X, &V) + XX.Square(&a.X) + M.Double(&XX). + Add(&M, &XX) // -> + A, but A=0 here + S2.Double(&S) + L.Mul(&W, &a.Y) + + p.X.Square(&M). + Sub(&p.X, &S2) + p.Y.Sub(&S, &p.X). + Mul(&p.Y, &M). + Add(&p.Y, &L) + p.ZZ.Set(&V) + p.ZZZ.Set(&W) + + return p +} + +// doubleMixed sets p to [2]a in Jacobian extended coordinates, where a.ZZ=1. +// +// http://www.hyperelliptic.org/EFD/g1p/auto-shortw-xyzz.html#doubling-dbl-2008-s-1 +func (p *g1JacExtended) doubleMixed(a *G1Affine) *g1JacExtended { + + var U, V, W, S, XX, M, S2, L fp.Element + + U.Double(&a.Y) + V.Square(&U) + W.Mul(&U, &V) + S.Mul(&a.X, &V) + XX.Square(&a.X) + M.Double(&XX). + Add(&M, &XX) // -> + A, but A=0 here + S2.Double(&S) + L.Mul(&W, &a.Y) + + p.X.Square(&M). + Sub(&p.X, &S2) + p.Y.Sub(&S, &p.X). + Mul(&p.Y, &M). + Sub(&p.Y, &L) + p.ZZ.Set(&V) + p.ZZZ.Set(&W) + + return p +} + +// BatchJacobianToAffineG1 converts points in Jacobian coordinates to Affine coordinates +// performing a single field inversion using the Montgomery batch inversion trick. +func BatchJacobianToAffineG1(points []G1Jac) []G1Affine { + result := make([]G1Affine, len(points)) + zeroes := make([]bool, len(points)) + accumulator := fp.One() + + // batch invert all points[].Z coordinates with Montgomery batch inversion trick + // (stores points[].Z^-1 in result[i].X to avoid allocating a slice of fr.Elements) + for i := 0; i < len(points); i++ { + if points[i].Z.IsZero() { + zeroes[i] = true + continue + } + result[i].X = accumulator + accumulator.Mul(&accumulator, &points[i].Z) + } + + var accInverse fp.Element + accInverse.Inverse(&accumulator) + + for i := len(points) - 1; i >= 0; i-- { + if zeroes[i] { + // do nothing, (X=0, Y=0) is infinity point in affine + continue + } + result[i].X.Mul(&result[i].X, &accInverse) + accInverse.Mul(&accInverse, &points[i].Z) + } + + // batch convert to affine. + parallel.Execute(len(points), func(start, end int) { + for i := start; i < end; i++ { + if zeroes[i] { + // do nothing, (X=0, Y=0) is infinity point in affine + continue + } + var a, b fp.Element + a = result[i].X + b.Square(&a) + result[i].X.Mul(&points[i].X, &b) + result[i].Y.Mul(&points[i].Y, &b). + Mul(&result[i].Y, &a) + } + }) + + return result +} + +// BatchScalarMultiplicationG1 multiplies the same base by all scalars +// and return resulting points in affine coordinates +// uses a simple windowed-NAF-like multiplication algorithm. +func BatchScalarMultiplicationG1(base *G1Affine, scalars []fr.Element) []G1Affine { + // approximate cost in group ops is + // cost = 2^{c-1} + n(scalar.nbBits+nbChunks) + + nbPoints := uint64(len(scalars)) + min := ^uint64(0) + bestC := 0 + for c := 2; c <= 16; c++ { + cost := uint64(1 << (c - 1)) // pre compute the table + nbChunks := computeNbChunks(uint64(c)) + cost += nbPoints * (uint64(c) + 1) * nbChunks // doublings + point add + if cost < min { + min = cost + bestC = c + } + } + c := uint64(bestC) // window size + nbChunks := int(computeNbChunks(c)) + + // last window may be slightly larger than c; in which case we need to compute one + // extra element in the baseTable + maxC := lastC(c) + if c > maxC { + maxC = c + } + + // precompute all powers of base for our window + // note here that if performance is critical, we can implement as in the msmX methods + // this allocation to be on the stack + baseTable := make([]G1Jac, (1 << (maxC - 1))) + baseTable[0].FromAffine(base) + for i := 1; i < len(baseTable); i++ { + baseTable[i] = baseTable[i-1] + baseTable[i].AddMixed(base) + } + // convert our base exp table into affine to use AddMixed + baseTableAff := BatchJacobianToAffineG1(baseTable) + toReturn := make([]G1Jac, len(scalars)) + + // partition the scalars into digits + digits, _ := partitionScalars(scalars, c, runtime.NumCPU()) + + // for each digit, take value in the base table, double it c time, voilà. + parallel.Execute(len(scalars), func(start, end int) { + var p G1Jac + for i := start; i < end; i++ { + p.Set(&g1Infinity) + for chunk := nbChunks - 1; chunk >= 0; chunk-- { + if chunk != nbChunks-1 { + for j := uint64(0); j < c; j++ { + p.DoubleAssign() + } + } + offset := chunk * len(scalars) + digit := digits[i+offset] + + if digit == 0 { + continue + } + + // if msbWindow bit is set, we need to subtract + if digit&1 == 0 { + // add + p.AddMixed(&baseTableAff[(digit>>1)-1]) + } else { + // sub + t := baseTableAff[digit>>1] + t.Neg(&t) + p.AddMixed(&t) + } + } + + // set our result point + toReturn[i] = p + + } + }) + toReturnAff := BatchJacobianToAffineG1(toReturn) + return toReturnAff +} + +// batchAddG1Affine adds affine points using the Montgomery batch inversion trick. +// Special cases (doubling, infinity) must be filtered out before this call. +func batchAddG1Affine[TP pG1Affine, TPP ppG1Affine, TC cG1Affine](R *TPP, P *TP, batchSize int) { + var lambda, lambdain TC + + // from https://docs.zkproof.org/pages/standards/accepted-workshop3/proposal-turbo_plonk.pdf + // affine point addition formula + // R(X1, Y1) + P(X2, Y2) = Q(X3, Y3) + // λ = (Y2 - Y1) / (X2 - X1) + // X3 = λ² - (X1 + X2) + // Y3 = λ * (X1 - X3) - Y1 + + // first we compute the 1 / (X2 - X1) for all points using Montgomery batch inversion trick + + // X2 - X1 + for j := 0; j < batchSize; j++ { + lambdain[j].Sub(&(*P)[j].X, &(*R)[j].X) + } + + // montgomery batch inversion; + // lambda[0] = 1 / (P[0].X - R[0].X) + // lambda[1] = 1 / (P[1].X - R[1].X) + // ... + { + var accumulator fp.Element + lambda[0].SetOne() + accumulator.Set(&lambdain[0]) + + for i := 1; i < batchSize; i++ { + lambda[i] = accumulator + accumulator.Mul(&accumulator, &lambdain[i]) + } + + accumulator.Inverse(&accumulator) + + for i := batchSize - 1; i > 0; i-- { + lambda[i].Mul(&lambda[i], &accumulator) + accumulator.Mul(&accumulator, &lambdain[i]) + } + lambda[0].Set(&accumulator) + } + + var t fp.Element + var Q G1Affine + + for j := 0; j < batchSize; j++ { + // λ = (Y2 - Y1) / (X2 - X1) + t.Sub(&(*P)[j].Y, &(*R)[j].Y) + lambda[j].Mul(&lambda[j], &t) + + // X3 = λ² - (X1 + X2) + Q.X.Square(&lambda[j]) + Q.X.Sub(&Q.X, &(*R)[j].X) + Q.X.Sub(&Q.X, &(*P)[j].X) + + // Y3 = λ * (X1 - X3) - Y1 + t.Sub(&(*R)[j].X, &Q.X) + Q.Y.Mul(&lambda[j], &t) + Q.Y.Sub(&Q.Y, &(*R)[j].Y) + + (*R)[j].Set(&Q) + } +} diff --git a/ecc/grumpkin/g1_test.go b/ecc/grumpkin/g1_test.go new file mode 100644 index 0000000000..bb1111e979 --- /dev/null +++ b/ecc/grumpkin/g1_test.go @@ -0,0 +1,827 @@ +// Copyright 2020-2025 Consensys Software Inc. +// Licensed under the Apache License, Version 2.0. See the LICENSE file for details. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package grumpkin + +import ( + "fmt" + "math/big" + "math/rand/v2" + "testing" + + "github.com/consensys/gnark-crypto/ecc/grumpkin/fp" + + "github.com/consensys/gnark-crypto/ecc/grumpkin/fr" + "github.com/leanovate/gopter" + "github.com/leanovate/gopter/prop" +) + +func TestG1AffineEndomorphism(t *testing.T) { + t.Parallel() + parameters := gopter.DefaultTestParameters() + if testing.Short() { + parameters.MinSuccessfulTests = nbFuzzShort + } else { + parameters.MinSuccessfulTests = nbFuzz + } + + properties := gopter.NewProperties(parameters) + + properties.Property("[GRUMPKIN] check that phi(P) = lambdaGLV * P", prop.ForAll( + func(a fp.Element) bool { + var p, res1, res2 G1Jac + g := MapToG1(a) + p.FromAffine(&g) + res1.phi(&p) + res2.mulWindowed(&p, &lambdaGLV) + + return p.IsInSubGroup() && res1.Equal(&res2) + }, + GenFp(), + )) + + properties.Property("[GRUMPKIN] check that phi^2(P) + phi(P) + P = 0", prop.ForAll( + func(a fp.Element) bool { + var p, res, tmp G1Jac + g := MapToG1(a) + p.FromAffine(&g) + tmp.phi(&p) + res.phi(&tmp). + AddAssign(&tmp). + AddAssign(&p) + + return res.Z.IsZero() + }, + GenFp(), + )) + + properties.TestingRun(t, gopter.ConsoleReporter(false)) +} + +func TestG1AffineIsOnCurve(t *testing.T) { + t.Parallel() + parameters := gopter.DefaultTestParameters() + if testing.Short() { + parameters.MinSuccessfulTests = nbFuzzShort + } else { + parameters.MinSuccessfulTests = nbFuzz + } + + properties := gopter.NewProperties(parameters) + + properties.Property("[GRUMPKIN] g1Gen (affine) should be on the curve", prop.ForAll( + func(a fp.Element) bool { + var op1, op2 G1Affine + op1.FromJacobian(&g1Gen) + op2.Set(&op1) + op2.Y.Mul(&op2.Y, &a) + return op1.IsOnCurve() && !op2.IsOnCurve() + }, + GenFp(), + )) + + properties.Property("[GRUMPKIN] g1Gen (Jacobian) should be on the curve", prop.ForAll( + func(a fp.Element) bool { + var op1, op2, op3 G1Jac + op1.Set(&g1Gen) + op3.Set(&g1Gen) + + op2 = fuzzG1Jac(&g1Gen, a) + op3.Y.Mul(&op3.Y, &a) + return op1.IsOnCurve() && op2.IsOnCurve() && !op3.IsOnCurve() + }, + GenFp(), + )) + + properties.Property("[GRUMPKIN] IsInSubGroup and MulBy subgroup order should be the same", prop.ForAll( + func(a fp.Element) bool { + var op1, op2 G1Jac + op1 = fuzzG1Jac(&g1Gen, a) + _r := fr.Modulus() + op2.mulWindowed(&op1, _r) + return op1.IsInSubGroup() && op2.Z.IsZero() + }, + GenFp(), + )) + + properties.TestingRun(t, gopter.ConsoleReporter(false)) +} + +func TestG1AffineConversions(t *testing.T) { + t.Parallel() + parameters := gopter.DefaultTestParameters() + if testing.Short() { + parameters.MinSuccessfulTests = nbFuzzShort + } else { + parameters.MinSuccessfulTests = nbFuzz + } + + properties := gopter.NewProperties(parameters) + + properties.Property("[GRUMPKIN] Affine representation should be independent of the Jacobian representative", prop.ForAll( + func(a fp.Element) bool { + g := fuzzG1Jac(&g1Gen, a) + var op1 G1Affine + op1.FromJacobian(&g) + return op1.X.Equal(&g1Gen.X) && op1.Y.Equal(&g1Gen.Y) + }, + GenFp(), + )) + + properties.Property("[GRUMPKIN] Affine representation should be independent of a Extended Jacobian representative", prop.ForAll( + func(a fp.Element) bool { + var g g1JacExtended + g.X.Set(&g1Gen.X) + g.Y.Set(&g1Gen.Y) + g.ZZ.Set(&g1Gen.Z) + g.ZZZ.Set(&g1Gen.Z) + gfuzz := fuzzg1JacExtended(&g, a) + + var op1 G1Affine + op1.fromJacExtended(&gfuzz) + return op1.X.Equal(&g1Gen.X) && op1.Y.Equal(&g1Gen.Y) + }, + GenFp(), + )) + + properties.Property("[GRUMPKIN] Jacobian representation should be the same as the affine representative", prop.ForAll( + func(a fp.Element) bool { + var g G1Jac + var op1 G1Affine + op1.X.Set(&g1Gen.X) + op1.Y.Set(&g1Gen.Y) + + var one fp.Element + one.SetOne() + + g.FromAffine(&op1) + + return g.X.Equal(&g1Gen.X) && g.Y.Equal(&g1Gen.Y) && g.Z.Equal(&one) + }, + GenFp(), + )) + + properties.Property("[GRUMPKIN] Converting affine symbol for infinity to Jacobian should output correct infinity in Jacobian", prop.ForAll( + func() bool { + var g G1Affine + g.X.SetZero() + g.Y.SetZero() + var op1 G1Jac + op1.FromAffine(&g) + var one, zero fp.Element + one.SetOne() + return op1.X.Equal(&one) && op1.Y.Equal(&one) && op1.Z.Equal(&zero) + }, + )) + + properties.Property("[GRUMPKIN] Converting infinity in extended Jacobian to affine should output infinity symbol in Affine", prop.ForAll( + func() bool { + var g G1Affine + var op1 g1JacExtended + var zero fp.Element + op1.X.Set(&g1Gen.X) + op1.Y.Set(&g1Gen.Y) + g.fromJacExtended(&op1) + return g.X.Equal(&zero) && g.Y.Equal(&zero) + }, + )) + + properties.Property("[GRUMPKIN] Converting infinity in extended Jacobian to Jacobian should output infinity in Jacobian", prop.ForAll( + func() bool { + var g G1Jac + var op1 g1JacExtended + var zero, one fp.Element + one.SetOne() + op1.X.Set(&g1Gen.X) + op1.Y.Set(&g1Gen.Y) + g.fromJacExtended(&op1) + return g.X.Equal(&one) && g.Y.Equal(&one) && g.Z.Equal(&zero) + }, + )) + + properties.Property("[GRUMPKIN] [Jacobian] Two representatives of the same class should be equal", prop.ForAll( + func(a, b fp.Element) bool { + op1 := fuzzG1Jac(&g1Gen, a) + op2 := fuzzG1Jac(&g1Gen, b) + return op1.Equal(&op2) + }, + GenFp(), + GenFp(), + )) + properties.Property("[GRUMPKIN] BatchJacobianToAffineG1 and FromJacobian should output the same result", prop.ForAll( + func(a, b fp.Element) bool { + g1 := fuzzG1Jac(&g1Gen, a) + g2 := fuzzG1Jac(&g1Gen, b) + var op1, op2 G1Affine + op1.FromJacobian(&g1) + op2.FromJacobian(&g2) + baseTableAff := BatchJacobianToAffineG1([]G1Jac{g1, g2}) + return op1.Equal(&baseTableAff[0]) && op2.Equal(&baseTableAff[1]) + }, + GenFp(), + GenFp(), + )) + + properties.TestingRun(t, gopter.ConsoleReporter(false)) +} + +func TestG1AffineOps(t *testing.T) { + t.Parallel() + parameters := gopter.DefaultTestParameters() + parameters.MinSuccessfulTests = 10 + + properties := gopter.NewProperties(parameters) + + genScalar := GenFr() + + properties.Property("[GRUMPKIN] Add(P,-P) should return the point at infinity", prop.ForAll( + func(s fr.Element) bool { + var op1, op2 G1Affine + var sInt big.Int + g := g1GenAff + s.BigInt(&sInt) + op1.ScalarMultiplication(&g, &sInt) + op2.Neg(&op1) + + op1.Add(&op1, &op2) + return op1.IsInfinity() + + }, + GenFr(), + )) + + properties.Property("[GRUMPKIN] Add(P,0) and Add(0,P) should return P", prop.ForAll( + func(s fr.Element) bool { + var op1, op2 G1Affine + var sInt big.Int + g := g1GenAff + s.BigInt(&sInt) + op1.ScalarMultiplication(&g, &sInt) + op2.SetInfinity() + + op1.Add(&op1, &op2) + op2.Add(&op2, &op1) + return op1.Equal(&op2) + + }, + GenFr(), + )) + + properties.Property("[GRUMPKIN] Add should call double when adding the same point", prop.ForAll( + func(s fr.Element) bool { + var op1, op2 G1Affine + var sInt big.Int + g := g1GenAff + s.BigInt(&sInt) + op1.ScalarMultiplication(&g, &sInt) + + op2.Double(&op1) + op1.Add(&op1, &op1) + return op1.Equal(&op2) + + }, + GenFr(), + )) + + properties.Property("[GRUMPKIN] [2]G = double(G) + G - G", prop.ForAll( + func(s fr.Element) bool { + var sInt big.Int + g := g1GenAff + s.BigInt(&sInt) + g.ScalarMultiplication(&g, &sInt) + var op1, op2 G1Affine + op1.ScalarMultiplication(&g, big.NewInt(2)) + op2.Double(&g) + op2.Add(&op2, &g) + op2.Sub(&op2, &g) + return op1.Equal(&op2) + }, + GenFr(), + )) + + properties.Property("[GRUMPKIN] [-s]G = -[s]G", prop.ForAll( + func(s fr.Element) bool { + g := g1GenAff + var gj G1Jac + var nbs, bs big.Int + s.BigInt(&bs) + nbs.Neg(&bs) + + var res = true + + // mulGLV + { + var op1, op2 G1Affine + op1.ScalarMultiplication(&g, &bs).Neg(&op1) + op2.ScalarMultiplication(&g, &nbs) + res = res && op1.Equal(&op2) + } + + // mulWindowed + { + var op1, op2 G1Jac + op1.mulWindowed(&gj, &bs).Neg(&op1) + op2.mulWindowed(&gj, &nbs) + res = res && op1.Equal(&op2) + } + + return res + }, + GenFr(), + )) + + properties.Property("[GRUMPKIN] [Jacobian] Add should call double when adding the same point", prop.ForAll( + func(a, b fp.Element) bool { + fop1 := fuzzG1Jac(&g1Gen, a) + fop2 := fuzzG1Jac(&g1Gen, b) + var op1, op2 G1Jac + op1.Set(&fop1).AddAssign(&fop2) + op2.Double(&fop2) + return op1.Equal(&op2) + }, + GenFp(), + GenFp(), + )) + + properties.Property("[GRUMPKIN] [Jacobian] Adding the opposite of a point to itself should output inf", prop.ForAll( + func(a, b fp.Element) bool { + fop1 := fuzzG1Jac(&g1Gen, a) + fop2 := fuzzG1Jac(&g1Gen, b) + fop2.Neg(&fop2) + fop1.AddAssign(&fop2) + return fop1.Equal(&g1Infinity) + }, + GenFp(), + GenFp(), + )) + + properties.Property("[GRUMPKIN] [Jacobian] Adding the inf to a point should not modify the point", prop.ForAll( + func(a fp.Element) bool { + fop1 := fuzzG1Jac(&g1Gen, a) + fop1.AddAssign(&g1Infinity) + var op2 G1Jac + op2.Set(&g1Infinity) + op2.AddAssign(&g1Gen) + return fop1.Equal(&g1Gen) && op2.Equal(&g1Gen) + }, + GenFp(), + )) + + properties.Property("[GRUMPKIN] [Jacobian Extended] addMixed (-G) should equal subMixed(G)", prop.ForAll( + func(a fp.Element) bool { + fop1 := fuzzG1Jac(&g1Gen, a) + var p1, p1Neg G1Affine + p1.FromJacobian(&fop1) + p1Neg = p1 + p1Neg.Y.Neg(&p1Neg.Y) + var o1, o2 g1JacExtended + o1.addMixed(&p1Neg) + o2.subMixed(&p1) + + return o1.X.Equal(&o2.X) && + o1.Y.Equal(&o2.Y) && + o1.ZZ.Equal(&o2.ZZ) && + o1.ZZZ.Equal(&o2.ZZZ) + }, + GenFp(), + )) + + properties.Property("[GRUMPKIN] [Jacobian Extended] doubleMixed (-G) should equal doubleNegMixed(G)", prop.ForAll( + func(a fp.Element) bool { + fop1 := fuzzG1Jac(&g1Gen, a) + var p1, p1Neg G1Affine + p1.FromJacobian(&fop1) + p1Neg = p1 + p1Neg.Y.Neg(&p1Neg.Y) + var o1, o2 g1JacExtended + o1.doubleMixed(&p1Neg) + o2.doubleNegMixed(&p1) + + return o1.X.Equal(&o2.X) && + o1.Y.Equal(&o2.Y) && + o1.ZZ.Equal(&o2.ZZ) && + o1.ZZZ.Equal(&o2.ZZZ) + }, + GenFp(), + )) + + properties.Property("[GRUMPKIN] [Jacobian] Addmix the negation to itself should output 0", prop.ForAll( + func(a fp.Element) bool { + fop1 := fuzzG1Jac(&g1Gen, a) + fop1.Neg(&fop1) + var op2 G1Affine + op2.FromJacobian(&g1Gen) + fop1.AddMixed(&op2) + return fop1.Equal(&g1Infinity) + }, + GenFp(), + )) + + properties.Property("[GRUMPKIN] scalar multiplication (double and add) should depend only on the scalar mod r", prop.ForAll( + func(s fr.Element) bool { + + r := fr.Modulus() + var g G1Jac + g.ScalarMultiplication(&g1Gen, r) + + var scalar, blindedScalar, rminusone big.Int + var op1, op2, op3, gneg G1Jac + rminusone.SetUint64(1).Sub(r, &rminusone) + op3.mulWindowed(&g1Gen, &rminusone) + gneg.Neg(&g1Gen) + s.BigInt(&scalar) + blindedScalar.Mul(&scalar, r).Add(&blindedScalar, &scalar) + op1.mulWindowed(&g1Gen, &scalar) + op2.mulWindowed(&g1Gen, &blindedScalar) + + return op1.Equal(&op2) && g.Equal(&g1Infinity) && !op1.Equal(&g1Infinity) && gneg.Equal(&op3) + + }, + genScalar, + )) + + properties.Property("[GRUMPKIN] scalar multiplication (GLV) should depend only on the scalar mod r", prop.ForAll( + func(s fr.Element) bool { + + r := fr.Modulus() + var g G1Jac + g.mulGLV(&g1Gen, r) + + var scalar, blindedScalar, rminusone big.Int + var op1, op2, op3, gneg G1Jac + rminusone.SetUint64(1).Sub(r, &rminusone) + op3.ScalarMultiplication(&g1Gen, &rminusone) + gneg.Neg(&g1Gen) + s.BigInt(&scalar) + blindedScalar.Mul(&scalar, r).Add(&blindedScalar, &scalar) + op1.ScalarMultiplication(&g1Gen, &scalar) + op2.ScalarMultiplication(&g1Gen, &blindedScalar) + + return op1.Equal(&op2) && g.Equal(&g1Infinity) && !op1.Equal(&g1Infinity) && gneg.Equal(&op3) + + }, + genScalar, + )) + + properties.Property("[GRUMPKIN] GLV and Double and Add should output the same result", prop.ForAll( + func(s fr.Element) bool { + + var r big.Int + var op1, op2 G1Jac + s.BigInt(&r) + op1.mulWindowed(&g1Gen, &r) + op2.mulGLV(&g1Gen, &r) + return op1.Equal(&op2) && !op1.Equal(&g1Infinity) + + }, + genScalar, + )) + + properties.Property("[GRUMPKIN] JointScalarMultiplicationBase and ScalarMultiplication should output the same results", prop.ForAll( + func(s1, s2 fr.Element) bool { + + var op1, op2, temp G1Jac + + op1.JointScalarMultiplicationBase(&g1GenAff, s1.BigInt(new(big.Int)), s2.BigInt(new(big.Int))) + temp.ScalarMultiplication(&g1Gen, s2.BigInt(new(big.Int))) + op2.ScalarMultiplication(&g1Gen, s1.BigInt(new(big.Int))). + AddAssign(&temp) + + return op1.Equal(&op2) + + }, + genScalar, + genScalar, + )) + + properties.TestingRun(t, gopter.ConsoleReporter(false)) +} + +func TestG1AffineBatchScalarMultiplication(t *testing.T) { + + parameters := gopter.DefaultTestParameters() + if testing.Short() { + parameters.MinSuccessfulTests = nbFuzzShort + } else { + parameters.MinSuccessfulTests = nbFuzzShort + } + + properties := gopter.NewProperties(parameters) + + genScalar := GenFr() + + // size of the multiExps + const nbSamples = 10 + + properties.Property("[GRUMPKIN] BatchScalarMultiplication should be consistent with individual scalar multiplications", prop.ForAll( + func(mixer fr.Element) bool { + // mixer ensures that all the words of a fpElement are set + var sampleScalars [nbSamples]fr.Element + + for i := 1; i <= nbSamples; i++ { + sampleScalars[i-1].SetUint64(uint64(i)). + Mul(&sampleScalars[i-1], &mixer) + } + + result := BatchScalarMultiplicationG1(&g1GenAff, sampleScalars[:]) + + if len(result) != len(sampleScalars) { + return false + } + + for i := 0; i < len(result); i++ { + var expectedJac G1Jac + var expected G1Affine + var b big.Int + expectedJac.ScalarMultiplication(&g1Gen, sampleScalars[i].BigInt(&b)) + expected.FromJacobian(&expectedJac) + if !result[i].Equal(&expected) { + return false + } + } + return true + }, + genScalar, + )) + + properties.TestingRun(t, gopter.ConsoleReporter(false)) +} + +// ------------------------------------------------------------ +// benches + +func BenchmarkG1JacIsInSubGroup(b *testing.B) { + var a G1Jac + a.Set(&g1Gen) + b.ResetTimer() + for i := 0; i < b.N; i++ { + a.IsInSubGroup() + } + +} + +func BenchmarkG1JacEqual(b *testing.B) { + var scalar fp.Element + if _, err := scalar.SetRandom(); err != nil { + b.Fatalf("failed to set scalar: %s", err) + } + + var a G1Jac + a.ScalarMultiplication(&g1Gen, big.NewInt(42)) + + b.Run("equal", func(b *testing.B) { + var scalarSquared fp.Element + scalarSquared.Square(&scalar) + + aZScaled := a + aZScaled.X.Mul(&aZScaled.X, &scalarSquared) + aZScaled.Y.Mul(&aZScaled.Y, &scalarSquared).Mul(&aZScaled.Y, &scalar) + aZScaled.Z.Mul(&aZScaled.Z, &scalar) + + // Check the setup. + if !a.Equal(&aZScaled) { + b.Fatalf("invalid test setup") + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + a.Equal(&aZScaled) + } + }) + + b.Run("not equal", func(b *testing.B) { + var aPlus1 G1Jac + aPlus1.AddAssign(&g1Gen) + + // Check the setup. + if a.Equal(&aPlus1) { + b.Fatalf("invalid test setup") + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + a.Equal(&aPlus1) + } + }) +} + +func BenchmarkBatchAddG1Affine(b *testing.B) { + + var P, R pG1AffineC16 + var RR ppG1AffineC16 + ridx := make([]int, len(P)) + + // TODO P == R may produce skewed benches + fillBenchBasesG1(P[:]) + fillBenchBasesG1(R[:]) + + for i := 0; i < len(ridx); i++ { + ridx[i] = i + } + + // random permute + rand.Shuffle(len(ridx), func(i, j int) { ridx[i], ridx[j] = ridx[j], ridx[i] }) + + for i, ri := range ridx { + RR[i] = &R[ri] + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + batchAddG1Affine[pG1AffineC16, ppG1AffineC16, cG1AffineC16](&RR, &P, len(P)) + } +} + +func BenchmarkG1AffineBatchScalarMultiplication(b *testing.B) { + // ensure every words of the scalars are filled + var mixer fr.Element + mixer.SetString("7716837800905789770901243404444209691916730933998574719964609384059111546487") + + const pow = 15 + const nbSamples = 1 << pow + + var sampleScalars [nbSamples]fr.Element + + for i := 1; i <= nbSamples; i++ { + sampleScalars[i-1].SetUint64(uint64(i)). + Mul(&sampleScalars[i-1], &mixer) + } + + for i := 5; i <= pow; i++ { + using := 1 << i + + b.Run(fmt.Sprintf("%d points", using), func(b *testing.B) { + b.ResetTimer() + for j := 0; j < b.N; j++ { + _ = BatchScalarMultiplicationG1(&g1GenAff, sampleScalars[:using]) + } + }) + } +} + +func BenchmarkG1JacScalarMultiplication(b *testing.B) { + + var scalar big.Int + r := fr.Modulus() + scalar.SetString("5243587517512619047944770508185965837690552500527637822603658699938581184513", 10) + scalar.Add(&scalar, r) + + var doubleAndAdd G1Jac + + b.Run("double and add", func(b *testing.B) { + b.ResetTimer() + for j := 0; j < b.N; j++ { + doubleAndAdd.mulWindowed(&g1Gen, &scalar) + } + }) + + var glv G1Jac + b.Run("GLV", func(b *testing.B) { + b.ResetTimer() + for j := 0; j < b.N; j++ { + glv.mulGLV(&g1Gen, &scalar) + } + }) + +} + +func BenchmarkG1JacAdd(b *testing.B) { + var a G1Jac + a.Double(&g1Gen) + b.ResetTimer() + for i := 0; i < b.N; i++ { + a.AddAssign(&g1Gen) + } +} + +func BenchmarkG1JacAddMixed(b *testing.B) { + var a G1Jac + a.Double(&g1Gen) + + var c G1Affine + c.FromJacobian(&g1Gen) + b.ResetTimer() + for i := 0; i < b.N; i++ { + a.AddMixed(&c) + } + +} + +func BenchmarkG1JacDouble(b *testing.B) { + var a G1Jac + a.Set(&g1Gen) + b.ResetTimer() + for i := 0; i < b.N; i++ { + a.DoubleAssign() + } + +} + +func BenchmarkG1JacExtAddMixed(b *testing.B) { + var a g1JacExtended + a.doubleMixed(&g1GenAff) + + var c G1Affine + c.FromJacobian(&g1Gen) + b.ResetTimer() + for i := 0; i < b.N; i++ { + a.addMixed(&c) + } +} + +func BenchmarkG1JacExtSubMixed(b *testing.B) { + var a g1JacExtended + a.doubleMixed(&g1GenAff) + + var c G1Affine + c.FromJacobian(&g1Gen) + b.ResetTimer() + for i := 0; i < b.N; i++ { + a.subMixed(&c) + } +} + +func BenchmarkG1JacExtDoubleMixed(b *testing.B) { + var a g1JacExtended + a.doubleMixed(&g1GenAff) + + var c G1Affine + c.FromJacobian(&g1Gen) + b.ResetTimer() + for i := 0; i < b.N; i++ { + a.doubleMixed(&c) + } +} + +func BenchmarkG1JacExtDoubleNegMixed(b *testing.B) { + var a g1JacExtended + a.doubleMixed(&g1GenAff) + + var c G1Affine + c.FromJacobian(&g1Gen) + b.ResetTimer() + for i := 0; i < b.N; i++ { + a.doubleNegMixed(&c) + } +} + +func BenchmarkG1JacExtAdd(b *testing.B) { + var a, c g1JacExtended + a.doubleMixed(&g1GenAff) + c.double(&a) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + a.add(&c) + } +} + +func BenchmarkG1JacExtDouble(b *testing.B) { + var a g1JacExtended + a.doubleMixed(&g1GenAff) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + a.double(&a) + } +} + +func BenchmarkG1AffineAdd(b *testing.B) { + var a G1Affine + a.Double(&g1GenAff) + b.ResetTimer() + for i := 0; i < b.N; i++ { + a.Add(&a, &g1GenAff) + } +} + +func BenchmarkG1AffineDouble(b *testing.B) { + var a G1Affine + a.Double(&g1GenAff) + b.ResetTimer() + for i := 0; i < b.N; i++ { + a.Double(&a) + } +} + +func fuzzG1Jac(p *G1Jac, f fp.Element) G1Jac { + var res G1Jac + res.X.Mul(&p.X, &f).Mul(&res.X, &f) + res.Y.Mul(&p.Y, &f).Mul(&res.Y, &f).Mul(&res.Y, &f) + res.Z.Mul(&p.Z, &f) + return res +} + +func fuzzg1JacExtended(p *g1JacExtended, f fp.Element) g1JacExtended { + var res g1JacExtended + var ff, fff fp.Element + ff.Square(&f) + fff.Mul(&ff, &f) + res.X.Mul(&p.X, &ff) + res.Y.Mul(&p.Y, &fff) + res.ZZ.Mul(&p.ZZ, &ff) + res.ZZZ.Mul(&p.ZZZ, &fff) + return res +} diff --git a/ecc/grumpkin/grumpkin.go b/ecc/grumpkin/grumpkin.go new file mode 100644 index 0000000000..33207330ae --- /dev/null +++ b/ecc/grumpkin/grumpkin.go @@ -0,0 +1,86 @@ +// Copyright 2020-2025 Consensys Software Inc. +// Licensed under the Apache License, Version 2.0. See the LICENSE file for details. + +// Package grumpkin efficient elliptic curve and hash to curve implementation for grumpkin. This curve appears forms a 2-cycle with bn254 [https://aztecprotocol.github.io/aztec-connect/primitives.html]. +// +// grumpkin: A j=0 curve with +// +// 𝔽r: r=21888242871839275222246405745257275088696311157297823662689037894645226208583 +// 𝔽p: p=21888242871839275222246405745257275088548364400416034343698204186575808495617 +// (E/𝔽p): Y²=X³-17 +// r ∣ #E(Fp) +// +// Security: estimated 127-bit level against Pollard's Rho attack +// (r is 254 bits) +// +// # Warning +// +// This code has been partially audited and is provided as-is. In particular, there is no security guarantees such as constant time implementation or side-channel attack resistance. +package grumpkin + +import ( + "math/big" + + "github.com/consensys/gnark-crypto/ecc" + "github.com/consensys/gnark-crypto/ecc/grumpkin/fp" + "github.com/consensys/gnark-crypto/ecc/grumpkin/fr" +) + +// ID grumpkin ID +const ID = ecc.GRUMPKIN + +// aCurveCoeff is the a coefficients of the curve Y²=X³+ax+b +var aCurveCoeff fp.Element +var bCurveCoeff fp.Element + +// generators of the r-torsion group, resp. in ker(pi-id), ker(Tr) +var g1Gen G1Jac + +var g1GenAff G1Affine + +// point at infinity +var g1Infinity G1Jac + +// Parameters useful for the GLV scalar multiplication. The third roots define the +// endomorphisms ϕ₁ for . lambda is such that lies above +// in the ring Z[ϕ]. More concretely it's the associated eigenvalue +// of ϕ₁ restricted to . +// see https://www.cosic.esat.kuleuven.be/nessie/reports/phase2/GLV.pdf +var thirdRootOneG1 fp.Element +var lambdaGLV big.Int + +// glvBasis stores R-linearly independent vectors (a,b), (c,d) +// in ker((u,v) → u+vλ[r]), and their determinant +var glvBasis ecc.Lattice + +func init() { + aCurveCoeff.SetUint64(0) + bCurveCoeff.SetUint64(17).Neg(&bCurveCoeff) + + g1Gen.X.SetOne() + g1Gen.Y.SetString("17631683881184975370165255887551781615748388533673675138860") // sqrt(-16) % p + g1Gen.Z.SetOne() + + g1GenAff.FromJacobian(&g1Gen) + + // (X,Y,Z) = (1,1,0) + g1Infinity.X.SetOne() + g1Infinity.Y.SetOne() + + thirdRootOneG1.SetString("4407920970296243842393367215006156084916469457145843978461") + lambdaGLV.SetString("2203960485148121921418603742825762020974279258880205651966", 10) + _r := fr.Modulus() + ecc.PrecomputeLattice(_r, &lambdaGLV, &glvBasis) +} + +// Generators return the generators of the r-torsion group, resp. in ker(pi-id), ker(Tr) +func Generators() (g1Jac G1Jac, g1Aff G1Affine) { + g1Aff = g1GenAff + g1Jac = g1Gen + return +} + +// CurveCoefficients returns the a, b coefficients of the curve equation. +func CurveCoefficients() (a, b fp.Element) { + return aCurveCoeff, bCurveCoeff +} diff --git a/ecc/grumpkin/hash_to_g1.go b/ecc/grumpkin/hash_to_g1.go new file mode 100644 index 0000000000..fe65dce8b9 --- /dev/null +++ b/ecc/grumpkin/hash_to_g1.go @@ -0,0 +1,150 @@ +// Copyright 2020-2025 Consensys Software Inc. +// Licensed under the Apache License, Version 2.0. See the LICENSE file for details. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package grumpkin + +import ( + "github.com/consensys/gnark-crypto/ecc/grumpkin/fp" +) + +// MapToCurve1 implements the Shallue and van de Woestijne method, applicable to any elliptic curve in Weierstrass form +// No cofactor clearing or isogeny +// https://www.ietf.org/archive/id/draft-irtf-cfrg-hash-to-curve-16.html#straightline-svdw +func MapToCurve1(u *fp.Element) G1Affine { + var tv1, tv2, tv3, tv4 fp.Element + var x1, x2, x3, gx1, gx2, gx, x, y fp.Element + var one fp.Element + var gx1NotSquare, gx1SquareOrGx2Not int + + //constants + //c1 = g(Z) + //c2 = -Z / 2 + //c3 = sqrt(-g(Z) * (3 * Z² + 4 * A)) # sgn0(c3) MUST equal 0 + //c4 = -4 * g(Z) / (3 * Z² + 4 * A) + + Z := fp.Element{12436184717236109307, 3962172157175319849, 7381016538464732718, 1011752739694698287} + c1 := fp.Element{9945788691500761173, 6430049622857769019, 3649927362066405102, 1246947498899680730} + c2 := fp.Element{14674382058109796355, 8690743149920539059, 2950087706404981015, 1237622763554136189} + c3 := fp.Element{13810035865635542012, 18434659582482110917, 5527417781862982147, 1648708720397174142} + c4 := fp.Element{10077153171078468837, 472429577399671532, 14563536826422671818, 1824401601603396358} + + one.SetOne() + + tv1.Square(u) // 1. tv1 = u² + tv1.Mul(&tv1, &c1) // 2. tv1 = tv1 * c1 + tv2.Add(&one, &tv1) // 3. tv2 = 1 + tv1 + tv1.Sub(&one, &tv1) // 4. tv1 = 1 - tv1 + tv3.Mul(&tv1, &tv2) // 5. tv3 = tv1 * tv2 + + tv3.Inverse(&tv3) // 6. tv3 = inv0(tv3) + tv4.Mul(u, &tv1) // 7. tv4 = u * tv1 + tv4.Mul(&tv4, &tv3) // 8. tv4 = tv4 * tv3 + tv4.Mul(&tv4, &c3) // 9. tv4 = tv4 * c3 + x1.Sub(&c2, &tv4) // 10. x1 = c2 - tv4 + + gx1.Square(&x1) // 11. gx1 = x1² + //12. gx1 = gx1 + A All curves in gnark-crypto have A=0 (j-invariant=0). It is crucial to include this step if the curve has nonzero A coefficient. + gx1.Mul(&gx1, &x1) // 13. gx1 = gx1 * x1 + gx1.Add(&gx1, &bCurveCoeff) // 14. gx1 = gx1 + B + gx1NotSquare = gx1.Legendre() >> 1 // 15. e1 = is_square(gx1) + // gx1NotSquare = 0 if gx1 is a square, -1 otherwise + + x2.Add(&c2, &tv4) // 16. x2 = c2 + tv4 + gx2.Square(&x2) // 17. gx2 = x2² + // 18. gx2 = gx2 + A See line 12 + gx2.Mul(&gx2, &x2) // 19. gx2 = gx2 * x2 + gx2.Add(&gx2, &bCurveCoeff) // 20. gx2 = gx2 + B + + { + gx2NotSquare := gx2.Legendre() >> 1 // gx2Square = 0 if gx2 is a square, -1 otherwise + gx1SquareOrGx2Not = gx2NotSquare | ^gx1NotSquare // 21. e2 = is_square(gx2) AND NOT e1 # Avoid short-circuit logic ops + } + + x3.Square(&tv2) // 22. x3 = tv2² + x3.Mul(&x3, &tv3) // 23. x3 = x3 * tv3 + x3.Square(&x3) // 24. x3 = x3² + x3.Mul(&x3, &c4) // 25. x3 = x3 * c4 + + x3.Add(&x3, &Z) // 26. x3 = x3 + Z + x.Select(gx1NotSquare, &x1, &x3) // 27. x = CMOV(x3, x1, e1) # x = x1 if gx1 is square, else x = x3 + // Select x1 iff gx1 is square iff gx1NotSquare = 0 + x.Select(gx1SquareOrGx2Not, &x2, &x) // 28. x = CMOV(x, x2, e2) # x = x2 if gx2 is square and gx1 is not + // Select x2 iff gx2 is square and gx1 is not, iff gx1SquareOrGx2Not = 0 + gx.Square(&x) // 29. gx = x² + // 30. gx = gx + A + + gx.Mul(&gx, &x) // 31. gx = gx * x + gx.Add(&gx, &bCurveCoeff) // 32. gx = gx + B + + y.Sqrt(&gx) // 33. y = sqrt(gx) + signsNotEqual := g1Sgn0(u) ^ g1Sgn0(&y) // 34. e3 = sgn0(u) == sgn0(y) + + tv1.Neg(&y) + y.Select(int(signsNotEqual), &y, &tv1) // 35. y = CMOV(-y, y, e3) # Select correct sign of y + return G1Affine{x, y} +} + +// g1Sgn0 is an algebraic substitute for the notion of sign in ordered fields +// Namely, every non-zero quadratic residue in a finite field of characteristic =/= 2 has exactly two square roots, one of each sign +// https://www.ietf.org/archive/id/draft-irtf-cfrg-hash-to-curve-16.html#name-the-sgn0-function +// The sign of an element is not obviously related to that of its Montgomery form +func g1Sgn0(z *fp.Element) uint64 { + + nonMont := z.Bits() + + // m == 1 + return nonMont[0] % 2 + +} + +// MapToG1 invokes the SVDW map, and guarantees that the result is in g1 +func MapToG1(u fp.Element) G1Affine { + res := MapToCurve1(&u) + return res +} + +// EncodeToG1 hashes a message to a point on the G1 curve using the SVDW map. +// It is faster than HashToG1, but the result is not uniformly distributed. Unsuitable as a random oracle. +// dst stands for "domain separation tag", a string unique to the construction using the hash function +// https://www.ietf.org/archive/id/draft-irtf-cfrg-hash-to-curve-16.html#roadmap +func EncodeToG1(msg, dst []byte) (G1Affine, error) { + + var res G1Affine + u, err := fp.Hash(msg, dst, 1) + if err != nil { + return res, err + } + + res = MapToCurve1(&u[0]) + + return res, nil +} + +// HashToG1 hashes a message to a point on the G1 curve using the SVDW map. +// Slower than EncodeToG1, but usable as a random oracle. +// dst stands for "domain separation tag", a string unique to the construction using the hash function +// https://www.ietf.org/archive/id/draft-irtf-cfrg-hash-to-curve-16.html#roadmap +func HashToG1(msg, dst []byte) (G1Affine, error) { + u, err := fp.Hash(msg, dst, 2*1) + if err != nil { + return G1Affine{}, err + } + + Q0 := MapToCurve1(&u[0]) + Q1 := MapToCurve1(&u[1]) + + var _Q0, _Q1 G1Jac + _Q0.FromAffine(&Q0) + _Q1.FromAffine(&Q1).AddAssign(&_Q0) + + Q1.FromJacobian(&_Q1) + return Q1, nil +} + +func g1NotZero(x *fp.Element) uint64 { + + return x[0] | x[1] | x[2] | x[3] + +} diff --git a/ecc/grumpkin/hash_to_g1_test.go b/ecc/grumpkin/hash_to_g1_test.go new file mode 100644 index 0000000000..d7ca26c836 --- /dev/null +++ b/ecc/grumpkin/hash_to_g1_test.go @@ -0,0 +1,223 @@ +// Copyright 2020-2025 Consensys Software Inc. +// Licensed under the Apache License, Version 2.0. See the LICENSE file for details. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package grumpkin + +import ( + "github.com/consensys/gnark-crypto/ecc/grumpkin/fp" + "github.com/leanovate/gopter" + "github.com/leanovate/gopter/prop" + "math/rand" + "testing" +) + +func TestHashToFpG1(t *testing.T) { + for _, c := range encodeToG1Vector.cases { + elems, err := fp.Hash([]byte(c.msg), encodeToG1Vector.dst, 1) + if err != nil { + t.Error(err) + } + g1TestMatchCoord(t, "u", c.msg, c.u, g1CoordAt(elems, 0)) + } + + for _, c := range hashToG1Vector.cases { + elems, err := fp.Hash([]byte(c.msg), hashToG1Vector.dst, 2*1) + if err != nil { + t.Error(err) + } + g1TestMatchCoord(t, "u0", c.msg, c.u0, g1CoordAt(elems, 0)) + g1TestMatchCoord(t, "u1", c.msg, c.u1, g1CoordAt(elems, 1)) + } +} + +func TestMapToCurve1(t *testing.T) { + t.Parallel() + parameters := gopter.DefaultTestParameters() + if testing.Short() { + parameters.MinSuccessfulTests = nbFuzzShort + } else { + parameters.MinSuccessfulTests = nbFuzz + } + + properties := gopter.NewProperties(parameters) + + properties.Property("[G1] mapping output must be on curve", prop.ForAll( + func(a fp.Element) bool { + + g := MapToCurve1(&a) + + if !g.IsOnCurve() { + t.Log("SVDW output not on curve") + return false + } + + return true + }, + GenFp(), + )) + + properties.TestingRun(t, gopter.ConsoleReporter(false)) + + for _, c := range encodeToG1Vector.cases { + var u fp.Element + g1CoordSetString(&u, c.u) + q := MapToCurve1(&u) + g1TestMatchPoint(t, "Q", c.msg, c.Q, &q) + } + + for _, c := range hashToG1Vector.cases { + var u fp.Element + g1CoordSetString(&u, c.u0) + q := MapToCurve1(&u) + g1TestMatchPoint(t, "Q0", c.msg, c.Q0, &q) + + g1CoordSetString(&u, c.u1) + q = MapToCurve1(&u) + g1TestMatchPoint(t, "Q1", c.msg, c.Q1, &q) + } +} + +func TestMapToG1(t *testing.T) { + t.Parallel() + parameters := gopter.DefaultTestParameters() + if testing.Short() { + parameters.MinSuccessfulTests = nbFuzzShort + } else { + parameters.MinSuccessfulTests = nbFuzz + } + + properties := gopter.NewProperties(parameters) + + properties.Property("[G1] mapping to curve should output point on the curve", prop.ForAll( + func(a fp.Element) bool { + g := MapToG1(a) + return g.IsInSubGroup() + }, + GenFp(), + )) + + properties.Property("[G1] mapping to curve should be deterministic", prop.ForAll( + func(a fp.Element) bool { + g1 := MapToG1(a) + g2 := MapToG1(a) + return g1.Equal(&g2) + }, + GenFp(), + )) + + properties.TestingRun(t, gopter.ConsoleReporter(false)) +} + +func TestEncodeToG1(t *testing.T) { + t.Parallel() + for _, c := range encodeToG1Vector.cases { + p, err := EncodeToG1([]byte(c.msg), encodeToG1Vector.dst) + if err != nil { + t.Fatal(err) + } + g1TestMatchPoint(t, "P", c.msg, c.P, &p) + } +} + +func TestHashToG1(t *testing.T) { + t.Parallel() + for _, c := range hashToG1Vector.cases { + p, err := HashToG1([]byte(c.msg), hashToG1Vector.dst) + if err != nil { + t.Fatal(err) + } + g1TestMatchPoint(t, "P", c.msg, c.P, &p) + } +} + +func BenchmarkEncodeToG1(b *testing.B) { + const size = 54 + bytes := make([]byte, size) + dst := encodeToG1Vector.dst + b.ResetTimer() + + for i := 0; i < b.N; i++ { + + bytes[rand.Int()%size] = byte(rand.Int()) //#nosec G404 weak rng is fine here + + if _, err := EncodeToG1(bytes, dst); err != nil { + b.Fail() + } + } +} + +func BenchmarkHashToG1(b *testing.B) { + const size = 54 + bytes := make([]byte, size) + dst := hashToG1Vector.dst + b.ResetTimer() + + for i := 0; i < b.N; i++ { + + bytes[rand.Int()%size] = byte(rand.Int()) //#nosec G404 weak rng is fine here + + if _, err := HashToG1(bytes, dst); err != nil { + b.Fail() + } + } +} + +// Only works on simple extensions (two-story towers) +func g1CoordSetString(z *fp.Element, s string) { + z.SetString(s) +} + +func g1CoordAt(slice []fp.Element, i int) fp.Element { + return slice[i] +} + +func g1TestMatchCoord(t *testing.T, coordName string, msg string, expectedStr string, seen fp.Element) { + var expected fp.Element + + g1CoordSetString(&expected, expectedStr) + + if !expected.Equal(&seen) { + t.Errorf("mismatch on \"%s\", %s:\n\texpected %s\n\tsaw %s", msg, coordName, expected.String(), &seen) + } +} + +func g1TestMatchPoint(t *testing.T, pointName string, msg string, expected point, seen *G1Affine) { + g1TestMatchCoord(t, pointName+".x", msg, expected.x, seen.X) + g1TestMatchCoord(t, pointName+".y", msg, expected.y, seen.Y) +} + +type hashTestVector struct { + dst []byte + cases []hashTestCase +} + +type encodeTestVector struct { + dst []byte + cases []encodeTestCase +} + +type point struct { + x string + y string +} + +type encodeTestCase struct { + msg string + P point //pY a coordinate of P, the final output + u string //u hashed onto the field + Q point //Q map to curve output +} + +type hashTestCase struct { + msg string + P point //pY a coordinate of P, the final output + u0 string //u0 hashed onto the field + u1 string //u1 extra hashed onto the field + Q0 point //Q0 map to curve output + Q1 point //Q1 extra map to curve output +} + +var encodeToG1Vector encodeTestVector +var hashToG1Vector hashTestVector diff --git a/ecc/grumpkin/marshal.go b/ecc/grumpkin/marshal.go new file mode 100644 index 0000000000..cadf33619a --- /dev/null +++ b/ecc/grumpkin/marshal.go @@ -0,0 +1,884 @@ +// Copyright 2020-2025 Consensys Software Inc. +// Licensed under the Apache License, Version 2.0. See the LICENSE file for details. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package grumpkin + +import ( + "encoding/binary" + "errors" + "io" + "reflect" + "sync/atomic" + + "github.com/consensys/gnark-crypto/ecc/grumpkin/fp" + "github.com/consensys/gnark-crypto/ecc/grumpkin/fr" + "github.com/consensys/gnark-crypto/internal/parallel" +) + +// To encode G1Affine points, we mask the most significant bits with these bits to specify without ambiguity +// metadata needed for point (de)compression +// we have less than 3 bits available on the msw, so we can't follow BLS12-381 style encoding. +// the difference is the case where a point is infinity and uncompressed is not flagged +const ( + mMask byte = 0b11 << 6 + mUncompressed byte = 0b00 << 6 + mCompressedSmallest byte = 0b10 << 6 + mCompressedLargest byte = 0b11 << 6 + mCompressedInfinity byte = 0b01 << 6 +) + +var ( + ErrInvalidInfinityEncoding = errors.New("invalid infinity point encoding") + ErrInvalidEncoding = errors.New("invalid point encoding") +) + +// Encoder writes bn254 object values to an output stream +type Encoder struct { + w io.Writer + n int64 // written bytes + raw bool // raw vs compressed encoding +} + +// Decoder reads bn254 object values from an inbound stream +type Decoder struct { + r io.Reader + n int64 // read bytes + subGroupCheck bool // default to true +} + +// NewDecoder returns a binary decoder supporting curve bn254 objects in both +// compressed and uncompressed (raw) forms +func NewDecoder(r io.Reader, options ...func(*Decoder)) *Decoder { + d := &Decoder{r: r, subGroupCheck: true} + + for _, o := range options { + o(d) + } + + return d +} + +// Decode reads the binary encoding of v from the stream +// type must be *uint64, *fr.Element, *fp.Element, *G1Affine or *[]G1Affine +func (dec *Decoder) Decode(v interface{}) (err error) { + rv := reflect.ValueOf(v) + if v == nil || rv.Kind() != reflect.Ptr || rv.IsNil() || !rv.Elem().CanSet() { + return errors.New("bn254 decoder: unsupported type, need pointer") + } + + // implementation note: code is a bit verbose (abusing code generation), but minimize allocations on the heap + // in particular, careful attention must be given to usage of Bytes() method on Elements and Points + // that return an array (not a slice) of bytes. Using this is beneficial to minimize memory allocations + // in very large (de)serialization upstream in gnark. + // (but detrimental to code readability here) + + var read64 int64 + if vf, ok := v.(io.ReaderFrom); ok { + read64, err = vf.ReadFrom(dec.r) + dec.n += read64 + return + } + + var buf [SizeOfG1AffineUncompressed]byte + var read int + var sliceLen uint32 + + switch t := v.(type) { + case *[][]uint64: + if sliceLen, err = dec.readUint32(); err != nil { + return + } + *t = make([][]uint64, sliceLen) + + for i := range *t { + if sliceLen, err = dec.readUint32(); err != nil { + return + } + (*t)[i] = make([]uint64, sliceLen) + for j := range (*t)[i] { + if (*t)[i][j], err = dec.readUint64(); err != nil { + return + } + } + } + return + case *[]uint64: + if sliceLen, err = dec.readUint32(); err != nil { + return + } + *t = make([]uint64, sliceLen) + for i := range *t { + if (*t)[i], err = dec.readUint64(); err != nil { + return + } + } + return + case *fr.Element: + read, err = io.ReadFull(dec.r, buf[:fr.Bytes]) + dec.n += int64(read) + if err != nil { + return + } + err = t.SetBytesCanonical(buf[:fr.Bytes]) + return + case *fp.Element: + read, err = io.ReadFull(dec.r, buf[:fp.Bytes]) + dec.n += int64(read) + if err != nil { + return + } + err = t.SetBytesCanonical(buf[:fp.Bytes]) + return + case *[]fr.Element: + read64, err = (*fr.Vector)(t).ReadFrom(dec.r) + dec.n += read64 + return + case *[]fp.Element: + read64, err = (*fp.Vector)(t).ReadFrom(dec.r) + dec.n += read64 + return + case *[][]fr.Element: + if sliceLen, err = dec.readUint32(); err != nil { + return + } + if len(*t) != int(sliceLen) { + *t = make([][]fr.Element, sliceLen) + } + for i := range *t { + read64, err = (*fr.Vector)(&(*t)[i]).ReadFrom(dec.r) + dec.n += read64 + } + return + case *[][][]fr.Element: + if sliceLen, err = dec.readUint32(); err != nil { + return + } + if len(*t) != int(sliceLen) { + *t = make([][][]fr.Element, sliceLen) + } + for i := range *t { + if sliceLen, err = dec.readUint32(); err != nil { + return + } + if len((*t)[i]) != int(sliceLen) { + (*t)[i] = make([][]fr.Element, sliceLen) + } + for j := range (*t)[i] { + read64, err = (*fr.Vector)(&(*t)[i][j]).ReadFrom(dec.r) + dec.n += read64 + } + } + return + case *G1Affine: + // we start by reading compressed point size, if metadata tells us it is uncompressed, we read more. + read, err = io.ReadFull(dec.r, buf[:SizeOfG1AffineCompressed]) + dec.n += int64(read) + if err != nil { + return + } + nbBytes := SizeOfG1AffineCompressed + + // most significant byte contains metadata + if !isCompressed(buf[0]) { + nbBytes = SizeOfG1AffineUncompressed + // we read more. + read, err = io.ReadFull(dec.r, buf[SizeOfG1AffineCompressed:SizeOfG1AffineUncompressed]) + dec.n += int64(read) + if err != nil { + return + } + } + _, err = t.setBytes(buf[:nbBytes], dec.subGroupCheck) + return + case *[]G1Affine: + sliceLen, err = dec.readUint32() + if err != nil { + return + } + if len(*t) != int(sliceLen) || *t == nil { + *t = make([]G1Affine, sliceLen) + } + compressed := make([]bool, sliceLen) + for i := 0; i < len(*t); i++ { + + // we start by reading compressed point size, if metadata tells us it is uncompressed, we read more. + read, err = io.ReadFull(dec.r, buf[:SizeOfG1AffineCompressed]) + dec.n += int64(read) + if err != nil { + return + } + nbBytes := SizeOfG1AffineCompressed + + // most significant byte contains metadata + if !isCompressed(buf[0]) { + nbBytes = SizeOfG1AffineUncompressed + // we read more. + read, err = io.ReadFull(dec.r, buf[SizeOfG1AffineCompressed:SizeOfG1AffineUncompressed]) + dec.n += int64(read) + if err != nil { + return + } + _, err = (*t)[i].setBytes(buf[:nbBytes], false) + if err != nil { + return + } + } else { + var r bool + if r, err = (*t)[i].unsafeSetCompressedBytes(buf[:nbBytes]); err != nil { + return + } + compressed[i] = !r + } + } + var nbErrs uint64 + parallel.Execute(len(compressed), func(start, end int) { + for i := start; i < end; i++ { + if compressed[i] { + if err := (*t)[i].unsafeComputeY(dec.subGroupCheck); err != nil { + atomic.AddUint64(&nbErrs, 1) + } + } else if dec.subGroupCheck { + if !(*t)[i].IsInSubGroup() { + atomic.AddUint64(&nbErrs, 1) + } + } + } + }) + if nbErrs != 0 { + return errors.New("point decompression failed") + } + + return nil + default: + n := binary.Size(t) + if n == -1 { + return errors.New("bn254 encoder: unsupported type") + } + err = binary.Read(dec.r, binary.BigEndian, t) + if err == nil { + dec.n += int64(n) + } + return + } +} + +// BytesRead return total bytes read from reader +func (dec *Decoder) BytesRead() int64 { + return dec.n +} + +func (dec *Decoder) readUint32() (r uint32, err error) { + var read int + var buf [4]byte + read, err = io.ReadFull(dec.r, buf[:4]) + dec.n += int64(read) + if err != nil { + return + } + r = binary.BigEndian.Uint32(buf[:4]) + return +} + +func (dec *Decoder) readUint64() (r uint64, err error) { + var read int + var buf [8]byte + read, err = io.ReadFull(dec.r, buf[:]) + dec.n += int64(read) + if err != nil { + return + } + r = binary.BigEndian.Uint64(buf[:]) + return +} + +func isCompressed(msb byte) bool { + mData := msb & mMask + return !(mData == mUncompressed) +} + +// NewEncoder returns a binary encoder supporting curve bn254 objects +func NewEncoder(w io.Writer, options ...func(*Encoder)) *Encoder { + // default settings + enc := &Encoder{ + w: w, + n: 0, + raw: false, + } + + // handle options + for _, option := range options { + option(enc) + } + + return enc +} + +// Encode writes the binary encoding of v to the stream +// type must be uint64, *fr.Element, *fp.Element, *G1Affine, []G1Affine or *[]G1Affine +func (enc *Encoder) Encode(v interface{}) (err error) { + if enc.raw { + return enc.encodeRaw(v) + } + return enc.encode(v) +} + +// BytesWritten return total bytes written on writer +func (enc *Encoder) BytesWritten() int64 { + return enc.n +} + +// RawEncoding returns an option to use in NewEncoder(...) which sets raw encoding mode to true +// points will not be compressed using this option +func RawEncoding() func(*Encoder) { + return func(enc *Encoder) { + enc.raw = true + } +} + +// NoSubgroupChecks returns an option to use in NewDecoder(...) which disable subgroup checks on the points +// the decoder will read. Use with caution, as crafted points from an untrusted source can lead to crypto-attacks. +func NoSubgroupChecks() func(*Decoder) { + return func(dec *Decoder) { + dec.subGroupCheck = false + } +} + +// isZeroed checks that the provided bytes are at 0 +func isZeroed(firstByte byte, buf []byte) bool { + if firstByte != 0 { + return false + } + for _, b := range buf { + if b != 0 { + return false + } + } + return true +} + +func (enc *Encoder) encode(v interface{}) (err error) { + rv := reflect.ValueOf(v) + if v == nil || (rv.Kind() == reflect.Ptr && rv.IsNil()) { + return errors.New(" encoder: can't encode ") + } + + // implementation note: code is a bit verbose (abusing code generation), but minimize allocations on the heap + + var written64 int64 + if vw, ok := v.(io.WriterTo); ok { + written64, err = vw.WriteTo(enc.w) + enc.n += written64 + return + } + + var written int + + switch t := v.(type) { + case []uint64: + return enc.writeUint64Slice(t) + case [][]uint64: + return enc.writeUint64SliceSlice(t) + case *fr.Element: + buf := t.Bytes() + written, err = enc.w.Write(buf[:]) + enc.n += int64(written) + return + case *fp.Element: + buf := t.Bytes() + written, err = enc.w.Write(buf[:]) + enc.n += int64(written) + return + case *G1Affine: + buf := t.Bytes() + written, err = enc.w.Write(buf[:]) + enc.n += int64(written) + return + case fr.Vector: + written64, err = t.WriteTo(enc.w) + enc.n += written64 + return + case fp.Vector: + written64, err = t.WriteTo(enc.w) + enc.n += written64 + return + case []fr.Element: + written64, err = (*fr.Vector)(&t).WriteTo(enc.w) + enc.n += written64 + return + case []fp.Element: + written64, err = (*fp.Vector)(&t).WriteTo(enc.w) + enc.n += written64 + return + case [][]fr.Element: + // write slice length + if err = binary.Write(enc.w, binary.BigEndian, uint32(len(t))); err != nil { + return + } + enc.n += 4 + for i := range t { + written64, err = (*fr.Vector)(&t[i]).WriteTo(enc.w) + enc.n += written64 + } + return + case [][][]fr.Element: + // number of collections + if err = binary.Write(enc.w, binary.BigEndian, uint32(len(t))); err != nil { + return + } + enc.n += 4 + for i := range t { + // size of current collection + if err = binary.Write(enc.w, binary.BigEndian, uint32(len(t[i]))); err != nil { + return + } + enc.n += 4 + // write each vector of the current collection + for j := range t[i] { + written64, err = (*fr.Vector)(&t[i][j]).WriteTo(enc.w) + enc.n += written64 + } + } + return + case *[]G1Affine: + return enc.encode(*t) + case []G1Affine: + // write slice length + err = binary.Write(enc.w, binary.BigEndian, uint32(len(t))) + if err != nil { + return + } + enc.n += 4 + + var buf [SizeOfG1AffineCompressed]byte + + for i := 0; i < len(t); i++ { + buf = t[i].Bytes() + written, err = enc.w.Write(buf[:]) + enc.n += int64(written) + if err != nil { + return + } + } + return nil + default: + n := binary.Size(t) + if n == -1 { + return errors.New(" encoder: unsupported type") + } + err = binary.Write(enc.w, binary.BigEndian, t) + enc.n += int64(n) + return + } +} + +func (enc *Encoder) encodeRaw(v interface{}) (err error) { + rv := reflect.ValueOf(v) + if v == nil || (rv.Kind() == reflect.Ptr && rv.IsNil()) { + return errors.New(" encoder: can't encode ") + } + + // implementation note: code is a bit verbose (abusing code generation), but minimize allocations on the heap + + var written64 int64 + if vw, ok := v.(io.WriterTo); ok { + written64, err = vw.WriteTo(enc.w) + enc.n += written64 + return + } + + var written int + + switch t := v.(type) { + case []uint64: + return enc.writeUint64Slice(t) + case [][]uint64: + return enc.writeUint64SliceSlice(t) + case *fr.Element: + buf := t.Bytes() + written, err = enc.w.Write(buf[:]) + enc.n += int64(written) + return + case *fp.Element: + buf := t.Bytes() + written, err = enc.w.Write(buf[:]) + enc.n += int64(written) + return + case *G1Affine: + buf := t.RawBytes() + written, err = enc.w.Write(buf[:]) + enc.n += int64(written) + return + case fr.Vector: + written64, err = t.WriteTo(enc.w) + enc.n += written64 + return + case fp.Vector: + written64, err = t.WriteTo(enc.w) + enc.n += written64 + return + case []fr.Element: + written64, err = (*fr.Vector)(&t).WriteTo(enc.w) + enc.n += written64 + return + case []fp.Element: + written64, err = (*fp.Vector)(&t).WriteTo(enc.w) + enc.n += written64 + return + case [][]fr.Element: + // write slice length + if err = binary.Write(enc.w, binary.BigEndian, uint32(len(t))); err != nil { + return + } + enc.n += 4 + for i := range t { + written64, err = (*fr.Vector)(&t[i]).WriteTo(enc.w) + enc.n += written64 + } + return + case [][][]fr.Element: + // number of collections + if err = binary.Write(enc.w, binary.BigEndian, uint32(len(t))); err != nil { + return + } + enc.n += 4 + for i := range t { + // size of current collection + if err = binary.Write(enc.w, binary.BigEndian, uint32(len(t[i]))); err != nil { + return + } + enc.n += 4 + // write each vector of the current collection + for j := range t[i] { + written64, err = (*fr.Vector)(&t[i][j]).WriteTo(enc.w) + enc.n += written64 + } + } + return + case *[]G1Affine: + return enc.encodeRaw(*t) + case []G1Affine: + // write slice length + err = binary.Write(enc.w, binary.BigEndian, uint32(len(t))) + if err != nil { + return + } + enc.n += 4 + + var buf [SizeOfG1AffineUncompressed]byte + + for i := 0; i < len(t); i++ { + buf = t[i].RawBytes() + written, err = enc.w.Write(buf[:]) + enc.n += int64(written) + if err != nil { + return + } + } + return nil + default: + n := binary.Size(t) + if n == -1 { + return errors.New(" encoder: unsupported type") + } + err = binary.Write(enc.w, binary.BigEndian, t) + enc.n += int64(n) + return + } +} + +func (enc *Encoder) writeUint64Slice(t []uint64) (err error) { + if err = enc.writeUint32(uint32(len(t))); err != nil { + return + } + for i := range t { + if err = enc.writeUint64(t[i]); err != nil { + return + } + } + return nil +} + +func (enc *Encoder) writeUint64SliceSlice(t [][]uint64) (err error) { + if err = enc.writeUint32(uint32(len(t))); err != nil { + return + } + for i := range t { + if err = enc.writeUint32(uint32(len(t[i]))); err != nil { + return + } + for j := range t[i] { + if err = enc.writeUint64(t[i][j]); err != nil { + return + } + } + } + return nil +} + +func (enc *Encoder) writeUint64(a uint64) error { + var buff [64 / 8]byte + binary.BigEndian.PutUint64(buff[:], a) + written, err := enc.w.Write(buff[:]) + enc.n += int64(written) + return err +} + +func (enc *Encoder) writeUint32(a uint32) error { + var buff [32 / 8]byte + binary.BigEndian.PutUint32(buff[:], a) + written, err := enc.w.Write(buff[:]) + enc.n += int64(written) + return err +} + +// SizeOfG1AffineCompressed represents the size in bytes that a G1Affine need in binary form, compressed +const SizeOfG1AffineCompressed = 32 + +// SizeOfG1AffineUncompressed represents the size in bytes that a G1Affine need in binary form, uncompressed +const SizeOfG1AffineUncompressed = SizeOfG1AffineCompressed * 2 + +// Marshal converts p to a byte slice (without point compression) +func (p *G1Affine) Marshal() []byte { + b := p.RawBytes() + return b[:] +} + +// Unmarshal is an alias to SetBytes() +func (p *G1Affine) Unmarshal(buf []byte) error { + _, err := p.SetBytes(buf) + return err +} + +// Bytes returns binary representation of p +// will store X coordinate in regular form and a parity bit +// as we have less than 3 bits available in our coordinate, we can't follow BLS12-381 style encoding (ZCash/IETF) +// +// we use the 2 most significant bits instead +// +// 00 -> uncompressed +// 10 -> compressed, use smallest lexicographically square root of Y^2 +// 11 -> compressed, use largest lexicographically square root of Y^2 +// 01 -> compressed infinity point +// the "uncompressed infinity point" will just have 00 (uncompressed) followed by zeroes (infinity = 0,0 in affine coordinates) +func (p *G1Affine) Bytes() (res [SizeOfG1AffineCompressed]byte) { + + // check if p is infinity point + if p.X.IsZero() && p.Y.IsZero() { + res[0] = mCompressedInfinity + return + } + + msbMask := mCompressedSmallest + // compressed, we need to know if Y is lexicographically bigger than -Y + // if p.Y ">" -p.Y + if p.Y.LexicographicallyLargest() { + msbMask = mCompressedLargest + } + + // we store X and mask the most significant word with our metadata mask + fp.BigEndian.PutElement((*[fp.Bytes]byte)(res[0:0+fp.Bytes]), p.X) + + res[0] |= msbMask + + return +} + +// RawBytes returns binary representation of p (stores X and Y coordinate) +// see Bytes() for a compressed representation +func (p *G1Affine) RawBytes() (res [SizeOfG1AffineUncompressed]byte) { + + // check if p is infinity point + if p.X.IsZero() && p.Y.IsZero() { + + res[0] = mUncompressed + + return + } + + // not compressed + // we store the Y coordinate + fp.BigEndian.PutElement((*[fp.Bytes]byte)(res[32:32+fp.Bytes]), p.Y) + + // we store X and mask the most significant word with our metadata mask + fp.BigEndian.PutElement((*[fp.Bytes]byte)(res[0:0+fp.Bytes]), p.X) + + res[0] |= mUncompressed + + return +} + +// SetBytes sets p from binary representation in buf and returns number of consumed bytes +// +// bytes in buf must match either RawBytes() or Bytes() output +// +// if buf is too short io.ErrShortBuffer is returned +// +// if buf contains compressed representation (output from Bytes()) and we're unable to compute +// the Y coordinate (i.e the square root doesn't exist) this function returns an error +// +// this check if the resulting point is on the curve and in the correct subgroup +func (p *G1Affine) SetBytes(buf []byte) (int, error) { + return p.setBytes(buf, true) +} + +func (p *G1Affine) setBytes(buf []byte, subGroupCheck bool) (int, error) { + if len(buf) < SizeOfG1AffineCompressed { + return 0, io.ErrShortBuffer + } + + // most significant byte + mData := buf[0] & mMask + + // check buffer size + if mData == mUncompressed { + if len(buf) < SizeOfG1AffineUncompressed { + return 0, io.ErrShortBuffer + } + } + + // infinity encoded, we still check that the buffer is full of zeroes. + if mData == mCompressedInfinity { + if !isZeroed(buf[0] & ^mMask, buf[1:SizeOfG1AffineCompressed]) { + return 0, ErrInvalidInfinityEncoding + } + p.X.SetZero() + p.Y.SetZero() + return SizeOfG1AffineCompressed, nil + } + + // uncompressed point + if mData == mUncompressed { + // read X and Y coordinates + if err := p.X.SetBytesCanonical(buf[:fp.Bytes]); err != nil { + return 0, err + } + if err := p.Y.SetBytesCanonical(buf[fp.Bytes : fp.Bytes*2]); err != nil { + return 0, err + } + + // subgroup check + if subGroupCheck && !p.IsInSubGroup() { + return 0, errors.New("invalid point: subgroup check failed") + } + + return SizeOfG1AffineUncompressed, nil + } + + // we have a compressed coordinate + // we need to + // 1. copy the buffer (to keep this method thread safe) + // 2. we need to solve the curve equation to compute Y + + var bufX [fp.Bytes]byte + copy(bufX[:fp.Bytes], buf[:fp.Bytes]) + bufX[0] &= ^mMask + + // read X coordinate + if err := p.X.SetBytesCanonical(bufX[:fp.Bytes]); err != nil { + return 0, err + } + + var YSquared, Y fp.Element + + YSquared.Square(&p.X).Mul(&YSquared, &p.X) + YSquared.Add(&YSquared, &bCurveCoeff) + if Y.Sqrt(&YSquared) == nil { + return 0, errors.New("invalid compressed coordinate: square root doesn't exist") + } + + if Y.LexicographicallyLargest() { + // Y ">" -Y + if mData == mCompressedSmallest { + Y.Neg(&Y) + } + } else { + // Y "<=" -Y + if mData == mCompressedLargest { + Y.Neg(&Y) + } + } + + p.Y = Y + + // subgroup check + if subGroupCheck && !p.IsInSubGroup() { + return 0, errors.New("invalid point: subgroup check failed") + } + + return SizeOfG1AffineCompressed, nil +} + +// unsafeComputeY called by Decoder when processing slices of compressed point in parallel (step 2) +// it computes the Y coordinate from the already set X coordinate and is compute intensive +func (p *G1Affine) unsafeComputeY(subGroupCheck bool) error { + // stored in unsafeSetCompressedBytes + + mData := byte(p.Y[0]) + + // we have a compressed coordinate, we need to solve the curve equation to compute Y + var YSquared, Y fp.Element + + YSquared.Square(&p.X).Mul(&YSquared, &p.X) + YSquared.Add(&YSquared, &bCurveCoeff) + if Y.Sqrt(&YSquared) == nil { + return errors.New("invalid compressed coordinate: square root doesn't exist") + } + + if Y.LexicographicallyLargest() { + // Y ">" -Y + if mData == mCompressedSmallest { + Y.Neg(&Y) + } + } else { + // Y "<=" -Y + if mData == mCompressedLargest { + Y.Neg(&Y) + } + } + + p.Y = Y + + // subgroup check + if subGroupCheck && !p.IsInSubGroup() { + return errors.New("invalid point: subgroup check failed") + } + + return nil +} + +// unsafeSetCompressedBytes is called by Decoder when processing slices of compressed point in parallel (step 1) +// assumes buf[:8] mask is set to compressed +// returns true if point is infinity and need no further processing +// it sets X coordinate and uses Y for scratch space to store decompression metadata +func (p *G1Affine) unsafeSetCompressedBytes(buf []byte) (isInfinity bool, err error) { + + // read the most significant byte + mData := buf[0] & mMask + + if mData == mCompressedInfinity { + isInfinity = true + if !isZeroed(buf[0] & ^mMask, buf[1:SizeOfG1AffineCompressed]) { + return isInfinity, ErrInvalidInfinityEncoding + } + p.X.SetZero() + p.Y.SetZero() + return isInfinity, nil + } + + // we need to copy the input buffer (to keep this method thread safe) + var bufX [fp.Bytes]byte + copy(bufX[:fp.Bytes], buf[:fp.Bytes]) + bufX[0] &= ^mMask + + // read X coordinate + if err := p.X.SetBytesCanonical(bufX[:fp.Bytes]); err != nil { + return false, err + } + // store mData in p.Y[0] + p.Y[0] = uint64(mData) + + // recomputing Y will be done asynchronously + return isInfinity, nil +} diff --git a/ecc/grumpkin/marshal_test.go b/ecc/grumpkin/marshal_test.go new file mode 100644 index 0000000000..5378c6f148 --- /dev/null +++ b/ecc/grumpkin/marshal_test.go @@ -0,0 +1,327 @@ +// Copyright 2020-2025 Consensys Software Inc. +// Licensed under the Apache License, Version 2.0. See the LICENSE file for details. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package grumpkin + +import ( + "bytes" + crand "crypto/rand" + "io" + "math/big" + "math/rand/v2" + "reflect" + "testing" + + "github.com/leanovate/gopter" + "github.com/leanovate/gopter/prop" + + "github.com/consensys/gnark-crypto/ecc/grumpkin/fp" + "github.com/consensys/gnark-crypto/ecc/grumpkin/fr" +) + +const ( + nbFuzzShort = 10 + nbFuzz = 100 +) + +func TestEncoder(t *testing.T) { + t.Parallel() + // TODO need proper fuzz testing here + + var inA uint64 + var inB fr.Element + var inC fp.Element + var inD G1Affine + var inE G1Affine + var inG []G1Affine + var inI []fp.Element + var inJ []fr.Element + var inK fr.Vector + var inL [][]fr.Element + var inM [][]uint64 + var inN [][][]fr.Element + + // set values of inputs + inA = rand.Uint64() //#nosec G404 weak rng is fine here + inB.SetRandom() + inC.SetRandom() + inD.ScalarMultiplication(&g1GenAff, new(big.Int).SetUint64(rand.Uint64())) //#nosec G404 weak rng is fine here + // inE --> infinity + inG = make([]G1Affine, 2) + inG[1] = inD + inI = make([]fp.Element, 3) + inI[2] = inD.X + inJ = make([]fr.Element, 0) + inK = make(fr.Vector, 42) + inK[41].SetUint64(42) + inL = [][]fr.Element{inJ, inK} + inM = [][]uint64{{1, 2}, {4}, {}} + inN = make([][][]fr.Element, 4) + for i := 0; i < 4; i++ { + inN[i] = make([][]fr.Element, i+2) + for j := 0; j < i+2; j++ { + inN[i][j] = make([]fr.Element, j+3) + for k := 0; k < j+3; k++ { + inN[i][j][k].SetRandom() + } + } + } + + // encode them, compressed and raw + var buf, bufRaw bytes.Buffer + enc := NewEncoder(&buf) + encRaw := NewEncoder(&bufRaw, RawEncoding()) + toEncode := []interface{}{inA, &inB, &inC, &inD, &inE, inG, inI, inJ, inK, inL, inM, inN} + for _, v := range toEncode { + if err := enc.Encode(v); err != nil { + t.Fatal(err) + } + if err := encRaw.Encode(v); err != nil { + t.Fatal(err) + } + } + + testDecode := func(t *testing.T, r io.Reader, n int64) { + dec := NewDecoder(r) + var outA uint64 + var outB fr.Element + var outC fp.Element + var outD G1Affine + var outE G1Affine + outE.X.SetOne() + outE.Y.SetUint64(42) + var outG []G1Affine + var outI []fp.Element + var outJ []fr.Element + var outK fr.Vector + var outL [][]fr.Element + var outM [][]uint64 + var outN [][][]fr.Element + + toDecode := []interface{}{&outA, &outB, &outC, &outD, &outE, &outG, &outI, &outJ, &outK, &outL, &outM, &outN} + for _, v := range toDecode { + if err := dec.Decode(v); err != nil { + t.Fatal(err) + } + } + + // compare values + if inA != outA { + t.Fatal("didn't encode/decode uint64 value properly") + } + + if !inB.Equal(&outB) || !inC.Equal(&outC) { + t.Fatal("decode(encode(Element) failed") + } + if !inD.Equal(&outD) || !inE.Equal(&outE) { + t.Fatal("decode(encode(G1Affine) failed") + } + for i := 0; i < len(inG); i++ { + if !inG[i].Equal(&outG[i]) { + t.Fatal("decode(encode(slice(points))) failed") + } + } + if (len(inI) != len(outI)) || (len(inJ) != len(outJ)) { + t.Fatal("decode(encode(slice(elements))) failed") + } + for i := 0; i < len(inI); i++ { + if !inI[i].Equal(&outI[i]) { + t.Fatal("decode(encode(slice(elements))) failed") + } + } + if !reflect.DeepEqual(inK, outK) { + t.Fatal("decode(encode(vector)) failed") + } + if !reflect.DeepEqual(inL, outL) { + t.Fatal("decode(encode(slice²(elements))) failed") + } + if !reflect.DeepEqual(inM, outM) { + t.Fatal("decode(encode(slice²(uint64))) failed") + } + if !reflect.DeepEqual(inN, outN) { + t.Fatal("decode(encode(slice^{3}(uint64))) failed") + } + if n != dec.BytesRead() { + t.Fatal("bytes read don't match bytes written") + } + } + + // decode them + testDecode(t, &buf, enc.BytesWritten()) + testDecode(t, &bufRaw, encRaw.BytesWritten()) + +} + +func TestIsCompressed(t *testing.T) { + t.Parallel() + var g1Inf, g1 G1Affine + + g1 = g1GenAff + + { + b := g1Inf.Bytes() + if !isCompressed(b[0]) { + t.Fatal("g1Inf.Bytes() should be compressed") + } + } + + { + b := g1Inf.RawBytes() + if isCompressed(b[0]) { + t.Fatal("g1Inf.RawBytes() should be uncompressed") + } + } + + { + b := g1.Bytes() + if !isCompressed(b[0]) { + t.Fatal("g1.Bytes() should be compressed") + } + } + + { + b := g1.RawBytes() + if isCompressed(b[0]) { + t.Fatal("g1.RawBytes() should be uncompressed") + } + } + +} + +func TestG1AffineSerialization(t *testing.T) { + t.Parallel() + // test round trip serialization of infinity + { + // compressed + { + var p1, p2 G1Affine + p2.X.SetRandom() + p2.Y.SetRandom() + buf := p1.Bytes() + n, err := p2.SetBytes(buf[:]) + if err != nil { + t.Fatal(err) + } + if n != SizeOfG1AffineCompressed { + t.Fatal("invalid number of bytes consumed in buffer") + } + if !(p2.X.IsZero() && p2.Y.IsZero()) { + t.Fatal("deserialization of uncompressed infinity point is not infinity") + } + } + + // uncompressed + { + var p1, p2 G1Affine + p2.X.SetRandom() + p2.Y.SetRandom() + buf := p1.RawBytes() + n, err := p2.SetBytes(buf[:]) + if err != nil { + t.Fatal(err) + } + if n != SizeOfG1AffineUncompressed { + t.Fatal("invalid number of bytes consumed in buffer") + } + if !(p2.X.IsZero() && p2.Y.IsZero()) { + t.Fatal("deserialization of uncompressed infinity point is not infinity") + } + } + } + + parameters := gopter.DefaultTestParameters() + if testing.Short() { + parameters.MinSuccessfulTests = nbFuzzShort + } else { + parameters.MinSuccessfulTests = nbFuzz + } + + properties := gopter.NewProperties(parameters) + + properties.Property("[G1] Affine SetBytes(RawBytes) should stay the same", prop.ForAll( + func(a fp.Element) bool { + var start, end G1Affine + var ab big.Int + a.BigInt(&ab) + start.ScalarMultiplication(&g1GenAff, &ab) + + buf := start.RawBytes() + n, err := end.SetBytes(buf[:]) + if err != nil { + return false + } + if n != SizeOfG1AffineUncompressed { + return false + } + return start.X.Equal(&end.X) && start.Y.Equal(&end.Y) + }, + GenFp(), + )) + + properties.Property("[G1] Affine SetBytes(Bytes()) should stay the same", prop.ForAll( + func(a fp.Element) bool { + var start, end G1Affine + var ab big.Int + a.BigInt(&ab) + start.ScalarMultiplication(&g1GenAff, &ab) + + buf := start.Bytes() + n, err := end.SetBytes(buf[:]) + if err != nil { + return false + } + if n != SizeOfG1AffineCompressed { + return false + } + return start.X.Equal(&end.X) && start.Y.Equal(&end.Y) + }, + GenFp(), + )) + + properties.TestingRun(t, gopter.ConsoleReporter(false)) +} + +// define Gopters generators + +// GenFr generates an Fr element +func GenFr() gopter.Gen { + return func(genParams *gopter.GenParameters) *gopter.GenResult { + var elmt fr.Element + + if _, err := elmt.SetRandom(); err != nil { + panic(err) + } + + return gopter.NewGenResult(elmt, gopter.NoShrinker) + } +} + +// GenFp generates an Fp element +func GenFp() gopter.Gen { + return func(genParams *gopter.GenParameters) *gopter.GenResult { + var elmt fp.Element + + if _, err := elmt.SetRandom(); err != nil { + panic(err) + } + + return gopter.NewGenResult(elmt, gopter.NoShrinker) + } +} + +// GenBigInt generates a big.Int +func GenBigInt() gopter.Gen { + return func(genParams *gopter.GenParameters) *gopter.GenResult { + var s big.Int + var b [fp.Bytes]byte + _, err := crand.Read(b[:]) + if err != nil { + panic(err) + } + s.SetBytes(b[:]) + genResult := gopter.NewGenResult(s, gopter.NoShrinker) + return genResult + } +} diff --git a/ecc/grumpkin/multiexp.go b/ecc/grumpkin/multiexp.go new file mode 100644 index 0000000000..68c0b9825f --- /dev/null +++ b/ecc/grumpkin/multiexp.go @@ -0,0 +1,531 @@ +// Copyright 2020-2025 Consensys Software Inc. +// Licensed under the Apache License, Version 2.0. See the LICENSE file for details. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package grumpkin + +import ( + "errors" + "github.com/consensys/gnark-crypto/ecc" + "github.com/consensys/gnark-crypto/ecc/grumpkin/fr" + "github.com/consensys/gnark-crypto/internal/parallel" + "math" + "runtime" +) + +// MultiExp implements section 4 of https://eprint.iacr.org/2012/549.pdf +// +// This call return an error if len(scalars) != len(points) or if provided config is invalid. +func (p *G1Affine) MultiExp(points []G1Affine, scalars []fr.Element, config ecc.MultiExpConfig) (*G1Affine, error) { + var _p G1Jac + if _, err := _p.MultiExp(points, scalars, config); err != nil { + return nil, err + } + p.FromJacobian(&_p) + return p, nil +} + +// MultiExp implements section 4 of https://eprint.iacr.org/2012/549.pdf +// +// This call return an error if len(scalars) != len(points) or if provided config is invalid. +func (p *G1Jac) MultiExp(points []G1Affine, scalars []fr.Element, config ecc.MultiExpConfig) (*G1Jac, error) { + // TODO @gbotrel replace the ecc.MultiExpConfig by a Option pattern for maintainability. + // note: + // each of the msmCX method is the same, except for the c constant it declares + // duplicating (through template generation) these methods allows to declare the buckets on the stack + // the choice of c needs to be improved: + // there is a theoretical value that gives optimal asymptotics + // but in practice, other factors come into play, including: + // * if c doesn't divide 64, the word size, then we're bound to select bits over 2 words of our scalars, instead of 1 + // * number of CPUs + // * cache friendliness (which depends on the host, G1 or G2... ) + // --> for example, on BN254, a G1 point fits into one cache line of 64bytes, but a G2 point don't. + + // for each msmCX + // step 1 + // we compute, for each scalars over c-bit wide windows, nbChunk digits + // if the digit is larger than 2^{c-1}, then, we borrow 2^c from the next window and subtract + // 2^{c} to the current digit, making it negative. + // negative digits will be processed in the next step as adding -G into the bucket instead of G + // (computing -G is cheap, and this saves us half of the buckets) + // step 2 + // buckets are declared on the stack + // notice that we have 2^{c-1} buckets instead of 2^{c} (see step1) + // we use jacobian extended formulas here as they are faster than mixed addition + // msmProcessChunk places points into buckets base on their selector and return the weighted bucket sum in given channel + // step 3 + // reduce the buckets weighed sums into our result (msmReduceChunk) + + // ensure len(points) == len(scalars) + nbPoints := len(points) + if nbPoints != len(scalars) { + return nil, errors.New("len(points) != len(scalars)") + } + + // if nbTasks is not set, use all available CPUs + if config.NbTasks <= 0 { + config.NbTasks = runtime.NumCPU() * 2 + } else if config.NbTasks > 1024 { + return nil, errors.New("invalid config: config.NbTasks > 1024") + } + + // here, we compute the best C for nbPoints + // we split recursively until nbChunks(c) >= nbTasks, + bestC := func(nbPoints int) uint64 { + // implemented msmC methods (the c we use must be in this slice) + implementedCs := []uint64{4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16} + var C uint64 + // approximate cost (in group operations) + // cost = bits/c * (nbPoints + 2^{c}) + // this needs to be verified empirically. + // for example, on a MBP 2016, for G2 MultiExp > 8M points, hand picking c gives better results + min := math.MaxFloat64 + for _, c := range implementedCs { + cc := (fr.Bits + 1) * (nbPoints + (1 << c)) + cost := float64(cc) / float64(c) + if cost < min { + min = cost + C = c + } + } + return C + } + + C := bestC(nbPoints) + nbChunks := int(computeNbChunks(C)) + + // should we recursively split the msm in half? (see below) + // we want to minimize the execution time of the algorithm; + // splitting the msm will **add** operations, but if it allows to use more CPU, it might be worth it. + + // costFunction returns a metric that represent the "wall time" of the algorithm + costFunction := func(nbTasks, nbCpus, costPerTask int) int { + // cost for the reduction of all tasks (msmReduceChunk) + totalCost := nbTasks + + // cost for the computation of each task (msmProcessChunk) + for nbTasks >= nbCpus { + nbTasks -= nbCpus + totalCost += costPerTask + } + if nbTasks > 0 { + totalCost += costPerTask + } + return totalCost + } + + // costPerTask is the approximate number of group ops per task + costPerTask := func(c uint64, nbPoints int) int { return (nbPoints + int((1 << c))) } + + costPreSplit := costFunction(nbChunks, config.NbTasks, costPerTask(C, nbPoints)) + + cPostSplit := bestC(nbPoints / 2) + nbChunksPostSplit := int(computeNbChunks(cPostSplit)) + costPostSplit := costFunction(nbChunksPostSplit*2, config.NbTasks, costPerTask(cPostSplit, nbPoints/2)) + + // if the cost of the split msm is lower than the cost of the non split msm, we split + if costPostSplit < costPreSplit { + config.NbTasks = int(math.Ceil(float64(config.NbTasks) / 2.0)) + var _p G1Jac + chDone := make(chan struct{}, 1) + go func() { + _p.MultiExp(points[:nbPoints/2], scalars[:nbPoints/2], config) + close(chDone) + }() + p.MultiExp(points[nbPoints/2:], scalars[nbPoints/2:], config) + <-chDone + p.AddAssign(&_p) + return p, nil + } + + // if we don't split, we use the best C we found + _innerMsmG1(p, C, points, scalars, config) + + return p, nil +} + +func _innerMsmG1(p *G1Jac, c uint64, points []G1Affine, scalars []fr.Element, config ecc.MultiExpConfig) *G1Jac { + // partition the scalars + digits, chunkStats := partitionScalars(scalars, c, config.NbTasks) + + nbChunks := computeNbChunks(c) + + // for each chunk, spawn one go routine that'll loop through all the scalars in the + // corresponding bit-window + // note that buckets is an array allocated on the stack and this is critical for performance + + // each go routine sends its result in chChunks[i] channel + chChunks := make([]chan g1JacExtended, nbChunks) + for i := 0; i < len(chChunks); i++ { + chChunks[i] = make(chan g1JacExtended, 1) + } + + // we use a semaphore to limit the number of go routines running concurrently + // (only if nbTasks < nbCPU) + var sem chan struct{} + if config.NbTasks < runtime.NumCPU() { + // we add nbChunks because if chunk is overweight we split it in two + sem = make(chan struct{}, config.NbTasks+int(nbChunks)) + for i := 0; i < config.NbTasks; i++ { + sem <- struct{}{} + } + defer func() { + close(sem) + }() + } + + // the last chunk may be processed with a different method than the rest, as it could be smaller. + n := len(points) + for j := int(nbChunks - 1); j >= 0; j-- { + processChunk := getChunkProcessorG1(c, chunkStats[j]) + if j == int(nbChunks-1) { + processChunk = getChunkProcessorG1(lastC(c), chunkStats[j]) + } + if chunkStats[j].weight >= 115 { + // we split this in more go routines since this chunk has more work to do than the others. + // else what would happen is this go routine would finish much later than the others. + chSplit := make(chan g1JacExtended, 2) + split := n / 2 + + if sem != nil { + sem <- struct{}{} // add another token to the semaphore, since we split in two. + } + go processChunk(uint64(j), chSplit, c, points[:split], digits[j*n:(j*n)+split], sem) + go processChunk(uint64(j), chSplit, c, points[split:], digits[(j*n)+split:(j+1)*n], sem) + go func(chunkID int) { + s1 := <-chSplit + s2 := <-chSplit + close(chSplit) + s1.add(&s2) + chChunks[chunkID] <- s1 + }(j) + continue + } + go processChunk(uint64(j), chChunks[j], c, points, digits[j*n:(j+1)*n], sem) + } + + return msmReduceChunkG1Affine(p, int(c), chChunks[:]) +} + +// getChunkProcessorG1 decides, depending on c window size and statistics for the chunk +// to return the best algorithm to process the chunk. +func getChunkProcessorG1(c uint64, stat chunkStat) func(chunkID uint64, chRes chan<- g1JacExtended, c uint64, points []G1Affine, digits []uint16, sem chan struct{}) { + switch c { + + case 2: + return processChunkG1Jacobian[bucketg1JacExtendedC2] + case 3: + return processChunkG1Jacobian[bucketg1JacExtendedC3] + case 4: + return processChunkG1Jacobian[bucketg1JacExtendedC4] + case 5: + return processChunkG1Jacobian[bucketg1JacExtendedC5] + case 6: + return processChunkG1Jacobian[bucketg1JacExtendedC6] + case 7: + return processChunkG1Jacobian[bucketg1JacExtendedC7] + case 8: + return processChunkG1Jacobian[bucketg1JacExtendedC8] + case 9: + return processChunkG1Jacobian[bucketg1JacExtendedC9] + case 10: + const batchSize = 80 + // here we could check some chunk statistic (deviation, ...) to determine if calling + // the batch affine version is worth it. + if stat.nbBucketFilled < batchSize { + // clear indicator that batch affine method is not appropriate here. + return processChunkG1Jacobian[bucketg1JacExtendedC10] + } + return processChunkG1BatchAffine[bucketg1JacExtendedC10, bucketG1AffineC10, bitSetC10, pG1AffineC10, ppG1AffineC10, qG1AffineC10, cG1AffineC10] + case 11: + const batchSize = 150 + // here we could check some chunk statistic (deviation, ...) to determine if calling + // the batch affine version is worth it. + if stat.nbBucketFilled < batchSize { + // clear indicator that batch affine method is not appropriate here. + return processChunkG1Jacobian[bucketg1JacExtendedC11] + } + return processChunkG1BatchAffine[bucketg1JacExtendedC11, bucketG1AffineC11, bitSetC11, pG1AffineC11, ppG1AffineC11, qG1AffineC11, cG1AffineC11] + case 12: + const batchSize = 200 + // here we could check some chunk statistic (deviation, ...) to determine if calling + // the batch affine version is worth it. + if stat.nbBucketFilled < batchSize { + // clear indicator that batch affine method is not appropriate here. + return processChunkG1Jacobian[bucketg1JacExtendedC12] + } + return processChunkG1BatchAffine[bucketg1JacExtendedC12, bucketG1AffineC12, bitSetC12, pG1AffineC12, ppG1AffineC12, qG1AffineC12, cG1AffineC12] + case 13: + const batchSize = 350 + // here we could check some chunk statistic (deviation, ...) to determine if calling + // the batch affine version is worth it. + if stat.nbBucketFilled < batchSize { + // clear indicator that batch affine method is not appropriate here. + return processChunkG1Jacobian[bucketg1JacExtendedC13] + } + return processChunkG1BatchAffine[bucketg1JacExtendedC13, bucketG1AffineC13, bitSetC13, pG1AffineC13, ppG1AffineC13, qG1AffineC13, cG1AffineC13] + case 14: + const batchSize = 400 + // here we could check some chunk statistic (deviation, ...) to determine if calling + // the batch affine version is worth it. + if stat.nbBucketFilled < batchSize { + // clear indicator that batch affine method is not appropriate here. + return processChunkG1Jacobian[bucketg1JacExtendedC14] + } + return processChunkG1BatchAffine[bucketg1JacExtendedC14, bucketG1AffineC14, bitSetC14, pG1AffineC14, ppG1AffineC14, qG1AffineC14, cG1AffineC14] + case 15: + const batchSize = 500 + // here we could check some chunk statistic (deviation, ...) to determine if calling + // the batch affine version is worth it. + if stat.nbBucketFilled < batchSize { + // clear indicator that batch affine method is not appropriate here. + return processChunkG1Jacobian[bucketg1JacExtendedC15] + } + return processChunkG1BatchAffine[bucketg1JacExtendedC15, bucketG1AffineC15, bitSetC15, pG1AffineC15, ppG1AffineC15, qG1AffineC15, cG1AffineC15] + case 16: + const batchSize = 640 + // here we could check some chunk statistic (deviation, ...) to determine if calling + // the batch affine version is worth it. + if stat.nbBucketFilled < batchSize { + // clear indicator that batch affine method is not appropriate here. + return processChunkG1Jacobian[bucketg1JacExtendedC16] + } + return processChunkG1BatchAffine[bucketg1JacExtendedC16, bucketG1AffineC16, bitSetC16, pG1AffineC16, ppG1AffineC16, qG1AffineC16, cG1AffineC16] + default: + // panic("will not happen c != previous values is not generated by templates") + return processChunkG1Jacobian[bucketg1JacExtendedC16] + } +} + +// msmReduceChunkG1Affine reduces the weighted sum of the buckets into the result of the multiExp +func msmReduceChunkG1Affine(p *G1Jac, c int, chChunks []chan g1JacExtended) *G1Jac { + var _p g1JacExtended + totalj := <-chChunks[len(chChunks)-1] + _p.Set(&totalj) + for j := len(chChunks) - 2; j >= 0; j-- { + for l := 0; l < c; l++ { + _p.double(&_p) + } + totalj := <-chChunks[j] + _p.add(&totalj) + } + + return p.unsafeFromJacExtended(&_p) +} + +// Fold computes the multi-exponentiation \sum_{i=0}^{len(points)-1} points[i] * +// combinationCoeff^i and stores the result in p. It returns error in case +// configuration is invalid. +func (p *G1Affine) Fold(points []G1Affine, combinationCoeff fr.Element, config ecc.MultiExpConfig) (*G1Affine, error) { + var _p G1Jac + if _, err := _p.Fold(points, combinationCoeff, config); err != nil { + return nil, err + } + p.FromJacobian(&_p) + return p, nil +} + +// Fold computes the multi-exponentiation \sum_{i=0}^{len(points)-1} points[i] * +// combinationCoeff^i and stores the result in p. It returns error in case +// configuration is invalid. +func (p *G1Jac) Fold(points []G1Affine, combinationCoeff fr.Element, config ecc.MultiExpConfig) (*G1Jac, error) { + scalars := make([]fr.Element, len(points)) + scalar := fr.NewElement(1) + for i := 0; i < len(points); i++ { + scalars[i].Set(&scalar) + scalar.Mul(&scalar, &combinationCoeff) + } + return p.MultiExp(points, scalars, config) +} + +// selector stores the index, mask and shifts needed to select bits from a scalar +// it is used during the multiExp algorithm or the batch scalar multiplication +type selector struct { + index uint64 // index in the multi-word scalar to select bits from + mask uint64 // mask (c-bit wide) + shift uint64 // shift needed to get our bits on low positions + + multiWordSelect bool // set to true if we need to select bits from 2 words (case where c doesn't divide 64) + maskHigh uint64 // same than mask, for index+1 + shiftHigh uint64 // same than shift, for index+1 +} + +// return number of chunks for a given window size c +// the last chunk may be bigger to accommodate a potential carry from the NAF decomposition +func computeNbChunks(c uint64) uint64 { + return (fr.Bits + c - 1) / c +} + +// return the last window size for a scalar; +// this last window should accommodate a carry (from the NAF decomposition) +// it can be == c if we have 1 available bit +// it can be > c if we have 0 available bit +// it can be < c if we have 2+ available bits +func lastC(c uint64) uint64 { + nbAvailableBits := (computeNbChunks(c) * c) - fr.Bits + return c + 1 - nbAvailableBits +} + +type chunkStat struct { + // relative weight of work compared to other chunks. 100.0 -> nominal weight. + weight float32 + + // percentage of bucket filled in the window; + ppBucketFilled float32 + nbBucketFilled int +} + +// partitionScalars compute, for each scalars over c-bit wide windows, nbChunk digits +// if the digit is larger than 2^{c-1}, then, we borrow 2^c from the next window and subtract +// 2^{c} to the current digit, making it negative. +// negative digits can be processed in a later step as adding -G into the bucket instead of G +// (computing -G is cheap, and this saves us half of the buckets in the MultiExp or BatchScalarMultiplication) +func partitionScalars(scalars []fr.Element, c uint64, nbTasks int) ([]uint16, []chunkStat) { + // no benefit here to have more tasks than CPUs + if nbTasks > runtime.NumCPU() { + nbTasks = runtime.NumCPU() + } + + // number of c-bit radixes in a scalar + nbChunks := computeNbChunks(c) + + digits := make([]uint16, len(scalars)*int(nbChunks)) + + mask := uint64((1 << c) - 1) // low c bits are 1 + max := int(1<<(c-1)) - 1 // max value (inclusive) we want for our digits + cDivides64 := (64 % c) == 0 // if c doesn't divide 64, we may need to select over multiple words + + // compute offset and word selector / shift to select the right bits of our windows + selectors := make([]selector, nbChunks) + for chunk := uint64(0); chunk < nbChunks; chunk++ { + jc := uint64(chunk * c) + d := selector{} + d.index = jc / 64 + d.shift = jc - (d.index * 64) + d.mask = mask << d.shift + d.multiWordSelect = !cDivides64 && d.shift > (64-c) && d.index < (fr.Limbs-1) + if d.multiWordSelect { + nbBitsHigh := d.shift - uint64(64-c) + d.maskHigh = (1 << nbBitsHigh) - 1 + d.shiftHigh = (c - nbBitsHigh) + } + selectors[chunk] = d + } + + parallel.Execute(len(scalars), func(start, end int) { + for i := start; i < end; i++ { + if scalars[i].IsZero() { + // everything is 0, no need to process this scalar + continue + } + scalar := scalars[i].Bits() + + var carry int + + // for each chunk in the scalar, compute the current digit, and an eventual carry + for chunk := uint64(0); chunk < nbChunks-1; chunk++ { + s := selectors[chunk] + + // init with carry if any + digit := carry + carry = 0 + + // digit = value of the c-bit window + digit += int((scalar[s.index] & s.mask) >> s.shift) + + if s.multiWordSelect { + // we are selecting bits over 2 words + digit += int(scalar[s.index+1]&s.maskHigh) << s.shiftHigh + } + + // if the digit is larger than 2^{c-1}, then, we borrow 2^c from the next window and subtract + // 2^{c} to the current digit, making it negative. + if digit > max { + digit -= (1 << c) + carry = 1 + } + + // if digit is zero, no impact on result + if digit == 0 { + continue + } + + var bits uint16 + if digit > 0 { + bits = uint16(digit) << 1 + } else { + bits = (uint16(-digit-1) << 1) + 1 + } + digits[int(chunk)*len(scalars)+i] = bits + } + + // for the last chunk, we don't want to borrow from a next window + // (but may have a larger max value) + chunk := nbChunks - 1 + s := selectors[chunk] + // init with carry if any + digit := carry + // digit = value of the c-bit window + digit += int((scalar[s.index] & s.mask) >> s.shift) + if s.multiWordSelect { + // we are selecting bits over 2 words + digit += int(scalar[s.index+1]&s.maskHigh) << s.shiftHigh + } + digits[int(chunk)*len(scalars)+i] = uint16(digit) << 1 + } + + }, nbTasks) + + // aggregate chunk stats + chunkStats := make([]chunkStat, nbChunks) + if c <= 9 { + // no need to compute stats for small window sizes + return digits, chunkStats + } + parallel.Execute(len(chunkStats), func(start, end int) { + // for each chunk compute the statistics + for chunkID := start; chunkID < end; chunkID++ { + // indicates if a bucket is hit. + var b bitSetC16 + + // digits for the chunk + chunkDigits := digits[chunkID*len(scalars) : (chunkID+1)*len(scalars)] + + totalOps := 0 + nz := 0 // non zero buckets count + for _, digit := range chunkDigits { + if digit == 0 { + continue + } + totalOps++ + bucketID := digit >> 1 + if digit&1 == 0 { + bucketID -= 1 + } + if !b[bucketID] { + nz++ + b[bucketID] = true + } + } + chunkStats[chunkID].weight = float32(totalOps) // count number of ops for now, we will compute the weight after + chunkStats[chunkID].ppBucketFilled = (float32(nz) * 100.0) / float32(int(1<<(c-1))) + chunkStats[chunkID].nbBucketFilled = nz + } + }, nbTasks) + + totalOps := float32(0.0) + for _, stat := range chunkStats { + totalOps += stat.weight + } + + target := totalOps / float32(nbChunks) + if target != 0.0 { + // if target == 0, it means all the scalars are 0 everywhere, there is no work to be done. + for i := 0; i < len(chunkStats); i++ { + chunkStats[i].weight = (chunkStats[i].weight * 100.0) / target + } + } + + return digits, chunkStats +} diff --git a/ecc/grumpkin/multiexp_affine.go b/ecc/grumpkin/multiexp_affine.go new file mode 100644 index 0000000000..601f4052f6 --- /dev/null +++ b/ecc/grumpkin/multiexp_affine.go @@ -0,0 +1,371 @@ +// Copyright 2020-2025 Consensys Software Inc. +// Licensed under the Apache License, Version 2.0. See the LICENSE file for details. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package grumpkin + +import ( + "github.com/consensys/gnark-crypto/ecc/grumpkin/fp" +) + +type batchOpG1Affine struct { + bucketID uint16 + point G1Affine +} + +// processChunkG1BatchAffine process a chunk of the scalars during the msm +// using affine coordinates for the buckets. To amortize the cost of the inverse in the affine addition +// we use a batch affine addition. +// +// this is derived from a PR by 0x0ece : https://github.com/Consensys/gnark-crypto/pull/249 +// See Section 5.3: ia.cr/2022/1396 +func processChunkG1BatchAffine[BJE ibg1JacExtended, B ibG1Affine, BS bitSet, TP pG1Affine, TPP ppG1Affine, TQ qOpsG1Affine, TC cG1Affine]( + chunk uint64, + chRes chan<- g1JacExtended, + c uint64, + points []G1Affine, + digits []uint16, + sem chan struct{}) { + + if sem != nil { + // if we are limited, wait for a token in the semaphore + <-sem + } + + // the batch affine addition needs independent points; in other words, for a window of batchSize + // we want to hit independent bucketIDs when processing the digit. if there is a conflict (we're trying + // to add 2 different points to the same bucket), then we push the conflicted point to a queue. + // each time the batch is full, we execute it, and tentatively put the points (if not conflict) + // from the top of the queue into the next batch. + // if the queue is full, we "flush it"; we sequentially add the points to the buckets in + // g1JacExtended coordinates. + // The reasoning behind this is the following; batchSize is chosen such as, for a uniformly random + // input, the number of conflicts is going to be low, and the element added to the queue should be immediately + // processed in the next batch. If it's not the case, then our inputs are not random; and we fallback to + // non-batch-affine version. + + // note that we have 2 sets of buckets + // 1 in G1Affine used with the batch affine additions + // 1 in g1JacExtended used in case the queue of conflicting points + var buckets B // in G1Affine coordinates, infinity point is represented as (0,0), no need to init + var bucketsJE BJE + for i := 0; i < len(buckets); i++ { + bucketsJE[i].SetInfinity() + } + + // setup for the batch affine; + var ( + bucketIds BS // bitSet to signify presence of a bucket in current batch + cptAdd int // count the number of bucket + point added to current batch + R TPP // bucket references + P TP // points to be added to R (buckets); it is beneficial to store them on the stack (ie copy) + queue TQ // queue of points that conflict the current batch + qID int // current position in queue + ) + + batchSize := len(P) + + isFull := func() bool { return cptAdd == batchSize } + + executeAndReset := func() { + batchAddG1Affine[TP, TPP, TC](&R, &P, cptAdd) + var tmp BS + bucketIds = tmp + cptAdd = 0 + } + + addFromQueue := func(op batchOpG1Affine) { + // @precondition: must ensures bucket is not "used" in current batch + // note that there is a bit of duplicate logic between add and addFromQueue + // the reason is that as of Go 1.19.3, if we pass a pointer to the queue item (see add signature) + // the compiler will put the queue on the heap. + BK := &buckets[op.bucketID] + + // handle special cases with inf or -P / P + if BK.IsInfinity() { + BK.Set(&op.point) + return + } + if BK.X.Equal(&op.point.X) { + if BK.Y.Equal(&op.point.Y) { + // P + P: doubling, which should be quite rare -- + // we use the other set of buckets + bucketsJE[op.bucketID].addMixed(&op.point) + return + } + BK.SetInfinity() + return + } + + bucketIds[op.bucketID] = true + R[cptAdd] = BK + P[cptAdd] = op.point + cptAdd++ + } + + add := func(bucketID uint16, PP *G1Affine, isAdd bool) { + // @precondition: ensures bucket is not "used" in current batch + BK := &buckets[bucketID] + // handle special cases with inf or -P / P + if BK.IsInfinity() { + if isAdd { + BK.Set(PP) + } else { + BK.Neg(PP) + } + return + } + if BK.X.Equal(&PP.X) { + if BK.Y.Equal(&PP.Y) { + // P + P: doubling, which should be quite rare -- + if isAdd { + bucketsJE[bucketID].addMixed(PP) + } else { + BK.SetInfinity() + } + return + } + if isAdd { + BK.SetInfinity() + } else { + bucketsJE[bucketID].subMixed(PP) + } + return + } + + bucketIds[bucketID] = true + R[cptAdd] = BK + if isAdd { + P[cptAdd].Set(PP) + } else { + P[cptAdd].Neg(PP) + } + cptAdd++ + } + + flushQueue := func() { + for i := 0; i < qID; i++ { + bucketsJE[queue[i].bucketID].addMixed(&queue[i].point) + } + qID = 0 + } + + processTopQueue := func() { + for i := qID - 1; i >= 0; i-- { + if bucketIds[queue[i].bucketID] { + return + } + addFromQueue(queue[i]) + // len(queue) < batchSize so no need to check for full batch. + qID-- + } + } + + for i, digit := range digits { + + if digit == 0 || points[i].IsInfinity() { + continue + } + + bucketID := uint16((digit >> 1)) + isAdd := digit&1 == 0 + if isAdd { + // add + bucketID -= 1 + } + + if bucketIds[bucketID] { + // put it in queue + queue[qID].bucketID = bucketID + if isAdd { + queue[qID].point.Set(&points[i]) + } else { + queue[qID].point.Neg(&points[i]) + } + qID++ + + // queue is full, flush it. + if qID == len(queue)-1 { + flushQueue() + } + continue + } + + // we add the point to the batch. + add(bucketID, &points[i], isAdd) + if isFull() { + executeAndReset() + processTopQueue() + } + } + + // flush items in batch. + executeAndReset() + + // empty the queue + flushQueue() + + // reduce buckets into total + // total = bucket[0] + 2*bucket[1] + 3*bucket[2] ... + n*bucket[n-1] + var runningSum, total g1JacExtended + runningSum.SetInfinity() + total.SetInfinity() + for k := len(buckets) - 1; k >= 0; k-- { + runningSum.addMixed(&buckets[k]) + if !bucketsJE[k].IsInfinity() { + runningSum.add(&bucketsJE[k]) + } + total.add(&runningSum) + } + + if sem != nil { + // release a token to the semaphore + // before sending to chRes + sem <- struct{}{} + } + + chRes <- total + +} + +// we declare the buckets as fixed-size array types +// this allow us to allocate the buckets on the stack +type bucketG1AffineC10 [512]G1Affine +type bucketG1AffineC11 [1024]G1Affine +type bucketG1AffineC12 [2048]G1Affine +type bucketG1AffineC13 [4096]G1Affine +type bucketG1AffineC14 [8192]G1Affine +type bucketG1AffineC15 [16384]G1Affine +type bucketG1AffineC16 [32768]G1Affine + +// buckets: array of G1Affine points of size 1 << (c-1) +type ibG1Affine interface { + bucketG1AffineC10 | + bucketG1AffineC11 | + bucketG1AffineC12 | + bucketG1AffineC13 | + bucketG1AffineC14 | + bucketG1AffineC15 | + bucketG1AffineC16 +} + +// array of coordinates fp.Element +type cG1Affine interface { + cG1AffineC10 | + cG1AffineC11 | + cG1AffineC12 | + cG1AffineC13 | + cG1AffineC14 | + cG1AffineC15 | + cG1AffineC16 +} + +// buckets: array of G1Affine points (for the batch addition) +type pG1Affine interface { + pG1AffineC10 | + pG1AffineC11 | + pG1AffineC12 | + pG1AffineC13 | + pG1AffineC14 | + pG1AffineC15 | + pG1AffineC16 +} + +// buckets: array of *G1Affine points (for the batch addition) +type ppG1Affine interface { + ppG1AffineC10 | + ppG1AffineC11 | + ppG1AffineC12 | + ppG1AffineC13 | + ppG1AffineC14 | + ppG1AffineC15 | + ppG1AffineC16 +} + +// buckets: array of G1Affine queue operations (for the batch addition) +type qOpsG1Affine interface { + qG1AffineC10 | + qG1AffineC11 | + qG1AffineC12 | + qG1AffineC13 | + qG1AffineC14 | + qG1AffineC15 | + qG1AffineC16 +} + +// batch size 80 when c = 10 +type cG1AffineC10 [80]fp.Element +type pG1AffineC10 [80]G1Affine +type ppG1AffineC10 [80]*G1Affine +type qG1AffineC10 [80]batchOpG1Affine + +// batch size 150 when c = 11 +type cG1AffineC11 [150]fp.Element +type pG1AffineC11 [150]G1Affine +type ppG1AffineC11 [150]*G1Affine +type qG1AffineC11 [150]batchOpG1Affine + +// batch size 200 when c = 12 +type cG1AffineC12 [200]fp.Element +type pG1AffineC12 [200]G1Affine +type ppG1AffineC12 [200]*G1Affine +type qG1AffineC12 [200]batchOpG1Affine + +// batch size 350 when c = 13 +type cG1AffineC13 [350]fp.Element +type pG1AffineC13 [350]G1Affine +type ppG1AffineC13 [350]*G1Affine +type qG1AffineC13 [350]batchOpG1Affine + +// batch size 400 when c = 14 +type cG1AffineC14 [400]fp.Element +type pG1AffineC14 [400]G1Affine +type ppG1AffineC14 [400]*G1Affine +type qG1AffineC14 [400]batchOpG1Affine + +// batch size 500 when c = 15 +type cG1AffineC15 [500]fp.Element +type pG1AffineC15 [500]G1Affine +type ppG1AffineC15 [500]*G1Affine +type qG1AffineC15 [500]batchOpG1Affine + +// batch size 640 when c = 16 +type cG1AffineC16 [640]fp.Element +type pG1AffineC16 [640]G1Affine +type ppG1AffineC16 [640]*G1Affine +type qG1AffineC16 [640]batchOpG1Affine + +type bitSetC2 [2]bool +type bitSetC3 [4]bool +type bitSetC4 [8]bool +type bitSetC5 [16]bool +type bitSetC6 [32]bool +type bitSetC7 [64]bool +type bitSetC8 [128]bool +type bitSetC9 [256]bool +type bitSetC10 [512]bool +type bitSetC11 [1024]bool +type bitSetC12 [2048]bool +type bitSetC13 [4096]bool +type bitSetC14 [8192]bool +type bitSetC15 [16384]bool +type bitSetC16 [32768]bool + +type bitSet interface { + bitSetC2 | + bitSetC3 | + bitSetC4 | + bitSetC5 | + bitSetC6 | + bitSetC7 | + bitSetC8 | + bitSetC9 | + bitSetC10 | + bitSetC11 | + bitSetC12 | + bitSetC13 | + bitSetC14 | + bitSetC15 | + bitSetC16 +} diff --git a/ecc/grumpkin/multiexp_jacobian.go b/ecc/grumpkin/multiexp_jacobian.go new file mode 100644 index 0000000000..79f5ddaf0a --- /dev/null +++ b/ecc/grumpkin/multiexp_jacobian.go @@ -0,0 +1,97 @@ +// Copyright 2020-2025 Consensys Software Inc. +// Licensed under the Apache License, Version 2.0. See the LICENSE file for details. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package grumpkin + +func processChunkG1Jacobian[B ibg1JacExtended](chunk uint64, + chRes chan<- g1JacExtended, + c uint64, + points []G1Affine, + digits []uint16, + sem chan struct{}) { + + if sem != nil { + // if we are limited, wait for a token in the semaphore + <-sem + } + + var buckets B + for i := 0; i < len(buckets); i++ { + buckets[i].SetInfinity() + } + + // for each scalars, get the digit corresponding to the chunk we're processing. + for i, digit := range digits { + if digit == 0 { + continue + } + + // if msbWindow bit is set, we need to subtract + if digit&1 == 0 { + // add + buckets[(digit>>1)-1].addMixed(&points[i]) + } else { + // sub + buckets[(digit >> 1)].subMixed(&points[i]) + } + } + + // reduce buckets into total + // total = bucket[0] + 2*bucket[1] + 3*bucket[2] ... + n*bucket[n-1] + + var runningSum, total g1JacExtended + runningSum.SetInfinity() + total.SetInfinity() + for k := len(buckets) - 1; k >= 0; k-- { + if !buckets[k].IsInfinity() { + runningSum.add(&buckets[k]) + } + total.add(&runningSum) + } + + if sem != nil { + // release a token to the semaphore + // before sending to chRes + sem <- struct{}{} + } + + chRes <- total +} + +// we declare the buckets as fixed-size array types +// this allow us to allocate the buckets on the stack +type bucketg1JacExtendedC2 [2]g1JacExtended +type bucketg1JacExtendedC3 [4]g1JacExtended +type bucketg1JacExtendedC4 [8]g1JacExtended +type bucketg1JacExtendedC5 [16]g1JacExtended +type bucketg1JacExtendedC6 [32]g1JacExtended +type bucketg1JacExtendedC7 [64]g1JacExtended +type bucketg1JacExtendedC8 [128]g1JacExtended +type bucketg1JacExtendedC9 [256]g1JacExtended +type bucketg1JacExtendedC10 [512]g1JacExtended +type bucketg1JacExtendedC11 [1024]g1JacExtended +type bucketg1JacExtendedC12 [2048]g1JacExtended +type bucketg1JacExtendedC13 [4096]g1JacExtended +type bucketg1JacExtendedC14 [8192]g1JacExtended +type bucketg1JacExtendedC15 [16384]g1JacExtended +type bucketg1JacExtendedC16 [32768]g1JacExtended + +type ibg1JacExtended interface { + bucketg1JacExtendedC2 | + bucketg1JacExtendedC3 | + bucketg1JacExtendedC4 | + bucketg1JacExtendedC5 | + bucketg1JacExtendedC6 | + bucketg1JacExtendedC7 | + bucketg1JacExtendedC8 | + bucketg1JacExtendedC9 | + bucketg1JacExtendedC10 | + bucketg1JacExtendedC11 | + bucketg1JacExtendedC12 | + bucketg1JacExtendedC13 | + bucketg1JacExtendedC14 | + bucketg1JacExtendedC15 | + bucketg1JacExtendedC16 +} diff --git a/ecc/grumpkin/multiexp_test.go b/ecc/grumpkin/multiexp_test.go new file mode 100644 index 0000000000..708a1463fd --- /dev/null +++ b/ecc/grumpkin/multiexp_test.go @@ -0,0 +1,442 @@ +// Copyright 2020-2025 Consensys Software Inc. +// Licensed under the Apache License, Version 2.0. See the LICENSE file for details. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package grumpkin + +import ( + "fmt" + "math/big" + "math/bits" + "math/rand/v2" + "runtime" + "sync" + "testing" + + "github.com/consensys/gnark-crypto/ecc" + "github.com/consensys/gnark-crypto/ecc/grumpkin/fr" + "github.com/leanovate/gopter" + "github.com/leanovate/gopter/prop" +) + +func TestMultiExpG1(t *testing.T) { + + parameters := gopter.DefaultTestParameters() + if testing.Short() { + parameters.MinSuccessfulTests = 3 + } else { + parameters.MinSuccessfulTests = nbFuzzShort * 2 + } + + properties := gopter.NewProperties(parameters) + + genScalar := GenFr() + + // size of the multiExps + const nbSamples = 73 + + // multi exp points + var samplePoints [nbSamples]G1Affine + var g G1Jac + g.Set(&g1Gen) + for i := 1; i <= nbSamples; i++ { + samplePoints[i-1].FromJacobian(&g) + g.AddAssign(&g1Gen) + } + + // sprinkle some points at infinity + samplePoints[rand.N(nbSamples)].SetInfinity() //#nosec G404 weak rng is fine here + samplePoints[rand.N(nbSamples)].SetInfinity() //#nosec G404 weak rng is fine here + samplePoints[rand.N(nbSamples)].SetInfinity() //#nosec G404 weak rng is fine here + samplePoints[rand.N(nbSamples)].SetInfinity() //#nosec G404 weak rng is fine here + + // final scalar to use in double and add method (without mixer factor) + // n(n+1)(2n+1)/6 (sum of the squares from 1 to n) + var scalar big.Int + scalar.SetInt64(nbSamples) + scalar.Mul(&scalar, new(big.Int).SetInt64(nbSamples+1)) + scalar.Mul(&scalar, new(big.Int).SetInt64(2*nbSamples+1)) + scalar.Div(&scalar, new(big.Int).SetInt64(6)) + + // ensure a multiexp that's splitted has the same result as a non-splitted one.. + properties.Property("[G1] Multi exponentiation (cmax) should be consistent with splitted multiexp", prop.ForAll( + func(mixer fr.Element) bool { + var samplePointsLarge [nbSamples * 13]G1Affine + for i := 0; i < 13; i++ { + copy(samplePointsLarge[i*nbSamples:], samplePoints[:]) + } + + var rmax, splitted1, splitted2 G1Jac + + // mixer ensures that all the words of a fpElement are set + var sampleScalars [nbSamples * 13]fr.Element + + for i := 1; i <= nbSamples; i++ { + sampleScalars[i-1].SetUint64(uint64(i)). + Mul(&sampleScalars[i-1], &mixer) + } + + rmax.MultiExp(samplePointsLarge[:], sampleScalars[:], ecc.MultiExpConfig{}) + splitted1.MultiExp(samplePointsLarge[:], sampleScalars[:], ecc.MultiExpConfig{NbTasks: 128}) + splitted2.MultiExp(samplePointsLarge[:], sampleScalars[:], ecc.MultiExpConfig{NbTasks: 51}) + return rmax.Equal(&splitted1) && rmax.Equal(&splitted2) + }, + genScalar, + )) + + // cRange is generated from template and contains the available parameters for the multiexp window size + cRange := []uint64{2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16} + if testing.Short() { + // test only "odd" and "even" (ie windows size divide word size vs not) + cRange = []uint64{5, 14} + } + + properties.Property(fmt.Sprintf("[G1] Multi exponentiation (c in %v) should be consistent with sum of square", cRange), prop.ForAll( + func(mixer fr.Element) bool { + + var expected G1Jac + + // compute expected result with double and add + var finalScalar, mixerBigInt big.Int + finalScalar.Mul(&scalar, mixer.BigInt(&mixerBigInt)) + expected.ScalarMultiplication(&g1Gen, &finalScalar) + + // mixer ensures that all the words of a fpElement are set + var sampleScalars [nbSamples]fr.Element + + for i := 1; i <= nbSamples; i++ { + sampleScalars[i-1].SetUint64(uint64(i)). + Mul(&sampleScalars[i-1], &mixer) + } + + results := make([]G1Jac, len(cRange)) + for i, c := range cRange { + _innerMsmG1(&results[i], c, samplePoints[:], sampleScalars[:], ecc.MultiExpConfig{NbTasks: runtime.NumCPU()}) + } + for i := 1; i < len(results); i++ { + if !results[i].Equal(&results[i-1]) { + t.Logf("result for c=%d != c=%d", cRange[i-1], cRange[i]) + return false + } + } + return true + }, + genScalar, + )) + + properties.Property(fmt.Sprintf("[G1] Multi exponentiation (c in %v) of points at infinity should output a point at infinity", cRange), prop.ForAll( + func(mixer fr.Element) bool { + + var samplePointsZero [nbSamples]G1Affine + + var expected G1Jac + + // compute expected result with double and add + var finalScalar, mixerBigInt big.Int + finalScalar.Mul(&scalar, mixer.BigInt(&mixerBigInt)) + expected.ScalarMultiplication(&g1Gen, &finalScalar) + + // mixer ensures that all the words of a fpElement are set + var sampleScalars [nbSamples]fr.Element + + for i := 1; i <= nbSamples; i++ { + sampleScalars[i-1].SetUint64(uint64(i)). + Mul(&sampleScalars[i-1], &mixer) + samplePointsZero[i-1].SetInfinity() + } + + results := make([]G1Jac, len(cRange)) + for i, c := range cRange { + _innerMsmG1(&results[i], c, samplePointsZero[:], sampleScalars[:], ecc.MultiExpConfig{NbTasks: runtime.NumCPU()}) + } + for i := 0; i < len(results); i++ { + if !results[i].Z.IsZero() { + t.Logf("result for c=%d is not infinity", cRange[i]) + return false + } + } + return true + }, + genScalar, + )) + + properties.Property(fmt.Sprintf("[G1] Multi exponentiation (c in %v) with a vector of 0s as input should output a point at infinity", cRange), prop.ForAll( + func(mixer fr.Element) bool { + // mixer ensures that all the words of a fpElement are set + var sampleScalars [nbSamples]fr.Element + + results := make([]G1Jac, len(cRange)) + for i, c := range cRange { + _innerMsmG1(&results[i], c, samplePoints[:], sampleScalars[:], ecc.MultiExpConfig{NbTasks: runtime.NumCPU()}) + } + for i := 0; i < len(results); i++ { + if !results[i].Z.IsZero() { + t.Logf("result for c=%d is not infinity", cRange[i]) + return false + } + } + return true + }, + genScalar, + )) + + // note : this test is here as we expect to have a different multiExp than the above bucket method + // for small number of points + properties.Property("[G1] Multi exponentiation (<50points) should be consistent with sum of square", prop.ForAll( + func(mixer fr.Element) bool { + + var g G1Jac + g.Set(&g1Gen) + + // mixer ensures that all the words of a fpElement are set + samplePoints := make([]G1Affine, 30) + sampleScalars := make([]fr.Element, 30) + + for i := 1; i <= 30; i++ { + sampleScalars[i-1].SetUint64(uint64(i)). + Mul(&sampleScalars[i-1], &mixer) + samplePoints[i-1].FromJacobian(&g) + g.AddAssign(&g1Gen) + } + + var op1MultiExp G1Affine + op1MultiExp.MultiExp(samplePoints, sampleScalars, ecc.MultiExpConfig{}) + + var finalBigScalar fr.Element + var finalBigScalarBi big.Int + var op1ScalarMul G1Affine + finalBigScalar.SetUint64(9455).Mul(&finalBigScalar, &mixer) + finalBigScalar.BigInt(&finalBigScalarBi) + op1ScalarMul.ScalarMultiplication(&g1GenAff, &finalBigScalarBi) + + return op1ScalarMul.Equal(&op1MultiExp) + }, + genScalar, + )) + + properties.TestingRun(t, gopter.ConsoleReporter(false)) +} + +func TestCrossMultiExpG1(t *testing.T) { + const nbSamples = 1 << 14 + // multi exp points + var samplePoints [nbSamples]G1Affine + var g G1Jac + g.Set(&g1Gen) + for i := 1; i <= nbSamples; i++ { + samplePoints[i-1].FromJacobian(&g) + g.AddAssign(&g1Gen) + } + + // sprinkle some points at infinity + samplePoints[rand.N(nbSamples)].SetInfinity() //#nosec G404 weak rng is fine here + samplePoints[rand.N(nbSamples)].SetInfinity() //#nosec G404 weak rng is fine here + samplePoints[rand.N(nbSamples)].SetInfinity() //#nosec G404 weak rng is fine here + samplePoints[rand.N(nbSamples)].SetInfinity() //#nosec G404 weak rng is fine here + + var sampleScalars [nbSamples]fr.Element + fillBenchScalars(sampleScalars[:]) + + // sprinkle some doublings + for i := 10; i < 100; i++ { + samplePoints[i] = samplePoints[0] + sampleScalars[i] = sampleScalars[0] + } + + // cRange is generated from template and contains the available parameters for the multiexp window size + cRange := []uint64{2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16} + if testing.Short() { + // test only "odd" and "even" (ie windows size divide word size vs not) + cRange = []uint64{5, 14} + } + + results := make([]G1Jac, len(cRange)) + for i, c := range cRange { + _innerMsmG1(&results[i], c, samplePoints[:], sampleScalars[:], ecc.MultiExpConfig{NbTasks: runtime.NumCPU()}) + } + + var r G1Jac + _innerMsmG1Reference(&r, samplePoints[:], sampleScalars[:], ecc.MultiExpConfig{NbTasks: runtime.NumCPU()}) + + var expected, got G1Affine + expected.FromJacobian(&r) + + for i := 0; i < len(results); i++ { + got.FromJacobian(&results[i]) + if !expected.Equal(&got) { + t.Fatalf("cross msm failed with c=%d", cRange[i]) + } + } + +} + +// _innerMsmG1Reference always do ext jacobian with c == 15 +func _innerMsmG1Reference(p *G1Jac, points []G1Affine, scalars []fr.Element, config ecc.MultiExpConfig) *G1Jac { + // partition the scalars + digits, _ := partitionScalars(scalars, 15, config.NbTasks) + + nbChunks := computeNbChunks(15) + + // for each chunk, spawn one go routine that'll loop through all the scalars in the + // corresponding bit-window + // note that buckets is an array allocated on the stack and this is critical for performance + + // each go routine sends its result in chChunks[i] channel + chChunks := make([]chan g1JacExtended, nbChunks) + for i := 0; i < len(chChunks); i++ { + chChunks[i] = make(chan g1JacExtended, 1) + } + + // the last chunk may be processed with a different method than the rest, as it could be smaller. + n := len(points) + for j := int(nbChunks - 1); j >= 0; j-- { + processChunk := processChunkG1Jacobian[bucketg1JacExtendedC15] + go processChunk(uint64(j), chChunks[j], 15, points, digits[j*n:(j+1)*n], nil) + } + + return msmReduceChunkG1Affine(p, int(15), chChunks[:]) +} + +func BenchmarkMultiExpG1(b *testing.B) { + + const ( + pow = (bits.UintSize / 2) - (bits.UintSize / 8) // 24 on 64 bits arch, 12 on 32 bits + nbSamples = 1 << pow + ) + + var ( + samplePoints [nbSamples]G1Affine + sampleScalars [nbSamples]fr.Element + sampleScalarsSmallValues [nbSamples]fr.Element + sampleScalarsRedundant [nbSamples]fr.Element + ) + + fillBenchScalars(sampleScalars[:]) + copy(sampleScalarsSmallValues[:], sampleScalars[:]) + copy(sampleScalarsRedundant[:], sampleScalars[:]) + + // this means first chunk is going to have more work to do and should be split into several go routines + for i := 0; i < len(sampleScalarsSmallValues); i++ { + if i%5 == 0 { + sampleScalarsSmallValues[i].SetZero() + sampleScalarsSmallValues[i][0] = 1 + } + } + + // bad case for batch affine because scalar distribution might look uniform + // but over batchSize windows, we may hit a lot of conflicts and force the msm-affine + // to process small batches of additions to flush its queue of conflicted points. + for i := 0; i < len(sampleScalarsRedundant); i += 100 { + for j := i + 1; j < i+100 && j < len(sampleScalarsRedundant); j++ { + sampleScalarsRedundant[j] = sampleScalarsRedundant[i] + } + } + + fillBenchBasesG1(samplePoints[:]) + + var testPoint G1Affine + + for i := 5; i <= pow; i++ { + using := 1 << i + + b.Run(fmt.Sprintf("%d points", using), func(b *testing.B) { + b.ResetTimer() + for j := 0; j < b.N; j++ { + testPoint.MultiExp(samplePoints[:using], sampleScalars[:using], ecc.MultiExpConfig{}) + } + }) + + b.Run(fmt.Sprintf("%d points-smallvalues", using), func(b *testing.B) { + b.ResetTimer() + for j := 0; j < b.N; j++ { + testPoint.MultiExp(samplePoints[:using], sampleScalarsSmallValues[:using], ecc.MultiExpConfig{}) + } + }) + + b.Run(fmt.Sprintf("%d points-redundancy", using), func(b *testing.B) { + b.ResetTimer() + for j := 0; j < b.N; j++ { + testPoint.MultiExp(samplePoints[:using], sampleScalarsRedundant[:using], ecc.MultiExpConfig{}) + } + }) + } +} + +func BenchmarkMultiExpG1Reference(b *testing.B) { + const nbSamples = 1 << 20 + + var ( + samplePoints [nbSamples]G1Affine + sampleScalars [nbSamples]fr.Element + ) + + fillBenchScalars(sampleScalars[:]) + fillBenchBasesG1(samplePoints[:]) + + var testPoint G1Affine + + b.ResetTimer() + for j := 0; j < b.N; j++ { + testPoint.MultiExp(samplePoints[:], sampleScalars[:], ecc.MultiExpConfig{}) + } +} + +func BenchmarkManyMultiExpG1Reference(b *testing.B) { + const nbSamples = 1 << 20 + + var ( + samplePoints [nbSamples]G1Affine + sampleScalars [nbSamples]fr.Element + ) + + fillBenchScalars(sampleScalars[:]) + fillBenchBasesG1(samplePoints[:]) + + var t1, t2, t3 G1Affine + b.ResetTimer() + for j := 0; j < b.N; j++ { + var wg sync.WaitGroup + wg.Add(3) + go func() { + t1.MultiExp(samplePoints[:], sampleScalars[:], ecc.MultiExpConfig{}) + wg.Done() + }() + go func() { + t2.MultiExp(samplePoints[:], sampleScalars[:], ecc.MultiExpConfig{}) + wg.Done() + }() + go func() { + t3.MultiExp(samplePoints[:], sampleScalars[:], ecc.MultiExpConfig{}) + wg.Done() + }() + wg.Wait() + } +} + +// WARNING: this return points that are NOT on the curve and is meant to be use for benchmarking +// purposes only. We don't check that the result is valid but just measure "computational complexity". +// +// Rationale for generating points that are not on the curve is that for large benchmarks, generating +// a vector of different points can take minutes. Using the same point or subset will bias the benchmark result +// since bucket additions in extended jacobian coordinates will hit doubling algorithm instead of add. +func fillBenchBasesG1(samplePoints []G1Affine) { + var r big.Int + r.SetString("340444420969191673093399857471996460938405", 10) + samplePoints[0].ScalarMultiplication(&samplePoints[0], &r) + + one := samplePoints[0].X + one.SetOne() + + for i := 1; i < len(samplePoints); i++ { + samplePoints[i].X.Add(&samplePoints[i-1].X, &one) + samplePoints[i].Y.Sub(&samplePoints[i-1].Y, &one) + } +} + +func fillBenchScalars(sampleScalars []fr.Element) { + // ensure every words of the scalars are filled + for i := 0; i < len(sampleScalars); i++ { + sampleScalars[i].SetRandom() + } +} diff --git a/hash/hashes.go b/hash/hashes.go index 4f433c86d9..82a2992700 100644 --- a/hash/hashes.go +++ b/hash/hashes.go @@ -38,6 +38,8 @@ const ( MIMC_BLS24_317 // MIMC_BW6_633 is the MiMC hash function for the BW6-633 curve. MIMC_BW6_633 + // MIMC_GRUMPKIN is the MiMC hash function for the Grumpkin curve. + MIMC_GRUMPKIN // POSEIDON2_BLS12_377 is the Poseidon2 hash function for the BLS12-377 curve. POSEIDON2_BLS12_377 @@ -53,6 +55,7 @@ var digestSize = []uint8{ MIMC_BLS24_315: 48, MIMC_BLS24_317: 48, MIMC_BW6_633: 80, + MIMC_GRUMPKIN: 32, POSEIDON2_BLS12_377: 48, } @@ -90,6 +93,8 @@ func (m Hash) String() string { return "MIMC_BLS24_317" case MIMC_BW6_633: return "MIMC_BW6_633" + case MIMC_GRUMPKIN: + return "MIMC_GRUMPKIN" case POSEIDON2_BLS12_377: return "POSEIDON2_BLS12_377" default: diff --git a/internal/generator/config/grumpkin.go b/internal/generator/config/grumpkin.go new file mode 100644 index 0000000000..1545532326 --- /dev/null +++ b/internal/generator/config/grumpkin.go @@ -0,0 +1,28 @@ +package config + +var GRUMPKIN = Curve{ + Name: "grumpkin", + CurvePackage: "grumpkin", + EnumID: "GRUMPKIN", + FrModulus: "21888242871839275222246405745257275088696311157297823662689037894645226208583", + FpModulus: "21888242871839275222246405745257275088548364400416034343698204186575808495617", + G1: Point{ + CoordType: "fp.Element", + CoordExtDegree: 1, + PointName: "g1", + GLV: true, + CofactorCleaning: false, + CRange: defaultCRange(), + }, + HashE1: &HashSuiteSvdw{ + z: []string{"1"}, + c1: []string{"21888242871839275222246405745257275088548364400416034343698204186575808495601"}, + c2: []string{"10944121435919637611123202872628637544274182200208017171849102093287904247808"}, + c3: []string{"17631683881184975371348829942606096167675058198229016842588"}, + c4: []string{"14592161914559516814830937163504850059032242933610689562465469457717205663766"}, + }, +} + +func init() { + addCurve(&GRUMPKIN) +} diff --git a/internal/generator/crypto/hash/mimc/template/mimc.go.tmpl b/internal/generator/crypto/hash/mimc/template/mimc.go.tmpl index a50dc24c1c..b1130e2094 100644 --- a/internal/generator/crypto/hash/mimc/template/mimc.go.tmpl +++ b/internal/generator/crypto/hash/mimc/template/mimc.go.tmpl @@ -17,7 +17,7 @@ func init() { } const ( -{{ if eq .Name "bn254" }} +{{ if or (eq .Name "bn254") (eq .Name "grumpkin")}} mimcNbRounds = 110 {{- else if eq .Name "bls12-381"}} mimcNbRounds = 111 @@ -256,7 +256,7 @@ func (d *digest) WriteString(rawBytes []byte) error { // SetState manually sets the state of the hasher to an user-provided value. In // the context of MiMC, the method expects a byte slice of 32 elements. func (d *digest) SetState(newState []byte) error { - + if len(newState) != {{ .Fr.NbBytes }} { return errors.New("the mimc state expects a state of {{ .Fr.NbBytes }} bytes") } @@ -275,4 +275,4 @@ func (d *digest) State() []byte { _ = d.Sum(nil) // this flushes the hasher b := d.h.Bytes() return b[:] -} \ No newline at end of file +} diff --git a/internal/generator/ecc/generate.go b/internal/generator/ecc/generate.go index 9e29595153..30b8e383d5 100644 --- a/internal/generator/ecc/generate.go +++ b/internal/generator/ecc/generate.go @@ -51,7 +51,7 @@ func Generate(conf config.Curve, baseDir string, bgen *bavard.BatchGenerator) er {File: filepath.Join(baseDir, "g1_test.go"), Templates: []string{"tests/point.go.tmpl"}}, } // if not secp256k1, generate the lagrange transform - if conf.Name != config.SECP256K1.Name { + if conf.Name != config.SECP256K1.Name || conf.Name != config.GRUMPKIN.Name { os.Remove(filepath.Join(baseDir, "g1_lagrange.go")) os.Remove(filepath.Join(baseDir, "g1_lagrange_test.go")) } @@ -171,8 +171,8 @@ func Generate(conf config.Curve, baseDir string, bgen *bavard.BatchGenerator) er return err } - // No G2 for secp256k1 - if conf.Equal(config.SECP256K1) { + // No G2 for secp256k1 and grumpkin + if conf.Equal(config.SECP256K1) || conf.Equal(config.GRUMPKIN) { return nil } diff --git a/internal/generator/ecc/template/multiexp.go.tmpl b/internal/generator/ecc/template/multiexp.go.tmpl index 51fcec1c25..5fe0d0bd8d 100644 --- a/internal/generator/ecc/template/multiexp.go.tmpl +++ b/internal/generator/ecc/template/multiexp.go.tmpl @@ -16,11 +16,13 @@ import ( "runtime" ) -{{- if ne .Name "secp256k1"}} +{{- if eq .Name "secp256k1"}} +{{template "multiexp" dict "PointName" .G1.PointName "UPointName" (toUpper .G1.PointName) "TAffine" $G1TAffine "TJacobian" $G1TJacobian "TJacobianExtended" $G1TJacobianExtended "FrNbWords" .Fr.NbWords "CRange" .G1.CRange "cmax" 15}} +{{- else if eq .Name "grumpkin"}} {{template "multiexp" dict "PointName" .G1.PointName "UPointName" (toUpper .G1.PointName) "TAffine" $G1TAffine "TJacobian" $G1TJacobian "TJacobianExtended" $G1TJacobianExtended "FrNbWords" .Fr.NbWords "CRange" .G1.CRange "cmax" 16}} -{{template "multiexp" dict "PointName" .G2.PointName "UPointName" (toUpper .G2.PointName) "TAffine" $G2TAffine "TJacobian" $G2TJacobian "TJacobianExtended" $G2TJacobianExtended "FrNbWords" .Fr.NbWords "CRange" .G2.CRange "cmax" 16}} {{- else}} -{{template "multiexp" dict "PointName" .G1.PointName "UPointName" (toUpper .G1.PointName) "TAffine" $G1TAffine "TJacobian" $G1TJacobian "TJacobianExtended" $G1TJacobianExtended "FrNbWords" .Fr.NbWords "CRange" .G1.CRange "cmax" 15}} +{{template "multiexp" dict "PointName" .G1.PointName "UPointName" (toUpper .G1.PointName) "TAffine" $G1TAffine "TJacobian" $G1TJacobian "TJacobianExtended" $G1TJacobianExtended "FrNbWords" .Fr.NbWords "CRange" .G1.CRange "cmax" 16}} +{{template "multiexp" dict "PointName" .G2.PointName "UPointName" (toUpper .G2.PointName) "TAffine" $G2TAffine "TJacobian" $G2TJacobian "TJacobianExtended" $G2TJacobianExtended "FrNbWords" .Fr.NbWords "CRange" .G2.CRange "cmax" 16}} {{- end}} @@ -313,7 +315,7 @@ func (p *{{ $.TJacobian }}) MultiExp(points []{{ $.TAffine }}, scalars []fr.Elem nbChunks := int(computeNbChunks(C)) // should we recursively split the msm in half? (see below) - // we want to minimize the execution time of the algorithm; + // we want to minimize the execution time of the algorithm; // splitting the msm will **add** operations, but if it allows to use more CPU, it might be worth it. // costFunction returns a metric that represent the "wall time" of the algorithm @@ -336,7 +338,7 @@ func (p *{{ $.TJacobian }}) MultiExp(points []{{ $.TAffine }}, scalars []fr.Elem costPerTask := func(c uint64, nbPoints int) int {return (nbPoints + int((1 << c)))} costPreSplit := costFunction(nbChunks, config.NbTasks, costPerTask(C, nbPoints)) - + cPostSplit := bestC(nbPoints/2) nbChunksPostSplit := int(computeNbChunks(cPostSplit)) costPostSplit := costFunction(nbChunksPostSplit * 2, config.NbTasks, costPerTask(cPostSplit, nbPoints/2)) @@ -383,7 +385,7 @@ func _innerMsm{{ $.UPointName }}(p *{{ $.TJacobian }}, c uint64, points []{{ $.T var sem chan struct{} if config.NbTasks < runtime.NumCPU() { // we add nbChunks because if chunk is overweight we split it in two - sem = make(chan struct{}, config.NbTasks + int(nbChunks)) + sem = make(chan struct{}, config.NbTasks + int(nbChunks)) for i:=0; i < config.NbTasks; i++ { sem <- struct{}{} } diff --git a/internal/generator/ecc/template/multiexp_affine.go.tmpl b/internal/generator/ecc/template/multiexp_affine.go.tmpl index 1c0bf7d325..666eb153fc 100644 --- a/internal/generator/ecc/template/multiexp_affine.go.tmpl +++ b/internal/generator/ecc/template/multiexp_affine.go.tmpl @@ -9,13 +9,13 @@ import ( "github.com/consensys/gnark-crypto/ecc/{{.Name}}/fp" - {{- if and (ne .G1.CoordType .G2.CoordType) (ne .Name "secp256k1") }} + {{- if and (ne .G1.CoordType .G2.CoordType) (ne .Name "secp256k1") (ne .Name "grumpkin") }} "github.com/consensys/gnark-crypto/ecc/{{.Name}}/internal/fptower" {{- end}} ) {{ template "multiexp" dict "CoordType" .G1.CoordType "PointName" .G1.PointName "UPointName" (toUpper .G1.PointName) "TAffine" $G1TAffine "TJacobian" $G1TJacobian "TJacobianExtended" $G1TJacobianExtended "FrNbWords" .Fr.NbWords "CRange" .G1.CRange}} -{{- if ne .Name "secp256k1"}} +{{- if and (ne .Name "secp256k1") (ne .Name "grumpkin")}} {{ template "multiexp" dict "CoordType" .G2.CoordType "PointName" .G2.PointName "UPointName" (toUpper .G2.PointName) "TAffine" $G2TAffine "TJacobian" $G2TJacobian "TJacobianExtended" $G2TJacobianExtended "FrNbWords" .Fr.NbWords "CRange" .G2.CRange}} {{- end}} diff --git a/internal/generator/ecc/template/multiexp_jacobian.go.tmpl b/internal/generator/ecc/template/multiexp_jacobian.go.tmpl index cd7947345e..2bc5cb3f37 100644 --- a/internal/generator/ecc/template/multiexp_jacobian.go.tmpl +++ b/internal/generator/ecc/template/multiexp_jacobian.go.tmpl @@ -9,7 +9,7 @@ {{ template "multiexp" dict "PointName" .G1.PointName "UPointName" (toUpper .G1.PointName) "TAffine" $G1TAffine "TJacobian" $G1TJacobian "TJacobianExtended" $G1TJacobianExtended "FrNbWords" .Fr.NbWords "CRange" .G1.CRange }} -{{- if ne .Name "secp256k1"}} +{{- if and (ne .Name "secp256k1") (ne .Name "grumpkin")}} {{ template "multiexp" dict "PointName" .G2.PointName "UPointName" (toUpper .G2.PointName) "TAffine" $G2TAffine "TJacobian" $G2TJacobian "TJacobianExtended" $G2TJacobianExtended "FrNbWords" .Fr.NbWords "CRange" .G2.CRange }} {{- end}} diff --git a/internal/generator/ecc/template/point.go.tmpl b/internal/generator/ecc/template/point.go.tmpl index 82e60b6369..de171733e4 100644 --- a/internal/generator/ecc/template/point.go.tmpl +++ b/internal/generator/ecc/template/point.go.tmpl @@ -543,7 +543,7 @@ func (p *{{ $TJacobian }}) IsOnCurve() bool { -{{- if or (eq .Name "bn254") (eq .Name "secp256k1")}} +{{- if or (eq .Name "bn254") (eq .Name "secp256k1") (eq .Name "grumpkin")}} {{- if eq .PointName "g1"}} // IsInSubGroup returns true if p is on the r-torsion, false otherwise. // the curve is of prime order i.e. E(𝔽p) is the full group @@ -1753,7 +1753,7 @@ func batchAdd{{ $TAffine }}[TP p{{ $TAffine }}, TPP pp{{ $TAffine }}, TC c{{ $TA lambdain[j].Sub(&(*P)[j].X, &(*R)[j].X) } - // montgomery batch inversion; + // montgomery batch inversion; // lambda[0] = 1 / (P[0].X - R[0].X) // lambda[1] = 1 / (P[1].X - R[1].X) // ... diff --git a/internal/generator/ecc/template/tests/multiexp.go.tmpl b/internal/generator/ecc/template/tests/multiexp.go.tmpl index 912b251352..3e7fcb0754 100644 --- a/internal/generator/ecc/template/tests/multiexp.go.tmpl +++ b/internal/generator/ecc/template/tests/multiexp.go.tmpl @@ -23,7 +23,7 @@ import ( ) -{{- if ne .Name "secp256k1"}} +{{- if and (ne .Name "secp256k1") (ne .Name "grumpkin")}} {{template "multiexp" dict "PointName" .G1.PointName "UPointName" (toUpper .G1.PointName) "TAffine" $G1TAffine "TJacobian" $G1TJacobian "TJacobianExtended" $G1TJacobianExtended "FrNbWords" .Fr.NbWords "CRange" .G1.CRange "cmax" 16}} {{template "multiexp" dict "PointName" .G2.PointName "UPointName" (toUpper .G2.PointName) "TAffine" $G2TAffine "TJacobian" $G2TJacobian "TJacobianExtended" $G2TJacobianExtended "FrNbWords" .Fr.NbWords "CRange" .G2.CRange "cmax" 16}} {{- else}} diff --git a/internal/generator/ecdsa/template/ecdsa.go.tmpl b/internal/generator/ecdsa/template/ecdsa.go.tmpl index aa0e3868cc..eec6657c67 100644 --- a/internal/generator/ecdsa/template/ecdsa.go.tmpl +++ b/internal/generator/ecdsa/template/ecdsa.go.tmpl @@ -84,7 +84,7 @@ func GenerateKey(rand io.Reader) (*PrivateKey, error) { } - {{- if or (eq .Name "secp256k1") (eq .Name "stark-curve")}} + {{- if or (eq .Name "secp256k1") (eq .Name "stark-curve") (eq .Name "grumpkin")}} _, g := {{ .CurvePackage }}.Generators() {{- else}} _, _, g, _ := {{ .CurvePackage }}.Generators() diff --git a/internal/generator/main.go b/internal/generator/main.go index c9453b7640..940f9d76f1 100644 --- a/internal/generator/main.go +++ b/internal/generator/main.go @@ -2,12 +2,13 @@ package main import ( "fmt" - "github.com/consensys/gnark-crypto/internal/generator/mpcsetup" "os" "os/exec" "path/filepath" "sync" + "github.com/consensys/gnark-crypto/internal/generator/mpcsetup" + "github.com/consensys/bavard" "github.com/consensys/gnark-crypto/field/generator" fieldConfig "github.com/consensys/gnark-crypto/field/generator/config" @@ -76,7 +77,7 @@ func main() { conf.FpUnusedBits = 64 - (conf.Fp.NbBits % 64) frOpts := []generator.Option{generator.WithASM(asmConfig)} - if !(conf.Equal(config.STARK_CURVE) || conf.Equal(config.SECP256K1)) { + if !(conf.Equal(config.STARK_CURVE) || conf.Equal(config.SECP256K1) || conf.Equal(config.GRUMPKIN)) { frOpts = append(frOpts, generator.WithFFT(fftConfig)) } if conf.Equal(config.BLS12_377) { @@ -99,36 +100,6 @@ func main() { return } - // generate tower of extension - assertNoError(tower.Generate(conf, filepath.Join(curveDir, "internal", "fptower"), bgen)) - - // generate pairing tests - assertNoError(pairing.Generate(conf, curveDir, bgen)) - - // generate fri on fr - assertNoError(fri.Generate(conf, filepath.Join(curveDir, "fr", "fri"), bgen)) - - // generate mpc setup tools - assertNoError(mpcsetup.Generate(conf, filepath.Join(curveDir, "mpcsetup"), bgen)) - - // generate kzg on fr - assertNoError(kzg.Generate(conf, filepath.Join(curveDir, "kzg"), bgen)) - - // generate shplonk on fr - assertNoError(shplonk.Generate(conf, filepath.Join(curveDir, "shplonk"), bgen)) - - // generate fflonk on fr - assertNoError(fflonk.Generate(conf, filepath.Join(curveDir, "fflonk"), bgen)) - - // generate pedersen on fr - assertNoError(pedersen.Generate(conf, filepath.Join(curveDir, "fr", "pedersen"), bgen)) - - // generate plookup on fr - assertNoError(plookup.Generate(conf, filepath.Join(curveDir, "fr", "plookup"), bgen)) - - // generate permutation on fr - assertNoError(permutation.Generate(conf, filepath.Join(curveDir, "fr", "permutation"), bgen)) - // generate mimc on fr assertNoError(mimc.Generate(conf, filepath.Join(curveDir, "fr", "mimc"), bgen)) @@ -144,9 +115,6 @@ func main() { // generate polynomial on fr assertNoError(polynomial.Generate(frInfo, filepath.Join(curveDir, "fr", "polynomial"), true, bgen)) - // generate eddsa on companion curves - assertNoError(fri.Generate(conf, filepath.Join(curveDir, "fr", "fri"), bgen)) - // generate sumcheck on fr assertNoError(sumcheck.Generate(frInfo, filepath.Join(curveDir, "fr", "sumcheck"), bgen)) @@ -163,9 +131,6 @@ func main() { RandomizeMissingHashEntries: false, }, filepath.Join(curveDir, "fr", "test_vector_utils"), bgen)) - // generate iop functions - assertNoError(iop.Generate(conf, filepath.Join(curveDir, "fr", "iop"), bgen)) - fpInfo := config.FieldDependency{ FieldPackagePath: "github.com/consensys/gnark-crypto/ecc/" + conf.Name + "/fp", FieldPackageName: "fp", @@ -176,6 +141,46 @@ func main() { assertNoError(hash_to_field.Generate(frInfo, filepath.Join(curveDir, "fr", "hash_to_field"), bgen)) assertNoError(hash_to_field.Generate(fpInfo, filepath.Join(curveDir, "fp", "hash_to_field"), bgen)) + if conf.Equal(config.GRUMPKIN) { + return + } + + // generate pedersen on fr + assertNoError(pedersen.Generate(conf, filepath.Join(curveDir, "fr", "pedersen"), bgen)) + + // generate tower of extension + assertNoError(tower.Generate(conf, filepath.Join(curveDir, "internal", "fptower"), bgen)) + + // generate pairing tests + assertNoError(pairing.Generate(conf, curveDir, bgen)) + + // generate fri on fr + assertNoError(fri.Generate(conf, filepath.Join(curveDir, "fr", "fri"), bgen)) + + // generate mpc setup tools + assertNoError(mpcsetup.Generate(conf, filepath.Join(curveDir, "mpcsetup"), bgen)) + + // generate kzg on fr + assertNoError(kzg.Generate(conf, filepath.Join(curveDir, "kzg"), bgen)) + + // generate shplonk on fr + assertNoError(shplonk.Generate(conf, filepath.Join(curveDir, "shplonk"), bgen)) + + // generate fflonk on fr + assertNoError(fflonk.Generate(conf, filepath.Join(curveDir, "fflonk"), bgen)) + + // generate plookup on fr + assertNoError(plookup.Generate(conf, filepath.Join(curveDir, "fr", "plookup"), bgen)) + + // generate permutation on fr + assertNoError(permutation.Generate(conf, filepath.Join(curveDir, "fr", "permutation"), bgen)) + + // generate eddsa on companion curves + assertNoError(fri.Generate(conf, filepath.Join(curveDir, "fr", "fri"), bgen)) + + // generate iop functions + assertNoError(iop.Generate(conf, filepath.Join(curveDir, "fr", "iop"), bgen)) + }(conf) }