Skip to content

Commit 3de80fe

Browse files
authoredJul 17, 2024··
[Code Health] refactor: SMST#Root(), #Sum(), & #Count() (#51)
Signed-off-by: Bryan White <bryanchriswhite@gmail.com>
1 parent 6c22c94 commit 3de80fe

File tree

6 files changed

+186
-112
lines changed

6 files changed

+186
-112
lines changed
 

‎root.go

+65-24
Original file line numberDiff line numberDiff line change
@@ -5,18 +5,11 @@ import (
55
"fmt"
66
)
77

8-
const (
9-
// These are intentionally exposed to allow for for testing and custom
10-
// implementations of downstream applications.
11-
SmtRootSizeBytes = 32
12-
SmstRootSizeBytes = SmtRootSizeBytes + sumSizeBytes + countSizeBytes
13-
)
14-
158
// MustSum returns the uint64 sum of the merkle root, it checks the length of the
169
// merkle root and if it is no the same as the size of the SMST's expected
1710
// root hash it will panic.
18-
func (r MerkleRoot) MustSum() uint64 {
19-
sum, err := r.Sum()
11+
func (root MerkleSumRoot) MustSum() uint64 {
12+
sum, err := root.Sum()
2013
if err != nil {
2114
panic(err)
2215
}
@@ -27,28 +20,76 @@ func (r MerkleRoot) MustSum() uint64 {
2720
// Sum returns the uint64 sum of the merkle root, it checks the length of the
2821
// merkle root and if it is no the same as the size of the SMST's expected
2922
// root hash it will return an error.
30-
func (r MerkleRoot) Sum() (uint64, error) {
31-
if len(r)%SmtRootSizeBytes == 0 {
32-
return 0, fmt.Errorf("root#sum: not a merkle sum trie")
23+
func (root MerkleSumRoot) Sum() (uint64, error) {
24+
if err := root.validateBasic(); err != nil {
25+
return 0, err
3326
}
3427

35-
firstSumByteIdx, firstCountByteIdx := getFirstMetaByteIdx([]byte(r))
28+
return root.sum(), nil
29+
}
30+
31+
// MustCount returns the uint64 count of the merkle root, a cryptographically secure
32+
// count of the number of non-empty leafs in the tree. It panics if the root length
33+
// is invalid.
34+
func (root MerkleSumRoot) MustCount() uint64 {
35+
count, err := root.Count()
36+
if err != nil {
37+
panic(err)
38+
}
3639

37-
var sumBz [sumSizeBytes]byte
38-
copy(sumBz[:], []byte(r)[firstSumByteIdx:firstCountByteIdx])
39-
return binary.BigEndian.Uint64(sumBz[:]), nil
40+
return count
4041
}
4142

4243
// Count returns the uint64 count of the merkle root, a cryptographically secure
43-
// count of the number of non-empty leafs in the tree.
44-
func (r MerkleRoot) Count() uint64 {
45-
if len(r)%SmtRootSizeBytes == 0 {
46-
panic("root#sum: not a merkle sum trie")
44+
// count of the number of non-empty leafs in the tree. It returns an error if the
45+
// root length is invalid.
46+
func (root MerkleSumRoot) Count() (uint64, error) {
47+
if err := root.validateBasic(); err != nil {
48+
return 0, err
4749
}
4850

49-
_, firstCountByteIdx := getFirstMetaByteIdx([]byte(r))
51+
return root.count(), nil
52+
}
53+
54+
// DigestSize returns the length of the digest portion of the root.
55+
func (root MerkleSumRoot) DigestSize() int {
56+
return len(root) - countSizeBytes - sumSizeBytes
57+
}
58+
59+
// HasDigestSize returns true if the root digest size is the same as
60+
// that of the size of the given hasher.
61+
func (root MerkleSumRoot) HasDigestSize(size int) bool {
62+
return root.DigestSize() == size
63+
}
5064

51-
var countBz [countSizeBytes]byte
52-
copy(countBz[:], []byte(r)[firstCountByteIdx:])
53-
return binary.BigEndian.Uint64(countBz[:])
65+
// validateBasic returns an error if the root digest size is not a power of two.
66+
func (root MerkleSumRoot) validateBasic() error {
67+
if !isPowerOfTwo(root.DigestSize()) {
68+
return fmt.Errorf("MerkleSumRoot#validateBasic: invalid root length")
69+
}
70+
71+
return nil
72+
}
73+
74+
// sum returns the sum of the node stored in the root.
75+
func (root MerkleSumRoot) sum() uint64 {
76+
firstSumByteIdx, firstCountByteIdx := getFirstMetaByteIdx(root)
77+
78+
return binary.BigEndian.Uint64(root[firstSumByteIdx:firstCountByteIdx])
79+
}
80+
81+
// count returns the count of the node stored in the root.
82+
func (root MerkleSumRoot) count() uint64 {
83+
_, firstCountByteIdx := getFirstMetaByteIdx(root)
84+
85+
return binary.BigEndian.Uint64(root[firstCountByteIdx:])
86+
}
87+
88+
// isPowerOfTwo function returns true if the input n is a power of 2
89+
func isPowerOfTwo(n int) bool {
90+
// A power of 2 has only one bit set in its binary representation
91+
if n <= 0 {
92+
return false
93+
}
94+
return (n & (n - 1)) == 0
5495
}

‎root_test.go

+59-41
Original file line numberDiff line numberDiff line change
@@ -13,64 +13,82 @@ import (
1313
"github.com/pokt-network/smt/kvstore/simplemap"
1414
)
1515

16-
func TestMerkleRoot_TrieTypes(t *testing.T) {
16+
func TestMerkleSumRoot_SumAndCountSuccess(t *testing.T) {
1717
tests := []struct {
18-
desc string
19-
sumTree bool
20-
hasher hash.Hash
21-
expectedPanic string
18+
desc string
19+
hasher hash.Hash
2220
}{
2321
{
24-
desc: "successfully: gets sum of sha256 hasher SMST",
25-
sumTree: true,
26-
hasher: sha256.New(),
27-
expectedPanic: "",
22+
desc: "sha256 hasher",
23+
hasher: sha256.New(),
2824
},
2925
{
30-
desc: "successfully: gets sum of sha512 hasher SMST",
31-
sumTree: true,
32-
hasher: sha512.New(),
33-
expectedPanic: "",
26+
desc: "sha512 hasher",
27+
hasher: sha512.New(),
3428
},
29+
}
30+
31+
nodeStore := simplemap.NewSimpleMap()
32+
for _, test := range tests {
33+
t.Run(test.desc, func(t *testing.T) {
34+
t.Cleanup(func() {
35+
require.NoError(t, nodeStore.ClearAll())
36+
})
37+
trie := smt.NewSparseMerkleSumTrie(nodeStore, test.hasher)
38+
for i := uint64(0); i < 10; i++ {
39+
require.NoError(t, trie.Update([]byte(fmt.Sprintf("key%d", i)), []byte(fmt.Sprintf("value%d", i)), i))
40+
}
41+
42+
sum, sumErr := trie.Sum()
43+
require.NoError(t, sumErr)
44+
45+
count, countErr := trie.Count()
46+
require.NoError(t, countErr)
47+
48+
require.EqualValues(t, uint64(45), sum)
49+
require.EqualValues(t, uint64(10), count)
50+
})
51+
}
52+
}
53+
54+
func TestMekleRoot_SumAndCountError(t *testing.T) {
55+
tests := []struct {
56+
desc string
57+
hasher hash.Hash
58+
}{
3559
{
36-
desc: "failure: panics for sha256 hasher SMT",
37-
sumTree: false,
38-
hasher: sha256.New(),
39-
expectedPanic: "roo#sum: not a merkle sum trie",
60+
desc: "sha256 hasher",
61+
hasher: sha256.New(),
4062
},
4163
{
42-
desc: "failure: panics for sha512 hasher SMT",
43-
sumTree: false,
44-
hasher: sha512.New(),
45-
expectedPanic: "roo#sum: not a merkle sum trie",
64+
desc: "sha512 hasher",
65+
hasher: sha512.New(),
4666
},
4767
}
4868

4969
nodeStore := simplemap.NewSimpleMap()
50-
for _, tt := range tests {
51-
tt := tt
52-
t.Run(tt.desc, func(t *testing.T) {
70+
for _, test := range tests {
71+
t.Run(test.desc, func(t *testing.T) {
5372
t.Cleanup(func() {
5473
require.NoError(t, nodeStore.ClearAll())
5574
})
56-
if tt.sumTree {
57-
trie := smt.NewSparseMerkleSumTrie(nodeStore, tt.hasher)
58-
for i := uint64(0); i < 10; i++ {
59-
require.NoError(t, trie.Update([]byte(fmt.Sprintf("key%d", i)), []byte(fmt.Sprintf("value%d", i)), i))
60-
}
61-
require.NotNil(t, trie.Sum())
62-
require.EqualValues(t, 45, trie.Sum())
63-
require.EqualValues(t, 10, trie.Count())
64-
65-
return
66-
}
67-
trie := smt.NewSparseMerkleTrie(nodeStore, tt.hasher)
68-
for i := 0; i < 10; i++ {
69-
require.NoError(t, trie.Update([]byte(fmt.Sprintf("key%d", i)), []byte(fmt.Sprintf("value%d", i))))
70-
}
71-
if panicStr := recover(); panicStr != nil {
72-
require.Equal(t, tt.expectedPanic, panicStr)
75+
trie := smt.NewSparseMerkleSumTrie(nodeStore, test.hasher)
76+
for i := uint64(0); i < 10; i++ {
77+
require.NoError(t, trie.Update([]byte(fmt.Sprintf("key%d", i)), []byte(fmt.Sprintf("value%d", i)), i))
7378
}
79+
80+
root := trie.Root()
81+
82+
// Mangle the root bytes.
83+
root = root[:len(root)-1]
84+
85+
sum, sumErr := root.Sum()
86+
require.Error(t, sumErr)
87+
require.Equal(t, uint64(0), sum)
88+
89+
count, countErr := root.Count()
90+
require.Error(t, countErr)
91+
require.Equal(t, uint64(0), count)
7492
})
7593
}
7694
}

‎smst.go

+30-20
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package smt
33
import (
44
"bytes"
55
"encoding/binary"
6+
"fmt"
67
"hash"
78

89
"github.com/pokt-network/smt/kvstore"
@@ -170,39 +171,48 @@ func (smst *SMST) Commit() error {
170171
}
171172

172173
// Root returns the root hash of the trie with the total sum bytes appended
173-
func (smst *SMST) Root() MerkleRoot {
174-
return smst.SMT.Root() // [digest]+[binary sum]
174+
func (smst *SMST) Root() MerkleSumRoot {
175+
return MerkleSumRoot(smst.SMT.Root()) // [digest]+[binary sum]+[binary count]
175176
}
176177

177-
// Sum returns the sum of the entire trie stored in the root.
178+
// MustSum returns the sum of the entire trie stored in the root.
178179
// If the tree is not a sum tree, it will panic.
179-
func (smst *SMST) Sum() uint64 {
180-
rootDigest := []byte(smst.Root())
180+
func (smst *SMST) MustSum() uint64 {
181+
sum, err := smst.Sum()
182+
if err != nil {
183+
panic(err)
184+
}
185+
return sum
186+
}
181187

188+
// Sum returns the sum of the entire trie stored in the root.
189+
// If the tree is not a sum tree, it will return an error.
190+
func (smst *SMST) Sum() (uint64, error) {
182191
if !smst.Spec().sumTrie {
183-
panic("SMST: not a merkle sum trie")
192+
return 0, fmt.Errorf("SMST: not a merkle sum trie")
184193
}
185194

186-
firstSumByteIdx, firstCountByteIdx := getFirstMetaByteIdx(rootDigest)
195+
return smst.Root().Sum()
196+
}
187197

188-
var sumBz [sumSizeBytes]byte
189-
copy(sumBz[:], rootDigest[firstSumByteIdx:firstCountByteIdx])
190-
return binary.BigEndian.Uint64(sumBz[:])
198+
// MustCount returns the number of non-empty nodes in the entire trie stored in the root.
199+
// If the tree is not a sum tree, it will panic.
200+
func (smst *SMST) MustCount() uint64 {
201+
count, err := smst.Count()
202+
if err != nil {
203+
panic(err)
204+
}
205+
return count
191206
}
192207

193208
// Count returns the number of non-empty nodes in the entire trie stored in the root.
194-
func (smst *SMST) Count() uint64 {
195-
rootDigest := []byte(smst.Root())
196-
209+
// If the tree is not a sum tree, it will return an error.
210+
func (smst *SMST) Count() (uint64, error) {
197211
if !smst.Spec().sumTrie {
198-
panic("SMST: not a merkle sum trie")
212+
return 0, fmt.Errorf("SMST: not a merkle sum trie")
199213
}
200214

201-
_, firstCountByteIdx := getFirstMetaByteIdx(rootDigest)
202-
203-
var countBz [countSizeBytes]byte
204-
copy(countBz[:], rootDigest[firstCountByteIdx:])
205-
return binary.BigEndian.Uint64(countBz[:])
215+
return smst.Root().Count()
206216
}
207217

208218
// getFirstMetaByteIdx returns the index of the first count byte and the first sum byte
@@ -211,5 +221,5 @@ func (smst *SMST) Count() uint64 {
211221
func getFirstMetaByteIdx(data []byte) (firstSumByteIdx, firstCountByteIdx int) {
212222
firstCountByteIdx = len(data) - countSizeBytes
213223
firstSumByteIdx = firstCountByteIdx - sumSizeBytes
214-
return
224+
return firstSumByteIdx, firstCountByteIdx
215225
}

‎smst_example_test.go

+3-3
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ func TestExampleSMST(t *testing.T) {
2828
_ = trie.Commit()
2929

3030
// Calculate the total sum of the trie
31-
_ = trie.Sum() // 20
31+
_ = trie.MustSum() // 20
3232

3333
// Generate a Merkle proof for "foo"
3434
proof1, _ := trie.Prove([]byte("foo"))
@@ -52,8 +52,8 @@ func TestExampleSMST(t *testing.T) {
5252
require.False(t, valid_false1)
5353

5454
// Verify the total sum of the trie
55-
require.EqualValues(t, 20, trie.Sum())
55+
require.EqualValues(t, 20, trie.MustSum())
5656

5757
// Verify the number of non-empty leafs in the trie
58-
require.EqualValues(t, 3, trie.Count())
58+
require.EqualValues(t, 3, trie.MustCount())
5959
}

‎smst_test.go

+20-20
Original file line numberDiff line numberDiff line change
@@ -359,36 +359,36 @@ func TestSMST_OrphanRemoval(t *testing.T) {
359359
err = smst.Update([]byte("testKey"), []byte("testValue"), 5)
360360
require.NoError(t, err)
361361
require.Equal(t, 1, nodeCount(t)) // only root node
362-
require.Equal(t, uint64(1), impl.Count())
362+
require.Equal(t, uint64(1), impl.MustCount())
363363
}
364364

365365
t.Run("delete 1", func(t *testing.T) {
366366
setup()
367367
err = smst.Delete([]byte("testKey"))
368368
require.NoError(t, err)
369369
require.Equal(t, 0, nodeCount(t))
370-
require.Equal(t, uint64(0), impl.Count())
370+
require.Equal(t, uint64(0), impl.MustCount())
371371
})
372372

373373
t.Run("overwrite 1", func(t *testing.T) {
374374
setup()
375375
err = smst.Update([]byte("testKey"), []byte("testValue2"), 10)
376376
require.NoError(t, err)
377377
require.Equal(t, 1, nodeCount(t))
378-
require.Equal(t, uint64(1), impl.Count())
378+
require.Equal(t, uint64(1), impl.MustCount())
379379
})
380380

381381
t.Run("overwrite and delete", func(t *testing.T) {
382382
setup()
383383
err = smst.Update([]byte("testKey"), []byte("testValue2"), 2)
384384
require.NoError(t, err)
385385
require.Equal(t, 1, nodeCount(t))
386-
require.Equal(t, uint64(1), impl.Count())
386+
require.Equal(t, uint64(1), impl.MustCount())
387387

388388
err = smst.Delete([]byte("testKey"))
389389
require.NoError(t, err)
390390
require.Equal(t, 0, nodeCount(t))
391-
require.Equal(t, uint64(0), impl.Count())
391+
require.Equal(t, uint64(0), impl.MustCount())
392392
})
393393

394394
type testCase struct {
@@ -436,29 +436,29 @@ func TestSMST_OrphanRemoval(t *testing.T) {
436436
require.NoError(t, err, tci)
437437
}
438438
require.Equal(t, tc.expectedNodeCount, nodeCount(t), tci)
439-
require.Equal(t, uint64(tc.expectedLeafCount), impl.Count())
439+
require.Equal(t, uint64(tc.expectedLeafCount), impl.MustCount())
440440

441441
// Overwrite doesn't change node or leaf count
442442
for _, key := range tc.keys {
443443
err = smst.Update([]byte(key), []byte("testValue3"), 10)
444444
require.NoError(t, err, tci)
445445
}
446446
require.Equal(t, tc.expectedNodeCount, nodeCount(t), tci)
447-
require.Equal(t, uint64(tc.expectedLeafCount), impl.Count())
447+
require.Equal(t, uint64(tc.expectedLeafCount), impl.MustCount())
448448

449449
// Deletion removes all nodes except root
450450
for _, key := range tc.keys {
451451
err = smst.Delete([]byte(key))
452452
require.NoError(t, err, tci)
453453
}
454454
require.Equal(t, 1, nodeCount(t), tci)
455-
require.Equal(t, uint64(1), impl.Count())
455+
require.Equal(t, uint64(1), impl.MustCount())
456456

457457
// Deleting and re-inserting a persisted node doesn't change count
458458
require.NoError(t, smst.Delete([]byte("testKey")))
459459
require.NoError(t, smst.Update([]byte("testKey"), []byte("testValue"), 10))
460460
require.Equal(t, 1, nodeCount(t), tci)
461-
require.Equal(t, uint64(1), impl.Count())
461+
require.Equal(t, uint64(1), impl.MustCount())
462462
})
463463
}
464464
}
@@ -486,12 +486,12 @@ func TestSMST_TotalSum(t *testing.T) {
486486
rootCount := binary.BigEndian.Uint64(countBz)
487487

488488
// Retrieve and compare the sum
489-
sum := smst.Sum()
489+
sum := smst.MustSum()
490490
require.Equal(t, sum, uint64(15))
491491
require.Equal(t, sum, rootSum)
492492

493493
// Retrieve and compare the count
494-
count := smst.Count()
494+
count := smst.MustCount()
495495
require.Equal(t, count, uint64(3))
496496
require.Equal(t, count, rootCount)
497497

@@ -506,22 +506,22 @@ func TestSMST_TotalSum(t *testing.T) {
506506
// Check that the sum is correct after deleting a key
507507
err = smst.Delete([]byte("key1"))
508508
require.NoError(t, err)
509-
sum = smst.Sum()
509+
sum = smst.MustSum()
510510
require.Equal(t, sum, uint64(10))
511511

512512
// Check that the count is correct after deleting a key
513-
count = smst.Count()
513+
count = smst.MustCount()
514514
require.Equal(t, count, uint64(2))
515515

516516
// Check that the sum is correct after importing the trie
517517
require.NoError(t, smst.Commit())
518518
root2 := smst.Root()
519519
smst = ImportSparseMerkleSumTrie(snm, sha256.New(), root2)
520-
sum = smst.Sum()
520+
sum = smst.MustSum()
521521
require.Equal(t, sum, uint64(10))
522522

523523
// Check that the count is correct after importing the trie
524-
count = smst.Count()
524+
count = smst.MustCount()
525525
require.Equal(t, count, uint64(2))
526526

527527
// Calculate the total sum of a larger trie
@@ -532,11 +532,11 @@ func TestSMST_TotalSum(t *testing.T) {
532532
require.NoError(t, err)
533533
}
534534
require.NoError(t, smst.Commit())
535-
sum = smst.Sum()
535+
sum = smst.MustSum()
536536
require.Equal(t, sum, uint64(49995000))
537537

538538
// Check that the count is correct after building a larger trie
539-
count = smst.Count()
539+
count = smst.MustCount()
540540
require.Equal(t, count, uint64(9999))
541541
}
542542

@@ -584,7 +584,7 @@ func TestSMST_Retrieval(t *testing.T) {
584584
require.Equal(t, uint64(5), sum)
585585

586586
root := smst.Root()
587-
sum = smst.Sum()
587+
sum = smst.MustSum()
588588
require.Equal(t, sum, uint64(15))
589589

590590
lazy := ImportSparseMerkleSumTrie(snm, sha256.New(), root, WithValueHasher(nil))
@@ -604,9 +604,9 @@ func TestSMST_Retrieval(t *testing.T) {
604604
require.Equal(t, []byte("value3"), value)
605605
require.Equal(t, uint64(5), sum)
606606

607-
sum = lazy.Sum()
607+
sum = lazy.MustSum()
608608
require.Equal(t, sum, uint64(15))
609609

610-
count := lazy.Count()
610+
count := lazy.MustCount()
611611
require.Equal(t, count, uint64(3))
612612
}

‎types.go

+9-4
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,12 @@ var (
2222
defaultEmptyCount [countSizeBytes]byte
2323
)
2424

25-
// MerkleRoot is a type alias for a byte slice returned from the Root method
25+
// MerkleRoot is a type alias for a byte slice returned from SparseMerkleTrie#Root().
2626
type MerkleRoot []byte
2727

28+
// MerkleSumRoot is a type alias for a byte slice returned from SparseMerkleSumTrie#Root().
29+
type MerkleSumRoot []byte
30+
2831
// A high-level interface that captures the behaviour of all types of nodes
2932
type trieNode interface {
3033
// Persisted returns a boolean to determine whether or not the node
@@ -68,11 +71,13 @@ type SparseMerkleSumTrie interface {
6871
// Get descends the trie to access a value. Returns nil if key is not present.
6972
Get(key []byte) (data []byte, sum uint64, err error)
7073
// Root computes the Merkle root digest.
71-
Root() MerkleRoot
74+
Root() MerkleSumRoot
7275
// Sum computes the total sum of the Merkle trie
73-
Sum() uint64
76+
Sum() (uint64, error)
77+
MustSum() uint64
7478
// Count returns the total number of non-empty leaves in the trie
75-
Count() uint64
79+
Count() (uint64, error)
80+
MustCount() uint64
7681
// Prove computes a Merkle proof of inclusion or exclusion of a key.
7782
Prove(key []byte) (*SparseMerkleProof, error)
7883
// ProveClosest computes a Merkle proof of inclusion for a key in the trie

0 commit comments

Comments
 (0)
Please sign in to comment.