Skip to content
This repository was archived by the owner on Apr 18, 2025. It is now read-only.

Commit b82ea41

Browse files
authored
Feat/privacy-scaling-explorations#1752 fix witness generation - adding testing (privacy-scaling-explorations#1784)
### Description closed privacy-scaling-explorations#1752 - Adding different numbers of txs to test different trie status (e.g. one leaf only, one ext. , one branch and one leaf ...etc) - Fixing `GetProof` (if input is an ext node, it always assign first child of the root ext. node, `st.children[0]`) - Fixing `getNodeFromBranchRLP` (if input is an leaf, it could throw an exception) - Refactoring `getNodeFromBranchRLP` ### Issue Link privacy-scaling-explorations#1752 ### Type of change - [x] Bug fix (non-breaking change which fixes an issue)
1 parent 41e3408 commit b82ea41

File tree

2 files changed

+216
-99
lines changed

2 files changed

+216
-99
lines changed

geth-utils/gethutil/mpt/trie/stacktrie.go

Lines changed: 91 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -243,16 +243,17 @@ func (st *StackTrie) insert(key, value []byte) {
243243
break
244244
}
245245
}
246+
246247
// Add new child
247248
if st.children[idx] == nil {
248-
st.children[idx] = stackTrieFromPool(st.db)
249-
st.children[idx].keyOffset = st.keyOffset + 1
249+
st.children[idx] = newLeaf(st.keyOffset+1, key, value, st.db)
250+
} else {
251+
st.children[idx].insert(key, value)
250252
}
251-
st.children[idx].insert(key, value)
253+
252254
case extNode: /* Ext */
253255
// Compare both key chunks and see where they differ
254256
diffidx := st.getDiffIndex(key)
255-
256257
// Check if chunks are identical. If so, recurse into
257258
// the child node. Otherwise, the key has to be split
258259
// into 1) an optional common prefix, 2) the fullnode
@@ -551,57 +552,85 @@ func (st *StackTrie) Commit() (common.Hash, error) {
551552
return common.BytesToHash(st.val), nil
552553
}
553554

