Skip to content

Commit 3de80fe

Browse files
[Code Health] refactor: SMST#Root(), #Sum(), & #Count() (#51)
Signed-off-by: Bryan White <[email protected]>
1 parent 6c22c94 commit 3de80fe

File tree

6 files changed

+186
-112
lines changed

6 files changed

+186
-112
lines changed

root.go

Lines changed: 65 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -5,18 +5,11 @@ import (
55
"fmt"
66
)
77

8-
const (
9-
// These are intentionally exposed to allow for for testing and custom
10-
// implementations of downstream applications.
11-
SmtRootSizeBytes = 32
12-
SmstRootSizeBytes = SmtRootSizeBytes + sumSizeBytes + countSizeBytes
13-
)
14-
158
// MustSum returns the uint64 sum of the merkle root, it checks the length of the
169
// merkle root and if it is no the same as the size of the SMST's expected
1710
// root hash it will panic.
18-
func (r MerkleRoot) MustSum() uint64 {
19-
sum, err := r.Sum()
11+
func (root MerkleSumRoot) MustSum() uint64 {
12+
sum, err := root.Sum()
2013
if err != nil {
2114
panic(err)
2215
}
@@ -27,28 +20,76 @@ func (r MerkleRoot) MustSum() uint64 {
2720
// Sum returns the uint64 sum of the merkle root, it checks the length of the
2821
// merkle root and if it is no the same as the size of the SMST's expected
2922
// root hash it will return an error.
30-
func (r MerkleRoot) Sum() (uint64, error) {
31-
if len(r)%SmtRootSizeBytes == 0 {
32-
return 0, fmt.Errorf("root#sum: not a merkle sum trie")
23+
func (root MerkleSumRoot) Sum() (uint64, error) {
24+
if err := root.validateBasic(); err != nil {
25+
return 0, err
3326
}
3427

35-
firstSumByteIdx, firstCountByteIdx := getFirstMetaByteIdx([]byte(r))
28+
return root.sum(), nil
29+
}
30+
31+
// MustCount returns the uint64 count of the merkle root, a cryptographically secure
32+
// count of the number of non-empty leafs in the tree. It panics if the root length
33+
// is invalid.
34+
func (root MerkleSumRoot) MustCount() uint64 {
35+
count, err := root.Count()
36+
if err != nil {
37+
panic(err)
38+
}
3639

37-
var sumBz [sumSizeBytes]byte
38-
copy(sumBz[:], []byte(r)[firstSumByteIdx:firstCountByteIdx])
39-
return binary.BigEndian.Uint64(sumBz[:]), nil
40+
return count
4041
}
4142

4243
// Count returns the uint64 count of the merkle root, a cryptographically secure
43-
// count of the number of non-empty leafs in the tree.
44-
func (r MerkleRoot) Count() uint64 {
45-
if len(r)%SmtRootSizeBytes == 0 {
46-
panic("root#sum: not a merkle sum trie")
44+
// count of the number of non-empty leafs in the tree. It returns an error if the
45+
// root length is invalid.
46+
func (root MerkleSumRoot) Count() (uint64, error) {
47+
if err := root.validateBasic(); err != nil {
48+
return 0, err
4749
}
4850

49-
_, firstCountByteIdx := getFirstMetaByteIdx([]byte(r))
51+
return root.count(), nil
52+
}
53+
54+
// DigestSize returns the length of the digest portion of the root.
55+
func (root MerkleSumRoot) DigestSize() int {
56+
return len(root) - countSizeBytes - sumSizeBytes
57+
}
58+
59+
// HasDigestSize returns true if the root digest size is the same as
60+
// that of the size of the given hasher.
61+
func (root MerkleSumRoot) HasDigestSize(size int) bool {
62+
return root.DigestSize() == size
63+
}
5064

51-
var countBz [countSizeBytes]byte
52-
copy(countBz[:], []byte(r)[firstCountByteIdx:])
53-
return binary.BigEndian.Uint64(countBz[:])
65+
// validateBasic returns an error if the root digest size is not a power of two.
66+
func (root MerkleSumRoot) validateBasic() error {
67+
if !isPowerOfTwo(root.DigestSize()) {
68+
return fmt.Errorf("MerkleSumRoot#validateBasic: invalid root length")
69+
}
70+
71+
return nil
72+
}
73+
74+
// sum returns the sum of the node stored in the root.
75+
func (root MerkleSumRoot) sum() uint64 {
76+
firstSumByteIdx, firstCountByteIdx := getFirstMetaByteIdx(root)
77+
78+
return binary.BigEndian.Uint64(root[firstSumByteIdx:firstCountByteIdx])
79+
}
80+
81+
// count returns the count of the node stored in the root.
82+
func (root MerkleSumRoot) count() uint64 {
83+
_, firstCountByteIdx := getFirstMetaByteIdx(root)
84+
85+
return binary.BigEndian.Uint64(root[firstCountByteIdx:])
86+
}
87+
88+
// isPowerOfTwo function returns true if the input n is a power of 2
89+
func isPowerOfTwo(n int) bool {
90+
// A power of 2 has only one bit set in its binary representation
91+
if n <= 0 {
92+
return false
93+
}
94+
return (n & (n - 1)) == 0
5495
}

root_test.go

Lines changed: 59 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -13,64 +13,82 @@ import (
1313
"github.com/pokt-network/smt/kvstore/simplemap"
1414
)
1515

16-
func TestMerkleRoot_TrieTypes(t *testing.T) {
16+
func TestMerkleSumRoot_SumAndCountSuccess(t *testing.T) {
1717
tests := []struct {
18-
desc string
19-
sumTree bool
20-
hasher hash.Hash
21-
expectedPanic string
18+
desc string
19+
hasher hash.Hash
2220
}{
2321
{
24-
desc: "successfully: gets sum of sha256 hasher SMST",
25-
sumTree: true,
26-
hasher: sha256.New(),
27-
expectedPanic: "",
22+
desc: "sha256 hasher",
23+
hasher: sha256.New(),
2824
},
2925
{
30-
desc: "successfully: gets sum of sha512 hasher SMST",
31-
sumTree: true,
32-
hasher: sha512.New(),
33-
expectedPanic: "",
26+
desc: "sha512 hasher",
27+
hasher: sha512.New(),
3428
},
29+
}
30+
31+
nodeStore := simplemap.NewSimpleMap()
32+
for _, test := range tests {
33+
t.Run(test.desc, func(t *testing.T) {
34+
t.Cleanup(func() {
35+
require.NoError(t, nodeStore.ClearAll())
36+
})
37+
trie := smt.NewSparseMerkleSumTrie(nodeStore, test.hasher)
38+
for i := uint64(0); i < 10; i++ {
39+
require.NoError(t, trie.Update([]byte(fmt.Sprintf("key%d", i)), []byte(fmt.Sprintf("value%d", i)), i))
40+
}
41+
42+
sum, sumErr := trie.Sum()
43+
require.NoError(t, sumErr)
44+
45+
count, countErr := trie.Count()
46+
require.NoError(t, countErr)
47+
48+
require.EqualValues(t, uint64(45), sum)
49+
require.EqualValues(t, uint64(10), count)
50+
})
51+
}
52+
}
53+
54+
func TestMekleRoot_SumAndCountError(t *testing.T) {
55+
tests := []struct {
56+
desc string
57+
hasher hash.Hash
58+
}{
3559
{
36-
desc: "failure: panics for sha256 hasher SMT",
37-
sumTree: false,
38-
hasher: sha256.New(),
39-
expectedPanic: "roo#sum: not a merkle sum trie",
60+
desc: "sha256 hasher",
61+
hasher: sha256.New(),
4062
},
4163
{
42-
desc: "failure: panics for sha512 hasher SMT",
43-
sumTree: false,
44-
hasher: sha512.New(),
45-
expectedPanic: "roo#sum: not a merkle sum trie",
64+
desc: "sha512 hasher",
65+
hasher: sha512.New(),
4666
},
4767
}
4868

4969
nodeStore := simplemap.NewSimpleMap()
50-
for _, tt := range tests {
51-
tt := tt
52-
t.Run(tt.desc, func(t *testing.T) {
70+
for _, test := range tests {
71+
t.Run(test.desc, func(t *testing.T) {
5372
t.Cleanup(func() {
5473
require.NoError(t, nodeStore.ClearAll())
5574
})
56-
if tt.sumTree {
57-
trie := smt.NewSparseMerkleSumTrie(nodeStore, tt.hasher)
58-
for i := uint64(0); i < 10; i++ {
59-
require.NoError(t, trie.Update([]byte(fmt.Sprintf("key%d", i)), []byte(fmt.Sprintf("value%d", i)), i))
60-
}
61-
require.NotNil(t, trie.Sum())
62-
require.EqualValues(t, 45, trie.Sum())
63-
require.EqualValues(t, 10, trie.Count())
64-
65-
return
66-
}
67-
trie := smt.NewSparseMerkleTrie(nodeStore, tt.hasher)
68-
for i := 0; i < 10; i++ {
69-
require.NoError(t, trie.Update([]byte(fmt.Sprintf("key%d", i)), []byte(fmt.Sprintf("value%d", i))))
70-
}
71-
if panicStr := recover(); panicStr != nil {
72-
require.Equal(t, tt.expectedPanic, panicStr)
75+
trie := smt.NewSparseMerkleSumTrie(nodeStore, test.hasher)
76+
for i := uint64(0); i < 10; i++ {
77+
require.NoError(t, trie.Update([]byte(fmt.Sprintf("key%d", i)), []byte(fmt.Sprintf("value%d", i)), i))
7378
}
79+
80+
root := trie.Root()
81+
82+
// Mangle the root bytes.
83+
root = root[:len(root)-1]
84+
85+
sum, sumErr := root.Sum()
86+
require.Error(t, sumErr)
87+
require.Equal(t, uint64(0), sum)
88+
89+
count, countErr := root.Count()
90+
require.Error(t, countErr)
91+
require.Equal(t, uint64(0), count)
7492
})
7593
}
7694
}

smst.go

Lines changed: 30 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package smt
33
import (
44
"bytes"
55
"encoding/binary"
6+
"fmt"
67
"hash"
78

89
"github.com/pokt-network/smt/kvstore"
@@ -170,39 +171,48 @@ func (smst *SMST) Commit() error {
170171
}
171172

172173
// Root returns the root hash of the trie with the total sum bytes appended
173-
func (smst *SMST) Root() MerkleRoot {
174-
return smst.SMT.Root() // [digest]+[binary sum]
174+
func (smst *SMST) Root() MerkleSumRoot {
175+
return MerkleSumRoot(smst.SMT.Root()) // [digest]+[binary sum]+[binary count]
175176
}
176177

177-
// Sum returns the sum of the entire trie stored in the root.
178+
// MustSum returns the sum of the entire trie stored in the root.
178179
// If the tree is not a sum tree, it will panic.
179-
func (smst *SMST) Sum() uint64 {
180-
rootDigest := []byte(smst.Root())
180+
func (smst *SMST) MustSum() uint64 {
181+
sum, err := smst.Sum()
182+
if err != nil {
183+
panic(err)
184+
}
185+
return sum
186+
}
181187

188+
// Sum returns the sum of the entire trie stored in the root.
189+
// If the tree is not a sum tree, it will return an error.
190+
func (smst *SMST) Sum() (uint64, error) {
182191
if !smst.Spec().sumTrie {
183-
panic("SMST: not a merkle sum trie")
192+
return 0, fmt.Errorf("SMST: not a merkle sum trie")
184193
}
185194

186-
firstSumByteIdx, firstCountByteIdx := getFirstMetaByteIdx(rootDigest)
195+
return smst.Root().Sum()
196+
}
187197

188-
var sumBz [sumSizeBytes]byte
189-
copy(sumBz[:], rootDigest[firstSumByteIdx:firstCountByteIdx])
190-
return binary.BigEndian.Uint64(sumBz[:])
198+
// MustCount returns the number of non-empty nodes in the entire trie stored in the root.
199+
// If the tree is not a sum tree, it will panic.
200+
func (smst *SMST) MustCount() uint64 {
201+
count, err := smst.Count()
202+
if err != nil {
203+
panic(err)
204+
}
205+
return count
191206
}
192207

193208
// Count returns the number of non-empty nodes in the entire trie stored in the root.
194-
func (smst *SMST) Count() uint64 {
195-
rootDigest := []byte(smst.Root())
196-
209+
// If the tree is not a sum tree, it will return an error.
210+
func (smst *SMST) Count() (uint64, error) {
197211
if !smst.Spec().sumTrie {
198-
panic("SMST: not a merkle sum trie")
212+
return 0, fmt.Errorf("SMST: not a merkle sum trie")
199213
}
200214

201-
_, firstCountByteIdx := getFirstMetaByteIdx(rootDigest)
202-
203-
var countBz [countSizeBytes]byte
204-
copy(countBz[:], rootDigest[firstCountByteIdx:])
205-
return binary.BigEndian.Uint64(countBz[:])
215+
return smst.Root().Count()
206216
}
207217

208218
// getFirstMetaByteIdx returns the index of the first count byte and the first sum byte
@@ -211,5 +221,5 @@ func (smst *SMST) Count() uint64 {
211221
func getFirstMetaByteIdx(data []byte) (firstSumByteIdx, firstCountByteIdx int) {
212222
firstCountByteIdx = len(data) - countSizeBytes
213223
firstSumByteIdx = firstCountByteIdx - sumSizeBytes
214-
return
224+
return firstSumByteIdx, firstCountByteIdx
215225
}

smst_example_test.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ func TestExampleSMST(t *testing.T) {
2828
_ = trie.Commit()
2929

3030
// Calculate the total sum of the trie
31-
_ = trie.Sum() // 20
31+
_ = trie.MustSum() // 20
3232

3333
// Generate a Merkle proof for "foo"
3434
proof1, _ := trie.Prove([]byte("foo"))
@@ -52,8 +52,8 @@ func TestExampleSMST(t *testing.T) {
5252
require.False(t, valid_false1)
5353

5454
// Verify the total sum of the trie
55-
require.EqualValues(t, 20, trie.Sum())
55+
require.EqualValues(t, 20, trie.MustSum())
5656

5757
// Verify the number of non-empty leafs in the trie
58-
require.EqualValues(t, 3, trie.Count())
58+
require.EqualValues(t, 3, trie.MustCount())
5959
}

0 commit comments

Comments
 (0)