core, eth: split eth package, implement snap protocol (#21482)

This commit splits the eth package, separating the handling of eth and snap protocols. It also includes the capability to run snap sync (https://github.com/ethereum/devp2p/blob/master/caps/snap.md) , but does not enable it by default. 

Co-authored-by: Marius van der Wijden <m.vanderwijden@live.de>
Co-authored-by: Martin Holst Swende <martin@swende.se>
This commit is contained in:
Péter Szilágyi 2020-12-14 11:27:15 +02:00 committed by GitHub
parent 00d10e610f
commit 017831dd5b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
74 changed files with 8246 additions and 3411 deletions

View File

@ -25,7 +25,6 @@ import (
"github.com/ethereum/go-ethereum/cmd/utils" "github.com/ethereum/go-ethereum/cmd/utils"
"github.com/ethereum/go-ethereum/consensus/ethash" "github.com/ethereum/go-ethereum/consensus/ethash"
"github.com/ethereum/go-ethereum/eth"
"github.com/ethereum/go-ethereum/params" "github.com/ethereum/go-ethereum/params"
"gopkg.in/urfave/cli.v1" "gopkg.in/urfave/cli.v1"
) )
@ -143,7 +142,6 @@ func version(ctx *cli.Context) error {
fmt.Println("Git Commit Date:", gitDate) fmt.Println("Git Commit Date:", gitDate)
} }
fmt.Println("Architecture:", runtime.GOARCH) fmt.Println("Architecture:", runtime.GOARCH)
fmt.Println("Protocol Versions:", eth.ProtocolVersions)
fmt.Println("Go Version:", runtime.Version()) fmt.Println("Go Version:", runtime.Version())
fmt.Println("Operating System:", runtime.GOOS) fmt.Println("Operating System:", runtime.GOOS)
fmt.Printf("GOPATH=%s\n", os.Getenv("GOPATH")) fmt.Printf("GOPATH=%s\n", os.Getenv("GOPATH"))

View File

@ -187,7 +187,7 @@ var (
defaultSyncMode = eth.DefaultConfig.SyncMode defaultSyncMode = eth.DefaultConfig.SyncMode
SyncModeFlag = TextMarshalerFlag{ SyncModeFlag = TextMarshalerFlag{
Name: "syncmode", Name: "syncmode",
Usage: `Blockchain sync mode ("fast", "full", or "light")`, Usage: `Blockchain sync mode ("fast", "full", "snap" or "light")`,
Value: &defaultSyncMode, Value: &defaultSyncMode,
} }
GCModeFlag = cli.StringFlag{ GCModeFlag = cli.StringFlag{
@ -1555,9 +1555,15 @@ func SetEthConfig(ctx *cli.Context, stack *node.Node, cfg *eth.Config) {
cfg.SnapshotCache = ctx.GlobalInt(CacheFlag.Name) * ctx.GlobalInt(CacheSnapshotFlag.Name) / 100 cfg.SnapshotCache = ctx.GlobalInt(CacheFlag.Name) * ctx.GlobalInt(CacheSnapshotFlag.Name) / 100
} }
if !ctx.GlobalIsSet(SnapshotFlag.Name) { if !ctx.GlobalIsSet(SnapshotFlag.Name) {
// If snap-sync is requested, this flag is also required
if cfg.SyncMode == downloader.SnapSync {
log.Info("Snap sync requested, enabling --snapshot")
ctx.Set(SnapshotFlag.Name, "true")
} else {
cfg.TrieCleanCache += cfg.SnapshotCache cfg.TrieCleanCache += cfg.SnapshotCache
cfg.SnapshotCache = 0 // Disabled cfg.SnapshotCache = 0 // Disabled
} }
}
if ctx.GlobalIsSet(DocRootFlag.Name) { if ctx.GlobalIsSet(DocRootFlag.Name) {
cfg.DocRoot = ctx.GlobalString(DocRootFlag.Name) cfg.DocRoot = ctx.GlobalString(DocRootFlag.Name)
} }
@ -1585,16 +1591,15 @@ func SetEthConfig(ctx *cli.Context, stack *node.Node, cfg *eth.Config) {
cfg.RPCTxFeeCap = ctx.GlobalFloat64(RPCGlobalTxFeeCapFlag.Name) cfg.RPCTxFeeCap = ctx.GlobalFloat64(RPCGlobalTxFeeCapFlag.Name)
} }
if ctx.GlobalIsSet(NoDiscoverFlag.Name) { if ctx.GlobalIsSet(NoDiscoverFlag.Name) {
cfg.DiscoveryURLs = []string{} cfg.EthDiscoveryURLs, cfg.SnapDiscoveryURLs = []string{}, []string{}
} else if ctx.GlobalIsSet(DNSDiscoveryFlag.Name) { } else if ctx.GlobalIsSet(DNSDiscoveryFlag.Name) {
urls := ctx.GlobalString(DNSDiscoveryFlag.Name) urls := ctx.GlobalString(DNSDiscoveryFlag.Name)
if urls == "" { if urls == "" {
cfg.DiscoveryURLs = []string{} cfg.EthDiscoveryURLs = []string{}
} else { } else {
cfg.DiscoveryURLs = SplitAndTrim(urls) cfg.EthDiscoveryURLs = SplitAndTrim(urls)
} }
} }
// Override any default configs for hard coded networks. // Override any default configs for hard coded networks.
switch { switch {
case ctx.GlobalBool(LegacyTestnetFlag.Name) || ctx.GlobalBool(RopstenFlag.Name): case ctx.GlobalBool(LegacyTestnetFlag.Name) || ctx.GlobalBool(RopstenFlag.Name):
@ -1676,16 +1681,20 @@ func SetEthConfig(ctx *cli.Context, stack *node.Node, cfg *eth.Config) {
// SetDNSDiscoveryDefaults configures DNS discovery with the given URL if // SetDNSDiscoveryDefaults configures DNS discovery with the given URL if
// no URLs are set. // no URLs are set.
func SetDNSDiscoveryDefaults(cfg *eth.Config, genesis common.Hash) { func SetDNSDiscoveryDefaults(cfg *eth.Config, genesis common.Hash) {
if cfg.DiscoveryURLs != nil { if cfg.EthDiscoveryURLs != nil {
return // already set through flags/config return // already set through flags/config
} }
protocol := "all" protocol := "all"
if cfg.SyncMode == downloader.LightSync { if cfg.SyncMode == downloader.LightSync {
protocol = "les" protocol = "les"
} }
if url := params.KnownDNSNetwork(genesis, protocol); url != "" { if url := params.KnownDNSNetwork(genesis, protocol); url != "" {
cfg.DiscoveryURLs = []string{url} cfg.EthDiscoveryURLs = []string{url}
}
if cfg.SyncMode == downloader.SnapSync {
if url := params.KnownDNSNetwork(genesis, "snap"); url != "" {
cfg.SnapDiscoveryURLs = []string{url}
}
} }
} }

View File

@ -659,12 +659,8 @@ func (bc *BlockChain) CurrentBlock() *types.Block {
return bc.currentBlock.Load().(*types.Block) return bc.currentBlock.Load().(*types.Block)
} }
// Snapshot returns the blockchain snapshot tree. This method is mainly used for // Snapshots returns the blockchain snapshot tree.
// testing, to make it possible to verify the snapshot after execution. func (bc *BlockChain) Snapshots() *snapshot.Tree {
//
// Warning: There are no guarantees about the safety of using the returned 'snap' if the
// blockchain is simultaneously importing blocks, so take care.
func (bc *BlockChain) Snapshot() *snapshot.Tree {
return bc.snaps return bc.snaps
} }

View File

@ -751,7 +751,7 @@ func testSnapshot(t *testing.T, tt *snapshotTest) {
t.Fatalf("Failed to recreate chain: %v", err) t.Fatalf("Failed to recreate chain: %v", err)
} }
chain.InsertChain(newBlocks) chain.InsertChain(newBlocks)
chain.Snapshot().Cap(newBlocks[len(newBlocks)-1].Root(), 0) chain.Snapshots().Cap(newBlocks[len(newBlocks)-1].Root(), 0)
// Simulate the blockchain crash // Simulate the blockchain crash
// Don't call chain.Stop here, so that no snapshot // Don't call chain.Stop here, so that no snapshot

View File

@ -84,6 +84,15 @@ func NewID(config *params.ChainConfig, genesis common.Hash, head uint64) ID {
return ID{Hash: checksumToBytes(hash), Next: next} return ID{Hash: checksumToBytes(hash), Next: next}
} }
// NewIDWithChain calculates the Ethereum fork ID from an existing chain instance.
func NewIDWithChain(chain Blockchain) ID {
return NewID(
chain.Config(),
chain.Genesis().Hash(),
chain.CurrentHeader().Number.Uint64(),
)
}
// NewFilter creates a filter that returns if a fork ID should be rejected or not // NewFilter creates a filter that returns if a fork ID should be rejected or not
// based on the local chain's status. // based on the local chain's status.
func NewFilter(chain Blockchain) Filter { func NewFilter(chain Blockchain) Filter {

View File

@ -175,3 +175,24 @@ func DeleteSnapshotRecoveryNumber(db ethdb.KeyValueWriter) {
log.Crit("Failed to remove snapshot recovery number", "err", err) log.Crit("Failed to remove snapshot recovery number", "err", err)
} }
} }
// ReadSanpshotSyncStatus retrieves the serialized sync status saved at shutdown.
func ReadSanpshotSyncStatus(db ethdb.KeyValueReader) []byte {
data, _ := db.Get(snapshotSyncStatusKey)
return data
}
// WriteSnapshotSyncStatus stores the serialized sync status to save at shutdown.
func WriteSnapshotSyncStatus(db ethdb.KeyValueWriter, status []byte) {
if err := db.Put(snapshotSyncStatusKey, status); err != nil {
log.Crit("Failed to store snapshot sync status", "err", err)
}
}
// DeleteSnapshotSyncStatus deletes the serialized sync status saved at the last
// shutdown
func DeleteSnapshotSyncStatus(db ethdb.KeyValueWriter) {
if err := db.Delete(snapshotSyncStatusKey); err != nil {
log.Crit("Failed to remove snapshot sync status", "err", err)
}
}

View File

@ -57,6 +57,9 @@ var (
// snapshotRecoveryKey tracks the snapshot recovery marker across restarts. // snapshotRecoveryKey tracks the snapshot recovery marker across restarts.
snapshotRecoveryKey = []byte("SnapshotRecovery") snapshotRecoveryKey = []byte("SnapshotRecovery")
// snapshotSyncStatusKey tracks the snapshot sync status across restarts.
snapshotSyncStatusKey = []byte("SnapshotSyncStatus")
// txIndexTailKey tracks the oldest block whose transactions have been indexed. // txIndexTailKey tracks the oldest block whose transactions have been indexed.
txIndexTailKey = []byte("TransactionIndexTail") txIndexTailKey = []byte("TransactionIndexTail")

View File

@ -241,7 +241,7 @@ func (dl *diskLayer) generate(stats *generatorStats) {
if acc.Root != emptyRoot { if acc.Root != emptyRoot {
storeTrie, err := trie.NewSecure(acc.Root, dl.triedb) storeTrie, err := trie.NewSecure(acc.Root, dl.triedb)
if err != nil { if err != nil {
log.Error("Generator failed to access storage trie", "accroot", dl.root, "acchash", common.BytesToHash(accIt.Key), "stroot", acc.Root, "err", err) log.Error("Generator failed to access storage trie", "root", dl.root, "account", accountHash, "stroot", acc.Root, "err", err)
abort := <-dl.genAbort abort := <-dl.genAbort
abort <- stats abort <- stats
return return

View File

@ -314,14 +314,19 @@ func (s *StateDB) GetState(addr common.Address, hash common.Hash) common.Hash {
return common.Hash{} return common.Hash{}
} }
// GetProof returns the MerkleProof for a given Account // GetProof returns the Merkle proof for a given account.
func (s *StateDB) GetProof(a common.Address) ([][]byte, error) { func (s *StateDB) GetProof(addr common.Address) ([][]byte, error) {
return s.GetProofByHash(crypto.Keccak256Hash(addr.Bytes()))
}
// GetProofByHash returns the Merkle proof for a given account.
func (s *StateDB) GetProofByHash(addrHash common.Hash) ([][]byte, error) {
var proof proofList var proof proofList
err := s.trie.Prove(crypto.Keccak256(a.Bytes()), 0, &proof) err := s.trie.Prove(addrHash[:], 0, &proof)
return proof, err return proof, err
} }
// GetStorageProof returns the StorageProof for given key // GetStorageProof returns the Merkle proof for given storage slot.
func (s *StateDB) GetStorageProof(a common.Address, key common.Hash) ([][]byte, error) { func (s *StateDB) GetStorageProof(a common.Address, key common.Hash) ([][]byte, error) {
var proof proofList var proof proofList
trie := s.StorageTrie(a) trie := s.StorageTrie(a)
@ -332,6 +337,17 @@ func (s *StateDB) GetStorageProof(a common.Address, key common.Hash) ([][]byte,
return proof, err return proof, err
} }
// GetStorageProofByHash returns the Merkle proof for given storage slot.
func (s *StateDB) GetStorageProofByHash(a common.Address, key common.Hash) ([][]byte, error) {
var proof proofList
trie := s.StorageTrie(a)
if trie == nil {
return proof, errors.New("storage trie for requested address does not exist")
}
err := trie.Prove(crypto.Keccak256(key.Bytes()), 0, &proof)
return proof, err
}
// GetCommittedState retrieves a value from the given account's committed storage trie. // GetCommittedState retrieves a value from the given account's committed storage trie.
func (s *StateDB) GetCommittedState(addr common.Address, hash common.Hash) common.Hash { func (s *StateDB) GetCommittedState(addr common.Address, hash common.Hash) common.Hash {
stateObject := s.getStateObject(addr) stateObject := s.getStateObject(addr)

View File

@ -56,7 +56,7 @@ func (b *EthAPIBackend) CurrentBlock() *types.Block {
} }
func (b *EthAPIBackend) SetHead(number uint64) { func (b *EthAPIBackend) SetHead(number uint64) {
b.eth.protocolManager.downloader.Cancel() b.eth.handler.downloader.Cancel()
b.eth.blockchain.SetHead(number) b.eth.blockchain.SetHead(number)
} }
@ -272,10 +272,6 @@ func (b *EthAPIBackend) Downloader() *downloader.Downloader {
return b.eth.Downloader() return b.eth.Downloader()
} }
func (b *EthAPIBackend) ProtocolVersion() int {
return b.eth.EthVersion()
}
func (b *EthAPIBackend) SuggestPrice(ctx context.Context) (*big.Int, error) { func (b *EthAPIBackend) SuggestPrice(ctx context.Context) (*big.Int, error) {
return b.gpo.SuggestPrice(ctx) return b.gpo.SuggestPrice(ctx)
} }

View File

@ -57,6 +57,8 @@ func (h resultHash) Swap(i, j int) { h[i], h[j] = h[j], h[i] }
func (h resultHash) Less(i, j int) bool { return bytes.Compare(h[i].Bytes(), h[j].Bytes()) < 0 } func (h resultHash) Less(i, j int) bool { return bytes.Compare(h[i].Bytes(), h[j].Bytes()) < 0 }
func TestAccountRange(t *testing.T) { func TestAccountRange(t *testing.T) {
t.Parallel()
var ( var (
statedb = state.NewDatabaseWithConfig(rawdb.NewMemoryDatabase(), nil) statedb = state.NewDatabaseWithConfig(rawdb.NewMemoryDatabase(), nil)
state, _ = state.New(common.Hash{}, statedb, nil) state, _ = state.New(common.Hash{}, statedb, nil)
@ -126,6 +128,8 @@ func TestAccountRange(t *testing.T) {
} }
func TestEmptyAccountRange(t *testing.T) { func TestEmptyAccountRange(t *testing.T) {
t.Parallel()
var ( var (
statedb = state.NewDatabase(rawdb.NewMemoryDatabase()) statedb = state.NewDatabase(rawdb.NewMemoryDatabase())
state, _ = state.New(common.Hash{}, statedb, nil) state, _ = state.New(common.Hash{}, statedb, nil)
@ -142,6 +146,8 @@ func TestEmptyAccountRange(t *testing.T) {
} }
func TestStorageRangeAt(t *testing.T) { func TestStorageRangeAt(t *testing.T) {
t.Parallel()
// Create a state where account 0x010000... has a few storage entries. // Create a state where account 0x010000... has a few storage entries.
var ( var (
state, _ = state.New(common.Hash{}, state.NewDatabase(rawdb.NewMemoryDatabase()), nil) state, _ = state.New(common.Hash{}, state.NewDatabase(rawdb.NewMemoryDatabase()), nil)

View File

@ -40,6 +40,8 @@ import (
"github.com/ethereum/go-ethereum/eth/downloader" "github.com/ethereum/go-ethereum/eth/downloader"
"github.com/ethereum/go-ethereum/eth/filters" "github.com/ethereum/go-ethereum/eth/filters"
"github.com/ethereum/go-ethereum/eth/gasprice" "github.com/ethereum/go-ethereum/eth/gasprice"
"github.com/ethereum/go-ethereum/eth/protocols/eth"
"github.com/ethereum/go-ethereum/eth/protocols/snap"
"github.com/ethereum/go-ethereum/ethdb" "github.com/ethereum/go-ethereum/ethdb"
"github.com/ethereum/go-ethereum/event" "github.com/ethereum/go-ethereum/event"
"github.com/ethereum/go-ethereum/internal/ethapi" "github.com/ethereum/go-ethereum/internal/ethapi"
@ -48,7 +50,6 @@ import (
"github.com/ethereum/go-ethereum/node" "github.com/ethereum/go-ethereum/node"
"github.com/ethereum/go-ethereum/p2p" "github.com/ethereum/go-ethereum/p2p"
"github.com/ethereum/go-ethereum/p2p/enode" "github.com/ethereum/go-ethereum/p2p/enode"
"github.com/ethereum/go-ethereum/p2p/enr"
"github.com/ethereum/go-ethereum/params" "github.com/ethereum/go-ethereum/params"
"github.com/ethereum/go-ethereum/rlp" "github.com/ethereum/go-ethereum/rlp"
"github.com/ethereum/go-ethereum/rpc" "github.com/ethereum/go-ethereum/rpc"
@ -61,8 +62,9 @@ type Ethereum struct {
// Handlers // Handlers
txPool *core.TxPool txPool *core.TxPool
blockchain *core.BlockChain blockchain *core.BlockChain
protocolManager *ProtocolManager handler *handler
dialCandidates enode.Iterator ethDialCandidates enode.Iterator
snapDialCandidates enode.Iterator
// DB interfaces // DB interfaces
chainDb ethdb.Database // Block chain database chainDb ethdb.Database // Block chain database
@ -145,7 +147,7 @@ func New(stack *node.Node, config *Config) (*Ethereum, error) {
if bcVersion != nil { if bcVersion != nil {
dbVer = fmt.Sprintf("%d", *bcVersion) dbVer = fmt.Sprintf("%d", *bcVersion)
} }
log.Info("Initialising Ethereum protocol", "versions", ProtocolVersions, "network", config.NetworkId, "dbversion", dbVer) log.Info("Initialising Ethereum protocol", "network", config.NetworkId, "dbversion", dbVer)
if !config.SkipBcVersionCheck { if !config.SkipBcVersionCheck {
if bcVersion != nil && *bcVersion > core.BlockChainVersion { if bcVersion != nil && *bcVersion > core.BlockChainVersion {
@ -196,7 +198,17 @@ func New(stack *node.Node, config *Config) (*Ethereum, error) {
if checkpoint == nil { if checkpoint == nil {
checkpoint = params.TrustedCheckpoints[genesisHash] checkpoint = params.TrustedCheckpoints[genesisHash]
} }
if eth.protocolManager, err = NewProtocolManager(chainConfig, checkpoint, config.SyncMode, config.NetworkId, eth.eventMux, eth.txPool, eth.engine, eth.blockchain, chainDb, cacheLimit, config.Whitelist); err != nil { if eth.handler, err = newHandler(&handlerConfig{
Database: chainDb,
Chain: eth.blockchain,
TxPool: eth.txPool,
Network: config.NetworkId,
Sync: config.SyncMode,
BloomCache: uint64(cacheLimit),
EventMux: eth.eventMux,
Checkpoint: checkpoint,
Whitelist: config.Whitelist,
}); err != nil {
return nil, err return nil, err
} }
eth.miner = miner.New(eth, &config.Miner, chainConfig, eth.EventMux(), eth.engine, eth.isLocalBlock) eth.miner = miner.New(eth, &config.Miner, chainConfig, eth.EventMux(), eth.engine, eth.isLocalBlock)
@ -209,13 +221,16 @@ func New(stack *node.Node, config *Config) (*Ethereum, error) {
} }
eth.APIBackend.gpo = gasprice.NewOracle(eth.APIBackend, gpoParams) eth.APIBackend.gpo = gasprice.NewOracle(eth.APIBackend, gpoParams)
eth.dialCandidates, err = eth.setupDiscovery() eth.ethDialCandidates, err = setupDiscovery(eth.config.EthDiscoveryURLs)
if err != nil {
return nil, err
}
eth.snapDialCandidates, err = setupDiscovery(eth.config.SnapDiscoveryURLs)
if err != nil { if err != nil {
return nil, err return nil, err
} }
// Start the RPC service // Start the RPC service
eth.netRPCService = ethapi.NewPublicNetAPI(eth.p2pServer, eth.NetVersion()) eth.netRPCService = ethapi.NewPublicNetAPI(eth.p2pServer)
// Register the backend on the node // Register the backend on the node
stack.RegisterAPIs(eth.APIs()) stack.RegisterAPIs(eth.APIs())
@ -310,7 +325,7 @@ func (s *Ethereum) APIs() []rpc.API {
}, { }, {
Namespace: "eth", Namespace: "eth",
Version: "1.0", Version: "1.0",
Service: downloader.NewPublicDownloaderAPI(s.protocolManager.downloader, s.eventMux), Service: downloader.NewPublicDownloaderAPI(s.handler.downloader, s.eventMux),
Public: true, Public: true,
}, { }, {
Namespace: "miner", Namespace: "miner",
@ -473,7 +488,7 @@ func (s *Ethereum) StartMining(threads int) error {
} }
// If mining is started, we can disable the transaction rejection mechanism // If mining is started, we can disable the transaction rejection mechanism
// introduced to speed sync times. // introduced to speed sync times.
atomic.StoreUint32(&s.protocolManager.acceptTxs, 1) atomic.StoreUint32(&s.handler.acceptTxs, 1)
go s.miner.Start(eb) go s.miner.Start(eb)
} }
@ -504,21 +519,17 @@ func (s *Ethereum) EventMux() *event.TypeMux { return s.eventMux }
func (s *Ethereum) Engine() consensus.Engine { return s.engine } func (s *Ethereum) Engine() consensus.Engine { return s.engine }
func (s *Ethereum) ChainDb() ethdb.Database { return s.chainDb } func (s *Ethereum) ChainDb() ethdb.Database { return s.chainDb }
func (s *Ethereum) IsListening() bool { return true } // Always listening func (s *Ethereum) IsListening() bool { return true } // Always listening
func (s *Ethereum) EthVersion() int { return int(ProtocolVersions[0]) } func (s *Ethereum) Downloader() *downloader.Downloader { return s.handler.downloader }
func (s *Ethereum) NetVersion() uint64 { return s.networkID } func (s *Ethereum) Synced() bool { return atomic.LoadUint32(&s.handler.acceptTxs) == 1 }
func (s *Ethereum) Downloader() *downloader.Downloader { return s.protocolManager.downloader }
func (s *Ethereum) Synced() bool { return atomic.LoadUint32(&s.protocolManager.acceptTxs) == 1 }
func (s *Ethereum) ArchiveMode() bool { return s.config.NoPruning } func (s *Ethereum) ArchiveMode() bool { return s.config.NoPruning }
func (s *Ethereum) BloomIndexer() *core.ChainIndexer { return s.bloomIndexer } func (s *Ethereum) BloomIndexer() *core.ChainIndexer { return s.bloomIndexer }
// Protocols returns all the currently configured // Protocols returns all the currently configured
// network protocols to start. // network protocols to start.
func (s *Ethereum) Protocols() []p2p.Protocol { func (s *Ethereum) Protocols() []p2p.Protocol {
protos := make([]p2p.Protocol, len(ProtocolVersions)) protos := eth.MakeProtocols((*ethHandler)(s.handler), s.networkID, s.ethDialCandidates)
for i, vsn := range ProtocolVersions { if s.config.SnapshotCache > 0 {
protos[i] = s.protocolManager.makeProtocol(vsn) protos = append(protos, snap.MakeProtocols((*snapHandler)(s.handler), s.snapDialCandidates)...)
protos[i].Attributes = []enr.Entry{s.currentEthEntry()}
protos[i].DialCandidates = s.dialCandidates
} }
return protos return protos
} }
@ -526,7 +537,7 @@ func (s *Ethereum) Protocols() []p2p.Protocol {
// Start implements node.Lifecycle, starting all internal goroutines needed by the // Start implements node.Lifecycle, starting all internal goroutines needed by the
// Ethereum protocol implementation. // Ethereum protocol implementation.
func (s *Ethereum) Start() error { func (s *Ethereum) Start() error {
s.startEthEntryUpdate(s.p2pServer.LocalNode()) eth.StartENRUpdater(s.blockchain, s.p2pServer.LocalNode())
// Start the bloom bits servicing goroutines // Start the bloom bits servicing goroutines
s.startBloomHandlers(params.BloomBitsBlocks) s.startBloomHandlers(params.BloomBitsBlocks)
@ -540,7 +551,7 @@ func (s *Ethereum) Start() error {
maxPeers -= s.config.LightPeers maxPeers -= s.config.LightPeers
} }
// Start the networking layer and the light server if requested // Start the networking layer and the light server if requested
s.protocolManager.Start(maxPeers) s.handler.Start(maxPeers)
return nil return nil
} }
@ -548,7 +559,7 @@ func (s *Ethereum) Start() error {
// Ethereum protocol. // Ethereum protocol.
func (s *Ethereum) Stop() error { func (s *Ethereum) Stop() error {
// Stop all the peer-related stuff first. // Stop all the peer-related stuff first.
s.protocolManager.Stop() s.handler.Stop()
// Then stop everything else. // Then stop everything else.
s.bloomIndexer.Close() s.bloomIndexer.Close()
@ -560,5 +571,6 @@ func (s *Ethereum) Stop() error {
rawdb.PopUncleanShutdownMarker(s.chainDb) rawdb.PopUncleanShutdownMarker(s.chainDb)
s.chainDb.Close() s.chainDb.Close()
s.eventMux.Stop() s.eventMux.Stop()
return nil return nil
} }

View File

@ -115,7 +115,8 @@ type Config struct {
// This can be set to list of enrtree:// URLs which will be queried for // This can be set to list of enrtree:// URLs which will be queried for
// for nodes to connect to. // for nodes to connect to.
DiscoveryURLs []string EthDiscoveryURLs []string
SnapDiscoveryURLs []string
NoPruning bool // Whether to disable pruning and flush everything to disk NoPruning bool // Whether to disable pruning and flush everything to disk
NoPrefetch bool // Whether to disable prefetching and only load state on demand NoPrefetch bool // Whether to disable prefetching and only load state on demand

View File

@ -63,11 +63,12 @@ func (eth *Ethereum) currentEthEntry() *ethEntry {
eth.blockchain.CurrentHeader().Number.Uint64())} eth.blockchain.CurrentHeader().Number.Uint64())}
} }
// setupDiscovery creates the node discovery source for the eth protocol. // setupDiscovery creates the node discovery source for the `eth` and `snap`
func (eth *Ethereum) setupDiscovery() (enode.Iterator, error) { // protocols.
if len(eth.config.DiscoveryURLs) == 0 { func setupDiscovery(urls []string) (enode.Iterator, error) {
if len(urls) == 0 {
return nil, nil return nil, nil
} }
client := dnsdisc.NewClient(dnsdisc.Config{}) client := dnsdisc.NewClient(dnsdisc.Config{})
return client.NewIterator(eth.config.DiscoveryURLs...) return client.NewIterator(urls...)
} }

View File

@ -29,6 +29,7 @@ import (
"github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/core/rawdb" "github.com/ethereum/go-ethereum/core/rawdb"
"github.com/ethereum/go-ethereum/core/types" "github.com/ethereum/go-ethereum/core/types"
"github.com/ethereum/go-ethereum/eth/protocols/snap"
"github.com/ethereum/go-ethereum/ethdb" "github.com/ethereum/go-ethereum/ethdb"
"github.com/ethereum/go-ethereum/event" "github.com/ethereum/go-ethereum/event"
"github.com/ethereum/go-ethereum/log" "github.com/ethereum/go-ethereum/log"
@ -38,7 +39,6 @@ import (
) )
var ( var (
MaxHashFetch = 512 // Amount of hashes to be fetched per retrieval request
MaxBlockFetch = 128 // Amount of blocks to be fetched per retrieval request MaxBlockFetch = 128 // Amount of blocks to be fetched per retrieval request
MaxHeaderFetch = 192 // Amount of block headers to be fetched per retrieval request MaxHeaderFetch = 192 // Amount of block headers to be fetched per retrieval request
MaxSkeletonSize = 128 // Number of header fetches to need for a skeleton assembly MaxSkeletonSize = 128 // Number of header fetches to need for a skeleton assembly
@ -89,7 +89,7 @@ var (
errCancelContentProcessing = errors.New("content processing canceled (requested)") errCancelContentProcessing = errors.New("content processing canceled (requested)")
errCanceled = errors.New("syncing canceled (requested)") errCanceled = errors.New("syncing canceled (requested)")
errNoSyncActive = errors.New("no sync active") errNoSyncActive = errors.New("no sync active")
errTooOld = errors.New("peer doesn't speak recent enough protocol version (need version >= 63)") errTooOld = errors.New("peer doesn't speak recent enough protocol version (need version >= 64)")
) )
type Downloader struct { type Downloader struct {
@ -131,20 +131,22 @@ type Downloader struct {
ancientLimit uint64 // The maximum block number which can be regarded as ancient data. ancientLimit uint64 // The maximum block number which can be regarded as ancient data.
// Channels // Channels
headerCh chan dataPack // [eth/62] Channel receiving inbound block headers headerCh chan dataPack // Channel receiving inbound block headers
bodyCh chan dataPack // [eth/62] Channel receiving inbound block bodies bodyCh chan dataPack // Channel receiving inbound block bodies
receiptCh chan dataPack // [eth/63] Channel receiving inbound receipts receiptCh chan dataPack // Channel receiving inbound receipts
bodyWakeCh chan bool // [eth/62] Channel to signal the block body fetcher of new tasks bodyWakeCh chan bool // Channel to signal the block body fetcher of new tasks
receiptWakeCh chan bool // [eth/63] Channel to signal the receipt fetcher of new tasks receiptWakeCh chan bool // Channel to signal the receipt fetcher of new tasks
headerProcCh chan []*types.Header // [eth/62] Channel to feed the header processor new tasks headerProcCh chan []*types.Header // Channel to feed the header processor new tasks
// State sync // State sync
pivotHeader *types.Header // Pivot block header to dynamically push the syncing state root pivotHeader *types.Header // Pivot block header to dynamically push the syncing state root
pivotLock sync.RWMutex // Lock protecting pivot header reads from updates pivotLock sync.RWMutex // Lock protecting pivot header reads from updates
snapSync bool // Whether to run state sync over the snap protocol
SnapSyncer *snap.Syncer // TODO(karalabe): make private! hack for now
stateSyncStart chan *stateSync stateSyncStart chan *stateSync
trackStateReq chan *stateReq trackStateReq chan *stateReq
stateCh chan dataPack // [eth/63] Channel receiving inbound node state data stateCh chan dataPack // Channel receiving inbound node state data
// Cancellation and termination // Cancellation and termination
cancelPeer string // Identifier of the peer currently being used as the master (cancel on drop) cancelPeer string // Identifier of the peer currently being used as the master (cancel on drop)
@ -237,6 +239,7 @@ func New(checkpoint uint64, stateDb ethdb.Database, stateBloom *trie.SyncBloom,
headerProcCh: make(chan []*types.Header, 1), headerProcCh: make(chan []*types.Header, 1),
quitCh: make(chan struct{}), quitCh: make(chan struct{}),
stateCh: make(chan dataPack), stateCh: make(chan dataPack),
SnapSyncer: snap.NewSyncer(stateDb, stateBloom),
stateSyncStart: make(chan *stateSync), stateSyncStart: make(chan *stateSync),
syncStatsState: stateSyncStats{ syncStatsState: stateSyncStats{
processed: rawdb.ReadFastTrieProgress(stateDb), processed: rawdb.ReadFastTrieProgress(stateDb),
@ -286,19 +289,16 @@ func (d *Downloader) Synchronising() bool {
return atomic.LoadInt32(&d.synchronising) > 0 return atomic.LoadInt32(&d.synchronising) > 0
} }
// SyncBloomContains tests if the syncbloom filter contains the given hash:
// - false: the bloom definitely does not contain hash
// - true: the bloom maybe contains hash
//
// While the bloom is being initialized (or is closed), all queries will return true.
func (d *Downloader) SyncBloomContains(hash []byte) bool {
return d.stateBloom == nil || d.stateBloom.Contains(hash)
}
// RegisterPeer injects a new download peer into the set of block source to be // RegisterPeer injects a new download peer into the set of block source to be
// used for fetching hashes and blocks from. // used for fetching hashes and blocks from.
func (d *Downloader) RegisterPeer(id string, version int, peer Peer) error { func (d *Downloader) RegisterPeer(id string, version uint, peer Peer) error {
logger := log.New("peer", id) var logger log.Logger
if len(id) < 16 {
// Tests use short IDs, don't choke on them
logger = log.New("peer", id)
} else {
logger = log.New("peer", id[:16])
}
logger.Trace("Registering sync peer") logger.Trace("Registering sync peer")
if err := d.peers.Register(newPeerConnection(id, version, peer, logger)); err != nil { if err := d.peers.Register(newPeerConnection(id, version, peer, logger)); err != nil {
logger.Error("Failed to register sync peer", "err", err) logger.Error("Failed to register sync peer", "err", err)
@ -310,7 +310,7 @@ func (d *Downloader) RegisterPeer(id string, version int, peer Peer) error {
} }
// RegisterLightPeer injects a light client peer, wrapping it so it appears as a regular peer. // RegisterLightPeer injects a light client peer, wrapping it so it appears as a regular peer.
func (d *Downloader) RegisterLightPeer(id string, version int, peer LightPeer) error { func (d *Downloader) RegisterLightPeer(id string, version uint, peer LightPeer) error {
return d.RegisterPeer(id, version, &lightPeerWrapper{peer}) return d.RegisterPeer(id, version, &lightPeerWrapper{peer})
} }
@ -319,7 +319,13 @@ func (d *Downloader) RegisterLightPeer(id string, version int, peer LightPeer) e
// the queue. // the queue.
func (d *Downloader) UnregisterPeer(id string) error { func (d *Downloader) UnregisterPeer(id string) error {
// Unregister the peer from the active peer set and revoke any fetch tasks // Unregister the peer from the active peer set and revoke any fetch tasks
logger := log.New("peer", id) var logger log.Logger
if len(id) < 16 {
// Tests use short IDs, don't choke on them
logger = log.New("peer", id)
} else {
logger = log.New("peer", id[:16])
}
logger.Trace("Unregistering sync peer") logger.Trace("Unregistering sync peer")
if err := d.peers.Unregister(id); err != nil { if err := d.peers.Unregister(id); err != nil {
logger.Error("Failed to unregister sync peer", "err", err) logger.Error("Failed to unregister sync peer", "err", err)
@ -381,6 +387,16 @@ func (d *Downloader) synchronise(id string, hash common.Hash, td *big.Int, mode
if mode == FullSync && d.stateBloom != nil { if mode == FullSync && d.stateBloom != nil {
d.stateBloom.Close() d.stateBloom.Close()
} }
// If snap sync was requested, create the snap scheduler and switch to fast
// sync mode. Long term we could drop fast sync or merge the two together,
// but until snap becomes prevalent, we should support both. TODO(karalabe).
if mode == SnapSync {
if !d.snapSync {
log.Warn("Enabling snapshot sync prototype")
d.snapSync = true
}
mode = FastSync
}
// Reset the queue, peer set and wake channels to clean any internal leftover state // Reset the queue, peer set and wake channels to clean any internal leftover state
d.queue.Reset(blockCacheMaxItems, blockCacheInitialItems) d.queue.Reset(blockCacheMaxItems, blockCacheInitialItems)
d.peers.Reset() d.peers.Reset()
@ -443,8 +459,8 @@ func (d *Downloader) syncWithPeer(p *peerConnection, hash common.Hash, td *big.I
d.mux.Post(DoneEvent{latest}) d.mux.Post(DoneEvent{latest})
} }
}() }()
if p.version < 63 { if p.version < 64 {
return errTooOld return fmt.Errorf("%w, peer version: %d", errTooOld, p.version)
} }
mode := d.getMode() mode := d.getMode()
@ -1910,27 +1926,53 @@ func (d *Downloader) commitPivotBlock(result *fetchResult) error {
// DeliverHeaders injects a new batch of block headers received from a remote // DeliverHeaders injects a new batch of block headers received from a remote
// node into the download schedule. // node into the download schedule.
func (d *Downloader) DeliverHeaders(id string, headers []*types.Header) (err error) { func (d *Downloader) DeliverHeaders(id string, headers []*types.Header) error {
return d.deliver(id, d.headerCh, &headerPack{id, headers}, headerInMeter, headerDropMeter) return d.deliver(d.headerCh, &headerPack{id, headers}, headerInMeter, headerDropMeter)
} }
// DeliverBodies injects a new batch of block bodies received from a remote node. // DeliverBodies injects a new batch of block bodies received from a remote node.
func (d *Downloader) DeliverBodies(id string, transactions [][]*types.Transaction, uncles [][]*types.Header) (err error) { func (d *Downloader) DeliverBodies(id string, transactions [][]*types.Transaction, uncles [][]*types.Header) error {
return d.deliver(id, d.bodyCh, &bodyPack{id, transactions, uncles}, bodyInMeter, bodyDropMeter) return d.deliver(d.bodyCh, &bodyPack{id, transactions, uncles}, bodyInMeter, bodyDropMeter)
} }
// DeliverReceipts injects a new batch of receipts received from a remote node. // DeliverReceipts injects a new batch of receipts received from a remote node.
func (d *Downloader) DeliverReceipts(id string, receipts [][]*types.Receipt) (err error) { func (d *Downloader) DeliverReceipts(id string, receipts [][]*types.Receipt) error {
return d.deliver(id, d.receiptCh, &receiptPack{id, receipts}, receiptInMeter, receiptDropMeter) return d.deliver(d.receiptCh, &receiptPack{id, receipts}, receiptInMeter, receiptDropMeter)
} }
// DeliverNodeData injects a new batch of node state data received from a remote node. // DeliverNodeData injects a new batch of node state data received from a remote node.
func (d *Downloader) DeliverNodeData(id string, data [][]byte) (err error) { func (d *Downloader) DeliverNodeData(id string, data [][]byte) error {
return d.deliver(id, d.stateCh, &statePack{id, data}, stateInMeter, stateDropMeter) return d.deliver(d.stateCh, &statePack{id, data}, stateInMeter, stateDropMeter)
}
// DeliverSnapPacket is invoked from a peer's message handler when it transmits a
// data packet for the local node to consume.
func (d *Downloader) DeliverSnapPacket(peer *snap.Peer, packet snap.Packet) error {
switch packet := packet.(type) {
case *snap.AccountRangePacket:
hashes, accounts, err := packet.Unpack()
if err != nil {
return err
}
return d.SnapSyncer.OnAccounts(peer, packet.ID, hashes, accounts, packet.Proof)
case *snap.StorageRangesPacket:
hashset, slotset := packet.Unpack()
return d.SnapSyncer.OnStorage(peer, packet.ID, hashset, slotset, packet.Proof)
case *snap.ByteCodesPacket:
return d.SnapSyncer.OnByteCodes(peer, packet.ID, packet.Codes)
case *snap.TrieNodesPacket:
return d.SnapSyncer.OnTrieNodes(peer, packet.ID, packet.Nodes)
default:
return fmt.Errorf("unexpected snap packet type: %T", packet)
}
} }
// deliver injects a new batch of data received from a remote node. // deliver injects a new batch of data received from a remote node.
func (d *Downloader) deliver(id string, destCh chan dataPack, packet dataPack, inMeter, dropMeter metrics.Meter) (err error) { func (d *Downloader) deliver(destCh chan dataPack, packet dataPack, inMeter, dropMeter metrics.Meter) (err error) {
// Update the delivery metrics for both good and failed deliveries // Update the delivery metrics for both good and failed deliveries
inMeter.Mark(int64(packet.Items())) inMeter.Mark(int64(packet.Items()))
defer func() { defer func() {

View File

@ -390,7 +390,7 @@ func (dl *downloadTester) Rollback(hashes []common.Hash) {
} }
// newPeer registers a new block download source into the downloader. // newPeer registers a new block download source into the downloader.
func (dl *downloadTester) newPeer(id string, version int, chain *testChain) error { func (dl *downloadTester) newPeer(id string, version uint, chain *testChain) error {
dl.lock.Lock() dl.lock.Lock()
defer dl.lock.Unlock() defer dl.lock.Unlock()
@ -518,8 +518,6 @@ func assertOwnForkedChain(t *testing.T, tester *downloadTester, common int, leng
// Tests that simple synchronization against a canonical chain works correctly. // Tests that simple synchronization against a canonical chain works correctly.
// In this test common ancestor lookup should be short circuited and not require // In this test common ancestor lookup should be short circuited and not require
// binary searching. // binary searching.
func TestCanonicalSynchronisation63Full(t *testing.T) { testCanonicalSynchronisation(t, 63, FullSync) }
func TestCanonicalSynchronisation63Fast(t *testing.T) { testCanonicalSynchronisation(t, 63, FastSync) }
func TestCanonicalSynchronisation64Full(t *testing.T) { testCanonicalSynchronisation(t, 64, FullSync) } func TestCanonicalSynchronisation64Full(t *testing.T) { testCanonicalSynchronisation(t, 64, FullSync) }
func TestCanonicalSynchronisation64Fast(t *testing.T) { testCanonicalSynchronisation(t, 64, FastSync) } func TestCanonicalSynchronisation64Fast(t *testing.T) { testCanonicalSynchronisation(t, 64, FastSync) }
func TestCanonicalSynchronisation65Full(t *testing.T) { testCanonicalSynchronisation(t, 65, FullSync) } func TestCanonicalSynchronisation65Full(t *testing.T) { testCanonicalSynchronisation(t, 65, FullSync) }
@ -528,7 +526,7 @@ func TestCanonicalSynchronisation65Light(t *testing.T) {
testCanonicalSynchronisation(t, 65, LightSync) testCanonicalSynchronisation(t, 65, LightSync)
} }
func testCanonicalSynchronisation(t *testing.T, protocol int, mode SyncMode) { func testCanonicalSynchronisation(t *testing.T, protocol uint, mode SyncMode) {
t.Parallel() t.Parallel()
tester := newTester() tester := newTester()
@ -547,14 +545,12 @@ func testCanonicalSynchronisation(t *testing.T, protocol int, mode SyncMode) {
// Tests that if a large batch of blocks are being downloaded, it is throttled // Tests that if a large batch of blocks are being downloaded, it is throttled
// until the cached blocks are retrieved. // until the cached blocks are retrieved.
func TestThrottling63Full(t *testing.T) { testThrottling(t, 63, FullSync) }
func TestThrottling63Fast(t *testing.T) { testThrottling(t, 63, FastSync) }
func TestThrottling64Full(t *testing.T) { testThrottling(t, 64, FullSync) } func TestThrottling64Full(t *testing.T) { testThrottling(t, 64, FullSync) }
func TestThrottling64Fast(t *testing.T) { testThrottling(t, 64, FastSync) } func TestThrottling64Fast(t *testing.T) { testThrottling(t, 64, FastSync) }
func TestThrottling65Full(t *testing.T) { testThrottling(t, 65, FullSync) } func TestThrottling65Full(t *testing.T) { testThrottling(t, 65, FullSync) }
func TestThrottling65Fast(t *testing.T) { testThrottling(t, 65, FastSync) } func TestThrottling65Fast(t *testing.T) { testThrottling(t, 65, FastSync) }
func testThrottling(t *testing.T, protocol int, mode SyncMode) { func testThrottling(t *testing.T, protocol uint, mode SyncMode) {
t.Parallel() t.Parallel()
tester := newTester() tester := newTester()
@ -632,15 +628,13 @@ func testThrottling(t *testing.T, protocol int, mode SyncMode) {
// Tests that simple synchronization against a forked chain works correctly. In // Tests that simple synchronization against a forked chain works correctly. In
// this test common ancestor lookup should *not* be short circuited, and a full // this test common ancestor lookup should *not* be short circuited, and a full
// binary search should be executed. // binary search should be executed.
func TestForkedSync63Full(t *testing.T) { testForkedSync(t, 63, FullSync) }
func TestForkedSync63Fast(t *testing.T) { testForkedSync(t, 63, FastSync) }
func TestForkedSync64Full(t *testing.T) { testForkedSync(t, 64, FullSync) } func TestForkedSync64Full(t *testing.T) { testForkedSync(t, 64, FullSync) }
func TestForkedSync64Fast(t *testing.T) { testForkedSync(t, 64, FastSync) } func TestForkedSync64Fast(t *testing.T) { testForkedSync(t, 64, FastSync) }
func TestForkedSync65Full(t *testing.T) { testForkedSync(t, 65, FullSync) } func TestForkedSync65Full(t *testing.T) { testForkedSync(t, 65, FullSync) }
func TestForkedSync65Fast(t *testing.T) { testForkedSync(t, 65, FastSync) } func TestForkedSync65Fast(t *testing.T) { testForkedSync(t, 65, FastSync) }
func TestForkedSync65Light(t *testing.T) { testForkedSync(t, 65, LightSync) } func TestForkedSync65Light(t *testing.T) { testForkedSync(t, 65, LightSync) }
func testForkedSync(t *testing.T, protocol int, mode SyncMode) { func testForkedSync(t *testing.T, protocol uint, mode SyncMode) {
t.Parallel() t.Parallel()
tester := newTester() tester := newTester()
@ -665,15 +659,13 @@ func testForkedSync(t *testing.T, protocol int, mode SyncMode) {
// Tests that synchronising against a much shorter but much heavyer fork works // Tests that synchronising against a much shorter but much heavyer fork works
// corrently and is not dropped. // corrently and is not dropped.
func TestHeavyForkedSync63Full(t *testing.T) { testHeavyForkedSync(t, 63, FullSync) }
func TestHeavyForkedSync63Fast(t *testing.T) { testHeavyForkedSync(t, 63, FastSync) }
func TestHeavyForkedSync64Full(t *testing.T) { testHeavyForkedSync(t, 64, FullSync) } func TestHeavyForkedSync64Full(t *testing.T) { testHeavyForkedSync(t, 64, FullSync) }
func TestHeavyForkedSync64Fast(t *testing.T) { testHeavyForkedSync(t, 64, FastSync) } func TestHeavyForkedSync64Fast(t *testing.T) { testHeavyForkedSync(t, 64, FastSync) }
func TestHeavyForkedSync65Full(t *testing.T) { testHeavyForkedSync(t, 65, FullSync) } func TestHeavyForkedSync65Full(t *testing.T) { testHeavyForkedSync(t, 65, FullSync) }
func TestHeavyForkedSync65Fast(t *testing.T) { testHeavyForkedSync(t, 65, FastSync) } func TestHeavyForkedSync65Fast(t *testing.T) { testHeavyForkedSync(t, 65, FastSync) }
func TestHeavyForkedSync65Light(t *testing.T) { testHeavyForkedSync(t, 65, LightSync) } func TestHeavyForkedSync65Light(t *testing.T) { testHeavyForkedSync(t, 65, LightSync) }
func testHeavyForkedSync(t *testing.T, protocol int, mode SyncMode) { func testHeavyForkedSync(t *testing.T, protocol uint, mode SyncMode) {
t.Parallel() t.Parallel()
tester := newTester() tester := newTester()
@ -700,15 +692,13 @@ func testHeavyForkedSync(t *testing.T, protocol int, mode SyncMode) {
// Tests that chain forks are contained within a certain interval of the current // Tests that chain forks are contained within a certain interval of the current
// chain head, ensuring that malicious peers cannot waste resources by feeding // chain head, ensuring that malicious peers cannot waste resources by feeding
// long dead chains. // long dead chains.
func TestBoundedForkedSync63Full(t *testing.T) { testBoundedForkedSync(t, 63, FullSync) }
func TestBoundedForkedSync63Fast(t *testing.T) { testBoundedForkedSync(t, 63, FastSync) }
func TestBoundedForkedSync64Full(t *testing.T) { testBoundedForkedSync(t, 64, FullSync) } func TestBoundedForkedSync64Full(t *testing.T) { testBoundedForkedSync(t, 64, FullSync) }
func TestBoundedForkedSync64Fast(t *testing.T) { testBoundedForkedSync(t, 64, FastSync) } func TestBoundedForkedSync64Fast(t *testing.T) { testBoundedForkedSync(t, 64, FastSync) }
func TestBoundedForkedSync65Full(t *testing.T) { testBoundedForkedSync(t, 65, FullSync) } func TestBoundedForkedSync65Full(t *testing.T) { testBoundedForkedSync(t, 65, FullSync) }
func TestBoundedForkedSync65Fast(t *testing.T) { testBoundedForkedSync(t, 65, FastSync) } func TestBoundedForkedSync65Fast(t *testing.T) { testBoundedForkedSync(t, 65, FastSync) }
func TestBoundedForkedSync65Light(t *testing.T) { testBoundedForkedSync(t, 65, LightSync) } func TestBoundedForkedSync65Light(t *testing.T) { testBoundedForkedSync(t, 65, LightSync) }
func testBoundedForkedSync(t *testing.T, protocol int, mode SyncMode) { func testBoundedForkedSync(t *testing.T, protocol uint, mode SyncMode) {
t.Parallel() t.Parallel()
tester := newTester() tester := newTester()
@ -734,15 +724,13 @@ func testBoundedForkedSync(t *testing.T, protocol int, mode SyncMode) {
// Tests that chain forks are contained within a certain interval of the current // Tests that chain forks are contained within a certain interval of the current
// chain head for short but heavy forks too. These are a bit special because they // chain head for short but heavy forks too. These are a bit special because they
// take different ancestor lookup paths. // take different ancestor lookup paths.
func TestBoundedHeavyForkedSync63Full(t *testing.T) { testBoundedHeavyForkedSync(t, 63, FullSync) }
func TestBoundedHeavyForkedSync63Fast(t *testing.T) { testBoundedHeavyForkedSync(t, 63, FastSync) }
func TestBoundedHeavyForkedSync64Full(t *testing.T) { testBoundedHeavyForkedSync(t, 64, FullSync) } func TestBoundedHeavyForkedSync64Full(t *testing.T) { testBoundedHeavyForkedSync(t, 64, FullSync) }
func TestBoundedHeavyForkedSync64Fast(t *testing.T) { testBoundedHeavyForkedSync(t, 64, FastSync) } func TestBoundedHeavyForkedSync64Fast(t *testing.T) { testBoundedHeavyForkedSync(t, 64, FastSync) }
func TestBoundedHeavyForkedSync65Full(t *testing.T) { testBoundedHeavyForkedSync(t, 65, FullSync) } func TestBoundedHeavyForkedSync65Full(t *testing.T) { testBoundedHeavyForkedSync(t, 65, FullSync) }
func TestBoundedHeavyForkedSync65Fast(t *testing.T) { testBoundedHeavyForkedSync(t, 65, FastSync) } func TestBoundedHeavyForkedSync65Fast(t *testing.T) { testBoundedHeavyForkedSync(t, 65, FastSync) }
func TestBoundedHeavyForkedSync65Light(t *testing.T) { testBoundedHeavyForkedSync(t, 65, LightSync) } func TestBoundedHeavyForkedSync65Light(t *testing.T) { testBoundedHeavyForkedSync(t, 65, LightSync) }
func testBoundedHeavyForkedSync(t *testing.T, protocol int, mode SyncMode) { func testBoundedHeavyForkedSync(t *testing.T, protocol uint, mode SyncMode) {
t.Parallel() t.Parallel()
tester := newTester() tester := newTester()
@ -786,15 +774,13 @@ func TestInactiveDownloader63(t *testing.T) {
} }
// Tests that a canceled download wipes all previously accumulated state. // Tests that a canceled download wipes all previously accumulated state.
func TestCancel63Full(t *testing.T) { testCancel(t, 63, FullSync) }
func TestCancel63Fast(t *testing.T) { testCancel(t, 63, FastSync) }
func TestCancel64Full(t *testing.T) { testCancel(t, 64, FullSync) } func TestCancel64Full(t *testing.T) { testCancel(t, 64, FullSync) }
func TestCancel64Fast(t *testing.T) { testCancel(t, 64, FastSync) } func TestCancel64Fast(t *testing.T) { testCancel(t, 64, FastSync) }
func TestCancel65Full(t *testing.T) { testCancel(t, 65, FullSync) } func TestCancel65Full(t *testing.T) { testCancel(t, 65, FullSync) }
func TestCancel65Fast(t *testing.T) { testCancel(t, 65, FastSync) } func TestCancel65Fast(t *testing.T) { testCancel(t, 65, FastSync) }
func TestCancel65Light(t *testing.T) { testCancel(t, 65, LightSync) } func TestCancel65Light(t *testing.T) { testCancel(t, 65, LightSync) }
func testCancel(t *testing.T, protocol int, mode SyncMode) { func testCancel(t *testing.T, protocol uint, mode SyncMode) {
t.Parallel() t.Parallel()
tester := newTester() tester := newTester()
@ -819,15 +805,13 @@ func testCancel(t *testing.T, protocol int, mode SyncMode) {
} }
// Tests that synchronisation from multiple peers works as intended (multi thread sanity test). // Tests that synchronisation from multiple peers works as intended (multi thread sanity test).
func TestMultiSynchronisation63Full(t *testing.T) { testMultiSynchronisation(t, 63, FullSync) }
func TestMultiSynchronisation63Fast(t *testing.T) { testMultiSynchronisation(t, 63, FastSync) }
func TestMultiSynchronisation64Full(t *testing.T) { testMultiSynchronisation(t, 64, FullSync) } func TestMultiSynchronisation64Full(t *testing.T) { testMultiSynchronisation(t, 64, FullSync) }
func TestMultiSynchronisation64Fast(t *testing.T) { testMultiSynchronisation(t, 64, FastSync) } func TestMultiSynchronisation64Fast(t *testing.T) { testMultiSynchronisation(t, 64, FastSync) }
func TestMultiSynchronisation65Full(t *testing.T) { testMultiSynchronisation(t, 65, FullSync) } func TestMultiSynchronisation65Full(t *testing.T) { testMultiSynchronisation(t, 65, FullSync) }
func TestMultiSynchronisation65Fast(t *testing.T) { testMultiSynchronisation(t, 65, FastSync) } func TestMultiSynchronisation65Fast(t *testing.T) { testMultiSynchronisation(t, 65, FastSync) }
func TestMultiSynchronisation65Light(t *testing.T) { testMultiSynchronisation(t, 65, LightSync) } func TestMultiSynchronisation65Light(t *testing.T) { testMultiSynchronisation(t, 65, LightSync) }
func testMultiSynchronisation(t *testing.T, protocol int, mode SyncMode) { func testMultiSynchronisation(t *testing.T, protocol uint, mode SyncMode) {
t.Parallel() t.Parallel()
tester := newTester() tester := newTester()
@ -849,15 +833,13 @@ func testMultiSynchronisation(t *testing.T, protocol int, mode SyncMode) {
// Tests that synchronisations behave well in multi-version protocol environments // Tests that synchronisations behave well in multi-version protocol environments
// and not wreak havoc on other nodes in the network. // and not wreak havoc on other nodes in the network.
func TestMultiProtoSynchronisation63Full(t *testing.T) { testMultiProtoSync(t, 63, FullSync) }
func TestMultiProtoSynchronisation63Fast(t *testing.T) { testMultiProtoSync(t, 63, FastSync) }
func TestMultiProtoSynchronisation64Full(t *testing.T) { testMultiProtoSync(t, 64, FullSync) } func TestMultiProtoSynchronisation64Full(t *testing.T) { testMultiProtoSync(t, 64, FullSync) }
func TestMultiProtoSynchronisation64Fast(t *testing.T) { testMultiProtoSync(t, 64, FastSync) } func TestMultiProtoSynchronisation64Fast(t *testing.T) { testMultiProtoSync(t, 64, FastSync) }
func TestMultiProtoSynchronisation65Full(t *testing.T) { testMultiProtoSync(t, 65, FullSync) } func TestMultiProtoSynchronisation65Full(t *testing.T) { testMultiProtoSync(t, 65, FullSync) }
func TestMultiProtoSynchronisation65Fast(t *testing.T) { testMultiProtoSync(t, 65, FastSync) } func TestMultiProtoSynchronisation65Fast(t *testing.T) { testMultiProtoSync(t, 65, FastSync) }
func TestMultiProtoSynchronisation65Light(t *testing.T) { testMultiProtoSync(t, 65, LightSync) } func TestMultiProtoSynchronisation65Light(t *testing.T) { testMultiProtoSync(t, 65, LightSync) }
func testMultiProtoSync(t *testing.T, protocol int, mode SyncMode) { func testMultiProtoSync(t *testing.T, protocol uint, mode SyncMode) {
t.Parallel() t.Parallel()
tester := newTester() tester := newTester()
@ -888,15 +870,13 @@ func testMultiProtoSync(t *testing.T, protocol int, mode SyncMode) {
// Tests that if a block is empty (e.g. header only), no body request should be // Tests that if a block is empty (e.g. header only), no body request should be
// made, and instead the header should be assembled into a whole block in itself. // made, and instead the header should be assembled into a whole block in itself.
func TestEmptyShortCircuit63Full(t *testing.T) { testEmptyShortCircuit(t, 63, FullSync) }
func TestEmptyShortCircuit63Fast(t *testing.T) { testEmptyShortCircuit(t, 63, FastSync) }
func TestEmptyShortCircuit64Full(t *testing.T) { testEmptyShortCircuit(t, 64, FullSync) } func TestEmptyShortCircuit64Full(t *testing.T) { testEmptyShortCircuit(t, 64, FullSync) }
func TestEmptyShortCircuit64Fast(t *testing.T) { testEmptyShortCircuit(t, 64, FastSync) } func TestEmptyShortCircuit64Fast(t *testing.T) { testEmptyShortCircuit(t, 64, FastSync) }
func TestEmptyShortCircuit65Full(t *testing.T) { testEmptyShortCircuit(t, 65, FullSync) } func TestEmptyShortCircuit65Full(t *testing.T) { testEmptyShortCircuit(t, 65, FullSync) }
func TestEmptyShortCircuit65Fast(t *testing.T) { testEmptyShortCircuit(t, 65, FastSync) } func TestEmptyShortCircuit65Fast(t *testing.T) { testEmptyShortCircuit(t, 65, FastSync) }
func TestEmptyShortCircuit65Light(t *testing.T) { testEmptyShortCircuit(t, 65, LightSync) } func TestEmptyShortCircuit65Light(t *testing.T) { testEmptyShortCircuit(t, 65, LightSync) }
func testEmptyShortCircuit(t *testing.T, protocol int, mode SyncMode) { func testEmptyShortCircuit(t *testing.T, protocol uint, mode SyncMode) {
t.Parallel() t.Parallel()
tester := newTester() tester := newTester()
@ -942,15 +922,13 @@ func testEmptyShortCircuit(t *testing.T, protocol int, mode SyncMode) {
// Tests that headers are enqueued continuously, preventing malicious nodes from // Tests that headers are enqueued continuously, preventing malicious nodes from
// stalling the downloader by feeding gapped header chains. // stalling the downloader by feeding gapped header chains.
func TestMissingHeaderAttack63Full(t *testing.T) { testMissingHeaderAttack(t, 63, FullSync) }
func TestMissingHeaderAttack63Fast(t *testing.T) { testMissingHeaderAttack(t, 63, FastSync) }
func TestMissingHeaderAttack64Full(t *testing.T) { testMissingHeaderAttack(t, 64, FullSync) } func TestMissingHeaderAttack64Full(t *testing.T) { testMissingHeaderAttack(t, 64, FullSync) }
func TestMissingHeaderAttack64Fast(t *testing.T) { testMissingHeaderAttack(t, 64, FastSync) } func TestMissingHeaderAttack64Fast(t *testing.T) { testMissingHeaderAttack(t, 64, FastSync) }
func TestMissingHeaderAttack65Full(t *testing.T) { testMissingHeaderAttack(t, 65, FullSync) } func TestMissingHeaderAttack65Full(t *testing.T) { testMissingHeaderAttack(t, 65, FullSync) }
func TestMissingHeaderAttack65Fast(t *testing.T) { testMissingHeaderAttack(t, 65, FastSync) } func TestMissingHeaderAttack65Fast(t *testing.T) { testMissingHeaderAttack(t, 65, FastSync) }
func TestMissingHeaderAttack65Light(t *testing.T) { testMissingHeaderAttack(t, 65, LightSync) } func TestMissingHeaderAttack65Light(t *testing.T) { testMissingHeaderAttack(t, 65, LightSync) }
func testMissingHeaderAttack(t *testing.T, protocol int, mode SyncMode) { func testMissingHeaderAttack(t *testing.T, protocol uint, mode SyncMode) {
t.Parallel() t.Parallel()
tester := newTester() tester := newTester()
@ -974,15 +952,13 @@ func testMissingHeaderAttack(t *testing.T, protocol int, mode SyncMode) {
// Tests that if requested headers are shifted (i.e. first is missing), the queue // Tests that if requested headers are shifted (i.e. first is missing), the queue
// detects the invalid numbering. // detects the invalid numbering.
func TestShiftedHeaderAttack63Full(t *testing.T) { testShiftedHeaderAttack(t, 63, FullSync) }
func TestShiftedHeaderAttack63Fast(t *testing.T) { testShiftedHeaderAttack(t, 63, FastSync) }
func TestShiftedHeaderAttack64Full(t *testing.T) { testShiftedHeaderAttack(t, 64, FullSync) } func TestShiftedHeaderAttack64Full(t *testing.T) { testShiftedHeaderAttack(t, 64, FullSync) }
func TestShiftedHeaderAttack64Fast(t *testing.T) { testShiftedHeaderAttack(t, 64, FastSync) } func TestShiftedHeaderAttack64Fast(t *testing.T) { testShiftedHeaderAttack(t, 64, FastSync) }
func TestShiftedHeaderAttack65Full(t *testing.T) { testShiftedHeaderAttack(t, 65, FullSync) } func TestShiftedHeaderAttack65Full(t *testing.T) { testShiftedHeaderAttack(t, 65, FullSync) }
func TestShiftedHeaderAttack65Fast(t *testing.T) { testShiftedHeaderAttack(t, 65, FastSync) } func TestShiftedHeaderAttack65Fast(t *testing.T) { testShiftedHeaderAttack(t, 65, FastSync) }
func TestShiftedHeaderAttack65Light(t *testing.T) { testShiftedHeaderAttack(t, 65, LightSync) } func TestShiftedHeaderAttack65Light(t *testing.T) { testShiftedHeaderAttack(t, 65, LightSync) }
func testShiftedHeaderAttack(t *testing.T, protocol int, mode SyncMode) { func testShiftedHeaderAttack(t *testing.T, protocol uint, mode SyncMode) {
t.Parallel() t.Parallel()
tester := newTester() tester := newTester()
@ -1011,11 +987,10 @@ func testShiftedHeaderAttack(t *testing.T, protocol int, mode SyncMode) {
// Tests that upon detecting an invalid header, the recent ones are rolled back // Tests that upon detecting an invalid header, the recent ones are rolled back
// for various failure scenarios. Afterwards a full sync is attempted to make // for various failure scenarios. Afterwards a full sync is attempted to make
// sure no state was corrupted. // sure no state was corrupted.
func TestInvalidHeaderRollback63Fast(t *testing.T) { testInvalidHeaderRollback(t, 63, FastSync) }
func TestInvalidHeaderRollback64Fast(t *testing.T) { testInvalidHeaderRollback(t, 64, FastSync) } func TestInvalidHeaderRollback64Fast(t *testing.T) { testInvalidHeaderRollback(t, 64, FastSync) }
func TestInvalidHeaderRollback65Fast(t *testing.T) { testInvalidHeaderRollback(t, 65, FastSync) } func TestInvalidHeaderRollback65Fast(t *testing.T) { testInvalidHeaderRollback(t, 65, FastSync) }
func testInvalidHeaderRollback(t *testing.T, protocol int, mode SyncMode) { func testInvalidHeaderRollback(t *testing.T, protocol uint, mode SyncMode) {
t.Parallel() t.Parallel()
tester := newTester() tester := newTester()
@ -1103,15 +1078,13 @@ func testInvalidHeaderRollback(t *testing.T, protocol int, mode SyncMode) {
// Tests that a peer advertising a high TD doesn't get to stall the downloader // Tests that a peer advertising a high TD doesn't get to stall the downloader
// afterwards by not sending any useful hashes. // afterwards by not sending any useful hashes.
func TestHighTDStarvationAttack63Full(t *testing.T) { testHighTDStarvationAttack(t, 63, FullSync) }
func TestHighTDStarvationAttack63Fast(t *testing.T) { testHighTDStarvationAttack(t, 63, FastSync) }
func TestHighTDStarvationAttack64Full(t *testing.T) { testHighTDStarvationAttack(t, 64, FullSync) } func TestHighTDStarvationAttack64Full(t *testing.T) { testHighTDStarvationAttack(t, 64, FullSync) }
func TestHighTDStarvationAttack64Fast(t *testing.T) { testHighTDStarvationAttack(t, 64, FastSync) } func TestHighTDStarvationAttack64Fast(t *testing.T) { testHighTDStarvationAttack(t, 64, FastSync) }
func TestHighTDStarvationAttack65Full(t *testing.T) { testHighTDStarvationAttack(t, 65, FullSync) } func TestHighTDStarvationAttack65Full(t *testing.T) { testHighTDStarvationAttack(t, 65, FullSync) }
func TestHighTDStarvationAttack65Fast(t *testing.T) { testHighTDStarvationAttack(t, 65, FastSync) } func TestHighTDStarvationAttack65Fast(t *testing.T) { testHighTDStarvationAttack(t, 65, FastSync) }
func TestHighTDStarvationAttack65Light(t *testing.T) { testHighTDStarvationAttack(t, 65, LightSync) } func TestHighTDStarvationAttack65Light(t *testing.T) { testHighTDStarvationAttack(t, 65, LightSync) }
func testHighTDStarvationAttack(t *testing.T, protocol int, mode SyncMode) { func testHighTDStarvationAttack(t *testing.T, protocol uint, mode SyncMode) {
t.Parallel() t.Parallel()
tester := newTester() tester := newTester()
@ -1125,11 +1098,10 @@ func testHighTDStarvationAttack(t *testing.T, protocol int, mode SyncMode) {
} }
// Tests that misbehaving peers are disconnected, whilst behaving ones are not. // Tests that misbehaving peers are disconnected, whilst behaving ones are not.
func TestBlockHeaderAttackerDropping63(t *testing.T) { testBlockHeaderAttackerDropping(t, 63) }
func TestBlockHeaderAttackerDropping64(t *testing.T) { testBlockHeaderAttackerDropping(t, 64) } func TestBlockHeaderAttackerDropping64(t *testing.T) { testBlockHeaderAttackerDropping(t, 64) }
func TestBlockHeaderAttackerDropping65(t *testing.T) { testBlockHeaderAttackerDropping(t, 65) } func TestBlockHeaderAttackerDropping65(t *testing.T) { testBlockHeaderAttackerDropping(t, 65) }
func testBlockHeaderAttackerDropping(t *testing.T, protocol int) { func testBlockHeaderAttackerDropping(t *testing.T, protocol uint) {
t.Parallel() t.Parallel()
// Define the disconnection requirement for individual hash fetch errors // Define the disconnection requirement for individual hash fetch errors
@ -1179,15 +1151,13 @@ func testBlockHeaderAttackerDropping(t *testing.T, protocol int) {
// Tests that synchronisation progress (origin block number, current block number // Tests that synchronisation progress (origin block number, current block number
// and highest block number) is tracked and updated correctly. // and highest block number) is tracked and updated correctly.
func TestSyncProgress63Full(t *testing.T) { testSyncProgress(t, 63, FullSync) }
func TestSyncProgress63Fast(t *testing.T) { testSyncProgress(t, 63, FastSync) }
func TestSyncProgress64Full(t *testing.T) { testSyncProgress(t, 64, FullSync) } func TestSyncProgress64Full(t *testing.T) { testSyncProgress(t, 64, FullSync) }
func TestSyncProgress64Fast(t *testing.T) { testSyncProgress(t, 64, FastSync) } func TestSyncProgress64Fast(t *testing.T) { testSyncProgress(t, 64, FastSync) }
func TestSyncProgress65Full(t *testing.T) { testSyncProgress(t, 65, FullSync) } func TestSyncProgress65Full(t *testing.T) { testSyncProgress(t, 65, FullSync) }
func TestSyncProgress65Fast(t *testing.T) { testSyncProgress(t, 65, FastSync) } func TestSyncProgress65Fast(t *testing.T) { testSyncProgress(t, 65, FastSync) }
func TestSyncProgress65Light(t *testing.T) { testSyncProgress(t, 65, LightSync) } func TestSyncProgress65Light(t *testing.T) { testSyncProgress(t, 65, LightSync) }
func testSyncProgress(t *testing.T, protocol int, mode SyncMode) { func testSyncProgress(t *testing.T, protocol uint, mode SyncMode) {
t.Parallel() t.Parallel()
tester := newTester() tester := newTester()
@ -1263,21 +1233,19 @@ func checkProgress(t *testing.T, d *Downloader, stage string, want ethereum.Sync
// Tests that synchronisation progress (origin block number and highest block // Tests that synchronisation progress (origin block number and highest block
// number) is tracked and updated correctly in case of a fork (or manual head // number) is tracked and updated correctly in case of a fork (or manual head
// revertal). // revertal).
func TestForkedSyncProgress63Full(t *testing.T) { testForkedSyncProgress(t, 63, FullSync) }
func TestForkedSyncProgress63Fast(t *testing.T) { testForkedSyncProgress(t, 63, FastSync) }
func TestForkedSyncProgress64Full(t *testing.T) { testForkedSyncProgress(t, 64, FullSync) } func TestForkedSyncProgress64Full(t *testing.T) { testForkedSyncProgress(t, 64, FullSync) }
func TestForkedSyncProgress64Fast(t *testing.T) { testForkedSyncProgress(t, 64, FastSync) } func TestForkedSyncProgress64Fast(t *testing.T) { testForkedSyncProgress(t, 64, FastSync) }
func TestForkedSyncProgress65Full(t *testing.T) { testForkedSyncProgress(t, 65, FullSync) } func TestForkedSyncProgress65Full(t *testing.T) { testForkedSyncProgress(t, 65, FullSync) }
func TestForkedSyncProgress65Fast(t *testing.T) { testForkedSyncProgress(t, 65, FastSync) } func TestForkedSyncProgress65Fast(t *testing.T) { testForkedSyncProgress(t, 65, FastSync) }
func TestForkedSyncProgress65Light(t *testing.T) { testForkedSyncProgress(t, 65, LightSync) } func TestForkedSyncProgress65Light(t *testing.T) { testForkedSyncProgress(t, 65, LightSync) }
func testForkedSyncProgress(t *testing.T, protocol int, mode SyncMode) { func testForkedSyncProgress(t *testing.T, protocol uint, mode SyncMode) {
t.Parallel() t.Parallel()
tester := newTester() tester := newTester()
defer tester.terminate() defer tester.terminate()
chainA := testChainForkLightA.shorten(testChainBase.len() + MaxHashFetch) chainA := testChainForkLightA.shorten(testChainBase.len() + MaxHeaderFetch)
chainB := testChainForkLightB.shorten(testChainBase.len() + MaxHashFetch) chainB := testChainForkLightB.shorten(testChainBase.len() + MaxHeaderFetch)
// Set a sync init hook to catch progress changes // Set a sync init hook to catch progress changes
starting := make(chan struct{}) starting := make(chan struct{})
@ -1339,15 +1307,13 @@ func testForkedSyncProgress(t *testing.T, protocol int, mode SyncMode) {
// Tests that if synchronisation is aborted due to some failure, then the progress // Tests that if synchronisation is aborted due to some failure, then the progress
// origin is not updated in the next sync cycle, as it should be considered the // origin is not updated in the next sync cycle, as it should be considered the
// continuation of the previous sync and not a new instance. // continuation of the previous sync and not a new instance.
func TestFailedSyncProgress63Full(t *testing.T) { testFailedSyncProgress(t, 63, FullSync) }
func TestFailedSyncProgress63Fast(t *testing.T) { testFailedSyncProgress(t, 63, FastSync) }
func TestFailedSyncProgress64Full(t *testing.T) { testFailedSyncProgress(t, 64, FullSync) } func TestFailedSyncProgress64Full(t *testing.T) { testFailedSyncProgress(t, 64, FullSync) }
func TestFailedSyncProgress64Fast(t *testing.T) { testFailedSyncProgress(t, 64, FastSync) } func TestFailedSyncProgress64Fast(t *testing.T) { testFailedSyncProgress(t, 64, FastSync) }
func TestFailedSyncProgress65Full(t *testing.T) { testFailedSyncProgress(t, 65, FullSync) } func TestFailedSyncProgress65Full(t *testing.T) { testFailedSyncProgress(t, 65, FullSync) }
func TestFailedSyncProgress65Fast(t *testing.T) { testFailedSyncProgress(t, 65, FastSync) } func TestFailedSyncProgress65Fast(t *testing.T) { testFailedSyncProgress(t, 65, FastSync) }
func TestFailedSyncProgress65Light(t *testing.T) { testFailedSyncProgress(t, 65, LightSync) } func TestFailedSyncProgress65Light(t *testing.T) { testFailedSyncProgress(t, 65, LightSync) }
func testFailedSyncProgress(t *testing.T, protocol int, mode SyncMode) { func testFailedSyncProgress(t *testing.T, protocol uint, mode SyncMode) {
t.Parallel() t.Parallel()
tester := newTester() tester := newTester()
@ -1412,15 +1378,13 @@ func testFailedSyncProgress(t *testing.T, protocol int, mode SyncMode) {
// Tests that if an attacker fakes a chain height, after the attack is detected, // Tests that if an attacker fakes a chain height, after the attack is detected,
// the progress height is successfully reduced at the next sync invocation. // the progress height is successfully reduced at the next sync invocation.
func TestFakedSyncProgress63Full(t *testing.T) { testFakedSyncProgress(t, 63, FullSync) }
func TestFakedSyncProgress63Fast(t *testing.T) { testFakedSyncProgress(t, 63, FastSync) }
func TestFakedSyncProgress64Full(t *testing.T) { testFakedSyncProgress(t, 64, FullSync) } func TestFakedSyncProgress64Full(t *testing.T) { testFakedSyncProgress(t, 64, FullSync) }
func TestFakedSyncProgress64Fast(t *testing.T) { testFakedSyncProgress(t, 64, FastSync) } func TestFakedSyncProgress64Fast(t *testing.T) { testFakedSyncProgress(t, 64, FastSync) }
func TestFakedSyncProgress65Full(t *testing.T) { testFakedSyncProgress(t, 65, FullSync) } func TestFakedSyncProgress65Full(t *testing.T) { testFakedSyncProgress(t, 65, FullSync) }
func TestFakedSyncProgress65Fast(t *testing.T) { testFakedSyncProgress(t, 65, FastSync) } func TestFakedSyncProgress65Fast(t *testing.T) { testFakedSyncProgress(t, 65, FastSync) }
func TestFakedSyncProgress65Light(t *testing.T) { testFakedSyncProgress(t, 65, LightSync) } func TestFakedSyncProgress65Light(t *testing.T) { testFakedSyncProgress(t, 65, LightSync) }
func testFakedSyncProgress(t *testing.T, protocol int, mode SyncMode) { func testFakedSyncProgress(t *testing.T, protocol uint, mode SyncMode) {
t.Parallel() t.Parallel()
tester := newTester() tester := newTester()
@ -1489,31 +1453,15 @@ func testFakedSyncProgress(t *testing.T, protocol int, mode SyncMode) {
// This test reproduces an issue where unexpected deliveries would // This test reproduces an issue where unexpected deliveries would
// block indefinitely if they arrived at the right time. // block indefinitely if they arrived at the right time.
func TestDeliverHeadersHang(t *testing.T) { func TestDeliverHeadersHang64Full(t *testing.T) { testDeliverHeadersHang(t, 64, FullSync) }
func TestDeliverHeadersHang64Fast(t *testing.T) { testDeliverHeadersHang(t, 64, FastSync) }
func TestDeliverHeadersHang65Full(t *testing.T) { testDeliverHeadersHang(t, 65, FullSync) }
func TestDeliverHeadersHang65Fast(t *testing.T) { testDeliverHeadersHang(t, 65, FastSync) }
func TestDeliverHeadersHang65Light(t *testing.T) { testDeliverHeadersHang(t, 65, LightSync) }
func testDeliverHeadersHang(t *testing.T, protocol uint, mode SyncMode) {
t.Parallel() t.Parallel()
testCases := []struct {
protocol int
syncMode SyncMode
}{
{63, FullSync},
{63, FastSync},
{64, FullSync},
{64, FastSync},
{64, LightSync},
{65, FullSync},
{65, FastSync},
{65, LightSync},
}
for _, tc := range testCases {
t.Run(fmt.Sprintf("protocol %d mode %v", tc.protocol, tc.syncMode), func(t *testing.T) {
t.Parallel()
testDeliverHeadersHang(t, tc.protocol, tc.syncMode)
})
}
}
func testDeliverHeadersHang(t *testing.T, protocol int, mode SyncMode) {
master := newTester() master := newTester()
defer master.terminate() defer master.terminate()
chain := testChainBase.shorten(15) chain := testChainBase.shorten(15)
@ -1664,15 +1612,13 @@ func TestRemoteHeaderRequestSpan(t *testing.T) {
// Tests that peers below a pre-configured checkpoint block are prevented from // Tests that peers below a pre-configured checkpoint block are prevented from
// being fast-synced from, avoiding potential cheap eclipse attacks. // being fast-synced from, avoiding potential cheap eclipse attacks.
func TestCheckpointEnforcement63Full(t *testing.T) { testCheckpointEnforcement(t, 63, FullSync) }
func TestCheckpointEnforcement63Fast(t *testing.T) { testCheckpointEnforcement(t, 63, FastSync) }
func TestCheckpointEnforcement64Full(t *testing.T) { testCheckpointEnforcement(t, 64, FullSync) } func TestCheckpointEnforcement64Full(t *testing.T) { testCheckpointEnforcement(t, 64, FullSync) }
func TestCheckpointEnforcement64Fast(t *testing.T) { testCheckpointEnforcement(t, 64, FastSync) } func TestCheckpointEnforcement64Fast(t *testing.T) { testCheckpointEnforcement(t, 64, FastSync) }
func TestCheckpointEnforcement65Full(t *testing.T) { testCheckpointEnforcement(t, 65, FullSync) } func TestCheckpointEnforcement65Full(t *testing.T) { testCheckpointEnforcement(t, 65, FullSync) }
func TestCheckpointEnforcement65Fast(t *testing.T) { testCheckpointEnforcement(t, 65, FastSync) } func TestCheckpointEnforcement65Fast(t *testing.T) { testCheckpointEnforcement(t, 65, FastSync) }
func TestCheckpointEnforcement65Light(t *testing.T) { testCheckpointEnforcement(t, 65, LightSync) } func TestCheckpointEnforcement65Light(t *testing.T) { testCheckpointEnforcement(t, 65, LightSync) }
func testCheckpointEnforcement(t *testing.T, protocol int, mode SyncMode) { func testCheckpointEnforcement(t *testing.T, protocol uint, mode SyncMode) {
t.Parallel() t.Parallel()
// Create a new tester with a particular hard coded checkpoint block // Create a new tester with a particular hard coded checkpoint block

View File

@ -24,7 +24,8 @@ type SyncMode uint32
const ( const (
FullSync SyncMode = iota // Synchronise the entire blockchain history from full blocks FullSync SyncMode = iota // Synchronise the entire blockchain history from full blocks
FastSync // Quickly download the headers, full sync only at the chain head FastSync // Quickly download the headers, full sync only at the chain
SnapSync // Download the chain and the state via compact snashots
LightSync // Download only the headers and terminate afterwards LightSync // Download only the headers and terminate afterwards
) )
@ -39,6 +40,8 @@ func (mode SyncMode) String() string {
return "full" return "full"
case FastSync: case FastSync:
return "fast" return "fast"
case SnapSync:
return "snap"
case LightSync: case LightSync:
return "light" return "light"
default: default:
@ -52,6 +55,8 @@ func (mode SyncMode) MarshalText() ([]byte, error) {
return []byte("full"), nil return []byte("full"), nil
case FastSync: case FastSync:
return []byte("fast"), nil return []byte("fast"), nil
case SnapSync:
return []byte("snap"), nil
case LightSync: case LightSync:
return []byte("light"), nil return []byte("light"), nil
default: default:
@ -65,6 +70,8 @@ func (mode *SyncMode) UnmarshalText(text []byte) error {
*mode = FullSync *mode = FullSync
case "fast": case "fast":
*mode = FastSync *mode = FastSync
case "snap":
*mode = SnapSync
case "light": case "light":
*mode = LightSync *mode = LightSync
default: default:

View File

@ -69,7 +69,7 @@ type peerConnection struct {
peer Peer peer Peer
version int // Eth protocol version number to switch strategies version uint // Eth protocol version number to switch strategies
log log.Logger // Contextual logger to add extra infos to peer logs log log.Logger // Contextual logger to add extra infos to peer logs
lock sync.RWMutex lock sync.RWMutex
} }
@ -112,7 +112,7 @@ func (w *lightPeerWrapper) RequestNodeData([]common.Hash) error {
} }
// newPeerConnection creates a new downloader peer. // newPeerConnection creates a new downloader peer.
func newPeerConnection(id string, version int, peer Peer, logger log.Logger) *peerConnection { func newPeerConnection(id string, version uint, peer Peer, logger log.Logger) *peerConnection {
return &peerConnection{ return &peerConnection{
id: id, id: id,
lacking: make(map[common.Hash]struct{}), lacking: make(map[common.Hash]struct{}),
@ -457,7 +457,7 @@ func (ps *peerSet) HeaderIdlePeers() ([]*peerConnection, int) {
defer p.lock.RUnlock() defer p.lock.RUnlock()
return p.headerThroughput return p.headerThroughput
} }
return ps.idlePeers(63, 65, idle, throughput) return ps.idlePeers(64, 65, idle, throughput)
} }
// BodyIdlePeers retrieves a flat list of all the currently body-idle peers within // BodyIdlePeers retrieves a flat list of all the currently body-idle peers within
@ -471,7 +471,7 @@ func (ps *peerSet) BodyIdlePeers() ([]*peerConnection, int) {
defer p.lock.RUnlock() defer p.lock.RUnlock()
return p.blockThroughput return p.blockThroughput
} }
return ps.idlePeers(63, 65, idle, throughput) return ps.idlePeers(64, 65, idle, throughput)
} }
// ReceiptIdlePeers retrieves a flat list of all the currently receipt-idle peers // ReceiptIdlePeers retrieves a flat list of all the currently receipt-idle peers
@ -485,7 +485,7 @@ func (ps *peerSet) ReceiptIdlePeers() ([]*peerConnection, int) {
defer p.lock.RUnlock() defer p.lock.RUnlock()
return p.receiptThroughput return p.receiptThroughput
} }
return ps.idlePeers(63, 65, idle, throughput) return ps.idlePeers(64, 65, idle, throughput)
} }
// NodeDataIdlePeers retrieves a flat list of all the currently node-data-idle // NodeDataIdlePeers retrieves a flat list of all the currently node-data-idle
@ -499,13 +499,13 @@ func (ps *peerSet) NodeDataIdlePeers() ([]*peerConnection, int) {
defer p.lock.RUnlock() defer p.lock.RUnlock()
return p.stateThroughput return p.stateThroughput
} }
return ps.idlePeers(63, 65, idle, throughput) return ps.idlePeers(64, 65, idle, throughput)
} }
// idlePeers retrieves a flat list of all currently idle peers satisfying the // idlePeers retrieves a flat list of all currently idle peers satisfying the
// protocol version constraints, using the provided function to check idleness. // protocol version constraints, using the provided function to check idleness.
// The resulting set of peers are sorted by their measure throughput. // The resulting set of peers are sorted by their measure throughput.
func (ps *peerSet) idlePeers(minProtocol, maxProtocol int, idleCheck func(*peerConnection) bool, throughput func(*peerConnection) float64) ([]*peerConnection, int) { func (ps *peerSet) idlePeers(minProtocol, maxProtocol uint, idleCheck func(*peerConnection) bool, throughput func(*peerConnection) float64) ([]*peerConnection, int) {
ps.lock.RLock() ps.lock.RLock()
defer ps.lock.RUnlock() defer ps.lock.RUnlock()

View File

@ -113,24 +113,24 @@ type queue struct {
mode SyncMode // Synchronisation mode to decide on the block parts to schedule for fetching mode SyncMode // Synchronisation mode to decide on the block parts to schedule for fetching
// Headers are "special", they download in batches, supported by a skeleton chain // Headers are "special", they download in batches, supported by a skeleton chain
headerHead common.Hash // [eth/62] Hash of the last queued header to verify order headerHead common.Hash // Hash of the last queued header to verify order
headerTaskPool map[uint64]*types.Header // [eth/62] Pending header retrieval tasks, mapping starting indexes to skeleton headers headerTaskPool map[uint64]*types.Header // Pending header retrieval tasks, mapping starting indexes to skeleton headers
headerTaskQueue *prque.Prque // [eth/62] Priority queue of the skeleton indexes to fetch the filling headers for headerTaskQueue *prque.Prque // Priority queue of the skeleton indexes to fetch the filling headers for
headerPeerMiss map[string]map[uint64]struct{} // [eth/62] Set of per-peer header batches known to be unavailable headerPeerMiss map[string]map[uint64]struct{} // Set of per-peer header batches known to be unavailable
headerPendPool map[string]*fetchRequest // [eth/62] Currently pending header retrieval operations headerPendPool map[string]*fetchRequest // Currently pending header retrieval operations
headerResults []*types.Header // [eth/62] Result cache accumulating the completed headers headerResults []*types.Header // Result cache accumulating the completed headers
headerProced int // [eth/62] Number of headers already processed from the results headerProced int // Number of headers already processed from the results
headerOffset uint64 // [eth/62] Number of the first header in the result cache headerOffset uint64 // Number of the first header in the result cache
headerContCh chan bool // [eth/62] Channel to notify when header download finishes headerContCh chan bool // Channel to notify when header download finishes
// All data retrievals below are based on an already assembles header chain // All data retrievals below are based on an already assembles header chain
blockTaskPool map[common.Hash]*types.Header // [eth/62] Pending block (body) retrieval tasks, mapping hashes to headers blockTaskPool map[common.Hash]*types.Header // Pending block (body) retrieval tasks, mapping hashes to headers
blockTaskQueue *prque.Prque // [eth/62] Priority queue of the headers to fetch the blocks (bodies) for blockTaskQueue *prque.Prque // Priority queue of the headers to fetch the blocks (bodies) for
blockPendPool map[string]*fetchRequest // [eth/62] Currently pending block (body) retrieval operations blockPendPool map[string]*fetchRequest // Currently pending block (body) retrieval operations
receiptTaskPool map[common.Hash]*types.Header // [eth/63] Pending receipt retrieval tasks, mapping hashes to headers receiptTaskPool map[common.Hash]*types.Header // Pending receipt retrieval tasks, mapping hashes to headers
receiptTaskQueue *prque.Prque // [eth/63] Priority queue of the headers to fetch the receipts for receiptTaskQueue *prque.Prque // Priority queue of the headers to fetch the receipts for
receiptPendPool map[string]*fetchRequest // [eth/63] Currently pending receipt retrieval operations receiptPendPool map[string]*fetchRequest // Currently pending receipt retrieval operations
resultCache *resultStore // Downloaded but not yet delivered fetch results resultCache *resultStore // Downloaded but not yet delivered fetch results
resultSize common.StorageSize // Approximate size of a block (exponential moving average) resultSize common.StorageSize // Approximate size of a block (exponential moving average)
@ -690,6 +690,13 @@ func (q *queue) DeliverHeaders(id string, headers []*types.Header, headerProcCh
q.lock.Lock() q.lock.Lock()
defer q.lock.Unlock() defer q.lock.Unlock()
var logger log.Logger
if len(id) < 16 {
// Tests use short IDs, don't choke on them
logger = log.New("peer", id)
} else {
logger = log.New("peer", id[:16])
}
// Short circuit if the data was never requested // Short circuit if the data was never requested
request := q.headerPendPool[id] request := q.headerPendPool[id]
if request == nil { if request == nil {
@ -704,10 +711,10 @@ func (q *queue) DeliverHeaders(id string, headers []*types.Header, headerProcCh
accepted := len(headers) == MaxHeaderFetch accepted := len(headers) == MaxHeaderFetch
if accepted { if accepted {
if headers[0].Number.Uint64() != request.From { if headers[0].Number.Uint64() != request.From {
log.Trace("First header broke chain ordering", "peer", id, "number", headers[0].Number, "hash", headers[0].Hash(), request.From) logger.Trace("First header broke chain ordering", "number", headers[0].Number, "hash", headers[0].Hash(), "expected", request.From)
accepted = false accepted = false
} else if headers[len(headers)-1].Hash() != target { } else if headers[len(headers)-1].Hash() != target {
log.Trace("Last header broke skeleton structure ", "peer", id, "number", headers[len(headers)-1].Number, "hash", headers[len(headers)-1].Hash(), "expected", target) logger.Trace("Last header broke skeleton structure ", "number", headers[len(headers)-1].Number, "hash", headers[len(headers)-1].Hash(), "expected", target)
accepted = false accepted = false
} }
} }
@ -716,12 +723,12 @@ func (q *queue) DeliverHeaders(id string, headers []*types.Header, headerProcCh
for i, header := range headers[1:] { for i, header := range headers[1:] {
hash := header.Hash() hash := header.Hash()
if want := request.From + 1 + uint64(i); header.Number.Uint64() != want { if want := request.From + 1 + uint64(i); header.Number.Uint64() != want {
log.Warn("Header broke chain ordering", "peer", id, "number", header.Number, "hash", hash, "expected", want) logger.Warn("Header broke chain ordering", "number", header.Number, "hash", hash, "expected", want)
accepted = false accepted = false
break break
} }
if parentHash != header.ParentHash { if parentHash != header.ParentHash {
log.Warn("Header broke chain ancestry", "peer", id, "number", header.Number, "hash", hash) logger.Warn("Header broke chain ancestry", "number", header.Number, "hash", hash)
accepted = false accepted = false
break break
} }
@ -731,7 +738,7 @@ func (q *queue) DeliverHeaders(id string, headers []*types.Header, headerProcCh
} }
// If the batch of headers wasn't accepted, mark as unavailable // If the batch of headers wasn't accepted, mark as unavailable
if !accepted { if !accepted {
log.Trace("Skeleton filling not accepted", "peer", id, "from", request.From) logger.Trace("Skeleton filling not accepted", "from", request.From)
miss := q.headerPeerMiss[id] miss := q.headerPeerMiss[id]
if miss == nil { if miss == nil {
@ -758,7 +765,7 @@ func (q *queue) DeliverHeaders(id string, headers []*types.Header, headerProcCh
select { select {
case headerProcCh <- process: case headerProcCh <- process:
log.Trace("Pre-scheduled new headers", "peer", id, "count", len(process), "from", process[0].Number) logger.Trace("Pre-scheduled new headers", "count", len(process), "from", process[0].Number)
q.headerProced += len(process) q.headerProced += len(process)
default: default:
} }

View File

@ -101,8 +101,16 @@ func (d *Downloader) runStateSync(s *stateSync) *stateSync {
finished []*stateReq // Completed or failed requests finished []*stateReq // Completed or failed requests
timeout = make(chan *stateReq) // Timed out active requests timeout = make(chan *stateReq) // Timed out active requests
) )
// Run the state sync.
log.Trace("State sync starting", "root", s.root) log.Trace("State sync starting", "root", s.root)
defer func() {
// Cancel active request timers on exit. Also set peers to idle so they're
// available for the next sync.
for _, req := range active {
req.timer.Stop()
req.peer.SetNodeDataIdle(int(req.nItems), time.Now())
}
}()
go s.run() go s.run()
defer s.Cancel() defer s.Cancel()
@ -252,6 +260,7 @@ func (d *Downloader) spindownStateSync(active map[string]*stateReq, finished []*
type stateSync struct { type stateSync struct {
d *Downloader // Downloader instance to access and manage current peerset d *Downloader // Downloader instance to access and manage current peerset
root common.Hash // State root currently being synced
sched *trie.Sync // State trie sync scheduler defining the tasks sched *trie.Sync // State trie sync scheduler defining the tasks
keccak hash.Hash // Keccak256 hasher to verify deliveries with keccak hash.Hash // Keccak256 hasher to verify deliveries with
@ -268,8 +277,6 @@ type stateSync struct {
cancelOnce sync.Once // Ensures cancel only ever gets called once cancelOnce sync.Once // Ensures cancel only ever gets called once
done chan struct{} // Channel to signal termination completion done chan struct{} // Channel to signal termination completion
err error // Any error hit during sync (set before completion) err error // Any error hit during sync (set before completion)
root common.Hash
} }
// trieTask represents a single trie node download task, containing a set of // trieTask represents a single trie node download task, containing a set of
@ -290,6 +297,7 @@ type codeTask struct {
func newStateSync(d *Downloader, root common.Hash) *stateSync { func newStateSync(d *Downloader, root common.Hash) *stateSync {
return &stateSync{ return &stateSync{
d: d, d: d,
root: root,
sched: state.NewStateSync(root, d.stateDB, d.stateBloom), sched: state.NewStateSync(root, d.stateDB, d.stateBloom),
keccak: sha3.NewLegacyKeccak256(), keccak: sha3.NewLegacyKeccak256(),
trieTasks: make(map[common.Hash]*trieTask), trieTasks: make(map[common.Hash]*trieTask),
@ -298,7 +306,6 @@ func newStateSync(d *Downloader, root common.Hash) *stateSync {
cancel: make(chan struct{}), cancel: make(chan struct{}),
done: make(chan struct{}), done: make(chan struct{}),
started: make(chan struct{}), started: make(chan struct{}),
root: root,
} }
} }
@ -306,7 +313,12 @@ func newStateSync(d *Downloader, root common.Hash) *stateSync {
// it finishes, and finally notifying any goroutines waiting for the loop to // it finishes, and finally notifying any goroutines waiting for the loop to
// finish. // finish.
func (s *stateSync) run() { func (s *stateSync) run() {
close(s.started)
if s.d.snapSync {
s.err = s.d.SnapSyncer.Sync(s.root, s.cancel)
} else {
s.err = s.loop() s.err = s.loop()
}
close(s.done) close(s.done)
} }
@ -318,7 +330,9 @@ func (s *stateSync) Wait() error {
// Cancel cancels the sync and waits until it has shut down. // Cancel cancels the sync and waits until it has shut down.
func (s *stateSync) Cancel() error { func (s *stateSync) Cancel() error {
s.cancelOnce.Do(func() { close(s.cancel) }) s.cancelOnce.Do(func() {
close(s.cancel)
})
return s.Wait() return s.Wait()
} }
@ -329,7 +343,6 @@ func (s *stateSync) Cancel() error {
// pushed here async. The reason is to decouple processing from data receipt // pushed here async. The reason is to decouple processing from data receipt
// and timeouts. // and timeouts.
func (s *stateSync) loop() (err error) { func (s *stateSync) loop() (err error) {
close(s.started)
// Listen for new peer events to assign tasks to them // Listen for new peer events to assign tasks to them
newPeer := make(chan *peerConnection, 1024) newPeer := make(chan *peerConnection, 1024)
peerSub := s.d.peers.SubscribeNewPeers(newPeer) peerSub := s.d.peers.SubscribeNewPeers(newPeer)

View File

@ -20,7 +20,7 @@ func (c Config) MarshalTOML() (interface{}, error) {
Genesis *core.Genesis `toml:",omitempty"` Genesis *core.Genesis `toml:",omitempty"`
NetworkId uint64 NetworkId uint64
SyncMode downloader.SyncMode SyncMode downloader.SyncMode
DiscoveryURLs []string EthDiscoveryURLs []string
NoPruning bool NoPruning bool
NoPrefetch bool NoPrefetch bool
TxLookupLimit uint64 `toml:",omitempty"` TxLookupLimit uint64 `toml:",omitempty"`
@ -61,7 +61,7 @@ func (c Config) MarshalTOML() (interface{}, error) {
enc.Genesis = c.Genesis enc.Genesis = c.Genesis
enc.NetworkId = c.NetworkId enc.NetworkId = c.NetworkId
enc.SyncMode = c.SyncMode enc.SyncMode = c.SyncMode
enc.DiscoveryURLs = c.DiscoveryURLs enc.EthDiscoveryURLs = c.EthDiscoveryURLs
enc.NoPruning = c.NoPruning enc.NoPruning = c.NoPruning
enc.NoPrefetch = c.NoPrefetch enc.NoPrefetch = c.NoPrefetch
enc.TxLookupLimit = c.TxLookupLimit enc.TxLookupLimit = c.TxLookupLimit
@ -106,7 +106,7 @@ func (c *Config) UnmarshalTOML(unmarshal func(interface{}) error) error {
Genesis *core.Genesis `toml:",omitempty"` Genesis *core.Genesis `toml:",omitempty"`
NetworkId *uint64 NetworkId *uint64
SyncMode *downloader.SyncMode SyncMode *downloader.SyncMode
DiscoveryURLs []string EthDiscoveryURLs []string
NoPruning *bool NoPruning *bool
NoPrefetch *bool NoPrefetch *bool
TxLookupLimit *uint64 `toml:",omitempty"` TxLookupLimit *uint64 `toml:",omitempty"`
@ -156,8 +156,8 @@ func (c *Config) UnmarshalTOML(unmarshal func(interface{}) error) error {
if dec.SyncMode != nil { if dec.SyncMode != nil {
c.SyncMode = *dec.SyncMode c.SyncMode = *dec.SyncMode
} }
if dec.DiscoveryURLs != nil { if dec.EthDiscoveryURLs != nil {
c.DiscoveryURLs = dec.DiscoveryURLs c.EthDiscoveryURLs = dec.EthDiscoveryURLs
} }
if dec.NoPruning != nil { if dec.NoPruning != nil {
c.NoPruning = *dec.NoPruning c.NoPruning = *dec.NoPruning

File diff suppressed because it is too large Load Diff

218
eth/handler_eth.go Normal file
View File

@ -0,0 +1,218 @@
// Copyright 2015 The go-ethereum Authors
// This file is part of the go-ethereum library.
//
// The go-ethereum library is free software: you can redistribute it and/or modify
// it under the terms of the GNU Lesser General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// The go-ethereum library is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Lesser General Public License for more details.
//
// You should have received a copy of the GNU Lesser General Public License
// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
package eth
import (
"errors"
"fmt"
"math/big"
"sync/atomic"
"time"
"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/core"
"github.com/ethereum/go-ethereum/core/types"
"github.com/ethereum/go-ethereum/eth/protocols/eth"
"github.com/ethereum/go-ethereum/log"
"github.com/ethereum/go-ethereum/p2p/enode"
"github.com/ethereum/go-ethereum/trie"
)
// ethHandler implements the eth.Backend interface to handle the various network
// packets that are sent as replies or broadcasts.
type ethHandler handler
func (h *ethHandler) Chain() *core.BlockChain { return h.chain }
func (h *ethHandler) StateBloom() *trie.SyncBloom { return h.stateBloom }
func (h *ethHandler) TxPool() eth.TxPool { return h.txpool }
// RunPeer is invoked when a peer joins on the `eth` protocol.
func (h *ethHandler) RunPeer(peer *eth.Peer, hand eth.Handler) error {
return (*handler)(h).runEthPeer(peer, hand)
}
// PeerInfo retrieves all known `eth` information about a peer.
func (h *ethHandler) PeerInfo(id enode.ID) interface{} {
if p := h.peers.ethPeer(id.String()); p != nil {
return p.info()
}
return nil
}
// AcceptTxs retrieves whether transaction processing is enabled on the node
// or if inbound transactions should simply be dropped.
func (h *ethHandler) AcceptTxs() bool {
return atomic.LoadUint32(&h.acceptTxs) == 1
}
// Handle is invoked from a peer's message handler when it receives a new remote
// message that the handler couldn't consume and serve itself.
func (h *ethHandler) Handle(peer *eth.Peer, packet eth.Packet) error {
// Consume any broadcasts and announces, forwarding the rest to the downloader
switch packet := packet.(type) {
case *eth.BlockHeadersPacket:
return h.handleHeaders(peer, *packet)
case *eth.BlockBodiesPacket:
txset, uncleset := packet.Unpack()
return h.handleBodies(peer, txset, uncleset)
case *eth.NodeDataPacket:
if err := h.downloader.DeliverNodeData(peer.ID(), *packet); err != nil {
log.Debug("Failed to deliver node state data", "err", err)
}
return nil
case *eth.ReceiptsPacket:
if err := h.downloader.DeliverReceipts(peer.ID(), *packet); err != nil {
log.Debug("Failed to deliver receipts", "err", err)
}
return nil
case *eth.NewBlockHashesPacket:
hashes, numbers := packet.Unpack()
return h.handleBlockAnnounces(peer, hashes, numbers)
case *eth.NewBlockPacket:
return h.handleBlockBroadcast(peer, packet.Block, packet.TD)
case *eth.NewPooledTransactionHashesPacket:
return h.txFetcher.Notify(peer.ID(), *packet)
case *eth.TransactionsPacket:
return h.txFetcher.Enqueue(peer.ID(), *packet, false)
case *eth.PooledTransactionsPacket:
return h.txFetcher.Enqueue(peer.ID(), *packet, true)
default:
return fmt.Errorf("unexpected eth packet type: %T", packet)
}
}
// handleHeaders is invoked from a peer's message handler when it transmits a batch
// of headers for the local node to process.
func (h *ethHandler) handleHeaders(peer *eth.Peer, headers []*types.Header) error {
p := h.peers.ethPeer(peer.ID())
if p == nil {
return errors.New("unregistered during callback")
}
// If no headers were received, but we're expencting a checkpoint header, consider it that
if len(headers) == 0 && p.syncDrop != nil {
// Stop the timer either way, decide later to drop or not
p.syncDrop.Stop()
p.syncDrop = nil
// If we're doing a fast (or snap) sync, we must enforce the checkpoint block to avoid
// eclipse attacks. Unsynced nodes are welcome to connect after we're done
// joining the network
if atomic.LoadUint32(&h.fastSync) == 1 {
peer.Log().Warn("Dropping unsynced node during sync", "addr", peer.RemoteAddr(), "type", peer.Name())
return errors.New("unsynced node cannot serve sync")
}
}
// Filter out any explicitly requested headers, deliver the rest to the downloader
filter := len(headers) == 1
if filter {
// If it's a potential sync progress check, validate the content and advertised chain weight
if p.syncDrop != nil && headers[0].Number.Uint64() == h.checkpointNumber {
// Disable the sync drop timer
p.syncDrop.Stop()
p.syncDrop = nil
// Validate the header and either drop the peer or continue
if headers[0].Hash() != h.checkpointHash {
return errors.New("checkpoint hash mismatch")
}
return nil
}
// Otherwise if it's a whitelisted block, validate against the set
if want, ok := h.whitelist[headers[0].Number.Uint64()]; ok {
if hash := headers[0].Hash(); want != hash {
peer.Log().Info("Whitelist mismatch, dropping peer", "number", headers[0].Number.Uint64(), "hash", hash, "want", want)
return errors.New("whitelist block mismatch")
}
peer.Log().Debug("Whitelist block verified", "number", headers[0].Number.Uint64(), "hash", want)
}
// Irrelevant of the fork checks, send the header to the fetcher just in case
headers = h.blockFetcher.FilterHeaders(peer.ID(), headers, time.Now())
}
if len(headers) > 0 || !filter {
err := h.downloader.DeliverHeaders(peer.ID(), headers)
if err != nil {
log.Debug("Failed to deliver headers", "err", err)
}
}
return nil
}
// handleBodies is invoked from a peer's message handler when it transmits a batch
// of block bodies for the local node to process.
func (h *ethHandler) handleBodies(peer *eth.Peer, txs [][]*types.Transaction, uncles [][]*types.Header) error {
// Filter out any explicitly requested bodies, deliver the rest to the downloader
filter := len(txs) > 0 || len(uncles) > 0
if filter {
txs, uncles = h.blockFetcher.FilterBodies(peer.ID(), txs, uncles, time.Now())
}
if len(txs) > 0 || len(uncles) > 0 || !filter {
err := h.downloader.DeliverBodies(peer.ID(), txs, uncles)
if err != nil {
log.Debug("Failed to deliver bodies", "err", err)
}
}
return nil
}
// handleBlockAnnounces is invoked from a peer's message handler when it transmits a
// batch of block announcements for the local node to process.
func (h *ethHandler) handleBlockAnnounces(peer *eth.Peer, hashes []common.Hash, numbers []uint64) error {
// Schedule all the unknown hashes for retrieval
var (
unknownHashes = make([]common.Hash, 0, len(hashes))
unknownNumbers = make([]uint64, 0, len(numbers))
)
for i := 0; i < len(hashes); i++ {
if !h.chain.HasBlock(hashes[i], numbers[i]) {
unknownHashes = append(unknownHashes, hashes[i])
unknownNumbers = append(unknownNumbers, numbers[i])
}
}
for i := 0; i < len(unknownHashes); i++ {
h.blockFetcher.Notify(peer.ID(), unknownHashes[i], unknownNumbers[i], time.Now(), peer.RequestOneHeader, peer.RequestBodies)
}
return nil
}
// handleBlockBroadcast is invoked from a peer's message handler when it transmits a
// block broadcast for the local node to process.
func (h *ethHandler) handleBlockBroadcast(peer *eth.Peer, block *types.Block, td *big.Int) error {
// Schedule the block for import
h.blockFetcher.Enqueue(peer.ID(), block)
// Assuming the block is importable by the peer, but possibly not yet done so,
// calculate the head hash and TD that the peer truly must have.
var (
trueHead = block.ParentHash()
trueTD = new(big.Int).Sub(td, block.Difficulty())
)
// Update the peer's total difficulty if better than the previous
if _, td := peer.Head(); trueTD.Cmp(td) > 0 {
peer.SetHead(trueHead, trueTD)
h.chainSync.handlePeerEvent(peer)
}
return nil
}

740
eth/handler_eth_test.go Normal file
View File

@ -0,0 +1,740 @@
// Copyright 2014 The go-ethereum Authors
// This file is part of the go-ethereum library.
//
// The go-ethereum library is free software: you can redistribute it and/or modify
// it under the terms of the GNU Lesser General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// The go-ethereum library is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Lesser General Public License for more details.
//
// You should have received a copy of the GNU Lesser General Public License
// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
package eth
import (
"fmt"
"math/big"
"math/rand"
"sync/atomic"
"testing"
"time"
"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/consensus/ethash"
"github.com/ethereum/go-ethereum/core"
"github.com/ethereum/go-ethereum/core/forkid"
"github.com/ethereum/go-ethereum/core/rawdb"
"github.com/ethereum/go-ethereum/core/types"
"github.com/ethereum/go-ethereum/core/vm"
"github.com/ethereum/go-ethereum/eth/downloader"
"github.com/ethereum/go-ethereum/eth/protocols/eth"
"github.com/ethereum/go-ethereum/event"
"github.com/ethereum/go-ethereum/p2p"
"github.com/ethereum/go-ethereum/p2p/enode"
"github.com/ethereum/go-ethereum/params"
"github.com/ethereum/go-ethereum/trie"
)
// testEthHandler is a mock event handler to listen for inbound network requests
// on the `eth` protocol and convert them into a more easily testable form.
type testEthHandler struct {
blockBroadcasts event.Feed
txAnnounces event.Feed
txBroadcasts event.Feed
}
func (h *testEthHandler) Chain() *core.BlockChain { panic("no backing chain") }
func (h *testEthHandler) StateBloom() *trie.SyncBloom { panic("no backing state bloom") }
func (h *testEthHandler) TxPool() eth.TxPool { panic("no backing tx pool") }
func (h *testEthHandler) AcceptTxs() bool { return true }
func (h *testEthHandler) RunPeer(*eth.Peer, eth.Handler) error { panic("not used in tests") }
func (h *testEthHandler) PeerInfo(enode.ID) interface{} { panic("not used in tests") }
func (h *testEthHandler) Handle(peer *eth.Peer, packet eth.Packet) error {
switch packet := packet.(type) {
case *eth.NewBlockPacket:
h.blockBroadcasts.Send(packet.Block)
return nil
case *eth.NewPooledTransactionHashesPacket:
h.txAnnounces.Send(([]common.Hash)(*packet))
return nil
case *eth.TransactionsPacket:
h.txBroadcasts.Send(([]*types.Transaction)(*packet))
return nil
case *eth.PooledTransactionsPacket:
h.txBroadcasts.Send(([]*types.Transaction)(*packet))
return nil
default:
panic(fmt.Sprintf("unexpected eth packet type in tests: %T", packet))
}
}
// Tests that peers are correctly accepted (or rejected) based on the advertised
// fork IDs in the protocol handshake.
func TestForkIDSplit64(t *testing.T) { testForkIDSplit(t, 64) }
func TestForkIDSplit65(t *testing.T) { testForkIDSplit(t, 65) }
func testForkIDSplit(t *testing.T, protocol uint) {
t.Parallel()
var (
engine = ethash.NewFaker()
configNoFork = &params.ChainConfig{HomesteadBlock: big.NewInt(1)}
configProFork = &params.ChainConfig{
HomesteadBlock: big.NewInt(1),
EIP150Block: big.NewInt(2),
EIP155Block: big.NewInt(2),
EIP158Block: big.NewInt(2),
ByzantiumBlock: big.NewInt(3),
}
dbNoFork = rawdb.NewMemoryDatabase()
dbProFork = rawdb.NewMemoryDatabase()
gspecNoFork = &core.Genesis{Config: configNoFork}
gspecProFork = &core.Genesis{Config: configProFork}
genesisNoFork = gspecNoFork.MustCommit(dbNoFork)
genesisProFork = gspecProFork.MustCommit(dbProFork)
chainNoFork, _ = core.NewBlockChain(dbNoFork, nil, configNoFork, engine, vm.Config{}, nil, nil)
chainProFork, _ = core.NewBlockChain(dbProFork, nil, configProFork, engine, vm.Config{}, nil, nil)
blocksNoFork, _ = core.GenerateChain(configNoFork, genesisNoFork, engine, dbNoFork, 2, nil)
blocksProFork, _ = core.GenerateChain(configProFork, genesisProFork, engine, dbProFork, 2, nil)
ethNoFork, _ = newHandler(&handlerConfig{
Database: dbNoFork,
Chain: chainNoFork,
TxPool: newTestTxPool(),
Network: 1,
Sync: downloader.FullSync,
BloomCache: 1,
})
ethProFork, _ = newHandler(&handlerConfig{
Database: dbProFork,
Chain: chainProFork,
TxPool: newTestTxPool(),
Network: 1,
Sync: downloader.FullSync,
BloomCache: 1,
})
)
ethNoFork.Start(1000)
ethProFork.Start(1000)
// Clean up everything after ourselves
defer chainNoFork.Stop()
defer chainProFork.Stop()
defer ethNoFork.Stop()
defer ethProFork.Stop()
// Both nodes should allow the other to connect (same genesis, next fork is the same)
p2pNoFork, p2pProFork := p2p.MsgPipe()
defer p2pNoFork.Close()
defer p2pProFork.Close()
peerNoFork := eth.NewPeer(protocol, p2p.NewPeer(enode.ID{1}, "", nil), p2pNoFork, nil)
peerProFork := eth.NewPeer(protocol, p2p.NewPeer(enode.ID{2}, "", nil), p2pProFork, nil)
defer peerNoFork.Close()
defer peerProFork.Close()
errc := make(chan error, 2)
go func(errc chan error) {
errc <- ethNoFork.runEthPeer(peerProFork, func(peer *eth.Peer) error { return nil })
}(errc)
go func(errc chan error) {
errc <- ethProFork.runEthPeer(peerNoFork, func(peer *eth.Peer) error { return nil })
}(errc)
for i := 0; i < 2; i++ {
select {
case err := <-errc:
if err != nil {
t.Fatalf("frontier nofork <-> profork failed: %v", err)
}
case <-time.After(250 * time.Millisecond):
t.Fatalf("frontier nofork <-> profork handler timeout")
}
}
// Progress into Homestead. Fork's match, so we don't care what the future holds
chainNoFork.InsertChain(blocksNoFork[:1])
chainProFork.InsertChain(blocksProFork[:1])
p2pNoFork, p2pProFork = p2p.MsgPipe()
defer p2pNoFork.Close()
defer p2pProFork.Close()
peerNoFork = eth.NewPeer(protocol, p2p.NewPeer(enode.ID{1}, "", nil), p2pNoFork, nil)
peerProFork = eth.NewPeer(protocol, p2p.NewPeer(enode.ID{2}, "", nil), p2pProFork, nil)
defer peerNoFork.Close()
defer peerProFork.Close()
errc = make(chan error, 2)
go func(errc chan error) {
errc <- ethNoFork.runEthPeer(peerProFork, func(peer *eth.Peer) error { return nil })
}(errc)
go func(errc chan error) {
errc <- ethProFork.runEthPeer(peerNoFork, func(peer *eth.Peer) error { return nil })
}(errc)
for i := 0; i < 2; i++ {
select {
case err := <-errc:
if err != nil {
t.Fatalf("homestead nofork <-> profork failed: %v", err)
}
case <-time.After(250 * time.Millisecond):
t.Fatalf("homestead nofork <-> profork handler timeout")
}
}
// Progress into Spurious. Forks mismatch, signalling differing chains, reject
chainNoFork.InsertChain(blocksNoFork[1:2])
chainProFork.InsertChain(blocksProFork[1:2])
p2pNoFork, p2pProFork = p2p.MsgPipe()
defer p2pNoFork.Close()
defer p2pProFork.Close()
peerNoFork = eth.NewPeer(protocol, p2p.NewPeer(enode.ID{1}, "", nil), p2pNoFork, nil)
peerProFork = eth.NewPeer(protocol, p2p.NewPeer(enode.ID{2}, "", nil), p2pProFork, nil)
defer peerNoFork.Close()
defer peerProFork.Close()
errc = make(chan error, 2)
go func(errc chan error) {
errc <- ethNoFork.runEthPeer(peerProFork, func(peer *eth.Peer) error { return nil })
}(errc)
go func(errc chan error) {
errc <- ethProFork.runEthPeer(peerNoFork, func(peer *eth.Peer) error { return nil })
}(errc)
var successes int
for i := 0; i < 2; i++ {
select {
case err := <-errc:
if err == nil {
successes++
if successes == 2 { // Only one side disconnects
t.Fatalf("fork ID rejection didn't happen")
}
}
case <-time.After(250 * time.Millisecond):
t.Fatalf("split peers not rejected")
}
}
}
// Tests that received transactions are added to the local pool.
func TestRecvTransactions64(t *testing.T) { testRecvTransactions(t, 64) }
func TestRecvTransactions65(t *testing.T) { testRecvTransactions(t, 65) }
func testRecvTransactions(t *testing.T, protocol uint) {
t.Parallel()
// Create a message handler, configure it to accept transactions and watch them
handler := newTestHandler()
defer handler.close()
handler.handler.acceptTxs = 1 // mark synced to accept transactions
txs := make(chan core.NewTxsEvent)
sub := handler.txpool.SubscribeNewTxsEvent(txs)
defer sub.Unsubscribe()
// Create a source peer to send messages through and a sink handler to receive them
p2pSrc, p2pSink := p2p.MsgPipe()
defer p2pSrc.Close()
defer p2pSink.Close()
src := eth.NewPeer(protocol, p2p.NewPeer(enode.ID{1}, "", nil), p2pSrc, handler.txpool)
sink := eth.NewPeer(protocol, p2p.NewPeer(enode.ID{2}, "", nil), p2pSink, handler.txpool)
defer src.Close()
defer sink.Close()
go handler.handler.runEthPeer(sink, func(peer *eth.Peer) error {
return eth.Handle((*ethHandler)(handler.handler), peer)
})
// Run the handshake locally to avoid spinning up a source handler
var (
genesis = handler.chain.Genesis()
head = handler.chain.CurrentBlock()
td = handler.chain.GetTd(head.Hash(), head.NumberU64())
)
if err := src.Handshake(1, td, head.Hash(), genesis.Hash(), forkid.NewIDWithChain(handler.chain), forkid.NewFilter(handler.chain)); err != nil {
t.Fatalf("failed to run protocol handshake")
}
// Send the transaction to the sink and verify that it's added to the tx pool
tx := types.NewTransaction(0, common.Address{}, big.NewInt(0), 100000, big.NewInt(0), nil)
tx, _ = types.SignTx(tx, types.HomesteadSigner{}, testKey)
if err := src.SendTransactions([]*types.Transaction{tx}); err != nil {
t.Fatalf("failed to send transaction: %v", err)
}
select {
case event := <-txs:
if len(event.Txs) != 1 {
t.Errorf("wrong number of added transactions: got %d, want 1", len(event.Txs))
} else if event.Txs[0].Hash() != tx.Hash() {
t.Errorf("added wrong tx hash: got %v, want %v", event.Txs[0].Hash(), tx.Hash())
}
case <-time.After(2 * time.Second):
t.Errorf("no NewTxsEvent received within 2 seconds")
}
}
// This test checks that pending transactions are sent.
func TestSendTransactions64(t *testing.T) { testSendTransactions(t, 64) }
func TestSendTransactions65(t *testing.T) { testSendTransactions(t, 65) }
func testSendTransactions(t *testing.T, protocol uint) {
t.Parallel()
// Create a message handler and fill the pool with big transactions
handler := newTestHandler()
defer handler.close()
insert := make([]*types.Transaction, 100)
for nonce := range insert {
tx := types.NewTransaction(uint64(nonce), common.Address{}, big.NewInt(0), 100000, big.NewInt(0), make([]byte, txsyncPackSize/10))
tx, _ = types.SignTx(tx, types.HomesteadSigner{}, testKey)
insert[nonce] = tx
}
go handler.txpool.AddRemotes(insert) // Need goroutine to not block on feed
time.Sleep(250 * time.Millisecond) // Wait until tx events get out of the system (can't use events, tx broadcaster races with peer join)
// Create a source handler to send messages through and a sink peer to receive them
p2pSrc, p2pSink := p2p.MsgPipe()
defer p2pSrc.Close()
defer p2pSink.Close()
src := eth.NewPeer(protocol, p2p.NewPeer(enode.ID{1}, "", nil), p2pSrc, handler.txpool)
sink := eth.NewPeer(protocol, p2p.NewPeer(enode.ID{2}, "", nil), p2pSink, handler.txpool)
defer src.Close()
defer sink.Close()
go handler.handler.runEthPeer(src, func(peer *eth.Peer) error {
return eth.Handle((*ethHandler)(handler.handler), peer)
})
// Run the handshake locally to avoid spinning up a source handler
var (
genesis = handler.chain.Genesis()
head = handler.chain.CurrentBlock()
td = handler.chain.GetTd(head.Hash(), head.NumberU64())
)
if err := sink.Handshake(1, td, head.Hash(), genesis.Hash(), forkid.NewIDWithChain(handler.chain), forkid.NewFilter(handler.chain)); err != nil {
t.Fatalf("failed to run protocol handshake")
}
// After the handshake completes, the source handler should stream the sink
// the transactions, subscribe to all inbound network events
backend := new(testEthHandler)
anns := make(chan []common.Hash)
annSub := backend.txAnnounces.Subscribe(anns)
defer annSub.Unsubscribe()
bcasts := make(chan []*types.Transaction)
bcastSub := backend.txBroadcasts.Subscribe(bcasts)
defer bcastSub.Unsubscribe()
go eth.Handle(backend, sink)
// Make sure we get all the transactions on the correct channels
seen := make(map[common.Hash]struct{})
for len(seen) < len(insert) {
switch protocol {
case 63, 64:
select {
case <-anns:
t.Errorf("tx announce received on pre eth/65")
case txs := <-bcasts:
for _, tx := range txs {
if _, ok := seen[tx.Hash()]; ok {
t.Errorf("duplicate transaction announced: %x", tx.Hash())
}
seen[tx.Hash()] = struct{}{}
}
}
case 65:
select {
case hashes := <-anns:
for _, hash := range hashes {
if _, ok := seen[hash]; ok {
t.Errorf("duplicate transaction announced: %x", hash)
}
seen[hash] = struct{}{}
}
case <-bcasts:
t.Errorf("initial tx broadcast received on post eth/65")
}
default:
panic("unsupported protocol, please extend test")
}
}
for _, tx := range insert {
if _, ok := seen[tx.Hash()]; !ok {
t.Errorf("missing transaction: %x", tx.Hash())
}
}
}
// Tests that transactions get propagated to all attached peers, either via direct
// broadcasts or via announcements/retrievals.
func TestTransactionPropagation64(t *testing.T) { testTransactionPropagation(t, 64) }
func TestTransactionPropagation65(t *testing.T) { testTransactionPropagation(t, 65) }
func testTransactionPropagation(t *testing.T, protocol uint) {
t.Parallel()
// Create a source handler to send transactions from and a number of sinks
// to receive them. We need multiple sinks since a one-to-one peering would
// broadcast all transactions without announcement.
source := newTestHandler()
defer source.close()
sinks := make([]*testHandler, 10)
for i := 0; i < len(sinks); i++ {
sinks[i] = newTestHandler()
defer sinks[i].close()
sinks[i].handler.acceptTxs = 1 // mark synced to accept transactions
}
// Interconnect all the sink handlers with the source handler
for i, sink := range sinks {
sink := sink // Closure for gorotuine below
sourcePipe, sinkPipe := p2p.MsgPipe()
defer sourcePipe.Close()
defer sinkPipe.Close()
sourcePeer := eth.NewPeer(protocol, p2p.NewPeer(enode.ID{byte(i)}, "", nil), sourcePipe, source.txpool)
sinkPeer := eth.NewPeer(protocol, p2p.NewPeer(enode.ID{0}, "", nil), sinkPipe, sink.txpool)
defer sourcePeer.Close()
defer sinkPeer.Close()
go source.handler.runEthPeer(sourcePeer, func(peer *eth.Peer) error {
return eth.Handle((*ethHandler)(source.handler), peer)
})
go sink.handler.runEthPeer(sinkPeer, func(peer *eth.Peer) error {
return eth.Handle((*ethHandler)(sink.handler), peer)
})
}
// Subscribe to all the transaction pools
txChs := make([]chan core.NewTxsEvent, len(sinks))
for i := 0; i < len(sinks); i++ {
txChs[i] = make(chan core.NewTxsEvent, 1024)
sub := sinks[i].txpool.SubscribeNewTxsEvent(txChs[i])
defer sub.Unsubscribe()
}
// Fill the source pool with transactions and wait for them at the sinks
txs := make([]*types.Transaction, 1024)
for nonce := range txs {
tx := types.NewTransaction(uint64(nonce), common.Address{}, big.NewInt(0), 100000, big.NewInt(0), nil)
tx, _ = types.SignTx(tx, types.HomesteadSigner{}, testKey)
txs[nonce] = tx
}
source.txpool.AddRemotes(txs)
// Iterate through all the sinks and ensure they all got the transactions
for i := range sinks {
for arrived := 0; arrived < len(txs); {
select {
case event := <-txChs[i]:
arrived += len(event.Txs)
case <-time.NewTimer(time.Second).C:
t.Errorf("sink %d: transaction propagation timed out: have %d, want %d", i, arrived, len(txs))
}
}
}
}
// Tests that post eth protocol handshake, clients perform a mutual checkpoint
// challenge to validate each other's chains. Hash mismatches, or missing ones
// during a fast sync should lead to the peer getting dropped.
func TestCheckpointChallenge(t *testing.T) {
tests := []struct {
syncmode downloader.SyncMode
checkpoint bool
timeout bool
empty bool
match bool
drop bool
}{
// If checkpointing is not enabled locally, don't challenge and don't drop
{downloader.FullSync, false, false, false, false, false},
{downloader.FastSync, false, false, false, false, false},
// If checkpointing is enabled locally and remote response is empty, only drop during fast sync
{downloader.FullSync, true, false, true, false, false},
{downloader.FastSync, true, false, true, false, true}, // Special case, fast sync, unsynced peer
// If checkpointing is enabled locally and remote response mismatches, always drop
{downloader.FullSync, true, false, false, false, true},
{downloader.FastSync, true, false, false, false, true},
// If checkpointing is enabled locally and remote response matches, never drop
{downloader.FullSync, true, false, false, true, false},
{downloader.FastSync, true, false, false, true, false},
// If checkpointing is enabled locally and remote times out, always drop
{downloader.FullSync, true, true, false, true, true},
{downloader.FastSync, true, true, false, true, true},
}
for _, tt := range tests {
t.Run(fmt.Sprintf("sync %v checkpoint %v timeout %v empty %v match %v", tt.syncmode, tt.checkpoint, tt.timeout, tt.empty, tt.match), func(t *testing.T) {
testCheckpointChallenge(t, tt.syncmode, tt.checkpoint, tt.timeout, tt.empty, tt.match, tt.drop)
})
}
}
func testCheckpointChallenge(t *testing.T, syncmode downloader.SyncMode, checkpoint bool, timeout bool, empty bool, match bool, drop bool) {
// Reduce the checkpoint handshake challenge timeout
defer func(old time.Duration) { syncChallengeTimeout = old }(syncChallengeTimeout)
syncChallengeTimeout = 250 * time.Millisecond
// Create a test handler and inject a CHT into it. The injection is a bit
// ugly, but it beats creating everything manually just to avoid reaching
// into the internals a bit.
handler := newTestHandler()
defer handler.close()
if syncmode == downloader.FastSync {
atomic.StoreUint32(&handler.handler.fastSync, 1)
} else {
atomic.StoreUint32(&handler.handler.fastSync, 0)
}
var response *types.Header
if checkpoint {
number := (uint64(rand.Intn(500))+1)*params.CHTFrequency - 1
response = &types.Header{Number: big.NewInt(int64(number)), Extra: []byte("valid")}
handler.handler.checkpointNumber = number
handler.handler.checkpointHash = response.Hash()
}
// Create a challenger peer and a challenged one
p2pLocal, p2pRemote := p2p.MsgPipe()
defer p2pLocal.Close()
defer p2pRemote.Close()
local := eth.NewPeer(eth.ETH64, p2p.NewPeer(enode.ID{1}, "", nil), p2pLocal, handler.txpool)
remote := eth.NewPeer(eth.ETH64, p2p.NewPeer(enode.ID{2}, "", nil), p2pRemote, handler.txpool)
defer local.Close()
defer remote.Close()
go handler.handler.runEthPeer(local, func(peer *eth.Peer) error {
return eth.Handle((*ethHandler)(handler.handler), peer)
})
// Run the handshake locally to avoid spinning up a remote handler
var (
genesis = handler.chain.Genesis()
head = handler.chain.CurrentBlock()
td = handler.chain.GetTd(head.Hash(), head.NumberU64())
)
if err := remote.Handshake(1, td, head.Hash(), genesis.Hash(), forkid.NewIDWithChain(handler.chain), forkid.NewFilter(handler.chain)); err != nil {
t.Fatalf("failed to run protocol handshake")
}
// Connect a new peer and check that we receive the checkpoint challenge
if checkpoint {
if err := remote.ExpectRequestHeadersByNumber(response.Number.Uint64(), 1, 0, false); err != nil {
t.Fatalf("challenge mismatch: %v", err)
}
// Create a block to reply to the challenge if no timeout is simulated
if !timeout {
if empty {
if err := remote.SendBlockHeaders([]*types.Header{}); err != nil {
t.Fatalf("failed to answer challenge: %v", err)
}
} else if match {
if err := remote.SendBlockHeaders([]*types.Header{response}); err != nil {
t.Fatalf("failed to answer challenge: %v", err)
}
} else {
if err := remote.SendBlockHeaders([]*types.Header{{Number: response.Number}}); err != nil {
t.Fatalf("failed to answer challenge: %v", err)
}
}
}
}
// Wait until the test timeout passes to ensure proper cleanup
time.Sleep(syncChallengeTimeout + 300*time.Millisecond)
// Verify that the remote peer is maintained or dropped
if drop {
if peers := handler.handler.peers.Len(); peers != 0 {
t.Fatalf("peer count mismatch: have %d, want %d", peers, 0)
}
} else {
if peers := handler.handler.peers.Len(); peers != 1 {
t.Fatalf("peer count mismatch: have %d, want %d", peers, 1)
}
}
}
// Tests that blocks are broadcast to a sqrt number of peers only.
func TestBroadcastBlock1Peer(t *testing.T) { testBroadcastBlock(t, 1, 1) }
func TestBroadcastBlock2Peers(t *testing.T) { testBroadcastBlock(t, 2, 1) }
func TestBroadcastBlock3Peers(t *testing.T) { testBroadcastBlock(t, 3, 1) }
func TestBroadcastBlock4Peers(t *testing.T) { testBroadcastBlock(t, 4, 2) }
func TestBroadcastBlock5Peers(t *testing.T) { testBroadcastBlock(t, 5, 2) }
func TestBroadcastBlock8Peers(t *testing.T) { testBroadcastBlock(t, 9, 3) }
func TestBroadcastBlock12Peers(t *testing.T) { testBroadcastBlock(t, 12, 3) }
func TestBroadcastBlock16Peers(t *testing.T) { testBroadcastBlock(t, 16, 4) }
func TestBroadcastBloc26Peers(t *testing.T) { testBroadcastBlock(t, 26, 5) }
func TestBroadcastBlock100Peers(t *testing.T) { testBroadcastBlock(t, 100, 10) }
func testBroadcastBlock(t *testing.T, peers, bcasts int) {
t.Parallel()
// Create a source handler to broadcast blocks from and a number of sinks
// to receive them.
source := newTestHandlerWithBlocks(1)
defer source.close()
sinks := make([]*testEthHandler, peers)
for i := 0; i < len(sinks); i++ {
sinks[i] = new(testEthHandler)
}
// Interconnect all the sink handlers with the source handler
var (
genesis = source.chain.Genesis()
td = source.chain.GetTd(genesis.Hash(), genesis.NumberU64())
)
for i, sink := range sinks {
sink := sink // Closure for gorotuine below
sourcePipe, sinkPipe := p2p.MsgPipe()
defer sourcePipe.Close()
defer sinkPipe.Close()
sourcePeer := eth.NewPeer(eth.ETH64, p2p.NewPeer(enode.ID{byte(i)}, "", nil), sourcePipe, nil)
sinkPeer := eth.NewPeer(eth.ETH64, p2p.NewPeer(enode.ID{0}, "", nil), sinkPipe, nil)
defer sourcePeer.Close()
defer sinkPeer.Close()
go source.handler.runEthPeer(sourcePeer, func(peer *eth.Peer) error {
return eth.Handle((*ethHandler)(source.handler), peer)
})
if err := sinkPeer.Handshake(1, td, genesis.Hash(), genesis.Hash(), forkid.NewIDWithChain(source.chain), forkid.NewFilter(source.chain)); err != nil {
t.Fatalf("failed to run protocol handshake")
}
go eth.Handle(sink, sinkPeer)
}
// Subscribe to all the transaction pools
blockChs := make([]chan *types.Block, len(sinks))
for i := 0; i < len(sinks); i++ {
blockChs[i] = make(chan *types.Block, 1)
defer close(blockChs[i])
sub := sinks[i].blockBroadcasts.Subscribe(blockChs[i])
defer sub.Unsubscribe()
}
// Initiate a block propagation across the peers
time.Sleep(100 * time.Millisecond)
source.handler.BroadcastBlock(source.chain.CurrentBlock(), true)
// Iterate through all the sinks and ensure the correct number got the block
done := make(chan struct{}, peers)
for _, ch := range blockChs {
ch := ch
go func() {
<-ch
done <- struct{}{}
}()
}
var received int
for {
select {
case <-done:
received++
case <-time.After(100 * time.Millisecond):
if received != bcasts {
t.Errorf("broadcast count mismatch: have %d, want %d", received, bcasts)
}
return
}
}
}
// Tests that a propagated malformed block (uncles or transactions don't match
// with the hashes in the header) gets discarded and not broadcast forward.
func TestBroadcastMalformedBlock64(t *testing.T) { testBroadcastMalformedBlock(t, 64) }
func TestBroadcastMalformedBlock65(t *testing.T) { testBroadcastMalformedBlock(t, 65) }
func testBroadcastMalformedBlock(t *testing.T, protocol uint) {
t.Parallel()
// Create a source handler to broadcast blocks from and a number of sinks
// to receive them.
source := newTestHandlerWithBlocks(1)
defer source.close()
// Create a source handler to send messages through and a sink peer to receive them
p2pSrc, p2pSink := p2p.MsgPipe()
defer p2pSrc.Close()
defer p2pSink.Close()
src := eth.NewPeer(protocol, p2p.NewPeer(enode.ID{1}, "", nil), p2pSrc, source.txpool)
sink := eth.NewPeer(protocol, p2p.NewPeer(enode.ID{2}, "", nil), p2pSink, source.txpool)
defer src.Close()
defer sink.Close()
go source.handler.runEthPeer(src, func(peer *eth.Peer) error {
return eth.Handle((*ethHandler)(source.handler), peer)
})
// Run the handshake locally to avoid spinning up a sink handler
var (
genesis = source.chain.Genesis()
td = source.chain.GetTd(genesis.Hash(), genesis.NumberU64())
)
if err := sink.Handshake(1, td, genesis.Hash(), genesis.Hash(), forkid.NewIDWithChain(source.chain), forkid.NewFilter(source.chain)); err != nil {
t.Fatalf("failed to run protocol handshake")
}
// After the handshake completes, the source handler should stream the sink
// the blocks, subscribe to inbound network events
backend := new(testEthHandler)
blocks := make(chan *types.Block, 1)
sub := backend.blockBroadcasts.Subscribe(blocks)
defer sub.Unsubscribe()
go eth.Handle(backend, sink)
// Create various combinations of malformed blocks
head := source.chain.CurrentBlock()
malformedUncles := head.Header()
malformedUncles.UncleHash[0]++
malformedTransactions := head.Header()
malformedTransactions.TxHash[0]++
malformedEverything := head.Header()
malformedEverything.UncleHash[0]++
malformedEverything.TxHash[0]++
// Try to broadcast all malformations and ensure they all get discarded
for _, header := range []*types.Header{malformedUncles, malformedTransactions, malformedEverything} {
block := types.NewBlockWithHeader(header).WithBody(head.Transactions(), head.Uncles())
if err := src.SendNewBlock(block, big.NewInt(131136)); err != nil {
t.Fatalf("failed to broadcast block: %v", err)
}
select {
case <-blocks:
t.Fatalf("malformed block forwarded")
case <-time.After(100 * time.Millisecond):
}
}
}

48
eth/handler_snap.go Normal file
View File

@ -0,0 +1,48 @@
// Copyright 2020 The go-ethereum Authors
// This file is part of the go-ethereum library.
//
// The go-ethereum library is free software: you can redistribute it and/or modify
// it under the terms of the GNU Lesser General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// The go-ethereum library is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Lesser General Public License for more details.
//
// You should have received a copy of the GNU Lesser General Public License
// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
package eth
import (
"github.com/ethereum/go-ethereum/core"
"github.com/ethereum/go-ethereum/eth/protocols/snap"
"github.com/ethereum/go-ethereum/p2p/enode"
)
// snapHandler implements the snap.Backend interface to handle the various network
// packets that are sent as replies or broadcasts.
type snapHandler handler
func (h *snapHandler) Chain() *core.BlockChain { return h.chain }
// RunPeer is invoked when a peer joins on the `snap` protocol.
func (h *snapHandler) RunPeer(peer *snap.Peer, hand snap.Handler) error {
return (*handler)(h).runSnapPeer(peer, hand)
}
// PeerInfo retrieves all known `snap` information about a peer.
func (h *snapHandler) PeerInfo(id enode.ID) interface{} {
if p := h.peers.snapPeer(id.String()); p != nil {
return p.info()
}
return nil
}
// Handle is invoked from a peer's message handler when it receives a new remote
// message that the handler couldn't consume and serve itself.
func (h *snapHandler) Handle(peer *snap.Peer, packet snap.Packet) error {
return h.downloader.DeliverSnapPacket(peer, packet)
}

View File

@ -17,678 +17,154 @@
package eth package eth
import ( import (
"fmt"
"math"
"math/big" "math/big"
"math/rand" "sort"
"testing" "sync"
"time"
"github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/consensus/ethash" "github.com/ethereum/go-ethereum/consensus/ethash"
"github.com/ethereum/go-ethereum/core" "github.com/ethereum/go-ethereum/core"
"github.com/ethereum/go-ethereum/core/rawdb" "github.com/ethereum/go-ethereum/core/rawdb"
"github.com/ethereum/go-ethereum/core/state"
"github.com/ethereum/go-ethereum/core/types" "github.com/ethereum/go-ethereum/core/types"
"github.com/ethereum/go-ethereum/core/vm" "github.com/ethereum/go-ethereum/core/vm"
"github.com/ethereum/go-ethereum/crypto" "github.com/ethereum/go-ethereum/crypto"
"github.com/ethereum/go-ethereum/eth/downloader" "github.com/ethereum/go-ethereum/eth/downloader"
"github.com/ethereum/go-ethereum/ethdb"
"github.com/ethereum/go-ethereum/event" "github.com/ethereum/go-ethereum/event"
"github.com/ethereum/go-ethereum/p2p"
"github.com/ethereum/go-ethereum/params" "github.com/ethereum/go-ethereum/params"
) )
// Tests that block headers can be retrieved from a remote chain based on user queries. var (
func TestGetBlockHeaders63(t *testing.T) { testGetBlockHeaders(t, 63) } // testKey is a private key to use for funding a tester account.
func TestGetBlockHeaders64(t *testing.T) { testGetBlockHeaders(t, 64) } testKey, _ = crypto.HexToECDSA("b71c71a67e1177ad4e901695e1b4b9ee17ae16c6668d313eac2f96dbcda3f291")
func testGetBlockHeaders(t *testing.T, protocol int) { // testAddr is the Ethereum address of the tester account.
pm, _ := newTestProtocolManagerMust(t, downloader.FullSync, downloader.MaxHashFetch+15, nil, nil) testAddr = crypto.PubkeyToAddress(testKey.PublicKey)
peer, _ := newTestPeer("peer", protocol, pm, true) )
defer peer.close()
// Create a "random" unknown hash for testing // testTxPool is a mock transaction pool that blindly accepts all transactions.
var unknown common.Hash // Its goal is to get around setting up a valid statedb for the balance and nonce
for i := range unknown { // checks.
unknown[i] = byte(i) type testTxPool struct {
} pool map[common.Hash]*types.Transaction // Hash map of collected transactions
// Create a batch of tests for various scenarios
limit := uint64(downloader.MaxHeaderFetch)
tests := []struct {
query *getBlockHeadersData // The query to execute for header retrieval
expect []common.Hash // The hashes of the block whose headers are expected
}{
// A single random block should be retrievable by hash and number too
{
&getBlockHeadersData{Origin: hashOrNumber{Hash: pm.blockchain.GetBlockByNumber(limit / 2).Hash()}, Amount: 1},
[]common.Hash{pm.blockchain.GetBlockByNumber(limit / 2).Hash()},
}, {
&getBlockHeadersData{Origin: hashOrNumber{Number: limit / 2}, Amount: 1},
[]common.Hash{pm.blockchain.GetBlockByNumber(limit / 2).Hash()},
},
// Multiple headers should be retrievable in both directions
{
&getBlockHeadersData{Origin: hashOrNumber{Number: limit / 2}, Amount: 3},
[]common.Hash{
pm.blockchain.GetBlockByNumber(limit / 2).Hash(),
pm.blockchain.GetBlockByNumber(limit/2 + 1).Hash(),
pm.blockchain.GetBlockByNumber(limit/2 + 2).Hash(),
},
}, {
&getBlockHeadersData{Origin: hashOrNumber{Number: limit / 2}, Amount: 3, Reverse: true},
[]common.Hash{
pm.blockchain.GetBlockByNumber(limit / 2).Hash(),
pm.blockchain.GetBlockByNumber(limit/2 - 1).Hash(),
pm.blockchain.GetBlockByNumber(limit/2 - 2).Hash(),
},
},
// Multiple headers with skip lists should be retrievable
{
&getBlockHeadersData{Origin: hashOrNumber{Number: limit / 2}, Skip: 3, Amount: 3},
[]common.Hash{
pm.blockchain.GetBlockByNumber(limit / 2).Hash(),
pm.blockchain.GetBlockByNumber(limit/2 + 4).Hash(),
pm.blockchain.GetBlockByNumber(limit/2 + 8).Hash(),
},
}, {
&getBlockHeadersData{Origin: hashOrNumber{Number: limit / 2}, Skip: 3, Amount: 3, Reverse: true},
[]common.Hash{
pm.blockchain.GetBlockByNumber(limit / 2).Hash(),
pm.blockchain.GetBlockByNumber(limit/2 - 4).Hash(),
pm.blockchain.GetBlockByNumber(limit/2 - 8).Hash(),
},
},
// The chain endpoints should be retrievable
{
&getBlockHeadersData{Origin: hashOrNumber{Number: 0}, Amount: 1},
[]common.Hash{pm.blockchain.GetBlockByNumber(0).Hash()},
}, {
&getBlockHeadersData{Origin: hashOrNumber{Number: pm.blockchain.CurrentBlock().NumberU64()}, Amount: 1},
[]common.Hash{pm.blockchain.CurrentBlock().Hash()},
},
// Ensure protocol limits are honored
{
&getBlockHeadersData{Origin: hashOrNumber{Number: pm.blockchain.CurrentBlock().NumberU64() - 1}, Amount: limit + 10, Reverse: true},
pm.blockchain.GetBlockHashesFromHash(pm.blockchain.CurrentBlock().Hash(), limit),
},
// Check that requesting more than available is handled gracefully
{
&getBlockHeadersData{Origin: hashOrNumber{Number: pm.blockchain.CurrentBlock().NumberU64() - 4}, Skip: 3, Amount: 3},
[]common.Hash{
pm.blockchain.GetBlockByNumber(pm.blockchain.CurrentBlock().NumberU64() - 4).Hash(),
pm.blockchain.GetBlockByNumber(pm.blockchain.CurrentBlock().NumberU64()).Hash(),
},
}, {
&getBlockHeadersData{Origin: hashOrNumber{Number: 4}, Skip: 3, Amount: 3, Reverse: true},
[]common.Hash{
pm.blockchain.GetBlockByNumber(4).Hash(),
pm.blockchain.GetBlockByNumber(0).Hash(),
},
},
// Check that requesting more than available is handled gracefully, even if mid skip
{
&getBlockHeadersData{Origin: hashOrNumber{Number: pm.blockchain.CurrentBlock().NumberU64() - 4}, Skip: 2, Amount: 3},
[]common.Hash{
pm.blockchain.GetBlockByNumber(pm.blockchain.CurrentBlock().NumberU64() - 4).Hash(),
pm.blockchain.GetBlockByNumber(pm.blockchain.CurrentBlock().NumberU64() - 1).Hash(),
},
}, {
&getBlockHeadersData{Origin: hashOrNumber{Number: 4}, Skip: 2, Amount: 3, Reverse: true},
[]common.Hash{
pm.blockchain.GetBlockByNumber(4).Hash(),
pm.blockchain.GetBlockByNumber(1).Hash(),
},
},
// Check a corner case where requesting more can iterate past the endpoints
{
&getBlockHeadersData{Origin: hashOrNumber{Number: 2}, Amount: 5, Reverse: true},
[]common.Hash{
pm.blockchain.GetBlockByNumber(2).Hash(),
pm.blockchain.GetBlockByNumber(1).Hash(),
pm.blockchain.GetBlockByNumber(0).Hash(),
},
},
// Check a corner case where skipping overflow loops back into the chain start
{
&getBlockHeadersData{Origin: hashOrNumber{Hash: pm.blockchain.GetBlockByNumber(3).Hash()}, Amount: 2, Reverse: false, Skip: math.MaxUint64 - 1},
[]common.Hash{
pm.blockchain.GetBlockByNumber(3).Hash(),
},
},
// Check a corner case where skipping overflow loops back to the same header
{
&getBlockHeadersData{Origin: hashOrNumber{Hash: pm.blockchain.GetBlockByNumber(1).Hash()}, Amount: 2, Reverse: false, Skip: math.MaxUint64},
[]common.Hash{
pm.blockchain.GetBlockByNumber(1).Hash(),
},
},
// Check that non existing headers aren't returned
{
&getBlockHeadersData{Origin: hashOrNumber{Hash: unknown}, Amount: 1},
[]common.Hash{},
}, {
&getBlockHeadersData{Origin: hashOrNumber{Number: pm.blockchain.CurrentBlock().NumberU64() + 1}, Amount: 1},
[]common.Hash{},
},
}
// Run each of the tests and verify the results against the chain
for i, tt := range tests {
// Collect the headers to expect in the response
headers := []*types.Header{}
for _, hash := range tt.expect {
headers = append(headers, pm.blockchain.GetBlockByHash(hash).Header())
}
// Send the hash request and verify the response
p2p.Send(peer.app, 0x03, tt.query)
if err := p2p.ExpectMsg(peer.app, 0x04, headers); err != nil {
t.Errorf("test %d: headers mismatch: %v", i, err)
}
// If the test used number origins, repeat with hashes as the too
if tt.query.Origin.Hash == (common.Hash{}) {
if origin := pm.blockchain.GetBlockByNumber(tt.query.Origin.Number); origin != nil {
tt.query.Origin.Hash, tt.query.Origin.Number = origin.Hash(), 0
p2p.Send(peer.app, 0x03, tt.query) txFeed event.Feed // Notification feed to allow waiting for inclusion
if err := p2p.ExpectMsg(peer.app, 0x04, headers); err != nil { lock sync.RWMutex // Protects the transaction pool
t.Errorf("test %d: headers mismatch: %v", i, err)
}
}
} }
// newTestTxPool creates a mock transaction pool.
func newTestTxPool() *testTxPool {
return &testTxPool{
pool: make(map[common.Hash]*types.Transaction),
} }
} }
// Tests that block contents can be retrieved from a remote chain based on their hashes. // Has returns an indicator whether txpool has a transaction
func TestGetBlockBodies63(t *testing.T) { testGetBlockBodies(t, 63) } // cached with the given hash.
func TestGetBlockBodies64(t *testing.T) { testGetBlockBodies(t, 64) } func (p *testTxPool) Has(hash common.Hash) bool {
p.lock.Lock()
defer p.lock.Unlock()
func testGetBlockBodies(t *testing.T, protocol int) { return p.pool[hash] != nil
pm, _ := newTestProtocolManagerMust(t, downloader.FullSync, downloader.MaxBlockFetch+15, nil, nil)
peer, _ := newTestPeer("peer", protocol, pm, true)
defer peer.close()
// Create a batch of tests for various scenarios
limit := downloader.MaxBlockFetch
tests := []struct {
random int // Number of blocks to fetch randomly from the chain
explicit []common.Hash // Explicitly requested blocks
available []bool // Availability of explicitly requested blocks
expected int // Total number of existing blocks to expect
}{
{1, nil, nil, 1}, // A single random block should be retrievable
{10, nil, nil, 10}, // Multiple random blocks should be retrievable
{limit, nil, nil, limit}, // The maximum possible blocks should be retrievable
{limit + 1, nil, nil, limit}, // No more than the possible block count should be returned
{0, []common.Hash{pm.blockchain.Genesis().Hash()}, []bool{true}, 1}, // The genesis block should be retrievable
{0, []common.Hash{pm.blockchain.CurrentBlock().Hash()}, []bool{true}, 1}, // The chains head block should be retrievable
{0, []common.Hash{{}}, []bool{false}, 0}, // A non existent block should not be returned
// Existing and non-existing blocks interleaved should not cause problems
{0, []common.Hash{
{},
pm.blockchain.GetBlockByNumber(1).Hash(),
{},
pm.blockchain.GetBlockByNumber(10).Hash(),
{},
pm.blockchain.GetBlockByNumber(100).Hash(),
{},
}, []bool{false, true, false, true, false, true, false}, 3},
}
// Run each of the tests and verify the results against the chain
for i, tt := range tests {
// Collect the hashes to request, and the response to expect
hashes, seen := []common.Hash{}, make(map[int64]bool)
bodies := []*blockBody{}
for j := 0; j < tt.random; j++ {
for {
num := rand.Int63n(int64(pm.blockchain.CurrentBlock().NumberU64()))
if !seen[num] {
seen[num] = true
block := pm.blockchain.GetBlockByNumber(uint64(num))
hashes = append(hashes, block.Hash())
if len(bodies) < tt.expected {
bodies = append(bodies, &blockBody{Transactions: block.Transactions(), Uncles: block.Uncles()})
}
break
}
}
}
for j, hash := range tt.explicit {
hashes = append(hashes, hash)
if tt.available[j] && len(bodies) < tt.expected {
block := pm.blockchain.GetBlockByHash(hash)
bodies = append(bodies, &blockBody{Transactions: block.Transactions(), Uncles: block.Uncles()})
}
}
// Send the hash request and verify the response
p2p.Send(peer.app, 0x05, hashes)
if err := p2p.ExpectMsg(peer.app, 0x06, bodies); err != nil {
t.Errorf("test %d: bodies mismatch: %v", i, err)
}
}
} }
// Tests that the node state database can be retrieved based on hashes. // Get retrieves the transaction from local txpool with given
func TestGetNodeData63(t *testing.T) { testGetNodeData(t, 63) } // tx hash.
func TestGetNodeData64(t *testing.T) { testGetNodeData(t, 64) } func (p *testTxPool) Get(hash common.Hash) *types.Transaction {
p.lock.Lock()
defer p.lock.Unlock()
func testGetNodeData(t *testing.T, protocol int) { return p.pool[hash]
// Define three accounts to simulate transactions with
acc1Key, _ := crypto.HexToECDSA("8a1f9a8f95be41cd7ccb6168179afb4504aefe388d1e14474d32c45c72ce7b7a")
acc2Key, _ := crypto.HexToECDSA("49a7b37aa6f6645917e7b807e9d1c00d4fa71f18343b0d4122a4d2df64dd6fee")
acc1Addr := crypto.PubkeyToAddress(acc1Key.PublicKey)
acc2Addr := crypto.PubkeyToAddress(acc2Key.PublicKey)
signer := types.HomesteadSigner{}
// Create a chain generator with some simple transactions (blatantly stolen from @fjl/chain_markets_test)
generator := func(i int, block *core.BlockGen) {
switch i {
case 0:
// In block 1, the test bank sends account #1 some ether.
tx, _ := types.SignTx(types.NewTransaction(block.TxNonce(testBank), acc1Addr, big.NewInt(10000), params.TxGas, nil, nil), signer, testBankKey)
block.AddTx(tx)
case 1:
// In block 2, the test bank sends some more ether to account #1.
// acc1Addr passes it on to account #2.
tx1, _ := types.SignTx(types.NewTransaction(block.TxNonce(testBank), acc1Addr, big.NewInt(1000), params.TxGas, nil, nil), signer, testBankKey)
tx2, _ := types.SignTx(types.NewTransaction(block.TxNonce(acc1Addr), acc2Addr, big.NewInt(1000), params.TxGas, nil, nil), signer, acc1Key)
block.AddTx(tx1)
block.AddTx(tx2)
case 2:
// Block 3 is empty but was mined by account #2.
block.SetCoinbase(acc2Addr)
block.SetExtra([]byte("yeehaw"))
case 3:
// Block 4 includes blocks 2 and 3 as uncle headers (with modified extra data).
b2 := block.PrevBlock(1).Header()
b2.Extra = []byte("foo")
block.AddUncle(b2)
b3 := block.PrevBlock(2).Header()
b3.Extra = []byte("foo")
block.AddUncle(b3)
}
}
// Assemble the test environment
pm, db := newTestProtocolManagerMust(t, downloader.FullSync, 4, generator, nil)
peer, _ := newTestPeer("peer", protocol, pm, true)
defer peer.close()
// Fetch for now the entire chain db
hashes := []common.Hash{}
it := db.NewIterator(nil, nil)
for it.Next() {
if key := it.Key(); len(key) == common.HashLength {
hashes = append(hashes, common.BytesToHash(key))
}
}
it.Release()
p2p.Send(peer.app, 0x0d, hashes)
msg, err := peer.app.ReadMsg()
if err != nil {
t.Fatalf("failed to read node data response: %v", err)
}
if msg.Code != 0x0e {
t.Fatalf("response packet code mismatch: have %x, want %x", msg.Code, 0x0c)
}
var data [][]byte
if err := msg.Decode(&data); err != nil {
t.Fatalf("failed to decode response node data: %v", err)
}
// Verify that all hashes correspond to the requested data, and reconstruct a state tree
for i, want := range hashes {
if hash := crypto.Keccak256Hash(data[i]); hash != want {
t.Errorf("data hash mismatch: have %x, want %x", hash, want)
}
}
statedb := rawdb.NewMemoryDatabase()
for i := 0; i < len(data); i++ {
statedb.Put(hashes[i].Bytes(), data[i])
}
accounts := []common.Address{testBank, acc1Addr, acc2Addr}
for i := uint64(0); i <= pm.blockchain.CurrentBlock().NumberU64(); i++ {
trie, _ := state.New(pm.blockchain.GetBlockByNumber(i).Root(), state.NewDatabase(statedb), nil)
for j, acc := range accounts {
state, _ := pm.blockchain.State()
bw := state.GetBalance(acc)
bh := trie.GetBalance(acc)
if (bw != nil && bh == nil) || (bw == nil && bh != nil) {
t.Errorf("test %d, account %d: balance mismatch: have %v, want %v", i, j, bh, bw)
}
if bw != nil && bh != nil && bw.Cmp(bw) != 0 {
t.Errorf("test %d, account %d: balance mismatch: have %v, want %v", i, j, bh, bw)
}
}
}
} }
// Tests that the transaction receipts can be retrieved based on hashes. // AddRemotes appends a batch of transactions to the pool, and notifies any
func TestGetReceipt63(t *testing.T) { testGetReceipt(t, 63) } // listeners if the addition channel is non nil
func TestGetReceipt64(t *testing.T) { testGetReceipt(t, 64) } func (p *testTxPool) AddRemotes(txs []*types.Transaction) []error {
p.lock.Lock()
defer p.lock.Unlock()
func testGetReceipt(t *testing.T, protocol int) { for _, tx := range txs {
// Define three accounts to simulate transactions with p.pool[tx.Hash()] = tx
acc1Key, _ := crypto.HexToECDSA("8a1f9a8f95be41cd7ccb6168179afb4504aefe388d1e14474d32c45c72ce7b7a")
acc2Key, _ := crypto.HexToECDSA("49a7b37aa6f6645917e7b807e9d1c00d4fa71f18343b0d4122a4d2df64dd6fee")
acc1Addr := crypto.PubkeyToAddress(acc1Key.PublicKey)
acc2Addr := crypto.PubkeyToAddress(acc2Key.PublicKey)
signer := types.HomesteadSigner{}
// Create a chain generator with some simple transactions (blatantly stolen from @fjl/chain_markets_test)
generator := func(i int, block *core.BlockGen) {
switch i {
case 0:
// In block 1, the test bank sends account #1 some ether.
tx, _ := types.SignTx(types.NewTransaction(block.TxNonce(testBank), acc1Addr, big.NewInt(10000), params.TxGas, nil, nil), signer, testBankKey)
block.AddTx(tx)
case 1:
// In block 2, the test bank sends some more ether to account #1.
// acc1Addr passes it on to account #2.
tx1, _ := types.SignTx(types.NewTransaction(block.TxNonce(testBank), acc1Addr, big.NewInt(1000), params.TxGas, nil, nil), signer, testBankKey)
tx2, _ := types.SignTx(types.NewTransaction(block.TxNonce(acc1Addr), acc2Addr, big.NewInt(1000), params.TxGas, nil, nil), signer, acc1Key)
block.AddTx(tx1)
block.AddTx(tx2)
case 2:
// Block 3 is empty but was mined by account #2.
block.SetCoinbase(acc2Addr)
block.SetExtra([]byte("yeehaw"))
case 3:
// Block 4 includes blocks 2 and 3 as uncle headers (with modified extra data).
b2 := block.PrevBlock(1).Header()
b2.Extra = []byte("foo")
block.AddUncle(b2)
b3 := block.PrevBlock(2).Header()
b3.Extra = []byte("foo")
block.AddUncle(b3)
}
}
// Assemble the test environment
pm, _ := newTestProtocolManagerMust(t, downloader.FullSync, 4, generator, nil)
peer, _ := newTestPeer("peer", protocol, pm, true)
defer peer.close()
// Collect the hashes to request, and the response to expect
hashes, receipts := []common.Hash{}, []types.Receipts{}
for i := uint64(0); i <= pm.blockchain.CurrentBlock().NumberU64(); i++ {
block := pm.blockchain.GetBlockByNumber(i)
hashes = append(hashes, block.Hash())
receipts = append(receipts, pm.blockchain.GetReceiptsByHash(block.Hash()))
}
// Send the hash request and verify the response
p2p.Send(peer.app, 0x0f, hashes)
if err := p2p.ExpectMsg(peer.app, 0x10, receipts); err != nil {
t.Errorf("receipts mismatch: %v", err)
} }
p.txFeed.Send(core.NewTxsEvent{Txs: txs})
return make([]error, len(txs))
} }
// Tests that post eth protocol handshake, clients perform a mutual checkpoint // Pending returns all the transactions known to the pool
// challenge to validate each other's chains. Hash mismatches, or missing ones func (p *testTxPool) Pending() (map[common.Address]types.Transactions, error) {
// during a fast sync should lead to the peer getting dropped. p.lock.RLock()
func TestCheckpointChallenge(t *testing.T) { defer p.lock.RUnlock()
tests := []struct {
syncmode downloader.SyncMode
checkpoint bool
timeout bool
empty bool
match bool
drop bool
}{
// If checkpointing is not enabled locally, don't challenge and don't drop
{downloader.FullSync, false, false, false, false, false},
{downloader.FastSync, false, false, false, false, false},
// If checkpointing is enabled locally and remote response is empty, only drop during fast sync batches := make(map[common.Address]types.Transactions)
{downloader.FullSync, true, false, true, false, false}, for _, tx := range p.pool {
{downloader.FastSync, true, false, true, false, true}, // Special case, fast sync, unsynced peer from, _ := types.Sender(types.HomesteadSigner{}, tx)
batches[from] = append(batches[from], tx)
// If checkpointing is enabled locally and remote response mismatches, always drop
{downloader.FullSync, true, false, false, false, true},
{downloader.FastSync, true, false, false, false, true},
// If checkpointing is enabled locally and remote response matches, never drop
{downloader.FullSync, true, false, false, true, false},
{downloader.FastSync, true, false, false, true, false},
// If checkpointing is enabled locally and remote times out, always drop
{downloader.FullSync, true, true, false, true, true},
{downloader.FastSync, true, true, false, true, true},
} }
for _, tt := range tests { for _, batch := range batches {
t.Run(fmt.Sprintf("sync %v checkpoint %v timeout %v empty %v match %v", tt.syncmode, tt.checkpoint, tt.timeout, tt.empty, tt.match), func(t *testing.T) { sort.Sort(types.TxByNonce(batch))
testCheckpointChallenge(t, tt.syncmode, tt.checkpoint, tt.timeout, tt.empty, tt.match, tt.drop) }
return batches, nil
}
// SubscribeNewTxsEvent should return an event subscription of NewTxsEvent and
// send events to the given channel.
func (p *testTxPool) SubscribeNewTxsEvent(ch chan<- core.NewTxsEvent) event.Subscription {
return p.txFeed.Subscribe(ch)
}
// testHandler is a live implementation of the Ethereum protocol handler, just
// preinitialized with some sane testing defaults and the transaction pool mocked
// out.
type testHandler struct {
db ethdb.Database
chain *core.BlockChain
txpool *testTxPool
handler *handler
}
// newTestHandler creates a new handler for testing purposes with no blocks.
func newTestHandler() *testHandler {
return newTestHandlerWithBlocks(0)
}
// newTestHandlerWithBlocks creates a new handler for testing purposes, with a
// given number of initial blocks.
func newTestHandlerWithBlocks(blocks int) *testHandler {
// Create a database pre-initialize with a genesis block
db := rawdb.NewMemoryDatabase()
(&core.Genesis{
Config: params.TestChainConfig,
Alloc: core.GenesisAlloc{testAddr: {Balance: big.NewInt(1000000)}},
}).MustCommit(db)
chain, _ := core.NewBlockChain(db, nil, params.TestChainConfig, ethash.NewFaker(), vm.Config{}, nil, nil)
bs, _ := core.GenerateChain(params.TestChainConfig, chain.Genesis(), ethash.NewFaker(), db, blocks, nil)
if _, err := chain.InsertChain(bs); err != nil {
panic(err)
}
txpool := newTestTxPool()
handler, _ := newHandler(&handlerConfig{
Database: db,
Chain: chain,
TxPool: txpool,
Network: 1,
Sync: downloader.FastSync,
BloomCache: 1,
}) })
handler.Start(1000)
return &testHandler{
db: db,
chain: chain,
txpool: txpool,
handler: handler,
} }
} }
func testCheckpointChallenge(t *testing.T, syncmode downloader.SyncMode, checkpoint bool, timeout bool, empty bool, match bool, drop bool) { // close tears down the handler and all its internal constructs.
// Reduce the checkpoint handshake challenge timeout func (b *testHandler) close() {
defer func(old time.Duration) { syncChallengeTimeout = old }(syncChallengeTimeout) b.handler.Stop()
syncChallengeTimeout = 250 * time.Millisecond b.chain.Stop()
// Initialize a chain and generate a fake CHT if checkpointing is enabled
var (
db = rawdb.NewMemoryDatabase()
config = new(params.ChainConfig)
)
(&core.Genesis{Config: config}).MustCommit(db) // Commit genesis block
// If checkpointing is enabled, create and inject a fake CHT and the corresponding
// chllenge response.
var response *types.Header
var cht *params.TrustedCheckpoint
if checkpoint {
index := uint64(rand.Intn(500))
number := (index+1)*params.CHTFrequency - 1
response = &types.Header{Number: big.NewInt(int64(number)), Extra: []byte("valid")}
cht = &params.TrustedCheckpoint{
SectionIndex: index,
SectionHead: response.Hash(),
}
}
// Create a checkpoint aware protocol manager
blockchain, err := core.NewBlockChain(db, nil, config, ethash.NewFaker(), vm.Config{}, nil, nil)
if err != nil {
t.Fatalf("failed to create new blockchain: %v", err)
}
pm, err := NewProtocolManager(config, cht, syncmode, DefaultConfig.NetworkId, new(event.TypeMux), &testTxPool{pool: make(map[common.Hash]*types.Transaction)}, ethash.NewFaker(), blockchain, db, 1, nil)
if err != nil {
t.Fatalf("failed to start test protocol manager: %v", err)
}
pm.Start(1000)
defer pm.Stop()
// Connect a new peer and check that we receive the checkpoint challenge
peer, _ := newTestPeer("peer", eth63, pm, true)
defer peer.close()
if checkpoint {
challenge := &getBlockHeadersData{
Origin: hashOrNumber{Number: response.Number.Uint64()},
Amount: 1,
Skip: 0,
Reverse: false,
}
if err := p2p.ExpectMsg(peer.app, GetBlockHeadersMsg, challenge); err != nil {
t.Fatalf("challenge mismatch: %v", err)
}
// Create a block to reply to the challenge if no timeout is simulated
if !timeout {
if empty {
if err := p2p.Send(peer.app, BlockHeadersMsg, []*types.Header{}); err != nil {
t.Fatalf("failed to answer challenge: %v", err)
}
} else if match {
if err := p2p.Send(peer.app, BlockHeadersMsg, []*types.Header{response}); err != nil {
t.Fatalf("failed to answer challenge: %v", err)
}
} else {
if err := p2p.Send(peer.app, BlockHeadersMsg, []*types.Header{{Number: response.Number}}); err != nil {
t.Fatalf("failed to answer challenge: %v", err)
}
}
}
}
// Wait until the test timeout passes to ensure proper cleanup
time.Sleep(syncChallengeTimeout + 300*time.Millisecond)
// Verify that the remote peer is maintained or dropped
if drop {
if peers := pm.peers.Len(); peers != 0 {
t.Fatalf("peer count mismatch: have %d, want %d", peers, 0)
}
} else {
if peers := pm.peers.Len(); peers != 1 {
t.Fatalf("peer count mismatch: have %d, want %d", peers, 1)
}
}
}
func TestBroadcastBlock(t *testing.T) {
var tests = []struct {
totalPeers int
broadcastExpected int
}{
{1, 1},
{2, 1},
{3, 1},
{4, 2},
{5, 2},
{9, 3},
{12, 3},
{16, 4},
{26, 5},
{100, 10},
}
for _, test := range tests {
testBroadcastBlock(t, test.totalPeers, test.broadcastExpected)
}
}
func testBroadcastBlock(t *testing.T, totalPeers, broadcastExpected int) {
var (
evmux = new(event.TypeMux)
pow = ethash.NewFaker()
db = rawdb.NewMemoryDatabase()
config = &params.ChainConfig{}
gspec = &core.Genesis{Config: config}
genesis = gspec.MustCommit(db)
)
blockchain, err := core.NewBlockChain(db, nil, config, pow, vm.Config{}, nil, nil)
if err != nil {
t.Fatalf("failed to create new blockchain: %v", err)
}
pm, err := NewProtocolManager(config, nil, downloader.FullSync, DefaultConfig.NetworkId, evmux, &testTxPool{pool: make(map[common.Hash]*types.Transaction)}, pow, blockchain, db, 1, nil)
if err != nil {
t.Fatalf("failed to start test protocol manager: %v", err)
}
pm.Start(1000)
defer pm.Stop()
var peers []*testPeer
for i := 0; i < totalPeers; i++ {
peer, _ := newTestPeer(fmt.Sprintf("peer %d", i), eth63, pm, true)
defer peer.close()
peers = append(peers, peer)
}
chain, _ := core.GenerateChain(gspec.Config, genesis, ethash.NewFaker(), db, 1, func(i int, gen *core.BlockGen) {})
pm.BroadcastBlock(chain[0], true /*propagate*/)
errCh := make(chan error, totalPeers)
doneCh := make(chan struct{}, totalPeers)
for _, peer := range peers {
go func(p *testPeer) {
if err := p2p.ExpectMsg(p.app, NewBlockMsg, &newBlockData{Block: chain[0], TD: big.NewInt(131136)}); err != nil {
errCh <- err
} else {
doneCh <- struct{}{}
}
}(peer)
}
var received int
for {
select {
case <-doneCh:
received++
if received > broadcastExpected {
// We can bail early here
t.Errorf("broadcast count mismatch: have %d > want %d", received, broadcastExpected)
return
}
case <-time.After(2 * time.Second):
if received != broadcastExpected {
t.Errorf("broadcast count mismatch: have %d, want %d", received, broadcastExpected)
}
return
case err = <-errCh:
t.Fatalf("broadcast failed: %v", err)
}
}
}
// Tests that a propagated malformed block (uncles or transactions don't match
// with the hashes in the header) gets discarded and not broadcast forward.
func TestBroadcastMalformedBlock(t *testing.T) {
// Create a live node to test propagation with
var (
engine = ethash.NewFaker()
db = rawdb.NewMemoryDatabase()
config = &params.ChainConfig{}
gspec = &core.Genesis{Config: config}
genesis = gspec.MustCommit(db)
)
blockchain, err := core.NewBlockChain(db, nil, config, engine, vm.Config{}, nil, nil)
if err != nil {
t.Fatalf("failed to create new blockchain: %v", err)
}
pm, err := NewProtocolManager(config, nil, downloader.FullSync, DefaultConfig.NetworkId, new(event.TypeMux), new(testTxPool), engine, blockchain, db, 1, nil)
if err != nil {
t.Fatalf("failed to start test protocol manager: %v", err)
}
pm.Start(2)
defer pm.Stop()
// Create two peers, one to send the malformed block with and one to check
// propagation
source, _ := newTestPeer("source", eth63, pm, true)
defer source.close()
sink, _ := newTestPeer("sink", eth63, pm, true)
defer sink.close()
// Create various combinations of malformed blocks
chain, _ := core.GenerateChain(gspec.Config, genesis, ethash.NewFaker(), db, 1, func(i int, gen *core.BlockGen) {})
malformedUncles := chain[0].Header()
malformedUncles.UncleHash[0]++
malformedTransactions := chain[0].Header()
malformedTransactions.TxHash[0]++
malformedEverything := chain[0].Header()
malformedEverything.UncleHash[0]++
malformedEverything.TxHash[0]++
// Keep listening to broadcasts and notify if any arrives
notify := make(chan struct{}, 1)
go func() {
if _, err := sink.app.ReadMsg(); err == nil {
notify <- struct{}{}
}
}()
// Try to broadcast all malformations and ensure they all get discarded
for _, header := range []*types.Header{malformedUncles, malformedTransactions, malformedEverything} {
block := types.NewBlockWithHeader(header).WithBody(chain[0].Transactions(), chain[0].Uncles())
if err := p2p.Send(source.app, NewBlockMsg, []interface{}{block, big.NewInt(131136)}); err != nil {
t.Fatalf("failed to broadcast block: %v", err)
}
select {
case <-notify:
t.Fatalf("malformed block forwarded")
case <-time.After(100 * time.Millisecond):
}
}
} }

View File

@ -1,231 +0,0 @@
// Copyright 2015 The go-ethereum Authors
// This file is part of the go-ethereum library.
//
// The go-ethereum library is free software: you can redistribute it and/or modify
// it under the terms of the GNU Lesser General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// The go-ethereum library is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Lesser General Public License for more details.
//
// You should have received a copy of the GNU Lesser General Public License
// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
// This file contains some shares testing functionality, common to multiple
// different files and modules being tested.
package eth
import (
"crypto/ecdsa"
"crypto/rand"
"fmt"
"math/big"
"sort"
"sync"
"testing"
"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/consensus/ethash"
"github.com/ethereum/go-ethereum/core"
"github.com/ethereum/go-ethereum/core/forkid"
"github.com/ethereum/go-ethereum/core/rawdb"
"github.com/ethereum/go-ethereum/core/types"
"github.com/ethereum/go-ethereum/core/vm"
"github.com/ethereum/go-ethereum/crypto"
"github.com/ethereum/go-ethereum/eth/downloader"
"github.com/ethereum/go-ethereum/ethdb"
"github.com/ethereum/go-ethereum/event"
"github.com/ethereum/go-ethereum/p2p"
"github.com/ethereum/go-ethereum/p2p/enode"
"github.com/ethereum/go-ethereum/params"
)
var (
testBankKey, _ = crypto.HexToECDSA("b71c71a67e1177ad4e901695e1b4b9ee17ae16c6668d313eac2f96dbcda3f291")
testBank = crypto.PubkeyToAddress(testBankKey.PublicKey)
)
// newTestProtocolManager creates a new protocol manager for testing purposes,
// with the given number of blocks already known, and potential notification
// channels for different events.
func newTestProtocolManager(mode downloader.SyncMode, blocks int, generator func(int, *core.BlockGen), newtx chan<- []*types.Transaction) (*ProtocolManager, ethdb.Database, error) {
var (
evmux = new(event.TypeMux)
engine = ethash.NewFaker()
db = rawdb.NewMemoryDatabase()
gspec = &core.Genesis{
Config: params.TestChainConfig,
Alloc: core.GenesisAlloc{testBank: {Balance: big.NewInt(1000000)}},
}
genesis = gspec.MustCommit(db)
blockchain, _ = core.NewBlockChain(db, nil, gspec.Config, engine, vm.Config{}, nil, nil)
)
chain, _ := core.GenerateChain(gspec.Config, genesis, ethash.NewFaker(), db, blocks, generator)
if _, err := blockchain.InsertChain(chain); err != nil {
panic(err)
}
pm, err := NewProtocolManager(gspec.Config, nil, mode, DefaultConfig.NetworkId, evmux, &testTxPool{added: newtx, pool: make(map[common.Hash]*types.Transaction)}, engine, blockchain, db, 1, nil)
if err != nil {
return nil, nil, err
}
pm.Start(1000)
return pm, db, nil
}
// newTestProtocolManagerMust creates a new protocol manager for testing purposes,
// with the given number of blocks already known, and potential notification
// channels for different events. In case of an error, the constructor force-
// fails the test.
func newTestProtocolManagerMust(t *testing.T, mode downloader.SyncMode, blocks int, generator func(int, *core.BlockGen), newtx chan<- []*types.Transaction) (*ProtocolManager, ethdb.Database) {
pm, db, err := newTestProtocolManager(mode, blocks, generator, newtx)
if err != nil {
t.Fatalf("Failed to create protocol manager: %v", err)
}
return pm, db
}
// testTxPool is a fake, helper transaction pool for testing purposes
type testTxPool struct {
txFeed event.Feed
pool map[common.Hash]*types.Transaction // Hash map of collected transactions
added chan<- []*types.Transaction // Notification channel for new transactions
lock sync.RWMutex // Protects the transaction pool
}
// Has returns an indicator whether txpool has a transaction
// cached with the given hash.
func (p *testTxPool) Has(hash common.Hash) bool {
p.lock.Lock()
defer p.lock.Unlock()
return p.pool[hash] != nil
}
// Get retrieves the transaction from local txpool with given
// tx hash.
func (p *testTxPool) Get(hash common.Hash) *types.Transaction {
p.lock.Lock()
defer p.lock.Unlock()
return p.pool[hash]
}
// AddRemotes appends a batch of transactions to the pool, and notifies any
// listeners if the addition channel is non nil
func (p *testTxPool) AddRemotes(txs []*types.Transaction) []error {
p.lock.Lock()
defer p.lock.Unlock()
for _, tx := range txs {
p.pool[tx.Hash()] = tx
}
if p.added != nil {
p.added <- txs
}
p.txFeed.Send(core.NewTxsEvent{Txs: txs})
return make([]error, len(txs))
}
// Pending returns all the transactions known to the pool
func (p *testTxPool) Pending() (map[common.Address]types.Transactions, error) {
p.lock.RLock()
defer p.lock.RUnlock()
batches := make(map[common.Address]types.Transactions)
for _, tx := range p.pool {
from, _ := types.Sender(types.HomesteadSigner{}, tx)
batches[from] = append(batches[from], tx)
}
for _, batch := range batches {
sort.Sort(types.TxByNonce(batch))
}
return batches, nil
}
func (p *testTxPool) SubscribeNewTxsEvent(ch chan<- core.NewTxsEvent) event.Subscription {
return p.txFeed.Subscribe(ch)
}
// newTestTransaction create a new dummy transaction.
func newTestTransaction(from *ecdsa.PrivateKey, nonce uint64, datasize int) *types.Transaction {
tx := types.NewTransaction(nonce, common.Address{}, big.NewInt(0), 100000, big.NewInt(0), make([]byte, datasize))
tx, _ = types.SignTx(tx, types.HomesteadSigner{}, from)
return tx
}
// testPeer is a simulated peer to allow testing direct network calls.
type testPeer struct {
net p2p.MsgReadWriter // Network layer reader/writer to simulate remote messaging
app *p2p.MsgPipeRW // Application layer reader/writer to simulate the local side
*peer
}
// newTestPeer creates a new peer registered at the given protocol manager.
func newTestPeer(name string, version int, pm *ProtocolManager, shake bool) (*testPeer, <-chan error) {
// Create a message pipe to communicate through
app, net := p2p.MsgPipe()
// Start the peer on a new thread
var id enode.ID
rand.Read(id[:])
peer := pm.newPeer(version, p2p.NewPeer(id, name, nil), net, pm.txpool.Get)
errc := make(chan error, 1)
go func() { errc <- pm.runPeer(peer) }()
tp := &testPeer{app: app, net: net, peer: peer}
// Execute any implicitly requested handshakes and return
if shake {
var (
genesis = pm.blockchain.Genesis()
head = pm.blockchain.CurrentHeader()
td = pm.blockchain.GetTd(head.Hash(), head.Number.Uint64())
)
forkID := forkid.NewID(pm.blockchain.Config(), pm.blockchain.Genesis().Hash(), pm.blockchain.CurrentHeader().Number.Uint64())
tp.handshake(nil, td, head.Hash(), genesis.Hash(), forkID, forkid.NewFilter(pm.blockchain))
}
return tp, errc
}
// handshake simulates a trivial handshake that expects the same state from the
// remote side as we are simulating locally.
func (p *testPeer) handshake(t *testing.T, td *big.Int, head common.Hash, genesis common.Hash, forkID forkid.ID, forkFilter forkid.Filter) {
var msg interface{}
switch {
case p.version == eth63:
msg = &statusData63{
ProtocolVersion: uint32(p.version),
NetworkId: DefaultConfig.NetworkId,
TD: td,
CurrentBlock: head,
GenesisBlock: genesis,
}
case p.version >= eth64:
msg = &statusData{
ProtocolVersion: uint32(p.version),
NetworkID: DefaultConfig.NetworkId,
TD: td,
Head: head,
Genesis: genesis,
ForkID: forkID,
}
default:
panic(fmt.Sprintf("unsupported eth protocol version: %d", p.version))
}
if err := p2p.ExpectMsg(p.app, StatusMsg, msg); err != nil {
t.Fatalf("status recv: %v", err)
}
if err := p2p.Send(p.app, StatusMsg, msg); err != nil {
t.Fatalf("status send: %v", err)
}
}
// close terminates the local side of the peer, notifying the remote protocol
// manager of termination.
func (p *testPeer) close() {
p.app.Close()
}

View File

@ -17,806 +17,58 @@
package eth package eth
import ( import (
"errors"
"fmt"
"math/big" "math/big"
"sync" "sync"
"time" "time"
mapset "github.com/deckarep/golang-set" "github.com/ethereum/go-ethereum/eth/protocols/eth"
"github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/eth/protocols/snap"
"github.com/ethereum/go-ethereum/core/forkid"
"github.com/ethereum/go-ethereum/core/types"
"github.com/ethereum/go-ethereum/p2p"
"github.com/ethereum/go-ethereum/rlp"
) )
var ( // ethPeerInfo represents a short summary of the `eth` sub-protocol metadata known
errClosed = errors.New("peer set is closed")
errAlreadyRegistered = errors.New("peer is already registered")
errNotRegistered = errors.New("peer is not registered")
)
const (
maxKnownTxs = 32768 // Maximum transactions hashes to keep in the known list (prevent DOS)
maxKnownBlocks = 1024 // Maximum block hashes to keep in the known list (prevent DOS)
// maxQueuedTxs is the maximum number of transactions to queue up before dropping
// older broadcasts.
maxQueuedTxs = 4096
// maxQueuedTxAnns is the maximum number of transaction announcements to queue up
// before dropping older announcements.
maxQueuedTxAnns = 4096
// maxQueuedBlocks is the maximum number of block propagations to queue up before
// dropping broadcasts. There's not much point in queueing stale blocks, so a few
// that might cover uncles should be enough.
maxQueuedBlocks = 4
// maxQueuedBlockAnns is the maximum number of block announcements to queue up before
// dropping broadcasts. Similarly to block propagations, there's no point to queue
// above some healthy uncle limit, so use that.
maxQueuedBlockAnns = 4
handshakeTimeout = 5 * time.Second
)
// max is a helper function which returns the larger of the two given integers.
func max(a, b int) int {
if a > b {
return a
}
return b
}
// PeerInfo represents a short summary of the Ethereum sub-protocol metadata known
// about a connected peer. // about a connected peer.
type PeerInfo struct { type ethPeerInfo struct {
Version int `json:"version"` // Ethereum protocol version negotiated Version uint `json:"version"` // Ethereum protocol version negotiated
Difficulty *big.Int `json:"difficulty"` // Total difficulty of the peer's blockchain Difficulty *big.Int `json:"difficulty"` // Total difficulty of the peer's blockchain
Head string `json:"head"` // SHA3 hash of the peer's best owned block Head string `json:"head"` // Hex hash of the peer's best owned block
} }
// propEvent is a block propagation, waiting for its turn in the broadcast queue. // ethPeer is a wrapper around eth.Peer to maintain a few extra metadata.
type propEvent struct { type ethPeer struct {
block *types.Block *eth.Peer
td *big.Int
syncDrop *time.Timer // Connection dropper if `eth` sync progress isn't validated in time
lock sync.RWMutex // Mutex protecting the internal fields
} }
type peer struct { // info gathers and returns some `eth` protocol metadata known about a peer.
id string func (p *ethPeer) info() *ethPeerInfo {
*p2p.Peer
rw p2p.MsgReadWriter
version int // Protocol version negotiated
syncDrop *time.Timer // Timed connection dropper if sync progress isn't validated in time
head common.Hash
td *big.Int
lock sync.RWMutex
knownBlocks mapset.Set // Set of block hashes known to be known by this peer
queuedBlocks chan *propEvent // Queue of blocks to broadcast to the peer
queuedBlockAnns chan *types.Block // Queue of blocks to announce to the peer
knownTxs mapset.Set // Set of transaction hashes known to be known by this peer
txBroadcast chan []common.Hash // Channel used to queue transaction propagation requests
txAnnounce chan []common.Hash // Channel used to queue transaction announcement requests
getPooledTx func(common.Hash) *types.Transaction // Callback used to retrieve transaction from txpool
term chan struct{} // Termination channel to stop the broadcaster
}
func newPeer(version int, p *p2p.Peer, rw p2p.MsgReadWriter, getPooledTx func(hash common.Hash) *types.Transaction) *peer {
return &peer{
Peer: p,
rw: rw,
version: version,
id: fmt.Sprintf("%x", p.ID().Bytes()[:8]),
knownTxs: mapset.NewSet(),
knownBlocks: mapset.NewSet(),
queuedBlocks: make(chan *propEvent, maxQueuedBlocks),
queuedBlockAnns: make(chan *types.Block, maxQueuedBlockAnns),
txBroadcast: make(chan []common.Hash),
txAnnounce: make(chan []common.Hash),
getPooledTx: getPooledTx,
term: make(chan struct{}),
}
}
// broadcastBlocks is a write loop that multiplexes blocks and block accouncements
// to the remote peer. The goal is to have an async writer that does not lock up
// node internals and at the same time rate limits queued data.
func (p *peer) broadcastBlocks(removePeer func(string)) {
for {
select {
case prop := <-p.queuedBlocks:
if err := p.SendNewBlock(prop.block, prop.td); err != nil {
removePeer(p.id)
return
}
p.Log().Trace("Propagated block", "number", prop.block.Number(), "hash", prop.block.Hash(), "td", prop.td)
case block := <-p.queuedBlockAnns:
if err := p.SendNewBlockHashes([]common.Hash{block.Hash()}, []uint64{block.NumberU64()}); err != nil {
removePeer(p.id)
return
}
p.Log().Trace("Announced block", "number", block.Number(), "hash", block.Hash())
case <-p.term:
return
}
}
}
// broadcastTransactions is a write loop that schedules transaction broadcasts
// to the remote peer. The goal is to have an async writer that does not lock up
// node internals and at the same time rate limits queued data.
func (p *peer) broadcastTransactions(removePeer func(string)) {
var (
queue []common.Hash // Queue of hashes to broadcast as full transactions
done chan struct{} // Non-nil if background broadcaster is running
fail = make(chan error, 1) // Channel used to receive network error
)
for {
// If there's no in-flight broadcast running, check if a new one is needed
if done == nil && len(queue) > 0 {
// Pile transaction until we reach our allowed network limit
var (
hashes []common.Hash
txs []*types.Transaction
size common.StorageSize
)
for i := 0; i < len(queue) && size < txsyncPackSize; i++ {
if tx := p.getPooledTx(queue[i]); tx != nil {
txs = append(txs, tx)
size += tx.Size()
}
hashes = append(hashes, queue[i])
}
queue = queue[:copy(queue, queue[len(hashes):])]
// If there's anything available to transfer, fire up an async writer
if len(txs) > 0 {
done = make(chan struct{})
go func() {
if err := p.sendTransactions(txs); err != nil {
fail <- err
return
}
close(done)
p.Log().Trace("Sent transactions", "count", len(txs))
}()
}
}
// Transfer goroutine may or may not have been started, listen for events
select {
case hashes := <-p.txBroadcast:
// New batch of transactions to be broadcast, queue them (with cap)
queue = append(queue, hashes...)
if len(queue) > maxQueuedTxs {
// Fancy copy and resize to ensure buffer doesn't grow indefinitely
queue = queue[:copy(queue, queue[len(queue)-maxQueuedTxs:])]
}
case <-done:
done = nil
case <-fail:
removePeer(p.id)
return
case <-p.term:
return
}
}
}
// announceTransactions is a write loop that schedules transaction broadcasts
// to the remote peer. The goal is to have an async writer that does not lock up
// node internals and at the same time rate limits queued data.
func (p *peer) announceTransactions(removePeer func(string)) {
var (
queue []common.Hash // Queue of hashes to announce as transaction stubs
done chan struct{} // Non-nil if background announcer is running
fail = make(chan error, 1) // Channel used to receive network error
)
for {
// If there's no in-flight announce running, check if a new one is needed
if done == nil && len(queue) > 0 {
// Pile transaction hashes until we reach our allowed network limit
var (
hashes []common.Hash
pending []common.Hash
size common.StorageSize
)
for i := 0; i < len(queue) && size < txsyncPackSize; i++ {
if p.getPooledTx(queue[i]) != nil {
pending = append(pending, queue[i])
size += common.HashLength
}
hashes = append(hashes, queue[i])
}
queue = queue[:copy(queue, queue[len(hashes):])]
// If there's anything available to transfer, fire up an async writer
if len(pending) > 0 {
done = make(chan struct{})
go func() {
if err := p.sendPooledTransactionHashes(pending); err != nil {
fail <- err
return
}
close(done)
p.Log().Trace("Sent transaction announcements", "count", len(pending))
}()
}
}
// Transfer goroutine may or may not have been started, listen for events
select {
case hashes := <-p.txAnnounce:
// New batch of transactions to be broadcast, queue them (with cap)
queue = append(queue, hashes...)
if len(queue) > maxQueuedTxAnns {
// Fancy copy and resize to ensure buffer doesn't grow indefinitely
queue = queue[:copy(queue, queue[len(queue)-maxQueuedTxAnns:])]
}
case <-done:
done = nil
case <-fail:
removePeer(p.id)
return
case <-p.term:
return
}
}
}
// close signals the broadcast goroutine to terminate.
func (p *peer) close() {
close(p.term)
}
// Info gathers and returns a collection of metadata known about a peer.
func (p *peer) Info() *PeerInfo {
hash, td := p.Head() hash, td := p.Head()
return &PeerInfo{ return &ethPeerInfo{
Version: p.version, Version: p.Version(),
Difficulty: td, Difficulty: td,
Head: hash.Hex(), Head: hash.Hex(),
} }
} }
// Head retrieves a copy of the current head hash and total difficulty of the // snapPeerInfo represents a short summary of the `snap` sub-protocol metadata known
// peer. // about a connected peer.
func (p *peer) Head() (hash common.Hash, td *big.Int) { type snapPeerInfo struct {
p.lock.RLock() Version uint `json:"version"` // Snapshot protocol version negotiated
defer p.lock.RUnlock()
copy(hash[:], p.head[:])
return hash, new(big.Int).Set(p.td)
} }
// SetHead updates the head hash and total difficulty of the peer. // snapPeer is a wrapper around snap.Peer to maintain a few extra metadata.
func (p *peer) SetHead(hash common.Hash, td *big.Int) { type snapPeer struct {
p.lock.Lock() *snap.Peer
defer p.lock.Unlock()
copy(p.head[:], hash[:]) ethDrop *time.Timer // Connection dropper if `eth` doesn't connect in time
p.td.Set(td) lock sync.RWMutex // Mutex protecting the internal fields
} }
// MarkBlock marks a block as known for the peer, ensuring that the block will // info gathers and returns some `snap` protocol metadata known about a peer.
// never be propagated to this particular peer. func (p *snapPeer) info() *snapPeerInfo {
func (p *peer) MarkBlock(hash common.Hash) { return &snapPeerInfo{
// If we reached the memory allowance, drop a previously known block hash Version: p.Version(),
for p.knownBlocks.Cardinality() >= maxKnownBlocks {
p.knownBlocks.Pop()
} }
p.knownBlocks.Add(hash)
}
// MarkTransaction marks a transaction as known for the peer, ensuring that it
// will never be propagated to this particular peer.
func (p *peer) MarkTransaction(hash common.Hash) {
// If we reached the memory allowance, drop a previously known transaction hash
for p.knownTxs.Cardinality() >= maxKnownTxs {
p.knownTxs.Pop()
}
p.knownTxs.Add(hash)
}
// SendTransactions64 sends transactions to the peer and includes the hashes
// in its transaction hash set for future reference.
//
// This method is legacy support for initial transaction exchange in eth/64 and
// prior. For eth/65 and higher use SendPooledTransactionHashes.
func (p *peer) SendTransactions64(txs types.Transactions) error {
return p.sendTransactions(txs)
}
// sendTransactions sends transactions to the peer and includes the hashes
// in its transaction hash set for future reference.
//
// This method is a helper used by the async transaction sender. Don't call it
// directly as the queueing (memory) and transmission (bandwidth) costs should
// not be managed directly.
func (p *peer) sendTransactions(txs types.Transactions) error {
// Mark all the transactions as known, but ensure we don't overflow our limits
for p.knownTxs.Cardinality() > max(0, maxKnownTxs-len(txs)) {
p.knownTxs.Pop()
}
for _, tx := range txs {
p.knownTxs.Add(tx.Hash())
}
return p2p.Send(p.rw, TransactionMsg, txs)
}
// AsyncSendTransactions queues a list of transactions (by hash) to eventually
// propagate to a remote peer. The number of pending sends are capped (new ones
// will force old sends to be dropped)
func (p *peer) AsyncSendTransactions(hashes []common.Hash) {
select {
case p.txBroadcast <- hashes:
// Mark all the transactions as known, but ensure we don't overflow our limits
for p.knownTxs.Cardinality() > max(0, maxKnownTxs-len(hashes)) {
p.knownTxs.Pop()
}
for _, hash := range hashes {
p.knownTxs.Add(hash)
}
case <-p.term:
p.Log().Debug("Dropping transaction propagation", "count", len(hashes))
}
}
// sendPooledTransactionHashes sends transaction hashes to the peer and includes
// them in its transaction hash set for future reference.
//
// This method is a helper used by the async transaction announcer. Don't call it
// directly as the queueing (memory) and transmission (bandwidth) costs should
// not be managed directly.
func (p *peer) sendPooledTransactionHashes(hashes []common.Hash) error {
// Mark all the transactions as known, but ensure we don't overflow our limits
for p.knownTxs.Cardinality() > max(0, maxKnownTxs-len(hashes)) {
p.knownTxs.Pop()
}
for _, hash := range hashes {
p.knownTxs.Add(hash)
}
return p2p.Send(p.rw, NewPooledTransactionHashesMsg, hashes)
}
// AsyncSendPooledTransactionHashes queues a list of transactions hashes to eventually
// announce to a remote peer. The number of pending sends are capped (new ones
// will force old sends to be dropped)
func (p *peer) AsyncSendPooledTransactionHashes(hashes []common.Hash) {
select {
case p.txAnnounce <- hashes:
// Mark all the transactions as known, but ensure we don't overflow our limits
for p.knownTxs.Cardinality() > max(0, maxKnownTxs-len(hashes)) {
p.knownTxs.Pop()
}
for _, hash := range hashes {
p.knownTxs.Add(hash)
}
case <-p.term:
p.Log().Debug("Dropping transaction announcement", "count", len(hashes))
}
}
// SendPooledTransactionsRLP sends requested transactions to the peer and adds the
// hashes in its transaction hash set for future reference.
//
// Note, the method assumes the hashes are correct and correspond to the list of
// transactions being sent.
func (p *peer) SendPooledTransactionsRLP(hashes []common.Hash, txs []rlp.RawValue) error {
// Mark all the transactions as known, but ensure we don't overflow our limits
for p.knownTxs.Cardinality() > max(0, maxKnownTxs-len(hashes)) {
p.knownTxs.Pop()
}
for _, hash := range hashes {
p.knownTxs.Add(hash)
}
return p2p.Send(p.rw, PooledTransactionsMsg, txs)
}
// SendNewBlockHashes announces the availability of a number of blocks through
// a hash notification.
func (p *peer) SendNewBlockHashes(hashes []common.Hash, numbers []uint64) error {
// Mark all the block hashes as known, but ensure we don't overflow our limits
for p.knownBlocks.Cardinality() > max(0, maxKnownBlocks-len(hashes)) {
p.knownBlocks.Pop()
}
for _, hash := range hashes {
p.knownBlocks.Add(hash)
}
request := make(newBlockHashesData, len(hashes))
for i := 0; i < len(hashes); i++ {
request[i].Hash = hashes[i]
request[i].Number = numbers[i]
}
return p2p.Send(p.rw, NewBlockHashesMsg, request)
}
// AsyncSendNewBlockHash queues the availability of a block for propagation to a
// remote peer. If the peer's broadcast queue is full, the event is silently
// dropped.
func (p *peer) AsyncSendNewBlockHash(block *types.Block) {
select {
case p.queuedBlockAnns <- block:
// Mark all the block hash as known, but ensure we don't overflow our limits
for p.knownBlocks.Cardinality() >= maxKnownBlocks {
p.knownBlocks.Pop()
}
p.knownBlocks.Add(block.Hash())
default:
p.Log().Debug("Dropping block announcement", "number", block.NumberU64(), "hash", block.Hash())
}
}
// SendNewBlock propagates an entire block to a remote peer.
func (p *peer) SendNewBlock(block *types.Block, td *big.Int) error {
// Mark all the block hash as known, but ensure we don't overflow our limits
for p.knownBlocks.Cardinality() >= maxKnownBlocks {
p.knownBlocks.Pop()
}
p.knownBlocks.Add(block.Hash())
return p2p.Send(p.rw, NewBlockMsg, []interface{}{block, td})
}
// AsyncSendNewBlock queues an entire block for propagation to a remote peer. If
// the peer's broadcast queue is full, the event is silently dropped.
func (p *peer) AsyncSendNewBlock(block *types.Block, td *big.Int) {
select {
case p.queuedBlocks <- &propEvent{block: block, td: td}:
// Mark all the block hash as known, but ensure we don't overflow our limits
for p.knownBlocks.Cardinality() >= maxKnownBlocks {
p.knownBlocks.Pop()
}
p.knownBlocks.Add(block.Hash())
default:
p.Log().Debug("Dropping block propagation", "number", block.NumberU64(), "hash", block.Hash())
}
}
// SendBlockHeaders sends a batch of block headers to the remote peer.
func (p *peer) SendBlockHeaders(headers []*types.Header) error {
return p2p.Send(p.rw, BlockHeadersMsg, headers)
}
// SendBlockBodies sends a batch of block contents to the remote peer.
func (p *peer) SendBlockBodies(bodies []*blockBody) error {
return p2p.Send(p.rw, BlockBodiesMsg, blockBodiesData(bodies))
}
// SendBlockBodiesRLP sends a batch of block contents to the remote peer from
// an already RLP encoded format.
func (p *peer) SendBlockBodiesRLP(bodies []rlp.RawValue) error {
return p2p.Send(p.rw, BlockBodiesMsg, bodies)
}
// SendNodeDataRLP sends a batch of arbitrary internal data, corresponding to the
// hashes requested.
func (p *peer) SendNodeData(data [][]byte) error {
return p2p.Send(p.rw, NodeDataMsg, data)
}
// SendReceiptsRLP sends a batch of transaction receipts, corresponding to the
// ones requested from an already RLP encoded format.
func (p *peer) SendReceiptsRLP(receipts []rlp.RawValue) error {
return p2p.Send(p.rw, ReceiptsMsg, receipts)
}
// RequestOneHeader is a wrapper around the header query functions to fetch a
// single header. It is used solely by the fetcher.
func (p *peer) RequestOneHeader(hash common.Hash) error {
p.Log().Debug("Fetching single header", "hash", hash)
return p2p.Send(p.rw, GetBlockHeadersMsg, &getBlockHeadersData{Origin: hashOrNumber{Hash: hash}, Amount: uint64(1), Skip: uint64(0), Reverse: false})
}
// RequestHeadersByHash fetches a batch of blocks' headers corresponding to the
// specified header query, based on the hash of an origin block.
func (p *peer) RequestHeadersByHash(origin common.Hash, amount int, skip int, reverse bool) error {
p.Log().Debug("Fetching batch of headers", "count", amount, "fromhash", origin, "skip", skip, "reverse", reverse)
return p2p.Send(p.rw, GetBlockHeadersMsg, &getBlockHeadersData{Origin: hashOrNumber{Hash: origin}, Amount: uint64(amount), Skip: uint64(skip), Reverse: reverse})
}
// RequestHeadersByNumber fetches a batch of blocks' headers corresponding to the
// specified header query, based on the number of an origin block.
func (p *peer) RequestHeadersByNumber(origin uint64, amount int, skip int, reverse bool) error {
p.Log().Debug("Fetching batch of headers", "count", amount, "fromnum", origin, "skip", skip, "reverse", reverse)
return p2p.Send(p.rw, GetBlockHeadersMsg, &getBlockHeadersData{Origin: hashOrNumber{Number: origin}, Amount: uint64(amount), Skip: uint64(skip), Reverse: reverse})
}
// RequestBodies fetches a batch of blocks' bodies corresponding to the hashes
// specified.
func (p *peer) RequestBodies(hashes []common.Hash) error {
p.Log().Debug("Fetching batch of block bodies", "count", len(hashes))
return p2p.Send(p.rw, GetBlockBodiesMsg, hashes)
}
// RequestNodeData fetches a batch of arbitrary data from a node's known state
// data, corresponding to the specified hashes.
func (p *peer) RequestNodeData(hashes []common.Hash) error {
p.Log().Debug("Fetching batch of state data", "count", len(hashes))
return p2p.Send(p.rw, GetNodeDataMsg, hashes)
}
// RequestReceipts fetches a batch of transaction receipts from a remote node.
func (p *peer) RequestReceipts(hashes []common.Hash) error {
p.Log().Debug("Fetching batch of receipts", "count", len(hashes))
return p2p.Send(p.rw, GetReceiptsMsg, hashes)
}
// RequestTxs fetches a batch of transactions from a remote node.
func (p *peer) RequestTxs(hashes []common.Hash) error {
p.Log().Debug("Fetching batch of transactions", "count", len(hashes))
return p2p.Send(p.rw, GetPooledTransactionsMsg, hashes)
}
// Handshake executes the eth protocol handshake, negotiating version number,
// network IDs, difficulties, head and genesis blocks.
func (p *peer) Handshake(network uint64, td *big.Int, head common.Hash, genesis common.Hash, forkID forkid.ID, forkFilter forkid.Filter) error {
// Send out own handshake in a new thread
errc := make(chan error, 2)
var (
status63 statusData63 // safe to read after two values have been received from errc
status statusData // safe to read after two values have been received from errc
)
go func() {
switch {
case p.version == eth63:
errc <- p2p.Send(p.rw, StatusMsg, &statusData63{
ProtocolVersion: uint32(p.version),
NetworkId: network,
TD: td,
CurrentBlock: head,
GenesisBlock: genesis,
})
case p.version >= eth64:
errc <- p2p.Send(p.rw, StatusMsg, &statusData{
ProtocolVersion: uint32(p.version),
NetworkID: network,
TD: td,
Head: head,
Genesis: genesis,
ForkID: forkID,
})
default:
panic(fmt.Sprintf("unsupported eth protocol version: %d", p.version))
}
}()
go func() {
switch {
case p.version == eth63:
errc <- p.readStatusLegacy(network, &status63, genesis)
case p.version >= eth64:
errc <- p.readStatus(network, &status, genesis, forkFilter)
default:
panic(fmt.Sprintf("unsupported eth protocol version: %d", p.version))
}
}()
timeout := time.NewTimer(handshakeTimeout)
defer timeout.Stop()
for i := 0; i < 2; i++ {
select {
case err := <-errc:
if err != nil {
return err
}
case <-timeout.C:
return p2p.DiscReadTimeout
}
}
switch {
case p.version == eth63:
p.td, p.head = status63.TD, status63.CurrentBlock
case p.version >= eth64:
p.td, p.head = status.TD, status.Head
default:
panic(fmt.Sprintf("unsupported eth protocol version: %d", p.version))
}
return nil
}
func (p *peer) readStatusLegacy(network uint64, status *statusData63, genesis common.Hash) error {
msg, err := p.rw.ReadMsg()
if err != nil {
return err
}
if msg.Code != StatusMsg {
return errResp(ErrNoStatusMsg, "first msg has code %x (!= %x)", msg.Code, StatusMsg)
}
if msg.Size > protocolMaxMsgSize {
return errResp(ErrMsgTooLarge, "%v > %v", msg.Size, protocolMaxMsgSize)
}
// Decode the handshake and make sure everything matches
if err := msg.Decode(&status); err != nil {
return errResp(ErrDecode, "msg %v: %v", msg, err)
}
if status.GenesisBlock != genesis {
return errResp(ErrGenesisMismatch, "%x (!= %x)", status.GenesisBlock[:8], genesis[:8])
}
if status.NetworkId != network {
return errResp(ErrNetworkIDMismatch, "%d (!= %d)", status.NetworkId, network)
}
if int(status.ProtocolVersion) != p.version {
return errResp(ErrProtocolVersionMismatch, "%d (!= %d)", status.ProtocolVersion, p.version)
}
return nil
}
func (p *peer) readStatus(network uint64, status *statusData, genesis common.Hash, forkFilter forkid.Filter) error {
msg, err := p.rw.ReadMsg()
if err != nil {
return err
}
if msg.Code != StatusMsg {
return errResp(ErrNoStatusMsg, "first msg has code %x (!= %x)", msg.Code, StatusMsg)
}
if msg.Size > protocolMaxMsgSize {
return errResp(ErrMsgTooLarge, "%v > %v", msg.Size, protocolMaxMsgSize)
}
// Decode the handshake and make sure everything matches
if err := msg.Decode(&status); err != nil {
return errResp(ErrDecode, "msg %v: %v", msg, err)
}
if status.NetworkID != network {
return errResp(ErrNetworkIDMismatch, "%d (!= %d)", status.NetworkID, network)
}
if int(status.ProtocolVersion) != p.version {
return errResp(ErrProtocolVersionMismatch, "%d (!= %d)", status.ProtocolVersion, p.version)
}
if status.Genesis != genesis {
return errResp(ErrGenesisMismatch, "%x (!= %x)", status.Genesis, genesis)
}
if err := forkFilter(status.ForkID); err != nil {
return errResp(ErrForkIDRejected, "%v", err)
}
return nil
}
// String implements fmt.Stringer.
func (p *peer) String() string {
return fmt.Sprintf("Peer %s [%s]", p.id,
fmt.Sprintf("eth/%2d", p.version),
)
}
// peerSet represents the collection of active peers currently participating in
// the Ethereum sub-protocol.
type peerSet struct {
peers map[string]*peer
lock sync.RWMutex
closed bool
}
// newPeerSet creates a new peer set to track the active participants.
func newPeerSet() *peerSet {
return &peerSet{
peers: make(map[string]*peer),
}
}
// Register injects a new peer into the working set, or returns an error if the
// peer is already known. If a new peer it registered, its broadcast loop is also
// started.
func (ps *peerSet) Register(p *peer, removePeer func(string)) error {
ps.lock.Lock()
defer ps.lock.Unlock()
if ps.closed {
return errClosed
}
if _, ok := ps.peers[p.id]; ok {
return errAlreadyRegistered
}
ps.peers[p.id] = p
go p.broadcastBlocks(removePeer)
go p.broadcastTransactions(removePeer)
if p.version >= eth65 {
go p.announceTransactions(removePeer)
}
return nil
}
// Unregister removes a remote peer from the active set, disabling any further
// actions to/from that particular entity.
func (ps *peerSet) Unregister(id string) error {
ps.lock.Lock()
defer ps.lock.Unlock()
p, ok := ps.peers[id]
if !ok {
return errNotRegistered
}
delete(ps.peers, id)
p.close()
return nil
}
// Peer retrieves the registered peer with the given id.
func (ps *peerSet) Peer(id string) *peer {
ps.lock.RLock()
defer ps.lock.RUnlock()
return ps.peers[id]
}
// Len returns if the current number of peers in the set.
func (ps *peerSet) Len() int {
ps.lock.RLock()
defer ps.lock.RUnlock()
return len(ps.peers)
}
// PeersWithoutBlock retrieves a list of peers that do not have a given block in
// their set of known hashes.
func (ps *peerSet) PeersWithoutBlock(hash common.Hash) []*peer {
ps.lock.RLock()
defer ps.lock.RUnlock()
list := make([]*peer, 0, len(ps.peers))
for _, p := range ps.peers {
if !p.knownBlocks.Contains(hash) {
list = append(list, p)
}
}
return list
}
// PeersWithoutTx retrieves a list of peers that do not have a given transaction
// in their set of known hashes.
func (ps *peerSet) PeersWithoutTx(hash common.Hash) []*peer {
ps.lock.RLock()
defer ps.lock.RUnlock()
list := make([]*peer, 0, len(ps.peers))
for _, p := range ps.peers {
if !p.knownTxs.Contains(hash) {
list = append(list, p)
}
}
return list
}
// BestPeer retrieves the known peer with the currently highest total difficulty.
func (ps *peerSet) BestPeer() *peer {
ps.lock.RLock()
defer ps.lock.RUnlock()
var (
bestPeer *peer
bestTd *big.Int
)
for _, p := range ps.peers {
if _, td := p.Head(); bestPeer == nil || td.Cmp(bestTd) > 0 {
bestPeer, bestTd = p, td
}
}
return bestPeer
}
// Close disconnects all peers.
// No new peers can be registered after Close has returned.
func (ps *peerSet) Close() {
ps.lock.Lock()
defer ps.lock.Unlock()
for _, p := range ps.peers {
p.Disconnect(p2p.DiscQuitting)
}
ps.closed = true
} }

301
eth/peerset.go Normal file
View File

@ -0,0 +1,301 @@
// Copyright 2020 The go-ethereum Authors
// This file is part of the go-ethereum library.
//
// The go-ethereum library is free software: you can redistribute it and/or modify
// it under the terms of the GNU Lesser General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// The go-ethereum library is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Lesser General Public License for more details.
//
// You should have received a copy of the GNU Lesser General Public License
// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
package eth
import (
"errors"
"math/big"
"sync"
"time"
"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/eth/protocols/eth"
"github.com/ethereum/go-ethereum/eth/protocols/snap"
"github.com/ethereum/go-ethereum/event"
"github.com/ethereum/go-ethereum/p2p"
)
var (
// errPeerSetClosed is returned if a peer is attempted to be added or removed
// from the peer set after it has been terminated.
errPeerSetClosed = errors.New("peerset closed")
// errPeerAlreadyRegistered is returned if a peer is attempted to be added
// to the peer set, but one with the same id already exists.
errPeerAlreadyRegistered = errors.New("peer already registered")
// errPeerNotRegistered is returned if a peer is attempted to be removed from
// a peer set, but no peer with the given id exists.
errPeerNotRegistered = errors.New("peer not registered")
// ethConnectTimeout is the `snap` timeout for `eth` to connect too.
ethConnectTimeout = 3 * time.Second
)
// peerSet represents the collection of active peers currently participating in
// the `eth` or `snap` protocols.
type peerSet struct {
ethPeers map[string]*ethPeer // Peers connected on the `eth` protocol
snapPeers map[string]*snapPeer // Peers connected on the `snap` protocol
ethJoinFeed event.Feed // Events when an `eth` peer successfully joins
ethDropFeed event.Feed // Events when an `eth` peer gets dropped
snapJoinFeed event.Feed // Events when a `snap` peer joins on both `eth` and `snap`
snapDropFeed event.Feed // Events when a `snap` peer gets dropped (only if fully joined)
scope event.SubscriptionScope // Subscription group to unsubscribe everyone at once
lock sync.RWMutex
closed bool
}
// newPeerSet creates a new peer set to track the active participants.
func newPeerSet() *peerSet {
return &peerSet{
ethPeers: make(map[string]*ethPeer),
snapPeers: make(map[string]*snapPeer),
}
}
// subscribeEthJoin registers a subscription for peers joining (and completing
// the handshake) on the `eth` protocol.
func (ps *peerSet) subscribeEthJoin(ch chan<- *eth.Peer) event.Subscription {
return ps.scope.Track(ps.ethJoinFeed.Subscribe(ch))
}
// subscribeEthDrop registers a subscription for peers being dropped from the
// `eth` protocol.
func (ps *peerSet) subscribeEthDrop(ch chan<- *eth.Peer) event.Subscription {
return ps.scope.Track(ps.ethDropFeed.Subscribe(ch))
}
// subscribeSnapJoin registers a subscription for peers joining (and completing
// the `eth` join) on the `snap` protocol.
func (ps *peerSet) subscribeSnapJoin(ch chan<- *snap.Peer) event.Subscription {
return ps.scope.Track(ps.snapJoinFeed.Subscribe(ch))
}
// subscribeSnapDrop registers a subscription for peers being dropped from the
// `snap` protocol.
func (ps *peerSet) subscribeSnapDrop(ch chan<- *snap.Peer) event.Subscription {
return ps.scope.Track(ps.snapDropFeed.Subscribe(ch))
}
// registerEthPeer injects a new `eth` peer into the working set, or returns an
// error if the peer is already known. The peer is announced on the `eth` join
// feed and if it completes a pending `snap` peer, also on that feed.
func (ps *peerSet) registerEthPeer(peer *eth.Peer) error {
ps.lock.Lock()
if ps.closed {
ps.lock.Unlock()
return errPeerSetClosed
}
id := peer.ID()
if _, ok := ps.ethPeers[id]; ok {
ps.lock.Unlock()
return errPeerAlreadyRegistered
}
ps.ethPeers[id] = &ethPeer{Peer: peer}
snap, ok := ps.snapPeers[id]
ps.lock.Unlock()
if ok {
// Previously dangling `snap` peer, stop it's timer since `eth` connected
snap.lock.Lock()
if snap.ethDrop != nil {
snap.ethDrop.Stop()
snap.ethDrop = nil
}
snap.lock.Unlock()
}
ps.ethJoinFeed.Send(peer)
if ok {
ps.snapJoinFeed.Send(snap.Peer)
}
return nil
}
// unregisterEthPeer removes a remote peer from the active set, disabling any further
// actions to/from that particular entity. The drop is announced on the `eth` drop
// feed and also on the `snap` feed if the eth/snap duality was broken just now.
func (ps *peerSet) unregisterEthPeer(id string) error {
ps.lock.Lock()
eth, ok := ps.ethPeers[id]
if !ok {
ps.lock.Unlock()
return errPeerNotRegistered
}
delete(ps.ethPeers, id)
snap, ok := ps.snapPeers[id]
ps.lock.Unlock()
ps.ethDropFeed.Send(eth)
if ok {
ps.snapDropFeed.Send(snap)
}
return nil
}
// registerSnapPeer injects a new `snap` peer into the working set, or returns
// an error if the peer is already known. The peer is announced on the `snap`
// join feed if it completes an existing `eth` peer.
//
// If the peer isn't yet connected on `eth` and fails to do so within a given
// amount of time, it is dropped. This enforces that `snap` is an extension to
// `eth`, not a standalone leeching protocol.
func (ps *peerSet) registerSnapPeer(peer *snap.Peer) error {
ps.lock.Lock()
if ps.closed {
ps.lock.Unlock()
return errPeerSetClosed
}
id := peer.ID()
if _, ok := ps.snapPeers[id]; ok {
ps.lock.Unlock()
return errPeerAlreadyRegistered
}
ps.snapPeers[id] = &snapPeer{Peer: peer}
_, ok := ps.ethPeers[id]
if !ok {
// Dangling `snap` peer, start a timer to drop if `eth` doesn't connect
ps.snapPeers[id].ethDrop = time.AfterFunc(ethConnectTimeout, func() {
peer.Log().Warn("Snapshot peer missing eth, dropping", "addr", peer.RemoteAddr(), "type", peer.Name())
peer.Disconnect(p2p.DiscUselessPeer)
})
}
ps.lock.Unlock()
if ok {
ps.snapJoinFeed.Send(peer)
}
return nil
}
// unregisterSnapPeer removes a remote peer from the active set, disabling any
// further actions to/from that particular entity. The drop is announced on the
// `snap` drop feed.
func (ps *peerSet) unregisterSnapPeer(id string) error {
ps.lock.Lock()
peer, ok := ps.snapPeers[id]
if !ok {
ps.lock.Unlock()
return errPeerNotRegistered
}
delete(ps.snapPeers, id)
ps.lock.Unlock()
peer.lock.Lock()
if peer.ethDrop != nil {
peer.ethDrop.Stop()
peer.ethDrop = nil
}
peer.lock.Unlock()
ps.snapDropFeed.Send(peer)
return nil
}
// ethPeer retrieves the registered `eth` peer with the given id.
func (ps *peerSet) ethPeer(id string) *ethPeer {
ps.lock.RLock()
defer ps.lock.RUnlock()
return ps.ethPeers[id]
}
// snapPeer retrieves the registered `snap` peer with the given id.
func (ps *peerSet) snapPeer(id string) *snapPeer {
ps.lock.RLock()
defer ps.lock.RUnlock()
return ps.snapPeers[id]
}
// ethPeersWithoutBlock retrieves a list of `eth` peers that do not have a given
// block in their set of known hashes so it might be propagated to them.
func (ps *peerSet) ethPeersWithoutBlock(hash common.Hash) []*ethPeer {
ps.lock.RLock()
defer ps.lock.RUnlock()
list := make([]*ethPeer, 0, len(ps.ethPeers))
for _, p := range ps.ethPeers {
if !p.KnownBlock(hash) {
list = append(list, p)
}
}
return list
}
// ethPeersWithoutTransacion retrieves a list of `eth` peers that do not have a
// given transaction in their set of known hashes.
func (ps *peerSet) ethPeersWithoutTransacion(hash common.Hash) []*ethPeer {
ps.lock.RLock()
defer ps.lock.RUnlock()
list := make([]*ethPeer, 0, len(ps.ethPeers))
for _, p := range ps.ethPeers {
if !p.KnownTransaction(hash) {
list = append(list, p)
}
}
return list
}
// Len returns if the current number of `eth` peers in the set. Since the `snap`
// peers are tied to the existnce of an `eth` connection, that will always be a
// subset of `eth`.
func (ps *peerSet) Len() int {
ps.lock.RLock()
defer ps.lock.RUnlock()
return len(ps.ethPeers)
}
// ethPeerWithHighestTD retrieves the known peer with the currently highest total
// difficulty.
func (ps *peerSet) ethPeerWithHighestTD() *eth.Peer {
ps.lock.RLock()
defer ps.lock.RUnlock()
var (
bestPeer *eth.Peer
bestTd *big.Int
)
for _, p := range ps.ethPeers {
if _, td := p.Head(); bestPeer == nil || td.Cmp(bestTd) > 0 {
bestPeer, bestTd = p.Peer, td
}
}
return bestPeer
}
// close disconnects all peers.
func (ps *peerSet) close() {
ps.lock.Lock()
defer ps.lock.Unlock()
for _, p := range ps.ethPeers {
p.Disconnect(p2p.DiscQuitting)
}
for _, p := range ps.snapPeers {
p.Disconnect(p2p.DiscQuitting)
}
ps.closed = true
}

View File

@ -1,221 +0,0 @@
// Copyright 2014 The go-ethereum Authors
// This file is part of the go-ethereum library.
//
// The go-ethereum library is free software: you can redistribute it and/or modify
// it under the terms of the GNU Lesser General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// The go-ethereum library is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Lesser General Public License for more details.
//
// You should have received a copy of the GNU Lesser General Public License
// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
package eth
import (
"fmt"
"io"
"math/big"
"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/core"
"github.com/ethereum/go-ethereum/core/forkid"
"github.com/ethereum/go-ethereum/core/types"
"github.com/ethereum/go-ethereum/event"
"github.com/ethereum/go-ethereum/rlp"
)
// Constants to match up protocol versions and messages
const (
eth63 = 63
eth64 = 64
eth65 = 65
)
// protocolName is the official short name of the protocol used during capability negotiation.
const protocolName = "eth"
// ProtocolVersions are the supported versions of the eth protocol (first is primary).
var ProtocolVersions = []uint{eth65, eth64, eth63}
// protocolLengths are the number of implemented message corresponding to different protocol versions.
var protocolLengths = map[uint]uint64{eth65: 17, eth64: 17, eth63: 17}
const protocolMaxMsgSize = 10 * 1024 * 1024 // Maximum cap on the size of a protocol message
// eth protocol message codes
const (
StatusMsg = 0x00
NewBlockHashesMsg = 0x01
TransactionMsg = 0x02
GetBlockHeadersMsg = 0x03
BlockHeadersMsg = 0x04
GetBlockBodiesMsg = 0x05
BlockBodiesMsg = 0x06
NewBlockMsg = 0x07
GetNodeDataMsg = 0x0d
NodeDataMsg = 0x0e
GetReceiptsMsg = 0x0f
ReceiptsMsg = 0x10
// New protocol message codes introduced in eth65
//
// Previously these message ids were used by some legacy and unsupported
// eth protocols, reown them here.
NewPooledTransactionHashesMsg = 0x08
GetPooledTransactionsMsg = 0x09
PooledTransactionsMsg = 0x0a
)
type errCode int
const (
ErrMsgTooLarge = iota
ErrDecode
ErrInvalidMsgCode
ErrProtocolVersionMismatch
ErrNetworkIDMismatch
ErrGenesisMismatch
ErrForkIDRejected
ErrNoStatusMsg
ErrExtraStatusMsg
)
func (e errCode) String() string {
return errorToString[int(e)]
}
// XXX change once legacy code is out
var errorToString = map[int]string{
ErrMsgTooLarge: "Message too long",
ErrDecode: "Invalid message",
ErrInvalidMsgCode: "Invalid message code",
ErrProtocolVersionMismatch: "Protocol version mismatch",
ErrNetworkIDMismatch: "Network ID mismatch",
ErrGenesisMismatch: "Genesis mismatch",
ErrForkIDRejected: "Fork ID rejected",
ErrNoStatusMsg: "No status message",
ErrExtraStatusMsg: "Extra status message",
}
type txPool interface {
// Has returns an indicator whether txpool has a transaction
// cached with the given hash.
Has(hash common.Hash) bool
// Get retrieves the transaction from local txpool with given
// tx hash.
Get(hash common.Hash) *types.Transaction
// AddRemotes should add the given transactions to the pool.
AddRemotes([]*types.Transaction) []error
// Pending should return pending transactions.
// The slice should be modifiable by the caller.
Pending() (map[common.Address]types.Transactions, error)
// SubscribeNewTxsEvent should return an event subscription of
// NewTxsEvent and send events to the given channel.
SubscribeNewTxsEvent(chan<- core.NewTxsEvent) event.Subscription
}
// statusData63 is the network packet for the status message for eth/63.
type statusData63 struct {
ProtocolVersion uint32
NetworkId uint64
TD *big.Int
CurrentBlock common.Hash
GenesisBlock common.Hash
}
// statusData is the network packet for the status message for eth/64 and later.
type statusData struct {
ProtocolVersion uint32
NetworkID uint64
TD *big.Int
Head common.Hash
Genesis common.Hash
ForkID forkid.ID
}
// newBlockHashesData is the network packet for the block announcements.
type newBlockHashesData []struct {
Hash common.Hash // Hash of one particular block being announced
Number uint64 // Number of one particular block being announced
}
// getBlockHeadersData represents a block header query.
type getBlockHeadersData struct {
Origin hashOrNumber // Block from which to retrieve headers
Amount uint64 // Maximum number of headers to retrieve
Skip uint64 // Blocks to skip between consecutive headers
Reverse bool // Query direction (false = rising towards latest, true = falling towards genesis)
}
// hashOrNumber is a combined field for specifying an origin block.
type hashOrNumber struct {
Hash common.Hash // Block hash from which to retrieve headers (excludes Number)
Number uint64 // Block hash from which to retrieve headers (excludes Hash)
}
// EncodeRLP is a specialized encoder for hashOrNumber to encode only one of the
// two contained union fields.
func (hn *hashOrNumber) EncodeRLP(w io.Writer) error {
if hn.Hash == (common.Hash{}) {
return rlp.Encode(w, hn.Number)
}
if hn.Number != 0 {
return fmt.Errorf("both origin hash (%x) and number (%d) provided", hn.Hash, hn.Number)
}
return rlp.Encode(w, hn.Hash)
}
// DecodeRLP is a specialized decoder for hashOrNumber to decode the contents
// into either a block hash or a block number.
func (hn *hashOrNumber) DecodeRLP(s *rlp.Stream) error {
_, size, _ := s.Kind()
origin, err := s.Raw()
if err == nil {
switch {
case size == 32:
err = rlp.DecodeBytes(origin, &hn.Hash)
case size <= 8:
err = rlp.DecodeBytes(origin, &hn.Number)
default:
err = fmt.Errorf("invalid input size %d for origin", size)
}
}
return err
}
// newBlockData is the network packet for the block propagation message.
type newBlockData struct {
Block *types.Block
TD *big.Int
}
// sanityCheck verifies that the values are reasonable, as a DoS protection
func (request *newBlockData) sanityCheck() error {
if err := request.Block.SanityCheck(); err != nil {
return err
}
//TD at mainnet block #7753254 is 76 bits. If it becomes 100 million times
// larger, it will still fit within 100 bits
if tdlen := request.TD.BitLen(); tdlen > 100 {
return fmt.Errorf("too large block TD: bitlen %d", tdlen)
}
return nil
}
// blockBody represents the data content of a single block.
type blockBody struct {
Transactions []*types.Transaction // Transactions contained within a block
Uncles []*types.Header // Uncles contained within a block
}
// blockBodiesData is the network packet for block content distribution.
type blockBodiesData []*blockBody

View File

@ -1,459 +0,0 @@
// Copyright 2014 The go-ethereum Authors
// This file is part of the go-ethereum library.
//
// The go-ethereum library is free software: you can redistribute it and/or modify
// it under the terms of the GNU Lesser General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// The go-ethereum library is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Lesser General Public License for more details.
//
// You should have received a copy of the GNU Lesser General Public License
// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
package eth
import (
"fmt"
"math/big"
"sync"
"sync/atomic"
"testing"
"time"
"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/consensus/ethash"
"github.com/ethereum/go-ethereum/core"
"github.com/ethereum/go-ethereum/core/forkid"
"github.com/ethereum/go-ethereum/core/rawdb"
"github.com/ethereum/go-ethereum/core/types"
"github.com/ethereum/go-ethereum/core/vm"
"github.com/ethereum/go-ethereum/crypto"
"github.com/ethereum/go-ethereum/eth/downloader"
"github.com/ethereum/go-ethereum/event"
"github.com/ethereum/go-ethereum/p2p"
"github.com/ethereum/go-ethereum/p2p/enode"
"github.com/ethereum/go-ethereum/params"
"github.com/ethereum/go-ethereum/rlp"
)
func init() {
// log.Root().SetHandler(log.LvlFilterHandler(log.LvlTrace, log.StreamHandler(os.Stderr, log.TerminalFormat(false))))
}
var testAccount, _ = crypto.HexToECDSA("b71c71a67e1177ad4e901695e1b4b9ee17ae16c6668d313eac2f96dbcda3f291")
// Tests that handshake failures are detected and reported correctly.
func TestStatusMsgErrors63(t *testing.T) {
pm, _ := newTestProtocolManagerMust(t, downloader.FullSync, 0, nil, nil)
var (
genesis = pm.blockchain.Genesis()
head = pm.blockchain.CurrentHeader()
td = pm.blockchain.GetTd(head.Hash(), head.Number.Uint64())
)
defer pm.Stop()
tests := []struct {
code uint64
data interface{}
wantError error
}{
{
code: TransactionMsg, data: []interface{}{},
wantError: errResp(ErrNoStatusMsg, "first msg has code 2 (!= 0)"),
},
{
code: StatusMsg, data: statusData63{10, DefaultConfig.NetworkId, td, head.Hash(), genesis.Hash()},
wantError: errResp(ErrProtocolVersionMismatch, "10 (!= %d)", 63),
},
{
code: StatusMsg, data: statusData63{63, 999, td, head.Hash(), genesis.Hash()},
wantError: errResp(ErrNetworkIDMismatch, "999 (!= %d)", DefaultConfig.NetworkId),
},
{
code: StatusMsg, data: statusData63{63, DefaultConfig.NetworkId, td, head.Hash(), common.Hash{3}},
wantError: errResp(ErrGenesisMismatch, "0300000000000000 (!= %x)", genesis.Hash().Bytes()[:8]),
},
}
for i, test := range tests {
p, errc := newTestPeer("peer", 63, pm, false)
// The send call might hang until reset because
// the protocol might not read the payload.
go p2p.Send(p.app, test.code, test.data)
select {
case err := <-errc:
if err == nil {
t.Errorf("test %d: protocol returned nil error, want %q", i, test.wantError)
} else if err.Error() != test.wantError.Error() {
t.Errorf("test %d: wrong error: got %q, want %q", i, err, test.wantError)
}
case <-time.After(2 * time.Second):
t.Errorf("protocol did not shut down within 2 seconds")
}
p.close()
}
}
func TestStatusMsgErrors64(t *testing.T) {
pm, _ := newTestProtocolManagerMust(t, downloader.FullSync, 0, nil, nil)
var (
genesis = pm.blockchain.Genesis()
head = pm.blockchain.CurrentHeader()
td = pm.blockchain.GetTd(head.Hash(), head.Number.Uint64())
forkID = forkid.NewID(pm.blockchain.Config(), pm.blockchain.Genesis().Hash(), pm.blockchain.CurrentHeader().Number.Uint64())
)
defer pm.Stop()
tests := []struct {
code uint64
data interface{}
wantError error
}{
{
code: TransactionMsg, data: []interface{}{},
wantError: errResp(ErrNoStatusMsg, "first msg has code 2 (!= 0)"),
},
{
code: StatusMsg, data: statusData{10, DefaultConfig.NetworkId, td, head.Hash(), genesis.Hash(), forkID},
wantError: errResp(ErrProtocolVersionMismatch, "10 (!= %d)", 64),
},
{
code: StatusMsg, data: statusData{64, 999, td, head.Hash(), genesis.Hash(), forkID},
wantError: errResp(ErrNetworkIDMismatch, "999 (!= %d)", DefaultConfig.NetworkId),
},
{
code: StatusMsg, data: statusData{64, DefaultConfig.NetworkId, td, head.Hash(), common.Hash{3}, forkID},
wantError: errResp(ErrGenesisMismatch, "0300000000000000000000000000000000000000000000000000000000000000 (!= %x)", genesis.Hash()),
},
{
code: StatusMsg, data: statusData{64, DefaultConfig.NetworkId, td, head.Hash(), genesis.Hash(), forkid.ID{Hash: [4]byte{0x00, 0x01, 0x02, 0x03}}},
wantError: errResp(ErrForkIDRejected, forkid.ErrLocalIncompatibleOrStale.Error()),
},
}
for i, test := range tests {
p, errc := newTestPeer("peer", 64, pm, false)
// The send call might hang until reset because
// the protocol might not read the payload.
go p2p.Send(p.app, test.code, test.data)
select {
case err := <-errc:
if err == nil {
t.Errorf("test %d: protocol returned nil error, want %q", i, test.wantError)
} else if err.Error() != test.wantError.Error() {
t.Errorf("test %d: wrong error: got %q, want %q", i, err, test.wantError)
}
case <-time.After(2 * time.Second):
t.Errorf("protocol did not shut down within 2 seconds")
}
p.close()
}
}
func TestForkIDSplit(t *testing.T) {
var (
engine = ethash.NewFaker()
configNoFork = &params.ChainConfig{HomesteadBlock: big.NewInt(1)}
configProFork = &params.ChainConfig{
HomesteadBlock: big.NewInt(1),
EIP150Block: big.NewInt(2),
EIP155Block: big.NewInt(2),
EIP158Block: big.NewInt(2),
ByzantiumBlock: big.NewInt(3),
}
dbNoFork = rawdb.NewMemoryDatabase()
dbProFork = rawdb.NewMemoryDatabase()
gspecNoFork = &core.Genesis{Config: configNoFork}
gspecProFork = &core.Genesis{Config: configProFork}
genesisNoFork = gspecNoFork.MustCommit(dbNoFork)
genesisProFork = gspecProFork.MustCommit(dbProFork)
chainNoFork, _ = core.NewBlockChain(dbNoFork, nil, configNoFork, engine, vm.Config{}, nil, nil)
chainProFork, _ = core.NewBlockChain(dbProFork, nil, configProFork, engine, vm.Config{}, nil, nil)
blocksNoFork, _ = core.GenerateChain(configNoFork, genesisNoFork, engine, dbNoFork, 2, nil)
blocksProFork, _ = core.GenerateChain(configProFork, genesisProFork, engine, dbProFork, 2, nil)
ethNoFork, _ = NewProtocolManager(configNoFork, nil, downloader.FullSync, 1, new(event.TypeMux), &testTxPool{pool: make(map[common.Hash]*types.Transaction)}, engine, chainNoFork, dbNoFork, 1, nil)
ethProFork, _ = NewProtocolManager(configProFork, nil, downloader.FullSync, 1, new(event.TypeMux), &testTxPool{pool: make(map[common.Hash]*types.Transaction)}, engine, chainProFork, dbProFork, 1, nil)
)
ethNoFork.Start(1000)
ethProFork.Start(1000)
// Both nodes should allow the other to connect (same genesis, next fork is the same)
p2pNoFork, p2pProFork := p2p.MsgPipe()
peerNoFork := newPeer(64, p2p.NewPeer(enode.ID{1}, "", nil), p2pNoFork, nil)
peerProFork := newPeer(64, p2p.NewPeer(enode.ID{2}, "", nil), p2pProFork, nil)
errc := make(chan error, 2)
go func() { errc <- ethNoFork.handle(peerProFork) }()
go func() { errc <- ethProFork.handle(peerNoFork) }()
select {
case err := <-errc:
t.Fatalf("frontier nofork <-> profork failed: %v", err)
case <-time.After(250 * time.Millisecond):
p2pNoFork.Close()
p2pProFork.Close()
}
// Progress into Homestead. Fork's match, so we don't care what the future holds
chainNoFork.InsertChain(blocksNoFork[:1])
chainProFork.InsertChain(blocksProFork[:1])
p2pNoFork, p2pProFork = p2p.MsgPipe()
peerNoFork = newPeer(64, p2p.NewPeer(enode.ID{1}, "", nil), p2pNoFork, nil)
peerProFork = newPeer(64, p2p.NewPeer(enode.ID{2}, "", nil), p2pProFork, nil)
errc = make(chan error, 2)
go func() { errc <- ethNoFork.handle(peerProFork) }()
go func() { errc <- ethProFork.handle(peerNoFork) }()
select {
case err := <-errc:
t.Fatalf("homestead nofork <-> profork failed: %v", err)
case <-time.After(250 * time.Millisecond):
p2pNoFork.Close()
p2pProFork.Close()
}
// Progress into Spurious. Forks mismatch, signalling differing chains, reject
chainNoFork.InsertChain(blocksNoFork[1:2])
chainProFork.InsertChain(blocksProFork[1:2])
p2pNoFork, p2pProFork = p2p.MsgPipe()
peerNoFork = newPeer(64, p2p.NewPeer(enode.ID{1}, "", nil), p2pNoFork, nil)
peerProFork = newPeer(64, p2p.NewPeer(enode.ID{2}, "", nil), p2pProFork, nil)
errc = make(chan error, 2)
go func() { errc <- ethNoFork.handle(peerProFork) }()
go func() { errc <- ethProFork.handle(peerNoFork) }()
select {
case err := <-errc:
if want := errResp(ErrForkIDRejected, forkid.ErrLocalIncompatibleOrStale.Error()); err.Error() != want.Error() {
t.Fatalf("fork ID rejection error mismatch: have %v, want %v", err, want)
}
case <-time.After(250 * time.Millisecond):
t.Fatalf("split peers not rejected")
}
}
// This test checks that received transactions are added to the local pool.
func TestRecvTransactions63(t *testing.T) { testRecvTransactions(t, 63) }
func TestRecvTransactions64(t *testing.T) { testRecvTransactions(t, 64) }
func TestRecvTransactions65(t *testing.T) { testRecvTransactions(t, 65) }
func testRecvTransactions(t *testing.T, protocol int) {
txAdded := make(chan []*types.Transaction)
pm, _ := newTestProtocolManagerMust(t, downloader.FullSync, 0, nil, txAdded)
pm.acceptTxs = 1 // mark synced to accept transactions
p, _ := newTestPeer("peer", protocol, pm, true)
defer pm.Stop()
defer p.close()
tx := newTestTransaction(testAccount, 0, 0)
if err := p2p.Send(p.app, TransactionMsg, []interface{}{tx}); err != nil {
t.Fatalf("send error: %v", err)
}
select {
case added := <-txAdded:
if len(added) != 1 {
t.Errorf("wrong number of added transactions: got %d, want 1", len(added))
} else if added[0].Hash() != tx.Hash() {
t.Errorf("added wrong tx hash: got %v, want %v", added[0].Hash(), tx.Hash())
}
case <-time.After(2 * time.Second):
t.Errorf("no NewTxsEvent received within 2 seconds")
}
}
// This test checks that pending transactions are sent.
func TestSendTransactions63(t *testing.T) { testSendTransactions(t, 63) }
func TestSendTransactions64(t *testing.T) { testSendTransactions(t, 64) }
func TestSendTransactions65(t *testing.T) { testSendTransactions(t, 65) }
func testSendTransactions(t *testing.T, protocol int) {
pm, _ := newTestProtocolManagerMust(t, downloader.FullSync, 0, nil, nil)
defer pm.Stop()
// Fill the pool with big transactions (use a subscription to wait until all
// the transactions are announced to avoid spurious events causing extra
// broadcasts).
const txsize = txsyncPackSize / 10
alltxs := make([]*types.Transaction, 100)
for nonce := range alltxs {
alltxs[nonce] = newTestTransaction(testAccount, uint64(nonce), txsize)
}
pm.txpool.AddRemotes(alltxs)
time.Sleep(100 * time.Millisecond) // Wait until new tx even gets out of the system (lame)
// Connect several peers. They should all receive the pending transactions.
var wg sync.WaitGroup
checktxs := func(p *testPeer) {
defer wg.Done()
defer p.close()
seen := make(map[common.Hash]bool)
for _, tx := range alltxs {
seen[tx.Hash()] = false
}
for n := 0; n < len(alltxs) && !t.Failed(); {
var forAllHashes func(callback func(hash common.Hash))
switch protocol {
case 63:
fallthrough
case 64:
msg, err := p.app.ReadMsg()
if err != nil {
t.Errorf("%v: read error: %v", p.Peer, err)
continue
} else if msg.Code != TransactionMsg {
t.Errorf("%v: got code %d, want TxMsg", p.Peer, msg.Code)
continue
}
var txs []*types.Transaction
if err := msg.Decode(&txs); err != nil {
t.Errorf("%v: %v", p.Peer, err)
continue
}
forAllHashes = func(callback func(hash common.Hash)) {
for _, tx := range txs {
callback(tx.Hash())
}
}
case 65:
msg, err := p.app.ReadMsg()
if err != nil {
t.Errorf("%v: read error: %v", p.Peer, err)
continue
} else if msg.Code != NewPooledTransactionHashesMsg {
t.Errorf("%v: got code %d, want NewPooledTransactionHashesMsg", p.Peer, msg.Code)
continue
}
var hashes []common.Hash
if err := msg.Decode(&hashes); err != nil {
t.Errorf("%v: %v", p.Peer, err)
continue
}
forAllHashes = func(callback func(hash common.Hash)) {
for _, h := range hashes {
callback(h)
}
}
}
forAllHashes(func(hash common.Hash) {
seentx, want := seen[hash]
if seentx {
t.Errorf("%v: got tx more than once: %x", p.Peer, hash)
}
if !want {
t.Errorf("%v: got unexpected tx: %x", p.Peer, hash)
}
seen[hash] = true
n++
})
}
}
for i := 0; i < 3; i++ {
p, _ := newTestPeer(fmt.Sprintf("peer #%d", i), protocol, pm, true)
wg.Add(1)
go checktxs(p)
}
wg.Wait()
}
func TestTransactionPropagation(t *testing.T) { testSyncTransaction(t, true) }
func TestTransactionAnnouncement(t *testing.T) { testSyncTransaction(t, false) }
func testSyncTransaction(t *testing.T, propagtion bool) {
// Create a protocol manager for transaction fetcher and sender
pmFetcher, _ := newTestProtocolManagerMust(t, downloader.FastSync, 0, nil, nil)
defer pmFetcher.Stop()
pmSender, _ := newTestProtocolManagerMust(t, downloader.FastSync, 1024, nil, nil)
pmSender.broadcastTxAnnouncesOnly = !propagtion
defer pmSender.Stop()
// Sync up the two peers
io1, io2 := p2p.MsgPipe()
go pmSender.handle(pmSender.newPeer(65, p2p.NewPeer(enode.ID{}, "sender", nil), io2, pmSender.txpool.Get))
go pmFetcher.handle(pmFetcher.newPeer(65, p2p.NewPeer(enode.ID{}, "fetcher", nil), io1, pmFetcher.txpool.Get))
time.Sleep(250 * time.Millisecond)
pmFetcher.doSync(peerToSyncOp(downloader.FullSync, pmFetcher.peers.BestPeer()))
atomic.StoreUint32(&pmFetcher.acceptTxs, 1)
newTxs := make(chan core.NewTxsEvent, 1024)
sub := pmFetcher.txpool.SubscribeNewTxsEvent(newTxs)
defer sub.Unsubscribe()
// Fill the pool with new transactions
alltxs := make([]*types.Transaction, 1024)
for nonce := range alltxs {
alltxs[nonce] = newTestTransaction(testAccount, uint64(nonce), 0)
}
pmSender.txpool.AddRemotes(alltxs)
var got int
loop:
for {
select {
case ev := <-newTxs:
got += len(ev.Txs)
if got == 1024 {
break loop
}
case <-time.NewTimer(time.Second).C:
t.Fatal("Failed to retrieve all transaction")
}
}
}
// Tests that the custom union field encoder and decoder works correctly.
func TestGetBlockHeadersDataEncodeDecode(t *testing.T) {
// Create a "random" hash for testing
var hash common.Hash
for i := range hash {
hash[i] = byte(i)
}
// Assemble some table driven tests
tests := []struct {
packet *getBlockHeadersData
fail bool
}{
// Providing the origin as either a hash or a number should both work
{fail: false, packet: &getBlockHeadersData{Origin: hashOrNumber{Number: 314}}},
{fail: false, packet: &getBlockHeadersData{Origin: hashOrNumber{Hash: hash}}},
// Providing arbitrary query field should also work
{fail: false, packet: &getBlockHeadersData{Origin: hashOrNumber{Number: 314}, Amount: 314, Skip: 1, Reverse: true}},
{fail: false, packet: &getBlockHeadersData{Origin: hashOrNumber{Hash: hash}, Amount: 314, Skip: 1, Reverse: true}},
// Providing both the origin hash and origin number must fail
{fail: true, packet: &getBlockHeadersData{Origin: hashOrNumber{Hash: hash, Number: 314}}},
}
// Iterate over each of the tests and try to encode and then decode
for i, tt := range tests {
bytes, err := rlp.EncodeToBytes(tt.packet)
if err != nil && !tt.fail {
t.Fatalf("test %d: failed to encode packet: %v", i, err)
} else if err == nil && tt.fail {
t.Fatalf("test %d: encode should have failed", i)
}
if !tt.fail {
packet := new(getBlockHeadersData)
if err := rlp.DecodeBytes(bytes, packet); err != nil {
t.Fatalf("test %d: failed to decode packet: %v", i, err)
}
if packet.Origin.Hash != tt.packet.Origin.Hash || packet.Origin.Number != tt.packet.Origin.Number || packet.Amount != tt.packet.Amount ||
packet.Skip != tt.packet.Skip || packet.Reverse != tt.packet.Reverse {
t.Fatalf("test %d: encode decode mismatch: have %+v, want %+v", i, packet, tt.packet)
}
}
}
}

View File

@ -0,0 +1,195 @@
// Copyright 2019 The go-ethereum Authors
// This file is part of the go-ethereum library.
//
// The go-ethereum library is free software: you can redistribute it and/or modify
// it under the terms of the GNU Lesser General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// The go-ethereum library is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Lesser General Public License for more details.
//
// You should have received a copy of the GNU Lesser General Public License
// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
package eth
import (
"math/big"
"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/core/types"
)
const (
// This is the target size for the packs of transactions or announcements. A
// pack can get larger than this if a single transactions exceeds this size.
maxTxPacketSize = 100 * 1024
)
// blockPropagation is a block propagation event, waiting for its turn in the
// broadcast queue.
type blockPropagation struct {
block *types.Block
td *big.Int
}
// broadcastBlocks is a write loop that multiplexes blocks and block accouncements
// to the remote peer. The goal is to have an async writer that does not lock up
// node internals and at the same time rate limits queued data.
func (p *Peer) broadcastBlocks() {
for {
select {
case prop := <-p.queuedBlocks:
if err := p.SendNewBlock(prop.block, prop.td); err != nil {
return
}
p.Log().Trace("Propagated block", "number", prop.block.Number(), "hash", prop.block.Hash(), "td", prop.td)
case block := <-p.queuedBlockAnns:
if err := p.SendNewBlockHashes([]common.Hash{block.Hash()}, []uint64{block.NumberU64()}); err != nil {
return
}
p.Log().Trace("Announced block", "number", block.Number(), "hash", block.Hash())
case <-p.term:
return
}
}
}
// broadcastTransactions is a write loop that schedules transaction broadcasts
// to the remote peer. The goal is to have an async writer that does not lock up
// node internals and at the same time rate limits queued data.
func (p *Peer) broadcastTransactions() {
var (
queue []common.Hash // Queue of hashes to broadcast as full transactions
done chan struct{} // Non-nil if background broadcaster is running
fail = make(chan error, 1) // Channel used to receive network error
failed bool // Flag whether a send failed, discard everything onward
)
for {
// If there's no in-flight broadcast running, check if a new one is needed
if done == nil && len(queue) > 0 {
// Pile transaction until we reach our allowed network limit
var (
hashes []common.Hash
txs []*types.Transaction
size common.StorageSize
)
for i := 0; i < len(queue) && size < maxTxPacketSize; i++ {
if tx := p.txpool.Get(queue[i]); tx != nil {
txs = append(txs, tx)
size += tx.Size()
}
hashes = append(hashes, queue[i])
}
queue = queue[:copy(queue, queue[len(hashes):])]
// If there's anything available to transfer, fire up an async writer
if len(txs) > 0 {
done = make(chan struct{})
go func() {
if err := p.SendTransactions(txs); err != nil {
fail <- err
return
}
close(done)
p.Log().Trace("Sent transactions", "count", len(txs))
}()
}
}
// Transfer goroutine may or may not have been started, listen for events
select {
case hashes := <-p.txBroadcast:
// If the connection failed, discard all transaction events
if failed {
continue
}
// New batch of transactions to be broadcast, queue them (with cap)
queue = append(queue, hashes...)
if len(queue) > maxQueuedTxs {
// Fancy copy and resize to ensure buffer doesn't grow indefinitely
queue = queue[:copy(queue, queue[len(queue)-maxQueuedTxs:])]
}
case <-done:
done = nil
case <-fail:
failed = true
case <-p.term:
return
}
}
}
// announceTransactions is a write loop that schedules transaction broadcasts
// to the remote peer. The goal is to have an async writer that does not lock up
// node internals and at the same time rate limits queued data.
func (p *Peer) announceTransactions() {
var (
queue []common.Hash // Queue of hashes to announce as transaction stubs
done chan struct{} // Non-nil if background announcer is running
fail = make(chan error, 1) // Channel used to receive network error
failed bool // Flag whether a send failed, discard everything onward
)
for {
// If there's no in-flight announce running, check if a new one is needed
if done == nil && len(queue) > 0 {
// Pile transaction hashes until we reach our allowed network limit
var (
hashes []common.Hash
pending []common.Hash
size common.StorageSize
)
for i := 0; i < len(queue) && size < maxTxPacketSize; i++ {
if p.txpool.Get(queue[i]) != nil {
pending = append(pending, queue[i])
size += common.HashLength
}
hashes = append(hashes, queue[i])
}
queue = queue[:copy(queue, queue[len(hashes):])]
// If there's anything available to transfer, fire up an async writer
if len(pending) > 0 {
done = make(chan struct{})
go func() {
if err := p.sendPooledTransactionHashes(pending); err != nil {
fail <- err
return
}
close(done)
p.Log().Trace("Sent transaction announcements", "count", len(pending))
}()
}
}
// Transfer goroutine may or may not have been started, listen for events
select {
case hashes := <-p.txAnnounce:
// If the connection failed, discard all transaction events
if failed {
continue
}
// New batch of transactions to be broadcast, queue them (with cap)
queue = append(queue, hashes...)
if len(queue) > maxQueuedTxAnns {
// Fancy copy and resize to ensure buffer doesn't grow indefinitely
queue = queue[:copy(queue, queue[len(queue)-maxQueuedTxs:])]
}
case <-done:
done = nil
case <-fail:
failed = true
case <-p.term:
return
}
}
}

View File

@ -0,0 +1,65 @@
// Copyright 2019 The go-ethereum Authors
// This file is part of the go-ethereum library.
//
// The go-ethereum library is free software: you can redistribute it and/or modify
// it under the terms of the GNU Lesser General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// The go-ethereum library is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Lesser General Public License for more details.
//
// You should have received a copy of the GNU Lesser General Public License
// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
package eth
import (
"github.com/ethereum/go-ethereum/core"
"github.com/ethereum/go-ethereum/core/forkid"
"github.com/ethereum/go-ethereum/p2p/enode"
"github.com/ethereum/go-ethereum/rlp"
)
// enrEntry is the ENR entry which advertises `eth` protocol on the discovery.
type enrEntry struct {
ForkID forkid.ID // Fork identifier per EIP-2124
// Ignore additional fields (for forward compatibility).
Rest []rlp.RawValue `rlp:"tail"`
}
// ENRKey implements enr.Entry.
func (e enrEntry) ENRKey() string {
return "eth"
}
// StartENRUpdater starts the `eth` ENR updater loop, which listens for chain
// head events and updates the requested node record whenever a fork is passed.
func StartENRUpdater(chain *core.BlockChain, ln *enode.LocalNode) {
var newHead = make(chan core.ChainHeadEvent, 10)
sub := chain.SubscribeChainHeadEvent(newHead)
go func() {
defer sub.Unsubscribe()
for {
select {
case <-newHead:
ln.Set(currentENREntry(chain))
case <-sub.Err():
// Would be nice to sync with Stop, but there is no
// good way to do that.
return
}
}
}()
}
// currentENREntry constructs an `eth` ENR entry based on the current state of the chain.
func currentENREntry(chain *core.BlockChain) *enrEntry {
return &enrEntry{
ForkID: forkid.NewID(chain.Config(), chain.Genesis().Hash(), chain.CurrentHeader().Number.Uint64()),
}
}

View File

@ -0,0 +1,512 @@
// Copyright 2020 The go-ethereum Authors
// This file is part of the go-ethereum library.
//
// The go-ethereum library is free software: you can redistribute it and/or modify
// it under the terms of the GNU Lesser General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// The go-ethereum library is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Lesser General Public License for more details.
//
// You should have received a copy of the GNU Lesser General Public License
// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
package eth
import (
"encoding/json"
"fmt"
"math/big"
"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/core"
"github.com/ethereum/go-ethereum/core/types"
"github.com/ethereum/go-ethereum/log"
"github.com/ethereum/go-ethereum/p2p"
"github.com/ethereum/go-ethereum/p2p/enode"
"github.com/ethereum/go-ethereum/p2p/enr"
"github.com/ethereum/go-ethereum/params"
"github.com/ethereum/go-ethereum/rlp"
"github.com/ethereum/go-ethereum/trie"
)
const (
// softResponseLimit is the target maximum size of replies to data retrievals.
softResponseLimit = 2 * 1024 * 1024
// estHeaderSize is the approximate size of an RLP encoded block header.
estHeaderSize = 500
// maxHeadersServe is the maximum number of block headers to serve. This number
// is there to limit the number of disk lookups.
maxHeadersServe = 1024
// maxBodiesServe is the maximum number of block bodies to serve. This number
// is mostly there to limit the number of disk lookups. With 24KB block sizes
// nowadays, the practical limit will always be softResponseLimit.
maxBodiesServe = 1024
// maxNodeDataServe is the maximum number of state trie nodes to serve. This
// number is there to limit the number of disk lookups.
maxNodeDataServe = 1024
// maxReceiptsServe is the maximum number of block receipts to serve. This
// number is mostly there to limit the number of disk lookups. With block
// containing 200+ transactions nowadays, the practical limit will always
// be softResponseLimit.
maxReceiptsServe = 1024
)
// Handler is a callback to invoke from an outside runner after the boilerplate
// exchanges have passed.
type Handler func(peer *Peer) error
// Backend defines the data retrieval methods to serve remote requests and the
// callback methods to invoke on remote deliveries.
type Backend interface {
// Chain retrieves the blockchain object to serve data.
Chain() *core.BlockChain
// StateBloom retrieves the bloom filter - if any - for state trie nodes.
StateBloom() *trie.SyncBloom
// TxPool retrieves the transaction pool object to serve data.
TxPool() TxPool
// AcceptTxs retrieves whether transaction processing is enabled on the node
// or if inbound transactions should simply be dropped.
AcceptTxs() bool
// RunPeer is invoked when a peer joins on the `eth` protocol. The handler
// should do any peer maintenance work, handshakes and validations. If all
// is passed, control should be given back to the `handler` to process the
// inbound messages going forward.
RunPeer(peer *Peer, handler Handler) error
// PeerInfo retrieves all known `eth` information about a peer.
PeerInfo(id enode.ID) interface{}
// Handle is a callback to be invoked when a data packet is received from
// the remote peer. Only packets not consumed by the protocol handler will
// be forwarded to the backend.
Handle(peer *Peer, packet Packet) error
}
// TxPool defines the methods needed by the protocol handler to serve transactions.
type TxPool interface {
// Get retrieves the the transaction from the local txpool with the given hash.
Get(hash common.Hash) *types.Transaction
}
// MakeProtocols constructs the P2P protocol definitions for `eth`.
func MakeProtocols(backend Backend, network uint64, dnsdisc enode.Iterator) []p2p.Protocol {
protocols := make([]p2p.Protocol, len(protocolVersions))
for i, version := range protocolVersions {
version := version // Closure
protocols[i] = p2p.Protocol{
Name: protocolName,
Version: version,
Length: protocolLengths[version],
Run: func(p *p2p.Peer, rw p2p.MsgReadWriter) error {
peer := NewPeer(version, p, rw, backend.TxPool())
defer peer.Close()
return backend.RunPeer(peer, func(peer *Peer) error {
return Handle(backend, peer)
})
},
NodeInfo: func() interface{} {
return nodeInfo(backend.Chain(), network)
},
PeerInfo: func(id enode.ID) interface{} {
return backend.PeerInfo(id)
},
Attributes: []enr.Entry{currentENREntry(backend.Chain())},
DialCandidates: dnsdisc,
}
}
return protocols
}
// NodeInfo represents a short summary of the `eth` sub-protocol metadata
// known about the host peer.
type NodeInfo struct {
Network uint64 `json:"network"` // Ethereum network ID (1=Frontier, 2=Morden, Ropsten=3, Rinkeby=4)
Difficulty *big.Int `json:"difficulty"` // Total difficulty of the host's blockchain
Genesis common.Hash `json:"genesis"` // SHA3 hash of the host's genesis block
Config *params.ChainConfig `json:"config"` // Chain configuration for the fork rules
Head common.Hash `json:"head"` // Hex hash of the host's best owned block
}
// nodeInfo retrieves some `eth` protocol metadata about the running host node.
func nodeInfo(chain *core.BlockChain, network uint64) *NodeInfo {
head := chain.CurrentBlock()
return &NodeInfo{
Network: network,
Difficulty: chain.GetTd(head.Hash(), head.NumberU64()),
Genesis: chain.Genesis().Hash(),
Config: chain.Config(),
Head: head.Hash(),
}
}
// Handle is invoked whenever an `eth` connection is made that successfully passes
// the protocol handshake. This method will keep processing messages until the
// connection is torn down.
func Handle(backend Backend, peer *Peer) error {
for {
if err := handleMessage(backend, peer); err != nil {
peer.Log().Debug("Message handling failed in `eth`", "err", err)
return err
}
}
}
// handleMessage is invoked whenever an inbound message is received from a remote
// peer. The remote connection is torn down upon returning any error.
func handleMessage(backend Backend, peer *Peer) error {
// Read the next message from the remote peer, and ensure it's fully consumed
msg, err := peer.rw.ReadMsg()
if err != nil {
return err
}
if msg.Size > maxMessageSize {
return fmt.Errorf("%w: %v > %v", errMsgTooLarge, msg.Size, maxMessageSize)
}
defer msg.Discard()
// Handle the message depending on its contents
switch {
case msg.Code == StatusMsg:
// Status messages should never arrive after the handshake
return fmt.Errorf("%w: uncontrolled status message", errExtraStatusMsg)
// Block header query, collect the requested headers and reply
case msg.Code == GetBlockHeadersMsg:
// Decode the complex header query
var query GetBlockHeadersPacket
if err := msg.Decode(&query); err != nil {
return fmt.Errorf("%w: message %v: %v", errDecode, msg, err)
}
hashMode := query.Origin.Hash != (common.Hash{})
first := true
maxNonCanonical := uint64(100)
// Gather headers until the fetch or network limits is reached
var (
bytes common.StorageSize
headers []*types.Header
unknown bool
lookups int
)
for !unknown && len(headers) < int(query.Amount) && bytes < softResponseLimit &&
len(headers) < maxHeadersServe && lookups < 2*maxHeadersServe {
lookups++
// Retrieve the next header satisfying the query
var origin *types.Header
if hashMode {
if first {
first = false
origin = backend.Chain().GetHeaderByHash(query.Origin.Hash)
if origin != nil {
query.Origin.Number = origin.Number.Uint64()
}
} else {
origin = backend.Chain().GetHeader(query.Origin.Hash, query.Origin.Number)
}
} else {
origin = backend.Chain().GetHeaderByNumber(query.Origin.Number)
}
if origin == nil {
break
}
headers = append(headers, origin)
bytes += estHeaderSize
// Advance to the next header of the query
switch {
case hashMode && query.Reverse:
// Hash based traversal towards the genesis block
ancestor := query.Skip + 1
if ancestor == 0 {
unknown = true
} else {
query.Origin.Hash, query.Origin.Number = backend.Chain().GetAncestor(query.Origin.Hash, query.Origin.Number, ancestor, &maxNonCanonical)
unknown = (query.Origin.Hash == common.Hash{})
}
case hashMode && !query.Reverse:
// Hash based traversal towards the leaf block
var (
current = origin.Number.Uint64()
next = current + query.Skip + 1
)
if next <= current {
infos, _ := json.MarshalIndent(peer.Peer.Info(), "", " ")
peer.Log().Warn("GetBlockHeaders skip overflow attack", "current", current, "skip", query.Skip, "next", next, "attacker", infos)
unknown = true
} else {
if header := backend.Chain().GetHeaderByNumber(next); header != nil {
nextHash := header.Hash()
expOldHash, _ := backend.Chain().GetAncestor(nextHash, next, query.Skip+1, &maxNonCanonical)
if expOldHash == query.Origin.Hash {
query.Origin.Hash, query.Origin.Number = nextHash, next
} else {
unknown = true
}
} else {
unknown = true
}
}
case query.Reverse:
// Number based traversal towards the genesis block
if query.Origin.Number >= query.Skip+1 {
query.Origin.Number -= query.Skip + 1
} else {
unknown = true
}
case !query.Reverse:
// Number based traversal towards the leaf block
query.Origin.Number += query.Skip + 1
}
}
return peer.SendBlockHeaders(headers)
case msg.Code == BlockHeadersMsg:
// A batch of headers arrived to one of our previous requests
res := new(BlockHeadersPacket)
if err := msg.Decode(res); err != nil {
return fmt.Errorf("%w: message %v: %v", errDecode, msg, err)
}
return backend.Handle(peer, res)
case msg.Code == GetBlockBodiesMsg:
// Decode the block body retrieval message
var query GetBlockBodiesPacket
if err := msg.Decode(&query); err != nil {
return fmt.Errorf("%w: message %v: %v", errDecode, msg, err)
}
// Gather blocks until the fetch or network limits is reached
var (
bytes int
bodies []rlp.RawValue
)
for lookups, hash := range query {
if bytes >= softResponseLimit || len(bodies) >= maxBodiesServe ||
lookups >= 2*maxBodiesServe {
break
}
if data := backend.Chain().GetBodyRLP(hash); len(data) != 0 {
bodies = append(bodies, data)
bytes += len(data)
}
}
return peer.SendBlockBodiesRLP(bodies)
case msg.Code == BlockBodiesMsg:
// A batch of block bodies arrived to one of our previous requests
res := new(BlockBodiesPacket)
if err := msg.Decode(res); err != nil {
return fmt.Errorf("%w: message %v: %v", errDecode, msg, err)
}
return backend.Handle(peer, res)
case msg.Code == GetNodeDataMsg:
// Decode the trie node data retrieval message
var query GetNodeDataPacket
if err := msg.Decode(&query); err != nil {
return fmt.Errorf("%w: message %v: %v", errDecode, msg, err)
}
// Gather state data until the fetch or network limits is reached
var (
bytes int
nodes [][]byte
)
for lookups, hash := range query {
if bytes >= softResponseLimit || len(nodes) >= maxNodeDataServe ||
lookups >= 2*maxNodeDataServe {
break
}
// Retrieve the requested state entry
if bloom := backend.StateBloom(); bloom != nil && !bloom.Contains(hash[:]) {
// Only lookup the trie node if there's chance that we actually have it
continue
}
entry, err := backend.Chain().TrieNode(hash)
if len(entry) == 0 || err != nil {
// Read the contract code with prefix only to save unnecessary lookups.
entry, err = backend.Chain().ContractCodeWithPrefix(hash)
}
if err == nil && len(entry) > 0 {
nodes = append(nodes, entry)
bytes += len(entry)
}
}
return peer.SendNodeData(nodes)
case msg.Code == NodeDataMsg:
// A batch of node state data arrived to one of our previous requests
res := new(NodeDataPacket)
if err := msg.Decode(res); err != nil {
return fmt.Errorf("%w: message %v: %v", errDecode, msg, err)
}
return backend.Handle(peer, res)
case msg.Code == GetReceiptsMsg:
// Decode the block receipts retrieval message
var query GetReceiptsPacket
if err := msg.Decode(&query); err != nil {
return fmt.Errorf("%w: message %v: %v", errDecode, msg, err)
}
// Gather state data until the fetch or network limits is reached
var (
bytes int
receipts []rlp.RawValue
)
for lookups, hash := range query {
if bytes >= softResponseLimit || len(receipts) >= maxReceiptsServe ||
lookups >= 2*maxReceiptsServe {
break
}
// Retrieve the requested block's receipts
results := backend.Chain().GetReceiptsByHash(hash)
if results == nil {
if header := backend.Chain().GetHeaderByHash(hash); header == nil || header.ReceiptHash != types.EmptyRootHash {
continue
}
}
// If known, encode and queue for response packet
if encoded, err := rlp.EncodeToBytes(results); err != nil {
log.Error("Failed to encode receipt", "err", err)
} else {
receipts = append(receipts, encoded)
bytes += len(encoded)
}
}
return peer.SendReceiptsRLP(receipts)
case msg.Code == ReceiptsMsg:
// A batch of receipts arrived to one of our previous requests
res := new(ReceiptsPacket)
if err := msg.Decode(res); err != nil {
return fmt.Errorf("%w: message %v: %v", errDecode, msg, err)
}
return backend.Handle(peer, res)
case msg.Code == NewBlockHashesMsg:
// A batch of new block announcements just arrived
ann := new(NewBlockHashesPacket)
if err := msg.Decode(ann); err != nil {
return fmt.Errorf("%w: message %v: %v", errDecode, msg, err)
}
// Mark the hashes as present at the remote node
for _, block := range *ann {
peer.markBlock(block.Hash)
}
// Deliver them all to the backend for queuing
return backend.Handle(peer, ann)
case msg.Code == NewBlockMsg:
// Retrieve and decode the propagated block
ann := new(NewBlockPacket)
if err := msg.Decode(ann); err != nil {
return fmt.Errorf("%w: message %v: %v", errDecode, msg, err)
}
if hash := types.CalcUncleHash(ann.Block.Uncles()); hash != ann.Block.UncleHash() {
log.Warn("Propagated block has invalid uncles", "have", hash, "exp", ann.Block.UncleHash())
break // TODO(karalabe): return error eventually, but wait a few releases
}
if hash := types.DeriveSha(ann.Block.Transactions(), trie.NewStackTrie(nil)); hash != ann.Block.TxHash() {
log.Warn("Propagated block has invalid body", "have", hash, "exp", ann.Block.TxHash())
break // TODO(karalabe): return error eventually, but wait a few releases
}
if err := ann.sanityCheck(); err != nil {
return err
}
ann.Block.ReceivedAt = msg.ReceivedAt
ann.Block.ReceivedFrom = peer
// Mark the peer as owning the block
peer.markBlock(ann.Block.Hash())
return backend.Handle(peer, ann)
case msg.Code == NewPooledTransactionHashesMsg && peer.version >= ETH65:
// New transaction announcement arrived, make sure we have
// a valid and fresh chain to handle them
if !backend.AcceptTxs() {
break
}
ann := new(NewPooledTransactionHashesPacket)
if err := msg.Decode(ann); err != nil {
return fmt.Errorf("%w: message %v: %v", errDecode, msg, err)
}
// Schedule all the unknown hashes for retrieval
for _, hash := range *ann {
peer.markTransaction(hash)
}
return backend.Handle(peer, ann)
case msg.Code == GetPooledTransactionsMsg && peer.version >= ETH65:
// Decode the pooled transactions retrieval message
var query GetPooledTransactionsPacket
if err := msg.Decode(&query); err != nil {
return fmt.Errorf("%w: message %v: %v", errDecode, msg, err)
}
// Gather transactions until the fetch or network limits is reached
var (
bytes int
hashes []common.Hash
txs []rlp.RawValue
)
for _, hash := range query {
if bytes >= softResponseLimit {
break
}
// Retrieve the requested transaction, skipping if unknown to us
tx := backend.TxPool().Get(hash)
if tx == nil {
continue
}
// If known, encode and queue for response packet
if encoded, err := rlp.EncodeToBytes(tx); err != nil {
log.Error("Failed to encode transaction", "err", err)
} else {
hashes = append(hashes, hash)
txs = append(txs, encoded)
bytes += len(encoded)
}
}
return peer.SendPooledTransactionsRLP(hashes, txs)
case msg.Code == TransactionsMsg || (msg.Code == PooledTransactionsMsg && peer.version >= ETH65):
// Transactions arrived, make sure we have a valid and fresh chain to handle them
if !backend.AcceptTxs() {
break
}
// Transactions can be processed, parse all of them and deliver to the pool
var txs []*types.Transaction
if err := msg.Decode(&txs); err != nil {
return fmt.Errorf("%w: message %v: %v", errDecode, msg, err)
}
for i, tx := range txs {
// Validate and mark the remote transaction
if tx == nil {
return fmt.Errorf("%w: transaction %d is nil", errDecode, i)
}
peer.markTransaction(tx.Hash())
}
if msg.Code == PooledTransactionsMsg {
return backend.Handle(peer, (*PooledTransactionsPacket)(&txs))
}
return backend.Handle(peer, (*TransactionsPacket)(&txs))
default:
return fmt.Errorf("%w: %v", errInvalidMsgCode, msg.Code)
}
return nil
}

View File

@ -0,0 +1,519 @@
// Copyright 2015 The go-ethereum Authors
// This file is part of the go-ethereum library.
//
// The go-ethereum library is free software: you can redistribute it and/or modify
// it under the terms of the GNU Lesser General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// The go-ethereum library is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Lesser General Public License for more details.
//
// You should have received a copy of the GNU Lesser General Public License
// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
package eth
import (
"math"
"math/big"
"math/rand"
"testing"
"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/consensus/ethash"
"github.com/ethereum/go-ethereum/core"
"github.com/ethereum/go-ethereum/core/rawdb"
"github.com/ethereum/go-ethereum/core/state"
"github.com/ethereum/go-ethereum/core/types"
"github.com/ethereum/go-ethereum/core/vm"
"github.com/ethereum/go-ethereum/crypto"
"github.com/ethereum/go-ethereum/ethdb"
"github.com/ethereum/go-ethereum/p2p"
"github.com/ethereum/go-ethereum/p2p/enode"
"github.com/ethereum/go-ethereum/params"
"github.com/ethereum/go-ethereum/trie"
)
var (
// testKey is a private key to use for funding a tester account.
testKey, _ = crypto.HexToECDSA("b71c71a67e1177ad4e901695e1b4b9ee17ae16c6668d313eac2f96dbcda3f291")
// testAddr is the Ethereum address of the tester account.
testAddr = crypto.PubkeyToAddress(testKey.PublicKey)
)
// testBackend is a mock implementation of the live Ethereum message handler. Its
// purpose is to allow testing the request/reply workflows and wire serialization
// in the `eth` protocol without actually doing any data processing.
type testBackend struct {
db ethdb.Database
chain *core.BlockChain
txpool *core.TxPool
}
// newTestBackend creates an empty chain and wraps it into a mock backend.
func newTestBackend(blocks int) *testBackend {
return newTestBackendWithGenerator(blocks, nil)
}
// newTestBackend creates a chain with a number of explicitly defined blocks and
// wraps it into a mock backend.
func newTestBackendWithGenerator(blocks int, generator func(int, *core.BlockGen)) *testBackend {
// Create a database pre-initialize with a genesis block
db := rawdb.NewMemoryDatabase()
(&core.Genesis{
Config: params.TestChainConfig,
Alloc: core.GenesisAlloc{testAddr: {Balance: big.NewInt(1000000)}},
}).MustCommit(db)
chain, _ := core.NewBlockChain(db, nil, params.TestChainConfig, ethash.NewFaker(), vm.Config{}, nil, nil)
bs, _ := core.GenerateChain(params.TestChainConfig, chain.Genesis(), ethash.NewFaker(), db, blocks, generator)
if _, err := chain.InsertChain(bs); err != nil {
panic(err)
}
txconfig := core.DefaultTxPoolConfig
txconfig.Journal = "" // Don't litter the disk with test journals
return &testBackend{
db: db,
chain: chain,
txpool: core.NewTxPool(txconfig, params.TestChainConfig, chain),
}
}
// close tears down the transaction pool and chain behind the mock backend.
func (b *testBackend) close() {
b.txpool.Stop()
b.chain.Stop()
}
func (b *testBackend) Chain() *core.BlockChain { return b.chain }
func (b *testBackend) StateBloom() *trie.SyncBloom { return nil }
func (b *testBackend) TxPool() TxPool { return b.txpool }
func (b *testBackend) RunPeer(peer *Peer, handler Handler) error {
// Normally the backend would do peer mainentance and handshakes. All that
// is omitted and we will just give control back to the handler.
return handler(peer)
}
func (b *testBackend) PeerInfo(enode.ID) interface{} { panic("not implemented") }
func (b *testBackend) AcceptTxs() bool {
panic("data processing tests should be done in the handler package")
}
func (b *testBackend) Handle(*Peer, Packet) error {
panic("data processing tests should be done in the handler package")
}
// Tests that block headers can be retrieved from a remote chain based on user queries.
func TestGetBlockHeaders64(t *testing.T) { testGetBlockHeaders(t, 64) }
func TestGetBlockHeaders65(t *testing.T) { testGetBlockHeaders(t, 65) }
func testGetBlockHeaders(t *testing.T, protocol uint) {
t.Parallel()
backend := newTestBackend(maxHeadersServe + 15)
defer backend.close()
peer, _ := newTestPeer("peer", protocol, backend)
defer peer.close()
// Create a "random" unknown hash for testing
var unknown common.Hash
for i := range unknown {
unknown[i] = byte(i)
}
// Create a batch of tests for various scenarios
limit := uint64(maxHeadersServe)
tests := []struct {
query *GetBlockHeadersPacket // The query to execute for header retrieval
expect []common.Hash // The hashes of the block whose headers are expected
}{
// A single random block should be retrievable by hash and number too
{
&GetBlockHeadersPacket{Origin: HashOrNumber{Hash: backend.chain.GetBlockByNumber(limit / 2).Hash()}, Amount: 1},
[]common.Hash{backend.chain.GetBlockByNumber(limit / 2).Hash()},
}, {
&GetBlockHeadersPacket{Origin: HashOrNumber{Number: limit / 2}, Amount: 1},
[]common.Hash{backend.chain.GetBlockByNumber(limit / 2).Hash()},
},
// Multiple headers should be retrievable in both directions
{
&GetBlockHeadersPacket{Origin: HashOrNumber{Number: limit / 2}, Amount: 3},
[]common.Hash{
backend.chain.GetBlockByNumber(limit / 2).Hash(),
backend.chain.GetBlockByNumber(limit/2 + 1).Hash(),
backend.chain.GetBlockByNumber(limit/2 + 2).Hash(),
},
}, {
&GetBlockHeadersPacket{Origin: HashOrNumber{Number: limit / 2}, Amount: 3, Reverse: true},
[]common.Hash{
backend.chain.GetBlockByNumber(limit / 2).Hash(),
backend.chain.GetBlockByNumber(limit/2 - 1).Hash(),
backend.chain.GetBlockByNumber(limit/2 - 2).Hash(),
},
},
// Multiple headers with skip lists should be retrievable
{
&GetBlockHeadersPacket{Origin: HashOrNumber{Number: limit / 2}, Skip: 3, Amount: 3},
[]common.Hash{
backend.chain.GetBlockByNumber(limit / 2).Hash(),
backend.chain.GetBlockByNumber(limit/2 + 4).Hash(),
backend.chain.GetBlockByNumber(limit/2 + 8).Hash(),
},
}, {
&GetBlockHeadersPacket{Origin: HashOrNumber{Number: limit / 2}, Skip: 3, Amount: 3, Reverse: true},
[]common.Hash{
backend.chain.GetBlockByNumber(limit / 2).Hash(),
backend.chain.GetBlockByNumber(limit/2 - 4).Hash(),
backend.chain.GetBlockByNumber(limit/2 - 8).Hash(),
},
},
// The chain endpoints should be retrievable
{
&GetBlockHeadersPacket{Origin: HashOrNumber{Number: 0}, Amount: 1},
[]common.Hash{backend.chain.GetBlockByNumber(0).Hash()},
}, {
&GetBlockHeadersPacket{Origin: HashOrNumber{Number: backend.chain.CurrentBlock().NumberU64()}, Amount: 1},
[]common.Hash{backend.chain.CurrentBlock().Hash()},
},
// Ensure protocol limits are honored
{
&GetBlockHeadersPacket{Origin: HashOrNumber{Number: backend.chain.CurrentBlock().NumberU64() - 1}, Amount: limit + 10, Reverse: true},
backend.chain.GetBlockHashesFromHash(backend.chain.CurrentBlock().Hash(), limit),
},
// Check that requesting more than available is handled gracefully
{
&GetBlockHeadersPacket{Origin: HashOrNumber{Number: backend.chain.CurrentBlock().NumberU64() - 4}, Skip: 3, Amount: 3},
[]common.Hash{
backend.chain.GetBlockByNumber(backend.chain.CurrentBlock().NumberU64() - 4).Hash(),
backend.chain.GetBlockByNumber(backend.chain.CurrentBlock().NumberU64()).Hash(),
},
}, {
&GetBlockHeadersPacket{Origin: HashOrNumber{Number: 4}, Skip: 3, Amount: 3, Reverse: true},
[]common.Hash{
backend.chain.GetBlockByNumber(4).Hash(),
backend.chain.GetBlockByNumber(0).Hash(),
},
},
// Check that requesting more than available is handled gracefully, even if mid skip
{
&GetBlockHeadersPacket{Origin: HashOrNumber{Number: backend.chain.CurrentBlock().NumberU64() - 4}, Skip: 2, Amount: 3},
[]common.Hash{
backend.chain.GetBlockByNumber(backend.chain.CurrentBlock().NumberU64() - 4).Hash(),
backend.chain.GetBlockByNumber(backend.chain.CurrentBlock().NumberU64() - 1).Hash(),
},
}, {
&GetBlockHeadersPacket{Origin: HashOrNumber{Number: 4}, Skip: 2, Amount: 3, Reverse: true},
[]common.Hash{
backend.chain.GetBlockByNumber(4).Hash(),
backend.chain.GetBlockByNumber(1).Hash(),
},
},
// Check a corner case where requesting more can iterate past the endpoints
{
&GetBlockHeadersPacket{Origin: HashOrNumber{Number: 2}, Amount: 5, Reverse: true},
[]common.Hash{
backend.chain.GetBlockByNumber(2).Hash(),
backend.chain.GetBlockByNumber(1).Hash(),
backend.chain.GetBlockByNumber(0).Hash(),
},
},
// Check a corner case where skipping overflow loops back into the chain start
{
&GetBlockHeadersPacket{Origin: HashOrNumber{Hash: backend.chain.GetBlockByNumber(3).Hash()}, Amount: 2, Reverse: false, Skip: math.MaxUint64 - 1},
[]common.Hash{
backend.chain.GetBlockByNumber(3).Hash(),
},
},
// Check a corner case where skipping overflow loops back to the same header
{
&GetBlockHeadersPacket{Origin: HashOrNumber{Hash: backend.chain.GetBlockByNumber(1).Hash()}, Amount: 2, Reverse: false, Skip: math.MaxUint64},
[]common.Hash{
backend.chain.GetBlockByNumber(1).Hash(),
},
},
// Check that non existing headers aren't returned
{
&GetBlockHeadersPacket{Origin: HashOrNumber{Hash: unknown}, Amount: 1},
[]common.Hash{},
}, {
&GetBlockHeadersPacket{Origin: HashOrNumber{Number: backend.chain.CurrentBlock().NumberU64() + 1}, Amount: 1},
[]common.Hash{},
},
}
// Run each of the tests and verify the results against the chain
for i, tt := range tests {
// Collect the headers to expect in the response
var headers []*types.Header
for _, hash := range tt.expect {
headers = append(headers, backend.chain.GetBlockByHash(hash).Header())
}
// Send the hash request and verify the response
p2p.Send(peer.app, 0x03, tt.query)
if err := p2p.ExpectMsg(peer.app, 0x04, headers); err != nil {
t.Errorf("test %d: headers mismatch: %v", i, err)
}
// If the test used number origins, repeat with hashes as the too
if tt.query.Origin.Hash == (common.Hash{}) {
if origin := backend.chain.GetBlockByNumber(tt.query.Origin.Number); origin != nil {
tt.query.Origin.Hash, tt.query.Origin.Number = origin.Hash(), 0
p2p.Send(peer.app, 0x03, tt.query)
if err := p2p.ExpectMsg(peer.app, 0x04, headers); err != nil {
t.Errorf("test %d: headers mismatch: %v", i, err)
}
}
}
}
}
// Tests that block contents can be retrieved from a remote chain based on their hashes.
func TestGetBlockBodies64(t *testing.T) { testGetBlockBodies(t, 64) }
func TestGetBlockBodies65(t *testing.T) { testGetBlockBodies(t, 65) }
func testGetBlockBodies(t *testing.T, protocol uint) {
t.Parallel()
backend := newTestBackend(maxBodiesServe + 15)
defer backend.close()
peer, _ := newTestPeer("peer", protocol, backend)
defer peer.close()
// Create a batch of tests for various scenarios
limit := maxBodiesServe
tests := []struct {
random int // Number of blocks to fetch randomly from the chain
explicit []common.Hash // Explicitly requested blocks
available []bool // Availability of explicitly requested blocks
expected int // Total number of existing blocks to expect
}{
{1, nil, nil, 1}, // A single random block should be retrievable
{10, nil, nil, 10}, // Multiple random blocks should be retrievable
{limit, nil, nil, limit}, // The maximum possible blocks should be retrievable
{limit + 1, nil, nil, limit}, // No more than the possible block count should be returned
{0, []common.Hash{backend.chain.Genesis().Hash()}, []bool{true}, 1}, // The genesis block should be retrievable
{0, []common.Hash{backend.chain.CurrentBlock().Hash()}, []bool{true}, 1}, // The chains head block should be retrievable
{0, []common.Hash{{}}, []bool{false}, 0}, // A non existent block should not be returned
// Existing and non-existing blocks interleaved should not cause problems
{0, []common.Hash{
{},
backend.chain.GetBlockByNumber(1).Hash(),
{},
backend.chain.GetBlockByNumber(10).Hash(),
{},
backend.chain.GetBlockByNumber(100).Hash(),
{},
}, []bool{false, true, false, true, false, true, false}, 3},
}
// Run each of the tests and verify the results against the chain
for i, tt := range tests {
// Collect the hashes to request, and the response to expectva
var (
hashes []common.Hash
bodies []*BlockBody
seen = make(map[int64]bool)
)
for j := 0; j < tt.random; j++ {
for {
num := rand.Int63n(int64(backend.chain.CurrentBlock().NumberU64()))
if !seen[num] {
seen[num] = true
block := backend.chain.GetBlockByNumber(uint64(num))
hashes = append(hashes, block.Hash())
if len(bodies) < tt.expected {
bodies = append(bodies, &BlockBody{Transactions: block.Transactions(), Uncles: block.Uncles()})
}
break
}
}
}
for j, hash := range tt.explicit {
hashes = append(hashes, hash)
if tt.available[j] && len(bodies) < tt.expected {
block := backend.chain.GetBlockByHash(hash)
bodies = append(bodies, &BlockBody{Transactions: block.Transactions(), Uncles: block.Uncles()})
}
}
// Send the hash request and verify the response
p2p.Send(peer.app, 0x05, hashes)
if err := p2p.ExpectMsg(peer.app, 0x06, bodies); err != nil {
t.Errorf("test %d: bodies mismatch: %v", i, err)
}
}
}
// Tests that the state trie nodes can be retrieved based on hashes.
func TestGetNodeData64(t *testing.T) { testGetNodeData(t, 64) }
func TestGetNodeData65(t *testing.T) { testGetNodeData(t, 65) }
func testGetNodeData(t *testing.T, protocol uint) {
t.Parallel()
// Define three accounts to simulate transactions with
acc1Key, _ := crypto.HexToECDSA("8a1f9a8f95be41cd7ccb6168179afb4504aefe388d1e14474d32c45c72ce7b7a")
acc2Key, _ := crypto.HexToECDSA("49a7b37aa6f6645917e7b807e9d1c00d4fa71f18343b0d4122a4d2df64dd6fee")
acc1Addr := crypto.PubkeyToAddress(acc1Key.PublicKey)
acc2Addr := crypto.PubkeyToAddress(acc2Key.PublicKey)
signer := types.HomesteadSigner{}
// Create a chain generator with some simple transactions (blatantly stolen from @fjl/chain_markets_test)
generator := func(i int, block *core.BlockGen) {
switch i {
case 0:
// In block 1, the test bank sends account #1 some ether.
tx, _ := types.SignTx(types.NewTransaction(block.TxNonce(testAddr), acc1Addr, big.NewInt(10000), params.TxGas, nil, nil), signer, testKey)
block.AddTx(tx)
case 1:
// In block 2, the test bank sends some more ether to account #1.
// acc1Addr passes it on to account #2.
tx1, _ := types.SignTx(types.NewTransaction(block.TxNonce(testAddr), acc1Addr, big.NewInt(1000), params.TxGas, nil, nil), signer, testKey)
tx2, _ := types.SignTx(types.NewTransaction(block.TxNonce(acc1Addr), acc2Addr, big.NewInt(1000), params.TxGas, nil, nil), signer, acc1Key)
block.AddTx(tx1)
block.AddTx(tx2)
case 2:
// Block 3 is empty but was mined by account #2.
block.SetCoinbase(acc2Addr)
block.SetExtra([]byte("yeehaw"))
case 3:
// Block 4 includes blocks 2 and 3 as uncle headers (with modified extra data).
b2 := block.PrevBlock(1).Header()
b2.Extra = []byte("foo")
block.AddUncle(b2)
b3 := block.PrevBlock(2).Header()
b3.Extra = []byte("foo")
block.AddUncle(b3)
}
}
// Assemble the test environment
backend := newTestBackendWithGenerator(4, generator)
defer backend.close()
peer, _ := newTestPeer("peer", protocol, backend)
defer peer.close()
// Fetch for now the entire chain db
var hashes []common.Hash
it := backend.db.NewIterator(nil, nil)
for it.Next() {
if key := it.Key(); len(key) == common.HashLength {
hashes = append(hashes, common.BytesToHash(key))
}
}
it.Release()
p2p.Send(peer.app, 0x0d, hashes)
msg, err := peer.app.ReadMsg()
if err != nil {
t.Fatalf("failed to read node data response: %v", err)
}
if msg.Code != 0x0e {
t.Fatalf("response packet code mismatch: have %x, want %x", msg.Code, 0x0c)
}
var data [][]byte
if err := msg.Decode(&data); err != nil {
t.Fatalf("failed to decode response node data: %v", err)
}
// Verify that all hashes correspond to the requested data, and reconstruct a state tree
for i, want := range hashes {
if hash := crypto.Keccak256Hash(data[i]); hash != want {
t.Errorf("data hash mismatch: have %x, want %x", hash, want)
}
}
statedb := rawdb.NewMemoryDatabase()
for i := 0; i < len(data); i++ {
statedb.Put(hashes[i].Bytes(), data[i])
}
accounts := []common.Address{testAddr, acc1Addr, acc2Addr}
for i := uint64(0); i <= backend.chain.CurrentBlock().NumberU64(); i++ {
trie, _ := state.New(backend.chain.GetBlockByNumber(i).Root(), state.NewDatabase(statedb), nil)
for j, acc := range accounts {
state, _ := backend.chain.State()
bw := state.GetBalance(acc)
bh := trie.GetBalance(acc)
if (bw != nil && bh == nil) || (bw == nil && bh != nil) {
t.Errorf("test %d, account %d: balance mismatch: have %v, want %v", i, j, bh, bw)
}
if bw != nil && bh != nil && bw.Cmp(bw) != 0 {
t.Errorf("test %d, account %d: balance mismatch: have %v, want %v", i, j, bh, bw)
}
}
}
}
// Tests that the transaction receipts can be retrieved based on hashes.
func TestGetBlockReceipts64(t *testing.T) { testGetBlockReceipts(t, 64) }
func TestGetBlockReceipts65(t *testing.T) { testGetBlockReceipts(t, 65) }
func testGetBlockReceipts(t *testing.T, protocol uint) {
t.Parallel()
// Define three accounts to simulate transactions with
acc1Key, _ := crypto.HexToECDSA("8a1f9a8f95be41cd7ccb6168179afb4504aefe388d1e14474d32c45c72ce7b7a")
acc2Key, _ := crypto.HexToECDSA("49a7b37aa6f6645917e7b807e9d1c00d4fa71f18343b0d4122a4d2df64dd6fee")
acc1Addr := crypto.PubkeyToAddress(acc1Key.PublicKey)
acc2Addr := crypto.PubkeyToAddress(acc2Key.PublicKey)
signer := types.HomesteadSigner{}
// Create a chain generator with some simple transactions (blatantly stolen from @fjl/chain_markets_test)
generator := func(i int, block *core.BlockGen) {
switch i {
case 0:
// In block 1, the test bank sends account #1 some ether.
tx, _ := types.SignTx(types.NewTransaction(block.TxNonce(testAddr), acc1Addr, big.NewInt(10000), params.TxGas, nil, nil), signer, testKey)
block.AddTx(tx)
case 1:
// In block 2, the test bank sends some more ether to account #1.
// acc1Addr passes it on to account #2.
tx1, _ := types.SignTx(types.NewTransaction(block.TxNonce(testAddr), acc1Addr, big.NewInt(1000), params.TxGas, nil, nil), signer, testKey)
tx2, _ := types.SignTx(types.NewTransaction(block.TxNonce(acc1Addr), acc2Addr, big.NewInt(1000), params.TxGas, nil, nil), signer, acc1Key)
block.AddTx(tx1)
block.AddTx(tx2)
case 2:
// Block 3 is empty but was mined by account #2.
block.SetCoinbase(acc2Addr)
block.SetExtra([]byte("yeehaw"))
case 3:
// Block 4 includes blocks 2 and 3 as uncle headers (with modified extra data).
b2 := block.PrevBlock(1).Header()
b2.Extra = []byte("foo")
block.AddUncle(b2)
b3 := block.PrevBlock(2).Header()
b3.Extra = []byte("foo")
block.AddUncle(b3)
}
}
// Assemble the test environment
backend := newTestBackendWithGenerator(4, generator)
defer backend.close()
peer, _ := newTestPeer("peer", protocol, backend)
defer peer.close()
// Collect the hashes to request, and the response to expect
var (
hashes []common.Hash
receipts []types.Receipts
)
for i := uint64(0); i <= backend.chain.CurrentBlock().NumberU64(); i++ {
block := backend.chain.GetBlockByNumber(i)
hashes = append(hashes, block.Hash())
receipts = append(receipts, backend.chain.GetReceiptsByHash(block.Hash()))
}
// Send the hash request and verify the response
p2p.Send(peer.app, 0x0f, hashes)
if err := p2p.ExpectMsg(peer.app, 0x10, receipts); err != nil {
t.Errorf("receipts mismatch: %v", err)
}
}

View File

@ -0,0 +1,107 @@
// Copyright 2015 The go-ethereum Authors
// This file is part of the go-ethereum library.
//
// The go-ethereum library is free software: you can redistribute it and/or modify
// it under the terms of the GNU Lesser General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// The go-ethereum library is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Lesser General Public License for more details.
//
// You should have received a copy of the GNU Lesser General Public License
// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
package eth
import (
"fmt"
"math/big"
"time"
"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/core/forkid"
"github.com/ethereum/go-ethereum/p2p"
)
const (
// handshakeTimeout is the maximum allowed time for the `eth` handshake to
// complete before dropping the connection.= as malicious.
handshakeTimeout = 5 * time.Second
)
// Handshake executes the eth protocol handshake, negotiating version number,
// network IDs, difficulties, head and genesis blocks.
func (p *Peer) Handshake(network uint64, td *big.Int, head common.Hash, genesis common.Hash, forkID forkid.ID, forkFilter forkid.Filter) error {
// Send out own handshake in a new thread
errc := make(chan error, 2)
var status StatusPacket // safe to read after two values have been received from errc
go func() {
errc <- p2p.Send(p.rw, StatusMsg, &StatusPacket{
ProtocolVersion: uint32(p.version),
NetworkID: network,
TD: td,
Head: head,
Genesis: genesis,
ForkID: forkID,
})
}()
go func() {
errc <- p.readStatus(network, &status, genesis, forkFilter)
}()
timeout := time.NewTimer(handshakeTimeout)
defer timeout.Stop()
for i := 0; i < 2; i++ {
select {
case err := <-errc:
if err != nil {
return err
}
case <-timeout.C:
return p2p.DiscReadTimeout
}
}
p.td, p.head = status.TD, status.Head
// TD at mainnet block #7753254 is 76 bits. If it becomes 100 million times
// larger, it will still fit within 100 bits
if tdlen := p.td.BitLen(); tdlen > 100 {
return fmt.Errorf("too large total difficulty: bitlen %d", tdlen)
}
return nil
}
// readStatus reads the remote handshake message.
func (p *Peer) readStatus(network uint64, status *StatusPacket, genesis common.Hash, forkFilter forkid.Filter) error {
msg, err := p.rw.ReadMsg()
if err != nil {
return err
}
if msg.Code != StatusMsg {
return fmt.Errorf("%w: first msg has code %x (!= %x)", errNoStatusMsg, msg.Code, StatusMsg)
}
if msg.Size > maxMessageSize {
return fmt.Errorf("%w: %v > %v", errMsgTooLarge, msg.Size, maxMessageSize)
}
// Decode the handshake and make sure everything matches
if err := msg.Decode(&status); err != nil {
return fmt.Errorf("%w: message %v: %v", errDecode, msg, err)
}
if status.NetworkID != network {
return fmt.Errorf("%w: %d (!= %d)", errNetworkIDMismatch, status.NetworkID, network)
}
if uint(status.ProtocolVersion) != p.version {
return fmt.Errorf("%w: %d (!= %d)", errProtocolVersionMismatch, status.ProtocolVersion, p.version)
}
if status.Genesis != genesis {
return fmt.Errorf("%w: %x (!= %x)", errGenesisMismatch, status.Genesis, genesis)
}
if err := forkFilter(status.ForkID); err != nil {
return fmt.Errorf("%w: %v", errForkIDRejected, err)
}
return nil
}

View File

@ -0,0 +1,91 @@
// Copyright 2014 The go-ethereum Authors
// This file is part of the go-ethereum library.
//
// The go-ethereum library is free software: you can redistribute it and/or modify
// it under the terms of the GNU Lesser General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// The go-ethereum library is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Lesser General Public License for more details.
//
// You should have received a copy of the GNU Lesser General Public License
// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
package eth
import (
"errors"
"testing"
"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/core/forkid"
"github.com/ethereum/go-ethereum/p2p"
"github.com/ethereum/go-ethereum/p2p/enode"
)
// Tests that handshake failures are detected and reported correctly.
func TestHandshake64(t *testing.T) { testHandshake(t, 64) }
func TestHandshake65(t *testing.T) { testHandshake(t, 65) }
func testHandshake(t *testing.T, protocol uint) {
t.Parallel()
// Create a test backend only to have some valid genesis chain
backend := newTestBackend(3)
defer backend.close()
var (
genesis = backend.chain.Genesis()
head = backend.chain.CurrentBlock()
td = backend.chain.GetTd(head.Hash(), head.NumberU64())
forkID = forkid.NewID(backend.chain.Config(), backend.chain.Genesis().Hash(), backend.chain.CurrentHeader().Number.Uint64())
)
tests := []struct {
code uint64
data interface{}
want error
}{
{
code: TransactionsMsg, data: []interface{}{},
want: errNoStatusMsg,
},
{
code: StatusMsg, data: StatusPacket{10, 1, td, head.Hash(), genesis.Hash(), forkID},
want: errProtocolVersionMismatch,
},
{
code: StatusMsg, data: StatusPacket{uint32(protocol), 999, td, head.Hash(), genesis.Hash(), forkID},
want: errNetworkIDMismatch,
},
{
code: StatusMsg, data: StatusPacket{uint32(protocol), 1, td, head.Hash(), common.Hash{3}, forkID},
want: errGenesisMismatch,
},
{
code: StatusMsg, data: StatusPacket{uint32(protocol), 1, td, head.Hash(), genesis.Hash(), forkid.ID{Hash: [4]byte{0x00, 0x01, 0x02, 0x03}}},
want: errForkIDRejected,
},
}
for i, test := range tests {
// Create the two peers to shake with each other
app, net := p2p.MsgPipe()
defer app.Close()
defer net.Close()
peer := NewPeer(protocol, p2p.NewPeer(enode.ID{}, "peer", nil), net, nil)
defer peer.Close()
// Send the junk test with one peer, check the handshake failure
go p2p.Send(app, test.code, test.data)
err := peer.Handshake(1, td, head.Hash(), genesis.Hash(), forkID, forkid.NewFilter(backend.chain))
if err == nil {
t.Errorf("test %d: protocol returned nil error, want %q", i, test.want)
} else if !errors.Is(err, test.want) {
t.Errorf("test %d: wrong error: got %q, want %q", i, err, test.want)
}
}
}

429
eth/protocols/eth/peer.go Normal file
View File

@ -0,0 +1,429 @@
// Copyright 2020 The go-ethereum Authors
// This file is part of the go-ethereum library.
//
// The go-ethereum library is free software: you can redistribute it and/or modify
// it under the terms of the GNU Lesser General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// The go-ethereum library is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Lesser General Public License for more details.
//
// You should have received a copy of the GNU Lesser General Public License
// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
package eth
import (
"math/big"
"sync"
mapset "github.com/deckarep/golang-set"
"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/core/types"
"github.com/ethereum/go-ethereum/p2p"
"github.com/ethereum/go-ethereum/rlp"
)
const (
// maxKnownTxs is the maximum transactions hashes to keep in the known list
// before starting to randomly evict them.
maxKnownTxs = 32768
// maxKnownBlocks is the maximum block hashes to keep in the known list
// before starting to randomly evict them.
maxKnownBlocks = 1024
// maxQueuedTxs is the maximum number of transactions to queue up before dropping
// older broadcasts.
maxQueuedTxs = 4096
// maxQueuedTxAnns is the maximum number of transaction announcements to queue up
// before dropping older announcements.
maxQueuedTxAnns = 4096
// maxQueuedBlocks is the maximum number of block propagations to queue up before
// dropping broadcasts. There's not much point in queueing stale blocks, so a few
// that might cover uncles should be enough.
maxQueuedBlocks = 4
// maxQueuedBlockAnns is the maximum number of block announcements to queue up before
// dropping broadcasts. Similarly to block propagations, there's no point to queue
// above some healthy uncle limit, so use that.
maxQueuedBlockAnns = 4
)
// max is a helper function which returns the larger of the two given integers.
func max(a, b int) int {
if a > b {
return a
}
return b
}
// Peer is a collection of relevant information we have about a `eth` peer.
type Peer struct {
id string // Unique ID for the peer, cached
*p2p.Peer // The embedded P2P package peer
rw p2p.MsgReadWriter // Input/output streams for snap
version uint // Protocol version negotiated
head common.Hash // Latest advertised head block hash
td *big.Int // Latest advertised head block total difficulty
knownBlocks mapset.Set // Set of block hashes known to be known by this peer
queuedBlocks chan *blockPropagation // Queue of blocks to broadcast to the peer
queuedBlockAnns chan *types.Block // Queue of blocks to announce to the peer
txpool TxPool // Transaction pool used by the broadcasters for liveness checks
knownTxs mapset.Set // Set of transaction hashes known to be known by this peer
txBroadcast chan []common.Hash // Channel used to queue transaction propagation requests
txAnnounce chan []common.Hash // Channel used to queue transaction announcement requests
term chan struct{} // Termination channel to stop the broadcasters
lock sync.RWMutex // Mutex protecting the internal fields
}
// NewPeer create a wrapper for a network connection and negotiated protocol
// version.
func NewPeer(version uint, p *p2p.Peer, rw p2p.MsgReadWriter, txpool TxPool) *Peer {
peer := &Peer{
id: p.ID().String(),
Peer: p,
rw: rw,
version: version,
knownTxs: mapset.NewSet(),
knownBlocks: mapset.NewSet(),
queuedBlocks: make(chan *blockPropagation, maxQueuedBlocks),
queuedBlockAnns: make(chan *types.Block, maxQueuedBlockAnns),
txBroadcast: make(chan []common.Hash),
txAnnounce: make(chan []common.Hash),
txpool: txpool,
term: make(chan struct{}),
}
// Start up all the broadcasters
go peer.broadcastBlocks()
go peer.broadcastTransactions()
if version >= ETH65 {
go peer.announceTransactions()
}
return peer
}
// Close signals the broadcast goroutine to terminate. Only ever call this if
// you created the peer yourself via NewPeer. Otherwise let whoever created it
// clean it up!
func (p *Peer) Close() {
close(p.term)
}
// ID retrieves the peer's unique identifier.
func (p *Peer) ID() string {
return p.id
}
// Version retrieves the peer's negoatiated `eth` protocol version.
func (p *Peer) Version() uint {
return p.version
}
// Head retrieves the current head hash and total difficulty of the peer.
func (p *Peer) Head() (hash common.Hash, td *big.Int) {
p.lock.RLock()
defer p.lock.RUnlock()
copy(hash[:], p.head[:])
return hash, new(big.Int).Set(p.td)
}
// SetHead updates the head hash and total difficulty of the peer.
func (p *Peer) SetHead(hash common.Hash, td *big.Int) {
p.lock.Lock()
defer p.lock.Unlock()
copy(p.head[:], hash[:])
p.td.Set(td)
}
// KnownBlock returns whether peer is known to already have a block.
func (p *Peer) KnownBlock(hash common.Hash) bool {
return p.knownBlocks.Contains(hash)
}
// KnownTransaction returns whether peer is known to already have a transaction.
func (p *Peer) KnownTransaction(hash common.Hash) bool {
return p.knownTxs.Contains(hash)
}
// markBlock marks a block as known for the peer, ensuring that the block will
// never be propagated to this particular peer.
func (p *Peer) markBlock(hash common.Hash) {
// If we reached the memory allowance, drop a previously known block hash
for p.knownBlocks.Cardinality() >= maxKnownBlocks {
p.knownBlocks.Pop()
}
p.knownBlocks.Add(hash)
}
// markTransaction marks a transaction as known for the peer, ensuring that it
// will never be propagated to this particular peer.
func (p *Peer) markTransaction(hash common.Hash) {
// If we reached the memory allowance, drop a previously known transaction hash
for p.knownTxs.Cardinality() >= maxKnownTxs {
p.knownTxs.Pop()
}
p.knownTxs.Add(hash)
}
// SendTransactions sends transactions to the peer and includes the hashes
// in its transaction hash set for future reference.
//
// This method is a helper used by the async transaction sender. Don't call it
// directly as the queueing (memory) and transmission (bandwidth) costs should
// not be managed directly.
//
// The reasons this is public is to allow packages using this protocol to write
// tests that directly send messages without having to do the asyn queueing.
func (p *Peer) SendTransactions(txs types.Transactions) error {
// Mark all the transactions as known, but ensure we don't overflow our limits
for p.knownTxs.Cardinality() > max(0, maxKnownTxs-len(txs)) {
p.knownTxs.Pop()
}
for _, tx := range txs {
p.knownTxs.Add(tx.Hash())
}
return p2p.Send(p.rw, TransactionsMsg, txs)
}
// AsyncSendTransactions queues a list of transactions (by hash) to eventually
// propagate to a remote peer. The number of pending sends are capped (new ones
// will force old sends to be dropped)
func (p *Peer) AsyncSendTransactions(hashes []common.Hash) {
select {
case p.txBroadcast <- hashes:
// Mark all the transactions as known, but ensure we don't overflow our limits
for p.knownTxs.Cardinality() > max(0, maxKnownTxs-len(hashes)) {
p.knownTxs.Pop()
}
for _, hash := range hashes {
p.knownTxs.Add(hash)
}
case <-p.term:
p.Log().Debug("Dropping transaction propagation", "count", len(hashes))
}
}
// sendPooledTransactionHashes sends transaction hashes to the peer and includes
// them in its transaction hash set for future reference.
//
// This method is a helper used by the async transaction announcer. Don't call it
// directly as the queueing (memory) and transmission (bandwidth) costs should
// not be managed directly.
func (p *Peer) sendPooledTransactionHashes(hashes []common.Hash) error {
// Mark all the transactions as known, but ensure we don't overflow our limits
for p.knownTxs.Cardinality() > max(0, maxKnownTxs-len(hashes)) {
p.knownTxs.Pop()
}
for _, hash := range hashes {
p.knownTxs.Add(hash)
}
return p2p.Send(p.rw, NewPooledTransactionHashesMsg, NewPooledTransactionHashesPacket(hashes))
}
// AsyncSendPooledTransactionHashes queues a list of transactions hashes to eventually
// announce to a remote peer. The number of pending sends are capped (new ones
// will force old sends to be dropped)
func (p *Peer) AsyncSendPooledTransactionHashes(hashes []common.Hash) {
select {
case p.txAnnounce <- hashes:
// Mark all the transactions as known, but ensure we don't overflow our limits
for p.knownTxs.Cardinality() > max(0, maxKnownTxs-len(hashes)) {
p.knownTxs.Pop()
}
for _, hash := range hashes {
p.knownTxs.Add(hash)
}
case <-p.term:
p.Log().Debug("Dropping transaction announcement", "count", len(hashes))
}
}
// SendPooledTransactionsRLP sends requested transactions to the peer and adds the
// hashes in its transaction hash set for future reference.
//
// Note, the method assumes the hashes are correct and correspond to the list of
// transactions being sent.
func (p *Peer) SendPooledTransactionsRLP(hashes []common.Hash, txs []rlp.RawValue) error {
// Mark all the transactions as known, but ensure we don't overflow our limits
for p.knownTxs.Cardinality() > max(0, maxKnownTxs-len(hashes)) {
p.knownTxs.Pop()
}
for _, hash := range hashes {
p.knownTxs.Add(hash)
}
return p2p.Send(p.rw, PooledTransactionsMsg, txs) // Not packed into PooledTransactionsPacket to avoid RLP decoding
}
// SendNewBlockHashes announces the availability of a number of blocks through
// a hash notification.
func (p *Peer) SendNewBlockHashes(hashes []common.Hash, numbers []uint64) error {
// Mark all the block hashes as known, but ensure we don't overflow our limits
for p.knownBlocks.Cardinality() > max(0, maxKnownBlocks-len(hashes)) {
p.knownBlocks.Pop()
}
for _, hash := range hashes {
p.knownBlocks.Add(hash)
}
request := make(NewBlockHashesPacket, len(hashes))
for i := 0; i < len(hashes); i++ {
request[i].Hash = hashes[i]
request[i].Number = numbers[i]
}
return p2p.Send(p.rw, NewBlockHashesMsg, request)
}
// AsyncSendNewBlockHash queues the availability of a block for propagation to a
// remote peer. If the peer's broadcast queue is full, the event is silently
// dropped.
func (p *Peer) AsyncSendNewBlockHash(block *types.Block) {
select {
case p.queuedBlockAnns <- block:
// Mark all the block hash as known, but ensure we don't overflow our limits
for p.knownBlocks.Cardinality() >= maxKnownBlocks {
p.knownBlocks.Pop()
}
p.knownBlocks.Add(block.Hash())
default:
p.Log().Debug("Dropping block announcement", "number", block.NumberU64(), "hash", block.Hash())
}
}
// SendNewBlock propagates an entire block to a remote peer.
func (p *Peer) SendNewBlock(block *types.Block, td *big.Int) error {
// Mark all the block hash as known, but ensure we don't overflow our limits
for p.knownBlocks.Cardinality() >= maxKnownBlocks {
p.knownBlocks.Pop()
}
p.knownBlocks.Add(block.Hash())
return p2p.Send(p.rw, NewBlockMsg, &NewBlockPacket{block, td})
}
// AsyncSendNewBlock queues an entire block for propagation to a remote peer. If
// the peer's broadcast queue is full, the event is silently dropped.
func (p *Peer) AsyncSendNewBlock(block *types.Block, td *big.Int) {
select {
case p.queuedBlocks <- &blockPropagation{block: block, td: td}:
// Mark all the block hash as known, but ensure we don't overflow our limits
for p.knownBlocks.Cardinality() >= maxKnownBlocks {
p.knownBlocks.Pop()
}
p.knownBlocks.Add(block.Hash())
default:
p.Log().Debug("Dropping block propagation", "number", block.NumberU64(), "hash", block.Hash())
}
}
// SendBlockHeaders sends a batch of block headers to the remote peer.
func (p *Peer) SendBlockHeaders(headers []*types.Header) error {
return p2p.Send(p.rw, BlockHeadersMsg, BlockHeadersPacket(headers))
}
// SendBlockBodies sends a batch of block contents to the remote peer.
func (p *Peer) SendBlockBodies(bodies []*BlockBody) error {
return p2p.Send(p.rw, BlockBodiesMsg, BlockBodiesPacket(bodies))
}
// SendBlockBodiesRLP sends a batch of block contents to the remote peer from
// an already RLP encoded format.
func (p *Peer) SendBlockBodiesRLP(bodies []rlp.RawValue) error {
return p2p.Send(p.rw, BlockBodiesMsg, bodies) // Not packed into BlockBodiesPacket to avoid RLP decoding
}
// SendNodeDataRLP sends a batch of arbitrary internal data, corresponding to the
// hashes requested.
func (p *Peer) SendNodeData(data [][]byte) error {
return p2p.Send(p.rw, NodeDataMsg, NodeDataPacket(data))
}
// SendReceiptsRLP sends a batch of transaction receipts, corresponding to the
// ones requested from an already RLP encoded format.
func (p *Peer) SendReceiptsRLP(receipts []rlp.RawValue) error {
return p2p.Send(p.rw, ReceiptsMsg, receipts) // Not packed into ReceiptsPacket to avoid RLP decoding
}
// RequestOneHeader is a wrapper around the header query functions to fetch a
// single header. It is used solely by the fetcher.
func (p *Peer) RequestOneHeader(hash common.Hash) error {
p.Log().Debug("Fetching single header", "hash", hash)
return p2p.Send(p.rw, GetBlockHeadersMsg, &GetBlockHeadersPacket{
Origin: HashOrNumber{Hash: hash},
Amount: uint64(1),
Skip: uint64(0),
Reverse: false,
})
}
// RequestHeadersByHash fetches a batch of blocks' headers corresponding to the
// specified header query, based on the hash of an origin block.
func (p *Peer) RequestHeadersByHash(origin common.Hash, amount int, skip int, reverse bool) error {
p.Log().Debug("Fetching batch of headers", "count", amount, "fromhash", origin, "skip", skip, "reverse", reverse)
return p2p.Send(p.rw, GetBlockHeadersMsg, &GetBlockHeadersPacket{
Origin: HashOrNumber{Hash: origin},
Amount: uint64(amount),
Skip: uint64(skip),
Reverse: reverse,
})
}
// RequestHeadersByNumber fetches a batch of blocks' headers corresponding to the
// specified header query, based on the number of an origin block.
func (p *Peer) RequestHeadersByNumber(origin uint64, amount int, skip int, reverse bool) error {
p.Log().Debug("Fetching batch of headers", "count", amount, "fromnum", origin, "skip", skip, "reverse", reverse)
return p2p.Send(p.rw, GetBlockHeadersMsg, &GetBlockHeadersPacket{
Origin: HashOrNumber{Number: origin},
Amount: uint64(amount),
Skip: uint64(skip),
Reverse: reverse,
})
}
// ExpectRequestHeadersByNumber is a testing method to mirror the recipient side
// of the RequestHeadersByNumber operation.
func (p *Peer) ExpectRequestHeadersByNumber(origin uint64, amount int, skip int, reverse bool) error {
req := &GetBlockHeadersPacket{
Origin: HashOrNumber{Number: origin},
Amount: uint64(amount),
Skip: uint64(skip),
Reverse: reverse,
}
return p2p.ExpectMsg(p.rw, GetBlockHeadersMsg, req)
}
// RequestBodies fetches a batch of blocks' bodies corresponding to the hashes
// specified.
func (p *Peer) RequestBodies(hashes []common.Hash) error {
p.Log().Debug("Fetching batch of block bodies", "count", len(hashes))
return p2p.Send(p.rw, GetBlockBodiesMsg, GetBlockBodiesPacket(hashes))
}
// RequestNodeData fetches a batch of arbitrary data from a node's known state
// data, corresponding to the specified hashes.
func (p *Peer) RequestNodeData(hashes []common.Hash) error {
p.Log().Debug("Fetching batch of state data", "count", len(hashes))
return p2p.Send(p.rw, GetNodeDataMsg, GetNodeDataPacket(hashes))
}
// RequestReceipts fetches a batch of transaction receipts from a remote node.
func (p *Peer) RequestReceipts(hashes []common.Hash) error {
p.Log().Debug("Fetching batch of receipts", "count", len(hashes))
return p2p.Send(p.rw, GetReceiptsMsg, GetReceiptsPacket(hashes))
}
// RequestTxs fetches a batch of transactions from a remote node.
func (p *Peer) RequestTxs(hashes []common.Hash) error {
p.Log().Debug("Fetching batch of transactions", "count", len(hashes))
return p2p.Send(p.rw, GetPooledTransactionsMsg, GetPooledTransactionsPacket(hashes))
}

View File

@ -0,0 +1,61 @@
// Copyright 2015 The go-ethereum Authors
// This file is part of the go-ethereum library.
//
// The go-ethereum library is free software: you can redistribute it and/or modify
// it under the terms of the GNU Lesser General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// The go-ethereum library is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Lesser General Public License for more details.
//
// You should have received a copy of the GNU Lesser General Public License
// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
// This file contains some shares testing functionality, common to multiple
// different files and modules being tested.
package eth
import (
"crypto/rand"
"github.com/ethereum/go-ethereum/p2p"
"github.com/ethereum/go-ethereum/p2p/enode"
)
// testPeer is a simulated peer to allow testing direct network calls.
type testPeer struct {
*Peer
net p2p.MsgReadWriter // Network layer reader/writer to simulate remote messaging
app *p2p.MsgPipeRW // Application layer reader/writer to simulate the local side
}
// newTestPeer creates a new peer registered at the given data backend.
func newTestPeer(name string, version uint, backend Backend) (*testPeer, <-chan error) {
// Create a message pipe to communicate through
app, net := p2p.MsgPipe()
// Start the peer on a new thread
var id enode.ID
rand.Read(id[:])
peer := NewPeer(version, p2p.NewPeer(id, name, nil), net, backend.TxPool())
errc := make(chan error, 1)
go func() {
errc <- backend.RunPeer(peer, func(peer *Peer) error {
return Handle(backend, peer)
})
}()
return &testPeer{app: app, net: net, Peer: peer}, errc
}
// close terminates the local side of the peer, notifying the remote protocol
// manager of termination.
func (p *testPeer) close() {
p.Peer.Close()
p.app.Close()
}

View File

@ -0,0 +1,279 @@
// Copyright 2014 The go-ethereum Authors
// This file is part of the go-ethereum library.
//
// The go-ethereum library is free software: you can redistribute it and/or modify
// it under the terms of the GNU Lesser General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// The go-ethereum library is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Lesser General Public License for more details.
//
// You should have received a copy of the GNU Lesser General Public License
// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
package eth
import (
"errors"
"fmt"
"io"
"math/big"
"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/core/forkid"
"github.com/ethereum/go-ethereum/core/types"
"github.com/ethereum/go-ethereum/rlp"
)
// Constants to match up protocol versions and messages
const (
ETH64 = 64
ETH65 = 65
)
// protocolName is the official short name of the `eth` protocol used during
// devp2p capability negotiation.
const protocolName = "eth"
// protocolVersions are the supported versions of the `eth` protocol (first
// is primary).
var protocolVersions = []uint{ETH65, ETH64}
// protocolLengths are the number of implemented message corresponding to
// different protocol versions.
var protocolLengths = map[uint]uint64{ETH65: 17, ETH64: 17}
// maxMessageSize is the maximum cap on the size of a protocol message.
const maxMessageSize = 10 * 1024 * 1024
const (
// Protocol messages in eth/64
StatusMsg = 0x00
NewBlockHashesMsg = 0x01
TransactionsMsg = 0x02
GetBlockHeadersMsg = 0x03
BlockHeadersMsg = 0x04
GetBlockBodiesMsg = 0x05
BlockBodiesMsg = 0x06
NewBlockMsg = 0x07
GetNodeDataMsg = 0x0d
NodeDataMsg = 0x0e
GetReceiptsMsg = 0x0f
ReceiptsMsg = 0x10
// Protocol messages overloaded in eth/65
NewPooledTransactionHashesMsg = 0x08
GetPooledTransactionsMsg = 0x09
PooledTransactionsMsg = 0x0a
)
var (
errNoStatusMsg = errors.New("no status message")
errMsgTooLarge = errors.New("message too long")
errDecode = errors.New("invalid message")
errInvalidMsgCode = errors.New("invalid message code")
errProtocolVersionMismatch = errors.New("protocol version mismatch")
errNetworkIDMismatch = errors.New("network ID mismatch")
errGenesisMismatch = errors.New("genesis mismatch")
errForkIDRejected = errors.New("fork ID rejected")
errExtraStatusMsg = errors.New("extra status message")
)
// Packet represents a p2p message in the `eth` protocol.
type Packet interface {
Name() string // Name returns a string corresponding to the message type.
Kind() byte // Kind returns the message type.
}
// StatusPacket is the network packet for the status message for eth/64 and later.
type StatusPacket struct {
ProtocolVersion uint32
NetworkID uint64
TD *big.Int
Head common.Hash
Genesis common.Hash
ForkID forkid.ID
}
// NewBlockHashesPacket is the network packet for the block announcements.
type NewBlockHashesPacket []struct {
Hash common.Hash // Hash of one particular block being announced
Number uint64 // Number of one particular block being announced
}
// Unpack retrieves the block hashes and numbers from the announcement packet
// and returns them in a split flat format that's more consistent with the
// internal data structures.
func (p *NewBlockHashesPacket) Unpack() ([]common.Hash, []uint64) {
var (
hashes = make([]common.Hash, len(*p))
numbers = make([]uint64, len(*p))
)
for i, body := range *p {
hashes[i], numbers[i] = body.Hash, body.Number
}
return hashes, numbers
}
// TransactionsPacket is the network packet for broadcasting new transactions.
type TransactionsPacket []*types.Transaction
// GetBlockHeadersPacket represents a block header query.
type GetBlockHeadersPacket struct {
Origin HashOrNumber // Block from which to retrieve headers
Amount uint64 // Maximum number of headers to retrieve
Skip uint64 // Blocks to skip between consecutive headers
Reverse bool // Query direction (false = rising towards latest, true = falling towards genesis)
}
// HashOrNumber is a combined field for specifying an origin block.
type HashOrNumber struct {
Hash common.Hash // Block hash from which to retrieve headers (excludes Number)
Number uint64 // Block hash from which to retrieve headers (excludes Hash)
}
// EncodeRLP is a specialized encoder for HashOrNumber to encode only one of the
// two contained union fields.
func (hn *HashOrNumber) EncodeRLP(w io.Writer) error {
if hn.Hash == (common.Hash{}) {
return rlp.Encode(w, hn.Number)
}
if hn.Number != 0 {
return fmt.Errorf("both origin hash (%x) and number (%d) provided", hn.Hash, hn.Number)
}
return rlp.Encode(w, hn.Hash)
}
// DecodeRLP is a specialized decoder for HashOrNumber to decode the contents
// into either a block hash or a block number.
func (hn *HashOrNumber) DecodeRLP(s *rlp.Stream) error {
_, size, _ := s.Kind()
origin, err := s.Raw()
if err == nil {
switch {
case size == 32:
err = rlp.DecodeBytes(origin, &hn.Hash)
case size <= 8:
err = rlp.DecodeBytes(origin, &hn.Number)
default:
err = fmt.Errorf("invalid input size %d for origin", size)
}
}
return err
}
// BlockHeadersPacket represents a block header response.
type BlockHeadersPacket []*types.Header
// NewBlockPacket is the network packet for the block propagation message.
type NewBlockPacket struct {
Block *types.Block
TD *big.Int
}
// sanityCheck verifies that the values are reasonable, as a DoS protection
func (request *NewBlockPacket) sanityCheck() error {
if err := request.Block.SanityCheck(); err != nil {
return err
}
//TD at mainnet block #7753254 is 76 bits. If it becomes 100 million times
// larger, it will still fit within 100 bits
if tdlen := request.TD.BitLen(); tdlen > 100 {
return fmt.Errorf("too large block TD: bitlen %d", tdlen)
}
return nil
}
// GetBlockBodiesPacket represents a block body query.
type GetBlockBodiesPacket []common.Hash
// BlockBodiesPacket is the network packet for block content distribution.
type BlockBodiesPacket []*BlockBody
// BlockBody represents the data content of a single block.
type BlockBody struct {
Transactions []*types.Transaction // Transactions contained within a block
Uncles []*types.Header // Uncles contained within a block
}
// Unpack retrieves the transactions and uncles from the range packet and returns
// them in a split flat format that's more consistent with the internal data structures.
func (p *BlockBodiesPacket) Unpack() ([][]*types.Transaction, [][]*types.Header) {
var (
txset = make([][]*types.Transaction, len(*p))
uncleset = make([][]*types.Header, len(*p))
)
for i, body := range *p {
txset[i], uncleset[i] = body.Transactions, body.Uncles
}
return txset, uncleset
}
// GetNodeDataPacket represents a trie node data query.
type GetNodeDataPacket []common.Hash
// NodeDataPacket is the network packet for trie node data distribution.
type NodeDataPacket [][]byte
// GetReceiptsPacket represents a block receipts query.
type GetReceiptsPacket []common.Hash
// ReceiptsPacket is the network packet for block receipts distribution.
type ReceiptsPacket [][]*types.Receipt
// NewPooledTransactionHashesPacket represents a transaction announcement packet.
type NewPooledTransactionHashesPacket []common.Hash
// GetPooledTransactionsPacket represents a transaction query.
type GetPooledTransactionsPacket []common.Hash
// PooledTransactionsPacket is the network packet for transaction distribution.
type PooledTransactionsPacket []*types.Transaction
func (*StatusPacket) Name() string { return "Status" }
func (*StatusPacket) Kind() byte { return StatusMsg }
func (*NewBlockHashesPacket) Name() string { return "NewBlockHashes" }
func (*NewBlockHashesPacket) Kind() byte { return NewBlockHashesMsg }
func (*TransactionsPacket) Name() string { return "Transactions" }
func (*TransactionsPacket) Kind() byte { return TransactionsMsg }
func (*GetBlockHeadersPacket) Name() string { return "GetBlockHeaders" }
func (*GetBlockHeadersPacket) Kind() byte { return GetBlockHeadersMsg }
func (*BlockHeadersPacket) Name() string { return "BlockHeaders" }
func (*BlockHeadersPacket) Kind() byte { return BlockHeadersMsg }
func (*GetBlockBodiesPacket) Name() string { return "GetBlockBodies" }
func (*GetBlockBodiesPacket) Kind() byte { return GetBlockBodiesMsg }
func (*BlockBodiesPacket) Name() string { return "BlockBodies" }
func (*BlockBodiesPacket) Kind() byte { return BlockBodiesMsg }
func (*NewBlockPacket) Name() string { return "NewBlock" }
func (*NewBlockPacket) Kind() byte { return NewBlockMsg }
func (*GetNodeDataPacket) Name() string { return "GetNodeData" }
func (*GetNodeDataPacket) Kind() byte { return GetNodeDataMsg }
func (*NodeDataPacket) Name() string { return "NodeData" }
func (*NodeDataPacket) Kind() byte { return NodeDataMsg }
func (*GetReceiptsPacket) Name() string { return "GetReceipts" }
func (*GetReceiptsPacket) Kind() byte { return GetReceiptsMsg }
func (*ReceiptsPacket) Name() string { return "Receipts" }
func (*ReceiptsPacket) Kind() byte { return ReceiptsMsg }
func (*NewPooledTransactionHashesPacket) Name() string { return "NewPooledTransactionHashes" }
func (*NewPooledTransactionHashesPacket) Kind() byte { return NewPooledTransactionHashesMsg }
func (*GetPooledTransactionsPacket) Name() string { return "GetPooledTransactions" }
func (*GetPooledTransactionsPacket) Kind() byte { return GetPooledTransactionsMsg }
func (*PooledTransactionsPacket) Name() string { return "PooledTransactions" }
func (*PooledTransactionsPacket) Kind() byte { return PooledTransactionsMsg }

View File

@ -0,0 +1,68 @@
// Copyright 2014 The go-ethereum Authors
// This file is part of the go-ethereum library.
//
// The go-ethereum library is free software: you can redistribute it and/or modify
// it under the terms of the GNU Lesser General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// The go-ethereum library is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Lesser General Public License for more details.
//
// You should have received a copy of the GNU Lesser General Public License
// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
package eth
import (
"testing"
"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/rlp"
)
// Tests that the custom union field encoder and decoder works correctly.
func TestGetBlockHeadersDataEncodeDecode(t *testing.T) {
// Create a "random" hash for testing
var hash common.Hash
for i := range hash {
hash[i] = byte(i)
}
// Assemble some table driven tests
tests := []struct {
packet *GetBlockHeadersPacket
fail bool
}{
// Providing the origin as either a hash or a number should both work
{fail: false, packet: &GetBlockHeadersPacket{Origin: HashOrNumber{Number: 314}}},
{fail: false, packet: &GetBlockHeadersPacket{Origin: HashOrNumber{Hash: hash}}},
// Providing arbitrary query field should also work
{fail: false, packet: &GetBlockHeadersPacket{Origin: HashOrNumber{Number: 314}, Amount: 314, Skip: 1, Reverse: true}},
{fail: false, packet: &GetBlockHeadersPacket{Origin: HashOrNumber{Hash: hash}, Amount: 314, Skip: 1, Reverse: true}},
// Providing both the origin hash and origin number must fail
{fail: true, packet: &GetBlockHeadersPacket{Origin: HashOrNumber{Hash: hash, Number: 314}}},
}
// Iterate over each of the tests and try to encode and then decode
for i, tt := range tests {
bytes, err := rlp.EncodeToBytes(tt.packet)
if err != nil && !tt.fail {
t.Fatalf("test %d: failed to encode packet: %v", i, err)
} else if err == nil && tt.fail {
t.Fatalf("test %d: encode should have failed", i)
}
if !tt.fail {
packet := new(GetBlockHeadersPacket)
if err := rlp.DecodeBytes(bytes, packet); err != nil {
t.Fatalf("test %d: failed to decode packet: %v", i, err)
}
if packet.Origin.Hash != tt.packet.Origin.Hash || packet.Origin.Number != tt.packet.Origin.Number || packet.Amount != tt.packet.Amount ||
packet.Skip != tt.packet.Skip || packet.Reverse != tt.packet.Reverse {
t.Fatalf("test %d: encode decode mismatch: have %+v, want %+v", i, packet, tt.packet)
}
}
}
}

View File

@ -0,0 +1,32 @@
// Copyright 2020 The go-ethereum Authors
// This file is part of the go-ethereum library.
//
// The go-ethereum library is free software: you can redistribute it and/or modify
// it under the terms of the GNU Lesser General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// The go-ethereum library is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Lesser General Public License for more details.
//
// You should have received a copy of the GNU Lesser General Public License
// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
package snap
import (
"github.com/ethereum/go-ethereum/rlp"
)
// enrEntry is the ENR entry which advertises `snap` protocol on the discovery.
type enrEntry struct {
// Ignore additional fields (for forward compatibility).
Rest []rlp.RawValue `rlp:"tail"`
}
// ENRKey implements enr.Entry.
func (e enrEntry) ENRKey() string {
return "snap"
}

View File

@ -0,0 +1,490 @@
// Copyright 2020 The go-ethereum Authors
// This file is part of the go-ethereum library.
//
// The go-ethereum library is free software: you can redistribute it and/or modify
// it under the terms of the GNU Lesser General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// The go-ethereum library is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Lesser General Public License for more details.
//
// You should have received a copy of the GNU Lesser General Public License
// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
package snap
import (
"bytes"
"fmt"
"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/core"
"github.com/ethereum/go-ethereum/core/state"
"github.com/ethereum/go-ethereum/light"
"github.com/ethereum/go-ethereum/log"
"github.com/ethereum/go-ethereum/p2p"
"github.com/ethereum/go-ethereum/p2p/enode"
"github.com/ethereum/go-ethereum/p2p/enr"
"github.com/ethereum/go-ethereum/rlp"
"github.com/ethereum/go-ethereum/trie"
)
const (
// softResponseLimit is the target maximum size of replies to data retrievals.
softResponseLimit = 2 * 1024 * 1024
// maxCodeLookups is the maximum number of bytecodes to serve. This number is
// there to limit the number of disk lookups.
maxCodeLookups = 1024
// stateLookupSlack defines the ratio by how much a state response can exceed
// the requested limit in order to try and avoid breaking up contracts into
// multiple packages and proving them.
stateLookupSlack = 0.1
// maxTrieNodeLookups is the maximum number of state trie nodes to serve. This
// number is there to limit the number of disk lookups.
maxTrieNodeLookups = 1024
)
// Handler is a callback to invoke from an outside runner after the boilerplate
// exchanges have passed.
type Handler func(peer *Peer) error
// Backend defines the data retrieval methods to serve remote requests and the
// callback methods to invoke on remote deliveries.
type Backend interface {
// Chain retrieves the blockchain object to serve data.
Chain() *core.BlockChain
// RunPeer is invoked when a peer joins on the `eth` protocol. The handler
// should do any peer maintenance work, handshakes and validations. If all
// is passed, control should be given back to the `handler` to process the
// inbound messages going forward.
RunPeer(peer *Peer, handler Handler) error
// PeerInfo retrieves all known `snap` information about a peer.
PeerInfo(id enode.ID) interface{}
// Handle is a callback to be invoked when a data packet is received from
// the remote peer. Only packets not consumed by the protocol handler will
// be forwarded to the backend.
Handle(peer *Peer, packet Packet) error
}
// MakeProtocols constructs the P2P protocol definitions for `snap`.
func MakeProtocols(backend Backend, dnsdisc enode.Iterator) []p2p.Protocol {
protocols := make([]p2p.Protocol, len(protocolVersions))
for i, version := range protocolVersions {
version := version // Closure
protocols[i] = p2p.Protocol{
Name: protocolName,
Version: version,
Length: protocolLengths[version],
Run: func(p *p2p.Peer, rw p2p.MsgReadWriter) error {
return backend.RunPeer(newPeer(version, p, rw), func(peer *Peer) error {
return handle(backend, peer)
})
},
NodeInfo: func() interface{} {
return nodeInfo(backend.Chain())
},
PeerInfo: func(id enode.ID) interface{} {
return backend.PeerInfo(id)
},
Attributes: []enr.Entry{&enrEntry{}},
DialCandidates: dnsdisc,
}
}
return protocols
}
// handle is the callback invoked to manage the life cycle of a `snap` peer.
// When this function terminates, the peer is disconnected.
func handle(backend Backend, peer *Peer) error {
for {
if err := handleMessage(backend, peer); err != nil {
peer.Log().Debug("Message handling failed in `snap`", "err", err)
return err
}
}
}
// handleMessage is invoked whenever an inbound message is received from a
// remote peer on the `spap` protocol. The remote connection is torn down upon
// returning any error.
func handleMessage(backend Backend, peer *Peer) error {
// Read the next message from the remote peer, and ensure it's fully consumed
msg, err := peer.rw.ReadMsg()
if err != nil {
return err
}
if msg.Size > maxMessageSize {
return fmt.Errorf("%w: %v > %v", errMsgTooLarge, msg.Size, maxMessageSize)
}
defer msg.Discard()
// Handle the message depending on its contents
switch {
case msg.Code == GetAccountRangeMsg:
// Decode the account retrieval request
var req GetAccountRangePacket
if err := msg.Decode(&req); err != nil {
return fmt.Errorf("%w: message %v: %v", errDecode, msg, err)
}
if req.Bytes > softResponseLimit {
req.Bytes = softResponseLimit
}
// Retrieve the requested state and bail out if non existent
tr, err := trie.New(req.Root, backend.Chain().StateCache().TrieDB())
if err != nil {
return p2p.Send(peer.rw, AccountRangeMsg, &AccountRangePacket{ID: req.ID})
}
it, err := backend.Chain().Snapshots().AccountIterator(req.Root, req.Origin)
if err != nil {
return p2p.Send(peer.rw, AccountRangeMsg, &AccountRangePacket{ID: req.ID})
}
// Iterate over the requested range and pile accounts up
var (
accounts []*AccountData
size uint64
last common.Hash
)
for it.Next() && size < req.Bytes {
hash, account := it.Hash(), common.CopyBytes(it.Account())
// Track the returned interval for the Merkle proofs
last = hash
// Assemble the reply item
size += uint64(common.HashLength + len(account))
accounts = append(accounts, &AccountData{
Hash: hash,
Body: account,
})
// If we've exceeded the request threshold, abort
if bytes.Compare(hash[:], req.Limit[:]) >= 0 {
break
}
}
it.Release()
// Generate the Merkle proofs for the first and last account
proof := light.NewNodeSet()
if err := tr.Prove(req.Origin[:], 0, proof); err != nil {
log.Warn("Failed to prove account range", "origin", req.Origin, "err", err)
return p2p.Send(peer.rw, AccountRangeMsg, &AccountRangePacket{ID: req.ID})
}
if last != (common.Hash{}) {
if err := tr.Prove(last[:], 0, proof); err != nil {
log.Warn("Failed to prove account range", "last", last, "err", err)
return p2p.Send(peer.rw, AccountRangeMsg, &AccountRangePacket{ID: req.ID})
}
}
var proofs [][]byte
for _, blob := range proof.NodeList() {
proofs = append(proofs, blob)
}
// Send back anything accumulated
return p2p.Send(peer.rw, AccountRangeMsg, &AccountRangePacket{
ID: req.ID,
Accounts: accounts,
Proof: proofs,
})
case msg.Code == AccountRangeMsg:
// A range of accounts arrived to one of our previous requests
res := new(AccountRangePacket)
if err := msg.Decode(res); err != nil {
return fmt.Errorf("%w: message %v: %v", errDecode, msg, err)
}
// Ensure the range is monotonically increasing
for i := 1; i < len(res.Accounts); i++ {
if bytes.Compare(res.Accounts[i-1].Hash[:], res.Accounts[i].Hash[:]) >= 0 {
return fmt.Errorf("accounts not monotonically increasing: #%d [%x] vs #%d [%x]", i-1, res.Accounts[i-1].Hash[:], i, res.Accounts[i].Hash[:])
}
}
return backend.Handle(peer, res)
case msg.Code == GetStorageRangesMsg:
// Decode the storage retrieval request
var req GetStorageRangesPacket
if err := msg.Decode(&req); err != nil {
return fmt.Errorf("%w: message %v: %v", errDecode, msg, err)
}
if req.Bytes > softResponseLimit {
req.Bytes = softResponseLimit
}
// TODO(karalabe): Do we want to enforce > 0 accounts and 1 account if origin is set?
// TODO(karalabe): - Logging locally is not ideal as remote faulst annoy the local user
// TODO(karalabe): - Dropping the remote peer is less flexible wrt client bugs (slow is better than non-functional)
// Calculate the hard limit at which to abort, even if mid storage trie
hardLimit := uint64(float64(req.Bytes) * (1 + stateLookupSlack))
// Retrieve storage ranges until the packet limit is reached
var (
slots [][]*StorageData
proofs [][]byte
size uint64
)
for _, account := range req.Accounts {
// If we've exceeded the requested data limit, abort without opening
// a new storage range (that we'd need to prove due to exceeded size)
if size >= req.Bytes {
break
}
// The first account might start from a different origin and end sooner
var origin common.Hash
if len(req.Origin) > 0 {
origin, req.Origin = common.BytesToHash(req.Origin), nil
}
var limit = common.HexToHash("0xffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff")
if len(req.Limit) > 0 {
limit, req.Limit = common.BytesToHash(req.Limit), nil
}
// Retrieve the requested state and bail out if non existent
it, err := backend.Chain().Snapshots().StorageIterator(req.Root, account, origin)
if err != nil {
return p2p.Send(peer.rw, StorageRangesMsg, &StorageRangesPacket{ID: req.ID})
}
// Iterate over the requested range and pile slots up
var (
storage []*StorageData
last common.Hash
)
for it.Next() && size < hardLimit {
hash, slot := it.Hash(), common.CopyBytes(it.Slot())
// Track the returned interval for the Merkle proofs
last = hash
// Assemble the reply item
size += uint64(common.HashLength + len(slot))
storage = append(storage, &StorageData{
Hash: hash,
Body: slot,
})
// If we've exceeded the request threshold, abort
if bytes.Compare(hash[:], limit[:]) >= 0 {
break
}
}
slots = append(slots, storage)
it.Release()
// Generate the Merkle proofs for the first and last storage slot, but
// only if the response was capped. If the entire storage trie included
// in the response, no need for any proofs.
if origin != (common.Hash{}) || size >= hardLimit {
// Request started at a non-zero hash or was capped prematurely, add
// the endpoint Merkle proofs
accTrie, err := trie.New(req.Root, backend.Chain().StateCache().TrieDB())
if err != nil {
return p2p.Send(peer.rw, StorageRangesMsg, &StorageRangesPacket{ID: req.ID})
}
var acc state.Account
if err := rlp.DecodeBytes(accTrie.Get(account[:]), &acc); err != nil {
return p2p.Send(peer.rw, StorageRangesMsg, &StorageRangesPacket{ID: req.ID})
}
stTrie, err := trie.New(acc.Root, backend.Chain().StateCache().TrieDB())
if err != nil {
return p2p.Send(peer.rw, StorageRangesMsg, &StorageRangesPacket{ID: req.ID})
}
proof := light.NewNodeSet()
if err := stTrie.Prove(origin[:], 0, proof); err != nil {
log.Warn("Failed to prove storage range", "origin", req.Origin, "err", err)
return p2p.Send(peer.rw, StorageRangesMsg, &StorageRangesPacket{ID: req.ID})
}
if last != (common.Hash{}) {
if err := stTrie.Prove(last[:], 0, proof); err != nil {
log.Warn("Failed to prove storage range", "last", last, "err", err)
return p2p.Send(peer.rw, StorageRangesMsg, &StorageRangesPacket{ID: req.ID})
}
}
for _, blob := range proof.NodeList() {
proofs = append(proofs, blob)
}
// Proof terminates the reply as proofs are only added if a node
// refuses to serve more data (exception when a contract fetch is
// finishing, but that's that).
break
}
}
// Send back anything accumulated
return p2p.Send(peer.rw, StorageRangesMsg, &StorageRangesPacket{
ID: req.ID,
Slots: slots,
Proof: proofs,
})
case msg.Code == StorageRangesMsg:
// A range of storage slots arrived to one of our previous requests
res := new(StorageRangesPacket)
if err := msg.Decode(res); err != nil {
return fmt.Errorf("%w: message %v: %v", errDecode, msg, err)
}
// Ensure the ranges ae monotonically increasing
for i, slots := range res.Slots {
for j := 1; j < len(slots); j++ {
if bytes.Compare(slots[j-1].Hash[:], slots[j].Hash[:]) >= 0 {
return fmt.Errorf("storage slots not monotonically increasing for account #%d: #%d [%x] vs #%d [%x]", i, j-1, slots[j-1].Hash[:], j, slots[j].Hash[:])
}
}
}
return backend.Handle(peer, res)
case msg.Code == GetByteCodesMsg:
// Decode bytecode retrieval request
var req GetByteCodesPacket
if err := msg.Decode(&req); err != nil {
return fmt.Errorf("%w: message %v: %v", errDecode, msg, err)
}
if req.Bytes > softResponseLimit {
req.Bytes = softResponseLimit
}
if len(req.Hashes) > maxCodeLookups {
req.Hashes = req.Hashes[:maxCodeLookups]
}
// Retrieve bytecodes until the packet size limit is reached
var (
codes [][]byte
bytes uint64
)
for _, hash := range req.Hashes {
if hash == emptyCode {
// Peers should not request the empty code, but if they do, at
// least sent them back a correct response without db lookups
codes = append(codes, []byte{})
} else if blob, err := backend.Chain().ContractCode(hash); err == nil {
codes = append(codes, blob)
bytes += uint64(len(blob))
}
if bytes > req.Bytes {
break
}
}
// Send back anything accumulated
return p2p.Send(peer.rw, ByteCodesMsg, &ByteCodesPacket{
ID: req.ID,
Codes: codes,
})
case msg.Code == ByteCodesMsg:
// A batch of byte codes arrived to one of our previous requests
res := new(ByteCodesPacket)
if err := msg.Decode(res); err != nil {
return fmt.Errorf("%w: message %v: %v", errDecode, msg, err)
}
return backend.Handle(peer, res)
case msg.Code == GetTrieNodesMsg:
// Decode trie node retrieval request
var req GetTrieNodesPacket
if err := msg.Decode(&req); err != nil {
return fmt.Errorf("%w: message %v: %v", errDecode, msg, err)
}
if req.Bytes > softResponseLimit {
req.Bytes = softResponseLimit
}
// Make sure we have the state associated with the request
triedb := backend.Chain().StateCache().TrieDB()
accTrie, err := trie.NewSecure(req.Root, triedb)
if err != nil {
// We don't have the requested state available, bail out
return p2p.Send(peer.rw, TrieNodesMsg, &TrieNodesPacket{ID: req.ID})
}
snap := backend.Chain().Snapshots().Snapshot(req.Root)
if snap == nil {
// We don't have the requested state snapshotted yet, bail out.
// In reality we could still serve using the account and storage
// tries only, but let's protect the node a bit while it's doing
// snapshot generation.
return p2p.Send(peer.rw, TrieNodesMsg, &TrieNodesPacket{ID: req.ID})
}
// Retrieve trie nodes until the packet size limit is reached
var (
nodes [][]byte
bytes uint64
loads int // Trie hash expansions to cound database reads
)
for _, pathset := range req.Paths {
switch len(pathset) {
case 0:
// Ensure we penalize invalid requests
return fmt.Errorf("%w: zero-item pathset requested", errBadRequest)
case 1:
// If we're only retrieving an account trie node, fetch it directly
blob, resolved, err := accTrie.TryGetNode(pathset[0])
loads += resolved // always account database reads, even for failures
if err != nil {
break
}
nodes = append(nodes, blob)
bytes += uint64(len(blob))
default:
// Storage slots requested, open the storage trie and retrieve from there
account, err := snap.Account(common.BytesToHash(pathset[0]))
loads++ // always account database reads, even for failures
if err != nil {
break
}
stTrie, err := trie.NewSecure(common.BytesToHash(account.Root), triedb)
loads++ // always account database reads, even for failures
if err != nil {
break
}
for _, path := range pathset[1:] {
blob, resolved, err := stTrie.TryGetNode(path)
loads += resolved // always account database reads, even for failures
if err != nil {
break
}
nodes = append(nodes, blob)
bytes += uint64(len(blob))
// Sanity check limits to avoid DoS on the store trie loads
if bytes > req.Bytes || loads > maxTrieNodeLookups {
break
}
}
}
// Abort request processing if we've exceeded our limits
if bytes > req.Bytes || loads > maxTrieNodeLookups {
break
}
}
// Send back anything accumulated
return p2p.Send(peer.rw, TrieNodesMsg, &TrieNodesPacket{
ID: req.ID,
Nodes: nodes,
})
case msg.Code == TrieNodesMsg:
// A batch of trie nodes arrived to one of our previous requests
res := new(TrieNodesPacket)
if err := msg.Decode(res); err != nil {
return fmt.Errorf("%w: message %v: %v", errDecode, msg, err)
}
return backend.Handle(peer, res)
default:
return fmt.Errorf("%w: %v", errInvalidMsgCode, msg.Code)
}
}
// NodeInfo represents a short summary of the `snap` sub-protocol metadata
// known about the host peer.
type NodeInfo struct{}
// nodeInfo retrieves some `snap` protocol metadata about the running host node.
func nodeInfo(chain *core.BlockChain) *NodeInfo {
return &NodeInfo{}
}

111
eth/protocols/snap/peer.go Normal file
View File

@ -0,0 +1,111 @@
// Copyright 2020 The go-ethereum Authors
// This file is part of the go-ethereum library.
//
// The go-ethereum library is free software: you can redistribute it and/or modify
// it under the terms of the GNU Lesser General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// The go-ethereum library is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Lesser General Public License for more details.
//
// You should have received a copy of the GNU Lesser General Public License
// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
package snap
import (
"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/log"
"github.com/ethereum/go-ethereum/p2p"
)
// Peer is a collection of relevant information we have about a `snap` peer.
type Peer struct {
id string // Unique ID for the peer, cached
*p2p.Peer // The embedded P2P package peer
rw p2p.MsgReadWriter // Input/output streams for snap
version uint // Protocol version negotiated
logger log.Logger // Contextual logger with the peer id injected
}
// newPeer create a wrapper for a network connection and negotiated protocol
// version.
func newPeer(version uint, p *p2p.Peer, rw p2p.MsgReadWriter) *Peer {
id := p.ID().String()
return &Peer{
id: id,
Peer: p,
rw: rw,
version: version,
logger: log.New("peer", id[:8]),
}
}
// ID retrieves the peer's unique identifier.
func (p *Peer) ID() string {
return p.id
}
// Version retrieves the peer's negoatiated `snap` protocol version.
func (p *Peer) Version() uint {
return p.version
}
// RequestAccountRange fetches a batch of accounts rooted in a specific account
// trie, starting with the origin.
func (p *Peer) RequestAccountRange(id uint64, root common.Hash, origin, limit common.Hash, bytes uint64) error {
p.logger.Trace("Fetching range of accounts", "reqid", id, "root", root, "origin", origin, "limit", limit, "bytes", common.StorageSize(bytes))
return p2p.Send(p.rw, GetAccountRangeMsg, &GetAccountRangePacket{
ID: id,
Root: root,
Origin: origin,
Limit: limit,
Bytes: bytes,
})
}
// RequestStorageRange fetches a batch of storage slots belonging to one or more
// accounts. If slots from only one accout is requested, an origin marker may also
// be used to retrieve from there.
func (p *Peer) RequestStorageRanges(id uint64, root common.Hash, accounts []common.Hash, origin, limit []byte, bytes uint64) error {
if len(accounts) == 1 && origin != nil {
p.logger.Trace("Fetching range of large storage slots", "reqid", id, "root", root, "account", accounts[0], "origin", common.BytesToHash(origin), "limit", common.BytesToHash(limit), "bytes", common.StorageSize(bytes))
} else {
p.logger.Trace("Fetching ranges of small storage slots", "reqid", id, "root", root, "accounts", len(accounts), "first", accounts[0], "bytes", common.StorageSize(bytes))
}
return p2p.Send(p.rw, GetStorageRangesMsg, &GetStorageRangesPacket{
ID: id,
Root: root,
Accounts: accounts,
Origin: origin,
Limit: limit,
Bytes: bytes,
})
}
// RequestByteCodes fetches a batch of bytecodes by hash.
func (p *Peer) RequestByteCodes(id uint64, hashes []common.Hash, bytes uint64) error {
p.logger.Trace("Fetching set of byte codes", "reqid", id, "hashes", len(hashes), "bytes", common.StorageSize(bytes))
return p2p.Send(p.rw, GetByteCodesMsg, &GetByteCodesPacket{
ID: id,
Hashes: hashes,
Bytes: bytes,
})
}
// RequestTrieNodes fetches a batch of account or storage trie nodes rooted in
// a specificstate trie.
func (p *Peer) RequestTrieNodes(id uint64, root common.Hash, paths []TrieNodePathSet, bytes uint64) error {
p.logger.Trace("Fetching set of trie nodes", "reqid", id, "root", root, "pathsets", len(paths), "bytes", common.StorageSize(bytes))
return p2p.Send(p.rw, GetTrieNodesMsg, &GetTrieNodesPacket{
ID: id,
Root: root,
Paths: paths,
Bytes: bytes,
})
}

View File

@ -0,0 +1,218 @@
// Copyright 2020 The go-ethereum Authors
// This file is part of the go-ethereum library.
//
// The go-ethereum library is free software: you can redistribute it and/or modify
// it under the terms of the GNU Lesser General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// The go-ethereum library is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Lesser General Public License for more details.
//
// You should have received a copy of the GNU Lesser General Public License
// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
package snap
import (
"errors"
"fmt"
"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/core/state/snapshot"
"github.com/ethereum/go-ethereum/rlp"
)
// Constants to match up protocol versions and messages
const (
snap1 = 1
)
// protocolName is the official short name of the `snap` protocol used during
// devp2p capability negotiation.
const protocolName = "snap"
// protocolVersions are the supported versions of the `snap` protocol (first
// is primary).
var protocolVersions = []uint{snap1}
// protocolLengths are the number of implemented message corresponding to
// different protocol versions.
var protocolLengths = map[uint]uint64{snap1: 8}
// maxMessageSize is the maximum cap on the size of a protocol message.
const maxMessageSize = 10 * 1024 * 1024
const (
GetAccountRangeMsg = 0x00
AccountRangeMsg = 0x01
GetStorageRangesMsg = 0x02
StorageRangesMsg = 0x03
GetByteCodesMsg = 0x04
ByteCodesMsg = 0x05
GetTrieNodesMsg = 0x06
TrieNodesMsg = 0x07
)
var (
errMsgTooLarge = errors.New("message too long")
errDecode = errors.New("invalid message")
errInvalidMsgCode = errors.New("invalid message code")
errBadRequest = errors.New("bad request")
)
// Packet represents a p2p message in the `snap` protocol.
type Packet interface {
Name() string // Name returns a string corresponding to the message type.
Kind() byte // Kind returns the message type.
}
// GetAccountRangePacket represents an account query.
type GetAccountRangePacket struct {
ID uint64 // Request ID to match up responses with
Root common.Hash // Root hash of the account trie to serve
Origin common.Hash // Hash of the first account to retrieve
Limit common.Hash // Hash of the last account to retrieve
Bytes uint64 // Soft limit at which to stop returning data
}
// AccountRangePacket represents an account query response.
type AccountRangePacket struct {
ID uint64 // ID of the request this is a response for
Accounts []*AccountData // List of consecutive accounts from the trie
Proof [][]byte // List of trie nodes proving the account range
}
// AccountData represents a single account in a query response.
type AccountData struct {
Hash common.Hash // Hash of the account
Body rlp.RawValue // Account body in slim format
}
// Unpack retrieves the accounts from the range packet and converts from slim
// wire representation to consensus format. The returned data is RLP encoded
// since it's expected to be serialized to disk without further interpretation.
//
// Note, this method does a round of RLP decoding and reencoding, so only use it
// once and cache the results if need be. Ideally discard the packet afterwards
// to not double the memory use.
func (p *AccountRangePacket) Unpack() ([]common.Hash, [][]byte, error) {
var (
hashes = make([]common.Hash, len(p.Accounts))
accounts = make([][]byte, len(p.Accounts))
)
for i, acc := range p.Accounts {
val, err := snapshot.FullAccountRLP(acc.Body)
if err != nil {
return nil, nil, fmt.Errorf("invalid account %x: %v", acc.Body, err)
}
hashes[i], accounts[i] = acc.Hash, val
}
return hashes, accounts, nil
}
// GetStorageRangesPacket represents an storage slot query.
type GetStorageRangesPacket struct {
ID uint64 // Request ID to match up responses with
Root common.Hash // Root hash of the account trie to serve
Accounts []common.Hash // Account hashes of the storage tries to serve
Origin []byte // Hash of the first storage slot to retrieve (large contract mode)
Limit []byte // Hash of the last storage slot to retrieve (large contract mode)
Bytes uint64 // Soft limit at which to stop returning data
}
// StorageRangesPacket represents a storage slot query response.
type StorageRangesPacket struct {
ID uint64 // ID of the request this is a response for
Slots [][]*StorageData // Lists of consecutive storage slots for the requested accounts
Proof [][]byte // Merkle proofs for the *last* slot range, if it's incomplete
}
// StorageData represents a single storage slot in a query response.
type StorageData struct {
Hash common.Hash // Hash of the storage slot
Body []byte // Data content of the slot
}
// Unpack retrieves the storage slots from the range packet and returns them in
// a split flat format that's more consistent with the internal data structures.
func (p *StorageRangesPacket) Unpack() ([][]common.Hash, [][][]byte) {
var (
hashset = make([][]common.Hash, len(p.Slots))
slotset = make([][][]byte, len(p.Slots))
)
for i, slots := range p.Slots {
hashset[i] = make([]common.Hash, len(slots))
slotset[i] = make([][]byte, len(slots))
for j, slot := range slots {
hashset[i][j] = slot.Hash
slotset[i][j] = slot.Body
}
}
return hashset, slotset
}
// GetByteCodesPacket represents a contract bytecode query.
type GetByteCodesPacket struct {
ID uint64 // Request ID to match up responses with
Hashes []common.Hash // Code hashes to retrieve the code for
Bytes uint64 // Soft limit at which to stop returning data
}
// ByteCodesPacket represents a contract bytecode query response.
type ByteCodesPacket struct {
ID uint64 // ID of the request this is a response for
Codes [][]byte // Requested contract bytecodes
}
// GetTrieNodesPacket represents a state trie node query.
type GetTrieNodesPacket struct {
ID uint64 // Request ID to match up responses with
Root common.Hash // Root hash of the account trie to serve
Paths []TrieNodePathSet // Trie node hashes to retrieve the nodes for
Bytes uint64 // Soft limit at which to stop returning data
}
// TrieNodePathSet is a list of trie node paths to retrieve. A naive way to
// represent trie nodes would be a simple list of `account || storage` path
// segments concatenated, but that would be very wasteful on the network.
//
// Instead, this array special cases the first element as the path in the
// account trie and the remaining elements as paths in the storage trie. To
// address an account node, the slice should have a length of 1 consisting
// of only the account path. There's no need to be able to address both an
// account node and a storage node in the same request as it cannot happen
// that a slot is accessed before the account path is fully expanded.
type TrieNodePathSet [][]byte
// TrieNodesPacket represents a state trie node query response.
type TrieNodesPacket struct {
ID uint64 // ID of the request this is a response for
Nodes [][]byte // Requested state trie nodes
}
func (*GetAccountRangePacket) Name() string { return "GetAccountRange" }
func (*GetAccountRangePacket) Kind() byte { return GetAccountRangeMsg }
func (*AccountRangePacket) Name() string { return "AccountRange" }
func (*AccountRangePacket) Kind() byte { return AccountRangeMsg }
func (*GetStorageRangesPacket) Name() string { return "GetStorageRanges" }
func (*GetStorageRangesPacket) Kind() byte { return GetStorageRangesMsg }
func (*StorageRangesPacket) Name() string { return "StorageRanges" }
func (*StorageRangesPacket) Kind() byte { return StorageRangesMsg }
func (*GetByteCodesPacket) Name() string { return "GetByteCodes" }
func (*GetByteCodesPacket) Kind() byte { return GetByteCodesMsg }
func (*ByteCodesPacket) Name() string { return "ByteCodes" }
func (*ByteCodesPacket) Kind() byte { return ByteCodesMsg }
func (*GetTrieNodesPacket) Name() string { return "GetTrieNodes" }
func (*GetTrieNodesPacket) Kind() byte { return GetTrieNodesMsg }
func (*TrieNodesPacket) Name() string { return "TrieNodes" }
func (*TrieNodesPacket) Kind() byte { return TrieNodesMsg }

2481
eth/protocols/snap/sync.go Normal file

File diff suppressed because it is too large Load Diff

View File

@ -26,6 +26,7 @@ import (
"github.com/ethereum/go-ethereum/core/rawdb" "github.com/ethereum/go-ethereum/core/rawdb"
"github.com/ethereum/go-ethereum/core/types" "github.com/ethereum/go-ethereum/core/types"
"github.com/ethereum/go-ethereum/eth/downloader" "github.com/ethereum/go-ethereum/eth/downloader"
"github.com/ethereum/go-ethereum/eth/protocols/eth"
"github.com/ethereum/go-ethereum/log" "github.com/ethereum/go-ethereum/log"
"github.com/ethereum/go-ethereum/p2p/enode" "github.com/ethereum/go-ethereum/p2p/enode"
) )
@ -40,12 +41,12 @@ const (
) )
type txsync struct { type txsync struct {
p *peer p *eth.Peer
txs []*types.Transaction txs []*types.Transaction
} }
// syncTransactions starts sending all currently pending transactions to the given peer. // syncTransactions starts sending all currently pending transactions to the given peer.
func (pm *ProtocolManager) syncTransactions(p *peer) { func (h *handler) syncTransactions(p *eth.Peer) {
// Assemble the set of transaction to broadcast or announce to the remote // Assemble the set of transaction to broadcast or announce to the remote
// peer. Fun fact, this is quite an expensive operation as it needs to sort // peer. Fun fact, this is quite an expensive operation as it needs to sort
// the transactions if the sorting is not cached yet. However, with a random // the transactions if the sorting is not cached yet. However, with a random
@ -53,7 +54,7 @@ func (pm *ProtocolManager) syncTransactions(p *peer) {
// //
// TODO(karalabe): Figure out if we could get away with random order somehow // TODO(karalabe): Figure out if we could get away with random order somehow
var txs types.Transactions var txs types.Transactions
pending, _ := pm.txpool.Pending() pending, _ := h.txpool.Pending()
for _, batch := range pending { for _, batch := range pending {
txs = append(txs, batch...) txs = append(txs, batch...)
} }
@ -63,7 +64,7 @@ func (pm *ProtocolManager) syncTransactions(p *peer) {
// The eth/65 protocol introduces proper transaction announcements, so instead // The eth/65 protocol introduces proper transaction announcements, so instead
// of dripping transactions across multiple peers, just send the entire list as // of dripping transactions across multiple peers, just send the entire list as
// an announcement and let the remote side decide what they need (likely nothing). // an announcement and let the remote side decide what they need (likely nothing).
if p.version >= eth65 { if p.Version() >= eth.ETH65 {
hashes := make([]common.Hash, len(txs)) hashes := make([]common.Hash, len(txs))
for i, tx := range txs { for i, tx := range txs {
hashes[i] = tx.Hash() hashes[i] = tx.Hash()
@ -73,8 +74,8 @@ func (pm *ProtocolManager) syncTransactions(p *peer) {
} }
// Out of luck, peer is running legacy protocols, drop the txs over // Out of luck, peer is running legacy protocols, drop the txs over
select { select {
case pm.txsyncCh <- &txsync{p: p, txs: txs}: case h.txsyncCh <- &txsync{p: p, txs: txs}:
case <-pm.quitSync: case <-h.quitSync:
} }
} }
@ -82,8 +83,8 @@ func (pm *ProtocolManager) syncTransactions(p *peer) {
// connection. When a new peer appears, we relay all currently pending // connection. When a new peer appears, we relay all currently pending
// transactions. In order to minimise egress bandwidth usage, we send // transactions. In order to minimise egress bandwidth usage, we send
// the transactions in small packs to one peer at a time. // the transactions in small packs to one peer at a time.
func (pm *ProtocolManager) txsyncLoop64() { func (h *handler) txsyncLoop64() {
defer pm.wg.Done() defer h.wg.Done()
var ( var (
pending = make(map[enode.ID]*txsync) pending = make(map[enode.ID]*txsync)
@ -94,7 +95,7 @@ func (pm *ProtocolManager) txsyncLoop64() {
// send starts a sending a pack of transactions from the sync. // send starts a sending a pack of transactions from the sync.
send := func(s *txsync) { send := func(s *txsync) {
if s.p.version >= eth65 { if s.p.Version() >= eth.ETH65 {
panic("initial transaction syncer running on eth/65+") panic("initial transaction syncer running on eth/65+")
} }
// Fill pack with transactions up to the target size. // Fill pack with transactions up to the target size.
@ -108,14 +109,13 @@ func (pm *ProtocolManager) txsyncLoop64() {
// Remove the transactions that will be sent. // Remove the transactions that will be sent.
s.txs = s.txs[:copy(s.txs, s.txs[len(pack.txs):])] s.txs = s.txs[:copy(s.txs, s.txs[len(pack.txs):])]
if len(s.txs) == 0 { if len(s.txs) == 0 {
delete(pending, s.p.ID()) delete(pending, s.p.Peer.ID())
} }
// Send the pack in the background. // Send the pack in the background.
s.p.Log().Trace("Sending batch of transactions", "count", len(pack.txs), "bytes", size) s.p.Log().Trace("Sending batch of transactions", "count", len(pack.txs), "bytes", size)
sending = true sending = true
go func() { done <- pack.p.SendTransactions64(pack.txs) }() go func() { done <- pack.p.SendTransactions(pack.txs) }()
} }
// pick chooses the next pending sync. // pick chooses the next pending sync.
pick := func() *txsync { pick := func() *txsync {
if len(pending) == 0 { if len(pending) == 0 {
@ -132,8 +132,8 @@ func (pm *ProtocolManager) txsyncLoop64() {
for { for {
select { select {
case s := <-pm.txsyncCh: case s := <-h.txsyncCh:
pending[s.p.ID()] = s pending[s.p.Peer.ID()] = s
if !sending { if !sending {
send(s) send(s)
} }
@ -142,13 +142,13 @@ func (pm *ProtocolManager) txsyncLoop64() {
// Stop tracking peers that cause send failures. // Stop tracking peers that cause send failures.
if err != nil { if err != nil {
pack.p.Log().Debug("Transaction send failed", "err", err) pack.p.Log().Debug("Transaction send failed", "err", err)
delete(pending, pack.p.ID()) delete(pending, pack.p.Peer.ID())
} }
// Schedule the next send. // Schedule the next send.
if s := pick(); s != nil { if s := pick(); s != nil {
send(s) send(s)
} }
case <-pm.quitSync: case <-h.quitSync:
return return
} }
} }
@ -156,7 +156,7 @@ func (pm *ProtocolManager) txsyncLoop64() {
// chainSyncer coordinates blockchain sync components. // chainSyncer coordinates blockchain sync components.
type chainSyncer struct { type chainSyncer struct {
pm *ProtocolManager handler *handler
force *time.Timer force *time.Timer
forced bool // true when force timer fired forced bool // true when force timer fired
peerEventCh chan struct{} peerEventCh chan struct{}
@ -166,15 +166,15 @@ type chainSyncer struct {
// chainSyncOp is a scheduled sync operation. // chainSyncOp is a scheduled sync operation.
type chainSyncOp struct { type chainSyncOp struct {
mode downloader.SyncMode mode downloader.SyncMode
peer *peer peer *eth.Peer
td *big.Int td *big.Int
head common.Hash head common.Hash
} }
// newChainSyncer creates a chainSyncer. // newChainSyncer creates a chainSyncer.
func newChainSyncer(pm *ProtocolManager) *chainSyncer { func newChainSyncer(handler *handler) *chainSyncer {
return &chainSyncer{ return &chainSyncer{
pm: pm, handler: handler,
peerEventCh: make(chan struct{}), peerEventCh: make(chan struct{}),
} }
} }
@ -182,23 +182,24 @@ func newChainSyncer(pm *ProtocolManager) *chainSyncer {
// handlePeerEvent notifies the syncer about a change in the peer set. // handlePeerEvent notifies the syncer about a change in the peer set.
// This is called for new peers and every time a peer announces a new // This is called for new peers and every time a peer announces a new
// chain head. // chain head.
func (cs *chainSyncer) handlePeerEvent(p *peer) bool { func (cs *chainSyncer) handlePeerEvent(peer *eth.Peer) bool {
select { select {
case cs.peerEventCh <- struct{}{}: case cs.peerEventCh <- struct{}{}:
return true return true
case <-cs.pm.quitSync: case <-cs.handler.quitSync:
return false return false
} }
} }
// loop runs in its own goroutine and launches the sync when necessary. // loop runs in its own goroutine and launches the sync when necessary.
func (cs *chainSyncer) loop() { func (cs *chainSyncer) loop() {
defer cs.pm.wg.Done() defer cs.handler.wg.Done()
cs.pm.blockFetcher.Start() cs.handler.blockFetcher.Start()
cs.pm.txFetcher.Start() cs.handler.txFetcher.Start()
defer cs.pm.blockFetcher.Stop() defer cs.handler.blockFetcher.Stop()
defer cs.pm.txFetcher.Stop() defer cs.handler.txFetcher.Stop()
defer cs.handler.downloader.Terminate()
// The force timer lowers the peer count threshold down to one when it fires. // The force timer lowers the peer count threshold down to one when it fires.
// This ensures we'll always start sync even if there aren't enough peers. // This ensures we'll always start sync even if there aren't enough peers.
@ -209,7 +210,6 @@ func (cs *chainSyncer) loop() {
if op := cs.nextSyncOp(); op != nil { if op := cs.nextSyncOp(); op != nil {
cs.startSync(op) cs.startSync(op)
} }
select { select {
case <-cs.peerEventCh: case <-cs.peerEventCh:
// Peer information changed, recheck. // Peer information changed, recheck.
@ -220,14 +220,13 @@ func (cs *chainSyncer) loop() {
case <-cs.force.C: case <-cs.force.C:
cs.forced = true cs.forced = true
case <-cs.pm.quitSync: case <-cs.handler.quitSync:
// Disable all insertion on the blockchain. This needs to happen before // Disable all insertion on the blockchain. This needs to happen before
// terminating the downloader because the downloader waits for blockchain // terminating the downloader because the downloader waits for blockchain
// inserts, and these can take a long time to finish. // inserts, and these can take a long time to finish.
cs.pm.blockchain.StopInsert() cs.handler.chain.StopInsert()
cs.pm.downloader.Terminate() cs.handler.downloader.Terminate()
if cs.doneCh != nil { if cs.doneCh != nil {
// Wait for the current sync to end.
<-cs.doneCh <-cs.doneCh
} }
return return
@ -245,19 +244,22 @@ func (cs *chainSyncer) nextSyncOp() *chainSyncOp {
minPeers := defaultMinSyncPeers minPeers := defaultMinSyncPeers
if cs.forced { if cs.forced {
minPeers = 1 minPeers = 1
} else if minPeers > cs.pm.maxPeers { } else if minPeers > cs.handler.maxPeers {
minPeers = cs.pm.maxPeers minPeers = cs.handler.maxPeers
} }
if cs.pm.peers.Len() < minPeers { if cs.handler.peers.Len() < minPeers {
return nil return nil
} }
// We have enough peers, check TD
// We have enough peers, check TD. peer := cs.handler.peers.ethPeerWithHighestTD()
peer := cs.pm.peers.BestPeer()
if peer == nil { if peer == nil {
return nil return nil
} }
mode, ourTD := cs.modeAndLocalHead() mode, ourTD := cs.modeAndLocalHead()
if mode == downloader.FastSync && atomic.LoadUint32(&cs.handler.snapSync) == 1 {
// Fast sync via the snap protocol
mode = downloader.SnapSync
}
op := peerToSyncOp(mode, peer) op := peerToSyncOp(mode, peer)
if op.td.Cmp(ourTD) <= 0 { if op.td.Cmp(ourTD) <= 0 {
return nil // We're in sync. return nil // We're in sync.
@ -265,42 +267,42 @@ func (cs *chainSyncer) nextSyncOp() *chainSyncOp {
return op return op
} }
func peerToSyncOp(mode downloader.SyncMode, p *peer) *chainSyncOp { func peerToSyncOp(mode downloader.SyncMode, p *eth.Peer) *chainSyncOp {
peerHead, peerTD := p.Head() peerHead, peerTD := p.Head()
return &chainSyncOp{mode: mode, peer: p, td: peerTD, head: peerHead} return &chainSyncOp{mode: mode, peer: p, td: peerTD, head: peerHead}
} }
func (cs *chainSyncer) modeAndLocalHead() (downloader.SyncMode, *big.Int) { func (cs *chainSyncer) modeAndLocalHead() (downloader.SyncMode, *big.Int) {
// If we're in fast sync mode, return that directly // If we're in fast sync mode, return that directly
if atomic.LoadUint32(&cs.pm.fastSync) == 1 { if atomic.LoadUint32(&cs.handler.fastSync) == 1 {
block := cs.pm.blockchain.CurrentFastBlock() block := cs.handler.chain.CurrentFastBlock()
td := cs.pm.blockchain.GetTdByHash(block.Hash()) td := cs.handler.chain.GetTdByHash(block.Hash())
return downloader.FastSync, td return downloader.FastSync, td
} }
// We are probably in full sync, but we might have rewound to before the // We are probably in full sync, but we might have rewound to before the
// fast sync pivot, check if we should reenable // fast sync pivot, check if we should reenable
if pivot := rawdb.ReadLastPivotNumber(cs.pm.chaindb); pivot != nil { if pivot := rawdb.ReadLastPivotNumber(cs.handler.database); pivot != nil {
if head := cs.pm.blockchain.CurrentBlock(); head.NumberU64() < *pivot { if head := cs.handler.chain.CurrentBlock(); head.NumberU64() < *pivot {
block := cs.pm.blockchain.CurrentFastBlock() block := cs.handler.chain.CurrentFastBlock()
td := cs.pm.blockchain.GetTdByHash(block.Hash()) td := cs.handler.chain.GetTdByHash(block.Hash())
return downloader.FastSync, td return downloader.FastSync, td
} }
} }
// Nope, we're really full syncing // Nope, we're really full syncing
head := cs.pm.blockchain.CurrentHeader() head := cs.handler.chain.CurrentHeader()
td := cs.pm.blockchain.GetTd(head.Hash(), head.Number.Uint64()) td := cs.handler.chain.GetTd(head.Hash(), head.Number.Uint64())
return downloader.FullSync, td return downloader.FullSync, td
} }
// startSync launches doSync in a new goroutine. // startSync launches doSync in a new goroutine.
func (cs *chainSyncer) startSync(op *chainSyncOp) { func (cs *chainSyncer) startSync(op *chainSyncOp) {
cs.doneCh = make(chan error, 1) cs.doneCh = make(chan error, 1)
go func() { cs.doneCh <- cs.pm.doSync(op) }() go func() { cs.doneCh <- cs.handler.doSync(op) }()
} }
// doSync synchronizes the local blockchain with a remote peer. // doSync synchronizes the local blockchain with a remote peer.
func (pm *ProtocolManager) doSync(op *chainSyncOp) error { func (h *handler) doSync(op *chainSyncOp) error {
if op.mode == downloader.FastSync { if op.mode == downloader.FastSync || op.mode == downloader.SnapSync {
// Before launch the fast sync, we have to ensure user uses the same // Before launch the fast sync, we have to ensure user uses the same
// txlookup limit. // txlookup limit.
// The main concern here is: during the fast sync Geth won't index the // The main concern here is: during the fast sync Geth won't index the
@ -310,35 +312,33 @@ func (pm *ProtocolManager) doSync(op *chainSyncOp) error {
// has been indexed. So here for the user-experience wise, it's non-optimal // has been indexed. So here for the user-experience wise, it's non-optimal
// that user can't change limit during the fast sync. If changed, Geth // that user can't change limit during the fast sync. If changed, Geth
// will just blindly use the original one. // will just blindly use the original one.
limit := pm.blockchain.TxLookupLimit() limit := h.chain.TxLookupLimit()
if stored := rawdb.ReadFastTxLookupLimit(pm.chaindb); stored == nil { if stored := rawdb.ReadFastTxLookupLimit(h.database); stored == nil {
rawdb.WriteFastTxLookupLimit(pm.chaindb, limit) rawdb.WriteFastTxLookupLimit(h.database, limit)
} else if *stored != limit { } else if *stored != limit {
pm.blockchain.SetTxLookupLimit(*stored) h.chain.SetTxLookupLimit(*stored)
log.Warn("Update txLookup limit", "provided", limit, "updated", *stored) log.Warn("Update txLookup limit", "provided", limit, "updated", *stored)
} }
} }
// Run the sync cycle, and disable fast sync if we're past the pivot block // Run the sync cycle, and disable fast sync if we're past the pivot block
err := pm.downloader.Synchronise(op.peer.id, op.head, op.td, op.mode) err := h.downloader.Synchronise(op.peer.ID(), op.head, op.td, op.mode)
if err != nil { if err != nil {
return err return err
} }
if atomic.LoadUint32(&pm.fastSync) == 1 { if atomic.LoadUint32(&h.fastSync) == 1 {
log.Info("Fast sync complete, auto disabling") log.Info("Fast sync complete, auto disabling")
atomic.StoreUint32(&pm.fastSync, 0) atomic.StoreUint32(&h.fastSync, 0)
} }
// If we've successfully finished a sync cycle and passed any required checkpoint, // If we've successfully finished a sync cycle and passed any required checkpoint,
// enable accepting transactions from the network. // enable accepting transactions from the network.
head := pm.blockchain.CurrentBlock() head := h.chain.CurrentBlock()
if head.NumberU64() >= pm.checkpointNumber { if head.NumberU64() >= h.checkpointNumber {
// Checkpoint passed, sanity check the timestamp to have a fallback mechanism // Checkpoint passed, sanity check the timestamp to have a fallback mechanism
// for non-checkpointed (number = 0) private networks. // for non-checkpointed (number = 0) private networks.
if head.Time() >= uint64(time.Now().AddDate(0, -1, 0).Unix()) { if head.Time() >= uint64(time.Now().AddDate(0, -1, 0).Unix()) {
atomic.StoreUint32(&pm.acceptTxs, 1) atomic.StoreUint32(&h.acceptTxs, 1)
} }
} }
if head.NumberU64() > 0 { if head.NumberU64() > 0 {
// We've completed a sync cycle, notify all peers of new state. This path is // We've completed a sync cycle, notify all peers of new state. This path is
// essential in star-topology networks where a gateway node needs to notify // essential in star-topology networks where a gateway node needs to notify
@ -346,8 +346,7 @@ func (pm *ProtocolManager) doSync(op *chainSyncOp) error {
// scenario will most often crop up in private and hackathon networks with // scenario will most often crop up in private and hackathon networks with
// degenerate connectivity, but it should be healthy for the mainnet too to // degenerate connectivity, but it should be healthy for the mainnet too to
// more reliably update peers or the local TD state. // more reliably update peers or the local TD state.
pm.BroadcastBlock(head, false) h.BroadcastBlock(head, false)
} }
return nil return nil
} }

View File

@ -22,43 +22,59 @@ import (
"time" "time"
"github.com/ethereum/go-ethereum/eth/downloader" "github.com/ethereum/go-ethereum/eth/downloader"
"github.com/ethereum/go-ethereum/eth/protocols/eth"
"github.com/ethereum/go-ethereum/p2p" "github.com/ethereum/go-ethereum/p2p"
"github.com/ethereum/go-ethereum/p2p/enode" "github.com/ethereum/go-ethereum/p2p/enode"
) )
func TestFastSyncDisabling63(t *testing.T) { testFastSyncDisabling(t, 63) } // Tests that fast sync is disabled after a successful sync cycle.
func TestFastSyncDisabling64(t *testing.T) { testFastSyncDisabling(t, 64) } func TestFastSyncDisabling64(t *testing.T) { testFastSyncDisabling(t, 64) }
func TestFastSyncDisabling65(t *testing.T) { testFastSyncDisabling(t, 65) } func TestFastSyncDisabling65(t *testing.T) { testFastSyncDisabling(t, 65) }
// Tests that fast sync gets disabled as soon as a real block is successfully // Tests that fast sync gets disabled as soon as a real block is successfully
// imported into the blockchain. // imported into the blockchain.
func testFastSyncDisabling(t *testing.T, protocol int) { func testFastSyncDisabling(t *testing.T, protocol uint) {
t.Parallel() t.Parallel()
// Create a pristine protocol manager, check that fast sync is left enabled // Create an empty handler and ensure it's in fast sync mode
pmEmpty, _ := newTestProtocolManagerMust(t, downloader.FastSync, 0, nil, nil) empty := newTestHandler()
if atomic.LoadUint32(&pmEmpty.fastSync) == 0 { if atomic.LoadUint32(&empty.handler.fastSync) == 0 {
t.Fatalf("fast sync disabled on pristine blockchain") t.Fatalf("fast sync disabled on pristine blockchain")
} }
// Create a full protocol manager, check that fast sync gets disabled defer empty.close()
pmFull, _ := newTestProtocolManagerMust(t, downloader.FastSync, 1024, nil, nil)
if atomic.LoadUint32(&pmFull.fastSync) == 1 { // Create a full handler and ensure fast sync ends up disabled
full := newTestHandlerWithBlocks(1024)
if atomic.LoadUint32(&full.handler.fastSync) == 1 {
t.Fatalf("fast sync not disabled on non-empty blockchain") t.Fatalf("fast sync not disabled on non-empty blockchain")
} }
defer full.close()
// Sync up the two peers // Sync up the two handlers
io1, io2 := p2p.MsgPipe() emptyPipe, fullPipe := p2p.MsgPipe()
go pmFull.handle(pmFull.newPeer(protocol, p2p.NewPeer(enode.ID{}, "empty", nil), io2, pmFull.txpool.Get)) defer emptyPipe.Close()
go pmEmpty.handle(pmEmpty.newPeer(protocol, p2p.NewPeer(enode.ID{}, "full", nil), io1, pmEmpty.txpool.Get)) defer fullPipe.Close()
emptyPeer := eth.NewPeer(protocol, p2p.NewPeer(enode.ID{1}, "", nil), emptyPipe, empty.txpool)
fullPeer := eth.NewPeer(protocol, p2p.NewPeer(enode.ID{2}, "", nil), fullPipe, full.txpool)
defer emptyPeer.Close()
defer fullPeer.Close()
go empty.handler.runEthPeer(emptyPeer, func(peer *eth.Peer) error {
return eth.Handle((*ethHandler)(empty.handler), peer)
})
go full.handler.runEthPeer(fullPeer, func(peer *eth.Peer) error {
return eth.Handle((*ethHandler)(full.handler), peer)
})
// Wait a bit for the above handlers to start
time.Sleep(250 * time.Millisecond) time.Sleep(250 * time.Millisecond)
op := peerToSyncOp(downloader.FastSync, pmEmpty.peers.BestPeer())
if err := pmEmpty.doSync(op); err != nil {
t.Fatal("sync failed:", err)
}
// Check that fast sync was disabled // Check that fast sync was disabled
if atomic.LoadUint32(&pmEmpty.fastSync) == 1 { op := peerToSyncOp(downloader.FastSync, empty.handler.peers.ethPeerWithHighestTD())
if err := empty.handler.doSync(op); err != nil {
t.Fatal("sync failed:", err)
}
if atomic.LoadUint32(&empty.handler.fastSync) == 1 {
t.Fatalf("fast sync not disabled after successful synchronisation") t.Fatalf("fast sync not disabled after successful synchronisation")
} }
} }

View File

@ -36,8 +36,8 @@ import (
"github.com/ethereum/go-ethereum/consensus" "github.com/ethereum/go-ethereum/consensus"
"github.com/ethereum/go-ethereum/core" "github.com/ethereum/go-ethereum/core"
"github.com/ethereum/go-ethereum/core/types" "github.com/ethereum/go-ethereum/core/types"
"github.com/ethereum/go-ethereum/eth"
"github.com/ethereum/go-ethereum/eth/downloader" "github.com/ethereum/go-ethereum/eth/downloader"
ethproto "github.com/ethereum/go-ethereum/eth/protocols/eth"
"github.com/ethereum/go-ethereum/event" "github.com/ethereum/go-ethereum/event"
"github.com/ethereum/go-ethereum/les" "github.com/ethereum/go-ethereum/les"
"github.com/ethereum/go-ethereum/log" "github.com/ethereum/go-ethereum/log"
@ -444,13 +444,15 @@ func (s *Service) login(conn *connWrapper) error {
// Construct and send the login authentication // Construct and send the login authentication
infos := s.server.NodeInfo() infos := s.server.NodeInfo()
var network, protocol string var protocols []string
for _, proto := range s.server.Protocols {
protocols = append(protocols, fmt.Sprintf("%s/%d", proto.Name, proto.Version))
}
var network string
if info := infos.Protocols["eth"]; info != nil { if info := infos.Protocols["eth"]; info != nil {
network = fmt.Sprintf("%d", info.(*eth.NodeInfo).Network) network = fmt.Sprintf("%d", info.(*ethproto.NodeInfo).Network)
protocol = fmt.Sprintf("eth/%d", eth.ProtocolVersions[0])
} else { } else {
network = fmt.Sprintf("%d", infos.Protocols["les"].(*les.NodeInfo).Network) network = fmt.Sprintf("%d", infos.Protocols["les"].(*les.NodeInfo).Network)
protocol = fmt.Sprintf("les/%d", les.ClientProtocolVersions[0])
} }
auth := &authMsg{ auth := &authMsg{
ID: s.node, ID: s.node,
@ -459,7 +461,7 @@ func (s *Service) login(conn *connWrapper) error {
Node: infos.Name, Node: infos.Name,
Port: infos.Ports.Listener, Port: infos.Ports.Listener,
Network: network, Network: network,
Protocol: protocol, Protocol: strings.Join(protocols, ", "),
API: "No", API: "No",
Os: runtime.GOOS, Os: runtime.GOOS,
OsVer: runtime.GOARCH, OsVer: runtime.GOARCH,

View File

@ -1040,10 +1040,6 @@ func (r *Resolver) GasPrice(ctx context.Context) (hexutil.Big, error) {
return hexutil.Big(*price), err return hexutil.Big(*price), err
} }
func (r *Resolver) ProtocolVersion(ctx context.Context) (int32, error) {
return int32(r.backend.ProtocolVersion()), nil
}
func (r *Resolver) ChainID(ctx context.Context) (hexutil.Big, error) { func (r *Resolver) ChainID(ctx context.Context) (hexutil.Big, error) {
return hexutil.Big(*r.backend.ChainConfig().ChainID), nil return hexutil.Big(*r.backend.ChainConfig().ChainID), nil
} }

View File

@ -310,8 +310,6 @@ const schema string = `
# GasPrice returns the node's estimate of a gas price sufficient to # GasPrice returns the node's estimate of a gas price sufficient to
# ensure a transaction is mined in a timely fashion. # ensure a transaction is mined in a timely fashion.
gasPrice: BigInt! gasPrice: BigInt!
# ProtocolVersion returns the current wire protocol version number.
protocolVersion: Int!
# Syncing returns information on the current synchronisation state. # Syncing returns information on the current synchronisation state.
syncing: SyncState syncing: SyncState
# ChainID returns the current chain ID for transaction replay protection. # ChainID returns the current chain ID for transaction replay protection.

View File

@ -64,11 +64,6 @@ func (s *PublicEthereumAPI) GasPrice(ctx context.Context) (*hexutil.Big, error)
return (*hexutil.Big)(price), err return (*hexutil.Big)(price), err
} }
// ProtocolVersion returns the current Ethereum protocol version this node supports
func (s *PublicEthereumAPI) ProtocolVersion() hexutil.Uint {
return hexutil.Uint(s.b.ProtocolVersion())
}
// Syncing returns false in case the node is currently not syncing with the network. It can be up to date or has not // Syncing returns false in case the node is currently not syncing with the network. It can be up to date or has not
// yet received the latest block headers from its pears. In case it is synchronizing: // yet received the latest block headers from its pears. In case it is synchronizing:
// - startingBlock: block number this node started to synchronise from // - startingBlock: block number this node started to synchronise from
@ -1906,12 +1901,11 @@ func (api *PrivateDebugAPI) SetHead(number hexutil.Uint64) {
// PublicNetAPI offers network related RPC methods // PublicNetAPI offers network related RPC methods
type PublicNetAPI struct { type PublicNetAPI struct {
net *p2p.Server net *p2p.Server
networkVersion uint64
} }
// NewPublicNetAPI creates a new net API instance. // NewPublicNetAPI creates a new net API instance.
func NewPublicNetAPI(net *p2p.Server, networkVersion uint64) *PublicNetAPI { func NewPublicNetAPI(net *p2p.Server) *PublicNetAPI {
return &PublicNetAPI{net, networkVersion} return &PublicNetAPI{net}
} }
// Listening returns an indication if the node is listening for network connections. // Listening returns an indication if the node is listening for network connections.
@ -1924,11 +1918,6 @@ func (s *PublicNetAPI) PeerCount() hexutil.Uint {
return hexutil.Uint(s.net.PeerCount()) return hexutil.Uint(s.net.PeerCount())
} }
// Version returns the current ethereum protocol version.
func (s *PublicNetAPI) Version() string {
return fmt.Sprintf("%d", s.networkVersion)
}
// checkTxFee is an internal function used to check whether the fee of // checkTxFee is an internal function used to check whether the fee of
// the given transaction is _reasonable_(under the cap). // the given transaction is _reasonable_(under the cap).
func checkTxFee(gasPrice *big.Int, gas uint64, cap float64) error { func checkTxFee(gasPrice *big.Int, gas uint64, cap float64) error {

View File

@ -41,7 +41,6 @@ import (
type Backend interface { type Backend interface {
// General Ethereum API // General Ethereum API
Downloader() *downloader.Downloader Downloader() *downloader.Downloader
ProtocolVersion() int
SuggestPrice(ctx context.Context) (*big.Int, error) SuggestPrice(ctx context.Context) (*big.Int, error)
ChainDb() ethdb.Database ChainDb() ethdb.Database
AccountManager() *accounts.Manager AccountManager() *accounts.Manager

View File

@ -171,7 +171,7 @@ func New(stack *node.Node, config *eth.Config) (*LightEthereum, error) {
leth.blockchain.DisableCheckFreq() leth.blockchain.DisableCheckFreq()
} }
leth.netRPCService = ethapi.NewPublicNetAPI(leth.p2pServer, leth.config.NetworkId) leth.netRPCService = ethapi.NewPublicNetAPI(leth.p2pServer)
// Register the backend on the node // Register the backend on the node
stack.RegisterAPIs(leth.APIs()) stack.RegisterAPIs(leth.APIs())

View File

@ -35,9 +35,9 @@ func (e lesEntry) ENRKey() string {
// setupDiscovery creates the node discovery source for the eth protocol. // setupDiscovery creates the node discovery source for the eth protocol.
func (eth *LightEthereum) setupDiscovery() (enode.Iterator, error) { func (eth *LightEthereum) setupDiscovery() (enode.Iterator, error) {
if len(eth.config.DiscoveryURLs) == 0 { if len(eth.config.EthDiscoveryURLs) == 0 {
return nil, nil return nil, nil
} }
client := dnsdisc.NewClient(dnsdisc.Config{}) client := dnsdisc.NewClient(dnsdisc.Config{})
return client.NewIterator(eth.config.DiscoveryURLs...) return client.NewIterator(eth.config.EthDiscoveryURLs...)
} }

View File

@ -51,7 +51,7 @@ func TestGetBlockHeadersLes2(t *testing.T) { testGetBlockHeaders(t, 2) }
func TestGetBlockHeadersLes3(t *testing.T) { testGetBlockHeaders(t, 3) } func TestGetBlockHeadersLes3(t *testing.T) { testGetBlockHeaders(t, 3) }
func testGetBlockHeaders(t *testing.T, protocol int) { func testGetBlockHeaders(t *testing.T, protocol int) {
server, tearDown := newServerEnv(t, downloader.MaxHashFetch+15, protocol, nil, false, true, 0) server, tearDown := newServerEnv(t, downloader.MaxHeaderFetch+15, protocol, nil, false, true, 0)
defer tearDown() defer tearDown()
bc := server.handler.blockchain bc := server.handler.blockchain

View File

@ -31,7 +31,6 @@ import (
"github.com/ethereum/go-ethereum/core" "github.com/ethereum/go-ethereum/core"
"github.com/ethereum/go-ethereum/core/forkid" "github.com/ethereum/go-ethereum/core/forkid"
"github.com/ethereum/go-ethereum/core/types" "github.com/ethereum/go-ethereum/core/types"
"github.com/ethereum/go-ethereum/eth"
"github.com/ethereum/go-ethereum/les/flowcontrol" "github.com/ethereum/go-ethereum/les/flowcontrol"
lpc "github.com/ethereum/go-ethereum/les/lespay/client" lpc "github.com/ethereum/go-ethereum/les/lespay/client"
lps "github.com/ethereum/go-ethereum/les/lespay/server" lps "github.com/ethereum/go-ethereum/les/lespay/server"
@ -162,9 +161,17 @@ func (p *peerCommons) String() string {
return fmt.Sprintf("Peer %s [%s]", p.id, fmt.Sprintf("les/%d", p.version)) return fmt.Sprintf("Peer %s [%s]", p.id, fmt.Sprintf("les/%d", p.version))
} }
// PeerInfo represents a short summary of the `eth` sub-protocol metadata known
// about a connected peer.
type PeerInfo struct {
Version int `json:"version"` // Ethereum protocol version negotiated
Difficulty *big.Int `json:"difficulty"` // Total difficulty of the peer's blockchain
Head string `json:"head"` // SHA3 hash of the peer's best owned block
}
// Info gathers and returns a collection of metadata known about a peer. // Info gathers and returns a collection of metadata known about a peer.
func (p *peerCommons) Info() *eth.PeerInfo { func (p *peerCommons) Info() *PeerInfo {
return &eth.PeerInfo{ return &PeerInfo{
Version: p.version, Version: p.version,
Difficulty: p.Td(), Difficulty: p.Td(),
Head: fmt.Sprintf("%x", p.Head()), Head: fmt.Sprintf("%x", p.Head()),

View File

@ -47,7 +47,7 @@ import (
const ( const (
softResponseLimit = 2 * 1024 * 1024 // Target maximum size of returned blocks, headers or node data. softResponseLimit = 2 * 1024 * 1024 // Target maximum size of returned blocks, headers or node data.
estHeaderRlpSize = 500 // Approximate size of an RLP encoded block header estHeaderRlpSize = 500 // Approximate size of an RLP encoded block header
ethVersion = 63 // equivalent eth version for the downloader ethVersion = 64 // equivalent eth version for the downloader
MaxHeaderFetch = 192 // Amount of block headers to be fetched per retrieval request MaxHeaderFetch = 192 // Amount of block headers to be fetched per retrieval request
MaxBodyFetch = 32 // Amount of block bodies to be fetched per retrieval request MaxBodyFetch = 32 // Amount of block bodies to be fetched per retrieval request

View File

@ -147,7 +147,7 @@ func (t *BlockTest) Run(snapshotter bool) error {
} }
// Cross-check the snapshot-to-hash against the trie hash // Cross-check the snapshot-to-hash against the trie hash
if snapshotter { if snapshotter {
if err := snapshot.VerifyState(chain.Snapshot(), chain.CurrentBlock().Root()); err != nil { if err := snapshot.VerifyState(chain.Snapshots(), chain.CurrentBlock().Root()); err != nil {
return err return err
} }
} }

Binary file not shown.

View File

@ -0,0 +1,41 @@
// Copyright 2020 The go-ethereum Authors
// This file is part of the go-ethereum library.
//
// The go-ethereum library is free software: you can redistribute it and/or modify
// it under the terms of the GNU Lesser General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// The go-ethereum library is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Lesser General Public License for more details.
//
// You should have received a copy of the GNU Lesser General Public License
// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
package main
import (
"fmt"
"io/ioutil"
"os"
"github.com/ethereum/go-ethereum/tests/fuzzers/rangeproof"
)
func main() {
if len(os.Args) != 2 {
fmt.Fprintf(os.Stderr, "Usage: debug <file>\n")
fmt.Fprintf(os.Stderr, "Example\n")
fmt.Fprintf(os.Stderr, " $ debug ../crashers/4bbef6857c733a87ecf6fd8b9e7238f65eb9862a\n")
os.Exit(1)
}
crasher := os.Args[1]
data, err := ioutil.ReadFile(crasher)
if err != nil {
fmt.Fprintf(os.Stderr, "error loading crasher %v: %v", crasher, err)
os.Exit(1)
}
rangeproof.Fuzz(data)
}

View File

@ -0,0 +1,218 @@
// Copyright 2020 The go-ethereum Authors
// This file is part of the go-ethereum library.
//
// The go-ethereum library is free software: you can redistribute it and/or modify
// it under the terms of the GNU Lesser General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// The go-ethereum library is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Lesser General Public License for more details.
//
// You should have received a copy of the GNU Lesser General Public License
// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
package rangeproof
import (
"bytes"
"encoding/binary"
"fmt"
"io"
"sort"
"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/ethdb/memorydb"
"github.com/ethereum/go-ethereum/trie"
)
type kv struct {
k, v []byte
t bool
}
type entrySlice []*kv
func (p entrySlice) Len() int { return len(p) }
func (p entrySlice) Less(i, j int) bool { return bytes.Compare(p[i].k, p[j].k) < 0 }
func (p entrySlice) Swap(i, j int) { p[i], p[j] = p[j], p[i] }
type fuzzer struct {
input io.Reader
exhausted bool
}
func (f *fuzzer) randBytes(n int) []byte {
r := make([]byte, n)
if _, err := f.input.Read(r); err != nil {
f.exhausted = true
}
return r
}
func (f *fuzzer) readInt() uint64 {
var x uint64
if err := binary.Read(f.input, binary.LittleEndian, &x); err != nil {
f.exhausted = true
}
return x
}
func (f *fuzzer) randomTrie(n int) (*trie.Trie, map[string]*kv) {
trie := new(trie.Trie)
vals := make(map[string]*kv)
size := f.readInt()
// Fill it with some fluff
for i := byte(0); i < byte(size); i++ {
value := &kv{common.LeftPadBytes([]byte{i}, 32), []byte{i}, false}
value2 := &kv{common.LeftPadBytes([]byte{i + 10}, 32), []byte{i}, false}
trie.Update(value.k, value.v)
trie.Update(value2.k, value2.v)
vals[string(value.k)] = value
vals[string(value2.k)] = value2
}
if f.exhausted {
return nil, nil
}
// And now fill with some random
for i := 0; i < n; i++ {
k := f.randBytes(32)
v := f.randBytes(20)
value := &kv{k, v, false}
trie.Update(k, v)
vals[string(k)] = value
if f.exhausted {
return nil, nil
}
}
return trie, vals
}
func (f *fuzzer) fuzz() int {
maxSize := 200
tr, vals := f.randomTrie(1 + int(f.readInt())%maxSize)
if f.exhausted {
return 0 // input too short
}
var entries entrySlice
for _, kv := range vals {
entries = append(entries, kv)
}
if len(entries) <= 1 {
return 0
}
sort.Sort(entries)
var ok = 0
for {
start := int(f.readInt() % uint64(len(entries)))
end := 1 + int(f.readInt()%uint64(len(entries)-1))
testcase := int(f.readInt() % uint64(6))
index := int(f.readInt() & 0xFFFFFFFF)
index2 := int(f.readInt() & 0xFFFFFFFF)
if f.exhausted {
break
}
proof := memorydb.New()
if err := tr.Prove(entries[start].k, 0, proof); err != nil {
panic(fmt.Sprintf("Failed to prove the first node %v", err))
}
if err := tr.Prove(entries[end-1].k, 0, proof); err != nil {
panic(fmt.Sprintf("Failed to prove the last node %v", err))
}
var keys [][]byte
var vals [][]byte
for i := start; i < end; i++ {
keys = append(keys, entries[i].k)
vals = append(vals, entries[i].v)
}
if len(keys) == 0 {
return 0
}
var first, last = keys[0], keys[len(keys)-1]
testcase %= 6
switch testcase {
case 0:
// Modified key
keys[index%len(keys)] = f.randBytes(32) // In theory it can't be same
case 1:
// Modified val
vals[index%len(vals)] = f.randBytes(20) // In theory it can't be same
case 2:
// Gapped entry slice
index = index % len(keys)
keys = append(keys[:index], keys[index+1:]...)
vals = append(vals[:index], vals[index+1:]...)
case 3:
// Out of order
index1 := index % len(keys)
index2 := index2 % len(keys)
keys[index1], keys[index2] = keys[index2], keys[index1]
vals[index1], vals[index2] = vals[index2], vals[index1]
case 4:
// Set random key to nil, do nothing
keys[index%len(keys)] = nil
case 5:
// Set random value to nil, deletion
vals[index%len(vals)] = nil
// Other cases:
// Modify something in the proof db
// add stuff to proof db
// drop stuff from proof db
}
if f.exhausted {
break
}
ok = 1
//nodes, subtrie
nodes, subtrie, notary, hasMore, err := trie.VerifyRangeProof(tr.Hash(), first, last, keys, vals, proof)
if err != nil {
if nodes != nil {
panic("err != nil && nodes != nil")
}
if subtrie != nil {
panic("err != nil && subtrie != nil")
}
if notary != nil {
panic("err != nil && notary != nil")
}
if hasMore {
panic("err != nil && hasMore == true")
}
} else {
if nodes == nil {
panic("err == nil && nodes == nil")
}
if subtrie == nil {
panic("err == nil && subtrie == nil")
}
if notary == nil {
panic("err == nil && subtrie == nil")
}
}
}
return ok
}
// The function must return
// 1 if the fuzzer should increase priority of the
// given input during subsequent fuzzing (for example, the input is lexically
// correct and was parsed successfully);
// -1 if the input must not be added to corpus even if gives new coverage; and
// 0 otherwise; other values are reserved for future use.
func Fuzz(input []byte) int {
if len(input) < 100 {
return 0
}
r := bytes.NewReader(input)
f := fuzzer{
input: r,
exhausted: false,
}
return f.fuzz()
}

57
trie/notary.go Normal file
View File

@ -0,0 +1,57 @@
// Copyright 2020 The go-ethereum Authors
// This file is part of the go-ethereum library.
//
// The go-ethereum library is free software: you can redistribute it and/or modify
// it under the terms of the GNU Lesser General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// The go-ethereum library is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Lesser General Public License for more details.
//
// You should have received a copy of the GNU Lesser General Public License
// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
package trie
import (
"github.com/ethereum/go-ethereum/ethdb"
"github.com/ethereum/go-ethereum/ethdb/memorydb"
)
// KeyValueNotary tracks which keys have been accessed through a key-value reader
// with te scope of verifying if certain proof datasets are maliciously bloated.
type KeyValueNotary struct {
ethdb.KeyValueReader
reads map[string]struct{}
}
// NewKeyValueNotary wraps a key-value database with an access notary to track
// which items have bene accessed.
func NewKeyValueNotary(db ethdb.KeyValueReader) *KeyValueNotary {
return &KeyValueNotary{
KeyValueReader: db,
reads: make(map[string]struct{}),
}
}
// Get retrieves an item from the underlying database, but also tracks it as an
// accessed slot for bloat checks.
func (k *KeyValueNotary) Get(key []byte) ([]byte, error) {
k.reads[string(key)] = struct{}{}
return k.KeyValueReader.Get(key)
}
// Accessed returns s snapshot of the original key-value store containing only the
// data accessed through the notary.
func (k *KeyValueNotary) Accessed() ethdb.KeyValueStore {
db := memorydb.New()
for keystr := range k.reads {
key := []byte(keystr)
val, _ := k.KeyValueReader.Get(key)
db.Put(key, val)
}
return db
}

View File

@ -454,96 +454,136 @@ func hasRightElement(node node, key []byte) bool {
// //
// Except returning the error to indicate the proof is valid or not, the function will // Except returning the error to indicate the proof is valid or not, the function will
// also return a flag to indicate whether there exists more accounts/slots in the trie. // also return a flag to indicate whether there exists more accounts/slots in the trie.
func VerifyRangeProof(rootHash common.Hash, firstKey []byte, lastKey []byte, keys [][]byte, values [][]byte, proof ethdb.KeyValueReader) (error, bool) { func VerifyRangeProof(rootHash common.Hash, firstKey []byte, lastKey []byte, keys [][]byte, values [][]byte, proof ethdb.KeyValueReader) (ethdb.KeyValueStore, *Trie, *KeyValueNotary, bool, error) {
if len(keys) != len(values) { if len(keys) != len(values) {
return fmt.Errorf("inconsistent proof data, keys: %d, values: %d", len(keys), len(values)), false return nil, nil, nil, false, fmt.Errorf("inconsistent proof data, keys: %d, values: %d", len(keys), len(values))
} }
// Ensure the received batch is monotonic increasing. // Ensure the received batch is monotonic increasing.
for i := 0; i < len(keys)-1; i++ { for i := 0; i < len(keys)-1; i++ {
if bytes.Compare(keys[i], keys[i+1]) >= 0 { if bytes.Compare(keys[i], keys[i+1]) >= 0 {
return errors.New("range is not monotonically increasing"), false return nil, nil, nil, false, errors.New("range is not monotonically increasing")
} }
} }
// Create a key-value notary to track which items from the given proof the
// range prover actually needed to verify the data
notary := NewKeyValueNotary(proof)
// Special case, there is no edge proof at all. The given range is expected // Special case, there is no edge proof at all. The given range is expected
// to be the whole leaf-set in the trie. // to be the whole leaf-set in the trie.
if proof == nil { if proof == nil {
emptytrie, err := New(common.Hash{}, NewDatabase(memorydb.New())) var (
diskdb = memorydb.New()
triedb = NewDatabase(diskdb)
)
tr, err := New(common.Hash{}, triedb)
if err != nil { if err != nil {
return err, false return nil, nil, nil, false, err
} }
for index, key := range keys { for index, key := range keys {
emptytrie.TryUpdate(key, values[index]) tr.TryUpdate(key, values[index])
} }
if emptytrie.Hash() != rootHash { if tr.Hash() != rootHash {
return fmt.Errorf("invalid proof, want hash %x, got %x", rootHash, emptytrie.Hash()), false return nil, nil, nil, false, fmt.Errorf("invalid proof, want hash %x, got %x", rootHash, tr.Hash())
} }
return nil, false // no more element. // Proof seems valid, serialize all the nodes into the database
if _, err := tr.Commit(nil); err != nil {
return nil, nil, nil, false, err
}
if err := triedb.Commit(rootHash, false, nil); err != nil {
return nil, nil, nil, false, err
}
return diskdb, tr, notary, false, nil // No more elements
} }
// Special case, there is a provided edge proof but zero key/value // Special case, there is a provided edge proof but zero key/value
// pairs, ensure there are no more accounts / slots in the trie. // pairs, ensure there are no more accounts / slots in the trie.
if len(keys) == 0 { if len(keys) == 0 {
root, val, err := proofToPath(rootHash, nil, firstKey, proof, true) root, val, err := proofToPath(rootHash, nil, firstKey, notary, true)
if err != nil { if err != nil {
return err, false return nil, nil, nil, false, err
} }
if val != nil || hasRightElement(root, firstKey) { if val != nil || hasRightElement(root, firstKey) {
return errors.New("more entries available"), false return nil, nil, nil, false, errors.New("more entries available")
} }
return nil, false // Since the entire proof is a single path, we can construct a trie and a
// node database directly out of the inputs, no need to generate them
diskdb := notary.Accessed()
tr := &Trie{
db: NewDatabase(diskdb),
root: root,
}
return diskdb, tr, notary, hasRightElement(root, firstKey), nil
} }
// Special case, there is only one element and two edge keys are same. // Special case, there is only one element and two edge keys are same.
// In this case, we can't construct two edge paths. So handle it here. // In this case, we can't construct two edge paths. So handle it here.
if len(keys) == 1 && bytes.Equal(firstKey, lastKey) { if len(keys) == 1 && bytes.Equal(firstKey, lastKey) {
root, val, err := proofToPath(rootHash, nil, firstKey, proof, false) root, val, err := proofToPath(rootHash, nil, firstKey, notary, false)
if err != nil { if err != nil {
return err, false return nil, nil, nil, false, err
} }
if !bytes.Equal(firstKey, keys[0]) { if !bytes.Equal(firstKey, keys[0]) {
return errors.New("correct proof but invalid key"), false return nil, nil, nil, false, errors.New("correct proof but invalid key")
} }
if !bytes.Equal(val, values[0]) { if !bytes.Equal(val, values[0]) {
return errors.New("correct proof but invalid data"), false return nil, nil, nil, false, errors.New("correct proof but invalid data")
} }
return nil, hasRightElement(root, firstKey) // Since the entire proof is a single path, we can construct a trie and a
// node database directly out of the inputs, no need to generate them
diskdb := notary.Accessed()
tr := &Trie{
db: NewDatabase(diskdb),
root: root,
}
return diskdb, tr, notary, hasRightElement(root, firstKey), nil
} }
// Ok, in all other cases, we require two edge paths available. // Ok, in all other cases, we require two edge paths available.
// First check the validity of edge keys. // First check the validity of edge keys.
if bytes.Compare(firstKey, lastKey) >= 0 { if bytes.Compare(firstKey, lastKey) >= 0 {
return errors.New("invalid edge keys"), false return nil, nil, nil, false, errors.New("invalid edge keys")
} }
// todo(rjl493456442) different length edge keys should be supported // todo(rjl493456442) different length edge keys should be supported
if len(firstKey) != len(lastKey) { if len(firstKey) != len(lastKey) {
return errors.New("inconsistent edge keys"), false return nil, nil, nil, false, errors.New("inconsistent edge keys")
} }
// Convert the edge proofs to edge trie paths. Then we can // Convert the edge proofs to edge trie paths. Then we can
// have the same tree architecture with the original one. // have the same tree architecture with the original one.
// For the first edge proof, non-existent proof is allowed. // For the first edge proof, non-existent proof is allowed.
root, _, err := proofToPath(rootHash, nil, firstKey, proof, true) root, _, err := proofToPath(rootHash, nil, firstKey, notary, true)
if err != nil { if err != nil {
return err, false return nil, nil, nil, false, err
} }
// Pass the root node here, the second path will be merged // Pass the root node here, the second path will be merged
// with the first one. For the last edge proof, non-existent // with the first one. For the last edge proof, non-existent
// proof is also allowed. // proof is also allowed.
root, _, err = proofToPath(rootHash, root, lastKey, proof, true) root, _, err = proofToPath(rootHash, root, lastKey, notary, true)
if err != nil { if err != nil {
return err, false return nil, nil, nil, false, err
} }
// Remove all internal references. All the removed parts should // Remove all internal references. All the removed parts should
// be re-filled(or re-constructed) by the given leaves range. // be re-filled(or re-constructed) by the given leaves range.
if err := unsetInternal(root, firstKey, lastKey); err != nil { if err := unsetInternal(root, firstKey, lastKey); err != nil {
return err, false return nil, nil, nil, false, err
} }
// Rebuild the trie with the leave stream, the shape of trie // Rebuild the trie with the leaf stream, the shape of trie
// should be same with the original one. // should be same with the original one.
newtrie := &Trie{root: root, db: NewDatabase(memorydb.New())} var (
diskdb = memorydb.New()
triedb = NewDatabase(diskdb)
)
tr := &Trie{root: root, db: triedb}
for index, key := range keys { for index, key := range keys {
newtrie.TryUpdate(key, values[index]) tr.TryUpdate(key, values[index])
} }
if newtrie.Hash() != rootHash { if tr.Hash() != rootHash {
return fmt.Errorf("invalid proof, want hash %x, got %x", rootHash, newtrie.Hash()), false return nil, nil, nil, false, fmt.Errorf("invalid proof, want hash %x, got %x", rootHash, tr.Hash())
} }
return nil, hasRightElement(root, keys[len(keys)-1]) // Proof seems valid, serialize all the nodes into the database
if _, err := tr.Commit(nil); err != nil {
return nil, nil, nil, false, err
}
if err := triedb.Commit(rootHash, false, nil); err != nil {
return nil, nil, nil, false, err
}
return diskdb, tr, notary, hasRightElement(root, keys[len(keys)-1]), nil
} }
// get returns the child of the given node. Return nil if the // get returns the child of the given node. Return nil if the

View File

@ -19,6 +19,7 @@ package trie
import ( import (
"bytes" "bytes"
crand "crypto/rand" crand "crypto/rand"
"encoding/binary"
mrand "math/rand" mrand "math/rand"
"sort" "sort"
"testing" "testing"
@ -181,7 +182,7 @@ func TestRangeProof(t *testing.T) {
keys = append(keys, entries[i].k) keys = append(keys, entries[i].k)
vals = append(vals, entries[i].v) vals = append(vals, entries[i].v)
} }
err, _ := VerifyRangeProof(trie.Hash(), keys[0], keys[len(keys)-1], keys, vals, proof) _, _, _, _, err := VerifyRangeProof(trie.Hash(), keys[0], keys[len(keys)-1], keys, vals, proof)
if err != nil { if err != nil {
t.Fatalf("Case %d(%d->%d) expect no error, got %v", i, start, end-1, err) t.Fatalf("Case %d(%d->%d) expect no error, got %v", i, start, end-1, err)
} }
@ -232,7 +233,7 @@ func TestRangeProofWithNonExistentProof(t *testing.T) {
keys = append(keys, entries[i].k) keys = append(keys, entries[i].k)
vals = append(vals, entries[i].v) vals = append(vals, entries[i].v)
} }
err, _ := VerifyRangeProof(trie.Hash(), first, last, keys, vals, proof) _, _, _, _, err := VerifyRangeProof(trie.Hash(), first, last, keys, vals, proof)
if err != nil { if err != nil {
t.Fatalf("Case %d(%d->%d) expect no error, got %v", i, start, end-1, err) t.Fatalf("Case %d(%d->%d) expect no error, got %v", i, start, end-1, err)
} }
@ -253,7 +254,7 @@ func TestRangeProofWithNonExistentProof(t *testing.T) {
k = append(k, entries[i].k) k = append(k, entries[i].k)
v = append(v, entries[i].v) v = append(v, entries[i].v)
} }
err, _ := VerifyRangeProof(trie.Hash(), first, last, k, v, proof) _, _, _, _, err := VerifyRangeProof(trie.Hash(), first, last, k, v, proof)
if err != nil { if err != nil {
t.Fatal("Failed to verify whole rang with non-existent edges") t.Fatal("Failed to verify whole rang with non-existent edges")
} }
@ -288,7 +289,7 @@ func TestRangeProofWithInvalidNonExistentProof(t *testing.T) {
k = append(k, entries[i].k) k = append(k, entries[i].k)
v = append(v, entries[i].v) v = append(v, entries[i].v)
} }
err, _ := VerifyRangeProof(trie.Hash(), first, k[len(k)-1], k, v, proof) _, _, _, _, err := VerifyRangeProof(trie.Hash(), first, k[len(k)-1], k, v, proof)
if err == nil { if err == nil {
t.Fatalf("Expected to detect the error, got nil") t.Fatalf("Expected to detect the error, got nil")
} }
@ -310,7 +311,7 @@ func TestRangeProofWithInvalidNonExistentProof(t *testing.T) {
k = append(k, entries[i].k) k = append(k, entries[i].k)
v = append(v, entries[i].v) v = append(v, entries[i].v)
} }
err, _ = VerifyRangeProof(trie.Hash(), k[0], last, k, v, proof) _, _, _, _, err = VerifyRangeProof(trie.Hash(), k[0], last, k, v, proof)
if err == nil { if err == nil {
t.Fatalf("Expected to detect the error, got nil") t.Fatalf("Expected to detect the error, got nil")
} }
@ -334,7 +335,7 @@ func TestOneElementRangeProof(t *testing.T) {
if err := trie.Prove(entries[start].k, 0, proof); err != nil { if err := trie.Prove(entries[start].k, 0, proof); err != nil {
t.Fatalf("Failed to prove the first node %v", err) t.Fatalf("Failed to prove the first node %v", err)
} }
err, _ := VerifyRangeProof(trie.Hash(), entries[start].k, entries[start].k, [][]byte{entries[start].k}, [][]byte{entries[start].v}, proof) _, _, _, _, err := VerifyRangeProof(trie.Hash(), entries[start].k, entries[start].k, [][]byte{entries[start].k}, [][]byte{entries[start].v}, proof)
if err != nil { if err != nil {
t.Fatalf("Expected no error, got %v", err) t.Fatalf("Expected no error, got %v", err)
} }
@ -349,7 +350,7 @@ func TestOneElementRangeProof(t *testing.T) {
if err := trie.Prove(entries[start].k, 0, proof); err != nil { if err := trie.Prove(entries[start].k, 0, proof); err != nil {
t.Fatalf("Failed to prove the last node %v", err) t.Fatalf("Failed to prove the last node %v", err)
} }
err, _ = VerifyRangeProof(trie.Hash(), first, entries[start].k, [][]byte{entries[start].k}, [][]byte{entries[start].v}, proof) _, _, _, _, err = VerifyRangeProof(trie.Hash(), first, entries[start].k, [][]byte{entries[start].k}, [][]byte{entries[start].v}, proof)
if err != nil { if err != nil {
t.Fatalf("Expected no error, got %v", err) t.Fatalf("Expected no error, got %v", err)
} }
@ -364,7 +365,7 @@ func TestOneElementRangeProof(t *testing.T) {
if err := trie.Prove(last, 0, proof); err != nil { if err := trie.Prove(last, 0, proof); err != nil {
t.Fatalf("Failed to prove the last node %v", err) t.Fatalf("Failed to prove the last node %v", err)
} }
err, _ = VerifyRangeProof(trie.Hash(), entries[start].k, last, [][]byte{entries[start].k}, [][]byte{entries[start].v}, proof) _, _, _, _, err = VerifyRangeProof(trie.Hash(), entries[start].k, last, [][]byte{entries[start].k}, [][]byte{entries[start].v}, proof)
if err != nil { if err != nil {
t.Fatalf("Expected no error, got %v", err) t.Fatalf("Expected no error, got %v", err)
} }
@ -379,7 +380,7 @@ func TestOneElementRangeProof(t *testing.T) {
if err := trie.Prove(last, 0, proof); err != nil { if err := trie.Prove(last, 0, proof); err != nil {
t.Fatalf("Failed to prove the last node %v", err) t.Fatalf("Failed to prove the last node %v", err)
} }
err, _ = VerifyRangeProof(trie.Hash(), first, last, [][]byte{entries[start].k}, [][]byte{entries[start].v}, proof) _, _, _, _, err = VerifyRangeProof(trie.Hash(), first, last, [][]byte{entries[start].k}, [][]byte{entries[start].v}, proof)
if err != nil { if err != nil {
t.Fatalf("Expected no error, got %v", err) t.Fatalf("Expected no error, got %v", err)
} }
@ -401,7 +402,7 @@ func TestAllElementsProof(t *testing.T) {
k = append(k, entries[i].k) k = append(k, entries[i].k)
v = append(v, entries[i].v) v = append(v, entries[i].v)
} }
err, _ := VerifyRangeProof(trie.Hash(), nil, nil, k, v, nil) _, _, _, _, err := VerifyRangeProof(trie.Hash(), nil, nil, k, v, nil)
if err != nil { if err != nil {
t.Fatalf("Expected no error, got %v", err) t.Fatalf("Expected no error, got %v", err)
} }
@ -414,7 +415,7 @@ func TestAllElementsProof(t *testing.T) {
if err := trie.Prove(entries[len(entries)-1].k, 0, proof); err != nil { if err := trie.Prove(entries[len(entries)-1].k, 0, proof); err != nil {
t.Fatalf("Failed to prove the last node %v", err) t.Fatalf("Failed to prove the last node %v", err)
} }
err, _ = VerifyRangeProof(trie.Hash(), k[0], k[len(k)-1], k, v, proof) _, _, _, _, err = VerifyRangeProof(trie.Hash(), k[0], k[len(k)-1], k, v, proof)
if err != nil { if err != nil {
t.Fatalf("Expected no error, got %v", err) t.Fatalf("Expected no error, got %v", err)
} }
@ -429,7 +430,7 @@ func TestAllElementsProof(t *testing.T) {
if err := trie.Prove(last, 0, proof); err != nil { if err := trie.Prove(last, 0, proof); err != nil {
t.Fatalf("Failed to prove the last node %v", err) t.Fatalf("Failed to prove the last node %v", err)
} }
err, _ = VerifyRangeProof(trie.Hash(), first, last, k, v, proof) _, _, _, _, err = VerifyRangeProof(trie.Hash(), first, last, k, v, proof)
if err != nil { if err != nil {
t.Fatalf("Expected no error, got %v", err) t.Fatalf("Expected no error, got %v", err)
} }
@ -462,7 +463,7 @@ func TestSingleSideRangeProof(t *testing.T) {
k = append(k, entries[i].k) k = append(k, entries[i].k)
v = append(v, entries[i].v) v = append(v, entries[i].v)
} }
err, _ := VerifyRangeProof(trie.Hash(), common.Hash{}.Bytes(), k[len(k)-1], k, v, proof) _, _, _, _, err := VerifyRangeProof(trie.Hash(), common.Hash{}.Bytes(), k[len(k)-1], k, v, proof)
if err != nil { if err != nil {
t.Fatalf("Expected no error, got %v", err) t.Fatalf("Expected no error, got %v", err)
} }
@ -498,7 +499,7 @@ func TestReverseSingleSideRangeProof(t *testing.T) {
k = append(k, entries[i].k) k = append(k, entries[i].k)
v = append(v, entries[i].v) v = append(v, entries[i].v)
} }
err, _ := VerifyRangeProof(trie.Hash(), k[0], last.Bytes(), k, v, proof) _, _, _, _, err := VerifyRangeProof(trie.Hash(), k[0], last.Bytes(), k, v, proof)
if err != nil { if err != nil {
t.Fatalf("Expected no error, got %v", err) t.Fatalf("Expected no error, got %v", err)
} }
@ -570,7 +571,7 @@ func TestBadRangeProof(t *testing.T) {
index = mrand.Intn(end - start) index = mrand.Intn(end - start)
vals[index] = nil vals[index] = nil
} }
err, _ := VerifyRangeProof(trie.Hash(), first, last, keys, vals, proof) _, _, _, _, err := VerifyRangeProof(trie.Hash(), first, last, keys, vals, proof)
if err == nil { if err == nil {
t.Fatalf("%d Case %d index %d range: (%d->%d) expect error, got nil", i, testcase, index, start, end-1) t.Fatalf("%d Case %d index %d range: (%d->%d) expect error, got nil", i, testcase, index, start, end-1)
} }
@ -604,7 +605,7 @@ func TestGappedRangeProof(t *testing.T) {
keys = append(keys, entries[i].k) keys = append(keys, entries[i].k)
vals = append(vals, entries[i].v) vals = append(vals, entries[i].v)
} }
err, _ := VerifyRangeProof(trie.Hash(), keys[0], keys[len(keys)-1], keys, vals, proof) _, _, _, _, err := VerifyRangeProof(trie.Hash(), keys[0], keys[len(keys)-1], keys, vals, proof)
if err == nil { if err == nil {
t.Fatal("expect error, got nil") t.Fatal("expect error, got nil")
} }
@ -631,7 +632,7 @@ func TestSameSideProofs(t *testing.T) {
if err := trie.Prove(last, 0, proof); err != nil { if err := trie.Prove(last, 0, proof); err != nil {
t.Fatalf("Failed to prove the last node %v", err) t.Fatalf("Failed to prove the last node %v", err)
} }
err, _ := VerifyRangeProof(trie.Hash(), first, last, [][]byte{entries[pos].k}, [][]byte{entries[pos].v}, proof) _, _, _, _, err := VerifyRangeProof(trie.Hash(), first, last, [][]byte{entries[pos].k}, [][]byte{entries[pos].v}, proof)
if err == nil { if err == nil {
t.Fatalf("Expected error, got nil") t.Fatalf("Expected error, got nil")
} }
@ -647,7 +648,7 @@ func TestSameSideProofs(t *testing.T) {
if err := trie.Prove(last, 0, proof); err != nil { if err := trie.Prove(last, 0, proof); err != nil {
t.Fatalf("Failed to prove the last node %v", err) t.Fatalf("Failed to prove the last node %v", err)
} }
err, _ = VerifyRangeProof(trie.Hash(), first, last, [][]byte{entries[pos].k}, [][]byte{entries[pos].v}, proof) _, _, _, _, err = VerifyRangeProof(trie.Hash(), first, last, [][]byte{entries[pos].k}, [][]byte{entries[pos].v}, proof)
if err == nil { if err == nil {
t.Fatalf("Expected error, got nil") t.Fatalf("Expected error, got nil")
} }
@ -715,7 +716,7 @@ func TestHasRightElement(t *testing.T) {
k = append(k, entries[i].k) k = append(k, entries[i].k)
v = append(v, entries[i].v) v = append(v, entries[i].v)
} }
err, hasMore := VerifyRangeProof(trie.Hash(), firstKey, lastKey, k, v, proof) _, _, _, hasMore, err := VerifyRangeProof(trie.Hash(), firstKey, lastKey, k, v, proof)
if err != nil { if err != nil {
t.Fatalf("Expected no error, got %v", err) t.Fatalf("Expected no error, got %v", err)
} }
@ -748,13 +749,57 @@ func TestEmptyRangeProof(t *testing.T) {
if err := trie.Prove(first, 0, proof); err != nil { if err := trie.Prove(first, 0, proof); err != nil {
t.Fatalf("Failed to prove the first node %v", err) t.Fatalf("Failed to prove the first node %v", err)
} }
err, _ := VerifyRangeProof(trie.Hash(), first, nil, nil, nil, proof) db, tr, not, _, err := VerifyRangeProof(trie.Hash(), first, nil, nil, nil, proof)
if c.err && err == nil { if c.err && err == nil {
t.Fatalf("Expected error, got nil") t.Fatalf("Expected error, got nil")
} }
if !c.err && err != nil { if !c.err && err != nil {
t.Fatalf("Expected no error, got %v", err) t.Fatalf("Expected no error, got %v", err)
} }
// If no error was returned, ensure the returned trie and database contains
// the entire proof, since there's no value
if !c.err {
if err := tr.Prove(first, 0, memorydb.New()); err != nil {
t.Errorf("returned trie doesn't contain original proof: %v", err)
}
if memdb := db.(*memorydb.Database); memdb.Len() != proof.Len() {
t.Errorf("database entry count mismatch: have %d, want %d", memdb.Len(), proof.Len())
}
if not == nil {
t.Errorf("missing notary")
}
}
}
}
// TestBloatedProof tests a malicious proof, where the proof is more or less the
// whole trie.
func TestBloatedProof(t *testing.T) {
// Use a small trie
trie, kvs := nonRandomTrie(100)
var entries entrySlice
for _, kv := range kvs {
entries = append(entries, kv)
}
sort.Sort(entries)
var keys [][]byte
var vals [][]byte
proof := memorydb.New()
for i, entry := range entries {
trie.Prove(entry.k, 0, proof)
if i == 50 {
keys = append(keys, entry.k)
vals = append(vals, entry.v)
}
}
want := memorydb.New()
trie.Prove(keys[0], 0, want)
trie.Prove(keys[len(keys)-1], 0, want)
_, _, notary, _, _ := VerifyRangeProof(trie.Hash(), keys[0], keys[len(keys)-1], keys, vals, proof)
if used := notary.Accessed().(*memorydb.Database); used.Len() != want.Len() {
t.Fatalf("notary proof size mismatch: have %d, want %d", used.Len(), want.Len())
} }
} }
@ -858,7 +903,7 @@ func benchmarkVerifyRangeProof(b *testing.B, size int) {
b.ResetTimer() b.ResetTimer()
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
err, _ := VerifyRangeProof(trie.Hash(), keys[0], keys[len(keys)-1], keys, values, proof) _, _, _, _, err := VerifyRangeProof(trie.Hash(), keys[0], keys[len(keys)-1], keys, values, proof)
if err != nil { if err != nil {
b.Fatalf("Case %d(%d->%d) expect no error, got %v", i, start, end-1, err) b.Fatalf("Case %d(%d->%d) expect no error, got %v", i, start, end-1, err)
} }
@ -889,3 +934,20 @@ func randBytes(n int) []byte {
crand.Read(r) crand.Read(r)
return r return r
} }
func nonRandomTrie(n int) (*Trie, map[string]*kv) {
trie := new(Trie)
vals := make(map[string]*kv)
max := uint64(0xffffffffffffffff)
for i := uint64(0); i < uint64(n); i++ {
value := make([]byte, 32)
key := make([]byte, 32)
binary.LittleEndian.PutUint64(key, i)
binary.LittleEndian.PutUint64(value, i-max)
//value := &kv{common.LeftPadBytes([]byte{i}, 32), []byte{i}, false}
elem := &kv{key, value, false}
trie.Update(elem.k, elem.v)
vals[string(elem.k)] = elem
}
return trie, vals
}

View File

@ -125,14 +125,14 @@ func (b *SyncBloom) init(database ethdb.Iteratee) {
it.Release() it.Release()
it = database.NewIterator(nil, key) it = database.NewIterator(nil, key)
log.Info("Initializing fast sync bloom", "items", b.bloom.N(), "errorrate", b.errorRate(), "elapsed", common.PrettyDuration(time.Since(start))) log.Info("Initializing state bloom", "items", b.bloom.N(), "errorrate", b.errorRate(), "elapsed", common.PrettyDuration(time.Since(start)))
swap = time.Now() swap = time.Now()
} }
} }
it.Release() it.Release()
// Mark the bloom filter inited and return // Mark the bloom filter inited and return
log.Info("Initialized fast sync bloom", "items", b.bloom.N(), "errorrate", b.errorRate(), "elapsed", common.PrettyDuration(time.Since(start))) log.Info("Initialized state bloom", "items", b.bloom.N(), "errorrate", b.errorRate(), "elapsed", common.PrettyDuration(time.Since(start)))
atomic.StoreUint32(&b.inited, 1) atomic.StoreUint32(&b.inited, 1)
} }
@ -162,7 +162,7 @@ func (b *SyncBloom) Close() error {
b.pend.Wait() b.pend.Wait()
// Wipe the bloom, but mark it "uninited" just in case someone attempts an access // Wipe the bloom, but mark it "uninited" just in case someone attempts an access
log.Info("Deallocated fast sync bloom", "items", b.bloom.N(), "errorrate", b.errorRate()) log.Info("Deallocated state bloom", "items", b.bloom.N(), "errorrate", b.errorRate())
atomic.StoreUint32(&b.inited, 0) atomic.StoreUint32(&b.inited, 0)
b.bloom = nil b.bloom = nil

View File

@ -19,13 +19,13 @@ package trie
import ( import (
"bytes" "bytes"
"errors"
"fmt" "fmt"
"sync" "sync"
"github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/crypto" "github.com/ethereum/go-ethereum/crypto"
"github.com/ethereum/go-ethereum/log" "github.com/ethereum/go-ethereum/log"
"github.com/ethereum/go-ethereum/rlp"
) )
var ( var (
@ -159,29 +159,26 @@ func (t *Trie) TryGetNode(path []byte) ([]byte, int, error) {
if item == nil { if item == nil {
return nil, resolved, nil return nil, resolved, nil
} }
enc, err := rlp.EncodeToBytes(item) return item, resolved, err
if err != nil {
log.Error("Encoding existing trie node failed", "err", err)
return nil, resolved, err
}
return enc, resolved, err
} }
func (t *Trie) tryGetNode(origNode node, path []byte, pos int) (item node, newnode node, resolved int, err error) { func (t *Trie) tryGetNode(origNode node, path []byte, pos int) (item []byte, newnode node, resolved int, err error) {
// If we reached the requested path, return the current node // If we reached the requested path, return the current node
if pos >= len(path) { if pos >= len(path) {
// Don't return collapsed hash nodes though // Although we most probably have the original node expanded, encoding
if _, ok := origNode.(hashNode); !ok { // that into consensus form can be nasty (needs to cascade down) and
// Short nodes have expanded keys, compact them before returning // time consuming. Instead, just pull the hash up from disk directly.
item := origNode var hash hashNode
if sn, ok := item.(*shortNode); ok { if node, ok := origNode.(hashNode); ok {
item = &shortNode{ hash = node
Key: hexToCompact(sn.Key), } else {
Val: sn.Val, hash, _ = origNode.cache()
} }
if hash == nil {
return nil, origNode, 0, errors.New("non-consensus node")
} }
return item, origNode, 0, nil blob, err := t.db.Node(common.BytesToHash(hash))
} return blob, origNode, 1, err
} }
// Path still needs to be traversed, descend into children // Path still needs to be traversed, descend into children
switch n := (origNode).(type) { switch n := (origNode).(type) {
@ -491,7 +488,7 @@ func (t *Trie) resolveHash(n hashNode, prefix []byte) (node, error) {
// Hash returns the root hash of the trie. It does not write to the // Hash returns the root hash of the trie. It does not write to the
// database and can be used even if the trie doesn't have one. // database and can be used even if the trie doesn't have one.
func (t *Trie) Hash() common.Hash { func (t *Trie) Hash() common.Hash {
hash, cached, _ := t.hashRoot(nil) hash, cached, _ := t.hashRoot()
t.root = cached t.root = cached
return common.BytesToHash(hash.(hashNode)) return common.BytesToHash(hash.(hashNode))
} }
@ -545,7 +542,7 @@ func (t *Trie) Commit(onleaf LeafCallback) (root common.Hash, err error) {
} }
// hashRoot calculates the root hash of the given trie // hashRoot calculates the root hash of the given trie
func (t *Trie) hashRoot(db *Database) (node, node, error) { func (t *Trie) hashRoot() (node, node, error) {
if t.root == nil { if t.root == nil {
return hashNode(emptyRoot.Bytes()), nil, nil return hashNode(emptyRoot.Bytes()), nil, nil
} }