diff --git a/cmd/geth/main.go b/cmd/geth/main.go index ff556c984..dc7e19c61 100644 --- a/cmd/geth/main.go +++ b/cmd/geth/main.go @@ -283,6 +283,7 @@ JavaScript API. See https://github.com/ethereum/go-ethereum/wiki/Javascipt-Conso utils.DataDirFlag, utils.BlockchainVersionFlag, utils.OlympicFlag, + utils.EthVersionFlag, utils.CacheFlag, utils.JSpathFlag, utils.ListenPortFlag, @@ -333,6 +334,7 @@ JavaScript API. See https://github.com/ethereum/go-ethereum/wiki/Javascipt-Conso app.Before = func(ctx *cli.Context) error { utils.SetupLogger(ctx) utils.SetupVM(ctx) + utils.SetupEth(ctx) if ctx.GlobalBool(utils.PProfEanbledFlag.Name) { utils.StartPProf(ctx) } diff --git a/cmd/utils/flags.go b/cmd/utils/flags.go index af2929d10..5ebc4ea61 100644 --- a/cmd/utils/flags.go +++ b/cmd/utils/flags.go @@ -138,6 +138,11 @@ var ( Name: "olympic", Usage: "Use olympic style protocol", } + EthVersionFlag = cli.IntFlag{ + Name: "eth", + Value: 61, + Usage: "Highest eth protocol to advertise (temporary, dev option)", + } // miner settings MinerThreadsFlag = cli.IntFlag{ @@ -459,6 +464,18 @@ func SetupVM(ctx *cli.Context) { vm.SetJITCacheSize(ctx.GlobalInt(VMJitCacheFlag.Name)) } +// SetupEth configures the eth packages global settings +func SetupEth(ctx *cli.Context) { + version := ctx.GlobalInt(EthVersionFlag.Name) + for len(eth.ProtocolVersions) > 0 && eth.ProtocolVersions[0] > uint(version) { + eth.ProtocolVersions = eth.ProtocolVersions[1:] + eth.ProtocolLengths = eth.ProtocolLengths[1:] + } + if len(eth.ProtocolVersions) == 0 { + Fatalf("No valid eth protocols remaining") + } +} + // MakeChain creates a chain manager from set command line flags. func MakeChain(ctx *cli.Context) (chain *core.ChainManager, chainDb common.Database) { datadir := ctx.GlobalString(DataDirFlag.Name) diff --git a/eth/backend.go b/eth/backend.go index 2b21a7c96..ad2a2c1f9 100644 --- a/eth/backend.go +++ b/eth/backend.go @@ -373,7 +373,7 @@ func New(config *Config) (*Ethereum, error) { eth.blockProcessor = core.NewBlockProcessor(chainDb, eth.pow, eth.chainManager, eth.EventMux()) eth.chainManager.SetProcessor(eth.blockProcessor) - eth.protocolManager = NewProtocolManager(config.NetworkId, eth.eventMux, eth.txPool, eth.pow, eth.chainManager) + eth.protocolManager = NewProtocolManager(config.NetworkId, eth.eventMux, eth.txPool, eth.pow, eth.chainManager, chainDb) eth.miner = miner.New(eth, eth.EventMux(), eth.pow) eth.miner.SetGasPrice(config.GasPrice) diff --git a/eth/downloader/downloader.go b/eth/downloader/downloader.go index 23d2e045e..6a6bce644 100644 --- a/eth/downloader/downloader.go +++ b/eth/downloader/downloader.go @@ -39,13 +39,15 @@ import ( const ( eth60 = 60 // Constant to check for old protocol support eth61 = 61 // Constant to check for new protocol support - eth62 = 62 // Constant to check for experimental protocol support ) var ( - MinHashFetch = 512 // Minimum amount of hashes to not consider a peer stalling - MaxHashFetch = 512 // Amount of hashes to be fetched per retrieval request - MaxBlockFetch = 128 // Amount of blocks to be fetched per retrieval request + MinHashFetch = 512 // Minimum amount of hashes to not consider a peer stalling + MaxHashFetch = 512 // Amount of hashes to be fetched per retrieval request + MaxBlockFetch = 128 // Amount of blocks to be fetched per retrieval request + MaxHeaderFetch = 256 // Amount of block headers to be fetched per retrieval request + MaxStateFetch = 384 // Amount of node state values to allow fetching per request + MaxReceiptsFetch = 384 // Amount of transaction receipts to allow fetching per request hashTTL = 5 * time.Second // Time it takes for a hash request to time out blockSoftTTL = 3 * time.Second // Request completion threshold for increasing or decreasing a peer's bandwidth @@ -330,7 +332,7 @@ func (d *Downloader) syncWithPeer(p *peer, hash common.Hash, td *big.Int) (err e if err = d.fetchBlocks60(); err != nil { return err } - case eth61, eth62: + case eth61: // New eth/61, use forward, concurrent hash and block retrieval algorithm number, err := d.findAncestor(p) if err != nil { diff --git a/eth/fetcher/fetcher.go b/eth/fetcher/fetcher.go index 55b6c5c1c..07eb165dc 100644 --- a/eth/fetcher/fetcher.go +++ b/eth/fetcher/fetcher.go @@ -69,8 +69,9 @@ type peerDropFn func(id string) // announce is the hash notification of the availability of a new block in the // network. type announce struct { - hash common.Hash // Hash of the block being announced - time time.Time // Timestamp of the announcement + hash common.Hash // Hash of the block being announced + number uint64 // Number of the block being announced (0 = unknown | old protocol) + time time.Time // Timestamp of the announcement origin string // Identifier of the peer originating the notification fetch blockRequesterFn // Fetcher function to retrieve @@ -152,9 +153,10 @@ func (f *Fetcher) Stop() { // Notify announces the fetcher of the potential availability of a new block in // the network. -func (f *Fetcher) Notify(peer string, hash common.Hash, time time.Time, fetcher blockRequesterFn) error { +func (f *Fetcher) Notify(peer string, hash common.Hash, number uint64, time time.Time, fetcher blockRequesterFn) error { block := &announce{ hash: hash, + number: number, time: time, origin: peer, fetch: fetcher, diff --git a/eth/fetcher/fetcher_test.go b/eth/fetcher/fetcher_test.go index ecbb3f868..b0d9ce843 100644 --- a/eth/fetcher/fetcher_test.go +++ b/eth/fetcher/fetcher_test.go @@ -194,7 +194,7 @@ func TestSequentialAnnouncements(t *testing.T) { tester.fetcher.importedHook = func(block *types.Block) { imported <- block } for i := len(hashes) - 2; i >= 0; i-- { - tester.fetcher.Notify("valid", hashes[i], time.Now().Add(-arriveTimeout), fetcher) + tester.fetcher.Notify("valid", hashes[i], 0, time.Now().Add(-arriveTimeout), fetcher) verifyImportEvent(t, imported) } verifyImportDone(t, imported) @@ -221,9 +221,9 @@ func TestConcurrentAnnouncements(t *testing.T) { tester.fetcher.importedHook = func(block *types.Block) { imported <- block } for i := len(hashes) - 2; i >= 0; i-- { - tester.fetcher.Notify("first", hashes[i], time.Now().Add(-arriveTimeout), wrapper) - tester.fetcher.Notify("second", hashes[i], time.Now().Add(-arriveTimeout+time.Millisecond), wrapper) - tester.fetcher.Notify("second", hashes[i], time.Now().Add(-arriveTimeout-time.Millisecond), wrapper) + tester.fetcher.Notify("first", hashes[i], 0, time.Now().Add(-arriveTimeout), wrapper) + tester.fetcher.Notify("second", hashes[i], 0, time.Now().Add(-arriveTimeout+time.Millisecond), wrapper) + tester.fetcher.Notify("second", hashes[i], 0, time.Now().Add(-arriveTimeout-time.Millisecond), wrapper) verifyImportEvent(t, imported) } @@ -252,7 +252,7 @@ func TestOverlappingAnnouncements(t *testing.T) { tester.fetcher.importedHook = func(block *types.Block) { imported <- block } for i := len(hashes) - 2; i >= 0; i-- { - tester.fetcher.Notify("valid", hashes[i], time.Now().Add(-arriveTimeout), fetcher) + tester.fetcher.Notify("valid", hashes[i], 0, time.Now().Add(-arriveTimeout), fetcher) select { case <-fetching: case <-time.After(time.Second): @@ -286,7 +286,7 @@ func TestPendingDeduplication(t *testing.T) { } // Announce the same block many times until it's fetched (wait for any pending ops) for tester.getBlock(hashes[0]) == nil { - tester.fetcher.Notify("repeater", hashes[0], time.Now().Add(-arriveTimeout), wrapper) + tester.fetcher.Notify("repeater", hashes[0], 0, time.Now().Add(-arriveTimeout), wrapper) time.Sleep(time.Millisecond) } time.Sleep(delay) @@ -317,12 +317,12 @@ func TestRandomArrivalImport(t *testing.T) { for i := len(hashes) - 1; i >= 0; i-- { if i != skip { - tester.fetcher.Notify("valid", hashes[i], time.Now().Add(-arriveTimeout), fetcher) + tester.fetcher.Notify("valid", hashes[i], 0, time.Now().Add(-arriveTimeout), fetcher) time.Sleep(time.Millisecond) } } // Finally announce the skipped entry and check full import - tester.fetcher.Notify("valid", hashes[skip], time.Now().Add(-arriveTimeout), fetcher) + tester.fetcher.Notify("valid", hashes[skip], 0, time.Now().Add(-arriveTimeout), fetcher) verifyImportCount(t, imported, len(hashes)-1) } @@ -343,7 +343,7 @@ func TestQueueGapFill(t *testing.T) { for i := len(hashes) - 1; i >= 0; i-- { if i != skip { - tester.fetcher.Notify("valid", hashes[i], time.Now().Add(-arriveTimeout), fetcher) + tester.fetcher.Notify("valid", hashes[i], 0, time.Now().Add(-arriveTimeout), fetcher) time.Sleep(time.Millisecond) } } @@ -374,7 +374,7 @@ func TestImportDeduplication(t *testing.T) { tester.fetcher.importedHook = func(block *types.Block) { imported <- block } // Announce the duplicating block, wait for retrieval, and also propagate directly - tester.fetcher.Notify("valid", hashes[0], time.Now().Add(-arriveTimeout), fetcher) + tester.fetcher.Notify("valid", hashes[0], 0, time.Now().Add(-arriveTimeout), fetcher) <-fetching tester.fetcher.Enqueue("valid", blocks[hashes[0]]) @@ -437,9 +437,9 @@ func TestHashMemoryExhaustionAttack(t *testing.T) { // Feed the tester a huge hashset from the attacker, and a limited from the valid peer for i := 0; i < len(attack); i++ { if i < maxQueueDist { - tester.fetcher.Notify("valid", hashes[len(hashes)-2-i], time.Now(), valid) + tester.fetcher.Notify("valid", hashes[len(hashes)-2-i], 0, time.Now(), valid) } - tester.fetcher.Notify("attacker", attack[i], time.Now(), attacker) + tester.fetcher.Notify("attacker", attack[i], 0, time.Now(), attacker) } if len(tester.fetcher.announced) != hashLimit+maxQueueDist { t.Fatalf("queued announce count mismatch: have %d, want %d", len(tester.fetcher.announced), hashLimit+maxQueueDist) @@ -449,7 +449,7 @@ func TestHashMemoryExhaustionAttack(t *testing.T) { // Feed the remaining valid hashes to ensure DOS protection state remains clean for i := len(hashes) - maxQueueDist - 2; i >= 0; i-- { - tester.fetcher.Notify("valid", hashes[i], time.Now().Add(-arriveTimeout), valid) + tester.fetcher.Notify("valid", hashes[i], 0, time.Now().Add(-arriveTimeout), valid) verifyImportEvent(t, imported) } verifyImportDone(t, imported) diff --git a/eth/handler.go b/eth/handler.go index 6c1895bbd..25ff0eef0 100644 --- a/eth/handler.go +++ b/eth/handler.go @@ -36,10 +36,8 @@ import ( "github.com/ethereum/go-ethereum/rlp" ) -// This is the target maximum size of returned blocks for the -// getBlocks message. The reply message may exceed it -// if a single block is larger than the limit. -const maxBlockRespSize = 2 * 1024 * 1024 +// This is the target maximum size of returned blocks, headers or node data. +const softResponseLimit = 2 * 1024 * 1024 func errResp(code errCode, format string, v ...interface{}) error { return fmt.Errorf("%v - %v", code, fmt.Sprintf(format, v...)) @@ -59,12 +57,13 @@ func (ep extProt) GetHashes(hash common.Hash) error { return ep.getHashes(has func (ep extProt) GetBlock(hashes []common.Hash) error { return ep.getBlocks(hashes) } type ProtocolManager struct { - protVer, netId int - txpool txPool - chainman *core.ChainManager - downloader *downloader.Downloader - fetcher *fetcher.Fetcher - peers *peerSet + txpool txPool + chainman *core.ChainManager + chaindb common.Database + + downloader *downloader.Downloader + fetcher *fetcher.Fetcher + peers *peerSet SubProtocols []p2p.Protocol @@ -85,17 +84,17 @@ type ProtocolManager struct { // NewProtocolManager returns a new ethereum sub protocol manager. The Ethereum sub protocol manages peers capable // with the ethereum network. -func NewProtocolManager(networkId int, mux *event.TypeMux, txpool txPool, pow pow.PoW, chainman *core.ChainManager) *ProtocolManager { +func NewProtocolManager(networkId int, mux *event.TypeMux, txpool txPool, pow pow.PoW, chainman *core.ChainManager, chaindb common.Database) *ProtocolManager { // Create the protocol manager with the base fields manager := &ProtocolManager{ eventMux: mux, txpool: txpool, chainman: chainman, + chaindb: chaindb, peers: newPeerSet(), newPeerCh: make(chan *peer, 1), txsyncCh: make(chan *txsync), quitSync: make(chan struct{}), - netId: networkId, } // Initiate a sub-protocol for every implemented version we can handle manager.SubProtocols = make([]p2p.Protocol, len(ProtocolVersions)) @@ -190,6 +189,9 @@ func (pm *ProtocolManager) handle(p *peer) error { glog.V(logger.Debug).Infof("%v: handshake failed: %v", p, err) return err } + if rw, ok := p.rw.(*meteredMsgReadWriter); ok { + rw.Init(p.version) + } // Register the peer locally glog.V(logger.Detail).Infof("%v: adding peer", p) if err := pm.peers.Register(p); err != nil { @@ -230,12 +232,12 @@ func (pm *ProtocolManager) handleMsg(p *peer) error { defer msg.Discard() // Handle the message depending on its contents - switch msg.Code { - case StatusMsg: + switch { + case msg.Code == StatusMsg: // Status messages should never arrive after the handshake return errResp(ErrExtraStatusMsg, "uncontrolled status message") - case GetBlockHashesMsg: + case p.version < eth62 && msg.Code == GetBlockHashesMsg: // Retrieve the number of hashes to return and from which origin hash var request getBlockHashesData if err := msg.Decode(&request); err != nil { @@ -251,7 +253,7 @@ func (pm *ProtocolManager) handleMsg(p *peer) error { } return p.SendBlockHashes(hashes) - case GetBlockHashesFromNumberMsg: + case p.version < eth62 && msg.Code == GetBlockHashesFromNumberMsg: // Retrieve and decode the number of hashes to return and from which origin number var request getBlockHashesFromNumberData if err := msg.Decode(&request); err != nil { @@ -278,12 +280,10 @@ func (pm *ProtocolManager) handleMsg(p *peer) error { } return p.SendBlockHashes(hashes) - case BlockHashesMsg: + case p.version < eth62 && msg.Code == BlockHashesMsg: // A batch of hashes arrived to one of our previous requests - msgStream := rlp.NewStream(msg.Payload, uint64(msg.Size)) - var hashes []common.Hash - if err := msgStream.Decode(&hashes); err != nil { + if err := msg.Decode(&hashes); err != nil { break } // Deliver them all to the downloader for queuing @@ -292,7 +292,7 @@ func (pm *ProtocolManager) handleMsg(p *peer) error { glog.V(logger.Debug).Infoln(err) } - case GetBlocksMsg: + case p.version < eth62 && msg.Code == GetBlocksMsg: // Decode the retrieval message msgStream := rlp.NewStream(msg.Payload, uint64(msg.Size)) if _, err := msgStream.List(); err != nil { @@ -302,44 +302,28 @@ func (pm *ProtocolManager) handleMsg(p *peer) error { var ( hash common.Hash bytes common.StorageSize - hashes []common.Hash blocks []*types.Block ) - for { + for len(blocks) < downloader.MaxBlockFetch && bytes < softResponseLimit { + //Retrieve the hash of the next block err := msgStream.Decode(&hash) if err == rlp.EOL { break } else if err != nil { return errResp(ErrDecode, "msg %v: %v", msg, err) } - hashes = append(hashes, hash) - // Retrieve the requested block, stopping if enough was found if block := pm.chainman.GetBlock(hash); block != nil { blocks = append(blocks, block) bytes += block.Size() - if len(blocks) >= downloader.MaxBlockFetch || bytes > maxBlockRespSize { - break - } } } - if glog.V(logger.Detail) && len(blocks) == 0 && len(hashes) > 0 { - list := "[" - for _, hash := range hashes { - list += fmt.Sprintf("%x, ", hash[:4]) - } - list = list[:len(list)-2] + "]" - - glog.Infof("%v: no blocks found for requested hashes %s", p, list) - } return p.SendBlocks(blocks) - case BlocksMsg: + case p.version < eth62 && msg.Code == BlocksMsg: // Decode the arrived block message - msgStream := rlp.NewStream(msg.Payload, uint64(msg.Size)) - var blocks []*types.Block - if err := msgStream.Decode(&blocks); err != nil { + if err := msg.Decode(&blocks); err != nil { glog.V(logger.Detail).Infoln("Decode error", err) blocks = nil } @@ -352,31 +336,196 @@ func (pm *ProtocolManager) handleMsg(p *peer) error { pm.downloader.DeliverBlocks(p.id, blocks) } - case NewBlockHashesMsg: - // Retrieve and deseralize the remote new block hashes notification - msgStream := rlp.NewStream(msg.Payload, uint64(msg.Size)) + // Block header query, collect the requested headers and reply + case p.version >= eth62 && msg.Code == GetBlockHeadersMsg: + // Decode the complex header query + var query getBlockHeadersData + if err := msg.Decode(&query); err != nil { + return errResp(ErrDecode, "%v: %v", msg, err) + } + // Gather blocks until the fetch or network limits is reached + var ( + bytes common.StorageSize + headers []*types.Header + unknown bool + ) + for !unknown && len(headers) < int(query.Amount) && bytes < softResponseLimit && len(headers) < downloader.MaxHeaderFetch { + // Retrieve the next block satisfying the query + var origin *types.Block + if query.Origin.Hash != (common.Hash{}) { + origin = pm.chainman.GetBlock(query.Origin.Hash) + } else { + origin = pm.chainman.GetBlockByNumber(query.Origin.Number) + } + if origin == nil { + break + } + headers = append(headers, origin.Header()) + bytes += origin.Size() - var hashes []common.Hash - if err := msgStream.Decode(&hashes); err != nil { - break - } - // Mark the hashes as present at the remote node - for _, hash := range hashes { - p.MarkBlock(hash) - p.SetHead(hash) - } - // Schedule all the unknown hashes for retrieval - unknown := make([]common.Hash, 0, len(hashes)) - for _, hash := range hashes { - if !pm.chainman.HasBlock(hash) { - unknown = append(unknown, hash) + // Advance to the next block of the query + switch { + case query.Origin.Hash != (common.Hash{}) && query.Reverse: + // Hash based traversal towards the genesis block + for i := 0; i < int(query.Skip)+1; i++ { + if block := pm.chainman.GetBlock(query.Origin.Hash); block != nil { + query.Origin.Hash = block.ParentHash() + } else { + unknown = true + break + } + } + case query.Origin.Hash != (common.Hash{}) && !query.Reverse: + // Hash based traversal towards the leaf block + if block := pm.chainman.GetBlockByNumber(origin.NumberU64() + query.Skip + 1); block != nil { + if pm.chainman.GetBlockHashesFromHash(block.Hash(), query.Skip+1)[query.Skip] == query.Origin.Hash { + query.Origin.Hash = block.Hash() + } else { + unknown = true + } + } else { + unknown = true + } + case query.Reverse: + // Number based traversal towards the genesis block + if query.Origin.Number >= query.Skip+1 { + query.Origin.Number -= (query.Skip + 1) + } else { + unknown = true + } + + case !query.Reverse: + // Number based traversal towards the leaf block + query.Origin.Number += (query.Skip + 1) } } - for _, hash := range unknown { - pm.fetcher.Notify(p.id, hash, time.Now(), p.RequestBlocks) + return p.SendBlockHeaders(headers) + + case p.version >= eth62 && msg.Code == GetBlockBodiesMsg: + // Decode the retrieval message + msgStream := rlp.NewStream(msg.Payload, uint64(msg.Size)) + if _, err := msgStream.List(); err != nil { + return err + } + // Gather blocks until the fetch or network limits is reached + var ( + hash common.Hash + bytes common.StorageSize + bodies []*blockBody + ) + for bytes < softResponseLimit && len(bodies) < downloader.MaxBlockFetch { + //Retrieve the hash of the next block + if err := msgStream.Decode(&hash); err == rlp.EOL { + break + } else if err != nil { + return errResp(ErrDecode, "msg %v: %v", msg, err) + } + // Retrieve the requested block, stopping if enough was found + if block := pm.chainman.GetBlock(hash); block != nil { + bodies = append(bodies, &blockBody{Transactions: block.Transactions(), Uncles: block.Uncles()}) + bytes += block.Size() + } + } + return p.SendBlockBodies(bodies) + + case p.version >= eth63 && msg.Code == GetNodeDataMsg: + // Decode the retrieval message + msgStream := rlp.NewStream(msg.Payload, uint64(msg.Size)) + if _, err := msgStream.List(); err != nil { + return err + } + // Gather state data until the fetch or network limits is reached + var ( + hash common.Hash + bytes int + data [][]byte + ) + for bytes < softResponseLimit && len(data) < downloader.MaxStateFetch { + // Retrieve the hash of the next state entry + if err := msgStream.Decode(&hash); err == rlp.EOL { + break + } else if err != nil { + return errResp(ErrDecode, "msg %v: %v", msg, err) + } + // Retrieve the requested state entry, stopping if enough was found + if entry, err := pm.chaindb.Get(hash.Bytes()); err == nil { + data = append(data, entry) + bytes += len(entry) + } + } + return p.SendNodeData(data) + + case p.version >= eth63 && msg.Code == GetReceiptsMsg: + // Decode the retrieval message + msgStream := rlp.NewStream(msg.Payload, uint64(msg.Size)) + if _, err := msgStream.List(); err != nil { + return err + } + // Gather state data until the fetch or network limits is reached + var ( + hash common.Hash + bytes int + receipts []*types.Receipt + ) + for bytes < softResponseLimit && len(receipts) < downloader.MaxReceiptsFetch { + // Retrieve the hash of the next transaction receipt + if err := msgStream.Decode(&hash); err == rlp.EOL { + break + } else if err != nil { + return errResp(ErrDecode, "msg %v: %v", msg, err) + } + // Retrieve the requested receipt, stopping if enough was found + if receipt := core.GetReceipt(pm.chaindb, hash); receipt != nil { + receipts = append(receipts, receipt) + bytes += len(receipt.RlpEncode()) + } + } + return p.SendReceipts(receipts) + + case msg.Code == NewBlockHashesMsg: + // Retrieve and deseralize the remote new block hashes notification + type announce struct { + Hash common.Hash + Number uint64 + } + var announces = []announce{} + + if p.version < eth62 { + // We're running the old protocol, make block number unknown (0) + var hashes []common.Hash + if err := msg.Decode(&hashes); err != nil { + return errResp(ErrDecode, "%v: %v", msg, err) + } + for _, hash := range hashes { + announces = append(announces, announce{hash, 0}) + } + } else { + // Otherwise extract both block hash and number + var request newBlockHashesData + if err := msg.Decode(&request); err != nil { + return errResp(ErrDecode, "%v: %v", msg, err) + } + for _, block := range request { + announces = append(announces, announce{block.Hash, block.Number}) + } + } + // Mark the hashes as present at the remote node + for _, block := range announces { + p.MarkBlock(block.Hash) + p.SetHead(block.Hash) + } + // Schedule all the unknown hashes for retrieval + unknown := make([]announce, 0, len(announces)) + for _, block := range announces { + if !pm.chainman.HasBlock(block.Hash) { + unknown = append(unknown, block) + } + } + for _, block := range unknown { + pm.fetcher.Notify(p.id, block.Hash, block.Number, time.Now(), p.RequestBlocks) } - case NewBlockMsg: + case msg.Code == NewBlockMsg: // Retrieve and decode the propagated block var request newBlockData if err := msg.Decode(&request); err != nil { @@ -410,7 +559,7 @@ func (pm *ProtocolManager) handleMsg(p *peer) error { } } - case TxMsg: + case msg.Code == TxMsg: // Transactions arrived, parse all of them and deliver to the pool var txs []*types.Transaction if err := msg.Decode(&txs); err != nil { diff --git a/eth/handler_test.go b/eth/handler_test.go new file mode 100644 index 000000000..63c94faa1 --- /dev/null +++ b/eth/handler_test.go @@ -0,0 +1,525 @@ +package eth + +import ( + "fmt" + "math/big" + "math/rand" + "testing" + + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/core" + "github.com/ethereum/go-ethereum/core/state" + "github.com/ethereum/go-ethereum/core/types" + "github.com/ethereum/go-ethereum/crypto" + "github.com/ethereum/go-ethereum/eth/downloader" + "github.com/ethereum/go-ethereum/ethdb" + "github.com/ethereum/go-ethereum/p2p" + "github.com/ethereum/go-ethereum/params" +) + +// Tests that hashes can be retrieved from a remote chain by hashes in reverse +// order. +func TestGetBlockHashes60(t *testing.T) { testGetBlockHashes(t, 60) } +func TestGetBlockHashes61(t *testing.T) { testGetBlockHashes(t, 61) } + +func testGetBlockHashes(t *testing.T, protocol int) { + pm := newTestProtocolManager(downloader.MaxHashFetch+15, nil, nil) + peer, _ := newTestPeer("peer", protocol, pm, true) + defer peer.close() + + // Create a batch of tests for various scenarios + limit := downloader.MaxHashFetch + tests := []struct { + origin common.Hash + number int + result int + }{ + {common.Hash{}, 1, 0}, // Make sure non existent hashes don't return results + {pm.chainman.Genesis().Hash(), 1, 0}, // There are no hashes to retrieve up from the genesis + {pm.chainman.GetBlockByNumber(5).Hash(), 5, 5}, // All the hashes including the genesis requested + {pm.chainman.GetBlockByNumber(5).Hash(), 10, 5}, // More hashes than available till the genesis requested + {pm.chainman.GetBlockByNumber(100).Hash(), 10, 10}, // All hashes available from the middle of the chain + {pm.chainman.CurrentBlock().Hash(), 10, 10}, // All hashes available from the head of the chain + {pm.chainman.CurrentBlock().Hash(), limit, limit}, // Request the maximum allowed hash count + {pm.chainman.CurrentBlock().Hash(), limit + 1, limit}, // Request more than the maximum allowed hash count + } + // Run each of the tests and verify the results against the chain + for i, tt := range tests { + // Assemble the hash response we would like to receive + resp := make([]common.Hash, tt.result) + if len(resp) > 0 { + from := pm.chainman.GetBlock(tt.origin).NumberU64() - 1 + for j := 0; j < len(resp); j++ { + resp[j] = pm.chainman.GetBlockByNumber(uint64(int(from) - j)).Hash() + } + } + // Send the hash request and verify the response + p2p.Send(peer.app, 0x03, getBlockHashesData{tt.origin, uint64(tt.number)}) + if err := p2p.ExpectMsg(peer.app, 0x04, resp); err != nil { + t.Errorf("test %d: block hashes mismatch: %v", i, err) + } + } +} + +// Tests that hashes can be retrieved from a remote chain by numbers in forward +// order. +func TestGetBlockHashesFromNumber60(t *testing.T) { testGetBlockHashesFromNumber(t, 60) } +func TestGetBlockHashesFromNumber61(t *testing.T) { testGetBlockHashesFromNumber(t, 61) } + +func testGetBlockHashesFromNumber(t *testing.T, protocol int) { + pm := newTestProtocolManager(downloader.MaxHashFetch+15, nil, nil) + peer, _ := newTestPeer("peer", protocol, pm, true) + defer peer.close() + + // Create a batch of tests for various scenarios + limit := downloader.MaxHashFetch + tests := []struct { + origin uint64 + number int + result int + }{ + {pm.chainman.CurrentBlock().NumberU64() + 1, 1, 0}, // Out of bounds requests should return empty + {pm.chainman.CurrentBlock().NumberU64(), 1, 1}, // Make sure the head hash can be retrieved + {pm.chainman.CurrentBlock().NumberU64() - 4, 5, 5}, // All hashes, including the head hash requested + {pm.chainman.CurrentBlock().NumberU64() - 4, 10, 5}, // More hashes requested than available till the head + {pm.chainman.CurrentBlock().NumberU64() - 100, 10, 10}, // All hashes available from the middle of the chain + {0, 10, 10}, // All hashes available from the root of the chain + {0, limit, limit}, // Request the maximum allowed hash count + {0, limit + 1, limit}, // Request more than the maximum allowed hash count + {0, 1, 1}, // Make sure the genesis hash can be retrieved + } + // Run each of the tests and verify the results against the chain + for i, tt := range tests { + // Assemble the hash response we would like to receive + resp := make([]common.Hash, tt.result) + for j := 0; j < len(resp); j++ { + resp[j] = pm.chainman.GetBlockByNumber(tt.origin + uint64(j)).Hash() + } + // Send the hash request and verify the response + p2p.Send(peer.app, 0x08, getBlockHashesFromNumberData{tt.origin, uint64(tt.number)}) + if err := p2p.ExpectMsg(peer.app, 0x04, resp); err != nil { + t.Errorf("test %d: block hashes mismatch: %v", i, err) + } + } +} + +// Tests that blocks can be retrieved from a remote chain based on their hashes. +func TestGetBlocks60(t *testing.T) { testGetBlocks(t, 60) } +func TestGetBlocks61(t *testing.T) { testGetBlocks(t, 61) } + +func testGetBlocks(t *testing.T, protocol int) { + pm := newTestProtocolManager(downloader.MaxHashFetch+15, nil, nil) + peer, _ := newTestPeer("peer", protocol, pm, true) + defer peer.close() + + // Create a batch of tests for various scenarios + limit := downloader.MaxBlockFetch + tests := []struct { + random int // Number of blocks to fetch randomly from the chain + explicit []common.Hash // Explicitly requested blocks + available []bool // Availability of explicitly requested blocks + expected int // Total number of existing blocks to expect + }{ + {1, nil, nil, 1}, // A single random block should be retrievable + {10, nil, nil, 10}, // Multiple random blocks should be retrievable + {limit, nil, nil, limit}, // The maximum possible blocks should be retrievable + {limit + 1, nil, nil, limit}, // No more that the possible block count should be returned + {0, []common.Hash{pm.chainman.Genesis().Hash()}, []bool{true}, 1}, // The genesis block should be retrievable + {0, []common.Hash{pm.chainman.CurrentBlock().Hash()}, []bool{true}, 1}, // The chains head block should be retrievable + {0, []common.Hash{common.Hash{}}, []bool{false}, 0}, // A non existent block should not be returned + + // Existing and non-existing blocks interleaved should not cause problems + {0, []common.Hash{ + common.Hash{}, + pm.chainman.GetBlockByNumber(1).Hash(), + common.Hash{}, + pm.chainman.GetBlockByNumber(10).Hash(), + common.Hash{}, + pm.chainman.GetBlockByNumber(100).Hash(), + common.Hash{}, + }, []bool{false, true, false, true, false, true, false}, 3}, + } + // Run each of the tests and verify the results against the chain + for i, tt := range tests { + // Collect the hashes to request, and the response to expect + hashes, seen := []common.Hash{}, make(map[int64]bool) + blocks := []*types.Block{} + + for j := 0; j < tt.random; j++ { + for { + num := rand.Int63n(int64(pm.chainman.CurrentBlock().NumberU64())) + if !seen[num] { + seen[num] = true + + block := pm.chainman.GetBlockByNumber(uint64(num)) + hashes = append(hashes, block.Hash()) + if len(blocks) < tt.expected { + blocks = append(blocks, block) + } + break + } + } + } + for j, hash := range tt.explicit { + hashes = append(hashes, hash) + if tt.available[j] && len(blocks) < tt.expected { + blocks = append(blocks, pm.chainman.GetBlock(hash)) + } + } + // Send the hash request and verify the response + p2p.Send(peer.app, 0x05, hashes) + if err := p2p.ExpectMsg(peer.app, 0x06, blocks); err != nil { + t.Errorf("test %d: blocks mismatch: %v", i, err) + } + } +} + +// Tests that block headers can be retrieved from a remote chain based on user queries. +func TestGetBlockHeaders62(t *testing.T) { testGetBlockHeaders(t, 62) } +func TestGetBlockHeaders63(t *testing.T) { testGetBlockHeaders(t, 63) } +func TestGetBlockHeaders64(t *testing.T) { testGetBlockHeaders(t, 64) } + +func testGetBlockHeaders(t *testing.T, protocol int) { + pm := newTestProtocolManager(downloader.MaxHashFetch+15, nil, nil) + peer, _ := newTestPeer("peer", protocol, pm, true) + defer peer.close() + + // Create a "random" unknown hash for testing + var unknown common.Hash + for i, _ := range unknown { + unknown[i] = byte(i) + } + // Create a batch of tests for various scenarios + limit := uint64(downloader.MaxHeaderFetch) + tests := []struct { + query *getBlockHeadersData // The query to execute for header retrieval + expect []common.Hash // The hashes of the block whose headers are expected + }{ + // A single random block should be retrievable by hash and number too + { + &getBlockHeadersData{Origin: hashOrNumber{Hash: pm.chainman.GetBlockByNumber(limit / 2).Hash()}, Amount: 1}, + []common.Hash{pm.chainman.GetBlockByNumber(limit / 2).Hash()}, + }, { + &getBlockHeadersData{Origin: hashOrNumber{Number: limit / 2}, Amount: 1}, + []common.Hash{pm.chainman.GetBlockByNumber(limit / 2).Hash()}, + }, + // Multiple headers should be retrievable in both directions + { + &getBlockHeadersData{Origin: hashOrNumber{Number: limit / 2}, Amount: 3}, + []common.Hash{ + pm.chainman.GetBlockByNumber(limit / 2).Hash(), + pm.chainman.GetBlockByNumber(limit/2 + 1).Hash(), + pm.chainman.GetBlockByNumber(limit/2 + 2).Hash(), + }, + }, { + &getBlockHeadersData{Origin: hashOrNumber{Number: limit / 2}, Amount: 3, Reverse: true}, + []common.Hash{ + pm.chainman.GetBlockByNumber(limit / 2).Hash(), + pm.chainman.GetBlockByNumber(limit/2 - 1).Hash(), + pm.chainman.GetBlockByNumber(limit/2 - 2).Hash(), + }, + }, + // Multiple headers with skip lists should be retrievable + { + &getBlockHeadersData{Origin: hashOrNumber{Number: limit / 2}, Skip: 3, Amount: 3}, + []common.Hash{ + pm.chainman.GetBlockByNumber(limit / 2).Hash(), + pm.chainman.GetBlockByNumber(limit/2 + 4).Hash(), + pm.chainman.GetBlockByNumber(limit/2 + 8).Hash(), + }, + }, { + &getBlockHeadersData{Origin: hashOrNumber{Number: limit / 2}, Skip: 3, Amount: 3, Reverse: true}, + []common.Hash{ + pm.chainman.GetBlockByNumber(limit / 2).Hash(), + pm.chainman.GetBlockByNumber(limit/2 - 4).Hash(), + pm.chainman.GetBlockByNumber(limit/2 - 8).Hash(), + }, + }, + // The chain endpoints should be retrievable + { + &getBlockHeadersData{Origin: hashOrNumber{Number: 0}, Amount: 1}, + []common.Hash{pm.chainman.GetBlockByNumber(0).Hash()}, + }, { + &getBlockHeadersData{Origin: hashOrNumber{Number: pm.chainman.CurrentBlock().NumberU64()}, Amount: 1}, + []common.Hash{pm.chainman.CurrentBlock().Hash()}, + }, + // Ensure protocol limits are honored + { + &getBlockHeadersData{Origin: hashOrNumber{Number: pm.chainman.CurrentBlock().NumberU64() - 1}, Amount: limit + 10, Reverse: true}, + pm.chainman.GetBlockHashesFromHash(pm.chainman.CurrentBlock().Hash(), limit), + }, + // Check that requesting more than available is handled gracefully + { + &getBlockHeadersData{Origin: hashOrNumber{Number: pm.chainman.CurrentBlock().NumberU64() - 4}, Skip: 3, Amount: 3}, + []common.Hash{ + pm.chainman.GetBlockByNumber(pm.chainman.CurrentBlock().NumberU64() - 4).Hash(), + pm.chainman.GetBlockByNumber(pm.chainman.CurrentBlock().NumberU64()).Hash(), + }, + }, { + &getBlockHeadersData{Origin: hashOrNumber{Number: 4}, Skip: 3, Amount: 3, Reverse: true}, + []common.Hash{ + pm.chainman.GetBlockByNumber(4).Hash(), + pm.chainman.GetBlockByNumber(0).Hash(), + }, + }, + // Check that requesting more than available is handled gracefully, even if mid skip + { + &getBlockHeadersData{Origin: hashOrNumber{Number: pm.chainman.CurrentBlock().NumberU64() - 4}, Skip: 2, Amount: 3}, + []common.Hash{ + pm.chainman.GetBlockByNumber(pm.chainman.CurrentBlock().NumberU64() - 4).Hash(), + pm.chainman.GetBlockByNumber(pm.chainman.CurrentBlock().NumberU64() - 1).Hash(), + }, + }, { + &getBlockHeadersData{Origin: hashOrNumber{Number: 4}, Skip: 2, Amount: 3, Reverse: true}, + []common.Hash{ + pm.chainman.GetBlockByNumber(4).Hash(), + pm.chainman.GetBlockByNumber(1).Hash(), + }, + }, + // Check that non existing headers aren't returned + { + &getBlockHeadersData{Origin: hashOrNumber{Hash: unknown}, Amount: 1}, + []common.Hash{}, + }, { + &getBlockHeadersData{Origin: hashOrNumber{Number: pm.chainman.CurrentBlock().NumberU64() + 1}, Amount: 1}, + []common.Hash{}, + }, + } + // Run each of the tests and verify the results against the chain + for i, tt := range tests { + // Collect the headers to expect in the response + headers := []*types.Header{} + for _, hash := range tt.expect { + headers = append(headers, pm.chainman.GetBlock(hash).Header()) + } + // Send the hash request and verify the response + p2p.Send(peer.app, 0x03, tt.query) + if err := p2p.ExpectMsg(peer.app, 0x04, headers); err != nil { + t.Errorf("test %d: headers mismatch: %v", i, err) + } + } +} + +// Tests that block contents can be retrieved from a remote chain based on their hashes. +func TestGetBlockBodies62(t *testing.T) { testGetBlockBodies(t, 62) } +func TestGetBlockBodies63(t *testing.T) { testGetBlockBodies(t, 63) } +func TestGetBlockBodies64(t *testing.T) { testGetBlockBodies(t, 64) } + +func testGetBlockBodies(t *testing.T, protocol int) { + pm := newTestProtocolManager(downloader.MaxBlockFetch+15, nil, nil) + peer, _ := newTestPeer("peer", protocol, pm, true) + defer peer.close() + + // Create a batch of tests for various scenarios + limit := downloader.MaxBlockFetch + tests := []struct { + random int // Number of blocks to fetch randomly from the chain + explicit []common.Hash // Explicitly requested blocks + available []bool // Availability of explicitly requested blocks + expected int // Total number of existing blocks to expect + }{ + {1, nil, nil, 1}, // A single random block should be retrievable + {10, nil, nil, 10}, // Multiple random blocks should be retrievable + {limit, nil, nil, limit}, // The maximum possible blocks should be retrievable + {limit + 1, nil, nil, limit}, // No more that the possible block count should be returned + {0, []common.Hash{pm.chainman.Genesis().Hash()}, []bool{true}, 1}, // The genesis block should be retrievable + {0, []common.Hash{pm.chainman.CurrentBlock().Hash()}, []bool{true}, 1}, // The chains head block should be retrievable + {0, []common.Hash{common.Hash{}}, []bool{false}, 0}, // A non existent block should not be returned + + // Existing and non-existing blocks interleaved should not cause problems + {0, []common.Hash{ + common.Hash{}, + pm.chainman.GetBlockByNumber(1).Hash(), + common.Hash{}, + pm.chainman.GetBlockByNumber(10).Hash(), + common.Hash{}, + pm.chainman.GetBlockByNumber(100).Hash(), + common.Hash{}, + }, []bool{false, true, false, true, false, true, false}, 3}, + } + // Run each of the tests and verify the results against the chain + for i, tt := range tests { + // Collect the hashes to request, and the response to expect + hashes, seen := []common.Hash{}, make(map[int64]bool) + bodies := []*blockBody{} + + for j := 0; j < tt.random; j++ { + for { + num := rand.Int63n(int64(pm.chainman.CurrentBlock().NumberU64())) + if !seen[num] { + seen[num] = true + + block := pm.chainman.GetBlockByNumber(uint64(num)) + hashes = append(hashes, block.Hash()) + if len(bodies) < tt.expected { + bodies = append(bodies, &blockBody{Transactions: block.Transactions(), Uncles: block.Uncles()}) + } + break + } + } + } + for j, hash := range tt.explicit { + hashes = append(hashes, hash) + if tt.available[j] && len(bodies) < tt.expected { + block := pm.chainman.GetBlock(hash) + bodies = append(bodies, &blockBody{Transactions: block.Transactions(), Uncles: block.Uncles()}) + } + } + // Send the hash request and verify the response + p2p.Send(peer.app, 0x05, hashes) + if err := p2p.ExpectMsg(peer.app, 0x06, bodies); err != nil { + t.Errorf("test %d: bodies mismatch: %v", i, err) + } + } +} + +// Tests that the node state database can be retrieved based on hashes. +func TestGetNodeData63(t *testing.T) { testGetNodeData(t, 63) } +func TestGetNodeData64(t *testing.T) { testGetNodeData(t, 64) } + +func testGetNodeData(t *testing.T, protocol int) { + // Define three accounts to simulate transactions with + acc1Key, _ := crypto.HexToECDSA("8a1f9a8f95be41cd7ccb6168179afb4504aefe388d1e14474d32c45c72ce7b7a") + acc2Key, _ := crypto.HexToECDSA("49a7b37aa6f6645917e7b807e9d1c00d4fa71f18343b0d4122a4d2df64dd6fee") + acc1Addr := crypto.PubkeyToAddress(acc1Key.PublicKey) + acc2Addr := crypto.PubkeyToAddress(acc2Key.PublicKey) + + // Create a chain generator with some simple transactions (blatantly stolen from @fjl/chain_makerts_test) + generator := func(i int, block *core.BlockGen) { + switch i { + case 0: + // In block 1, the test bank sends account #1 some ether. + tx, _ := types.NewTransaction(block.TxNonce(testBankAddress), acc1Addr, big.NewInt(10000), params.TxGas, nil, nil).SignECDSA(testBankKey) + block.AddTx(tx) + case 1: + // In block 2, the test bank sends some more ether to account #1. + // acc1Addr passes it on to account #2. + tx1, _ := types.NewTransaction(block.TxNonce(testBankAddress), acc1Addr, big.NewInt(1000), params.TxGas, nil, nil).SignECDSA(testBankKey) + tx2, _ := types.NewTransaction(block.TxNonce(acc1Addr), acc2Addr, big.NewInt(1000), params.TxGas, nil, nil).SignECDSA(acc1Key) + block.AddTx(tx1) + block.AddTx(tx2) + case 2: + // Block 3 is empty but was mined by account #2. + block.SetCoinbase(acc2Addr) + block.SetExtra([]byte("yeehaw")) + case 3: + // Block 4 includes blocks 2 and 3 as uncle headers (with modified extra data). + b2 := block.PrevBlock(1).Header() + b2.Extra = []byte("foo") + block.AddUncle(b2) + b3 := block.PrevBlock(2).Header() + b3.Extra = []byte("foo") + block.AddUncle(b3) + } + } + // Assemble the test environment + pm := newTestProtocolManager(4, generator, nil) + peer, _ := newTestPeer("peer", protocol, pm, true) + defer peer.close() + + // Fetch for now the entire chain db + hashes := []common.Hash{} + for _, key := range pm.chaindb.(*ethdb.MemDatabase).Keys() { + hashes = append(hashes, common.BytesToHash(key)) + } + p2p.Send(peer.app, 0x0d, hashes) + msg, err := peer.app.ReadMsg() + if err != nil { + t.Fatalf("failed to read node data response: %v", err) + } + if msg.Code != 0x0e { + t.Fatalf("response packet code mismatch: have %x, want %x", msg.Code, 0x0c) + } + var data [][]byte + if err := msg.Decode(&data); err != nil { + t.Fatalf("failed to decode response node data: %v", err) + } + // Verify that all hashes correspond to the requested data, and reconstruct a state tree + for i, want := range hashes { + if hash := crypto.Sha3Hash(data[i]); hash != want { + fmt.Errorf("data hash mismatch: have %x, want %x", hash, want) + } + } + statedb, _ := ethdb.NewMemDatabase() + for i := 0; i < len(data); i++ { + statedb.Put(hashes[i].Bytes(), data[i]) + } + accounts := []common.Address{testBankAddress, acc1Addr, acc2Addr} + for i := uint64(0); i <= pm.chainman.CurrentBlock().NumberU64(); i++ { + trie := state.New(pm.chainman.GetBlockByNumber(i).Root(), statedb) + + for j, acc := range accounts { + bw := pm.chainman.State().GetBalance(acc) + bh := trie.GetBalance(acc) + + if (bw != nil && bh == nil) || (bw == nil && bh != nil) { + t.Errorf("test %d, account %d: balance mismatch: have %v, want %v", i, j, bh, bw) + } + if bw != nil && bh != nil && bw.Cmp(bw) != 0 { + t.Errorf("test %d, account %d: balance mismatch: have %v, want %v", i, j, bh, bw) + } + } + } +} + +// Tests that the transaction receipts can be retrieved based on hashes. +func TestGetReceipt63(t *testing.T) { testGetReceipt(t, 63) } +func TestGetReceipt64(t *testing.T) { testGetReceipt(t, 64) } + +func testGetReceipt(t *testing.T, protocol int) { + // Define three accounts to simulate transactions with + acc1Key, _ := crypto.HexToECDSA("8a1f9a8f95be41cd7ccb6168179afb4504aefe388d1e14474d32c45c72ce7b7a") + acc2Key, _ := crypto.HexToECDSA("49a7b37aa6f6645917e7b807e9d1c00d4fa71f18343b0d4122a4d2df64dd6fee") + acc1Addr := crypto.PubkeyToAddress(acc1Key.PublicKey) + acc2Addr := crypto.PubkeyToAddress(acc2Key.PublicKey) + + // Create a chain generator with some simple transactions (blatantly stolen from @fjl/chain_makerts_test) + generator := func(i int, block *core.BlockGen) { + switch i { + case 0: + // In block 1, the test bank sends account #1 some ether. + tx, _ := types.NewTransaction(block.TxNonce(testBankAddress), acc1Addr, big.NewInt(10000), params.TxGas, nil, nil).SignECDSA(testBankKey) + block.AddTx(tx) + case 1: + // In block 2, the test bank sends some more ether to account #1. + // acc1Addr passes it on to account #2. + tx1, _ := types.NewTransaction(block.TxNonce(testBankAddress), acc1Addr, big.NewInt(1000), params.TxGas, nil, nil).SignECDSA(testBankKey) + tx2, _ := types.NewTransaction(block.TxNonce(acc1Addr), acc2Addr, big.NewInt(1000), params.TxGas, nil, nil).SignECDSA(acc1Key) + block.AddTx(tx1) + block.AddTx(tx2) + case 2: + // Block 3 is empty but was mined by account #2. + block.SetCoinbase(acc2Addr) + block.SetExtra([]byte("yeehaw")) + case 3: + // Block 4 includes blocks 2 and 3 as uncle headers (with modified extra data). + b2 := block.PrevBlock(1).Header() + b2.Extra = []byte("foo") + block.AddUncle(b2) + b3 := block.PrevBlock(2).Header() + b3.Extra = []byte("foo") + block.AddUncle(b3) + } + } + // Assemble the test environment + pm := newTestProtocolManager(4, generator, nil) + peer, _ := newTestPeer("peer", protocol, pm, true) + defer peer.close() + + // Collect the hashes to request, and the response to expect + hashes := []common.Hash{} + for i := uint64(0); i <= pm.chainman.CurrentBlock().NumberU64(); i++ { + for _, tx := range pm.chainman.GetBlockByNumber(i).Transactions() { + hashes = append(hashes, tx.Hash()) + } + } + receipts := make([]*types.Receipt, len(hashes)) + for i, hash := range hashes { + receipts[i] = core.GetReceipt(pm.chaindb, hash) + } + // Send the hash request and verify the response + p2p.Send(peer.app, 0x0f, hashes) + if err := p2p.ExpectMsg(peer.app, 0x10, receipts); err != nil { + t.Errorf("receipts mismatch: %v", err) + } +} diff --git a/eth/helper_test.go b/eth/helper_test.go new file mode 100644 index 000000000..3a799e6f6 --- /dev/null +++ b/eth/helper_test.go @@ -0,0 +1,147 @@ +// This file contains some shares testing functionality, common to multiple +// different files and modules being tested. + +package eth + +import ( + "crypto/rand" + "math/big" + "sync" + "testing" + + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/core" + "github.com/ethereum/go-ethereum/core/types" + "github.com/ethereum/go-ethereum/crypto" + "github.com/ethereum/go-ethereum/ethdb" + "github.com/ethereum/go-ethereum/event" + "github.com/ethereum/go-ethereum/p2p" + "github.com/ethereum/go-ethereum/p2p/discover" +) + +var ( + testBankKey, _ = crypto.HexToECDSA("b71c71a67e1177ad4e901695e1b4b9ee17ae16c6668d313eac2f96dbcda3f291") + testBankAddress = crypto.PubkeyToAddress(testBankKey.PublicKey) + testBankFunds = big.NewInt(1000000) +) + +// newTestProtocolManager creates a new protocol manager for testing purposes, +// with the given number of blocks already known, and potential notification +// channels for different events. +func newTestProtocolManager(blocks int, generator func(int, *core.BlockGen), newtx chan<- []*types.Transaction) *ProtocolManager { + var ( + evmux = new(event.TypeMux) + pow = new(core.FakePow) + db, _ = ethdb.NewMemDatabase() + genesis = core.WriteGenesisBlockForTesting(db, testBankAddress, testBankFunds) + chainman, _ = core.NewChainManager(db, pow, evmux) + blockproc = core.NewBlockProcessor(db, pow, chainman, evmux) + ) + chainman.SetProcessor(blockproc) + if _, err := chainman.InsertChain(core.GenerateChain(genesis, db, blocks, generator)); err != nil { + panic(err) + } + pm := NewProtocolManager(NetworkId, evmux, &testTxPool{added: newtx}, pow, chainman, db) + pm.Start() + return pm +} + +// testTxPool is a fake, helper transaction pool for testing purposes +type testTxPool struct { + pool []*types.Transaction // Collection of all transactions + added chan<- []*types.Transaction // Notification channel for new transactions + + lock sync.RWMutex // Protects the transaction pool +} + +// AddTransactions appends a batch of transactions to the pool, and notifies any +// listeners if the addition channel is non nil +func (p *testTxPool) AddTransactions(txs []*types.Transaction) { + p.lock.Lock() + defer p.lock.Unlock() + + p.pool = append(p.pool, txs...) + if p.added != nil { + p.added <- txs + } +} + +// GetTransactions returns all the transactions known to the pool +func (p *testTxPool) GetTransactions() types.Transactions { + p.lock.RLock() + defer p.lock.RUnlock() + + txs := make([]*types.Transaction, len(p.pool)) + copy(txs, p.pool) + + return txs +} + +// newTestTransaction create a new dummy transaction. +func newTestTransaction(from *crypto.Key, nonce uint64, datasize int) *types.Transaction { + tx := types.NewTransaction(nonce, common.Address{}, big.NewInt(0), big.NewInt(100000), big.NewInt(0), make([]byte, datasize)) + tx, _ = tx.SignECDSA(from.PrivateKey) + + return tx +} + +// testPeer is a simulated peer to allow testing direct network calls. +type testPeer struct { + net p2p.MsgReadWriter // Network layer reader/writer to simulate remote messaging + app *p2p.MsgPipeRW // Application layer reader/writer to simulate the local side + *peer +} + +// newTestPeer creates a new peer registered at the given protocol manager. +func newTestPeer(name string, version int, pm *ProtocolManager, shake bool) (*testPeer, <-chan error) { + // Create a message pipe to communicate through + app, net := p2p.MsgPipe() + + // Generate a random id and create the peer + var id discover.NodeID + rand.Read(id[:]) + + peer := pm.newPeer(version, NetworkId, p2p.NewPeer(id, name, nil), net) + + // Start the peer on a new thread + errc := make(chan error, 1) + go func() { + pm.newPeerCh <- peer + errc <- pm.handle(peer) + }() + tp := &testPeer{ + app: app, + net: net, + peer: peer, + } + // Execute any implicitly requested handshakes and return + if shake { + td, head, genesis := pm.chainman.Status() + tp.handshake(nil, td, head, genesis) + } + return tp, errc +} + +// handshake simulates a trivial handshake that expects the same state from the +// remote side as we are simulating locally. +func (p *testPeer) handshake(t *testing.T, td *big.Int, head common.Hash, genesis common.Hash) { + msg := &statusData{ + ProtocolVersion: uint32(p.version), + NetworkId: uint32(NetworkId), + TD: td, + CurrentBlock: head, + GenesisBlock: genesis, + } + if err := p2p.ExpectMsg(p.app, StatusMsg, msg); err != nil { + t.Fatalf("status recv: %v", err) + } + if err := p2p.Send(p.app, StatusMsg, msg); err != nil { + t.Fatalf("status send: %v", err) + } +} + +// close terminates the local side of the peer, notifying the remote protocol +// manager of termination. +func (p *testPeer) close() { + p.app.Close() +} diff --git a/eth/metrics.go b/eth/metrics.go index 13745dc43..778747210 100644 --- a/eth/metrics.go +++ b/eth/metrics.go @@ -22,44 +22,53 @@ import ( ) var ( - propTxnInPacketsMeter = metrics.NewMeter("eth/prop/txns/in/packets") - propTxnInTrafficMeter = metrics.NewMeter("eth/prop/txns/in/traffic") - propTxnOutPacketsMeter = metrics.NewMeter("eth/prop/txns/out/packets") - propTxnOutTrafficMeter = metrics.NewMeter("eth/prop/txns/out/traffic") - propHashInPacketsMeter = metrics.NewMeter("eth/prop/hashes/in/packets") - propHashInTrafficMeter = metrics.NewMeter("eth/prop/hashes/in/traffic") - propHashOutPacketsMeter = metrics.NewMeter("eth/prop/hashes/out/packets") - propHashOutTrafficMeter = metrics.NewMeter("eth/prop/hashes/out/traffic") - propBlockInPacketsMeter = metrics.NewMeter("eth/prop/blocks/in/packets") - propBlockInTrafficMeter = metrics.NewMeter("eth/prop/blocks/in/traffic") - propBlockOutPacketsMeter = metrics.NewMeter("eth/prop/blocks/out/packets") - propBlockOutTrafficMeter = metrics.NewMeter("eth/prop/blocks/out/traffic") - reqHashInPacketsMeter = metrics.NewMeter("eth/req/hashes/in/packets") - reqHashInTrafficMeter = metrics.NewMeter("eth/req/hashes/in/traffic") - reqHashOutPacketsMeter = metrics.NewMeter("eth/req/hashes/out/packets") - reqHashOutTrafficMeter = metrics.NewMeter("eth/req/hashes/out/traffic") - reqBlockInPacketsMeter = metrics.NewMeter("eth/req/blocks/in/packets") - reqBlockInTrafficMeter = metrics.NewMeter("eth/req/blocks/in/traffic") - reqBlockOutPacketsMeter = metrics.NewMeter("eth/req/blocks/out/packets") - reqBlockOutTrafficMeter = metrics.NewMeter("eth/req/blocks/out/traffic") - reqHeaderInPacketsMeter = metrics.NewMeter("eth/req/header/in/packets") - reqHeaderInTrafficMeter = metrics.NewMeter("eth/req/header/in/traffic") - reqHeaderOutPacketsMeter = metrics.NewMeter("eth/req/header/out/packets") - reqHeaderOutTrafficMeter = metrics.NewMeter("eth/req/header/out/traffic") - reqStateInPacketsMeter = metrics.NewMeter("eth/req/state/in/packets") - reqStateInTrafficMeter = metrics.NewMeter("eth/req/state/in/traffic") - reqStateOutPacketsMeter = metrics.NewMeter("eth/req/state/out/packets") - reqStateOutTrafficMeter = metrics.NewMeter("eth/req/state/out/traffic") - miscInPacketsMeter = metrics.NewMeter("eth/misc/in/packets") - miscInTrafficMeter = metrics.NewMeter("eth/misc/in/traffic") - miscOutPacketsMeter = metrics.NewMeter("eth/misc/out/packets") - miscOutTrafficMeter = metrics.NewMeter("eth/misc/out/traffic") + propTxnInPacketsMeter = metrics.NewMeter("eth/prop/txns/in/packets") + propTxnInTrafficMeter = metrics.NewMeter("eth/prop/txns/in/traffic") + propTxnOutPacketsMeter = metrics.NewMeter("eth/prop/txns/out/packets") + propTxnOutTrafficMeter = metrics.NewMeter("eth/prop/txns/out/traffic") + propHashInPacketsMeter = metrics.NewMeter("eth/prop/hashes/in/packets") + propHashInTrafficMeter = metrics.NewMeter("eth/prop/hashes/in/traffic") + propHashOutPacketsMeter = metrics.NewMeter("eth/prop/hashes/out/packets") + propHashOutTrafficMeter = metrics.NewMeter("eth/prop/hashes/out/traffic") + propBlockInPacketsMeter = metrics.NewMeter("eth/prop/blocks/in/packets") + propBlockInTrafficMeter = metrics.NewMeter("eth/prop/blocks/in/traffic") + propBlockOutPacketsMeter = metrics.NewMeter("eth/prop/blocks/out/packets") + propBlockOutTrafficMeter = metrics.NewMeter("eth/prop/blocks/out/traffic") + reqHashInPacketsMeter = metrics.NewMeter("eth/req/hashes/in/packets") + reqHashInTrafficMeter = metrics.NewMeter("eth/req/hashes/in/traffic") + reqHashOutPacketsMeter = metrics.NewMeter("eth/req/hashes/out/packets") + reqHashOutTrafficMeter = metrics.NewMeter("eth/req/hashes/out/traffic") + reqBlockInPacketsMeter = metrics.NewMeter("eth/req/blocks/in/packets") + reqBlockInTrafficMeter = metrics.NewMeter("eth/req/blocks/in/traffic") + reqBlockOutPacketsMeter = metrics.NewMeter("eth/req/blocks/out/packets") + reqBlockOutTrafficMeter = metrics.NewMeter("eth/req/blocks/out/traffic") + reqHeaderInPacketsMeter = metrics.NewMeter("eth/req/header/in/packets") + reqHeaderInTrafficMeter = metrics.NewMeter("eth/req/header/in/traffic") + reqHeaderOutPacketsMeter = metrics.NewMeter("eth/req/header/out/packets") + reqHeaderOutTrafficMeter = metrics.NewMeter("eth/req/header/out/traffic") + reqBodyInPacketsMeter = metrics.NewMeter("eth/req/body/in/packets") + reqBodyInTrafficMeter = metrics.NewMeter("eth/req/body/in/traffic") + reqBodyOutPacketsMeter = metrics.NewMeter("eth/req/body/out/packets") + reqBodyOutTrafficMeter = metrics.NewMeter("eth/req/body/out/traffic") + reqStateInPacketsMeter = metrics.NewMeter("eth/req/state/in/packets") + reqStateInTrafficMeter = metrics.NewMeter("eth/req/state/in/traffic") + reqStateOutPacketsMeter = metrics.NewMeter("eth/req/state/out/packets") + reqStateOutTrafficMeter = metrics.NewMeter("eth/req/state/out/traffic") + reqReceiptInPacketsMeter = metrics.NewMeter("eth/req/receipt/in/packets") + reqReceiptInTrafficMeter = metrics.NewMeter("eth/req/receipt/in/traffic") + reqReceiptOutPacketsMeter = metrics.NewMeter("eth/req/receipt/out/packets") + reqReceiptOutTrafficMeter = metrics.NewMeter("eth/req/receipt/out/traffic") + miscInPacketsMeter = metrics.NewMeter("eth/misc/in/packets") + miscInTrafficMeter = metrics.NewMeter("eth/misc/in/traffic") + miscOutPacketsMeter = metrics.NewMeter("eth/misc/out/packets") + miscOutTrafficMeter = metrics.NewMeter("eth/misc/out/traffic") ) // meteredMsgReadWriter is a wrapper around a p2p.MsgReadWriter, capable of // accumulating the above defined metrics based on the data stream contents. type meteredMsgReadWriter struct { - p2p.MsgReadWriter + p2p.MsgReadWriter // Wrapped message stream to meter + version int // Protocol version to select correct meters } // newMeteredMsgWriter wraps a p2p MsgReadWriter with metering support. If the @@ -68,7 +77,13 @@ func newMeteredMsgWriter(rw p2p.MsgReadWriter) p2p.MsgReadWriter { if !metrics.Enabled { return rw } - return &meteredMsgReadWriter{rw} + return &meteredMsgReadWriter{MsgReadWriter: rw} +} + +// Init sets the protocol version used by the stream to know which meters to +// increment in case of overlapping message ids between protocol versions. +func (rw *meteredMsgReadWriter) Init(version int) { + rw.version = version } func (rw *meteredMsgReadWriter) ReadMsg() (p2p.Msg, error) { @@ -79,20 +94,27 @@ func (rw *meteredMsgReadWriter) ReadMsg() (p2p.Msg, error) { } // Account for the data traffic packets, traffic := miscInPacketsMeter, miscInTrafficMeter - switch msg.Code { - case BlockHashesMsg: + switch { + case (rw.version == eth60 || rw.version == eth61) && msg.Code == BlockHashesMsg: packets, traffic = reqHashInPacketsMeter, reqHashInTrafficMeter - case BlocksMsg: + case (rw.version == eth60 || rw.version == eth61) && msg.Code == BlocksMsg: packets, traffic = reqBlockInPacketsMeter, reqBlockInTrafficMeter - case BlockHeadersMsg: - packets, traffic = reqHeaderInPacketsMeter, reqHeaderInTrafficMeter - case NodeDataMsg: + + case rw.version == eth62 && msg.Code == BlockHeadersMsg: + packets, traffic = reqBlockInPacketsMeter, reqBlockInTrafficMeter + case rw.version == eth62 && msg.Code == BlockBodiesMsg: + packets, traffic = reqBodyInPacketsMeter, reqBodyInTrafficMeter + + case rw.version == eth63 && msg.Code == NodeDataMsg: packets, traffic = reqStateInPacketsMeter, reqStateInTrafficMeter - case NewBlockHashesMsg: + case rw.version == eth63 && msg.Code == ReceiptsMsg: + packets, traffic = reqReceiptInPacketsMeter, reqReceiptInTrafficMeter + + case msg.Code == NewBlockHashesMsg: packets, traffic = propHashInPacketsMeter, propHashInTrafficMeter - case NewBlockMsg: + case msg.Code == NewBlockMsg: packets, traffic = propBlockInPacketsMeter, propBlockInTrafficMeter - case TxMsg: + case msg.Code == TxMsg: packets, traffic = propTxnInPacketsMeter, propTxnInTrafficMeter } packets.Mark(1) @@ -104,20 +126,27 @@ func (rw *meteredMsgReadWriter) ReadMsg() (p2p.Msg, error) { func (rw *meteredMsgReadWriter) WriteMsg(msg p2p.Msg) error { // Account for the data traffic packets, traffic := miscOutPacketsMeter, miscOutTrafficMeter - switch msg.Code { - case BlockHashesMsg: + switch { + case (rw.version == eth60 || rw.version == eth61) && msg.Code == BlockHashesMsg: packets, traffic = reqHashOutPacketsMeter, reqHashOutTrafficMeter - case BlocksMsg: + case (rw.version == eth60 || rw.version == eth61) && msg.Code == BlocksMsg: packets, traffic = reqBlockOutPacketsMeter, reqBlockOutTrafficMeter - case BlockHeadersMsg: + + case rw.version == eth62 && msg.Code == BlockHeadersMsg: packets, traffic = reqHeaderOutPacketsMeter, reqHeaderOutTrafficMeter - case NodeDataMsg: + case rw.version == eth62 && msg.Code == BlockBodiesMsg: + packets, traffic = reqBodyOutPacketsMeter, reqBodyOutTrafficMeter + + case rw.version == eth63 && msg.Code == NodeDataMsg: packets, traffic = reqStateOutPacketsMeter, reqStateOutTrafficMeter - case NewBlockHashesMsg: + case rw.version == eth63 && msg.Code == ReceiptsMsg: + packets, traffic = reqReceiptOutPacketsMeter, reqReceiptOutTrafficMeter + + case msg.Code == NewBlockHashesMsg: packets, traffic = propHashOutPacketsMeter, propHashOutTrafficMeter - case NewBlockMsg: + case msg.Code == NewBlockMsg: packets, traffic = propBlockOutPacketsMeter, propBlockOutTrafficMeter - case TxMsg: + case msg.Code == TxMsg: packets, traffic = propTxnOutPacketsMeter, propTxnOutTrafficMeter } packets.Mark(1) diff --git a/eth/peer.go b/eth/peer.go index c17cdfca7..78de8a9d3 100644 --- a/eth/peer.go +++ b/eth/peer.go @@ -165,12 +165,23 @@ func (p *peer) SendBlockHeaders(headers []*types.Header) error { return p2p.Send(p.rw, BlockHeadersMsg, headers) } +// SendBlockBodies sends a batch of block contents to the remote peer. +func (p *peer) SendBlockBodies(bodies []*blockBody) error { + return p2p.Send(p.rw, BlockBodiesMsg, blockBodiesData(bodies)) +} + // SendNodeData sends a batch of arbitrary internal data, corresponding to the // hashes requested. func (p *peer) SendNodeData(data [][]byte) error { return p2p.Send(p.rw, NodeDataMsg, data) } +// SendReceipts sends a batch of transaction receipts, corresponding to the ones +// requested. +func (p *peer) SendReceipts(receipts []*types.Receipt) error { + return p2p.Send(p.rw, ReceiptsMsg, receipts) +} + // RequestHashes fetches a batch of hashes from a peer, starting at from, going // towards the genesis block. func (p *peer) RequestHashes(from common.Hash) error { @@ -205,6 +216,12 @@ func (p *peer) RequestNodeData(hashes []common.Hash) error { return p2p.Send(p.rw, GetNodeDataMsg, hashes) } +// RequestReceipts fetches a batch of transaction receipts from a remote node. +func (p *peer) RequestReceipts(hashes []common.Hash) error { + glog.V(logger.Debug).Infof("%v fetching %v receipts\n", p, len(hashes)) + return p2p.Send(p.rw, GetReceiptsMsg, hashes) +} + // Handshake executes the eth protocol handshake, negotiating version number, // network IDs, difficulties, head and genesis blocks. func (p *peer) Handshake(td *big.Int, head common.Hash, genesis common.Hash) error { diff --git a/eth/protocol.go b/eth/protocol.go index fcc5f21e2..c16223ccf 100644 --- a/eth/protocol.go +++ b/eth/protocol.go @@ -17,17 +17,29 @@ package eth import ( + "fmt" + "io" "math/big" "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/core/types" + "github.com/ethereum/go-ethereum/rlp" +) + +// Constants to match up protocol versions and messages +const ( + eth60 = 60 + eth61 = 61 + eth62 = 62 + eth63 = 63 + eth64 = 64 ) // Supported versions of the eth protocol (first is primary). -var ProtocolVersions = []uint{62, 61, 60} +var ProtocolVersions = []uint{eth64, eth63, eth62, eth61, eth60} // Number of implemented message corresponding to different protocol versions. -var ProtocolLengths = []uint64{13, 9, 8} +var ProtocolLengths = []uint64{15, 12, 8, 9, 8} const ( NetworkId = 1 @@ -37,23 +49,38 @@ const ( // eth protocol message codes const ( // Protocol messages belonging to eth/60 - StatusMsg = iota - NewBlockHashesMsg - TxMsg - GetBlockHashesMsg - BlockHashesMsg - GetBlocksMsg - BlocksMsg - NewBlockMsg + StatusMsg = 0x00 + NewBlockHashesMsg = 0x01 + TxMsg = 0x02 + GetBlockHashesMsg = 0x03 + BlockHashesMsg = 0x04 + GetBlocksMsg = 0x05 + BlocksMsg = 0x06 + NewBlockMsg = 0x07 - // Protocol messages belonging to eth/61 - GetBlockHashesFromNumberMsg + // Protocol messages belonging to eth/61 (extension of eth/60) + GetBlockHashesFromNumberMsg = 0x08 - // Protocol messages belonging to eth/62 - GetBlockHeadersMsg - BlockHeadersMsg - GetNodeDataMsg - NodeDataMsg + // Protocol messages belonging to eth/62 (new protocol from scratch) + // StatusMsg = 0x00 (uncomment after eth/61 deprecation) + // NewBlockHashesMsg = 0x01 (uncomment after eth/61 deprecation) + // TxMsg = 0x02 (uncomment after eth/61 deprecation) + GetBlockHeadersMsg = 0x03 + BlockHeadersMsg = 0x04 + GetBlockBodiesMsg = 0x05 + BlockBodiesMsg = 0x06 + // NewBlockMsg = 0x07 (uncomment after eth/61 deprecation) + + // Protocol messages belonging to eth/63 + GetNodeDataMsg = 0x0d + NodeDataMsg = 0x0e + GetReceiptsMsg = 0x0f + ReceiptsMsg = 0x10 + + // Protocol messages belonging to eth/64 + GetAcctProofMsg = 0x11 + GetStorageDataProof = 0x12 + Proof = 0x13 ) type errCode int @@ -111,6 +138,12 @@ type statusData struct { GenesisBlock common.Hash } +// newBlockHashesData is the network packet for the block announcements. +type newBlockHashesData []struct { + Hash common.Hash // Hash of one particular block being announced + Number uint64 // Number of one particular block being announced +} + // getBlockHashesData is the network packet for the hash based hash retrieval. type getBlockHashesData struct { Hash common.Hash @@ -124,12 +157,65 @@ type getBlockHashesFromNumberData struct { Amount uint64 } +// getBlockHeadersData represents a block header query. +type getBlockHeadersData struct { + Origin hashOrNumber // Block from which to retrieve headers + Amount uint64 // Maximum number of headers to retrieve + Skip uint64 // Blocks to skip between consecutive headers + Reverse bool // Query direction (false = rising towards latest, true = falling towards genesis) +} + +// hashOrNumber is a combined field for specifying an origin block. +type hashOrNumber struct { + Hash common.Hash // Block hash from which to retrieve headers (excludes Number) + Number uint64 // Block hash from which to retrieve headers (excludes Hash) +} + +// EncodeRLP is a specialized encoder for hashOrNumber to encode only one of the +// two contained union fields. +func (hn *hashOrNumber) EncodeRLP(w io.Writer) error { + if hn.Hash == (common.Hash{}) { + return rlp.Encode(w, hn.Number) + } + if hn.Number != 0 { + return fmt.Errorf("both origin hash (%x) and number (%d) provided", hn.Hash, hn.Number) + } + return rlp.Encode(w, hn.Hash) +} + +// DecodeRLP is a specialized decoder for hashOrNumber to decode the contents +// into either a block hash or a block number. +func (hn *hashOrNumber) DecodeRLP(s *rlp.Stream) error { + _, size, _ := s.Kind() + origin, err := s.Raw() + if err == nil { + switch { + case size == 32: + err = rlp.DecodeBytes(origin, &hn.Hash) + case size <= 8: + err = rlp.DecodeBytes(origin, &hn.Number) + default: + err = fmt.Errorf("invalid input size %d for origin", size) + } + } + return err +} + // newBlockData is the network packet for the block propagation message. type newBlockData struct { Block *types.Block TD *big.Int } +// blockBody represents the data content of a single block. +type blockBody struct { + Transactions []*types.Transaction // Transactions contained within a block + Uncles []*types.Header // Uncles contained within a block +} + +// blockBodiesData is the network packet for block content distribution. +type blockBodiesData []*blockBody + // nodeDataData is the network response packet for a node data retrieval. type nodeDataData []struct { Value []byte diff --git a/eth/protocol_test.go b/eth/protocol_test.go index 08c9b6a06..263088099 100644 --- a/eth/protocol_test.go +++ b/eth/protocol_test.go @@ -18,19 +18,16 @@ package eth import ( "crypto/rand" - "math/big" + "fmt" "sync" "testing" "time" "github.com/ethereum/go-ethereum/common" - "github.com/ethereum/go-ethereum/core" "github.com/ethereum/go-ethereum/core/types" "github.com/ethereum/go-ethereum/crypto" - "github.com/ethereum/go-ethereum/ethdb" - "github.com/ethereum/go-ethereum/event" "github.com/ethereum/go-ethereum/p2p" - "github.com/ethereum/go-ethereum/p2p/discover" + "github.com/ethereum/go-ethereum/rlp" ) func init() { @@ -40,8 +37,15 @@ func init() { var testAccount = crypto.NewKey(rand.Reader) -func TestStatusMsgErrors(t *testing.T) { - pm := newProtocolManagerForTesting(nil) +// Tests that handshake failures are detected and reported correctly. +func TestStatusMsgErrors60(t *testing.T) { testStatusMsgErrors(t, 60) } +func TestStatusMsgErrors61(t *testing.T) { testStatusMsgErrors(t, 61) } +func TestStatusMsgErrors62(t *testing.T) { testStatusMsgErrors(t, 62) } +func TestStatusMsgErrors63(t *testing.T) { testStatusMsgErrors(t, 63) } +func TestStatusMsgErrors64(t *testing.T) { testStatusMsgErrors(t, 64) } + +func testStatusMsgErrors(t *testing.T, protocol int) { + pm := newTestProtocolManager(0, nil, nil) td, currentBlock, genesis := pm.chainman.Status() defer pm.Stop() @@ -56,23 +60,23 @@ func TestStatusMsgErrors(t *testing.T) { }, { code: StatusMsg, data: statusData{10, NetworkId, td, currentBlock, genesis}, - wantError: errResp(ErrProtocolVersionMismatch, "10 (!= 0)"), + wantError: errResp(ErrProtocolVersionMismatch, "10 (!= %d)", protocol), }, { - code: StatusMsg, data: statusData{uint32(ProtocolVersions[0]), 999, td, currentBlock, genesis}, + code: StatusMsg, data: statusData{uint32(protocol), 999, td, currentBlock, genesis}, wantError: errResp(ErrNetworkIdMismatch, "999 (!= 1)"), }, { - code: StatusMsg, data: statusData{uint32(ProtocolVersions[0]), NetworkId, td, currentBlock, common.Hash{3}}, + code: StatusMsg, data: statusData{uint32(protocol), NetworkId, td, currentBlock, common.Hash{3}}, wantError: errResp(ErrGenesisBlockMismatch, "0300000000000000000000000000000000000000000000000000000000000000 (!= %x)", genesis), }, } for i, test := range tests { - p, errc := newTestPeer(pm) + p, errc := newTestPeer("peer", protocol, pm, false) // The send call might hang until reset because // the protocol might not read the payload. - go p2p.Send(p, test.code, test.data) + go p2p.Send(p.app, test.code, test.data) select { case err := <-errc: @@ -89,16 +93,21 @@ func TestStatusMsgErrors(t *testing.T) { } // This test checks that received transactions are added to the local pool. -func TestRecvTransactions(t *testing.T) { +func TestRecvTransactions60(t *testing.T) { testRecvTransactions(t, 60) } +func TestRecvTransactions61(t *testing.T) { testRecvTransactions(t, 61) } +func TestRecvTransactions62(t *testing.T) { testRecvTransactions(t, 62) } +func TestRecvTransactions63(t *testing.T) { testRecvTransactions(t, 63) } +func TestRecvTransactions64(t *testing.T) { testRecvTransactions(t, 64) } + +func testRecvTransactions(t *testing.T, protocol int) { txAdded := make(chan []*types.Transaction) - pm := newProtocolManagerForTesting(txAdded) - p, _ := newTestPeer(pm) + pm := newTestProtocolManager(0, nil, txAdded) + p, _ := newTestPeer("peer", protocol, pm, true) defer pm.Stop() defer p.close() - p.handshake(t) - tx := newtx(testAccount, 0, 0) - if err := p2p.Send(p, TxMsg, []interface{}{tx}); err != nil { + tx := newTestTransaction(testAccount, 0, 0) + if err := p2p.Send(p.app, TxMsg, []interface{}{tx}); err != nil { t.Fatalf("send error: %v", err) } select { @@ -114,15 +123,21 @@ func TestRecvTransactions(t *testing.T) { } // This test checks that pending transactions are sent. -func TestSendTransactions(t *testing.T) { - pm := newProtocolManagerForTesting(nil) +func TestSendTransactions60(t *testing.T) { testSendTransactions(t, 60) } +func TestSendTransactions61(t *testing.T) { testSendTransactions(t, 61) } +func TestSendTransactions62(t *testing.T) { testSendTransactions(t, 62) } +func TestSendTransactions63(t *testing.T) { testSendTransactions(t, 63) } +func TestSendTransactions64(t *testing.T) { testSendTransactions(t, 64) } + +func testSendTransactions(t *testing.T, protocol int) { + pm := newTestProtocolManager(0, nil, nil) defer pm.Stop() // Fill the pool with big transactions. const txsize = txsyncPackSize / 10 alltxs := make([]*types.Transaction, 100) for nonce := range alltxs { - alltxs[nonce] = newtx(testAccount, uint64(nonce), txsize) + alltxs[nonce] = newTestTransaction(testAccount, uint64(nonce), txsize) } pm.txpool.AddTransactions(alltxs) @@ -137,7 +152,7 @@ func TestSendTransactions(t *testing.T) { } for n := 0; n < len(alltxs) && !t.Failed(); { var txs []*types.Transaction - msg, err := p.ReadMsg() + msg, err := p.app.ReadMsg() if err != nil { t.Errorf("%v: read error: %v", p.Peer, err) } else if msg.Code != TxMsg { @@ -161,97 +176,53 @@ func TestSendTransactions(t *testing.T) { } } for i := 0; i < 3; i++ { - p, _ := newTestPeer(pm) - p.handshake(t) + p, _ := newTestPeer(fmt.Sprintf("peer #%d", i), protocol, pm, true) wg.Add(1) go checktxs(p) } wg.Wait() } -// testPeer wraps all peer-related data for tests. -type testPeer struct { - p2p.MsgReadWriter // writing to the test peer feeds the protocol - pipe *p2p.MsgPipeRW // the protocol read/writes on this end - pm *ProtocolManager - *peer -} - -func newProtocolManagerForTesting(txAdded chan<- []*types.Transaction) *ProtocolManager { - db, _ := ethdb.NewMemDatabase() - core.WriteTestNetGenesisBlock(db, 0) - var ( - em = new(event.TypeMux) - chain, _ = core.NewChainManager(db, core.FakePow{}, em) - txpool = &fakeTxPool{added: txAdded} - pm = NewProtocolManager(NetworkId, em, txpool, core.FakePow{}, chain) - ) - pm.Start() - return pm -} - -func newTestPeer(pm *ProtocolManager) (*testPeer, <-chan error) { - var id discover.NodeID - rand.Read(id[:]) - rw1, rw2 := p2p.MsgPipe() - peer := pm.newPeer(pm.protVer, pm.netId, p2p.NewPeer(id, "test peer", nil), rw2) - errc := make(chan error, 1) - go func() { - pm.newPeerCh <- peer - errc <- pm.handle(peer) - }() - return &testPeer{rw1, rw2, pm, peer}, errc -} - -func (p *testPeer) handshake(t *testing.T) { - td, currentBlock, genesis := p.pm.chainman.Status() - msg := &statusData{ - ProtocolVersion: uint32(p.pm.protVer), - NetworkId: uint32(p.pm.netId), - TD: td, - CurrentBlock: currentBlock, - GenesisBlock: genesis, +// Tests that the custom union field encoder and decoder works correctly. +func TestGetBlockHeadersDataEncodeDecode(t *testing.T) { + // Create a "random" hash for testing + var hash common.Hash + for i, _ := range hash { + hash[i] = byte(i) } - if err := p2p.ExpectMsg(p, StatusMsg, msg); err != nil { - t.Fatalf("status recv: %v", err) + // Assemble some table driven tests + tests := []struct { + packet *getBlockHeadersData + fail bool + }{ + // Providing the origin as either a hash or a number should both work + {fail: false, packet: &getBlockHeadersData{Origin: hashOrNumber{Number: 314}}}, + {fail: false, packet: &getBlockHeadersData{Origin: hashOrNumber{Hash: hash}}}, + + // Providing arbitrary query field should also work + {fail: false, packet: &getBlockHeadersData{Origin: hashOrNumber{Number: 314}, Amount: 314, Skip: 1, Reverse: true}}, + {fail: false, packet: &getBlockHeadersData{Origin: hashOrNumber{Hash: hash}, Amount: 314, Skip: 1, Reverse: true}}, + + // Providing both the origin hash and origin number must fail + {fail: true, packet: &getBlockHeadersData{Origin: hashOrNumber{Hash: hash, Number: 314}}}, } - if err := p2p.Send(p, StatusMsg, msg); err != nil { - t.Fatalf("status send: %v", err) + // Iterate over each of the tests and try to encode and then decode + for i, tt := range tests { + bytes, err := rlp.EncodeToBytes(tt.packet) + if err != nil && !tt.fail { + t.Fatalf("test %d: failed to encode packet: %v", i, err) + } else if err == nil && tt.fail { + t.Fatalf("test %d: encode should have failed", i) + } + if !tt.fail { + packet := new(getBlockHeadersData) + if err := rlp.DecodeBytes(bytes, packet); err != nil { + t.Fatalf("test %d: failed to decode packet: %v", i, err) + } + if packet.Origin.Hash != tt.packet.Origin.Hash || packet.Origin.Number != tt.packet.Origin.Number || packet.Amount != tt.packet.Amount || + packet.Skip != tt.packet.Skip || packet.Reverse != tt.packet.Reverse { + t.Fatalf("test %d: encode decode mismatch: have %+v, want %+v", i, packet, tt.packet) + } + } } } - -func (p *testPeer) close() { - p.pipe.Close() -} - -type fakeTxPool struct { - // all transactions are collected. - mu sync.Mutex - all []*types.Transaction - // if added is non-nil, it receives added transactions. - added chan<- []*types.Transaction -} - -func (pool *fakeTxPool) AddTransactions(txs []*types.Transaction) { - pool.mu.Lock() - defer pool.mu.Unlock() - pool.all = append(pool.all, txs...) - if pool.added != nil { - pool.added <- txs - } -} - -func (pool *fakeTxPool) GetTransactions() types.Transactions { - pool.mu.Lock() - defer pool.mu.Unlock() - txs := make([]*types.Transaction, len(pool.all)) - copy(txs, pool.all) - return types.Transactions(txs) -} - -func newtx(from *crypto.Key, nonce uint64, datasize int) *types.Transaction { - data := make([]byte, datasize) - tx := types.NewTransaction(nonce, common.Address{}, big.NewInt(0), big.NewInt(100000), big.NewInt(0), data) - tx, _ = tx.SignECDSA(from.PrivateKey) - return tx -} diff --git a/ethdb/memory_database.go b/ethdb/memory_database.go index 70b03dfad..d50f8f9d4 100644 --- a/ethdb/memory_database.go +++ b/ethdb/memory_database.go @@ -49,6 +49,14 @@ func (db *MemDatabase) Get(key []byte) ([]byte, error) { return db.db[string(key)], nil } +func (db *MemDatabase) Keys() [][]byte { + keys := [][]byte{} + for key, _ := range db.db { + keys = append(keys, []byte(key)) + } + return keys +} + /* func (db *MemDatabase) GetKeys() []*common.Key { data, _ := db.Get([]byte("KeyRing"))