554-
func (st *StackTrie) getNodeFromBranchRLP(branch []byte, ind byte) []byte {
555-
start := 2 // when branch[0] == 248
556-
if branch[0] == 249 {
557-
start = 3
558-
}
559-
560-
i := 0
561-
insideInd := -1
562-
cInd := byte(0)
563-
for {
564-
if start+i == len(branch)-1 { // -1 because of the last 128 (branch value)
565-
return []byte{0}
566-
}
567-
b := branch[start+i]
568-
if insideInd == -1 && b == 128 {
569-
if cInd == ind {
555+
const RLP_SHORT_STR_FLAG = 128
556+
const RLP_SHORT_LIST_FLAG = 192
557+
const RLP_LONG_LIST_FLAG = 248
558+
const LEN_OF_HASH = 32
559+
560+
// Note:
561+
// In RLP encoding, if the value is between [0x80, 0xb7] ([128, 183]),
562+
// it means following data is a short string (0 - 55bytes).
563+
// Which implies if the value is 128, it's an empty string.
564+
func (st *StackTrie) getNodeFromBranchRLP(branch []byte, idx int) []byte {
565+
566+
start := int(branch[0])
567+
start_idx := 0
568+
if start >= RLP_SHORT_LIST_FLAG && start < RLP_LONG_LIST_FLAG {
569+
// In RLP encoding, length in the range of [192 248] is a short list.
570+
// In stack trie, it usually means an extension node and the first byte is nibble
571+
// and that's why we start from 2
572+
start_idx = 2
573+
} else if start >= RLP_LONG_LIST_FLAG {
574+
// In RLP encoding, length in the range of [248 ~ ] is a long list.
575+
// The RLP byte minus 248 (branch[0] - 248) is the length in bytes of the length of the payload
576+
// and the payload is right after the length.
577+
// That's why we add 2 here
578+
// e.g. [248 81 128 160 ...]
579+
// `81` is the length of the payload and payload starts from `128`
580+
start_idx = start - RLP_LONG_LIST_FLAG + 2
581+
}
582+
583+
// If 1st node is neither 128(empty node) nor 160, it should be a leaf
584+
b := int(branch[start_idx])
585+
if b != RLP_SHORT_STR_FLAG && b != (RLP_SHORT_STR_FLAG+LEN_OF_HASH) {
586+
return []byte{0}
587+
}
588+
589+
current_idx := 0
590+
for i := start_idx; i < len(branch); i++ {
591+
b = int(branch[i])
592+
switch b {
593+
case RLP_SHORT_STR_FLAG: // 128
594+
// if the current index is we're looking for, return an empty node directly
595+
if current_idx == idx {
570596
return []byte{128}
571-
} else {
572-
cInd += 1
573597
}
574-
} else if insideInd == -1 && b != 128 {
575-
if b == 160 {
576-
if cInd == ind {
577-
return branch[start+i+1 : start+i+1+32]
578-
}
579-
insideInd = 32
580-
} else {
581-
// non-hashed node
582-
if cInd == ind {
583-
return branch[start+i+1 : start+i+1+int(b)-192]
584-
}
585-
insideInd = int(b) - 192
598+
current_idx++
599+
case RLP_SHORT_STR_FLAG + LEN_OF_HASH: // 160
600+
if current_idx == idx {
601+
return branch[i+1 : i+1+LEN_OF_HASH]
586602
}
587-
cInd += 1
588-
} else {
589-
if insideInd == 1 {
590-
insideInd = -1
591-
} else {
592-
insideInd--
603+
// jump to next encoded element
604+
i += LEN_OF_HASH
605+
current_idx++
606+
default:
607+
if b >= 192 && b < 248 {
608+
length := b - 192
609+
if current_idx == idx {
610+
return branch[i+1 : i+1+length]
611+
}
612+
i += length
613+
current_idx++
593614
}
594615
}
595-
596-
i++
597616
}
617+
618+
return []byte{0}
598619
}
599620

600621
type StackProof struct {
601622
proofS [][]byte
602623
proofC [][]byte
603624
}
604625

626+
func (sp *StackProof) GetProofS() [][]byte {
627+
return sp.proofS
628+
}
629+
630+
func (sp *StackProof) GetProofC() [][]byte {
631+
return sp.proofC
632+
}
633+
605634
func (st *StackTrie) UpdateAndGetProof(db ethdb.KeyValueReader, indexBuf, value []byte) (StackProof, error) {
606635
proofS, err := st.GetProof(db, indexBuf)
607636
if err != nil {
@@ -618,6 +647,8 @@ func (st *StackTrie) UpdateAndGetProof(db ethdb.KeyValueReader, indexBuf, value
618647
return StackProof{proofS, proofC}, nil
619648
}
620649

650+
// We refer to the link below for this function.
651+
// https://github.com/ethereum/go-ethereum/blob/00905f7dc406cfb67f64cd74113777044fb886d8/core/types/hashing.go#L105-L134
621652
func (st *StackTrie) UpdateAndGetProofs(db ethdb.KeyValueReader, list types.DerivableList) ([]StackProof, error) {
622653
valueBuf := types.EncodeBufferPool.Get().(*bytes.Buffer)
623654
defer types.EncodeBufferPool.Put(valueBuf)
@@ -631,33 +662,40 @@ func (st *StackTrie) UpdateAndGetProofs(db ethdb.KeyValueReader, list types.Deri
631662
for i := 1; i < list.Len() && i <= 0x7f; i++ {
632663
indexBuf = rlp.AppendUint64(indexBuf[:0], uint64(i))
633664
value := types.EncodeForDerive(list, i, valueBuf)
634-
635665
proof, err := st.UpdateAndGetProof(db, indexBuf, value)
636666
if err != nil {
637667
return nil, err
638668
}
639-
640669
proofs = append(proofs, proof)
641670
}
671+
672+
// special case when index is 0
673+
// rlp.AppendUint64() encodes index 0 to [128]
642674
if list.Len() > 0 {
643675
indexBuf = rlp.AppendUint64(indexBuf[:0], 0)
644676
value := types.EncodeForDerive(list, 0, valueBuf)
645-
// TODO: get proof
646-
st.Update(indexBuf, value)
677+
proof, err := st.UpdateAndGetProof(db, indexBuf, value)
678+
if err != nil {
679+
return nil, err
680+
}
681+
proofs = append(proofs, proof)
647682
}
683+
648684
for i := 0x80; i < list.Len(); i++ {
649685
indexBuf = rlp.AppendUint64(indexBuf[:0], uint64(i))
650686
value := types.EncodeForDerive(list, i, valueBuf)
651-
// TODO: get proof
652-
st.Update(indexBuf, value)
687+
proof, err := st.UpdateAndGetProof(db, indexBuf, value)
688+
if err != nil {
689+
return nil, err
690+
}
691+
proofs = append(proofs, proof)
653692
}
654693

655694
return proofs, nil
656695
}
657696

658697
func (st *StackTrie) GetProof(db ethdb.KeyValueReader, key []byte) ([][]byte, error) {
659698
k := KeybytesToHex(key)
660-
661699
if st.nodeType == emptyNode {
662700
return [][]byte{}, nil
663701
}
@@ -682,7 +720,8 @@ func (st *StackTrie) GetProof(db ethdb.KeyValueReader, key []byte) ([][]byte, er
682720
for i := 0; i < len(k); i++ {
683721
if c.nodeType == extNode {
684722
nodes = append(nodes, c)
685-
c = st.children[0]
723+
c = c.children[0]
724+
686725
} else if c.nodeType == branchNode {
687726
nodes = append(nodes, c)
688727
c = c.children[k[i]]
@@ -700,11 +739,11 @@ func (st *StackTrie) GetProof(db ethdb.KeyValueReader, key []byte) ([][]byte, er
700739
}
701740

702741
proof = append(proof, c_rlp)
703-
branchChild := st.getNodeFromBranchRLP(c_rlp, k[i])
742+
branchChild := st.getNodeFromBranchRLP(c_rlp, int(k[i]))
704743

705744
// branchChild is of length 1 when there is no child at this position in the branch
706745
// (`branchChild = [128]` in this case), but it is also of length 1 when `c_rlp` is a leaf.
707-
if len(branchChild) == 1 {
746+
if len(branchChild) == 1 && (branchChild[0] == 128 || branchChild[0] == 0) {
708747
// no child at this position - 128 is RLP encoding for nil object
709748
break
710749
}

0 commit comments

Comments
 (0)