diff --git a/core/blockchain_reader.go b/core/blockchain_reader.go
index 96199417ee..70465df4b9 100644
--- a/core/blockchain_reader.go
+++ b/core/blockchain_reader.go
@@ -247,6 +247,20 @@ func (bc *BlockChain) HasBlockAndState(hash common.Hash, number uint64) bool {
return bc.HasState(block.Root())
}
+// ContractCodeWithPrefix retrieves a blob of data associated with a contract
+// hash either from ephemeral in-memory cache, or from persistent storage.
+//
+// If the code doesn't exist in the in-memory cache, check the storage with
+// new code scheme.
+func (bc *BlockChain) ContractCodeWithPrefix(hash common.Hash) ([]byte, error) {
+ type codeReader interface {
+ ContractCodeWithPrefix(address common.Address, codeHash common.Hash) ([]byte, error)
+ }
+ // TODO(rjl493456442) The associated account address is also required
+ // in Verkle scheme. Fix it once snap-sync is supported for Verkle.
+ return bc.stateCache.(codeReader).ContractCodeWithPrefix(common.Address{}, hash)
+}
+
// State returns a new mutable state based on the current HEAD block.
func (bc *BlockChain) State() (*state.StateDB, error) {
return bc.StateAt(bc.CurrentBlock().Root)
diff --git a/core/state/database.go b/core/state/database.go
index b810bf2c3d..dd6dd2f096 100644
--- a/core/state/database.go
+++ b/core/state/database.go
@@ -171,6 +171,10 @@ func (db *cachingDB) ContractCode(address common.Address, codeHash common.Hash)
return nil, errors.New("not found")
}
+func (db *cachingDB) ContractCodeWithPrefix(address common.Address, codeHash common.Hash) ([]byte, error) {
+ return db.ContractCode(address, codeHash)
+}
+
// ContractCodeSize retrieves a particular contracts code's size.
func (db *cachingDB) ContractCodeSize(addr common.Address, codeHash common.Hash) (int, error) {
if cached, ok := db.codeSizeCache.Get(codeHash); ok {
diff --git a/core/types/account.go b/core/types/account.go
index efc0927770..8965845107 100644
--- a/core/types/account.go
+++ b/core/types/account.go
@@ -17,71 +17,13 @@
package types
import (
- "bytes"
- "encoding/hex"
- "encoding/json"
- "fmt"
- "math/big"
-
- "github.com/ava-labs/libevm/common"
- "github.com/ava-labs/libevm/common/hexutil"
- "github.com/ava-labs/libevm/common/math"
+ ethtypes "github.com/ava-labs/libevm/core/types"
)
-//go:generate go run github.com/fjl/gencodec -type Account -field-override accountMarshaling -out gen_account.go
-
// Account represents an Ethereum account and its attached data.
// This type is used to specify accounts in the genesis block state, and
// is also useful for JSON encoding/decoding of accounts.
-type Account struct {
- Code []byte `json:"code,omitempty"`
- Storage map[common.Hash]common.Hash `json:"storage,omitempty"`
- Balance *big.Int `json:"balance" gencodec:"required"`
- Nonce uint64 `json:"nonce,omitempty"`
-
- // used in tests
- PrivateKey []byte `json:"secretKey,omitempty"`
-}
-
-type accountMarshaling struct {
- Code hexutil.Bytes
- Balance *math.HexOrDecimal256
- Nonce math.HexOrDecimal64
- Storage map[storageJSON]storageJSON
- PrivateKey hexutil.Bytes
-}
-
-// storageJSON represents a 256 bit byte array, but allows less than 256 bits when
-// unmarshaling from hex.
-type storageJSON common.Hash
-
-func (h *storageJSON) UnmarshalText(text []byte) error {
- text = bytes.TrimPrefix(text, []byte("0x"))
- if len(text) > 64 {
- return fmt.Errorf("too many hex characters in storage key/value %q", text)
- }
- offset := len(h) - len(text)/2 // pad on the left
- if _, err := hex.Decode(h[offset:], text); err != nil {
- return fmt.Errorf("invalid hex storage key/value %q", text)
- }
- return nil
-}
-
-func (h storageJSON) MarshalText() ([]byte, error) {
- return hexutil.Bytes(h[:]).MarshalText()
-}
+type Account = ethtypes.Account
// GenesisAlloc specifies the initial state of a genesis block.
-type GenesisAlloc map[common.Address]Account
-
-func (ga *GenesisAlloc) UnmarshalJSON(data []byte) error {
- m := make(map[common.UnprefixedAddress]Account)
- if err := json.Unmarshal(data, &m); err != nil {
- return err
- }
- *ga = make(GenesisAlloc)
- for addr, a := range m {
- (*ga)[common.Address(addr)] = a
- }
- return nil
-}
+type GenesisAlloc = ethtypes.GenesisAlloc
diff --git a/core/types/gen_account.go b/core/types/gen_account.go
deleted file mode 100644
index c3c7fb3fdf..0000000000
--- a/core/types/gen_account.go
+++ /dev/null
@@ -1,73 +0,0 @@
-// Code generated by github.com/fjl/gencodec. DO NOT EDIT.
-
-package types
-
-import (
- "encoding/json"
- "errors"
- "math/big"
-
- "github.com/ava-labs/libevm/common"
- "github.com/ava-labs/libevm/common/hexutil"
- "github.com/ava-labs/libevm/common/math"
-)
-
-var _ = (*accountMarshaling)(nil)
-
-// MarshalJSON marshals as JSON.
-func (a Account) MarshalJSON() ([]byte, error) {
- type Account struct {
- Code hexutil.Bytes `json:"code,omitempty"`
- Storage map[storageJSON]storageJSON `json:"storage,omitempty"`
- Balance *math.HexOrDecimal256 `json:"balance" gencodec:"required"`
- Nonce math.HexOrDecimal64 `json:"nonce,omitempty"`
- PrivateKey hexutil.Bytes `json:"secretKey,omitempty"`
- }
- var enc Account
- enc.Code = a.Code
- if a.Storage != nil {
- enc.Storage = make(map[storageJSON]storageJSON, len(a.Storage))
- for k, v := range a.Storage {
- enc.Storage[storageJSON(k)] = storageJSON(v)
- }
- }
- enc.Balance = (*math.HexOrDecimal256)(a.Balance)
- enc.Nonce = math.HexOrDecimal64(a.Nonce)
- enc.PrivateKey = a.PrivateKey
- return json.Marshal(&enc)
-}
-
-// UnmarshalJSON unmarshals from JSON.
-func (a *Account) UnmarshalJSON(input []byte) error {
- type Account struct {
- Code *hexutil.Bytes `json:"code,omitempty"`
- Storage map[storageJSON]storageJSON `json:"storage,omitempty"`
- Balance *math.HexOrDecimal256 `json:"balance" gencodec:"required"`
- Nonce *math.HexOrDecimal64 `json:"nonce,omitempty"`
- PrivateKey *hexutil.Bytes `json:"secretKey,omitempty"`
- }
- var dec Account
- if err := json.Unmarshal(input, &dec); err != nil {
- return err
- }
- if dec.Code != nil {
- a.Code = *dec.Code
- }
- if dec.Storage != nil {
- a.Storage = make(map[common.Hash]common.Hash, len(dec.Storage))
- for k, v := range dec.Storage {
- a.Storage[common.Hash(k)] = common.Hash(v)
- }
- }
- if dec.Balance == nil {
- return errors.New("missing required field 'balance' for Account")
- }
- a.Balance = (*big.Int)(dec.Balance)
- if dec.Nonce != nil {
- a.Nonce = uint64(*dec.Nonce)
- }
- if dec.PrivateKey != nil {
- a.PrivateKey = *dec.PrivateKey
- }
- return nil
-}
diff --git a/eth/protocols/snap/discovery.go b/eth/protocols/snap/discovery.go
new file mode 100644
index 0000000000..dece58eacc
--- /dev/null
+++ b/eth/protocols/snap/discovery.go
@@ -0,0 +1,32 @@
+// Copyright 2020 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see .
+
+package snap
+
+import (
+ "github.com/ava-labs/libevm/rlp"
+)
+
+// enrEntry is the ENR entry which advertises `snap` protocol on the discovery.
+type enrEntry struct {
+ // Ignore additional fields (for forward compatibility).
+ Rest []rlp.RawValue `rlp:"tail"`
+}
+
+// ENRKey implements enr.Entry.
+func (e enrEntry) ENRKey() string {
+ return "snap"
+}
diff --git a/eth/protocols/snap/handler.go b/eth/protocols/snap/handler.go
new file mode 100644
index 0000000000..4ceb4a1abc
--- /dev/null
+++ b/eth/protocols/snap/handler.go
@@ -0,0 +1,654 @@
+// Copyright 2020 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see .
+
+package snap
+
+import (
+ "bytes"
+ "context"
+ "fmt"
+ "time"
+
+ "github.com/ava-labs/coreth/core"
+ "github.com/ava-labs/coreth/metrics"
+ "github.com/ava-labs/coreth/plugin/evm/message"
+ "github.com/ava-labs/coreth/sync/handlers"
+ "github.com/ava-labs/coreth/sync/handlers/stats"
+ "github.com/ava-labs/libevm/common"
+ "github.com/ava-labs/libevm/core/types"
+ "github.com/ava-labs/libevm/log"
+ "github.com/ava-labs/libevm/p2p"
+ "github.com/ava-labs/libevm/p2p/enode"
+ "github.com/ava-labs/libevm/p2p/enr"
+ "github.com/ava-labs/libevm/rlp"
+ "github.com/ava-labs/libevm/trie"
+ "github.com/ava-labs/libevm/trie/trienode"
+)
+
+const (
+ maxLeavesLimit = uint16(2048)
+ stateKeyLength = common.HashLength
+)
+
+const (
+ // softResponseLimit is the target maximum size of replies to data retrievals.
+ softResponseLimit = 1 * 1024 * 1024
+
+ // maxCodeLookups is the maximum number of bytecodes to serve. This number is
+ // there to limit the number of disk lookups.
+ maxCodeLookups = 1024
+
+ // stateLookupSlack defines the ratio by how much a state response can exceed
+ // the requested limit in order to try and avoid breaking up contracts into
+ // multiple packages and proving them.
+ stateLookupSlack = 0.1
+
+ // maxTrieNodeLookups is the maximum number of state trie nodes to serve. This
+ // number is there to limit the number of disk lookups.
+ maxTrieNodeLookups = 1024
+
+ // maxTrieNodeTimeSpent is the maximum time we should spend on looking up trie nodes.
+ // If we spend too much time, then it's a fairly high chance of timing out
+ // at the remote side, which means all the work is in vain.
+ maxTrieNodeTimeSpent = 5 * time.Second
+)
+
+// Handler is a callback to invoke from an outside runner after the boilerplate
+// exchanges have passed.
+type Handler func(peer *Peer) error
+
+// Backend defines the data retrieval methods to serve remote requests and the
+// callback methods to invoke on remote deliveries.
+type Backend interface {
+ // Chain retrieves the blockchain object to serve data.
+ Chain() *core.BlockChain
+
+ // RunPeer is invoked when a peer joins on the `eth` protocol. The handler
+ // should do any peer maintenance work, handshakes and validations. If all
+ // is passed, control should be given back to the `handler` to process the
+ // inbound messages going forward.
+ RunPeer(peer *Peer, handler Handler) error
+
+ // PeerInfo retrieves all known `snap` information about a peer.
+ PeerInfo(id enode.ID) interface{}
+
+ // Handle is a callback to be invoked when a data packet is received from
+ // the remote peer. Only packets not consumed by the protocol handler will
+ // be forwarded to the backend.
+ Handle(peer *Peer, packet Packet) error
+}
+
+// MakeProtocols constructs the P2P protocol definitions for `snap`.
+func MakeProtocols(backend Backend, dnsdisc enode.Iterator) []p2p.Protocol {
+ // Filter the discovery iterator for nodes advertising snap support.
+ dnsdisc = enode.Filter(dnsdisc, func(n *enode.Node) bool {
+ var snap enrEntry
+ return n.Load(&snap) == nil
+ })
+
+ protocols := make([]p2p.Protocol, len(ProtocolVersions))
+ for i, version := range ProtocolVersions {
+ version := version // Closure
+
+ protocols[i] = p2p.Protocol{
+ Name: ProtocolName,
+ Version: version,
+ Length: protocolLengths[version],
+ Run: func(p *p2p.Peer, rw p2p.MsgReadWriter) error {
+ return backend.RunPeer(NewPeer(version, p, rw), func(peer *Peer) error {
+ return Handle(backend, peer)
+ })
+ },
+ NodeInfo: func() interface{} {
+ return nodeInfo(backend.Chain())
+ },
+ PeerInfo: func(id enode.ID) interface{} {
+ return backend.PeerInfo(id)
+ },
+ Attributes: []enr.Entry{&enrEntry{}},
+ DialCandidates: dnsdisc,
+ }
+ }
+ return protocols
+}
+
+// Handle is the callback invoked to manage the life cycle of a `snap` peer.
+// When this function terminates, the peer is disconnected.
+func Handle(backend Backend, peer *Peer) error {
+ for {
+ if err := HandleMessage(backend, peer); err != nil {
+ peer.Log().Debug("Message handling failed in `snap`", "err", err)
+ return err
+ }
+ }
+}
+
+// HandleMessage is invoked whenever an inbound message is received from a
+// remote peer on the `snap` protocol. The remote connection is torn down upon
+// returning any error.
+func HandleMessage(backend Backend, peer *Peer) error {
+ // Read the next message from the remote peer, and ensure it's fully consumed
+ msg, err := peer.rw.ReadMsg()
+ if err != nil {
+ return err
+ }
+ if msg.Size > maxMessageSize {
+ return fmt.Errorf("%w: %v > %v", errMsgTooLarge, msg.Size, maxMessageSize)
+ }
+ defer msg.Discard()
+ start := time.Now()
+ // Track the amount of time it takes to serve the request and run the handler
+ if metrics.Enabled {
+ h := fmt.Sprintf("%s/%s/%d/%#02x", p2p.HandleHistName, ProtocolName, peer.Version(), msg.Code)
+ defer func(start time.Time) {
+ sampler := func() metrics.Sample {
+ return metrics.ResettingSample(
+ metrics.NewExpDecaySample(1028, 0.015),
+ )
+ }
+ metrics.GetOrRegisterHistogramLazy(h, nil, sampler).Update(time.Since(start).Microseconds())
+ }(start)
+ }
+ // Handle the message depending on its contents
+ switch {
+ case msg.Code == GetAccountRangeMsg:
+ // Decode the account retrieval request
+ var req GetAccountRangePacket
+ if err := msg.Decode(&req); err != nil {
+ return fmt.Errorf("%w: message %v: %v", errDecode, msg, err)
+ }
+ // Service the request, potentially returning nothing in case of errors
+ accounts, proofs := ServiceGetAccountRangeQuery(backend.Chain(), &req)
+
+ // Send back anything accumulated (or empty in case of errors)
+ return p2p.Send(peer.rw, AccountRangeMsg, &AccountRangePacket{
+ ID: req.ID,
+ Accounts: accounts,
+ Proof: proofs,
+ })
+
+ case msg.Code == AccountRangeMsg:
+ // A range of accounts arrived to one of our previous requests
+ res := new(AccountRangePacket)
+ if err := msg.Decode(res); err != nil {
+ return fmt.Errorf("%w: message %v: %v", errDecode, msg, err)
+ }
+ // Ensure the range is monotonically increasing
+ for i := 1; i < len(res.Accounts); i++ {
+ if bytes.Compare(res.Accounts[i-1].Hash[:], res.Accounts[i].Hash[:]) >= 0 {
+ return fmt.Errorf("accounts not monotonically increasing: #%d [%x] vs #%d [%x]", i-1, res.Accounts[i-1].Hash[:], i, res.Accounts[i].Hash[:])
+ }
+ }
+ requestTracker.Fulfil(peer.id, peer.version, AccountRangeMsg, res.ID)
+
+ return backend.Handle(peer, res)
+
+ case msg.Code == GetStorageRangesMsg:
+ // Decode the storage retrieval request
+ var req GetStorageRangesPacket
+ if err := msg.Decode(&req); err != nil {
+ return fmt.Errorf("%w: message %v: %v", errDecode, msg, err)
+ }
+ // Service the request, potentially returning nothing in case of errors
+ slots, proofs := ServiceGetStorageRangesQuery(backend.Chain(), &req)
+
+ // Send back anything accumulated (or empty in case of errors)
+ return p2p.Send(peer.rw, StorageRangesMsg, &StorageRangesPacket{
+ ID: req.ID,
+ Slots: slots,
+ Proof: proofs,
+ })
+
+ case msg.Code == StorageRangesMsg:
+ // A range of storage slots arrived to one of our previous requests
+ res := new(StorageRangesPacket)
+ if err := msg.Decode(res); err != nil {
+ return fmt.Errorf("%w: message %v: %v", errDecode, msg, err)
+ }
+ // Ensure the ranges are monotonically increasing
+ for i, slots := range res.Slots {
+ for j := 1; j < len(slots); j++ {
+ if bytes.Compare(slots[j-1].Hash[:], slots[j].Hash[:]) >= 0 {
+ return fmt.Errorf("storage slots not monotonically increasing for account #%d: #%d [%x] vs #%d [%x]", i, j-1, slots[j-1].Hash[:], j, slots[j].Hash[:])
+ }
+ }
+ }
+ requestTracker.Fulfil(peer.id, peer.version, StorageRangesMsg, res.ID)
+
+ return backend.Handle(peer, res)
+
+ case msg.Code == GetByteCodesMsg:
+ // Decode bytecode retrieval request
+ var req GetByteCodesPacket
+ if err := msg.Decode(&req); err != nil {
+ return fmt.Errorf("%w: message %v: %v", errDecode, msg, err)
+ }
+ // Service the request, potentially returning nothing in case of errors
+ codes := ServiceGetByteCodesQuery(backend.Chain(), &req)
+
+ // Send back anything accumulated (or empty in case of errors)
+ return p2p.Send(peer.rw, ByteCodesMsg, &ByteCodesPacket{
+ ID: req.ID,
+ Codes: codes,
+ })
+
+ case msg.Code == ByteCodesMsg:
+ // A batch of byte codes arrived to one of our previous requests
+ res := new(ByteCodesPacket)
+ if err := msg.Decode(res); err != nil {
+ return fmt.Errorf("%w: message %v: %v", errDecode, msg, err)
+ }
+ requestTracker.Fulfil(peer.id, peer.version, ByteCodesMsg, res.ID)
+
+ return backend.Handle(peer, res)
+
+ case msg.Code == GetTrieNodesMsg:
+ // Decode trie node retrieval request
+ var req GetTrieNodesPacket
+ if err := msg.Decode(&req); err != nil {
+ return fmt.Errorf("%w: message %v: %v", errDecode, msg, err)
+ }
+ // Service the request, potentially returning nothing in case of errors
+ nodes, err := ServiceGetTrieNodesQuery(backend.Chain(), &req, start)
+ if err != nil {
+ return err
+ }
+ // Send back anything accumulated (or empty in case of errors)
+ return p2p.Send(peer.rw, TrieNodesMsg, &TrieNodesPacket{
+ ID: req.ID,
+ Nodes: nodes,
+ })
+
+ case msg.Code == TrieNodesMsg:
+ // A batch of trie nodes arrived to one of our previous requests
+ res := new(TrieNodesPacket)
+ if err := msg.Decode(res); err != nil {
+ return fmt.Errorf("%w: message %v: %v", errDecode, msg, err)
+ }
+ requestTracker.Fulfil(peer.id, peer.version, TrieNodesMsg, res.ID)
+
+ return backend.Handle(peer, res)
+
+ default:
+ return fmt.Errorf("%w: %v", errInvalidMsgCode, msg.Code)
+ }
+}
+
+// ServiceGetAccountRangeQuery assembles the response to an account range query.
+// It is exposed to allow external packages to test protocol behavior.
+func ServiceGetAccountRangeQuery(chain *core.BlockChain, req *GetAccountRangePacket) ([]*AccountData, [][]byte) {
+ if req.Bytes > softResponseLimit {
+ req.Bytes = softResponseLimit
+ }
+ // Retrieve the requested state and bail out if non existent
+ tr, err := trie.New(trie.StateTrieID(req.Root), chain.TrieDB())
+ if err != nil {
+ log.Debug("Failed to open account trie", "root", req.Root, "err", err)
+ return nil, nil
+ }
+ //nodeIt, err := tr.NodeIterator(req.Origin[:])
+ //if err != nil {
+ // log.Debug("Failed to iterate over account range", "origin", req.Origin, "err", err)
+ // return nil, nil
+ //}
+ //it := trie.NewIterator(nodeIt)
+ // XXX: restore these
+ // Patching in the existing response mechanism for now
+ leafsRequest := &message.LeafsRequest{
+ Root: req.Root,
+ Start: req.Origin[:],
+ End: req.Limit[:],
+ NodeType: message.StateTrieNode,
+ }
+ leafsResponse := &message.LeafsResponse{}
+ handlerStats := stats.NewNoopHandlerStats()
+ handler := handlers.NewResponseBuilder(
+ leafsRequest, leafsResponse, tr, chain.Snapshots(), stateKeyLength, maxLeavesLimit, handlerStats,
+ )
+ responseTime := 1200 * time.Millisecond // 2s is the AppRequest timeout, better be conservative here
+ ctx, cancel := context.WithTimeout(context.Background(), responseTime)
+ defer cancel() // don't leak a goroutine
+ handler.HandleRequest(ctx)
+
+ //it, err := chain.Snapshots().AccountIterator(req.Root, req.Origin, false)
+ //if err != nil {
+ // return nil, nil
+ //}
+ // Iterate over the requested range and pile accounts up
+ var (
+ accounts []*AccountData
+ size uint64
+ last common.Hash
+ )
+ for i, leafResponseKey := range leafsResponse.Keys {
+ //for it.Next() {
+ //hash, account := it.Hash(), common.CopyBytes(it.Account())
+ //hash, account := common.BytesToHash(it.Key), it.Value
+ hash, account := common.BytesToHash(leafResponseKey), leafsResponse.Vals[i]
+ acc := new(types.StateAccount)
+ if err := rlp.DecodeBytes(account, &acc); err != nil {
+ log.Warn("Failed to unmarshal account", "hash", hash, "err", err)
+ continue
+ }
+ account = types.SlimAccountRLP(*acc)
+
+ // Track the returned interval for the Merkle proofs
+ last = hash
+
+ // Assemble the reply item
+ size += uint64(common.HashLength + len(account))
+ accounts = append(accounts, &AccountData{
+ Hash: hash,
+ Body: account,
+ })
+ // If we've exceeded the request threshold, abort
+ if bytes.Compare(hash[:], req.Limit[:]) >= 0 {
+ break
+ }
+ if size > req.Bytes {
+ break
+ }
+ }
+ //it.Release()
+
+ // Generate the Merkle proofs for the first and last account
+ proof := trienode.NewProofSet()
+ if err := tr.Prove(req.Origin[:], proof); err != nil {
+ log.Warn("Failed to prove account range", "origin", req.Origin, "err", err)
+ return nil, nil
+ }
+ if last != (common.Hash{}) {
+ if err := tr.Prove(last[:], proof); err != nil {
+ log.Warn("Failed to prove account range", "last", last, "err", err)
+ return nil, nil
+ }
+ }
+ var proofs [][]byte
+ for _, blob := range proof.List() {
+ proofs = append(proofs, blob)
+ }
+ return accounts, proofs
+}
+
+func ServiceGetStorageRangesQuery(chain *core.BlockChain, req *GetStorageRangesPacket) ([][]*StorageData, [][]byte) {
+ if req.Bytes > softResponseLimit {
+ req.Bytes = softResponseLimit
+ }
+ // TODO(karalabe): Do we want to enforce > 0 accounts and 1 account if origin is set?
+ // TODO(karalabe): - Logging locally is not ideal as remote faults annoy the local user
+ // TODO(karalabe): - Dropping the remote peer is less flexible wrt client bugs (slow is better than non-functional)
+
+ // Calculate the hard limit at which to abort, even if mid storage trie
+ hardLimit := uint64(float64(req.Bytes) * (1 + stateLookupSlack))
+
+ // Retrieve storage ranges until the packet limit is reached
+ var (
+ slots [][]*StorageData
+ proofs [][]byte
+ size uint64
+ )
+ for _, account := range req.Accounts {
+ // If we've exceeded the requested data limit, abort without opening
+ // a new storage range (that we'd need to prove due to exceeded size)
+ if size >= req.Bytes {
+ break
+ }
+ // The first account might start from a different origin and end sooner
+ var origin common.Hash
+ if len(req.Origin) > 0 {
+ origin, req.Origin = common.BytesToHash(req.Origin), nil
+ }
+ var limit = common.MaxHash
+ if len(req.Limit) > 0 {
+ limit, req.Limit = common.BytesToHash(req.Limit), nil
+ }
+
+ // XXX: put this back
+ accTrie, err := trie.NewStateTrie(trie.StateTrieID(req.Root), chain.TrieDB())
+ if err != nil {
+ return nil, nil
+ }
+ acc, err := accTrie.GetAccountByHash(account)
+ if err != nil || acc == nil {
+ return nil, nil
+ }
+ id := trie.StorageTrieID(req.Root, account, acc.Root)
+ stTrie, err := trie.New(id, chain.TrieDB()) // XXX: put this back (was trie.NewStateTrie)
+ if err != nil {
+ return nil, nil
+ }
+ // Patching in the existing response mechanism for now
+ leafsRequest := &message.LeafsRequest{
+ Root: acc.Root,
+ Account: account,
+ Start: origin[:],
+ End: limit[:],
+ NodeType: message.StateTrieNode,
+ }
+ leafsResponse := &message.LeafsResponse{}
+ handlerStats := stats.NewNoopHandlerStats()
+ handler := handlers.NewResponseBuilder(
+ leafsRequest, leafsResponse, stTrie, chain.Snapshots(), stateKeyLength, maxLeavesLimit, handlerStats,
+ )
+ responseTime := 1200 * time.Millisecond // 2s is the AppRequest timeout, better be conservative here
+ ctx, cancel := context.WithTimeout(context.Background(), responseTime)
+ defer cancel() // don't leak a goroutine
+ handler.HandleRequest(ctx)
+
+ // XXX: put this back
+ // Retrieve the requested state and bail out if non existent
+ //it, err := chain.Snapshots().StorageIterator(req.Root, account, origin)
+ //if err != nil {
+ // return nil, nil
+ //}
+ //nodeIt, err := stTrie.NodeIterator(origin[:])
+ //if err != nil {
+ // log.Debug("Failed to iterate over storage range", "origin", origin, "err", err)
+ // return nil, nil
+ //}
+ //it := trie.NewIterator(nodeIt)
+
+ // Iterate over the requested range and pile slots up
+ var (
+ storage []*StorageData
+ last common.Hash
+ abort bool
+ )
+ //for it.Next() {
+ for i, leafResponseKey := range leafsResponse.Keys {
+ if size >= hardLimit {
+ abort = true
+ break
+ }
+ //hash, slot := it.Hash(), common.CopyBytes(it.Slot())
+ //hash, slot := common.BytesToHash(it.Key), common.CopyBytes(it.Value)
+ hash, slot := common.BytesToHash(leafResponseKey), leafsResponse.Vals[i]
+
+ // Track the returned interval for the Merkle proofs
+ last = hash
+
+ // Assemble the reply item
+ size += uint64(common.HashLength + len(slot))
+ storage = append(storage, &StorageData{
+ Hash: hash,
+ Body: slot,
+ })
+ // If we've exceeded the request threshold, abort
+ if bytes.Compare(hash[:], limit[:]) >= 0 {
+ break
+ }
+ }
+ abort = abort || leafsResponse.More
+
+ if len(storage) > 0 {
+ slots = append(slots, storage)
+ }
+ //it.Release()
+
+ // Generate the Merkle proofs for the first and last storage slot, but
+ // only if the response was capped. If the entire storage trie included
+ // in the response, no need for any proofs.
+ if origin != (common.Hash{}) || (abort && len(storage) > 0) {
+ // Request started at a non-zero hash or was capped prematurely, add
+ // the endpoint Merkle proofs
+ proof := trienode.NewProofSet()
+ if err := stTrie.Prove(origin[:], proof); err != nil {
+ log.Warn("Failed to prove storage range", "origin", req.Origin, "err", err)
+ return nil, nil
+ }
+ if last != (common.Hash{}) {
+ if err := stTrie.Prove(last[:], proof); err != nil {
+ log.Warn("Failed to prove storage range", "last", last, "err", err)
+ return nil, nil
+ }
+ }
+ for _, blob := range proof.List() {
+ proofs = append(proofs, blob)
+ }
+ // Proof terminates the reply as proofs are only added if a node
+ // refuses to serve more data (exception when a contract fetch is
+ // finishing, but that's that).
+ break
+ }
+ }
+ return slots, proofs
+}
+
+// ServiceGetByteCodesQuery assembles the response to a byte codes query.
+// It is exposed to allow external packages to test protocol behavior.
+func ServiceGetByteCodesQuery(chain *core.BlockChain, req *GetByteCodesPacket) [][]byte {
+ if req.Bytes > softResponseLimit {
+ req.Bytes = softResponseLimit
+ }
+ if len(req.Hashes) > maxCodeLookups {
+ req.Hashes = req.Hashes[:maxCodeLookups]
+ }
+ // Retrieve bytecodes until the packet size limit is reached
+ var (
+ codes [][]byte
+ bytes uint64
+ )
+ for _, hash := range req.Hashes {
+ if hash == types.EmptyCodeHash {
+ // Peers should not request the empty code, but if they do, at
+ // least sent them back a correct response without db lookups
+ codes = append(codes, []byte{})
+ } else if blob, err := chain.ContractCodeWithPrefix(hash); err == nil {
+ codes = append(codes, blob)
+ bytes += uint64(len(blob))
+ }
+ if bytes > req.Bytes {
+ break
+ }
+ }
+ return codes
+}
+
+// ServiceGetTrieNodesQuery assembles the response to a trie nodes query.
+// It is exposed to allow external packages to test protocol behavior.
+func ServiceGetTrieNodesQuery(chain *core.BlockChain, req *GetTrieNodesPacket, start time.Time) ([][]byte, error) {
+ if req.Bytes > softResponseLimit {
+ req.Bytes = softResponseLimit
+ }
+ // Make sure we have the state associated with the request
+ triedb := chain.TrieDB()
+
+ accTrie, err := trie.NewStateTrie(trie.StateTrieID(req.Root), triedb)
+ if err != nil {
+ // We don't have the requested state available, bail out
+ return nil, nil
+ }
+ // The 'snap' might be nil, in which case we cannot serve storage slots.
+ snap := chain.Snapshots().Snapshot(req.Root)
+ // Retrieve trie nodes until the packet size limit is reached
+ var (
+ nodes [][]byte
+ bytes uint64
+ loads int // Trie hash expansions to count database reads
+ )
+ for _, pathset := range req.Paths {
+ switch len(pathset) {
+ case 0:
+ // Ensure we penalize invalid requests
+ return nil, fmt.Errorf("%w: zero-item pathset requested", errBadRequest)
+
+ case 1:
+ // If we're only retrieving an account trie node, fetch it directly
+ blob, resolved, err := accTrie.GetNode(pathset[0])
+ loads += resolved // always account database reads, even for failures
+ if err != nil {
+ break
+ }
+ nodes = append(nodes, blob)
+ bytes += uint64(len(blob))
+
+ default:
+ var stRoot common.Hash
+ // Storage slots requested, open the storage trie and retrieve from there
+ if snap == nil {
+ // We don't have the requested state snapshotted yet (or it is stale),
+ // but can look up the account via the trie instead.
+ account, err := accTrie.GetAccountByHash(common.BytesToHash(pathset[0]))
+ loads += 8 // We don't know the exact cost of lookup, this is an estimate
+ if err != nil || account == nil {
+ break
+ }
+ stRoot = account.Root
+ } else {
+ account, err := snap.Account(common.BytesToHash(pathset[0]))
+ loads++ // always account database reads, even for failures
+ if err != nil || account == nil {
+ break
+ }
+ stRoot = common.BytesToHash(account.Root)
+ }
+ id := trie.StorageTrieID(req.Root, common.BytesToHash(pathset[0]), stRoot)
+ stTrie, err := trie.NewStateTrie(id, triedb)
+ loads++ // always account database reads, even for failures
+ if err != nil {
+ break
+ }
+ for _, path := range pathset[1:] {
+ blob, resolved, err := stTrie.GetNode(path)
+ loads += resolved // always account database reads, even for failures
+ if err != nil {
+ break
+ }
+ nodes = append(nodes, blob)
+ bytes += uint64(len(blob))
+
+ // Sanity check limits to avoid DoS on the store trie loads
+ if bytes > req.Bytes || loads > maxTrieNodeLookups || time.Since(start) > maxTrieNodeTimeSpent {
+ break
+ }
+ }
+ }
+ // Abort request processing if we've exceeded our limits
+ if bytes > req.Bytes || loads > maxTrieNodeLookups || time.Since(start) > maxTrieNodeTimeSpent {
+ break
+ }
+ }
+ return nodes, nil
+}
+
+// NodeInfo represents a short summary of the `snap` sub-protocol metadata
+// known about the host peer.
+type NodeInfo struct{}
+
+// nodeInfo retrieves some `snap` protocol metadata about the running host node.
+func nodeInfo(chain *core.BlockChain) *NodeInfo {
+ return &NodeInfo{}
+}
diff --git a/eth/protocols/snap/handler_fuzzing_test.go b/eth/protocols/snap/handler_fuzzing_test.go
new file mode 100644
index 0000000000..91b9aad6fb
--- /dev/null
+++ b/eth/protocols/snap/handler_fuzzing_test.go
@@ -0,0 +1,166 @@
+// Copyright 2021 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see .
+
+package snap
+
+import (
+ "bytes"
+ "encoding/binary"
+ "fmt"
+ "math/big"
+ "testing"
+ "time"
+
+ "github.com/ava-labs/coreth/consensus/dummy"
+ "github.com/ava-labs/coreth/core"
+ "github.com/ava-labs/coreth/core/rawdb"
+ "github.com/ava-labs/libevm/common"
+ "github.com/ava-labs/libevm/core/types"
+ "github.com/ava-labs/libevm/core/vm"
+ "github.com/ava-labs/libevm/p2p"
+ "github.com/ava-labs/libevm/p2p/enode"
+ "github.com/ava-labs/libevm/params"
+ "github.com/ava-labs/libevm/rlp"
+ fuzz "github.com/google/gofuzz"
+)
+
+func FuzzARange(f *testing.F) {
+ f.Fuzz(func(t *testing.T, data []byte) {
+ doFuzz(data, &GetAccountRangePacket{}, GetAccountRangeMsg)
+ })
+}
+
+func FuzzSRange(f *testing.F) {
+ f.Fuzz(func(t *testing.T, data []byte) {
+ doFuzz(data, &GetStorageRangesPacket{}, GetStorageRangesMsg)
+ })
+}
+
+func FuzzByteCodes(f *testing.F) {
+ f.Fuzz(func(t *testing.T, data []byte) {
+ doFuzz(data, &GetByteCodesPacket{}, GetByteCodesMsg)
+ })
+}
+
+func FuzzTrieNodes(f *testing.F) {
+ f.Fuzz(func(t *testing.T, data []byte) {
+ doFuzz(data, &GetTrieNodesPacket{}, GetTrieNodesMsg)
+ })
+}
+
+func doFuzz(input []byte, obj interface{}, code int) {
+ bc := getChain()
+ defer bc.Stop()
+ fuzz.NewFromGoFuzz(input).Fuzz(obj)
+ var data []byte
+ switch p := obj.(type) {
+ case *GetTrieNodesPacket:
+ p.Root = trieRoot
+ data, _ = rlp.EncodeToBytes(obj)
+ default:
+ data, _ = rlp.EncodeToBytes(obj)
+ }
+ cli := &dummyRW{
+ code: uint64(code),
+ data: data,
+ }
+ peer := NewFakePeer(65, "gazonk01", cli)
+ err := HandleMessage(&dummyBackend{bc}, peer)
+ switch {
+ case err == nil && cli.writeCount != 1:
+ panic(fmt.Sprintf("Expected 1 response, got %d", cli.writeCount))
+ case err != nil && cli.writeCount != 0:
+ panic(fmt.Sprintf("Expected 0 response, got %d", cli.writeCount))
+ }
+}
+
+var trieRoot common.Hash
+
+func getChain() *core.BlockChain {
+ ga := make(types.GenesisAlloc, 1000)
+ var a = make([]byte, 20)
+ var mkStorage = func(k, v int) (common.Hash, common.Hash) {
+ var kB = make([]byte, 32)
+ var vB = make([]byte, 32)
+ binary.LittleEndian.PutUint64(kB, uint64(k))
+ binary.LittleEndian.PutUint64(vB, uint64(v))
+ return common.BytesToHash(kB), common.BytesToHash(vB)
+ }
+ storage := make(map[common.Hash]common.Hash)
+ for i := 0; i < 10; i++ {
+ k, v := mkStorage(i, i)
+ storage[k] = v
+ }
+ for i := 0; i < 1000; i++ {
+ binary.LittleEndian.PutUint64(a, uint64(i+0xff))
+ acc := types.Account{Balance: big.NewInt(int64(i))}
+ if i%2 == 1 {
+ acc.Storage = storage
+ }
+ ga[common.BytesToAddress(a)] = acc
+ }
+ gspec := &core.Genesis{
+ Config: params.TestChainConfig,
+ Alloc: ga,
+ }
+ _, blocks, _, err := core.GenerateChainWithGenesis(gspec, dummy.NewFaker(), 2, 10, func(i int, gen *core.BlockGen) {})
+ cacheConf := &core.CacheConfig{
+ TrieCleanLimit: 0,
+ TrieDirtyLimit: 0,
+ // TrieTimeLimit: 5 * time.Minute,
+ // TrieCleanNoPrefetch: true,
+ SnapshotLimit: 100,
+ SnapshotWait: true,
+ }
+ if err != nil {
+ panic(err)
+ }
+ trieRoot = blocks[len(blocks)-1].Root()
+ bc, _ := core.NewBlockChain(rawdb.NewMemoryDatabase(), cacheConf, gspec, dummy.NewFaker(), vm.Config{}, common.Hash{}, false)
+ if _, err := bc.InsertChain(blocks); err != nil {
+ panic(err)
+ }
+ return bc
+}
+
+type dummyBackend struct {
+ chain *core.BlockChain
+}
+
+func (d *dummyBackend) Chain() *core.BlockChain { return d.chain }
+func (d *dummyBackend) RunPeer(*Peer, Handler) error { return nil }
+func (d *dummyBackend) PeerInfo(enode.ID) interface{} { return "Foo" }
+func (d *dummyBackend) Handle(*Peer, Packet) error { return nil }
+
+type dummyRW struct {
+ code uint64
+ data []byte
+ writeCount int
+}
+
+func (d *dummyRW) ReadMsg() (p2p.Msg, error) {
+ return p2p.Msg{
+ Code: d.code,
+ Payload: bytes.NewReader(d.data),
+ ReceivedAt: time.Now(),
+ Size: uint32(len(d.data)),
+ }, nil
+}
+
+func (d *dummyRW) WriteMsg(msg p2p.Msg) error {
+ d.writeCount++
+ return nil
+}
diff --git a/eth/protocols/snap/metrics.go b/eth/protocols/snap/metrics.go
new file mode 100644
index 0000000000..378a4cab8d
--- /dev/null
+++ b/eth/protocols/snap/metrics.go
@@ -0,0 +1,57 @@
+// Copyright 2023 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see .
+
+package snap
+
+import (
+ metrics "github.com/ava-labs/libevm/metrics"
+)
+
+var (
+ ingressRegistrationErrorName = "eth/protocols/snap/ingress/registration/error"
+ egressRegistrationErrorName = "eth/protocols/snap/egress/registration/error"
+
+ IngressRegistrationErrorMeter = metrics.NewRegisteredMeter(ingressRegistrationErrorName, nil)
+ EgressRegistrationErrorMeter = metrics.NewRegisteredMeter(egressRegistrationErrorName, nil)
+
+ // deletionGauge is the metric to track how many trie node deletions
+ // are performed in total during the sync process.
+ deletionGauge = metrics.NewRegisteredGauge("eth/protocols/snap/sync/delete", nil)
+
+ // lookupGauge is the metric to track how many trie node lookups are
+ // performed to determine if node needs to be deleted.
+ lookupGauge = metrics.NewRegisteredGauge("eth/protocols/snap/sync/lookup", nil)
+
+ // boundaryAccountNodesGauge is the metric to track how many boundary trie
+ // nodes in account trie are met.
+ boundaryAccountNodesGauge = metrics.NewRegisteredGauge("eth/protocols/snap/sync/boundary/account", nil)
+
+ // boundaryAccountNodesGauge is the metric to track how many boundary trie
+ // nodes in storage tries are met.
+ boundaryStorageNodesGauge = metrics.NewRegisteredGauge("eth/protocols/snap/sync/boundary/storage", nil)
+
+ // smallStorageGauge is the metric to track how many storages are small enough
+ // to retrieved in one or two request.
+ smallStorageGauge = metrics.NewRegisteredGauge("eth/protocols/snap/sync/storage/small", nil)
+
+ // largeStorageGauge is the metric to track how many storages are large enough
+ // to retrieved concurrently.
+ largeStorageGauge = metrics.NewRegisteredGauge("eth/protocols/snap/sync/storage/large", nil)
+
+ // skipStorageHealingGauge is the metric to track how many storages are retrieved
+ // in multiple requests but healing is not necessary.
+ skipStorageHealingGauge = metrics.NewRegisteredGauge("eth/protocols/snap/sync/storage/noheal", nil)
+)
diff --git a/eth/protocols/snap/peer.go b/eth/protocols/snap/peer.go
new file mode 100644
index 0000000000..89eb0d10dc
--- /dev/null
+++ b/eth/protocols/snap/peer.go
@@ -0,0 +1,133 @@
+// Copyright 2020 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see .
+
+package snap
+
+import (
+ "github.com/ava-labs/libevm/common"
+ "github.com/ava-labs/libevm/log"
+ "github.com/ava-labs/libevm/p2p"
+)
+
+// Peer is a collection of relevant information we have about a `snap` peer.
+type Peer struct {
+ id string // Unique ID for the peer, cached
+
+ *p2p.Peer // The embedded P2P package peer
+ rw p2p.MsgReadWriter // Input/output streams for snap
+ version uint // Protocol version negotiated
+
+ logger log.Logger // Contextual logger with the peer id injected
+}
+
+// NewPeer creates a wrapper for a network connection and negotiated protocol
+// version.
+func NewPeer(version uint, p *p2p.Peer, rw p2p.MsgReadWriter) *Peer {
+ id := p.ID().String()
+ return &Peer{
+ id: id,
+ Peer: p,
+ rw: rw,
+ version: version,
+ logger: log.New("peer", id[:8]),
+ }
+}
+
+// NewFakePeer creates a fake snap peer without a backing p2p peer, for testing purposes.
+func NewFakePeer(version uint, id string, rw p2p.MsgReadWriter) *Peer {
+ return &Peer{
+ id: id,
+ rw: rw,
+ version: version,
+ logger: log.New("peer", id[:8]),
+ }
+}
+
+// ID retrieves the peer's unique identifier.
+func (p *Peer) ID() string {
+ return p.id
+}
+
+// Version retrieves the peer's negotiated `snap` protocol version.
+func (p *Peer) Version() uint {
+ return p.version
+}
+
+// Log overrides the P2P logger with the higher level one containing only the id.
+func (p *Peer) Log() log.Logger {
+ return p.logger
+}
+
+// RequestAccountRange fetches a batch of accounts rooted in a specific account
+// trie, starting with the origin.
+func (p *Peer) RequestAccountRange(id uint64, root common.Hash, origin, limit common.Hash, bytes uint64) error {
+ p.logger.Trace("Fetching range of accounts", "reqid", id, "root", root, "origin", origin, "limit", limit, "bytes", common.StorageSize(bytes))
+
+ requestTracker.Track(p.id, p.version, GetAccountRangeMsg, AccountRangeMsg, id)
+ return p2p.Send(p.rw, GetAccountRangeMsg, &GetAccountRangePacket{
+ ID: id,
+ Root: root,
+ Origin: origin,
+ Limit: limit,
+ Bytes: bytes,
+ })
+}
+
+// RequestStorageRanges fetches a batch of storage slots belonging to one or more
+// accounts. If slots from only one account is requested, an origin marker may also
+// be used to retrieve from there.
+func (p *Peer) RequestStorageRanges(id uint64, root common.Hash, accounts []common.Hash, origin, limit []byte, bytes uint64) error {
+ if len(accounts) == 1 && origin != nil {
+ p.logger.Trace("Fetching range of large storage slots", "reqid", id, "root", root, "account", accounts[0], "origin", common.BytesToHash(origin), "limit", common.BytesToHash(limit), "bytes", common.StorageSize(bytes))
+ } else {
+ p.logger.Trace("Fetching ranges of small storage slots", "reqid", id, "root", root, "accounts", len(accounts), "first", accounts[0], "bytes", common.StorageSize(bytes))
+ }
+ requestTracker.Track(p.id, p.version, GetStorageRangesMsg, StorageRangesMsg, id)
+ return p2p.Send(p.rw, GetStorageRangesMsg, &GetStorageRangesPacket{
+ ID: id,
+ Root: root,
+ Accounts: accounts,
+ Origin: origin,
+ Limit: limit,
+ Bytes: bytes,
+ })
+}
+
+// RequestByteCodes fetches a batch of bytecodes by hash.
+func (p *Peer) RequestByteCodes(id uint64, hashes []common.Hash, bytes uint64) error {
+ p.logger.Trace("Fetching set of byte codes", "reqid", id, "hashes", len(hashes), "bytes", common.StorageSize(bytes))
+
+ requestTracker.Track(p.id, p.version, GetByteCodesMsg, ByteCodesMsg, id)
+ return p2p.Send(p.rw, GetByteCodesMsg, &GetByteCodesPacket{
+ ID: id,
+ Hashes: hashes,
+ Bytes: bytes,
+ })
+}
+
+// RequestTrieNodes fetches a batch of account or storage trie nodes rooted in
+// a specific state trie.
+func (p *Peer) RequestTrieNodes(id uint64, root common.Hash, paths []TrieNodePathSet, bytes uint64) error {
+ p.logger.Trace("Fetching set of trie nodes", "reqid", id, "root", root, "pathsets", len(paths), "bytes", common.StorageSize(bytes))
+
+ requestTracker.Track(p.id, p.version, GetTrieNodesMsg, TrieNodesMsg, id)
+ return p2p.Send(p.rw, GetTrieNodesMsg, &GetTrieNodesPacket{
+ ID: id,
+ Root: root,
+ Paths: paths,
+ Bytes: bytes,
+ })
+}
diff --git a/eth/protocols/snap/protocol.go b/eth/protocols/snap/protocol.go
new file mode 100644
index 0000000000..d615a08e0e
--- /dev/null
+++ b/eth/protocols/snap/protocol.go
@@ -0,0 +1,218 @@
+// Copyright 2020 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see .
+
+package snap
+
+import (
+ "errors"
+ "fmt"
+
+ "github.com/ava-labs/libevm/common"
+ "github.com/ava-labs/libevm/core/types"
+ "github.com/ava-labs/libevm/rlp"
+)
+
+// Constants to match up protocol versions and messages
+const (
+ SNAP1 = 1
+)
+
+// ProtocolName is the official short name of the `snap` protocol used during
+// devp2p capability negotiation.
+const ProtocolName = "snap"
+
+// ProtocolVersions are the supported versions of the `snap` protocol (first
+// is primary).
+var ProtocolVersions = []uint{SNAP1}
+
+// protocolLengths are the number of implemented message corresponding to
+// different protocol versions.
+var protocolLengths = map[uint]uint64{SNAP1: 8}
+
+// maxMessageSize is the maximum cap on the size of a protocol message.
+const maxMessageSize = 10 * 1024 * 1024
+
+const (
+ GetAccountRangeMsg = 0x00
+ AccountRangeMsg = 0x01
+ GetStorageRangesMsg = 0x02
+ StorageRangesMsg = 0x03
+ GetByteCodesMsg = 0x04
+ ByteCodesMsg = 0x05
+ GetTrieNodesMsg = 0x06
+ TrieNodesMsg = 0x07
+)
+
+var (
+ errMsgTooLarge = errors.New("message too long")
+ errDecode = errors.New("invalid message")
+ errInvalidMsgCode = errors.New("invalid message code")
+ errBadRequest = errors.New("bad request")
+)
+
+// Packet represents a p2p message in the `snap` protocol.
+type Packet interface {
+ Name() string // Name returns a string corresponding to the message type.
+ Kind() byte // Kind returns the message type.
+}
+
+// GetAccountRangePacket represents an account query.
+type GetAccountRangePacket struct {
+ ID uint64 // Request ID to match up responses with
+ Root common.Hash // Root hash of the account trie to serve
+ Origin common.Hash // Hash of the first account to retrieve
+ Limit common.Hash // Hash of the last account to retrieve
+ Bytes uint64 // Soft limit at which to stop returning data
+}
+
+// AccountRangePacket represents an account query response.
+type AccountRangePacket struct {
+ ID uint64 // ID of the request this is a response for
+ Accounts []*AccountData // List of consecutive accounts from the trie
+ Proof [][]byte // List of trie nodes proving the account range
+}
+
+// AccountData represents a single account in a query response.
+type AccountData struct {
+ Hash common.Hash // Hash of the account
+ Body rlp.RawValue // Account body in slim format
+}
+
+// Unpack retrieves the accounts from the range packet and converts from slim
+// wire representation to consensus format. The returned data is RLP encoded
+// since it's expected to be serialized to disk without further interpretation.
+//
+// Note, this method does a round of RLP decoding and reencoding, so only use it
+// once and cache the results if need be. Ideally discard the packet afterwards
+// to not double the memory use.
+func (p *AccountRangePacket) Unpack() ([]common.Hash, [][]byte, error) {
+ var (
+ hashes = make([]common.Hash, len(p.Accounts))
+ accounts = make([][]byte, len(p.Accounts))
+ )
+ for i, acc := range p.Accounts {
+ val, err := types.FullAccountRLP(acc.Body)
+ if err != nil {
+ return nil, nil, fmt.Errorf("invalid account %x: %v", acc.Body, err)
+ }
+ hashes[i], accounts[i] = acc.Hash, val
+ }
+ return hashes, accounts, nil
+}
+
+// GetStorageRangesPacket represents an storage slot query.
+type GetStorageRangesPacket struct {
+ ID uint64 // Request ID to match up responses with
+ Root common.Hash // Root hash of the account trie to serve
+ Accounts []common.Hash // Account hashes of the storage tries to serve
+ Origin []byte // Hash of the first storage slot to retrieve (large contract mode)
+ Limit []byte // Hash of the last storage slot to retrieve (large contract mode)
+ Bytes uint64 // Soft limit at which to stop returning data
+}
+
+// StorageRangesPacket represents a storage slot query response.
+type StorageRangesPacket struct {
+ ID uint64 // ID of the request this is a response for
+ Slots [][]*StorageData // Lists of consecutive storage slots for the requested accounts
+ Proof [][]byte // Merkle proofs for the *last* slot range, if it's incomplete
+}
+
+// StorageData represents a single storage slot in a query response.
+type StorageData struct {
+ Hash common.Hash // Hash of the storage slot
+ Body []byte // Data content of the slot
+}
+
+// Unpack retrieves the storage slots from the range packet and returns them in
+// a split flat format that's more consistent with the internal data structures.
+func (p *StorageRangesPacket) Unpack() ([][]common.Hash, [][][]byte) {
+ var (
+ hashset = make([][]common.Hash, len(p.Slots))
+ slotset = make([][][]byte, len(p.Slots))
+ )
+ for i, slots := range p.Slots {
+ hashset[i] = make([]common.Hash, len(slots))
+ slotset[i] = make([][]byte, len(slots))
+ for j, slot := range slots {
+ hashset[i][j] = slot.Hash
+ slotset[i][j] = slot.Body
+ }
+ }
+ return hashset, slotset
+}
+
+// GetByteCodesPacket represents a contract bytecode query.
+type GetByteCodesPacket struct {
+ ID uint64 // Request ID to match up responses with
+ Hashes []common.Hash // Code hashes to retrieve the code for
+ Bytes uint64 // Soft limit at which to stop returning data
+}
+
+// ByteCodesPacket represents a contract bytecode query response.
+type ByteCodesPacket struct {
+ ID uint64 // ID of the request this is a response for
+ Codes [][]byte // Requested contract bytecodes
+}
+
+// GetTrieNodesPacket represents a state trie node query.
+type GetTrieNodesPacket struct {
+ ID uint64 // Request ID to match up responses with
+ Root common.Hash // Root hash of the account trie to serve
+ Paths []TrieNodePathSet // Trie node hashes to retrieve the nodes for
+ Bytes uint64 // Soft limit at which to stop returning data
+}
+
+// TrieNodePathSet is a list of trie node paths to retrieve. A naive way to
+// represent trie nodes would be a simple list of `account || storage` path
+// segments concatenated, but that would be very wasteful on the network.
+//
+// Instead, this array special cases the first element as the path in the
+// account trie and the remaining elements as paths in the storage trie. To
+// address an account node, the slice should have a length of 1 consisting
+// of only the account path. There's no need to be able to address both an
+// account node and a storage node in the same request as it cannot happen
+// that a slot is accessed before the account path is fully expanded.
+type TrieNodePathSet [][]byte
+
+// TrieNodesPacket represents a state trie node query response.
+type TrieNodesPacket struct {
+ ID uint64 // ID of the request this is a response for
+ Nodes [][]byte // Requested state trie nodes
+}
+
+func (*GetAccountRangePacket) Name() string { return "GetAccountRange" }
+func (*GetAccountRangePacket) Kind() byte { return GetAccountRangeMsg }
+
+func (*AccountRangePacket) Name() string { return "AccountRange" }
+func (*AccountRangePacket) Kind() byte { return AccountRangeMsg }
+
+func (*GetStorageRangesPacket) Name() string { return "GetStorageRanges" }
+func (*GetStorageRangesPacket) Kind() byte { return GetStorageRangesMsg }
+
+func (*StorageRangesPacket) Name() string { return "StorageRanges" }
+func (*StorageRangesPacket) Kind() byte { return StorageRangesMsg }
+
+func (*GetByteCodesPacket) Name() string { return "GetByteCodes" }
+func (*GetByteCodesPacket) Kind() byte { return GetByteCodesMsg }
+
+func (*ByteCodesPacket) Name() string { return "ByteCodes" }
+func (*ByteCodesPacket) Kind() byte { return ByteCodesMsg }
+
+func (*GetTrieNodesPacket) Name() string { return "GetTrieNodes" }
+func (*GetTrieNodesPacket) Kind() byte { return GetTrieNodesMsg }
+
+func (*TrieNodesPacket) Name() string { return "TrieNodes" }
+func (*TrieNodesPacket) Kind() byte { return TrieNodesMsg }
diff --git a/eth/protocols/snap/range.go b/eth/protocols/snap/range.go
new file mode 100644
index 0000000000..ea7ceeb9e4
--- /dev/null
+++ b/eth/protocols/snap/range.go
@@ -0,0 +1,81 @@
+// Copyright 2021 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see .
+
+package snap
+
+import (
+ "math/big"
+
+ "github.com/ava-labs/libevm/common"
+ "github.com/holiman/uint256"
+)
+
+// hashRange is a utility to handle ranges of hashes, Split up the
+// hash-space into sections, and 'walk' over the sections
+type hashRange struct {
+ current *uint256.Int
+ step *uint256.Int
+}
+
+// newHashRange creates a new hashRange, initiated at the start position,
+// and with the step set to fill the desired 'num' chunks
+func newHashRange(start common.Hash, num uint64) *hashRange {
+ left := new(big.Int).Sub(hashSpace, start.Big())
+ step := new(big.Int).Div(
+ new(big.Int).Add(left, new(big.Int).SetUint64(num-1)),
+ new(big.Int).SetUint64(num),
+ )
+ step256 := new(uint256.Int)
+ step256.SetFromBig(step)
+
+ return &hashRange{
+ current: new(uint256.Int).SetBytes32(start[:]),
+ step: step256,
+ }
+}
+
+// Next pushes the hash range to the next interval.
+func (r *hashRange) Next() bool {
+ next, overflow := new(uint256.Int).AddOverflow(r.current, r.step)
+ if overflow {
+ return false
+ }
+ r.current = next
+ return true
+}
+
+// Start returns the first hash in the current interval.
+func (r *hashRange) Start() common.Hash {
+ return r.current.Bytes32()
+}
+
+// End returns the last hash in the current interval.
+func (r *hashRange) End() common.Hash {
+ // If the end overflows (non divisible range), return a shorter interval
+ next, overflow := new(uint256.Int).AddOverflow(r.current, r.step)
+ if overflow {
+ return common.MaxHash
+ }
+ return next.SubUint64(next, 1).Bytes32()
+}
+
+// incHash returns the next hash, in lexicographical order (a.k.a plus one)
+func incHash(h common.Hash) common.Hash {
+ var a uint256.Int
+ a.SetBytes32(h[:])
+ a.AddUint64(&a, 1)
+ return common.Hash(a.Bytes32())
+}
diff --git a/eth/protocols/snap/range_test.go b/eth/protocols/snap/range_test.go
new file mode 100644
index 0000000000..7aca87d778
--- /dev/null
+++ b/eth/protocols/snap/range_test.go
@@ -0,0 +1,143 @@
+// Copyright 2021 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see .
+
+package snap
+
+import (
+ "testing"
+
+ "github.com/ava-labs/libevm/common"
+)
+
+// Tests that given a starting hash and a density, the hash ranger can correctly
+// split up the remaining hash space into a fixed number of chunks.
+func TestHashRanges(t *testing.T) {
+ tests := []struct {
+ head common.Hash
+ chunks uint64
+ starts []common.Hash
+ ends []common.Hash
+ }{
+ // Simple test case to split the entire hash range into 4 chunks
+ {
+ head: common.Hash{},
+ chunks: 4,
+ starts: []common.Hash{
+ {},
+ common.HexToHash("0x4000000000000000000000000000000000000000000000000000000000000000"),
+ common.HexToHash("0x8000000000000000000000000000000000000000000000000000000000000000"),
+ common.HexToHash("0xc000000000000000000000000000000000000000000000000000000000000000"),
+ },
+ ends: []common.Hash{
+ common.HexToHash("0x3fffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff"),
+ common.HexToHash("0x7fffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff"),
+ common.HexToHash("0xbfffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff"),
+ common.MaxHash,
+ },
+ },
+ // Split a divisible part of the hash range up into 2 chunks
+ {
+ head: common.HexToHash("0x2000000000000000000000000000000000000000000000000000000000000000"),
+ chunks: 2,
+ starts: []common.Hash{
+ {},
+ common.HexToHash("0x9000000000000000000000000000000000000000000000000000000000000000"),
+ },
+ ends: []common.Hash{
+ common.HexToHash("0x8fffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff"),
+ common.MaxHash,
+ },
+ },
+ // Split the entire hash range into a non divisible 3 chunks
+ {
+ head: common.Hash{},
+ chunks: 3,
+ starts: []common.Hash{
+ {},
+ common.HexToHash("0x5555555555555555555555555555555555555555555555555555555555555556"),
+ common.HexToHash("0xaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaac"),
+ },
+ ends: []common.Hash{
+ common.HexToHash("0x5555555555555555555555555555555555555555555555555555555555555555"),
+ common.HexToHash("0xaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaab"),
+ common.MaxHash,
+ },
+ },
+ // Split a part of hash range into a non divisible 3 chunks
+ {
+ head: common.HexToHash("0x2000000000000000000000000000000000000000000000000000000000000000"),
+ chunks: 3,
+ starts: []common.Hash{
+ {},
+ common.HexToHash("0x6aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaab"),
+ common.HexToHash("0xb555555555555555555555555555555555555555555555555555555555555556"),
+ },
+ ends: []common.Hash{
+ common.HexToHash("0x6aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"),
+ common.HexToHash("0xb555555555555555555555555555555555555555555555555555555555555555"),
+ common.MaxHash,
+ },
+ },
+ // Split a part of hash range into a non divisible 3 chunks, but with a
+ // meaningful space size for manual verification.
+ // - The head being 0xff...f0, we have 14 hashes left in the space
+ // - Chunking up 14 into 3 pieces is 4.(6), but we need the ceil of 5 to avoid a micro-last-chunk
+ // - Since the range is not divisible, the last interval will be shorter, capped at 0xff...f
+ // - The chunk ranges thus needs to be [..0, ..5], [..6, ..b], [..c, ..f]
+ {
+ head: common.HexToHash("0xfffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff0"),
+ chunks: 3,
+ starts: []common.Hash{
+ {},
+ common.HexToHash("0xfffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff6"),
+ common.HexToHash("0xfffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffc"),
+ },
+ ends: []common.Hash{
+ common.HexToHash("0xfffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff5"),
+ common.HexToHash("0xfffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffb"),
+ common.MaxHash,
+ },
+ },
+ }
+ for i, tt := range tests {
+ r := newHashRange(tt.head, tt.chunks)
+
+ var (
+ starts = []common.Hash{{}}
+ ends = []common.Hash{r.End()}
+ )
+ for r.Next() {
+ starts = append(starts, r.Start())
+ ends = append(ends, r.End())
+ }
+ if len(starts) != len(tt.starts) {
+ t.Errorf("test %d: starts count mismatch: have %d, want %d", i, len(starts), len(tt.starts))
+ }
+ for j := 0; j < len(starts) && j < len(tt.starts); j++ {
+ if starts[j] != tt.starts[j] {
+ t.Errorf("test %d, start %d: hash mismatch: have %x, want %x", i, j, starts[j], tt.starts[j])
+ }
+ }
+ if len(ends) != len(tt.ends) {
+ t.Errorf("test %d: ends count mismatch: have %d, want %d", i, len(ends), len(tt.ends))
+ }
+ for j := 0; j < len(ends) && j < len(tt.ends); j++ {
+ if ends[j] != tt.ends[j] {
+ t.Errorf("test %d, end %d: hash mismatch: have %x, want %x", i, j, ends[j], tt.ends[j])
+ }
+ }
+ }
+}
diff --git a/eth/protocols/snap/sort_test.go b/eth/protocols/snap/sort_test.go
new file mode 100644
index 0000000000..2d51de44d9
--- /dev/null
+++ b/eth/protocols/snap/sort_test.go
@@ -0,0 +1,101 @@
+// Copyright 2022 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see .
+
+package snap
+
+import (
+ "bytes"
+ "fmt"
+ "testing"
+
+ "github.com/ava-labs/libevm/common"
+)
+
+func hexToNibbles(s string) []byte {
+ if len(s) >= 2 && s[0] == '0' && s[1] == 'x' {
+ s = s[2:]
+ }
+ var s2 []byte
+ for _, ch := range []byte(s) {
+ s2 = append(s2, '0')
+ s2 = append(s2, ch)
+ }
+ return common.Hex2Bytes(string(s2))
+}
+
+func TestRequestSorting(t *testing.T) {
+ // - Path 0x9 -> {0x19}
+ // - Path 0x99 -> {0x0099}
+ // - Path 0x01234567890123456789012345678901012345678901234567890123456789019 -> {0x0123456789012345678901234567890101234567890123456789012345678901, 0x19}
+ // - Path 0x012345678901234567890123456789010123456789012345678901234567890199 -> {0x0123456789012345678901234567890101234567890123456789012345678901, 0x0099}
+ var f = func(path string) string {
+ data := hexToNibbles(path)
+ return string(data)
+ }
+ var (
+ hashes []common.Hash
+ paths []string
+ )
+ for _, x := range []string{
+ "0x9",
+ "0x012345678901234567890123456789010123456789012345678901234567890195",
+ "0x012345678901234567890123456789010123456789012345678901234567890197",
+ "0x012345678901234567890123456789010123456789012345678901234567890196",
+ "0x99",
+ "0x012345678901234567890123456789010123456789012345678901234567890199",
+ "0x01234567890123456789012345678901012345678901234567890123456789019",
+ "0x0123456789012345678901234567890101234567890123456789012345678901",
+ "0x01234567890123456789012345678901012345678901234567890123456789010",
+ "0x01234567890123456789012345678901012345678901234567890123456789011",
+ } {
+ paths = append(paths, f(x))
+ hashes = append(hashes, common.Hash{})
+ }
+ _, _, syncPaths, pathsets := sortByAccountPath(paths, hashes)
+ {
+ var b = new(bytes.Buffer)
+ for i := 0; i < len(syncPaths); i++ {
+ fmt.Fprintf(b, "\n%d. paths %x", i, syncPaths[i])
+ }
+ want := `
+0. paths [0099]
+1. paths [0123456789012345678901234567890101234567890123456789012345678901 00]
+2. paths [0123456789012345678901234567890101234567890123456789012345678901 0095]
+3. paths [0123456789012345678901234567890101234567890123456789012345678901 0096]
+4. paths [0123456789012345678901234567890101234567890123456789012345678901 0097]
+5. paths [0123456789012345678901234567890101234567890123456789012345678901 0099]
+6. paths [0123456789012345678901234567890101234567890123456789012345678901 10]
+7. paths [0123456789012345678901234567890101234567890123456789012345678901 11]
+8. paths [0123456789012345678901234567890101234567890123456789012345678901 19]
+9. paths [19]`
+ if have := b.String(); have != want {
+ t.Errorf("have:%v\nwant:%v\n", have, want)
+ }
+ }
+ {
+ var b = new(bytes.Buffer)
+ for i := 0; i < len(pathsets); i++ {
+ fmt.Fprintf(b, "\n%d. pathset %x", i, pathsets[i])
+ }
+ want := `
+0. pathset [0099]
+1. pathset [0123456789012345678901234567890101234567890123456789012345678901 00 0095 0096 0097 0099 10 11 19]
+2. pathset [19]`
+ if have := b.String(); have != want {
+ t.Errorf("have:%v\nwant:%v\n", have, want)
+ }
+ }
+}
diff --git a/eth/protocols/snap/sync.go b/eth/protocols/snap/sync.go
new file mode 100644
index 0000000000..d47580bc17
--- /dev/null
+++ b/eth/protocols/snap/sync.go
@@ -0,0 +1,3210 @@
+// Copyright 2020 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see .
+
+package snap
+
+import (
+ "bytes"
+ "encoding/json"
+ "errors"
+ "fmt"
+ gomath "math"
+ "math/big"
+ "math/rand"
+ "sort"
+ "sync"
+ "sync/atomic"
+ "time"
+
+ "github.com/ava-labs/libevm/common"
+ "github.com/ava-labs/libevm/common/math"
+ "github.com/ava-labs/libevm/core/rawdb"
+ "github.com/ava-labs/libevm/core/state"
+ "github.com/ava-labs/libevm/core/types"
+ "github.com/ava-labs/libevm/crypto"
+ "github.com/ava-labs/libevm/ethdb"
+ "github.com/ava-labs/libevm/event"
+ "github.com/ava-labs/libevm/log"
+ "github.com/ava-labs/libevm/p2p/msgrate"
+ "github.com/ava-labs/libevm/rlp"
+ "github.com/ava-labs/libevm/trie"
+ "github.com/ava-labs/libevm/trie/trienode"
+ "golang.org/x/crypto/sha3"
+)
+
+const (
+ // minRequestSize is the minimum number of bytes to request from a remote peer.
+ // This number is used as the low cap for account and storage range requests.
+ // Bytecode and trienode are limited inherently by item count (1).
+ minRequestSize = 64 * 1024
+
+ // maxRequestSize is the maximum number of bytes to request from a remote peer.
+ // This number is used as the high cap for account and storage range requests.
+ // Bytecode and trienode are limited more explicitly by the caps below.
+ maxRequestSize = 512 * 1024
+
+ // maxCodeRequestCount is the maximum number of bytecode blobs to request in a
+ // single query. If this number is too low, we're not filling responses fully
+ // and waste round trip times. If it's too high, we're capping responses and
+ // waste bandwidth.
+ //
+ // Deployed bytecodes are currently capped at 24KB, so the minimum request
+ // size should be maxRequestSize / 24K. Assuming that most contracts do not
+ // come close to that, requesting 4x should be a good approximation.
+ maxCodeRequestCount = maxRequestSize / (24 * 1024) * 4
+
+ // maxTrieRequestCount is the maximum number of trie node blobs to request in
+ // a single query. If this number is too low, we're not filling responses fully
+ // and waste round trip times. If it's too high, we're capping responses and
+ // waste bandwidth.
+ maxTrieRequestCount = maxRequestSize / 512
+
+ // trienodeHealRateMeasurementImpact is the impact a single measurement has on
+ // the local node's trienode processing capacity. A value closer to 0 reacts
+ // slower to sudden changes, but it is also more stable against temporary hiccups.
+ trienodeHealRateMeasurementImpact = 0.005
+
+ // minTrienodeHealThrottle is the minimum divisor for throttling trie node
+ // heal requests to avoid overloading the local node and excessively expanding
+ // the state trie breadth wise.
+ minTrienodeHealThrottle = 1
+
+ // maxTrienodeHealThrottle is the maximum divisor for throttling trie node
+ // heal requests to avoid overloading the local node and exessively expanding
+ // the state trie bedth wise.
+ maxTrienodeHealThrottle = maxTrieRequestCount
+
+ // trienodeHealThrottleIncrease is the multiplier for the throttle when the
+ // rate of arriving data is higher than the rate of processing it.
+ trienodeHealThrottleIncrease = 1.33
+
+ // trienodeHealThrottleDecrease is the divisor for the throttle when the
+ // rate of arriving data is lower than the rate of processing it.
+ trienodeHealThrottleDecrease = 1.25
+)
+
+var (
+ // accountConcurrency is the number of chunks to split the account trie into
+ // to allow concurrent retrievals.
+ accountConcurrency = 16
+
+ // storageConcurrency is the number of chunks to split the a large contract
+ // storage trie into to allow concurrent retrievals.
+ storageConcurrency = 16
+)
+
+// ErrCancelled is returned from snap syncing if the operation was prematurely
+// terminated.
+var ErrCancelled = errors.New("sync cancelled")
+
+// accountRequest tracks a pending account range request to ensure responses are
+// to actual requests and to validate any security constraints.
+//
+// Concurrency note: account requests and responses are handled concurrently from
+// the main runloop to allow Merkle proof verifications on the peer's thread and
+// to drop on invalid response. The request struct must contain all the data to
+// construct the response without accessing runloop internals (i.e. task). That
+// is only included to allow the runloop to match a response to the task being
+// synced without having yet another set of maps.
+type accountRequest struct {
+ peer string // Peer to which this request is assigned
+ id uint64 // Request ID of this request
+ time time.Time // Timestamp when the request was sent
+
+ deliver chan *accountResponse // Channel to deliver successful response on
+ revert chan *accountRequest // Channel to deliver request failure on
+ cancel chan struct{} // Channel to track sync cancellation
+ timeout *time.Timer // Timer to track delivery timeout
+ stale chan struct{} // Channel to signal the request was dropped
+
+ origin common.Hash // First account requested to allow continuation checks
+ limit common.Hash // Last account requested to allow non-overlapping chunking
+
+ task *accountTask // Task which this request is filling (only access fields through the runloop!!)
+}
+
+// accountResponse is an already Merkle-verified remote response to an account
+// range request. It contains the subtrie for the requested account range and
+// the database that's going to be filled with the internal nodes on commit.
+type accountResponse struct {
+ task *accountTask // Task which this request is filling
+
+ hashes []common.Hash // Account hashes in the returned range
+ accounts []*types.StateAccount // Expanded accounts in the returned range
+
+ cont bool // Whether the account range has a continuation
+}
+
+// bytecodeRequest tracks a pending bytecode request to ensure responses are to
+// actual requests and to validate any security constraints.
+//
+// Concurrency note: bytecode requests and responses are handled concurrently from
+// the main runloop to allow Keccak256 hash verifications on the peer's thread and
+// to drop on invalid response. The request struct must contain all the data to
+// construct the response without accessing runloop internals (i.e. task). That
+// is only included to allow the runloop to match a response to the task being
+// synced without having yet another set of maps.
+type bytecodeRequest struct {
+ peer string // Peer to which this request is assigned
+ id uint64 // Request ID of this request
+ time time.Time // Timestamp when the request was sent
+
+ deliver chan *bytecodeResponse // Channel to deliver successful response on
+ revert chan *bytecodeRequest // Channel to deliver request failure on
+ cancel chan struct{} // Channel to track sync cancellation
+ timeout *time.Timer // Timer to track delivery timeout
+ stale chan struct{} // Channel to signal the request was dropped
+
+ hashes []common.Hash // Bytecode hashes to validate responses
+ task *accountTask // Task which this request is filling (only access fields through the runloop!!)
+}
+
+// bytecodeResponse is an already verified remote response to a bytecode request.
+type bytecodeResponse struct {
+ task *accountTask // Task which this request is filling
+
+ hashes []common.Hash // Hashes of the bytecode to avoid double hashing
+ codes [][]byte // Actual bytecodes to store into the database (nil = missing)
+}
+
+// storageRequest tracks a pending storage ranges request to ensure responses are
+// to actual requests and to validate any security constraints.
+//
+// Concurrency note: storage requests and responses are handled concurrently from
+// the main runloop to allow Merkle proof verifications on the peer's thread and
+// to drop on invalid response. The request struct must contain all the data to
+// construct the response without accessing runloop internals (i.e. tasks). That
+// is only included to allow the runloop to match a response to the task being
+// synced without having yet another set of maps.
+type storageRequest struct {
+ peer string // Peer to which this request is assigned
+ id uint64 // Request ID of this request
+ time time.Time // Timestamp when the request was sent
+
+ deliver chan *storageResponse // Channel to deliver successful response on
+ revert chan *storageRequest // Channel to deliver request failure on
+ cancel chan struct{} // Channel to track sync cancellation
+ timeout *time.Timer // Timer to track delivery timeout
+ stale chan struct{} // Channel to signal the request was dropped
+
+ accounts []common.Hash // Account hashes to validate responses
+ roots []common.Hash // Storage roots to validate responses
+
+ origin common.Hash // First storage slot requested to allow continuation checks
+ limit common.Hash // Last storage slot requested to allow non-overlapping chunking
+
+ mainTask *accountTask // Task which this response belongs to (only access fields through the runloop!!)
+ subTask *storageTask // Task which this response is filling (only access fields through the runloop!!)
+}
+
+// storageResponse is an already Merkle-verified remote response to a storage
+// range request. It contains the subtries for the requested storage ranges and
+// the databases that's going to be filled with the internal nodes on commit.
+type storageResponse struct {
+ mainTask *accountTask // Task which this response belongs to
+ subTask *storageTask // Task which this response is filling
+
+ accounts []common.Hash // Account hashes requested, may be only partially filled
+ roots []common.Hash // Storage roots requested, may be only partially filled
+
+ hashes [][]common.Hash // Storage slot hashes in the returned range
+ slots [][][]byte // Storage slot values in the returned range
+
+ cont bool // Whether the last storage range has a continuation
+}
+
+// trienodeHealRequest tracks a pending state trie request to ensure responses
+// are to actual requests and to validate any security constraints.
+//
+// Concurrency note: trie node requests and responses are handled concurrently from
+// the main runloop to allow Keccak256 hash verifications on the peer's thread and
+// to drop on invalid response. The request struct must contain all the data to
+// construct the response without accessing runloop internals (i.e. task). That
+// is only included to allow the runloop to match a response to the task being
+// synced without having yet another set of maps.
+type trienodeHealRequest struct {
+ peer string // Peer to which this request is assigned
+ id uint64 // Request ID of this request
+ time time.Time // Timestamp when the request was sent
+
+ deliver chan *trienodeHealResponse // Channel to deliver successful response on
+ revert chan *trienodeHealRequest // Channel to deliver request failure on
+ cancel chan struct{} // Channel to track sync cancellation
+ timeout *time.Timer // Timer to track delivery timeout
+ stale chan struct{} // Channel to signal the request was dropped
+
+ paths []string // Trie node paths for identifying trie node
+ hashes []common.Hash // Trie node hashes to validate responses
+
+ task *healTask // Task which this request is filling (only access fields through the runloop!!)
+}
+
+// trienodeHealResponse is an already verified remote response to a trie node request.
+type trienodeHealResponse struct {
+ task *healTask // Task which this request is filling
+
+ paths []string // Paths of the trie nodes
+ hashes []common.Hash // Hashes of the trie nodes to avoid double hashing
+ nodes [][]byte // Actual trie nodes to store into the database (nil = missing)
+}
+
+// bytecodeHealRequest tracks a pending bytecode request to ensure responses are to
+// actual requests and to validate any security constraints.
+//
+// Concurrency note: bytecode requests and responses are handled concurrently from
+// the main runloop to allow Keccak256 hash verifications on the peer's thread and
+// to drop on invalid response. The request struct must contain all the data to
+// construct the response without accessing runloop internals (i.e. task). That
+// is only included to allow the runloop to match a response to the task being
+// synced without having yet another set of maps.
+type bytecodeHealRequest struct {
+ peer string // Peer to which this request is assigned
+ id uint64 // Request ID of this request
+ time time.Time // Timestamp when the request was sent
+
+ deliver chan *bytecodeHealResponse // Channel to deliver successful response on
+ revert chan *bytecodeHealRequest // Channel to deliver request failure on
+ cancel chan struct{} // Channel to track sync cancellation
+ timeout *time.Timer // Timer to track delivery timeout
+ stale chan struct{} // Channel to signal the request was dropped
+
+ hashes []common.Hash // Bytecode hashes to validate responses
+ task *healTask // Task which this request is filling (only access fields through the runloop!!)
+}
+
+// bytecodeHealResponse is an already verified remote response to a bytecode request.
+type bytecodeHealResponse struct {
+ task *healTask // Task which this request is filling
+
+ hashes []common.Hash // Hashes of the bytecode to avoid double hashing
+ codes [][]byte // Actual bytecodes to store into the database (nil = missing)
+}
+
+// accountTask represents the sync task for a chunk of the account snapshot.
+type accountTask struct {
+ // These fields get serialized to leveldb on shutdown
+ Next common.Hash // Next account to sync in this interval
+ Last common.Hash // Last account to sync in this interval
+ SubTasks map[common.Hash][]*storageTask // Storage intervals needing fetching for large contracts
+
+ // These fields are internals used during runtime
+ req *accountRequest // Pending request to fill this task
+ res *accountResponse // Validate response filling this task
+ pend int // Number of pending subtasks for this round
+
+ needCode []bool // Flags whether the filling accounts need code retrieval
+ needState []bool // Flags whether the filling accounts need storage retrieval
+ needHeal []bool // Flags whether the filling accounts's state was chunked and need healing
+
+ codeTasks map[common.Hash]struct{} // Code hashes that need retrieval
+ stateTasks map[common.Hash]common.Hash // Account hashes->roots that need full state retrieval
+
+ genBatch ethdb.Batch // Batch used by the node generator
+ genTrie *trie.StackTrie // Node generator from storage slots
+
+ done bool // Flag whether the task can be removed
+}
+
+// storageTask represents the sync task for a chunk of the storage snapshot.
+type storageTask struct {
+ Next common.Hash // Next account to sync in this interval
+ Last common.Hash // Last account to sync in this interval
+
+ // These fields are internals used during runtime
+ root common.Hash // Storage root hash for this instance
+ req *storageRequest // Pending request to fill this task
+
+ genBatch ethdb.Batch // Batch used by the node generator
+ genTrie *trie.StackTrie // Node generator from storage slots
+
+ done bool // Flag whether the task can be removed
+}
+
+// healTask represents the sync task for healing the snap-synced chunk boundaries.
+type healTask struct {
+ scheduler *trie.Sync // State trie sync scheduler defining the tasks
+
+ trieTasks map[string]common.Hash // Set of trie node tasks currently queued for retrieval, indexed by node path
+ codeTasks map[common.Hash]struct{} // Set of byte code tasks currently queued for retrieval, indexed by code hash
+}
+
+// SyncProgress is a database entry to allow suspending and resuming a snapshot state
+// sync. Opposed to full and fast sync, there is no way to restart a suspended
+// snap sync without prior knowledge of the suspension point.
+type SyncProgress struct {
+ Tasks []*accountTask // The suspended account tasks (contract tasks within)
+
+ // Status report during syncing phase
+ AccountSynced uint64 // Number of accounts downloaded
+ AccountBytes common.StorageSize // Number of account trie bytes persisted to disk
+ BytecodeSynced uint64 // Number of bytecodes downloaded
+ BytecodeBytes common.StorageSize // Number of bytecode bytes downloaded
+ StorageSynced uint64 // Number of storage slots downloaded
+ StorageBytes common.StorageSize // Number of storage trie bytes persisted to disk
+
+ // Status report during healing phase
+ TrienodeHealSynced uint64 // Number of state trie nodes downloaded
+ TrienodeHealBytes common.StorageSize // Number of state trie bytes persisted to disk
+ BytecodeHealSynced uint64 // Number of bytecodes downloaded
+ BytecodeHealBytes common.StorageSize // Number of bytecodes persisted to disk
+}
+
+// SyncPending is analogous to SyncProgress, but it's used to report on pending
+// ephemeral sync progress that doesn't get persisted into the database.
+type SyncPending struct {
+ TrienodeHeal uint64 // Number of state trie nodes pending
+ BytecodeHeal uint64 // Number of bytecodes pending
+}
+
+// SyncPeer abstracts out the methods required for a peer to be synced against
+// with the goal of allowing the construction of mock peers without the full
+// blown networking.
+type SyncPeer interface {
+ // ID retrieves the peer's unique identifier.
+ ID() string
+
+ // RequestAccountRange fetches a batch of accounts rooted in a specific account
+ // trie, starting with the origin.
+ RequestAccountRange(id uint64, root, origin, limit common.Hash, bytes uint64) error
+
+ // RequestStorageRanges fetches a batch of storage slots belonging to one or
+ // more accounts. If slots from only one account is requested, an origin marker
+ // may also be used to retrieve from there.
+ RequestStorageRanges(id uint64, root common.Hash, accounts []common.Hash, origin, limit []byte, bytes uint64) error
+
+ // RequestByteCodes fetches a batch of bytecodes by hash.
+ RequestByteCodes(id uint64, hashes []common.Hash, bytes uint64) error
+
+ // RequestTrieNodes fetches a batch of account or storage trie nodes rooted in
+ // a specific state trie.
+ RequestTrieNodes(id uint64, root common.Hash, paths []TrieNodePathSet, bytes uint64) error
+
+ // Log retrieves the peer's own contextual logger.
+ Log() log.Logger
+}
+
+// Syncer is an Ethereum account and storage trie syncer based on snapshots and
+// the snap protocol. It's purpose is to download all the accounts and storage
+// slots from remote peers and reassemble chunks of the state trie, on top of
+// which a state sync can be run to fix any gaps / overlaps.
+//
+// Every network request has a variety of failure events:
+// - The peer disconnects after task assignment, failing to send the request
+// - The peer disconnects after sending the request, before delivering on it
+// - The peer remains connected, but does not deliver a response in time
+// - The peer delivers a stale response after a previous timeout
+// - The peer delivers a refusal to serve the requested state
+type Syncer struct {
+ db ethdb.KeyValueStore // Database to store the trie nodes into (and dedup)
+ scheme string // Node scheme used in node database
+
+ root common.Hash // Current state trie root being synced
+ tasks []*accountTask // Current account task set being synced
+ snapped bool // Flag to signal that snap phase is done
+ healer *healTask // Current state healing task being executed
+ update chan struct{} // Notification channel for possible sync progression
+
+ peers map[string]SyncPeer // Currently active peers to download from
+ peerJoin *event.Feed // Event feed to react to peers joining
+ peerDrop *event.Feed // Event feed to react to peers dropping
+ rates *msgrate.Trackers // Message throughput rates for peers
+
+ // Request tracking during syncing phase
+ statelessPeers map[string]struct{} // Peers that failed to deliver state data
+ accountIdlers map[string]struct{} // Peers that aren't serving account requests
+ bytecodeIdlers map[string]struct{} // Peers that aren't serving bytecode requests
+ storageIdlers map[string]struct{} // Peers that aren't serving storage requests
+
+ accountReqs map[uint64]*accountRequest // Account requests currently running
+ bytecodeReqs map[uint64]*bytecodeRequest // Bytecode requests currently running
+ storageReqs map[uint64]*storageRequest // Storage requests currently running
+
+ accountSynced uint64 // Number of accounts downloaded
+ accountBytes common.StorageSize // Number of account trie bytes persisted to disk
+ bytecodeSynced uint64 // Number of bytecodes downloaded
+ bytecodeBytes common.StorageSize // Number of bytecode bytes downloaded
+ storageSynced uint64 // Number of storage slots downloaded
+ storageBytes common.StorageSize // Number of storage trie bytes persisted to disk
+
+ extProgress *SyncProgress // progress that can be exposed to external caller.
+
+ // Request tracking during healing phase
+ trienodeHealIdlers map[string]struct{} // Peers that aren't serving trie node requests
+ bytecodeHealIdlers map[string]struct{} // Peers that aren't serving bytecode requests
+
+ trienodeHealReqs map[uint64]*trienodeHealRequest // Trie node requests currently running
+ bytecodeHealReqs map[uint64]*bytecodeHealRequest // Bytecode requests currently running
+
+ trienodeHealRate float64 // Average heal rate for processing trie node data
+ trienodeHealPend atomic.Uint64 // Number of trie nodes currently pending for processing
+ trienodeHealThrottle float64 // Divisor for throttling the amount of trienode heal data requested
+ trienodeHealThrottled time.Time // Timestamp the last time the throttle was updated
+
+ trienodeHealSynced uint64 // Number of state trie nodes downloaded
+ trienodeHealBytes common.StorageSize // Number of state trie bytes persisted to disk
+ trienodeHealDups uint64 // Number of state trie nodes already processed
+ trienodeHealNops uint64 // Number of state trie nodes not requested
+ bytecodeHealSynced uint64 // Number of bytecodes downloaded
+ bytecodeHealBytes common.StorageSize // Number of bytecodes persisted to disk
+ bytecodeHealDups uint64 // Number of bytecodes already processed
+ bytecodeHealNops uint64 // Number of bytecodes not requested
+
+ stateWriter ethdb.Batch // Shared batch writer used for persisting raw states
+ accountHealed uint64 // Number of accounts downloaded during the healing stage
+ accountHealedBytes common.StorageSize // Number of raw account bytes persisted to disk during the healing stage
+ storageHealed uint64 // Number of storage slots downloaded during the healing stage
+ storageHealedBytes common.StorageSize // Number of raw storage bytes persisted to disk during the healing stage
+
+ startTime time.Time // Time instance when snapshot sync started
+ logTime time.Time // Time instance when status was last reported
+
+ pend sync.WaitGroup // Tracks network request goroutines for graceful shutdown
+ lock sync.RWMutex // Protects fields that can change outside of sync (peers, reqs, root)
+}
+
+// NewSyncer creates a new snapshot syncer to download the Ethereum state over the
+// snap protocol.
+func NewSyncer(db ethdb.KeyValueStore, scheme string) *Syncer {
+ return &Syncer{
+ db: db,
+ scheme: scheme,
+
+ peers: make(map[string]SyncPeer),
+ peerJoin: new(event.Feed),
+ peerDrop: new(event.Feed),
+ rates: msgrate.NewTrackers(log.New("proto", "snap")),
+ update: make(chan struct{}, 1),
+
+ accountIdlers: make(map[string]struct{}),
+ storageIdlers: make(map[string]struct{}),
+ bytecodeIdlers: make(map[string]struct{}),
+
+ accountReqs: make(map[uint64]*accountRequest),
+ storageReqs: make(map[uint64]*storageRequest),
+ bytecodeReqs: make(map[uint64]*bytecodeRequest),
+
+ trienodeHealIdlers: make(map[string]struct{}),
+ bytecodeHealIdlers: make(map[string]struct{}),
+
+ trienodeHealReqs: make(map[uint64]*trienodeHealRequest),
+ bytecodeHealReqs: make(map[uint64]*bytecodeHealRequest),
+ trienodeHealThrottle: maxTrienodeHealThrottle, // Tune downward instead of insta-filling with junk
+ stateWriter: db.NewBatch(),
+
+ extProgress: new(SyncProgress),
+ }
+}
+
+// Register injects a new data source into the syncer's peerset.
+func (s *Syncer) Register(peer SyncPeer) error {
+ // Make sure the peer is not registered yet
+ id := peer.ID()
+
+ s.lock.Lock()
+ if _, ok := s.peers[id]; ok {
+ log.Error("Snap peer already registered", "id", id)
+
+ s.lock.Unlock()
+ return errors.New("already registered")
+ }
+ s.peers[id] = peer
+ s.rates.Track(id, msgrate.NewTracker(s.rates.MeanCapacities(), s.rates.MedianRoundTrip()))
+
+ // Mark the peer as idle, even if no sync is running
+ s.accountIdlers[id] = struct{}{}
+ s.storageIdlers[id] = struct{}{}
+ s.bytecodeIdlers[id] = struct{}{}
+ s.trienodeHealIdlers[id] = struct{}{}
+ s.bytecodeHealIdlers[id] = struct{}{}
+ s.lock.Unlock()
+
+ // Notify any active syncs that a new peer can be assigned data
+ s.peerJoin.Send(id)
+ return nil
+}
+
+// Unregister injects a new data source into the syncer's peerset.
+func (s *Syncer) Unregister(id string) error {
+ // Remove all traces of the peer from the registry
+ s.lock.Lock()
+ if _, ok := s.peers[id]; !ok {
+ log.Error("Snap peer not registered", "id", id)
+
+ s.lock.Unlock()
+ return errors.New("not registered")
+ }
+ delete(s.peers, id)
+ s.rates.Untrack(id)
+
+ // Remove status markers, even if no sync is running
+ delete(s.statelessPeers, id)
+
+ delete(s.accountIdlers, id)
+ delete(s.storageIdlers, id)
+ delete(s.bytecodeIdlers, id)
+ delete(s.trienodeHealIdlers, id)
+ delete(s.bytecodeHealIdlers, id)
+ s.lock.Unlock()
+
+ // Notify any active syncs that pending requests need to be reverted
+ s.peerDrop.Send(id)
+ return nil
+}
+
+// Sync starts (or resumes a previous) sync cycle to iterate over a state trie
+// with the given root and reconstruct the nodes based on the snapshot leaves.
+// Previously downloaded segments will not be redownloaded of fixed, rather any
+// errors will be healed after the leaves are fully accumulated.
+func (s *Syncer) Sync(root common.Hash, cancel chan struct{}) error {
+ // Move the trie root from any previous value, revert stateless markers for
+ // any peers and initialize the syncer if it was not yet run
+ s.lock.Lock()
+ s.root = root
+ s.healer = &healTask{
+ scheduler: state.NewStateSync(root, s.db, s.onHealState, s.scheme),
+ trieTasks: make(map[string]common.Hash),
+ codeTasks: make(map[common.Hash]struct{}),
+ }
+ s.statelessPeers = make(map[string]struct{})
+ s.lock.Unlock()
+
+ if s.startTime == (time.Time{}) {
+ s.startTime = time.Now()
+ }
+ // Retrieve the previous sync status from LevelDB and abort if already synced
+ s.loadSyncStatus()
+ if len(s.tasks) == 0 && s.healer.scheduler.Pending() == 0 {
+ log.Debug("Snapshot sync already completed")
+ return nil
+ }
+ defer func() { // Persist any progress, independent of failure
+ for _, task := range s.tasks {
+ s.forwardAccountTask(task)
+ }
+ s.cleanAccountTasks()
+ s.saveSyncStatus()
+ }()
+
+ log.Debug("Starting snapshot sync cycle", "root", root)
+
+ // Flush out the last committed raw states
+ defer func() {
+ if s.stateWriter.ValueSize() > 0 {
+ s.stateWriter.Write()
+ s.stateWriter.Reset()
+ }
+ }()
+ defer s.report(true)
+ // commit any trie- and bytecode-healing data.
+ defer s.commitHealer(true)
+
+ // Whether sync completed or not, disregard any future packets
+ defer func() {
+ log.Debug("Terminating snapshot sync cycle", "root", root)
+ s.lock.Lock()
+ s.accountReqs = make(map[uint64]*accountRequest)
+ s.storageReqs = make(map[uint64]*storageRequest)
+ s.bytecodeReqs = make(map[uint64]*bytecodeRequest)
+ s.trienodeHealReqs = make(map[uint64]*trienodeHealRequest)
+ s.bytecodeHealReqs = make(map[uint64]*bytecodeHealRequest)
+ s.lock.Unlock()
+ }()
+ // Keep scheduling sync tasks
+ peerJoin := make(chan string, 16)
+ peerJoinSub := s.peerJoin.Subscribe(peerJoin)
+ defer peerJoinSub.Unsubscribe()
+
+ peerDrop := make(chan string, 16)
+ peerDropSub := s.peerDrop.Subscribe(peerDrop)
+ defer peerDropSub.Unsubscribe()
+
+ // Create a set of unique channels for this sync cycle. We need these to be
+ // ephemeral so a data race doesn't accidentally deliver something stale on
+ // a persistent channel across syncs (yup, this happened)
+ var (
+ accountReqFails = make(chan *accountRequest)
+ storageReqFails = make(chan *storageRequest)
+ bytecodeReqFails = make(chan *bytecodeRequest)
+ accountResps = make(chan *accountResponse)
+ storageResps = make(chan *storageResponse)
+ bytecodeResps = make(chan *bytecodeResponse)
+ trienodeHealReqFails = make(chan *trienodeHealRequest)
+ bytecodeHealReqFails = make(chan *bytecodeHealRequest)
+ trienodeHealResps = make(chan *trienodeHealResponse)
+ bytecodeHealResps = make(chan *bytecodeHealResponse)
+ )
+ for {
+ // Remove all completed tasks and terminate sync if everything's done
+ s.cleanStorageTasks()
+ s.cleanAccountTasks()
+ if len(s.tasks) == 0 && s.healer.scheduler.Pending() == 0 {
+ return nil
+ }
+ // Assign all the data retrieval tasks to any free peers
+ s.assignAccountTasks(accountResps, accountReqFails, cancel)
+ s.assignBytecodeTasks(bytecodeResps, bytecodeReqFails, cancel)
+ s.assignStorageTasks(storageResps, storageReqFails, cancel)
+
+ if len(s.tasks) == 0 {
+ // Sync phase done, run heal phase
+ s.assignTrienodeHealTasks(trienodeHealResps, trienodeHealReqFails, cancel)
+ s.assignBytecodeHealTasks(bytecodeHealResps, bytecodeHealReqFails, cancel)
+ }
+ // Update sync progress
+ s.lock.Lock()
+ s.extProgress = &SyncProgress{
+ AccountSynced: s.accountSynced,
+ AccountBytes: s.accountBytes,
+ BytecodeSynced: s.bytecodeSynced,
+ BytecodeBytes: s.bytecodeBytes,
+ StorageSynced: s.storageSynced,
+ StorageBytes: s.storageBytes,
+ TrienodeHealSynced: s.trienodeHealSynced,
+ TrienodeHealBytes: s.trienodeHealBytes,
+ BytecodeHealSynced: s.bytecodeHealSynced,
+ BytecodeHealBytes: s.bytecodeHealBytes,
+ }
+ s.lock.Unlock()
+ // Wait for something to happen
+ select {
+ case <-s.update:
+ // Something happened (new peer, delivery, timeout), recheck tasks
+ case <-peerJoin:
+ // A new peer joined, try to schedule it new tasks
+ case id := <-peerDrop:
+ s.revertRequests(id)
+ case <-cancel:
+ return ErrCancelled
+
+ case req := <-accountReqFails:
+ s.revertAccountRequest(req)
+ case req := <-bytecodeReqFails:
+ s.revertBytecodeRequest(req)
+ case req := <-storageReqFails:
+ s.revertStorageRequest(req)
+ case req := <-trienodeHealReqFails:
+ s.revertTrienodeHealRequest(req)
+ case req := <-bytecodeHealReqFails:
+ s.revertBytecodeHealRequest(req)
+
+ case res := <-accountResps:
+ s.processAccountResponse(res)
+ case res := <-bytecodeResps:
+ s.processBytecodeResponse(res)
+ case res := <-storageResps:
+ s.processStorageResponse(res)
+ case res := <-trienodeHealResps:
+ s.processTrienodeHealResponse(res)
+ case res := <-bytecodeHealResps:
+ s.processBytecodeHealResponse(res)
+ }
+ // Report stats if something meaningful happened
+ s.report(false)
+ }
+}
+
+// cleanPath is used to remove the dangling nodes in the stackTrie.
+func (s *Syncer) cleanPath(batch ethdb.Batch, owner common.Hash, path []byte) {
+ if owner == (common.Hash{}) && rawdb.ExistsAccountTrieNode(s.db, path) {
+ rawdb.DeleteAccountTrieNode(batch, path)
+ deletionGauge.Inc(1)
+ }
+ if owner != (common.Hash{}) && rawdb.ExistsStorageTrieNode(s.db, owner, path) {
+ rawdb.DeleteStorageTrieNode(batch, owner, path)
+ deletionGauge.Inc(1)
+ }
+ lookupGauge.Inc(1)
+}
+
+// loadSyncStatus retrieves a previously aborted sync status from the database,
+// or generates a fresh one if none is available.
+func (s *Syncer) loadSyncStatus() {
+ var progress SyncProgress
+
+ if status := rawdb.ReadSnapshotSyncStatus(s.db); status != nil {
+ if err := json.Unmarshal(status, &progress); err != nil {
+ log.Error("Failed to decode snap sync status", "err", err)
+ } else {
+ for _, task := range progress.Tasks {
+ log.Debug("Scheduled account sync task", "from", task.Next, "last", task.Last)
+ }
+ s.tasks = progress.Tasks
+ for _, task := range s.tasks {
+ task := task // closure for task.genBatch in the stacktrie writer callback
+
+ task.genBatch = ethdb.HookedBatch{
+ Batch: s.db.NewBatch(),
+ OnPut: func(key []byte, value []byte) {
+ s.accountBytes += common.StorageSize(len(key) + len(value))
+ },
+ }
+ options := trie.NewStackTrieOptions()
+ options = options.WithWriter(func(path []byte, hash common.Hash, blob []byte) {
+ rawdb.WriteTrieNode(task.genBatch, common.Hash{}, path, hash, blob, s.scheme)
+ })
+ if s.scheme == rawdb.PathScheme {
+ // Configure the dangling node cleaner and also filter out boundary nodes
+ // only in the context of the path scheme. Deletion is forbidden in the
+ // hash scheme, as it can disrupt state completeness.
+ options = options.WithCleaner(func(path []byte) {
+ s.cleanPath(task.genBatch, common.Hash{}, path)
+ })
+ // Skip the left boundary if it's not the first range.
+ // Skip the right boundary if it's not the last range.
+ options = options.WithSkipBoundary(task.Next != (common.Hash{}), task.Last != common.MaxHash, boundaryAccountNodesGauge)
+ }
+ task.genTrie = trie.NewStackTrie(options)
+ for accountHash, subtasks := range task.SubTasks {
+ for _, subtask := range subtasks {
+ subtask := subtask // closure for subtask.genBatch in the stacktrie writer callback
+
+ subtask.genBatch = ethdb.HookedBatch{
+ Batch: s.db.NewBatch(),
+ OnPut: func(key []byte, value []byte) {
+ s.storageBytes += common.StorageSize(len(key) + len(value))
+ },
+ }
+ owner := accountHash // local assignment for stacktrie writer closure
+ options := trie.NewStackTrieOptions()
+ options = options.WithWriter(func(path []byte, hash common.Hash, blob []byte) {
+ rawdb.WriteTrieNode(subtask.genBatch, owner, path, hash, blob, s.scheme)
+ })
+ if s.scheme == rawdb.PathScheme {
+ // Configure the dangling node cleaner and also filter out boundary nodes
+ // only in the context of the path scheme. Deletion is forbidden in the
+ // hash scheme, as it can disrupt state completeness.
+ options = options.WithCleaner(func(path []byte) {
+ s.cleanPath(subtask.genBatch, owner, path)
+ })
+ // Skip the left boundary if it's not the first range.
+ // Skip the right boundary if it's not the last range.
+ options = options.WithSkipBoundary(subtask.Next != common.Hash{}, subtask.Last != common.MaxHash, boundaryStorageNodesGauge)
+ }
+ subtask.genTrie = trie.NewStackTrie(options)
+ }
+ }
+ }
+ s.lock.Lock()
+ defer s.lock.Unlock()
+
+ s.snapped = len(s.tasks) == 0
+
+ s.accountSynced = progress.AccountSynced
+ s.accountBytes = progress.AccountBytes
+ s.bytecodeSynced = progress.BytecodeSynced
+ s.bytecodeBytes = progress.BytecodeBytes
+ s.storageSynced = progress.StorageSynced
+ s.storageBytes = progress.StorageBytes
+
+ s.trienodeHealSynced = progress.TrienodeHealSynced
+ s.trienodeHealBytes = progress.TrienodeHealBytes
+ s.bytecodeHealSynced = progress.BytecodeHealSynced
+ s.bytecodeHealBytes = progress.BytecodeHealBytes
+ return
+ }
+ }
+ // Either we've failed to decode the previous state, or there was none.
+ // Start a fresh sync by chunking up the account range and scheduling
+ // them for retrieval.
+ s.tasks = nil
+ s.accountSynced, s.accountBytes = 0, 0
+ s.bytecodeSynced, s.bytecodeBytes = 0, 0
+ s.storageSynced, s.storageBytes = 0, 0
+ s.trienodeHealSynced, s.trienodeHealBytes = 0, 0
+ s.bytecodeHealSynced, s.bytecodeHealBytes = 0, 0
+
+ var next common.Hash
+ step := new(big.Int).Sub(
+ new(big.Int).Div(
+ new(big.Int).Exp(common.Big2, common.Big256, nil),
+ big.NewInt(int64(accountConcurrency)),
+ ), common.Big1,
+ )
+ for i := 0; i < accountConcurrency; i++ {
+ last := common.BigToHash(new(big.Int).Add(next.Big(), step))
+ if i == accountConcurrency-1 {
+ // Make sure we don't overflow if the step is not a proper divisor
+ last = common.MaxHash
+ }
+ batch := ethdb.HookedBatch{
+ Batch: s.db.NewBatch(),
+ OnPut: func(key []byte, value []byte) {
+ s.accountBytes += common.StorageSize(len(key) + len(value))
+ },
+ }
+ options := trie.NewStackTrieOptions()
+ options = options.WithWriter(func(path []byte, hash common.Hash, blob []byte) {
+ rawdb.WriteTrieNode(batch, common.Hash{}, path, hash, blob, s.scheme)
+ })
+ if s.scheme == rawdb.PathScheme {
+ // Configure the dangling node cleaner and also filter out boundary nodes
+ // only in the context of the path scheme. Deletion is forbidden in the
+ // hash scheme, as it can disrupt state completeness.
+ options = options.WithCleaner(func(path []byte) {
+ s.cleanPath(batch, common.Hash{}, path)
+ })
+ // Skip the left boundary if it's not the first range.
+ // Skip the right boundary if it's not the last range.
+ options = options.WithSkipBoundary(next != common.Hash{}, last != common.MaxHash, boundaryAccountNodesGauge)
+ }
+ s.tasks = append(s.tasks, &accountTask{
+ Next: next,
+ Last: last,
+ SubTasks: make(map[common.Hash][]*storageTask),
+ genBatch: batch,
+ genTrie: trie.NewStackTrie(options),
+ })
+ log.Debug("Created account sync task", "from", next, "last", last)
+ next = common.BigToHash(new(big.Int).Add(last.Big(), common.Big1))
+ }
+}
+
+// saveSyncStatus marshals the remaining sync tasks into leveldb.
+func (s *Syncer) saveSyncStatus() {
+ // Serialize any partial progress to disk before spinning down
+ for _, task := range s.tasks {
+ if err := task.genBatch.Write(); err != nil {
+ log.Error("Failed to persist account slots", "err", err)
+ }
+ for _, subtasks := range task.SubTasks {
+ for _, subtask := range subtasks {
+ if err := subtask.genBatch.Write(); err != nil {
+ log.Error("Failed to persist storage slots", "err", err)
+ }
+ }
+ }
+ }
+ // Store the actual progress markers
+ progress := &SyncProgress{
+ Tasks: s.tasks,
+ AccountSynced: s.accountSynced,
+ AccountBytes: s.accountBytes,
+ BytecodeSynced: s.bytecodeSynced,
+ BytecodeBytes: s.bytecodeBytes,
+ StorageSynced: s.storageSynced,
+ StorageBytes: s.storageBytes,
+ TrienodeHealSynced: s.trienodeHealSynced,
+ TrienodeHealBytes: s.trienodeHealBytes,
+ BytecodeHealSynced: s.bytecodeHealSynced,
+ BytecodeHealBytes: s.bytecodeHealBytes,
+ }
+ status, err := json.Marshal(progress)
+ if err != nil {
+ panic(err) // This can only fail during implementation
+ }
+ rawdb.WriteSnapshotSyncStatus(s.db, status)
+}
+
+// Progress returns the snap sync status statistics.
+func (s *Syncer) Progress() (*SyncProgress, *SyncPending) {
+ s.lock.Lock()
+ defer s.lock.Unlock()
+ pending := new(SyncPending)
+ if s.healer != nil {
+ pending.TrienodeHeal = uint64(len(s.healer.trieTasks))
+ pending.BytecodeHeal = uint64(len(s.healer.codeTasks))
+ }
+ return s.extProgress, pending
+}
+
+// cleanAccountTasks removes account range retrieval tasks that have already been
+// completed.
+func (s *Syncer) cleanAccountTasks() {
+ // If the sync was already done before, don't even bother
+ if len(s.tasks) == 0 {
+ return
+ }
+ // Sync wasn't finished previously, check for any task that can be finalized
+ for i := 0; i < len(s.tasks); i++ {
+ if s.tasks[i].done {
+ s.tasks = append(s.tasks[:i], s.tasks[i+1:]...)
+ i--
+ }
+ }
+ // If everything was just finalized just, generate the account trie and start heal
+ if len(s.tasks) == 0 {
+ s.lock.Lock()
+ s.snapped = true
+ s.lock.Unlock()
+
+ // Push the final sync report
+ s.reportSyncProgress(true)
+ }
+}
+
+// cleanStorageTasks iterates over all the account tasks and storage sub-tasks
+// within, cleaning any that have been completed.
+func (s *Syncer) cleanStorageTasks() {
+ for _, task := range s.tasks {
+ for account, subtasks := range task.SubTasks {
+ // Remove storage range retrieval tasks that completed
+ for j := 0; j < len(subtasks); j++ {
+ if subtasks[j].done {
+ subtasks = append(subtasks[:j], subtasks[j+1:]...)
+ j--
+ }
+ }
+ if len(subtasks) > 0 {
+ task.SubTasks[account] = subtasks
+ continue
+ }
+ // If all storage chunks are done, mark the account as done too
+ for j, hash := range task.res.hashes {
+ if hash == account {
+ task.needState[j] = false
+ }
+ }
+ delete(task.SubTasks, account)
+ task.pend--
+
+ // If this was the last pending task, forward the account task
+ if task.pend == 0 {
+ s.forwardAccountTask(task)
+ }
+ }
+ }
+}
+
+// assignAccountTasks attempts to match idle peers to pending account range
+// retrievals.
+func (s *Syncer) assignAccountTasks(success chan *accountResponse, fail chan *accountRequest, cancel chan struct{}) {
+ s.lock.Lock()
+ defer s.lock.Unlock()
+
+ // Sort the peers by download capacity to use faster ones if many available
+ idlers := &capacitySort{
+ ids: make([]string, 0, len(s.accountIdlers)),
+ caps: make([]int, 0, len(s.accountIdlers)),
+ }
+ targetTTL := s.rates.TargetTimeout()
+ for id := range s.accountIdlers {
+ if _, ok := s.statelessPeers[id]; ok {
+ continue
+ }
+ idlers.ids = append(idlers.ids, id)
+ idlers.caps = append(idlers.caps, s.rates.Capacity(id, AccountRangeMsg, targetTTL))
+ }
+ if len(idlers.ids) == 0 {
+ return
+ }
+ sort.Sort(sort.Reverse(idlers))
+
+ // Iterate over all the tasks and try to find a pending one
+ for _, task := range s.tasks {
+ // Skip any tasks already filling
+ if task.req != nil || task.res != nil {
+ continue
+ }
+ // Task pending retrieval, try to find an idle peer. If no such peer
+ // exists, we probably assigned tasks for all (or they are stateless).
+ // Abort the entire assignment mechanism.
+ if len(idlers.ids) == 0 {
+ return
+ }
+ var (
+ idle = idlers.ids[0]
+ peer = s.peers[idle]
+ cap = idlers.caps[0]
+ )
+ idlers.ids, idlers.caps = idlers.ids[1:], idlers.caps[1:]
+
+ // Matched a pending task to an idle peer, allocate a unique request id
+ var reqid uint64
+ for {
+ reqid = uint64(rand.Int63())
+ if reqid == 0 {
+ continue
+ }
+ if _, ok := s.accountReqs[reqid]; ok {
+ continue
+ }
+ break
+ }
+ // Generate the network query and send it to the peer
+ req := &accountRequest{
+ peer: idle,
+ id: reqid,
+ time: time.Now(),
+ deliver: success,
+ revert: fail,
+ cancel: cancel,
+ stale: make(chan struct{}),
+ origin: task.Next,
+ limit: task.Last,
+ task: task,
+ }
+ req.timeout = time.AfterFunc(s.rates.TargetTimeout(), func() {
+ peer.Log().Debug("Account range request timed out", "reqid", reqid)
+ s.rates.Update(idle, AccountRangeMsg, 0, 0)
+ s.scheduleRevertAccountRequest(req)
+ })
+ s.accountReqs[reqid] = req
+ delete(s.accountIdlers, idle)
+
+ s.pend.Add(1)
+ go func(root common.Hash) {
+ defer s.pend.Done()
+
+ // Attempt to send the remote request and revert if it fails
+ if cap > maxRequestSize {
+ cap = maxRequestSize
+ }
+ if cap < minRequestSize { // Don't bother with peers below a bare minimum performance
+ cap = minRequestSize
+ }
+ if err := peer.RequestAccountRange(reqid, root, req.origin, req.limit, uint64(cap)); err != nil {
+ peer.Log().Debug("Failed to request account range", "err", err)
+ s.scheduleRevertAccountRequest(req)
+ }
+ }(s.root)
+
+ // Inject the request into the task to block further assignments
+ task.req = req
+ }
+}
+
+// assignBytecodeTasks attempts to match idle peers to pending code retrievals.
+func (s *Syncer) assignBytecodeTasks(success chan *bytecodeResponse, fail chan *bytecodeRequest, cancel chan struct{}) {
+ s.lock.Lock()
+ defer s.lock.Unlock()
+
+ // Sort the peers by download capacity to use faster ones if many available
+ idlers := &capacitySort{
+ ids: make([]string, 0, len(s.bytecodeIdlers)),
+ caps: make([]int, 0, len(s.bytecodeIdlers)),
+ }
+ targetTTL := s.rates.TargetTimeout()
+ for id := range s.bytecodeIdlers {
+ if _, ok := s.statelessPeers[id]; ok {
+ continue
+ }
+ idlers.ids = append(idlers.ids, id)
+ idlers.caps = append(idlers.caps, s.rates.Capacity(id, ByteCodesMsg, targetTTL))
+ }
+ if len(idlers.ids) == 0 {
+ return
+ }
+ sort.Sort(sort.Reverse(idlers))
+
+ // Iterate over all the tasks and try to find a pending one
+ for _, task := range s.tasks {
+ // Skip any tasks not in the bytecode retrieval phase
+ if task.res == nil {
+ continue
+ }
+ // Skip tasks that are already retrieving (or done with) all codes
+ if len(task.codeTasks) == 0 {
+ continue
+ }
+ // Task pending retrieval, try to find an idle peer. If no such peer
+ // exists, we probably assigned tasks for all (or they are stateless).
+ // Abort the entire assignment mechanism.
+ if len(idlers.ids) == 0 {
+ return
+ }
+ var (
+ idle = idlers.ids[0]
+ peer = s.peers[idle]
+ cap = idlers.caps[0]
+ )
+ idlers.ids, idlers.caps = idlers.ids[1:], idlers.caps[1:]
+
+ // Matched a pending task to an idle peer, allocate a unique request id
+ var reqid uint64
+ for {
+ reqid = uint64(rand.Int63())
+ if reqid == 0 {
+ continue
+ }
+ if _, ok := s.bytecodeReqs[reqid]; ok {
+ continue
+ }
+ break
+ }
+ // Generate the network query and send it to the peer
+ if cap > maxCodeRequestCount {
+ cap = maxCodeRequestCount
+ }
+ hashes := make([]common.Hash, 0, cap)
+ for hash := range task.codeTasks {
+ delete(task.codeTasks, hash)
+ hashes = append(hashes, hash)
+ if len(hashes) >= cap {
+ break
+ }
+ }
+ req := &bytecodeRequest{
+ peer: idle,
+ id: reqid,
+ time: time.Now(),
+ deliver: success,
+ revert: fail,
+ cancel: cancel,
+ stale: make(chan struct{}),
+ hashes: hashes,
+ task: task,
+ }
+ req.timeout = time.AfterFunc(s.rates.TargetTimeout(), func() {
+ peer.Log().Debug("Bytecode request timed out", "reqid", reqid)
+ s.rates.Update(idle, ByteCodesMsg, 0, 0)
+ s.scheduleRevertBytecodeRequest(req)
+ })
+ s.bytecodeReqs[reqid] = req
+ delete(s.bytecodeIdlers, idle)
+
+ s.pend.Add(1)
+ go func() {
+ defer s.pend.Done()
+
+ // Attempt to send the remote request and revert if it fails
+ if err := peer.RequestByteCodes(reqid, hashes, maxRequestSize); err != nil {
+ log.Debug("Failed to request bytecodes", "err", err)
+ s.scheduleRevertBytecodeRequest(req)
+ }
+ }()
+ }
+}
+
+// assignStorageTasks attempts to match idle peers to pending storage range
+// retrievals.
+func (s *Syncer) assignStorageTasks(success chan *storageResponse, fail chan *storageRequest, cancel chan struct{}) {
+ s.lock.Lock()
+ defer s.lock.Unlock()
+
+ // Sort the peers by download capacity to use faster ones if many available
+ idlers := &capacitySort{
+ ids: make([]string, 0, len(s.storageIdlers)),
+ caps: make([]int, 0, len(s.storageIdlers)),
+ }
+ targetTTL := s.rates.TargetTimeout()
+ for id := range s.storageIdlers {
+ if _, ok := s.statelessPeers[id]; ok {
+ continue
+ }
+ idlers.ids = append(idlers.ids, id)
+ idlers.caps = append(idlers.caps, s.rates.Capacity(id, StorageRangesMsg, targetTTL))
+ }
+ if len(idlers.ids) == 0 {
+ return
+ }
+ sort.Sort(sort.Reverse(idlers))
+
+ // Iterate over all the tasks and try to find a pending one
+ for _, task := range s.tasks {
+ // Skip any tasks not in the storage retrieval phase
+ if task.res == nil {
+ continue
+ }
+ // Skip tasks that are already retrieving (or done with) all small states
+ if len(task.SubTasks) == 0 && len(task.stateTasks) == 0 {
+ continue
+ }
+ // Task pending retrieval, try to find an idle peer. If no such peer
+ // exists, we probably assigned tasks for all (or they are stateless).
+ // Abort the entire assignment mechanism.
+ if len(idlers.ids) == 0 {
+ return
+ }
+ var (
+ idle = idlers.ids[0]
+ peer = s.peers[idle]
+ cap = idlers.caps[0]
+ )
+ idlers.ids, idlers.caps = idlers.ids[1:], idlers.caps[1:]
+
+ // Matched a pending task to an idle peer, allocate a unique request id
+ var reqid uint64
+ for {
+ reqid = uint64(rand.Int63())
+ if reqid == 0 {
+ continue
+ }
+ if _, ok := s.storageReqs[reqid]; ok {
+ continue
+ }
+ break
+ }
+ // Generate the network query and send it to the peer. If there are
+ // large contract tasks pending, complete those before diving into
+ // even more new contracts.
+ if cap > maxRequestSize {
+ cap = maxRequestSize
+ }
+ if cap < minRequestSize { // Don't bother with peers below a bare minimum performance
+ cap = minRequestSize
+ }
+ storageSets := cap / 1024
+
+ var (
+ accounts = make([]common.Hash, 0, storageSets)
+ roots = make([]common.Hash, 0, storageSets)
+ subtask *storageTask
+ )
+ for account, subtasks := range task.SubTasks {
+ for _, st := range subtasks {
+ // Skip any subtasks already filling
+ if st.req != nil {
+ continue
+ }
+ // Found an incomplete storage chunk, schedule it
+ accounts = append(accounts, account)
+ roots = append(roots, st.root)
+ subtask = st
+ break // Large contract chunks are downloaded individually
+ }
+ if subtask != nil {
+ break // Large contract chunks are downloaded individually
+ }
+ }
+ if subtask == nil {
+ // No large contract required retrieval, but small ones available
+ for account, root := range task.stateTasks {
+ delete(task.stateTasks, account)
+
+ accounts = append(accounts, account)
+ roots = append(roots, root)
+
+ if len(accounts) >= storageSets {
+ break
+ }
+ }
+ }
+ // If nothing was found, it means this task is actually already fully
+ // retrieving, but large contracts are hard to detect. Skip to the next.
+ if len(accounts) == 0 {
+ continue
+ }
+ req := &storageRequest{
+ peer: idle,
+ id: reqid,
+ time: time.Now(),
+ deliver: success,
+ revert: fail,
+ cancel: cancel,
+ stale: make(chan struct{}),
+ accounts: accounts,
+ roots: roots,
+ mainTask: task,
+ subTask: subtask,
+ }
+ if subtask != nil {
+ req.origin = subtask.Next
+ req.limit = subtask.Last
+ }
+ req.timeout = time.AfterFunc(s.rates.TargetTimeout(), func() {
+ peer.Log().Debug("Storage request timed out", "reqid", reqid)
+ s.rates.Update(idle, StorageRangesMsg, 0, 0)
+ s.scheduleRevertStorageRequest(req)
+ })
+ s.storageReqs[reqid] = req
+ delete(s.storageIdlers, idle)
+
+ s.pend.Add(1)
+ go func(root common.Hash) {
+ defer s.pend.Done()
+
+ // Attempt to send the remote request and revert if it fails
+ var origin, limit []byte
+ if subtask != nil {
+ origin, limit = req.origin[:], req.limit[:]
+ }
+ if err := peer.RequestStorageRanges(reqid, root, accounts, origin, limit, uint64(cap)); err != nil {
+ log.Debug("Failed to request storage", "err", err)
+ s.scheduleRevertStorageRequest(req)
+ }
+ }(s.root)
+
+ // Inject the request into the subtask to block further assignments
+ if subtask != nil {
+ subtask.req = req
+ }
+ }
+}
+
+// assignTrienodeHealTasks attempts to match idle peers to trie node requests to
+// heal any trie errors caused by the snap sync's chunked retrieval model.
+func (s *Syncer) assignTrienodeHealTasks(success chan *trienodeHealResponse, fail chan *trienodeHealRequest, cancel chan struct{}) {
+ s.lock.Lock()
+ defer s.lock.Unlock()
+
+ // Sort the peers by download capacity to use faster ones if many available
+ idlers := &capacitySort{
+ ids: make([]string, 0, len(s.trienodeHealIdlers)),
+ caps: make([]int, 0, len(s.trienodeHealIdlers)),
+ }
+ targetTTL := s.rates.TargetTimeout()
+ for id := range s.trienodeHealIdlers {
+ if _, ok := s.statelessPeers[id]; ok {
+ continue
+ }
+ idlers.ids = append(idlers.ids, id)
+ idlers.caps = append(idlers.caps, s.rates.Capacity(id, TrieNodesMsg, targetTTL))
+ }
+ if len(idlers.ids) == 0 {
+ return
+ }
+ sort.Sort(sort.Reverse(idlers))
+
+ // Iterate over pending tasks and try to find a peer to retrieve with
+ for len(s.healer.trieTasks) > 0 || s.healer.scheduler.Pending() > 0 {
+ // If there are not enough trie tasks queued to fully assign, fill the
+ // queue from the state sync scheduler. The trie synced schedules these
+ // together with bytecodes, so we need to queue them combined.
+ var (
+ have = len(s.healer.trieTasks) + len(s.healer.codeTasks)
+ want = maxTrieRequestCount + maxCodeRequestCount
+ )
+ if have < want {
+ paths, hashes, codes := s.healer.scheduler.Missing(want - have)
+ for i, path := range paths {
+ s.healer.trieTasks[path] = hashes[i]
+ }
+ for _, hash := range codes {
+ s.healer.codeTasks[hash] = struct{}{}
+ }
+ }
+ // If all the heal tasks are bytecodes or already downloading, bail
+ if len(s.healer.trieTasks) == 0 {
+ return
+ }
+ // Task pending retrieval, try to find an idle peer. If no such peer
+ // exists, we probably assigned tasks for all (or they are stateless).
+ // Abort the entire assignment mechanism.
+ if len(idlers.ids) == 0 {
+ return
+ }
+ var (
+ idle = idlers.ids[0]
+ peer = s.peers[idle]
+ cap = idlers.caps[0]
+ )
+ idlers.ids, idlers.caps = idlers.ids[1:], idlers.caps[1:]
+
+ // Matched a pending task to an idle peer, allocate a unique request id
+ var reqid uint64
+ for {
+ reqid = uint64(rand.Int63())
+ if reqid == 0 {
+ continue
+ }
+ if _, ok := s.trienodeHealReqs[reqid]; ok {
+ continue
+ }
+ break
+ }
+ // Generate the network query and send it to the peer
+ if cap > maxTrieRequestCount {
+ cap = maxTrieRequestCount
+ }
+ cap = int(float64(cap) / s.trienodeHealThrottle)
+ if cap <= 0 {
+ cap = 1
+ }
+ var (
+ hashes = make([]common.Hash, 0, cap)
+ paths = make([]string, 0, cap)
+ pathsets = make([]TrieNodePathSet, 0, cap)
+ )
+ for path, hash := range s.healer.trieTasks {
+ delete(s.healer.trieTasks, path)
+
+ paths = append(paths, path)
+ hashes = append(hashes, hash)
+ if len(paths) >= cap {
+ break
+ }
+ }
+ // Group requests by account hash
+ paths, hashes, _, pathsets = sortByAccountPath(paths, hashes)
+ req := &trienodeHealRequest{
+ peer: idle,
+ id: reqid,
+ time: time.Now(),
+ deliver: success,
+ revert: fail,
+ cancel: cancel,
+ stale: make(chan struct{}),
+ paths: paths,
+ hashes: hashes,
+ task: s.healer,
+ }
+ req.timeout = time.AfterFunc(s.rates.TargetTimeout(), func() {
+ peer.Log().Debug("Trienode heal request timed out", "reqid", reqid)
+ s.rates.Update(idle, TrieNodesMsg, 0, 0)
+ s.scheduleRevertTrienodeHealRequest(req)
+ })
+ s.trienodeHealReqs[reqid] = req
+ delete(s.trienodeHealIdlers, idle)
+
+ s.pend.Add(1)
+ go func(root common.Hash) {
+ defer s.pend.Done()
+
+ // Attempt to send the remote request and revert if it fails
+ if err := peer.RequestTrieNodes(reqid, root, pathsets, maxRequestSize); err != nil {
+ log.Debug("Failed to request trienode healers", "err", err)
+ s.scheduleRevertTrienodeHealRequest(req)
+ }
+ }(s.root)
+ }
+}
+
+// assignBytecodeHealTasks attempts to match idle peers to bytecode requests to
+// heal any trie errors caused by the snap sync's chunked retrieval model.
+func (s *Syncer) assignBytecodeHealTasks(success chan *bytecodeHealResponse, fail chan *bytecodeHealRequest, cancel chan struct{}) {
+ s.lock.Lock()
+ defer s.lock.Unlock()
+
+ // Sort the peers by download capacity to use faster ones if many available
+ idlers := &capacitySort{
+ ids: make([]string, 0, len(s.bytecodeHealIdlers)),
+ caps: make([]int, 0, len(s.bytecodeHealIdlers)),
+ }
+ targetTTL := s.rates.TargetTimeout()
+ for id := range s.bytecodeHealIdlers {
+ if _, ok := s.statelessPeers[id]; ok {
+ continue
+ }
+ idlers.ids = append(idlers.ids, id)
+ idlers.caps = append(idlers.caps, s.rates.Capacity(id, ByteCodesMsg, targetTTL))
+ }
+ if len(idlers.ids) == 0 {
+ return
+ }
+ sort.Sort(sort.Reverse(idlers))
+
+ // Iterate over pending tasks and try to find a peer to retrieve with
+ for len(s.healer.codeTasks) > 0 || s.healer.scheduler.Pending() > 0 {
+ // If there are not enough trie tasks queued to fully assign, fill the
+ // queue from the state sync scheduler. The trie synced schedules these
+ // together with trie nodes, so we need to queue them combined.
+ var (
+ have = len(s.healer.trieTasks) + len(s.healer.codeTasks)
+ want = maxTrieRequestCount + maxCodeRequestCount
+ )
+ if have < want {
+ paths, hashes, codes := s.healer.scheduler.Missing(want - have)
+ for i, path := range paths {
+ s.healer.trieTasks[path] = hashes[i]
+ }
+ for _, hash := range codes {
+ s.healer.codeTasks[hash] = struct{}{}
+ }
+ }
+ // If all the heal tasks are trienodes or already downloading, bail
+ if len(s.healer.codeTasks) == 0 {
+ return
+ }
+ // Task pending retrieval, try to find an idle peer. If no such peer
+ // exists, we probably assigned tasks for all (or they are stateless).
+ // Abort the entire assignment mechanism.
+ if len(idlers.ids) == 0 {
+ return
+ }
+ var (
+ idle = idlers.ids[0]
+ peer = s.peers[idle]
+ cap = idlers.caps[0]
+ )
+ idlers.ids, idlers.caps = idlers.ids[1:], idlers.caps[1:]
+
+ // Matched a pending task to an idle peer, allocate a unique request id
+ var reqid uint64
+ for {
+ reqid = uint64(rand.Int63())
+ if reqid == 0 {
+ continue
+ }
+ if _, ok := s.bytecodeHealReqs[reqid]; ok {
+ continue
+ }
+ break
+ }
+ // Generate the network query and send it to the peer
+ if cap > maxCodeRequestCount {
+ cap = maxCodeRequestCount
+ }
+ hashes := make([]common.Hash, 0, cap)
+ for hash := range s.healer.codeTasks {
+ delete(s.healer.codeTasks, hash)
+
+ hashes = append(hashes, hash)
+ if len(hashes) >= cap {
+ break
+ }
+ }
+ req := &bytecodeHealRequest{
+ peer: idle,
+ id: reqid,
+ time: time.Now(),
+ deliver: success,
+ revert: fail,
+ cancel: cancel,
+ stale: make(chan struct{}),
+ hashes: hashes,
+ task: s.healer,
+ }
+ req.timeout = time.AfterFunc(s.rates.TargetTimeout(), func() {
+ peer.Log().Debug("Bytecode heal request timed out", "reqid", reqid)
+ s.rates.Update(idle, ByteCodesMsg, 0, 0)
+ s.scheduleRevertBytecodeHealRequest(req)
+ })
+ s.bytecodeHealReqs[reqid] = req
+ delete(s.bytecodeHealIdlers, idle)
+
+ s.pend.Add(1)
+ go func() {
+ defer s.pend.Done()
+
+ // Attempt to send the remote request and revert if it fails
+ if err := peer.RequestByteCodes(reqid, hashes, maxRequestSize); err != nil {
+ log.Debug("Failed to request bytecode healers", "err", err)
+ s.scheduleRevertBytecodeHealRequest(req)
+ }
+ }()
+ }
+}
+
+// revertRequests locates all the currently pending requests from a particular
+// peer and reverts them, rescheduling for others to fulfill.
+func (s *Syncer) revertRequests(peer string) {
+ // Gather the requests first, revertals need the lock too
+ s.lock.Lock()
+ var accountReqs []*accountRequest
+ for _, req := range s.accountReqs {
+ if req.peer == peer {
+ accountReqs = append(accountReqs, req)
+ }
+ }
+ var bytecodeReqs []*bytecodeRequest
+ for _, req := range s.bytecodeReqs {
+ if req.peer == peer {
+ bytecodeReqs = append(bytecodeReqs, req)
+ }
+ }
+ var storageReqs []*storageRequest
+ for _, req := range s.storageReqs {
+ if req.peer == peer {
+ storageReqs = append(storageReqs, req)
+ }
+ }
+ var trienodeHealReqs []*trienodeHealRequest
+ for _, req := range s.trienodeHealReqs {
+ if req.peer == peer {
+ trienodeHealReqs = append(trienodeHealReqs, req)
+ }
+ }
+ var bytecodeHealReqs []*bytecodeHealRequest
+ for _, req := range s.bytecodeHealReqs {
+ if req.peer == peer {
+ bytecodeHealReqs = append(bytecodeHealReqs, req)
+ }
+ }
+ s.lock.Unlock()
+
+ // Revert all the requests matching the peer
+ for _, req := range accountReqs {
+ s.revertAccountRequest(req)
+ }
+ for _, req := range bytecodeReqs {
+ s.revertBytecodeRequest(req)
+ }
+ for _, req := range storageReqs {
+ s.revertStorageRequest(req)
+ }
+ for _, req := range trienodeHealReqs {
+ s.revertTrienodeHealRequest(req)
+ }
+ for _, req := range bytecodeHealReqs {
+ s.revertBytecodeHealRequest(req)
+ }
+}
+
+// scheduleRevertAccountRequest asks the event loop to clean up an account range
+// request and return all failed retrieval tasks to the scheduler for reassignment.
+func (s *Syncer) scheduleRevertAccountRequest(req *accountRequest) {
+ select {
+ case req.revert <- req:
+ // Sync event loop notified
+ case <-req.cancel:
+ // Sync cycle got cancelled
+ case <-req.stale:
+ // Request already reverted
+ }
+}
+
+// revertAccountRequest cleans up an account range request and returns all failed
+// retrieval tasks to the scheduler for reassignment.
+//
+// Note, this needs to run on the event runloop thread to reschedule to idle peers.
+// On peer threads, use scheduleRevertAccountRequest.
+func (s *Syncer) revertAccountRequest(req *accountRequest) {
+ log.Debug("Reverting account request", "peer", req.peer, "reqid", req.id)
+ select {
+ case <-req.stale:
+ log.Trace("Account request already reverted", "peer", req.peer, "reqid", req.id)
+ return
+ default:
+ }
+ close(req.stale)
+
+ // Remove the request from the tracked set
+ s.lock.Lock()
+ delete(s.accountReqs, req.id)
+ s.lock.Unlock()
+
+ // If there's a timeout timer still running, abort it and mark the account
+ // task as not-pending, ready for rescheduling
+ req.timeout.Stop()
+ if req.task.req == req {
+ req.task.req = nil
+ }
+}
+
+// scheduleRevertBytecodeRequest asks the event loop to clean up a bytecode request
+// and return all failed retrieval tasks to the scheduler for reassignment.
+func (s *Syncer) scheduleRevertBytecodeRequest(req *bytecodeRequest) {
+ select {
+ case req.revert <- req:
+ // Sync event loop notified
+ case <-req.cancel:
+ // Sync cycle got cancelled
+ case <-req.stale:
+ // Request already reverted
+ }
+}
+
+// revertBytecodeRequest cleans up a bytecode request and returns all failed
+// retrieval tasks to the scheduler for reassignment.
+//
+// Note, this needs to run on the event runloop thread to reschedule to idle peers.
+// On peer threads, use scheduleRevertBytecodeRequest.
+func (s *Syncer) revertBytecodeRequest(req *bytecodeRequest) {
+ log.Debug("Reverting bytecode request", "peer", req.peer)
+ select {
+ case <-req.stale:
+ log.Trace("Bytecode request already reverted", "peer", req.peer, "reqid", req.id)
+ return
+ default:
+ }
+ close(req.stale)
+
+ // Remove the request from the tracked set
+ s.lock.Lock()
+ delete(s.bytecodeReqs, req.id)
+ s.lock.Unlock()
+
+ // If there's a timeout timer still running, abort it and mark the code
+ // retrievals as not-pending, ready for rescheduling
+ req.timeout.Stop()
+ for _, hash := range req.hashes {
+ req.task.codeTasks[hash] = struct{}{}
+ }
+}
+
+// scheduleRevertStorageRequest asks the event loop to clean up a storage range
+// request and return all failed retrieval tasks to the scheduler for reassignment.
+func (s *Syncer) scheduleRevertStorageRequest(req *storageRequest) {
+ select {
+ case req.revert <- req:
+ // Sync event loop notified
+ case <-req.cancel:
+ // Sync cycle got cancelled
+ case <-req.stale:
+ // Request already reverted
+ }
+}
+
+// revertStorageRequest cleans up a storage range request and returns all failed
+// retrieval tasks to the scheduler for reassignment.
+//
+// Note, this needs to run on the event runloop thread to reschedule to idle peers.
+// On peer threads, use scheduleRevertStorageRequest.
+func (s *Syncer) revertStorageRequest(req *storageRequest) {
+ log.Debug("Reverting storage request", "peer", req.peer)
+ select {
+ case <-req.stale:
+ log.Trace("Storage request already reverted", "peer", req.peer, "reqid", req.id)
+ return
+ default:
+ }
+ close(req.stale)
+
+ // Remove the request from the tracked set
+ s.lock.Lock()
+ delete(s.storageReqs, req.id)
+ s.lock.Unlock()
+
+ // If there's a timeout timer still running, abort it and mark the storage
+ // task as not-pending, ready for rescheduling
+ req.timeout.Stop()
+ if req.subTask != nil {
+ req.subTask.req = nil
+ } else {
+ for i, account := range req.accounts {
+ req.mainTask.stateTasks[account] = req.roots[i]
+ }
+ }
+}
+
+// scheduleRevertTrienodeHealRequest asks the event loop to clean up a trienode heal
+// request and return all failed retrieval tasks to the scheduler for reassignment.
+func (s *Syncer) scheduleRevertTrienodeHealRequest(req *trienodeHealRequest) {
+ select {
+ case req.revert <- req:
+ // Sync event loop notified
+ case <-req.cancel:
+ // Sync cycle got cancelled
+ case <-req.stale:
+ // Request already reverted
+ }
+}
+
+// revertTrienodeHealRequest cleans up a trienode heal request and returns all
+// failed retrieval tasks to the scheduler for reassignment.
+//
+// Note, this needs to run on the event runloop thread to reschedule to idle peers.
+// On peer threads, use scheduleRevertTrienodeHealRequest.
+func (s *Syncer) revertTrienodeHealRequest(req *trienodeHealRequest) {
+ log.Debug("Reverting trienode heal request", "peer", req.peer)
+ select {
+ case <-req.stale:
+ log.Trace("Trienode heal request already reverted", "peer", req.peer, "reqid", req.id)
+ return
+ default:
+ }
+ close(req.stale)
+
+ // Remove the request from the tracked set
+ s.lock.Lock()
+ delete(s.trienodeHealReqs, req.id)
+ s.lock.Unlock()
+
+ // If there's a timeout timer still running, abort it and mark the trie node
+ // retrievals as not-pending, ready for rescheduling
+ req.timeout.Stop()
+ for i, path := range req.paths {
+ req.task.trieTasks[path] = req.hashes[i]
+ }
+}
+
+// scheduleRevertBytecodeHealRequest asks the event loop to clean up a bytecode heal
+// request and return all failed retrieval tasks to the scheduler for reassignment.
+func (s *Syncer) scheduleRevertBytecodeHealRequest(req *bytecodeHealRequest) {
+ select {
+ case req.revert <- req:
+ // Sync event loop notified
+ case <-req.cancel:
+ // Sync cycle got cancelled
+ case <-req.stale:
+ // Request already reverted
+ }
+}
+
+// revertBytecodeHealRequest cleans up a bytecode heal request and returns all
+// failed retrieval tasks to the scheduler for reassignment.
+//
+// Note, this needs to run on the event runloop thread to reschedule to idle peers.
+// On peer threads, use scheduleRevertBytecodeHealRequest.
+func (s *Syncer) revertBytecodeHealRequest(req *bytecodeHealRequest) {
+ log.Debug("Reverting bytecode heal request", "peer", req.peer)
+ select {
+ case <-req.stale:
+ log.Trace("Bytecode heal request already reverted", "peer", req.peer, "reqid", req.id)
+ return
+ default:
+ }
+ close(req.stale)
+
+ // Remove the request from the tracked set
+ s.lock.Lock()
+ delete(s.bytecodeHealReqs, req.id)
+ s.lock.Unlock()
+
+ // If there's a timeout timer still running, abort it and mark the code
+ // retrievals as not-pending, ready for rescheduling
+ req.timeout.Stop()
+ for _, hash := range req.hashes {
+ req.task.codeTasks[hash] = struct{}{}
+ }
+}
+
+// processAccountResponse integrates an already validated account range response
+// into the account tasks.
+func (s *Syncer) processAccountResponse(res *accountResponse) {
+ // Switch the task from pending to filling
+ res.task.req = nil
+ res.task.res = res
+
+ // Ensure that the response doesn't overflow into the subsequent task
+ last := res.task.Last.Big()
+ for i, hash := range res.hashes {
+ // Mark the range complete if the last is already included.
+ // Keep iteration to delete the extra states if exists.
+ cmp := hash.Big().Cmp(last)
+ if cmp == 0 {
+ res.cont = false
+ continue
+ }
+ if cmp > 0 {
+ // Chunk overflown, cut off excess
+ res.hashes = res.hashes[:i]
+ res.accounts = res.accounts[:i]
+ res.cont = false // Mark range completed
+ break
+ }
+ }
+ // Iterate over all the accounts and assemble which ones need further sub-
+ // filling before the entire account range can be persisted.
+ res.task.needCode = make([]bool, len(res.accounts))
+ res.task.needState = make([]bool, len(res.accounts))
+ res.task.needHeal = make([]bool, len(res.accounts))
+
+ res.task.codeTasks = make(map[common.Hash]struct{})
+ res.task.stateTasks = make(map[common.Hash]common.Hash)
+
+ resumed := make(map[common.Hash]struct{})
+
+ res.task.pend = 0
+ for i, account := range res.accounts {
+ // Check if the account is a contract with an unknown code
+ if !bytes.Equal(account.CodeHash, types.EmptyCodeHash.Bytes()) {
+ if !rawdb.HasCodeWithPrefix(s.db, common.BytesToHash(account.CodeHash)) {
+ res.task.codeTasks[common.BytesToHash(account.CodeHash)] = struct{}{}
+ res.task.needCode[i] = true
+ res.task.pend++
+ }
+ }
+ // Check if the account is a contract with an unknown storage trie
+ if account.Root != types.EmptyRootHash {
+ //if !ethrawdb.HasTrieNode(s.db, res.hashes[i], nil, account.Root, s.scheme) {
+ // If there was a previous large state retrieval in progress,
+ // don't restart it from scratch. This happens if a sync cycle
+ // is interrupted and resumed later. However, *do* update the
+ // previous root hash.
+ if subtasks, ok := res.task.SubTasks[res.hashes[i]]; ok {
+ log.Debug("Resuming large storage retrieval", "account", res.hashes[i], "root", account.Root)
+ for _, subtask := range subtasks {
+ subtask.root = account.Root
+ }
+ res.task.needHeal[i] = true
+ resumed[res.hashes[i]] = struct{}{}
+ } else {
+ res.task.stateTasks[res.hashes[i]] = account.Root
+ }
+ res.task.needState[i] = true
+ res.task.pend++
+ }
+ //}
+ }
+ // Delete any subtasks that have been aborted but not resumed. This may undo
+ // some progress if a new peer gives us less accounts than an old one, but for
+ // now we have to live with that.
+ for hash := range res.task.SubTasks {
+ if _, ok := resumed[hash]; !ok {
+ log.Debug("Aborting suspended storage retrieval", "account", hash)
+ delete(res.task.SubTasks, hash)
+ }
+ }
+ // If the account range contained no contracts, or all have been fully filled
+ // beforehand, short circuit storage filling and forward to the next task
+ if res.task.pend == 0 {
+ s.forwardAccountTask(res.task)
+ return
+ }
+ // Some accounts are incomplete, leave as is for the storage and contract
+ // task assigners to pick up and fill
+}
+
+// processBytecodeResponse integrates an already validated bytecode response
+// into the account tasks.
+func (s *Syncer) processBytecodeResponse(res *bytecodeResponse) {
+ batch := s.db.NewBatch()
+
+ var (
+ codes uint64
+ )
+ for i, hash := range res.hashes {
+ code := res.codes[i]
+
+ // If the bytecode was not delivered, reschedule it
+ if code == nil {
+ res.task.codeTasks[hash] = struct{}{}
+ continue
+ }
+ // Code was delivered, mark it not needed any more
+ for j, account := range res.task.res.accounts {
+ if res.task.needCode[j] && hash == common.BytesToHash(account.CodeHash) {
+ res.task.needCode[j] = false
+ res.task.pend--
+ }
+ }
+ // Push the bytecode into a database batch
+ codes++
+ rawdb.WriteCode(batch, hash, code)
+ }
+ bytes := common.StorageSize(batch.ValueSize())
+ if err := batch.Write(); err != nil {
+ log.Crit("Failed to persist bytecodes", "err", err)
+ }
+ s.bytecodeSynced += codes
+ s.bytecodeBytes += bytes
+
+ log.Debug("Persisted set of bytecodes", "count", codes, "bytes", bytes)
+
+ // If this delivery completed the last pending task, forward the account task
+ // to the next chunk
+ if res.task.pend == 0 {
+ s.forwardAccountTask(res.task)
+ return
+ }
+ // Some accounts are still incomplete, leave as is for the storage and contract
+ // task assigners to pick up and fill.
+}
+
+// processStorageResponse integrates an already validated storage response
+// into the account tasks.
+func (s *Syncer) processStorageResponse(res *storageResponse) {
+ // Switch the subtask from pending to idle
+ if res.subTask != nil {
+ res.subTask.req = nil
+ }
+ batch := ethdb.HookedBatch{
+ Batch: s.db.NewBatch(),
+ OnPut: func(key []byte, value []byte) {
+ s.storageBytes += common.StorageSize(len(key) + len(value))
+ },
+ }
+ var (
+ slots int
+ oldStorageBytes = s.storageBytes
+ )
+ // Iterate over all the accounts and reconstruct their storage tries from the
+ // delivered slots
+ for i, account := range res.accounts {
+ // If the account was not delivered, reschedule it
+ if i >= len(res.hashes) {
+ res.mainTask.stateTasks[account] = res.roots[i]
+ continue
+ }
+ // State was delivered, if complete mark as not needed any more, otherwise
+ // mark the account as needing healing
+ for j, hash := range res.mainTask.res.hashes {
+ if account != hash {
+ continue
+ }
+ acc := res.mainTask.res.accounts[j]
+
+ // If the packet contains multiple contract storage slots, all
+ // but the last are surely complete. The last contract may be
+ // chunked, so check it's continuation flag.
+ if res.subTask == nil && res.mainTask.needState[j] && (i < len(res.hashes)-1 || !res.cont) {
+ res.mainTask.needState[j] = false
+ res.mainTask.pend--
+ smallStorageGauge.Inc(1)
+ }
+ // If the last contract was chunked, mark it as needing healing
+ // to avoid writing it out to disk prematurely.
+ if res.subTask == nil && !res.mainTask.needHeal[j] && i == len(res.hashes)-1 && res.cont {
+ res.mainTask.needHeal[j] = true
+ }
+ // If the last contract was chunked, we need to switch to large
+ // contract handling mode
+ if res.subTask == nil && i == len(res.hashes)-1 && res.cont {
+ // If we haven't yet started a large-contract retrieval, create
+ // the subtasks for it within the main account task
+ if tasks, ok := res.mainTask.SubTasks[account]; !ok {
+ var (
+ keys = res.hashes[i]
+ chunks = uint64(storageConcurrency)
+ lastKey common.Hash
+ )
+ if len(keys) > 0 {
+ lastKey = keys[len(keys)-1]
+ }
+ // If the number of slots remaining is low, decrease the
+ // number of chunks. Somewhere on the order of 10-15K slots
+ // fit into a packet of 500KB. A key/slot pair is maximum 64
+ // bytes, so pessimistically maxRequestSize/64 = 8K.
+ //
+ // Chunk so that at least 2 packets are needed to fill a task.
+ if estimate, err := estimateRemainingSlots(len(keys), lastKey); err == nil {
+ if n := estimate / (2 * (maxRequestSize / 64)); n+1 < chunks {
+ chunks = n + 1
+ }
+ log.Debug("Chunked large contract", "initiators", len(keys), "tail", lastKey, "remaining", estimate, "chunks", chunks)
+ } else {
+ log.Debug("Chunked large contract", "initiators", len(keys), "tail", lastKey, "chunks", chunks)
+ }
+ r := newHashRange(lastKey, chunks)
+ if chunks == 1 {
+ smallStorageGauge.Inc(1)
+ } else {
+ largeStorageGauge.Inc(1)
+ }
+ // Our first task is the one that was just filled by this response.
+ batch := ethdb.HookedBatch{
+ Batch: s.db.NewBatch(),
+ OnPut: func(key []byte, value []byte) {
+ s.storageBytes += common.StorageSize(len(key) + len(value))
+ },
+ }
+ owner := account // local assignment for stacktrie writer closure
+ options := trie.NewStackTrieOptions()
+ options = options.WithWriter(func(path []byte, hash common.Hash, blob []byte) {
+ rawdb.WriteTrieNode(batch, owner, path, hash, blob, s.scheme)
+ })
+ if s.scheme == rawdb.PathScheme {
+ options = options.WithCleaner(func(path []byte) {
+ s.cleanPath(batch, owner, path)
+ })
+ // Keep the left boundary as it's the first range.
+ // Skip the right boundary if it's not the last range.
+ options = options.WithSkipBoundary(false, r.End() != common.MaxHash, boundaryStorageNodesGauge)
+ }
+ tasks = append(tasks, &storageTask{
+ Next: common.Hash{},
+ Last: r.End(),
+ root: acc.Root,
+ genBatch: batch,
+ genTrie: trie.NewStackTrie(options),
+ })
+ for r.Next() {
+ batch := ethdb.HookedBatch{
+ Batch: s.db.NewBatch(),
+ OnPut: func(key []byte, value []byte) {
+ s.storageBytes += common.StorageSize(len(key) + len(value))
+ },
+ }
+ options := trie.NewStackTrieOptions()
+ options = options.WithWriter(func(path []byte, hash common.Hash, blob []byte) {
+ rawdb.WriteTrieNode(batch, owner, path, hash, blob, s.scheme)
+ })
+ if s.scheme == rawdb.PathScheme {
+ // Configure the dangling node cleaner and also filter out boundary nodes
+ // only in the context of the path scheme. Deletion is forbidden in the
+ // hash scheme, as it can disrupt state completeness.
+ options = options.WithCleaner(func(path []byte) {
+ s.cleanPath(batch, owner, path)
+ })
+ // Skip the left boundary as it's not the first range
+ // Skip the right boundary if it's not the last range.
+ options = options.WithSkipBoundary(true, r.End() != common.MaxHash, boundaryStorageNodesGauge)
+ }
+ tasks = append(tasks, &storageTask{
+ Next: r.Start(),
+ Last: r.End(),
+ root: acc.Root,
+ genBatch: batch,
+ genTrie: trie.NewStackTrie(options),
+ })
+ }
+ for _, task := range tasks {
+ log.Debug("Created storage sync task", "account", account, "root", acc.Root, "from", task.Next, "last", task.Last)
+ }
+ res.mainTask.SubTasks[account] = tasks
+
+ // Since we've just created the sub-tasks, this response
+ // is surely for the first one (zero origin)
+ res.subTask = tasks[0]
+ }
+ }
+ // If we're in large contract delivery mode, forward the subtask
+ if res.subTask != nil {
+ // Ensure the response doesn't overflow into the subsequent task
+ last := res.subTask.Last.Big()
+ // Find the first overflowing key. While at it, mark res as complete
+ // if we find the range to include or pass the 'last'
+ index := sort.Search(len(res.hashes[i]), func(k int) bool {
+ cmp := res.hashes[i][k].Big().Cmp(last)
+ if cmp >= 0 {
+ res.cont = false
+ }
+ return cmp > 0
+ })
+ if index >= 0 {
+ // cut off excess
+ res.hashes[i] = res.hashes[i][:index]
+ res.slots[i] = res.slots[i][:index]
+ }
+ // Forward the relevant storage chunk (even if created just now)
+ if res.cont {
+ res.subTask.Next = incHash(res.hashes[i][len(res.hashes[i])-1])
+ } else {
+ res.subTask.done = true
+ }
+ }
+ }
+ // Iterate over all the complete contracts, reconstruct the trie nodes and
+ // push them to disk. If the contract is chunked, the trie nodes will be
+ // reconstructed later.
+ slots += len(res.hashes[i])
+
+ if i < len(res.hashes)-1 || res.subTask == nil {
+ // no need to make local reassignment of account: this closure does not outlive the loop
+ options := trie.NewStackTrieOptions()
+ options = options.WithWriter(func(path []byte, hash common.Hash, blob []byte) {
+ rawdb.WriteTrieNode(batch, account, path, hash, blob, s.scheme)
+ })
+ if s.scheme == rawdb.PathScheme {
+ // Configure the dangling node cleaner only in the context of the
+ // path scheme. Deletion is forbidden in the hash scheme, as it can
+ // disrupt state completeness.
+ //
+ // Notably, boundary nodes can be also kept because the whole storage
+ // trie is complete.
+ options = options.WithCleaner(func(path []byte) {
+ s.cleanPath(batch, account, path)
+ })
+ }
+ tr := trie.NewStackTrie(options)
+ for j := 0; j < len(res.hashes[i]); j++ {
+ tr.Update(res.hashes[i][j][:], res.slots[i][j])
+ }
+ tr.Commit()
+ }
+ // Persist the received storage segments. These flat state maybe
+ // outdated during the sync, but it can be fixed later during the
+ // snapshot generation.
+ for j := 0; j < len(res.hashes[i]); j++ {
+ rawdb.WriteStorageSnapshot(batch, account, res.hashes[i][j], res.slots[i][j])
+
+ // If we're storing large contracts, generate the trie nodes
+ // on the fly to not trash the gluing points
+ if i == len(res.hashes)-1 && res.subTask != nil {
+ res.subTask.genTrie.Update(res.hashes[i][j][:], res.slots[i][j])
+ }
+ }
+ }
+ // Large contracts could have generated new trie nodes, flush them to disk
+ if res.subTask != nil {
+ if res.subTask.done {
+ root := res.subTask.genTrie.Commit()
+ if err := res.subTask.genBatch.Write(); err != nil {
+ log.Error("Failed to persist stack slots", "err", err)
+ }
+ res.subTask.genBatch.Reset()
+
+ // If the chunk's root is an overflown but full delivery,
+ // clear the heal request.
+ accountHash := res.accounts[len(res.accounts)-1]
+ if root == res.subTask.root && rawdb.HasStorageTrieNode(s.db, accountHash, nil, root) {
+ for i, account := range res.mainTask.res.hashes {
+ if account == accountHash {
+ res.mainTask.needHeal[i] = false
+ skipStorageHealingGauge.Inc(1)
+ }
+ }
+ }
+ }
+ if res.subTask.genBatch.ValueSize() > ethdb.IdealBatchSize {
+ if err := res.subTask.genBatch.Write(); err != nil {
+ log.Error("Failed to persist stack slots", "err", err)
+ }
+ res.subTask.genBatch.Reset()
+ }
+ }
+ // Flush anything written just now and update the stats
+ if err := batch.Write(); err != nil {
+ log.Crit("Failed to persist storage slots", "err", err)
+ }
+ s.storageSynced += uint64(slots)
+
+ log.Debug("Persisted set of storage slots", "accounts", len(res.hashes), "slots", slots, "bytes", s.storageBytes-oldStorageBytes)
+
+ // If this delivery completed the last pending task, forward the account task
+ // to the next chunk
+ if res.mainTask.pend == 0 {
+ s.forwardAccountTask(res.mainTask)
+ return
+ }
+ // Some accounts are still incomplete, leave as is for the storage and contract
+ // task assigners to pick up and fill.
+}
+
+// processTrienodeHealResponse integrates an already validated trienode response
+// into the healer tasks.
+func (s *Syncer) processTrienodeHealResponse(res *trienodeHealResponse) {
+ var (
+ start = time.Now()
+ fills int
+ )
+ for i, hash := range res.hashes {
+ node := res.nodes[i]
+
+ // If the trie node was not delivered, reschedule it
+ if node == nil {
+ res.task.trieTasks[res.paths[i]] = res.hashes[i]
+ continue
+ }
+ fills++
+
+ // Push the trie node into the state syncer
+ s.trienodeHealSynced++
+ s.trienodeHealBytes += common.StorageSize(len(node))
+
+ err := s.healer.scheduler.ProcessNode(trie.NodeSyncResult{Path: res.paths[i], Data: node})
+ switch err {
+ case nil:
+ case trie.ErrAlreadyProcessed:
+ s.trienodeHealDups++
+ case trie.ErrNotRequested:
+ s.trienodeHealNops++
+ default:
+ log.Error("Invalid trienode processed", "hash", hash, "err", err)
+ }
+ }
+ s.commitHealer(false)
+
+ // Calculate the processing rate of one filled trie node
+ rate := float64(fills) / (float64(time.Since(start)) / float64(time.Second))
+
+ // Update the currently measured trienode queueing and processing throughput.
+ //
+ // The processing rate needs to be updated uniformly independent if we've
+ // processed 1x100 trie nodes or 100x1 to keep the rate consistent even in
+ // the face of varying network packets. As such, we cannot just measure the
+ // time it took to process N trie nodes and update once, we need one update
+ // per trie node.
+ //
+ // Naively, that would be:
+ //
+ // for i:=0; i time.Second {
+ // Periodically adjust the trie node throttler
+ if float64(pending) > 2*s.trienodeHealRate {
+ s.trienodeHealThrottle *= trienodeHealThrottleIncrease
+ } else {
+ s.trienodeHealThrottle /= trienodeHealThrottleDecrease
+ }
+ if s.trienodeHealThrottle > maxTrienodeHealThrottle {
+ s.trienodeHealThrottle = maxTrienodeHealThrottle
+ } else if s.trienodeHealThrottle < minTrienodeHealThrottle {
+ s.trienodeHealThrottle = minTrienodeHealThrottle
+ }
+ s.trienodeHealThrottled = time.Now()
+
+ log.Debug("Updated trie node heal throttler", "rate", s.trienodeHealRate, "pending", pending, "throttle", s.trienodeHealThrottle)
+ }
+}
+
+func (s *Syncer) commitHealer(force bool) {
+ if !force && s.healer.scheduler.MemSize() < ethdb.IdealBatchSize {
+ return
+ }
+ batch := s.db.NewBatch()
+ if err := s.healer.scheduler.Commit(batch); err != nil {
+ log.Error("Failed to commit healing data", "err", err)
+ }
+ if err := batch.Write(); err != nil {
+ log.Crit("Failed to persist healing data", "err", err)
+ }
+ log.Debug("Persisted set of healing data", "type", "trienodes", "bytes", common.StorageSize(batch.ValueSize()))
+}
+
+// processBytecodeHealResponse integrates an already validated bytecode response
+// into the healer tasks.
+func (s *Syncer) processBytecodeHealResponse(res *bytecodeHealResponse) {
+ for i, hash := range res.hashes {
+ node := res.codes[i]
+
+ // If the trie node was not delivered, reschedule it
+ if node == nil {
+ res.task.codeTasks[hash] = struct{}{}
+ continue
+ }
+ // Push the trie node into the state syncer
+ s.bytecodeHealSynced++
+ s.bytecodeHealBytes += common.StorageSize(len(node))
+
+ err := s.healer.scheduler.ProcessCode(trie.CodeSyncResult{Hash: hash, Data: node})
+ switch err {
+ case nil:
+ case trie.ErrAlreadyProcessed:
+ s.bytecodeHealDups++
+ case trie.ErrNotRequested:
+ s.bytecodeHealNops++
+ default:
+ log.Error("Invalid bytecode processed", "hash", hash, "err", err)
+ }
+ }
+ s.commitHealer(false)
+}
+
+// forwardAccountTask takes a filled account task and persists anything available
+// into the database, after which it forwards the next account marker so that the
+// task's next chunk may be filled.
+func (s *Syncer) forwardAccountTask(task *accountTask) {
+ // Remove any pending delivery
+ res := task.res
+ if res == nil {
+ return // nothing to forward
+ }
+ task.res = nil
+
+ // Persist the received account segments. These flat state maybe
+ // outdated during the sync, but it can be fixed later during the
+ // snapshot generation.
+ oldAccountBytes := s.accountBytes
+
+ batch := ethdb.HookedBatch{
+ Batch: s.db.NewBatch(),
+ OnPut: func(key []byte, value []byte) {
+ s.accountBytes += common.StorageSize(len(key) + len(value))
+ },
+ }
+ for i, hash := range res.hashes {
+ if task.needCode[i] || task.needState[i] {
+ break
+ }
+ slim := types.SlimAccountRLP(*res.accounts[i])
+ rawdb.WriteAccountSnapshot(batch, hash, slim)
+
+ // If the task is complete, drop it into the stack trie to generate
+ // account trie nodes for it
+ if !task.needHeal[i] {
+ full, err := types.FullAccountRLP(slim) // TODO(karalabe): Slim parsing can be omitted
+ if err != nil {
+ panic(err) // Really shouldn't ever happen
+ }
+ task.genTrie.Update(hash[:], full)
+ }
+ }
+ // Flush anything written just now and update the stats
+ if err := batch.Write(); err != nil {
+ log.Crit("Failed to persist accounts", "err", err)
+ }
+ s.accountSynced += uint64(len(res.accounts))
+
+ // Task filling persisted, push it the chunk marker forward to the first
+ // account still missing data.
+ for i, hash := range res.hashes {
+ if task.needCode[i] || task.needState[i] {
+ return
+ }
+ task.Next = incHash(hash)
+ }
+ // All accounts marked as complete, track if the entire task is done
+ task.done = !res.cont
+
+ // Stack trie could have generated trie nodes, push them to disk (we need to
+ // flush after finalizing task.done. It's fine even if we crash and lose this
+ // write as it will only cause more data to be downloaded during heal.
+ if task.done {
+ task.genTrie.Commit()
+ }
+ if task.genBatch.ValueSize() > ethdb.IdealBatchSize || task.done {
+ if err := task.genBatch.Write(); err != nil {
+ log.Error("Failed to persist stack account", "err", err)
+ }
+ task.genBatch.Reset()
+ }
+ log.Debug("Persisted range of accounts", "accounts", len(res.accounts), "bytes", s.accountBytes-oldAccountBytes)
+}
+
+// OnAccounts is a callback method to invoke when a range of accounts are
+// received from a remote peer.
+func (s *Syncer) OnAccounts(peer SyncPeer, id uint64, hashes []common.Hash, accounts [][]byte, proof [][]byte) error {
+ size := common.StorageSize(len(hashes) * common.HashLength)
+ for _, account := range accounts {
+ size += common.StorageSize(len(account))
+ }
+ for _, node := range proof {
+ size += common.StorageSize(len(node))
+ }
+ logger := peer.Log().New("reqid", id)
+ logger.Trace("Delivering range of accounts", "hashes", len(hashes), "accounts", len(accounts), "proofs", len(proof), "bytes", size)
+
+ // Whether or not the response is valid, we can mark the peer as idle and
+ // notify the scheduler to assign a new task. If the response is invalid,
+ // we'll drop the peer in a bit.
+ defer func() {
+ s.lock.Lock()
+ defer s.lock.Unlock()
+ if _, ok := s.peers[peer.ID()]; ok {
+ s.accountIdlers[peer.ID()] = struct{}{}
+ }
+ select {
+ case s.update <- struct{}{}:
+ default:
+ }
+ }()
+ s.lock.Lock()
+ // Ensure the response is for a valid request
+ req, ok := s.accountReqs[id]
+ if !ok {
+ // Request stale, perhaps the peer timed out but came through in the end
+ logger.Warn("Unexpected account range packet")
+ s.lock.Unlock()
+ return nil
+ }
+ delete(s.accountReqs, id)
+ s.rates.Update(peer.ID(), AccountRangeMsg, time.Since(req.time), int(size))
+
+ // Clean up the request timeout timer, we'll see how to proceed further based
+ // on the actual delivered content
+ if !req.timeout.Stop() {
+ // The timeout is already triggered, and this request will be reverted+rescheduled
+ s.lock.Unlock()
+ return nil
+ }
+ // Response is valid, but check if peer is signalling that it does not have
+ // the requested data. For account range queries that means the state being
+ // retrieved was either already pruned remotely, or the peer is not yet
+ // synced to our head.
+ if len(hashes) == 0 && len(accounts) == 0 && len(proof) == 0 {
+ logger.Debug("Peer rejected account range request", "root", s.root)
+ s.statelessPeers[peer.ID()] = struct{}{}
+ s.lock.Unlock()
+
+ // Signal this request as failed, and ready for rescheduling
+ s.scheduleRevertAccountRequest(req)
+ return nil
+ }
+ root := s.root
+ s.lock.Unlock()
+
+ // Reconstruct a partial trie from the response and verify it
+ keys := make([][]byte, len(hashes))
+ for i, key := range hashes {
+ keys[i] = common.CopyBytes(key[:])
+ }
+ nodes := make(trienode.ProofList, len(proof))
+ for i, node := range proof {
+ nodes[i] = node
+ }
+ cont, err := trie.VerifyRangeProof(root, req.origin[:], keys, accounts, nodes.Set())
+ if err != nil {
+ logger.Warn("Account range failed proof", "err", err)
+ // Signal this request as failed, and ready for rescheduling
+ s.scheduleRevertAccountRequest(req)
+ return err
+ }
+ accs := make([]*types.StateAccount, len(accounts))
+ for i, account := range accounts {
+ acc := new(types.StateAccount)
+ if err := rlp.DecodeBytes(account, acc); err != nil {
+ panic(err) // We created these blobs, we must be able to decode them
+ }
+ accs[i] = acc
+ }
+ response := &accountResponse{
+ task: req.task,
+ hashes: hashes,
+ accounts: accs,
+ cont: cont,
+ }
+ select {
+ case req.deliver <- response:
+ case <-req.cancel:
+ case <-req.stale:
+ }
+ return nil
+}
+
+// OnByteCodes is a callback method to invoke when a batch of contract
+// bytes codes are received from a remote peer.
+func (s *Syncer) OnByteCodes(peer SyncPeer, id uint64, bytecodes [][]byte) error {
+ s.lock.RLock()
+ syncing := !s.snapped
+ s.lock.RUnlock()
+
+ if syncing {
+ return s.onByteCodes(peer, id, bytecodes)
+ }
+ return s.onHealByteCodes(peer, id, bytecodes)
+}
+
+// onByteCodes is a callback method to invoke when a batch of contract
+// bytes codes are received from a remote peer in the syncing phase.
+func (s *Syncer) onByteCodes(peer SyncPeer, id uint64, bytecodes [][]byte) error {
+ var size common.StorageSize
+ for _, code := range bytecodes {
+ size += common.StorageSize(len(code))
+ }
+ logger := peer.Log().New("reqid", id)
+ logger.Trace("Delivering set of bytecodes", "bytecodes", len(bytecodes), "bytes", size)
+
+ // Whether or not the response is valid, we can mark the peer as idle and
+ // notify the scheduler to assign a new task. If the response is invalid,
+ // we'll drop the peer in a bit.
+ defer func() {
+ s.lock.Lock()
+ defer s.lock.Unlock()
+ if _, ok := s.peers[peer.ID()]; ok {
+ s.bytecodeIdlers[peer.ID()] = struct{}{}
+ }
+ select {
+ case s.update <- struct{}{}:
+ default:
+ }
+ }()
+ s.lock.Lock()
+ // Ensure the response is for a valid request
+ req, ok := s.bytecodeReqs[id]
+ if !ok {
+ // Request stale, perhaps the peer timed out but came through in the end
+ logger.Warn("Unexpected bytecode packet")
+ s.lock.Unlock()
+ return nil
+ }
+ delete(s.bytecodeReqs, id)
+ s.rates.Update(peer.ID(), ByteCodesMsg, time.Since(req.time), len(bytecodes))
+
+ // Clean up the request timeout timer, we'll see how to proceed further based
+ // on the actual delivered content
+ if !req.timeout.Stop() {
+ // The timeout is already triggered, and this request will be reverted+rescheduled
+ s.lock.Unlock()
+ return nil
+ }
+
+ // Response is valid, but check if peer is signalling that it does not have
+ // the requested data. For bytecode range queries that means the peer is not
+ // yet synced.
+ if len(bytecodes) == 0 {
+ logger.Debug("Peer rejected bytecode request")
+ s.statelessPeers[peer.ID()] = struct{}{}
+ s.lock.Unlock()
+
+ // Signal this request as failed, and ready for rescheduling
+ s.scheduleRevertBytecodeRequest(req)
+ return nil
+ }
+ s.lock.Unlock()
+
+ // Cross reference the requested bytecodes with the response to find gaps
+ // that the serving node is missing
+ hasher := sha3.NewLegacyKeccak256().(crypto.KeccakState)
+ hash := make([]byte, 32)
+
+ codes := make([][]byte, len(req.hashes))
+ for i, j := 0, 0; i < len(bytecodes); i++ {
+ // Find the next hash that we've been served, leaving misses with nils
+ hasher.Reset()
+ hasher.Write(bytecodes[i])
+ hasher.Read(hash)
+
+ for j < len(req.hashes) && !bytes.Equal(hash, req.hashes[j][:]) {
+ j++
+ }
+ if j < len(req.hashes) {
+ codes[j] = bytecodes[i]
+ j++
+ continue
+ }
+ // We've either ran out of hashes, or got unrequested data
+ logger.Warn("Unexpected bytecodes", "count", len(bytecodes)-i)
+ // Signal this request as failed, and ready for rescheduling
+ s.scheduleRevertBytecodeRequest(req)
+ return errors.New("unexpected bytecode")
+ }
+ // Response validated, send it to the scheduler for filling
+ response := &bytecodeResponse{
+ task: req.task,
+ hashes: req.hashes,
+ codes: codes,
+ }
+ select {
+ case req.deliver <- response:
+ case <-req.cancel:
+ case <-req.stale:
+ }
+ return nil
+}
+
+// OnStorage is a callback method to invoke when ranges of storage slots
+// are received from a remote peer.
+func (s *Syncer) OnStorage(peer SyncPeer, id uint64, hashes [][]common.Hash, slots [][][]byte, proof [][]byte) error {
+ // Gather some trace stats to aid in debugging issues
+ var (
+ hashCount int
+ slotCount int
+ size common.StorageSize
+ )
+ for _, hashset := range hashes {
+ size += common.StorageSize(common.HashLength * len(hashset))
+ hashCount += len(hashset)
+ }
+ for _, slotset := range slots {
+ for _, slot := range slotset {
+ size += common.StorageSize(len(slot))
+ }
+ slotCount += len(slotset)
+ }
+ for _, node := range proof {
+ size += common.StorageSize(len(node))
+ }
+ logger := peer.Log().New("reqid", id)
+ logger.Trace("Delivering ranges of storage slots", "accounts", len(hashes), "hashes", hashCount, "slots", slotCount, "proofs", len(proof), "size", size)
+
+ // Whether or not the response is valid, we can mark the peer as idle and
+ // notify the scheduler to assign a new task. If the response is invalid,
+ // we'll drop the peer in a bit.
+ defer func() {
+ s.lock.Lock()
+ defer s.lock.Unlock()
+ if _, ok := s.peers[peer.ID()]; ok {
+ s.storageIdlers[peer.ID()] = struct{}{}
+ }
+ select {
+ case s.update <- struct{}{}:
+ default:
+ }
+ }()
+ s.lock.Lock()
+ // Ensure the response is for a valid request
+ req, ok := s.storageReqs[id]
+ if !ok {
+ // Request stale, perhaps the peer timed out but came through in the end
+ logger.Warn("Unexpected storage ranges packet")
+ s.lock.Unlock()
+ return nil
+ }
+ delete(s.storageReqs, id)
+ s.rates.Update(peer.ID(), StorageRangesMsg, time.Since(req.time), int(size))
+
+ // Clean up the request timeout timer, we'll see how to proceed further based
+ // on the actual delivered content
+ if !req.timeout.Stop() {
+ // The timeout is already triggered, and this request will be reverted+rescheduled
+ s.lock.Unlock()
+ return nil
+ }
+
+ // Reject the response if the hash sets and slot sets don't match, or if the
+ // peer sent more data than requested.
+ if len(hashes) != len(slots) {
+ s.lock.Unlock()
+ s.scheduleRevertStorageRequest(req) // reschedule request
+ logger.Warn("Hash and slot set size mismatch", "hashset", len(hashes), "slotset", len(slots))
+ return errors.New("hash and slot set size mismatch")
+ }
+ if len(hashes) > len(req.accounts) {
+ s.lock.Unlock()
+ s.scheduleRevertStorageRequest(req) // reschedule request
+ logger.Warn("Hash set larger than requested", "hashset", len(hashes), "requested", len(req.accounts))
+ return errors.New("hash set larger than requested")
+ }
+ // Response is valid, but check if peer is signalling that it does not have
+ // the requested data. For storage range queries that means the state being
+ // retrieved was either already pruned remotely, or the peer is not yet
+ // synced to our head.
+ if len(hashes) == 0 && len(proof) == 0 {
+ logger.Debug("Peer rejected storage request")
+ s.statelessPeers[peer.ID()] = struct{}{}
+ s.lock.Unlock()
+ s.scheduleRevertStorageRequest(req) // reschedule request
+ return nil
+ }
+ s.lock.Unlock()
+
+ // Reconstruct the partial tries from the response and verify them
+ var cont bool
+
+ // If a proof was attached while the response is empty, it indicates that the
+ // requested range specified with 'origin' is empty. Construct an empty state
+ // response locally to finalize the range.
+ if len(hashes) == 0 && len(proof) > 0 {
+ hashes = append(hashes, []common.Hash{})
+ slots = append(slots, [][]byte{})
+ }
+ for i := 0; i < len(hashes); i++ {
+ // Convert the keys and proofs into an internal format
+ keys := make([][]byte, len(hashes[i]))
+ for j, key := range hashes[i] {
+ keys[j] = common.CopyBytes(key[:])
+ }
+ nodes := make(trienode.ProofList, 0, len(proof))
+ if i == len(hashes)-1 {
+ for _, node := range proof {
+ nodes = append(nodes, node)
+ }
+ }
+ var err error
+ if len(nodes) == 0 {
+ // No proof has been attached, the response must cover the entire key
+ // space and hash to the origin root.
+ _, err = trie.VerifyRangeProof(req.roots[i], nil, keys, slots[i], nil)
+ if err != nil {
+ s.scheduleRevertStorageRequest(req) // reschedule request
+ logger.Warn("Storage slots failed proof", "err", err)
+ return err
+ }
+ } else {
+ // A proof was attached, the response is only partial, check that the
+ // returned data is indeed part of the storage trie
+ proofdb := nodes.Set()
+
+ cont, err = trie.VerifyRangeProof(req.roots[i], req.origin[:], keys, slots[i], proofdb)
+ if err != nil {
+ s.scheduleRevertStorageRequest(req) // reschedule request
+ logger.Warn("Storage range failed proof", "err", err)
+ return err
+ }
+ }
+ }
+ // Partial tries reconstructed, send them to the scheduler for storage filling
+ response := &storageResponse{
+ mainTask: req.mainTask,
+ subTask: req.subTask,
+ accounts: req.accounts,
+ roots: req.roots,
+ hashes: hashes,
+ slots: slots,
+ cont: cont,
+ }
+ select {
+ case req.deliver <- response:
+ case <-req.cancel:
+ case <-req.stale:
+ }
+ return nil
+}
+
+// OnTrieNodes is a callback method to invoke when a batch of trie nodes
+// are received from a remote peer.
+func (s *Syncer) OnTrieNodes(peer SyncPeer, id uint64, trienodes [][]byte) error {
+ var size common.StorageSize
+ for _, node := range trienodes {
+ size += common.StorageSize(len(node))
+ }
+ logger := peer.Log().New("reqid", id)
+ logger.Trace("Delivering set of healing trienodes", "trienodes", len(trienodes), "bytes", size)
+
+ // Whether or not the response is valid, we can mark the peer as idle and
+ // notify the scheduler to assign a new task. If the response is invalid,
+ // we'll drop the peer in a bit.
+ defer func() {
+ s.lock.Lock()
+ defer s.lock.Unlock()
+ if _, ok := s.peers[peer.ID()]; ok {
+ s.trienodeHealIdlers[peer.ID()] = struct{}{}
+ }
+ select {
+ case s.update <- struct{}{}:
+ default:
+ }
+ }()
+ s.lock.Lock()
+ // Ensure the response is for a valid request
+ req, ok := s.trienodeHealReqs[id]
+ if !ok {
+ // Request stale, perhaps the peer timed out but came through in the end
+ logger.Warn("Unexpected trienode heal packet")
+ s.lock.Unlock()
+ return nil
+ }
+ delete(s.trienodeHealReqs, id)
+ s.rates.Update(peer.ID(), TrieNodesMsg, time.Since(req.time), len(trienodes))
+
+ // Clean up the request timeout timer, we'll see how to proceed further based
+ // on the actual delivered content
+ if !req.timeout.Stop() {
+ // The timeout is already triggered, and this request will be reverted+rescheduled
+ s.lock.Unlock()
+ return nil
+ }
+
+ // Response is valid, but check if peer is signalling that it does not have
+ // the requested data. For bytecode range queries that means the peer is not
+ // yet synced.
+ if len(trienodes) == 0 {
+ logger.Debug("Peer rejected trienode heal request")
+ s.statelessPeers[peer.ID()] = struct{}{}
+ s.lock.Unlock()
+
+ // Signal this request as failed, and ready for rescheduling
+ s.scheduleRevertTrienodeHealRequest(req)
+ return nil
+ }
+ s.lock.Unlock()
+
+ // Cross reference the requested trienodes with the response to find gaps
+ // that the serving node is missing
+ var (
+ hasher = sha3.NewLegacyKeccak256().(crypto.KeccakState)
+ hash = make([]byte, 32)
+ nodes = make([][]byte, len(req.hashes))
+ fills uint64
+ )
+ for i, j := 0, 0; i < len(trienodes); i++ {
+ // Find the next hash that we've been served, leaving misses with nils
+ hasher.Reset()
+ hasher.Write(trienodes[i])
+ hasher.Read(hash)
+
+ for j < len(req.hashes) && !bytes.Equal(hash, req.hashes[j][:]) {
+ j++
+ }
+ if j < len(req.hashes) {
+ nodes[j] = trienodes[i]
+ fills++
+ j++
+ continue
+ }
+ // We've either ran out of hashes, or got unrequested data
+ logger.Warn("Unexpected healing trienodes", "count", len(trienodes)-i)
+
+ // Signal this request as failed, and ready for rescheduling
+ s.scheduleRevertTrienodeHealRequest(req)
+ return errors.New("unexpected healing trienode")
+ }
+ // Response validated, send it to the scheduler for filling
+ s.trienodeHealPend.Add(fills)
+ defer func() {
+ s.trienodeHealPend.Add(^(fills - 1))
+ }()
+ response := &trienodeHealResponse{
+ paths: req.paths,
+ task: req.task,
+ hashes: req.hashes,
+ nodes: nodes,
+ }
+ select {
+ case req.deliver <- response:
+ case <-req.cancel:
+ case <-req.stale:
+ }
+ return nil
+}
+
+// onHealByteCodes is a callback method to invoke when a batch of contract
+// bytes codes are received from a remote peer in the healing phase.
+func (s *Syncer) onHealByteCodes(peer SyncPeer, id uint64, bytecodes [][]byte) error {
+ var size common.StorageSize
+ for _, code := range bytecodes {
+ size += common.StorageSize(len(code))
+ }
+ logger := peer.Log().New("reqid", id)
+ logger.Trace("Delivering set of healing bytecodes", "bytecodes", len(bytecodes), "bytes", size)
+
+ // Whether or not the response is valid, we can mark the peer as idle and
+ // notify the scheduler to assign a new task. If the response is invalid,
+ // we'll drop the peer in a bit.
+ defer func() {
+ s.lock.Lock()
+ defer s.lock.Unlock()
+ if _, ok := s.peers[peer.ID()]; ok {
+ s.bytecodeHealIdlers[peer.ID()] = struct{}{}
+ }
+ select {
+ case s.update <- struct{}{}:
+ default:
+ }
+ }()
+ s.lock.Lock()
+ // Ensure the response is for a valid request
+ req, ok := s.bytecodeHealReqs[id]
+ if !ok {
+ // Request stale, perhaps the peer timed out but came through in the end
+ logger.Warn("Unexpected bytecode heal packet")
+ s.lock.Unlock()
+ return nil
+ }
+ delete(s.bytecodeHealReqs, id)
+ s.rates.Update(peer.ID(), ByteCodesMsg, time.Since(req.time), len(bytecodes))
+
+ // Clean up the request timeout timer, we'll see how to proceed further based
+ // on the actual delivered content
+ if !req.timeout.Stop() {
+ // The timeout is already triggered, and this request will be reverted+rescheduled
+ s.lock.Unlock()
+ return nil
+ }
+
+ // Response is valid, but check if peer is signalling that it does not have
+ // the requested data. For bytecode range queries that means the peer is not
+ // yet synced.
+ if len(bytecodes) == 0 {
+ logger.Debug("Peer rejected bytecode heal request")
+ s.statelessPeers[peer.ID()] = struct{}{}
+ s.lock.Unlock()
+
+ // Signal this request as failed, and ready for rescheduling
+ s.scheduleRevertBytecodeHealRequest(req)
+ return nil
+ }
+ s.lock.Unlock()
+
+ // Cross reference the requested bytecodes with the response to find gaps
+ // that the serving node is missing
+ hasher := sha3.NewLegacyKeccak256().(crypto.KeccakState)
+ hash := make([]byte, 32)
+
+ codes := make([][]byte, len(req.hashes))
+ for i, j := 0, 0; i < len(bytecodes); i++ {
+ // Find the next hash that we've been served, leaving misses with nils
+ hasher.Reset()
+ hasher.Write(bytecodes[i])
+ hasher.Read(hash)
+
+ for j < len(req.hashes) && !bytes.Equal(hash, req.hashes[j][:]) {
+ j++
+ }
+ if j < len(req.hashes) {
+ codes[j] = bytecodes[i]
+ j++
+ continue
+ }
+ // We've either ran out of hashes, or got unrequested data
+ logger.Warn("Unexpected healing bytecodes", "count", len(bytecodes)-i)
+ // Signal this request as failed, and ready for rescheduling
+ s.scheduleRevertBytecodeHealRequest(req)
+ return errors.New("unexpected healing bytecode")
+ }
+ // Response validated, send it to the scheduler for filling
+ response := &bytecodeHealResponse{
+ task: req.task,
+ hashes: req.hashes,
+ codes: codes,
+ }
+ select {
+ case req.deliver <- response:
+ case <-req.cancel:
+ case <-req.stale:
+ }
+ return nil
+}
+
+// onHealState is a callback method to invoke when a flat state(account
+// or storage slot) is downloaded during the healing stage. The flat states
+// can be persisted blindly and can be fixed later in the generation stage.
+// Note it's not concurrent safe, please handle the concurrent issue outside.
+func (s *Syncer) onHealState(paths [][]byte, value []byte) error {
+ if len(paths) == 1 {
+ var account types.StateAccount
+ if err := rlp.DecodeBytes(value, &account); err != nil {
+ return nil // Returning the error here would drop the remote peer
+ }
+ blob := types.SlimAccountRLP(account)
+ rawdb.WriteAccountSnapshot(s.stateWriter, common.BytesToHash(paths[0]), blob)
+ s.accountHealed += 1
+ s.accountHealedBytes += common.StorageSize(1 + common.HashLength + len(blob))
+ }
+ if len(paths) == 2 {
+ rawdb.WriteStorageSnapshot(s.stateWriter, common.BytesToHash(paths[0]), common.BytesToHash(paths[1]), value)
+ s.storageHealed += 1
+ s.storageHealedBytes += common.StorageSize(1 + 2*common.HashLength + len(value))
+ }
+ if s.stateWriter.ValueSize() > ethdb.IdealBatchSize {
+ s.stateWriter.Write() // It's fine to ignore the error here
+ s.stateWriter.Reset()
+ }
+ return nil
+}
+
+// hashSpace is the total size of the 256 bit hash space for accounts.
+var hashSpace = new(big.Int).Exp(common.Big2, common.Big256, nil)
+
+// report calculates various status reports and provides it to the user.
+func (s *Syncer) report(force bool) {
+ if len(s.tasks) > 0 {
+ s.reportSyncProgress(force)
+ return
+ }
+ s.reportHealProgress(force)
+}
+
+// reportSyncProgress calculates various status reports and provides it to the user.
+func (s *Syncer) reportSyncProgress(force bool) {
+ // Don't report all the events, just occasionally
+ if !force && time.Since(s.logTime) < 8*time.Second {
+ return
+ }
+ // Don't report anything until we have a meaningful progress
+ synced := s.accountBytes + s.bytecodeBytes + s.storageBytes
+ if synced == 0 {
+ return
+ }
+ accountGaps := new(big.Int)
+ for _, task := range s.tasks {
+ accountGaps.Add(accountGaps, new(big.Int).Sub(task.Last.Big(), task.Next.Big()))
+ }
+ accountFills := new(big.Int).Sub(hashSpace, accountGaps)
+ if accountFills.BitLen() == 0 {
+ return
+ }
+ s.logTime = time.Now()
+ estBytes := float64(new(big.Int).Div(
+ new(big.Int).Mul(new(big.Int).SetUint64(uint64(synced)), hashSpace),
+ accountFills,
+ ).Uint64())
+ // Don't report anything until we have a meaningful progress
+ if estBytes < 1.0 {
+ return
+ }
+ elapsed := time.Since(s.startTime)
+ estTime := elapsed / time.Duration(synced) * time.Duration(estBytes)
+
+ // Create a mega progress report
+ var (
+ progress = fmt.Sprintf("%.2f%%", float64(synced)*100/estBytes)
+ accounts = fmt.Sprintf("%v@%v", log.FormatLogfmtUint64(s.accountSynced), s.accountBytes.TerminalString())
+ storage = fmt.Sprintf("%v@%v", log.FormatLogfmtUint64(s.storageSynced), s.storageBytes.TerminalString())
+ bytecode = fmt.Sprintf("%v@%v", log.FormatLogfmtUint64(s.bytecodeSynced), s.bytecodeBytes.TerminalString())
+ )
+ log.Info("Syncing: state download in progress", "synced", progress, "state", synced,
+ "accounts", accounts, "slots", storage, "codes", bytecode, "eta", common.PrettyDuration(estTime-elapsed))
+}
+
+// reportHealProgress calculates various status reports and provides it to the user.
+func (s *Syncer) reportHealProgress(force bool) {
+ // Don't report all the events, just occasionally
+ if !force && time.Since(s.logTime) < 8*time.Second {
+ return
+ }
+ s.logTime = time.Now()
+
+ // Create a mega progress report
+ var (
+ trienode = fmt.Sprintf("%v@%v", log.FormatLogfmtUint64(s.trienodeHealSynced), s.trienodeHealBytes.TerminalString())
+ bytecode = fmt.Sprintf("%v@%v", log.FormatLogfmtUint64(s.bytecodeHealSynced), s.bytecodeHealBytes.TerminalString())
+ accounts = fmt.Sprintf("%v@%v", log.FormatLogfmtUint64(s.accountHealed), s.accountHealedBytes.TerminalString())
+ storage = fmt.Sprintf("%v@%v", log.FormatLogfmtUint64(s.storageHealed), s.storageHealedBytes.TerminalString())
+ )
+ log.Info("Syncing: state healing in progress", "accounts", accounts, "slots", storage,
+ "codes", bytecode, "nodes", trienode, "pending", s.healer.scheduler.Pending())
+}
+
+// estimateRemainingSlots tries to determine roughly how many slots are left in
+// a contract storage, based on the number of keys and the last hash. This method
+// assumes that the hashes are lexicographically ordered and evenly distributed.
+func estimateRemainingSlots(hashes int, last common.Hash) (uint64, error) {
+ if last == (common.Hash{}) {
+ return 0, errors.New("last hash empty")
+ }
+ space := new(big.Int).Mul(math.MaxBig256, big.NewInt(int64(hashes)))
+ space.Div(space, last.Big())
+ if !space.IsUint64() {
+ // Gigantic address space probably due to too few or malicious slots
+ return 0, errors.New("too few slots for estimation")
+ }
+ return space.Uint64() - uint64(hashes), nil
+}
+
+// capacitySort implements the Sort interface, allowing sorting by peer message
+// throughput. Note, callers should use sort.Reverse to get the desired effect
+// of highest capacity being at the front.
+type capacitySort struct {
+ ids []string
+ caps []int
+}
+
+func (s *capacitySort) Len() int {
+ return len(s.ids)
+}
+
+func (s *capacitySort) Less(i, j int) bool {
+ return s.caps[i] < s.caps[j]
+}
+
+func (s *capacitySort) Swap(i, j int) {
+ s.ids[i], s.ids[j] = s.ids[j], s.ids[i]
+ s.caps[i], s.caps[j] = s.caps[j], s.caps[i]
+}
+
+// healRequestSort implements the Sort interface, allowing sorting trienode
+// heal requests, which is a prerequisite for merging storage-requests.
+type healRequestSort struct {
+ paths []string
+ hashes []common.Hash
+ syncPaths []trie.SyncPath
+}
+
+func (t *healRequestSort) Len() int {
+ return len(t.hashes)
+}
+
+func (t *healRequestSort) Less(i, j int) bool {
+ a := t.syncPaths[i]
+ b := t.syncPaths[j]
+ switch bytes.Compare(a[0], b[0]) {
+ case -1:
+ return true
+ case 1:
+ return false
+ }
+ // identical first part
+ if len(a) < len(b) {
+ return true
+ }
+ if len(b) < len(a) {
+ return false
+ }
+ if len(a) == 2 {
+ return bytes.Compare(a[1], b[1]) < 0
+ }
+ return false
+}
+
+func (t *healRequestSort) Swap(i, j int) {
+ t.paths[i], t.paths[j] = t.paths[j], t.paths[i]
+ t.hashes[i], t.hashes[j] = t.hashes[j], t.hashes[i]
+ t.syncPaths[i], t.syncPaths[j] = t.syncPaths[j], t.syncPaths[i]
+}
+
+// Merge merges the pathsets, so that several storage requests concerning the
+// same account are merged into one, to reduce bandwidth.
+// OBS: This operation is moot if t has not first been sorted.
+func (t *healRequestSort) Merge() []TrieNodePathSet {
+ var result []TrieNodePathSet
+ for _, path := range t.syncPaths {
+ pathset := TrieNodePathSet(path)
+ if len(path) == 1 {
+ // It's an account reference.
+ result = append(result, pathset)
+ } else {
+ // It's a storage reference.
+ end := len(result) - 1
+ if len(result) == 0 || !bytes.Equal(pathset[0], result[end][0]) {
+ // The account doesn't match last, create a new entry.
+ result = append(result, pathset)
+ } else {
+ // It's the same account as the previous one, add to the storage
+ // paths of that request.
+ result[end] = append(result[end], pathset[1])
+ }
+ }
+ }
+ return result
+}
+
+// sortByAccountPath takes hashes and paths, and sorts them. After that, it generates
+// the TrieNodePaths and merges paths which belongs to the same account path.
+func sortByAccountPath(paths []string, hashes []common.Hash) ([]string, []common.Hash, []trie.SyncPath, []TrieNodePathSet) {
+ var syncPaths []trie.SyncPath
+ for _, path := range paths {
+ syncPaths = append(syncPaths, trie.NewSyncPath([]byte(path)))
+ }
+ n := &healRequestSort{paths, hashes, syncPaths}
+ sort.Sort(n)
+ pathsets := n.Merge()
+ return n.paths, n.hashes, n.syncPaths, pathsets
+}
diff --git a/eth/protocols/snap/sync_test.go b/eth/protocols/snap/sync_test.go
new file mode 100644
index 0000000000..056a6bfa1a
--- /dev/null
+++ b/eth/protocols/snap/sync_test.go
@@ -0,0 +1,1971 @@
+// Copyright 2021 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see .
+
+package snap
+
+import (
+ "bytes"
+ "crypto/rand"
+ "encoding/binary"
+ "fmt"
+ "math/big"
+ mrand "math/rand"
+ "sync"
+ "testing"
+ "time"
+
+ "github.com/ava-labs/coreth/core/rawdb"
+ "github.com/ava-labs/coreth/triedb/pathdb"
+ "github.com/ava-labs/libevm/common"
+ "github.com/ava-labs/libevm/core/types"
+ "github.com/ava-labs/libevm/crypto"
+ "github.com/ava-labs/libevm/ethdb"
+ "github.com/ava-labs/libevm/log"
+ "github.com/ava-labs/libevm/rlp"
+ "github.com/ava-labs/libevm/trie"
+ "github.com/ava-labs/libevm/trie/testutil"
+ "github.com/ava-labs/libevm/trie/trienode"
+ "github.com/ava-labs/libevm/triedb"
+ "github.com/holiman/uint256"
+ "golang.org/x/crypto/sha3"
+ "golang.org/x/exp/slices"
+)
+
+func TestHashing(t *testing.T) {
+ t.Parallel()
+
+ var bytecodes = make([][]byte, 10)
+ for i := 0; i < len(bytecodes); i++ {
+ buf := make([]byte, 100)
+ rand.Read(buf)
+ bytecodes[i] = buf
+ }
+ var want, got string
+ var old = func() {
+ hasher := sha3.NewLegacyKeccak256()
+ for i := 0; i < len(bytecodes); i++ {
+ hasher.Reset()
+ hasher.Write(bytecodes[i])
+ hash := hasher.Sum(nil)
+ got = fmt.Sprintf("%v\n%v", got, hash)
+ }
+ }
+ var new = func() {
+ hasher := sha3.NewLegacyKeccak256().(crypto.KeccakState)
+ var hash = make([]byte, 32)
+ for i := 0; i < len(bytecodes); i++ {
+ hasher.Reset()
+ hasher.Write(bytecodes[i])
+ hasher.Read(hash)
+ want = fmt.Sprintf("%v\n%v", want, hash)
+ }
+ }
+ old()
+ new()
+ if want != got {
+ t.Errorf("want\n%v\ngot\n%v\n", want, got)
+ }
+}
+
+func BenchmarkHashing(b *testing.B) {
+ var bytecodes = make([][]byte, 10000)
+ for i := 0; i < len(bytecodes); i++ {
+ buf := make([]byte, 100)
+ rand.Read(buf)
+ bytecodes[i] = buf
+ }
+ var old = func() {
+ hasher := sha3.NewLegacyKeccak256()
+ for i := 0; i < len(bytecodes); i++ {
+ hasher.Reset()
+ hasher.Write(bytecodes[i])
+ hasher.Sum(nil)
+ }
+ }
+ var new = func() {
+ hasher := sha3.NewLegacyKeccak256().(crypto.KeccakState)
+ var hash = make([]byte, 32)
+ for i := 0; i < len(bytecodes); i++ {
+ hasher.Reset()
+ hasher.Write(bytecodes[i])
+ hasher.Read(hash)
+ }
+ }
+ b.Run("old", func(b *testing.B) {
+ b.ReportAllocs()
+ for i := 0; i < b.N; i++ {
+ old()
+ }
+ })
+ b.Run("new", func(b *testing.B) {
+ b.ReportAllocs()
+ for i := 0; i < b.N; i++ {
+ new()
+ }
+ })
+}
+
+type (
+ accountHandlerFunc func(t *testPeer, requestId uint64, root common.Hash, origin common.Hash, limit common.Hash, cap uint64) error
+ storageHandlerFunc func(t *testPeer, requestId uint64, root common.Hash, accounts []common.Hash, origin, limit []byte, max uint64) error
+ trieHandlerFunc func(t *testPeer, requestId uint64, root common.Hash, paths []TrieNodePathSet, cap uint64) error
+ codeHandlerFunc func(t *testPeer, id uint64, hashes []common.Hash, max uint64) error
+)
+
+type testPeer struct {
+ id string
+ test *testing.T
+ remote *Syncer
+ logger log.Logger
+ accountTrie *trie.Trie
+ accountValues []*kv
+ storageTries map[common.Hash]*trie.Trie
+ storageValues map[common.Hash][]*kv
+
+ accountRequestHandler accountHandlerFunc
+ storageRequestHandler storageHandlerFunc
+ trieRequestHandler trieHandlerFunc
+ codeRequestHandler codeHandlerFunc
+ term func()
+
+ // counters
+ nAccountRequests int
+ nStorageRequests int
+ nBytecodeRequests int
+ nTrienodeRequests int
+}
+
+func newTestPeer(id string, t *testing.T, term func()) *testPeer {
+ peer := &testPeer{
+ id: id,
+ test: t,
+ logger: log.New("id", id),
+ accountRequestHandler: defaultAccountRequestHandler,
+ trieRequestHandler: defaultTrieRequestHandler,
+ storageRequestHandler: defaultStorageRequestHandler,
+ codeRequestHandler: defaultCodeRequestHandler,
+ term: term,
+ }
+ //stderrHandler := log.StreamHandler(os.Stderr, log.TerminalFormat(true))
+ //peer.logger.SetHandler(stderrHandler)
+ return peer
+}
+
+func (t *testPeer) setStorageTries(tries map[common.Hash]*trie.Trie) {
+ t.storageTries = make(map[common.Hash]*trie.Trie)
+ for root, trie := range tries {
+ t.storageTries[root] = trie.Copy()
+ }
+}
+
+func (t *testPeer) ID() string { return t.id }
+func (t *testPeer) Log() log.Logger { return t.logger }
+
+func (t *testPeer) Stats() string {
+ return fmt.Sprintf(`Account requests: %d
+Storage requests: %d
+Bytecode requests: %d
+Trienode requests: %d
+`, t.nAccountRequests, t.nStorageRequests, t.nBytecodeRequests, t.nTrienodeRequests)
+}
+
+func (t *testPeer) RequestAccountRange(id uint64, root, origin, limit common.Hash, bytes uint64) error {
+ t.logger.Trace("Fetching range of accounts", "reqid", id, "root", root, "origin", origin, "limit", limit, "bytes", common.StorageSize(bytes))
+ t.nAccountRequests++
+ go t.accountRequestHandler(t, id, root, origin, limit, bytes)
+ return nil
+}
+
+func (t *testPeer) RequestTrieNodes(id uint64, root common.Hash, paths []TrieNodePathSet, bytes uint64) error {
+ t.logger.Trace("Fetching set of trie nodes", "reqid", id, "root", root, "pathsets", len(paths), "bytes", common.StorageSize(bytes))
+ t.nTrienodeRequests++
+ go t.trieRequestHandler(t, id, root, paths, bytes)
+ return nil
+}
+
+func (t *testPeer) RequestStorageRanges(id uint64, root common.Hash, accounts []common.Hash, origin, limit []byte, bytes uint64) error {
+ t.nStorageRequests++
+ if len(accounts) == 1 && origin != nil {
+ t.logger.Trace("Fetching range of large storage slots", "reqid", id, "root", root, "account", accounts[0], "origin", common.BytesToHash(origin), "limit", common.BytesToHash(limit), "bytes", common.StorageSize(bytes))
+ } else {
+ t.logger.Trace("Fetching ranges of small storage slots", "reqid", id, "root", root, "accounts", len(accounts), "first", accounts[0], "bytes", common.StorageSize(bytes))
+ }
+ go t.storageRequestHandler(t, id, root, accounts, origin, limit, bytes)
+ return nil
+}
+
+func (t *testPeer) RequestByteCodes(id uint64, hashes []common.Hash, bytes uint64) error {
+ t.nBytecodeRequests++
+ t.logger.Trace("Fetching set of byte codes", "reqid", id, "hashes", len(hashes), "bytes", common.StorageSize(bytes))
+ go t.codeRequestHandler(t, id, hashes, bytes)
+ return nil
+}
+
+// defaultTrieRequestHandler is a well-behaving handler for trie healing requests
+func defaultTrieRequestHandler(t *testPeer, requestId uint64, root common.Hash, paths []TrieNodePathSet, cap uint64) error {
+ // Pass the response
+ var nodes [][]byte
+ for _, pathset := range paths {
+ switch len(pathset) {
+ case 1:
+ blob, _, err := t.accountTrie.GetNode(pathset[0])
+ if err != nil {
+ t.logger.Info("Error handling req", "error", err)
+ break
+ }
+ nodes = append(nodes, blob)
+ default:
+ account := t.storageTries[(common.BytesToHash(pathset[0]))]
+ for _, path := range pathset[1:] {
+ blob, _, err := account.GetNode(path)
+ if err != nil {
+ t.logger.Info("Error handling req", "error", err)
+ break
+ }
+ nodes = append(nodes, blob)
+ }
+ }
+ }
+ t.remote.OnTrieNodes(t, requestId, nodes)
+ return nil
+}
+
+// defaultAccountRequestHandler is a well-behaving handler for AccountRangeRequests
+func defaultAccountRequestHandler(t *testPeer, id uint64, root common.Hash, origin common.Hash, limit common.Hash, cap uint64) error {
+ keys, vals, proofs := createAccountRequestResponse(t, root, origin, limit, cap)
+ if err := t.remote.OnAccounts(t, id, keys, vals, proofs); err != nil {
+ t.test.Errorf("Remote side rejected our delivery: %v", err)
+ t.term()
+ return err
+ }
+ return nil
+}
+
+func createAccountRequestResponse(t *testPeer, root common.Hash, origin common.Hash, limit common.Hash, cap uint64) (keys []common.Hash, vals [][]byte, proofs [][]byte) {
+ var size uint64
+ if limit == (common.Hash{}) {
+ limit = common.MaxHash
+ }
+ for _, entry := range t.accountValues {
+ if size > cap {
+ break
+ }
+ if bytes.Compare(origin[:], entry.k) <= 0 {
+ keys = append(keys, common.BytesToHash(entry.k))
+ vals = append(vals, entry.v)
+ size += uint64(32 + len(entry.v))
+ }
+ // If we've exceeded the request threshold, abort
+ if bytes.Compare(entry.k, limit[:]) >= 0 {
+ break
+ }
+ }
+ // Unless we send the entire trie, we need to supply proofs
+ // Actually, we need to supply proofs either way! This seems to be an implementation
+ // quirk in go-ethereum
+ proof := trienode.NewProofSet()
+ if err := t.accountTrie.Prove(origin[:], proof); err != nil {
+ t.logger.Error("Could not prove inexistence of origin", "origin", origin, "error", err)
+ }
+ if len(keys) > 0 {
+ lastK := (keys[len(keys)-1])[:]
+ if err := t.accountTrie.Prove(lastK, proof); err != nil {
+ t.logger.Error("Could not prove last item", "error", err)
+ }
+ }
+ for _, blob := range proof.List() {
+ proofs = append(proofs, blob)
+ }
+ return keys, vals, proofs
+}
+
+// defaultStorageRequestHandler is a well-behaving storage request handler
+func defaultStorageRequestHandler(t *testPeer, requestId uint64, root common.Hash, accounts []common.Hash, bOrigin, bLimit []byte, max uint64) error {
+ hashes, slots, proofs := createStorageRequestResponse(t, root, accounts, bOrigin, bLimit, max)
+ if err := t.remote.OnStorage(t, requestId, hashes, slots, proofs); err != nil {
+ t.test.Errorf("Remote side rejected our delivery: %v", err)
+ t.term()
+ }
+ return nil
+}
+
+func defaultCodeRequestHandler(t *testPeer, id uint64, hashes []common.Hash, max uint64) error {
+ var bytecodes [][]byte
+ for _, h := range hashes {
+ bytecodes = append(bytecodes, getCodeByHash(h))
+ }
+ if err := t.remote.OnByteCodes(t, id, bytecodes); err != nil {
+ t.test.Errorf("Remote side rejected our delivery: %v", err)
+ t.term()
+ }
+ return nil
+}
+
+func createStorageRequestResponse(t *testPeer, root common.Hash, accounts []common.Hash, origin, limit []byte, max uint64) (hashes [][]common.Hash, slots [][][]byte, proofs [][]byte) {
+ var size uint64
+ for _, account := range accounts {
+ // The first account might start from a different origin and end sooner
+ var originHash common.Hash
+ if len(origin) > 0 {
+ originHash = common.BytesToHash(origin)
+ }
+ var limitHash = common.MaxHash
+ if len(limit) > 0 {
+ limitHash = common.BytesToHash(limit)
+ }
+ var (
+ keys []common.Hash
+ vals [][]byte
+ abort bool
+ )
+ for _, entry := range t.storageValues[account] {
+ if size >= max {
+ abort = true
+ break
+ }
+ if bytes.Compare(entry.k, originHash[:]) < 0 {
+ continue
+ }
+ keys = append(keys, common.BytesToHash(entry.k))
+ vals = append(vals, entry.v)
+ size += uint64(32 + len(entry.v))
+ if bytes.Compare(entry.k, limitHash[:]) >= 0 {
+ break
+ }
+ }
+ if len(keys) > 0 {
+ hashes = append(hashes, keys)
+ slots = append(slots, vals)
+ }
+ // Generate the Merkle proofs for the first and last storage slot, but
+ // only if the response was capped. If the entire storage trie included
+ // in the response, no need for any proofs.
+ if originHash != (common.Hash{}) || (abort && len(keys) > 0) {
+ // If we're aborting, we need to prove the first and last item
+ // This terminates the response (and thus the loop)
+ proof := trienode.NewProofSet()
+ stTrie := t.storageTries[account]
+
+ // Here's a potential gotcha: when constructing the proof, we cannot
+ // use the 'origin' slice directly, but must use the full 32-byte
+ // hash form.
+ if err := stTrie.Prove(originHash[:], proof); err != nil {
+ t.logger.Error("Could not prove inexistence of origin", "origin", originHash, "error", err)
+ }
+ if len(keys) > 0 {
+ lastK := (keys[len(keys)-1])[:]
+ if err := stTrie.Prove(lastK, proof); err != nil {
+ t.logger.Error("Could not prove last item", "error", err)
+ }
+ }
+ for _, blob := range proof.List() {
+ proofs = append(proofs, blob)
+ }
+ break
+ }
+ }
+ return hashes, slots, proofs
+}
+
+// createStorageRequestResponseAlwaysProve tests a cornercase, where the peer always
+// supplies the proof for the last account, even if it is 'complete'.
+func createStorageRequestResponseAlwaysProve(t *testPeer, root common.Hash, accounts []common.Hash, bOrigin, bLimit []byte, max uint64) (hashes [][]common.Hash, slots [][][]byte, proofs [][]byte) {
+ var size uint64
+ max = max * 3 / 4
+
+ var origin common.Hash
+ if len(bOrigin) > 0 {
+ origin = common.BytesToHash(bOrigin)
+ }
+ var exit bool
+ for i, account := range accounts {
+ var keys []common.Hash
+ var vals [][]byte
+ for _, entry := range t.storageValues[account] {
+ if bytes.Compare(entry.k, origin[:]) < 0 {
+ exit = true
+ }
+ keys = append(keys, common.BytesToHash(entry.k))
+ vals = append(vals, entry.v)
+ size += uint64(32 + len(entry.v))
+ if size > max {
+ exit = true
+ }
+ }
+ if i == len(accounts)-1 {
+ exit = true
+ }
+ hashes = append(hashes, keys)
+ slots = append(slots, vals)
+
+ if exit {
+ // If we're aborting, we need to prove the first and last item
+ // This terminates the response (and thus the loop)
+ proof := trienode.NewProofSet()
+ stTrie := t.storageTries[account]
+
+ // Here's a potential gotcha: when constructing the proof, we cannot
+ // use the 'origin' slice directly, but must use the full 32-byte
+ // hash form.
+ if err := stTrie.Prove(origin[:], proof); err != nil {
+ t.logger.Error("Could not prove inexistence of origin", "origin", origin,
+ "error", err)
+ }
+ if len(keys) > 0 {
+ lastK := (keys[len(keys)-1])[:]
+ if err := stTrie.Prove(lastK, proof); err != nil {
+ t.logger.Error("Could not prove last item", "error", err)
+ }
+ }
+ for _, blob := range proof.List() {
+ proofs = append(proofs, blob)
+ }
+ break
+ }
+ }
+ return hashes, slots, proofs
+}
+
+// emptyRequestAccountRangeFn is a rejects AccountRangeRequests
+func emptyRequestAccountRangeFn(t *testPeer, requestId uint64, root common.Hash, origin common.Hash, limit common.Hash, cap uint64) error {
+ t.remote.OnAccounts(t, requestId, nil, nil, nil)
+ return nil
+}
+
+func nonResponsiveRequestAccountRangeFn(t *testPeer, requestId uint64, root common.Hash, origin common.Hash, limit common.Hash, cap uint64) error {
+ return nil
+}
+
+func emptyTrieRequestHandler(t *testPeer, requestId uint64, root common.Hash, paths []TrieNodePathSet, cap uint64) error {
+ t.remote.OnTrieNodes(t, requestId, nil)
+ return nil
+}
+
+func nonResponsiveTrieRequestHandler(t *testPeer, requestId uint64, root common.Hash, paths []TrieNodePathSet, cap uint64) error {
+ return nil
+}
+
+func emptyStorageRequestHandler(t *testPeer, requestId uint64, root common.Hash, accounts []common.Hash, origin, limit []byte, max uint64) error {
+ t.remote.OnStorage(t, requestId, nil, nil, nil)
+ return nil
+}
+
+func nonResponsiveStorageRequestHandler(t *testPeer, requestId uint64, root common.Hash, accounts []common.Hash, origin, limit []byte, max uint64) error {
+ return nil
+}
+
+func proofHappyStorageRequestHandler(t *testPeer, requestId uint64, root common.Hash, accounts []common.Hash, origin, limit []byte, max uint64) error {
+ hashes, slots, proofs := createStorageRequestResponseAlwaysProve(t, root, accounts, origin, limit, max)
+ if err := t.remote.OnStorage(t, requestId, hashes, slots, proofs); err != nil {
+ t.test.Errorf("Remote side rejected our delivery: %v", err)
+ t.term()
+ }
+ return nil
+}
+
+//func emptyCodeRequestHandler(t *testPeer, id uint64, hashes []common.Hash, max uint64) error {
+// var bytecodes [][]byte
+// t.remote.OnByteCodes(t, id, bytecodes)
+// return nil
+//}
+
+func corruptCodeRequestHandler(t *testPeer, id uint64, hashes []common.Hash, max uint64) error {
+ var bytecodes [][]byte
+ for _, h := range hashes {
+ // Send back the hashes
+ bytecodes = append(bytecodes, h[:])
+ }
+ if err := t.remote.OnByteCodes(t, id, bytecodes); err != nil {
+ t.logger.Info("remote error on delivery (as expected)", "error", err)
+ // Mimic the real-life handler, which drops a peer on errors
+ t.remote.Unregister(t.id)
+ }
+ return nil
+}
+
+func cappedCodeRequestHandler(t *testPeer, id uint64, hashes []common.Hash, max uint64) error {
+ var bytecodes [][]byte
+ for _, h := range hashes[:1] {
+ bytecodes = append(bytecodes, getCodeByHash(h))
+ }
+ // Missing bytecode can be retrieved again, no error expected
+ if err := t.remote.OnByteCodes(t, id, bytecodes); err != nil {
+ t.test.Errorf("Remote side rejected our delivery: %v", err)
+ t.term()
+ }
+ return nil
+}
+
+// starvingStorageRequestHandler is somewhat well-behaving storage handler, but it caps the returned results to be very small
+func starvingStorageRequestHandler(t *testPeer, requestId uint64, root common.Hash, accounts []common.Hash, origin, limit []byte, max uint64) error {
+ return defaultStorageRequestHandler(t, requestId, root, accounts, origin, limit, 500)
+}
+
+func starvingAccountRequestHandler(t *testPeer, requestId uint64, root common.Hash, origin common.Hash, limit common.Hash, cap uint64) error {
+ return defaultAccountRequestHandler(t, requestId, root, origin, limit, 500)
+}
+
+//func misdeliveringAccountRequestHandler(t *testPeer, requestId uint64, root common.Hash, origin common.Hash, cap uint64) error {
+// return defaultAccountRequestHandler(t, requestId-1, root, origin, 500)
+//}
+
+func corruptAccountRequestHandler(t *testPeer, requestId uint64, root common.Hash, origin common.Hash, limit common.Hash, cap uint64) error {
+ hashes, accounts, proofs := createAccountRequestResponse(t, root, origin, limit, cap)
+ if len(proofs) > 0 {
+ proofs = proofs[1:]
+ }
+ if err := t.remote.OnAccounts(t, requestId, hashes, accounts, proofs); err != nil {
+ t.logger.Info("remote error on delivery (as expected)", "error", err)
+ // Mimic the real-life handler, which drops a peer on errors
+ t.remote.Unregister(t.id)
+ }
+ return nil
+}
+
+// corruptStorageRequestHandler doesn't provide good proofs
+func corruptStorageRequestHandler(t *testPeer, requestId uint64, root common.Hash, accounts []common.Hash, origin, limit []byte, max uint64) error {
+ hashes, slots, proofs := createStorageRequestResponse(t, root, accounts, origin, limit, max)
+ if len(proofs) > 0 {
+ proofs = proofs[1:]
+ }
+ if err := t.remote.OnStorage(t, requestId, hashes, slots, proofs); err != nil {
+ t.logger.Info("remote error on delivery (as expected)", "error", err)
+ // Mimic the real-life handler, which drops a peer on errors
+ t.remote.Unregister(t.id)
+ }
+ return nil
+}
+
+func noProofStorageRequestHandler(t *testPeer, requestId uint64, root common.Hash, accounts []common.Hash, origin, limit []byte, max uint64) error {
+ hashes, slots, _ := createStorageRequestResponse(t, root, accounts, origin, limit, max)
+ if err := t.remote.OnStorage(t, requestId, hashes, slots, nil); err != nil {
+ t.logger.Info("remote error on delivery (as expected)", "error", err)
+ // Mimic the real-life handler, which drops a peer on errors
+ t.remote.Unregister(t.id)
+ }
+ return nil
+}
+
+// TestSyncBloatedProof tests a scenario where we provide only _one_ value, but
+// also ship the entire trie inside the proof. If the attack is successful,
+// the remote side does not do any follow-up requests
+func TestSyncBloatedProof(t *testing.T) {
+ t.Parallel()
+
+ testSyncBloatedProof(t, rawdb.HashScheme)
+ testSyncBloatedProof(t, rawdb.PathScheme)
+}
+
+func testSyncBloatedProof(t *testing.T, scheme string) {
+ var (
+ once sync.Once
+ cancel = make(chan struct{})
+ term = func() {
+ once.Do(func() {
+ close(cancel)
+ })
+ }
+ )
+ nodeScheme, sourceAccountTrie, elems := makeAccountTrieNoStorage(100, scheme)
+ source := newTestPeer("source", t, term)
+ source.accountTrie = sourceAccountTrie.Copy()
+ source.accountValues = elems
+
+ source.accountRequestHandler = func(t *testPeer, requestId uint64, root common.Hash, origin common.Hash, limit common.Hash, cap uint64) error {
+ var (
+ proofs [][]byte
+ keys []common.Hash
+ vals [][]byte
+ )
+ // The values
+ for _, entry := range t.accountValues {
+ if bytes.Compare(entry.k, origin[:]) < 0 {
+ continue
+ }
+ if bytes.Compare(entry.k, limit[:]) > 0 {
+ continue
+ }
+ keys = append(keys, common.BytesToHash(entry.k))
+ vals = append(vals, entry.v)
+ }
+ // The proofs
+ proof := trienode.NewProofSet()
+ if err := t.accountTrie.Prove(origin[:], proof); err != nil {
+ t.logger.Error("Could not prove origin", "origin", origin, "error", err)
+ t.logger.Error("Could not prove origin", "origin", origin, "error", err)
+ }
+ // The bloat: add proof of every single element
+ for _, entry := range t.accountValues {
+ if err := t.accountTrie.Prove(entry.k, proof); err != nil {
+ t.logger.Error("Could not prove item", "error", err)
+ }
+ }
+ // And remove one item from the elements
+ if len(keys) > 2 {
+ keys = append(keys[:1], keys[2:]...)
+ vals = append(vals[:1], vals[2:]...)
+ }
+ for _, blob := range proof.List() {
+ proofs = append(proofs, blob)
+ }
+ if err := t.remote.OnAccounts(t, requestId, keys, vals, proofs); err != nil {
+ t.logger.Info("remote error on delivery (as expected)", "error", err)
+ t.term()
+ // This is actually correct, signal to exit the test successfully
+ }
+ return nil
+ }
+ syncer := setupSyncer(nodeScheme, source)
+ if err := syncer.Sync(sourceAccountTrie.Hash(), cancel); err == nil {
+ t.Fatal("No error returned from incomplete/cancelled sync")
+ }
+}
+
+func setupSyncer(scheme string, peers ...*testPeer) *Syncer {
+ stateDb := rawdb.NewMemoryDatabase()
+ syncer := NewSyncer(stateDb, scheme)
+ for _, peer := range peers {
+ syncer.Register(peer)
+ peer.remote = syncer
+ }
+ return syncer
+}
+
+// TestSync tests a basic sync with one peer
+func TestSync(t *testing.T) {
+ t.Parallel()
+
+ testSync(t, rawdb.HashScheme)
+ testSync(t, rawdb.PathScheme)
+}
+
+func testSync(t *testing.T, scheme string) {
+ var (
+ once sync.Once
+ cancel = make(chan struct{})
+ term = func() {
+ once.Do(func() {
+ close(cancel)
+ })
+ }
+ )
+ nodeScheme, sourceAccountTrie, elems := makeAccountTrieNoStorage(100, scheme)
+
+ mkSource := func(name string) *testPeer {
+ source := newTestPeer(name, t, term)
+ source.accountTrie = sourceAccountTrie.Copy()
+ source.accountValues = elems
+ return source
+ }
+ syncer := setupSyncer(nodeScheme, mkSource("source"))
+ if err := syncer.Sync(sourceAccountTrie.Hash(), cancel); err != nil {
+ t.Fatalf("sync failed: %v", err)
+ }
+ verifyTrie(scheme, syncer.db, sourceAccountTrie.Hash(), t)
+}
+
+// TestSyncTinyTriePanic tests a basic sync with one peer, and a tiny trie. This caused a
+// panic within the prover
+func TestSyncTinyTriePanic(t *testing.T) {
+ t.Parallel()
+
+ testSyncTinyTriePanic(t, rawdb.HashScheme)
+ testSyncTinyTriePanic(t, rawdb.PathScheme)
+}
+
+func testSyncTinyTriePanic(t *testing.T, scheme string) {
+ var (
+ once sync.Once
+ cancel = make(chan struct{})
+ term = func() {
+ once.Do(func() {
+ close(cancel)
+ })
+ }
+ )
+ nodeScheme, sourceAccountTrie, elems := makeAccountTrieNoStorage(1, scheme)
+
+ mkSource := func(name string) *testPeer {
+ source := newTestPeer(name, t, term)
+ source.accountTrie = sourceAccountTrie.Copy()
+ source.accountValues = elems
+ return source
+ }
+ syncer := setupSyncer(nodeScheme, mkSource("source"))
+ done := checkStall(t, term)
+ if err := syncer.Sync(sourceAccountTrie.Hash(), cancel); err != nil {
+ t.Fatalf("sync failed: %v", err)
+ }
+ close(done)
+ verifyTrie(scheme, syncer.db, sourceAccountTrie.Hash(), t)
+}
+
+// TestMultiSync tests a basic sync with multiple peers
+func TestMultiSync(t *testing.T) {
+ t.Parallel()
+
+ testMultiSync(t, rawdb.HashScheme)
+ testMultiSync(t, rawdb.PathScheme)
+}
+
+func testMultiSync(t *testing.T, scheme string) {
+ var (
+ once sync.Once
+ cancel = make(chan struct{})
+ term = func() {
+ once.Do(func() {
+ close(cancel)
+ })
+ }
+ )
+ nodeScheme, sourceAccountTrie, elems := makeAccountTrieNoStorage(100, scheme)
+
+ mkSource := func(name string) *testPeer {
+ source := newTestPeer(name, t, term)
+ source.accountTrie = sourceAccountTrie.Copy()
+ source.accountValues = elems
+ return source
+ }
+ syncer := setupSyncer(nodeScheme, mkSource("sourceA"), mkSource("sourceB"))
+ done := checkStall(t, term)
+ if err := syncer.Sync(sourceAccountTrie.Hash(), cancel); err != nil {
+ t.Fatalf("sync failed: %v", err)
+ }
+ close(done)
+ verifyTrie(scheme, syncer.db, sourceAccountTrie.Hash(), t)
+}
+
+// TestSyncWithStorage tests basic sync using accounts + storage + code
+func TestSyncWithStorage(t *testing.T) {
+ t.Parallel()
+
+ testSyncWithStorage(t, rawdb.HashScheme)
+ testSyncWithStorage(t, rawdb.PathScheme)
+}
+
+func testSyncWithStorage(t *testing.T, scheme string) {
+ var (
+ once sync.Once
+ cancel = make(chan struct{})
+ term = func() {
+ once.Do(func() {
+ close(cancel)
+ })
+ }
+ )
+ sourceAccountTrie, elems, storageTries, storageElems := makeAccountTrieWithStorage(scheme, 3, 3000, true, false, false)
+
+ mkSource := func(name string) *testPeer {
+ source := newTestPeer(name, t, term)
+ source.accountTrie = sourceAccountTrie.Copy()
+ source.accountValues = elems
+ source.setStorageTries(storageTries)
+ source.storageValues = storageElems
+ return source
+ }
+ syncer := setupSyncer(scheme, mkSource("sourceA"))
+ done := checkStall(t, term)
+ if err := syncer.Sync(sourceAccountTrie.Hash(), cancel); err != nil {
+ t.Fatalf("sync failed: %v", err)
+ }
+ close(done)
+ verifyTrie(scheme, syncer.db, sourceAccountTrie.Hash(), t)
+}
+
+// TestMultiSyncManyUseless contains one good peer, and many which doesn't return anything valuable at all
+func TestMultiSyncManyUseless(t *testing.T) {
+ t.Parallel()
+
+ testMultiSyncManyUseless(t, rawdb.HashScheme)
+ testMultiSyncManyUseless(t, rawdb.PathScheme)
+}
+
+func testMultiSyncManyUseless(t *testing.T, scheme string) {
+ var (
+ once sync.Once
+ cancel = make(chan struct{})
+ term = func() {
+ once.Do(func() {
+ close(cancel)
+ })
+ }
+ )
+ sourceAccountTrie, elems, storageTries, storageElems := makeAccountTrieWithStorage(scheme, 100, 3000, true, false, false)
+
+ mkSource := func(name string, noAccount, noStorage, noTrieNode bool) *testPeer {
+ source := newTestPeer(name, t, term)
+ source.accountTrie = sourceAccountTrie.Copy()
+ source.accountValues = elems
+ source.setStorageTries(storageTries)
+ source.storageValues = storageElems
+
+ if !noAccount {
+ source.accountRequestHandler = emptyRequestAccountRangeFn
+ }
+ if !noStorage {
+ source.storageRequestHandler = emptyStorageRequestHandler
+ }
+ if !noTrieNode {
+ source.trieRequestHandler = emptyTrieRequestHandler
+ }
+ return source
+ }
+
+ syncer := setupSyncer(
+ scheme,
+ mkSource("full", true, true, true),
+ mkSource("noAccounts", false, true, true),
+ mkSource("noStorage", true, false, true),
+ mkSource("noTrie", true, true, false),
+ )
+ done := checkStall(t, term)
+ if err := syncer.Sync(sourceAccountTrie.Hash(), cancel); err != nil {
+ t.Fatalf("sync failed: %v", err)
+ }
+ close(done)
+ verifyTrie(scheme, syncer.db, sourceAccountTrie.Hash(), t)
+}
+
+// TestMultiSyncManyUseless contains one good peer, and many which doesn't return anything valuable at all
+func TestMultiSyncManyUselessWithLowTimeout(t *testing.T) {
+ t.Parallel()
+
+ testMultiSyncManyUselessWithLowTimeout(t, rawdb.HashScheme)
+ testMultiSyncManyUselessWithLowTimeout(t, rawdb.PathScheme)
+}
+
+func testMultiSyncManyUselessWithLowTimeout(t *testing.T, scheme string) {
+ var (
+ once sync.Once
+ cancel = make(chan struct{})
+ term = func() {
+ once.Do(func() {
+ close(cancel)
+ })
+ }
+ )
+ sourceAccountTrie, elems, storageTries, storageElems := makeAccountTrieWithStorage(scheme, 100, 3000, true, false, false)
+
+ mkSource := func(name string, noAccount, noStorage, noTrieNode bool) *testPeer {
+ source := newTestPeer(name, t, term)
+ source.accountTrie = sourceAccountTrie.Copy()
+ source.accountValues = elems
+ source.setStorageTries(storageTries)
+ source.storageValues = storageElems
+
+ if !noAccount {
+ source.accountRequestHandler = emptyRequestAccountRangeFn
+ }
+ if !noStorage {
+ source.storageRequestHandler = emptyStorageRequestHandler
+ }
+ if !noTrieNode {
+ source.trieRequestHandler = emptyTrieRequestHandler
+ }
+ return source
+ }
+
+ syncer := setupSyncer(
+ scheme,
+ mkSource("full", true, true, true),
+ mkSource("noAccounts", false, true, true),
+ mkSource("noStorage", true, false, true),
+ mkSource("noTrie", true, true, false),
+ )
+ // We're setting the timeout to very low, to increase the chance of the timeout
+ // being triggered. This was previously a cause of panic, when a response
+ // arrived simultaneously as a timeout was triggered.
+ syncer.rates.OverrideTTLLimit = time.Millisecond
+
+ done := checkStall(t, term)
+ if err := syncer.Sync(sourceAccountTrie.Hash(), cancel); err != nil {
+ t.Fatalf("sync failed: %v", err)
+ }
+ close(done)
+ verifyTrie(scheme, syncer.db, sourceAccountTrie.Hash(), t)
+}
+
+// TestMultiSyncManyUnresponsive contains one good peer, and many which doesn't respond at all
+func TestMultiSyncManyUnresponsive(t *testing.T) {
+ t.Parallel()
+
+ testMultiSyncManyUnresponsive(t, rawdb.HashScheme)
+ testMultiSyncManyUnresponsive(t, rawdb.PathScheme)
+}
+
+func testMultiSyncManyUnresponsive(t *testing.T, scheme string) {
+ var (
+ once sync.Once
+ cancel = make(chan struct{})
+ term = func() {
+ once.Do(func() {
+ close(cancel)
+ })
+ }
+ )
+ sourceAccountTrie, elems, storageTries, storageElems := makeAccountTrieWithStorage(scheme, 100, 3000, true, false, false)
+
+ mkSource := func(name string, noAccount, noStorage, noTrieNode bool) *testPeer {
+ source := newTestPeer(name, t, term)
+ source.accountTrie = sourceAccountTrie.Copy()
+ source.accountValues = elems
+ source.setStorageTries(storageTries)
+ source.storageValues = storageElems
+
+ if !noAccount {
+ source.accountRequestHandler = nonResponsiveRequestAccountRangeFn
+ }
+ if !noStorage {
+ source.storageRequestHandler = nonResponsiveStorageRequestHandler
+ }
+ if !noTrieNode {
+ source.trieRequestHandler = nonResponsiveTrieRequestHandler
+ }
+ return source
+ }
+
+ syncer := setupSyncer(
+ scheme,
+ mkSource("full", true, true, true),
+ mkSource("noAccounts", false, true, true),
+ mkSource("noStorage", true, false, true),
+ mkSource("noTrie", true, true, false),
+ )
+ // We're setting the timeout to very low, to make the test run a bit faster
+ syncer.rates.OverrideTTLLimit = time.Millisecond
+
+ done := checkStall(t, term)
+ if err := syncer.Sync(sourceAccountTrie.Hash(), cancel); err != nil {
+ t.Fatalf("sync failed: %v", err)
+ }
+ close(done)
+ verifyTrie(scheme, syncer.db, sourceAccountTrie.Hash(), t)
+}
+
+func checkStall(t *testing.T, term func()) chan struct{} {
+ testDone := make(chan struct{})
+ go func() {
+ select {
+ case <-time.After(10 * time.Minute): // TODO(karalabe): Make tests smaller, this is too much
+ t.Log("Sync stalled")
+ term()
+ case <-testDone:
+ return
+ }
+ }()
+ return testDone
+}
+
+// TestSyncBoundaryAccountTrie tests sync against a few normal peers, but the
+// account trie has a few boundary elements.
+func TestSyncBoundaryAccountTrie(t *testing.T) {
+ t.Parallel()
+
+ testSyncBoundaryAccountTrie(t, rawdb.HashScheme)
+ testSyncBoundaryAccountTrie(t, rawdb.PathScheme)
+}
+
+func testSyncBoundaryAccountTrie(t *testing.T, scheme string) {
+ var (
+ once sync.Once
+ cancel = make(chan struct{})
+ term = func() {
+ once.Do(func() {
+ close(cancel)
+ })
+ }
+ )
+ nodeScheme, sourceAccountTrie, elems := makeBoundaryAccountTrie(scheme, 3000)
+
+ mkSource := func(name string) *testPeer {
+ source := newTestPeer(name, t, term)
+ source.accountTrie = sourceAccountTrie.Copy()
+ source.accountValues = elems
+ return source
+ }
+ syncer := setupSyncer(
+ nodeScheme,
+ mkSource("peer-a"),
+ mkSource("peer-b"),
+ )
+ done := checkStall(t, term)
+ if err := syncer.Sync(sourceAccountTrie.Hash(), cancel); err != nil {
+ t.Fatalf("sync failed: %v", err)
+ }
+ close(done)
+ verifyTrie(scheme, syncer.db, sourceAccountTrie.Hash(), t)
+}
+
+// TestSyncNoStorageAndOneCappedPeer tests sync using accounts and no storage, where one peer is
+// consistently returning very small results
+func TestSyncNoStorageAndOneCappedPeer(t *testing.T) {
+ t.Parallel()
+
+ testSyncNoStorageAndOneCappedPeer(t, rawdb.HashScheme)
+ testSyncNoStorageAndOneCappedPeer(t, rawdb.PathScheme)
+}
+
+func testSyncNoStorageAndOneCappedPeer(t *testing.T, scheme string) {
+ var (
+ once sync.Once
+ cancel = make(chan struct{})
+ term = func() {
+ once.Do(func() {
+ close(cancel)
+ })
+ }
+ )
+ nodeScheme, sourceAccountTrie, elems := makeAccountTrieNoStorage(3000, scheme)
+
+ mkSource := func(name string, slow bool) *testPeer {
+ source := newTestPeer(name, t, term)
+ source.accountTrie = sourceAccountTrie.Copy()
+ source.accountValues = elems
+
+ if slow {
+ source.accountRequestHandler = starvingAccountRequestHandler
+ }
+ return source
+ }
+
+ syncer := setupSyncer(
+ nodeScheme,
+ mkSource("nice-a", false),
+ mkSource("nice-b", false),
+ mkSource("nice-c", false),
+ mkSource("capped", true),
+ )
+ done := checkStall(t, term)
+ if err := syncer.Sync(sourceAccountTrie.Hash(), cancel); err != nil {
+ t.Fatalf("sync failed: %v", err)
+ }
+ close(done)
+ verifyTrie(scheme, syncer.db, sourceAccountTrie.Hash(), t)
+}
+
+// TestSyncNoStorageAndOneCodeCorruptPeer has one peer which doesn't deliver
+// code requests properly.
+func TestSyncNoStorageAndOneCodeCorruptPeer(t *testing.T) {
+ t.Parallel()
+
+ testSyncNoStorageAndOneCodeCorruptPeer(t, rawdb.HashScheme)
+ testSyncNoStorageAndOneCodeCorruptPeer(t, rawdb.PathScheme)
+}
+
+func testSyncNoStorageAndOneCodeCorruptPeer(t *testing.T, scheme string) {
+ var (
+ once sync.Once
+ cancel = make(chan struct{})
+ term = func() {
+ once.Do(func() {
+ close(cancel)
+ })
+ }
+ )
+ nodeScheme, sourceAccountTrie, elems := makeAccountTrieNoStorage(3000, scheme)
+
+ mkSource := func(name string, codeFn codeHandlerFunc) *testPeer {
+ source := newTestPeer(name, t, term)
+ source.accountTrie = sourceAccountTrie.Copy()
+ source.accountValues = elems
+ source.codeRequestHandler = codeFn
+ return source
+ }
+ // One is capped, one is corrupt. If we don't use a capped one, there's a 50%
+ // chance that the full set of codes requested are sent only to the
+ // non-corrupt peer, which delivers everything in one go, and makes the
+ // test moot
+ syncer := setupSyncer(
+ nodeScheme,
+ mkSource("capped", cappedCodeRequestHandler),
+ mkSource("corrupt", corruptCodeRequestHandler),
+ )
+ done := checkStall(t, term)
+ if err := syncer.Sync(sourceAccountTrie.Hash(), cancel); err != nil {
+ t.Fatalf("sync failed: %v", err)
+ }
+ close(done)
+ verifyTrie(scheme, syncer.db, sourceAccountTrie.Hash(), t)
+}
+
+func TestSyncNoStorageAndOneAccountCorruptPeer(t *testing.T) {
+ t.Parallel()
+
+ testSyncNoStorageAndOneAccountCorruptPeer(t, rawdb.HashScheme)
+ testSyncNoStorageAndOneAccountCorruptPeer(t, rawdb.PathScheme)
+}
+
+func testSyncNoStorageAndOneAccountCorruptPeer(t *testing.T, scheme string) {
+ var (
+ once sync.Once
+ cancel = make(chan struct{})
+ term = func() {
+ once.Do(func() {
+ close(cancel)
+ })
+ }
+ )
+ nodeScheme, sourceAccountTrie, elems := makeAccountTrieNoStorage(3000, scheme)
+
+ mkSource := func(name string, accFn accountHandlerFunc) *testPeer {
+ source := newTestPeer(name, t, term)
+ source.accountTrie = sourceAccountTrie.Copy()
+ source.accountValues = elems
+ source.accountRequestHandler = accFn
+ return source
+ }
+ // One is capped, one is corrupt. If we don't use a capped one, there's a 50%
+ // chance that the full set of codes requested are sent only to the
+ // non-corrupt peer, which delivers everything in one go, and makes the
+ // test moot
+ syncer := setupSyncer(
+ nodeScheme,
+ mkSource("capped", defaultAccountRequestHandler),
+ mkSource("corrupt", corruptAccountRequestHandler),
+ )
+ done := checkStall(t, term)
+ if err := syncer.Sync(sourceAccountTrie.Hash(), cancel); err != nil {
+ t.Fatalf("sync failed: %v", err)
+ }
+ close(done)
+ verifyTrie(scheme, syncer.db, sourceAccountTrie.Hash(), t)
+}
+
+// TestSyncNoStorageAndOneCodeCappedPeer has one peer which delivers code hashes
+// one by one
+func TestSyncNoStorageAndOneCodeCappedPeer(t *testing.T) {
+ t.Parallel()
+
+ testSyncNoStorageAndOneCodeCappedPeer(t, rawdb.HashScheme)
+ testSyncNoStorageAndOneCodeCappedPeer(t, rawdb.PathScheme)
+}
+
+func testSyncNoStorageAndOneCodeCappedPeer(t *testing.T, scheme string) {
+ var (
+ once sync.Once
+ cancel = make(chan struct{})
+ term = func() {
+ once.Do(func() {
+ close(cancel)
+ })
+ }
+ )
+ nodeScheme, sourceAccountTrie, elems := makeAccountTrieNoStorage(3000, scheme)
+
+ mkSource := func(name string, codeFn codeHandlerFunc) *testPeer {
+ source := newTestPeer(name, t, term)
+ source.accountTrie = sourceAccountTrie.Copy()
+ source.accountValues = elems
+ source.codeRequestHandler = codeFn
+ return source
+ }
+ // Count how many times it's invoked. Remember, there are only 8 unique hashes,
+ // so it shouldn't be more than that
+ var counter int
+ syncer := setupSyncer(
+ nodeScheme,
+ mkSource("capped", func(t *testPeer, id uint64, hashes []common.Hash, max uint64) error {
+ counter++
+ return cappedCodeRequestHandler(t, id, hashes, max)
+ }),
+ )
+ done := checkStall(t, term)
+ if err := syncer.Sync(sourceAccountTrie.Hash(), cancel); err != nil {
+ t.Fatalf("sync failed: %v", err)
+ }
+ close(done)
+
+ // There are only 8 unique hashes, and 3K accounts. However, the code
+ // deduplication is per request batch. If it were a perfect global dedup,
+ // we would expect only 8 requests. If there were no dedup, there would be
+ // 3k requests.
+ // We expect somewhere below 100 requests for these 8 unique hashes. But
+ // the number can be flaky, so don't limit it so strictly.
+ if threshold := 100; counter > threshold {
+ t.Logf("Error, expected < %d invocations, got %d", threshold, counter)
+ }
+ verifyTrie(scheme, syncer.db, sourceAccountTrie.Hash(), t)
+}
+
+// TestSyncBoundaryStorageTrie tests sync against a few normal peers, but the
+// storage trie has a few boundary elements.
+func TestSyncBoundaryStorageTrie(t *testing.T) {
+ t.Parallel()
+
+ testSyncBoundaryStorageTrie(t, rawdb.HashScheme)
+ testSyncBoundaryStorageTrie(t, rawdb.PathScheme)
+}
+
+func testSyncBoundaryStorageTrie(t *testing.T, scheme string) {
+ var (
+ once sync.Once
+ cancel = make(chan struct{})
+ term = func() {
+ once.Do(func() {
+ close(cancel)
+ })
+ }
+ )
+ sourceAccountTrie, elems, storageTries, storageElems := makeAccountTrieWithStorage(scheme, 10, 1000, false, true, false)
+
+ mkSource := func(name string) *testPeer {
+ source := newTestPeer(name, t, term)
+ source.accountTrie = sourceAccountTrie.Copy()
+ source.accountValues = elems
+ source.setStorageTries(storageTries)
+ source.storageValues = storageElems
+ return source
+ }
+ syncer := setupSyncer(
+ scheme,
+ mkSource("peer-a"),
+ mkSource("peer-b"),
+ )
+ done := checkStall(t, term)
+ if err := syncer.Sync(sourceAccountTrie.Hash(), cancel); err != nil {
+ t.Fatalf("sync failed: %v", err)
+ }
+ close(done)
+ verifyTrie(scheme, syncer.db, sourceAccountTrie.Hash(), t)
+}
+
+// TestSyncWithStorageAndOneCappedPeer tests sync using accounts + storage, where one peer is
+// consistently returning very small results
+func TestSyncWithStorageAndOneCappedPeer(t *testing.T) {
+ t.Parallel()
+
+ testSyncWithStorageAndOneCappedPeer(t, rawdb.HashScheme)
+ testSyncWithStorageAndOneCappedPeer(t, rawdb.PathScheme)
+}
+
+func testSyncWithStorageAndOneCappedPeer(t *testing.T, scheme string) {
+ var (
+ once sync.Once
+ cancel = make(chan struct{})
+ term = func() {
+ once.Do(func() {
+ close(cancel)
+ })
+ }
+ )
+ sourceAccountTrie, elems, storageTries, storageElems := makeAccountTrieWithStorage(scheme, 300, 1000, false, false, false)
+
+ mkSource := func(name string, slow bool) *testPeer {
+ source := newTestPeer(name, t, term)
+ source.accountTrie = sourceAccountTrie.Copy()
+ source.accountValues = elems
+ source.setStorageTries(storageTries)
+ source.storageValues = storageElems
+
+ if slow {
+ source.storageRequestHandler = starvingStorageRequestHandler
+ }
+ return source
+ }
+
+ syncer := setupSyncer(
+ scheme,
+ mkSource("nice-a", false),
+ mkSource("slow", true),
+ )
+ done := checkStall(t, term)
+ if err := syncer.Sync(sourceAccountTrie.Hash(), cancel); err != nil {
+ t.Fatalf("sync failed: %v", err)
+ }
+ close(done)
+ verifyTrie(scheme, syncer.db, sourceAccountTrie.Hash(), t)
+}
+
+// TestSyncWithStorageAndCorruptPeer tests sync using accounts + storage, where one peer is
+// sometimes sending bad proofs
+func TestSyncWithStorageAndCorruptPeer(t *testing.T) {
+ t.Parallel()
+
+ testSyncWithStorageAndCorruptPeer(t, rawdb.HashScheme)
+ testSyncWithStorageAndCorruptPeer(t, rawdb.PathScheme)
+}
+
+func testSyncWithStorageAndCorruptPeer(t *testing.T, scheme string) {
+ var (
+ once sync.Once
+ cancel = make(chan struct{})
+ term = func() {
+ once.Do(func() {
+ close(cancel)
+ })
+ }
+ )
+ sourceAccountTrie, elems, storageTries, storageElems := makeAccountTrieWithStorage(scheme, 100, 3000, true, false, false)
+
+ mkSource := func(name string, handler storageHandlerFunc) *testPeer {
+ source := newTestPeer(name, t, term)
+ source.accountTrie = sourceAccountTrie.Copy()
+ source.accountValues = elems
+ source.setStorageTries(storageTries)
+ source.storageValues = storageElems
+ source.storageRequestHandler = handler
+ return source
+ }
+
+ syncer := setupSyncer(
+ scheme,
+ mkSource("nice-a", defaultStorageRequestHandler),
+ mkSource("nice-b", defaultStorageRequestHandler),
+ mkSource("nice-c", defaultStorageRequestHandler),
+ mkSource("corrupt", corruptStorageRequestHandler),
+ )
+ done := checkStall(t, term)
+ if err := syncer.Sync(sourceAccountTrie.Hash(), cancel); err != nil {
+ t.Fatalf("sync failed: %v", err)
+ }
+ close(done)
+ verifyTrie(scheme, syncer.db, sourceAccountTrie.Hash(), t)
+}
+
+func TestSyncWithStorageAndNonProvingPeer(t *testing.T) {
+ t.Parallel()
+
+ testSyncWithStorageAndNonProvingPeer(t, rawdb.HashScheme)
+ testSyncWithStorageAndNonProvingPeer(t, rawdb.PathScheme)
+}
+
+func testSyncWithStorageAndNonProvingPeer(t *testing.T, scheme string) {
+ var (
+ once sync.Once
+ cancel = make(chan struct{})
+ term = func() {
+ once.Do(func() {
+ close(cancel)
+ })
+ }
+ )
+ sourceAccountTrie, elems, storageTries, storageElems := makeAccountTrieWithStorage(scheme, 100, 3000, true, false, false)
+
+ mkSource := func(name string, handler storageHandlerFunc) *testPeer {
+ source := newTestPeer(name, t, term)
+ source.accountTrie = sourceAccountTrie.Copy()
+ source.accountValues = elems
+ source.setStorageTries(storageTries)
+ source.storageValues = storageElems
+ source.storageRequestHandler = handler
+ return source
+ }
+ syncer := setupSyncer(
+ scheme,
+ mkSource("nice-a", defaultStorageRequestHandler),
+ mkSource("nice-b", defaultStorageRequestHandler),
+ mkSource("nice-c", defaultStorageRequestHandler),
+ mkSource("corrupt", noProofStorageRequestHandler),
+ )
+ done := checkStall(t, term)
+ if err := syncer.Sync(sourceAccountTrie.Hash(), cancel); err != nil {
+ t.Fatalf("sync failed: %v", err)
+ }
+ close(done)
+ verifyTrie(scheme, syncer.db, sourceAccountTrie.Hash(), t)
+}
+
+// TestSyncWithStorage tests basic sync using accounts + storage + code, against
+// a peer who insists on delivering full storage sets _and_ proofs. This triggered
+// an error, where the recipient erroneously clipped the boundary nodes, but
+// did not mark the account for healing.
+func TestSyncWithStorageMisbehavingProve(t *testing.T) {
+ t.Parallel()
+
+ testSyncWithStorageMisbehavingProve(t, rawdb.HashScheme)
+ testSyncWithStorageMisbehavingProve(t, rawdb.PathScheme)
+}
+
+func testSyncWithStorageMisbehavingProve(t *testing.T, scheme string) {
+ var (
+ once sync.Once
+ cancel = make(chan struct{})
+ term = func() {
+ once.Do(func() {
+ close(cancel)
+ })
+ }
+ )
+ nodeScheme, sourceAccountTrie, elems, storageTries, storageElems := makeAccountTrieWithStorageWithUniqueStorage(scheme, 10, 30, false)
+
+ mkSource := func(name string) *testPeer {
+ source := newTestPeer(name, t, term)
+ source.accountTrie = sourceAccountTrie.Copy()
+ source.accountValues = elems
+ source.setStorageTries(storageTries)
+ source.storageValues = storageElems
+ source.storageRequestHandler = proofHappyStorageRequestHandler
+ return source
+ }
+ syncer := setupSyncer(nodeScheme, mkSource("sourceA"))
+ if err := syncer.Sync(sourceAccountTrie.Hash(), cancel); err != nil {
+ t.Fatalf("sync failed: %v", err)
+ }
+ verifyTrie(scheme, syncer.db, sourceAccountTrie.Hash(), t)
+}
+
+// TestSyncWithUnevenStorage tests sync where the storage trie is not even
+// and with a few empty ranges.
+func TestSyncWithUnevenStorage(t *testing.T) {
+ t.Parallel()
+
+ testSyncWithUnevenStorage(t, rawdb.HashScheme)
+ testSyncWithUnevenStorage(t, rawdb.PathScheme)
+}
+
+func testSyncWithUnevenStorage(t *testing.T, scheme string) {
+ var (
+ once sync.Once
+ cancel = make(chan struct{})
+ term = func() {
+ once.Do(func() {
+ close(cancel)
+ })
+ }
+ )
+ accountTrie, accounts, storageTries, storageElems := makeAccountTrieWithStorage(scheme, 3, 256, false, false, true)
+
+ mkSource := func(name string) *testPeer {
+ source := newTestPeer(name, t, term)
+ source.accountTrie = accountTrie.Copy()
+ source.accountValues = accounts
+ source.setStorageTries(storageTries)
+ source.storageValues = storageElems
+ source.storageRequestHandler = func(t *testPeer, reqId uint64, root common.Hash, accounts []common.Hash, origin, limit []byte, max uint64) error {
+ return defaultStorageRequestHandler(t, reqId, root, accounts, origin, limit, 128) // retrieve storage in large mode
+ }
+ return source
+ }
+ syncer := setupSyncer(scheme, mkSource("source"))
+ if err := syncer.Sync(accountTrie.Hash(), cancel); err != nil {
+ t.Fatalf("sync failed: %v", err)
+ }
+ verifyTrie(scheme, syncer.db, accountTrie.Hash(), t)
+}
+
+type kv struct {
+ k, v []byte
+}
+
+func (k *kv) cmp(other *kv) int {
+ return bytes.Compare(k.k, other.k)
+}
+
+func key32(i uint64) []byte {
+ key := make([]byte, 32)
+ binary.LittleEndian.PutUint64(key, i)
+ return key
+}
+
+var (
+ codehashes = []common.Hash{
+ crypto.Keccak256Hash([]byte{0}),
+ crypto.Keccak256Hash([]byte{1}),
+ crypto.Keccak256Hash([]byte{2}),
+ crypto.Keccak256Hash([]byte{3}),
+ crypto.Keccak256Hash([]byte{4}),
+ crypto.Keccak256Hash([]byte{5}),
+ crypto.Keccak256Hash([]byte{6}),
+ crypto.Keccak256Hash([]byte{7}),
+ }
+)
+
+// getCodeHash returns a pseudo-random code hash
+func getCodeHash(i uint64) []byte {
+ h := codehashes[int(i)%len(codehashes)]
+ return common.CopyBytes(h[:])
+}
+
+// getCodeByHash convenience function to lookup the code from the code hash
+func getCodeByHash(hash common.Hash) []byte {
+ if hash == types.EmptyCodeHash {
+ return nil
+ }
+ for i, h := range codehashes {
+ if h == hash {
+ return []byte{byte(i)}
+ }
+ }
+ return nil
+}
+
+// makeAccountTrieNoStorage spits out a trie, along with the leafs
+func makeAccountTrieNoStorage(n int, scheme string) (string, *trie.Trie, []*kv) {
+ var (
+ db = triedb.NewDatabase(rawdb.NewMemoryDatabase(), newDbConfig(scheme))
+ accTrie = trie.NewEmpty(db)
+ entries []*kv
+ )
+ for i := uint64(1); i <= uint64(n); i++ {
+ value, _ := rlp.EncodeToBytes(&types.StateAccount{
+ Nonce: i,
+ Balance: uint256.NewInt(i),
+ Root: types.EmptyRootHash,
+ CodeHash: getCodeHash(i),
+ })
+ key := key32(i)
+ elem := &kv{key, value}
+ accTrie.MustUpdate(elem.k, elem.v)
+ entries = append(entries, elem)
+ }
+ slices.SortFunc(entries, (*kv).cmp)
+
+ // Commit the state changes into db and re-create the trie
+ // for accessing later.
+ root, nodes, _ := accTrie.Commit(false)
+ db.Update(root, types.EmptyRootHash, 0, trienode.NewWithNodeSet(nodes), nil)
+
+ accTrie, _ = trie.New(trie.StateTrieID(root), db)
+ return db.Scheme(), accTrie, entries
+}
+
+// makeBoundaryAccountTrie constructs an account trie. Instead of filling
+// accounts normally, this function will fill a few accounts which have
+// boundary hash.
+func makeBoundaryAccountTrie(scheme string, n int) (string, *trie.Trie, []*kv) {
+ var (
+ entries []*kv
+ boundaries []common.Hash
+
+ db = triedb.NewDatabase(rawdb.NewMemoryDatabase(), newDbConfig(scheme))
+ accTrie = trie.NewEmpty(db)
+ )
+ // Initialize boundaries
+ var next common.Hash
+ step := new(big.Int).Sub(
+ new(big.Int).Div(
+ new(big.Int).Exp(common.Big2, common.Big256, nil),
+ big.NewInt(int64(accountConcurrency)),
+ ), common.Big1,
+ )
+ for i := 0; i < accountConcurrency; i++ {
+ last := common.BigToHash(new(big.Int).Add(next.Big(), step))
+ if i == accountConcurrency-1 {
+ last = common.MaxHash
+ }
+ boundaries = append(boundaries, last)
+ next = common.BigToHash(new(big.Int).Add(last.Big(), common.Big1))
+ }
+ // Fill boundary accounts
+ for i := 0; i < len(boundaries); i++ {
+ value, _ := rlp.EncodeToBytes(&types.StateAccount{
+ Nonce: uint64(0),
+ Balance: uint256.NewInt(uint64(i)),
+ Root: types.EmptyRootHash,
+ CodeHash: getCodeHash(uint64(i)),
+ })
+ elem := &kv{boundaries[i].Bytes(), value}
+ accTrie.MustUpdate(elem.k, elem.v)
+ entries = append(entries, elem)
+ }
+ // Fill other accounts if required
+ for i := uint64(1); i <= uint64(n); i++ {
+ value, _ := rlp.EncodeToBytes(&types.StateAccount{
+ Nonce: i,
+ Balance: uint256.NewInt(i),
+ Root: types.EmptyRootHash,
+ CodeHash: getCodeHash(i),
+ })
+ elem := &kv{key32(i), value}
+ accTrie.MustUpdate(elem.k, elem.v)
+ entries = append(entries, elem)
+ }
+ slices.SortFunc(entries, (*kv).cmp)
+
+ // Commit the state changes into db and re-create the trie
+ // for accessing later.
+ root, nodes, _ := accTrie.Commit(false)
+ db.Update(root, types.EmptyRootHash, 0, trienode.NewWithNodeSet(nodes), nil)
+
+ accTrie, _ = trie.New(trie.StateTrieID(root), db)
+ return db.Scheme(), accTrie, entries
+}
+
+// makeAccountTrieWithStorageWithUniqueStorage creates an account trie where each accounts
+// has a unique storage set.
+func makeAccountTrieWithStorageWithUniqueStorage(scheme string, accounts, slots int, code bool) (string, *trie.Trie, []*kv, map[common.Hash]*trie.Trie, map[common.Hash][]*kv) {
+ var (
+ db = triedb.NewDatabase(rawdb.NewMemoryDatabase(), newDbConfig(scheme))
+ accTrie = trie.NewEmpty(db)
+ entries []*kv
+ storageRoots = make(map[common.Hash]common.Hash)
+ storageTries = make(map[common.Hash]*trie.Trie)
+ storageEntries = make(map[common.Hash][]*kv)
+ nodes = trienode.NewMergedNodeSet()
+ )
+ // Create n accounts in the trie
+ for i := uint64(1); i <= uint64(accounts); i++ {
+ key := key32(i)
+ codehash := types.EmptyCodeHash.Bytes()
+ if code {
+ codehash = getCodeHash(i)
+ }
+ // Create a storage trie
+ stRoot, stNodes, stEntries := makeStorageTrieWithSeed(common.BytesToHash(key), uint64(slots), i, db)
+ nodes.Merge(stNodes)
+
+ value, _ := rlp.EncodeToBytes(&types.StateAccount{
+ Nonce: i,
+ Balance: uint256.NewInt(i),
+ Root: stRoot,
+ CodeHash: codehash,
+ })
+ elem := &kv{key, value}
+ accTrie.MustUpdate(elem.k, elem.v)
+ entries = append(entries, elem)
+
+ storageRoots[common.BytesToHash(key)] = stRoot
+ storageEntries[common.BytesToHash(key)] = stEntries
+ }
+ slices.SortFunc(entries, (*kv).cmp)
+
+ // Commit account trie
+ root, set, _ := accTrie.Commit(true)
+ nodes.Merge(set)
+
+ // Commit gathered dirty nodes into database
+ db.Update(root, types.EmptyRootHash, 0, nodes, nil)
+
+ // Re-create tries with new root
+ accTrie, _ = trie.New(trie.StateTrieID(root), db)
+ for i := uint64(1); i <= uint64(accounts); i++ {
+ key := key32(i)
+ id := trie.StorageTrieID(root, common.BytesToHash(key), storageRoots[common.BytesToHash(key)])
+ trie, _ := trie.New(id, db)
+ storageTries[common.BytesToHash(key)] = trie
+ }
+ return db.Scheme(), accTrie, entries, storageTries, storageEntries
+}
+
+// makeAccountTrieWithStorage spits out a trie, along with the leafs
+func makeAccountTrieWithStorage(scheme string, accounts, slots int, code, boundary bool, uneven bool) (*trie.Trie, []*kv, map[common.Hash]*trie.Trie, map[common.Hash][]*kv) {
+ var (
+ db = triedb.NewDatabase(rawdb.NewMemoryDatabase(), newDbConfig(scheme))
+ accTrie = trie.NewEmpty(db)
+ entries []*kv
+ storageRoots = make(map[common.Hash]common.Hash)
+ storageTries = make(map[common.Hash]*trie.Trie)
+ storageEntries = make(map[common.Hash][]*kv)
+ nodes = trienode.NewMergedNodeSet()
+ )
+ // Create n accounts in the trie
+ for i := uint64(1); i <= uint64(accounts); i++ {
+ key := key32(i)
+ codehash := types.EmptyCodeHash.Bytes()
+ if code {
+ codehash = getCodeHash(i)
+ }
+ // Make a storage trie
+ var (
+ stRoot common.Hash
+ stNodes *trienode.NodeSet
+ stEntries []*kv
+ )
+ if boundary {
+ stRoot, stNodes, stEntries = makeBoundaryStorageTrie(common.BytesToHash(key), slots, db)
+ } else if uneven {
+ stRoot, stNodes, stEntries = makeUnevenStorageTrie(common.BytesToHash(key), slots, db)
+ } else {
+ stRoot, stNodes, stEntries = makeStorageTrieWithSeed(common.BytesToHash(key), uint64(slots), 0, db)
+ }
+ nodes.Merge(stNodes)
+
+ value, _ := rlp.EncodeToBytes(&types.StateAccount{
+ Nonce: i,
+ Balance: uint256.NewInt(i),
+ Root: stRoot,
+ CodeHash: codehash,
+ })
+ elem := &kv{key, value}
+ accTrie.MustUpdate(elem.k, elem.v)
+ entries = append(entries, elem)
+
+ // we reuse the same one for all accounts
+ storageRoots[common.BytesToHash(key)] = stRoot
+ storageEntries[common.BytesToHash(key)] = stEntries
+ }
+ slices.SortFunc(entries, (*kv).cmp)
+
+ // Commit account trie
+ root, set, _ := accTrie.Commit(true)
+ nodes.Merge(set)
+
+ // Commit gathered dirty nodes into database
+ db.Update(root, types.EmptyRootHash, 0, nodes, nil)
+
+ // Re-create tries with new root
+ accTrie, err := trie.New(trie.StateTrieID(root), db)
+ if err != nil {
+ panic(err)
+ }
+ for i := uint64(1); i <= uint64(accounts); i++ {
+ key := key32(i)
+ id := trie.StorageTrieID(root, common.BytesToHash(key), storageRoots[common.BytesToHash(key)])
+ trie, err := trie.New(id, db)
+ if err != nil {
+ panic(err)
+ }
+ storageTries[common.BytesToHash(key)] = trie
+ }
+ return accTrie, entries, storageTries, storageEntries
+}
+
+// makeStorageTrieWithSeed fills a storage trie with n items, returning the
+// not-yet-committed trie and the sorted entries. The seeds can be used to ensure
+// that tries are unique.
+func makeStorageTrieWithSeed(owner common.Hash, n, seed uint64, db *triedb.Database) (common.Hash, *trienode.NodeSet, []*kv) {
+ trie, _ := trie.New(trie.StorageTrieID(types.EmptyRootHash, owner, types.EmptyRootHash), db)
+ var entries []*kv
+ for i := uint64(1); i <= n; i++ {
+ // store 'x' at slot 'x'
+ slotValue := key32(i + seed)
+ rlpSlotValue, _ := rlp.EncodeToBytes(common.TrimLeftZeroes(slotValue[:]))
+
+ slotKey := key32(i)
+ key := crypto.Keccak256Hash(slotKey[:])
+
+ elem := &kv{key[:], rlpSlotValue}
+ trie.MustUpdate(elem.k, elem.v)
+ entries = append(entries, elem)
+ }
+ slices.SortFunc(entries, (*kv).cmp)
+ root, nodes, _ := trie.Commit(false)
+ return root, nodes, entries
+}
+
+// makeBoundaryStorageTrie constructs a storage trie. Instead of filling
+// storage slots normally, this function will fill a few slots which have
+// boundary hash.
+func makeBoundaryStorageTrie(owner common.Hash, n int, db *triedb.Database) (common.Hash, *trienode.NodeSet, []*kv) {
+ var (
+ entries []*kv
+ boundaries []common.Hash
+ trie, _ = trie.New(trie.StorageTrieID(types.EmptyRootHash, owner, types.EmptyRootHash), db)
+ )
+ // Initialize boundaries
+ var next common.Hash
+ step := new(big.Int).Sub(
+ new(big.Int).Div(
+ new(big.Int).Exp(common.Big2, common.Big256, nil),
+ big.NewInt(int64(accountConcurrency)),
+ ), common.Big1,
+ )
+ for i := 0; i < accountConcurrency; i++ {
+ last := common.BigToHash(new(big.Int).Add(next.Big(), step))
+ if i == accountConcurrency-1 {
+ last = common.MaxHash
+ }
+ boundaries = append(boundaries, last)
+ next = common.BigToHash(new(big.Int).Add(last.Big(), common.Big1))
+ }
+ // Fill boundary slots
+ for i := 0; i < len(boundaries); i++ {
+ key := boundaries[i]
+ val := []byte{0xde, 0xad, 0xbe, 0xef}
+
+ elem := &kv{key[:], val}
+ trie.MustUpdate(elem.k, elem.v)
+ entries = append(entries, elem)
+ }
+ // Fill other slots if required
+ for i := uint64(1); i <= uint64(n); i++ {
+ slotKey := key32(i)
+ key := crypto.Keccak256Hash(slotKey[:])
+
+ slotValue := key32(i)
+ rlpSlotValue, _ := rlp.EncodeToBytes(common.TrimLeftZeroes(slotValue[:]))
+
+ elem := &kv{key[:], rlpSlotValue}
+ trie.MustUpdate(elem.k, elem.v)
+ entries = append(entries, elem)
+ }
+ slices.SortFunc(entries, (*kv).cmp)
+ root, nodes, _ := trie.Commit(false)
+ return root, nodes, entries
+}
+
+// makeUnevenStorageTrie constructs a storage tries will states distributed in
+// different range unevenly.
+func makeUnevenStorageTrie(owner common.Hash, slots int, db *triedb.Database) (common.Hash, *trienode.NodeSet, []*kv) {
+ var (
+ entries []*kv
+ tr, _ = trie.New(trie.StorageTrieID(types.EmptyRootHash, owner, types.EmptyRootHash), db)
+ chosen = make(map[byte]struct{})
+ )
+ for i := 0; i < 3; i++ {
+ var n int
+ for {
+ n = mrand.Intn(15) // the last range is set empty deliberately
+ if _, ok := chosen[byte(n)]; ok {
+ continue
+ }
+ chosen[byte(n)] = struct{}{}
+ break
+ }
+ for j := 0; j < slots/3; j++ {
+ key := append([]byte{byte(n)}, testutil.RandBytes(31)...)
+ val, _ := rlp.EncodeToBytes(testutil.RandBytes(32))
+
+ elem := &kv{key, val}
+ tr.MustUpdate(elem.k, elem.v)
+ entries = append(entries, elem)
+ }
+ }
+ slices.SortFunc(entries, (*kv).cmp)
+ root, nodes, _ := tr.Commit(false)
+ return root, nodes, entries
+}
+
+func verifyTrie(scheme string, db ethdb.KeyValueStore, root common.Hash, t *testing.T) {
+ t.Helper()
+ triedb := triedb.NewDatabase(rawdb.NewDatabase(db), newDbConfig(scheme))
+ accTrie, err := trie.New(trie.StateTrieID(root), triedb)
+ if err != nil {
+ t.Fatal(err)
+ }
+ accounts, slots := 0, 0
+ accIt := trie.NewIterator(accTrie.MustNodeIterator(nil))
+ for accIt.Next() {
+ var acc types.StateAccount
+ if err := rlp.DecodeBytes(accIt.Value, &acc); err != nil {
+ log.Crit("Invalid account encountered during snapshot creation", "err", err)
+ }
+ accounts++
+ if acc.Root != types.EmptyRootHash {
+ id := trie.StorageTrieID(root, common.BytesToHash(accIt.Key), acc.Root)
+ storeTrie, err := trie.NewStateTrie(id, triedb)
+ if err != nil {
+ t.Fatal(err)
+ }
+ storeIt := trie.NewIterator(storeTrie.MustNodeIterator(nil))
+ for storeIt.Next() {
+ slots++
+ }
+ if err := storeIt.Err; err != nil {
+ t.Fatal(err)
+ }
+ }
+ }
+ if err := accIt.Err; err != nil {
+ t.Fatal(err)
+ }
+ t.Logf("accounts: %d, slots: %d", accounts, slots)
+}
+
+// TestSyncAccountPerformance tests how efficient the snap algo is at minimizing
+// state healing
+func TestSyncAccountPerformance(t *testing.T) {
+ //t.Parallel()
+
+ testSyncAccountPerformance(t, rawdb.HashScheme)
+ testSyncAccountPerformance(t, rawdb.PathScheme)
+}
+
+func testSyncAccountPerformance(t *testing.T, scheme string) {
+ // Set the account concurrency to 1. This _should_ result in the
+ // range root to become correct, and there should be no healing needed
+ defer func(old int) { accountConcurrency = old }(accountConcurrency)
+ accountConcurrency = 1
+
+ var (
+ once sync.Once
+ cancel = make(chan struct{})
+ term = func() {
+ once.Do(func() {
+ close(cancel)
+ })
+ }
+ )
+ nodeScheme, sourceAccountTrie, elems := makeAccountTrieNoStorage(100, scheme)
+
+ mkSource := func(name string) *testPeer {
+ source := newTestPeer(name, t, term)
+ source.accountTrie = sourceAccountTrie.Copy()
+ source.accountValues = elems
+ return source
+ }
+ src := mkSource("source")
+ syncer := setupSyncer(nodeScheme, src)
+ if err := syncer.Sync(sourceAccountTrie.Hash(), cancel); err != nil {
+ t.Fatalf("sync failed: %v", err)
+ }
+ verifyTrie(scheme, syncer.db, sourceAccountTrie.Hash(), t)
+ // The trie root will always be requested, since it is added when the snap
+ // sync cycle starts. When popping the queue, we do not look it up again.
+ // Doing so would bring this number down to zero in this artificial testcase,
+ // but only add extra IO for no reason in practice.
+ if have, want := src.nTrienodeRequests, 1; have != want {
+ fmt.Print(src.Stats())
+ t.Errorf("trie node heal requests wrong, want %d, have %d", want, have)
+ }
+}
+
+func TestSlotEstimation(t *testing.T) {
+ for i, tc := range []struct {
+ last common.Hash
+ count int
+ want uint64
+ }{
+ {
+ // Half the space
+ common.HexToHash("0x7fffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff"),
+ 100,
+ 100,
+ },
+ {
+ // 1 / 16th
+ common.HexToHash("0x0fffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff"),
+ 100,
+ 1500,
+ },
+ {
+ // Bit more than 1 / 16th
+ common.HexToHash("0x1000000000000000000000000000000000000000000000000000000000000000"),
+ 100,
+ 1499,
+ },
+ {
+ // Almost everything
+ common.HexToHash("0xF000000000000000000000000000000000000000000000000000000000000000"),
+ 100,
+ 6,
+ },
+ {
+ // Almost nothing -- should lead to error
+ common.HexToHash("0x0000000000000000000000000000000000000000000000000000000000000001"),
+ 1,
+ 0,
+ },
+ {
+ // Nothing -- should lead to error
+ common.Hash{},
+ 100,
+ 0,
+ },
+ } {
+ have, _ := estimateRemainingSlots(tc.count, tc.last)
+ if want := tc.want; have != want {
+ t.Errorf("test %d: have %d want %d", i, have, want)
+ }
+ }
+}
+
+func newDbConfig(scheme string) *triedb.Config {
+ if scheme == rawdb.HashScheme {
+ return &triedb.Config{}
+ }
+ return &triedb.Config{DBOverride: pathdb.Defaults.BackendConstructor}
+}
diff --git a/eth/protocols/snap/tracker.go b/eth/protocols/snap/tracker.go
new file mode 100644
index 0000000000..f4b6b779df
--- /dev/null
+++ b/eth/protocols/snap/tracker.go
@@ -0,0 +1,26 @@
+// Copyright 2021 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see .
+
+package snap
+
+import (
+ "time"
+
+ "github.com/ava-labs/libevm/p2p/tracker"
+)
+
+// requestTracker is a singleton tracker for request times.
+var requestTracker = tracker.New(ProtocolName, time.Minute)
diff --git a/go.mod b/go.mod
index 7b23e789b0..afca046650 100644
--- a/go.mod
+++ b/go.mod
@@ -9,6 +9,7 @@ require (
github.com/crate-crypto/go-ipa v0.0.0-20231025140028-3c0104f4b233
github.com/davecgh/go-spew v1.1.1
github.com/deckarep/golang-set/v2 v2.1.0
+ github.com/google/gofuzz v1.2.0
github.com/gorilla/rpc v1.2.0
github.com/gorilla/websocket v1.5.0
github.com/hashicorp/go-bexpr v0.1.10
diff --git a/peer/network.go b/peer/network.go
index 7ce1203b45..8c5d6d95b7 100644
--- a/peer/network.go
+++ b/peer/network.go
@@ -70,6 +70,10 @@ type Network interface {
NewClient(protocol uint64, options ...p2p.ClientOption) *p2p.Client
// AddHandler registers a server handler for an application protocol
AddHandler(protocol uint64, handler p2p.Handler) error
+
+ // AddConnector adds a listener for Connected/Disconnected events.
+ // When a connector is added it will be called for all currently connected peers.
+ AddConnector(validators.Connector) error
}
// network is an implementation of Network that processes message requests for
@@ -86,6 +90,7 @@ type network struct {
appRequestHandler message.RequestHandler // maps request type => handler
peers *peerTracker // tracking of peers & bandwidth
appStats stats.RequestHandlerStats // Provide request handler metrics
+ connectors []validators.Connector // List of connectors to notify on Connected/Disconnected events
// Set to true when Shutdown is called, after which all operations on this
// struct are no-ops.
@@ -112,6 +117,22 @@ func NewNetwork(p2pNetwork *p2p.Network, appSender common.AppSender, codec codec
}
}
+func (n *network) AddConnector(connector validators.Connector) error {
+ n.lock.Lock()
+ defer n.lock.Unlock()
+
+ n.connectors = append(n.connectors, connector)
+
+ // Notify the connector of all currently connected peers
+ for peerID, peer := range n.peers.peers {
+ if err := connector.Connected(context.Background(), peerID, peer.version); err != nil {
+ return fmt.Errorf("failed to notify connector of connected peer %s: %w", peerID, err)
+ }
+ }
+
+ return nil
+}
+
// SendAppRequestAny synchronously sends request to an arbitrary peer with a
// node version greater than or equal to minVersion. If minVersion is nil,
// the request will be sent to any peer regardless of their version.
@@ -359,6 +380,12 @@ func (n *network) Connected(ctx context.Context, nodeID ids.NodeID, nodeVersion
n.peers.Connected(nodeID, nodeVersion)
}
+ for _, connector := range n.connectors {
+ if err := connector.Connected(ctx, nodeID, nodeVersion); err != nil {
+ return fmt.Errorf("failed to notify connector of connected peer %s: %w", nodeID, err)
+ }
+ }
+
return n.p2pNetwork.Connected(ctx, nodeID, nodeVersion)
}
@@ -377,6 +404,12 @@ func (n *network) Disconnected(ctx context.Context, nodeID ids.NodeID) error {
n.peers.Disconnected(nodeID)
}
+ for _, connector := range n.connectors {
+ if err := connector.Disconnected(ctx, nodeID); err != nil {
+ return fmt.Errorf("failed to notify connector of disconnected peer %s: %w", nodeID, err)
+ }
+ }
+
return n.p2pNetwork.Disconnected(ctx, nodeID)
}
diff --git a/plugin/evm/config/config.go b/plugin/evm/config/config.go
index 864c990911..a3c28f083d 100644
--- a/plugin/evm/config/config.go
+++ b/plugin/evm/config/config.go
@@ -23,7 +23,7 @@ const (
defaultTrieDirtyCommitTarget = 20
defaultTriePrefetcherParallelism = 16
defaultSnapshotCache = 256
- defaultSyncableCommitInterval = defaultCommitInterval * 4
+ defaultSyncableCommitInterval = defaultCommitInterval
defaultSnapshotWait = false
defaultRpcGasCap = 50_000_000 // Default to 50M Gas Limit
defaultRpcTxFeeCap = 100 // 100 AVAX
@@ -220,6 +220,9 @@ type Config struct {
// RPC settings
HttpBodyLimit uint64 `json:"http-body-limit"`
+
+ // Experimental
+ StateSyncUseUpstream bool `json:"state-sync-use-upstream"`
}
// TxPoolConfig contains the transaction pool config to be passed
diff --git a/plugin/evm/statesync/connector.go b/plugin/evm/statesync/connector.go
new file mode 100644
index 0000000000..813f1f87ea
--- /dev/null
+++ b/plugin/evm/statesync/connector.go
@@ -0,0 +1,172 @@
+// (c) 2025, Ava Labs, Inc. All rights reserved.
+// See the file LICENSE for licensing terms.
+
+package statesync
+
+import (
+ "context"
+ "fmt"
+ "time"
+
+ "github.com/ava-labs/avalanchego/ids"
+ "github.com/ava-labs/avalanchego/network/p2p"
+ "github.com/ava-labs/avalanchego/snow/engine/common"
+ "github.com/ava-labs/avalanchego/snow/validators"
+ "github.com/ava-labs/avalanchego/utils/set"
+ "github.com/ava-labs/avalanchego/version"
+ "github.com/ava-labs/coreth/core"
+ "github.com/ava-labs/coreth/eth/protocols/snap"
+ "github.com/ava-labs/libevm/log"
+ ethp2p "github.com/ava-labs/libevm/p2p"
+ "github.com/ava-labs/libevm/p2p/enode"
+)
+
+const (
+ maxRetries = 5
+ failedRequestSleepInterval = 10 * time.Millisecond
+)
+
+var (
+ _ validators.Connector = (*Connector)(nil)
+ _ ethp2p.MsgReadWriter = (*outbound)(nil)
+)
+
+type Connector struct {
+ sync *snap.Syncer
+ sender *p2p.Client
+}
+
+func NewConnector(sync *snap.Syncer, sender *p2p.Client) *Connector {
+ return &Connector{sync: sync, sender: sender}
+}
+
+func (c *Connector) Connected(ctx context.Context, nodeID ids.NodeID, version *version.Application) error {
+ return c.sync.Register(NewOutboundPeer(nodeID, c.sync, c.sender))
+}
+
+func (c *Connector) Disconnected(ctx context.Context, nodeID ids.NodeID) error {
+ return c.sync.Unregister(nodeID.String())
+}
+
+type outbound struct {
+ peerID ids.NodeID
+ sync *snap.Syncer
+ sender *p2p.Client
+}
+
+func NewOutboundPeer(nodeID ids.NodeID, sync *snap.Syncer, sender *p2p.Client) *snap.Peer {
+ return snap.NewFakePeer(protocolVersion, nodeID.String(), &outbound{
+ peerID: nodeID,
+ sync: sync,
+ sender: sender,
+ })
+}
+
+// ReadMsg implements the ethp2p.MsgReadWriter interface.
+// It is not expected to be called in the used code path.
+func (o *outbound) ReadMsg() (ethp2p.Msg, error) { panic("not expected to be called") }
+
+func (o *outbound) WriteMsg(msg ethp2p.Msg) error {
+ bytes, err := toBytes(msg)
+ if err != nil {
+ return fmt.Errorf("failed to convert message to bytes: %w, expected: %d", err, msg.Size)
+ }
+
+ message := &retryableMessage{outbound: o, outBytes: bytes}
+ return message.send()
+}
+
+type retryableMessage struct {
+ outbound *outbound
+ outBytes []byte
+ retries int
+}
+
+func (r *retryableMessage) send() error {
+ nodeIDs := set.NewSet[ids.NodeID](1)
+ nodeIDs.Add(r.outbound.peerID)
+
+ return r.outbound.sender.AppRequest(context.Background(), nodeIDs, r.outBytes, r.handleResponse)
+}
+
+func (r *retryableMessage) handleResponse(
+ ctx context.Context,
+ nodeID ids.NodeID,
+ responseBytes []byte,
+ err error,
+) {
+ if err == nil { // Handle successful response
+ log.Debug("statesync AppRequest response", "nodeID", nodeID, "responseBytes", len(responseBytes))
+ p := snap.NewFakePeer(protocolVersion, nodeID.String(), &rw{readBytes: responseBytes})
+ if err := snap.HandleMessage(r.outbound, p); err != nil {
+ log.Warn("failed to handle response", "peer", nodeID, "err", err)
+ }
+ return
+ }
+
+ // Handle retry
+
+ log.Warn("got error response from peer", "peer", nodeID, "err", err)
+ // TODO: Is this the right way to check for AppError?
+ // Notably errors.As expects a ptr to a type that implements error interface,
+ // but *AppError implements error, not AppError.
+ appErr, ok := err.(*common.AppError)
+ if !ok {
+ log.Warn("unexpected error type", "err", err)
+ return
+ }
+ if appErr.Code != common.ErrTimeout.Code {
+ log.Debug("dropping non-timeout error", "peer", nodeID, "err", err)
+ return // only retry on timeout
+ }
+ if r.retries >= maxRetries {
+ log.Warn("reached max retries", "peer", nodeID)
+ return
+ }
+ r.retries++
+ log.Debug("retrying request", "peer", nodeID, "retries", r.retries)
+ time.Sleep(failedRequestSleepInterval)
+ if err := r.send(); err != nil {
+ log.Warn("failed to retry request, dropping", "peer", nodeID, "err", err)
+ }
+}
+
+func (o *outbound) Chain() *core.BlockChain { panic("not expected to be called") }
+func (o *outbound) RunPeer(*snap.Peer, snap.Handler) error { panic("not expected to be called") }
+func (o *outbound) PeerInfo(id enode.ID) interface{} { panic("not expected to be called") }
+
+func (o *outbound) Handle(peer *snap.Peer, packet snap.Packet) error {
+ d := &Downloader{SnapSyncer: o.sync}
+ return d.DeliverSnapPacket(peer, packet)
+}
+
+// Downloader is copied from eth/downloader/downloader.go
+type Downloader struct {
+ SnapSyncer *snap.Syncer
+}
+
+// DeliverSnapPacket is invoked from a peer's message handler when it transmits a
+// data packet for the local node to consume.
+func (d *Downloader) DeliverSnapPacket(peer *snap.Peer, packet snap.Packet) error {
+ switch packet := packet.(type) {
+ case *snap.AccountRangePacket:
+ hashes, accounts, err := packet.Unpack()
+ if err != nil {
+ return err
+ }
+ return d.SnapSyncer.OnAccounts(peer, packet.ID, hashes, accounts, packet.Proof)
+
+ case *snap.StorageRangesPacket:
+ hashset, slotset := packet.Unpack()
+ return d.SnapSyncer.OnStorage(peer, packet.ID, hashset, slotset, packet.Proof)
+
+ case *snap.ByteCodesPacket:
+ return d.SnapSyncer.OnByteCodes(peer, packet.ID, packet.Codes)
+
+ case *snap.TrieNodesPacket:
+ return d.SnapSyncer.OnTrieNodes(peer, packet.ID, packet.Nodes)
+
+ default:
+ return fmt.Errorf("unexpected snap packet type: %T", packet)
+ }
+}
diff --git a/plugin/evm/statesync/handler.go b/plugin/evm/statesync/handler.go
new file mode 100644
index 0000000000..cffe5242f2
--- /dev/null
+++ b/plugin/evm/statesync/handler.go
@@ -0,0 +1,128 @@
+// (c) 2025, Ava Labs, Inc. All rights reserved.
+// See the file LICENSE for licensing terms.
+
+package statesync
+
+import (
+ "bytes"
+ "context"
+ "encoding/binary"
+ "errors"
+ "fmt"
+ "io"
+ "time"
+
+ "github.com/ava-labs/avalanchego/ids"
+ "github.com/ava-labs/avalanchego/network/p2p"
+ "github.com/ava-labs/avalanchego/snow/engine/common"
+ "github.com/ava-labs/avalanchego/utils/wrappers"
+ "github.com/ava-labs/coreth/core"
+ "github.com/ava-labs/coreth/eth/protocols/snap"
+ "github.com/ava-labs/libevm/log"
+ ethp2p "github.com/ava-labs/libevm/p2p"
+ "github.com/ava-labs/libevm/p2p/enode"
+)
+
+var (
+ _ p2p.Handler = (*Handler)(nil)
+ _ snap.Backend = (*Handler)(nil)
+ _ ethp2p.MsgReadWriter = (*rw)(nil)
+)
+
+const (
+ ProtocolID = 128 // ID for the state sync handler, leaving earlier IDs reserved
+ ErrCodeSnapHandlerFailed = 1
+
+ protocolVersion = 0
+)
+
+type Handler struct {
+ chain *core.BlockChain
+}
+
+func NewHandler(chain *core.BlockChain) *Handler {
+ return &Handler{chain: chain}
+}
+
+func (h *Handler) AppRequest(
+ ctx context.Context,
+ nodeID ids.NodeID,
+ deadline time.Time,
+ requestBytes []byte,
+) ([]byte, *common.AppError) {
+ start := time.Now()
+ rw := &rw{readBytes: requestBytes}
+ p := snap.NewFakePeer(protocolVersion, nodeID.String(), rw)
+ err := snap.HandleMessage(h, p)
+ if err != nil {
+ log.Debug("statesync AppRequest err", "nodeID", nodeID, "err", err)
+ return nil, &common.AppError{
+ Code: ErrCodeSnapHandlerFailed,
+ Message: err.Error(),
+ }
+ }
+ log.Debug("statesync AppRequest response", "nodeID", nodeID, "responseBytes", len(rw.writeBytes), "duration", time.Since(start))
+ return rw.writeBytes, nil
+}
+
+// AppGossip implements p2p.Handler.
+// It is implemented as a no-op as gossip is not used in the state sync protocol.
+func (h *Handler) AppGossip(ctx context.Context, nodeID ids.NodeID, gossipBytes []byte) {}
+
+// Chain implements snap.Backend.
+func (h *Handler) Chain() *core.BlockChain { return h.chain }
+
+// Handle implements snap.Backend.
+// It is implemented as a no-op as the main handling is done in AppRequest.
+func (h *Handler) Handle(*snap.Peer, snap.Packet) error { return nil }
+
+// RunPeer implements snap.Handler.
+// It is not expected to be called in the used code path.
+func (h *Handler) RunPeer(*snap.Peer, snap.Handler) error { panic("calling not expected") }
+
+// PeerInfo implements snap.Handler.
+// It is not expected to be called in the used code path.
+func (h *Handler) PeerInfo(id enode.ID) interface{} { panic("calling not expected") }
+
+// rw is a helper struct that implements ethp2p.MsgReadWriter.
+type rw struct {
+ readBytes []byte
+ writeBytes []byte
+}
+
+// ReadMsg implements ethp2p.MsgReadWriter.
+// It is expected to be called exactly once, immediately after the request is received.
+func (rw *rw) ReadMsg() (ethp2p.Msg, error) {
+ return fromBytes(rw.readBytes)
+}
+
+// WriteMsg implements ethp2p.MsgReadWriter.
+// It is expected to be called exactly once, immediately after the response is prepared.
+func (rw *rw) WriteMsg(msg ethp2p.Msg) error {
+ var err error
+ rw.writeBytes, err = toBytes(msg)
+ return err
+}
+
+func fromBytes(msgBytes []byte) (ethp2p.Msg, error) {
+ if len(msgBytes) < wrappers.LongLen {
+ return ethp2p.Msg{}, fmt.Errorf("bytes too short: %d", len(msgBytes))
+ }
+ code := binary.BigEndian.Uint64(msgBytes)
+ return ethp2p.Msg{
+ Code: code,
+ Size: uint32(len(msgBytes) - wrappers.LongLen),
+ Payload: bytes.NewReader(msgBytes[wrappers.LongLen:]),
+ ReceivedAt: time.Now(),
+ }, nil
+}
+
+func toBytes(msg ethp2p.Msg) ([]byte, error) {
+ bytes := make([]byte, msg.Size+wrappers.LongLen)
+ binary.BigEndian.PutUint64(bytes, msg.Code)
+ n, err := msg.Payload.Read(bytes[wrappers.LongLen:])
+ if n == int(msg.Size) && errors.Is(err, io.EOF) {
+ err = nil
+ }
+ return bytes, err
+}
diff --git a/plugin/evm/syncervm_client.go b/plugin/evm/syncervm_client.go
index 2807b6b0da..8a3438108e 100644
--- a/plugin/evm/syncervm_client.go
+++ b/plugin/evm/syncervm_client.go
@@ -17,8 +17,11 @@ import (
"github.com/ava-labs/coreth/core/rawdb"
"github.com/ava-labs/coreth/core/state/snapshot"
"github.com/ava-labs/coreth/eth"
+ "github.com/ava-labs/coreth/eth/protocols/snap"
"github.com/ava-labs/coreth/params"
+ "github.com/ava-labs/coreth/peer"
"github.com/ava-labs/coreth/plugin/evm/message"
+ ethstatesync "github.com/ava-labs/coreth/plugin/evm/statesync"
syncclient "github.com/ava-labs/coreth/sync/client"
"github.com/ava-labs/coreth/sync/statesync"
"github.com/ava-labs/libevm/common"
@@ -57,6 +60,12 @@ type stateSyncClientConfig struct {
client syncclient.Client
toEngine chan<- commonEng.Message
+
+ ///
+ useUpstream bool
+ network peer.Network
+ appSender commonEng.AppSender
+ stateSyncNodes []ids.NodeID
}
type stateSyncerClient struct {
@@ -151,13 +160,43 @@ func (client *stateSyncerClient) stateSync(ctx context.Context) error {
// Sync the EVM trie and then the atomic trie. These steps could be done
// in parallel or in the opposite order. Keeping them serial for simplicity for now.
- if err := client.syncStateTrie(ctx); err != nil {
- return err
+ if client.useUpstream {
+ log.Warn("Using upstream state syncer (untested)")
+ syncer := snap.NewSyncer(client.chaindb, rawdb.HashScheme)
+ p2pClient := client.network.NewClient(ethstatesync.ProtocolID)
+ if len(client.stateSyncNodes) > 0 {
+ for _, nodeID := range client.stateSyncNodes {
+ syncer.Register(ethstatesync.NewOutboundPeer(nodeID, syncer, p2pClient))
+ }
+ } else {
+ client.network.AddConnector(ethstatesync.NewConnector(syncer, p2pClient))
+ }
+ if err := syncer.Sync(client.syncSummary.BlockRoot, convertReadOnlyToBidirectional(ctx.Done())); err != nil {
+ return err
+ }
+ log.Info("Upstream state syncer completed")
+ } else {
+ if err := client.syncStateTrie(ctx); err != nil {
+ return err
+ }
}
return client.syncAtomicTrie(ctx)
}
+func convertReadOnlyToBidirectional[T any](readOnly <-chan T) chan T {
+ bidirectional := make(chan T)
+
+ go func() {
+ defer close(bidirectional)
+ for value := range readOnly {
+ bidirectional <- value
+ }
+ }()
+
+ return bidirectional
+}
+
// acceptSyncSummary returns true if sync will be performed and launches the state sync process
// in a goroutine.
func (client *stateSyncerClient) acceptSyncSummary(proposedSummary message.SyncSummary) (block.StateSyncMode, error) {
diff --git a/plugin/evm/vm.go b/plugin/evm/vm.go
index 99374fc3ab..cadafd80a1 100644
--- a/plugin/evm/vm.go
+++ b/plugin/evm/vm.go
@@ -44,6 +44,7 @@ import (
"github.com/ava-labs/coreth/plugin/evm/atomic"
"github.com/ava-labs/coreth/plugin/evm/config"
"github.com/ava-labs/coreth/plugin/evm/message"
+ "github.com/ava-labs/coreth/plugin/evm/statesync"
"github.com/ava-labs/coreth/triedb/hashdb"
"github.com/ava-labs/coreth/utils"
"github.com/ava-labs/libevm/triedb"
@@ -621,6 +622,9 @@ func (vm *VM) Initialize(
warpHandler := acp118.NewCachedHandler(meteredCache, vm.warpBackend, vm.ctx.WarpSigner)
vm.Network.AddHandler(p2p.SignatureRequestHandlerID, warpHandler)
+ //////
+ vm.Network.AddHandler(statesync.ProtocolID, statesync.NewHandler(vm.blockChain))
+
vm.setAppRequestHandlers()
vm.StateSyncServer = NewStateSyncServer(&stateSyncServerConfig{
@@ -726,6 +730,10 @@ func (vm *VM) initializeStateSyncClient(lastAcceptedHeight uint64) error {
db: vm.versiondb,
atomicBackend: vm.atomicBackend,
toEngine: vm.toEngine,
+ network: vm.Network,
+ appSender: vm.p2pSender,
+ stateSyncNodes: stateSyncIDs,
+ useUpstream: vm.config.StateSyncUseUpstream,
})
// If StateSync is disabled, clear any ongoing summary so that we will not attempt to resume
diff --git a/scripts/eth-allowed-packages.txt b/scripts/eth-allowed-packages.txt
index 6f69e52bcd..8412e8d94e 100644
--- a/scripts/eth-allowed-packages.txt
+++ b/scripts/eth-allowed-packages.txt
@@ -27,6 +27,7 @@
"github.com/ava-labs/libevm/libevm"
"github.com/ava-labs/libevm/libevm/stateconf"
"github.com/ava-labs/libevm/log"
+"github.com/ava-labs/libevm/p2p/enode"
"github.com/ava-labs/libevm/params"
"github.com/ava-labs/libevm/rlp"
"github.com/ava-labs/libevm/trie"
diff --git a/scripts/lint_allowed_eth_imports.sh b/scripts/lint_allowed_eth_imports.sh
index fcbfe7c6f4..be4af2bdf9 100755
--- a/scripts/lint_allowed_eth_imports.sh
+++ b/scripts/lint_allowed_eth_imports.sh
@@ -11,7 +11,7 @@ set -o pipefail
# 4. Print out the difference between the search results and the list of specified allowed package imports from libevm.
libevm_regexp='"github.com/ava-labs/libevm/.*"'
allow_named_imports='eth\w\+ "'
-extra_imports=$(grep -r --include='*.go' --exclude=mocks.go "${libevm_regexp}" -h | grep -v "${allow_named_imports}" | grep -o "${libevm_regexp}" | sort -u | comm -23 - ./scripts/eth-allowed-packages.txt)
+extra_imports=$(find . -type f -name '*.go' ! -path './eth/protocols/snap/*' ! -name 'mocks.go' -print0 | xargs -0 grep "${libevm_regexp}" -h | grep -v "${allow_named_imports}" | grep -o "${libevm_regexp}" | sort -u | comm -23 - ./scripts/eth-allowed-packages.txt)
if [ -n "${extra_imports}" ]; then
echo "new ethereum imports should be added to ./scripts/eth-allowed-packages.txt to prevent accidental imports:"
echo "${extra_imports}"
diff --git a/sync/handlers/leafs_request.go b/sync/handlers/leafs_request.go
index c9df10ef10..1b628a8327 100644
--- a/sync/handlers/leafs_request.go
+++ b/sync/handlers/leafs_request.go
@@ -185,6 +185,30 @@ type responseBuilder struct {
stats stats.LeafsRequestHandlerStats
}
+func NewResponseBuilder(
+ request *message.LeafsRequest,
+ response *message.LeafsResponse,
+ t *trie.Trie,
+ snap *snapshot.Tree,
+ keyLength int,
+ limit uint16,
+ stats stats.LeafsRequestHandlerStats,
+) *responseBuilder {
+ return &responseBuilder{
+ request: request,
+ response: response,
+ t: t,
+ snap: snap,
+ keyLength: keyLength,
+ limit: limit,
+ stats: stats,
+ }
+}
+
+func (rb *responseBuilder) HandleRequest(ctx context.Context) error {
+ return rb.handleRequest(ctx)
+}
+
func (rb *responseBuilder) handleRequest(ctx context.Context) error {
// Read from snapshot if a [snapshot.Tree] was provided in initialization
if rb.snap != nil {
@@ -223,6 +247,8 @@ func (rb *responseBuilder) handleRequest(ctx context.Context) error {
rb.stats.IncProofError()
return err
}
+ rb.response.More = true // set to signal a proof is needed
+
return nil
}
@@ -271,6 +297,7 @@ func (rb *responseBuilder) fillFromSnapshot(ctx context.Context) (bool, error) {
rb.stats.IncSnapshotReadSuccess()
return true, nil
}
+ rb.response.More = true // set to signal a proof is needed
rb.response.ProofVals, err = iterateVals(proof)
if err != nil {
rb.stats.IncProofError()
@@ -436,7 +463,7 @@ func (rb *responseBuilder) fillFromTrie(ctx context.Context, end []byte) (bool,
more := false
for it.Next() {
// if we're at the end, break this loop
- if len(end) > 0 && bytes.Compare(it.Key, end) > 0 {
+ if len(rb.response.Keys) > 0 && len(end) > 0 && bytes.Compare(it.Key, end) > 0 {
more = true
break
}
@@ -487,7 +514,7 @@ func (rb *responseBuilder) readLeafsFromSnapshot(ctx context.Context) ([][]byte,
defer snapIt.Release()
for snapIt.Next() {
// if we're at the end, break this loop
- if len(rb.request.End) > 0 && bytes.Compare(snapIt.Key(), rb.request.End) > 0 {
+ if len(keys) > 0 && len(rb.request.End) > 0 && bytes.Compare(snapIt.Key(), rb.request.End) > 0 {
break
}
// If we've returned enough data or run out of time, set the more flag and exit