Skip to content

Commit 56b51eb

Browse files
committed
arbo: add CheckProofBatch and CalculateProofNodes
1 parent 98fff29 commit 56b51eb

File tree

3 files changed

+275
-10
lines changed

3 files changed

+275
-10
lines changed

Diff for: tree/arbo/circomproofs.go

+33
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
11
package arbo
22

33
import (
4+
"bytes"
45
"encoding/json"
6+
"fmt"
7+
"slices"
58
)
69

710
// CircomVerifierProof contains the needed data to check a Circom Verifier Proof
@@ -89,3 +92,33 @@ func (t *Tree) GenerateCircomVerifierProof(k []byte) (*CircomVerifierProof, erro
8992

9093
return &cp, nil
9194
}
95+
96+
// CalculateProofNodes calculates the chain of hashes in the path of the proof.
97+
// In the returned list, first item is the root, and last item is the hash of the leaf.
98+
func (cvp CircomVerifierProof) CalculateProofNodes(hashFunc HashFunction) ([][]byte, error) {
99+
paddedSiblings := slices.Clone(cvp.Siblings)
100+
for k, v := range paddedSiblings {
101+
if bytes.Equal(v, []byte{0}) {
102+
paddedSiblings[k] = make([]byte, hashFunc.Len())
103+
}
104+
}
105+
packedSiblings, err := PackSiblings(hashFunc, paddedSiblings)
106+
if err != nil {
107+
return nil, err
108+
}
109+
return CalculateProofNodes(hashFunc, cvp.Key, cvp.Value, packedSiblings, cvp.OldKey, (cvp.Fnc == 1))
110+
}
111+
112+
// CheckProof verifies the given proof. The proof verification depends on the
113+
// HashFunction passed as parameter.
114+
// Returns nil if the proof is valid, or an error otherwise.
115+
func (cvp CircomVerifierProof) CheckProof(hashFunc HashFunction) error {
116+
hashes, err := cvp.CalculateProofNodes(hashFunc)
117+
if err != nil {
118+
return err
119+
}
120+
if !bytes.Equal(hashes[0], cvp.Root) {
121+
return fmt.Errorf("calculated vs expected root mismatch")
122+
}
123+
return nil
124+
}

Diff for: tree/arbo/proof.go

