diff --git a/cmd/geth/misccmd.go b/cmd/geth/misccmd.go
index 967df2ada..b347d31d9 100644
--- a/cmd/geth/misccmd.go
+++ b/cmd/geth/misccmd.go
@@ -25,7 +25,6 @@ import (
"github.com/ethereum/go-ethereum/cmd/utils"
"github.com/ethereum/go-ethereum/consensus/ethash"
- "github.com/ethereum/go-ethereum/eth"
"github.com/ethereum/go-ethereum/params"
"gopkg.in/urfave/cli.v1"
)
@@ -143,7 +142,6 @@ func version(ctx *cli.Context) error {
fmt.Println("Git Commit Date:", gitDate)
}
fmt.Println("Architecture:", runtime.GOARCH)
- fmt.Println("Protocol Versions:", eth.ProtocolVersions)
fmt.Println("Go Version:", runtime.Version())
fmt.Println("Operating System:", runtime.GOOS)
fmt.Printf("GOPATH=%s\n", os.Getenv("GOPATH"))
diff --git a/cmd/utils/flags.go b/cmd/utils/flags.go
index 051bdd630..0b1695d0a 100644
--- a/cmd/utils/flags.go
+++ b/cmd/utils/flags.go
@@ -187,7 +187,7 @@ var (
defaultSyncMode = eth.DefaultConfig.SyncMode
SyncModeFlag = TextMarshalerFlag{
Name: "syncmode",
- Usage: `Blockchain sync mode ("fast", "full", or "light")`,
+ Usage: `Blockchain sync mode ("fast", "full", "snap" or "light")`,
Value: &defaultSyncMode,
}
GCModeFlag = cli.StringFlag{
@@ -1555,8 +1555,14 @@ func SetEthConfig(ctx *cli.Context, stack *node.Node, cfg *eth.Config) {
cfg.SnapshotCache = ctx.GlobalInt(CacheFlag.Name) * ctx.GlobalInt(CacheSnapshotFlag.Name) / 100
}
if !ctx.GlobalIsSet(SnapshotFlag.Name) {
- cfg.TrieCleanCache += cfg.SnapshotCache
- cfg.SnapshotCache = 0 // Disabled
+ // If snap-sync is requested, this flag is also required
+ if cfg.SyncMode == downloader.SnapSync {
+ log.Info("Snap sync requested, enabling --snapshot")
+ ctx.Set(SnapshotFlag.Name, "true")
+ } else {
+ cfg.TrieCleanCache += cfg.SnapshotCache
+ cfg.SnapshotCache = 0 // Disabled
+ }
}
if ctx.GlobalIsSet(DocRootFlag.Name) {
cfg.DocRoot = ctx.GlobalString(DocRootFlag.Name)
@@ -1585,16 +1591,15 @@ func SetEthConfig(ctx *cli.Context, stack *node.Node, cfg *eth.Config) {
cfg.RPCTxFeeCap = ctx.GlobalFloat64(RPCGlobalTxFeeCapFlag.Name)
}
if ctx.GlobalIsSet(NoDiscoverFlag.Name) {
- cfg.DiscoveryURLs = []string{}
+ cfg.EthDiscoveryURLs, cfg.SnapDiscoveryURLs = []string{}, []string{}
} else if ctx.GlobalIsSet(DNSDiscoveryFlag.Name) {
urls := ctx.GlobalString(DNSDiscoveryFlag.Name)
if urls == "" {
- cfg.DiscoveryURLs = []string{}
+ cfg.EthDiscoveryURLs = []string{}
} else {
- cfg.DiscoveryURLs = SplitAndTrim(urls)
+ cfg.EthDiscoveryURLs = SplitAndTrim(urls)
}
}
-
// Override any default configs for hard coded networks.
switch {
case ctx.GlobalBool(LegacyTestnetFlag.Name) || ctx.GlobalBool(RopstenFlag.Name):
@@ -1676,16 +1681,20 @@ func SetEthConfig(ctx *cli.Context, stack *node.Node, cfg *eth.Config) {
// SetDNSDiscoveryDefaults configures DNS discovery with the given URL if
// no URLs are set.
func SetDNSDiscoveryDefaults(cfg *eth.Config, genesis common.Hash) {
- if cfg.DiscoveryURLs != nil {
+ if cfg.EthDiscoveryURLs != nil {
return // already set through flags/config
}
-
protocol := "all"
if cfg.SyncMode == downloader.LightSync {
protocol = "les"
}
if url := params.KnownDNSNetwork(genesis, protocol); url != "" {
- cfg.DiscoveryURLs = []string{url}
+ cfg.EthDiscoveryURLs = []string{url}
+ }
+ if cfg.SyncMode == downloader.SnapSync {
+ if url := params.KnownDNSNetwork(genesis, "snap"); url != "" {
+ cfg.SnapDiscoveryURLs = []string{url}
+ }
}
}
diff --git a/core/blockchain.go b/core/blockchain.go
index bc1db49f3..d9505dcf6 100644
--- a/core/blockchain.go
+++ b/core/blockchain.go
@@ -659,12 +659,8 @@ func (bc *BlockChain) CurrentBlock() *types.Block {
return bc.currentBlock.Load().(*types.Block)
}
-// Snapshot returns the blockchain snapshot tree. This method is mainly used for
-// testing, to make it possible to verify the snapshot after execution.
-//
-// Warning: There are no guarantees about the safety of using the returned 'snap' if the
-// blockchain is simultaneously importing blocks, so take care.
-func (bc *BlockChain) Snapshot() *snapshot.Tree {
+// Snapshots returns the blockchain snapshot tree.
+func (bc *BlockChain) Snapshots() *snapshot.Tree {
return bc.snaps
}
diff --git a/core/blockchain_snapshot_test.go b/core/blockchain_snapshot_test.go
index e8d3b2470..f35dae167 100644
--- a/core/blockchain_snapshot_test.go
+++ b/core/blockchain_snapshot_test.go
@@ -751,7 +751,7 @@ func testSnapshot(t *testing.T, tt *snapshotTest) {
t.Fatalf("Failed to recreate chain: %v", err)
}
chain.InsertChain(newBlocks)
- chain.Snapshot().Cap(newBlocks[len(newBlocks)-1].Root(), 0)
+ chain.Snapshots().Cap(newBlocks[len(newBlocks)-1].Root(), 0)
// Simulate the blockchain crash
// Don't call chain.Stop here, so that no snapshot
diff --git a/core/forkid/forkid.go b/core/forkid/forkid.go
index c43285861..1bf340682 100644
--- a/core/forkid/forkid.go
+++ b/core/forkid/forkid.go
@@ -84,6 +84,15 @@ func NewID(config *params.ChainConfig, genesis common.Hash, head uint64) ID {
return ID{Hash: checksumToBytes(hash), Next: next}
}
+// NewIDWithChain calculates the Ethereum fork ID from an existing chain instance.
+func NewIDWithChain(chain Blockchain) ID {
+ return NewID(
+ chain.Config(),
+ chain.Genesis().Hash(),
+ chain.CurrentHeader().Number.Uint64(),
+ )
+}
+
// NewFilter creates a filter that returns if a fork ID should be rejected or not
// based on the local chain's status.
func NewFilter(chain Blockchain) Filter {
diff --git a/core/rawdb/accessors_snapshot.go b/core/rawdb/accessors_snapshot.go
index 5bd48ad5f..0a91d9353 100644
--- a/core/rawdb/accessors_snapshot.go
+++ b/core/rawdb/accessors_snapshot.go
@@ -175,3 +175,24 @@ func DeleteSnapshotRecoveryNumber(db ethdb.KeyValueWriter) {
log.Crit("Failed to remove snapshot recovery number", "err", err)
}
}
+
+// ReadSanpshotSyncStatus retrieves the serialized sync status saved at shutdown.
+func ReadSanpshotSyncStatus(db ethdb.KeyValueReader) []byte {
+ data, _ := db.Get(snapshotSyncStatusKey)
+ return data
+}
+
+// WriteSnapshotSyncStatus stores the serialized sync status to save at shutdown.
+func WriteSnapshotSyncStatus(db ethdb.KeyValueWriter, status []byte) {
+ if err := db.Put(snapshotSyncStatusKey, status); err != nil {
+ log.Crit("Failed to store snapshot sync status", "err", err)
+ }
+}
+
+// DeleteSnapshotSyncStatus deletes the serialized sync status saved at the last
+// shutdown
+func DeleteSnapshotSyncStatus(db ethdb.KeyValueWriter) {
+ if err := db.Delete(snapshotSyncStatusKey); err != nil {
+ log.Crit("Failed to remove snapshot sync status", "err", err)
+ }
+}
diff --git a/core/rawdb/schema.go b/core/rawdb/schema.go
index cff27b4bb..2aabfd3ba 100644
--- a/core/rawdb/schema.go
+++ b/core/rawdb/schema.go
@@ -57,6 +57,9 @@ var (
// snapshotRecoveryKey tracks the snapshot recovery marker across restarts.
snapshotRecoveryKey = []byte("SnapshotRecovery")
+ // snapshotSyncStatusKey tracks the snapshot sync status across restarts.
+ snapshotSyncStatusKey = []byte("SnapshotSyncStatus")
+
// txIndexTailKey tracks the oldest block whose transactions have been indexed.
txIndexTailKey = []byte("TransactionIndexTail")
diff --git a/core/state/snapshot/generate.go b/core/state/snapshot/generate.go
index 92c7640c4..4a2fa78d3 100644
--- a/core/state/snapshot/generate.go
+++ b/core/state/snapshot/generate.go
@@ -241,7 +241,7 @@ func (dl *diskLayer) generate(stats *generatorStats) {
if acc.Root != emptyRoot {
storeTrie, err := trie.NewSecure(acc.Root, dl.triedb)
if err != nil {
- log.Error("Generator failed to access storage trie", "accroot", dl.root, "acchash", common.BytesToHash(accIt.Key), "stroot", acc.Root, "err", err)
+ log.Error("Generator failed to access storage trie", "root", dl.root, "account", accountHash, "stroot", acc.Root, "err", err)
abort := <-dl.genAbort
abort <- stats
return
diff --git a/core/state/statedb.go b/core/state/statedb.go
index ed9a82379..a9d1de2e0 100644
--- a/core/state/statedb.go
+++ b/core/state/statedb.go
@@ -314,14 +314,19 @@ func (s *StateDB) GetState(addr common.Address, hash common.Hash) common.Hash {
return common.Hash{}
}
-// GetProof returns the MerkleProof for a given Account
-func (s *StateDB) GetProof(a common.Address) ([][]byte, error) {
+// GetProof returns the Merkle proof for a given account.
+func (s *StateDB) GetProof(addr common.Address) ([][]byte, error) {
+ return s.GetProofByHash(crypto.Keccak256Hash(addr.Bytes()))
+}
+
+// GetProofByHash returns the Merkle proof for a given account.
+func (s *StateDB) GetProofByHash(addrHash common.Hash) ([][]byte, error) {
var proof proofList
- err := s.trie.Prove(crypto.Keccak256(a.Bytes()), 0, &proof)
+ err := s.trie.Prove(addrHash[:], 0, &proof)
return proof, err
}
-// GetStorageProof returns the StorageProof for given key
+// GetStorageProof returns the Merkle proof for given storage slot.
func (s *StateDB) GetStorageProof(a common.Address, key common.Hash) ([][]byte, error) {
var proof proofList
trie := s.StorageTrie(a)
@@ -332,6 +337,17 @@ func (s *StateDB) GetStorageProof(a common.Address, key common.Hash) ([][]byte,
return proof, err
}
+// GetStorageProofByHash returns the Merkle proof for given storage slot.
+func (s *StateDB) GetStorageProofByHash(a common.Address, key common.Hash) ([][]byte, error) {
+ var proof proofList
+ trie := s.StorageTrie(a)
+ if trie == nil {
+ return proof, errors.New("storage trie for requested address does not exist")
+ }
+ err := trie.Prove(crypto.Keccak256(key.Bytes()), 0, &proof)
+ return proof, err
+}
+
// GetCommittedState retrieves a value from the given account's committed storage trie.
func (s *StateDB) GetCommittedState(addr common.Address, hash common.Hash) common.Hash {
stateObject := s.getStateObject(addr)
diff --git a/eth/api_backend.go b/eth/api_backend.go
index e7f676f17..2f7020475 100644
--- a/eth/api_backend.go
+++ b/eth/api_backend.go
@@ -56,7 +56,7 @@ func (b *EthAPIBackend) CurrentBlock() *types.Block {
}
func (b *EthAPIBackend) SetHead(number uint64) {
- b.eth.protocolManager.downloader.Cancel()
+ b.eth.handler.downloader.Cancel()
b.eth.blockchain.SetHead(number)
}
@@ -272,10 +272,6 @@ func (b *EthAPIBackend) Downloader() *downloader.Downloader {
return b.eth.Downloader()
}
-func (b *EthAPIBackend) ProtocolVersion() int {
- return b.eth.EthVersion()
-}
-
func (b *EthAPIBackend) SuggestPrice(ctx context.Context) (*big.Int, error) {
return b.gpo.SuggestPrice(ctx)
}
diff --git a/eth/api_test.go b/eth/api_test.go
index 2c9a2e54e..b44eed40b 100644
--- a/eth/api_test.go
+++ b/eth/api_test.go
@@ -57,6 +57,8 @@ func (h resultHash) Swap(i, j int) { h[i], h[j] = h[j], h[i] }
func (h resultHash) Less(i, j int) bool { return bytes.Compare(h[i].Bytes(), h[j].Bytes()) < 0 }
func TestAccountRange(t *testing.T) {
+ t.Parallel()
+
var (
statedb = state.NewDatabaseWithConfig(rawdb.NewMemoryDatabase(), nil)
state, _ = state.New(common.Hash{}, statedb, nil)
@@ -126,6 +128,8 @@ func TestAccountRange(t *testing.T) {
}
func TestEmptyAccountRange(t *testing.T) {
+ t.Parallel()
+
var (
statedb = state.NewDatabase(rawdb.NewMemoryDatabase())
state, _ = state.New(common.Hash{}, statedb, nil)
@@ -142,6 +146,8 @@ func TestEmptyAccountRange(t *testing.T) {
}
func TestStorageRangeAt(t *testing.T) {
+ t.Parallel()
+
// Create a state where account 0x010000... has a few storage entries.
var (
state, _ = state.New(common.Hash{}, state.NewDatabase(rawdb.NewMemoryDatabase()), nil)
diff --git a/eth/backend.go b/eth/backend.go
index bb4275b92..987dee6d5 100644
--- a/eth/backend.go
+++ b/eth/backend.go
@@ -40,6 +40,8 @@ import (
"github.com/ethereum/go-ethereum/eth/downloader"
"github.com/ethereum/go-ethereum/eth/filters"
"github.com/ethereum/go-ethereum/eth/gasprice"
+ "github.com/ethereum/go-ethereum/eth/protocols/eth"
+ "github.com/ethereum/go-ethereum/eth/protocols/snap"
"github.com/ethereum/go-ethereum/ethdb"
"github.com/ethereum/go-ethereum/event"
"github.com/ethereum/go-ethereum/internal/ethapi"
@@ -48,7 +50,6 @@ import (
"github.com/ethereum/go-ethereum/node"
"github.com/ethereum/go-ethereum/p2p"
"github.com/ethereum/go-ethereum/p2p/enode"
- "github.com/ethereum/go-ethereum/p2p/enr"
"github.com/ethereum/go-ethereum/params"
"github.com/ethereum/go-ethereum/rlp"
"github.com/ethereum/go-ethereum/rpc"
@@ -59,10 +60,11 @@ type Ethereum struct {
config *Config
// Handlers
- txPool *core.TxPool
- blockchain *core.BlockChain
- protocolManager *ProtocolManager
- dialCandidates enode.Iterator
+ txPool *core.TxPool
+ blockchain *core.BlockChain
+ handler *handler
+ ethDialCandidates enode.Iterator
+ snapDialCandidates enode.Iterator
// DB interfaces
chainDb ethdb.Database // Block chain database
@@ -145,7 +147,7 @@ func New(stack *node.Node, config *Config) (*Ethereum, error) {
if bcVersion != nil {
dbVer = fmt.Sprintf("%d", *bcVersion)
}
- log.Info("Initialising Ethereum protocol", "versions", ProtocolVersions, "network", config.NetworkId, "dbversion", dbVer)
+ log.Info("Initialising Ethereum protocol", "network", config.NetworkId, "dbversion", dbVer)
if !config.SkipBcVersionCheck {
if bcVersion != nil && *bcVersion > core.BlockChainVersion {
@@ -196,7 +198,17 @@ func New(stack *node.Node, config *Config) (*Ethereum, error) {
if checkpoint == nil {
checkpoint = params.TrustedCheckpoints[genesisHash]
}
- if eth.protocolManager, err = NewProtocolManager(chainConfig, checkpoint, config.SyncMode, config.NetworkId, eth.eventMux, eth.txPool, eth.engine, eth.blockchain, chainDb, cacheLimit, config.Whitelist); err != nil {
+ if eth.handler, err = newHandler(&handlerConfig{
+ Database: chainDb,
+ Chain: eth.blockchain,
+ TxPool: eth.txPool,
+ Network: config.NetworkId,
+ Sync: config.SyncMode,
+ BloomCache: uint64(cacheLimit),
+ EventMux: eth.eventMux,
+ Checkpoint: checkpoint,
+ Whitelist: config.Whitelist,
+ }); err != nil {
return nil, err
}
eth.miner = miner.New(eth, &config.Miner, chainConfig, eth.EventMux(), eth.engine, eth.isLocalBlock)
@@ -209,13 +221,16 @@ func New(stack *node.Node, config *Config) (*Ethereum, error) {
}
eth.APIBackend.gpo = gasprice.NewOracle(eth.APIBackend, gpoParams)
- eth.dialCandidates, err = eth.setupDiscovery()
+ eth.ethDialCandidates, err = setupDiscovery(eth.config.EthDiscoveryURLs)
+ if err != nil {
+ return nil, err
+ }
+ eth.snapDialCandidates, err = setupDiscovery(eth.config.SnapDiscoveryURLs)
if err != nil {
return nil, err
}
-
// Start the RPC service
- eth.netRPCService = ethapi.NewPublicNetAPI(eth.p2pServer, eth.NetVersion())
+ eth.netRPCService = ethapi.NewPublicNetAPI(eth.p2pServer)
// Register the backend on the node
stack.RegisterAPIs(eth.APIs())
@@ -310,7 +325,7 @@ func (s *Ethereum) APIs() []rpc.API {
}, {
Namespace: "eth",
Version: "1.0",
- Service: downloader.NewPublicDownloaderAPI(s.protocolManager.downloader, s.eventMux),
+ Service: downloader.NewPublicDownloaderAPI(s.handler.downloader, s.eventMux),
Public: true,
}, {
Namespace: "miner",
@@ -473,7 +488,7 @@ func (s *Ethereum) StartMining(threads int) error {
}
// If mining is started, we can disable the transaction rejection mechanism
// introduced to speed sync times.
- atomic.StoreUint32(&s.protocolManager.acceptTxs, 1)
+ atomic.StoreUint32(&s.handler.acceptTxs, 1)
go s.miner.Start(eb)
}
@@ -504,21 +519,17 @@ func (s *Ethereum) EventMux() *event.TypeMux { return s.eventMux }
func (s *Ethereum) Engine() consensus.Engine { return s.engine }
func (s *Ethereum) ChainDb() ethdb.Database { return s.chainDb }
func (s *Ethereum) IsListening() bool { return true } // Always listening
-func (s *Ethereum) EthVersion() int { return int(ProtocolVersions[0]) }
-func (s *Ethereum) NetVersion() uint64 { return s.networkID }
-func (s *Ethereum) Downloader() *downloader.Downloader { return s.protocolManager.downloader }
-func (s *Ethereum) Synced() bool { return atomic.LoadUint32(&s.protocolManager.acceptTxs) == 1 }
+func (s *Ethereum) Downloader() *downloader.Downloader { return s.handler.downloader }
+func (s *Ethereum) Synced() bool { return atomic.LoadUint32(&s.handler.acceptTxs) == 1 }
func (s *Ethereum) ArchiveMode() bool { return s.config.NoPruning }
func (s *Ethereum) BloomIndexer() *core.ChainIndexer { return s.bloomIndexer }
// Protocols returns all the currently configured
// network protocols to start.
func (s *Ethereum) Protocols() []p2p.Protocol {
- protos := make([]p2p.Protocol, len(ProtocolVersions))
- for i, vsn := range ProtocolVersions {
- protos[i] = s.protocolManager.makeProtocol(vsn)
- protos[i].Attributes = []enr.Entry{s.currentEthEntry()}
- protos[i].DialCandidates = s.dialCandidates
+ protos := eth.MakeProtocols((*ethHandler)(s.handler), s.networkID, s.ethDialCandidates)
+ if s.config.SnapshotCache > 0 {
+ protos = append(protos, snap.MakeProtocols((*snapHandler)(s.handler), s.snapDialCandidates)...)
}
return protos
}
@@ -526,7 +537,7 @@ func (s *Ethereum) Protocols() []p2p.Protocol {
// Start implements node.Lifecycle, starting all internal goroutines needed by the
// Ethereum protocol implementation.
func (s *Ethereum) Start() error {
- s.startEthEntryUpdate(s.p2pServer.LocalNode())
+ eth.StartENRUpdater(s.blockchain, s.p2pServer.LocalNode())
// Start the bloom bits servicing goroutines
s.startBloomHandlers(params.BloomBitsBlocks)
@@ -540,7 +551,7 @@ func (s *Ethereum) Start() error {
maxPeers -= s.config.LightPeers
}
// Start the networking layer and the light server if requested
- s.protocolManager.Start(maxPeers)
+ s.handler.Start(maxPeers)
return nil
}
@@ -548,7 +559,7 @@ func (s *Ethereum) Start() error {
// Ethereum protocol.
func (s *Ethereum) Stop() error {
// Stop all the peer-related stuff first.
- s.protocolManager.Stop()
+ s.handler.Stop()
// Then stop everything else.
s.bloomIndexer.Close()
@@ -560,5 +571,6 @@ func (s *Ethereum) Stop() error {
rawdb.PopUncleanShutdownMarker(s.chainDb)
s.chainDb.Close()
s.eventMux.Stop()
+
return nil
}
diff --git a/eth/config.go b/eth/config.go
index 0d90376d9..77d03e956 100644
--- a/eth/config.go
+++ b/eth/config.go
@@ -115,7 +115,8 @@ type Config struct {
// This can be set to list of enrtree:// URLs which will be queried for
// for nodes to connect to.
- DiscoveryURLs []string
+ EthDiscoveryURLs []string
+ SnapDiscoveryURLs []string
NoPruning bool // Whether to disable pruning and flush everything to disk
NoPrefetch bool // Whether to disable prefetching and only load state on demand
diff --git a/eth/discovery.go b/eth/discovery.go
index e7a281d35..855ce3b0e 100644
--- a/eth/discovery.go
+++ b/eth/discovery.go
@@ -63,11 +63,12 @@ func (eth *Ethereum) currentEthEntry() *ethEntry {
eth.blockchain.CurrentHeader().Number.Uint64())}
}
-// setupDiscovery creates the node discovery source for the eth protocol.
-func (eth *Ethereum) setupDiscovery() (enode.Iterator, error) {
- if len(eth.config.DiscoveryURLs) == 0 {
+// setupDiscovery creates the node discovery source for the `eth` and `snap`
+// protocols.
+func setupDiscovery(urls []string) (enode.Iterator, error) {
+ if len(urls) == 0 {
return nil, nil
}
client := dnsdisc.NewClient(dnsdisc.Config{})
- return client.NewIterator(eth.config.DiscoveryURLs...)
+ return client.NewIterator(urls...)
}
diff --git a/eth/downloader/downloader.go b/eth/downloader/downloader.go
index 686c1ace1..312359843 100644
--- a/eth/downloader/downloader.go
+++ b/eth/downloader/downloader.go
@@ -29,6 +29,7 @@ import (
"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/core/rawdb"
"github.com/ethereum/go-ethereum/core/types"
+ "github.com/ethereum/go-ethereum/eth/protocols/snap"
"github.com/ethereum/go-ethereum/ethdb"
"github.com/ethereum/go-ethereum/event"
"github.com/ethereum/go-ethereum/log"
@@ -38,7 +39,6 @@ import (
)
var (
- MaxHashFetch = 512 // Amount of hashes to be fetched per retrieval request
MaxBlockFetch = 128 // Amount of blocks to be fetched per retrieval request
MaxHeaderFetch = 192 // Amount of block headers to be fetched per retrieval request
MaxSkeletonSize = 128 // Number of header fetches to need for a skeleton assembly
@@ -89,7 +89,7 @@ var (
errCancelContentProcessing = errors.New("content processing canceled (requested)")
errCanceled = errors.New("syncing canceled (requested)")
errNoSyncActive = errors.New("no sync active")
- errTooOld = errors.New("peer doesn't speak recent enough protocol version (need version >= 63)")
+ errTooOld = errors.New("peer doesn't speak recent enough protocol version (need version >= 64)")
)
type Downloader struct {
@@ -131,20 +131,22 @@ type Downloader struct {
ancientLimit uint64 // The maximum block number which can be regarded as ancient data.
// Channels
- headerCh chan dataPack // [eth/62] Channel receiving inbound block headers
- bodyCh chan dataPack // [eth/62] Channel receiving inbound block bodies
- receiptCh chan dataPack // [eth/63] Channel receiving inbound receipts
- bodyWakeCh chan bool // [eth/62] Channel to signal the block body fetcher of new tasks
- receiptWakeCh chan bool // [eth/63] Channel to signal the receipt fetcher of new tasks
- headerProcCh chan []*types.Header // [eth/62] Channel to feed the header processor new tasks
+ headerCh chan dataPack // Channel receiving inbound block headers
+ bodyCh chan dataPack // Channel receiving inbound block bodies
+ receiptCh chan dataPack // Channel receiving inbound receipts
+ bodyWakeCh chan bool // Channel to signal the block body fetcher of new tasks
+ receiptWakeCh chan bool // Channel to signal the receipt fetcher of new tasks
+ headerProcCh chan []*types.Header // Channel to feed the header processor new tasks
// State sync
pivotHeader *types.Header // Pivot block header to dynamically push the syncing state root
pivotLock sync.RWMutex // Lock protecting pivot header reads from updates
+ snapSync bool // Whether to run state sync over the snap protocol
+ SnapSyncer *snap.Syncer // TODO(karalabe): make private! hack for now
stateSyncStart chan *stateSync
trackStateReq chan *stateReq
- stateCh chan dataPack // [eth/63] Channel receiving inbound node state data
+ stateCh chan dataPack // Channel receiving inbound node state data
// Cancellation and termination
cancelPeer string // Identifier of the peer currently being used as the master (cancel on drop)
@@ -237,6 +239,7 @@ func New(checkpoint uint64, stateDb ethdb.Database, stateBloom *trie.SyncBloom,
headerProcCh: make(chan []*types.Header, 1),
quitCh: make(chan struct{}),
stateCh: make(chan dataPack),
+ SnapSyncer: snap.NewSyncer(stateDb, stateBloom),
stateSyncStart: make(chan *stateSync),
syncStatsState: stateSyncStats{
processed: rawdb.ReadFastTrieProgress(stateDb),
@@ -286,19 +289,16 @@ func (d *Downloader) Synchronising() bool {
return atomic.LoadInt32(&d.synchronising) > 0
}
-// SyncBloomContains tests if the syncbloom filter contains the given hash:
-// - false: the bloom definitely does not contain hash
-// - true: the bloom maybe contains hash
-//
-// While the bloom is being initialized (or is closed), all queries will return true.
-func (d *Downloader) SyncBloomContains(hash []byte) bool {
- return d.stateBloom == nil || d.stateBloom.Contains(hash)
-}
-
// RegisterPeer injects a new download peer into the set of block source to be
// used for fetching hashes and blocks from.
-func (d *Downloader) RegisterPeer(id string, version int, peer Peer) error {
- logger := log.New("peer", id)
+func (d *Downloader) RegisterPeer(id string, version uint, peer Peer) error {
+ var logger log.Logger
+ if len(id) < 16 {
+ // Tests use short IDs, don't choke on them
+ logger = log.New("peer", id)
+ } else {
+ logger = log.New("peer", id[:16])
+ }
logger.Trace("Registering sync peer")
if err := d.peers.Register(newPeerConnection(id, version, peer, logger)); err != nil {
logger.Error("Failed to register sync peer", "err", err)
@@ -310,7 +310,7 @@ func (d *Downloader) RegisterPeer(id string, version int, peer Peer) error {
}
// RegisterLightPeer injects a light client peer, wrapping it so it appears as a regular peer.
-func (d *Downloader) RegisterLightPeer(id string, version int, peer LightPeer) error {
+func (d *Downloader) RegisterLightPeer(id string, version uint, peer LightPeer) error {
return d.RegisterPeer(id, version, &lightPeerWrapper{peer})
}
@@ -319,7 +319,13 @@ func (d *Downloader) RegisterLightPeer(id string, version int, peer LightPeer) e
// the queue.
func (d *Downloader) UnregisterPeer(id string) error {
// Unregister the peer from the active peer set and revoke any fetch tasks
- logger := log.New("peer", id)
+ var logger log.Logger
+ if len(id) < 16 {
+ // Tests use short IDs, don't choke on them
+ logger = log.New("peer", id)
+ } else {
+ logger = log.New("peer", id[:16])
+ }
logger.Trace("Unregistering sync peer")
if err := d.peers.Unregister(id); err != nil {
logger.Error("Failed to unregister sync peer", "err", err)
@@ -381,6 +387,16 @@ func (d *Downloader) synchronise(id string, hash common.Hash, td *big.Int, mode
if mode == FullSync && d.stateBloom != nil {
d.stateBloom.Close()
}
+ // If snap sync was requested, create the snap scheduler and switch to fast
+ // sync mode. Long term we could drop fast sync or merge the two together,
+ // but until snap becomes prevalent, we should support both. TODO(karalabe).
+ if mode == SnapSync {
+ if !d.snapSync {
+ log.Warn("Enabling snapshot sync prototype")
+ d.snapSync = true
+ }
+ mode = FastSync
+ }
// Reset the queue, peer set and wake channels to clean any internal leftover state
d.queue.Reset(blockCacheMaxItems, blockCacheInitialItems)
d.peers.Reset()
@@ -443,8 +459,8 @@ func (d *Downloader) syncWithPeer(p *peerConnection, hash common.Hash, td *big.I
d.mux.Post(DoneEvent{latest})
}
}()
- if p.version < 63 {
- return errTooOld
+ if p.version < 64 {
+ return fmt.Errorf("%w, peer version: %d", errTooOld, p.version)
}
mode := d.getMode()
@@ -1910,27 +1926,53 @@ func (d *Downloader) commitPivotBlock(result *fetchResult) error {
// DeliverHeaders injects a new batch of block headers received from a remote
// node into the download schedule.
-func (d *Downloader) DeliverHeaders(id string, headers []*types.Header) (err error) {
- return d.deliver(id, d.headerCh, &headerPack{id, headers}, headerInMeter, headerDropMeter)
+func (d *Downloader) DeliverHeaders(id string, headers []*types.Header) error {
+ return d.deliver(d.headerCh, &headerPack{id, headers}, headerInMeter, headerDropMeter)
}
// DeliverBodies injects a new batch of block bodies received from a remote node.
-func (d *Downloader) DeliverBodies(id string, transactions [][]*types.Transaction, uncles [][]*types.Header) (err error) {
- return d.deliver(id, d.bodyCh, &bodyPack{id, transactions, uncles}, bodyInMeter, bodyDropMeter)
+func (d *Downloader) DeliverBodies(id string, transactions [][]*types.Transaction, uncles [][]*types.Header) error {
+ return d.deliver(d.bodyCh, &bodyPack{id, transactions, uncles}, bodyInMeter, bodyDropMeter)
}
// DeliverReceipts injects a new batch of receipts received from a remote node.
-func (d *Downloader) DeliverReceipts(id string, receipts [][]*types.Receipt) (err error) {
- return d.deliver(id, d.receiptCh, &receiptPack{id, receipts}, receiptInMeter, receiptDropMeter)
+func (d *Downloader) DeliverReceipts(id string, receipts [][]*types.Receipt) error {
+ return d.deliver(d.receiptCh, &receiptPack{id, receipts}, receiptInMeter, receiptDropMeter)
}
// DeliverNodeData injects a new batch of node state data received from a remote node.
-func (d *Downloader) DeliverNodeData(id string, data [][]byte) (err error) {
- return d.deliver(id, d.stateCh, &statePack{id, data}, stateInMeter, stateDropMeter)
+func (d *Downloader) DeliverNodeData(id string, data [][]byte) error {
+ return d.deliver(d.stateCh, &statePack{id, data}, stateInMeter, stateDropMeter)
+}
+
+// DeliverSnapPacket is invoked from a peer's message handler when it transmits a
+// data packet for the local node to consume.
+func (d *Downloader) DeliverSnapPacket(peer *snap.Peer, packet snap.Packet) error {
+ switch packet := packet.(type) {
+ case *snap.AccountRangePacket:
+ hashes, accounts, err := packet.Unpack()
+ if err != nil {
+ return err
+ }
+ return d.SnapSyncer.OnAccounts(peer, packet.ID, hashes, accounts, packet.Proof)
+
+ case *snap.StorageRangesPacket:
+ hashset, slotset := packet.Unpack()
+ return d.SnapSyncer.OnStorage(peer, packet.ID, hashset, slotset, packet.Proof)
+
+ case *snap.ByteCodesPacket:
+ return d.SnapSyncer.OnByteCodes(peer, packet.ID, packet.Codes)
+
+ case *snap.TrieNodesPacket:
+ return d.SnapSyncer.OnTrieNodes(peer, packet.ID, packet.Nodes)
+
+ default:
+ return fmt.Errorf("unexpected snap packet type: %T", packet)
+ }
}
// deliver injects a new batch of data received from a remote node.
-func (d *Downloader) deliver(id string, destCh chan dataPack, packet dataPack, inMeter, dropMeter metrics.Meter) (err error) {
+func (d *Downloader) deliver(destCh chan dataPack, packet dataPack, inMeter, dropMeter metrics.Meter) (err error) {
// Update the delivery metrics for both good and failed deliveries
inMeter.Mark(int64(packet.Items()))
defer func() {
diff --git a/eth/downloader/downloader_test.go b/eth/downloader/downloader_test.go
index 5e46042ae..6578275d0 100644
--- a/eth/downloader/downloader_test.go
+++ b/eth/downloader/downloader_test.go
@@ -390,7 +390,7 @@ func (dl *downloadTester) Rollback(hashes []common.Hash) {
}
// newPeer registers a new block download source into the downloader.
-func (dl *downloadTester) newPeer(id string, version int, chain *testChain) error {
+func (dl *downloadTester) newPeer(id string, version uint, chain *testChain) error {
dl.lock.Lock()
defer dl.lock.Unlock()
@@ -518,8 +518,6 @@ func assertOwnForkedChain(t *testing.T, tester *downloadTester, common int, leng
// Tests that simple synchronization against a canonical chain works correctly.
// In this test common ancestor lookup should be short circuited and not require
// binary searching.
-func TestCanonicalSynchronisation63Full(t *testing.T) { testCanonicalSynchronisation(t, 63, FullSync) }
-func TestCanonicalSynchronisation63Fast(t *testing.T) { testCanonicalSynchronisation(t, 63, FastSync) }
func TestCanonicalSynchronisation64Full(t *testing.T) { testCanonicalSynchronisation(t, 64, FullSync) }
func TestCanonicalSynchronisation64Fast(t *testing.T) { testCanonicalSynchronisation(t, 64, FastSync) }
func TestCanonicalSynchronisation65Full(t *testing.T) { testCanonicalSynchronisation(t, 65, FullSync) }
@@ -528,7 +526,7 @@ func TestCanonicalSynchronisation65Light(t *testing.T) {
testCanonicalSynchronisation(t, 65, LightSync)
}
-func testCanonicalSynchronisation(t *testing.T, protocol int, mode SyncMode) {
+func testCanonicalSynchronisation(t *testing.T, protocol uint, mode SyncMode) {
t.Parallel()
tester := newTester()
@@ -547,14 +545,12 @@ func testCanonicalSynchronisation(t *testing.T, protocol int, mode SyncMode) {
// Tests that if a large batch of blocks are being downloaded, it is throttled
// until the cached blocks are retrieved.
-func TestThrottling63Full(t *testing.T) { testThrottling(t, 63, FullSync) }
-func TestThrottling63Fast(t *testing.T) { testThrottling(t, 63, FastSync) }
func TestThrottling64Full(t *testing.T) { testThrottling(t, 64, FullSync) }
func TestThrottling64Fast(t *testing.T) { testThrottling(t, 64, FastSync) }
func TestThrottling65Full(t *testing.T) { testThrottling(t, 65, FullSync) }
func TestThrottling65Fast(t *testing.T) { testThrottling(t, 65, FastSync) }
-func testThrottling(t *testing.T, protocol int, mode SyncMode) {
+func testThrottling(t *testing.T, protocol uint, mode SyncMode) {
t.Parallel()
tester := newTester()
@@ -632,15 +628,13 @@ func testThrottling(t *testing.T, protocol int, mode SyncMode) {
// Tests that simple synchronization against a forked chain works correctly. In
// this test common ancestor lookup should *not* be short circuited, and a full
// binary search should be executed.
-func TestForkedSync63Full(t *testing.T) { testForkedSync(t, 63, FullSync) }
-func TestForkedSync63Fast(t *testing.T) { testForkedSync(t, 63, FastSync) }
func TestForkedSync64Full(t *testing.T) { testForkedSync(t, 64, FullSync) }
func TestForkedSync64Fast(t *testing.T) { testForkedSync(t, 64, FastSync) }
func TestForkedSync65Full(t *testing.T) { testForkedSync(t, 65, FullSync) }
func TestForkedSync65Fast(t *testing.T) { testForkedSync(t, 65, FastSync) }
func TestForkedSync65Light(t *testing.T) { testForkedSync(t, 65, LightSync) }
-func testForkedSync(t *testing.T, protocol int, mode SyncMode) {
+func testForkedSync(t *testing.T, protocol uint, mode SyncMode) {
t.Parallel()
tester := newTester()
@@ -665,15 +659,13 @@ func testForkedSync(t *testing.T, protocol int, mode SyncMode) {
// Tests that synchronising against a much shorter but much heavyer fork works
// corrently and is not dropped.
-func TestHeavyForkedSync63Full(t *testing.T) { testHeavyForkedSync(t, 63, FullSync) }
-func TestHeavyForkedSync63Fast(t *testing.T) { testHeavyForkedSync(t, 63, FastSync) }
func TestHeavyForkedSync64Full(t *testing.T) { testHeavyForkedSync(t, 64, FullSync) }
func TestHeavyForkedSync64Fast(t *testing.T) { testHeavyForkedSync(t, 64, FastSync) }
func TestHeavyForkedSync65Full(t *testing.T) { testHeavyForkedSync(t, 65, FullSync) }
func TestHeavyForkedSync65Fast(t *testing.T) { testHeavyForkedSync(t, 65, FastSync) }
func TestHeavyForkedSync65Light(t *testing.T) { testHeavyForkedSync(t, 65, LightSync) }
-func testHeavyForkedSync(t *testing.T, protocol int, mode SyncMode) {
+func testHeavyForkedSync(t *testing.T, protocol uint, mode SyncMode) {
t.Parallel()
tester := newTester()
@@ -700,15 +692,13 @@ func testHeavyForkedSync(t *testing.T, protocol int, mode SyncMode) {
// Tests that chain forks are contained within a certain interval of the current
// chain head, ensuring that malicious peers cannot waste resources by feeding
// long dead chains.
-func TestBoundedForkedSync63Full(t *testing.T) { testBoundedForkedSync(t, 63, FullSync) }
-func TestBoundedForkedSync63Fast(t *testing.T) { testBoundedForkedSync(t, 63, FastSync) }
func TestBoundedForkedSync64Full(t *testing.T) { testBoundedForkedSync(t, 64, FullSync) }
func TestBoundedForkedSync64Fast(t *testing.T) { testBoundedForkedSync(t, 64, FastSync) }
func TestBoundedForkedSync65Full(t *testing.T) { testBoundedForkedSync(t, 65, FullSync) }
func TestBoundedForkedSync65Fast(t *testing.T) { testBoundedForkedSync(t, 65, FastSync) }
func TestBoundedForkedSync65Light(t *testing.T) { testBoundedForkedSync(t, 65, LightSync) }
-func testBoundedForkedSync(t *testing.T, protocol int, mode SyncMode) {
+func testBoundedForkedSync(t *testing.T, protocol uint, mode SyncMode) {
t.Parallel()
tester := newTester()
@@ -734,15 +724,13 @@ func testBoundedForkedSync(t *testing.T, protocol int, mode SyncMode) {
// Tests that chain forks are contained within a certain interval of the current
// chain head for short but heavy forks too. These are a bit special because they
// take different ancestor lookup paths.
-func TestBoundedHeavyForkedSync63Full(t *testing.T) { testBoundedHeavyForkedSync(t, 63, FullSync) }
-func TestBoundedHeavyForkedSync63Fast(t *testing.T) { testBoundedHeavyForkedSync(t, 63, FastSync) }
func TestBoundedHeavyForkedSync64Full(t *testing.T) { testBoundedHeavyForkedSync(t, 64, FullSync) }
func TestBoundedHeavyForkedSync64Fast(t *testing.T) { testBoundedHeavyForkedSync(t, 64, FastSync) }
func TestBoundedHeavyForkedSync65Full(t *testing.T) { testBoundedHeavyForkedSync(t, 65, FullSync) }
func TestBoundedHeavyForkedSync65Fast(t *testing.T) { testBoundedHeavyForkedSync(t, 65, FastSync) }
func TestBoundedHeavyForkedSync65Light(t *testing.T) { testBoundedHeavyForkedSync(t, 65, LightSync) }
-func testBoundedHeavyForkedSync(t *testing.T, protocol int, mode SyncMode) {
+func testBoundedHeavyForkedSync(t *testing.T, protocol uint, mode SyncMode) {
t.Parallel()
tester := newTester()
@@ -786,15 +774,13 @@ func TestInactiveDownloader63(t *testing.T) {
}
// Tests that a canceled download wipes all previously accumulated state.
-func TestCancel63Full(t *testing.T) { testCancel(t, 63, FullSync) }
-func TestCancel63Fast(t *testing.T) { testCancel(t, 63, FastSync) }
func TestCancel64Full(t *testing.T) { testCancel(t, 64, FullSync) }
func TestCancel64Fast(t *testing.T) { testCancel(t, 64, FastSync) }
func TestCancel65Full(t *testing.T) { testCancel(t, 65, FullSync) }
func TestCancel65Fast(t *testing.T) { testCancel(t, 65, FastSync) }
func TestCancel65Light(t *testing.T) { testCancel(t, 65, LightSync) }
-func testCancel(t *testing.T, protocol int, mode SyncMode) {
+func testCancel(t *testing.T, protocol uint, mode SyncMode) {
t.Parallel()
tester := newTester()
@@ -819,15 +805,13 @@ func testCancel(t *testing.T, protocol int, mode SyncMode) {
}
// Tests that synchronisation from multiple peers works as intended (multi thread sanity test).
-func TestMultiSynchronisation63Full(t *testing.T) { testMultiSynchronisation(t, 63, FullSync) }
-func TestMultiSynchronisation63Fast(t *testing.T) { testMultiSynchronisation(t, 63, FastSync) }
func TestMultiSynchronisation64Full(t *testing.T) { testMultiSynchronisation(t, 64, FullSync) }
func TestMultiSynchronisation64Fast(t *testing.T) { testMultiSynchronisation(t, 64, FastSync) }
func TestMultiSynchronisation65Full(t *testing.T) { testMultiSynchronisation(t, 65, FullSync) }
func TestMultiSynchronisation65Fast(t *testing.T) { testMultiSynchronisation(t, 65, FastSync) }
func TestMultiSynchronisation65Light(t *testing.T) { testMultiSynchronisation(t, 65, LightSync) }
-func testMultiSynchronisation(t *testing.T, protocol int, mode SyncMode) {
+func testMultiSynchronisation(t *testing.T, protocol uint, mode SyncMode) {
t.Parallel()
tester := newTester()
@@ -849,15 +833,13 @@ func testMultiSynchronisation(t *testing.T, protocol int, mode SyncMode) {
// Tests that synchronisations behave well in multi-version protocol environments
// and not wreak havoc on other nodes in the network.
-func TestMultiProtoSynchronisation63Full(t *testing.T) { testMultiProtoSync(t, 63, FullSync) }
-func TestMultiProtoSynchronisation63Fast(t *testing.T) { testMultiProtoSync(t, 63, FastSync) }
func TestMultiProtoSynchronisation64Full(t *testing.T) { testMultiProtoSync(t, 64, FullSync) }
func TestMultiProtoSynchronisation64Fast(t *testing.T) { testMultiProtoSync(t, 64, FastSync) }
func TestMultiProtoSynchronisation65Full(t *testing.T) { testMultiProtoSync(t, 65, FullSync) }
func TestMultiProtoSynchronisation65Fast(t *testing.T) { testMultiProtoSync(t, 65, FastSync) }
func TestMultiProtoSynchronisation65Light(t *testing.T) { testMultiProtoSync(t, 65, LightSync) }
-func testMultiProtoSync(t *testing.T, protocol int, mode SyncMode) {
+func testMultiProtoSync(t *testing.T, protocol uint, mode SyncMode) {
t.Parallel()
tester := newTester()
@@ -888,15 +870,13 @@ func testMultiProtoSync(t *testing.T, protocol int, mode SyncMode) {
// Tests that if a block is empty (e.g. header only), no body request should be
// made, and instead the header should be assembled into a whole block in itself.
-func TestEmptyShortCircuit63Full(t *testing.T) { testEmptyShortCircuit(t, 63, FullSync) }
-func TestEmptyShortCircuit63Fast(t *testing.T) { testEmptyShortCircuit(t, 63, FastSync) }
func TestEmptyShortCircuit64Full(t *testing.T) { testEmptyShortCircuit(t, 64, FullSync) }
func TestEmptyShortCircuit64Fast(t *testing.T) { testEmptyShortCircuit(t, 64, FastSync) }
func TestEmptyShortCircuit65Full(t *testing.T) { testEmptyShortCircuit(t, 65, FullSync) }
func TestEmptyShortCircuit65Fast(t *testing.T) { testEmptyShortCircuit(t, 65, FastSync) }
func TestEmptyShortCircuit65Light(t *testing.T) { testEmptyShortCircuit(t, 65, LightSync) }
-func testEmptyShortCircuit(t *testing.T, protocol int, mode SyncMode) {
+func testEmptyShortCircuit(t *testing.T, protocol uint, mode SyncMode) {
t.Parallel()
tester := newTester()
@@ -942,15 +922,13 @@ func testEmptyShortCircuit(t *testing.T, protocol int, mode SyncMode) {
// Tests that headers are enqueued continuously, preventing malicious nodes from
// stalling the downloader by feeding gapped header chains.
-func TestMissingHeaderAttack63Full(t *testing.T) { testMissingHeaderAttack(t, 63, FullSync) }
-func TestMissingHeaderAttack63Fast(t *testing.T) { testMissingHeaderAttack(t, 63, FastSync) }
func TestMissingHeaderAttack64Full(t *testing.T) { testMissingHeaderAttack(t, 64, FullSync) }
func TestMissingHeaderAttack64Fast(t *testing.T) { testMissingHeaderAttack(t, 64, FastSync) }
func TestMissingHeaderAttack65Full(t *testing.T) { testMissingHeaderAttack(t, 65, FullSync) }
func TestMissingHeaderAttack65Fast(t *testing.T) { testMissingHeaderAttack(t, 65, FastSync) }
func TestMissingHeaderAttack65Light(t *testing.T) { testMissingHeaderAttack(t, 65, LightSync) }
-func testMissingHeaderAttack(t *testing.T, protocol int, mode SyncMode) {
+func testMissingHeaderAttack(t *testing.T, protocol uint, mode SyncMode) {
t.Parallel()
tester := newTester()
@@ -974,15 +952,13 @@ func testMissingHeaderAttack(t *testing.T, protocol int, mode SyncMode) {
// Tests that if requested headers are shifted (i.e. first is missing), the queue
// detects the invalid numbering.
-func TestShiftedHeaderAttack63Full(t *testing.T) { testShiftedHeaderAttack(t, 63, FullSync) }
-func TestShiftedHeaderAttack63Fast(t *testing.T) { testShiftedHeaderAttack(t, 63, FastSync) }
func TestShiftedHeaderAttack64Full(t *testing.T) { testShiftedHeaderAttack(t, 64, FullSync) }
func TestShiftedHeaderAttack64Fast(t *testing.T) { testShiftedHeaderAttack(t, 64, FastSync) }
func TestShiftedHeaderAttack65Full(t *testing.T) { testShiftedHeaderAttack(t, 65, FullSync) }
func TestShiftedHeaderAttack65Fast(t *testing.T) { testShiftedHeaderAttack(t, 65, FastSync) }
func TestShiftedHeaderAttack65Light(t *testing.T) { testShiftedHeaderAttack(t, 65, LightSync) }
-func testShiftedHeaderAttack(t *testing.T, protocol int, mode SyncMode) {
+func testShiftedHeaderAttack(t *testing.T, protocol uint, mode SyncMode) {
t.Parallel()
tester := newTester()
@@ -1011,11 +987,10 @@ func testShiftedHeaderAttack(t *testing.T, protocol int, mode SyncMode) {
// Tests that upon detecting an invalid header, the recent ones are rolled back
// for various failure scenarios. Afterwards a full sync is attempted to make
// sure no state was corrupted.
-func TestInvalidHeaderRollback63Fast(t *testing.T) { testInvalidHeaderRollback(t, 63, FastSync) }
func TestInvalidHeaderRollback64Fast(t *testing.T) { testInvalidHeaderRollback(t, 64, FastSync) }
func TestInvalidHeaderRollback65Fast(t *testing.T) { testInvalidHeaderRollback(t, 65, FastSync) }
-func testInvalidHeaderRollback(t *testing.T, protocol int, mode SyncMode) {
+func testInvalidHeaderRollback(t *testing.T, protocol uint, mode SyncMode) {
t.Parallel()
tester := newTester()
@@ -1103,15 +1078,13 @@ func testInvalidHeaderRollback(t *testing.T, protocol int, mode SyncMode) {
// Tests that a peer advertising a high TD doesn't get to stall the downloader
// afterwards by not sending any useful hashes.
-func TestHighTDStarvationAttack63Full(t *testing.T) { testHighTDStarvationAttack(t, 63, FullSync) }
-func TestHighTDStarvationAttack63Fast(t *testing.T) { testHighTDStarvationAttack(t, 63, FastSync) }
func TestHighTDStarvationAttack64Full(t *testing.T) { testHighTDStarvationAttack(t, 64, FullSync) }
func TestHighTDStarvationAttack64Fast(t *testing.T) { testHighTDStarvationAttack(t, 64, FastSync) }
func TestHighTDStarvationAttack65Full(t *testing.T) { testHighTDStarvationAttack(t, 65, FullSync) }
func TestHighTDStarvationAttack65Fast(t *testing.T) { testHighTDStarvationAttack(t, 65, FastSync) }
func TestHighTDStarvationAttack65Light(t *testing.T) { testHighTDStarvationAttack(t, 65, LightSync) }
-func testHighTDStarvationAttack(t *testing.T, protocol int, mode SyncMode) {
+func testHighTDStarvationAttack(t *testing.T, protocol uint, mode SyncMode) {
t.Parallel()
tester := newTester()
@@ -1125,11 +1098,10 @@ func testHighTDStarvationAttack(t *testing.T, protocol int, mode SyncMode) {
}
// Tests that misbehaving peers are disconnected, whilst behaving ones are not.
-func TestBlockHeaderAttackerDropping63(t *testing.T) { testBlockHeaderAttackerDropping(t, 63) }
func TestBlockHeaderAttackerDropping64(t *testing.T) { testBlockHeaderAttackerDropping(t, 64) }
func TestBlockHeaderAttackerDropping65(t *testing.T) { testBlockHeaderAttackerDropping(t, 65) }
-func testBlockHeaderAttackerDropping(t *testing.T, protocol int) {
+func testBlockHeaderAttackerDropping(t *testing.T, protocol uint) {
t.Parallel()
// Define the disconnection requirement for individual hash fetch errors
@@ -1179,15 +1151,13 @@ func testBlockHeaderAttackerDropping(t *testing.T, protocol int) {
// Tests that synchronisation progress (origin block number, current block number
// and highest block number) is tracked and updated correctly.
-func TestSyncProgress63Full(t *testing.T) { testSyncProgress(t, 63, FullSync) }
-func TestSyncProgress63Fast(t *testing.T) { testSyncProgress(t, 63, FastSync) }
func TestSyncProgress64Full(t *testing.T) { testSyncProgress(t, 64, FullSync) }
func TestSyncProgress64Fast(t *testing.T) { testSyncProgress(t, 64, FastSync) }
func TestSyncProgress65Full(t *testing.T) { testSyncProgress(t, 65, FullSync) }
func TestSyncProgress65Fast(t *testing.T) { testSyncProgress(t, 65, FastSync) }
func TestSyncProgress65Light(t *testing.T) { testSyncProgress(t, 65, LightSync) }
-func testSyncProgress(t *testing.T, protocol int, mode SyncMode) {
+func testSyncProgress(t *testing.T, protocol uint, mode SyncMode) {
t.Parallel()
tester := newTester()
@@ -1263,21 +1233,19 @@ func checkProgress(t *testing.T, d *Downloader, stage string, want ethereum.Sync
// Tests that synchronisation progress (origin block number and highest block
// number) is tracked and updated correctly in case of a fork (or manual head
// revertal).
-func TestForkedSyncProgress63Full(t *testing.T) { testForkedSyncProgress(t, 63, FullSync) }
-func TestForkedSyncProgress63Fast(t *testing.T) { testForkedSyncProgress(t, 63, FastSync) }
func TestForkedSyncProgress64Full(t *testing.T) { testForkedSyncProgress(t, 64, FullSync) }
func TestForkedSyncProgress64Fast(t *testing.T) { testForkedSyncProgress(t, 64, FastSync) }
func TestForkedSyncProgress65Full(t *testing.T) { testForkedSyncProgress(t, 65, FullSync) }
func TestForkedSyncProgress65Fast(t *testing.T) { testForkedSyncProgress(t, 65, FastSync) }
func TestForkedSyncProgress65Light(t *testing.T) { testForkedSyncProgress(t, 65, LightSync) }
-func testForkedSyncProgress(t *testing.T, protocol int, mode SyncMode) {
+func testForkedSyncProgress(t *testing.T, protocol uint, mode SyncMode) {
t.Parallel()
tester := newTester()
defer tester.terminate()
- chainA := testChainForkLightA.shorten(testChainBase.len() + MaxHashFetch)
- chainB := testChainForkLightB.shorten(testChainBase.len() + MaxHashFetch)
+ chainA := testChainForkLightA.shorten(testChainBase.len() + MaxHeaderFetch)
+ chainB := testChainForkLightB.shorten(testChainBase.len() + MaxHeaderFetch)
// Set a sync init hook to catch progress changes
starting := make(chan struct{})
@@ -1339,15 +1307,13 @@ func testForkedSyncProgress(t *testing.T, protocol int, mode SyncMode) {
// Tests that if synchronisation is aborted due to some failure, then the progress
// origin is not updated in the next sync cycle, as it should be considered the
// continuation of the previous sync and not a new instance.
-func TestFailedSyncProgress63Full(t *testing.T) { testFailedSyncProgress(t, 63, FullSync) }
-func TestFailedSyncProgress63Fast(t *testing.T) { testFailedSyncProgress(t, 63, FastSync) }
func TestFailedSyncProgress64Full(t *testing.T) { testFailedSyncProgress(t, 64, FullSync) }
func TestFailedSyncProgress64Fast(t *testing.T) { testFailedSyncProgress(t, 64, FastSync) }
func TestFailedSyncProgress65Full(t *testing.T) { testFailedSyncProgress(t, 65, FullSync) }
func TestFailedSyncProgress65Fast(t *testing.T) { testFailedSyncProgress(t, 65, FastSync) }
func TestFailedSyncProgress65Light(t *testing.T) { testFailedSyncProgress(t, 65, LightSync) }
-func testFailedSyncProgress(t *testing.T, protocol int, mode SyncMode) {
+func testFailedSyncProgress(t *testing.T, protocol uint, mode SyncMode) {
t.Parallel()
tester := newTester()
@@ -1412,15 +1378,13 @@ func testFailedSyncProgress(t *testing.T, protocol int, mode SyncMode) {
// Tests that if an attacker fakes a chain height, after the attack is detected,
// the progress height is successfully reduced at the next sync invocation.
-func TestFakedSyncProgress63Full(t *testing.T) { testFakedSyncProgress(t, 63, FullSync) }
-func TestFakedSyncProgress63Fast(t *testing.T) { testFakedSyncProgress(t, 63, FastSync) }
func TestFakedSyncProgress64Full(t *testing.T) { testFakedSyncProgress(t, 64, FullSync) }
func TestFakedSyncProgress64Fast(t *testing.T) { testFakedSyncProgress(t, 64, FastSync) }
func TestFakedSyncProgress65Full(t *testing.T) { testFakedSyncProgress(t, 65, FullSync) }
func TestFakedSyncProgress65Fast(t *testing.T) { testFakedSyncProgress(t, 65, FastSync) }
func TestFakedSyncProgress65Light(t *testing.T) { testFakedSyncProgress(t, 65, LightSync) }
-func testFakedSyncProgress(t *testing.T, protocol int, mode SyncMode) {
+func testFakedSyncProgress(t *testing.T, protocol uint, mode SyncMode) {
t.Parallel()
tester := newTester()
@@ -1489,31 +1453,15 @@ func testFakedSyncProgress(t *testing.T, protocol int, mode SyncMode) {
// This test reproduces an issue where unexpected deliveries would
// block indefinitely if they arrived at the right time.
-func TestDeliverHeadersHang(t *testing.T) {
+func TestDeliverHeadersHang64Full(t *testing.T) { testDeliverHeadersHang(t, 64, FullSync) }
+func TestDeliverHeadersHang64Fast(t *testing.T) { testDeliverHeadersHang(t, 64, FastSync) }
+func TestDeliverHeadersHang65Full(t *testing.T) { testDeliverHeadersHang(t, 65, FullSync) }
+func TestDeliverHeadersHang65Fast(t *testing.T) { testDeliverHeadersHang(t, 65, FastSync) }
+func TestDeliverHeadersHang65Light(t *testing.T) { testDeliverHeadersHang(t, 65, LightSync) }
+
+func testDeliverHeadersHang(t *testing.T, protocol uint, mode SyncMode) {
t.Parallel()
- testCases := []struct {
- protocol int
- syncMode SyncMode
- }{
- {63, FullSync},
- {63, FastSync},
- {64, FullSync},
- {64, FastSync},
- {64, LightSync},
- {65, FullSync},
- {65, FastSync},
- {65, LightSync},
- }
- for _, tc := range testCases {
- t.Run(fmt.Sprintf("protocol %d mode %v", tc.protocol, tc.syncMode), func(t *testing.T) {
- t.Parallel()
- testDeliverHeadersHang(t, tc.protocol, tc.syncMode)
- })
- }
-}
-
-func testDeliverHeadersHang(t *testing.T, protocol int, mode SyncMode) {
master := newTester()
defer master.terminate()
chain := testChainBase.shorten(15)
@@ -1664,15 +1612,13 @@ func TestRemoteHeaderRequestSpan(t *testing.T) {
// Tests that peers below a pre-configured checkpoint block are prevented from
// being fast-synced from, avoiding potential cheap eclipse attacks.
-func TestCheckpointEnforcement63Full(t *testing.T) { testCheckpointEnforcement(t, 63, FullSync) }
-func TestCheckpointEnforcement63Fast(t *testing.T) { testCheckpointEnforcement(t, 63, FastSync) }
func TestCheckpointEnforcement64Full(t *testing.T) { testCheckpointEnforcement(t, 64, FullSync) }
func TestCheckpointEnforcement64Fast(t *testing.T) { testCheckpointEnforcement(t, 64, FastSync) }
func TestCheckpointEnforcement65Full(t *testing.T) { testCheckpointEnforcement(t, 65, FullSync) }
func TestCheckpointEnforcement65Fast(t *testing.T) { testCheckpointEnforcement(t, 65, FastSync) }
func TestCheckpointEnforcement65Light(t *testing.T) { testCheckpointEnforcement(t, 65, LightSync) }
-func testCheckpointEnforcement(t *testing.T, protocol int, mode SyncMode) {
+func testCheckpointEnforcement(t *testing.T, protocol uint, mode SyncMode) {
t.Parallel()
// Create a new tester with a particular hard coded checkpoint block
diff --git a/eth/downloader/modes.go b/eth/downloader/modes.go
index d866ceabc..8ea7876a1 100644
--- a/eth/downloader/modes.go
+++ b/eth/downloader/modes.go
@@ -24,7 +24,8 @@ type SyncMode uint32
const (
FullSync SyncMode = iota // Synchronise the entire blockchain history from full blocks
- FastSync // Quickly download the headers, full sync only at the chain head
+ FastSync // Quickly download the headers, full sync only at the chain
+ SnapSync // Download the chain and the state via compact snashots
LightSync // Download only the headers and terminate afterwards
)
@@ -39,6 +40,8 @@ func (mode SyncMode) String() string {
return "full"
case FastSync:
return "fast"
+ case SnapSync:
+ return "snap"
case LightSync:
return "light"
default:
@@ -52,6 +55,8 @@ func (mode SyncMode) MarshalText() ([]byte, error) {
return []byte("full"), nil
case FastSync:
return []byte("fast"), nil
+ case SnapSync:
+ return []byte("snap"), nil
case LightSync:
return []byte("light"), nil
default:
@@ -65,6 +70,8 @@ func (mode *SyncMode) UnmarshalText(text []byte) error {
*mode = FullSync
case "fast":
*mode = FastSync
+ case "snap":
+ *mode = SnapSync
case "light":
*mode = LightSync
default:
diff --git a/eth/downloader/peer.go b/eth/downloader/peer.go
index c6671436f..ba90bf31c 100644
--- a/eth/downloader/peer.go
+++ b/eth/downloader/peer.go
@@ -69,7 +69,7 @@ type peerConnection struct {
peer Peer
- version int // Eth protocol version number to switch strategies
+ version uint // Eth protocol version number to switch strategies
log log.Logger // Contextual logger to add extra infos to peer logs
lock sync.RWMutex
}
@@ -112,7 +112,7 @@ func (w *lightPeerWrapper) RequestNodeData([]common.Hash) error {
}
// newPeerConnection creates a new downloader peer.
-func newPeerConnection(id string, version int, peer Peer, logger log.Logger) *peerConnection {
+func newPeerConnection(id string, version uint, peer Peer, logger log.Logger) *peerConnection {
return &peerConnection{
id: id,
lacking: make(map[common.Hash]struct{}),
@@ -457,7 +457,7 @@ func (ps *peerSet) HeaderIdlePeers() ([]*peerConnection, int) {
defer p.lock.RUnlock()
return p.headerThroughput
}
- return ps.idlePeers(63, 65, idle, throughput)
+ return ps.idlePeers(64, 65, idle, throughput)
}
// BodyIdlePeers retrieves a flat list of all the currently body-idle peers within
@@ -471,7 +471,7 @@ func (ps *peerSet) BodyIdlePeers() ([]*peerConnection, int) {
defer p.lock.RUnlock()
return p.blockThroughput
}
- return ps.idlePeers(63, 65, idle, throughput)
+ return ps.idlePeers(64, 65, idle, throughput)
}
// ReceiptIdlePeers retrieves a flat list of all the currently receipt-idle peers
@@ -485,7 +485,7 @@ func (ps *peerSet) ReceiptIdlePeers() ([]*peerConnection, int) {
defer p.lock.RUnlock()
return p.receiptThroughput
}
- return ps.idlePeers(63, 65, idle, throughput)
+ return ps.idlePeers(64, 65, idle, throughput)
}
// NodeDataIdlePeers retrieves a flat list of all the currently node-data-idle
@@ -499,13 +499,13 @@ func (ps *peerSet) NodeDataIdlePeers() ([]*peerConnection, int) {
defer p.lock.RUnlock()
return p.stateThroughput
}
- return ps.idlePeers(63, 65, idle, throughput)
+ return ps.idlePeers(64, 65, idle, throughput)
}
// idlePeers retrieves a flat list of all currently idle peers satisfying the
// protocol version constraints, using the provided function to check idleness.
// The resulting set of peers are sorted by their measure throughput.
-func (ps *peerSet) idlePeers(minProtocol, maxProtocol int, idleCheck func(*peerConnection) bool, throughput func(*peerConnection) float64) ([]*peerConnection, int) {
+func (ps *peerSet) idlePeers(minProtocol, maxProtocol uint, idleCheck func(*peerConnection) bool, throughput func(*peerConnection) float64) ([]*peerConnection, int) {
ps.lock.RLock()
defer ps.lock.RUnlock()
diff --git a/eth/downloader/queue.go b/eth/downloader/queue.go
index d2ec8ba69..2150842f8 100644
--- a/eth/downloader/queue.go
+++ b/eth/downloader/queue.go
@@ -113,24 +113,24 @@ type queue struct {
mode SyncMode // Synchronisation mode to decide on the block parts to schedule for fetching
// Headers are "special", they download in batches, supported by a skeleton chain
- headerHead common.Hash // [eth/62] Hash of the last queued header to verify order
- headerTaskPool map[uint64]*types.Header // [eth/62] Pending header retrieval tasks, mapping starting indexes to skeleton headers
- headerTaskQueue *prque.Prque // [eth/62] Priority queue of the skeleton indexes to fetch the filling headers for
- headerPeerMiss map[string]map[uint64]struct{} // [eth/62] Set of per-peer header batches known to be unavailable
- headerPendPool map[string]*fetchRequest // [eth/62] Currently pending header retrieval operations
- headerResults []*types.Header // [eth/62] Result cache accumulating the completed headers
- headerProced int // [eth/62] Number of headers already processed from the results
- headerOffset uint64 // [eth/62] Number of the first header in the result cache
- headerContCh chan bool // [eth/62] Channel to notify when header download finishes
+ headerHead common.Hash // Hash of the last queued header to verify order
+ headerTaskPool map[uint64]*types.Header // Pending header retrieval tasks, mapping starting indexes to skeleton headers
+ headerTaskQueue *prque.Prque // Priority queue of the skeleton indexes to fetch the filling headers for
+ headerPeerMiss map[string]map[uint64]struct{} // Set of per-peer header batches known to be unavailable
+ headerPendPool map[string]*fetchRequest // Currently pending header retrieval operations
+ headerResults []*types.Header // Result cache accumulating the completed headers
+ headerProced int // Number of headers already processed from the results
+ headerOffset uint64 // Number of the first header in the result cache
+ headerContCh chan bool // Channel to notify when header download finishes
// All data retrievals below are based on an already assembles header chain
- blockTaskPool map[common.Hash]*types.Header // [eth/62] Pending block (body) retrieval tasks, mapping hashes to headers
- blockTaskQueue *prque.Prque // [eth/62] Priority queue of the headers to fetch the blocks (bodies) for
- blockPendPool map[string]*fetchRequest // [eth/62] Currently pending block (body) retrieval operations
+ blockTaskPool map[common.Hash]*types.Header // Pending block (body) retrieval tasks, mapping hashes to headers
+ blockTaskQueue *prque.Prque // Priority queue of the headers to fetch the blocks (bodies) for
+ blockPendPool map[string]*fetchRequest // Currently pending block (body) retrieval operations
- receiptTaskPool map[common.Hash]*types.Header // [eth/63] Pending receipt retrieval tasks, mapping hashes to headers
- receiptTaskQueue *prque.Prque // [eth/63] Priority queue of the headers to fetch the receipts for
- receiptPendPool map[string]*fetchRequest // [eth/63] Currently pending receipt retrieval operations
+ receiptTaskPool map[common.Hash]*types.Header // Pending receipt retrieval tasks, mapping hashes to headers
+ receiptTaskQueue *prque.Prque // Priority queue of the headers to fetch the receipts for
+ receiptPendPool map[string]*fetchRequest // Currently pending receipt retrieval operations
resultCache *resultStore // Downloaded but not yet delivered fetch results
resultSize common.StorageSize // Approximate size of a block (exponential moving average)
@@ -690,6 +690,13 @@ func (q *queue) DeliverHeaders(id string, headers []*types.Header, headerProcCh
q.lock.Lock()
defer q.lock.Unlock()
+ var logger log.Logger
+ if len(id) < 16 {
+ // Tests use short IDs, don't choke on them
+ logger = log.New("peer", id)
+ } else {
+ logger = log.New("peer", id[:16])
+ }
// Short circuit if the data was never requested
request := q.headerPendPool[id]
if request == nil {
@@ -704,10 +711,10 @@ func (q *queue) DeliverHeaders(id string, headers []*types.Header, headerProcCh
accepted := len(headers) == MaxHeaderFetch
if accepted {
if headers[0].Number.Uint64() != request.From {
- log.Trace("First header broke chain ordering", "peer", id, "number", headers[0].Number, "hash", headers[0].Hash(), request.From)
+ logger.Trace("First header broke chain ordering", "number", headers[0].Number, "hash", headers[0].Hash(), "expected", request.From)
accepted = false
} else if headers[len(headers)-1].Hash() != target {
- log.Trace("Last header broke skeleton structure ", "peer", id, "number", headers[len(headers)-1].Number, "hash", headers[len(headers)-1].Hash(), "expected", target)
+ logger.Trace("Last header broke skeleton structure ", "number", headers[len(headers)-1].Number, "hash", headers[len(headers)-1].Hash(), "expected", target)
accepted = false
}
}
@@ -716,12 +723,12 @@ func (q *queue) DeliverHeaders(id string, headers []*types.Header, headerProcCh
for i, header := range headers[1:] {
hash := header.Hash()
if want := request.From + 1 + uint64(i); header.Number.Uint64() != want {
- log.Warn("Header broke chain ordering", "peer", id, "number", header.Number, "hash", hash, "expected", want)
+ logger.Warn("Header broke chain ordering", "number", header.Number, "hash", hash, "expected", want)
accepted = false
break
}
if parentHash != header.ParentHash {
- log.Warn("Header broke chain ancestry", "peer", id, "number", header.Number, "hash", hash)
+ logger.Warn("Header broke chain ancestry", "number", header.Number, "hash", hash)
accepted = false
break
}
@@ -731,7 +738,7 @@ func (q *queue) DeliverHeaders(id string, headers []*types.Header, headerProcCh
}
// If the batch of headers wasn't accepted, mark as unavailable
if !accepted {
- log.Trace("Skeleton filling not accepted", "peer", id, "from", request.From)
+ logger.Trace("Skeleton filling not accepted", "from", request.From)
miss := q.headerPeerMiss[id]
if miss == nil {
@@ -758,7 +765,7 @@ func (q *queue) DeliverHeaders(id string, headers []*types.Header, headerProcCh
select {
case headerProcCh <- process:
- log.Trace("Pre-scheduled new headers", "peer", id, "count", len(process), "from", process[0].Number)
+ logger.Trace("Pre-scheduled new headers", "count", len(process), "from", process[0].Number)
q.headerProced += len(process)
default:
}
diff --git a/eth/downloader/statesync.go b/eth/downloader/statesync.go
index 6745aa54a..69bd13c2f 100644
--- a/eth/downloader/statesync.go
+++ b/eth/downloader/statesync.go
@@ -101,8 +101,16 @@ func (d *Downloader) runStateSync(s *stateSync) *stateSync {
finished []*stateReq // Completed or failed requests
timeout = make(chan *stateReq) // Timed out active requests
)
- // Run the state sync.
log.Trace("State sync starting", "root", s.root)
+
+ defer func() {
+ // Cancel active request timers on exit. Also set peers to idle so they're
+ // available for the next sync.
+ for _, req := range active {
+ req.timer.Stop()
+ req.peer.SetNodeDataIdle(int(req.nItems), time.Now())
+ }
+ }()
go s.run()
defer s.Cancel()
@@ -252,8 +260,9 @@ func (d *Downloader) spindownStateSync(active map[string]*stateReq, finished []*
type stateSync struct {
d *Downloader // Downloader instance to access and manage current peerset
- sched *trie.Sync // State trie sync scheduler defining the tasks
- keccak hash.Hash // Keccak256 hasher to verify deliveries with
+ root common.Hash // State root currently being synced
+ sched *trie.Sync // State trie sync scheduler defining the tasks
+ keccak hash.Hash // Keccak256 hasher to verify deliveries with
trieTasks map[common.Hash]*trieTask // Set of trie node tasks currently queued for retrieval
codeTasks map[common.Hash]*codeTask // Set of byte code tasks currently queued for retrieval
@@ -268,8 +277,6 @@ type stateSync struct {
cancelOnce sync.Once // Ensures cancel only ever gets called once
done chan struct{} // Channel to signal termination completion
err error // Any error hit during sync (set before completion)
-
- root common.Hash
}
// trieTask represents a single trie node download task, containing a set of
@@ -290,6 +297,7 @@ type codeTask struct {
func newStateSync(d *Downloader, root common.Hash) *stateSync {
return &stateSync{
d: d,
+ root: root,
sched: state.NewStateSync(root, d.stateDB, d.stateBloom),
keccak: sha3.NewLegacyKeccak256(),
trieTasks: make(map[common.Hash]*trieTask),
@@ -298,7 +306,6 @@ func newStateSync(d *Downloader, root common.Hash) *stateSync {
cancel: make(chan struct{}),
done: make(chan struct{}),
started: make(chan struct{}),
- root: root,
}
}
@@ -306,7 +313,12 @@ func newStateSync(d *Downloader, root common.Hash) *stateSync {
// it finishes, and finally notifying any goroutines waiting for the loop to
// finish.
func (s *stateSync) run() {
- s.err = s.loop()
+ close(s.started)
+ if s.d.snapSync {
+ s.err = s.d.SnapSyncer.Sync(s.root, s.cancel)
+ } else {
+ s.err = s.loop()
+ }
close(s.done)
}
@@ -318,7 +330,9 @@ func (s *stateSync) Wait() error {
// Cancel cancels the sync and waits until it has shut down.
func (s *stateSync) Cancel() error {
- s.cancelOnce.Do(func() { close(s.cancel) })
+ s.cancelOnce.Do(func() {
+ close(s.cancel)
+ })
return s.Wait()
}
@@ -329,7 +343,6 @@ func (s *stateSync) Cancel() error {
// pushed here async. The reason is to decouple processing from data receipt
// and timeouts.
func (s *stateSync) loop() (err error) {
- close(s.started)
// Listen for new peer events to assign tasks to them
newPeer := make(chan *peerConnection, 1024)
peerSub := s.d.peers.SubscribeNewPeers(newPeer)
diff --git a/eth/gen_config.go b/eth/gen_config.go
index b0674c7d7..dd04635ee 100644
--- a/eth/gen_config.go
+++ b/eth/gen_config.go
@@ -20,7 +20,7 @@ func (c Config) MarshalTOML() (interface{}, error) {
Genesis *core.Genesis `toml:",omitempty"`
NetworkId uint64
SyncMode downloader.SyncMode
- DiscoveryURLs []string
+ EthDiscoveryURLs []string
NoPruning bool
NoPrefetch bool
TxLookupLimit uint64 `toml:",omitempty"`
@@ -61,7 +61,7 @@ func (c Config) MarshalTOML() (interface{}, error) {
enc.Genesis = c.Genesis
enc.NetworkId = c.NetworkId
enc.SyncMode = c.SyncMode
- enc.DiscoveryURLs = c.DiscoveryURLs
+ enc.EthDiscoveryURLs = c.EthDiscoveryURLs
enc.NoPruning = c.NoPruning
enc.NoPrefetch = c.NoPrefetch
enc.TxLookupLimit = c.TxLookupLimit
@@ -106,7 +106,7 @@ func (c *Config) UnmarshalTOML(unmarshal func(interface{}) error) error {
Genesis *core.Genesis `toml:",omitempty"`
NetworkId *uint64
SyncMode *downloader.SyncMode
- DiscoveryURLs []string
+ EthDiscoveryURLs []string
NoPruning *bool
NoPrefetch *bool
TxLookupLimit *uint64 `toml:",omitempty"`
@@ -156,8 +156,8 @@ func (c *Config) UnmarshalTOML(unmarshal func(interface{}) error) error {
if dec.SyncMode != nil {
c.SyncMode = *dec.SyncMode
}
- if dec.DiscoveryURLs != nil {
- c.DiscoveryURLs = dec.DiscoveryURLs
+ if dec.EthDiscoveryURLs != nil {
+ c.EthDiscoveryURLs = dec.EthDiscoveryURLs
}
if dec.NoPruning != nil {
c.NoPruning = *dec.NoPruning
diff --git a/eth/handler.go b/eth/handler.go
index 5b8998653..76a429f6d 100644
--- a/eth/handler.go
+++ b/eth/handler.go
@@ -17,9 +17,7 @@
package eth
import (
- "encoding/json"
"errors"
- "fmt"
"math"
"math/big"
"sync"
@@ -27,26 +25,22 @@ import (
"time"
"github.com/ethereum/go-ethereum/common"
- "github.com/ethereum/go-ethereum/consensus"
"github.com/ethereum/go-ethereum/core"
"github.com/ethereum/go-ethereum/core/forkid"
"github.com/ethereum/go-ethereum/core/types"
"github.com/ethereum/go-ethereum/eth/downloader"
"github.com/ethereum/go-ethereum/eth/fetcher"
+ "github.com/ethereum/go-ethereum/eth/protocols/eth"
+ "github.com/ethereum/go-ethereum/eth/protocols/snap"
"github.com/ethereum/go-ethereum/ethdb"
"github.com/ethereum/go-ethereum/event"
"github.com/ethereum/go-ethereum/log"
"github.com/ethereum/go-ethereum/p2p"
- "github.com/ethereum/go-ethereum/p2p/enode"
"github.com/ethereum/go-ethereum/params"
- "github.com/ethereum/go-ethereum/rlp"
"github.com/ethereum/go-ethereum/trie"
)
const (
- softResponseLimit = 2 * 1024 * 1024 // Target maximum size of returned blocks, headers or node data.
- estHeaderRlpSize = 500 // Approximate size of an RLP encoded block header
-
// txChanSize is the size of channel listening to NewTxsEvent.
// The number is referenced from the size of tx pool.
txChanSize = 4096
@@ -56,26 +50,61 @@ var (
syncChallengeTimeout = 15 * time.Second // Time allowance for a node to reply to the sync progress challenge
)
-func errResp(code errCode, format string, v ...interface{}) error {
- return fmt.Errorf("%v - %v", code, fmt.Sprintf(format, v...))
+// txPool defines the methods needed from a transaction pool implementation to
+// support all the operations needed by the Ethereum chain protocols.
+type txPool interface {
+ // Has returns an indicator whether txpool has a transaction
+ // cached with the given hash.
+ Has(hash common.Hash) bool
+
+ // Get retrieves the transaction from local txpool with given
+ // tx hash.
+ Get(hash common.Hash) *types.Transaction
+
+ // AddRemotes should add the given transactions to the pool.
+ AddRemotes([]*types.Transaction) []error
+
+ // Pending should return pending transactions.
+ // The slice should be modifiable by the caller.
+ Pending() (map[common.Address]types.Transactions, error)
+
+ // SubscribeNewTxsEvent should return an event subscription of
+ // NewTxsEvent and send events to the given channel.
+ SubscribeNewTxsEvent(chan<- core.NewTxsEvent) event.Subscription
}
-type ProtocolManager struct {
+// handlerConfig is the collection of initialization parameters to create a full
+// node network handler.
+type handlerConfig struct {
+ Database ethdb.Database // Database for direct sync insertions
+ Chain *core.BlockChain // Blockchain to serve data from
+ TxPool txPool // Transaction pool to propagate from
+ Network uint64 // Network identifier to adfvertise
+ Sync downloader.SyncMode // Whether to fast or full sync
+ BloomCache uint64 // Megabytes to alloc for fast sync bloom
+ EventMux *event.TypeMux // Legacy event mux, deprecate for `feed`
+ Checkpoint *params.TrustedCheckpoint // Hard coded checkpoint for sync challenges
+ Whitelist map[uint64]common.Hash // Hard coded whitelist for sync challenged
+}
+
+type handler struct {
networkID uint64
forkFilter forkid.Filter // Fork ID filter, constant across the lifetime of the node
fastSync uint32 // Flag whether fast sync is enabled (gets disabled if we already have blocks)
+ snapSync uint32 // Flag whether fast sync should operate on top of the snap protocol
acceptTxs uint32 // Flag whether we're considered synchronised (enables transaction processing)
checkpointNumber uint64 // Block number for the sync progress validator to cross reference
checkpointHash common.Hash // Block hash for the sync progress validator to cross reference
- txpool txPool
- blockchain *core.BlockChain
- chaindb ethdb.Database
- maxPeers int
+ database ethdb.Database
+ txpool txPool
+ chain *core.BlockChain
+ maxPeers int
downloader *downloader.Downloader
+ stateBloom *trie.SyncBloom
blockFetcher *fetcher.BlockFetcher
txFetcher *fetcher.TxFetcher
peers *peerSet
@@ -94,29 +123,27 @@ type ProtocolManager struct {
chainSync *chainSyncer
wg sync.WaitGroup
peerWG sync.WaitGroup
-
- // Test fields or hooks
- broadcastTxAnnouncesOnly bool // Testing field, disable transaction propagation
}
-// NewProtocolManager returns a new Ethereum sub protocol manager. The Ethereum sub protocol manages peers capable
-// with the Ethereum network.
-func NewProtocolManager(config *params.ChainConfig, checkpoint *params.TrustedCheckpoint, mode downloader.SyncMode, networkID uint64, mux *event.TypeMux, txpool txPool, engine consensus.Engine, blockchain *core.BlockChain, chaindb ethdb.Database, cacheLimit int, whitelist map[uint64]common.Hash) (*ProtocolManager, error) {
+// newHandler returns a handler for all Ethereum chain management protocol.
+func newHandler(config *handlerConfig) (*handler, error) {
// Create the protocol manager with the base fields
- manager := &ProtocolManager{
- networkID: networkID,
- forkFilter: forkid.NewFilter(blockchain),
- eventMux: mux,
- txpool: txpool,
- blockchain: blockchain,
- chaindb: chaindb,
+ if config.EventMux == nil {
+ config.EventMux = new(event.TypeMux) // Nicety initialization for tests
+ }
+ h := &handler{
+ networkID: config.Network,
+ forkFilter: forkid.NewFilter(config.Chain),
+ eventMux: config.EventMux,
+ database: config.Database,
+ txpool: config.TxPool,
+ chain: config.Chain,
peers: newPeerSet(),
- whitelist: whitelist,
+ whitelist: config.Whitelist,
txsyncCh: make(chan *txsync),
quitSync: make(chan struct{}),
}
-
- if mode == downloader.FullSync {
+ if config.Sync == downloader.FullSync {
// The database seems empty as the current block is the genesis. Yet the fast
// block is ahead, so fast sync was enabled for this node at a certain point.
// The scenarios where this can happen is
@@ -125,42 +152,42 @@ func NewProtocolManager(config *params.ChainConfig, checkpoint *params.TrustedCh
// * the last fast sync is not finished while user specifies a full sync this
// time. But we don't have any recent state for full sync.
// In these cases however it's safe to reenable fast sync.
- fullBlock, fastBlock := blockchain.CurrentBlock(), blockchain.CurrentFastBlock()
+ fullBlock, fastBlock := h.chain.CurrentBlock(), h.chain.CurrentFastBlock()
if fullBlock.NumberU64() == 0 && fastBlock.NumberU64() > 0 {
- manager.fastSync = uint32(1)
+ h.fastSync = uint32(1)
log.Warn("Switch sync mode from full sync to fast sync")
}
} else {
- if blockchain.CurrentBlock().NumberU64() > 0 {
+ if h.chain.CurrentBlock().NumberU64() > 0 {
// Print warning log if database is not empty to run fast sync.
log.Warn("Switch sync mode from fast sync to full sync")
} else {
// If fast sync was requested and our database is empty, grant it
- manager.fastSync = uint32(1)
+ h.fastSync = uint32(1)
+ if config.Sync == downloader.SnapSync {
+ h.snapSync = uint32(1)
+ }
}
}
-
// If we have trusted checkpoints, enforce them on the chain
- if checkpoint != nil {
- manager.checkpointNumber = (checkpoint.SectionIndex+1)*params.CHTFrequency - 1
- manager.checkpointHash = checkpoint.SectionHead
+ if config.Checkpoint != nil {
+ h.checkpointNumber = (config.Checkpoint.SectionIndex+1)*params.CHTFrequency - 1
+ h.checkpointHash = config.Checkpoint.SectionHead
}
-
// Construct the downloader (long sync) and its backing state bloom if fast
// sync is requested. The downloader is responsible for deallocating the state
// bloom when it's done.
- var stateBloom *trie.SyncBloom
- if atomic.LoadUint32(&manager.fastSync) == 1 {
- stateBloom = trie.NewSyncBloom(uint64(cacheLimit), chaindb)
+ if atomic.LoadUint32(&h.fastSync) == 1 {
+ h.stateBloom = trie.NewSyncBloom(config.BloomCache, config.Database)
}
- manager.downloader = downloader.New(manager.checkpointNumber, chaindb, stateBloom, manager.eventMux, blockchain, nil, manager.removePeer)
+ h.downloader = downloader.New(h.checkpointNumber, config.Database, h.stateBloom, h.eventMux, h.chain, nil, h.removePeer)
// Construct the fetcher (short sync)
validator := func(header *types.Header) error {
- return engine.VerifyHeader(blockchain, header, true)
+ return h.chain.Engine().VerifyHeader(h.chain, header, true)
}
heighter := func() uint64 {
- return blockchain.CurrentBlock().NumberU64()
+ return h.chain.CurrentBlock().NumberU64()
}
inserter := func(blocks types.Blocks) (int, error) {
// If sync hasn't reached the checkpoint yet, deny importing weird blocks.
@@ -169,7 +196,7 @@ func NewProtocolManager(config *params.ChainConfig, checkpoint *params.TrustedCh
// the propagated block if the head is too old. Unfortunately there is a corner
// case when starting new networks, where the genesis might be ancient (0 unix)
// which would prevent full nodes from accepting it.
- if manager.blockchain.CurrentBlock().NumberU64() < manager.checkpointNumber {
+ if h.chain.CurrentBlock().NumberU64() < h.checkpointNumber {
log.Warn("Unsynced yet, discarded propagated block", "number", blocks[0].Number(), "hash", blocks[0].Hash())
return 0, nil
}
@@ -178,180 +205,88 @@ func NewProtocolManager(config *params.ChainConfig, checkpoint *params.TrustedCh
// accept each others' blocks until a restart. Unfortunately we haven't figured
// out a way yet where nodes can decide unilaterally whether the network is new
// or not. This should be fixed if we figure out a solution.
- if atomic.LoadUint32(&manager.fastSync) == 1 {
+ if atomic.LoadUint32(&h.fastSync) == 1 {
log.Warn("Fast syncing, discarded propagated block", "number", blocks[0].Number(), "hash", blocks[0].Hash())
return 0, nil
}
- n, err := manager.blockchain.InsertChain(blocks)
+ n, err := h.chain.InsertChain(blocks)
if err == nil {
- atomic.StoreUint32(&manager.acceptTxs, 1) // Mark initial sync done on any fetcher import
+ atomic.StoreUint32(&h.acceptTxs, 1) // Mark initial sync done on any fetcher import
}
return n, err
}
- manager.blockFetcher = fetcher.NewBlockFetcher(false, nil, blockchain.GetBlockByHash, validator, manager.BroadcastBlock, heighter, nil, inserter, manager.removePeer)
+ h.blockFetcher = fetcher.NewBlockFetcher(false, nil, h.chain.GetBlockByHash, validator, h.BroadcastBlock, heighter, nil, inserter, h.removePeer)
fetchTx := func(peer string, hashes []common.Hash) error {
- p := manager.peers.Peer(peer)
+ p := h.peers.ethPeer(peer)
if p == nil {
return errors.New("unknown peer")
}
return p.RequestTxs(hashes)
}
- manager.txFetcher = fetcher.NewTxFetcher(txpool.Has, txpool.AddRemotes, fetchTx)
-
- manager.chainSync = newChainSyncer(manager)
-
- return manager, nil
+ h.txFetcher = fetcher.NewTxFetcher(h.txpool.Has, h.txpool.AddRemotes, fetchTx)
+ h.chainSync = newChainSyncer(h)
+ return h, nil
}
-func (pm *ProtocolManager) makeProtocol(version uint) p2p.Protocol {
- length, ok := protocolLengths[version]
- if !ok {
- panic("makeProtocol for unknown version")
- }
-
- return p2p.Protocol{
- Name: protocolName,
- Version: version,
- Length: length,
- Run: func(p *p2p.Peer, rw p2p.MsgReadWriter) error {
- return pm.runPeer(pm.newPeer(int(version), p, rw, pm.txpool.Get))
- },
- NodeInfo: func() interface{} {
- return pm.NodeInfo()
- },
- PeerInfo: func(id enode.ID) interface{} {
- if p := pm.peers.Peer(fmt.Sprintf("%x", id[:8])); p != nil {
- return p.Info()
- }
- return nil
- },
- }
-}
-
-func (pm *ProtocolManager) removePeer(id string) {
- // Short circuit if the peer was already removed
- peer := pm.peers.Peer(id)
- if peer == nil {
- return
- }
- log.Debug("Removing Ethereum peer", "peer", id)
-
- // Unregister the peer from the downloader and Ethereum peer set
- pm.downloader.UnregisterPeer(id)
- pm.txFetcher.Drop(id)
-
- if err := pm.peers.Unregister(id); err != nil {
- log.Error("Peer removal failed", "peer", id, "err", err)
- }
- // Hard disconnect at the networking layer
- if peer != nil {
- peer.Peer.Disconnect(p2p.DiscUselessPeer)
- }
-}
-
-func (pm *ProtocolManager) Start(maxPeers int) {
- pm.maxPeers = maxPeers
-
- // broadcast transactions
- pm.wg.Add(1)
- pm.txsCh = make(chan core.NewTxsEvent, txChanSize)
- pm.txsSub = pm.txpool.SubscribeNewTxsEvent(pm.txsCh)
- go pm.txBroadcastLoop()
-
- // broadcast mined blocks
- pm.wg.Add(1)
- pm.minedBlockSub = pm.eventMux.Subscribe(core.NewMinedBlockEvent{})
- go pm.minedBroadcastLoop()
-
- // start sync handlers
- pm.wg.Add(2)
- go pm.chainSync.loop()
- go pm.txsyncLoop64() // TODO(karalabe): Legacy initial tx echange, drop with eth/64.
-}
-
-func (pm *ProtocolManager) Stop() {
- pm.txsSub.Unsubscribe() // quits txBroadcastLoop
- pm.minedBlockSub.Unsubscribe() // quits blockBroadcastLoop
-
- // Quit chainSync and txsync64.
- // After this is done, no new peers will be accepted.
- close(pm.quitSync)
- pm.wg.Wait()
-
- // Disconnect existing sessions.
- // This also closes the gate for any new registrations on the peer set.
- // sessions which are already established but not added to pm.peers yet
- // will exit when they try to register.
- pm.peers.Close()
- pm.peerWG.Wait()
-
- log.Info("Ethereum protocol stopped")
-}
-
-func (pm *ProtocolManager) newPeer(pv int, p *p2p.Peer, rw p2p.MsgReadWriter, getPooledTx func(hash common.Hash) *types.Transaction) *peer {
- return newPeer(pv, p, rw, getPooledTx)
-}
-
-func (pm *ProtocolManager) runPeer(p *peer) error {
- if !pm.chainSync.handlePeerEvent(p) {
+// runEthPeer
+func (h *handler) runEthPeer(peer *eth.Peer, handler eth.Handler) error {
+ if !h.chainSync.handlePeerEvent(peer) {
return p2p.DiscQuitting
}
- pm.peerWG.Add(1)
- defer pm.peerWG.Done()
- return pm.handle(p)
-}
-
-// handle is the callback invoked to manage the life cycle of an eth peer. When
-// this function terminates, the peer is disconnected.
-func (pm *ProtocolManager) handle(p *peer) error {
- // Ignore maxPeers if this is a trusted peer
- if pm.peers.Len() >= pm.maxPeers && !p.Peer.Info().Network.Trusted {
- return p2p.DiscTooManyPeers
- }
- p.Log().Debug("Ethereum peer connected", "name", p.Name())
+ h.peerWG.Add(1)
+ defer h.peerWG.Done()
// Execute the Ethereum handshake
var (
- genesis = pm.blockchain.Genesis()
- head = pm.blockchain.CurrentHeader()
+ genesis = h.chain.Genesis()
+ head = h.chain.CurrentHeader()
hash = head.Hash()
number = head.Number.Uint64()
- td = pm.blockchain.GetTd(hash, number)
+ td = h.chain.GetTd(hash, number)
)
- forkID := forkid.NewID(pm.blockchain.Config(), pm.blockchain.Genesis().Hash(), pm.blockchain.CurrentHeader().Number.Uint64())
- if err := p.Handshake(pm.networkID, td, hash, genesis.Hash(), forkID, pm.forkFilter); err != nil {
- p.Log().Debug("Ethereum handshake failed", "err", err)
+ forkID := forkid.NewID(h.chain.Config(), h.chain.Genesis().Hash(), h.chain.CurrentHeader().Number.Uint64())
+ if err := peer.Handshake(h.networkID, td, hash, genesis.Hash(), forkID, h.forkFilter); err != nil {
+ peer.Log().Debug("Ethereum handshake failed", "err", err)
return err
}
+ // Ignore maxPeers if this is a trusted peer
+ if h.peers.Len() >= h.maxPeers && !peer.Peer.Info().Network.Trusted {
+ return p2p.DiscTooManyPeers
+ }
+ peer.Log().Debug("Ethereum peer connected", "name", peer.Name())
// Register the peer locally
- if err := pm.peers.Register(p, pm.removePeer); err != nil {
- p.Log().Error("Ethereum peer registration failed", "err", err)
+ if err := h.peers.registerEthPeer(peer); err != nil {
+ peer.Log().Error("Ethereum peer registration failed", "err", err)
return err
}
- defer pm.removePeer(p.id)
+ defer h.removePeer(peer.ID())
+ p := h.peers.ethPeer(peer.ID())
+ if p == nil {
+ return errors.New("peer dropped during handling")
+ }
// Register the peer in the downloader. If the downloader considers it banned, we disconnect
- if err := pm.downloader.RegisterPeer(p.id, p.version, p); err != nil {
+ if err := h.downloader.RegisterPeer(peer.ID(), peer.Version(), peer); err != nil {
return err
}
- pm.chainSync.handlePeerEvent(p)
+ h.chainSync.handlePeerEvent(peer)
// Propagate existing transactions. new transactions appearing
// after this will be sent via broadcasts.
- pm.syncTransactions(p)
+ h.syncTransactions(peer)
// If we have a trusted CHT, reject all peers below that (avoid fast sync eclipse)
- if pm.checkpointHash != (common.Hash{}) {
+ if h.checkpointHash != (common.Hash{}) {
// Request the peer's checkpoint header for chain height/weight validation
- if err := p.RequestHeadersByNumber(pm.checkpointNumber, 1, 0, false); err != nil {
+ if err := peer.RequestHeadersByNumber(h.checkpointNumber, 1, 0, false); err != nil {
return err
}
// Start a timer to disconnect if the peer doesn't reply in time
p.syncDrop = time.AfterFunc(syncChallengeTimeout, func() {
- p.Log().Warn("Checkpoint challenge timed out, dropping", "addr", p.RemoteAddr(), "type", p.Name())
- pm.removePeer(p.id)
+ peer.Log().Warn("Checkpoint challenge timed out, dropping", "addr", peer.RemoteAddr(), "type", peer.Name())
+ h.removePeer(peer.ID())
})
// Make sure it's cleaned up if the peer dies off
defer func() {
@@ -362,474 +297,115 @@ func (pm *ProtocolManager) handle(p *peer) error {
}()
}
// If we have any explicit whitelist block hashes, request them
- for number := range pm.whitelist {
- if err := p.RequestHeadersByNumber(number, 1, 0, false); err != nil {
+ for number := range h.whitelist {
+ if err := peer.RequestHeadersByNumber(number, 1, 0, false); err != nil {
return err
}
}
// Handle incoming messages until the connection is torn down
- for {
- if err := pm.handleMsg(p); err != nil {
- p.Log().Debug("Ethereum message handling failed", "err", err)
- return err
+ return handler(peer)
+}
+
+// runSnapPeer
+func (h *handler) runSnapPeer(peer *snap.Peer, handler snap.Handler) error {
+ h.peerWG.Add(1)
+ defer h.peerWG.Done()
+
+ // Register the peer locally
+ if err := h.peers.registerSnapPeer(peer); err != nil {
+ peer.Log().Error("Snapshot peer registration failed", "err", err)
+ return err
+ }
+ defer h.removePeer(peer.ID())
+
+ if err := h.downloader.SnapSyncer.Register(peer); err != nil {
+ return err
+ }
+ // Handle incoming messages until the connection is torn down
+ return handler(peer)
+}
+
+func (h *handler) removePeer(id string) {
+ // Remove the eth peer if it exists
+ eth := h.peers.ethPeer(id)
+ if eth != nil {
+ log.Debug("Removing Ethereum peer", "peer", id)
+ h.downloader.UnregisterPeer(id)
+ h.txFetcher.Drop(id)
+
+ if err := h.peers.unregisterEthPeer(id); err != nil {
+ log.Error("Peer removal failed", "peer", id, "err", err)
}
}
+ // Remove the snap peer if it exists
+ snap := h.peers.snapPeer(id)
+ if snap != nil {
+ log.Debug("Removing Snapshot peer", "peer", id)
+ h.downloader.SnapSyncer.Unregister(id)
+ if err := h.peers.unregisterSnapPeer(id); err != nil {
+ log.Error("Peer removal failed", "peer", id, "err", err)
+ }
+ }
+ // Hard disconnect at the networking layer
+ if eth != nil {
+ eth.Peer.Disconnect(p2p.DiscUselessPeer)
+ }
+ if snap != nil {
+ snap.Peer.Disconnect(p2p.DiscUselessPeer)
+ }
}
-// handleMsg is invoked whenever an inbound message is received from a remote
-// peer. The remote connection is torn down upon returning any error.
-func (pm *ProtocolManager) handleMsg(p *peer) error {
- // Read the next message from the remote peer, and ensure it's fully consumed
- msg, err := p.rw.ReadMsg()
- if err != nil {
- return err
- }
- if msg.Size > protocolMaxMsgSize {
- return errResp(ErrMsgTooLarge, "%v > %v", msg.Size, protocolMaxMsgSize)
- }
- defer msg.Discard()
+func (h *handler) Start(maxPeers int) {
+ h.maxPeers = maxPeers
- // Handle the message depending on its contents
- switch {
- case msg.Code == StatusMsg:
- // Status messages should never arrive after the handshake
- return errResp(ErrExtraStatusMsg, "uncontrolled status message")
+ // broadcast transactions
+ h.wg.Add(1)
+ h.txsCh = make(chan core.NewTxsEvent, txChanSize)
+ h.txsSub = h.txpool.SubscribeNewTxsEvent(h.txsCh)
+ go h.txBroadcastLoop()
- // Block header query, collect the requested headers and reply
- case msg.Code == GetBlockHeadersMsg:
- // Decode the complex header query
- var query getBlockHeadersData
- if err := msg.Decode(&query); err != nil {
- return errResp(ErrDecode, "%v: %v", msg, err)
- }
- hashMode := query.Origin.Hash != (common.Hash{})
- first := true
- maxNonCanonical := uint64(100)
+ // broadcast mined blocks
+ h.wg.Add(1)
+ h.minedBlockSub = h.eventMux.Subscribe(core.NewMinedBlockEvent{})
+ go h.minedBroadcastLoop()
- // Gather headers until the fetch or network limits is reached
- var (
- bytes common.StorageSize
- headers []*types.Header
- unknown bool
- )
- for !unknown && len(headers) < int(query.Amount) && bytes < softResponseLimit && len(headers) < downloader.MaxHeaderFetch {
- // Retrieve the next header satisfying the query
- var origin *types.Header
- if hashMode {
- if first {
- first = false
- origin = pm.blockchain.GetHeaderByHash(query.Origin.Hash)
- if origin != nil {
- query.Origin.Number = origin.Number.Uint64()
- }
- } else {
- origin = pm.blockchain.GetHeader(query.Origin.Hash, query.Origin.Number)
- }
- } else {
- origin = pm.blockchain.GetHeaderByNumber(query.Origin.Number)
- }
- if origin == nil {
- break
- }
- headers = append(headers, origin)
- bytes += estHeaderRlpSize
+ // start sync handlers
+ h.wg.Add(2)
+ go h.chainSync.loop()
+ go h.txsyncLoop64() // TODO(karalabe): Legacy initial tx echange, drop with eth/64.
+}
- // Advance to the next header of the query
- switch {
- case hashMode && query.Reverse:
- // Hash based traversal towards the genesis block
- ancestor := query.Skip + 1
- if ancestor == 0 {
- unknown = true
- } else {
- query.Origin.Hash, query.Origin.Number = pm.blockchain.GetAncestor(query.Origin.Hash, query.Origin.Number, ancestor, &maxNonCanonical)
- unknown = (query.Origin.Hash == common.Hash{})
- }
- case hashMode && !query.Reverse:
- // Hash based traversal towards the leaf block
- var (
- current = origin.Number.Uint64()
- next = current + query.Skip + 1
- )
- if next <= current {
- infos, _ := json.MarshalIndent(p.Peer.Info(), "", " ")
- p.Log().Warn("GetBlockHeaders skip overflow attack", "current", current, "skip", query.Skip, "next", next, "attacker", infos)
- unknown = true
- } else {
- if header := pm.blockchain.GetHeaderByNumber(next); header != nil {
- nextHash := header.Hash()
- expOldHash, _ := pm.blockchain.GetAncestor(nextHash, next, query.Skip+1, &maxNonCanonical)
- if expOldHash == query.Origin.Hash {
- query.Origin.Hash, query.Origin.Number = nextHash, next
- } else {
- unknown = true
- }
- } else {
- unknown = true
- }
- }
- case query.Reverse:
- // Number based traversal towards the genesis block
- if query.Origin.Number >= query.Skip+1 {
- query.Origin.Number -= query.Skip + 1
- } else {
- unknown = true
- }
+func (h *handler) Stop() {
+ h.txsSub.Unsubscribe() // quits txBroadcastLoop
+ h.minedBlockSub.Unsubscribe() // quits blockBroadcastLoop
- case !query.Reverse:
- // Number based traversal towards the leaf block
- query.Origin.Number += query.Skip + 1
- }
- }
- return p.SendBlockHeaders(headers)
+ // Quit chainSync and txsync64.
+ // After this is done, no new peers will be accepted.
+ close(h.quitSync)
+ h.wg.Wait()
- case msg.Code == BlockHeadersMsg:
- // A batch of headers arrived to one of our previous requests
- var headers []*types.Header
- if err := msg.Decode(&headers); err != nil {
- return errResp(ErrDecode, "msg %v: %v", msg, err)
- }
- // If no headers were received, but we're expencting a checkpoint header, consider it that
- if len(headers) == 0 && p.syncDrop != nil {
- // Stop the timer either way, decide later to drop or not
- p.syncDrop.Stop()
- p.syncDrop = nil
+ // Disconnect existing sessions.
+ // This also closes the gate for any new registrations on the peer set.
+ // sessions which are already established but not added to h.peers yet
+ // will exit when they try to register.
+ h.peers.close()
+ h.peerWG.Wait()
- // If we're doing a fast sync, we must enforce the checkpoint block to avoid
- // eclipse attacks. Unsynced nodes are welcome to connect after we're done
- // joining the network
- if atomic.LoadUint32(&pm.fastSync) == 1 {
- p.Log().Warn("Dropping unsynced node during fast sync", "addr", p.RemoteAddr(), "type", p.Name())
- return errors.New("unsynced node cannot serve fast sync")
- }
- }
- // Filter out any explicitly requested headers, deliver the rest to the downloader
- filter := len(headers) == 1
- if filter {
- // If it's a potential sync progress check, validate the content and advertised chain weight
- if p.syncDrop != nil && headers[0].Number.Uint64() == pm.checkpointNumber {
- // Disable the sync drop timer
- p.syncDrop.Stop()
- p.syncDrop = nil
-
- // Validate the header and either drop the peer or continue
- if headers[0].Hash() != pm.checkpointHash {
- return errors.New("checkpoint hash mismatch")
- }
- return nil
- }
- // Otherwise if it's a whitelisted block, validate against the set
- if want, ok := pm.whitelist[headers[0].Number.Uint64()]; ok {
- if hash := headers[0].Hash(); want != hash {
- p.Log().Info("Whitelist mismatch, dropping peer", "number", headers[0].Number.Uint64(), "hash", hash, "want", want)
- return errors.New("whitelist block mismatch")
- }
- p.Log().Debug("Whitelist block verified", "number", headers[0].Number.Uint64(), "hash", want)
- }
- // Irrelevant of the fork checks, send the header to the fetcher just in case
- headers = pm.blockFetcher.FilterHeaders(p.id, headers, time.Now())
- }
- if len(headers) > 0 || !filter {
- err := pm.downloader.DeliverHeaders(p.id, headers)
- if err != nil {
- log.Debug("Failed to deliver headers", "err", err)
- }
- }
-
- case msg.Code == GetBlockBodiesMsg:
- // Decode the retrieval message
- msgStream := rlp.NewStream(msg.Payload, uint64(msg.Size))
- if _, err := msgStream.List(); err != nil {
- return err
- }
- // Gather blocks until the fetch or network limits is reached
- var (
- hash common.Hash
- bytes int
- bodies []rlp.RawValue
- )
- for bytes < softResponseLimit && len(bodies) < downloader.MaxBlockFetch {
- // Retrieve the hash of the next block
- if err := msgStream.Decode(&hash); err == rlp.EOL {
- break
- } else if err != nil {
- return errResp(ErrDecode, "msg %v: %v", msg, err)
- }
- // Retrieve the requested block body, stopping if enough was found
- if data := pm.blockchain.GetBodyRLP(hash); len(data) != 0 {
- bodies = append(bodies, data)
- bytes += len(data)
- }
- }
- return p.SendBlockBodiesRLP(bodies)
-
- case msg.Code == BlockBodiesMsg:
- // A batch of block bodies arrived to one of our previous requests
- var request blockBodiesData
- if err := msg.Decode(&request); err != nil {
- return errResp(ErrDecode, "msg %v: %v", msg, err)
- }
- // Deliver them all to the downloader for queuing
- transactions := make([][]*types.Transaction, len(request))
- uncles := make([][]*types.Header, len(request))
-
- for i, body := range request {
- transactions[i] = body.Transactions
- uncles[i] = body.Uncles
- }
- // Filter out any explicitly requested bodies, deliver the rest to the downloader
- filter := len(transactions) > 0 || len(uncles) > 0
- if filter {
- transactions, uncles = pm.blockFetcher.FilterBodies(p.id, transactions, uncles, time.Now())
- }
- if len(transactions) > 0 || len(uncles) > 0 || !filter {
- err := pm.downloader.DeliverBodies(p.id, transactions, uncles)
- if err != nil {
- log.Debug("Failed to deliver bodies", "err", err)
- }
- }
-
- case p.version >= eth63 && msg.Code == GetNodeDataMsg:
- // Decode the retrieval message
- msgStream := rlp.NewStream(msg.Payload, uint64(msg.Size))
- if _, err := msgStream.List(); err != nil {
- return err
- }
- // Gather state data until the fetch or network limits is reached
- var (
- hash common.Hash
- bytes int
- data [][]byte
- )
- for bytes < softResponseLimit && len(data) < downloader.MaxStateFetch {
- // Retrieve the hash of the next state entry
- if err := msgStream.Decode(&hash); err == rlp.EOL {
- break
- } else if err != nil {
- return errResp(ErrDecode, "msg %v: %v", msg, err)
- }
- // Retrieve the requested state entry, stopping if enough was found
- // todo now the code and trienode is mixed in the protocol level,
- // separate these two types.
- if !pm.downloader.SyncBloomContains(hash[:]) {
- // Only lookup the trie node if there's chance that we actually have it
- continue
- }
- entry, err := pm.blockchain.TrieNode(hash)
- if len(entry) == 0 || err != nil {
- // Read the contract code with prefix only to save unnecessary lookups.
- entry, err = pm.blockchain.ContractCodeWithPrefix(hash)
- }
- if err == nil && len(entry) > 0 {
- data = append(data, entry)
- bytes += len(entry)
- }
- }
- return p.SendNodeData(data)
-
- case p.version >= eth63 && msg.Code == NodeDataMsg:
- // A batch of node state data arrived to one of our previous requests
- var data [][]byte
- if err := msg.Decode(&data); err != nil {
- return errResp(ErrDecode, "msg %v: %v", msg, err)
- }
- // Deliver all to the downloader
- if err := pm.downloader.DeliverNodeData(p.id, data); err != nil {
- log.Debug("Failed to deliver node state data", "err", err)
- }
-
- case p.version >= eth63 && msg.Code == GetReceiptsMsg:
- // Decode the retrieval message
- msgStream := rlp.NewStream(msg.Payload, uint64(msg.Size))
- if _, err := msgStream.List(); err != nil {
- return err
- }
- // Gather state data until the fetch or network limits is reached
- var (
- hash common.Hash
- bytes int
- receipts []rlp.RawValue
- )
- for bytes < softResponseLimit && len(receipts) < downloader.MaxReceiptFetch {
- // Retrieve the hash of the next block
- if err := msgStream.Decode(&hash); err == rlp.EOL {
- break
- } else if err != nil {
- return errResp(ErrDecode, "msg %v: %v", msg, err)
- }
- // Retrieve the requested block's receipts, skipping if unknown to us
- results := pm.blockchain.GetReceiptsByHash(hash)
- if results == nil {
- if header := pm.blockchain.GetHeaderByHash(hash); header == nil || header.ReceiptHash != types.EmptyRootHash {
- continue
- }
- }
- // If known, encode and queue for response packet
- if encoded, err := rlp.EncodeToBytes(results); err != nil {
- log.Error("Failed to encode receipt", "err", err)
- } else {
- receipts = append(receipts, encoded)
- bytes += len(encoded)
- }
- }
- return p.SendReceiptsRLP(receipts)
-
- case p.version >= eth63 && msg.Code == ReceiptsMsg:
- // A batch of receipts arrived to one of our previous requests
- var receipts [][]*types.Receipt
- if err := msg.Decode(&receipts); err != nil {
- return errResp(ErrDecode, "msg %v: %v", msg, err)
- }
- // Deliver all to the downloader
- if err := pm.downloader.DeliverReceipts(p.id, receipts); err != nil {
- log.Debug("Failed to deliver receipts", "err", err)
- }
-
- case msg.Code == NewBlockHashesMsg:
- var announces newBlockHashesData
- if err := msg.Decode(&announces); err != nil {
- return errResp(ErrDecode, "%v: %v", msg, err)
- }
- // Mark the hashes as present at the remote node
- for _, block := range announces {
- p.MarkBlock(block.Hash)
- }
- // Schedule all the unknown hashes for retrieval
- unknown := make(newBlockHashesData, 0, len(announces))
- for _, block := range announces {
- if !pm.blockchain.HasBlock(block.Hash, block.Number) {
- unknown = append(unknown, block)
- }
- }
- for _, block := range unknown {
- pm.blockFetcher.Notify(p.id, block.Hash, block.Number, time.Now(), p.RequestOneHeader, p.RequestBodies)
- }
-
- case msg.Code == NewBlockMsg:
- // Retrieve and decode the propagated block
- var request newBlockData
- if err := msg.Decode(&request); err != nil {
- return errResp(ErrDecode, "%v: %v", msg, err)
- }
- if hash := types.CalcUncleHash(request.Block.Uncles()); hash != request.Block.UncleHash() {
- log.Warn("Propagated block has invalid uncles", "have", hash, "exp", request.Block.UncleHash())
- break // TODO(karalabe): return error eventually, but wait a few releases
- }
- if hash := types.DeriveSha(request.Block.Transactions(), trie.NewStackTrie(nil)); hash != request.Block.TxHash() {
- log.Warn("Propagated block has invalid body", "have", hash, "exp", request.Block.TxHash())
- break // TODO(karalabe): return error eventually, but wait a few releases
- }
- if err := request.sanityCheck(); err != nil {
- return err
- }
- request.Block.ReceivedAt = msg.ReceivedAt
- request.Block.ReceivedFrom = p
-
- // Mark the peer as owning the block and schedule it for import
- p.MarkBlock(request.Block.Hash())
- pm.blockFetcher.Enqueue(p.id, request.Block)
-
- // Assuming the block is importable by the peer, but possibly not yet done so,
- // calculate the head hash and TD that the peer truly must have.
- var (
- trueHead = request.Block.ParentHash()
- trueTD = new(big.Int).Sub(request.TD, request.Block.Difficulty())
- )
- // Update the peer's total difficulty if better than the previous
- if _, td := p.Head(); trueTD.Cmp(td) > 0 {
- p.SetHead(trueHead, trueTD)
- pm.chainSync.handlePeerEvent(p)
- }
-
- case msg.Code == NewPooledTransactionHashesMsg && p.version >= eth65:
- // New transaction announcement arrived, make sure we have
- // a valid and fresh chain to handle them
- if atomic.LoadUint32(&pm.acceptTxs) == 0 {
- break
- }
- var hashes []common.Hash
- if err := msg.Decode(&hashes); err != nil {
- return errResp(ErrDecode, "msg %v: %v", msg, err)
- }
- // Schedule all the unknown hashes for retrieval
- for _, hash := range hashes {
- p.MarkTransaction(hash)
- }
- pm.txFetcher.Notify(p.id, hashes)
-
- case msg.Code == GetPooledTransactionsMsg && p.version >= eth65:
- // Decode the retrieval message
- msgStream := rlp.NewStream(msg.Payload, uint64(msg.Size))
- if _, err := msgStream.List(); err != nil {
- return err
- }
- // Gather transactions until the fetch or network limits is reached
- var (
- hash common.Hash
- bytes int
- hashes []common.Hash
- txs []rlp.RawValue
- )
- for bytes < softResponseLimit {
- // Retrieve the hash of the next block
- if err := msgStream.Decode(&hash); err == rlp.EOL {
- break
- } else if err != nil {
- return errResp(ErrDecode, "msg %v: %v", msg, err)
- }
- // Retrieve the requested transaction, skipping if unknown to us
- tx := pm.txpool.Get(hash)
- if tx == nil {
- continue
- }
- // If known, encode and queue for response packet
- if encoded, err := rlp.EncodeToBytes(tx); err != nil {
- log.Error("Failed to encode transaction", "err", err)
- } else {
- hashes = append(hashes, hash)
- txs = append(txs, encoded)
- bytes += len(encoded)
- }
- }
- return p.SendPooledTransactionsRLP(hashes, txs)
-
- case msg.Code == TransactionMsg || (msg.Code == PooledTransactionsMsg && p.version >= eth65):
- // Transactions arrived, make sure we have a valid and fresh chain to handle them
- if atomic.LoadUint32(&pm.acceptTxs) == 0 {
- break
- }
- // Transactions can be processed, parse all of them and deliver to the pool
- var txs []*types.Transaction
- if err := msg.Decode(&txs); err != nil {
- return errResp(ErrDecode, "msg %v: %v", msg, err)
- }
- for i, tx := range txs {
- // Validate and mark the remote transaction
- if tx == nil {
- return errResp(ErrDecode, "transaction %d is nil", i)
- }
- p.MarkTransaction(tx.Hash())
- }
- pm.txFetcher.Enqueue(p.id, txs, msg.Code == PooledTransactionsMsg)
-
- default:
- return errResp(ErrInvalidMsgCode, "%v", msg.Code)
- }
- return nil
+ log.Info("Ethereum protocol stopped")
}
// BroadcastBlock will either propagate a block to a subset of its peers, or
// will only announce its availability (depending what's requested).
-func (pm *ProtocolManager) BroadcastBlock(block *types.Block, propagate bool) {
+func (h *handler) BroadcastBlock(block *types.Block, propagate bool) {
hash := block.Hash()
- peers := pm.peers.PeersWithoutBlock(hash)
+ peers := h.peers.ethPeersWithoutBlock(hash)
// If propagation is requested, send to a subset of the peer
if propagate {
// Calculate the TD of the block (it's not imported yet, so block.Td is not valid)
var td *big.Int
- if parent := pm.blockchain.GetBlock(block.ParentHash(), block.NumberU64()-1); parent != nil {
- td = new(big.Int).Add(block.Difficulty(), pm.blockchain.GetTd(block.ParentHash(), block.NumberU64()-1))
+ if parent := h.chain.GetBlock(block.ParentHash(), block.NumberU64()-1); parent != nil {
+ td = new(big.Int).Add(block.Difficulty(), h.chain.GetTd(block.ParentHash(), block.NumberU64()-1))
} else {
log.Error("Propagating dangling block", "number", block.Number(), "hash", hash)
return
@@ -843,7 +419,7 @@ func (pm *ProtocolManager) BroadcastBlock(block *types.Block, propagate bool) {
return
}
// Otherwise if the block is indeed in out own chain, announce it
- if pm.blockchain.HasBlock(hash, block.NumberU64()) {
+ if h.chain.HasBlock(hash, block.NumberU64()) {
for _, peer := range peers {
peer.AsyncSendNewBlockHash(block)
}
@@ -853,15 +429,15 @@ func (pm *ProtocolManager) BroadcastBlock(block *types.Block, propagate bool) {
// BroadcastTransactions will propagate a batch of transactions to all peers which are not known to
// already have the given transaction.
-func (pm *ProtocolManager) BroadcastTransactions(txs types.Transactions, propagate bool) {
+func (h *handler) BroadcastTransactions(txs types.Transactions, propagate bool) {
var (
- txset = make(map[*peer][]common.Hash)
- annos = make(map[*peer][]common.Hash)
+ txset = make(map[*ethPeer][]common.Hash)
+ annos = make(map[*ethPeer][]common.Hash)
)
// Broadcast transactions to a batch of peers not knowing about it
if propagate {
for _, tx := range txs {
- peers := pm.peers.PeersWithoutTx(tx.Hash())
+ peers := h.peers.ethPeersWithoutTransacion(tx.Hash())
// Send the block to a subset of our peers
transfer := peers[:int(math.Sqrt(float64(len(peers))))]
@@ -877,13 +453,13 @@ func (pm *ProtocolManager) BroadcastTransactions(txs types.Transactions, propaga
}
// Otherwise only broadcast the announcement to peers
for _, tx := range txs {
- peers := pm.peers.PeersWithoutTx(tx.Hash())
+ peers := h.peers.ethPeersWithoutTransacion(tx.Hash())
for _, peer := range peers {
annos[peer] = append(annos[peer], tx.Hash())
}
}
for peer, hashes := range annos {
- if peer.version >= eth65 {
+ if peer.Version() >= eth.ETH65 {
peer.AsyncSendPooledTransactionHashes(hashes)
} else {
peer.AsyncSendTransactions(hashes)
@@ -892,56 +468,29 @@ func (pm *ProtocolManager) BroadcastTransactions(txs types.Transactions, propaga
}
// minedBroadcastLoop sends mined blocks to connected peers.
-func (pm *ProtocolManager) minedBroadcastLoop() {
- defer pm.wg.Done()
+func (h *handler) minedBroadcastLoop() {
+ defer h.wg.Done()
- for obj := range pm.minedBlockSub.Chan() {
+ for obj := range h.minedBlockSub.Chan() {
if ev, ok := obj.Data.(core.NewMinedBlockEvent); ok {
- pm.BroadcastBlock(ev.Block, true) // First propagate block to peers
- pm.BroadcastBlock(ev.Block, false) // Only then announce to the rest
+ h.BroadcastBlock(ev.Block, true) // First propagate block to peers
+ h.BroadcastBlock(ev.Block, false) // Only then announce to the rest
}
}
}
// txBroadcastLoop announces new transactions to connected peers.
-func (pm *ProtocolManager) txBroadcastLoop() {
- defer pm.wg.Done()
+func (h *handler) txBroadcastLoop() {
+ defer h.wg.Done()
for {
select {
- case event := <-pm.txsCh:
- // For testing purpose only, disable propagation
- if pm.broadcastTxAnnouncesOnly {
- pm.BroadcastTransactions(event.Txs, false)
- continue
- }
- pm.BroadcastTransactions(event.Txs, true) // First propagate transactions to peers
- pm.BroadcastTransactions(event.Txs, false) // Only then announce to the rest
+ case event := <-h.txsCh:
+ h.BroadcastTransactions(event.Txs, true) // First propagate transactions to peers
+ h.BroadcastTransactions(event.Txs, false) // Only then announce to the rest
- case <-pm.txsSub.Err():
+ case <-h.txsSub.Err():
return
}
}
}
-
-// NodeInfo represents a short summary of the Ethereum sub-protocol metadata
-// known about the host peer.
-type NodeInfo struct {
- Network uint64 `json:"network"` // Ethereum network ID (1=Frontier, 2=Morden, Ropsten=3, Rinkeby=4)
- Difficulty *big.Int `json:"difficulty"` // Total difficulty of the host's blockchain
- Genesis common.Hash `json:"genesis"` // SHA3 hash of the host's genesis block
- Config *params.ChainConfig `json:"config"` // Chain configuration for the fork rules
- Head common.Hash `json:"head"` // SHA3 hash of the host's best owned block
-}
-
-// NodeInfo retrieves some protocol metadata about the running host node.
-func (pm *ProtocolManager) NodeInfo() *NodeInfo {
- currentBlock := pm.blockchain.CurrentBlock()
- return &NodeInfo{
- Network: pm.networkID,
- Difficulty: pm.blockchain.GetTd(currentBlock.Hash(), currentBlock.NumberU64()),
- Genesis: pm.blockchain.Genesis().Hash(),
- Config: pm.blockchain.Config(),
- Head: currentBlock.Hash(),
- }
-}
diff --git a/eth/handler_eth.go b/eth/handler_eth.go
new file mode 100644
index 000000000..84bdac659
--- /dev/null
+++ b/eth/handler_eth.go
@@ -0,0 +1,218 @@
+// Copyright 2015 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see .
+
+package eth
+
+import (
+ "errors"
+ "fmt"
+ "math/big"
+ "sync/atomic"
+ "time"
+
+ "github.com/ethereum/go-ethereum/common"
+ "github.com/ethereum/go-ethereum/core"
+ "github.com/ethereum/go-ethereum/core/types"
+ "github.com/ethereum/go-ethereum/eth/protocols/eth"
+ "github.com/ethereum/go-ethereum/log"
+ "github.com/ethereum/go-ethereum/p2p/enode"
+ "github.com/ethereum/go-ethereum/trie"
+)
+
+// ethHandler implements the eth.Backend interface to handle the various network
+// packets that are sent as replies or broadcasts.
+type ethHandler handler
+
+func (h *ethHandler) Chain() *core.BlockChain { return h.chain }
+func (h *ethHandler) StateBloom() *trie.SyncBloom { return h.stateBloom }
+func (h *ethHandler) TxPool() eth.TxPool { return h.txpool }
+
+// RunPeer is invoked when a peer joins on the `eth` protocol.
+func (h *ethHandler) RunPeer(peer *eth.Peer, hand eth.Handler) error {
+ return (*handler)(h).runEthPeer(peer, hand)
+}
+
+// PeerInfo retrieves all known `eth` information about a peer.
+func (h *ethHandler) PeerInfo(id enode.ID) interface{} {
+ if p := h.peers.ethPeer(id.String()); p != nil {
+ return p.info()
+ }
+ return nil
+}
+
+// AcceptTxs retrieves whether transaction processing is enabled on the node
+// or if inbound transactions should simply be dropped.
+func (h *ethHandler) AcceptTxs() bool {
+ return atomic.LoadUint32(&h.acceptTxs) == 1
+}
+
+// Handle is invoked from a peer's message handler when it receives a new remote
+// message that the handler couldn't consume and serve itself.
+func (h *ethHandler) Handle(peer *eth.Peer, packet eth.Packet) error {
+ // Consume any broadcasts and announces, forwarding the rest to the downloader
+ switch packet := packet.(type) {
+ case *eth.BlockHeadersPacket:
+ return h.handleHeaders(peer, *packet)
+
+ case *eth.BlockBodiesPacket:
+ txset, uncleset := packet.Unpack()
+ return h.handleBodies(peer, txset, uncleset)
+
+ case *eth.NodeDataPacket:
+ if err := h.downloader.DeliverNodeData(peer.ID(), *packet); err != nil {
+ log.Debug("Failed to deliver node state data", "err", err)
+ }
+ return nil
+
+ case *eth.ReceiptsPacket:
+ if err := h.downloader.DeliverReceipts(peer.ID(), *packet); err != nil {
+ log.Debug("Failed to deliver receipts", "err", err)
+ }
+ return nil
+
+ case *eth.NewBlockHashesPacket:
+ hashes, numbers := packet.Unpack()
+ return h.handleBlockAnnounces(peer, hashes, numbers)
+
+ case *eth.NewBlockPacket:
+ return h.handleBlockBroadcast(peer, packet.Block, packet.TD)
+
+ case *eth.NewPooledTransactionHashesPacket:
+ return h.txFetcher.Notify(peer.ID(), *packet)
+
+ case *eth.TransactionsPacket:
+ return h.txFetcher.Enqueue(peer.ID(), *packet, false)
+
+ case *eth.PooledTransactionsPacket:
+ return h.txFetcher.Enqueue(peer.ID(), *packet, true)
+
+ default:
+ return fmt.Errorf("unexpected eth packet type: %T", packet)
+ }
+}
+
+// handleHeaders is invoked from a peer's message handler when it transmits a batch
+// of headers for the local node to process.
+func (h *ethHandler) handleHeaders(peer *eth.Peer, headers []*types.Header) error {
+ p := h.peers.ethPeer(peer.ID())
+ if p == nil {
+ return errors.New("unregistered during callback")
+ }
+ // If no headers were received, but we're expencting a checkpoint header, consider it that
+ if len(headers) == 0 && p.syncDrop != nil {
+ // Stop the timer either way, decide later to drop or not
+ p.syncDrop.Stop()
+ p.syncDrop = nil
+
+ // If we're doing a fast (or snap) sync, we must enforce the checkpoint block to avoid
+ // eclipse attacks. Unsynced nodes are welcome to connect after we're done
+ // joining the network
+ if atomic.LoadUint32(&h.fastSync) == 1 {
+ peer.Log().Warn("Dropping unsynced node during sync", "addr", peer.RemoteAddr(), "type", peer.Name())
+ return errors.New("unsynced node cannot serve sync")
+ }
+ }
+ // Filter out any explicitly requested headers, deliver the rest to the downloader
+ filter := len(headers) == 1
+ if filter {
+ // If it's a potential sync progress check, validate the content and advertised chain weight
+ if p.syncDrop != nil && headers[0].Number.Uint64() == h.checkpointNumber {
+ // Disable the sync drop timer
+ p.syncDrop.Stop()
+ p.syncDrop = nil
+
+ // Validate the header and either drop the peer or continue
+ if headers[0].Hash() != h.checkpointHash {
+ return errors.New("checkpoint hash mismatch")
+ }
+ return nil
+ }
+ // Otherwise if it's a whitelisted block, validate against the set
+ if want, ok := h.whitelist[headers[0].Number.Uint64()]; ok {
+ if hash := headers[0].Hash(); want != hash {
+ peer.Log().Info("Whitelist mismatch, dropping peer", "number", headers[0].Number.Uint64(), "hash", hash, "want", want)
+ return errors.New("whitelist block mismatch")
+ }
+ peer.Log().Debug("Whitelist block verified", "number", headers[0].Number.Uint64(), "hash", want)
+ }
+ // Irrelevant of the fork checks, send the header to the fetcher just in case
+ headers = h.blockFetcher.FilterHeaders(peer.ID(), headers, time.Now())
+ }
+ if len(headers) > 0 || !filter {
+ err := h.downloader.DeliverHeaders(peer.ID(), headers)
+ if err != nil {
+ log.Debug("Failed to deliver headers", "err", err)
+ }
+ }
+ return nil
+}
+
+// handleBodies is invoked from a peer's message handler when it transmits a batch
+// of block bodies for the local node to process.
+func (h *ethHandler) handleBodies(peer *eth.Peer, txs [][]*types.Transaction, uncles [][]*types.Header) error {
+ // Filter out any explicitly requested bodies, deliver the rest to the downloader
+ filter := len(txs) > 0 || len(uncles) > 0
+ if filter {
+ txs, uncles = h.blockFetcher.FilterBodies(peer.ID(), txs, uncles, time.Now())
+ }
+ if len(txs) > 0 || len(uncles) > 0 || !filter {
+ err := h.downloader.DeliverBodies(peer.ID(), txs, uncles)
+ if err != nil {
+ log.Debug("Failed to deliver bodies", "err", err)
+ }
+ }
+ return nil
+}
+
+// handleBlockAnnounces is invoked from a peer's message handler when it transmits a
+// batch of block announcements for the local node to process.
+func (h *ethHandler) handleBlockAnnounces(peer *eth.Peer, hashes []common.Hash, numbers []uint64) error {
+ // Schedule all the unknown hashes for retrieval
+ var (
+ unknownHashes = make([]common.Hash, 0, len(hashes))
+ unknownNumbers = make([]uint64, 0, len(numbers))
+ )
+ for i := 0; i < len(hashes); i++ {
+ if !h.chain.HasBlock(hashes[i], numbers[i]) {
+ unknownHashes = append(unknownHashes, hashes[i])
+ unknownNumbers = append(unknownNumbers, numbers[i])
+ }
+ }
+ for i := 0; i < len(unknownHashes); i++ {
+ h.blockFetcher.Notify(peer.ID(), unknownHashes[i], unknownNumbers[i], time.Now(), peer.RequestOneHeader, peer.RequestBodies)
+ }
+ return nil
+}
+
+// handleBlockBroadcast is invoked from a peer's message handler when it transmits a
+// block broadcast for the local node to process.
+func (h *ethHandler) handleBlockBroadcast(peer *eth.Peer, block *types.Block, td *big.Int) error {
+ // Schedule the block for import
+ h.blockFetcher.Enqueue(peer.ID(), block)
+
+ // Assuming the block is importable by the peer, but possibly not yet done so,
+ // calculate the head hash and TD that the peer truly must have.
+ var (
+ trueHead = block.ParentHash()
+ trueTD = new(big.Int).Sub(td, block.Difficulty())
+ )
+ // Update the peer's total difficulty if better than the previous
+ if _, td := peer.Head(); trueTD.Cmp(td) > 0 {
+ peer.SetHead(trueHead, trueTD)
+ h.chainSync.handlePeerEvent(peer)
+ }
+ return nil
+}
diff --git a/eth/handler_eth_test.go b/eth/handler_eth_test.go
new file mode 100644
index 000000000..0e5c0c90e
--- /dev/null
+++ b/eth/handler_eth_test.go
@@ -0,0 +1,740 @@
+// Copyright 2014 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see .
+
+package eth
+
+import (
+ "fmt"
+ "math/big"
+ "math/rand"
+ "sync/atomic"
+ "testing"
+ "time"
+
+ "github.com/ethereum/go-ethereum/common"
+ "github.com/ethereum/go-ethereum/consensus/ethash"
+ "github.com/ethereum/go-ethereum/core"
+ "github.com/ethereum/go-ethereum/core/forkid"
+ "github.com/ethereum/go-ethereum/core/rawdb"
+ "github.com/ethereum/go-ethereum/core/types"
+ "github.com/ethereum/go-ethereum/core/vm"
+ "github.com/ethereum/go-ethereum/eth/downloader"
+ "github.com/ethereum/go-ethereum/eth/protocols/eth"
+ "github.com/ethereum/go-ethereum/event"
+ "github.com/ethereum/go-ethereum/p2p"
+ "github.com/ethereum/go-ethereum/p2p/enode"
+ "github.com/ethereum/go-ethereum/params"
+ "github.com/ethereum/go-ethereum/trie"
+)
+
+// testEthHandler is a mock event handler to listen for inbound network requests
+// on the `eth` protocol and convert them into a more easily testable form.
+type testEthHandler struct {
+ blockBroadcasts event.Feed
+ txAnnounces event.Feed
+ txBroadcasts event.Feed
+}
+
+func (h *testEthHandler) Chain() *core.BlockChain { panic("no backing chain") }
+func (h *testEthHandler) StateBloom() *trie.SyncBloom { panic("no backing state bloom") }
+func (h *testEthHandler) TxPool() eth.TxPool { panic("no backing tx pool") }
+func (h *testEthHandler) AcceptTxs() bool { return true }
+func (h *testEthHandler) RunPeer(*eth.Peer, eth.Handler) error { panic("not used in tests") }
+func (h *testEthHandler) PeerInfo(enode.ID) interface{} { panic("not used in tests") }
+
+func (h *testEthHandler) Handle(peer *eth.Peer, packet eth.Packet) error {
+ switch packet := packet.(type) {
+ case *eth.NewBlockPacket:
+ h.blockBroadcasts.Send(packet.Block)
+ return nil
+
+ case *eth.NewPooledTransactionHashesPacket:
+ h.txAnnounces.Send(([]common.Hash)(*packet))
+ return nil
+
+ case *eth.TransactionsPacket:
+ h.txBroadcasts.Send(([]*types.Transaction)(*packet))
+ return nil
+
+ case *eth.PooledTransactionsPacket:
+ h.txBroadcasts.Send(([]*types.Transaction)(*packet))
+ return nil
+
+ default:
+ panic(fmt.Sprintf("unexpected eth packet type in tests: %T", packet))
+ }
+}
+
+// Tests that peers are correctly accepted (or rejected) based on the advertised
+// fork IDs in the protocol handshake.
+func TestForkIDSplit64(t *testing.T) { testForkIDSplit(t, 64) }
+func TestForkIDSplit65(t *testing.T) { testForkIDSplit(t, 65) }
+
+func testForkIDSplit(t *testing.T, protocol uint) {
+ t.Parallel()
+
+ var (
+ engine = ethash.NewFaker()
+
+ configNoFork = ¶ms.ChainConfig{HomesteadBlock: big.NewInt(1)}
+ configProFork = ¶ms.ChainConfig{
+ HomesteadBlock: big.NewInt(1),
+ EIP150Block: big.NewInt(2),
+ EIP155Block: big.NewInt(2),
+ EIP158Block: big.NewInt(2),
+ ByzantiumBlock: big.NewInt(3),
+ }
+ dbNoFork = rawdb.NewMemoryDatabase()
+ dbProFork = rawdb.NewMemoryDatabase()
+
+ gspecNoFork = &core.Genesis{Config: configNoFork}
+ gspecProFork = &core.Genesis{Config: configProFork}
+
+ genesisNoFork = gspecNoFork.MustCommit(dbNoFork)
+ genesisProFork = gspecProFork.MustCommit(dbProFork)
+
+ chainNoFork, _ = core.NewBlockChain(dbNoFork, nil, configNoFork, engine, vm.Config{}, nil, nil)
+ chainProFork, _ = core.NewBlockChain(dbProFork, nil, configProFork, engine, vm.Config{}, nil, nil)
+
+ blocksNoFork, _ = core.GenerateChain(configNoFork, genesisNoFork, engine, dbNoFork, 2, nil)
+ blocksProFork, _ = core.GenerateChain(configProFork, genesisProFork, engine, dbProFork, 2, nil)
+
+ ethNoFork, _ = newHandler(&handlerConfig{
+ Database: dbNoFork,
+ Chain: chainNoFork,
+ TxPool: newTestTxPool(),
+ Network: 1,
+ Sync: downloader.FullSync,
+ BloomCache: 1,
+ })
+ ethProFork, _ = newHandler(&handlerConfig{
+ Database: dbProFork,
+ Chain: chainProFork,
+ TxPool: newTestTxPool(),
+ Network: 1,
+ Sync: downloader.FullSync,
+ BloomCache: 1,
+ })
+ )
+ ethNoFork.Start(1000)
+ ethProFork.Start(1000)
+
+ // Clean up everything after ourselves
+ defer chainNoFork.Stop()
+ defer chainProFork.Stop()
+
+ defer ethNoFork.Stop()
+ defer ethProFork.Stop()
+
+ // Both nodes should allow the other to connect (same genesis, next fork is the same)
+ p2pNoFork, p2pProFork := p2p.MsgPipe()
+ defer p2pNoFork.Close()
+ defer p2pProFork.Close()
+
+ peerNoFork := eth.NewPeer(protocol, p2p.NewPeer(enode.ID{1}, "", nil), p2pNoFork, nil)
+ peerProFork := eth.NewPeer(protocol, p2p.NewPeer(enode.ID{2}, "", nil), p2pProFork, nil)
+ defer peerNoFork.Close()
+ defer peerProFork.Close()
+
+ errc := make(chan error, 2)
+ go func(errc chan error) {
+ errc <- ethNoFork.runEthPeer(peerProFork, func(peer *eth.Peer) error { return nil })
+ }(errc)
+ go func(errc chan error) {
+ errc <- ethProFork.runEthPeer(peerNoFork, func(peer *eth.Peer) error { return nil })
+ }(errc)
+
+ for i := 0; i < 2; i++ {
+ select {
+ case err := <-errc:
+ if err != nil {
+ t.Fatalf("frontier nofork <-> profork failed: %v", err)
+ }
+ case <-time.After(250 * time.Millisecond):
+ t.Fatalf("frontier nofork <-> profork handler timeout")
+ }
+ }
+ // Progress into Homestead. Fork's match, so we don't care what the future holds
+ chainNoFork.InsertChain(blocksNoFork[:1])
+ chainProFork.InsertChain(blocksProFork[:1])
+
+ p2pNoFork, p2pProFork = p2p.MsgPipe()
+ defer p2pNoFork.Close()
+ defer p2pProFork.Close()
+
+ peerNoFork = eth.NewPeer(protocol, p2p.NewPeer(enode.ID{1}, "", nil), p2pNoFork, nil)
+ peerProFork = eth.NewPeer(protocol, p2p.NewPeer(enode.ID{2}, "", nil), p2pProFork, nil)
+ defer peerNoFork.Close()
+ defer peerProFork.Close()
+
+ errc = make(chan error, 2)
+ go func(errc chan error) {
+ errc <- ethNoFork.runEthPeer(peerProFork, func(peer *eth.Peer) error { return nil })
+ }(errc)
+ go func(errc chan error) {
+ errc <- ethProFork.runEthPeer(peerNoFork, func(peer *eth.Peer) error { return nil })
+ }(errc)
+
+ for i := 0; i < 2; i++ {
+ select {
+ case err := <-errc:
+ if err != nil {
+ t.Fatalf("homestead nofork <-> profork failed: %v", err)
+ }
+ case <-time.After(250 * time.Millisecond):
+ t.Fatalf("homestead nofork <-> profork handler timeout")
+ }
+ }
+ // Progress into Spurious. Forks mismatch, signalling differing chains, reject
+ chainNoFork.InsertChain(blocksNoFork[1:2])
+ chainProFork.InsertChain(blocksProFork[1:2])
+
+ p2pNoFork, p2pProFork = p2p.MsgPipe()
+ defer p2pNoFork.Close()
+ defer p2pProFork.Close()
+
+ peerNoFork = eth.NewPeer(protocol, p2p.NewPeer(enode.ID{1}, "", nil), p2pNoFork, nil)
+ peerProFork = eth.NewPeer(protocol, p2p.NewPeer(enode.ID{2}, "", nil), p2pProFork, nil)
+ defer peerNoFork.Close()
+ defer peerProFork.Close()
+
+ errc = make(chan error, 2)
+ go func(errc chan error) {
+ errc <- ethNoFork.runEthPeer(peerProFork, func(peer *eth.Peer) error { return nil })
+ }(errc)
+ go func(errc chan error) {
+ errc <- ethProFork.runEthPeer(peerNoFork, func(peer *eth.Peer) error { return nil })
+ }(errc)
+
+ var successes int
+ for i := 0; i < 2; i++ {
+ select {
+ case err := <-errc:
+ if err == nil {
+ successes++
+ if successes == 2 { // Only one side disconnects
+ t.Fatalf("fork ID rejection didn't happen")
+ }
+ }
+ case <-time.After(250 * time.Millisecond):
+ t.Fatalf("split peers not rejected")
+ }
+ }
+}
+
+// Tests that received transactions are added to the local pool.
+func TestRecvTransactions64(t *testing.T) { testRecvTransactions(t, 64) }
+func TestRecvTransactions65(t *testing.T) { testRecvTransactions(t, 65) }
+
+func testRecvTransactions(t *testing.T, protocol uint) {
+ t.Parallel()
+
+ // Create a message handler, configure it to accept transactions and watch them
+ handler := newTestHandler()
+ defer handler.close()
+
+ handler.handler.acceptTxs = 1 // mark synced to accept transactions
+
+ txs := make(chan core.NewTxsEvent)
+ sub := handler.txpool.SubscribeNewTxsEvent(txs)
+ defer sub.Unsubscribe()
+
+ // Create a source peer to send messages through and a sink handler to receive them
+ p2pSrc, p2pSink := p2p.MsgPipe()
+ defer p2pSrc.Close()
+ defer p2pSink.Close()
+
+ src := eth.NewPeer(protocol, p2p.NewPeer(enode.ID{1}, "", nil), p2pSrc, handler.txpool)
+ sink := eth.NewPeer(protocol, p2p.NewPeer(enode.ID{2}, "", nil), p2pSink, handler.txpool)
+ defer src.Close()
+ defer sink.Close()
+
+ go handler.handler.runEthPeer(sink, func(peer *eth.Peer) error {
+ return eth.Handle((*ethHandler)(handler.handler), peer)
+ })
+ // Run the handshake locally to avoid spinning up a source handler
+ var (
+ genesis = handler.chain.Genesis()
+ head = handler.chain.CurrentBlock()
+ td = handler.chain.GetTd(head.Hash(), head.NumberU64())
+ )
+ if err := src.Handshake(1, td, head.Hash(), genesis.Hash(), forkid.NewIDWithChain(handler.chain), forkid.NewFilter(handler.chain)); err != nil {
+ t.Fatalf("failed to run protocol handshake")
+ }
+ // Send the transaction to the sink and verify that it's added to the tx pool
+ tx := types.NewTransaction(0, common.Address{}, big.NewInt(0), 100000, big.NewInt(0), nil)
+ tx, _ = types.SignTx(tx, types.HomesteadSigner{}, testKey)
+
+ if err := src.SendTransactions([]*types.Transaction{tx}); err != nil {
+ t.Fatalf("failed to send transaction: %v", err)
+ }
+ select {
+ case event := <-txs:
+ if len(event.Txs) != 1 {
+ t.Errorf("wrong number of added transactions: got %d, want 1", len(event.Txs))
+ } else if event.Txs[0].Hash() != tx.Hash() {
+ t.Errorf("added wrong tx hash: got %v, want %v", event.Txs[0].Hash(), tx.Hash())
+ }
+ case <-time.After(2 * time.Second):
+ t.Errorf("no NewTxsEvent received within 2 seconds")
+ }
+}
+
+// This test checks that pending transactions are sent.
+func TestSendTransactions64(t *testing.T) { testSendTransactions(t, 64) }
+func TestSendTransactions65(t *testing.T) { testSendTransactions(t, 65) }
+
+func testSendTransactions(t *testing.T, protocol uint) {
+ t.Parallel()
+
+ // Create a message handler and fill the pool with big transactions
+ handler := newTestHandler()
+ defer handler.close()
+
+ insert := make([]*types.Transaction, 100)
+ for nonce := range insert {
+ tx := types.NewTransaction(uint64(nonce), common.Address{}, big.NewInt(0), 100000, big.NewInt(0), make([]byte, txsyncPackSize/10))
+ tx, _ = types.SignTx(tx, types.HomesteadSigner{}, testKey)
+
+ insert[nonce] = tx
+ }
+ go handler.txpool.AddRemotes(insert) // Need goroutine to not block on feed
+ time.Sleep(250 * time.Millisecond) // Wait until tx events get out of the system (can't use events, tx broadcaster races with peer join)
+
+ // Create a source handler to send messages through and a sink peer to receive them
+ p2pSrc, p2pSink := p2p.MsgPipe()
+ defer p2pSrc.Close()
+ defer p2pSink.Close()
+
+ src := eth.NewPeer(protocol, p2p.NewPeer(enode.ID{1}, "", nil), p2pSrc, handler.txpool)
+ sink := eth.NewPeer(protocol, p2p.NewPeer(enode.ID{2}, "", nil), p2pSink, handler.txpool)
+ defer src.Close()
+ defer sink.Close()
+
+ go handler.handler.runEthPeer(src, func(peer *eth.Peer) error {
+ return eth.Handle((*ethHandler)(handler.handler), peer)
+ })
+ // Run the handshake locally to avoid spinning up a source handler
+ var (
+ genesis = handler.chain.Genesis()
+ head = handler.chain.CurrentBlock()
+ td = handler.chain.GetTd(head.Hash(), head.NumberU64())
+ )
+ if err := sink.Handshake(1, td, head.Hash(), genesis.Hash(), forkid.NewIDWithChain(handler.chain), forkid.NewFilter(handler.chain)); err != nil {
+ t.Fatalf("failed to run protocol handshake")
+ }
+ // After the handshake completes, the source handler should stream the sink
+ // the transactions, subscribe to all inbound network events
+ backend := new(testEthHandler)
+
+ anns := make(chan []common.Hash)
+ annSub := backend.txAnnounces.Subscribe(anns)
+ defer annSub.Unsubscribe()
+
+ bcasts := make(chan []*types.Transaction)
+ bcastSub := backend.txBroadcasts.Subscribe(bcasts)
+ defer bcastSub.Unsubscribe()
+
+ go eth.Handle(backend, sink)
+
+ // Make sure we get all the transactions on the correct channels
+ seen := make(map[common.Hash]struct{})
+ for len(seen) < len(insert) {
+ switch protocol {
+ case 63, 64:
+ select {
+ case <-anns:
+ t.Errorf("tx announce received on pre eth/65")
+ case txs := <-bcasts:
+ for _, tx := range txs {
+ if _, ok := seen[tx.Hash()]; ok {
+ t.Errorf("duplicate transaction announced: %x", tx.Hash())
+ }
+ seen[tx.Hash()] = struct{}{}
+ }
+ }
+ case 65:
+ select {
+ case hashes := <-anns:
+ for _, hash := range hashes {
+ if _, ok := seen[hash]; ok {
+ t.Errorf("duplicate transaction announced: %x", hash)
+ }
+ seen[hash] = struct{}{}
+ }
+ case <-bcasts:
+ t.Errorf("initial tx broadcast received on post eth/65")
+ }
+
+ default:
+ panic("unsupported protocol, please extend test")
+ }
+ }
+ for _, tx := range insert {
+ if _, ok := seen[tx.Hash()]; !ok {
+ t.Errorf("missing transaction: %x", tx.Hash())
+ }
+ }
+}
+
+// Tests that transactions get propagated to all attached peers, either via direct
+// broadcasts or via announcements/retrievals.
+func TestTransactionPropagation64(t *testing.T) { testTransactionPropagation(t, 64) }
+func TestTransactionPropagation65(t *testing.T) { testTransactionPropagation(t, 65) }
+
+func testTransactionPropagation(t *testing.T, protocol uint) {
+ t.Parallel()
+
+ // Create a source handler to send transactions from and a number of sinks
+ // to receive them. We need multiple sinks since a one-to-one peering would
+ // broadcast all transactions without announcement.
+ source := newTestHandler()
+ defer source.close()
+
+ sinks := make([]*testHandler, 10)
+ for i := 0; i < len(sinks); i++ {
+ sinks[i] = newTestHandler()
+ defer sinks[i].close()
+
+ sinks[i].handler.acceptTxs = 1 // mark synced to accept transactions
+ }
+ // Interconnect all the sink handlers with the source handler
+ for i, sink := range sinks {
+ sink := sink // Closure for gorotuine below
+
+ sourcePipe, sinkPipe := p2p.MsgPipe()
+ defer sourcePipe.Close()
+ defer sinkPipe.Close()
+
+ sourcePeer := eth.NewPeer(protocol, p2p.NewPeer(enode.ID{byte(i)}, "", nil), sourcePipe, source.txpool)
+ sinkPeer := eth.NewPeer(protocol, p2p.NewPeer(enode.ID{0}, "", nil), sinkPipe, sink.txpool)
+ defer sourcePeer.Close()
+ defer sinkPeer.Close()
+
+ go source.handler.runEthPeer(sourcePeer, func(peer *eth.Peer) error {
+ return eth.Handle((*ethHandler)(source.handler), peer)
+ })
+ go sink.handler.runEthPeer(sinkPeer, func(peer *eth.Peer) error {
+ return eth.Handle((*ethHandler)(sink.handler), peer)
+ })
+ }
+ // Subscribe to all the transaction pools
+ txChs := make([]chan core.NewTxsEvent, len(sinks))
+ for i := 0; i < len(sinks); i++ {
+ txChs[i] = make(chan core.NewTxsEvent, 1024)
+
+ sub := sinks[i].txpool.SubscribeNewTxsEvent(txChs[i])
+ defer sub.Unsubscribe()
+ }
+ // Fill the source pool with transactions and wait for them at the sinks
+ txs := make([]*types.Transaction, 1024)
+ for nonce := range txs {
+ tx := types.NewTransaction(uint64(nonce), common.Address{}, big.NewInt(0), 100000, big.NewInt(0), nil)
+ tx, _ = types.SignTx(tx, types.HomesteadSigner{}, testKey)
+
+ txs[nonce] = tx
+ }
+ source.txpool.AddRemotes(txs)
+
+ // Iterate through all the sinks and ensure they all got the transactions
+ for i := range sinks {
+ for arrived := 0; arrived < len(txs); {
+ select {
+ case event := <-txChs[i]:
+ arrived += len(event.Txs)
+ case <-time.NewTimer(time.Second).C:
+ t.Errorf("sink %d: transaction propagation timed out: have %d, want %d", i, arrived, len(txs))
+ }
+ }
+ }
+}
+
+// Tests that post eth protocol handshake, clients perform a mutual checkpoint
+// challenge to validate each other's chains. Hash mismatches, or missing ones
+// during a fast sync should lead to the peer getting dropped.
+func TestCheckpointChallenge(t *testing.T) {
+ tests := []struct {
+ syncmode downloader.SyncMode
+ checkpoint bool
+ timeout bool
+ empty bool
+ match bool
+ drop bool
+ }{
+ // If checkpointing is not enabled locally, don't challenge and don't drop
+ {downloader.FullSync, false, false, false, false, false},
+ {downloader.FastSync, false, false, false, false, false},
+
+ // If checkpointing is enabled locally and remote response is empty, only drop during fast sync
+ {downloader.FullSync, true, false, true, false, false},
+ {downloader.FastSync, true, false, true, false, true}, // Special case, fast sync, unsynced peer
+
+ // If checkpointing is enabled locally and remote response mismatches, always drop
+ {downloader.FullSync, true, false, false, false, true},
+ {downloader.FastSync, true, false, false, false, true},
+
+ // If checkpointing is enabled locally and remote response matches, never drop
+ {downloader.FullSync, true, false, false, true, false},
+ {downloader.FastSync, true, false, false, true, false},
+
+ // If checkpointing is enabled locally and remote times out, always drop
+ {downloader.FullSync, true, true, false, true, true},
+ {downloader.FastSync, true, true, false, true, true},
+ }
+ for _, tt := range tests {
+ t.Run(fmt.Sprintf("sync %v checkpoint %v timeout %v empty %v match %v", tt.syncmode, tt.checkpoint, tt.timeout, tt.empty, tt.match), func(t *testing.T) {
+ testCheckpointChallenge(t, tt.syncmode, tt.checkpoint, tt.timeout, tt.empty, tt.match, tt.drop)
+ })
+ }
+}
+
+func testCheckpointChallenge(t *testing.T, syncmode downloader.SyncMode, checkpoint bool, timeout bool, empty bool, match bool, drop bool) {
+ // Reduce the checkpoint handshake challenge timeout
+ defer func(old time.Duration) { syncChallengeTimeout = old }(syncChallengeTimeout)
+ syncChallengeTimeout = 250 * time.Millisecond
+
+ // Create a test handler and inject a CHT into it. The injection is a bit
+ // ugly, but it beats creating everything manually just to avoid reaching
+ // into the internals a bit.
+ handler := newTestHandler()
+ defer handler.close()
+
+ if syncmode == downloader.FastSync {
+ atomic.StoreUint32(&handler.handler.fastSync, 1)
+ } else {
+ atomic.StoreUint32(&handler.handler.fastSync, 0)
+ }
+ var response *types.Header
+ if checkpoint {
+ number := (uint64(rand.Intn(500))+1)*params.CHTFrequency - 1
+ response = &types.Header{Number: big.NewInt(int64(number)), Extra: []byte("valid")}
+
+ handler.handler.checkpointNumber = number
+ handler.handler.checkpointHash = response.Hash()
+ }
+ // Create a challenger peer and a challenged one
+ p2pLocal, p2pRemote := p2p.MsgPipe()
+ defer p2pLocal.Close()
+ defer p2pRemote.Close()
+
+ local := eth.NewPeer(eth.ETH64, p2p.NewPeer(enode.ID{1}, "", nil), p2pLocal, handler.txpool)
+ remote := eth.NewPeer(eth.ETH64, p2p.NewPeer(enode.ID{2}, "", nil), p2pRemote, handler.txpool)
+ defer local.Close()
+ defer remote.Close()
+
+ go handler.handler.runEthPeer(local, func(peer *eth.Peer) error {
+ return eth.Handle((*ethHandler)(handler.handler), peer)
+ })
+ // Run the handshake locally to avoid spinning up a remote handler
+ var (
+ genesis = handler.chain.Genesis()
+ head = handler.chain.CurrentBlock()
+ td = handler.chain.GetTd(head.Hash(), head.NumberU64())
+ )
+ if err := remote.Handshake(1, td, head.Hash(), genesis.Hash(), forkid.NewIDWithChain(handler.chain), forkid.NewFilter(handler.chain)); err != nil {
+ t.Fatalf("failed to run protocol handshake")
+ }
+ // Connect a new peer and check that we receive the checkpoint challenge
+ if checkpoint {
+ if err := remote.ExpectRequestHeadersByNumber(response.Number.Uint64(), 1, 0, false); err != nil {
+ t.Fatalf("challenge mismatch: %v", err)
+ }
+ // Create a block to reply to the challenge if no timeout is simulated
+ if !timeout {
+ if empty {
+ if err := remote.SendBlockHeaders([]*types.Header{}); err != nil {
+ t.Fatalf("failed to answer challenge: %v", err)
+ }
+ } else if match {
+ if err := remote.SendBlockHeaders([]*types.Header{response}); err != nil {
+ t.Fatalf("failed to answer challenge: %v", err)
+ }
+ } else {
+ if err := remote.SendBlockHeaders([]*types.Header{{Number: response.Number}}); err != nil {
+ t.Fatalf("failed to answer challenge: %v", err)
+ }
+ }
+ }
+ }
+ // Wait until the test timeout passes to ensure proper cleanup
+ time.Sleep(syncChallengeTimeout + 300*time.Millisecond)
+
+ // Verify that the remote peer is maintained or dropped
+ if drop {
+ if peers := handler.handler.peers.Len(); peers != 0 {
+ t.Fatalf("peer count mismatch: have %d, want %d", peers, 0)
+ }
+ } else {
+ if peers := handler.handler.peers.Len(); peers != 1 {
+ t.Fatalf("peer count mismatch: have %d, want %d", peers, 1)
+ }
+ }
+}
+
+// Tests that blocks are broadcast to a sqrt number of peers only.
+func TestBroadcastBlock1Peer(t *testing.T) { testBroadcastBlock(t, 1, 1) }
+func TestBroadcastBlock2Peers(t *testing.T) { testBroadcastBlock(t, 2, 1) }
+func TestBroadcastBlock3Peers(t *testing.T) { testBroadcastBlock(t, 3, 1) }
+func TestBroadcastBlock4Peers(t *testing.T) { testBroadcastBlock(t, 4, 2) }
+func TestBroadcastBlock5Peers(t *testing.T) { testBroadcastBlock(t, 5, 2) }
+func TestBroadcastBlock8Peers(t *testing.T) { testBroadcastBlock(t, 9, 3) }
+func TestBroadcastBlock12Peers(t *testing.T) { testBroadcastBlock(t, 12, 3) }
+func TestBroadcastBlock16Peers(t *testing.T) { testBroadcastBlock(t, 16, 4) }
+func TestBroadcastBloc26Peers(t *testing.T) { testBroadcastBlock(t, 26, 5) }
+func TestBroadcastBlock100Peers(t *testing.T) { testBroadcastBlock(t, 100, 10) }
+
+func testBroadcastBlock(t *testing.T, peers, bcasts int) {
+ t.Parallel()
+
+ // Create a source handler to broadcast blocks from and a number of sinks
+ // to receive them.
+ source := newTestHandlerWithBlocks(1)
+ defer source.close()
+
+ sinks := make([]*testEthHandler, peers)
+ for i := 0; i < len(sinks); i++ {
+ sinks[i] = new(testEthHandler)
+ }
+ // Interconnect all the sink handlers with the source handler
+ var (
+ genesis = source.chain.Genesis()
+ td = source.chain.GetTd(genesis.Hash(), genesis.NumberU64())
+ )
+ for i, sink := range sinks {
+ sink := sink // Closure for gorotuine below
+
+ sourcePipe, sinkPipe := p2p.MsgPipe()
+ defer sourcePipe.Close()
+ defer sinkPipe.Close()
+
+ sourcePeer := eth.NewPeer(eth.ETH64, p2p.NewPeer(enode.ID{byte(i)}, "", nil), sourcePipe, nil)
+ sinkPeer := eth.NewPeer(eth.ETH64, p2p.NewPeer(enode.ID{0}, "", nil), sinkPipe, nil)
+ defer sourcePeer.Close()
+ defer sinkPeer.Close()
+
+ go source.handler.runEthPeer(sourcePeer, func(peer *eth.Peer) error {
+ return eth.Handle((*ethHandler)(source.handler), peer)
+ })
+ if err := sinkPeer.Handshake(1, td, genesis.Hash(), genesis.Hash(), forkid.NewIDWithChain(source.chain), forkid.NewFilter(source.chain)); err != nil {
+ t.Fatalf("failed to run protocol handshake")
+ }
+ go eth.Handle(sink, sinkPeer)
+ }
+ // Subscribe to all the transaction pools
+ blockChs := make([]chan *types.Block, len(sinks))
+ for i := 0; i < len(sinks); i++ {
+ blockChs[i] = make(chan *types.Block, 1)
+ defer close(blockChs[i])
+
+ sub := sinks[i].blockBroadcasts.Subscribe(blockChs[i])
+ defer sub.Unsubscribe()
+ }
+ // Initiate a block propagation across the peers
+ time.Sleep(100 * time.Millisecond)
+ source.handler.BroadcastBlock(source.chain.CurrentBlock(), true)
+
+ // Iterate through all the sinks and ensure the correct number got the block
+ done := make(chan struct{}, peers)
+ for _, ch := range blockChs {
+ ch := ch
+ go func() {
+ <-ch
+ done <- struct{}{}
+ }()
+ }
+ var received int
+ for {
+ select {
+ case <-done:
+ received++
+
+ case <-time.After(100 * time.Millisecond):
+ if received != bcasts {
+ t.Errorf("broadcast count mismatch: have %d, want %d", received, bcasts)
+ }
+ return
+ }
+ }
+}
+
+// Tests that a propagated malformed block (uncles or transactions don't match
+// with the hashes in the header) gets discarded and not broadcast forward.
+func TestBroadcastMalformedBlock64(t *testing.T) { testBroadcastMalformedBlock(t, 64) }
+func TestBroadcastMalformedBlock65(t *testing.T) { testBroadcastMalformedBlock(t, 65) }
+
+func testBroadcastMalformedBlock(t *testing.T, protocol uint) {
+ t.Parallel()
+
+ // Create a source handler to broadcast blocks from and a number of sinks
+ // to receive them.
+ source := newTestHandlerWithBlocks(1)
+ defer source.close()
+
+ // Create a source handler to send messages through and a sink peer to receive them
+ p2pSrc, p2pSink := p2p.MsgPipe()
+ defer p2pSrc.Close()
+ defer p2pSink.Close()
+
+ src := eth.NewPeer(protocol, p2p.NewPeer(enode.ID{1}, "", nil), p2pSrc, source.txpool)
+ sink := eth.NewPeer(protocol, p2p.NewPeer(enode.ID{2}, "", nil), p2pSink, source.txpool)
+ defer src.Close()
+ defer sink.Close()
+
+ go source.handler.runEthPeer(src, func(peer *eth.Peer) error {
+ return eth.Handle((*ethHandler)(source.handler), peer)
+ })
+ // Run the handshake locally to avoid spinning up a sink handler
+ var (
+ genesis = source.chain.Genesis()
+ td = source.chain.GetTd(genesis.Hash(), genesis.NumberU64())
+ )
+ if err := sink.Handshake(1, td, genesis.Hash(), genesis.Hash(), forkid.NewIDWithChain(source.chain), forkid.NewFilter(source.chain)); err != nil {
+ t.Fatalf("failed to run protocol handshake")
+ }
+ // After the handshake completes, the source handler should stream the sink
+ // the blocks, subscribe to inbound network events
+ backend := new(testEthHandler)
+
+ blocks := make(chan *types.Block, 1)
+ sub := backend.blockBroadcasts.Subscribe(blocks)
+ defer sub.Unsubscribe()
+
+ go eth.Handle(backend, sink)
+
+ // Create various combinations of malformed blocks
+ head := source.chain.CurrentBlock()
+
+ malformedUncles := head.Header()
+ malformedUncles.UncleHash[0]++
+ malformedTransactions := head.Header()
+ malformedTransactions.TxHash[0]++
+ malformedEverything := head.Header()
+ malformedEverything.UncleHash[0]++
+ malformedEverything.TxHash[0]++
+
+ // Try to broadcast all malformations and ensure they all get discarded
+ for _, header := range []*types.Header{malformedUncles, malformedTransactions, malformedEverything} {
+ block := types.NewBlockWithHeader(header).WithBody(head.Transactions(), head.Uncles())
+ if err := src.SendNewBlock(block, big.NewInt(131136)); err != nil {
+ t.Fatalf("failed to broadcast block: %v", err)
+ }
+ select {
+ case <-blocks:
+ t.Fatalf("malformed block forwarded")
+ case <-time.After(100 * time.Millisecond):
+ }
+ }
+}
diff --git a/eth/handler_snap.go b/eth/handler_snap.go
new file mode 100644
index 000000000..25975bf60
--- /dev/null
+++ b/eth/handler_snap.go
@@ -0,0 +1,48 @@
+// Copyright 2020 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see .
+
+package eth
+
+import (
+ "github.com/ethereum/go-ethereum/core"
+ "github.com/ethereum/go-ethereum/eth/protocols/snap"
+ "github.com/ethereum/go-ethereum/p2p/enode"
+)
+
+// snapHandler implements the snap.Backend interface to handle the various network
+// packets that are sent as replies or broadcasts.
+type snapHandler handler
+
+func (h *snapHandler) Chain() *core.BlockChain { return h.chain }
+
+// RunPeer is invoked when a peer joins on the `snap` protocol.
+func (h *snapHandler) RunPeer(peer *snap.Peer, hand snap.Handler) error {
+ return (*handler)(h).runSnapPeer(peer, hand)
+}
+
+// PeerInfo retrieves all known `snap` information about a peer.
+func (h *snapHandler) PeerInfo(id enode.ID) interface{} {
+ if p := h.peers.snapPeer(id.String()); p != nil {
+ return p.info()
+ }
+ return nil
+}
+
+// Handle is invoked from a peer's message handler when it receives a new remote
+// message that the handler couldn't consume and serve itself.
+func (h *snapHandler) Handle(peer *snap.Peer, packet snap.Packet) error {
+ return h.downloader.DeliverSnapPacket(peer, packet)
+}
diff --git a/eth/handler_test.go b/eth/handler_test.go
index fc6c6f274..a90ef5c34 100644
--- a/eth/handler_test.go
+++ b/eth/handler_test.go
@@ -17,678 +17,154 @@
package eth
import (
- "fmt"
- "math"
"math/big"
- "math/rand"
- "testing"
- "time"
+ "sort"
+ "sync"
"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/consensus/ethash"
"github.com/ethereum/go-ethereum/core"
"github.com/ethereum/go-ethereum/core/rawdb"
- "github.com/ethereum/go-ethereum/core/state"
"github.com/ethereum/go-ethereum/core/types"
"github.com/ethereum/go-ethereum/core/vm"
"github.com/ethereum/go-ethereum/crypto"
"github.com/ethereum/go-ethereum/eth/downloader"
+ "github.com/ethereum/go-ethereum/ethdb"
"github.com/ethereum/go-ethereum/event"
- "github.com/ethereum/go-ethereum/p2p"
"github.com/ethereum/go-ethereum/params"
)
-// Tests that block headers can be retrieved from a remote chain based on user queries.
-func TestGetBlockHeaders63(t *testing.T) { testGetBlockHeaders(t, 63) }
-func TestGetBlockHeaders64(t *testing.T) { testGetBlockHeaders(t, 64) }
+var (
+ // testKey is a private key to use for funding a tester account.
+ testKey, _ = crypto.HexToECDSA("b71c71a67e1177ad4e901695e1b4b9ee17ae16c6668d313eac2f96dbcda3f291")
-func testGetBlockHeaders(t *testing.T, protocol int) {
- pm, _ := newTestProtocolManagerMust(t, downloader.FullSync, downloader.MaxHashFetch+15, nil, nil)
- peer, _ := newTestPeer("peer", protocol, pm, true)
- defer peer.close()
+ // testAddr is the Ethereum address of the tester account.
+ testAddr = crypto.PubkeyToAddress(testKey.PublicKey)
+)
- // Create a "random" unknown hash for testing
- var unknown common.Hash
- for i := range unknown {
- unknown[i] = byte(i)
- }
- // Create a batch of tests for various scenarios
- limit := uint64(downloader.MaxHeaderFetch)
- tests := []struct {
- query *getBlockHeadersData // The query to execute for header retrieval
- expect []common.Hash // The hashes of the block whose headers are expected
- }{
- // A single random block should be retrievable by hash and number too
- {
- &getBlockHeadersData{Origin: hashOrNumber{Hash: pm.blockchain.GetBlockByNumber(limit / 2).Hash()}, Amount: 1},
- []common.Hash{pm.blockchain.GetBlockByNumber(limit / 2).Hash()},
- }, {
- &getBlockHeadersData{Origin: hashOrNumber{Number: limit / 2}, Amount: 1},
- []common.Hash{pm.blockchain.GetBlockByNumber(limit / 2).Hash()},
- },
- // Multiple headers should be retrievable in both directions
- {
- &getBlockHeadersData{Origin: hashOrNumber{Number: limit / 2}, Amount: 3},
- []common.Hash{
- pm.blockchain.GetBlockByNumber(limit / 2).Hash(),
- pm.blockchain.GetBlockByNumber(limit/2 + 1).Hash(),
- pm.blockchain.GetBlockByNumber(limit/2 + 2).Hash(),
- },
- }, {
- &getBlockHeadersData{Origin: hashOrNumber{Number: limit / 2}, Amount: 3, Reverse: true},
- []common.Hash{
- pm.blockchain.GetBlockByNumber(limit / 2).Hash(),
- pm.blockchain.GetBlockByNumber(limit/2 - 1).Hash(),
- pm.blockchain.GetBlockByNumber(limit/2 - 2).Hash(),
- },
- },
- // Multiple headers with skip lists should be retrievable
- {
- &getBlockHeadersData{Origin: hashOrNumber{Number: limit / 2}, Skip: 3, Amount: 3},
- []common.Hash{
- pm.blockchain.GetBlockByNumber(limit / 2).Hash(),
- pm.blockchain.GetBlockByNumber(limit/2 + 4).Hash(),
- pm.blockchain.GetBlockByNumber(limit/2 + 8).Hash(),
- },
- }, {
- &getBlockHeadersData{Origin: hashOrNumber{Number: limit / 2}, Skip: 3, Amount: 3, Reverse: true},
- []common.Hash{
- pm.blockchain.GetBlockByNumber(limit / 2).Hash(),
- pm.blockchain.GetBlockByNumber(limit/2 - 4).Hash(),
- pm.blockchain.GetBlockByNumber(limit/2 - 8).Hash(),
- },
- },
- // The chain endpoints should be retrievable
- {
- &getBlockHeadersData{Origin: hashOrNumber{Number: 0}, Amount: 1},
- []common.Hash{pm.blockchain.GetBlockByNumber(0).Hash()},
- }, {
- &getBlockHeadersData{Origin: hashOrNumber{Number: pm.blockchain.CurrentBlock().NumberU64()}, Amount: 1},
- []common.Hash{pm.blockchain.CurrentBlock().Hash()},
- },
- // Ensure protocol limits are honored
- {
- &getBlockHeadersData{Origin: hashOrNumber{Number: pm.blockchain.CurrentBlock().NumberU64() - 1}, Amount: limit + 10, Reverse: true},
- pm.blockchain.GetBlockHashesFromHash(pm.blockchain.CurrentBlock().Hash(), limit),
- },
- // Check that requesting more than available is handled gracefully
- {
- &getBlockHeadersData{Origin: hashOrNumber{Number: pm.blockchain.CurrentBlock().NumberU64() - 4}, Skip: 3, Amount: 3},
- []common.Hash{
- pm.blockchain.GetBlockByNumber(pm.blockchain.CurrentBlock().NumberU64() - 4).Hash(),
- pm.blockchain.GetBlockByNumber(pm.blockchain.CurrentBlock().NumberU64()).Hash(),
- },
- }, {
- &getBlockHeadersData{Origin: hashOrNumber{Number: 4}, Skip: 3, Amount: 3, Reverse: true},
- []common.Hash{
- pm.blockchain.GetBlockByNumber(4).Hash(),
- pm.blockchain.GetBlockByNumber(0).Hash(),
- },
- },
- // Check that requesting more than available is handled gracefully, even if mid skip
- {
- &getBlockHeadersData{Origin: hashOrNumber{Number: pm.blockchain.CurrentBlock().NumberU64() - 4}, Skip: 2, Amount: 3},
- []common.Hash{
- pm.blockchain.GetBlockByNumber(pm.blockchain.CurrentBlock().NumberU64() - 4).Hash(),
- pm.blockchain.GetBlockByNumber(pm.blockchain.CurrentBlock().NumberU64() - 1).Hash(),
- },
- }, {
- &getBlockHeadersData{Origin: hashOrNumber{Number: 4}, Skip: 2, Amount: 3, Reverse: true},
- []common.Hash{
- pm.blockchain.GetBlockByNumber(4).Hash(),
- pm.blockchain.GetBlockByNumber(1).Hash(),
- },
- },
- // Check a corner case where requesting more can iterate past the endpoints
- {
- &getBlockHeadersData{Origin: hashOrNumber{Number: 2}, Amount: 5, Reverse: true},
- []common.Hash{
- pm.blockchain.GetBlockByNumber(2).Hash(),
- pm.blockchain.GetBlockByNumber(1).Hash(),
- pm.blockchain.GetBlockByNumber(0).Hash(),
- },
- },
- // Check a corner case where skipping overflow loops back into the chain start
- {
- &getBlockHeadersData{Origin: hashOrNumber{Hash: pm.blockchain.GetBlockByNumber(3).Hash()}, Amount: 2, Reverse: false, Skip: math.MaxUint64 - 1},
- []common.Hash{
- pm.blockchain.GetBlockByNumber(3).Hash(),
- },
- },
- // Check a corner case where skipping overflow loops back to the same header
- {
- &getBlockHeadersData{Origin: hashOrNumber{Hash: pm.blockchain.GetBlockByNumber(1).Hash()}, Amount: 2, Reverse: false, Skip: math.MaxUint64},
- []common.Hash{
- pm.blockchain.GetBlockByNumber(1).Hash(),
- },
- },
- // Check that non existing headers aren't returned
- {
- &getBlockHeadersData{Origin: hashOrNumber{Hash: unknown}, Amount: 1},
- []common.Hash{},
- }, {
- &getBlockHeadersData{Origin: hashOrNumber{Number: pm.blockchain.CurrentBlock().NumberU64() + 1}, Amount: 1},
- []common.Hash{},
- },
- }
- // Run each of the tests and verify the results against the chain
- for i, tt := range tests {
- // Collect the headers to expect in the response
- headers := []*types.Header{}
- for _, hash := range tt.expect {
- headers = append(headers, pm.blockchain.GetBlockByHash(hash).Header())
- }
- // Send the hash request and verify the response
- p2p.Send(peer.app, 0x03, tt.query)
- if err := p2p.ExpectMsg(peer.app, 0x04, headers); err != nil {
- t.Errorf("test %d: headers mismatch: %v", i, err)
- }
- // If the test used number origins, repeat with hashes as the too
- if tt.query.Origin.Hash == (common.Hash{}) {
- if origin := pm.blockchain.GetBlockByNumber(tt.query.Origin.Number); origin != nil {
- tt.query.Origin.Hash, tt.query.Origin.Number = origin.Hash(), 0
+// testTxPool is a mock transaction pool that blindly accepts all transactions.
+// Its goal is to get around setting up a valid statedb for the balance and nonce
+// checks.
+type testTxPool struct {
+ pool map[common.Hash]*types.Transaction // Hash map of collected transactions
- p2p.Send(peer.app, 0x03, tt.query)
- if err := p2p.ExpectMsg(peer.app, 0x04, headers); err != nil {
- t.Errorf("test %d: headers mismatch: %v", i, err)
- }
- }
- }
+ txFeed event.Feed // Notification feed to allow waiting for inclusion
+ lock sync.RWMutex // Protects the transaction pool
+}
+
+// newTestTxPool creates a mock transaction pool.
+func newTestTxPool() *testTxPool {
+ return &testTxPool{
+ pool: make(map[common.Hash]*types.Transaction),
}
}
-// Tests that block contents can be retrieved from a remote chain based on their hashes.
-func TestGetBlockBodies63(t *testing.T) { testGetBlockBodies(t, 63) }
-func TestGetBlockBodies64(t *testing.T) { testGetBlockBodies(t, 64) }
+// Has returns an indicator whether txpool has a transaction
+// cached with the given hash.
+func (p *testTxPool) Has(hash common.Hash) bool {
+ p.lock.Lock()
+ defer p.lock.Unlock()
-func testGetBlockBodies(t *testing.T, protocol int) {
- pm, _ := newTestProtocolManagerMust(t, downloader.FullSync, downloader.MaxBlockFetch+15, nil, nil)
- peer, _ := newTestPeer("peer", protocol, pm, true)
- defer peer.close()
+ return p.pool[hash] != nil
+}
- // Create a batch of tests for various scenarios
- limit := downloader.MaxBlockFetch
- tests := []struct {
- random int // Number of blocks to fetch randomly from the chain
- explicit []common.Hash // Explicitly requested blocks
- available []bool // Availability of explicitly requested blocks
- expected int // Total number of existing blocks to expect
- }{
- {1, nil, nil, 1}, // A single random block should be retrievable
- {10, nil, nil, 10}, // Multiple random blocks should be retrievable
- {limit, nil, nil, limit}, // The maximum possible blocks should be retrievable
- {limit + 1, nil, nil, limit}, // No more than the possible block count should be returned
- {0, []common.Hash{pm.blockchain.Genesis().Hash()}, []bool{true}, 1}, // The genesis block should be retrievable
- {0, []common.Hash{pm.blockchain.CurrentBlock().Hash()}, []bool{true}, 1}, // The chains head block should be retrievable
- {0, []common.Hash{{}}, []bool{false}, 0}, // A non existent block should not be returned
+// Get retrieves the transaction from local txpool with given
+// tx hash.
+func (p *testTxPool) Get(hash common.Hash) *types.Transaction {
+ p.lock.Lock()
+ defer p.lock.Unlock()
- // Existing and non-existing blocks interleaved should not cause problems
- {0, []common.Hash{
- {},
- pm.blockchain.GetBlockByNumber(1).Hash(),
- {},
- pm.blockchain.GetBlockByNumber(10).Hash(),
- {},
- pm.blockchain.GetBlockByNumber(100).Hash(),
- {},
- }, []bool{false, true, false, true, false, true, false}, 3},
+ return p.pool[hash]
+}
+
+// AddRemotes appends a batch of transactions to the pool, and notifies any
+// listeners if the addition channel is non nil
+func (p *testTxPool) AddRemotes(txs []*types.Transaction) []error {
+ p.lock.Lock()
+ defer p.lock.Unlock()
+
+ for _, tx := range txs {
+ p.pool[tx.Hash()] = tx
}
- // Run each of the tests and verify the results against the chain
- for i, tt := range tests {
- // Collect the hashes to request, and the response to expect
- hashes, seen := []common.Hash{}, make(map[int64]bool)
- bodies := []*blockBody{}
+ p.txFeed.Send(core.NewTxsEvent{Txs: txs})
+ return make([]error, len(txs))
+}
- for j := 0; j < tt.random; j++ {
- for {
- num := rand.Int63n(int64(pm.blockchain.CurrentBlock().NumberU64()))
- if !seen[num] {
- seen[num] = true
+// Pending returns all the transactions known to the pool
+func (p *testTxPool) Pending() (map[common.Address]types.Transactions, error) {
+ p.lock.RLock()
+ defer p.lock.RUnlock()
- block := pm.blockchain.GetBlockByNumber(uint64(num))
- hashes = append(hashes, block.Hash())
- if len(bodies) < tt.expected {
- bodies = append(bodies, &blockBody{Transactions: block.Transactions(), Uncles: block.Uncles()})
- }
- break
- }
- }
- }
- for j, hash := range tt.explicit {
- hashes = append(hashes, hash)
- if tt.available[j] && len(bodies) < tt.expected {
- block := pm.blockchain.GetBlockByHash(hash)
- bodies = append(bodies, &blockBody{Transactions: block.Transactions(), Uncles: block.Uncles()})
- }
- }
- // Send the hash request and verify the response
- p2p.Send(peer.app, 0x05, hashes)
- if err := p2p.ExpectMsg(peer.app, 0x06, bodies); err != nil {
- t.Errorf("test %d: bodies mismatch: %v", i, err)
- }
+ batches := make(map[common.Address]types.Transactions)
+ for _, tx := range p.pool {
+ from, _ := types.Sender(types.HomesteadSigner{}, tx)
+ batches[from] = append(batches[from], tx)
+ }
+ for _, batch := range batches {
+ sort.Sort(types.TxByNonce(batch))
+ }
+ return batches, nil
+}
+
+// SubscribeNewTxsEvent should return an event subscription of NewTxsEvent and
+// send events to the given channel.
+func (p *testTxPool) SubscribeNewTxsEvent(ch chan<- core.NewTxsEvent) event.Subscription {
+ return p.txFeed.Subscribe(ch)
+}
+
+// testHandler is a live implementation of the Ethereum protocol handler, just
+// preinitialized with some sane testing defaults and the transaction pool mocked
+// out.
+type testHandler struct {
+ db ethdb.Database
+ chain *core.BlockChain
+ txpool *testTxPool
+ handler *handler
+}
+
+// newTestHandler creates a new handler for testing purposes with no blocks.
+func newTestHandler() *testHandler {
+ return newTestHandlerWithBlocks(0)
+}
+
+// newTestHandlerWithBlocks creates a new handler for testing purposes, with a
+// given number of initial blocks.
+func newTestHandlerWithBlocks(blocks int) *testHandler {
+ // Create a database pre-initialize with a genesis block
+ db := rawdb.NewMemoryDatabase()
+ (&core.Genesis{
+ Config: params.TestChainConfig,
+ Alloc: core.GenesisAlloc{testAddr: {Balance: big.NewInt(1000000)}},
+ }).MustCommit(db)
+
+ chain, _ := core.NewBlockChain(db, nil, params.TestChainConfig, ethash.NewFaker(), vm.Config{}, nil, nil)
+
+ bs, _ := core.GenerateChain(params.TestChainConfig, chain.Genesis(), ethash.NewFaker(), db, blocks, nil)
+ if _, err := chain.InsertChain(bs); err != nil {
+ panic(err)
+ }
+ txpool := newTestTxPool()
+
+ handler, _ := newHandler(&handlerConfig{
+ Database: db,
+ Chain: chain,
+ TxPool: txpool,
+ Network: 1,
+ Sync: downloader.FastSync,
+ BloomCache: 1,
+ })
+ handler.Start(1000)
+
+ return &testHandler{
+ db: db,
+ chain: chain,
+ txpool: txpool,
+ handler: handler,
}
}
-// Tests that the node state database can be retrieved based on hashes.
-func TestGetNodeData63(t *testing.T) { testGetNodeData(t, 63) }
-func TestGetNodeData64(t *testing.T) { testGetNodeData(t, 64) }
-
-func testGetNodeData(t *testing.T, protocol int) {
- // Define three accounts to simulate transactions with
- acc1Key, _ := crypto.HexToECDSA("8a1f9a8f95be41cd7ccb6168179afb4504aefe388d1e14474d32c45c72ce7b7a")
- acc2Key, _ := crypto.HexToECDSA("49a7b37aa6f6645917e7b807e9d1c00d4fa71f18343b0d4122a4d2df64dd6fee")
- acc1Addr := crypto.PubkeyToAddress(acc1Key.PublicKey)
- acc2Addr := crypto.PubkeyToAddress(acc2Key.PublicKey)
-
- signer := types.HomesteadSigner{}
- // Create a chain generator with some simple transactions (blatantly stolen from @fjl/chain_markets_test)
- generator := func(i int, block *core.BlockGen) {
- switch i {
- case 0:
- // In block 1, the test bank sends account #1 some ether.
- tx, _ := types.SignTx(types.NewTransaction(block.TxNonce(testBank), acc1Addr, big.NewInt(10000), params.TxGas, nil, nil), signer, testBankKey)
- block.AddTx(tx)
- case 1:
- // In block 2, the test bank sends some more ether to account #1.
- // acc1Addr passes it on to account #2.
- tx1, _ := types.SignTx(types.NewTransaction(block.TxNonce(testBank), acc1Addr, big.NewInt(1000), params.TxGas, nil, nil), signer, testBankKey)
- tx2, _ := types.SignTx(types.NewTransaction(block.TxNonce(acc1Addr), acc2Addr, big.NewInt(1000), params.TxGas, nil, nil), signer, acc1Key)
- block.AddTx(tx1)
- block.AddTx(tx2)
- case 2:
- // Block 3 is empty but was mined by account #2.
- block.SetCoinbase(acc2Addr)
- block.SetExtra([]byte("yeehaw"))
- case 3:
- // Block 4 includes blocks 2 and 3 as uncle headers (with modified extra data).
- b2 := block.PrevBlock(1).Header()
- b2.Extra = []byte("foo")
- block.AddUncle(b2)
- b3 := block.PrevBlock(2).Header()
- b3.Extra = []byte("foo")
- block.AddUncle(b3)
- }
- }
- // Assemble the test environment
- pm, db := newTestProtocolManagerMust(t, downloader.FullSync, 4, generator, nil)
- peer, _ := newTestPeer("peer", protocol, pm, true)
- defer peer.close()
-
- // Fetch for now the entire chain db
- hashes := []common.Hash{}
-
- it := db.NewIterator(nil, nil)
- for it.Next() {
- if key := it.Key(); len(key) == common.HashLength {
- hashes = append(hashes, common.BytesToHash(key))
- }
- }
- it.Release()
-
- p2p.Send(peer.app, 0x0d, hashes)
- msg, err := peer.app.ReadMsg()
- if err != nil {
- t.Fatalf("failed to read node data response: %v", err)
- }
- if msg.Code != 0x0e {
- t.Fatalf("response packet code mismatch: have %x, want %x", msg.Code, 0x0c)
- }
- var data [][]byte
- if err := msg.Decode(&data); err != nil {
- t.Fatalf("failed to decode response node data: %v", err)
- }
- // Verify that all hashes correspond to the requested data, and reconstruct a state tree
- for i, want := range hashes {
- if hash := crypto.Keccak256Hash(data[i]); hash != want {
- t.Errorf("data hash mismatch: have %x, want %x", hash, want)
- }
- }
- statedb := rawdb.NewMemoryDatabase()
- for i := 0; i < len(data); i++ {
- statedb.Put(hashes[i].Bytes(), data[i])
- }
- accounts := []common.Address{testBank, acc1Addr, acc2Addr}
- for i := uint64(0); i <= pm.blockchain.CurrentBlock().NumberU64(); i++ {
- trie, _ := state.New(pm.blockchain.GetBlockByNumber(i).Root(), state.NewDatabase(statedb), nil)
-
- for j, acc := range accounts {
- state, _ := pm.blockchain.State()
- bw := state.GetBalance(acc)
- bh := trie.GetBalance(acc)
-
- if (bw != nil && bh == nil) || (bw == nil && bh != nil) {
- t.Errorf("test %d, account %d: balance mismatch: have %v, want %v", i, j, bh, bw)
- }
- if bw != nil && bh != nil && bw.Cmp(bw) != 0 {
- t.Errorf("test %d, account %d: balance mismatch: have %v, want %v", i, j, bh, bw)
- }
- }
- }
-}
-
-// Tests that the transaction receipts can be retrieved based on hashes.
-func TestGetReceipt63(t *testing.T) { testGetReceipt(t, 63) }
-func TestGetReceipt64(t *testing.T) { testGetReceipt(t, 64) }
-
-func testGetReceipt(t *testing.T, protocol int) {
- // Define three accounts to simulate transactions with
- acc1Key, _ := crypto.HexToECDSA("8a1f9a8f95be41cd7ccb6168179afb4504aefe388d1e14474d32c45c72ce7b7a")
- acc2Key, _ := crypto.HexToECDSA("49a7b37aa6f6645917e7b807e9d1c00d4fa71f18343b0d4122a4d2df64dd6fee")
- acc1Addr := crypto.PubkeyToAddress(acc1Key.PublicKey)
- acc2Addr := crypto.PubkeyToAddress(acc2Key.PublicKey)
-
- signer := types.HomesteadSigner{}
- // Create a chain generator with some simple transactions (blatantly stolen from @fjl/chain_markets_test)
- generator := func(i int, block *core.BlockGen) {
- switch i {
- case 0:
- // In block 1, the test bank sends account #1 some ether.
- tx, _ := types.SignTx(types.NewTransaction(block.TxNonce(testBank), acc1Addr, big.NewInt(10000), params.TxGas, nil, nil), signer, testBankKey)
- block.AddTx(tx)
- case 1:
- // In block 2, the test bank sends some more ether to account #1.
- // acc1Addr passes it on to account #2.
- tx1, _ := types.SignTx(types.NewTransaction(block.TxNonce(testBank), acc1Addr, big.NewInt(1000), params.TxGas, nil, nil), signer, testBankKey)
- tx2, _ := types.SignTx(types.NewTransaction(block.TxNonce(acc1Addr), acc2Addr, big.NewInt(1000), params.TxGas, nil, nil), signer, acc1Key)
- block.AddTx(tx1)
- block.AddTx(tx2)
- case 2:
- // Block 3 is empty but was mined by account #2.
- block.SetCoinbase(acc2Addr)
- block.SetExtra([]byte("yeehaw"))
- case 3:
- // Block 4 includes blocks 2 and 3 as uncle headers (with modified extra data).
- b2 := block.PrevBlock(1).Header()
- b2.Extra = []byte("foo")
- block.AddUncle(b2)
- b3 := block.PrevBlock(2).Header()
- b3.Extra = []byte("foo")
- block.AddUncle(b3)
- }
- }
- // Assemble the test environment
- pm, _ := newTestProtocolManagerMust(t, downloader.FullSync, 4, generator, nil)
- peer, _ := newTestPeer("peer", protocol, pm, true)
- defer peer.close()
-
- // Collect the hashes to request, and the response to expect
- hashes, receipts := []common.Hash{}, []types.Receipts{}
- for i := uint64(0); i <= pm.blockchain.CurrentBlock().NumberU64(); i++ {
- block := pm.blockchain.GetBlockByNumber(i)
-
- hashes = append(hashes, block.Hash())
- receipts = append(receipts, pm.blockchain.GetReceiptsByHash(block.Hash()))
- }
- // Send the hash request and verify the response
- p2p.Send(peer.app, 0x0f, hashes)
- if err := p2p.ExpectMsg(peer.app, 0x10, receipts); err != nil {
- t.Errorf("receipts mismatch: %v", err)
- }
-}
-
-// Tests that post eth protocol handshake, clients perform a mutual checkpoint
-// challenge to validate each other's chains. Hash mismatches, or missing ones
-// during a fast sync should lead to the peer getting dropped.
-func TestCheckpointChallenge(t *testing.T) {
- tests := []struct {
- syncmode downloader.SyncMode
- checkpoint bool
- timeout bool
- empty bool
- match bool
- drop bool
- }{
- // If checkpointing is not enabled locally, don't challenge and don't drop
- {downloader.FullSync, false, false, false, false, false},
- {downloader.FastSync, false, false, false, false, false},
-
- // If checkpointing is enabled locally and remote response is empty, only drop during fast sync
- {downloader.FullSync, true, false, true, false, false},
- {downloader.FastSync, true, false, true, false, true}, // Special case, fast sync, unsynced peer
-
- // If checkpointing is enabled locally and remote response mismatches, always drop
- {downloader.FullSync, true, false, false, false, true},
- {downloader.FastSync, true, false, false, false, true},
-
- // If checkpointing is enabled locally and remote response matches, never drop
- {downloader.FullSync, true, false, false, true, false},
- {downloader.FastSync, true, false, false, true, false},
-
- // If checkpointing is enabled locally and remote times out, always drop
- {downloader.FullSync, true, true, false, true, true},
- {downloader.FastSync, true, true, false, true, true},
- }
- for _, tt := range tests {
- t.Run(fmt.Sprintf("sync %v checkpoint %v timeout %v empty %v match %v", tt.syncmode, tt.checkpoint, tt.timeout, tt.empty, tt.match), func(t *testing.T) {
- testCheckpointChallenge(t, tt.syncmode, tt.checkpoint, tt.timeout, tt.empty, tt.match, tt.drop)
- })
- }
-}
-
-func testCheckpointChallenge(t *testing.T, syncmode downloader.SyncMode, checkpoint bool, timeout bool, empty bool, match bool, drop bool) {
- // Reduce the checkpoint handshake challenge timeout
- defer func(old time.Duration) { syncChallengeTimeout = old }(syncChallengeTimeout)
- syncChallengeTimeout = 250 * time.Millisecond
-
- // Initialize a chain and generate a fake CHT if checkpointing is enabled
- var (
- db = rawdb.NewMemoryDatabase()
- config = new(params.ChainConfig)
- )
- (&core.Genesis{Config: config}).MustCommit(db) // Commit genesis block
- // If checkpointing is enabled, create and inject a fake CHT and the corresponding
- // chllenge response.
- var response *types.Header
- var cht *params.TrustedCheckpoint
- if checkpoint {
- index := uint64(rand.Intn(500))
- number := (index+1)*params.CHTFrequency - 1
- response = &types.Header{Number: big.NewInt(int64(number)), Extra: []byte("valid")}
-
- cht = ¶ms.TrustedCheckpoint{
- SectionIndex: index,
- SectionHead: response.Hash(),
- }
- }
- // Create a checkpoint aware protocol manager
- blockchain, err := core.NewBlockChain(db, nil, config, ethash.NewFaker(), vm.Config{}, nil, nil)
- if err != nil {
- t.Fatalf("failed to create new blockchain: %v", err)
- }
- pm, err := NewProtocolManager(config, cht, syncmode, DefaultConfig.NetworkId, new(event.TypeMux), &testTxPool{pool: make(map[common.Hash]*types.Transaction)}, ethash.NewFaker(), blockchain, db, 1, nil)
- if err != nil {
- t.Fatalf("failed to start test protocol manager: %v", err)
- }
- pm.Start(1000)
- defer pm.Stop()
-
- // Connect a new peer and check that we receive the checkpoint challenge
- peer, _ := newTestPeer("peer", eth63, pm, true)
- defer peer.close()
-
- if checkpoint {
- challenge := &getBlockHeadersData{
- Origin: hashOrNumber{Number: response.Number.Uint64()},
- Amount: 1,
- Skip: 0,
- Reverse: false,
- }
- if err := p2p.ExpectMsg(peer.app, GetBlockHeadersMsg, challenge); err != nil {
- t.Fatalf("challenge mismatch: %v", err)
- }
- // Create a block to reply to the challenge if no timeout is simulated
- if !timeout {
- if empty {
- if err := p2p.Send(peer.app, BlockHeadersMsg, []*types.Header{}); err != nil {
- t.Fatalf("failed to answer challenge: %v", err)
- }
- } else if match {
- if err := p2p.Send(peer.app, BlockHeadersMsg, []*types.Header{response}); err != nil {
- t.Fatalf("failed to answer challenge: %v", err)
- }
- } else {
- if err := p2p.Send(peer.app, BlockHeadersMsg, []*types.Header{{Number: response.Number}}); err != nil {
- t.Fatalf("failed to answer challenge: %v", err)
- }
- }
- }
- }
- // Wait until the test timeout passes to ensure proper cleanup
- time.Sleep(syncChallengeTimeout + 300*time.Millisecond)
-
- // Verify that the remote peer is maintained or dropped
- if drop {
- if peers := pm.peers.Len(); peers != 0 {
- t.Fatalf("peer count mismatch: have %d, want %d", peers, 0)
- }
- } else {
- if peers := pm.peers.Len(); peers != 1 {
- t.Fatalf("peer count mismatch: have %d, want %d", peers, 1)
- }
- }
-}
-
-func TestBroadcastBlock(t *testing.T) {
- var tests = []struct {
- totalPeers int
- broadcastExpected int
- }{
- {1, 1},
- {2, 1},
- {3, 1},
- {4, 2},
- {5, 2},
- {9, 3},
- {12, 3},
- {16, 4},
- {26, 5},
- {100, 10},
- }
- for _, test := range tests {
- testBroadcastBlock(t, test.totalPeers, test.broadcastExpected)
- }
-}
-
-func testBroadcastBlock(t *testing.T, totalPeers, broadcastExpected int) {
- var (
- evmux = new(event.TypeMux)
- pow = ethash.NewFaker()
- db = rawdb.NewMemoryDatabase()
- config = ¶ms.ChainConfig{}
- gspec = &core.Genesis{Config: config}
- genesis = gspec.MustCommit(db)
- )
- blockchain, err := core.NewBlockChain(db, nil, config, pow, vm.Config{}, nil, nil)
- if err != nil {
- t.Fatalf("failed to create new blockchain: %v", err)
- }
- pm, err := NewProtocolManager(config, nil, downloader.FullSync, DefaultConfig.NetworkId, evmux, &testTxPool{pool: make(map[common.Hash]*types.Transaction)}, pow, blockchain, db, 1, nil)
- if err != nil {
- t.Fatalf("failed to start test protocol manager: %v", err)
- }
- pm.Start(1000)
- defer pm.Stop()
- var peers []*testPeer
- for i := 0; i < totalPeers; i++ {
- peer, _ := newTestPeer(fmt.Sprintf("peer %d", i), eth63, pm, true)
- defer peer.close()
-
- peers = append(peers, peer)
- }
- chain, _ := core.GenerateChain(gspec.Config, genesis, ethash.NewFaker(), db, 1, func(i int, gen *core.BlockGen) {})
- pm.BroadcastBlock(chain[0], true /*propagate*/)
-
- errCh := make(chan error, totalPeers)
- doneCh := make(chan struct{}, totalPeers)
- for _, peer := range peers {
- go func(p *testPeer) {
- if err := p2p.ExpectMsg(p.app, NewBlockMsg, &newBlockData{Block: chain[0], TD: big.NewInt(131136)}); err != nil {
- errCh <- err
- } else {
- doneCh <- struct{}{}
- }
- }(peer)
- }
- var received int
- for {
- select {
- case <-doneCh:
- received++
- if received > broadcastExpected {
- // We can bail early here
- t.Errorf("broadcast count mismatch: have %d > want %d", received, broadcastExpected)
- return
- }
- case <-time.After(2 * time.Second):
- if received != broadcastExpected {
- t.Errorf("broadcast count mismatch: have %d, want %d", received, broadcastExpected)
- }
- return
- case err = <-errCh:
- t.Fatalf("broadcast failed: %v", err)
- }
- }
-
-}
-
-// Tests that a propagated malformed block (uncles or transactions don't match
-// with the hashes in the header) gets discarded and not broadcast forward.
-func TestBroadcastMalformedBlock(t *testing.T) {
- // Create a live node to test propagation with
- var (
- engine = ethash.NewFaker()
- db = rawdb.NewMemoryDatabase()
- config = ¶ms.ChainConfig{}
- gspec = &core.Genesis{Config: config}
- genesis = gspec.MustCommit(db)
- )
- blockchain, err := core.NewBlockChain(db, nil, config, engine, vm.Config{}, nil, nil)
- if err != nil {
- t.Fatalf("failed to create new blockchain: %v", err)
- }
- pm, err := NewProtocolManager(config, nil, downloader.FullSync, DefaultConfig.NetworkId, new(event.TypeMux), new(testTxPool), engine, blockchain, db, 1, nil)
- if err != nil {
- t.Fatalf("failed to start test protocol manager: %v", err)
- }
- pm.Start(2)
- defer pm.Stop()
-
- // Create two peers, one to send the malformed block with and one to check
- // propagation
- source, _ := newTestPeer("source", eth63, pm, true)
- defer source.close()
-
- sink, _ := newTestPeer("sink", eth63, pm, true)
- defer sink.close()
-
- // Create various combinations of malformed blocks
- chain, _ := core.GenerateChain(gspec.Config, genesis, ethash.NewFaker(), db, 1, func(i int, gen *core.BlockGen) {})
-
- malformedUncles := chain[0].Header()
- malformedUncles.UncleHash[0]++
- malformedTransactions := chain[0].Header()
- malformedTransactions.TxHash[0]++
- malformedEverything := chain[0].Header()
- malformedEverything.UncleHash[0]++
- malformedEverything.TxHash[0]++
-
- // Keep listening to broadcasts and notify if any arrives
- notify := make(chan struct{}, 1)
- go func() {
- if _, err := sink.app.ReadMsg(); err == nil {
- notify <- struct{}{}
- }
- }()
- // Try to broadcast all malformations and ensure they all get discarded
- for _, header := range []*types.Header{malformedUncles, malformedTransactions, malformedEverything} {
- block := types.NewBlockWithHeader(header).WithBody(chain[0].Transactions(), chain[0].Uncles())
- if err := p2p.Send(source.app, NewBlockMsg, []interface{}{block, big.NewInt(131136)}); err != nil {
- t.Fatalf("failed to broadcast block: %v", err)
- }
- select {
- case <-notify:
- t.Fatalf("malformed block forwarded")
- case <-time.After(100 * time.Millisecond):
- }
- }
+// close tears down the handler and all its internal constructs.
+func (b *testHandler) close() {
+ b.handler.Stop()
+ b.chain.Stop()
}
diff --git a/eth/helper_test.go b/eth/helper_test.go
deleted file mode 100644
index c0bda181e..000000000
--- a/eth/helper_test.go
+++ /dev/null
@@ -1,231 +0,0 @@
-// Copyright 2015 The go-ethereum Authors
-// This file is part of the go-ethereum library.
-//
-// The go-ethereum library is free software: you can redistribute it and/or modify
-// it under the terms of the GNU Lesser General Public License as published by
-// the Free Software Foundation, either version 3 of the License, or
-// (at your option) any later version.
-//
-// The go-ethereum library is distributed in the hope that it will be useful,
-// but WITHOUT ANY WARRANTY; without even the implied warranty of
-// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
-// GNU Lesser General Public License for more details.
-//
-// You should have received a copy of the GNU Lesser General Public License
-// along with the go-ethereum library. If not, see .
-
-// This file contains some shares testing functionality, common to multiple
-// different files and modules being tested.
-
-package eth
-
-import (
- "crypto/ecdsa"
- "crypto/rand"
- "fmt"
- "math/big"
- "sort"
- "sync"
- "testing"
-
- "github.com/ethereum/go-ethereum/common"
- "github.com/ethereum/go-ethereum/consensus/ethash"
- "github.com/ethereum/go-ethereum/core"
- "github.com/ethereum/go-ethereum/core/forkid"
- "github.com/ethereum/go-ethereum/core/rawdb"
- "github.com/ethereum/go-ethereum/core/types"
- "github.com/ethereum/go-ethereum/core/vm"
- "github.com/ethereum/go-ethereum/crypto"
- "github.com/ethereum/go-ethereum/eth/downloader"
- "github.com/ethereum/go-ethereum/ethdb"
- "github.com/ethereum/go-ethereum/event"
- "github.com/ethereum/go-ethereum/p2p"
- "github.com/ethereum/go-ethereum/p2p/enode"
- "github.com/ethereum/go-ethereum/params"
-)
-
-var (
- testBankKey, _ = crypto.HexToECDSA("b71c71a67e1177ad4e901695e1b4b9ee17ae16c6668d313eac2f96dbcda3f291")
- testBank = crypto.PubkeyToAddress(testBankKey.PublicKey)
-)
-
-// newTestProtocolManager creates a new protocol manager for testing purposes,
-// with the given number of blocks already known, and potential notification
-// channels for different events.
-func newTestProtocolManager(mode downloader.SyncMode, blocks int, generator func(int, *core.BlockGen), newtx chan<- []*types.Transaction) (*ProtocolManager, ethdb.Database, error) {
- var (
- evmux = new(event.TypeMux)
- engine = ethash.NewFaker()
- db = rawdb.NewMemoryDatabase()
- gspec = &core.Genesis{
- Config: params.TestChainConfig,
- Alloc: core.GenesisAlloc{testBank: {Balance: big.NewInt(1000000)}},
- }
- genesis = gspec.MustCommit(db)
- blockchain, _ = core.NewBlockChain(db, nil, gspec.Config, engine, vm.Config{}, nil, nil)
- )
- chain, _ := core.GenerateChain(gspec.Config, genesis, ethash.NewFaker(), db, blocks, generator)
- if _, err := blockchain.InsertChain(chain); err != nil {
- panic(err)
- }
- pm, err := NewProtocolManager(gspec.Config, nil, mode, DefaultConfig.NetworkId, evmux, &testTxPool{added: newtx, pool: make(map[common.Hash]*types.Transaction)}, engine, blockchain, db, 1, nil)
- if err != nil {
- return nil, nil, err
- }
- pm.Start(1000)
- return pm, db, nil
-}
-
-// newTestProtocolManagerMust creates a new protocol manager for testing purposes,
-// with the given number of blocks already known, and potential notification
-// channels for different events. In case of an error, the constructor force-
-// fails the test.
-func newTestProtocolManagerMust(t *testing.T, mode downloader.SyncMode, blocks int, generator func(int, *core.BlockGen), newtx chan<- []*types.Transaction) (*ProtocolManager, ethdb.Database) {
- pm, db, err := newTestProtocolManager(mode, blocks, generator, newtx)
- if err != nil {
- t.Fatalf("Failed to create protocol manager: %v", err)
- }
- return pm, db
-}
-
-// testTxPool is a fake, helper transaction pool for testing purposes
-type testTxPool struct {
- txFeed event.Feed
- pool map[common.Hash]*types.Transaction // Hash map of collected transactions
- added chan<- []*types.Transaction // Notification channel for new transactions
-
- lock sync.RWMutex // Protects the transaction pool
-}
-
-// Has returns an indicator whether txpool has a transaction
-// cached with the given hash.
-func (p *testTxPool) Has(hash common.Hash) bool {
- p.lock.Lock()
- defer p.lock.Unlock()
-
- return p.pool[hash] != nil
-}
-
-// Get retrieves the transaction from local txpool with given
-// tx hash.
-func (p *testTxPool) Get(hash common.Hash) *types.Transaction {
- p.lock.Lock()
- defer p.lock.Unlock()
-
- return p.pool[hash]
-}
-
-// AddRemotes appends a batch of transactions to the pool, and notifies any
-// listeners if the addition channel is non nil
-func (p *testTxPool) AddRemotes(txs []*types.Transaction) []error {
- p.lock.Lock()
- defer p.lock.Unlock()
-
- for _, tx := range txs {
- p.pool[tx.Hash()] = tx
- }
- if p.added != nil {
- p.added <- txs
- }
- p.txFeed.Send(core.NewTxsEvent{Txs: txs})
- return make([]error, len(txs))
-}
-
-// Pending returns all the transactions known to the pool
-func (p *testTxPool) Pending() (map[common.Address]types.Transactions, error) {
- p.lock.RLock()
- defer p.lock.RUnlock()
-
- batches := make(map[common.Address]types.Transactions)
- for _, tx := range p.pool {
- from, _ := types.Sender(types.HomesteadSigner{}, tx)
- batches[from] = append(batches[from], tx)
- }
- for _, batch := range batches {
- sort.Sort(types.TxByNonce(batch))
- }
- return batches, nil
-}
-
-func (p *testTxPool) SubscribeNewTxsEvent(ch chan<- core.NewTxsEvent) event.Subscription {
- return p.txFeed.Subscribe(ch)
-}
-
-// newTestTransaction create a new dummy transaction.
-func newTestTransaction(from *ecdsa.PrivateKey, nonce uint64, datasize int) *types.Transaction {
- tx := types.NewTransaction(nonce, common.Address{}, big.NewInt(0), 100000, big.NewInt(0), make([]byte, datasize))
- tx, _ = types.SignTx(tx, types.HomesteadSigner{}, from)
- return tx
-}
-
-// testPeer is a simulated peer to allow testing direct network calls.
-type testPeer struct {
- net p2p.MsgReadWriter // Network layer reader/writer to simulate remote messaging
- app *p2p.MsgPipeRW // Application layer reader/writer to simulate the local side
- *peer
-}
-
-// newTestPeer creates a new peer registered at the given protocol manager.
-func newTestPeer(name string, version int, pm *ProtocolManager, shake bool) (*testPeer, <-chan error) {
- // Create a message pipe to communicate through
- app, net := p2p.MsgPipe()
-
- // Start the peer on a new thread
- var id enode.ID
- rand.Read(id[:])
- peer := pm.newPeer(version, p2p.NewPeer(id, name, nil), net, pm.txpool.Get)
- errc := make(chan error, 1)
- go func() { errc <- pm.runPeer(peer) }()
- tp := &testPeer{app: app, net: net, peer: peer}
-
- // Execute any implicitly requested handshakes and return
- if shake {
- var (
- genesis = pm.blockchain.Genesis()
- head = pm.blockchain.CurrentHeader()
- td = pm.blockchain.GetTd(head.Hash(), head.Number.Uint64())
- )
- forkID := forkid.NewID(pm.blockchain.Config(), pm.blockchain.Genesis().Hash(), pm.blockchain.CurrentHeader().Number.Uint64())
- tp.handshake(nil, td, head.Hash(), genesis.Hash(), forkID, forkid.NewFilter(pm.blockchain))
- }
- return tp, errc
-}
-
-// handshake simulates a trivial handshake that expects the same state from the
-// remote side as we are simulating locally.
-func (p *testPeer) handshake(t *testing.T, td *big.Int, head common.Hash, genesis common.Hash, forkID forkid.ID, forkFilter forkid.Filter) {
- var msg interface{}
- switch {
- case p.version == eth63:
- msg = &statusData63{
- ProtocolVersion: uint32(p.version),
- NetworkId: DefaultConfig.NetworkId,
- TD: td,
- CurrentBlock: head,
- GenesisBlock: genesis,
- }
- case p.version >= eth64:
- msg = &statusData{
- ProtocolVersion: uint32(p.version),
- NetworkID: DefaultConfig.NetworkId,
- TD: td,
- Head: head,
- Genesis: genesis,
- ForkID: forkID,
- }
- default:
- panic(fmt.Sprintf("unsupported eth protocol version: %d", p.version))
- }
- if err := p2p.ExpectMsg(p.app, StatusMsg, msg); err != nil {
- t.Fatalf("status recv: %v", err)
- }
- if err := p2p.Send(p.app, StatusMsg, msg); err != nil {
- t.Fatalf("status send: %v", err)
- }
-}
-
-// close terminates the local side of the peer, notifying the remote protocol
-// manager of termination.
-func (p *testPeer) close() {
- p.app.Close()
-}
diff --git a/eth/peer.go b/eth/peer.go
index 21b82a19c..6970c8afd 100644
--- a/eth/peer.go
+++ b/eth/peer.go
@@ -17,806 +17,58 @@
package eth
import (
- "errors"
- "fmt"
"math/big"
"sync"
"time"
- mapset "github.com/deckarep/golang-set"
- "github.com/ethereum/go-ethereum/common"
- "github.com/ethereum/go-ethereum/core/forkid"
- "github.com/ethereum/go-ethereum/core/types"
- "github.com/ethereum/go-ethereum/p2p"
- "github.com/ethereum/go-ethereum/rlp"
+ "github.com/ethereum/go-ethereum/eth/protocols/eth"
+ "github.com/ethereum/go-ethereum/eth/protocols/snap"
)
-var (
- errClosed = errors.New("peer set is closed")
- errAlreadyRegistered = errors.New("peer is already registered")
- errNotRegistered = errors.New("peer is not registered")
-)
-
-const (
- maxKnownTxs = 32768 // Maximum transactions hashes to keep in the known list (prevent DOS)
- maxKnownBlocks = 1024 // Maximum block hashes to keep in the known list (prevent DOS)
-
- // maxQueuedTxs is the maximum number of transactions to queue up before dropping
- // older broadcasts.
- maxQueuedTxs = 4096
-
- // maxQueuedTxAnns is the maximum number of transaction announcements to queue up
- // before dropping older announcements.
- maxQueuedTxAnns = 4096
-
- // maxQueuedBlocks is the maximum number of block propagations to queue up before
- // dropping broadcasts. There's not much point in queueing stale blocks, so a few
- // that might cover uncles should be enough.
- maxQueuedBlocks = 4
-
- // maxQueuedBlockAnns is the maximum number of block announcements to queue up before
- // dropping broadcasts. Similarly to block propagations, there's no point to queue
- // above some healthy uncle limit, so use that.
- maxQueuedBlockAnns = 4
-
- handshakeTimeout = 5 * time.Second
-)
-
-// max is a helper function which returns the larger of the two given integers.
-func max(a, b int) int {
- if a > b {
- return a
- }
- return b
-}
-
-// PeerInfo represents a short summary of the Ethereum sub-protocol metadata known
+// ethPeerInfo represents a short summary of the `eth` sub-protocol metadata known
// about a connected peer.
-type PeerInfo struct {
- Version int `json:"version"` // Ethereum protocol version negotiated
+type ethPeerInfo struct {
+ Version uint `json:"version"` // Ethereum protocol version negotiated
Difficulty *big.Int `json:"difficulty"` // Total difficulty of the peer's blockchain
- Head string `json:"head"` // SHA3 hash of the peer's best owned block
+ Head string `json:"head"` // Hex hash of the peer's best owned block
}
-// propEvent is a block propagation, waiting for its turn in the broadcast queue.
-type propEvent struct {
- block *types.Block
- td *big.Int
+// ethPeer is a wrapper around eth.Peer to maintain a few extra metadata.
+type ethPeer struct {
+ *eth.Peer
+
+ syncDrop *time.Timer // Connection dropper if `eth` sync progress isn't validated in time
+ lock sync.RWMutex // Mutex protecting the internal fields
}
-type peer struct {
- id string
-
- *p2p.Peer
- rw p2p.MsgReadWriter
-
- version int // Protocol version negotiated
- syncDrop *time.Timer // Timed connection dropper if sync progress isn't validated in time
-
- head common.Hash
- td *big.Int
- lock sync.RWMutex
-
- knownBlocks mapset.Set // Set of block hashes known to be known by this peer
- queuedBlocks chan *propEvent // Queue of blocks to broadcast to the peer
- queuedBlockAnns chan *types.Block // Queue of blocks to announce to the peer
-
- knownTxs mapset.Set // Set of transaction hashes known to be known by this peer
- txBroadcast chan []common.Hash // Channel used to queue transaction propagation requests
- txAnnounce chan []common.Hash // Channel used to queue transaction announcement requests
- getPooledTx func(common.Hash) *types.Transaction // Callback used to retrieve transaction from txpool
-
- term chan struct{} // Termination channel to stop the broadcaster
-}
-
-func newPeer(version int, p *p2p.Peer, rw p2p.MsgReadWriter, getPooledTx func(hash common.Hash) *types.Transaction) *peer {
- return &peer{
- Peer: p,
- rw: rw,
- version: version,
- id: fmt.Sprintf("%x", p.ID().Bytes()[:8]),
- knownTxs: mapset.NewSet(),
- knownBlocks: mapset.NewSet(),
- queuedBlocks: make(chan *propEvent, maxQueuedBlocks),
- queuedBlockAnns: make(chan *types.Block, maxQueuedBlockAnns),
- txBroadcast: make(chan []common.Hash),
- txAnnounce: make(chan []common.Hash),
- getPooledTx: getPooledTx,
- term: make(chan struct{}),
- }
-}
-
-// broadcastBlocks is a write loop that multiplexes blocks and block accouncements
-// to the remote peer. The goal is to have an async writer that does not lock up
-// node internals and at the same time rate limits queued data.
-func (p *peer) broadcastBlocks(removePeer func(string)) {
- for {
- select {
- case prop := <-p.queuedBlocks:
- if err := p.SendNewBlock(prop.block, prop.td); err != nil {
- removePeer(p.id)
- return
- }
- p.Log().Trace("Propagated block", "number", prop.block.Number(), "hash", prop.block.Hash(), "td", prop.td)
-
- case block := <-p.queuedBlockAnns:
- if err := p.SendNewBlockHashes([]common.Hash{block.Hash()}, []uint64{block.NumberU64()}); err != nil {
- removePeer(p.id)
- return
- }
- p.Log().Trace("Announced block", "number", block.Number(), "hash", block.Hash())
-
- case <-p.term:
- return
- }
- }
-}
-
-// broadcastTransactions is a write loop that schedules transaction broadcasts
-// to the remote peer. The goal is to have an async writer that does not lock up
-// node internals and at the same time rate limits queued data.
-func (p *peer) broadcastTransactions(removePeer func(string)) {
- var (
- queue []common.Hash // Queue of hashes to broadcast as full transactions
- done chan struct{} // Non-nil if background broadcaster is running
- fail = make(chan error, 1) // Channel used to receive network error
- )
- for {
- // If there's no in-flight broadcast running, check if a new one is needed
- if done == nil && len(queue) > 0 {
- // Pile transaction until we reach our allowed network limit
- var (
- hashes []common.Hash
- txs []*types.Transaction
- size common.StorageSize
- )
- for i := 0; i < len(queue) && size < txsyncPackSize; i++ {
- if tx := p.getPooledTx(queue[i]); tx != nil {
- txs = append(txs, tx)
- size += tx.Size()
- }
- hashes = append(hashes, queue[i])
- }
- queue = queue[:copy(queue, queue[len(hashes):])]
-
- // If there's anything available to transfer, fire up an async writer
- if len(txs) > 0 {
- done = make(chan struct{})
- go func() {
- if err := p.sendTransactions(txs); err != nil {
- fail <- err
- return
- }
- close(done)
- p.Log().Trace("Sent transactions", "count", len(txs))
- }()
- }
- }
- // Transfer goroutine may or may not have been started, listen for events
- select {
- case hashes := <-p.txBroadcast:
- // New batch of transactions to be broadcast, queue them (with cap)
- queue = append(queue, hashes...)
- if len(queue) > maxQueuedTxs {
- // Fancy copy and resize to ensure buffer doesn't grow indefinitely
- queue = queue[:copy(queue, queue[len(queue)-maxQueuedTxs:])]
- }
-
- case <-done:
- done = nil
-
- case <-fail:
- removePeer(p.id)
- return
-
- case <-p.term:
- return
- }
- }
-}
-
-// announceTransactions is a write loop that schedules transaction broadcasts
-// to the remote peer. The goal is to have an async writer that does not lock up
-// node internals and at the same time rate limits queued data.
-func (p *peer) announceTransactions(removePeer func(string)) {
- var (
- queue []common.Hash // Queue of hashes to announce as transaction stubs
- done chan struct{} // Non-nil if background announcer is running
- fail = make(chan error, 1) // Channel used to receive network error
- )
- for {
- // If there's no in-flight announce running, check if a new one is needed
- if done == nil && len(queue) > 0 {
- // Pile transaction hashes until we reach our allowed network limit
- var (
- hashes []common.Hash
- pending []common.Hash
- size common.StorageSize
- )
- for i := 0; i < len(queue) && size < txsyncPackSize; i++ {
- if p.getPooledTx(queue[i]) != nil {
- pending = append(pending, queue[i])
- size += common.HashLength
- }
- hashes = append(hashes, queue[i])
- }
- queue = queue[:copy(queue, queue[len(hashes):])]
-
- // If there's anything available to transfer, fire up an async writer
- if len(pending) > 0 {
- done = make(chan struct{})
- go func() {
- if err := p.sendPooledTransactionHashes(pending); err != nil {
- fail <- err
- return
- }
- close(done)
- p.Log().Trace("Sent transaction announcements", "count", len(pending))
- }()
- }
- }
- // Transfer goroutine may or may not have been started, listen for events
- select {
- case hashes := <-p.txAnnounce:
- // New batch of transactions to be broadcast, queue them (with cap)
- queue = append(queue, hashes...)
- if len(queue) > maxQueuedTxAnns {
- // Fancy copy and resize to ensure buffer doesn't grow indefinitely
- queue = queue[:copy(queue, queue[len(queue)-maxQueuedTxAnns:])]
- }
-
- case <-done:
- done = nil
-
- case <-fail:
- removePeer(p.id)
- return
-
- case <-p.term:
- return
- }
- }
-}
-
-// close signals the broadcast goroutine to terminate.
-func (p *peer) close() {
- close(p.term)
-}
-
-// Info gathers and returns a collection of metadata known about a peer.
-func (p *peer) Info() *PeerInfo {
+// info gathers and returns some `eth` protocol metadata known about a peer.
+func (p *ethPeer) info() *ethPeerInfo {
hash, td := p.Head()
- return &PeerInfo{
- Version: p.version,
+ return ðPeerInfo{
+ Version: p.Version(),
Difficulty: td,
Head: hash.Hex(),
}
}
-// Head retrieves a copy of the current head hash and total difficulty of the
-// peer.
-func (p *peer) Head() (hash common.Hash, td *big.Int) {
- p.lock.RLock()
- defer p.lock.RUnlock()
-
- copy(hash[:], p.head[:])
- return hash, new(big.Int).Set(p.td)
+// snapPeerInfo represents a short summary of the `snap` sub-protocol metadata known
+// about a connected peer.
+type snapPeerInfo struct {
+ Version uint `json:"version"` // Snapshot protocol version negotiated
}
-// SetHead updates the head hash and total difficulty of the peer.
-func (p *peer) SetHead(hash common.Hash, td *big.Int) {
- p.lock.Lock()
- defer p.lock.Unlock()
+// snapPeer is a wrapper around snap.Peer to maintain a few extra metadata.
+type snapPeer struct {
+ *snap.Peer
- copy(p.head[:], hash[:])
- p.td.Set(td)
+ ethDrop *time.Timer // Connection dropper if `eth` doesn't connect in time
+ lock sync.RWMutex // Mutex protecting the internal fields
}
-// MarkBlock marks a block as known for the peer, ensuring that the block will
-// never be propagated to this particular peer.
-func (p *peer) MarkBlock(hash common.Hash) {
- // If we reached the memory allowance, drop a previously known block hash
- for p.knownBlocks.Cardinality() >= maxKnownBlocks {
- p.knownBlocks.Pop()
- }
- p.knownBlocks.Add(hash)
-}
-
-// MarkTransaction marks a transaction as known for the peer, ensuring that it
-// will never be propagated to this particular peer.
-func (p *peer) MarkTransaction(hash common.Hash) {
- // If we reached the memory allowance, drop a previously known transaction hash
- for p.knownTxs.Cardinality() >= maxKnownTxs {
- p.knownTxs.Pop()
- }
- p.knownTxs.Add(hash)
-}
-
-// SendTransactions64 sends transactions to the peer and includes the hashes
-// in its transaction hash set for future reference.
-//
-// This method is legacy support for initial transaction exchange in eth/64 and
-// prior. For eth/65 and higher use SendPooledTransactionHashes.
-func (p *peer) SendTransactions64(txs types.Transactions) error {
- return p.sendTransactions(txs)
-}
-
-// sendTransactions sends transactions to the peer and includes the hashes
-// in its transaction hash set for future reference.
-//
-// This method is a helper used by the async transaction sender. Don't call it
-// directly as the queueing (memory) and transmission (bandwidth) costs should
-// not be managed directly.
-func (p *peer) sendTransactions(txs types.Transactions) error {
- // Mark all the transactions as known, but ensure we don't overflow our limits
- for p.knownTxs.Cardinality() > max(0, maxKnownTxs-len(txs)) {
- p.knownTxs.Pop()
- }
- for _, tx := range txs {
- p.knownTxs.Add(tx.Hash())
- }
- return p2p.Send(p.rw, TransactionMsg, txs)
-}
-
-// AsyncSendTransactions queues a list of transactions (by hash) to eventually
-// propagate to a remote peer. The number of pending sends are capped (new ones
-// will force old sends to be dropped)
-func (p *peer) AsyncSendTransactions(hashes []common.Hash) {
- select {
- case p.txBroadcast <- hashes:
- // Mark all the transactions as known, but ensure we don't overflow our limits
- for p.knownTxs.Cardinality() > max(0, maxKnownTxs-len(hashes)) {
- p.knownTxs.Pop()
- }
- for _, hash := range hashes {
- p.knownTxs.Add(hash)
- }
- case <-p.term:
- p.Log().Debug("Dropping transaction propagation", "count", len(hashes))
+// info gathers and returns some `snap` protocol metadata known about a peer.
+func (p *snapPeer) info() *snapPeerInfo {
+ return &snapPeerInfo{
+ Version: p.Version(),
}
}
-
-// sendPooledTransactionHashes sends transaction hashes to the peer and includes
-// them in its transaction hash set for future reference.
-//
-// This method is a helper used by the async transaction announcer. Don't call it
-// directly as the queueing (memory) and transmission (bandwidth) costs should
-// not be managed directly.
-func (p *peer) sendPooledTransactionHashes(hashes []common.Hash) error {
- // Mark all the transactions as known, but ensure we don't overflow our limits
- for p.knownTxs.Cardinality() > max(0, maxKnownTxs-len(hashes)) {
- p.knownTxs.Pop()
- }
- for _, hash := range hashes {
- p.knownTxs.Add(hash)
- }
- return p2p.Send(p.rw, NewPooledTransactionHashesMsg, hashes)
-}
-
-// AsyncSendPooledTransactionHashes queues a list of transactions hashes to eventually
-// announce to a remote peer. The number of pending sends are capped (new ones
-// will force old sends to be dropped)
-func (p *peer) AsyncSendPooledTransactionHashes(hashes []common.Hash) {
- select {
- case p.txAnnounce <- hashes:
- // Mark all the transactions as known, but ensure we don't overflow our limits
- for p.knownTxs.Cardinality() > max(0, maxKnownTxs-len(hashes)) {
- p.knownTxs.Pop()
- }
- for _, hash := range hashes {
- p.knownTxs.Add(hash)
- }
- case <-p.term:
- p.Log().Debug("Dropping transaction announcement", "count", len(hashes))
- }
-}
-
-// SendPooledTransactionsRLP sends requested transactions to the peer and adds the
-// hashes in its transaction hash set for future reference.
-//
-// Note, the method assumes the hashes are correct and correspond to the list of
-// transactions being sent.
-func (p *peer) SendPooledTransactionsRLP(hashes []common.Hash, txs []rlp.RawValue) error {
- // Mark all the transactions as known, but ensure we don't overflow our limits
- for p.knownTxs.Cardinality() > max(0, maxKnownTxs-len(hashes)) {
- p.knownTxs.Pop()
- }
- for _, hash := range hashes {
- p.knownTxs.Add(hash)
- }
- return p2p.Send(p.rw, PooledTransactionsMsg, txs)
-}
-
-// SendNewBlockHashes announces the availability of a number of blocks through
-// a hash notification.
-func (p *peer) SendNewBlockHashes(hashes []common.Hash, numbers []uint64) error {
- // Mark all the block hashes as known, but ensure we don't overflow our limits
- for p.knownBlocks.Cardinality() > max(0, maxKnownBlocks-len(hashes)) {
- p.knownBlocks.Pop()
- }
- for _, hash := range hashes {
- p.knownBlocks.Add(hash)
- }
- request := make(newBlockHashesData, len(hashes))
- for i := 0; i < len(hashes); i++ {
- request[i].Hash = hashes[i]
- request[i].Number = numbers[i]
- }
- return p2p.Send(p.rw, NewBlockHashesMsg, request)
-}
-
-// AsyncSendNewBlockHash queues the availability of a block for propagation to a
-// remote peer. If the peer's broadcast queue is full, the event is silently
-// dropped.
-func (p *peer) AsyncSendNewBlockHash(block *types.Block) {
- select {
- case p.queuedBlockAnns <- block:
- // Mark all the block hash as known, but ensure we don't overflow our limits
- for p.knownBlocks.Cardinality() >= maxKnownBlocks {
- p.knownBlocks.Pop()
- }
- p.knownBlocks.Add(block.Hash())
- default:
- p.Log().Debug("Dropping block announcement", "number", block.NumberU64(), "hash", block.Hash())
- }
-}
-
-// SendNewBlock propagates an entire block to a remote peer.
-func (p *peer) SendNewBlock(block *types.Block, td *big.Int) error {
- // Mark all the block hash as known, but ensure we don't overflow our limits
- for p.knownBlocks.Cardinality() >= maxKnownBlocks {
- p.knownBlocks.Pop()
- }
- p.knownBlocks.Add(block.Hash())
- return p2p.Send(p.rw, NewBlockMsg, []interface{}{block, td})
-}
-
-// AsyncSendNewBlock queues an entire block for propagation to a remote peer. If
-// the peer's broadcast queue is full, the event is silently dropped.
-func (p *peer) AsyncSendNewBlock(block *types.Block, td *big.Int) {
- select {
- case p.queuedBlocks <- &propEvent{block: block, td: td}:
- // Mark all the block hash as known, but ensure we don't overflow our limits
- for p.knownBlocks.Cardinality() >= maxKnownBlocks {
- p.knownBlocks.Pop()
- }
- p.knownBlocks.Add(block.Hash())
- default:
- p.Log().Debug("Dropping block propagation", "number", block.NumberU64(), "hash", block.Hash())
- }
-}
-
-// SendBlockHeaders sends a batch of block headers to the remote peer.
-func (p *peer) SendBlockHeaders(headers []*types.Header) error {
- return p2p.Send(p.rw, BlockHeadersMsg, headers)
-}
-
-// SendBlockBodies sends a batch of block contents to the remote peer.
-func (p *peer) SendBlockBodies(bodies []*blockBody) error {
- return p2p.Send(p.rw, BlockBodiesMsg, blockBodiesData(bodies))
-}
-
-// SendBlockBodiesRLP sends a batch of block contents to the remote peer from
-// an already RLP encoded format.
-func (p *peer) SendBlockBodiesRLP(bodies []rlp.RawValue) error {
- return p2p.Send(p.rw, BlockBodiesMsg, bodies)
-}
-
-// SendNodeDataRLP sends a batch of arbitrary internal data, corresponding to the
-// hashes requested.
-func (p *peer) SendNodeData(data [][]byte) error {
- return p2p.Send(p.rw, NodeDataMsg, data)
-}
-
-// SendReceiptsRLP sends a batch of transaction receipts, corresponding to the
-// ones requested from an already RLP encoded format.
-func (p *peer) SendReceiptsRLP(receipts []rlp.RawValue) error {
- return p2p.Send(p.rw, ReceiptsMsg, receipts)
-}
-
-// RequestOneHeader is a wrapper around the header query functions to fetch a
-// single header. It is used solely by the fetcher.
-func (p *peer) RequestOneHeader(hash common.Hash) error {
- p.Log().Debug("Fetching single header", "hash", hash)
- return p2p.Send(p.rw, GetBlockHeadersMsg, &getBlockHeadersData{Origin: hashOrNumber{Hash: hash}, Amount: uint64(1), Skip: uint64(0), Reverse: false})
-}
-
-// RequestHeadersByHash fetches a batch of blocks' headers corresponding to the
-// specified header query, based on the hash of an origin block.
-func (p *peer) RequestHeadersByHash(origin common.Hash, amount int, skip int, reverse bool) error {
- p.Log().Debug("Fetching batch of headers", "count", amount, "fromhash", origin, "skip", skip, "reverse", reverse)
- return p2p.Send(p.rw, GetBlockHeadersMsg, &getBlockHeadersData{Origin: hashOrNumber{Hash: origin}, Amount: uint64(amount), Skip: uint64(skip), Reverse: reverse})
-}
-
-// RequestHeadersByNumber fetches a batch of blocks' headers corresponding to the
-// specified header query, based on the number of an origin block.
-func (p *peer) RequestHeadersByNumber(origin uint64, amount int, skip int, reverse bool) error {
- p.Log().Debug("Fetching batch of headers", "count", amount, "fromnum", origin, "skip", skip, "reverse", reverse)
- return p2p.Send(p.rw, GetBlockHeadersMsg, &getBlockHeadersData{Origin: hashOrNumber{Number: origin}, Amount: uint64(amount), Skip: uint64(skip), Reverse: reverse})
-}
-
-// RequestBodies fetches a batch of blocks' bodies corresponding to the hashes
-// specified.
-func (p *peer) RequestBodies(hashes []common.Hash) error {
- p.Log().Debug("Fetching batch of block bodies", "count", len(hashes))
- return p2p.Send(p.rw, GetBlockBodiesMsg, hashes)
-}
-
-// RequestNodeData fetches a batch of arbitrary data from a node's known state
-// data, corresponding to the specified hashes.
-func (p *peer) RequestNodeData(hashes []common.Hash) error {
- p.Log().Debug("Fetching batch of state data", "count", len(hashes))
- return p2p.Send(p.rw, GetNodeDataMsg, hashes)
-}
-
-// RequestReceipts fetches a batch of transaction receipts from a remote node.
-func (p *peer) RequestReceipts(hashes []common.Hash) error {
- p.Log().Debug("Fetching batch of receipts", "count", len(hashes))
- return p2p.Send(p.rw, GetReceiptsMsg, hashes)
-}
-
-// RequestTxs fetches a batch of transactions from a remote node.
-func (p *peer) RequestTxs(hashes []common.Hash) error {
- p.Log().Debug("Fetching batch of transactions", "count", len(hashes))
- return p2p.Send(p.rw, GetPooledTransactionsMsg, hashes)
-}
-
-// Handshake executes the eth protocol handshake, negotiating version number,
-// network IDs, difficulties, head and genesis blocks.
-func (p *peer) Handshake(network uint64, td *big.Int, head common.Hash, genesis common.Hash, forkID forkid.ID, forkFilter forkid.Filter) error {
- // Send out own handshake in a new thread
- errc := make(chan error, 2)
-
- var (
- status63 statusData63 // safe to read after two values have been received from errc
- status statusData // safe to read after two values have been received from errc
- )
- go func() {
- switch {
- case p.version == eth63:
- errc <- p2p.Send(p.rw, StatusMsg, &statusData63{
- ProtocolVersion: uint32(p.version),
- NetworkId: network,
- TD: td,
- CurrentBlock: head,
- GenesisBlock: genesis,
- })
- case p.version >= eth64:
- errc <- p2p.Send(p.rw, StatusMsg, &statusData{
- ProtocolVersion: uint32(p.version),
- NetworkID: network,
- TD: td,
- Head: head,
- Genesis: genesis,
- ForkID: forkID,
- })
- default:
- panic(fmt.Sprintf("unsupported eth protocol version: %d", p.version))
- }
- }()
- go func() {
- switch {
- case p.version == eth63:
- errc <- p.readStatusLegacy(network, &status63, genesis)
- case p.version >= eth64:
- errc <- p.readStatus(network, &status, genesis, forkFilter)
- default:
- panic(fmt.Sprintf("unsupported eth protocol version: %d", p.version))
- }
- }()
- timeout := time.NewTimer(handshakeTimeout)
- defer timeout.Stop()
- for i := 0; i < 2; i++ {
- select {
- case err := <-errc:
- if err != nil {
- return err
- }
- case <-timeout.C:
- return p2p.DiscReadTimeout
- }
- }
- switch {
- case p.version == eth63:
- p.td, p.head = status63.TD, status63.CurrentBlock
- case p.version >= eth64:
- p.td, p.head = status.TD, status.Head
- default:
- panic(fmt.Sprintf("unsupported eth protocol version: %d", p.version))
- }
- return nil
-}
-
-func (p *peer) readStatusLegacy(network uint64, status *statusData63, genesis common.Hash) error {
- msg, err := p.rw.ReadMsg()
- if err != nil {
- return err
- }
- if msg.Code != StatusMsg {
- return errResp(ErrNoStatusMsg, "first msg has code %x (!= %x)", msg.Code, StatusMsg)
- }
- if msg.Size > protocolMaxMsgSize {
- return errResp(ErrMsgTooLarge, "%v > %v", msg.Size, protocolMaxMsgSize)
- }
- // Decode the handshake and make sure everything matches
- if err := msg.Decode(&status); err != nil {
- return errResp(ErrDecode, "msg %v: %v", msg, err)
- }
- if status.GenesisBlock != genesis {
- return errResp(ErrGenesisMismatch, "%x (!= %x)", status.GenesisBlock[:8], genesis[:8])
- }
- if status.NetworkId != network {
- return errResp(ErrNetworkIDMismatch, "%d (!= %d)", status.NetworkId, network)
- }
- if int(status.ProtocolVersion) != p.version {
- return errResp(ErrProtocolVersionMismatch, "%d (!= %d)", status.ProtocolVersion, p.version)
- }
- return nil
-}
-
-func (p *peer) readStatus(network uint64, status *statusData, genesis common.Hash, forkFilter forkid.Filter) error {
- msg, err := p.rw.ReadMsg()
- if err != nil {
- return err
- }
- if msg.Code != StatusMsg {
- return errResp(ErrNoStatusMsg, "first msg has code %x (!= %x)", msg.Code, StatusMsg)
- }
- if msg.Size > protocolMaxMsgSize {
- return errResp(ErrMsgTooLarge, "%v > %v", msg.Size, protocolMaxMsgSize)
- }
- // Decode the handshake and make sure everything matches
- if err := msg.Decode(&status); err != nil {
- return errResp(ErrDecode, "msg %v: %v", msg, err)
- }
- if status.NetworkID != network {
- return errResp(ErrNetworkIDMismatch, "%d (!= %d)", status.NetworkID, network)
- }
- if int(status.ProtocolVersion) != p.version {
- return errResp(ErrProtocolVersionMismatch, "%d (!= %d)", status.ProtocolVersion, p.version)
- }
- if status.Genesis != genesis {
- return errResp(ErrGenesisMismatch, "%x (!= %x)", status.Genesis, genesis)
- }
- if err := forkFilter(status.ForkID); err != nil {
- return errResp(ErrForkIDRejected, "%v", err)
- }
- return nil
-}
-
-// String implements fmt.Stringer.
-func (p *peer) String() string {
- return fmt.Sprintf("Peer %s [%s]", p.id,
- fmt.Sprintf("eth/%2d", p.version),
- )
-}
-
-// peerSet represents the collection of active peers currently participating in
-// the Ethereum sub-protocol.
-type peerSet struct {
- peers map[string]*peer
- lock sync.RWMutex
- closed bool
-}
-
-// newPeerSet creates a new peer set to track the active participants.
-func newPeerSet() *peerSet {
- return &peerSet{
- peers: make(map[string]*peer),
- }
-}
-
-// Register injects a new peer into the working set, or returns an error if the
-// peer is already known. If a new peer it registered, its broadcast loop is also
-// started.
-func (ps *peerSet) Register(p *peer, removePeer func(string)) error {
- ps.lock.Lock()
- defer ps.lock.Unlock()
-
- if ps.closed {
- return errClosed
- }
- if _, ok := ps.peers[p.id]; ok {
- return errAlreadyRegistered
- }
- ps.peers[p.id] = p
-
- go p.broadcastBlocks(removePeer)
- go p.broadcastTransactions(removePeer)
- if p.version >= eth65 {
- go p.announceTransactions(removePeer)
- }
- return nil
-}
-
-// Unregister removes a remote peer from the active set, disabling any further
-// actions to/from that particular entity.
-func (ps *peerSet) Unregister(id string) error {
- ps.lock.Lock()
- defer ps.lock.Unlock()
-
- p, ok := ps.peers[id]
- if !ok {
- return errNotRegistered
- }
- delete(ps.peers, id)
- p.close()
-
- return nil
-}
-
-// Peer retrieves the registered peer with the given id.
-func (ps *peerSet) Peer(id string) *peer {
- ps.lock.RLock()
- defer ps.lock.RUnlock()
-
- return ps.peers[id]
-}
-
-// Len returns if the current number of peers in the set.
-func (ps *peerSet) Len() int {
- ps.lock.RLock()
- defer ps.lock.RUnlock()
-
- return len(ps.peers)
-}
-
-// PeersWithoutBlock retrieves a list of peers that do not have a given block in
-// their set of known hashes.
-func (ps *peerSet) PeersWithoutBlock(hash common.Hash) []*peer {
- ps.lock.RLock()
- defer ps.lock.RUnlock()
-
- list := make([]*peer, 0, len(ps.peers))
- for _, p := range ps.peers {
- if !p.knownBlocks.Contains(hash) {
- list = append(list, p)
- }
- }
- return list
-}
-
-// PeersWithoutTx retrieves a list of peers that do not have a given transaction
-// in their set of known hashes.
-func (ps *peerSet) PeersWithoutTx(hash common.Hash) []*peer {
- ps.lock.RLock()
- defer ps.lock.RUnlock()
-
- list := make([]*peer, 0, len(ps.peers))
- for _, p := range ps.peers {
- if !p.knownTxs.Contains(hash) {
- list = append(list, p)
- }
- }
- return list
-}
-
-// BestPeer retrieves the known peer with the currently highest total difficulty.
-func (ps *peerSet) BestPeer() *peer {
- ps.lock.RLock()
- defer ps.lock.RUnlock()
-
- var (
- bestPeer *peer
- bestTd *big.Int
- )
- for _, p := range ps.peers {
- if _, td := p.Head(); bestPeer == nil || td.Cmp(bestTd) > 0 {
- bestPeer, bestTd = p, td
- }
- }
- return bestPeer
-}
-
-// Close disconnects all peers.
-// No new peers can be registered after Close has returned.
-func (ps *peerSet) Close() {
- ps.lock.Lock()
- defer ps.lock.Unlock()
-
- for _, p := range ps.peers {
- p.Disconnect(p2p.DiscQuitting)
- }
- ps.closed = true
-}
diff --git a/eth/peerset.go b/eth/peerset.go
new file mode 100644
index 000000000..9b584ec32
--- /dev/null
+++ b/eth/peerset.go
@@ -0,0 +1,301 @@
+// Copyright 2020 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see .
+
+package eth
+
+import (
+ "errors"
+ "math/big"
+ "sync"
+ "time"
+
+ "github.com/ethereum/go-ethereum/common"
+ "github.com/ethereum/go-ethereum/eth/protocols/eth"
+ "github.com/ethereum/go-ethereum/eth/protocols/snap"
+ "github.com/ethereum/go-ethereum/event"
+ "github.com/ethereum/go-ethereum/p2p"
+)
+
+var (
+ // errPeerSetClosed is returned if a peer is attempted to be added or removed
+ // from the peer set after it has been terminated.
+ errPeerSetClosed = errors.New("peerset closed")
+
+ // errPeerAlreadyRegistered is returned if a peer is attempted to be added
+ // to the peer set, but one with the same id already exists.
+ errPeerAlreadyRegistered = errors.New("peer already registered")
+
+ // errPeerNotRegistered is returned if a peer is attempted to be removed from
+ // a peer set, but no peer with the given id exists.
+ errPeerNotRegistered = errors.New("peer not registered")
+
+ // ethConnectTimeout is the `snap` timeout for `eth` to connect too.
+ ethConnectTimeout = 3 * time.Second
+)
+
+// peerSet represents the collection of active peers currently participating in
+// the `eth` or `snap` protocols.
+type peerSet struct {
+ ethPeers map[string]*ethPeer // Peers connected on the `eth` protocol
+ snapPeers map[string]*snapPeer // Peers connected on the `snap` protocol
+
+ ethJoinFeed event.Feed // Events when an `eth` peer successfully joins
+ ethDropFeed event.Feed // Events when an `eth` peer gets dropped
+ snapJoinFeed event.Feed // Events when a `snap` peer joins on both `eth` and `snap`
+ snapDropFeed event.Feed // Events when a `snap` peer gets dropped (only if fully joined)
+
+ scope event.SubscriptionScope // Subscription group to unsubscribe everyone at once
+
+ lock sync.RWMutex
+ closed bool
+}
+
+// newPeerSet creates a new peer set to track the active participants.
+func newPeerSet() *peerSet {
+ return &peerSet{
+ ethPeers: make(map[string]*ethPeer),
+ snapPeers: make(map[string]*snapPeer),
+ }
+}
+
+// subscribeEthJoin registers a subscription for peers joining (and completing
+// the handshake) on the `eth` protocol.
+func (ps *peerSet) subscribeEthJoin(ch chan<- *eth.Peer) event.Subscription {
+ return ps.scope.Track(ps.ethJoinFeed.Subscribe(ch))
+}
+
+// subscribeEthDrop registers a subscription for peers being dropped from the
+// `eth` protocol.
+func (ps *peerSet) subscribeEthDrop(ch chan<- *eth.Peer) event.Subscription {
+ return ps.scope.Track(ps.ethDropFeed.Subscribe(ch))
+}
+
+// subscribeSnapJoin registers a subscription for peers joining (and completing
+// the `eth` join) on the `snap` protocol.
+func (ps *peerSet) subscribeSnapJoin(ch chan<- *snap.Peer) event.Subscription {
+ return ps.scope.Track(ps.snapJoinFeed.Subscribe(ch))
+}
+
+// subscribeSnapDrop registers a subscription for peers being dropped from the
+// `snap` protocol.
+func (ps *peerSet) subscribeSnapDrop(ch chan<- *snap.Peer) event.Subscription {
+ return ps.scope.Track(ps.snapDropFeed.Subscribe(ch))
+}
+
+// registerEthPeer injects a new `eth` peer into the working set, or returns an
+// error if the peer is already known. The peer is announced on the `eth` join
+// feed and if it completes a pending `snap` peer, also on that feed.
+func (ps *peerSet) registerEthPeer(peer *eth.Peer) error {
+ ps.lock.Lock()
+ if ps.closed {
+ ps.lock.Unlock()
+ return errPeerSetClosed
+ }
+ id := peer.ID()
+ if _, ok := ps.ethPeers[id]; ok {
+ ps.lock.Unlock()
+ return errPeerAlreadyRegistered
+ }
+ ps.ethPeers[id] = ðPeer{Peer: peer}
+
+ snap, ok := ps.snapPeers[id]
+ ps.lock.Unlock()
+
+ if ok {
+ // Previously dangling `snap` peer, stop it's timer since `eth` connected
+ snap.lock.Lock()
+ if snap.ethDrop != nil {
+ snap.ethDrop.Stop()
+ snap.ethDrop = nil
+ }
+ snap.lock.Unlock()
+ }
+ ps.ethJoinFeed.Send(peer)
+ if ok {
+ ps.snapJoinFeed.Send(snap.Peer)
+ }
+ return nil
+}
+
+// unregisterEthPeer removes a remote peer from the active set, disabling any further
+// actions to/from that particular entity. The drop is announced on the `eth` drop
+// feed and also on the `snap` feed if the eth/snap duality was broken just now.
+func (ps *peerSet) unregisterEthPeer(id string) error {
+ ps.lock.Lock()
+ eth, ok := ps.ethPeers[id]
+ if !ok {
+ ps.lock.Unlock()
+ return errPeerNotRegistered
+ }
+ delete(ps.ethPeers, id)
+
+ snap, ok := ps.snapPeers[id]
+ ps.lock.Unlock()
+
+ ps.ethDropFeed.Send(eth)
+ if ok {
+ ps.snapDropFeed.Send(snap)
+ }
+ return nil
+}
+
+// registerSnapPeer injects a new `snap` peer into the working set, or returns
+// an error if the peer is already known. The peer is announced on the `snap`
+// join feed if it completes an existing `eth` peer.
+//
+// If the peer isn't yet connected on `eth` and fails to do so within a given
+// amount of time, it is dropped. This enforces that `snap` is an extension to
+// `eth`, not a standalone leeching protocol.
+func (ps *peerSet) registerSnapPeer(peer *snap.Peer) error {
+ ps.lock.Lock()
+ if ps.closed {
+ ps.lock.Unlock()
+ return errPeerSetClosed
+ }
+ id := peer.ID()
+ if _, ok := ps.snapPeers[id]; ok {
+ ps.lock.Unlock()
+ return errPeerAlreadyRegistered
+ }
+ ps.snapPeers[id] = &snapPeer{Peer: peer}
+
+ _, ok := ps.ethPeers[id]
+ if !ok {
+ // Dangling `snap` peer, start a timer to drop if `eth` doesn't connect
+ ps.snapPeers[id].ethDrop = time.AfterFunc(ethConnectTimeout, func() {
+ peer.Log().Warn("Snapshot peer missing eth, dropping", "addr", peer.RemoteAddr(), "type", peer.Name())
+ peer.Disconnect(p2p.DiscUselessPeer)
+ })
+ }
+ ps.lock.Unlock()
+
+ if ok {
+ ps.snapJoinFeed.Send(peer)
+ }
+ return nil
+}
+
+// unregisterSnapPeer removes a remote peer from the active set, disabling any
+// further actions to/from that particular entity. The drop is announced on the
+// `snap` drop feed.
+func (ps *peerSet) unregisterSnapPeer(id string) error {
+ ps.lock.Lock()
+ peer, ok := ps.snapPeers[id]
+ if !ok {
+ ps.lock.Unlock()
+ return errPeerNotRegistered
+ }
+ delete(ps.snapPeers, id)
+ ps.lock.Unlock()
+
+ peer.lock.Lock()
+ if peer.ethDrop != nil {
+ peer.ethDrop.Stop()
+ peer.ethDrop = nil
+ }
+ peer.lock.Unlock()
+
+ ps.snapDropFeed.Send(peer)
+ return nil
+}
+
+// ethPeer retrieves the registered `eth` peer with the given id.
+func (ps *peerSet) ethPeer(id string) *ethPeer {
+ ps.lock.RLock()
+ defer ps.lock.RUnlock()
+
+ return ps.ethPeers[id]
+}
+
+// snapPeer retrieves the registered `snap` peer with the given id.
+func (ps *peerSet) snapPeer(id string) *snapPeer {
+ ps.lock.RLock()
+ defer ps.lock.RUnlock()
+
+ return ps.snapPeers[id]
+}
+
+// ethPeersWithoutBlock retrieves a list of `eth` peers that do not have a given
+// block in their set of known hashes so it might be propagated to them.
+func (ps *peerSet) ethPeersWithoutBlock(hash common.Hash) []*ethPeer {
+ ps.lock.RLock()
+ defer ps.lock.RUnlock()
+
+ list := make([]*ethPeer, 0, len(ps.ethPeers))
+ for _, p := range ps.ethPeers {
+ if !p.KnownBlock(hash) {
+ list = append(list, p)
+ }
+ }
+ return list
+}
+
+// ethPeersWithoutTransacion retrieves a list of `eth` peers that do not have a
+// given transaction in their set of known hashes.
+func (ps *peerSet) ethPeersWithoutTransacion(hash common.Hash) []*ethPeer {
+ ps.lock.RLock()
+ defer ps.lock.RUnlock()
+
+ list := make([]*ethPeer, 0, len(ps.ethPeers))
+ for _, p := range ps.ethPeers {
+ if !p.KnownTransaction(hash) {
+ list = append(list, p)
+ }
+ }
+ return list
+}
+
+// Len returns if the current number of `eth` peers in the set. Since the `snap`
+// peers are tied to the existnce of an `eth` connection, that will always be a
+// subset of `eth`.
+func (ps *peerSet) Len() int {
+ ps.lock.RLock()
+ defer ps.lock.RUnlock()
+
+ return len(ps.ethPeers)
+}
+
+// ethPeerWithHighestTD retrieves the known peer with the currently highest total
+// difficulty.
+func (ps *peerSet) ethPeerWithHighestTD() *eth.Peer {
+ ps.lock.RLock()
+ defer ps.lock.RUnlock()
+
+ var (
+ bestPeer *eth.Peer
+ bestTd *big.Int
+ )
+ for _, p := range ps.ethPeers {
+ if _, td := p.Head(); bestPeer == nil || td.Cmp(bestTd) > 0 {
+ bestPeer, bestTd = p.Peer, td
+ }
+ }
+ return bestPeer
+}
+
+// close disconnects all peers.
+func (ps *peerSet) close() {
+ ps.lock.Lock()
+ defer ps.lock.Unlock()
+
+ for _, p := range ps.ethPeers {
+ p.Disconnect(p2p.DiscQuitting)
+ }
+ for _, p := range ps.snapPeers {
+ p.Disconnect(p2p.DiscQuitting)
+ }
+ ps.closed = true
+}
diff --git a/eth/protocol.go b/eth/protocol.go
deleted file mode 100644
index dc75d6b31..000000000
--- a/eth/protocol.go
+++ /dev/null
@@ -1,221 +0,0 @@
-// Copyright 2014 The go-ethereum Authors
-// This file is part of the go-ethereum library.
-//
-// The go-ethereum library is free software: you can redistribute it and/or modify
-// it under the terms of the GNU Lesser General Public License as published by
-// the Free Software Foundation, either version 3 of the License, or
-// (at your option) any later version.
-//
-// The go-ethereum library is distributed in the hope that it will be useful,
-// but WITHOUT ANY WARRANTY; without even the implied warranty of
-// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
-// GNU Lesser General Public License for more details.
-//
-// You should have received a copy of the GNU Lesser General Public License
-// along with the go-ethereum library. If not, see .
-
-package eth
-
-import (
- "fmt"
- "io"
- "math/big"
-
- "github.com/ethereum/go-ethereum/common"
- "github.com/ethereum/go-ethereum/core"
- "github.com/ethereum/go-ethereum/core/forkid"
- "github.com/ethereum/go-ethereum/core/types"
- "github.com/ethereum/go-ethereum/event"
- "github.com/ethereum/go-ethereum/rlp"
-)
-
-// Constants to match up protocol versions and messages
-const (
- eth63 = 63
- eth64 = 64
- eth65 = 65
-)
-
-// protocolName is the official short name of the protocol used during capability negotiation.
-const protocolName = "eth"
-
-// ProtocolVersions are the supported versions of the eth protocol (first is primary).
-var ProtocolVersions = []uint{eth65, eth64, eth63}
-
-// protocolLengths are the number of implemented message corresponding to different protocol versions.
-var protocolLengths = map[uint]uint64{eth65: 17, eth64: 17, eth63: 17}
-
-const protocolMaxMsgSize = 10 * 1024 * 1024 // Maximum cap on the size of a protocol message
-
-// eth protocol message codes
-const (
- StatusMsg = 0x00
- NewBlockHashesMsg = 0x01
- TransactionMsg = 0x02
- GetBlockHeadersMsg = 0x03
- BlockHeadersMsg = 0x04
- GetBlockBodiesMsg = 0x05
- BlockBodiesMsg = 0x06
- NewBlockMsg = 0x07
- GetNodeDataMsg = 0x0d
- NodeDataMsg = 0x0e
- GetReceiptsMsg = 0x0f
- ReceiptsMsg = 0x10
-
- // New protocol message codes introduced in eth65
- //
- // Previously these message ids were used by some legacy and unsupported
- // eth protocols, reown them here.
- NewPooledTransactionHashesMsg = 0x08
- GetPooledTransactionsMsg = 0x09
- PooledTransactionsMsg = 0x0a
-)
-
-type errCode int
-
-const (
- ErrMsgTooLarge = iota
- ErrDecode
- ErrInvalidMsgCode
- ErrProtocolVersionMismatch
- ErrNetworkIDMismatch
- ErrGenesisMismatch
- ErrForkIDRejected
- ErrNoStatusMsg
- ErrExtraStatusMsg
-)
-
-func (e errCode) String() string {
- return errorToString[int(e)]
-}
-
-// XXX change once legacy code is out
-var errorToString = map[int]string{
- ErrMsgTooLarge: "Message too long",
- ErrDecode: "Invalid message",
- ErrInvalidMsgCode: "Invalid message code",
- ErrProtocolVersionMismatch: "Protocol version mismatch",
- ErrNetworkIDMismatch: "Network ID mismatch",
- ErrGenesisMismatch: "Genesis mismatch",
- ErrForkIDRejected: "Fork ID rejected",
- ErrNoStatusMsg: "No status message",
- ErrExtraStatusMsg: "Extra status message",
-}
-
-type txPool interface {
- // Has returns an indicator whether txpool has a transaction
- // cached with the given hash.
- Has(hash common.Hash) bool
-
- // Get retrieves the transaction from local txpool with given
- // tx hash.
- Get(hash common.Hash) *types.Transaction
-
- // AddRemotes should add the given transactions to the pool.
- AddRemotes([]*types.Transaction) []error
-
- // Pending should return pending transactions.
- // The slice should be modifiable by the caller.
- Pending() (map[common.Address]types.Transactions, error)
-
- // SubscribeNewTxsEvent should return an event subscription of
- // NewTxsEvent and send events to the given channel.
- SubscribeNewTxsEvent(chan<- core.NewTxsEvent) event.Subscription
-}
-
-// statusData63 is the network packet for the status message for eth/63.
-type statusData63 struct {
- ProtocolVersion uint32
- NetworkId uint64
- TD *big.Int
- CurrentBlock common.Hash
- GenesisBlock common.Hash
-}
-
-// statusData is the network packet for the status message for eth/64 and later.
-type statusData struct {
- ProtocolVersion uint32
- NetworkID uint64
- TD *big.Int
- Head common.Hash
- Genesis common.Hash
- ForkID forkid.ID
-}
-
-// newBlockHashesData is the network packet for the block announcements.
-type newBlockHashesData []struct {
- Hash common.Hash // Hash of one particular block being announced
- Number uint64 // Number of one particular block being announced
-}
-
-// getBlockHeadersData represents a block header query.
-type getBlockHeadersData struct {
- Origin hashOrNumber // Block from which to retrieve headers
- Amount uint64 // Maximum number of headers to retrieve
- Skip uint64 // Blocks to skip between consecutive headers
- Reverse bool // Query direction (false = rising towards latest, true = falling towards genesis)
-}
-
-// hashOrNumber is a combined field for specifying an origin block.
-type hashOrNumber struct {
- Hash common.Hash // Block hash from which to retrieve headers (excludes Number)
- Number uint64 // Block hash from which to retrieve headers (excludes Hash)
-}
-
-// EncodeRLP is a specialized encoder for hashOrNumber to encode only one of the
-// two contained union fields.
-func (hn *hashOrNumber) EncodeRLP(w io.Writer) error {
- if hn.Hash == (common.Hash{}) {
- return rlp.Encode(w, hn.Number)
- }
- if hn.Number != 0 {
- return fmt.Errorf("both origin hash (%x) and number (%d) provided", hn.Hash, hn.Number)
- }
- return rlp.Encode(w, hn.Hash)
-}
-
-// DecodeRLP is a specialized decoder for hashOrNumber to decode the contents
-// into either a block hash or a block number.
-func (hn *hashOrNumber) DecodeRLP(s *rlp.Stream) error {
- _, size, _ := s.Kind()
- origin, err := s.Raw()
- if err == nil {
- switch {
- case size == 32:
- err = rlp.DecodeBytes(origin, &hn.Hash)
- case size <= 8:
- err = rlp.DecodeBytes(origin, &hn.Number)
- default:
- err = fmt.Errorf("invalid input size %d for origin", size)
- }
- }
- return err
-}
-
-// newBlockData is the network packet for the block propagation message.
-type newBlockData struct {
- Block *types.Block
- TD *big.Int
-}
-
-// sanityCheck verifies that the values are reasonable, as a DoS protection
-func (request *newBlockData) sanityCheck() error {
- if err := request.Block.SanityCheck(); err != nil {
- return err
- }
- //TD at mainnet block #7753254 is 76 bits. If it becomes 100 million times
- // larger, it will still fit within 100 bits
- if tdlen := request.TD.BitLen(); tdlen > 100 {
- return fmt.Errorf("too large block TD: bitlen %d", tdlen)
- }
- return nil
-}
-
-// blockBody represents the data content of a single block.
-type blockBody struct {
- Transactions []*types.Transaction // Transactions contained within a block
- Uncles []*types.Header // Uncles contained within a block
-}
-
-// blockBodiesData is the network packet for block content distribution.
-type blockBodiesData []*blockBody
diff --git a/eth/protocol_test.go b/eth/protocol_test.go
deleted file mode 100644
index 331dd05ce..000000000
--- a/eth/protocol_test.go
+++ /dev/null
@@ -1,459 +0,0 @@
-// Copyright 2014 The go-ethereum Authors
-// This file is part of the go-ethereum library.
-//
-// The go-ethereum library is free software: you can redistribute it and/or modify
-// it under the terms of the GNU Lesser General Public License as published by
-// the Free Software Foundation, either version 3 of the License, or
-// (at your option) any later version.
-//
-// The go-ethereum library is distributed in the hope that it will be useful,
-// but WITHOUT ANY WARRANTY; without even the implied warranty of
-// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
-// GNU Lesser General Public License for more details.
-//
-// You should have received a copy of the GNU Lesser General Public License
-// along with the go-ethereum library. If not, see .
-
-package eth
-
-import (
- "fmt"
- "math/big"
- "sync"
- "sync/atomic"
- "testing"
- "time"
-
- "github.com/ethereum/go-ethereum/common"
- "github.com/ethereum/go-ethereum/consensus/ethash"
- "github.com/ethereum/go-ethereum/core"
- "github.com/ethereum/go-ethereum/core/forkid"
- "github.com/ethereum/go-ethereum/core/rawdb"
- "github.com/ethereum/go-ethereum/core/types"
- "github.com/ethereum/go-ethereum/core/vm"
- "github.com/ethereum/go-ethereum/crypto"
- "github.com/ethereum/go-ethereum/eth/downloader"
- "github.com/ethereum/go-ethereum/event"
- "github.com/ethereum/go-ethereum/p2p"
- "github.com/ethereum/go-ethereum/p2p/enode"
- "github.com/ethereum/go-ethereum/params"
- "github.com/ethereum/go-ethereum/rlp"
-)
-
-func init() {
- // log.Root().SetHandler(log.LvlFilterHandler(log.LvlTrace, log.StreamHandler(os.Stderr, log.TerminalFormat(false))))
-}
-
-var testAccount, _ = crypto.HexToECDSA("b71c71a67e1177ad4e901695e1b4b9ee17ae16c6668d313eac2f96dbcda3f291")
-
-// Tests that handshake failures are detected and reported correctly.
-func TestStatusMsgErrors63(t *testing.T) {
- pm, _ := newTestProtocolManagerMust(t, downloader.FullSync, 0, nil, nil)
- var (
- genesis = pm.blockchain.Genesis()
- head = pm.blockchain.CurrentHeader()
- td = pm.blockchain.GetTd(head.Hash(), head.Number.Uint64())
- )
- defer pm.Stop()
-
- tests := []struct {
- code uint64
- data interface{}
- wantError error
- }{
- {
- code: TransactionMsg, data: []interface{}{},
- wantError: errResp(ErrNoStatusMsg, "first msg has code 2 (!= 0)"),
- },
- {
- code: StatusMsg, data: statusData63{10, DefaultConfig.NetworkId, td, head.Hash(), genesis.Hash()},
- wantError: errResp(ErrProtocolVersionMismatch, "10 (!= %d)", 63),
- },
- {
- code: StatusMsg, data: statusData63{63, 999, td, head.Hash(), genesis.Hash()},
- wantError: errResp(ErrNetworkIDMismatch, "999 (!= %d)", DefaultConfig.NetworkId),
- },
- {
- code: StatusMsg, data: statusData63{63, DefaultConfig.NetworkId, td, head.Hash(), common.Hash{3}},
- wantError: errResp(ErrGenesisMismatch, "0300000000000000 (!= %x)", genesis.Hash().Bytes()[:8]),
- },
- }
- for i, test := range tests {
- p, errc := newTestPeer("peer", 63, pm, false)
- // The send call might hang until reset because
- // the protocol might not read the payload.
- go p2p.Send(p.app, test.code, test.data)
-
- select {
- case err := <-errc:
- if err == nil {
- t.Errorf("test %d: protocol returned nil error, want %q", i, test.wantError)
- } else if err.Error() != test.wantError.Error() {
- t.Errorf("test %d: wrong error: got %q, want %q", i, err, test.wantError)
- }
- case <-time.After(2 * time.Second):
- t.Errorf("protocol did not shut down within 2 seconds")
- }
- p.close()
- }
-}
-
-func TestStatusMsgErrors64(t *testing.T) {
- pm, _ := newTestProtocolManagerMust(t, downloader.FullSync, 0, nil, nil)
- var (
- genesis = pm.blockchain.Genesis()
- head = pm.blockchain.CurrentHeader()
- td = pm.blockchain.GetTd(head.Hash(), head.Number.Uint64())
- forkID = forkid.NewID(pm.blockchain.Config(), pm.blockchain.Genesis().Hash(), pm.blockchain.CurrentHeader().Number.Uint64())
- )
- defer pm.Stop()
-
- tests := []struct {
- code uint64
- data interface{}
- wantError error
- }{
- {
- code: TransactionMsg, data: []interface{}{},
- wantError: errResp(ErrNoStatusMsg, "first msg has code 2 (!= 0)"),
- },
- {
- code: StatusMsg, data: statusData{10, DefaultConfig.NetworkId, td, head.Hash(), genesis.Hash(), forkID},
- wantError: errResp(ErrProtocolVersionMismatch, "10 (!= %d)", 64),
- },
- {
- code: StatusMsg, data: statusData{64, 999, td, head.Hash(), genesis.Hash(), forkID},
- wantError: errResp(ErrNetworkIDMismatch, "999 (!= %d)", DefaultConfig.NetworkId),
- },
- {
- code: StatusMsg, data: statusData{64, DefaultConfig.NetworkId, td, head.Hash(), common.Hash{3}, forkID},
- wantError: errResp(ErrGenesisMismatch, "0300000000000000000000000000000000000000000000000000000000000000 (!= %x)", genesis.Hash()),
- },
- {
- code: StatusMsg, data: statusData{64, DefaultConfig.NetworkId, td, head.Hash(), genesis.Hash(), forkid.ID{Hash: [4]byte{0x00, 0x01, 0x02, 0x03}}},
- wantError: errResp(ErrForkIDRejected, forkid.ErrLocalIncompatibleOrStale.Error()),
- },
- }
- for i, test := range tests {
- p, errc := newTestPeer("peer", 64, pm, false)
- // The send call might hang until reset because
- // the protocol might not read the payload.
- go p2p.Send(p.app, test.code, test.data)
-
- select {
- case err := <-errc:
- if err == nil {
- t.Errorf("test %d: protocol returned nil error, want %q", i, test.wantError)
- } else if err.Error() != test.wantError.Error() {
- t.Errorf("test %d: wrong error: got %q, want %q", i, err, test.wantError)
- }
- case <-time.After(2 * time.Second):
- t.Errorf("protocol did not shut down within 2 seconds")
- }
- p.close()
- }
-}
-
-func TestForkIDSplit(t *testing.T) {
- var (
- engine = ethash.NewFaker()
-
- configNoFork = ¶ms.ChainConfig{HomesteadBlock: big.NewInt(1)}
- configProFork = ¶ms.ChainConfig{
- HomesteadBlock: big.NewInt(1),
- EIP150Block: big.NewInt(2),
- EIP155Block: big.NewInt(2),
- EIP158Block: big.NewInt(2),
- ByzantiumBlock: big.NewInt(3),
- }
- dbNoFork = rawdb.NewMemoryDatabase()
- dbProFork = rawdb.NewMemoryDatabase()
-
- gspecNoFork = &core.Genesis{Config: configNoFork}
- gspecProFork = &core.Genesis{Config: configProFork}
-
- genesisNoFork = gspecNoFork.MustCommit(dbNoFork)
- genesisProFork = gspecProFork.MustCommit(dbProFork)
-
- chainNoFork, _ = core.NewBlockChain(dbNoFork, nil, configNoFork, engine, vm.Config{}, nil, nil)
- chainProFork, _ = core.NewBlockChain(dbProFork, nil, configProFork, engine, vm.Config{}, nil, nil)
-
- blocksNoFork, _ = core.GenerateChain(configNoFork, genesisNoFork, engine, dbNoFork, 2, nil)
- blocksProFork, _ = core.GenerateChain(configProFork, genesisProFork, engine, dbProFork, 2, nil)
-
- ethNoFork, _ = NewProtocolManager(configNoFork, nil, downloader.FullSync, 1, new(event.TypeMux), &testTxPool{pool: make(map[common.Hash]*types.Transaction)}, engine, chainNoFork, dbNoFork, 1, nil)
- ethProFork, _ = NewProtocolManager(configProFork, nil, downloader.FullSync, 1, new(event.TypeMux), &testTxPool{pool: make(map[common.Hash]*types.Transaction)}, engine, chainProFork, dbProFork, 1, nil)
- )
- ethNoFork.Start(1000)
- ethProFork.Start(1000)
-
- // Both nodes should allow the other to connect (same genesis, next fork is the same)
- p2pNoFork, p2pProFork := p2p.MsgPipe()
- peerNoFork := newPeer(64, p2p.NewPeer(enode.ID{1}, "", nil), p2pNoFork, nil)
- peerProFork := newPeer(64, p2p.NewPeer(enode.ID{2}, "", nil), p2pProFork, nil)
-
- errc := make(chan error, 2)
- go func() { errc <- ethNoFork.handle(peerProFork) }()
- go func() { errc <- ethProFork.handle(peerNoFork) }()
-
- select {
- case err := <-errc:
- t.Fatalf("frontier nofork <-> profork failed: %v", err)
- case <-time.After(250 * time.Millisecond):
- p2pNoFork.Close()
- p2pProFork.Close()
- }
- // Progress into Homestead. Fork's match, so we don't care what the future holds
- chainNoFork.InsertChain(blocksNoFork[:1])
- chainProFork.InsertChain(blocksProFork[:1])
-
- p2pNoFork, p2pProFork = p2p.MsgPipe()
- peerNoFork = newPeer(64, p2p.NewPeer(enode.ID{1}, "", nil), p2pNoFork, nil)
- peerProFork = newPeer(64, p2p.NewPeer(enode.ID{2}, "", nil), p2pProFork, nil)
-
- errc = make(chan error, 2)
- go func() { errc <- ethNoFork.handle(peerProFork) }()
- go func() { errc <- ethProFork.handle(peerNoFork) }()
-
- select {
- case err := <-errc:
- t.Fatalf("homestead nofork <-> profork failed: %v", err)
- case <-time.After(250 * time.Millisecond):
- p2pNoFork.Close()
- p2pProFork.Close()
- }
- // Progress into Spurious. Forks mismatch, signalling differing chains, reject
- chainNoFork.InsertChain(blocksNoFork[1:2])
- chainProFork.InsertChain(blocksProFork[1:2])
-
- p2pNoFork, p2pProFork = p2p.MsgPipe()
- peerNoFork = newPeer(64, p2p.NewPeer(enode.ID{1}, "", nil), p2pNoFork, nil)
- peerProFork = newPeer(64, p2p.NewPeer(enode.ID{2}, "", nil), p2pProFork, nil)
-
- errc = make(chan error, 2)
- go func() { errc <- ethNoFork.handle(peerProFork) }()
- go func() { errc <- ethProFork.handle(peerNoFork) }()
-
- select {
- case err := <-errc:
- if want := errResp(ErrForkIDRejected, forkid.ErrLocalIncompatibleOrStale.Error()); err.Error() != want.Error() {
- t.Fatalf("fork ID rejection error mismatch: have %v, want %v", err, want)
- }
- case <-time.After(250 * time.Millisecond):
- t.Fatalf("split peers not rejected")
- }
-}
-
-// This test checks that received transactions are added to the local pool.
-func TestRecvTransactions63(t *testing.T) { testRecvTransactions(t, 63) }
-func TestRecvTransactions64(t *testing.T) { testRecvTransactions(t, 64) }
-func TestRecvTransactions65(t *testing.T) { testRecvTransactions(t, 65) }
-
-func testRecvTransactions(t *testing.T, protocol int) {
- txAdded := make(chan []*types.Transaction)
- pm, _ := newTestProtocolManagerMust(t, downloader.FullSync, 0, nil, txAdded)
- pm.acceptTxs = 1 // mark synced to accept transactions
- p, _ := newTestPeer("peer", protocol, pm, true)
- defer pm.Stop()
- defer p.close()
-
- tx := newTestTransaction(testAccount, 0, 0)
- if err := p2p.Send(p.app, TransactionMsg, []interface{}{tx}); err != nil {
- t.Fatalf("send error: %v", err)
- }
- select {
- case added := <-txAdded:
- if len(added) != 1 {
- t.Errorf("wrong number of added transactions: got %d, want 1", len(added))
- } else if added[0].Hash() != tx.Hash() {
- t.Errorf("added wrong tx hash: got %v, want %v", added[0].Hash(), tx.Hash())
- }
- case <-time.After(2 * time.Second):
- t.Errorf("no NewTxsEvent received within 2 seconds")
- }
-}
-
-// This test checks that pending transactions are sent.
-func TestSendTransactions63(t *testing.T) { testSendTransactions(t, 63) }
-func TestSendTransactions64(t *testing.T) { testSendTransactions(t, 64) }
-func TestSendTransactions65(t *testing.T) { testSendTransactions(t, 65) }
-
-func testSendTransactions(t *testing.T, protocol int) {
- pm, _ := newTestProtocolManagerMust(t, downloader.FullSync, 0, nil, nil)
- defer pm.Stop()
-
- // Fill the pool with big transactions (use a subscription to wait until all
- // the transactions are announced to avoid spurious events causing extra
- // broadcasts).
- const txsize = txsyncPackSize / 10
- alltxs := make([]*types.Transaction, 100)
- for nonce := range alltxs {
- alltxs[nonce] = newTestTransaction(testAccount, uint64(nonce), txsize)
- }
- pm.txpool.AddRemotes(alltxs)
- time.Sleep(100 * time.Millisecond) // Wait until new tx even gets out of the system (lame)
-
- // Connect several peers. They should all receive the pending transactions.
- var wg sync.WaitGroup
- checktxs := func(p *testPeer) {
- defer wg.Done()
- defer p.close()
- seen := make(map[common.Hash]bool)
- for _, tx := range alltxs {
- seen[tx.Hash()] = false
- }
- for n := 0; n < len(alltxs) && !t.Failed(); {
- var forAllHashes func(callback func(hash common.Hash))
- switch protocol {
- case 63:
- fallthrough
- case 64:
- msg, err := p.app.ReadMsg()
- if err != nil {
- t.Errorf("%v: read error: %v", p.Peer, err)
- continue
- } else if msg.Code != TransactionMsg {
- t.Errorf("%v: got code %d, want TxMsg", p.Peer, msg.Code)
- continue
- }
- var txs []*types.Transaction
- if err := msg.Decode(&txs); err != nil {
- t.Errorf("%v: %v", p.Peer, err)
- continue
- }
- forAllHashes = func(callback func(hash common.Hash)) {
- for _, tx := range txs {
- callback(tx.Hash())
- }
- }
- case 65:
- msg, err := p.app.ReadMsg()
- if err != nil {
- t.Errorf("%v: read error: %v", p.Peer, err)
- continue
- } else if msg.Code != NewPooledTransactionHashesMsg {
- t.Errorf("%v: got code %d, want NewPooledTransactionHashesMsg", p.Peer, msg.Code)
- continue
- }
- var hashes []common.Hash
- if err := msg.Decode(&hashes); err != nil {
- t.Errorf("%v: %v", p.Peer, err)
- continue
- }
- forAllHashes = func(callback func(hash common.Hash)) {
- for _, h := range hashes {
- callback(h)
- }
- }
- }
- forAllHashes(func(hash common.Hash) {
- seentx, want := seen[hash]
- if seentx {
- t.Errorf("%v: got tx more than once: %x", p.Peer, hash)
- }
- if !want {
- t.Errorf("%v: got unexpected tx: %x", p.Peer, hash)
- }
- seen[hash] = true
- n++
- })
- }
- }
- for i := 0; i < 3; i++ {
- p, _ := newTestPeer(fmt.Sprintf("peer #%d", i), protocol, pm, true)
- wg.Add(1)
- go checktxs(p)
- }
- wg.Wait()
-}
-
-func TestTransactionPropagation(t *testing.T) { testSyncTransaction(t, true) }
-func TestTransactionAnnouncement(t *testing.T) { testSyncTransaction(t, false) }
-
-func testSyncTransaction(t *testing.T, propagtion bool) {
- // Create a protocol manager for transaction fetcher and sender
- pmFetcher, _ := newTestProtocolManagerMust(t, downloader.FastSync, 0, nil, nil)
- defer pmFetcher.Stop()
- pmSender, _ := newTestProtocolManagerMust(t, downloader.FastSync, 1024, nil, nil)
- pmSender.broadcastTxAnnouncesOnly = !propagtion
- defer pmSender.Stop()
-
- // Sync up the two peers
- io1, io2 := p2p.MsgPipe()
-
- go pmSender.handle(pmSender.newPeer(65, p2p.NewPeer(enode.ID{}, "sender", nil), io2, pmSender.txpool.Get))
- go pmFetcher.handle(pmFetcher.newPeer(65, p2p.NewPeer(enode.ID{}, "fetcher", nil), io1, pmFetcher.txpool.Get))
-
- time.Sleep(250 * time.Millisecond)
- pmFetcher.doSync(peerToSyncOp(downloader.FullSync, pmFetcher.peers.BestPeer()))
- atomic.StoreUint32(&pmFetcher.acceptTxs, 1)
-
- newTxs := make(chan core.NewTxsEvent, 1024)
- sub := pmFetcher.txpool.SubscribeNewTxsEvent(newTxs)
- defer sub.Unsubscribe()
-
- // Fill the pool with new transactions
- alltxs := make([]*types.Transaction, 1024)
- for nonce := range alltxs {
- alltxs[nonce] = newTestTransaction(testAccount, uint64(nonce), 0)
- }
- pmSender.txpool.AddRemotes(alltxs)
-
- var got int
-loop:
- for {
- select {
- case ev := <-newTxs:
- got += len(ev.Txs)
- if got == 1024 {
- break loop
- }
- case <-time.NewTimer(time.Second).C:
- t.Fatal("Failed to retrieve all transaction")
- }
- }
-}
-
-// Tests that the custom union field encoder and decoder works correctly.
-func TestGetBlockHeadersDataEncodeDecode(t *testing.T) {
- // Create a "random" hash for testing
- var hash common.Hash
- for i := range hash {
- hash[i] = byte(i)
- }
- // Assemble some table driven tests
- tests := []struct {
- packet *getBlockHeadersData
- fail bool
- }{
- // Providing the origin as either a hash or a number should both work
- {fail: false, packet: &getBlockHeadersData{Origin: hashOrNumber{Number: 314}}},
- {fail: false, packet: &getBlockHeadersData{Origin: hashOrNumber{Hash: hash}}},
-
- // Providing arbitrary query field should also work
- {fail: false, packet: &getBlockHeadersData{Origin: hashOrNumber{Number: 314}, Amount: 314, Skip: 1, Reverse: true}},
- {fail: false, packet: &getBlockHeadersData{Origin: hashOrNumber{Hash: hash}, Amount: 314, Skip: 1, Reverse: true}},
-
- // Providing both the origin hash and origin number must fail
- {fail: true, packet: &getBlockHeadersData{Origin: hashOrNumber{Hash: hash, Number: 314}}},
- }
- // Iterate over each of the tests and try to encode and then decode
- for i, tt := range tests {
- bytes, err := rlp.EncodeToBytes(tt.packet)
- if err != nil && !tt.fail {
- t.Fatalf("test %d: failed to encode packet: %v", i, err)
- } else if err == nil && tt.fail {
- t.Fatalf("test %d: encode should have failed", i)
- }
- if !tt.fail {
- packet := new(getBlockHeadersData)
- if err := rlp.DecodeBytes(bytes, packet); err != nil {
- t.Fatalf("test %d: failed to decode packet: %v", i, err)
- }
- if packet.Origin.Hash != tt.packet.Origin.Hash || packet.Origin.Number != tt.packet.Origin.Number || packet.Amount != tt.packet.Amount ||
- packet.Skip != tt.packet.Skip || packet.Reverse != tt.packet.Reverse {
- t.Fatalf("test %d: encode decode mismatch: have %+v, want %+v", i, packet, tt.packet)
- }
- }
- }
-}
diff --git a/eth/protocols/eth/broadcast.go b/eth/protocols/eth/broadcast.go
new file mode 100644
index 000000000..2349398fa
--- /dev/null
+++ b/eth/protocols/eth/broadcast.go
@@ -0,0 +1,195 @@
+// Copyright 2019 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see .
+
+package eth
+
+import (
+ "math/big"
+
+ "github.com/ethereum/go-ethereum/common"
+ "github.com/ethereum/go-ethereum/core/types"
+)
+
+const (
+ // This is the target size for the packs of transactions or announcements. A
+ // pack can get larger than this if a single transactions exceeds this size.
+ maxTxPacketSize = 100 * 1024
+)
+
+// blockPropagation is a block propagation event, waiting for its turn in the
+// broadcast queue.
+type blockPropagation struct {
+ block *types.Block
+ td *big.Int
+}
+
+// broadcastBlocks is a write loop that multiplexes blocks and block accouncements
+// to the remote peer. The goal is to have an async writer that does not lock up
+// node internals and at the same time rate limits queued data.
+func (p *Peer) broadcastBlocks() {
+ for {
+ select {
+ case prop := <-p.queuedBlocks:
+ if err := p.SendNewBlock(prop.block, prop.td); err != nil {
+ return
+ }
+ p.Log().Trace("Propagated block", "number", prop.block.Number(), "hash", prop.block.Hash(), "td", prop.td)
+
+ case block := <-p.queuedBlockAnns:
+ if err := p.SendNewBlockHashes([]common.Hash{block.Hash()}, []uint64{block.NumberU64()}); err != nil {
+ return
+ }
+ p.Log().Trace("Announced block", "number", block.Number(), "hash", block.Hash())
+
+ case <-p.term:
+ return
+ }
+ }
+}
+
+// broadcastTransactions is a write loop that schedules transaction broadcasts
+// to the remote peer. The goal is to have an async writer that does not lock up
+// node internals and at the same time rate limits queued data.
+func (p *Peer) broadcastTransactions() {
+ var (
+ queue []common.Hash // Queue of hashes to broadcast as full transactions
+ done chan struct{} // Non-nil if background broadcaster is running
+ fail = make(chan error, 1) // Channel used to receive network error
+ failed bool // Flag whether a send failed, discard everything onward
+ )
+ for {
+ // If there's no in-flight broadcast running, check if a new one is needed
+ if done == nil && len(queue) > 0 {
+ // Pile transaction until we reach our allowed network limit
+ var (
+ hashes []common.Hash
+ txs []*types.Transaction
+ size common.StorageSize
+ )
+ for i := 0; i < len(queue) && size < maxTxPacketSize; i++ {
+ if tx := p.txpool.Get(queue[i]); tx != nil {
+ txs = append(txs, tx)
+ size += tx.Size()
+ }
+ hashes = append(hashes, queue[i])
+ }
+ queue = queue[:copy(queue, queue[len(hashes):])]
+
+ // If there's anything available to transfer, fire up an async writer
+ if len(txs) > 0 {
+ done = make(chan struct{})
+ go func() {
+ if err := p.SendTransactions(txs); err != nil {
+ fail <- err
+ return
+ }
+ close(done)
+ p.Log().Trace("Sent transactions", "count", len(txs))
+ }()
+ }
+ }
+ // Transfer goroutine may or may not have been started, listen for events
+ select {
+ case hashes := <-p.txBroadcast:
+ // If the connection failed, discard all transaction events
+ if failed {
+ continue
+ }
+ // New batch of transactions to be broadcast, queue them (with cap)
+ queue = append(queue, hashes...)
+ if len(queue) > maxQueuedTxs {
+ // Fancy copy and resize to ensure buffer doesn't grow indefinitely
+ queue = queue[:copy(queue, queue[len(queue)-maxQueuedTxs:])]
+ }
+
+ case <-done:
+ done = nil
+
+ case <-fail:
+ failed = true
+
+ case <-p.term:
+ return
+ }
+ }
+}
+
+// announceTransactions is a write loop that schedules transaction broadcasts
+// to the remote peer. The goal is to have an async writer that does not lock up
+// node internals and at the same time rate limits queued data.
+func (p *Peer) announceTransactions() {
+ var (
+ queue []common.Hash // Queue of hashes to announce as transaction stubs
+ done chan struct{} // Non-nil if background announcer is running
+ fail = make(chan error, 1) // Channel used to receive network error
+ failed bool // Flag whether a send failed, discard everything onward
+ )
+ for {
+ // If there's no in-flight announce running, check if a new one is needed
+ if done == nil && len(queue) > 0 {
+ // Pile transaction hashes until we reach our allowed network limit
+ var (
+ hashes []common.Hash
+ pending []common.Hash
+ size common.StorageSize
+ )
+ for i := 0; i < len(queue) && size < maxTxPacketSize; i++ {
+ if p.txpool.Get(queue[i]) != nil {
+ pending = append(pending, queue[i])
+ size += common.HashLength
+ }
+ hashes = append(hashes, queue[i])
+ }
+ queue = queue[:copy(queue, queue[len(hashes):])]
+
+ // If there's anything available to transfer, fire up an async writer
+ if len(pending) > 0 {
+ done = make(chan struct{})
+ go func() {
+ if err := p.sendPooledTransactionHashes(pending); err != nil {
+ fail <- err
+ return
+ }
+ close(done)
+ p.Log().Trace("Sent transaction announcements", "count", len(pending))
+ }()
+ }
+ }
+ // Transfer goroutine may or may not have been started, listen for events
+ select {
+ case hashes := <-p.txAnnounce:
+ // If the connection failed, discard all transaction events
+ if failed {
+ continue
+ }
+ // New batch of transactions to be broadcast, queue them (with cap)
+ queue = append(queue, hashes...)
+ if len(queue) > maxQueuedTxAnns {
+ // Fancy copy and resize to ensure buffer doesn't grow indefinitely
+ queue = queue[:copy(queue, queue[len(queue)-maxQueuedTxs:])]
+ }
+
+ case <-done:
+ done = nil
+
+ case <-fail:
+ failed = true
+
+ case <-p.term:
+ return
+ }
+ }
+}
diff --git a/eth/protocols/eth/discovery.go b/eth/protocols/eth/discovery.go
new file mode 100644
index 000000000..025479b42
--- /dev/null
+++ b/eth/protocols/eth/discovery.go
@@ -0,0 +1,65 @@
+// Copyright 2019 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see .
+
+package eth
+
+import (
+ "github.com/ethereum/go-ethereum/core"
+ "github.com/ethereum/go-ethereum/core/forkid"
+ "github.com/ethereum/go-ethereum/p2p/enode"
+ "github.com/ethereum/go-ethereum/rlp"
+)
+
+// enrEntry is the ENR entry which advertises `eth` protocol on the discovery.
+type enrEntry struct {
+ ForkID forkid.ID // Fork identifier per EIP-2124
+
+ // Ignore additional fields (for forward compatibility).
+ Rest []rlp.RawValue `rlp:"tail"`
+}
+
+// ENRKey implements enr.Entry.
+func (e enrEntry) ENRKey() string {
+ return "eth"
+}
+
+// StartENRUpdater starts the `eth` ENR updater loop, which listens for chain
+// head events and updates the requested node record whenever a fork is passed.
+func StartENRUpdater(chain *core.BlockChain, ln *enode.LocalNode) {
+ var newHead = make(chan core.ChainHeadEvent, 10)
+ sub := chain.SubscribeChainHeadEvent(newHead)
+
+ go func() {
+ defer sub.Unsubscribe()
+ for {
+ select {
+ case <-newHead:
+ ln.Set(currentENREntry(chain))
+ case <-sub.Err():
+ // Would be nice to sync with Stop, but there is no
+ // good way to do that.
+ return
+ }
+ }
+ }()
+}
+
+// currentENREntry constructs an `eth` ENR entry based on the current state of the chain.
+func currentENREntry(chain *core.BlockChain) *enrEntry {
+ return &enrEntry{
+ ForkID: forkid.NewID(chain.Config(), chain.Genesis().Hash(), chain.CurrentHeader().Number.Uint64()),
+ }
+}
diff --git a/eth/protocols/eth/handler.go b/eth/protocols/eth/handler.go
new file mode 100644
index 000000000..25ddcd93e
--- /dev/null
+++ b/eth/protocols/eth/handler.go
@@ -0,0 +1,512 @@
+// Copyright 2020 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see .
+
+package eth
+
+import (
+ "encoding/json"
+ "fmt"
+ "math/big"
+
+ "github.com/ethereum/go-ethereum/common"
+ "github.com/ethereum/go-ethereum/core"
+ "github.com/ethereum/go-ethereum/core/types"
+ "github.com/ethereum/go-ethereum/log"
+ "github.com/ethereum/go-ethereum/p2p"
+ "github.com/ethereum/go-ethereum/p2p/enode"
+ "github.com/ethereum/go-ethereum/p2p/enr"
+ "github.com/ethereum/go-ethereum/params"
+ "github.com/ethereum/go-ethereum/rlp"
+ "github.com/ethereum/go-ethereum/trie"
+)
+
+const (
+ // softResponseLimit is the target maximum size of replies to data retrievals.
+ softResponseLimit = 2 * 1024 * 1024
+
+ // estHeaderSize is the approximate size of an RLP encoded block header.
+ estHeaderSize = 500
+
+ // maxHeadersServe is the maximum number of block headers to serve. This number
+ // is there to limit the number of disk lookups.
+ maxHeadersServe = 1024
+
+ // maxBodiesServe is the maximum number of block bodies to serve. This number
+ // is mostly there to limit the number of disk lookups. With 24KB block sizes
+ // nowadays, the practical limit will always be softResponseLimit.
+ maxBodiesServe = 1024
+
+ // maxNodeDataServe is the maximum number of state trie nodes to serve. This
+ // number is there to limit the number of disk lookups.
+ maxNodeDataServe = 1024
+
+ // maxReceiptsServe is the maximum number of block receipts to serve. This
+ // number is mostly there to limit the number of disk lookups. With block
+ // containing 200+ transactions nowadays, the practical limit will always
+ // be softResponseLimit.
+ maxReceiptsServe = 1024
+)
+
+// Handler is a callback to invoke from an outside runner after the boilerplate
+// exchanges have passed.
+type Handler func(peer *Peer) error
+
+// Backend defines the data retrieval methods to serve remote requests and the
+// callback methods to invoke on remote deliveries.
+type Backend interface {
+ // Chain retrieves the blockchain object to serve data.
+ Chain() *core.BlockChain
+
+ // StateBloom retrieves the bloom filter - if any - for state trie nodes.
+ StateBloom() *trie.SyncBloom
+
+ // TxPool retrieves the transaction pool object to serve data.
+ TxPool() TxPool
+
+ // AcceptTxs retrieves whether transaction processing is enabled on the node
+ // or if inbound transactions should simply be dropped.
+ AcceptTxs() bool
+
+ // RunPeer is invoked when a peer joins on the `eth` protocol. The handler
+ // should do any peer maintenance work, handshakes and validations. If all
+ // is passed, control should be given back to the `handler` to process the
+ // inbound messages going forward.
+ RunPeer(peer *Peer, handler Handler) error
+
+ // PeerInfo retrieves all known `eth` information about a peer.
+ PeerInfo(id enode.ID) interface{}
+
+ // Handle is a callback to be invoked when a data packet is received from
+ // the remote peer. Only packets not consumed by the protocol handler will
+ // be forwarded to the backend.
+ Handle(peer *Peer, packet Packet) error
+}
+
+// TxPool defines the methods needed by the protocol handler to serve transactions.
+type TxPool interface {
+ // Get retrieves the the transaction from the local txpool with the given hash.
+ Get(hash common.Hash) *types.Transaction
+}
+
+// MakeProtocols constructs the P2P protocol definitions for `eth`.
+func MakeProtocols(backend Backend, network uint64, dnsdisc enode.Iterator) []p2p.Protocol {
+ protocols := make([]p2p.Protocol, len(protocolVersions))
+ for i, version := range protocolVersions {
+ version := version // Closure
+
+ protocols[i] = p2p.Protocol{
+ Name: protocolName,
+ Version: version,
+ Length: protocolLengths[version],
+ Run: func(p *p2p.Peer, rw p2p.MsgReadWriter) error {
+ peer := NewPeer(version, p, rw, backend.TxPool())
+ defer peer.Close()
+
+ return backend.RunPeer(peer, func(peer *Peer) error {
+ return Handle(backend, peer)
+ })
+ },
+ NodeInfo: func() interface{} {
+ return nodeInfo(backend.Chain(), network)
+ },
+ PeerInfo: func(id enode.ID) interface{} {
+ return backend.PeerInfo(id)
+ },
+ Attributes: []enr.Entry{currentENREntry(backend.Chain())},
+ DialCandidates: dnsdisc,
+ }
+ }
+ return protocols
+}
+
+// NodeInfo represents a short summary of the `eth` sub-protocol metadata
+// known about the host peer.
+type NodeInfo struct {
+ Network uint64 `json:"network"` // Ethereum network ID (1=Frontier, 2=Morden, Ropsten=3, Rinkeby=4)
+ Difficulty *big.Int `json:"difficulty"` // Total difficulty of the host's blockchain
+ Genesis common.Hash `json:"genesis"` // SHA3 hash of the host's genesis block
+ Config *params.ChainConfig `json:"config"` // Chain configuration for the fork rules
+ Head common.Hash `json:"head"` // Hex hash of the host's best owned block
+}
+
+// nodeInfo retrieves some `eth` protocol metadata about the running host node.
+func nodeInfo(chain *core.BlockChain, network uint64) *NodeInfo {
+ head := chain.CurrentBlock()
+ return &NodeInfo{
+ Network: network,
+ Difficulty: chain.GetTd(head.Hash(), head.NumberU64()),
+ Genesis: chain.Genesis().Hash(),
+ Config: chain.Config(),
+ Head: head.Hash(),
+ }
+}
+
+// Handle is invoked whenever an `eth` connection is made that successfully passes
+// the protocol handshake. This method will keep processing messages until the
+// connection is torn down.
+func Handle(backend Backend, peer *Peer) error {
+ for {
+ if err := handleMessage(backend, peer); err != nil {
+ peer.Log().Debug("Message handling failed in `eth`", "err", err)
+ return err
+ }
+ }
+}
+
+// handleMessage is invoked whenever an inbound message is received from a remote
+// peer. The remote connection is torn down upon returning any error.
+func handleMessage(backend Backend, peer *Peer) error {
+ // Read the next message from the remote peer, and ensure it's fully consumed
+ msg, err := peer.rw.ReadMsg()
+ if err != nil {
+ return err
+ }
+ if msg.Size > maxMessageSize {
+ return fmt.Errorf("%w: %v > %v", errMsgTooLarge, msg.Size, maxMessageSize)
+ }
+ defer msg.Discard()
+
+ // Handle the message depending on its contents
+ switch {
+ case msg.Code == StatusMsg:
+ // Status messages should never arrive after the handshake
+ return fmt.Errorf("%w: uncontrolled status message", errExtraStatusMsg)
+
+ // Block header query, collect the requested headers and reply
+ case msg.Code == GetBlockHeadersMsg:
+ // Decode the complex header query
+ var query GetBlockHeadersPacket
+ if err := msg.Decode(&query); err != nil {
+ return fmt.Errorf("%w: message %v: %v", errDecode, msg, err)
+ }
+ hashMode := query.Origin.Hash != (common.Hash{})
+ first := true
+ maxNonCanonical := uint64(100)
+
+ // Gather headers until the fetch or network limits is reached
+ var (
+ bytes common.StorageSize
+ headers []*types.Header
+ unknown bool
+ lookups int
+ )
+ for !unknown && len(headers) < int(query.Amount) && bytes < softResponseLimit &&
+ len(headers) < maxHeadersServe && lookups < 2*maxHeadersServe {
+ lookups++
+ // Retrieve the next header satisfying the query
+ var origin *types.Header
+ if hashMode {
+ if first {
+ first = false
+ origin = backend.Chain().GetHeaderByHash(query.Origin.Hash)
+ if origin != nil {
+ query.Origin.Number = origin.Number.Uint64()
+ }
+ } else {
+ origin = backend.Chain().GetHeader(query.Origin.Hash, query.Origin.Number)
+ }
+ } else {
+ origin = backend.Chain().GetHeaderByNumber(query.Origin.Number)
+ }
+ if origin == nil {
+ break
+ }
+ headers = append(headers, origin)
+ bytes += estHeaderSize
+
+ // Advance to the next header of the query
+ switch {
+ case hashMode && query.Reverse:
+ // Hash based traversal towards the genesis block
+ ancestor := query.Skip + 1
+ if ancestor == 0 {
+ unknown = true
+ } else {
+ query.Origin.Hash, query.Origin.Number = backend.Chain().GetAncestor(query.Origin.Hash, query.Origin.Number, ancestor, &maxNonCanonical)
+ unknown = (query.Origin.Hash == common.Hash{})
+ }
+ case hashMode && !query.Reverse:
+ // Hash based traversal towards the leaf block
+ var (
+ current = origin.Number.Uint64()
+ next = current + query.Skip + 1
+ )
+ if next <= current {
+ infos, _ := json.MarshalIndent(peer.Peer.Info(), "", " ")
+ peer.Log().Warn("GetBlockHeaders skip overflow attack", "current", current, "skip", query.Skip, "next", next, "attacker", infos)
+ unknown = true
+ } else {
+ if header := backend.Chain().GetHeaderByNumber(next); header != nil {
+ nextHash := header.Hash()
+ expOldHash, _ := backend.Chain().GetAncestor(nextHash, next, query.Skip+1, &maxNonCanonical)
+ if expOldHash == query.Origin.Hash {
+ query.Origin.Hash, query.Origin.Number = nextHash, next
+ } else {
+ unknown = true
+ }
+ } else {
+ unknown = true
+ }
+ }
+ case query.Reverse:
+ // Number based traversal towards the genesis block
+ if query.Origin.Number >= query.Skip+1 {
+ query.Origin.Number -= query.Skip + 1
+ } else {
+ unknown = true
+ }
+
+ case !query.Reverse:
+ // Number based traversal towards the leaf block
+ query.Origin.Number += query.Skip + 1
+ }
+ }
+ return peer.SendBlockHeaders(headers)
+
+ case msg.Code == BlockHeadersMsg:
+ // A batch of headers arrived to one of our previous requests
+ res := new(BlockHeadersPacket)
+ if err := msg.Decode(res); err != nil {
+ return fmt.Errorf("%w: message %v: %v", errDecode, msg, err)
+ }
+ return backend.Handle(peer, res)
+
+ case msg.Code == GetBlockBodiesMsg:
+ // Decode the block body retrieval message
+ var query GetBlockBodiesPacket
+ if err := msg.Decode(&query); err != nil {
+ return fmt.Errorf("%w: message %v: %v", errDecode, msg, err)
+ }
+ // Gather blocks until the fetch or network limits is reached
+ var (
+ bytes int
+ bodies []rlp.RawValue
+ )
+ for lookups, hash := range query {
+ if bytes >= softResponseLimit || len(bodies) >= maxBodiesServe ||
+ lookups >= 2*maxBodiesServe {
+ break
+ }
+ if data := backend.Chain().GetBodyRLP(hash); len(data) != 0 {
+ bodies = append(bodies, data)
+ bytes += len(data)
+ }
+ }
+ return peer.SendBlockBodiesRLP(bodies)
+
+ case msg.Code == BlockBodiesMsg:
+ // A batch of block bodies arrived to one of our previous requests
+ res := new(BlockBodiesPacket)
+ if err := msg.Decode(res); err != nil {
+ return fmt.Errorf("%w: message %v: %v", errDecode, msg, err)
+ }
+ return backend.Handle(peer, res)
+
+ case msg.Code == GetNodeDataMsg:
+ // Decode the trie node data retrieval message
+ var query GetNodeDataPacket
+ if err := msg.Decode(&query); err != nil {
+ return fmt.Errorf("%w: message %v: %v", errDecode, msg, err)
+ }
+ // Gather state data until the fetch or network limits is reached
+ var (
+ bytes int
+ nodes [][]byte
+ )
+ for lookups, hash := range query {
+ if bytes >= softResponseLimit || len(nodes) >= maxNodeDataServe ||
+ lookups >= 2*maxNodeDataServe {
+ break
+ }
+ // Retrieve the requested state entry
+ if bloom := backend.StateBloom(); bloom != nil && !bloom.Contains(hash[:]) {
+ // Only lookup the trie node if there's chance that we actually have it
+ continue
+ }
+ entry, err := backend.Chain().TrieNode(hash)
+ if len(entry) == 0 || err != nil {
+ // Read the contract code with prefix only to save unnecessary lookups.
+ entry, err = backend.Chain().ContractCodeWithPrefix(hash)
+ }
+ if err == nil && len(entry) > 0 {
+ nodes = append(nodes, entry)
+ bytes += len(entry)
+ }
+ }
+ return peer.SendNodeData(nodes)
+
+ case msg.Code == NodeDataMsg:
+ // A batch of node state data arrived to one of our previous requests
+ res := new(NodeDataPacket)
+ if err := msg.Decode(res); err != nil {
+ return fmt.Errorf("%w: message %v: %v", errDecode, msg, err)
+ }
+ return backend.Handle(peer, res)
+
+ case msg.Code == GetReceiptsMsg:
+ // Decode the block receipts retrieval message
+ var query GetReceiptsPacket
+ if err := msg.Decode(&query); err != nil {
+ return fmt.Errorf("%w: message %v: %v", errDecode, msg, err)
+ }
+ // Gather state data until the fetch or network limits is reached
+ var (
+ bytes int
+ receipts []rlp.RawValue
+ )
+ for lookups, hash := range query {
+ if bytes >= softResponseLimit || len(receipts) >= maxReceiptsServe ||
+ lookups >= 2*maxReceiptsServe {
+ break
+ }
+ // Retrieve the requested block's receipts
+ results := backend.Chain().GetReceiptsByHash(hash)
+ if results == nil {
+ if header := backend.Chain().GetHeaderByHash(hash); header == nil || header.ReceiptHash != types.EmptyRootHash {
+ continue
+ }
+ }
+ // If known, encode and queue for response packet
+ if encoded, err := rlp.EncodeToBytes(results); err != nil {
+ log.Error("Failed to encode receipt", "err", err)
+ } else {
+ receipts = append(receipts, encoded)
+ bytes += len(encoded)
+ }
+ }
+ return peer.SendReceiptsRLP(receipts)
+
+ case msg.Code == ReceiptsMsg:
+ // A batch of receipts arrived to one of our previous requests
+ res := new(ReceiptsPacket)
+ if err := msg.Decode(res); err != nil {
+ return fmt.Errorf("%w: message %v: %v", errDecode, msg, err)
+ }
+ return backend.Handle(peer, res)
+
+ case msg.Code == NewBlockHashesMsg:
+ // A batch of new block announcements just arrived
+ ann := new(NewBlockHashesPacket)
+ if err := msg.Decode(ann); err != nil {
+ return fmt.Errorf("%w: message %v: %v", errDecode, msg, err)
+ }
+ // Mark the hashes as present at the remote node
+ for _, block := range *ann {
+ peer.markBlock(block.Hash)
+ }
+ // Deliver them all to the backend for queuing
+ return backend.Handle(peer, ann)
+
+ case msg.Code == NewBlockMsg:
+ // Retrieve and decode the propagated block
+ ann := new(NewBlockPacket)
+ if err := msg.Decode(ann); err != nil {
+ return fmt.Errorf("%w: message %v: %v", errDecode, msg, err)
+ }
+ if hash := types.CalcUncleHash(ann.Block.Uncles()); hash != ann.Block.UncleHash() {
+ log.Warn("Propagated block has invalid uncles", "have", hash, "exp", ann.Block.UncleHash())
+ break // TODO(karalabe): return error eventually, but wait a few releases
+ }
+ if hash := types.DeriveSha(ann.Block.Transactions(), trie.NewStackTrie(nil)); hash != ann.Block.TxHash() {
+ log.Warn("Propagated block has invalid body", "have", hash, "exp", ann.Block.TxHash())
+ break // TODO(karalabe): return error eventually, but wait a few releases
+ }
+ if err := ann.sanityCheck(); err != nil {
+ return err
+ }
+ ann.Block.ReceivedAt = msg.ReceivedAt
+ ann.Block.ReceivedFrom = peer
+
+ // Mark the peer as owning the block
+ peer.markBlock(ann.Block.Hash())
+
+ return backend.Handle(peer, ann)
+
+ case msg.Code == NewPooledTransactionHashesMsg && peer.version >= ETH65:
+ // New transaction announcement arrived, make sure we have
+ // a valid and fresh chain to handle them
+ if !backend.AcceptTxs() {
+ break
+ }
+ ann := new(NewPooledTransactionHashesPacket)
+ if err := msg.Decode(ann); err != nil {
+ return fmt.Errorf("%w: message %v: %v", errDecode, msg, err)
+ }
+ // Schedule all the unknown hashes for retrieval
+ for _, hash := range *ann {
+ peer.markTransaction(hash)
+ }
+ return backend.Handle(peer, ann)
+
+ case msg.Code == GetPooledTransactionsMsg && peer.version >= ETH65:
+ // Decode the pooled transactions retrieval message
+ var query GetPooledTransactionsPacket
+ if err := msg.Decode(&query); err != nil {
+ return fmt.Errorf("%w: message %v: %v", errDecode, msg, err)
+ }
+ // Gather transactions until the fetch or network limits is reached
+ var (
+ bytes int
+ hashes []common.Hash
+ txs []rlp.RawValue
+ )
+ for _, hash := range query {
+ if bytes >= softResponseLimit {
+ break
+ }
+ // Retrieve the requested transaction, skipping if unknown to us
+ tx := backend.TxPool().Get(hash)
+ if tx == nil {
+ continue
+ }
+ // If known, encode and queue for response packet
+ if encoded, err := rlp.EncodeToBytes(tx); err != nil {
+ log.Error("Failed to encode transaction", "err", err)
+ } else {
+ hashes = append(hashes, hash)
+ txs = append(txs, encoded)
+ bytes += len(encoded)
+ }
+ }
+ return peer.SendPooledTransactionsRLP(hashes, txs)
+
+ case msg.Code == TransactionsMsg || (msg.Code == PooledTransactionsMsg && peer.version >= ETH65):
+ // Transactions arrived, make sure we have a valid and fresh chain to handle them
+ if !backend.AcceptTxs() {
+ break
+ }
+ // Transactions can be processed, parse all of them and deliver to the pool
+ var txs []*types.Transaction
+ if err := msg.Decode(&txs); err != nil {
+ return fmt.Errorf("%w: message %v: %v", errDecode, msg, err)
+ }
+ for i, tx := range txs {
+ // Validate and mark the remote transaction
+ if tx == nil {
+ return fmt.Errorf("%w: transaction %d is nil", errDecode, i)
+ }
+ peer.markTransaction(tx.Hash())
+ }
+ if msg.Code == PooledTransactionsMsg {
+ return backend.Handle(peer, (*PooledTransactionsPacket)(&txs))
+ }
+ return backend.Handle(peer, (*TransactionsPacket)(&txs))
+
+ default:
+ return fmt.Errorf("%w: %v", errInvalidMsgCode, msg.Code)
+ }
+ return nil
+}
diff --git a/eth/protocols/eth/handler_test.go b/eth/protocols/eth/handler_test.go
new file mode 100644
index 000000000..65c4a10b0
--- /dev/null
+++ b/eth/protocols/eth/handler_test.go
@@ -0,0 +1,519 @@
+// Copyright 2015 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see .
+
+package eth
+
+import (
+ "math"
+ "math/big"
+ "math/rand"
+ "testing"
+
+ "github.com/ethereum/go-ethereum/common"
+ "github.com/ethereum/go-ethereum/consensus/ethash"
+ "github.com/ethereum/go-ethereum/core"
+ "github.com/ethereum/go-ethereum/core/rawdb"
+ "github.com/ethereum/go-ethereum/core/state"
+ "github.com/ethereum/go-ethereum/core/types"
+ "github.com/ethereum/go-ethereum/core/vm"
+ "github.com/ethereum/go-ethereum/crypto"
+ "github.com/ethereum/go-ethereum/ethdb"
+ "github.com/ethereum/go-ethereum/p2p"
+ "github.com/ethereum/go-ethereum/p2p/enode"
+ "github.com/ethereum/go-ethereum/params"
+ "github.com/ethereum/go-ethereum/trie"
+)
+
+var (
+ // testKey is a private key to use for funding a tester account.
+ testKey, _ = crypto.HexToECDSA("b71c71a67e1177ad4e901695e1b4b9ee17ae16c6668d313eac2f96dbcda3f291")
+
+ // testAddr is the Ethereum address of the tester account.
+ testAddr = crypto.PubkeyToAddress(testKey.PublicKey)
+)
+
+// testBackend is a mock implementation of the live Ethereum message handler. Its
+// purpose is to allow testing the request/reply workflows and wire serialization
+// in the `eth` protocol without actually doing any data processing.
+type testBackend struct {
+ db ethdb.Database
+ chain *core.BlockChain
+ txpool *core.TxPool
+}
+
+// newTestBackend creates an empty chain and wraps it into a mock backend.
+func newTestBackend(blocks int) *testBackend {
+ return newTestBackendWithGenerator(blocks, nil)
+}
+
+// newTestBackend creates a chain with a number of explicitly defined blocks and
+// wraps it into a mock backend.
+func newTestBackendWithGenerator(blocks int, generator func(int, *core.BlockGen)) *testBackend {
+ // Create a database pre-initialize with a genesis block
+ db := rawdb.NewMemoryDatabase()
+ (&core.Genesis{
+ Config: params.TestChainConfig,
+ Alloc: core.GenesisAlloc{testAddr: {Balance: big.NewInt(1000000)}},
+ }).MustCommit(db)
+
+ chain, _ := core.NewBlockChain(db, nil, params.TestChainConfig, ethash.NewFaker(), vm.Config{}, nil, nil)
+
+ bs, _ := core.GenerateChain(params.TestChainConfig, chain.Genesis(), ethash.NewFaker(), db, blocks, generator)
+ if _, err := chain.InsertChain(bs); err != nil {
+ panic(err)
+ }
+ txconfig := core.DefaultTxPoolConfig
+ txconfig.Journal = "" // Don't litter the disk with test journals
+
+ return &testBackend{
+ db: db,
+ chain: chain,
+ txpool: core.NewTxPool(txconfig, params.TestChainConfig, chain),
+ }
+}
+
+// close tears down the transaction pool and chain behind the mock backend.
+func (b *testBackend) close() {
+ b.txpool.Stop()
+ b.chain.Stop()
+}
+
+func (b *testBackend) Chain() *core.BlockChain { return b.chain }
+func (b *testBackend) StateBloom() *trie.SyncBloom { return nil }
+func (b *testBackend) TxPool() TxPool { return b.txpool }
+
+func (b *testBackend) RunPeer(peer *Peer, handler Handler) error {
+ // Normally the backend would do peer mainentance and handshakes. All that
+ // is omitted and we will just give control back to the handler.
+ return handler(peer)
+}
+func (b *testBackend) PeerInfo(enode.ID) interface{} { panic("not implemented") }
+
+func (b *testBackend) AcceptTxs() bool {
+ panic("data processing tests should be done in the handler package")
+}
+func (b *testBackend) Handle(*Peer, Packet) error {
+ panic("data processing tests should be done in the handler package")
+}
+
+// Tests that block headers can be retrieved from a remote chain based on user queries.
+func TestGetBlockHeaders64(t *testing.T) { testGetBlockHeaders(t, 64) }
+func TestGetBlockHeaders65(t *testing.T) { testGetBlockHeaders(t, 65) }
+
+func testGetBlockHeaders(t *testing.T, protocol uint) {
+ t.Parallel()
+
+ backend := newTestBackend(maxHeadersServe + 15)
+ defer backend.close()
+
+ peer, _ := newTestPeer("peer", protocol, backend)
+ defer peer.close()
+
+ // Create a "random" unknown hash for testing
+ var unknown common.Hash
+ for i := range unknown {
+ unknown[i] = byte(i)
+ }
+ // Create a batch of tests for various scenarios
+ limit := uint64(maxHeadersServe)
+ tests := []struct {
+ query *GetBlockHeadersPacket // The query to execute for header retrieval
+ expect []common.Hash // The hashes of the block whose headers are expected
+ }{
+ // A single random block should be retrievable by hash and number too
+ {
+ &GetBlockHeadersPacket{Origin: HashOrNumber{Hash: backend.chain.GetBlockByNumber(limit / 2).Hash()}, Amount: 1},
+ []common.Hash{backend.chain.GetBlockByNumber(limit / 2).Hash()},
+ }, {
+ &GetBlockHeadersPacket{Origin: HashOrNumber{Number: limit / 2}, Amount: 1},
+ []common.Hash{backend.chain.GetBlockByNumber(limit / 2).Hash()},
+ },
+ // Multiple headers should be retrievable in both directions
+ {
+ &GetBlockHeadersPacket{Origin: HashOrNumber{Number: limit / 2}, Amount: 3},
+ []common.Hash{
+ backend.chain.GetBlockByNumber(limit / 2).Hash(),
+ backend.chain.GetBlockByNumber(limit/2 + 1).Hash(),
+ backend.chain.GetBlockByNumber(limit/2 + 2).Hash(),
+ },
+ }, {
+ &GetBlockHeadersPacket{Origin: HashOrNumber{Number: limit / 2}, Amount: 3, Reverse: true},
+ []common.Hash{
+ backend.chain.GetBlockByNumber(limit / 2).Hash(),
+ backend.chain.GetBlockByNumber(limit/2 - 1).Hash(),
+ backend.chain.GetBlockByNumber(limit/2 - 2).Hash(),
+ },
+ },
+ // Multiple headers with skip lists should be retrievable
+ {
+ &GetBlockHeadersPacket{Origin: HashOrNumber{Number: limit / 2}, Skip: 3, Amount: 3},
+ []common.Hash{
+ backend.chain.GetBlockByNumber(limit / 2).Hash(),
+ backend.chain.GetBlockByNumber(limit/2 + 4).Hash(),
+ backend.chain.GetBlockByNumber(limit/2 + 8).Hash(),
+ },
+ }, {
+ &GetBlockHeadersPacket{Origin: HashOrNumber{Number: limit / 2}, Skip: 3, Amount: 3, Reverse: true},
+ []common.Hash{
+ backend.chain.GetBlockByNumber(limit / 2).Hash(),
+ backend.chain.GetBlockByNumber(limit/2 - 4).Hash(),
+ backend.chain.GetBlockByNumber(limit/2 - 8).Hash(),
+ },
+ },
+ // The chain endpoints should be retrievable
+ {
+ &GetBlockHeadersPacket{Origin: HashOrNumber{Number: 0}, Amount: 1},
+ []common.Hash{backend.chain.GetBlockByNumber(0).Hash()},
+ }, {
+ &GetBlockHeadersPacket{Origin: HashOrNumber{Number: backend.chain.CurrentBlock().NumberU64()}, Amount: 1},
+ []common.Hash{backend.chain.CurrentBlock().Hash()},
+ },
+ // Ensure protocol limits are honored
+ {
+ &GetBlockHeadersPacket{Origin: HashOrNumber{Number: backend.chain.CurrentBlock().NumberU64() - 1}, Amount: limit + 10, Reverse: true},
+ backend.chain.GetBlockHashesFromHash(backend.chain.CurrentBlock().Hash(), limit),
+ },
+ // Check that requesting more than available is handled gracefully
+ {
+ &GetBlockHeadersPacket{Origin: HashOrNumber{Number: backend.chain.CurrentBlock().NumberU64() - 4}, Skip: 3, Amount: 3},
+ []common.Hash{
+ backend.chain.GetBlockByNumber(backend.chain.CurrentBlock().NumberU64() - 4).Hash(),
+ backend.chain.GetBlockByNumber(backend.chain.CurrentBlock().NumberU64()).Hash(),
+ },
+ }, {
+ &GetBlockHeadersPacket{Origin: HashOrNumber{Number: 4}, Skip: 3, Amount: 3, Reverse: true},
+ []common.Hash{
+ backend.chain.GetBlockByNumber(4).Hash(),
+ backend.chain.GetBlockByNumber(0).Hash(),
+ },
+ },
+ // Check that requesting more than available is handled gracefully, even if mid skip
+ {
+ &GetBlockHeadersPacket{Origin: HashOrNumber{Number: backend.chain.CurrentBlock().NumberU64() - 4}, Skip: 2, Amount: 3},
+ []common.Hash{
+ backend.chain.GetBlockByNumber(backend.chain.CurrentBlock().NumberU64() - 4).Hash(),
+ backend.chain.GetBlockByNumber(backend.chain.CurrentBlock().NumberU64() - 1).Hash(),
+ },
+ }, {
+ &GetBlockHeadersPacket{Origin: HashOrNumber{Number: 4}, Skip: 2, Amount: 3, Reverse: true},
+ []common.Hash{
+ backend.chain.GetBlockByNumber(4).Hash(),
+ backend.chain.GetBlockByNumber(1).Hash(),
+ },
+ },
+ // Check a corner case where requesting more can iterate past the endpoints
+ {
+ &GetBlockHeadersPacket{Origin: HashOrNumber{Number: 2}, Amount: 5, Reverse: true},
+ []common.Hash{
+ backend.chain.GetBlockByNumber(2).Hash(),
+ backend.chain.GetBlockByNumber(1).Hash(),
+ backend.chain.GetBlockByNumber(0).Hash(),
+ },
+ },
+ // Check a corner case where skipping overflow loops back into the chain start
+ {
+ &GetBlockHeadersPacket{Origin: HashOrNumber{Hash: backend.chain.GetBlockByNumber(3).Hash()}, Amount: 2, Reverse: false, Skip: math.MaxUint64 - 1},
+ []common.Hash{
+ backend.chain.GetBlockByNumber(3).Hash(),
+ },
+ },
+ // Check a corner case where skipping overflow loops back to the same header
+ {
+ &GetBlockHeadersPacket{Origin: HashOrNumber{Hash: backend.chain.GetBlockByNumber(1).Hash()}, Amount: 2, Reverse: false, Skip: math.MaxUint64},
+ []common.Hash{
+ backend.chain.GetBlockByNumber(1).Hash(),
+ },
+ },
+ // Check that non existing headers aren't returned
+ {
+ &GetBlockHeadersPacket{Origin: HashOrNumber{Hash: unknown}, Amount: 1},
+ []common.Hash{},
+ }, {
+ &GetBlockHeadersPacket{Origin: HashOrNumber{Number: backend.chain.CurrentBlock().NumberU64() + 1}, Amount: 1},
+ []common.Hash{},
+ },
+ }
+ // Run each of the tests and verify the results against the chain
+ for i, tt := range tests {
+ // Collect the headers to expect in the response
+ var headers []*types.Header
+ for _, hash := range tt.expect {
+ headers = append(headers, backend.chain.GetBlockByHash(hash).Header())
+ }
+ // Send the hash request and verify the response
+ p2p.Send(peer.app, 0x03, tt.query)
+ if err := p2p.ExpectMsg(peer.app, 0x04, headers); err != nil {
+ t.Errorf("test %d: headers mismatch: %v", i, err)
+ }
+ // If the test used number origins, repeat with hashes as the too
+ if tt.query.Origin.Hash == (common.Hash{}) {
+ if origin := backend.chain.GetBlockByNumber(tt.query.Origin.Number); origin != nil {
+ tt.query.Origin.Hash, tt.query.Origin.Number = origin.Hash(), 0
+
+ p2p.Send(peer.app, 0x03, tt.query)
+ if err := p2p.ExpectMsg(peer.app, 0x04, headers); err != nil {
+ t.Errorf("test %d: headers mismatch: %v", i, err)
+ }
+ }
+ }
+ }
+}
+
+// Tests that block contents can be retrieved from a remote chain based on their hashes.
+func TestGetBlockBodies64(t *testing.T) { testGetBlockBodies(t, 64) }
+func TestGetBlockBodies65(t *testing.T) { testGetBlockBodies(t, 65) }
+
+func testGetBlockBodies(t *testing.T, protocol uint) {
+ t.Parallel()
+
+ backend := newTestBackend(maxBodiesServe + 15)
+ defer backend.close()
+
+ peer, _ := newTestPeer("peer", protocol, backend)
+ defer peer.close()
+
+ // Create a batch of tests for various scenarios
+ limit := maxBodiesServe
+ tests := []struct {
+ random int // Number of blocks to fetch randomly from the chain
+ explicit []common.Hash // Explicitly requested blocks
+ available []bool // Availability of explicitly requested blocks
+ expected int // Total number of existing blocks to expect
+ }{
+ {1, nil, nil, 1}, // A single random block should be retrievable
+ {10, nil, nil, 10}, // Multiple random blocks should be retrievable
+ {limit, nil, nil, limit}, // The maximum possible blocks should be retrievable
+ {limit + 1, nil, nil, limit}, // No more than the possible block count should be returned
+ {0, []common.Hash{backend.chain.Genesis().Hash()}, []bool{true}, 1}, // The genesis block should be retrievable
+ {0, []common.Hash{backend.chain.CurrentBlock().Hash()}, []bool{true}, 1}, // The chains head block should be retrievable
+ {0, []common.Hash{{}}, []bool{false}, 0}, // A non existent block should not be returned
+
+ // Existing and non-existing blocks interleaved should not cause problems
+ {0, []common.Hash{
+ {},
+ backend.chain.GetBlockByNumber(1).Hash(),
+ {},
+ backend.chain.GetBlockByNumber(10).Hash(),
+ {},
+ backend.chain.GetBlockByNumber(100).Hash(),
+ {},
+ }, []bool{false, true, false, true, false, true, false}, 3},
+ }
+ // Run each of the tests and verify the results against the chain
+ for i, tt := range tests {
+ // Collect the hashes to request, and the response to expectva
+ var (
+ hashes []common.Hash
+ bodies []*BlockBody
+ seen = make(map[int64]bool)
+ )
+ for j := 0; j < tt.random; j++ {
+ for {
+ num := rand.Int63n(int64(backend.chain.CurrentBlock().NumberU64()))
+ if !seen[num] {
+ seen[num] = true
+
+ block := backend.chain.GetBlockByNumber(uint64(num))
+ hashes = append(hashes, block.Hash())
+ if len(bodies) < tt.expected {
+ bodies = append(bodies, &BlockBody{Transactions: block.Transactions(), Uncles: block.Uncles()})
+ }
+ break
+ }
+ }
+ }
+ for j, hash := range tt.explicit {
+ hashes = append(hashes, hash)
+ if tt.available[j] && len(bodies) < tt.expected {
+ block := backend.chain.GetBlockByHash(hash)
+ bodies = append(bodies, &BlockBody{Transactions: block.Transactions(), Uncles: block.Uncles()})
+ }
+ }
+ // Send the hash request and verify the response
+ p2p.Send(peer.app, 0x05, hashes)
+ if err := p2p.ExpectMsg(peer.app, 0x06, bodies); err != nil {
+ t.Errorf("test %d: bodies mismatch: %v", i, err)
+ }
+ }
+}
+
+// Tests that the state trie nodes can be retrieved based on hashes.
+func TestGetNodeData64(t *testing.T) { testGetNodeData(t, 64) }
+func TestGetNodeData65(t *testing.T) { testGetNodeData(t, 65) }
+
+func testGetNodeData(t *testing.T, protocol uint) {
+ t.Parallel()
+
+ // Define three accounts to simulate transactions with
+ acc1Key, _ := crypto.HexToECDSA("8a1f9a8f95be41cd7ccb6168179afb4504aefe388d1e14474d32c45c72ce7b7a")
+ acc2Key, _ := crypto.HexToECDSA("49a7b37aa6f6645917e7b807e9d1c00d4fa71f18343b0d4122a4d2df64dd6fee")
+ acc1Addr := crypto.PubkeyToAddress(acc1Key.PublicKey)
+ acc2Addr := crypto.PubkeyToAddress(acc2Key.PublicKey)
+
+ signer := types.HomesteadSigner{}
+ // Create a chain generator with some simple transactions (blatantly stolen from @fjl/chain_markets_test)
+ generator := func(i int, block *core.BlockGen) {
+ switch i {
+ case 0:
+ // In block 1, the test bank sends account #1 some ether.
+ tx, _ := types.SignTx(types.NewTransaction(block.TxNonce(testAddr), acc1Addr, big.NewInt(10000), params.TxGas, nil, nil), signer, testKey)
+ block.AddTx(tx)
+ case 1:
+ // In block 2, the test bank sends some more ether to account #1.
+ // acc1Addr passes it on to account #2.
+ tx1, _ := types.SignTx(types.NewTransaction(block.TxNonce(testAddr), acc1Addr, big.NewInt(1000), params.TxGas, nil, nil), signer, testKey)
+ tx2, _ := types.SignTx(types.NewTransaction(block.TxNonce(acc1Addr), acc2Addr, big.NewInt(1000), params.TxGas, nil, nil), signer, acc1Key)
+ block.AddTx(tx1)
+ block.AddTx(tx2)
+ case 2:
+ // Block 3 is empty but was mined by account #2.
+ block.SetCoinbase(acc2Addr)
+ block.SetExtra([]byte("yeehaw"))
+ case 3:
+ // Block 4 includes blocks 2 and 3 as uncle headers (with modified extra data).
+ b2 := block.PrevBlock(1).Header()
+ b2.Extra = []byte("foo")
+ block.AddUncle(b2)
+ b3 := block.PrevBlock(2).Header()
+ b3.Extra = []byte("foo")
+ block.AddUncle(b3)
+ }
+ }
+ // Assemble the test environment
+ backend := newTestBackendWithGenerator(4, generator)
+ defer backend.close()
+
+ peer, _ := newTestPeer("peer", protocol, backend)
+ defer peer.close()
+
+ // Fetch for now the entire chain db
+ var hashes []common.Hash
+
+ it := backend.db.NewIterator(nil, nil)
+ for it.Next() {
+ if key := it.Key(); len(key) == common.HashLength {
+ hashes = append(hashes, common.BytesToHash(key))
+ }
+ }
+ it.Release()
+
+ p2p.Send(peer.app, 0x0d, hashes)
+ msg, err := peer.app.ReadMsg()
+ if err != nil {
+ t.Fatalf("failed to read node data response: %v", err)
+ }
+ if msg.Code != 0x0e {
+ t.Fatalf("response packet code mismatch: have %x, want %x", msg.Code, 0x0c)
+ }
+ var data [][]byte
+ if err := msg.Decode(&data); err != nil {
+ t.Fatalf("failed to decode response node data: %v", err)
+ }
+ // Verify that all hashes correspond to the requested data, and reconstruct a state tree
+ for i, want := range hashes {
+ if hash := crypto.Keccak256Hash(data[i]); hash != want {
+ t.Errorf("data hash mismatch: have %x, want %x", hash, want)
+ }
+ }
+ statedb := rawdb.NewMemoryDatabase()
+ for i := 0; i < len(data); i++ {
+ statedb.Put(hashes[i].Bytes(), data[i])
+ }
+ accounts := []common.Address{testAddr, acc1Addr, acc2Addr}
+ for i := uint64(0); i <= backend.chain.CurrentBlock().NumberU64(); i++ {
+ trie, _ := state.New(backend.chain.GetBlockByNumber(i).Root(), state.NewDatabase(statedb), nil)
+
+ for j, acc := range accounts {
+ state, _ := backend.chain.State()
+ bw := state.GetBalance(acc)
+ bh := trie.GetBalance(acc)
+
+ if (bw != nil && bh == nil) || (bw == nil && bh != nil) {
+ t.Errorf("test %d, account %d: balance mismatch: have %v, want %v", i, j, bh, bw)
+ }
+ if bw != nil && bh != nil && bw.Cmp(bw) != 0 {
+ t.Errorf("test %d, account %d: balance mismatch: have %v, want %v", i, j, bh, bw)
+ }
+ }
+ }
+}
+
+// Tests that the transaction receipts can be retrieved based on hashes.
+func TestGetBlockReceipts64(t *testing.T) { testGetBlockReceipts(t, 64) }
+func TestGetBlockReceipts65(t *testing.T) { testGetBlockReceipts(t, 65) }
+
+func testGetBlockReceipts(t *testing.T, protocol uint) {
+ t.Parallel()
+
+ // Define three accounts to simulate transactions with
+ acc1Key, _ := crypto.HexToECDSA("8a1f9a8f95be41cd7ccb6168179afb4504aefe388d1e14474d32c45c72ce7b7a")
+ acc2Key, _ := crypto.HexToECDSA("49a7b37aa6f6645917e7b807e9d1c00d4fa71f18343b0d4122a4d2df64dd6fee")
+ acc1Addr := crypto.PubkeyToAddress(acc1Key.PublicKey)
+ acc2Addr := crypto.PubkeyToAddress(acc2Key.PublicKey)
+
+ signer := types.HomesteadSigner{}
+ // Create a chain generator with some simple transactions (blatantly stolen from @fjl/chain_markets_test)
+ generator := func(i int, block *core.BlockGen) {
+ switch i {
+ case 0:
+ // In block 1, the test bank sends account #1 some ether.
+ tx, _ := types.SignTx(types.NewTransaction(block.TxNonce(testAddr), acc1Addr, big.NewInt(10000), params.TxGas, nil, nil), signer, testKey)
+ block.AddTx(tx)
+ case 1:
+ // In block 2, the test bank sends some more ether to account #1.
+ // acc1Addr passes it on to account #2.
+ tx1, _ := types.SignTx(types.NewTransaction(block.TxNonce(testAddr), acc1Addr, big.NewInt(1000), params.TxGas, nil, nil), signer, testKey)
+ tx2, _ := types.SignTx(types.NewTransaction(block.TxNonce(acc1Addr), acc2Addr, big.NewInt(1000), params.TxGas, nil, nil), signer, acc1Key)
+ block.AddTx(tx1)
+ block.AddTx(tx2)
+ case 2:
+ // Block 3 is empty but was mined by account #2.
+ block.SetCoinbase(acc2Addr)
+ block.SetExtra([]byte("yeehaw"))
+ case 3:
+ // Block 4 includes blocks 2 and 3 as uncle headers (with modified extra data).
+ b2 := block.PrevBlock(1).Header()
+ b2.Extra = []byte("foo")
+ block.AddUncle(b2)
+ b3 := block.PrevBlock(2).Header()
+ b3.Extra = []byte("foo")
+ block.AddUncle(b3)
+ }
+ }
+ // Assemble the test environment
+ backend := newTestBackendWithGenerator(4, generator)
+ defer backend.close()
+
+ peer, _ := newTestPeer("peer", protocol, backend)
+ defer peer.close()
+
+ // Collect the hashes to request, and the response to expect
+ var (
+ hashes []common.Hash
+ receipts []types.Receipts
+ )
+ for i := uint64(0); i <= backend.chain.CurrentBlock().NumberU64(); i++ {
+ block := backend.chain.GetBlockByNumber(i)
+
+ hashes = append(hashes, block.Hash())
+ receipts = append(receipts, backend.chain.GetReceiptsByHash(block.Hash()))
+ }
+ // Send the hash request and verify the response
+ p2p.Send(peer.app, 0x0f, hashes)
+ if err := p2p.ExpectMsg(peer.app, 0x10, receipts); err != nil {
+ t.Errorf("receipts mismatch: %v", err)
+ }
+}
diff --git a/eth/protocols/eth/handshake.go b/eth/protocols/eth/handshake.go
new file mode 100644
index 000000000..57a4e0bc3
--- /dev/null
+++ b/eth/protocols/eth/handshake.go
@@ -0,0 +1,107 @@
+// Copyright 2015 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see .
+
+package eth
+
+import (
+ "fmt"
+ "math/big"
+ "time"
+
+ "github.com/ethereum/go-ethereum/common"
+ "github.com/ethereum/go-ethereum/core/forkid"
+ "github.com/ethereum/go-ethereum/p2p"
+)
+
+const (
+ // handshakeTimeout is the maximum allowed time for the `eth` handshake to
+ // complete before dropping the connection.= as malicious.
+ handshakeTimeout = 5 * time.Second
+)
+
+// Handshake executes the eth protocol handshake, negotiating version number,
+// network IDs, difficulties, head and genesis blocks.
+func (p *Peer) Handshake(network uint64, td *big.Int, head common.Hash, genesis common.Hash, forkID forkid.ID, forkFilter forkid.Filter) error {
+ // Send out own handshake in a new thread
+ errc := make(chan error, 2)
+
+ var status StatusPacket // safe to read after two values have been received from errc
+
+ go func() {
+ errc <- p2p.Send(p.rw, StatusMsg, &StatusPacket{
+ ProtocolVersion: uint32(p.version),
+ NetworkID: network,
+ TD: td,
+ Head: head,
+ Genesis: genesis,
+ ForkID: forkID,
+ })
+ }()
+ go func() {
+ errc <- p.readStatus(network, &status, genesis, forkFilter)
+ }()
+ timeout := time.NewTimer(handshakeTimeout)
+ defer timeout.Stop()
+ for i := 0; i < 2; i++ {
+ select {
+ case err := <-errc:
+ if err != nil {
+ return err
+ }
+ case <-timeout.C:
+ return p2p.DiscReadTimeout
+ }
+ }
+ p.td, p.head = status.TD, status.Head
+
+ // TD at mainnet block #7753254 is 76 bits. If it becomes 100 million times
+ // larger, it will still fit within 100 bits
+ if tdlen := p.td.BitLen(); tdlen > 100 {
+ return fmt.Errorf("too large total difficulty: bitlen %d", tdlen)
+ }
+ return nil
+}
+
+// readStatus reads the remote handshake message.
+func (p *Peer) readStatus(network uint64, status *StatusPacket, genesis common.Hash, forkFilter forkid.Filter) error {
+ msg, err := p.rw.ReadMsg()
+ if err != nil {
+ return err
+ }
+ if msg.Code != StatusMsg {
+ return fmt.Errorf("%w: first msg has code %x (!= %x)", errNoStatusMsg, msg.Code, StatusMsg)
+ }
+ if msg.Size > maxMessageSize {
+ return fmt.Errorf("%w: %v > %v", errMsgTooLarge, msg.Size, maxMessageSize)
+ }
+ // Decode the handshake and make sure everything matches
+ if err := msg.Decode(&status); err != nil {
+ return fmt.Errorf("%w: message %v: %v", errDecode, msg, err)
+ }
+ if status.NetworkID != network {
+ return fmt.Errorf("%w: %d (!= %d)", errNetworkIDMismatch, status.NetworkID, network)
+ }
+ if uint(status.ProtocolVersion) != p.version {
+ return fmt.Errorf("%w: %d (!= %d)", errProtocolVersionMismatch, status.ProtocolVersion, p.version)
+ }
+ if status.Genesis != genesis {
+ return fmt.Errorf("%w: %x (!= %x)", errGenesisMismatch, status.Genesis, genesis)
+ }
+ if err := forkFilter(status.ForkID); err != nil {
+ return fmt.Errorf("%w: %v", errForkIDRejected, err)
+ }
+ return nil
+}
diff --git a/eth/protocols/eth/handshake_test.go b/eth/protocols/eth/handshake_test.go
new file mode 100644
index 000000000..65f9a0006
--- /dev/null
+++ b/eth/protocols/eth/handshake_test.go
@@ -0,0 +1,91 @@
+// Copyright 2014 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see .
+
+package eth
+
+import (
+ "errors"
+ "testing"
+
+ "github.com/ethereum/go-ethereum/common"
+ "github.com/ethereum/go-ethereum/core/forkid"
+ "github.com/ethereum/go-ethereum/p2p"
+ "github.com/ethereum/go-ethereum/p2p/enode"
+)
+
+// Tests that handshake failures are detected and reported correctly.
+func TestHandshake64(t *testing.T) { testHandshake(t, 64) }
+func TestHandshake65(t *testing.T) { testHandshake(t, 65) }
+
+func testHandshake(t *testing.T, protocol uint) {
+ t.Parallel()
+
+ // Create a test backend only to have some valid genesis chain
+ backend := newTestBackend(3)
+ defer backend.close()
+
+ var (
+ genesis = backend.chain.Genesis()
+ head = backend.chain.CurrentBlock()
+ td = backend.chain.GetTd(head.Hash(), head.NumberU64())
+ forkID = forkid.NewID(backend.chain.Config(), backend.chain.Genesis().Hash(), backend.chain.CurrentHeader().Number.Uint64())
+ )
+ tests := []struct {
+ code uint64
+ data interface{}
+ want error
+ }{
+ {
+ code: TransactionsMsg, data: []interface{}{},
+ want: errNoStatusMsg,
+ },
+ {
+ code: StatusMsg, data: StatusPacket{10, 1, td, head.Hash(), genesis.Hash(), forkID},
+ want: errProtocolVersionMismatch,
+ },
+ {
+ code: StatusMsg, data: StatusPacket{uint32(protocol), 999, td, head.Hash(), genesis.Hash(), forkID},
+ want: errNetworkIDMismatch,
+ },
+ {
+ code: StatusMsg, data: StatusPacket{uint32(protocol), 1, td, head.Hash(), common.Hash{3}, forkID},
+ want: errGenesisMismatch,
+ },
+ {
+ code: StatusMsg, data: StatusPacket{uint32(protocol), 1, td, head.Hash(), genesis.Hash(), forkid.ID{Hash: [4]byte{0x00, 0x01, 0x02, 0x03}}},
+ want: errForkIDRejected,
+ },
+ }
+ for i, test := range tests {
+ // Create the two peers to shake with each other
+ app, net := p2p.MsgPipe()
+ defer app.Close()
+ defer net.Close()
+
+ peer := NewPeer(protocol, p2p.NewPeer(enode.ID{}, "peer", nil), net, nil)
+ defer peer.Close()
+
+ // Send the junk test with one peer, check the handshake failure
+ go p2p.Send(app, test.code, test.data)
+
+ err := peer.Handshake(1, td, head.Hash(), genesis.Hash(), forkID, forkid.NewFilter(backend.chain))
+ if err == nil {
+ t.Errorf("test %d: protocol returned nil error, want %q", i, test.want)
+ } else if !errors.Is(err, test.want) {
+ t.Errorf("test %d: wrong error: got %q, want %q", i, err, test.want)
+ }
+ }
+}
diff --git a/eth/protocols/eth/peer.go b/eth/protocols/eth/peer.go
new file mode 100644
index 000000000..735ef78ce
--- /dev/null
+++ b/eth/protocols/eth/peer.go
@@ -0,0 +1,429 @@
+// Copyright 2020 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see .
+
+package eth
+
+import (
+ "math/big"
+ "sync"
+
+ mapset "github.com/deckarep/golang-set"
+ "github.com/ethereum/go-ethereum/common"
+ "github.com/ethereum/go-ethereum/core/types"
+ "github.com/ethereum/go-ethereum/p2p"
+ "github.com/ethereum/go-ethereum/rlp"
+)
+
+const (
+ // maxKnownTxs is the maximum transactions hashes to keep in the known list
+ // before starting to randomly evict them.
+ maxKnownTxs = 32768
+
+ // maxKnownBlocks is the maximum block hashes to keep in the known list
+ // before starting to randomly evict them.
+ maxKnownBlocks = 1024
+
+ // maxQueuedTxs is the maximum number of transactions to queue up before dropping
+ // older broadcasts.
+ maxQueuedTxs = 4096
+
+ // maxQueuedTxAnns is the maximum number of transaction announcements to queue up
+ // before dropping older announcements.
+ maxQueuedTxAnns = 4096
+
+ // maxQueuedBlocks is the maximum number of block propagations to queue up before
+ // dropping broadcasts. There's not much point in queueing stale blocks, so a few
+ // that might cover uncles should be enough.
+ maxQueuedBlocks = 4
+
+ // maxQueuedBlockAnns is the maximum number of block announcements to queue up before
+ // dropping broadcasts. Similarly to block propagations, there's no point to queue
+ // above some healthy uncle limit, so use that.
+ maxQueuedBlockAnns = 4
+)
+
+// max is a helper function which returns the larger of the two given integers.
+func max(a, b int) int {
+ if a > b {
+ return a
+ }
+ return b
+}
+
+// Peer is a collection of relevant information we have about a `eth` peer.
+type Peer struct {
+ id string // Unique ID for the peer, cached
+
+ *p2p.Peer // The embedded P2P package peer
+ rw p2p.MsgReadWriter // Input/output streams for snap
+ version uint // Protocol version negotiated
+
+ head common.Hash // Latest advertised head block hash
+ td *big.Int // Latest advertised head block total difficulty
+
+ knownBlocks mapset.Set // Set of block hashes known to be known by this peer
+ queuedBlocks chan *blockPropagation // Queue of blocks to broadcast to the peer
+ queuedBlockAnns chan *types.Block // Queue of blocks to announce to the peer
+
+ txpool TxPool // Transaction pool used by the broadcasters for liveness checks
+ knownTxs mapset.Set // Set of transaction hashes known to be known by this peer
+ txBroadcast chan []common.Hash // Channel used to queue transaction propagation requests
+ txAnnounce chan []common.Hash // Channel used to queue transaction announcement requests
+
+ term chan struct{} // Termination channel to stop the broadcasters
+ lock sync.RWMutex // Mutex protecting the internal fields
+}
+
+// NewPeer create a wrapper for a network connection and negotiated protocol
+// version.
+func NewPeer(version uint, p *p2p.Peer, rw p2p.MsgReadWriter, txpool TxPool) *Peer {
+ peer := &Peer{
+ id: p.ID().String(),
+ Peer: p,
+ rw: rw,
+ version: version,
+ knownTxs: mapset.NewSet(),
+ knownBlocks: mapset.NewSet(),
+ queuedBlocks: make(chan *blockPropagation, maxQueuedBlocks),
+ queuedBlockAnns: make(chan *types.Block, maxQueuedBlockAnns),
+ txBroadcast: make(chan []common.Hash),
+ txAnnounce: make(chan []common.Hash),
+ txpool: txpool,
+ term: make(chan struct{}),
+ }
+ // Start up all the broadcasters
+ go peer.broadcastBlocks()
+ go peer.broadcastTransactions()
+ if version >= ETH65 {
+ go peer.announceTransactions()
+ }
+ return peer
+}
+
+// Close signals the broadcast goroutine to terminate. Only ever call this if
+// you created the peer yourself via NewPeer. Otherwise let whoever created it
+// clean it up!
+func (p *Peer) Close() {
+ close(p.term)
+}
+
+// ID retrieves the peer's unique identifier.
+func (p *Peer) ID() string {
+ return p.id
+}
+
+// Version retrieves the peer's negoatiated `eth` protocol version.
+func (p *Peer) Version() uint {
+ return p.version
+}
+
+// Head retrieves the current head hash and total difficulty of the peer.
+func (p *Peer) Head() (hash common.Hash, td *big.Int) {
+ p.lock.RLock()
+ defer p.lock.RUnlock()
+
+ copy(hash[:], p.head[:])
+ return hash, new(big.Int).Set(p.td)
+}
+
+// SetHead updates the head hash and total difficulty of the peer.
+func (p *Peer) SetHead(hash common.Hash, td *big.Int) {
+ p.lock.Lock()
+ defer p.lock.Unlock()
+
+ copy(p.head[:], hash[:])
+ p.td.Set(td)
+}
+
+// KnownBlock returns whether peer is known to already have a block.
+func (p *Peer) KnownBlock(hash common.Hash) bool {
+ return p.knownBlocks.Contains(hash)
+}
+
+// KnownTransaction returns whether peer is known to already have a transaction.
+func (p *Peer) KnownTransaction(hash common.Hash) bool {
+ return p.knownTxs.Contains(hash)
+}
+
+// markBlock marks a block as known for the peer, ensuring that the block will
+// never be propagated to this particular peer.
+func (p *Peer) markBlock(hash common.Hash) {
+ // If we reached the memory allowance, drop a previously known block hash
+ for p.knownBlocks.Cardinality() >= maxKnownBlocks {
+ p.knownBlocks.Pop()
+ }
+ p.knownBlocks.Add(hash)
+}
+
+// markTransaction marks a transaction as known for the peer, ensuring that it
+// will never be propagated to this particular peer.
+func (p *Peer) markTransaction(hash common.Hash) {
+ // If we reached the memory allowance, drop a previously known transaction hash
+ for p.knownTxs.Cardinality() >= maxKnownTxs {
+ p.knownTxs.Pop()
+ }
+ p.knownTxs.Add(hash)
+}
+
+// SendTransactions sends transactions to the peer and includes the hashes
+// in its transaction hash set for future reference.
+//
+// This method is a helper used by the async transaction sender. Don't call it
+// directly as the queueing (memory) and transmission (bandwidth) costs should
+// not be managed directly.
+//
+// The reasons this is public is to allow packages using this protocol to write
+// tests that directly send messages without having to do the asyn queueing.
+func (p *Peer) SendTransactions(txs types.Transactions) error {
+ // Mark all the transactions as known, but ensure we don't overflow our limits
+ for p.knownTxs.Cardinality() > max(0, maxKnownTxs-len(txs)) {
+ p.knownTxs.Pop()
+ }
+ for _, tx := range txs {
+ p.knownTxs.Add(tx.Hash())
+ }
+ return p2p.Send(p.rw, TransactionsMsg, txs)
+}
+
+// AsyncSendTransactions queues a list of transactions (by hash) to eventually
+// propagate to a remote peer. The number of pending sends are capped (new ones
+// will force old sends to be dropped)
+func (p *Peer) AsyncSendTransactions(hashes []common.Hash) {
+ select {
+ case p.txBroadcast <- hashes:
+ // Mark all the transactions as known, but ensure we don't overflow our limits
+ for p.knownTxs.Cardinality() > max(0, maxKnownTxs-len(hashes)) {
+ p.knownTxs.Pop()
+ }
+ for _, hash := range hashes {
+ p.knownTxs.Add(hash)
+ }
+ case <-p.term:
+ p.Log().Debug("Dropping transaction propagation", "count", len(hashes))
+ }
+}
+
+// sendPooledTransactionHashes sends transaction hashes to the peer and includes
+// them in its transaction hash set for future reference.
+//
+// This method is a helper used by the async transaction announcer. Don't call it
+// directly as the queueing (memory) and transmission (bandwidth) costs should
+// not be managed directly.
+func (p *Peer) sendPooledTransactionHashes(hashes []common.Hash) error {
+ // Mark all the transactions as known, but ensure we don't overflow our limits
+ for p.knownTxs.Cardinality() > max(0, maxKnownTxs-len(hashes)) {
+ p.knownTxs.Pop()
+ }
+ for _, hash := range hashes {
+ p.knownTxs.Add(hash)
+ }
+ return p2p.Send(p.rw, NewPooledTransactionHashesMsg, NewPooledTransactionHashesPacket(hashes))
+}
+
+// AsyncSendPooledTransactionHashes queues a list of transactions hashes to eventually
+// announce to a remote peer. The number of pending sends are capped (new ones
+// will force old sends to be dropped)
+func (p *Peer) AsyncSendPooledTransactionHashes(hashes []common.Hash) {
+ select {
+ case p.txAnnounce <- hashes:
+ // Mark all the transactions as known, but ensure we don't overflow our limits
+ for p.knownTxs.Cardinality() > max(0, maxKnownTxs-len(hashes)) {
+ p.knownTxs.Pop()
+ }
+ for _, hash := range hashes {
+ p.knownTxs.Add(hash)
+ }
+ case <-p.term:
+ p.Log().Debug("Dropping transaction announcement", "count", len(hashes))
+ }
+}
+
+// SendPooledTransactionsRLP sends requested transactions to the peer and adds the
+// hashes in its transaction hash set for future reference.
+//
+// Note, the method assumes the hashes are correct and correspond to the list of
+// transactions being sent.
+func (p *Peer) SendPooledTransactionsRLP(hashes []common.Hash, txs []rlp.RawValue) error {
+ // Mark all the transactions as known, but ensure we don't overflow our limits
+ for p.knownTxs.Cardinality() > max(0, maxKnownTxs-len(hashes)) {
+ p.knownTxs.Pop()
+ }
+ for _, hash := range hashes {
+ p.knownTxs.Add(hash)
+ }
+ return p2p.Send(p.rw, PooledTransactionsMsg, txs) // Not packed into PooledTransactionsPacket to avoid RLP decoding
+}
+
+// SendNewBlockHashes announces the availability of a number of blocks through
+// a hash notification.
+func (p *Peer) SendNewBlockHashes(hashes []common.Hash, numbers []uint64) error {
+ // Mark all the block hashes as known, but ensure we don't overflow our limits
+ for p.knownBlocks.Cardinality() > max(0, maxKnownBlocks-len(hashes)) {
+ p.knownBlocks.Pop()
+ }
+ for _, hash := range hashes {
+ p.knownBlocks.Add(hash)
+ }
+ request := make(NewBlockHashesPacket, len(hashes))
+ for i := 0; i < len(hashes); i++ {
+ request[i].Hash = hashes[i]
+ request[i].Number = numbers[i]
+ }
+ return p2p.Send(p.rw, NewBlockHashesMsg, request)
+}
+
+// AsyncSendNewBlockHash queues the availability of a block for propagation to a
+// remote peer. If the peer's broadcast queue is full, the event is silently
+// dropped.
+func (p *Peer) AsyncSendNewBlockHash(block *types.Block) {
+ select {
+ case p.queuedBlockAnns <- block:
+ // Mark all the block hash as known, but ensure we don't overflow our limits
+ for p.knownBlocks.Cardinality() >= maxKnownBlocks {
+ p.knownBlocks.Pop()
+ }
+ p.knownBlocks.Add(block.Hash())
+ default:
+ p.Log().Debug("Dropping block announcement", "number", block.NumberU64(), "hash", block.Hash())
+ }
+}
+
+// SendNewBlock propagates an entire block to a remote peer.
+func (p *Peer) SendNewBlock(block *types.Block, td *big.Int) error {
+ // Mark all the block hash as known, but ensure we don't overflow our limits
+ for p.knownBlocks.Cardinality() >= maxKnownBlocks {
+ p.knownBlocks.Pop()
+ }
+ p.knownBlocks.Add(block.Hash())
+ return p2p.Send(p.rw, NewBlockMsg, &NewBlockPacket{block, td})
+}
+
+// AsyncSendNewBlock queues an entire block for propagation to a remote peer. If
+// the peer's broadcast queue is full, the event is silently dropped.
+func (p *Peer) AsyncSendNewBlock(block *types.Block, td *big.Int) {
+ select {
+ case p.queuedBlocks <- &blockPropagation{block: block, td: td}:
+ // Mark all the block hash as known, but ensure we don't overflow our limits
+ for p.knownBlocks.Cardinality() >= maxKnownBlocks {
+ p.knownBlocks.Pop()
+ }
+ p.knownBlocks.Add(block.Hash())
+ default:
+ p.Log().Debug("Dropping block propagation", "number", block.NumberU64(), "hash", block.Hash())
+ }
+}
+
+// SendBlockHeaders sends a batch of block headers to the remote peer.
+func (p *Peer) SendBlockHeaders(headers []*types.Header) error {
+ return p2p.Send(p.rw, BlockHeadersMsg, BlockHeadersPacket(headers))
+}
+
+// SendBlockBodies sends a batch of block contents to the remote peer.
+func (p *Peer) SendBlockBodies(bodies []*BlockBody) error {
+ return p2p.Send(p.rw, BlockBodiesMsg, BlockBodiesPacket(bodies))
+}
+
+// SendBlockBodiesRLP sends a batch of block contents to the remote peer from
+// an already RLP encoded format.
+func (p *Peer) SendBlockBodiesRLP(bodies []rlp.RawValue) error {
+ return p2p.Send(p.rw, BlockBodiesMsg, bodies) // Not packed into BlockBodiesPacket to avoid RLP decoding
+}
+
+// SendNodeDataRLP sends a batch of arbitrary internal data, corresponding to the
+// hashes requested.
+func (p *Peer) SendNodeData(data [][]byte) error {
+ return p2p.Send(p.rw, NodeDataMsg, NodeDataPacket(data))
+}
+
+// SendReceiptsRLP sends a batch of transaction receipts, corresponding to the
+// ones requested from an already RLP encoded format.
+func (p *Peer) SendReceiptsRLP(receipts []rlp.RawValue) error {
+ return p2p.Send(p.rw, ReceiptsMsg, receipts) // Not packed into ReceiptsPacket to avoid RLP decoding
+}
+
+// RequestOneHeader is a wrapper around the header query functions to fetch a
+// single header. It is used solely by the fetcher.
+func (p *Peer) RequestOneHeader(hash common.Hash) error {
+ p.Log().Debug("Fetching single header", "hash", hash)
+ return p2p.Send(p.rw, GetBlockHeadersMsg, &GetBlockHeadersPacket{
+ Origin: HashOrNumber{Hash: hash},
+ Amount: uint64(1),
+ Skip: uint64(0),
+ Reverse: false,
+ })
+}
+
+// RequestHeadersByHash fetches a batch of blocks' headers corresponding to the
+// specified header query, based on the hash of an origin block.
+func (p *Peer) RequestHeadersByHash(origin common.Hash, amount int, skip int, reverse bool) error {
+ p.Log().Debug("Fetching batch of headers", "count", amount, "fromhash", origin, "skip", skip, "reverse", reverse)
+ return p2p.Send(p.rw, GetBlockHeadersMsg, &GetBlockHeadersPacket{
+ Origin: HashOrNumber{Hash: origin},
+ Amount: uint64(amount),
+ Skip: uint64(skip),
+ Reverse: reverse,
+ })
+}
+
+// RequestHeadersByNumber fetches a batch of blocks' headers corresponding to the
+// specified header query, based on the number of an origin block.
+func (p *Peer) RequestHeadersByNumber(origin uint64, amount int, skip int, reverse bool) error {
+ p.Log().Debug("Fetching batch of headers", "count", amount, "fromnum", origin, "skip", skip, "reverse", reverse)
+ return p2p.Send(p.rw, GetBlockHeadersMsg, &GetBlockHeadersPacket{
+ Origin: HashOrNumber{Number: origin},
+ Amount: uint64(amount),
+ Skip: uint64(skip),
+ Reverse: reverse,
+ })
+}
+
+// ExpectRequestHeadersByNumber is a testing method to mirror the recipient side
+// of the RequestHeadersByNumber operation.
+func (p *Peer) ExpectRequestHeadersByNumber(origin uint64, amount int, skip int, reverse bool) error {
+ req := &GetBlockHeadersPacket{
+ Origin: HashOrNumber{Number: origin},
+ Amount: uint64(amount),
+ Skip: uint64(skip),
+ Reverse: reverse,
+ }
+ return p2p.ExpectMsg(p.rw, GetBlockHeadersMsg, req)
+}
+
+// RequestBodies fetches a batch of blocks' bodies corresponding to the hashes
+// specified.
+func (p *Peer) RequestBodies(hashes []common.Hash) error {
+ p.Log().Debug("Fetching batch of block bodies", "count", len(hashes))
+ return p2p.Send(p.rw, GetBlockBodiesMsg, GetBlockBodiesPacket(hashes))
+}
+
+// RequestNodeData fetches a batch of arbitrary data from a node's known state
+// data, corresponding to the specified hashes.
+func (p *Peer) RequestNodeData(hashes []common.Hash) error {
+ p.Log().Debug("Fetching batch of state data", "count", len(hashes))
+ return p2p.Send(p.rw, GetNodeDataMsg, GetNodeDataPacket(hashes))
+}
+
+// RequestReceipts fetches a batch of transaction receipts from a remote node.
+func (p *Peer) RequestReceipts(hashes []common.Hash) error {
+ p.Log().Debug("Fetching batch of receipts", "count", len(hashes))
+ return p2p.Send(p.rw, GetReceiptsMsg, GetReceiptsPacket(hashes))
+}
+
+// RequestTxs fetches a batch of transactions from a remote node.
+func (p *Peer) RequestTxs(hashes []common.Hash) error {
+ p.Log().Debug("Fetching batch of transactions", "count", len(hashes))
+ return p2p.Send(p.rw, GetPooledTransactionsMsg, GetPooledTransactionsPacket(hashes))
+}
diff --git a/eth/protocols/eth/peer_test.go b/eth/protocols/eth/peer_test.go
new file mode 100644
index 000000000..70e9959f8
--- /dev/null
+++ b/eth/protocols/eth/peer_test.go
@@ -0,0 +1,61 @@
+// Copyright 2015 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see .
+
+// This file contains some shares testing functionality, common to multiple
+// different files and modules being tested.
+
+package eth
+
+import (
+ "crypto/rand"
+
+ "github.com/ethereum/go-ethereum/p2p"
+ "github.com/ethereum/go-ethereum/p2p/enode"
+)
+
+// testPeer is a simulated peer to allow testing direct network calls.
+type testPeer struct {
+ *Peer
+
+ net p2p.MsgReadWriter // Network layer reader/writer to simulate remote messaging
+ app *p2p.MsgPipeRW // Application layer reader/writer to simulate the local side
+}
+
+// newTestPeer creates a new peer registered at the given data backend.
+func newTestPeer(name string, version uint, backend Backend) (*testPeer, <-chan error) {
+ // Create a message pipe to communicate through
+ app, net := p2p.MsgPipe()
+
+ // Start the peer on a new thread
+ var id enode.ID
+ rand.Read(id[:])
+
+ peer := NewPeer(version, p2p.NewPeer(id, name, nil), net, backend.TxPool())
+ errc := make(chan error, 1)
+ go func() {
+ errc <- backend.RunPeer(peer, func(peer *Peer) error {
+ return Handle(backend, peer)
+ })
+ }()
+ return &testPeer{app: app, net: net, Peer: peer}, errc
+}
+
+// close terminates the local side of the peer, notifying the remote protocol
+// manager of termination.
+func (p *testPeer) close() {
+ p.Peer.Close()
+ p.app.Close()
+}
diff --git a/eth/protocols/eth/protocol.go b/eth/protocols/eth/protocol.go
new file mode 100644
index 000000000..63d3494ec
--- /dev/null
+++ b/eth/protocols/eth/protocol.go
@@ -0,0 +1,279 @@
+// Copyright 2014 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see .
+
+package eth
+
+import (
+ "errors"
+ "fmt"
+ "io"
+ "math/big"
+
+ "github.com/ethereum/go-ethereum/common"
+ "github.com/ethereum/go-ethereum/core/forkid"
+ "github.com/ethereum/go-ethereum/core/types"
+ "github.com/ethereum/go-ethereum/rlp"
+)
+
+// Constants to match up protocol versions and messages
+const (
+ ETH64 = 64
+ ETH65 = 65
+)
+
+// protocolName is the official short name of the `eth` protocol used during
+// devp2p capability negotiation.
+const protocolName = "eth"
+
+// protocolVersions are the supported versions of the `eth` protocol (first
+// is primary).
+var protocolVersions = []uint{ETH65, ETH64}
+
+// protocolLengths are the number of implemented message corresponding to
+// different protocol versions.
+var protocolLengths = map[uint]uint64{ETH65: 17, ETH64: 17}
+
+// maxMessageSize is the maximum cap on the size of a protocol message.
+const maxMessageSize = 10 * 1024 * 1024
+
+const (
+ // Protocol messages in eth/64
+ StatusMsg = 0x00
+ NewBlockHashesMsg = 0x01
+ TransactionsMsg = 0x02
+ GetBlockHeadersMsg = 0x03
+ BlockHeadersMsg = 0x04
+ GetBlockBodiesMsg = 0x05
+ BlockBodiesMsg = 0x06
+ NewBlockMsg = 0x07
+ GetNodeDataMsg = 0x0d
+ NodeDataMsg = 0x0e
+ GetReceiptsMsg = 0x0f
+ ReceiptsMsg = 0x10
+
+ // Protocol messages overloaded in eth/65
+ NewPooledTransactionHashesMsg = 0x08
+ GetPooledTransactionsMsg = 0x09
+ PooledTransactionsMsg = 0x0a
+)
+
+var (
+ errNoStatusMsg = errors.New("no status message")
+ errMsgTooLarge = errors.New("message too long")
+ errDecode = errors.New("invalid message")
+ errInvalidMsgCode = errors.New("invalid message code")
+ errProtocolVersionMismatch = errors.New("protocol version mismatch")
+ errNetworkIDMismatch = errors.New("network ID mismatch")
+ errGenesisMismatch = errors.New("genesis mismatch")
+ errForkIDRejected = errors.New("fork ID rejected")
+ errExtraStatusMsg = errors.New("extra status message")
+)
+
+// Packet represents a p2p message in the `eth` protocol.
+type Packet interface {
+ Name() string // Name returns a string corresponding to the message type.
+ Kind() byte // Kind returns the message type.
+}
+
+// StatusPacket is the network packet for the status message for eth/64 and later.
+type StatusPacket struct {
+ ProtocolVersion uint32
+ NetworkID uint64
+ TD *big.Int
+ Head common.Hash
+ Genesis common.Hash
+ ForkID forkid.ID
+}
+
+// NewBlockHashesPacket is the network packet for the block announcements.
+type NewBlockHashesPacket []struct {
+ Hash common.Hash // Hash of one particular block being announced
+ Number uint64 // Number of one particular block being announced
+}
+
+// Unpack retrieves the block hashes and numbers from the announcement packet
+// and returns them in a split flat format that's more consistent with the
+// internal data structures.
+func (p *NewBlockHashesPacket) Unpack() ([]common.Hash, []uint64) {
+ var (
+ hashes = make([]common.Hash, len(*p))
+ numbers = make([]uint64, len(*p))
+ )
+ for i, body := range *p {
+ hashes[i], numbers[i] = body.Hash, body.Number
+ }
+ return hashes, numbers
+}
+
+// TransactionsPacket is the network packet for broadcasting new transactions.
+type TransactionsPacket []*types.Transaction
+
+// GetBlockHeadersPacket represents a block header query.
+type GetBlockHeadersPacket struct {
+ Origin HashOrNumber // Block from which to retrieve headers
+ Amount uint64 // Maximum number of headers to retrieve
+ Skip uint64 // Blocks to skip between consecutive headers
+ Reverse bool // Query direction (false = rising towards latest, true = falling towards genesis)
+}
+
+// HashOrNumber is a combined field for specifying an origin block.
+type HashOrNumber struct {
+ Hash common.Hash // Block hash from which to retrieve headers (excludes Number)
+ Number uint64 // Block hash from which to retrieve headers (excludes Hash)
+}
+
+// EncodeRLP is a specialized encoder for HashOrNumber to encode only one of the
+// two contained union fields.
+func (hn *HashOrNumber) EncodeRLP(w io.Writer) error {
+ if hn.Hash == (common.Hash{}) {
+ return rlp.Encode(w, hn.Number)
+ }
+ if hn.Number != 0 {
+ return fmt.Errorf("both origin hash (%x) and number (%d) provided", hn.Hash, hn.Number)
+ }
+ return rlp.Encode(w, hn.Hash)
+}
+
+// DecodeRLP is a specialized decoder for HashOrNumber to decode the contents
+// into either a block hash or a block number.
+func (hn *HashOrNumber) DecodeRLP(s *rlp.Stream) error {
+ _, size, _ := s.Kind()
+ origin, err := s.Raw()
+ if err == nil {
+ switch {
+ case size == 32:
+ err = rlp.DecodeBytes(origin, &hn.Hash)
+ case size <= 8:
+ err = rlp.DecodeBytes(origin, &hn.Number)
+ default:
+ err = fmt.Errorf("invalid input size %d for origin", size)
+ }
+ }
+ return err
+}
+
+// BlockHeadersPacket represents a block header response.
+type BlockHeadersPacket []*types.Header
+
+// NewBlockPacket is the network packet for the block propagation message.
+type NewBlockPacket struct {
+ Block *types.Block
+ TD *big.Int
+}
+
+// sanityCheck verifies that the values are reasonable, as a DoS protection
+func (request *NewBlockPacket) sanityCheck() error {
+ if err := request.Block.SanityCheck(); err != nil {
+ return err
+ }
+ //TD at mainnet block #7753254 is 76 bits. If it becomes 100 million times
+ // larger, it will still fit within 100 bits
+ if tdlen := request.TD.BitLen(); tdlen > 100 {
+ return fmt.Errorf("too large block TD: bitlen %d", tdlen)
+ }
+ return nil
+}
+
+// GetBlockBodiesPacket represents a block body query.
+type GetBlockBodiesPacket []common.Hash
+
+// BlockBodiesPacket is the network packet for block content distribution.
+type BlockBodiesPacket []*BlockBody
+
+// BlockBody represents the data content of a single block.
+type BlockBody struct {
+ Transactions []*types.Transaction // Transactions contained within a block
+ Uncles []*types.Header // Uncles contained within a block
+}
+
+// Unpack retrieves the transactions and uncles from the range packet and returns
+// them in a split flat format that's more consistent with the internal data structures.
+func (p *BlockBodiesPacket) Unpack() ([][]*types.Transaction, [][]*types.Header) {
+ var (
+ txset = make([][]*types.Transaction, len(*p))
+ uncleset = make([][]*types.Header, len(*p))
+ )
+ for i, body := range *p {
+ txset[i], uncleset[i] = body.Transactions, body.Uncles
+ }
+ return txset, uncleset
+}
+
+// GetNodeDataPacket represents a trie node data query.
+type GetNodeDataPacket []common.Hash
+
+// NodeDataPacket is the network packet for trie node data distribution.
+type NodeDataPacket [][]byte
+
+// GetReceiptsPacket represents a block receipts query.
+type GetReceiptsPacket []common.Hash
+
+// ReceiptsPacket is the network packet for block receipts distribution.
+type ReceiptsPacket [][]*types.Receipt
+
+// NewPooledTransactionHashesPacket represents a transaction announcement packet.
+type NewPooledTransactionHashesPacket []common.Hash
+
+// GetPooledTransactionsPacket represents a transaction query.
+type GetPooledTransactionsPacket []common.Hash
+
+// PooledTransactionsPacket is the network packet for transaction distribution.
+type PooledTransactionsPacket []*types.Transaction
+
+func (*StatusPacket) Name() string { return "Status" }
+func (*StatusPacket) Kind() byte { return StatusMsg }
+
+func (*NewBlockHashesPacket) Name() string { return "NewBlockHashes" }
+func (*NewBlockHashesPacket) Kind() byte { return NewBlockHashesMsg }
+
+func (*TransactionsPacket) Name() string { return "Transactions" }
+func (*TransactionsPacket) Kind() byte { return TransactionsMsg }
+
+func (*GetBlockHeadersPacket) Name() string { return "GetBlockHeaders" }
+func (*GetBlockHeadersPacket) Kind() byte { return GetBlockHeadersMsg }
+
+func (*BlockHeadersPacket) Name() string { return "BlockHeaders" }
+func (*BlockHeadersPacket) Kind() byte { return BlockHeadersMsg }
+
+func (*GetBlockBodiesPacket) Name() string { return "GetBlockBodies" }
+func (*GetBlockBodiesPacket) Kind() byte { return GetBlockBodiesMsg }
+
+func (*BlockBodiesPacket) Name() string { return "BlockBodies" }
+func (*BlockBodiesPacket) Kind() byte { return BlockBodiesMsg }
+
+func (*NewBlockPacket) Name() string { return "NewBlock" }
+func (*NewBlockPacket) Kind() byte { return NewBlockMsg }
+
+func (*GetNodeDataPacket) Name() string { return "GetNodeData" }
+func (*GetNodeDataPacket) Kind() byte { return GetNodeDataMsg }
+
+func (*NodeDataPacket) Name() string { return "NodeData" }
+func (*NodeDataPacket) Kind() byte { return NodeDataMsg }
+
+func (*GetReceiptsPacket) Name() string { return "GetReceipts" }
+func (*GetReceiptsPacket) Kind() byte { return GetReceiptsMsg }
+
+func (*ReceiptsPacket) Name() string { return "Receipts" }
+func (*ReceiptsPacket) Kind() byte { return ReceiptsMsg }
+
+func (*NewPooledTransactionHashesPacket) Name() string { return "NewPooledTransactionHashes" }
+func (*NewPooledTransactionHashesPacket) Kind() byte { return NewPooledTransactionHashesMsg }
+
+func (*GetPooledTransactionsPacket) Name() string { return "GetPooledTransactions" }
+func (*GetPooledTransactionsPacket) Kind() byte { return GetPooledTransactionsMsg }
+
+func (*PooledTransactionsPacket) Name() string { return "PooledTransactions" }
+func (*PooledTransactionsPacket) Kind() byte { return PooledTransactionsMsg }
diff --git a/eth/protocols/eth/protocol_test.go b/eth/protocols/eth/protocol_test.go
new file mode 100644
index 000000000..056ea5648
--- /dev/null
+++ b/eth/protocols/eth/protocol_test.go
@@ -0,0 +1,68 @@
+// Copyright 2014 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see .
+
+package eth
+
+import (
+ "testing"
+
+ "github.com/ethereum/go-ethereum/common"
+ "github.com/ethereum/go-ethereum/rlp"
+)
+
+// Tests that the custom union field encoder and decoder works correctly.
+func TestGetBlockHeadersDataEncodeDecode(t *testing.T) {
+ // Create a "random" hash for testing
+ var hash common.Hash
+ for i := range hash {
+ hash[i] = byte(i)
+ }
+ // Assemble some table driven tests
+ tests := []struct {
+ packet *GetBlockHeadersPacket
+ fail bool
+ }{
+ // Providing the origin as either a hash or a number should both work
+ {fail: false, packet: &GetBlockHeadersPacket{Origin: HashOrNumber{Number: 314}}},
+ {fail: false, packet: &GetBlockHeadersPacket{Origin: HashOrNumber{Hash: hash}}},
+
+ // Providing arbitrary query field should also work
+ {fail: false, packet: &GetBlockHeadersPacket{Origin: HashOrNumber{Number: 314}, Amount: 314, Skip: 1, Reverse: true}},
+ {fail: false, packet: &GetBlockHeadersPacket{Origin: HashOrNumber{Hash: hash}, Amount: 314, Skip: 1, Reverse: true}},
+
+ // Providing both the origin hash and origin number must fail
+ {fail: true, packet: &GetBlockHeadersPacket{Origin: HashOrNumber{Hash: hash, Number: 314}}},
+ }
+ // Iterate over each of the tests and try to encode and then decode
+ for i, tt := range tests {
+ bytes, err := rlp.EncodeToBytes(tt.packet)
+ if err != nil && !tt.fail {
+ t.Fatalf("test %d: failed to encode packet: %v", i, err)
+ } else if err == nil && tt.fail {
+ t.Fatalf("test %d: encode should have failed", i)
+ }
+ if !tt.fail {
+ packet := new(GetBlockHeadersPacket)
+ if err := rlp.DecodeBytes(bytes, packet); err != nil {
+ t.Fatalf("test %d: failed to decode packet: %v", i, err)
+ }
+ if packet.Origin.Hash != tt.packet.Origin.Hash || packet.Origin.Number != tt.packet.Origin.Number || packet.Amount != tt.packet.Amount ||
+ packet.Skip != tt.packet.Skip || packet.Reverse != tt.packet.Reverse {
+ t.Fatalf("test %d: encode decode mismatch: have %+v, want %+v", i, packet, tt.packet)
+ }
+ }
+ }
+}
diff --git a/eth/protocols/snap/discovery.go b/eth/protocols/snap/discovery.go
new file mode 100644
index 000000000..684ec7e63
--- /dev/null
+++ b/eth/protocols/snap/discovery.go
@@ -0,0 +1,32 @@
+// Copyright 2020 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see .
+
+package snap
+
+import (
+ "github.com/ethereum/go-ethereum/rlp"
+)
+
+// enrEntry is the ENR entry which advertises `snap` protocol on the discovery.
+type enrEntry struct {
+ // Ignore additional fields (for forward compatibility).
+ Rest []rlp.RawValue `rlp:"tail"`
+}
+
+// ENRKey implements enr.Entry.
+func (e enrEntry) ENRKey() string {
+ return "snap"
+}
diff --git a/eth/protocols/snap/handler.go b/eth/protocols/snap/handler.go
new file mode 100644
index 000000000..36322e648
--- /dev/null
+++ b/eth/protocols/snap/handler.go
@@ -0,0 +1,490 @@
+// Copyright 2020 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see .
+
+package snap
+
+import (
+ "bytes"
+ "fmt"
+
+ "github.com/ethereum/go-ethereum/common"
+ "github.com/ethereum/go-ethereum/core"
+ "github.com/ethereum/go-ethereum/core/state"
+ "github.com/ethereum/go-ethereum/light"
+ "github.com/ethereum/go-ethereum/log"
+ "github.com/ethereum/go-ethereum/p2p"
+ "github.com/ethereum/go-ethereum/p2p/enode"
+ "github.com/ethereum/go-ethereum/p2p/enr"
+ "github.com/ethereum/go-ethereum/rlp"
+ "github.com/ethereum/go-ethereum/trie"
+)
+
+const (
+ // softResponseLimit is the target maximum size of replies to data retrievals.
+ softResponseLimit = 2 * 1024 * 1024
+
+ // maxCodeLookups is the maximum number of bytecodes to serve. This number is
+ // there to limit the number of disk lookups.
+ maxCodeLookups = 1024
+
+ // stateLookupSlack defines the ratio by how much a state response can exceed
+ // the requested limit in order to try and avoid breaking up contracts into
+ // multiple packages and proving them.
+ stateLookupSlack = 0.1
+
+ // maxTrieNodeLookups is the maximum number of state trie nodes to serve. This
+ // number is there to limit the number of disk lookups.
+ maxTrieNodeLookups = 1024
+)
+
+// Handler is a callback to invoke from an outside runner after the boilerplate
+// exchanges have passed.
+type Handler func(peer *Peer) error
+
+// Backend defines the data retrieval methods to serve remote requests and the
+// callback methods to invoke on remote deliveries.
+type Backend interface {
+ // Chain retrieves the blockchain object to serve data.
+ Chain() *core.BlockChain
+
+ // RunPeer is invoked when a peer joins on the `eth` protocol. The handler
+ // should do any peer maintenance work, handshakes and validations. If all
+ // is passed, control should be given back to the `handler` to process the
+ // inbound messages going forward.
+ RunPeer(peer *Peer, handler Handler) error
+
+ // PeerInfo retrieves all known `snap` information about a peer.
+ PeerInfo(id enode.ID) interface{}
+
+ // Handle is a callback to be invoked when a data packet is received from
+ // the remote peer. Only packets not consumed by the protocol handler will
+ // be forwarded to the backend.
+ Handle(peer *Peer, packet Packet) error
+}
+
+// MakeProtocols constructs the P2P protocol definitions for `snap`.
+func MakeProtocols(backend Backend, dnsdisc enode.Iterator) []p2p.Protocol {
+ protocols := make([]p2p.Protocol, len(protocolVersions))
+ for i, version := range protocolVersions {
+ version := version // Closure
+
+ protocols[i] = p2p.Protocol{
+ Name: protocolName,
+ Version: version,
+ Length: protocolLengths[version],
+ Run: func(p *p2p.Peer, rw p2p.MsgReadWriter) error {
+ return backend.RunPeer(newPeer(version, p, rw), func(peer *Peer) error {
+ return handle(backend, peer)
+ })
+ },
+ NodeInfo: func() interface{} {
+ return nodeInfo(backend.Chain())
+ },
+ PeerInfo: func(id enode.ID) interface{} {
+ return backend.PeerInfo(id)
+ },
+ Attributes: []enr.Entry{&enrEntry{}},
+ DialCandidates: dnsdisc,
+ }
+ }
+ return protocols
+}
+
+// handle is the callback invoked to manage the life cycle of a `snap` peer.
+// When this function terminates, the peer is disconnected.
+func handle(backend Backend, peer *Peer) error {
+ for {
+ if err := handleMessage(backend, peer); err != nil {
+ peer.Log().Debug("Message handling failed in `snap`", "err", err)
+ return err
+ }
+ }
+}
+
+// handleMessage is invoked whenever an inbound message is received from a
+// remote peer on the `spap` protocol. The remote connection is torn down upon
+// returning any error.
+func handleMessage(backend Backend, peer *Peer) error {
+ // Read the next message from the remote peer, and ensure it's fully consumed
+ msg, err := peer.rw.ReadMsg()
+ if err != nil {
+ return err
+ }
+ if msg.Size > maxMessageSize {
+ return fmt.Errorf("%w: %v > %v", errMsgTooLarge, msg.Size, maxMessageSize)
+ }
+ defer msg.Discard()
+
+ // Handle the message depending on its contents
+ switch {
+ case msg.Code == GetAccountRangeMsg:
+ // Decode the account retrieval request
+ var req GetAccountRangePacket
+ if err := msg.Decode(&req); err != nil {
+ return fmt.Errorf("%w: message %v: %v", errDecode, msg, err)
+ }
+ if req.Bytes > softResponseLimit {
+ req.Bytes = softResponseLimit
+ }
+ // Retrieve the requested state and bail out if non existent
+ tr, err := trie.New(req.Root, backend.Chain().StateCache().TrieDB())
+ if err != nil {
+ return p2p.Send(peer.rw, AccountRangeMsg, &AccountRangePacket{ID: req.ID})
+ }
+ it, err := backend.Chain().Snapshots().AccountIterator(req.Root, req.Origin)
+ if err != nil {
+ return p2p.Send(peer.rw, AccountRangeMsg, &AccountRangePacket{ID: req.ID})
+ }
+ // Iterate over the requested range and pile accounts up
+ var (
+ accounts []*AccountData
+ size uint64
+ last common.Hash
+ )
+ for it.Next() && size < req.Bytes {
+ hash, account := it.Hash(), common.CopyBytes(it.Account())
+
+ // Track the returned interval for the Merkle proofs
+ last = hash
+
+ // Assemble the reply item
+ size += uint64(common.HashLength + len(account))
+ accounts = append(accounts, &AccountData{
+ Hash: hash,
+ Body: account,
+ })
+ // If we've exceeded the request threshold, abort
+ if bytes.Compare(hash[:], req.Limit[:]) >= 0 {
+ break
+ }
+ }
+ it.Release()
+
+ // Generate the Merkle proofs for the first and last account
+ proof := light.NewNodeSet()
+ if err := tr.Prove(req.Origin[:], 0, proof); err != nil {
+ log.Warn("Failed to prove account range", "origin", req.Origin, "err", err)
+ return p2p.Send(peer.rw, AccountRangeMsg, &AccountRangePacket{ID: req.ID})
+ }
+ if last != (common.Hash{}) {
+ if err := tr.Prove(last[:], 0, proof); err != nil {
+ log.Warn("Failed to prove account range", "last", last, "err", err)
+ return p2p.Send(peer.rw, AccountRangeMsg, &AccountRangePacket{ID: req.ID})
+ }
+ }
+ var proofs [][]byte
+ for _, blob := range proof.NodeList() {
+ proofs = append(proofs, blob)
+ }
+ // Send back anything accumulated
+ return p2p.Send(peer.rw, AccountRangeMsg, &AccountRangePacket{
+ ID: req.ID,
+ Accounts: accounts,
+ Proof: proofs,
+ })
+
+ case msg.Code == AccountRangeMsg:
+ // A range of accounts arrived to one of our previous requests
+ res := new(AccountRangePacket)
+ if err := msg.Decode(res); err != nil {
+ return fmt.Errorf("%w: message %v: %v", errDecode, msg, err)
+ }
+ // Ensure the range is monotonically increasing
+ for i := 1; i < len(res.Accounts); i++ {
+ if bytes.Compare(res.Accounts[i-1].Hash[:], res.Accounts[i].Hash[:]) >= 0 {
+ return fmt.Errorf("accounts not monotonically increasing: #%d [%x] vs #%d [%x]", i-1, res.Accounts[i-1].Hash[:], i, res.Accounts[i].Hash[:])
+ }
+ }
+ return backend.Handle(peer, res)
+
+ case msg.Code == GetStorageRangesMsg:
+ // Decode the storage retrieval request
+ var req GetStorageRangesPacket
+ if err := msg.Decode(&req); err != nil {
+ return fmt.Errorf("%w: message %v: %v", errDecode, msg, err)
+ }
+ if req.Bytes > softResponseLimit {
+ req.Bytes = softResponseLimit
+ }
+ // TODO(karalabe): Do we want to enforce > 0 accounts and 1 account if origin is set?
+ // TODO(karalabe): - Logging locally is not ideal as remote faulst annoy the local user
+ // TODO(karalabe): - Dropping the remote peer is less flexible wrt client bugs (slow is better than non-functional)
+
+ // Calculate the hard limit at which to abort, even if mid storage trie
+ hardLimit := uint64(float64(req.Bytes) * (1 + stateLookupSlack))
+
+ // Retrieve storage ranges until the packet limit is reached
+ var (
+ slots [][]*StorageData
+ proofs [][]byte
+ size uint64
+ )
+ for _, account := range req.Accounts {
+ // If we've exceeded the requested data limit, abort without opening
+ // a new storage range (that we'd need to prove due to exceeded size)
+ if size >= req.Bytes {
+ break
+ }
+ // The first account might start from a different origin and end sooner
+ var origin common.Hash
+ if len(req.Origin) > 0 {
+ origin, req.Origin = common.BytesToHash(req.Origin), nil
+ }
+ var limit = common.HexToHash("0xffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff")
+ if len(req.Limit) > 0 {
+ limit, req.Limit = common.BytesToHash(req.Limit), nil
+ }
+ // Retrieve the requested state and bail out if non existent
+ it, err := backend.Chain().Snapshots().StorageIterator(req.Root, account, origin)
+ if err != nil {
+ return p2p.Send(peer.rw, StorageRangesMsg, &StorageRangesPacket{ID: req.ID})
+ }
+ // Iterate over the requested range and pile slots up
+ var (
+ storage []*StorageData
+ last common.Hash
+ )
+ for it.Next() && size < hardLimit {
+ hash, slot := it.Hash(), common.CopyBytes(it.Slot())
+
+ // Track the returned interval for the Merkle proofs
+ last = hash
+
+ // Assemble the reply item
+ size += uint64(common.HashLength + len(slot))
+ storage = append(storage, &StorageData{
+ Hash: hash,
+ Body: slot,
+ })
+ // If we've exceeded the request threshold, abort
+ if bytes.Compare(hash[:], limit[:]) >= 0 {
+ break
+ }
+ }
+ slots = append(slots, storage)
+ it.Release()
+
+ // Generate the Merkle proofs for the first and last storage slot, but
+ // only if the response was capped. If the entire storage trie included
+ // in the response, no need for any proofs.
+ if origin != (common.Hash{}) || size >= hardLimit {
+ // Request started at a non-zero hash or was capped prematurely, add
+ // the endpoint Merkle proofs
+ accTrie, err := trie.New(req.Root, backend.Chain().StateCache().TrieDB())
+ if err != nil {
+ return p2p.Send(peer.rw, StorageRangesMsg, &StorageRangesPacket{ID: req.ID})
+ }
+ var acc state.Account
+ if err := rlp.DecodeBytes(accTrie.Get(account[:]), &acc); err != nil {
+ return p2p.Send(peer.rw, StorageRangesMsg, &StorageRangesPacket{ID: req.ID})
+ }
+ stTrie, err := trie.New(acc.Root, backend.Chain().StateCache().TrieDB())
+ if err != nil {
+ return p2p.Send(peer.rw, StorageRangesMsg, &StorageRangesPacket{ID: req.ID})
+ }
+ proof := light.NewNodeSet()
+ if err := stTrie.Prove(origin[:], 0, proof); err != nil {
+ log.Warn("Failed to prove storage range", "origin", req.Origin, "err", err)
+ return p2p.Send(peer.rw, StorageRangesMsg, &StorageRangesPacket{ID: req.ID})
+ }
+ if last != (common.Hash{}) {
+ if err := stTrie.Prove(last[:], 0, proof); err != nil {
+ log.Warn("Failed to prove storage range", "last", last, "err", err)
+ return p2p.Send(peer.rw, StorageRangesMsg, &StorageRangesPacket{ID: req.ID})
+ }
+ }
+ for _, blob := range proof.NodeList() {
+ proofs = append(proofs, blob)
+ }
+ // Proof terminates the reply as proofs are only added if a node
+ // refuses to serve more data (exception when a contract fetch is
+ // finishing, but that's that).
+ break
+ }
+ }
+ // Send back anything accumulated
+ return p2p.Send(peer.rw, StorageRangesMsg, &StorageRangesPacket{
+ ID: req.ID,
+ Slots: slots,
+ Proof: proofs,
+ })
+
+ case msg.Code == StorageRangesMsg:
+ // A range of storage slots arrived to one of our previous requests
+ res := new(StorageRangesPacket)
+ if err := msg.Decode(res); err != nil {
+ return fmt.Errorf("%w: message %v: %v", errDecode, msg, err)
+ }
+ // Ensure the ranges ae monotonically increasing
+ for i, slots := range res.Slots {
+ for j := 1; j < len(slots); j++ {
+ if bytes.Compare(slots[j-1].Hash[:], slots[j].Hash[:]) >= 0 {
+ return fmt.Errorf("storage slots not monotonically increasing for account #%d: #%d [%x] vs #%d [%x]", i, j-1, slots[j-1].Hash[:], j, slots[j].Hash[:])
+ }
+ }
+ }
+ return backend.Handle(peer, res)
+
+ case msg.Code == GetByteCodesMsg:
+ // Decode bytecode retrieval request
+ var req GetByteCodesPacket
+ if err := msg.Decode(&req); err != nil {
+ return fmt.Errorf("%w: message %v: %v", errDecode, msg, err)
+ }
+ if req.Bytes > softResponseLimit {
+ req.Bytes = softResponseLimit
+ }
+ if len(req.Hashes) > maxCodeLookups {
+ req.Hashes = req.Hashes[:maxCodeLookups]
+ }
+ // Retrieve bytecodes until the packet size limit is reached
+ var (
+ codes [][]byte
+ bytes uint64
+ )
+ for _, hash := range req.Hashes {
+ if hash == emptyCode {
+ // Peers should not request the empty code, but if they do, at
+ // least sent them back a correct response without db lookups
+ codes = append(codes, []byte{})
+ } else if blob, err := backend.Chain().ContractCode(hash); err == nil {
+ codes = append(codes, blob)
+ bytes += uint64(len(blob))
+ }
+ if bytes > req.Bytes {
+ break
+ }
+ }
+ // Send back anything accumulated
+ return p2p.Send(peer.rw, ByteCodesMsg, &ByteCodesPacket{
+ ID: req.ID,
+ Codes: codes,
+ })
+
+ case msg.Code == ByteCodesMsg:
+ // A batch of byte codes arrived to one of our previous requests
+ res := new(ByteCodesPacket)
+ if err := msg.Decode(res); err != nil {
+ return fmt.Errorf("%w: message %v: %v", errDecode, msg, err)
+ }
+ return backend.Handle(peer, res)
+
+ case msg.Code == GetTrieNodesMsg:
+ // Decode trie node retrieval request
+ var req GetTrieNodesPacket
+ if err := msg.Decode(&req); err != nil {
+ return fmt.Errorf("%w: message %v: %v", errDecode, msg, err)
+ }
+ if req.Bytes > softResponseLimit {
+ req.Bytes = softResponseLimit
+ }
+ // Make sure we have the state associated with the request
+ triedb := backend.Chain().StateCache().TrieDB()
+
+ accTrie, err := trie.NewSecure(req.Root, triedb)
+ if err != nil {
+ // We don't have the requested state available, bail out
+ return p2p.Send(peer.rw, TrieNodesMsg, &TrieNodesPacket{ID: req.ID})
+ }
+ snap := backend.Chain().Snapshots().Snapshot(req.Root)
+ if snap == nil {
+ // We don't have the requested state snapshotted yet, bail out.
+ // In reality we could still serve using the account and storage
+ // tries only, but let's protect the node a bit while it's doing
+ // snapshot generation.
+ return p2p.Send(peer.rw, TrieNodesMsg, &TrieNodesPacket{ID: req.ID})
+ }
+ // Retrieve trie nodes until the packet size limit is reached
+ var (
+ nodes [][]byte
+ bytes uint64
+ loads int // Trie hash expansions to cound database reads
+ )
+ for _, pathset := range req.Paths {
+ switch len(pathset) {
+ case 0:
+ // Ensure we penalize invalid requests
+ return fmt.Errorf("%w: zero-item pathset requested", errBadRequest)
+
+ case 1:
+ // If we're only retrieving an account trie node, fetch it directly
+ blob, resolved, err := accTrie.TryGetNode(pathset[0])
+ loads += resolved // always account database reads, even for failures
+ if err != nil {
+ break
+ }
+ nodes = append(nodes, blob)
+ bytes += uint64(len(blob))
+
+ default:
+ // Storage slots requested, open the storage trie and retrieve from there
+ account, err := snap.Account(common.BytesToHash(pathset[0]))
+ loads++ // always account database reads, even for failures
+ if err != nil {
+ break
+ }
+ stTrie, err := trie.NewSecure(common.BytesToHash(account.Root), triedb)
+ loads++ // always account database reads, even for failures
+ if err != nil {
+ break
+ }
+ for _, path := range pathset[1:] {
+ blob, resolved, err := stTrie.TryGetNode(path)
+ loads += resolved // always account database reads, even for failures
+ if err != nil {
+ break
+ }
+ nodes = append(nodes, blob)
+ bytes += uint64(len(blob))
+
+ // Sanity check limits to avoid DoS on the store trie loads
+ if bytes > req.Bytes || loads > maxTrieNodeLookups {
+ break
+ }
+ }
+ }
+ // Abort request processing if we've exceeded our limits
+ if bytes > req.Bytes || loads > maxTrieNodeLookups {
+ break
+ }
+ }
+ // Send back anything accumulated
+ return p2p.Send(peer.rw, TrieNodesMsg, &TrieNodesPacket{
+ ID: req.ID,
+ Nodes: nodes,
+ })
+
+ case msg.Code == TrieNodesMsg:
+ // A batch of trie nodes arrived to one of our previous requests
+ res := new(TrieNodesPacket)
+ if err := msg.Decode(res); err != nil {
+ return fmt.Errorf("%w: message %v: %v", errDecode, msg, err)
+ }
+ return backend.Handle(peer, res)
+
+ default:
+ return fmt.Errorf("%w: %v", errInvalidMsgCode, msg.Code)
+ }
+}
+
+// NodeInfo represents a short summary of the `snap` sub-protocol metadata
+// known about the host peer.
+type NodeInfo struct{}
+
+// nodeInfo retrieves some `snap` protocol metadata about the running host node.
+func nodeInfo(chain *core.BlockChain) *NodeInfo {
+ return &NodeInfo{}
+}
diff --git a/eth/protocols/snap/peer.go b/eth/protocols/snap/peer.go
new file mode 100644
index 000000000..73eaaadd0
--- /dev/null
+++ b/eth/protocols/snap/peer.go
@@ -0,0 +1,111 @@
+// Copyright 2020 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see .
+
+package snap
+
+import (
+ "github.com/ethereum/go-ethereum/common"
+ "github.com/ethereum/go-ethereum/log"
+ "github.com/ethereum/go-ethereum/p2p"
+)
+
+// Peer is a collection of relevant information we have about a `snap` peer.
+type Peer struct {
+ id string // Unique ID for the peer, cached
+
+ *p2p.Peer // The embedded P2P package peer
+ rw p2p.MsgReadWriter // Input/output streams for snap
+ version uint // Protocol version negotiated
+
+ logger log.Logger // Contextual logger with the peer id injected
+}
+
+// newPeer create a wrapper for a network connection and negotiated protocol
+// version.
+func newPeer(version uint, p *p2p.Peer, rw p2p.MsgReadWriter) *Peer {
+ id := p.ID().String()
+ return &Peer{
+ id: id,
+ Peer: p,
+ rw: rw,
+ version: version,
+ logger: log.New("peer", id[:8]),
+ }
+}
+
+// ID retrieves the peer's unique identifier.
+func (p *Peer) ID() string {
+ return p.id
+}
+
+// Version retrieves the peer's negoatiated `snap` protocol version.
+func (p *Peer) Version() uint {
+ return p.version
+}
+
+// RequestAccountRange fetches a batch of accounts rooted in a specific account
+// trie, starting with the origin.
+func (p *Peer) RequestAccountRange(id uint64, root common.Hash, origin, limit common.Hash, bytes uint64) error {
+ p.logger.Trace("Fetching range of accounts", "reqid", id, "root", root, "origin", origin, "limit", limit, "bytes", common.StorageSize(bytes))
+ return p2p.Send(p.rw, GetAccountRangeMsg, &GetAccountRangePacket{
+ ID: id,
+ Root: root,
+ Origin: origin,
+ Limit: limit,
+ Bytes: bytes,
+ })
+}
+
+// RequestStorageRange fetches a batch of storage slots belonging to one or more
+// accounts. If slots from only one accout is requested, an origin marker may also
+// be used to retrieve from there.
+func (p *Peer) RequestStorageRanges(id uint64, root common.Hash, accounts []common.Hash, origin, limit []byte, bytes uint64) error {
+ if len(accounts) == 1 && origin != nil {
+ p.logger.Trace("Fetching range of large storage slots", "reqid", id, "root", root, "account", accounts[0], "origin", common.BytesToHash(origin), "limit", common.BytesToHash(limit), "bytes", common.StorageSize(bytes))
+ } else {
+ p.logger.Trace("Fetching ranges of small storage slots", "reqid", id, "root", root, "accounts", len(accounts), "first", accounts[0], "bytes", common.StorageSize(bytes))
+ }
+ return p2p.Send(p.rw, GetStorageRangesMsg, &GetStorageRangesPacket{
+ ID: id,
+ Root: root,
+ Accounts: accounts,
+ Origin: origin,
+ Limit: limit,
+ Bytes: bytes,
+ })
+}
+
+// RequestByteCodes fetches a batch of bytecodes by hash.
+func (p *Peer) RequestByteCodes(id uint64, hashes []common.Hash, bytes uint64) error {
+ p.logger.Trace("Fetching set of byte codes", "reqid", id, "hashes", len(hashes), "bytes", common.StorageSize(bytes))
+ return p2p.Send(p.rw, GetByteCodesMsg, &GetByteCodesPacket{
+ ID: id,
+ Hashes: hashes,
+ Bytes: bytes,
+ })
+}
+
+// RequestTrieNodes fetches a batch of account or storage trie nodes rooted in
+// a specificstate trie.
+func (p *Peer) RequestTrieNodes(id uint64, root common.Hash, paths []TrieNodePathSet, bytes uint64) error {
+ p.logger.Trace("Fetching set of trie nodes", "reqid", id, "root", root, "pathsets", len(paths), "bytes", common.StorageSize(bytes))
+ return p2p.Send(p.rw, GetTrieNodesMsg, &GetTrieNodesPacket{
+ ID: id,
+ Root: root,
+ Paths: paths,
+ Bytes: bytes,
+ })
+}
diff --git a/eth/protocols/snap/protocol.go b/eth/protocols/snap/protocol.go
new file mode 100644
index 000000000..a1e434969
--- /dev/null
+++ b/eth/protocols/snap/protocol.go
@@ -0,0 +1,218 @@
+// Copyright 2020 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see .
+
+package snap
+
+import (
+ "errors"
+ "fmt"
+
+ "github.com/ethereum/go-ethereum/common"
+ "github.com/ethereum/go-ethereum/core/state/snapshot"
+ "github.com/ethereum/go-ethereum/rlp"
+)
+
+// Constants to match up protocol versions and messages
+const (
+ snap1 = 1
+)
+
+// protocolName is the official short name of the `snap` protocol used during
+// devp2p capability negotiation.
+const protocolName = "snap"
+
+// protocolVersions are the supported versions of the `snap` protocol (first
+// is primary).
+var protocolVersions = []uint{snap1}
+
+// protocolLengths are the number of implemented message corresponding to
+// different protocol versions.
+var protocolLengths = map[uint]uint64{snap1: 8}
+
+// maxMessageSize is the maximum cap on the size of a protocol message.
+const maxMessageSize = 10 * 1024 * 1024
+
+const (
+ GetAccountRangeMsg = 0x00
+ AccountRangeMsg = 0x01
+ GetStorageRangesMsg = 0x02
+ StorageRangesMsg = 0x03
+ GetByteCodesMsg = 0x04
+ ByteCodesMsg = 0x05
+ GetTrieNodesMsg = 0x06
+ TrieNodesMsg = 0x07
+)
+
+var (
+ errMsgTooLarge = errors.New("message too long")
+ errDecode = errors.New("invalid message")
+ errInvalidMsgCode = errors.New("invalid message code")
+ errBadRequest = errors.New("bad request")
+)
+
+// Packet represents a p2p message in the `snap` protocol.
+type Packet interface {
+ Name() string // Name returns a string corresponding to the message type.
+ Kind() byte // Kind returns the message type.
+}
+
+// GetAccountRangePacket represents an account query.
+type GetAccountRangePacket struct {
+ ID uint64 // Request ID to match up responses with
+ Root common.Hash // Root hash of the account trie to serve
+ Origin common.Hash // Hash of the first account to retrieve
+ Limit common.Hash // Hash of the last account to retrieve
+ Bytes uint64 // Soft limit at which to stop returning data
+}
+
+// AccountRangePacket represents an account query response.
+type AccountRangePacket struct {
+ ID uint64 // ID of the request this is a response for
+ Accounts []*AccountData // List of consecutive accounts from the trie
+ Proof [][]byte // List of trie nodes proving the account range
+}
+
+// AccountData represents a single account in a query response.
+type AccountData struct {
+ Hash common.Hash // Hash of the account
+ Body rlp.RawValue // Account body in slim format
+}
+
+// Unpack retrieves the accounts from the range packet and converts from slim
+// wire representation to consensus format. The returned data is RLP encoded
+// since it's expected to be serialized to disk without further interpretation.
+//
+// Note, this method does a round of RLP decoding and reencoding, so only use it
+// once and cache the results if need be. Ideally discard the packet afterwards
+// to not double the memory use.
+func (p *AccountRangePacket) Unpack() ([]common.Hash, [][]byte, error) {
+ var (
+ hashes = make([]common.Hash, len(p.Accounts))
+ accounts = make([][]byte, len(p.Accounts))
+ )
+ for i, acc := range p.Accounts {
+ val, err := snapshot.FullAccountRLP(acc.Body)
+ if err != nil {
+ return nil, nil, fmt.Errorf("invalid account %x: %v", acc.Body, err)
+ }
+ hashes[i], accounts[i] = acc.Hash, val
+ }
+ return hashes, accounts, nil
+}
+
+// GetStorageRangesPacket represents an storage slot query.
+type GetStorageRangesPacket struct {
+ ID uint64 // Request ID to match up responses with
+ Root common.Hash // Root hash of the account trie to serve
+ Accounts []common.Hash // Account hashes of the storage tries to serve
+ Origin []byte // Hash of the first storage slot to retrieve (large contract mode)
+ Limit []byte // Hash of the last storage slot to retrieve (large contract mode)
+ Bytes uint64 // Soft limit at which to stop returning data
+}
+
+// StorageRangesPacket represents a storage slot query response.
+type StorageRangesPacket struct {
+ ID uint64 // ID of the request this is a response for
+ Slots [][]*StorageData // Lists of consecutive storage slots for the requested accounts
+ Proof [][]byte // Merkle proofs for the *last* slot range, if it's incomplete
+}
+
+// StorageData represents a single storage slot in a query response.
+type StorageData struct {
+ Hash common.Hash // Hash of the storage slot
+ Body []byte // Data content of the slot
+}
+
+// Unpack retrieves the storage slots from the range packet and returns them in
+// a split flat format that's more consistent with the internal data structures.
+func (p *StorageRangesPacket) Unpack() ([][]common.Hash, [][][]byte) {
+ var (
+ hashset = make([][]common.Hash, len(p.Slots))
+ slotset = make([][][]byte, len(p.Slots))
+ )
+ for i, slots := range p.Slots {
+ hashset[i] = make([]common.Hash, len(slots))
+ slotset[i] = make([][]byte, len(slots))
+ for j, slot := range slots {
+ hashset[i][j] = slot.Hash
+ slotset[i][j] = slot.Body
+ }
+ }
+ return hashset, slotset
+}
+
+// GetByteCodesPacket represents a contract bytecode query.
+type GetByteCodesPacket struct {
+ ID uint64 // Request ID to match up responses with
+ Hashes []common.Hash // Code hashes to retrieve the code for
+ Bytes uint64 // Soft limit at which to stop returning data
+}
+
+// ByteCodesPacket represents a contract bytecode query response.
+type ByteCodesPacket struct {
+ ID uint64 // ID of the request this is a response for
+ Codes [][]byte // Requested contract bytecodes
+}
+
+// GetTrieNodesPacket represents a state trie node query.
+type GetTrieNodesPacket struct {
+ ID uint64 // Request ID to match up responses with
+ Root common.Hash // Root hash of the account trie to serve
+ Paths []TrieNodePathSet // Trie node hashes to retrieve the nodes for
+ Bytes uint64 // Soft limit at which to stop returning data
+}
+
+// TrieNodePathSet is a list of trie node paths to retrieve. A naive way to
+// represent trie nodes would be a simple list of `account || storage` path
+// segments concatenated, but that would be very wasteful on the network.
+//
+// Instead, this array special cases the first element as the path in the
+// account trie and the remaining elements as paths in the storage trie. To
+// address an account node, the slice should have a length of 1 consisting
+// of only the account path. There's no need to be able to address both an
+// account node and a storage node in the same request as it cannot happen
+// that a slot is accessed before the account path is fully expanded.
+type TrieNodePathSet [][]byte
+
+// TrieNodesPacket represents a state trie node query response.
+type TrieNodesPacket struct {
+ ID uint64 // ID of the request this is a response for
+ Nodes [][]byte // Requested state trie nodes
+}
+
+func (*GetAccountRangePacket) Name() string { return "GetAccountRange" }
+func (*GetAccountRangePacket) Kind() byte { return GetAccountRangeMsg }
+
+func (*AccountRangePacket) Name() string { return "AccountRange" }
+func (*AccountRangePacket) Kind() byte { return AccountRangeMsg }
+
+func (*GetStorageRangesPacket) Name() string { return "GetStorageRanges" }
+func (*GetStorageRangesPacket) Kind() byte { return GetStorageRangesMsg }
+
+func (*StorageRangesPacket) Name() string { return "StorageRanges" }
+func (*StorageRangesPacket) Kind() byte { return StorageRangesMsg }
+
+func (*GetByteCodesPacket) Name() string { return "GetByteCodes" }
+func (*GetByteCodesPacket) Kind() byte { return GetByteCodesMsg }
+
+func (*ByteCodesPacket) Name() string { return "ByteCodes" }
+func (*ByteCodesPacket) Kind() byte { return ByteCodesMsg }
+
+func (*GetTrieNodesPacket) Name() string { return "GetTrieNodes" }
+func (*GetTrieNodesPacket) Kind() byte { return GetTrieNodesMsg }
+
+func (*TrieNodesPacket) Name() string { return "TrieNodes" }
+func (*TrieNodesPacket) Kind() byte { return TrieNodesMsg }
diff --git a/eth/protocols/snap/sync.go b/eth/protocols/snap/sync.go
new file mode 100644
index 000000000..679b32828
--- /dev/null
+++ b/eth/protocols/snap/sync.go
@@ -0,0 +1,2481 @@
+// Copyright 2020 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see .
+
+package snap
+
+import (
+ "bytes"
+ "encoding/json"
+ "errors"
+ "fmt"
+ "math/big"
+ "math/rand"
+ "sync"
+ "time"
+
+ "github.com/ethereum/go-ethereum/common"
+ "github.com/ethereum/go-ethereum/core/rawdb"
+ "github.com/ethereum/go-ethereum/core/state"
+ "github.com/ethereum/go-ethereum/crypto"
+ "github.com/ethereum/go-ethereum/ethdb"
+ "github.com/ethereum/go-ethereum/event"
+ "github.com/ethereum/go-ethereum/light"
+ "github.com/ethereum/go-ethereum/log"
+ "github.com/ethereum/go-ethereum/rlp"
+ "github.com/ethereum/go-ethereum/trie"
+ "golang.org/x/crypto/sha3"
+)
+
+var (
+ // emptyRoot is the known root hash of an empty trie.
+ emptyRoot = common.HexToHash("56e81f171bcc55a6ff8345e692c0f86e5b48e01b996cadc001622fb5e363b421")
+
+ // emptyCode is the known hash of the empty EVM bytecode.
+ emptyCode = crypto.Keccak256Hash(nil)
+)
+
+const (
+ // maxRequestSize is the maximum number of bytes to request from a remote peer.
+ maxRequestSize = 512 * 1024
+
+ // maxStorageSetRequestCountis th maximum number of contracts to request the
+ // storage of in a single query. If this number is too low, we're not filling
+ // responses fully and waste round trip times. If it's too high, we're capping
+ // responses and waste bandwidth.
+ maxStorageSetRequestCount = maxRequestSize / 1024
+
+ // maxCodeRequestCount is the maximum number of bytecode blobs to request in a
+ // single query. If this number is too low, we're not filling responses fully
+ // and waste round trip times. If it's too high, we're capping responses and
+ // waste bandwidth.
+ //
+ // Depoyed bytecodes are currently capped at 24KB, so the minimum request
+ // size should be maxRequestSize / 24K. Assuming that most contracts do not
+ // come close to that, requesting 4x should be a good approximation.
+ maxCodeRequestCount = maxRequestSize / (24 * 1024) * 4
+
+ // maxTrieRequestCount is the maximum number of trie node blobs to request in
+ // a single query. If this number is too low, we're not filling responses fully
+ // and waste round trip times. If it's too high, we're capping responses and
+ // waste bandwidth.
+ maxTrieRequestCount = 512
+
+ // requestTimeout is the maximum time a peer is allowed to spend on serving
+ // a single network request.
+ requestTimeout = 10 * time.Second // TODO(karalabe): Make it dynamic ala fast-sync?
+
+ // accountConcurrency is the number of chunks to split the account trie into
+ // to allow concurrent retrievals.
+ accountConcurrency = 16
+
+ // storageConcurrency is the number of chunks to split the a large contract
+ // storage trie into to allow concurrent retrievals.
+ storageConcurrency = 16
+)
+
+// accountRequest tracks a pending account range request to ensure responses are
+// to actual requests and to validate any security constraints.
+//
+// Concurrency note: account requests and responses are handled concurrently from
+// the main runloop to allow Merkle proof verifications on the peer's thread and
+// to drop on invalid response. The request struct must contain all the data to
+// construct the response without accessing runloop internals (i.e. task). That
+// is only included to allow the runloop to match a response to the task being
+// synced without having yet another set of maps.
+type accountRequest struct {
+ peer string // Peer to which this request is assigned
+ id uint64 // Request ID of this request
+
+ cancel chan struct{} // Channel to track sync cancellation
+ timeout *time.Timer // Timer to track delivery timeout
+ stale chan struct{} // Channel to signal the request was dropped
+
+ origin common.Hash // First account requested to allow continuation checks
+ limit common.Hash // Last account requested to allow non-overlapping chunking
+
+ task *accountTask // Task which this request is filling (only access fields through the runloop!!)
+}
+
+// accountResponse is an already Merkle-verified remote response to an account
+// range request. It contains the subtrie for the requested account range and
+// the database that's going to be filled with the internal nodes on commit.
+type accountResponse struct {
+ task *accountTask // Task which this request is filling
+
+ hashes []common.Hash // Account hashes in the returned range
+ accounts []*state.Account // Expanded accounts in the returned range
+
+ nodes ethdb.KeyValueStore // Database containing the reconstructed trie nodes
+ trie *trie.Trie // Reconstructed trie to reject incomplete account paths
+
+ bounds map[common.Hash]struct{} // Boundary nodes to avoid persisting incomplete accounts
+ overflow *light.NodeSet // Overflow nodes to avoid persisting across chunk boundaries
+
+ cont bool // Whether the account range has a continuation
+}
+
+// bytecodeRequest tracks a pending bytecode request to ensure responses are to
+// actual requests and to validate any security constraints.
+//
+// Concurrency note: bytecode requests and responses are handled concurrently from
+// the main runloop to allow Keccak256 hash verifications on the peer's thread and
+// to drop on invalid response. The request struct must contain all the data to
+// construct the response without accessing runloop internals (i.e. task). That
+// is only included to allow the runloop to match a response to the task being
+// synced without having yet another set of maps.
+type bytecodeRequest struct {
+ peer string // Peer to which this request is assigned
+ id uint64 // Request ID of this request
+
+ cancel chan struct{} // Channel to track sync cancellation
+ timeout *time.Timer // Timer to track delivery timeout
+ stale chan struct{} // Channel to signal the request was dropped
+
+ hashes []common.Hash // Bytecode hashes to validate responses
+ task *accountTask // Task which this request is filling (only access fields through the runloop!!)
+}
+
+// bytecodeResponse is an already verified remote response to a bytecode request.
+type bytecodeResponse struct {
+ task *accountTask // Task which this request is filling
+
+ hashes []common.Hash // Hashes of the bytecode to avoid double hashing
+ codes [][]byte // Actual bytecodes to store into the database (nil = missing)
+}
+
+// storageRequest tracks a pending storage ranges request to ensure responses are
+// to actual requests and to validate any security constraints.
+//
+// Concurrency note: storage requests and responses are handled concurrently from
+// the main runloop to allow Merkel proof verifications on the peer's thread and
+// to drop on invalid response. The request struct must contain all the data to
+// construct the response without accessing runloop internals (i.e. tasks). That
+// is only included to allow the runloop to match a response to the task being
+// synced without having yet another set of maps.
+type storageRequest struct {
+ peer string // Peer to which this request is assigned
+ id uint64 // Request ID of this request
+
+ cancel chan struct{} // Channel to track sync cancellation
+ timeout *time.Timer // Timer to track delivery timeout
+ stale chan struct{} // Channel to signal the request was dropped
+
+ accounts []common.Hash // Account hashes to validate responses
+ roots []common.Hash // Storage roots to validate responses
+
+ origin common.Hash // First storage slot requested to allow continuation checks
+ limit common.Hash // Last storage slot requested to allow non-overlapping chunking
+
+ mainTask *accountTask // Task which this response belongs to (only access fields through the runloop!!)
+ subTask *storageTask // Task which this response is filling (only access fields through the runloop!!)
+}
+
+// storageResponse is an already Merkle-verified remote response to a storage
+// range request. It contains the subtries for the requested storage ranges and
+// the databases that's going to be filled with the internal nodes on commit.
+type storageResponse struct {
+ mainTask *accountTask // Task which this response belongs to
+ subTask *storageTask // Task which this response is filling
+
+ accounts []common.Hash // Account hashes requested, may be only partially filled
+ roots []common.Hash // Storage roots requested, may be only partially filled
+
+ hashes [][]common.Hash // Storage slot hashes in the returned range
+ slots [][][]byte // Storage slot values in the returned range
+ nodes []ethdb.KeyValueStore // Database containing the reconstructed trie nodes
+ tries []*trie.Trie // Reconstructed tries to reject overflown slots
+
+ // Fields relevant for the last account only
+ bounds map[common.Hash]struct{} // Boundary nodes to avoid persisting (incomplete)
+ overflow *light.NodeSet // Overflow nodes to avoid persisting across chunk boundaries
+ cont bool // Whether the last storage range has a continuation
+}
+
+// trienodeHealRequest tracks a pending state trie request to ensure responses
+// are to actual requests and to validate any security constraints.
+//
+// Concurrency note: trie node requests and responses are handled concurrently from
+// the main runloop to allow Keccak256 hash verifications on the peer's thread and
+// to drop on invalid response. The request struct must contain all the data to
+// construct the response without accessing runloop internals (i.e. task). That
+// is only included to allow the runloop to match a response to the task being
+// synced without having yet another set of maps.
+type trienodeHealRequest struct {
+ peer string // Peer to which this request is assigned
+ id uint64 // Request ID of this request
+
+ cancel chan struct{} // Channel to track sync cancellation
+ timeout *time.Timer // Timer to track delivery timeout
+ stale chan struct{} // Channel to signal the request was dropped
+
+ hashes []common.Hash // Trie node hashes to validate responses
+ paths []trie.SyncPath // Trie node paths requested for rescheduling
+
+ task *healTask // Task which this request is filling (only access fields through the runloop!!)
+}
+
+// trienodeHealResponse is an already verified remote response to a trie node request.
+type trienodeHealResponse struct {
+ task *healTask // Task which this request is filling
+
+ hashes []common.Hash // Hashes of the trie nodes to avoid double hashing
+ paths []trie.SyncPath // Trie node paths requested for rescheduling missing ones
+ nodes [][]byte // Actual trie nodes to store into the database (nil = missing)
+}
+
+// bytecodeHealRequest tracks a pending bytecode request to ensure responses are to
+// actual requests and to validate any security constraints.
+//
+// Concurrency note: bytecode requests and responses are handled concurrently from
+// the main runloop to allow Keccak256 hash verifications on the peer's thread and
+// to drop on invalid response. The request struct must contain all the data to
+// construct the response without accessing runloop internals (i.e. task). That
+// is only included to allow the runloop to match a response to the task being
+// synced without having yet another set of maps.
+type bytecodeHealRequest struct {
+ peer string // Peer to which this request is assigned
+ id uint64 // Request ID of this request
+
+ cancel chan struct{} // Channel to track sync cancellation
+ timeout *time.Timer // Timer to track delivery timeout
+ stale chan struct{} // Channel to signal the request was dropped
+
+ hashes []common.Hash // Bytecode hashes to validate responses
+ task *healTask // Task which this request is filling (only access fields through the runloop!!)
+}
+
+// bytecodeHealResponse is an already verified remote response to a bytecode request.
+type bytecodeHealResponse struct {
+ task *healTask // Task which this request is filling
+
+ hashes []common.Hash // Hashes of the bytecode to avoid double hashing
+ codes [][]byte // Actual bytecodes to store into the database (nil = missing)
+}
+
+// accountTask represents the sync task for a chunk of the account snapshot.
+type accountTask struct {
+ // These fields get serialized to leveldb on shutdown
+ Next common.Hash // Next account to sync in this interval
+ Last common.Hash // Last account to sync in this interval
+ SubTasks map[common.Hash][]*storageTask // Storage intervals needing fetching for large contracts
+
+ // These fields are internals used during runtime
+ req *accountRequest // Pending request to fill this task
+ res *accountResponse // Validate response filling this task
+ pend int // Number of pending subtasks for this round
+
+ needCode []bool // Flags whether the filling accounts need code retrieval
+ needState []bool // Flags whether the filling accounts need storage retrieval
+ needHeal []bool // Flags whether the filling accounts's state was chunked and need healing
+
+ codeTasks map[common.Hash]struct{} // Code hashes that need retrieval
+ stateTasks map[common.Hash]common.Hash // Account hashes->roots that need full state retrieval
+
+ done bool // Flag whether the task can be removed
+}
+
+// storageTask represents the sync task for a chunk of the storage snapshot.
+type storageTask struct {
+ Next common.Hash // Next account to sync in this interval
+ Last common.Hash // Last account to sync in this interval
+
+ // These fields are internals used during runtime
+ root common.Hash // Storage root hash for this instance
+ req *storageRequest // Pending request to fill this task
+ done bool // Flag whether the task can be removed
+}
+
+// healTask represents the sync task for healing the snap-synced chunk boundaries.
+type healTask struct {
+ scheduler *trie.Sync // State trie sync scheduler defining the tasks
+
+ trieTasks map[common.Hash]trie.SyncPath // Set of trie node tasks currently queued for retrieval
+ codeTasks map[common.Hash]struct{} // Set of byte code tasks currently queued for retrieval
+}
+
+// syncProgress is a database entry to allow suspending and resuming a snapshot state
+// sync. Opposed to full and fast sync, there is no way to restart a suspended
+// snap sync without prior knowledge of the suspension point.
+type syncProgress struct {
+ Tasks []*accountTask // The suspended account tasks (contract tasks within)
+
+ // Status report during syncing phase
+ AccountSynced uint64 // Number of accounts downloaded
+ AccountBytes common.StorageSize // Number of account trie bytes persisted to disk
+ BytecodeSynced uint64 // Number of bytecodes downloaded
+ BytecodeBytes common.StorageSize // Number of bytecode bytes downloaded
+ StorageSynced uint64 // Number of storage slots downloaded
+ StorageBytes common.StorageSize // Number of storage trie bytes persisted to disk
+
+ // Status report during healing phase
+ TrienodeHealSynced uint64 // Number of state trie nodes downloaded
+ TrienodeHealBytes common.StorageSize // Number of state trie bytes persisted to disk
+ TrienodeHealDups uint64 // Number of state trie nodes already processed
+ TrienodeHealNops uint64 // Number of state trie nodes not requested
+ BytecodeHealSynced uint64 // Number of bytecodes downloaded
+ BytecodeHealBytes common.StorageSize // Number of bytecodes persisted to disk
+ BytecodeHealDups uint64 // Number of bytecodes already processed
+ BytecodeHealNops uint64 // Number of bytecodes not requested
+}
+
+// Syncer is an Ethereum account and storage trie syncer based on snapshots and
+// the snap protocol. It's purpose is to download all the accounts and storage
+// slots from remote peers and reassemble chunks of the state trie, on top of
+// which a state sync can be run to fix any gaps / overlaps.
+//
+// Every network request has a variety of failure events:
+// - The peer disconnects after task assignment, failing to send the request
+// - The peer disconnects after sending the request, before delivering on it
+// - The peer remains connected, but does not deliver a response in time
+// - The peer delivers a stale response after a previous timeout
+// - The peer delivers a refusal to serve the requested state
+type Syncer struct {
+ db ethdb.KeyValueStore // Database to store the trie nodes into (and dedup)
+ bloom *trie.SyncBloom // Bloom filter to deduplicate nodes for state fixup
+
+ root common.Hash // Current state trie root being synced
+ tasks []*accountTask // Current account task set being synced
+ healer *healTask // Current state healing task being executed
+ update chan struct{} // Notification channel for possible sync progression
+
+ peers map[string]*Peer // Currently active peers to download from
+ peerJoin *event.Feed // Event feed to react to peers joining
+ peerDrop *event.Feed // Event feed to react to peers dropping
+
+ // Request tracking during syncing phase
+ statelessPeers map[string]struct{} // Peers that failed to deliver state data
+ accountIdlers map[string]struct{} // Peers that aren't serving account requests
+ bytecodeIdlers map[string]struct{} // Peers that aren't serving bytecode requests
+ storageIdlers map[string]struct{} // Peers that aren't serving storage requests
+
+ accountReqs map[uint64]*accountRequest // Account requests currently running
+ bytecodeReqs map[uint64]*bytecodeRequest // Bytecode requests currently running
+ storageReqs map[uint64]*storageRequest // Storage requests currently running
+
+ accountReqFails chan *accountRequest // Failed account range requests to revert
+ bytecodeReqFails chan *bytecodeRequest // Failed bytecode requests to revert
+ storageReqFails chan *storageRequest // Failed storage requests to revert
+
+ accountResps chan *accountResponse // Account sub-tries to integrate into the database
+ bytecodeResps chan *bytecodeResponse // Bytecodes to integrate into the database
+ storageResps chan *storageResponse // Storage sub-tries to integrate into the database
+
+ accountSynced uint64 // Number of accounts downloaded
+ accountBytes common.StorageSize // Number of account trie bytes persisted to disk
+ bytecodeSynced uint64 // Number of bytecodes downloaded
+ bytecodeBytes common.StorageSize // Number of bytecode bytes downloaded
+ storageSynced uint64 // Number of storage slots downloaded
+ storageBytes common.StorageSize // Number of storage trie bytes persisted to disk
+
+ // Request tracking during healing phase
+ trienodeHealIdlers map[string]struct{} // Peers that aren't serving trie node requests
+ bytecodeHealIdlers map[string]struct{} // Peers that aren't serving bytecode requests
+
+ trienodeHealReqs map[uint64]*trienodeHealRequest // Trie node requests currently running
+ bytecodeHealReqs map[uint64]*bytecodeHealRequest // Bytecode requests currently running
+
+ trienodeHealReqFails chan *trienodeHealRequest // Failed trienode requests to revert
+ bytecodeHealReqFails chan *bytecodeHealRequest // Failed bytecode requests to revert
+
+ trienodeHealResps chan *trienodeHealResponse // Trie nodes to integrate into the database
+ bytecodeHealResps chan *bytecodeHealResponse // Bytecodes to integrate into the database
+
+ trienodeHealSynced uint64 // Number of state trie nodes downloaded
+ trienodeHealBytes common.StorageSize // Number of state trie bytes persisted to disk
+ trienodeHealDups uint64 // Number of state trie nodes already processed
+ trienodeHealNops uint64 // Number of state trie nodes not requested
+ bytecodeHealSynced uint64 // Number of bytecodes downloaded
+ bytecodeHealBytes common.StorageSize // Number of bytecodes persisted to disk
+ bytecodeHealDups uint64 // Number of bytecodes already processed
+ bytecodeHealNops uint64 // Number of bytecodes not requested
+
+ startTime time.Time // Time instance when snapshot sync started
+ startAcc common.Hash // Account hash where sync started from
+ logTime time.Time // Time instance when status was last reported
+
+ pend sync.WaitGroup // Tracks network request goroutines for graceful shutdown
+ lock sync.RWMutex // Protects fields that can change outside of sync (peers, reqs, root)
+}
+
+func NewSyncer(db ethdb.KeyValueStore, bloom *trie.SyncBloom) *Syncer {
+ return &Syncer{
+ db: db,
+ bloom: bloom,
+
+ peers: make(map[string]*Peer),
+ peerJoin: new(event.Feed),
+ peerDrop: new(event.Feed),
+ update: make(chan struct{}, 1),
+
+ accountIdlers: make(map[string]struct{}),
+ storageIdlers: make(map[string]struct{}),
+ bytecodeIdlers: make(map[string]struct{}),
+
+ accountReqs: make(map[uint64]*accountRequest),
+ storageReqs: make(map[uint64]*storageRequest),
+ bytecodeReqs: make(map[uint64]*bytecodeRequest),
+ accountReqFails: make(chan *accountRequest),
+ storageReqFails: make(chan *storageRequest),
+ bytecodeReqFails: make(chan *bytecodeRequest),
+ accountResps: make(chan *accountResponse),
+ storageResps: make(chan *storageResponse),
+ bytecodeResps: make(chan *bytecodeResponse),
+
+ trienodeHealIdlers: make(map[string]struct{}),
+ bytecodeHealIdlers: make(map[string]struct{}),
+
+ trienodeHealReqs: make(map[uint64]*trienodeHealRequest),
+ bytecodeHealReqs: make(map[uint64]*bytecodeHealRequest),
+ trienodeHealReqFails: make(chan *trienodeHealRequest),
+ bytecodeHealReqFails: make(chan *bytecodeHealRequest),
+ trienodeHealResps: make(chan *trienodeHealResponse),
+ bytecodeHealResps: make(chan *bytecodeHealResponse),
+ }
+}
+
+// Register injects a new data source into the syncer's peerset.
+func (s *Syncer) Register(peer *Peer) error {
+ // Make sure the peer is not registered yet
+ s.lock.Lock()
+ if _, ok := s.peers[peer.id]; ok {
+ log.Error("Snap peer already registered", "id", peer.id)
+
+ s.lock.Unlock()
+ return errors.New("already registered")
+ }
+ s.peers[peer.id] = peer
+
+ // Mark the peer as idle, even if no sync is running
+ s.accountIdlers[peer.id] = struct{}{}
+ s.storageIdlers[peer.id] = struct{}{}
+ s.bytecodeIdlers[peer.id] = struct{}{}
+ s.trienodeHealIdlers[peer.id] = struct{}{}
+ s.bytecodeHealIdlers[peer.id] = struct{}{}
+ s.lock.Unlock()
+
+ // Notify any active syncs that a new peer can be assigned data
+ s.peerJoin.Send(peer.id)
+ return nil
+}
+
+// Unregister injects a new data source into the syncer's peerset.
+func (s *Syncer) Unregister(id string) error {
+ // Remove all traces of the peer from the registry
+ s.lock.Lock()
+ if _, ok := s.peers[id]; !ok {
+ log.Error("Snap peer not registered", "id", id)
+
+ s.lock.Unlock()
+ return errors.New("not registered")
+ }
+ delete(s.peers, id)
+
+ // Remove status markers, even if no sync is running
+ delete(s.statelessPeers, id)
+
+ delete(s.accountIdlers, id)
+ delete(s.storageIdlers, id)
+ delete(s.bytecodeIdlers, id)
+ delete(s.trienodeHealIdlers, id)
+ delete(s.bytecodeHealIdlers, id)
+ s.lock.Unlock()
+
+ // Notify any active syncs that pending requests need to be reverted
+ s.peerDrop.Send(id)
+ return nil
+}
+
+// Sync starts (or resumes a previous) sync cycle to iterate over an state trie
+// with the given root and reconstruct the nodes based on the snapshot leaves.
+// Previously downloaded segments will not be redownloaded of fixed, rather any
+// errors will be healed after the leaves are fully accumulated.
+func (s *Syncer) Sync(root common.Hash, cancel chan struct{}) error {
+ // Move the trie root from any previous value, revert stateless markers for
+ // any peers and initialize the syncer if it was not yet run
+ s.lock.Lock()
+ s.root = root
+ s.healer = &healTask{
+ scheduler: state.NewStateSync(root, s.db, s.bloom),
+ trieTasks: make(map[common.Hash]trie.SyncPath),
+ codeTasks: make(map[common.Hash]struct{}),
+ }
+ s.statelessPeers = make(map[string]struct{})
+ s.lock.Unlock()
+
+ if s.startTime == (time.Time{}) {
+ s.startTime = time.Now()
+ }
+ // Retrieve the previous sync status from LevelDB and abort if already synced
+ s.loadSyncStatus()
+ if len(s.tasks) == 0 && s.healer.scheduler.Pending() == 0 {
+ log.Debug("Snapshot sync already completed")
+ return nil
+ }
+ defer func() { // Persist any progress, independent of failure
+ for _, task := range s.tasks {
+ s.forwardAccountTask(task)
+ }
+ s.cleanAccountTasks()
+ s.saveSyncStatus()
+ }()
+
+ log.Debug("Starting snapshot sync cycle", "root", root)
+ defer s.report(true)
+
+ // Whether sync completed or not, disregard any future packets
+ defer func() {
+ log.Debug("Terminating snapshot sync cycle", "root", root)
+ s.lock.Lock()
+ s.accountReqs = make(map[uint64]*accountRequest)
+ s.storageReqs = make(map[uint64]*storageRequest)
+ s.bytecodeReqs = make(map[uint64]*bytecodeRequest)
+ s.trienodeHealReqs = make(map[uint64]*trienodeHealRequest)
+ s.bytecodeHealReqs = make(map[uint64]*bytecodeHealRequest)
+ s.lock.Unlock()
+ }()
+ // Keep scheduling sync tasks
+ peerJoin := make(chan string, 16)
+ peerJoinSub := s.peerJoin.Subscribe(peerJoin)
+ defer peerJoinSub.Unsubscribe()
+
+ peerDrop := make(chan string, 16)
+ peerDropSub := s.peerDrop.Subscribe(peerDrop)
+ defer peerDropSub.Unsubscribe()
+
+ for {
+ // Remove all completed tasks and terminate sync if everything's done
+ s.cleanStorageTasks()
+ s.cleanAccountTasks()
+ if len(s.tasks) == 0 && s.healer.scheduler.Pending() == 0 {
+ return nil
+ }
+ // Assign all the data retrieval tasks to any free peers
+ s.assignAccountTasks(cancel)
+ s.assignBytecodeTasks(cancel)
+ s.assignStorageTasks(cancel)
+ if len(s.tasks) == 0 {
+ // Sync phase done, run heal phase
+ s.assignTrienodeHealTasks(cancel)
+ s.assignBytecodeHealTasks(cancel)
+ }
+ // Wait for something to happen
+ select {
+ case <-s.update:
+ // Something happened (new peer, delivery, timeout), recheck tasks
+ case <-peerJoin:
+ // A new peer joined, try to schedule it new tasks
+ case id := <-peerDrop:
+ s.revertRequests(id)
+ case <-cancel:
+ return nil
+
+ case req := <-s.accountReqFails:
+ s.revertAccountRequest(req)
+ case req := <-s.bytecodeReqFails:
+ s.revertBytecodeRequest(req)
+ case req := <-s.storageReqFails:
+ s.revertStorageRequest(req)
+ case req := <-s.trienodeHealReqFails:
+ s.revertTrienodeHealRequest(req)
+ case req := <-s.bytecodeHealReqFails:
+ s.revertBytecodeHealRequest(req)
+
+ case res := <-s.accountResps:
+ s.processAccountResponse(res)
+ case res := <-s.bytecodeResps:
+ s.processBytecodeResponse(res)
+ case res := <-s.storageResps:
+ s.processStorageResponse(res)
+ case res := <-s.trienodeHealResps:
+ s.processTrienodeHealResponse(res)
+ case res := <-s.bytecodeHealResps:
+ s.processBytecodeHealResponse(res)
+ }
+ // Report stats if something meaningful happened
+ s.report(false)
+ }
+}
+
+// loadSyncStatus retrieves a previously aborted sync status from the database,
+// or generates a fresh one if none is available.
+func (s *Syncer) loadSyncStatus() {
+ var progress syncProgress
+
+ if status := rawdb.ReadSanpshotSyncStatus(s.db); status != nil {
+ if err := json.Unmarshal(status, &progress); err != nil {
+ log.Error("Failed to decode snap sync status", "err", err)
+ } else {
+ for _, task := range progress.Tasks {
+ log.Debug("Scheduled account sync task", "from", task.Next, "last", task.Last)
+ }
+ s.tasks = progress.Tasks
+
+ s.accountSynced = progress.AccountSynced
+ s.accountBytes = progress.AccountBytes
+ s.bytecodeSynced = progress.BytecodeSynced
+ s.bytecodeBytes = progress.BytecodeBytes
+ s.storageSynced = progress.StorageSynced
+ s.storageBytes = progress.StorageBytes
+
+ s.trienodeHealSynced = progress.TrienodeHealSynced
+ s.trienodeHealBytes = progress.TrienodeHealBytes
+ s.bytecodeHealSynced = progress.BytecodeHealSynced
+ s.bytecodeHealBytes = progress.BytecodeHealBytes
+ return
+ }
+ }
+ // Either we've failed to decode the previus state, or there was none.
+ // Start a fresh sync by chunking up the account range and scheduling
+ // them for retrieval.
+ s.tasks = nil
+ s.accountSynced, s.accountBytes = 0, 0
+ s.bytecodeSynced, s.bytecodeBytes = 0, 0
+ s.storageSynced, s.storageBytes = 0, 0
+ s.trienodeHealSynced, s.trienodeHealBytes = 0, 0
+ s.bytecodeHealSynced, s.bytecodeHealBytes = 0, 0
+
+ var next common.Hash
+ step := new(big.Int).Sub(
+ new(big.Int).Div(
+ new(big.Int).Exp(common.Big2, common.Big256, nil),
+ big.NewInt(accountConcurrency),
+ ), common.Big1,
+ )
+ for i := 0; i < accountConcurrency; i++ {
+ last := common.BigToHash(new(big.Int).Add(next.Big(), step))
+ if i == accountConcurrency-1 {
+ // Make sure we don't overflow if the step is not a proper divisor
+ last = common.HexToHash("0xffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff")
+ }
+ s.tasks = append(s.tasks, &accountTask{
+ Next: next,
+ Last: last,
+ SubTasks: make(map[common.Hash][]*storageTask),
+ })
+ log.Debug("Created account sync task", "from", next, "last", last)
+ next = common.BigToHash(new(big.Int).Add(last.Big(), common.Big1))
+ }
+}
+
+// saveSyncStatus marshals the remaining sync tasks into leveldb.
+func (s *Syncer) saveSyncStatus() {
+ progress := &syncProgress{
+ Tasks: s.tasks,
+ AccountSynced: s.accountSynced,
+ AccountBytes: s.accountBytes,
+ BytecodeSynced: s.bytecodeSynced,
+ BytecodeBytes: s.bytecodeBytes,
+ StorageSynced: s.storageSynced,
+ StorageBytes: s.storageBytes,
+ TrienodeHealSynced: s.trienodeHealSynced,
+ TrienodeHealBytes: s.trienodeHealBytes,
+ BytecodeHealSynced: s.bytecodeHealSynced,
+ BytecodeHealBytes: s.bytecodeHealBytes,
+ }
+ status, err := json.Marshal(progress)
+ if err != nil {
+ panic(err) // This can only fail during implementation
+ }
+ rawdb.WriteSnapshotSyncStatus(s.db, status)
+}
+
+// cleanAccountTasks removes account range retrieval tasks that have already been
+// completed.
+func (s *Syncer) cleanAccountTasks() {
+ for i := 0; i < len(s.tasks); i++ {
+ if s.tasks[i].done {
+ s.tasks = append(s.tasks[:i], s.tasks[i+1:]...)
+ i--
+ }
+ }
+}
+
+// cleanStorageTasks iterates over all the account tasks and storage sub-tasks
+// within, cleaning any that have been completed.
+func (s *Syncer) cleanStorageTasks() {
+ for _, task := range s.tasks {
+ for account, subtasks := range task.SubTasks {
+ // Remove storage range retrieval tasks that completed
+ for j := 0; j < len(subtasks); j++ {
+ if subtasks[j].done {
+ subtasks = append(subtasks[:j], subtasks[j+1:]...)
+ j--
+ }
+ }
+ if len(subtasks) > 0 {
+ task.SubTasks[account] = subtasks
+ continue
+ }
+ // If all storage chunks are done, mark the account as done too
+ for j, hash := range task.res.hashes {
+ if hash == account {
+ task.needState[j] = false
+ }
+ }
+ delete(task.SubTasks, account)
+ task.pend--
+
+ // If this was the last pending task, forward the account task
+ if task.pend == 0 {
+ s.forwardAccountTask(task)
+ }
+ }
+ }
+}
+
+// assignAccountTasks attempts to match idle peers to pending account range
+// retrievals.
+func (s *Syncer) assignAccountTasks(cancel chan struct{}) {
+ s.lock.Lock()
+ defer s.lock.Unlock()
+
+ // If there are no idle peers, short circuit assignment
+ if len(s.accountIdlers) == 0 {
+ return
+ }
+ // Iterate over all the tasks and try to find a pending one
+ for _, task := range s.tasks {
+ // Skip any tasks already filling
+ if task.req != nil || task.res != nil {
+ continue
+ }
+ // Task pending retrieval, try to find an idle peer. If no such peer
+ // exists, we probably assigned tasks for all (or they are stateless).
+ // Abort the entire assignment mechanism.
+ var idle string
+ for id := range s.accountIdlers {
+ // If the peer rejected a query in this sync cycle, don't bother asking
+ // again for anything, it's either out of sync or already pruned
+ if _, ok := s.statelessPeers[id]; ok {
+ continue
+ }
+ idle = id
+ break
+ }
+ if idle == "" {
+ return
+ }
+ // Matched a pending task to an idle peer, allocate a unique request id
+ var reqid uint64
+ for {
+ reqid = uint64(rand.Int63())
+ if reqid == 0 {
+ continue
+ }
+ if _, ok := s.accountReqs[reqid]; ok {
+ continue
+ }
+ break
+ }
+ // Generate the network query and send it to the peer
+ req := &accountRequest{
+ peer: idle,
+ id: reqid,
+ cancel: cancel,
+ stale: make(chan struct{}),
+ origin: task.Next,
+ limit: task.Last,
+ task: task,
+ }
+ req.timeout = time.AfterFunc(requestTimeout, func() {
+ log.Debug("Account range request timed out")
+ select {
+ case s.accountReqFails <- req:
+ default:
+ }
+ })
+ s.accountReqs[reqid] = req
+ delete(s.accountIdlers, idle)
+
+ s.pend.Add(1)
+ go func(peer *Peer, root common.Hash) {
+ defer s.pend.Done()
+
+ // Attempt to send the remote request and revert if it fails
+ if err := peer.RequestAccountRange(reqid, root, req.origin, req.limit, maxRequestSize); err != nil {
+ peer.Log().Debug("Failed to request account range", "err", err)
+ select {
+ case s.accountReqFails <- req:
+ default:
+ }
+ }
+ // Request successfully sent, start a
+ }(s.peers[idle], s.root) // We're in the lock, peers[id] surely exists
+
+ // Inject the request into the task to block further assignments
+ task.req = req
+ }
+}
+
+// assignBytecodeTasks attempts to match idle peers to pending code retrievals.
+func (s *Syncer) assignBytecodeTasks(cancel chan struct{}) {
+ s.lock.Lock()
+ defer s.lock.Unlock()
+
+ // If there are no idle peers, short circuit assignment
+ if len(s.bytecodeIdlers) == 0 {
+ return
+ }
+ // Iterate over all the tasks and try to find a pending one
+ for _, task := range s.tasks {
+ // Skip any tasks not in the bytecode retrieval phase
+ if task.res == nil {
+ continue
+ }
+ // Skip tasks that are already retrieving (or done with) all codes
+ if len(task.codeTasks) == 0 {
+ continue
+ }
+ // Task pending retrieval, try to find an idle peer. If no such peer
+ // exists, we probably assigned tasks for all (or they are stateless).
+ // Abort the entire assignment mechanism.
+ var idle string
+ for id := range s.bytecodeIdlers {
+ // If the peer rejected a query in this sync cycle, don't bother asking
+ // again for anything, it's either out of sync or already pruned
+ if _, ok := s.statelessPeers[id]; ok {
+ continue
+ }
+ idle = id
+ break
+ }
+ if idle == "" {
+ return
+ }
+ // Matched a pending task to an idle peer, allocate a unique request id
+ var reqid uint64
+ for {
+ reqid = uint64(rand.Int63())
+ if reqid == 0 {
+ continue
+ }
+ if _, ok := s.bytecodeReqs[reqid]; ok {
+ continue
+ }
+ break
+ }
+ // Generate the network query and send it to the peer
+ hashes := make([]common.Hash, 0, maxCodeRequestCount)
+ for hash := range task.codeTasks {
+ delete(task.codeTasks, hash)
+ hashes = append(hashes, hash)
+ if len(hashes) >= maxCodeRequestCount {
+ break
+ }
+ }
+ req := &bytecodeRequest{
+ peer: idle,
+ id: reqid,
+ cancel: cancel,
+ stale: make(chan struct{}),
+ hashes: hashes,
+ task: task,
+ }
+ req.timeout = time.AfterFunc(requestTimeout, func() {
+ log.Debug("Bytecode request timed out")
+ select {
+ case s.bytecodeReqFails <- req:
+ default:
+ }
+ })
+ s.bytecodeReqs[reqid] = req
+ delete(s.bytecodeIdlers, idle)
+
+ s.pend.Add(1)
+ go func(peer *Peer) {
+ defer s.pend.Done()
+
+ // Attempt to send the remote request and revert if it fails
+ if err := peer.RequestByteCodes(reqid, hashes, maxRequestSize); err != nil {
+ log.Debug("Failed to request bytecodes", "err", err)
+ select {
+ case s.bytecodeReqFails <- req:
+ default:
+ }
+ }
+ // Request successfully sent, start a
+ }(s.peers[idle]) // We're in the lock, peers[id] surely exists
+ }
+}
+
+// assignStorageTasks attempts to match idle peers to pending storage range
+// retrievals.
+func (s *Syncer) assignStorageTasks(cancel chan struct{}) {
+ s.lock.Lock()
+ defer s.lock.Unlock()
+
+ // If there are no idle peers, short circuit assignment
+ if len(s.storageIdlers) == 0 {
+ return
+ }
+ // Iterate over all the tasks and try to find a pending one
+ for _, task := range s.tasks {
+ // Skip any tasks not in the storage retrieval phase
+ if task.res == nil {
+ continue
+ }
+ // Skip tasks that are already retrieving (or done with) all small states
+ if len(task.SubTasks) == 0 && len(task.stateTasks) == 0 {
+ continue
+ }
+ // Task pending retrieval, try to find an idle peer. If no such peer
+ // exists, we probably assigned tasks for all (or they are stateless).
+ // Abort the entire assignment mechanism.
+ var idle string
+ for id := range s.storageIdlers {
+ // If the peer rejected a query in this sync cycle, don't bother asking
+ // again for anything, it's either out of sync or already pruned
+ if _, ok := s.statelessPeers[id]; ok {
+ continue
+ }
+ idle = id
+ break
+ }
+ if idle == "" {
+ return
+ }
+ // Matched a pending task to an idle peer, allocate a unique request id
+ var reqid uint64
+ for {
+ reqid = uint64(rand.Int63())
+ if reqid == 0 {
+ continue
+ }
+ if _, ok := s.storageReqs[reqid]; ok {
+ continue
+ }
+ break
+ }
+ // Generate the network query and send it to the peer. If there are
+ // large contract tasks pending, complete those before diving into
+ // even more new contracts.
+ var (
+ accounts = make([]common.Hash, 0, maxStorageSetRequestCount)
+ roots = make([]common.Hash, 0, maxStorageSetRequestCount)
+ subtask *storageTask
+ )
+ for account, subtasks := range task.SubTasks {
+ for _, st := range subtasks {
+ // Skip any subtasks already filling
+ if st.req != nil {
+ continue
+ }
+ // Found an incomplete storage chunk, schedule it
+ accounts = append(accounts, account)
+ roots = append(roots, st.root)
+
+ subtask = st
+ break // Large contract chunks are downloaded individually
+ }
+ if subtask != nil {
+ break // Large contract chunks are downloaded individually
+ }
+ }
+ if subtask == nil {
+ // No large contract required retrieval, but small ones available
+ for acccount, root := range task.stateTasks {
+ delete(task.stateTasks, acccount)
+
+ accounts = append(accounts, acccount)
+ roots = append(roots, root)
+
+ if len(accounts) >= maxStorageSetRequestCount {
+ break
+ }
+ }
+ }
+ // If nothing was found, it means this task is actually already fully
+ // retrieving, but large contracts are hard to detect. Skip to the next.
+ if len(accounts) == 0 {
+ continue
+ }
+ req := &storageRequest{
+ peer: idle,
+ id: reqid,
+ cancel: cancel,
+ stale: make(chan struct{}),
+ accounts: accounts,
+ roots: roots,
+ mainTask: task,
+ subTask: subtask,
+ }
+ if subtask != nil {
+ req.origin = subtask.Next
+ req.limit = subtask.Last
+ }
+ req.timeout = time.AfterFunc(requestTimeout, func() {
+ log.Debug("Storage request timed out")
+ select {
+ case s.storageReqFails <- req:
+ default:
+ }
+ })
+ s.storageReqs[reqid] = req
+ delete(s.storageIdlers, idle)
+
+ s.pend.Add(1)
+ go func(peer *Peer, root common.Hash) {
+ defer s.pend.Done()
+
+ // Attempt to send the remote request and revert if it fails
+ var origin, limit []byte
+ if subtask != nil {
+ origin, limit = req.origin[:], req.limit[:]
+ }
+ if err := peer.RequestStorageRanges(reqid, root, accounts, origin, limit, maxRequestSize); err != nil {
+ log.Debug("Failed to request storage", "err", err)
+ select {
+ case s.storageReqFails <- req:
+ default:
+ }
+ }
+ // Request successfully sent, start a
+ }(s.peers[idle], s.root) // We're in the lock, peers[id] surely exists
+
+ // Inject the request into the subtask to block further assignments
+ if subtask != nil {
+ subtask.req = req
+ }
+ }
+}
+
+// assignTrienodeHealTasks attempts to match idle peers to trie node requests to
+// heal any trie errors caused by the snap sync's chunked retrieval model.
+func (s *Syncer) assignTrienodeHealTasks(cancel chan struct{}) {
+ s.lock.Lock()
+ defer s.lock.Unlock()
+
+ // If there are no idle peers, short circuit assignment
+ if len(s.trienodeHealIdlers) == 0 {
+ return
+ }
+ // Iterate over pending tasks and try to find a peer to retrieve with
+ for len(s.healer.trieTasks) > 0 || s.healer.scheduler.Pending() > 0 {
+ // If there are not enough trie tasks queued to fully assign, fill the
+ // queue from the state sync scheduler. The trie synced schedules these
+ // together with bytecodes, so we need to queue them combined.
+ var (
+ have = len(s.healer.trieTasks) + len(s.healer.codeTasks)
+ want = maxTrieRequestCount + maxCodeRequestCount
+ )
+ if have < want {
+ nodes, paths, codes := s.healer.scheduler.Missing(want - have)
+ for i, hash := range nodes {
+ s.healer.trieTasks[hash] = paths[i]
+ }
+ for _, hash := range codes {
+ s.healer.codeTasks[hash] = struct{}{}
+ }
+ }
+ // If all the heal tasks are bytecodes or already downloading, bail
+ if len(s.healer.trieTasks) == 0 {
+ return
+ }
+ // Task pending retrieval, try to find an idle peer. If no such peer
+ // exists, we probably assigned tasks for all (or they are stateless).
+ // Abort the entire assignment mechanism.
+ var idle string
+ for id := range s.trienodeHealIdlers {
+ // If the peer rejected a query in this sync cycle, don't bother asking
+ // again for anything, it's either out of sync or already pruned
+ if _, ok := s.statelessPeers[id]; ok {
+ continue
+ }
+ idle = id
+ break
+ }
+ if idle == "" {
+ return
+ }
+ // Matched a pending task to an idle peer, allocate a unique request id
+ var reqid uint64
+ for {
+ reqid = uint64(rand.Int63())
+ if reqid == 0 {
+ continue
+ }
+ if _, ok := s.trienodeHealReqs[reqid]; ok {
+ continue
+ }
+ break
+ }
+ // Generate the network query and send it to the peer
+ var (
+ hashes = make([]common.Hash, 0, maxTrieRequestCount)
+ paths = make([]trie.SyncPath, 0, maxTrieRequestCount)
+ pathsets = make([]TrieNodePathSet, 0, maxTrieRequestCount)
+ )
+ for hash, pathset := range s.healer.trieTasks {
+ delete(s.healer.trieTasks, hash)
+
+ hashes = append(hashes, hash)
+ paths = append(paths, pathset)
+ pathsets = append(pathsets, [][]byte(pathset)) // TODO(karalabe): group requests by account hash
+
+ if len(hashes) >= maxTrieRequestCount {
+ break
+ }
+ }
+ req := &trienodeHealRequest{
+ peer: idle,
+ id: reqid,
+ cancel: cancel,
+ stale: make(chan struct{}),
+ hashes: hashes,
+ paths: paths,
+ task: s.healer,
+ }
+ req.timeout = time.AfterFunc(requestTimeout, func() {
+ log.Debug("Trienode heal request timed out")
+ select {
+ case s.trienodeHealReqFails <- req:
+ default:
+ }
+ })
+ s.trienodeHealReqs[reqid] = req
+ delete(s.trienodeHealIdlers, idle)
+
+ s.pend.Add(1)
+ go func(peer *Peer, root common.Hash) {
+ defer s.pend.Done()
+
+ // Attempt to send the remote request and revert if it fails
+ if err := peer.RequestTrieNodes(reqid, root, pathsets, maxRequestSize); err != nil {
+ log.Debug("Failed to request trienode healers", "err", err)
+ select {
+ case s.trienodeHealReqFails <- req:
+ default:
+ }
+ }
+ // Request successfully sent, start a
+ }(s.peers[idle], s.root) // We're in the lock, peers[id] surely exists
+ }
+}
+
+// assignBytecodeHealTasks attempts to match idle peers to bytecode requests to
+// heal any trie errors caused by the snap sync's chunked retrieval model.
+func (s *Syncer) assignBytecodeHealTasks(cancel chan struct{}) {
+ s.lock.Lock()
+ defer s.lock.Unlock()
+
+ // If there are no idle peers, short circuit assignment
+ if len(s.bytecodeHealIdlers) == 0 {
+ return
+ }
+ // Iterate over pending tasks and try to find a peer to retrieve with
+ for len(s.healer.codeTasks) > 0 || s.healer.scheduler.Pending() > 0 {
+ // If there are not enough trie tasks queued to fully assign, fill the
+ // queue from the state sync scheduler. The trie synced schedules these
+ // together with trie nodes, so we need to queue them combined.
+ var (
+ have = len(s.healer.trieTasks) + len(s.healer.codeTasks)
+ want = maxTrieRequestCount + maxCodeRequestCount
+ )
+ if have < want {
+ nodes, paths, codes := s.healer.scheduler.Missing(want - have)
+ for i, hash := range nodes {
+ s.healer.trieTasks[hash] = paths[i]
+ }
+ for _, hash := range codes {
+ s.healer.codeTasks[hash] = struct{}{}
+ }
+ }
+ // If all the heal tasks are trienodes or already downloading, bail
+ if len(s.healer.codeTasks) == 0 {
+ return
+ }
+ // Task pending retrieval, try to find an idle peer. If no such peer
+ // exists, we probably assigned tasks for all (or they are stateless).
+ // Abort the entire assignment mechanism.
+ var idle string
+ for id := range s.bytecodeHealIdlers {
+ // If the peer rejected a query in this sync cycle, don't bother asking
+ // again for anything, it's either out of sync or already pruned
+ if _, ok := s.statelessPeers[id]; ok {
+ continue
+ }
+ idle = id
+ break
+ }
+ if idle == "" {
+ return
+ }
+ // Matched a pending task to an idle peer, allocate a unique request id
+ var reqid uint64
+ for {
+ reqid = uint64(rand.Int63())
+ if reqid == 0 {
+ continue
+ }
+ if _, ok := s.bytecodeHealReqs[reqid]; ok {
+ continue
+ }
+ break
+ }
+ // Generate the network query and send it to the peer
+ hashes := make([]common.Hash, 0, maxCodeRequestCount)
+ for hash := range s.healer.codeTasks {
+ delete(s.healer.codeTasks, hash)
+
+ hashes = append(hashes, hash)
+ if len(hashes) >= maxCodeRequestCount {
+ break
+ }
+ }
+ req := &bytecodeHealRequest{
+ peer: idle,
+ id: reqid,
+ cancel: cancel,
+ stale: make(chan struct{}),
+ hashes: hashes,
+ task: s.healer,
+ }
+ req.timeout = time.AfterFunc(requestTimeout, func() {
+ log.Debug("Bytecode heal request timed out")
+ select {
+ case s.bytecodeHealReqFails <- req:
+ default:
+ }
+ })
+ s.bytecodeHealReqs[reqid] = req
+ delete(s.bytecodeHealIdlers, idle)
+
+ s.pend.Add(1)
+ go func(peer *Peer) {
+ defer s.pend.Done()
+
+ // Attempt to send the remote request and revert if it fails
+ if err := peer.RequestByteCodes(reqid, hashes, maxRequestSize); err != nil {
+ log.Debug("Failed to request bytecode healers", "err", err)
+ select {
+ case s.bytecodeHealReqFails <- req:
+ default:
+ }
+ }
+ // Request successfully sent, start a
+ }(s.peers[idle]) // We're in the lock, peers[id] surely exists
+ }
+}
+
+// revertRequests locates all the currently pending reuqests from a particular
+// peer and reverts them, rescheduling for others to fulfill.
+func (s *Syncer) revertRequests(peer string) {
+ // Gather the requests first, revertals need the lock too
+ s.lock.Lock()
+ var accountReqs []*accountRequest
+ for _, req := range s.accountReqs {
+ if req.peer == peer {
+ accountReqs = append(accountReqs, req)
+ }
+ }
+ var bytecodeReqs []*bytecodeRequest
+ for _, req := range s.bytecodeReqs {
+ if req.peer == peer {
+ bytecodeReqs = append(bytecodeReqs, req)
+ }
+ }
+ var storageReqs []*storageRequest
+ for _, req := range s.storageReqs {
+ if req.peer == peer {
+ storageReqs = append(storageReqs, req)
+ }
+ }
+ var trienodeHealReqs []*trienodeHealRequest
+ for _, req := range s.trienodeHealReqs {
+ if req.peer == peer {
+ trienodeHealReqs = append(trienodeHealReqs, req)
+ }
+ }
+ var bytecodeHealReqs []*bytecodeHealRequest
+ for _, req := range s.bytecodeHealReqs {
+ if req.peer == peer {
+ bytecodeHealReqs = append(bytecodeHealReqs, req)
+ }
+ }
+ s.lock.Unlock()
+
+ // Revert all the requests matching the peer
+ for _, req := range accountReqs {
+ s.revertAccountRequest(req)
+ }
+ for _, req := range bytecodeReqs {
+ s.revertBytecodeRequest(req)
+ }
+ for _, req := range storageReqs {
+ s.revertStorageRequest(req)
+ }
+ for _, req := range trienodeHealReqs {
+ s.revertTrienodeHealRequest(req)
+ }
+ for _, req := range bytecodeHealReqs {
+ s.revertBytecodeHealRequest(req)
+ }
+}
+
+// revertAccountRequest cleans up an account range request and returns all failed
+// retrieval tasks to the scheduler for reassignment.
+func (s *Syncer) revertAccountRequest(req *accountRequest) {
+ log.Trace("Reverting account request", "peer", req.peer, "reqid", req.id)
+ select {
+ case <-req.stale:
+ log.Trace("Account request already reverted", "peer", req.peer, "reqid", req.id)
+ return
+ default:
+ }
+ close(req.stale)
+
+ // Remove the request from the tracked set
+ s.lock.Lock()
+ delete(s.accountReqs, req.id)
+ s.lock.Unlock()
+
+ // If there's a timeout timer still running, abort it and mark the account
+ // task as not-pending, ready for resheduling
+ req.timeout.Stop()
+ if req.task.req == req {
+ req.task.req = nil
+ }
+}
+
+// revertBytecodeRequest cleans up an bytecode request and returns all failed
+// retrieval tasks to the scheduler for reassignment.
+func (s *Syncer) revertBytecodeRequest(req *bytecodeRequest) {
+ log.Trace("Reverting bytecode request", "peer", req.peer)
+ select {
+ case <-req.stale:
+ log.Trace("Bytecode request already reverted", "peer", req.peer, "reqid", req.id)
+ return
+ default:
+ }
+ close(req.stale)
+
+ // Remove the request from the tracked set
+ s.lock.Lock()
+ delete(s.bytecodeReqs, req.id)
+ s.lock.Unlock()
+
+ // If there's a timeout timer still running, abort it and mark the code
+ // retrievals as not-pending, ready for resheduling
+ req.timeout.Stop()
+ for _, hash := range req.hashes {
+ req.task.codeTasks[hash] = struct{}{}
+ }
+}
+
+// revertStorageRequest cleans up a storage range request and returns all failed
+// retrieval tasks to the scheduler for reassignment.
+func (s *Syncer) revertStorageRequest(req *storageRequest) {
+ log.Trace("Reverting storage request", "peer", req.peer)
+ select {
+ case <-req.stale:
+ log.Trace("Storage request already reverted", "peer", req.peer, "reqid", req.id)
+ return
+ default:
+ }
+ close(req.stale)
+
+ // Remove the request from the tracked set
+ s.lock.Lock()
+ delete(s.storageReqs, req.id)
+ s.lock.Unlock()
+
+ // If there's a timeout timer still running, abort it and mark the storage
+ // task as not-pending, ready for resheduling
+ req.timeout.Stop()
+ if req.subTask != nil {
+ req.subTask.req = nil
+ } else {
+ for i, account := range req.accounts {
+ req.mainTask.stateTasks[account] = req.roots[i]
+ }
+ }
+}
+
+// revertTrienodeHealRequest cleans up an trienode heal request and returns all
+// failed retrieval tasks to the scheduler for reassignment.
+func (s *Syncer) revertTrienodeHealRequest(req *trienodeHealRequest) {
+ log.Trace("Reverting trienode heal request", "peer", req.peer)
+ select {
+ case <-req.stale:
+ log.Trace("Trienode heal request already reverted", "peer", req.peer, "reqid", req.id)
+ return
+ default:
+ }
+ close(req.stale)
+
+ // Remove the request from the tracked set
+ s.lock.Lock()
+ delete(s.trienodeHealReqs, req.id)
+ s.lock.Unlock()
+
+ // If there's a timeout timer still running, abort it and mark the trie node
+ // retrievals as not-pending, ready for resheduling
+ req.timeout.Stop()
+ for i, hash := range req.hashes {
+ req.task.trieTasks[hash] = [][]byte(req.paths[i])
+ }
+}
+
+// revertBytecodeHealRequest cleans up an bytecode request and returns all failed
+// retrieval tasks to the scheduler for reassignment.
+func (s *Syncer) revertBytecodeHealRequest(req *bytecodeHealRequest) {
+ log.Trace("Reverting bytecode heal request", "peer", req.peer)
+ select {
+ case <-req.stale:
+ log.Trace("Bytecode heal request already reverted", "peer", req.peer, "reqid", req.id)
+ return
+ default:
+ }
+ close(req.stale)
+
+ // Remove the request from the tracked set
+ s.lock.Lock()
+ delete(s.bytecodeHealReqs, req.id)
+ s.lock.Unlock()
+
+ // If there's a timeout timer still running, abort it and mark the code
+ // retrievals as not-pending, ready for resheduling
+ req.timeout.Stop()
+ for _, hash := range req.hashes {
+ req.task.codeTasks[hash] = struct{}{}
+ }
+}
+
+// processAccountResponse integrates an already validated account range response
+// into the account tasks.
+func (s *Syncer) processAccountResponse(res *accountResponse) {
+ // Switch the task from pending to filling
+ res.task.req = nil
+ res.task.res = res
+
+ // Ensure that the response doesn't overflow into the subsequent task
+ last := res.task.Last.Big()
+ for i, hash := range res.hashes {
+ if hash.Big().Cmp(last) > 0 {
+ // Chunk overflown, cut off excess, but also update the boundary nodes
+ for j := i; j < len(res.hashes); j++ {
+ if err := res.trie.Prove(res.hashes[j][:], 0, res.overflow); err != nil {
+ panic(err) // Account range was already proven, what happened
+ }
+ }
+ res.hashes = res.hashes[:i]
+ res.accounts = res.accounts[:i]
+ res.cont = false // Mark range completed
+ break
+ }
+ }
+ // Itereate over all the accounts and assemble which ones need further sub-
+ // filling before the entire account range can be persisted.
+ res.task.needCode = make([]bool, len(res.accounts))
+ res.task.needState = make([]bool, len(res.accounts))
+ res.task.needHeal = make([]bool, len(res.accounts))
+
+ res.task.codeTasks = make(map[common.Hash]struct{})
+ res.task.stateTasks = make(map[common.Hash]common.Hash)
+
+ resumed := make(map[common.Hash]struct{})
+
+ res.task.pend = 0
+ for i, account := range res.accounts {
+ // Check if the account is a contract with an unknown code
+ if !bytes.Equal(account.CodeHash, emptyCode[:]) {
+ if code := rawdb.ReadCodeWithPrefix(s.db, common.BytesToHash(account.CodeHash)); code == nil {
+ res.task.codeTasks[common.BytesToHash(account.CodeHash)] = struct{}{}
+ res.task.needCode[i] = true
+ res.task.pend++
+ }
+ }
+ // Check if the account is a contract with an unknown storage trie
+ if account.Root != emptyRoot {
+ if node, err := s.db.Get(account.Root[:]); err != nil || node == nil {
+ // If there was a previous large state retrieval in progress,
+ // don't restart it from scratch. This happens if a sync cycle
+ // is interrupted and resumed later. However, *do* update the
+ // previous root hash.
+ if subtasks, ok := res.task.SubTasks[res.hashes[i]]; ok {
+ log.Error("Resuming large storage retrieval", "account", res.hashes[i], "root", account.Root)
+ for _, subtask := range subtasks {
+ subtask.root = account.Root
+ }
+ res.task.needHeal[i] = true
+ resumed[res.hashes[i]] = struct{}{}
+ } else {
+ res.task.stateTasks[res.hashes[i]] = account.Root
+ }
+ res.task.needState[i] = true
+ res.task.pend++
+ }
+ }
+ }
+ // Delete any subtasks that have been aborted but not resumed. This may undo
+ // some progress if a newpeer gives us less accounts than an old one, but for
+ // now we have to live with that.
+ for hash := range res.task.SubTasks {
+ if _, ok := resumed[hash]; !ok {
+ log.Error("Aborting suspended storage retrieval", "account", hash)
+ delete(res.task.SubTasks, hash)
+ }
+ }
+ // If the account range contained no contracts, or all have been fully filled
+ // beforehand, short circuit storage filling and forward to the next task
+ if res.task.pend == 0 {
+ s.forwardAccountTask(res.task)
+ return
+ }
+ // Some accounts are incomplete, leave as is for the storage and contract
+ // task assigners to pick up and fill.
+}
+
+// processBytecodeResponse integrates an already validated bytecode response
+// into the account tasks.
+func (s *Syncer) processBytecodeResponse(res *bytecodeResponse) {
+ batch := s.db.NewBatch()
+
+ var (
+ codes uint64
+ bytes common.StorageSize
+ )
+ for i, hash := range res.hashes {
+ code := res.codes[i]
+
+ // If the bytecode was not delivered, reschedule it
+ if code == nil {
+ res.task.codeTasks[hash] = struct{}{}
+ continue
+ }
+ // Code was delivered, mark it not needed any more
+ for j, account := range res.task.res.accounts {
+ if res.task.needCode[j] && hash == common.BytesToHash(account.CodeHash) {
+ res.task.needCode[j] = false
+ res.task.pend--
+ }
+ }
+ // Push the bytecode into a database batch
+ s.bytecodeSynced++
+ s.bytecodeBytes += common.StorageSize(len(code))
+
+ codes++
+ bytes += common.StorageSize(len(code))
+
+ rawdb.WriteCode(batch, hash, code)
+ s.bloom.Add(hash[:])
+ }
+ if err := batch.Write(); err != nil {
+ log.Crit("Failed to persist bytecodes", "err", err)
+ }
+ log.Debug("Persisted set of bytecodes", "count", codes, "bytes", bytes)
+
+ // If this delivery completed the last pending task, forward the account task
+ // to the next chunk
+ if res.task.pend == 0 {
+ s.forwardAccountTask(res.task)
+ return
+ }
+ // Some accounts are still incomplete, leave as is for the storage and contract
+ // task assigners to pick up and fill.
+}
+
+// processStorageResponse integrates an already validated storage response
+// into the account tasks.
+func (s *Syncer) processStorageResponse(res *storageResponse) {
+ // Switch the suntask from pending to idle
+ if res.subTask != nil {
+ res.subTask.req = nil
+ }
+ batch := s.db.NewBatch()
+
+ var (
+ slots int
+ nodes int
+ skipped int
+ bytes common.StorageSize
+ )
+ // Iterate over all the accounts and reconstruct their storage tries from the
+ // delivered slots
+ delivered := make(map[common.Hash]bool)
+ for i := 0; i < len(res.hashes); i++ {
+ delivered[res.roots[i]] = true
+ }
+ for i, account := range res.accounts {
+ // If the account was not delivered, reschedule it
+ if i >= len(res.hashes) {
+ if !delivered[res.roots[i]] {
+ res.mainTask.stateTasks[account] = res.roots[i]
+ }
+ continue
+ }
+ // State was delivered, if complete mark as not needed any more, otherwise
+ // mark the account as needing healing
+ for j, acc := range res.mainTask.res.accounts {
+ if res.roots[i] == acc.Root {
+ // If the packet contains multiple contract storage slots, all
+ // but the last are surely complete. The last contract may be
+ // chunked, so check it's continuation flag.
+ if res.subTask == nil && res.mainTask.needState[j] && (i < len(res.hashes)-1 || !res.cont) {
+ res.mainTask.needState[j] = false
+ res.mainTask.pend--
+ }
+ // If the last contract was chunked, mark it as needing healing
+ // to avoid writing it out to disk prematurely.
+ if res.subTask == nil && !res.mainTask.needHeal[j] && i == len(res.hashes)-1 && res.cont {
+ res.mainTask.needHeal[j] = true
+ }
+ // If the last contract was chunked, we need to switch to large
+ // contract handling mode
+ if res.subTask == nil && i == len(res.hashes)-1 && res.cont {
+ // If we haven't yet started a large-contract retrieval, create
+ // the subtasks for it within the main account task
+ if tasks, ok := res.mainTask.SubTasks[account]; !ok {
+ var (
+ next common.Hash
+ )
+ step := new(big.Int).Sub(
+ new(big.Int).Div(
+ new(big.Int).Exp(common.Big2, common.Big256, nil),
+ big.NewInt(storageConcurrency),
+ ), common.Big1,
+ )
+ for k := 0; k < storageConcurrency; k++ {
+ last := common.BigToHash(new(big.Int).Add(next.Big(), step))
+ if k == storageConcurrency-1 {
+ // Make sure we don't overflow if the step is not a proper divisor
+ last = common.HexToHash("0xffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff")
+ }
+ tasks = append(tasks, &storageTask{
+ Next: next,
+ Last: last,
+ root: acc.Root,
+ })
+ log.Debug("Created storage sync task", "account", account, "root", acc.Root, "from", next, "last", last)
+ next = common.BigToHash(new(big.Int).Add(last.Big(), common.Big1))
+ }
+ res.mainTask.SubTasks[account] = tasks
+
+ // Since we've just created the sub-tasks, this response
+ // is surely for the first one (zero origin)
+ res.subTask = tasks[0]
+ }
+ }
+ // If we're in large contract delivery mode, forward the subtask
+ if res.subTask != nil {
+ // Ensure the response doesn't overflow into the subsequent task
+ last := res.subTask.Last.Big()
+ for k, hash := range res.hashes[i] {
+ if hash.Big().Cmp(last) > 0 {
+ // Chunk overflown, cut off excess, but also update the boundary
+ for l := k; l < len(res.hashes[i]); l++ {
+ if err := res.tries[i].Prove(res.hashes[i][l][:], 0, res.overflow); err != nil {
+ panic(err) // Account range was already proven, what happened
+ }
+ }
+ res.hashes[i] = res.hashes[i][:k]
+ res.slots[i] = res.slots[i][:k]
+ res.cont = false // Mark range completed
+ break
+ }
+ }
+ // Forward the relevant storage chunk (even if created just now)
+ if res.cont {
+ res.subTask.Next = common.BigToHash(new(big.Int).Add(res.hashes[i][len(res.hashes[i])-1].Big(), big.NewInt(1)))
+ } else {
+ res.subTask.done = true
+ }
+ }
+ }
+ }
+ // Iterate over all the reconstructed trie nodes and push them to disk
+ slots += len(res.hashes[i])
+
+ it := res.nodes[i].NewIterator(nil, nil)
+ for it.Next() {
+ // Boundary nodes are not written for the last result, since they are incomplete
+ if i == len(res.hashes)-1 {
+ if _, ok := res.bounds[common.BytesToHash(it.Key())]; ok {
+ skipped++
+ continue
+ }
+ }
+ // Node is not a boundary, persist to disk
+ batch.Put(it.Key(), it.Value())
+ s.bloom.Add(it.Key())
+
+ bytes += common.StorageSize(common.HashLength + len(it.Value()))
+ nodes++
+ }
+ it.Release()
+ }
+ if err := batch.Write(); err != nil {
+ log.Crit("Failed to persist storage slots", "err", err)
+ }
+ s.storageSynced += uint64(slots)
+ s.storageBytes += bytes
+
+ log.Debug("Persisted set of storage slots", "accounts", len(res.hashes), "slots", slots, "nodes", nodes, "skipped", skipped, "bytes", bytes)
+
+ // If this delivery completed the last pending task, forward the account task
+ // to the next chunk
+ if res.mainTask.pend == 0 {
+ s.forwardAccountTask(res.mainTask)
+ return
+ }
+ // Some accounts are still incomplete, leave as is for the storage and contract
+ // task assigners to pick up and fill.
+}
+
+// processTrienodeHealResponse integrates an already validated trienode response
+// into the healer tasks.
+func (s *Syncer) processTrienodeHealResponse(res *trienodeHealResponse) {
+ for i, hash := range res.hashes {
+ node := res.nodes[i]
+
+ // If the trie node was not delivered, reschedule it
+ if node == nil {
+ res.task.trieTasks[hash] = res.paths[i]
+ continue
+ }
+ // Push the trie node into the state syncer
+ s.trienodeHealSynced++
+ s.trienodeHealBytes += common.StorageSize(len(node))
+
+ err := s.healer.scheduler.Process(trie.SyncResult{Hash: hash, Data: node})
+ switch err {
+ case nil:
+ case trie.ErrAlreadyProcessed:
+ s.trienodeHealDups++
+ case trie.ErrNotRequested:
+ s.trienodeHealNops++
+ default:
+ log.Error("Invalid trienode processed", "hash", hash, "err", err)
+ }
+ }
+ batch := s.db.NewBatch()
+ if err := s.healer.scheduler.Commit(batch); err != nil {
+ log.Error("Failed to commit healing data", "err", err)
+ }
+ if err := batch.Write(); err != nil {
+ log.Crit("Failed to persist healing data", "err", err)
+ }
+ log.Debug("Persisted set of healing data", "bytes", common.StorageSize(batch.ValueSize()))
+}
+
+// processBytecodeHealResponse integrates an already validated bytecode response
+// into the healer tasks.
+func (s *Syncer) processBytecodeHealResponse(res *bytecodeHealResponse) {
+ for i, hash := range res.hashes {
+ node := res.codes[i]
+
+ // If the trie node was not delivered, reschedule it
+ if node == nil {
+ res.task.codeTasks[hash] = struct{}{}
+ continue
+ }
+ // Push the trie node into the state syncer
+ s.bytecodeHealSynced++
+ s.bytecodeHealBytes += common.StorageSize(len(node))
+
+ err := s.healer.scheduler.Process(trie.SyncResult{Hash: hash, Data: node})
+ switch err {
+ case nil:
+ case trie.ErrAlreadyProcessed:
+ s.bytecodeHealDups++
+ case trie.ErrNotRequested:
+ s.bytecodeHealNops++
+ default:
+ log.Error("Invalid bytecode processed", "hash", hash, "err", err)
+ }
+ }
+ batch := s.db.NewBatch()
+ if err := s.healer.scheduler.Commit(batch); err != nil {
+ log.Error("Failed to commit healing data", "err", err)
+ }
+ if err := batch.Write(); err != nil {
+ log.Crit("Failed to persist healing data", "err", err)
+ }
+ log.Debug("Persisted set of healing data", "bytes", common.StorageSize(batch.ValueSize()))
+}
+
+// forwardAccountTask takes a filled account task and persists anything available
+// into the database, after which it forwards the next account marker so that the
+// task's next chunk may be filled.
+func (s *Syncer) forwardAccountTask(task *accountTask) {
+ // Remove any pending delivery
+ res := task.res
+ if res == nil {
+ return // nothing to forward
+ }
+ task.res = nil
+
+ // Iterate over all the accounts and gather all the incomplete trie nodes. A
+ // node is incomplete if we haven't yet filled it (sync was interrupted), or
+ // if we filled it in multiple chunks (storage trie), in which case the few
+ // nodes on the chunk boundaries are missing.
+ incompletes := light.NewNodeSet()
+ for i := range res.accounts {
+ // If the filling was interrupted, mark everything after as incomplete
+ if task.needCode[i] || task.needState[i] {
+ for j := i; j < len(res.accounts); j++ {
+ if err := res.trie.Prove(res.hashes[j][:], 0, incompletes); err != nil {
+ panic(err) // Account range was already proven, what happened
+ }
+ }
+ break
+ }
+ // Filling not interrupted until this point, mark incomplete if needs healing
+ if task.needHeal[i] {
+ if err := res.trie.Prove(res.hashes[i][:], 0, incompletes); err != nil {
+ panic(err) // Account range was already proven, what happened
+ }
+ }
+ }
+ // Persist every finalized trie node that's not on the boundary
+ batch := s.db.NewBatch()
+
+ var (
+ nodes int
+ skipped int
+ bytes common.StorageSize
+ )
+ it := res.nodes.NewIterator(nil, nil)
+ for it.Next() {
+ // Boundary nodes are not written, since they are incomplete
+ if _, ok := res.bounds[common.BytesToHash(it.Key())]; ok {
+ skipped++
+ continue
+ }
+ // Overflow nodes are not written, since they mess with another task
+ if _, err := res.overflow.Get(it.Key()); err == nil {
+ skipped++
+ continue
+ }
+ // Accounts with split storage requests are incomplete
+ if _, err := incompletes.Get(it.Key()); err == nil {
+ skipped++
+ continue
+ }
+ // Node is neither a boundary, not an incomplete account, persist to disk
+ batch.Put(it.Key(), it.Value())
+ s.bloom.Add(it.Key())
+
+ bytes += common.StorageSize(common.HashLength + len(it.Value()))
+ nodes++
+ }
+ it.Release()
+
+ if err := batch.Write(); err != nil {
+ log.Crit("Failed to persist accounts", "err", err)
+ }
+ s.accountBytes += bytes
+ s.accountSynced += uint64(len(res.accounts))
+
+ log.Debug("Persisted range of accounts", "accounts", len(res.accounts), "nodes", nodes, "skipped", skipped, "bytes", bytes)
+
+ // Task filling persisted, push it the chunk marker forward to the first
+ // account still missing data.
+ for i, hash := range res.hashes {
+ if task.needCode[i] || task.needState[i] {
+ return
+ }
+ task.Next = common.BigToHash(new(big.Int).Add(hash.Big(), big.NewInt(1)))
+ }
+ // All accounts marked as complete, track if the entire task is done
+ task.done = !res.cont
+}
+
+// OnAccounts is a callback method to invoke when a range of accounts are
+// received from a remote peer.
+func (s *Syncer) OnAccounts(peer *Peer, id uint64, hashes []common.Hash, accounts [][]byte, proof [][]byte) error {
+ size := common.StorageSize(len(hashes) * common.HashLength)
+ for _, account := range accounts {
+ size += common.StorageSize(len(account))
+ }
+ for _, node := range proof {
+ size += common.StorageSize(len(node))
+ }
+ logger := peer.logger.New("reqid", id)
+ logger.Trace("Delivering range of accounts", "hashes", len(hashes), "accounts", len(accounts), "proofs", len(proof), "bytes", size)
+
+ // Whether or not the response is valid, we can mark the peer as idle and
+ // notify the scheduler to assign a new task. If the response is invalid,
+ // we'll drop the peer in a bit.
+ s.lock.Lock()
+ if _, ok := s.peers[peer.id]; ok {
+ s.accountIdlers[peer.id] = struct{}{}
+ }
+ select {
+ case s.update <- struct{}{}:
+ default:
+ }
+ // Ensure the response is for a valid request
+ req, ok := s.accountReqs[id]
+ if !ok {
+ // Request stale, perhaps the peer timed out but came through in the end
+ logger.Warn("Unexpected account range packet")
+ s.lock.Unlock()
+ return nil
+ }
+ delete(s.accountReqs, id)
+
+ // Clean up the request timeout timer, we'll see how to proceed further based
+ // on the actual delivered content
+ req.timeout.Stop()
+
+ // Response is valid, but check if peer is signalling that it does not have
+ // the requested data. For account range queries that means the state being
+ // retrieved was either already pruned remotely, or the peer is not yet
+ // synced to our head.
+ if len(hashes) == 0 && len(accounts) == 0 && len(proof) == 0 {
+ logger.Debug("Peer rejected account range request", "root", s.root)
+ s.statelessPeers[peer.id] = struct{}{}
+ s.lock.Unlock()
+ return nil
+ }
+ root := s.root
+ s.lock.Unlock()
+
+ // Reconstruct a partial trie from the response and verify it
+ keys := make([][]byte, len(hashes))
+ for i, key := range hashes {
+ keys[i] = common.CopyBytes(key[:])
+ }
+ nodes := make(light.NodeList, len(proof))
+ for i, node := range proof {
+ nodes[i] = node
+ }
+ proofdb := nodes.NodeSet()
+
+ var end []byte
+ if len(keys) > 0 {
+ end = keys[len(keys)-1]
+ }
+ db, tr, notary, cont, err := trie.VerifyRangeProof(root, req.origin[:], end, keys, accounts, proofdb)
+ if err != nil {
+ logger.Warn("Account range failed proof", "err", err)
+ return err
+ }
+ // Partial trie reconstructed, send it to the scheduler for storage filling
+ bounds := make(map[common.Hash]struct{})
+
+ it := notary.Accessed().NewIterator(nil, nil)
+ for it.Next() {
+ bounds[common.BytesToHash(it.Key())] = struct{}{}
+ }
+ it.Release()
+
+ accs := make([]*state.Account, len(accounts))
+ for i, account := range accounts {
+ acc := new(state.Account)
+ if err := rlp.DecodeBytes(account, acc); err != nil {
+ panic(err) // We created these blobs, we must be able to decode them
+ }
+ accs[i] = acc
+ }
+ response := &accountResponse{
+ task: req.task,
+ hashes: hashes,
+ accounts: accs,
+ nodes: db,
+ trie: tr,
+ bounds: bounds,
+ overflow: light.NewNodeSet(),
+ cont: cont,
+ }
+ select {
+ case s.accountResps <- response:
+ case <-req.cancel:
+ case <-req.stale:
+ }
+ return nil
+}
+
+// OnByteCodes is a callback method to invoke when a batch of contract
+// bytes codes are received from a remote peer.
+func (s *Syncer) OnByteCodes(peer *Peer, id uint64, bytecodes [][]byte) error {
+ s.lock.RLock()
+ syncing := len(s.tasks) > 0
+ s.lock.RUnlock()
+
+ if syncing {
+ return s.onByteCodes(peer, id, bytecodes)
+ }
+ return s.onHealByteCodes(peer, id, bytecodes)
+}
+
+// onByteCodes is a callback method to invoke when a batch of contract
+// bytes codes are received from a remote peer in the syncing phase.
+func (s *Syncer) onByteCodes(peer *Peer, id uint64, bytecodes [][]byte) error {
+ var size common.StorageSize
+ for _, code := range bytecodes {
+ size += common.StorageSize(len(code))
+ }
+ logger := peer.logger.New("reqid", id)
+ logger.Trace("Delivering set of bytecodes", "bytecodes", len(bytecodes), "bytes", size)
+
+ // Whether or not the response is valid, we can mark the peer as idle and
+ // notify the scheduler to assign a new task. If the response is invalid,
+ // we'll drop the peer in a bit.
+ s.lock.Lock()
+ if _, ok := s.peers[peer.id]; ok {
+ s.bytecodeIdlers[peer.id] = struct{}{}
+ }
+ select {
+ case s.update <- struct{}{}:
+ default:
+ }
+ // Ensure the response is for a valid request
+ req, ok := s.bytecodeReqs[id]
+ if !ok {
+ // Request stale, perhaps the peer timed out but came through in the end
+ logger.Warn("Unexpected bytecode packet")
+ s.lock.Unlock()
+ return nil
+ }
+ delete(s.bytecodeReqs, id)
+
+ // Clean up the request timeout timer, we'll see how to proceed further based
+ // on the actual delivered content
+ req.timeout.Stop()
+
+ // Response is valid, but check if peer is signalling that it does not have
+ // the requested data. For bytecode range queries that means the peer is not
+ // yet synced.
+ if len(bytecodes) == 0 {
+ logger.Debug("Peer rejected bytecode request")
+ s.statelessPeers[peer.id] = struct{}{}
+ s.lock.Unlock()
+ return nil
+ }
+ s.lock.Unlock()
+
+ // Cross reference the requested bytecodes with the response to find gaps
+ // that the serving node is missing
+ hasher := sha3.NewLegacyKeccak256()
+
+ codes := make([][]byte, len(req.hashes))
+ for i, j := 0, 0; i < len(bytecodes); i++ {
+ // Find the next hash that we've been served, leaving misses with nils
+ hasher.Reset()
+ hasher.Write(bytecodes[i])
+ hash := hasher.Sum(nil)
+
+ for j < len(req.hashes) && !bytes.Equal(hash, req.hashes[j][:]) {
+ j++
+ }
+ if j < len(req.hashes) {
+ codes[j] = bytecodes[i]
+ j++
+ continue
+ }
+ // We've either ran out of hashes, or got unrequested data
+ logger.Warn("Unexpected bytecodes", "count", len(bytecodes)-i)
+ return errors.New("unexpected bytecode")
+ }
+ // Response validated, send it to the scheduler for filling
+ response := &bytecodeResponse{
+ task: req.task,
+ hashes: req.hashes,
+ codes: codes,
+ }
+ select {
+ case s.bytecodeResps <- response:
+ case <-req.cancel:
+ case <-req.stale:
+ }
+ return nil
+}
+
+// OnStorage is a callback method to invoke when ranges of storage slots
+// are received from a remote peer.
+func (s *Syncer) OnStorage(peer *Peer, id uint64, hashes [][]common.Hash, slots [][][]byte, proof [][]byte) error {
+ // Gather some trace stats to aid in debugging issues
+ var (
+ hashCount int
+ slotCount int
+ size common.StorageSize
+ )
+ for _, hashset := range hashes {
+ size += common.StorageSize(common.HashLength * len(hashset))
+ hashCount += len(hashset)
+ }
+ for _, slotset := range slots {
+ for _, slot := range slotset {
+ size += common.StorageSize(len(slot))
+ }
+ slotCount += len(slotset)
+ }
+ for _, node := range proof {
+ size += common.StorageSize(len(node))
+ }
+ logger := peer.logger.New("reqid", id)
+ logger.Trace("Delivering ranges of storage slots", "accounts", len(hashes), "hashes", hashCount, "slots", slotCount, "proofs", len(proof), "size", size)
+
+ // Whether or not the response is valid, we can mark the peer as idle and
+ // notify the scheduler to assign a new task. If the response is invalid,
+ // we'll drop the peer in a bit.
+ s.lock.Lock()
+ if _, ok := s.peers[peer.id]; ok {
+ s.storageIdlers[peer.id] = struct{}{}
+ }
+ select {
+ case s.update <- struct{}{}:
+ default:
+ }
+ // Ensure the response is for a valid request
+ req, ok := s.storageReqs[id]
+ if !ok {
+ // Request stale, perhaps the peer timed out but came through in the end
+ logger.Warn("Unexpected storage ranges packet")
+ s.lock.Unlock()
+ return nil
+ }
+ delete(s.storageReqs, id)
+
+ // Clean up the request timeout timer, we'll see how to proceed further based
+ // on the actual delivered content
+ req.timeout.Stop()
+
+ // Reject the response if the hash sets and slot sets don't match, or if the
+ // peer sent more data than requested.
+ if len(hashes) != len(slots) {
+ s.lock.Unlock()
+ logger.Warn("Hash and slot set size mismatch", "hashset", len(hashes), "slotset", len(slots))
+ return errors.New("hash and slot set size mismatch")
+ }
+ if len(hashes) > len(req.accounts) {
+ s.lock.Unlock()
+ logger.Warn("Hash set larger than requested", "hashset", len(hashes), "requested", len(req.accounts))
+ return errors.New("hash set larger than requested")
+ }
+ // Response is valid, but check if peer is signalling that it does not have
+ // the requested data. For storage range queries that means the state being
+ // retrieved was either already pruned remotely, or the peer is not yet
+ // synced to our head.
+ if len(hashes) == 0 {
+ logger.Debug("Peer rejected storage request")
+ s.statelessPeers[peer.id] = struct{}{}
+ s.lock.Unlock()
+ return nil
+ }
+ s.lock.Unlock()
+
+ // Reconstruct the partial tries from the response and verify them
+ var (
+ dbs = make([]ethdb.KeyValueStore, len(hashes))
+ tries = make([]*trie.Trie, len(hashes))
+ notary *trie.KeyValueNotary
+ cont bool
+ )
+ for i := 0; i < len(hashes); i++ {
+ // Convert the keys and proofs into an internal format
+ keys := make([][]byte, len(hashes[i]))
+ for j, key := range hashes[i] {
+ keys[j] = common.CopyBytes(key[:])
+ }
+ nodes := make(light.NodeList, 0, len(proof))
+ if i == len(hashes)-1 {
+ for _, node := range proof {
+ nodes = append(nodes, node)
+ }
+ }
+ var err error
+ if len(nodes) == 0 {
+ // No proof has been attached, the response must cover the entire key
+ // space and hash to the origin root.
+ dbs[i], tries[i], _, _, err = trie.VerifyRangeProof(req.roots[i], nil, nil, keys, slots[i], nil)
+ if err != nil {
+ logger.Warn("Storage slots failed proof", "err", err)
+ return err
+ }
+ } else {
+ // A proof was attached, the response is only partial, check that the
+ // returned data is indeed part of the storage trie
+ proofdb := nodes.NodeSet()
+
+ var end []byte
+ if len(keys) > 0 {
+ end = keys[len(keys)-1]
+ }
+ dbs[i], tries[i], notary, cont, err = trie.VerifyRangeProof(req.roots[i], req.origin[:], end, keys, slots[i], proofdb)
+ if err != nil {
+ logger.Warn("Storage range failed proof", "err", err)
+ return err
+ }
+ }
+ }
+ // Partial tries reconstructed, send them to the scheduler for storage filling
+ bounds := make(map[common.Hash]struct{})
+
+ if notary != nil { // if all contract storages are delivered in full, no notary will be created
+ it := notary.Accessed().NewIterator(nil, nil)
+ for it.Next() {
+ bounds[common.BytesToHash(it.Key())] = struct{}{}
+ }
+ it.Release()
+ }
+ response := &storageResponse{
+ mainTask: req.mainTask,
+ subTask: req.subTask,
+ accounts: req.accounts,
+ roots: req.roots,
+ hashes: hashes,
+ slots: slots,
+ nodes: dbs,
+ tries: tries,
+ bounds: bounds,
+ overflow: light.NewNodeSet(),
+ cont: cont,
+ }
+ select {
+ case s.storageResps <- response:
+ case <-req.cancel:
+ case <-req.stale:
+ }
+ return nil
+}
+
+// OnTrieNodes is a callback method to invoke when a batch of trie nodes
+// are received from a remote peer.
+func (s *Syncer) OnTrieNodes(peer *Peer, id uint64, trienodes [][]byte) error {
+ var size common.StorageSize
+ for _, node := range trienodes {
+ size += common.StorageSize(len(node))
+ }
+ logger := peer.logger.New("reqid", id)
+ logger.Trace("Delivering set of healing trienodes", "trienodes", len(trienodes), "bytes", size)
+
+ // Whether or not the response is valid, we can mark the peer as idle and
+ // notify the scheduler to assign a new task. If the response is invalid,
+ // we'll drop the peer in a bit.
+ s.lock.Lock()
+ if _, ok := s.peers[peer.id]; ok {
+ s.trienodeHealIdlers[peer.id] = struct{}{}
+ }
+ select {
+ case s.update <- struct{}{}:
+ default:
+ }
+ // Ensure the response is for a valid request
+ req, ok := s.trienodeHealReqs[id]
+ if !ok {
+ // Request stale, perhaps the peer timed out but came through in the end
+ logger.Warn("Unexpected trienode heal packet")
+ s.lock.Unlock()
+ return nil
+ }
+ delete(s.trienodeHealReqs, id)
+
+ // Clean up the request timeout timer, we'll see how to proceed further based
+ // on the actual delivered content
+ req.timeout.Stop()
+
+ // Response is valid, but check if peer is signalling that it does not have
+ // the requested data. For bytecode range queries that means the peer is not
+ // yet synced.
+ if len(trienodes) == 0 {
+ logger.Debug("Peer rejected trienode heal request")
+ s.statelessPeers[peer.id] = struct{}{}
+ s.lock.Unlock()
+ return nil
+ }
+ s.lock.Unlock()
+
+ // Cross reference the requested trienodes with the response to find gaps
+ // that the serving node is missing
+ hasher := sha3.NewLegacyKeccak256()
+
+ nodes := make([][]byte, len(req.hashes))
+ for i, j := 0, 0; i < len(trienodes); i++ {
+ // Find the next hash that we've been served, leaving misses with nils
+ hasher.Reset()
+ hasher.Write(trienodes[i])
+ hash := hasher.Sum(nil)
+
+ for j < len(req.hashes) && !bytes.Equal(hash, req.hashes[j][:]) {
+ j++
+ }
+ if j < len(req.hashes) {
+ nodes[j] = trienodes[i]
+ j++
+ continue
+ }
+ // We've either ran out of hashes, or got unrequested data
+ logger.Warn("Unexpected healing trienodes", "count", len(trienodes)-i)
+ return errors.New("unexpected healing trienode")
+ }
+ // Response validated, send it to the scheduler for filling
+ response := &trienodeHealResponse{
+ task: req.task,
+ hashes: req.hashes,
+ paths: req.paths,
+ nodes: nodes,
+ }
+ select {
+ case s.trienodeHealResps <- response:
+ case <-req.cancel:
+ case <-req.stale:
+ }
+ return nil
+}
+
+// onHealByteCodes is a callback method to invoke when a batch of contract
+// bytes codes are received from a remote peer in the healing phase.
+func (s *Syncer) onHealByteCodes(peer *Peer, id uint64, bytecodes [][]byte) error {
+ var size common.StorageSize
+ for _, code := range bytecodes {
+ size += common.StorageSize(len(code))
+ }
+ logger := peer.logger.New("reqid", id)
+ logger.Trace("Delivering set of healing bytecodes", "bytecodes", len(bytecodes), "bytes", size)
+
+ // Whether or not the response is valid, we can mark the peer as idle and
+ // notify the scheduler to assign a new task. If the response is invalid,
+ // we'll drop the peer in a bit.
+ s.lock.Lock()
+ if _, ok := s.peers[peer.id]; ok {
+ s.bytecodeHealIdlers[peer.id] = struct{}{}
+ }
+ select {
+ case s.update <- struct{}{}:
+ default:
+ }
+ // Ensure the response is for a valid request
+ req, ok := s.bytecodeHealReqs[id]
+ if !ok {
+ // Request stale, perhaps the peer timed out but came through in the end
+ logger.Warn("Unexpected bytecode heal packet")
+ s.lock.Unlock()
+ return nil
+ }
+ delete(s.bytecodeHealReqs, id)
+
+ // Clean up the request timeout timer, we'll see how to proceed further based
+ // on the actual delivered content
+ req.timeout.Stop()
+
+ // Response is valid, but check if peer is signalling that it does not have
+ // the requested data. For bytecode range queries that means the peer is not
+ // yet synced.
+ if len(bytecodes) == 0 {
+ logger.Debug("Peer rejected bytecode heal request")
+ s.statelessPeers[peer.id] = struct{}{}
+ s.lock.Unlock()
+ return nil
+ }
+ s.lock.Unlock()
+
+ // Cross reference the requested bytecodes with the response to find gaps
+ // that the serving node is missing
+ hasher := sha3.NewLegacyKeccak256()
+
+ codes := make([][]byte, len(req.hashes))
+ for i, j := 0, 0; i < len(bytecodes); i++ {
+ // Find the next hash that we've been served, leaving misses with nils
+ hasher.Reset()
+ hasher.Write(bytecodes[i])
+ hash := hasher.Sum(nil)
+
+ for j < len(req.hashes) && !bytes.Equal(hash, req.hashes[j][:]) {
+ j++
+ }
+ if j < len(req.hashes) {
+ codes[j] = bytecodes[i]
+ j++
+ continue
+ }
+ // We've either ran out of hashes, or got unrequested data
+ logger.Warn("Unexpected healing bytecodes", "count", len(bytecodes)-i)
+ return errors.New("unexpected healing bytecode")
+ }
+ // Response validated, send it to the scheduler for filling
+ response := &bytecodeHealResponse{
+ task: req.task,
+ hashes: req.hashes,
+ codes: codes,
+ }
+ select {
+ case s.bytecodeHealResps <- response:
+ case <-req.cancel:
+ case <-req.stale:
+ }
+ return nil
+}
+
+// hashSpace is the total size of the 256 bit hash space for accounts.
+var hashSpace = new(big.Int).Exp(common.Big2, common.Big256, nil)
+
+// report calculates various status reports and provides it to the user.
+func (s *Syncer) report(force bool) {
+ if len(s.tasks) > 0 {
+ s.reportSyncProgress(force)
+ return
+ }
+ s.reportHealProgress(force)
+}
+
+// reportSyncProgress calculates various status reports and provides it to the user.
+func (s *Syncer) reportSyncProgress(force bool) {
+ // Don't report all the events, just occasionally
+ if !force && time.Since(s.logTime) < 3*time.Second {
+ return
+ }
+ // Don't report anything until we have a meaningful progress
+ synced := s.accountBytes + s.bytecodeBytes + s.storageBytes
+ if synced == 0 {
+ return
+ }
+ accountGaps := new(big.Int)
+ for _, task := range s.tasks {
+ accountGaps.Add(accountGaps, new(big.Int).Sub(task.Last.Big(), task.Next.Big()))
+ }
+ accountFills := new(big.Int).Sub(hashSpace, accountGaps)
+ if accountFills.BitLen() == 0 {
+ return
+ }
+ s.logTime = time.Now()
+ estBytes := float64(new(big.Int).Div(
+ new(big.Int).Mul(new(big.Int).SetUint64(uint64(synced)), hashSpace),
+ accountFills,
+ ).Uint64())
+
+ elapsed := time.Since(s.startTime)
+ estTime := elapsed / time.Duration(synced) * time.Duration(estBytes)
+
+ // Create a mega progress report
+ var (
+ progress = fmt.Sprintf("%.2f%%", float64(synced)*100/estBytes)
+ accounts = fmt.Sprintf("%d@%v", s.accountSynced, s.accountBytes.TerminalString())
+ storage = fmt.Sprintf("%d@%v", s.storageSynced, s.storageBytes.TerminalString())
+ bytecode = fmt.Sprintf("%d@%v", s.bytecodeSynced, s.bytecodeBytes.TerminalString())
+ )
+ log.Info("State sync in progress", "synced", progress, "state", synced,
+ "accounts", accounts, "slots", storage, "codes", bytecode, "eta", common.PrettyDuration(estTime-elapsed))
+}
+
+// reportHealProgress calculates various status reports and provides it to the user.
+func (s *Syncer) reportHealProgress(force bool) {
+ // Don't report all the events, just occasionally
+ if !force && time.Since(s.logTime) < 3*time.Second {
+ return
+ }
+ s.logTime = time.Now()
+
+ // Create a mega progress report
+ var (
+ trienode = fmt.Sprintf("%d@%v", s.trienodeHealSynced, s.trienodeHealBytes.TerminalString())
+ bytecode = fmt.Sprintf("%d@%v", s.bytecodeHealSynced, s.bytecodeHealBytes.TerminalString())
+ )
+ log.Info("State heal in progress", "nodes", trienode, "codes", bytecode,
+ "pending", s.healer.scheduler.Pending())
+}
diff --git a/eth/sync.go b/eth/sync.go
index 26badd1e2..03a516524 100644
--- a/eth/sync.go
+++ b/eth/sync.go
@@ -26,6 +26,7 @@ import (
"github.com/ethereum/go-ethereum/core/rawdb"
"github.com/ethereum/go-ethereum/core/types"
"github.com/ethereum/go-ethereum/eth/downloader"
+ "github.com/ethereum/go-ethereum/eth/protocols/eth"
"github.com/ethereum/go-ethereum/log"
"github.com/ethereum/go-ethereum/p2p/enode"
)
@@ -40,12 +41,12 @@ const (
)
type txsync struct {
- p *peer
+ p *eth.Peer
txs []*types.Transaction
}
// syncTransactions starts sending all currently pending transactions to the given peer.
-func (pm *ProtocolManager) syncTransactions(p *peer) {
+func (h *handler) syncTransactions(p *eth.Peer) {
// Assemble the set of transaction to broadcast or announce to the remote
// peer. Fun fact, this is quite an expensive operation as it needs to sort
// the transactions if the sorting is not cached yet. However, with a random
@@ -53,7 +54,7 @@ func (pm *ProtocolManager) syncTransactions(p *peer) {
//
// TODO(karalabe): Figure out if we could get away with random order somehow
var txs types.Transactions
- pending, _ := pm.txpool.Pending()
+ pending, _ := h.txpool.Pending()
for _, batch := range pending {
txs = append(txs, batch...)
}
@@ -63,7 +64,7 @@ func (pm *ProtocolManager) syncTransactions(p *peer) {
// The eth/65 protocol introduces proper transaction announcements, so instead
// of dripping transactions across multiple peers, just send the entire list as
// an announcement and let the remote side decide what they need (likely nothing).
- if p.version >= eth65 {
+ if p.Version() >= eth.ETH65 {
hashes := make([]common.Hash, len(txs))
for i, tx := range txs {
hashes[i] = tx.Hash()
@@ -73,8 +74,8 @@ func (pm *ProtocolManager) syncTransactions(p *peer) {
}
// Out of luck, peer is running legacy protocols, drop the txs over
select {
- case pm.txsyncCh <- &txsync{p: p, txs: txs}:
- case <-pm.quitSync:
+ case h.txsyncCh <- &txsync{p: p, txs: txs}:
+ case <-h.quitSync:
}
}
@@ -82,8 +83,8 @@ func (pm *ProtocolManager) syncTransactions(p *peer) {
// connection. When a new peer appears, we relay all currently pending
// transactions. In order to minimise egress bandwidth usage, we send
// the transactions in small packs to one peer at a time.
-func (pm *ProtocolManager) txsyncLoop64() {
- defer pm.wg.Done()
+func (h *handler) txsyncLoop64() {
+ defer h.wg.Done()
var (
pending = make(map[enode.ID]*txsync)
@@ -94,7 +95,7 @@ func (pm *ProtocolManager) txsyncLoop64() {
// send starts a sending a pack of transactions from the sync.
send := func(s *txsync) {
- if s.p.version >= eth65 {
+ if s.p.Version() >= eth.ETH65 {
panic("initial transaction syncer running on eth/65+")
}
// Fill pack with transactions up to the target size.
@@ -108,14 +109,13 @@ func (pm *ProtocolManager) txsyncLoop64() {
// Remove the transactions that will be sent.
s.txs = s.txs[:copy(s.txs, s.txs[len(pack.txs):])]
if len(s.txs) == 0 {
- delete(pending, s.p.ID())
+ delete(pending, s.p.Peer.ID())
}
// Send the pack in the background.
s.p.Log().Trace("Sending batch of transactions", "count", len(pack.txs), "bytes", size)
sending = true
- go func() { done <- pack.p.SendTransactions64(pack.txs) }()
+ go func() { done <- pack.p.SendTransactions(pack.txs) }()
}
-
// pick chooses the next pending sync.
pick := func() *txsync {
if len(pending) == 0 {
@@ -132,8 +132,8 @@ func (pm *ProtocolManager) txsyncLoop64() {
for {
select {
- case s := <-pm.txsyncCh:
- pending[s.p.ID()] = s
+ case s := <-h.txsyncCh:
+ pending[s.p.Peer.ID()] = s
if !sending {
send(s)
}
@@ -142,13 +142,13 @@ func (pm *ProtocolManager) txsyncLoop64() {
// Stop tracking peers that cause send failures.
if err != nil {
pack.p.Log().Debug("Transaction send failed", "err", err)
- delete(pending, pack.p.ID())
+ delete(pending, pack.p.Peer.ID())
}
// Schedule the next send.
if s := pick(); s != nil {
send(s)
}
- case <-pm.quitSync:
+ case <-h.quitSync:
return
}
}
@@ -156,7 +156,7 @@ func (pm *ProtocolManager) txsyncLoop64() {
// chainSyncer coordinates blockchain sync components.
type chainSyncer struct {
- pm *ProtocolManager
+ handler *handler
force *time.Timer
forced bool // true when force timer fired
peerEventCh chan struct{}
@@ -166,15 +166,15 @@ type chainSyncer struct {
// chainSyncOp is a scheduled sync operation.
type chainSyncOp struct {
mode downloader.SyncMode
- peer *peer
+ peer *eth.Peer
td *big.Int
head common.Hash
}
// newChainSyncer creates a chainSyncer.
-func newChainSyncer(pm *ProtocolManager) *chainSyncer {
+func newChainSyncer(handler *handler) *chainSyncer {
return &chainSyncer{
- pm: pm,
+ handler: handler,
peerEventCh: make(chan struct{}),
}
}
@@ -182,23 +182,24 @@ func newChainSyncer(pm *ProtocolManager) *chainSyncer {
// handlePeerEvent notifies the syncer about a change in the peer set.
// This is called for new peers and every time a peer announces a new
// chain head.
-func (cs *chainSyncer) handlePeerEvent(p *peer) bool {
+func (cs *chainSyncer) handlePeerEvent(peer *eth.Peer) bool {
select {
case cs.peerEventCh <- struct{}{}:
return true
- case <-cs.pm.quitSync:
+ case <-cs.handler.quitSync:
return false
}
}
// loop runs in its own goroutine and launches the sync when necessary.
func (cs *chainSyncer) loop() {
- defer cs.pm.wg.Done()
+ defer cs.handler.wg.Done()
- cs.pm.blockFetcher.Start()
- cs.pm.txFetcher.Start()
- defer cs.pm.blockFetcher.Stop()
- defer cs.pm.txFetcher.Stop()
+ cs.handler.blockFetcher.Start()
+ cs.handler.txFetcher.Start()
+ defer cs.handler.blockFetcher.Stop()
+ defer cs.handler.txFetcher.Stop()
+ defer cs.handler.downloader.Terminate()
// The force timer lowers the peer count threshold down to one when it fires.
// This ensures we'll always start sync even if there aren't enough peers.
@@ -209,7 +210,6 @@ func (cs *chainSyncer) loop() {
if op := cs.nextSyncOp(); op != nil {
cs.startSync(op)
}
-
select {
case <-cs.peerEventCh:
// Peer information changed, recheck.
@@ -220,14 +220,13 @@ func (cs *chainSyncer) loop() {
case <-cs.force.C:
cs.forced = true
- case <-cs.pm.quitSync:
+ case <-cs.handler.quitSync:
// Disable all insertion on the blockchain. This needs to happen before
// terminating the downloader because the downloader waits for blockchain
// inserts, and these can take a long time to finish.
- cs.pm.blockchain.StopInsert()
- cs.pm.downloader.Terminate()
+ cs.handler.chain.StopInsert()
+ cs.handler.downloader.Terminate()
if cs.doneCh != nil {
- // Wait for the current sync to end.
<-cs.doneCh
}
return
@@ -245,19 +244,22 @@ func (cs *chainSyncer) nextSyncOp() *chainSyncOp {
minPeers := defaultMinSyncPeers
if cs.forced {
minPeers = 1
- } else if minPeers > cs.pm.maxPeers {
- minPeers = cs.pm.maxPeers
+ } else if minPeers > cs.handler.maxPeers {
+ minPeers = cs.handler.maxPeers
}
- if cs.pm.peers.Len() < minPeers {
+ if cs.handler.peers.Len() < minPeers {
return nil
}
-
- // We have enough peers, check TD.
- peer := cs.pm.peers.BestPeer()
+ // We have enough peers, check TD
+ peer := cs.handler.peers.ethPeerWithHighestTD()
if peer == nil {
return nil
}
mode, ourTD := cs.modeAndLocalHead()
+ if mode == downloader.FastSync && atomic.LoadUint32(&cs.handler.snapSync) == 1 {
+ // Fast sync via the snap protocol
+ mode = downloader.SnapSync
+ }
op := peerToSyncOp(mode, peer)
if op.td.Cmp(ourTD) <= 0 {
return nil // We're in sync.
@@ -265,42 +267,42 @@ func (cs *chainSyncer) nextSyncOp() *chainSyncOp {
return op
}
-func peerToSyncOp(mode downloader.SyncMode, p *peer) *chainSyncOp {
+func peerToSyncOp(mode downloader.SyncMode, p *eth.Peer) *chainSyncOp {
peerHead, peerTD := p.Head()
return &chainSyncOp{mode: mode, peer: p, td: peerTD, head: peerHead}
}
func (cs *chainSyncer) modeAndLocalHead() (downloader.SyncMode, *big.Int) {
// If we're in fast sync mode, return that directly
- if atomic.LoadUint32(&cs.pm.fastSync) == 1 {
- block := cs.pm.blockchain.CurrentFastBlock()
- td := cs.pm.blockchain.GetTdByHash(block.Hash())
+ if atomic.LoadUint32(&cs.handler.fastSync) == 1 {
+ block := cs.handler.chain.CurrentFastBlock()
+ td := cs.handler.chain.GetTdByHash(block.Hash())
return downloader.FastSync, td
}
// We are probably in full sync, but we might have rewound to before the
// fast sync pivot, check if we should reenable
- if pivot := rawdb.ReadLastPivotNumber(cs.pm.chaindb); pivot != nil {
- if head := cs.pm.blockchain.CurrentBlock(); head.NumberU64() < *pivot {
- block := cs.pm.blockchain.CurrentFastBlock()
- td := cs.pm.blockchain.GetTdByHash(block.Hash())
+ if pivot := rawdb.ReadLastPivotNumber(cs.handler.database); pivot != nil {
+ if head := cs.handler.chain.CurrentBlock(); head.NumberU64() < *pivot {
+ block := cs.handler.chain.CurrentFastBlock()
+ td := cs.handler.chain.GetTdByHash(block.Hash())
return downloader.FastSync, td
}
}
// Nope, we're really full syncing
- head := cs.pm.blockchain.CurrentHeader()
- td := cs.pm.blockchain.GetTd(head.Hash(), head.Number.Uint64())
+ head := cs.handler.chain.CurrentHeader()
+ td := cs.handler.chain.GetTd(head.Hash(), head.Number.Uint64())
return downloader.FullSync, td
}
// startSync launches doSync in a new goroutine.
func (cs *chainSyncer) startSync(op *chainSyncOp) {
cs.doneCh = make(chan error, 1)
- go func() { cs.doneCh <- cs.pm.doSync(op) }()
+ go func() { cs.doneCh <- cs.handler.doSync(op) }()
}
// doSync synchronizes the local blockchain with a remote peer.
-func (pm *ProtocolManager) doSync(op *chainSyncOp) error {
- if op.mode == downloader.FastSync {
+func (h *handler) doSync(op *chainSyncOp) error {
+ if op.mode == downloader.FastSync || op.mode == downloader.SnapSync {
// Before launch the fast sync, we have to ensure user uses the same
// txlookup limit.
// The main concern here is: during the fast sync Geth won't index the
@@ -310,35 +312,33 @@ func (pm *ProtocolManager) doSync(op *chainSyncOp) error {
// has been indexed. So here for the user-experience wise, it's non-optimal
// that user can't change limit during the fast sync. If changed, Geth
// will just blindly use the original one.
- limit := pm.blockchain.TxLookupLimit()
- if stored := rawdb.ReadFastTxLookupLimit(pm.chaindb); stored == nil {
- rawdb.WriteFastTxLookupLimit(pm.chaindb, limit)
+ limit := h.chain.TxLookupLimit()
+ if stored := rawdb.ReadFastTxLookupLimit(h.database); stored == nil {
+ rawdb.WriteFastTxLookupLimit(h.database, limit)
} else if *stored != limit {
- pm.blockchain.SetTxLookupLimit(*stored)
+ h.chain.SetTxLookupLimit(*stored)
log.Warn("Update txLookup limit", "provided", limit, "updated", *stored)
}
}
// Run the sync cycle, and disable fast sync if we're past the pivot block
- err := pm.downloader.Synchronise(op.peer.id, op.head, op.td, op.mode)
+ err := h.downloader.Synchronise(op.peer.ID(), op.head, op.td, op.mode)
if err != nil {
return err
}
- if atomic.LoadUint32(&pm.fastSync) == 1 {
+ if atomic.LoadUint32(&h.fastSync) == 1 {
log.Info("Fast sync complete, auto disabling")
- atomic.StoreUint32(&pm.fastSync, 0)
+ atomic.StoreUint32(&h.fastSync, 0)
}
-
// If we've successfully finished a sync cycle and passed any required checkpoint,
// enable accepting transactions from the network.
- head := pm.blockchain.CurrentBlock()
- if head.NumberU64() >= pm.checkpointNumber {
+ head := h.chain.CurrentBlock()
+ if head.NumberU64() >= h.checkpointNumber {
// Checkpoint passed, sanity check the timestamp to have a fallback mechanism
// for non-checkpointed (number = 0) private networks.
if head.Time() >= uint64(time.Now().AddDate(0, -1, 0).Unix()) {
- atomic.StoreUint32(&pm.acceptTxs, 1)
+ atomic.StoreUint32(&h.acceptTxs, 1)
}
}
-
if head.NumberU64() > 0 {
// We've completed a sync cycle, notify all peers of new state. This path is
// essential in star-topology networks where a gateway node needs to notify
@@ -346,8 +346,7 @@ func (pm *ProtocolManager) doSync(op *chainSyncOp) error {
// scenario will most often crop up in private and hackathon networks with
// degenerate connectivity, but it should be healthy for the mainnet too to
// more reliably update peers or the local TD state.
- pm.BroadcastBlock(head, false)
+ h.BroadcastBlock(head, false)
}
-
return nil
}
diff --git a/eth/sync_test.go b/eth/sync_test.go
index ac1e5fad1..473e19518 100644
--- a/eth/sync_test.go
+++ b/eth/sync_test.go
@@ -22,43 +22,59 @@ import (
"time"
"github.com/ethereum/go-ethereum/eth/downloader"
+ "github.com/ethereum/go-ethereum/eth/protocols/eth"
"github.com/ethereum/go-ethereum/p2p"
"github.com/ethereum/go-ethereum/p2p/enode"
)
-func TestFastSyncDisabling63(t *testing.T) { testFastSyncDisabling(t, 63) }
+// Tests that fast sync is disabled after a successful sync cycle.
func TestFastSyncDisabling64(t *testing.T) { testFastSyncDisabling(t, 64) }
func TestFastSyncDisabling65(t *testing.T) { testFastSyncDisabling(t, 65) }
// Tests that fast sync gets disabled as soon as a real block is successfully
// imported into the blockchain.
-func testFastSyncDisabling(t *testing.T, protocol int) {
+func testFastSyncDisabling(t *testing.T, protocol uint) {
t.Parallel()
- // Create a pristine protocol manager, check that fast sync is left enabled
- pmEmpty, _ := newTestProtocolManagerMust(t, downloader.FastSync, 0, nil, nil)
- if atomic.LoadUint32(&pmEmpty.fastSync) == 0 {
+ // Create an empty handler and ensure it's in fast sync mode
+ empty := newTestHandler()
+ if atomic.LoadUint32(&empty.handler.fastSync) == 0 {
t.Fatalf("fast sync disabled on pristine blockchain")
}
- // Create a full protocol manager, check that fast sync gets disabled
- pmFull, _ := newTestProtocolManagerMust(t, downloader.FastSync, 1024, nil, nil)
- if atomic.LoadUint32(&pmFull.fastSync) == 1 {
+ defer empty.close()
+
+ // Create a full handler and ensure fast sync ends up disabled
+ full := newTestHandlerWithBlocks(1024)
+ if atomic.LoadUint32(&full.handler.fastSync) == 1 {
t.Fatalf("fast sync not disabled on non-empty blockchain")
}
+ defer full.close()
- // Sync up the two peers
- io1, io2 := p2p.MsgPipe()
- go pmFull.handle(pmFull.newPeer(protocol, p2p.NewPeer(enode.ID{}, "empty", nil), io2, pmFull.txpool.Get))
- go pmEmpty.handle(pmEmpty.newPeer(protocol, p2p.NewPeer(enode.ID{}, "full", nil), io1, pmEmpty.txpool.Get))
+ // Sync up the two handlers
+ emptyPipe, fullPipe := p2p.MsgPipe()
+ defer emptyPipe.Close()
+ defer fullPipe.Close()
+ emptyPeer := eth.NewPeer(protocol, p2p.NewPeer(enode.ID{1}, "", nil), emptyPipe, empty.txpool)
+ fullPeer := eth.NewPeer(protocol, p2p.NewPeer(enode.ID{2}, "", nil), fullPipe, full.txpool)
+ defer emptyPeer.Close()
+ defer fullPeer.Close()
+
+ go empty.handler.runEthPeer(emptyPeer, func(peer *eth.Peer) error {
+ return eth.Handle((*ethHandler)(empty.handler), peer)
+ })
+ go full.handler.runEthPeer(fullPeer, func(peer *eth.Peer) error {
+ return eth.Handle((*ethHandler)(full.handler), peer)
+ })
+ // Wait a bit for the above handlers to start
time.Sleep(250 * time.Millisecond)
- op := peerToSyncOp(downloader.FastSync, pmEmpty.peers.BestPeer())
- if err := pmEmpty.doSync(op); err != nil {
- t.Fatal("sync failed:", err)
- }
// Check that fast sync was disabled
- if atomic.LoadUint32(&pmEmpty.fastSync) == 1 {
+ op := peerToSyncOp(downloader.FastSync, empty.handler.peers.ethPeerWithHighestTD())
+ if err := empty.handler.doSync(op); err != nil {
+ t.Fatal("sync failed:", err)
+ }
+ if atomic.LoadUint32(&empty.handler.fastSync) == 1 {
t.Fatalf("fast sync not disabled after successful synchronisation")
}
}
diff --git a/ethstats/ethstats.go b/ethstats/ethstats.go
index 1828ad70f..e0f4f95ff 100644
--- a/ethstats/ethstats.go
+++ b/ethstats/ethstats.go
@@ -36,8 +36,8 @@ import (
"github.com/ethereum/go-ethereum/consensus"
"github.com/ethereum/go-ethereum/core"
"github.com/ethereum/go-ethereum/core/types"
- "github.com/ethereum/go-ethereum/eth"
"github.com/ethereum/go-ethereum/eth/downloader"
+ ethproto "github.com/ethereum/go-ethereum/eth/protocols/eth"
"github.com/ethereum/go-ethereum/event"
"github.com/ethereum/go-ethereum/les"
"github.com/ethereum/go-ethereum/log"
@@ -444,13 +444,15 @@ func (s *Service) login(conn *connWrapper) error {
// Construct and send the login authentication
infos := s.server.NodeInfo()
- var network, protocol string
+ var protocols []string
+ for _, proto := range s.server.Protocols {
+ protocols = append(protocols, fmt.Sprintf("%s/%d", proto.Name, proto.Version))
+ }
+ var network string
if info := infos.Protocols["eth"]; info != nil {
- network = fmt.Sprintf("%d", info.(*eth.NodeInfo).Network)
- protocol = fmt.Sprintf("eth/%d", eth.ProtocolVersions[0])
+ network = fmt.Sprintf("%d", info.(*ethproto.NodeInfo).Network)
} else {
network = fmt.Sprintf("%d", infos.Protocols["les"].(*les.NodeInfo).Network)
- protocol = fmt.Sprintf("les/%d", les.ClientProtocolVersions[0])
}
auth := &authMsg{
ID: s.node,
@@ -459,7 +461,7 @@ func (s *Service) login(conn *connWrapper) error {
Node: infos.Name,
Port: infos.Ports.Listener,
Network: network,
- Protocol: protocol,
+ Protocol: strings.Join(protocols, ", "),
API: "No",
Os: runtime.GOOS,
OsVer: runtime.GOARCH,
diff --git a/graphql/graphql.go b/graphql/graphql.go
index 16a74b403..22cfcf663 100644
--- a/graphql/graphql.go
+++ b/graphql/graphql.go
@@ -1040,10 +1040,6 @@ func (r *Resolver) GasPrice(ctx context.Context) (hexutil.Big, error) {
return hexutil.Big(*price), err
}
-func (r *Resolver) ProtocolVersion(ctx context.Context) (int32, error) {
- return int32(r.backend.ProtocolVersion()), nil
-}
-
func (r *Resolver) ChainID(ctx context.Context) (hexutil.Big, error) {
return hexutil.Big(*r.backend.ChainConfig().ChainID), nil
}
diff --git a/graphql/schema.go b/graphql/schema.go
index d7b253f22..1fdc37004 100644
--- a/graphql/schema.go
+++ b/graphql/schema.go
@@ -310,8 +310,6 @@ const schema string = `
# GasPrice returns the node's estimate of a gas price sufficient to
# ensure a transaction is mined in a timely fashion.
gasPrice: BigInt!
- # ProtocolVersion returns the current wire protocol version number.
- protocolVersion: Int!
# Syncing returns information on the current synchronisation state.
syncing: SyncState
# ChainID returns the current chain ID for transaction replay protection.
diff --git a/internal/ethapi/api.go b/internal/ethapi/api.go
index 030bdb37a..9ff1781e4 100644
--- a/internal/ethapi/api.go
+++ b/internal/ethapi/api.go
@@ -64,11 +64,6 @@ func (s *PublicEthereumAPI) GasPrice(ctx context.Context) (*hexutil.Big, error)
return (*hexutil.Big)(price), err
}
-// ProtocolVersion returns the current Ethereum protocol version this node supports
-func (s *PublicEthereumAPI) ProtocolVersion() hexutil.Uint {
- return hexutil.Uint(s.b.ProtocolVersion())
-}
-
// Syncing returns false in case the node is currently not syncing with the network. It can be up to date or has not
// yet received the latest block headers from its pears. In case it is synchronizing:
// - startingBlock: block number this node started to synchronise from
@@ -1905,13 +1900,12 @@ func (api *PrivateDebugAPI) SetHead(number hexutil.Uint64) {
// PublicNetAPI offers network related RPC methods
type PublicNetAPI struct {
- net *p2p.Server
- networkVersion uint64
+ net *p2p.Server
}
// NewPublicNetAPI creates a new net API instance.
-func NewPublicNetAPI(net *p2p.Server, networkVersion uint64) *PublicNetAPI {
- return &PublicNetAPI{net, networkVersion}
+func NewPublicNetAPI(net *p2p.Server) *PublicNetAPI {
+ return &PublicNetAPI{net}
}
// Listening returns an indication if the node is listening for network connections.
@@ -1924,11 +1918,6 @@ func (s *PublicNetAPI) PeerCount() hexutil.Uint {
return hexutil.Uint(s.net.PeerCount())
}
-// Version returns the current ethereum protocol version.
-func (s *PublicNetAPI) Version() string {
- return fmt.Sprintf("%d", s.networkVersion)
-}
-
// checkTxFee is an internal function used to check whether the fee of
// the given transaction is _reasonable_(under the cap).
func checkTxFee(gasPrice *big.Int, gas uint64, cap float64) error {
diff --git a/internal/ethapi/backend.go b/internal/ethapi/backend.go
index 10e716bf2..f0a4c0493 100644
--- a/internal/ethapi/backend.go
+++ b/internal/ethapi/backend.go
@@ -41,7 +41,6 @@ import (
type Backend interface {
// General Ethereum API
Downloader() *downloader.Downloader
- ProtocolVersion() int
SuggestPrice(ctx context.Context) (*big.Int, error)
ChainDb() ethdb.Database
AccountManager() *accounts.Manager
diff --git a/les/client.go b/les/client.go
index 47997a098..198255dc5 100644
--- a/les/client.go
+++ b/les/client.go
@@ -171,7 +171,7 @@ func New(stack *node.Node, config *eth.Config) (*LightEthereum, error) {
leth.blockchain.DisableCheckFreq()
}
- leth.netRPCService = ethapi.NewPublicNetAPI(leth.p2pServer, leth.config.NetworkId)
+ leth.netRPCService = ethapi.NewPublicNetAPI(leth.p2pServer)
// Register the backend on the node
stack.RegisterAPIs(leth.APIs())
diff --git a/les/enr_entry.go b/les/enr_entry.go
index 8f0169bee..a357f689d 100644
--- a/les/enr_entry.go
+++ b/les/enr_entry.go
@@ -35,9 +35,9 @@ func (e lesEntry) ENRKey() string {
// setupDiscovery creates the node discovery source for the eth protocol.
func (eth *LightEthereum) setupDiscovery() (enode.Iterator, error) {
- if len(eth.config.DiscoveryURLs) == 0 {
+ if len(eth.config.EthDiscoveryURLs) == 0 {
return nil, nil
}
client := dnsdisc.NewClient(dnsdisc.Config{})
- return client.NewIterator(eth.config.DiscoveryURLs...)
+ return client.NewIterator(eth.config.EthDiscoveryURLs...)
}
diff --git a/les/handler_test.go b/les/handler_test.go
index 1612caf42..04277f661 100644
--- a/les/handler_test.go
+++ b/les/handler_test.go
@@ -51,7 +51,7 @@ func TestGetBlockHeadersLes2(t *testing.T) { testGetBlockHeaders(t, 2) }
func TestGetBlockHeadersLes3(t *testing.T) { testGetBlockHeaders(t, 3) }
func testGetBlockHeaders(t *testing.T, protocol int) {
- server, tearDown := newServerEnv(t, downloader.MaxHashFetch+15, protocol, nil, false, true, 0)
+ server, tearDown := newServerEnv(t, downloader.MaxHeaderFetch+15, protocol, nil, false, true, 0)
defer tearDown()
bc := server.handler.blockchain
diff --git a/les/peer.go b/les/peer.go
index 31ee8f7f0..6004af03f 100644
--- a/les/peer.go
+++ b/les/peer.go
@@ -31,7 +31,6 @@ import (
"github.com/ethereum/go-ethereum/core"
"github.com/ethereum/go-ethereum/core/forkid"
"github.com/ethereum/go-ethereum/core/types"
- "github.com/ethereum/go-ethereum/eth"
"github.com/ethereum/go-ethereum/les/flowcontrol"
lpc "github.com/ethereum/go-ethereum/les/lespay/client"
lps "github.com/ethereum/go-ethereum/les/lespay/server"
@@ -162,9 +161,17 @@ func (p *peerCommons) String() string {
return fmt.Sprintf("Peer %s [%s]", p.id, fmt.Sprintf("les/%d", p.version))
}
+// PeerInfo represents a short summary of the `eth` sub-protocol metadata known
+// about a connected peer.
+type PeerInfo struct {
+ Version int `json:"version"` // Ethereum protocol version negotiated
+ Difficulty *big.Int `json:"difficulty"` // Total difficulty of the peer's blockchain
+ Head string `json:"head"` // SHA3 hash of the peer's best owned block
+}
+
// Info gathers and returns a collection of metadata known about a peer.
-func (p *peerCommons) Info() *eth.PeerInfo {
- return ð.PeerInfo{
+func (p *peerCommons) Info() *PeerInfo {
+ return &PeerInfo{
Version: p.version,
Difficulty: p.Td(),
Head: fmt.Sprintf("%x", p.Head()),
diff --git a/les/server_handler.go b/les/server_handler.go
index f965e3fc6..2316c9c5a 100644
--- a/les/server_handler.go
+++ b/les/server_handler.go
@@ -47,7 +47,7 @@ import (
const (
softResponseLimit = 2 * 1024 * 1024 // Target maximum size of returned blocks, headers or node data.
estHeaderRlpSize = 500 // Approximate size of an RLP encoded block header
- ethVersion = 63 // equivalent eth version for the downloader
+ ethVersion = 64 // equivalent eth version for the downloader
MaxHeaderFetch = 192 // Amount of block headers to be fetched per retrieval request
MaxBodyFetch = 32 // Amount of block bodies to be fetched per retrieval request
diff --git a/tests/block_test_util.go b/tests/block_test_util.go
index be9cdb70c..c043f0b3e 100644
--- a/tests/block_test_util.go
+++ b/tests/block_test_util.go
@@ -147,7 +147,7 @@ func (t *BlockTest) Run(snapshotter bool) error {
}
// Cross-check the snapshot-to-hash against the trie hash
if snapshotter {
- if err := snapshot.VerifyState(chain.Snapshot(), chain.CurrentBlock().Root()); err != nil {
+ if err := snapshot.VerifyState(chain.Snapshots(), chain.CurrentBlock().Root()); err != nil {
return err
}
}
diff --git a/tests/fuzzers/rangeproof/corpus/1c14030f26872e57bf1481084f151d71eed8161c-1 b/tests/fuzzers/rangeproof/corpus/1c14030f26872e57bf1481084f151d71eed8161c-1
new file mode 100644
index 000000000..31c08bafa
Binary files /dev/null and b/tests/fuzzers/rangeproof/corpus/1c14030f26872e57bf1481084f151d71eed8161c-1 differ
diff --git a/tests/fuzzers/rangeproof/corpus/27e54254422543060a13ea8a4bc913d768e4adb6-2 b/tests/fuzzers/rangeproof/corpus/27e54254422543060a13ea8a4bc913d768e4adb6-2
new file mode 100644
index 000000000..7bce13ef8
Binary files /dev/null and b/tests/fuzzers/rangeproof/corpus/27e54254422543060a13ea8a4bc913d768e4adb6-2 differ
diff --git a/tests/fuzzers/rangeproof/corpus/6bfc2cbe2d7a43361e240118439785445a0fdfb7-5 b/tests/fuzzers/rangeproof/corpus/6bfc2cbe2d7a43361e240118439785445a0fdfb7-5
new file mode 100644
index 000000000..613e76a02
Binary files /dev/null and b/tests/fuzzers/rangeproof/corpus/6bfc2cbe2d7a43361e240118439785445a0fdfb7-5 differ
diff --git a/tests/fuzzers/rangeproof/corpus/a67e63bc0c0004bd009944a6061297cb7d4ac238-1 b/tests/fuzzers/rangeproof/corpus/a67e63bc0c0004bd009944a6061297cb7d4ac238-1
new file mode 100644
index 000000000..805ad8df7
Binary files /dev/null and b/tests/fuzzers/rangeproof/corpus/a67e63bc0c0004bd009944a6061297cb7d4ac238-1 differ
diff --git a/tests/fuzzers/rangeproof/corpus/ae892bbae0a843950bc8316496e595b1a194c009-4 b/tests/fuzzers/rangeproof/corpus/ae892bbae0a843950bc8316496e595b1a194c009-4
new file mode 100644
index 000000000..605acf81c
Binary files /dev/null and b/tests/fuzzers/rangeproof/corpus/ae892bbae0a843950bc8316496e595b1a194c009-4 differ
diff --git a/tests/fuzzers/rangeproof/corpus/ee05d0d813f6261b3dba16506f9ea03d9c5e993d-2 b/tests/fuzzers/rangeproof/corpus/ee05d0d813f6261b3dba16506f9ea03d9c5e993d-2
new file mode 100644
index 000000000..8f32dd775
Binary files /dev/null and b/tests/fuzzers/rangeproof/corpus/ee05d0d813f6261b3dba16506f9ea03d9c5e993d-2 differ
diff --git a/tests/fuzzers/rangeproof/corpus/f50a6d57a46d30184aa294af5b252ab9701af7c9-2 b/tests/fuzzers/rangeproof/corpus/f50a6d57a46d30184aa294af5b252ab9701af7c9-2
new file mode 100644
index 000000000..af96210f2
Binary files /dev/null and b/tests/fuzzers/rangeproof/corpus/f50a6d57a46d30184aa294af5b252ab9701af7c9-2 differ
diff --git a/tests/fuzzers/rangeproof/corpus/random.dat b/tests/fuzzers/rangeproof/corpus/random.dat
new file mode 100644
index 000000000..2c998ad81
Binary files /dev/null and b/tests/fuzzers/rangeproof/corpus/random.dat differ
diff --git a/tests/fuzzers/rangeproof/debug/main.go b/tests/fuzzers/rangeproof/debug/main.go
new file mode 100644
index 000000000..a81c69fea
--- /dev/null
+++ b/tests/fuzzers/rangeproof/debug/main.go
@@ -0,0 +1,41 @@
+// Copyright 2020 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see .
+
+package main
+
+import (
+ "fmt"
+ "io/ioutil"
+ "os"
+
+ "github.com/ethereum/go-ethereum/tests/fuzzers/rangeproof"
+)
+
+func main() {
+ if len(os.Args) != 2 {
+ fmt.Fprintf(os.Stderr, "Usage: debug \n")
+ fmt.Fprintf(os.Stderr, "Example\n")
+ fmt.Fprintf(os.Stderr, " $ debug ../crashers/4bbef6857c733a87ecf6fd8b9e7238f65eb9862a\n")
+ os.Exit(1)
+ }
+ crasher := os.Args[1]
+ data, err := ioutil.ReadFile(crasher)
+ if err != nil {
+ fmt.Fprintf(os.Stderr, "error loading crasher %v: %v", crasher, err)
+ os.Exit(1)
+ }
+ rangeproof.Fuzz(data)
+}
diff --git a/tests/fuzzers/rangeproof/rangeproof-fuzzer.go b/tests/fuzzers/rangeproof/rangeproof-fuzzer.go
new file mode 100644
index 000000000..b82a38072
--- /dev/null
+++ b/tests/fuzzers/rangeproof/rangeproof-fuzzer.go
@@ -0,0 +1,218 @@
+// Copyright 2020 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see .
+
+package rangeproof
+
+import (
+ "bytes"
+ "encoding/binary"
+ "fmt"
+ "io"
+ "sort"
+
+ "github.com/ethereum/go-ethereum/common"
+ "github.com/ethereum/go-ethereum/ethdb/memorydb"
+ "github.com/ethereum/go-ethereum/trie"
+)
+
+type kv struct {
+ k, v []byte
+ t bool
+}
+
+type entrySlice []*kv
+
+func (p entrySlice) Len() int { return len(p) }
+func (p entrySlice) Less(i, j int) bool { return bytes.Compare(p[i].k, p[j].k) < 0 }
+func (p entrySlice) Swap(i, j int) { p[i], p[j] = p[j], p[i] }
+
+type fuzzer struct {
+ input io.Reader
+ exhausted bool
+}
+
+func (f *fuzzer) randBytes(n int) []byte {
+ r := make([]byte, n)
+ if _, err := f.input.Read(r); err != nil {
+ f.exhausted = true
+ }
+ return r
+}
+
+func (f *fuzzer) readInt() uint64 {
+ var x uint64
+ if err := binary.Read(f.input, binary.LittleEndian, &x); err != nil {
+ f.exhausted = true
+ }
+ return x
+}
+
+func (f *fuzzer) randomTrie(n int) (*trie.Trie, map[string]*kv) {
+
+ trie := new(trie.Trie)
+ vals := make(map[string]*kv)
+ size := f.readInt()
+ // Fill it with some fluff
+ for i := byte(0); i < byte(size); i++ {
+ value := &kv{common.LeftPadBytes([]byte{i}, 32), []byte{i}, false}
+ value2 := &kv{common.LeftPadBytes([]byte{i + 10}, 32), []byte{i}, false}
+ trie.Update(value.k, value.v)
+ trie.Update(value2.k, value2.v)
+ vals[string(value.k)] = value
+ vals[string(value2.k)] = value2
+ }
+ if f.exhausted {
+ return nil, nil
+ }
+ // And now fill with some random
+ for i := 0; i < n; i++ {
+ k := f.randBytes(32)
+ v := f.randBytes(20)
+ value := &kv{k, v, false}
+ trie.Update(k, v)
+ vals[string(k)] = value
+ if f.exhausted {
+ return nil, nil
+ }
+ }
+ return trie, vals
+}
+
+func (f *fuzzer) fuzz() int {
+ maxSize := 200
+ tr, vals := f.randomTrie(1 + int(f.readInt())%maxSize)
+ if f.exhausted {
+ return 0 // input too short
+ }
+ var entries entrySlice
+ for _, kv := range vals {
+ entries = append(entries, kv)
+ }
+ if len(entries) <= 1 {
+ return 0
+ }
+ sort.Sort(entries)
+
+ var ok = 0
+ for {
+ start := int(f.readInt() % uint64(len(entries)))
+ end := 1 + int(f.readInt()%uint64(len(entries)-1))
+ testcase := int(f.readInt() % uint64(6))
+ index := int(f.readInt() & 0xFFFFFFFF)
+ index2 := int(f.readInt() & 0xFFFFFFFF)
+ if f.exhausted {
+ break
+ }
+ proof := memorydb.New()
+ if err := tr.Prove(entries[start].k, 0, proof); err != nil {
+ panic(fmt.Sprintf("Failed to prove the first node %v", err))
+ }
+ if err := tr.Prove(entries[end-1].k, 0, proof); err != nil {
+ panic(fmt.Sprintf("Failed to prove the last node %v", err))
+ }
+ var keys [][]byte
+ var vals [][]byte
+ for i := start; i < end; i++ {
+ keys = append(keys, entries[i].k)
+ vals = append(vals, entries[i].v)
+ }
+ if len(keys) == 0 {
+ return 0
+ }
+ var first, last = keys[0], keys[len(keys)-1]
+ testcase %= 6
+ switch testcase {
+ case 0:
+ // Modified key
+ keys[index%len(keys)] = f.randBytes(32) // In theory it can't be same
+ case 1:
+ // Modified val
+ vals[index%len(vals)] = f.randBytes(20) // In theory it can't be same
+ case 2:
+ // Gapped entry slice
+ index = index % len(keys)
+ keys = append(keys[:index], keys[index+1:]...)
+ vals = append(vals[:index], vals[index+1:]...)
+ case 3:
+ // Out of order
+ index1 := index % len(keys)
+ index2 := index2 % len(keys)
+ keys[index1], keys[index2] = keys[index2], keys[index1]
+ vals[index1], vals[index2] = vals[index2], vals[index1]
+ case 4:
+ // Set random key to nil, do nothing
+ keys[index%len(keys)] = nil
+ case 5:
+ // Set random value to nil, deletion
+ vals[index%len(vals)] = nil
+
+ // Other cases:
+ // Modify something in the proof db
+ // add stuff to proof db
+ // drop stuff from proof db
+
+ }
+ if f.exhausted {
+ break
+ }
+ ok = 1
+ //nodes, subtrie
+ nodes, subtrie, notary, hasMore, err := trie.VerifyRangeProof(tr.Hash(), first, last, keys, vals, proof)
+ if err != nil {
+ if nodes != nil {
+ panic("err != nil && nodes != nil")
+ }
+ if subtrie != nil {
+ panic("err != nil && subtrie != nil")
+ }
+ if notary != nil {
+ panic("err != nil && notary != nil")
+ }
+ if hasMore {
+ panic("err != nil && hasMore == true")
+ }
+ } else {
+ if nodes == nil {
+ panic("err == nil && nodes == nil")
+ }
+ if subtrie == nil {
+ panic("err == nil && subtrie == nil")
+ }
+ if notary == nil {
+ panic("err == nil && subtrie == nil")
+ }
+ }
+ }
+ return ok
+}
+
+// The function must return
+// 1 if the fuzzer should increase priority of the
+// given input during subsequent fuzzing (for example, the input is lexically
+// correct and was parsed successfully);
+// -1 if the input must not be added to corpus even if gives new coverage; and
+// 0 otherwise; other values are reserved for future use.
+func Fuzz(input []byte) int {
+ if len(input) < 100 {
+ return 0
+ }
+ r := bytes.NewReader(input)
+ f := fuzzer{
+ input: r,
+ exhausted: false,
+ }
+ return f.fuzz()
+}
diff --git a/trie/notary.go b/trie/notary.go
new file mode 100644
index 000000000..5a64727aa
--- /dev/null
+++ b/trie/notary.go
@@ -0,0 +1,57 @@
+// Copyright 2020 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see .
+
+package trie
+
+import (
+ "github.com/ethereum/go-ethereum/ethdb"
+ "github.com/ethereum/go-ethereum/ethdb/memorydb"
+)
+
+// KeyValueNotary tracks which keys have been accessed through a key-value reader
+// with te scope of verifying if certain proof datasets are maliciously bloated.
+type KeyValueNotary struct {
+ ethdb.KeyValueReader
+ reads map[string]struct{}
+}
+
+// NewKeyValueNotary wraps a key-value database with an access notary to track
+// which items have bene accessed.
+func NewKeyValueNotary(db ethdb.KeyValueReader) *KeyValueNotary {
+ return &KeyValueNotary{
+ KeyValueReader: db,
+ reads: make(map[string]struct{}),
+ }
+}
+
+// Get retrieves an item from the underlying database, but also tracks it as an
+// accessed slot for bloat checks.
+func (k *KeyValueNotary) Get(key []byte) ([]byte, error) {
+ k.reads[string(key)] = struct{}{}
+ return k.KeyValueReader.Get(key)
+}
+
+// Accessed returns s snapshot of the original key-value store containing only the
+// data accessed through the notary.
+func (k *KeyValueNotary) Accessed() ethdb.KeyValueStore {
+ db := memorydb.New()
+ for keystr := range k.reads {
+ key := []byte(keystr)
+ val, _ := k.KeyValueReader.Get(key)
+ db.Put(key, val)
+ }
+ return db
+}
diff --git a/trie/proof.go b/trie/proof.go
index 2f52438f9..e7102f12b 100644
--- a/trie/proof.go
+++ b/trie/proof.go
@@ -426,7 +426,7 @@ func hasRightElement(node node, key []byte) bool {
// VerifyRangeProof checks whether the given leaf nodes and edge proof
// can prove the given trie leaves range is matched with the specific root.
-// Besides, the range should be consecutive(no gap inside) and monotonic
+// Besides, the range should be consecutive (no gap inside) and monotonic
// increasing.
//
// Note the given proof actually contains two edge proofs. Both of them can
@@ -454,96 +454,136 @@ func hasRightElement(node node, key []byte) bool {
//
// Except returning the error to indicate the proof is valid or not, the function will
// also return a flag to indicate whether there exists more accounts/slots in the trie.
-func VerifyRangeProof(rootHash common.Hash, firstKey []byte, lastKey []byte, keys [][]byte, values [][]byte, proof ethdb.KeyValueReader) (error, bool) {
+func VerifyRangeProof(rootHash common.Hash, firstKey []byte, lastKey []byte, keys [][]byte, values [][]byte, proof ethdb.KeyValueReader) (ethdb.KeyValueStore, *Trie, *KeyValueNotary, bool, error) {
if len(keys) != len(values) {
- return fmt.Errorf("inconsistent proof data, keys: %d, values: %d", len(keys), len(values)), false
+ return nil, nil, nil, false, fmt.Errorf("inconsistent proof data, keys: %d, values: %d", len(keys), len(values))
}
// Ensure the received batch is monotonic increasing.
for i := 0; i < len(keys)-1; i++ {
if bytes.Compare(keys[i], keys[i+1]) >= 0 {
- return errors.New("range is not monotonically increasing"), false
+ return nil, nil, nil, false, errors.New("range is not monotonically increasing")
}
}
+ // Create a key-value notary to track which items from the given proof the
+ // range prover actually needed to verify the data
+ notary := NewKeyValueNotary(proof)
+
// Special case, there is no edge proof at all. The given range is expected
// to be the whole leaf-set in the trie.
if proof == nil {
- emptytrie, err := New(common.Hash{}, NewDatabase(memorydb.New()))
+ var (
+ diskdb = memorydb.New()
+ triedb = NewDatabase(diskdb)
+ )
+ tr, err := New(common.Hash{}, triedb)
if err != nil {
- return err, false
+ return nil, nil, nil, false, err
}
for index, key := range keys {
- emptytrie.TryUpdate(key, values[index])
+ tr.TryUpdate(key, values[index])
}
- if emptytrie.Hash() != rootHash {
- return fmt.Errorf("invalid proof, want hash %x, got %x", rootHash, emptytrie.Hash()), false
+ if tr.Hash() != rootHash {
+ return nil, nil, nil, false, fmt.Errorf("invalid proof, want hash %x, got %x", rootHash, tr.Hash())
}
- return nil, false // no more element.
+ // Proof seems valid, serialize all the nodes into the database
+ if _, err := tr.Commit(nil); err != nil {
+ return nil, nil, nil, false, err
+ }
+ if err := triedb.Commit(rootHash, false, nil); err != nil {
+ return nil, nil, nil, false, err
+ }
+ return diskdb, tr, notary, false, nil // No more elements
}
// Special case, there is a provided edge proof but zero key/value
// pairs, ensure there are no more accounts / slots in the trie.
if len(keys) == 0 {
- root, val, err := proofToPath(rootHash, nil, firstKey, proof, true)
+ root, val, err := proofToPath(rootHash, nil, firstKey, notary, true)
if err != nil {
- return err, false
+ return nil, nil, nil, false, err
}
if val != nil || hasRightElement(root, firstKey) {
- return errors.New("more entries available"), false
+ return nil, nil, nil, false, errors.New("more entries available")
}
- return nil, false
+ // Since the entire proof is a single path, we can construct a trie and a
+ // node database directly out of the inputs, no need to generate them
+ diskdb := notary.Accessed()
+ tr := &Trie{
+ db: NewDatabase(diskdb),
+ root: root,
+ }
+ return diskdb, tr, notary, hasRightElement(root, firstKey), nil
}
// Special case, there is only one element and two edge keys are same.
// In this case, we can't construct two edge paths. So handle it here.
if len(keys) == 1 && bytes.Equal(firstKey, lastKey) {
- root, val, err := proofToPath(rootHash, nil, firstKey, proof, false)
+ root, val, err := proofToPath(rootHash, nil, firstKey, notary, false)
if err != nil {
- return err, false
+ return nil, nil, nil, false, err
}
if !bytes.Equal(firstKey, keys[0]) {
- return errors.New("correct proof but invalid key"), false
+ return nil, nil, nil, false, errors.New("correct proof but invalid key")
}
if !bytes.Equal(val, values[0]) {
- return errors.New("correct proof but invalid data"), false
+ return nil, nil, nil, false, errors.New("correct proof but invalid data")
}
- return nil, hasRightElement(root, firstKey)
+ // Since the entire proof is a single path, we can construct a trie and a
+ // node database directly out of the inputs, no need to generate them
+ diskdb := notary.Accessed()
+ tr := &Trie{
+ db: NewDatabase(diskdb),
+ root: root,
+ }
+ return diskdb, tr, notary, hasRightElement(root, firstKey), nil
}
// Ok, in all other cases, we require two edge paths available.
// First check the validity of edge keys.
if bytes.Compare(firstKey, lastKey) >= 0 {
- return errors.New("invalid edge keys"), false
+ return nil, nil, nil, false, errors.New("invalid edge keys")
}
// todo(rjl493456442) different length edge keys should be supported
if len(firstKey) != len(lastKey) {
- return errors.New("inconsistent edge keys"), false
+ return nil, nil, nil, false, errors.New("inconsistent edge keys")
}
// Convert the edge proofs to edge trie paths. Then we can
// have the same tree architecture with the original one.
// For the first edge proof, non-existent proof is allowed.
- root, _, err := proofToPath(rootHash, nil, firstKey, proof, true)
+ root, _, err := proofToPath(rootHash, nil, firstKey, notary, true)
if err != nil {
- return err, false
+ return nil, nil, nil, false, err
}
// Pass the root node here, the second path will be merged
// with the first one. For the last edge proof, non-existent
// proof is also allowed.
- root, _, err = proofToPath(rootHash, root, lastKey, proof, true)
+ root, _, err = proofToPath(rootHash, root, lastKey, notary, true)
if err != nil {
- return err, false
+ return nil, nil, nil, false, err
}
// Remove all internal references. All the removed parts should
// be re-filled(or re-constructed) by the given leaves range.
if err := unsetInternal(root, firstKey, lastKey); err != nil {
- return err, false
+ return nil, nil, nil, false, err
}
- // Rebuild the trie with the leave stream, the shape of trie
+ // Rebuild the trie with the leaf stream, the shape of trie
// should be same with the original one.
- newtrie := &Trie{root: root, db: NewDatabase(memorydb.New())}
+ var (
+ diskdb = memorydb.New()
+ triedb = NewDatabase(diskdb)
+ )
+ tr := &Trie{root: root, db: triedb}
for index, key := range keys {
- newtrie.TryUpdate(key, values[index])
+ tr.TryUpdate(key, values[index])
}
- if newtrie.Hash() != rootHash {
- return fmt.Errorf("invalid proof, want hash %x, got %x", rootHash, newtrie.Hash()), false
+ if tr.Hash() != rootHash {
+ return nil, nil, nil, false, fmt.Errorf("invalid proof, want hash %x, got %x", rootHash, tr.Hash())
}
- return nil, hasRightElement(root, keys[len(keys)-1])
+ // Proof seems valid, serialize all the nodes into the database
+ if _, err := tr.Commit(nil); err != nil {
+ return nil, nil, nil, false, err
+ }
+ if err := triedb.Commit(rootHash, false, nil); err != nil {
+ return nil, nil, nil, false, err
+ }
+ return diskdb, tr, notary, hasRightElement(root, keys[len(keys)-1]), nil
}
// get returns the child of the given node. Return nil if the
diff --git a/trie/proof_test.go b/trie/proof_test.go
index 6cdc242d9..3ecd31888 100644
--- a/trie/proof_test.go
+++ b/trie/proof_test.go
@@ -19,6 +19,7 @@ package trie
import (
"bytes"
crand "crypto/rand"
+ "encoding/binary"
mrand "math/rand"
"sort"
"testing"
@@ -181,7 +182,7 @@ func TestRangeProof(t *testing.T) {
keys = append(keys, entries[i].k)
vals = append(vals, entries[i].v)
}
- err, _ := VerifyRangeProof(trie.Hash(), keys[0], keys[len(keys)-1], keys, vals, proof)
+ _, _, _, _, err := VerifyRangeProof(trie.Hash(), keys[0], keys[len(keys)-1], keys, vals, proof)
if err != nil {
t.Fatalf("Case %d(%d->%d) expect no error, got %v", i, start, end-1, err)
}
@@ -232,7 +233,7 @@ func TestRangeProofWithNonExistentProof(t *testing.T) {
keys = append(keys, entries[i].k)
vals = append(vals, entries[i].v)
}
- err, _ := VerifyRangeProof(trie.Hash(), first, last, keys, vals, proof)
+ _, _, _, _, err := VerifyRangeProof(trie.Hash(), first, last, keys, vals, proof)
if err != nil {
t.Fatalf("Case %d(%d->%d) expect no error, got %v", i, start, end-1, err)
}
@@ -253,7 +254,7 @@ func TestRangeProofWithNonExistentProof(t *testing.T) {
k = append(k, entries[i].k)
v = append(v, entries[i].v)
}
- err, _ := VerifyRangeProof(trie.Hash(), first, last, k, v, proof)
+ _, _, _, _, err := VerifyRangeProof(trie.Hash(), first, last, k, v, proof)
if err != nil {
t.Fatal("Failed to verify whole rang with non-existent edges")
}
@@ -288,7 +289,7 @@ func TestRangeProofWithInvalidNonExistentProof(t *testing.T) {
k = append(k, entries[i].k)
v = append(v, entries[i].v)
}
- err, _ := VerifyRangeProof(trie.Hash(), first, k[len(k)-1], k, v, proof)
+ _, _, _, _, err := VerifyRangeProof(trie.Hash(), first, k[len(k)-1], k, v, proof)
if err == nil {
t.Fatalf("Expected to detect the error, got nil")
}
@@ -310,7 +311,7 @@ func TestRangeProofWithInvalidNonExistentProof(t *testing.T) {
k = append(k, entries[i].k)
v = append(v, entries[i].v)
}
- err, _ = VerifyRangeProof(trie.Hash(), k[0], last, k, v, proof)
+ _, _, _, _, err = VerifyRangeProof(trie.Hash(), k[0], last, k, v, proof)
if err == nil {
t.Fatalf("Expected to detect the error, got nil")
}
@@ -334,7 +335,7 @@ func TestOneElementRangeProof(t *testing.T) {
if err := trie.Prove(entries[start].k, 0, proof); err != nil {
t.Fatalf("Failed to prove the first node %v", err)
}
- err, _ := VerifyRangeProof(trie.Hash(), entries[start].k, entries[start].k, [][]byte{entries[start].k}, [][]byte{entries[start].v}, proof)
+ _, _, _, _, err := VerifyRangeProof(trie.Hash(), entries[start].k, entries[start].k, [][]byte{entries[start].k}, [][]byte{entries[start].v}, proof)
if err != nil {
t.Fatalf("Expected no error, got %v", err)
}
@@ -349,7 +350,7 @@ func TestOneElementRangeProof(t *testing.T) {
if err := trie.Prove(entries[start].k, 0, proof); err != nil {
t.Fatalf("Failed to prove the last node %v", err)
}
- err, _ = VerifyRangeProof(trie.Hash(), first, entries[start].k, [][]byte{entries[start].k}, [][]byte{entries[start].v}, proof)
+ _, _, _, _, err = VerifyRangeProof(trie.Hash(), first, entries[start].k, [][]byte{entries[start].k}, [][]byte{entries[start].v}, proof)
if err != nil {
t.Fatalf("Expected no error, got %v", err)
}
@@ -364,7 +365,7 @@ func TestOneElementRangeProof(t *testing.T) {
if err := trie.Prove(last, 0, proof); err != nil {
t.Fatalf("Failed to prove the last node %v", err)
}
- err, _ = VerifyRangeProof(trie.Hash(), entries[start].k, last, [][]byte{entries[start].k}, [][]byte{entries[start].v}, proof)
+ _, _, _, _, err = VerifyRangeProof(trie.Hash(), entries[start].k, last, [][]byte{entries[start].k}, [][]byte{entries[start].v}, proof)
if err != nil {
t.Fatalf("Expected no error, got %v", err)
}
@@ -379,7 +380,7 @@ func TestOneElementRangeProof(t *testing.T) {
if err := trie.Prove(last, 0, proof); err != nil {
t.Fatalf("Failed to prove the last node %v", err)
}
- err, _ = VerifyRangeProof(trie.Hash(), first, last, [][]byte{entries[start].k}, [][]byte{entries[start].v}, proof)
+ _, _, _, _, err = VerifyRangeProof(trie.Hash(), first, last, [][]byte{entries[start].k}, [][]byte{entries[start].v}, proof)
if err != nil {
t.Fatalf("Expected no error, got %v", err)
}
@@ -401,7 +402,7 @@ func TestAllElementsProof(t *testing.T) {
k = append(k, entries[i].k)
v = append(v, entries[i].v)
}
- err, _ := VerifyRangeProof(trie.Hash(), nil, nil, k, v, nil)
+ _, _, _, _, err := VerifyRangeProof(trie.Hash(), nil, nil, k, v, nil)
if err != nil {
t.Fatalf("Expected no error, got %v", err)
}
@@ -414,7 +415,7 @@ func TestAllElementsProof(t *testing.T) {
if err := trie.Prove(entries[len(entries)-1].k, 0, proof); err != nil {
t.Fatalf("Failed to prove the last node %v", err)
}
- err, _ = VerifyRangeProof(trie.Hash(), k[0], k[len(k)-1], k, v, proof)
+ _, _, _, _, err = VerifyRangeProof(trie.Hash(), k[0], k[len(k)-1], k, v, proof)
if err != nil {
t.Fatalf("Expected no error, got %v", err)
}
@@ -429,7 +430,7 @@ func TestAllElementsProof(t *testing.T) {
if err := trie.Prove(last, 0, proof); err != nil {
t.Fatalf("Failed to prove the last node %v", err)
}
- err, _ = VerifyRangeProof(trie.Hash(), first, last, k, v, proof)
+ _, _, _, _, err = VerifyRangeProof(trie.Hash(), first, last, k, v, proof)
if err != nil {
t.Fatalf("Expected no error, got %v", err)
}
@@ -462,7 +463,7 @@ func TestSingleSideRangeProof(t *testing.T) {
k = append(k, entries[i].k)
v = append(v, entries[i].v)
}
- err, _ := VerifyRangeProof(trie.Hash(), common.Hash{}.Bytes(), k[len(k)-1], k, v, proof)
+ _, _, _, _, err := VerifyRangeProof(trie.Hash(), common.Hash{}.Bytes(), k[len(k)-1], k, v, proof)
if err != nil {
t.Fatalf("Expected no error, got %v", err)
}
@@ -498,7 +499,7 @@ func TestReverseSingleSideRangeProof(t *testing.T) {
k = append(k, entries[i].k)
v = append(v, entries[i].v)
}
- err, _ := VerifyRangeProof(trie.Hash(), k[0], last.Bytes(), k, v, proof)
+ _, _, _, _, err := VerifyRangeProof(trie.Hash(), k[0], last.Bytes(), k, v, proof)
if err != nil {
t.Fatalf("Expected no error, got %v", err)
}
@@ -570,7 +571,7 @@ func TestBadRangeProof(t *testing.T) {
index = mrand.Intn(end - start)
vals[index] = nil
}
- err, _ := VerifyRangeProof(trie.Hash(), first, last, keys, vals, proof)
+ _, _, _, _, err := VerifyRangeProof(trie.Hash(), first, last, keys, vals, proof)
if err == nil {
t.Fatalf("%d Case %d index %d range: (%d->%d) expect error, got nil", i, testcase, index, start, end-1)
}
@@ -604,7 +605,7 @@ func TestGappedRangeProof(t *testing.T) {
keys = append(keys, entries[i].k)
vals = append(vals, entries[i].v)
}
- err, _ := VerifyRangeProof(trie.Hash(), keys[0], keys[len(keys)-1], keys, vals, proof)
+ _, _, _, _, err := VerifyRangeProof(trie.Hash(), keys[0], keys[len(keys)-1], keys, vals, proof)
if err == nil {
t.Fatal("expect error, got nil")
}
@@ -631,7 +632,7 @@ func TestSameSideProofs(t *testing.T) {
if err := trie.Prove(last, 0, proof); err != nil {
t.Fatalf("Failed to prove the last node %v", err)
}
- err, _ := VerifyRangeProof(trie.Hash(), first, last, [][]byte{entries[pos].k}, [][]byte{entries[pos].v}, proof)
+ _, _, _, _, err := VerifyRangeProof(trie.Hash(), first, last, [][]byte{entries[pos].k}, [][]byte{entries[pos].v}, proof)
if err == nil {
t.Fatalf("Expected error, got nil")
}
@@ -647,7 +648,7 @@ func TestSameSideProofs(t *testing.T) {
if err := trie.Prove(last, 0, proof); err != nil {
t.Fatalf("Failed to prove the last node %v", err)
}
- err, _ = VerifyRangeProof(trie.Hash(), first, last, [][]byte{entries[pos].k}, [][]byte{entries[pos].v}, proof)
+ _, _, _, _, err = VerifyRangeProof(trie.Hash(), first, last, [][]byte{entries[pos].k}, [][]byte{entries[pos].v}, proof)
if err == nil {
t.Fatalf("Expected error, got nil")
}
@@ -715,7 +716,7 @@ func TestHasRightElement(t *testing.T) {
k = append(k, entries[i].k)
v = append(v, entries[i].v)
}
- err, hasMore := VerifyRangeProof(trie.Hash(), firstKey, lastKey, k, v, proof)
+ _, _, _, hasMore, err := VerifyRangeProof(trie.Hash(), firstKey, lastKey, k, v, proof)
if err != nil {
t.Fatalf("Expected no error, got %v", err)
}
@@ -748,13 +749,57 @@ func TestEmptyRangeProof(t *testing.T) {
if err := trie.Prove(first, 0, proof); err != nil {
t.Fatalf("Failed to prove the first node %v", err)
}
- err, _ := VerifyRangeProof(trie.Hash(), first, nil, nil, nil, proof)
+ db, tr, not, _, err := VerifyRangeProof(trie.Hash(), first, nil, nil, nil, proof)
if c.err && err == nil {
t.Fatalf("Expected error, got nil")
}
if !c.err && err != nil {
t.Fatalf("Expected no error, got %v", err)
}
+ // If no error was returned, ensure the returned trie and database contains
+ // the entire proof, since there's no value
+ if !c.err {
+ if err := tr.Prove(first, 0, memorydb.New()); err != nil {
+ t.Errorf("returned trie doesn't contain original proof: %v", err)
+ }
+ if memdb := db.(*memorydb.Database); memdb.Len() != proof.Len() {
+ t.Errorf("database entry count mismatch: have %d, want %d", memdb.Len(), proof.Len())
+ }
+ if not == nil {
+ t.Errorf("missing notary")
+ }
+ }
+ }
+}
+
+// TestBloatedProof tests a malicious proof, where the proof is more or less the
+// whole trie.
+func TestBloatedProof(t *testing.T) {
+ // Use a small trie
+ trie, kvs := nonRandomTrie(100)
+ var entries entrySlice
+ for _, kv := range kvs {
+ entries = append(entries, kv)
+ }
+ sort.Sort(entries)
+ var keys [][]byte
+ var vals [][]byte
+
+ proof := memorydb.New()
+ for i, entry := range entries {
+ trie.Prove(entry.k, 0, proof)
+ if i == 50 {
+ keys = append(keys, entry.k)
+ vals = append(vals, entry.v)
+ }
+ }
+ want := memorydb.New()
+ trie.Prove(keys[0], 0, want)
+ trie.Prove(keys[len(keys)-1], 0, want)
+
+ _, _, notary, _, _ := VerifyRangeProof(trie.Hash(), keys[0], keys[len(keys)-1], keys, vals, proof)
+ if used := notary.Accessed().(*memorydb.Database); used.Len() != want.Len() {
+ t.Fatalf("notary proof size mismatch: have %d, want %d", used.Len(), want.Len())
}
}
@@ -858,7 +903,7 @@ func benchmarkVerifyRangeProof(b *testing.B, size int) {
b.ResetTimer()
for i := 0; i < b.N; i++ {
- err, _ := VerifyRangeProof(trie.Hash(), keys[0], keys[len(keys)-1], keys, values, proof)
+ _, _, _, _, err := VerifyRangeProof(trie.Hash(), keys[0], keys[len(keys)-1], keys, values, proof)
if err != nil {
b.Fatalf("Case %d(%d->%d) expect no error, got %v", i, start, end-1, err)
}
@@ -889,3 +934,20 @@ func randBytes(n int) []byte {
crand.Read(r)
return r
}
+
+func nonRandomTrie(n int) (*Trie, map[string]*kv) {
+ trie := new(Trie)
+ vals := make(map[string]*kv)
+ max := uint64(0xffffffffffffffff)
+ for i := uint64(0); i < uint64(n); i++ {
+ value := make([]byte, 32)
+ key := make([]byte, 32)
+ binary.LittleEndian.PutUint64(key, i)
+ binary.LittleEndian.PutUint64(value, i-max)
+ //value := &kv{common.LeftPadBytes([]byte{i}, 32), []byte{i}, false}
+ elem := &kv{key, value, false}
+ trie.Update(elem.k, elem.v)
+ vals[string(elem.k)] = elem
+ }
+ return trie, vals
+}
diff --git a/trie/sync_bloom.go b/trie/sync_bloom.go
index 89f61d66d..979f4748f 100644
--- a/trie/sync_bloom.go
+++ b/trie/sync_bloom.go
@@ -125,14 +125,14 @@ func (b *SyncBloom) init(database ethdb.Iteratee) {
it.Release()
it = database.NewIterator(nil, key)
- log.Info("Initializing fast sync bloom", "items", b.bloom.N(), "errorrate", b.errorRate(), "elapsed", common.PrettyDuration(time.Since(start)))
+ log.Info("Initializing state bloom", "items", b.bloom.N(), "errorrate", b.errorRate(), "elapsed", common.PrettyDuration(time.Since(start)))
swap = time.Now()
}
}
it.Release()
// Mark the bloom filter inited and return
- log.Info("Initialized fast sync bloom", "items", b.bloom.N(), "errorrate", b.errorRate(), "elapsed", common.PrettyDuration(time.Since(start)))
+ log.Info("Initialized state bloom", "items", b.bloom.N(), "errorrate", b.errorRate(), "elapsed", common.PrettyDuration(time.Since(start)))
atomic.StoreUint32(&b.inited, 1)
}
@@ -162,7 +162,7 @@ func (b *SyncBloom) Close() error {
b.pend.Wait()
// Wipe the bloom, but mark it "uninited" just in case someone attempts an access
- log.Info("Deallocated fast sync bloom", "items", b.bloom.N(), "errorrate", b.errorRate())
+ log.Info("Deallocated state bloom", "items", b.bloom.N(), "errorrate", b.errorRate())
atomic.StoreUint32(&b.inited, 0)
b.bloom = nil
diff --git a/trie/trie.go b/trie/trie.go
index 6ddbbd78d..87b72ecf1 100644
--- a/trie/trie.go
+++ b/trie/trie.go
@@ -19,13 +19,13 @@ package trie
import (
"bytes"
+ "errors"
"fmt"
"sync"
"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/crypto"
"github.com/ethereum/go-ethereum/log"
- "github.com/ethereum/go-ethereum/rlp"
)
var (
@@ -159,29 +159,26 @@ func (t *Trie) TryGetNode(path []byte) ([]byte, int, error) {
if item == nil {
return nil, resolved, nil
}
- enc, err := rlp.EncodeToBytes(item)
- if err != nil {
- log.Error("Encoding existing trie node failed", "err", err)
- return nil, resolved, err
- }
- return enc, resolved, err
+ return item, resolved, err
}
-func (t *Trie) tryGetNode(origNode node, path []byte, pos int) (item node, newnode node, resolved int, err error) {
+func (t *Trie) tryGetNode(origNode node, path []byte, pos int) (item []byte, newnode node, resolved int, err error) {
// If we reached the requested path, return the current node
if pos >= len(path) {
- // Don't return collapsed hash nodes though
- if _, ok := origNode.(hashNode); !ok {
- // Short nodes have expanded keys, compact them before returning
- item := origNode
- if sn, ok := item.(*shortNode); ok {
- item = &shortNode{
- Key: hexToCompact(sn.Key),
- Val: sn.Val,
- }
- }
- return item, origNode, 0, nil
+ // Although we most probably have the original node expanded, encoding
+ // that into consensus form can be nasty (needs to cascade down) and
+ // time consuming. Instead, just pull the hash up from disk directly.
+ var hash hashNode
+ if node, ok := origNode.(hashNode); ok {
+ hash = node
+ } else {
+ hash, _ = origNode.cache()
}
+ if hash == nil {
+ return nil, origNode, 0, errors.New("non-consensus node")
+ }
+ blob, err := t.db.Node(common.BytesToHash(hash))
+ return blob, origNode, 1, err
}
// Path still needs to be traversed, descend into children
switch n := (origNode).(type) {
@@ -491,7 +488,7 @@ func (t *Trie) resolveHash(n hashNode, prefix []byte) (node, error) {
// Hash returns the root hash of the trie. It does not write to the
// database and can be used even if the trie doesn't have one.
func (t *Trie) Hash() common.Hash {
- hash, cached, _ := t.hashRoot(nil)
+ hash, cached, _ := t.hashRoot()
t.root = cached
return common.BytesToHash(hash.(hashNode))
}
@@ -545,7 +542,7 @@ func (t *Trie) Commit(onleaf LeafCallback) (root common.Hash, err error) {
}
// hashRoot calculates the root hash of the given trie
-func (t *Trie) hashRoot(db *Database) (node, node, error) {
+func (t *Trie) hashRoot() (node, node, error) {
if t.root == nil {
return hashNode(emptyRoot.Bytes()), nil, nil
}