@@ -243,16 +243,17 @@ func (st *StackTrie) insert(key, value []byte) {
243
243
break
244
244
}
245
245
}
246
+
246
247
// Add new child
247
248
if st .children [idx ] == nil {
248
- st .children [idx ] = stackTrieFromPool (st .db )
249
- st .children [idx ].keyOffset = st .keyOffset + 1
249
+ st .children [idx ] = newLeaf (st .keyOffset + 1 , key , value , st .db )
250
+ } else {
251
+ st .children [idx ].insert (key , value )
250
252
}
251
- st . children [ idx ]. insert ( key , value )
253
+
252
254
case extNode : /* Ext */
253
255
// Compare both key chunks and see where they differ
254
256
diffidx := st .getDiffIndex (key )
255
-
256
257
// Check if chunks are identical. If so, recurse into
257
258
// the child node. Otherwise, the key has to be split
258
259
// into 1) an optional common prefix, 2) the fullnode
@@ -551,57 +552,85 @@ func (st *StackTrie) Commit() (common.Hash, error) {
551
552
return common .BytesToHash (st .val ), nil
552
553
}
553
554
554
- func (st * StackTrie ) getNodeFromBranchRLP (branch []byte , ind byte ) []byte {
555
- start := 2 // when branch[0] == 248
556
- if branch [0 ] == 249 {
557
- start = 3
558
- }
559
-
560
- i := 0
561
- insideInd := - 1
562
- cInd := byte (0 )
563
- for {
564
- if start + i == len (branch )- 1 { // -1 because of the last 128 (branch value)
565
- return []byte {0 }
566
- }
567
- b := branch [start + i ]
568
- if insideInd == - 1 && b == 128 {
569
- if cInd == ind {
555
+ const RLP_SHORT_STR_FLAG = 128
556
+ const RLP_SHORT_LIST_FLAG = 192
557
+ const RLP_LONG_LIST_FLAG = 248
558
+ const LEN_OF_HASH = 32
559
+
560
+ // Note:
561
+ // In RLP encoding, if the value is between [0x80, 0xb7] ([128, 183]),
562
+ // it means following data is a short string (0 - 55bytes).
563
+ // Which implies if the value is 128, it's an empty string.
564
+ func (st * StackTrie ) getNodeFromBranchRLP (branch []byte , idx int ) []byte {
565
+
566
+ start := int (branch [0 ])
567
+ start_idx := 0
568
+ if start >= RLP_SHORT_LIST_FLAG && start < RLP_LONG_LIST_FLAG {
569
+ // In RLP encoding, length in the range of [192 248] is a short list.
570
+ // In stack trie, it usually means an extension node and the first byte is nibble
571
+ // and that's why we start from 2
572
+ start_idx = 2
573
+ } else if start >= RLP_LONG_LIST_FLAG {
574
+ // In RLP encoding, length in the range of [248 ~ ] is a long list.
575
+ // The RLP byte minus 248 (branch[0] - 248) is the length in bytes of the length of the payload
576
+ // and the payload is right after the length.
577
+ // That's why we add 2 here
578
+ // e.g. [248 81 128 160 ...]
579
+ // `81` is the length of the payload and payload starts from `128`
580
+ start_idx = start - RLP_LONG_LIST_FLAG + 2
581
+ }
582
+
583
+ // If 1st node is neither 128(empty node) nor 160, it should be a leaf
584
+ b := int (branch [start_idx ])
585
+ if b != RLP_SHORT_STR_FLAG && b != (RLP_SHORT_STR_FLAG + LEN_OF_HASH ) {
586
+ return []byte {0 }
587
+ }
588
+
589
+ current_idx := 0
590
+ for i := start_idx ; i < len (branch ); i ++ {
591
+ b = int (branch [i ])
592
+ switch b {
593
+ case RLP_SHORT_STR_FLAG : // 128
594
+ // if the current index is we're looking for, return an empty node directly
595
+ if current_idx == idx {
570
596
return []byte {128 }
571
- } else {
572
- cInd += 1
573
597
}
574
- } else if insideInd == - 1 && b != 128 {
575
- if b == 160 {
576
- if cInd == ind {
577
- return branch [start + i + 1 : start + i + 1 + 32 ]
578
- }
579
- insideInd = 32
580
- } else {
581
- // non-hashed node
582
- if cInd == ind {
583
- return branch [start + i + 1 : start + i + 1 + int (b )- 192 ]
584
- }
585
- insideInd = int (b ) - 192
598
+ current_idx ++
599
+ case RLP_SHORT_STR_FLAG + LEN_OF_HASH : // 160
600
+ if current_idx == idx {
601
+ return branch [i + 1 : i + 1 + LEN_OF_HASH ]
586
602
}
587
- cInd += 1
588
- } else {
589
- if insideInd == 1 {
590
- insideInd = - 1
591
- } else {
592
- insideInd --
603
+ // jump to next encoded element
604
+ i += LEN_OF_HASH
605
+ current_idx ++
606
+ default :
607
+ if b >= 192 && b < 248 {
608
+ length := b - 192
609
+ if current_idx == idx {
610
+ return branch [i + 1 : i + 1 + length ]
611
+ }
612
+ i += length
613
+ current_idx ++
593
614
}
594
615
}
595
-
596
- i ++
597
616
}
617
+
618
+ return []byte {0 }
598
619
}
599
620
600
621
type StackProof struct {
601
622
proofS [][]byte
602
623
proofC [][]byte
603
624
}
604
625
626
+ func (sp * StackProof ) GetProofS () [][]byte {
627
+ return sp .proofS
628
+ }
629
+
630
+ func (sp * StackProof ) GetProofC () [][]byte {
631
+ return sp .proofC
632
+ }
633
+
605
634
func (st * StackTrie ) UpdateAndGetProof (db ethdb.KeyValueReader , indexBuf , value []byte ) (StackProof , error ) {
606
635
proofS , err := st .GetProof (db , indexBuf )
607
636
if err != nil {
@@ -618,6 +647,8 @@ func (st *StackTrie) UpdateAndGetProof(db ethdb.KeyValueReader, indexBuf, value
618
647
return StackProof {proofS , proofC }, nil
619
648
}
620
649
650
+ // We refer to the link below for this function.
651
+ // https://github.com/ethereum/go-ethereum/blob/00905f7dc406cfb67f64cd74113777044fb886d8/core/types/hashing.go#L105-L134
621
652
func (st * StackTrie ) UpdateAndGetProofs (db ethdb.KeyValueReader , list types.DerivableList ) ([]StackProof , error ) {
622
653
valueBuf := types .EncodeBufferPool .Get ().(* bytes.Buffer )
623
654
defer types .EncodeBufferPool .Put (valueBuf )
@@ -631,33 +662,40 @@ func (st *StackTrie) UpdateAndGetProofs(db ethdb.KeyValueReader, list types.Deri
631
662
for i := 1 ; i < list .Len () && i <= 0x7f ; i ++ {
632
663
indexBuf = rlp .AppendUint64 (indexBuf [:0 ], uint64 (i ))
633
664
value := types .EncodeForDerive (list , i , valueBuf )
634
-
635
665
proof , err := st .UpdateAndGetProof (db , indexBuf , value )
636
666
if err != nil {
637
667
return nil , err
638
668
}
639
-
640
669
proofs = append (proofs , proof )
641
670
}
671
+
672
+ // special case when index is 0
673
+ // rlp.AppendUint64() encodes index 0 to [128]
642
674
if list .Len () > 0 {
643
675
indexBuf = rlp .AppendUint64 (indexBuf [:0 ], 0 )
644
676
value := types .EncodeForDerive (list , 0 , valueBuf )
645
- // TODO: get proof
646
- st .Update (indexBuf , value )
677
+ proof , err := st .UpdateAndGetProof (db , indexBuf , value )
678
+ if err != nil {
679
+ return nil , err
680
+ }
681
+ proofs = append (proofs , proof )
647
682
}
683
+
648
684
for i := 0x80 ; i < list .Len (); i ++ {
649
685
indexBuf = rlp .AppendUint64 (indexBuf [:0 ], uint64 (i ))
650
686
value := types .EncodeForDerive (list , i , valueBuf )
651
- // TODO: get proof
652
- st .Update (indexBuf , value )
687
+ proof , err := st .UpdateAndGetProof (db , indexBuf , value )
688
+ if err != nil {
689
+ return nil , err
690
+ }
691
+ proofs = append (proofs , proof )
653
692
}
654
693
655
694
return proofs , nil
656
695
}
657
696
658
697
func (st * StackTrie ) GetProof (db ethdb.KeyValueReader , key []byte ) ([][]byte , error ) {
659
698
k := KeybytesToHex (key )
660
-
661
699
if st .nodeType == emptyNode {
662
700
return [][]byte {}, nil
663
701
}
@@ -682,7 +720,8 @@ func (st *StackTrie) GetProof(db ethdb.KeyValueReader, key []byte) ([][]byte, er
682
720
for i := 0 ; i < len (k ); i ++ {
683
721
if c .nodeType == extNode {
684
722
nodes = append (nodes , c )
685
- c = st .children [0 ]
723
+ c = c .children [0 ]
724
+
686
725
} else if c .nodeType == branchNode {
687
726
nodes = append (nodes , c )
688
727
c = c.children [k [i ]]
@@ -700,11 +739,11 @@ func (st *StackTrie) GetProof(db ethdb.KeyValueReader, key []byte) ([][]byte, er
700
739
}
701
740
702
741
proof = append (proof , c_rlp )
703
- branchChild := st .getNodeFromBranchRLP (c_rlp , k [i ])
742
+ branchChild := st .getNodeFromBranchRLP (c_rlp , int ( k [i ]) )
704
743
705
744
// branchChild is of length 1 when there is no child at this position in the branch
706
745
// (`branchChild = [128]` in this case), but it is also of length 1 when `c_rlp` is a leaf.
707
- if len (branchChild ) == 1 {
746
+ if len (branchChild ) == 1 && ( branchChild [ 0 ] == 128 || branchChild [ 0 ] == 0 ) {
708
747
// no child at this position - 128 is RLP encoding for nil object
709
748
break
710
749
}
0 commit comments