+92-10
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package arbo
33
import (
44
"bytes"
55
"encoding/binary"
6+
"encoding/hex"
67
"fmt"
78
"math"
89
"slices"
@@ -161,32 +162,113 @@ func bytesToBitmap(b []byte) []bool {
161162
// HashFunction passed as parameter.
162163
// Returns nil if the proof is valid, or an error otherwise.
163164
func CheckProof(hashFunc HashFunction, k, v, root, packedSiblings []byte) error {
164-
siblings, err := UnpackSiblings(hashFunc, packedSiblings)
165+
hashes, err := CalculateProofNodes(hashFunc, k, v, packedSiblings, nil, false)
165166
if err != nil {
166167
return err
167168
}
169+
if !bytes.Equal(hashes[0], root) {
170+
return fmt.Errorf("calculated vs expected root mismatch")
171+
}
172+
return nil
173+
}
174+
175+
// CalculateProofNodes calculates the chain of hashes in the path of the given proof.
176+
// In the returned list, first item is the root, and last item is the hash of the leaf.
177+
func CalculateProofNodes(hashFunc HashFunction, k, v, packedSiblings, oldKey []byte, exclusion bool) ([][]byte, error) {
178+
siblings, err := UnpackSiblings(hashFunc, packedSiblings)
179+
if err != nil {
180+
return nil, err
181+
}
168182

169183
keyPath := make([]byte, int(math.Ceil(float64(len(siblings))/float64(8))))
170184
copy(keyPath, k)
185+
path := getPath(len(siblings), keyPath)
171186

172-
key, _, err := newLeafValue(hashFunc, k, v)
173-
if err != nil {
174-
return err
187+
key := slices.Clone(k)
188+
189+
if exclusion {
190+
if slices.Equal(k, oldKey) {
191+
return nil, fmt.Errorf("exclusion proof invalid, key and oldKey are equal")
192+
}
193+
// we'll prove the path to the existing key (passed as oldKey)
194+
key = slices.Clone(oldKey)
175195
}
176196

177-
path := getPath(len(siblings), keyPath)
197+
hash, _, err := newLeafValue(hashFunc, key, v)
198+
if err != nil {
199+
return nil, err
200+
}
201+
hashes := [][]byte{hash}
178202
for i, sibling := range slices.Backward(siblings) {
179203
if path[i] {
180-
key, _, err = newIntermediate(hashFunc, sibling, key)
204+
hash, _, err = newIntermediate(hashFunc, sibling, hash)
181205
} else {
182-
key, _, err = newIntermediate(hashFunc, key, sibling)
206+
hash, _, err = newIntermediate(hashFunc, hash, sibling)
183207
}
184208
if err != nil {
185-
return err
209+
return nil, err
186210
}
211+
hashes = append(hashes, hash)
187212
}
188-
if !bytes.Equal(key, root) {
189-
return fmt.Errorf("calculated vs expected root mismatch")
213+
slices.Reverse(hashes)
214+
return hashes, nil
215+
}
216+
217+
// CheckProofBatch verifies a batch of N proofs pairs (old and new). The proof verification depends on the
218+
// HashFunction passed as parameter.
219+
// Returns nil if the batch is valid, or an error otherwise.
220+
//
221+
// TODO: doesn't support removing leaves (newProofs can only update or add new leaves)
222+
func CheckProofBatch(hashFunc HashFunction, oldProofs, newProofs []*CircomVerifierProof) error {
223+
newBranches := make(map[string]int)
224+
newSiblings := make(map[string]int)
225+
226+
if len(oldProofs) != len(newProofs) {
227+
return fmt.Errorf("batch of proofs incomplete")
228+
}
229+
230+
if len(oldProofs) == 0 {
231+
return fmt.Errorf("empty batch")
232+
}
233+
234+
for i := range oldProofs {
235+
// Map all old branches
236+
oldNodes, err := oldProofs[i].CalculateProofNodes(hashFunc)
237+
if err != nil {
238+
return fmt.Errorf("old proof invalid: %w", err)
239+
}
240+
// and check they are valid
241+
if !bytes.Equal(oldProofs[i].Root, oldNodes[0]) {
242+
return fmt.Errorf("old proof invalid: root doesn't match")
243+
}
244+
245+
// Map all new branches
246+
newNodes, err := newProofs[i].CalculateProofNodes(hashFunc)
247+
if err != nil {
248+
return fmt.Errorf("new proof invalid: %w", err)
249+
}
250+
// and check they are valid
251+
if !bytes.Equal(newProofs[i].Root, newNodes[0]) {
252+
return fmt.Errorf("new proof invalid: root doesn't match")
253+
}
254+
255+
for level, hash := range newNodes {
256+
newBranches[hex.EncodeToString(hash)] = level
257+
}
258+
259+
for level := range newProofs[i].Siblings {
260+
if !slices.Equal(oldProofs[i].Siblings[level], newProofs[i].Siblings[level]) {
261+
// since in newBranch the root is level 0, we shift siblings to level + 1
262+
newSiblings[hex.EncodeToString(newProofs[i].Siblings[level])] = level + 1
263+
}
264+
}
190265
}
266+
267+
for hash, level := range newSiblings {
268+
if newBranches[hash] != newSiblings[hash] {
269+
return fmt.Errorf("sibling %s (at level %d) changed but there's no proof why", hash, level)
270+
}
271+
}
272+
191273
return nil
192274
}

Diff for: tree/arbo/proof_test.go

+150
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,150 @@
1+
package arbo
2+
3+
import (
4+
"math/big"
5+
"slices"
6+
"testing"
7+
8+
qt "github.com/frankban/quicktest"
9+
"go.vocdoni.io/dvote/db/metadb"
10+
)
11+
12+
func TestCheckProofBatch(t *testing.T) {
13+
database := metadb.NewTest(t)
14+
c := qt.New(t)
15+
16+
keyLen := 1
17+
maxLevels := keyLen * 8
18+
tree, err := NewTree(Config{
19+
Database: database, MaxLevels: maxLevels,
20+
HashFunction: HashFunctionBlake3,
21+
})
22+
c.Assert(err, qt.IsNil)
23+
24+
censusRoot := []byte("01234567890123456789012345678901")
25+
ballotMode := []byte("1234")
26+
27+
err = tree.Add(BigIntToBytesLE(keyLen, big.NewInt(0x01)), censusRoot)
28+
c.Assert(err, qt.IsNil)
29+
30+
err = tree.Add(BigIntToBytesLE(keyLen, big.NewInt(0x02)), ballotMode)
31+
c.Assert(err, qt.IsNil)
32+
33+
var oldProofs, newProofs []*CircomVerifierProof
34+
35+
for i := int64(0x00); i <= int64(0x04); i++ {
36+
proof, err := tree.GenerateCircomVerifierProof(BigIntToBytesLE(keyLen, big.NewInt(i)))
37+
c.Assert(err, qt.IsNil)
38+
oldProofs = append(oldProofs, proof)
39+
}
40+
41+
censusRoot[0] = byte(0x02)
42+
ballotMode[0] = byte(0x02)
43+
44+
err = tree.Update(BigIntToBytesLE(keyLen, big.NewInt(0x01)), censusRoot)
45+
c.Assert(err, qt.IsNil)
46+
47+
err = tree.Update(BigIntToBytesLE(keyLen, big.NewInt(0x02)), ballotMode)
48+
c.Assert(err, qt.IsNil)
49+
50+
err = tree.Add(BigIntToBytesLE(keyLen, big.NewInt(0x03)), ballotMode)
51+
c.Assert(err, qt.IsNil)
52+
53+
for i := int64(0x00); i <= int64(0x04); i++ {
54+
proof, err := tree.GenerateCircomVerifierProof(BigIntToBytesLE(keyLen, big.NewInt(i)))
55+
c.Assert(err, qt.IsNil)
56+
newProofs = append(newProofs, proof)
57+
}
58+
59+
// passing all proofs should be OK:
60+
// proof 1 + 2 + 3 are required
61+
// proof 0 and 4 are of unchanged keys, but the new siblings are explained by the other proofs
62+
err = CheckProofBatch(HashFunctionBlake3, oldProofs, newProofs)
63+
c.Assert(err, qt.IsNil)
64+
65+
// omitting proof 0 and 4 (unchanged keys) should also be OK
66+
err = CheckProofBatch(HashFunctionBlake3, oldProofs[1:4], newProofs[1:4])
67+
c.Assert(err, qt.IsNil)
68+
69+
// providing an empty batch should not pass
70+
err = CheckProofBatch(HashFunctionBlake3, []*CircomVerifierProof{}, []*CircomVerifierProof{})
71+
c.Assert(err, qt.ErrorMatches, "empty batch")
72+
73+
// length mismatch
74+
err = CheckProofBatch(HashFunctionBlake3, oldProofs, newProofs[:1])
75+
c.Assert(err, qt.ErrorMatches, "batch of proofs incomplete")
76+
77+
// providing just proof 0 (unchanged key) should not pass since siblings can't be explained
78+
err = CheckProofBatch(HashFunctionBlake3, oldProofs[:1], newProofs[:1])
79+
c.Assert(err, qt.ErrorMatches, ".*changed but there's no proof why.*")
80+
81+
// providing just proof 0 (unchanged key) and an add, should fail
82+
err = CheckProofBatch(HashFunctionBlake3, oldProofs[:1], newProofs[3:4])
83+
c.Assert(err, qt.ErrorMatches, ".*changed but there's no proof why.*")
84+
85+
// omitting proof 3 should fail (since changed siblings in other proofs can't be explained)
86+
err = CheckProofBatch(HashFunctionBlake3, oldProofs[:3], newProofs[:3])
87+
c.Assert(err, qt.ErrorMatches, ".*changed but there's no proof why.*")
88+
89+
// the next 4 are mangling proofs to simulate other unexplained changes in the tree, all of these should fail
90+
badProofs := deepClone(oldProofs)
91+
badProofs[0].Root = []byte("01234567890123456789012345678900")
92+
err = CheckProofBatch(HashFunctionBlake3, badProofs, newProofs)
93+
c.Assert(err, qt.ErrorMatches, "old proof invalid: root doesn't match")
94+
95+
badProofs = deepClone(oldProofs)
96+
badProofs[0].Siblings[0] = []byte("01234567890123456789012345678900")
97+
err = CheckProofBatch(HashFunctionBlake3, badProofs, newProofs)
98+
c.Assert(err, qt.ErrorMatches, "old proof invalid: root doesn't match")
99+
100+
badProofs = deepClone(newProofs)
101+
badProofs[0].Root = []byte("01234567890123456789012345678900")
102+
err = CheckProofBatch(HashFunctionBlake3, oldProofs, badProofs)
103+
c.Assert(err, qt.ErrorMatches, "new proof invalid: root doesn't match")
104+
105+
badProofs = deepClone(newProofs)
106+
badProofs[0].Siblings[0] = []byte("01234567890123456789012345678900")
107+
err = CheckProofBatch(HashFunctionBlake3, oldProofs, badProofs)
108+
c.Assert(err, qt.ErrorMatches, "new proof invalid: root doesn't match")
109+
110+
// also test exclusion proofs:
111+
// exclusion proof of key 0x04 can't be used to prove exclusion of 0x01, 0x03 or 0x05 obviously
112+
badProofs = deepClone(oldProofs)
113+
badProofs[4].Key = []byte{0x01}
114+
err = CheckProofBatch(HashFunctionBlake3, oldProofs[4:], badProofs[4:])
115+
c.Assert(err, qt.ErrorMatches, "new proof invalid: root doesn't match")
116+
badProofs[4].Key = []byte{0x03}
117+
err = CheckProofBatch(HashFunctionBlake3, oldProofs[4:], badProofs[4:])
118+
c.Assert(err, qt.ErrorMatches, "new proof invalid: root doesn't match")
119+
badProofs[4].Key = []byte{0x05}
120+
err = CheckProofBatch(HashFunctionBlake3, oldProofs[4:], badProofs[4:])
121+
c.Assert(err, qt.ErrorMatches, "new proof invalid: root doesn't match")
122+
// also can't prove key 0x02 exclusion (since that leaf exists and is indeed the starting point of the proof)
123+
badProofs[4].Key = []byte{0x02}
124+
err = CheckProofBatch(HashFunctionBlake3, oldProofs[4:], badProofs[4:])
125+
c.Assert(err, qt.ErrorMatches, "new proof invalid: exclusion proof invalid, key and oldKey are equal")
126+
// but exclusion proof of key 0x04 can also prove exclusion of the whole prefix (0x00, 0x08, 0x0c, 0x10, etc)
127+
badProofs[4].Key = []byte{0x00}
128+
err = CheckProofBatch(HashFunctionBlake3, oldProofs[4:], badProofs[4:])
129+
c.Assert(err, qt.IsNil)
130+
badProofs[4].Key = []byte{0x08}
131+
err = CheckProofBatch(HashFunctionBlake3, oldProofs[4:], badProofs[4:])
132+
c.Assert(err, qt.IsNil)
133+
badProofs[4].Key = []byte{0x0c}
134+
err = CheckProofBatch(HashFunctionBlake3, oldProofs[4:], badProofs[4:])
135+
c.Assert(err, qt.IsNil)
136+
badProofs[4].Key = []byte{0x10}
137+
err = CheckProofBatch(HashFunctionBlake3, oldProofs[4:], badProofs[4:])
138+
c.Assert(err, qt.IsNil)
139+
}
140+
141+
func deepClone(src []*CircomVerifierProof) []*CircomVerifierProof {
142+
dst := slices.Clone(src)
143+
for i := range src {
144+
proof := *src[i]
145+
dst[i] = &proof
146+
147+
dst[i].Siblings = slices.Clone(src[i].Siblings)
148+
}
149+
return dst
150+
}

0 commit comments

Comments
 (0)