From 2ed729d38e90154d1f23ebdf5a9f2808212276d8 Mon Sep 17 00:00:00 2001 From: gary rong Date: Wed, 21 Aug 2019 17:29:34 +0800 Subject: [PATCH] les: handler separation (#19639) les: handler separation --- core/blockchain.go | 22 + les/api.go | 14 +- les/api_backend.go | 2 +- les/api_test.go | 18 +- les/benchmark.go | 47 +- les/bloombits.go | 3 +- les/{backend.go => client.go} | 131 ++- les/client_handler.go | 401 ++++++++ les/commons.go | 69 +- les/costtracker.go | 11 +- les/distributor.go | 37 +- les/distributor_test.go | 2 +- les/fetcher.go | 75 +- les/fetcher_test.go | 168 --- les/handler.go | 1293 ------------------------ les/handler_test.go | 200 ++-- les/metrics.go | 86 +- les/odr.go | 5 +- les/odr_test.go | 38 +- les/peer.go | 48 +- les/peer_test.go | 80 +- les/request_test.go | 28 +- les/server.go | 360 +++---- les/server_handler.go | 921 +++++++++++++++++ les/serverpool.go | 61 +- les/sync.go | 71 +- les/sync_test.go | 17 +- les/{helper_test.go => test_helper.go} | 458 +++++---- les/ulc_test.go | 224 ++-- light/odr_util.go | 6 +- light/postprocess.go | 6 +- 31 files changed, 2377 insertions(+), 2525 deletions(-) rename les/{backend.go => client.go} (78%) create mode 100644 les/client_handler.go delete mode 100644 les/fetcher_test.go delete mode 100644 les/handler.go create mode 100644 les/server_handler.go rename les/{helper_test.go => test_helper.go} (55%) diff --git a/core/blockchain.go b/core/blockchain.go index 361fa8243..2fd373c7c 100644 --- a/core/blockchain.go +++ b/core/blockchain.go @@ -75,6 +75,7 @@ const ( bodyCacheLimit = 256 blockCacheLimit = 256 receiptsCacheLimit = 32 + txLookupCacheLimit = 1024 maxFutureBlocks = 256 maxTimeFutureBlocks = 30 badBlockLimit = 10 @@ -155,6 +156,7 @@ type BlockChain struct { bodyRLPCache *lru.Cache // Cache for the most recent block bodies in RLP encoded format receiptsCache *lru.Cache // Cache for the most recent receipts per block blockCache *lru.Cache // Cache for the most recent entire blocks + txLookupCache *lru.Cache // Cache for the most recent transaction lookup data. futureBlocks *lru.Cache // future blocks are blocks added for later processing quit chan struct{} // blockchain quit channel @@ -189,6 +191,7 @@ func NewBlockChain(db ethdb.Database, cacheConfig *CacheConfig, chainConfig *par bodyRLPCache, _ := lru.New(bodyCacheLimit) receiptsCache, _ := lru.New(receiptsCacheLimit) blockCache, _ := lru.New(blockCacheLimit) + txLookupCache, _ := lru.New(txLookupCacheLimit) futureBlocks, _ := lru.New(maxFutureBlocks) badBlocks, _ := lru.New(badBlockLimit) @@ -204,6 +207,7 @@ func NewBlockChain(db ethdb.Database, cacheConfig *CacheConfig, chainConfig *par bodyRLPCache: bodyRLPCache, receiptsCache: receiptsCache, blockCache: blockCache, + txLookupCache: txLookupCache, futureBlocks: futureBlocks, engine: engine, vmConfig: vmConfig, @@ -440,6 +444,7 @@ func (bc *BlockChain) SetHead(head uint64) error { bc.bodyRLPCache.Purge() bc.receiptsCache.Purge() bc.blockCache.Purge() + bc.txLookupCache.Purge() bc.futureBlocks.Purge() return bc.loadLastState() @@ -921,6 +926,7 @@ func (bc *BlockChain) truncateAncient(head uint64) error { bc.bodyRLPCache.Purge() bc.receiptsCache.Purge() bc.blockCache.Purge() + bc.txLookupCache.Purge() bc.futureBlocks.Purge() log.Info("Rewind ancient data", "number", head) @@ -2151,6 +2157,22 @@ func (bc *BlockChain) GetHeaderByNumber(number uint64) *types.Header { return bc.hc.GetHeaderByNumber(number) } +// GetTransactionLookup retrieves the lookup associate with the given transaction +// hash from the cache or database. +func (bc *BlockChain) GetTransactionLookup(hash common.Hash) *rawdb.LegacyTxLookupEntry { + // Short circuit if the txlookup already in the cache, retrieve otherwise + if lookup, exist := bc.txLookupCache.Get(hash); exist { + return lookup.(*rawdb.LegacyTxLookupEntry) + } + tx, blockHash, blockNumber, txIndex := rawdb.ReadTransaction(bc.db, hash) + if tx == nil { + return nil + } + lookup := &rawdb.LegacyTxLookupEntry{BlockHash: blockHash, BlockIndex: blockNumber, Index: txIndex} + bc.txLookupCache.Add(hash, lookup) + return lookup +} + // Config retrieves the chain's fork configuration. func (bc *BlockChain) Config() *params.ChainConfig { return bc.chainConfig } diff --git a/les/api.go b/les/api.go index e20f72cad..bbef771f0 100644 --- a/les/api.go +++ b/les/api.go @@ -30,15 +30,11 @@ var ( // PrivateLightAPI provides an API to access the LES light server or light client. type PrivateLightAPI struct { backend *lesCommons - reg *checkpointOracle } // NewPrivateLightAPI creates a new LES service API. -func NewPrivateLightAPI(backend *lesCommons, reg *checkpointOracle) *PrivateLightAPI { - return &PrivateLightAPI{ - backend: backend, - reg: reg, - } +func NewPrivateLightAPI(backend *lesCommons) *PrivateLightAPI { + return &PrivateLightAPI{backend: backend} } // LatestCheckpoint returns the latest local checkpoint package. @@ -67,7 +63,7 @@ func (api *PrivateLightAPI) LatestCheckpoint() ([4]string, error) { // result[2], 32 bytes hex encoded latest section bloom trie root hash func (api *PrivateLightAPI) GetCheckpoint(index uint64) ([3]string, error) { var res [3]string - cp := api.backend.getLocalCheckpoint(index) + cp := api.backend.localCheckpoint(index) if cp.Empty() { return res, errNoCheckpoint } @@ -77,8 +73,8 @@ func (api *PrivateLightAPI) GetCheckpoint(index uint64) ([3]string, error) { // GetCheckpointContractAddress returns the contract contract address in hex format. func (api *PrivateLightAPI) GetCheckpointContractAddress() (string, error) { - if api.reg == nil { + if api.backend.oracle == nil { return "", errNotActivated } - return api.reg.config.Address.Hex(), nil + return api.backend.oracle.config.Address.Hex(), nil } diff --git a/les/api_backend.go b/les/api_backend.go index 07601c242..5cd432dcf 100644 --- a/les/api_backend.go +++ b/les/api_backend.go @@ -54,7 +54,7 @@ func (b *LesApiBackend) CurrentBlock() *types.Block { } func (b *LesApiBackend) SetHead(number uint64) { - b.eth.protocolManager.downloader.Cancel() + b.eth.handler.downloader.Cancel() b.eth.blockchain.SetHead(number) } diff --git a/les/api_test.go b/les/api_test.go index 6e622313c..7d3b4ce5d 100644 --- a/les/api_test.go +++ b/les/api_test.go @@ -78,19 +78,16 @@ func TestCapacityAPI10(t *testing.T) { // while connected and going back and forth between free and priority mode with // the supplied API calls is also thoroughly tested. func testCapacityAPI(t *testing.T, clientCount int) { + // Skip test if no data dir specified if testServerDataDir == "" { - // Skip test if no data dir specified return } - for !testSim(t, 1, clientCount, []string{testServerDataDir}, nil, func(ctx context.Context, net *simulations.Network, servers []*simulations.Node, clients []*simulations.Node) bool { if len(servers) != 1 { t.Fatalf("Invalid number of servers: %d", len(servers)) } server := servers[0] - clientRpcClients := make([]*rpc.Client, len(clients)) - serverRpcClient, err := server.Client() if err != nil { t.Fatalf("Failed to obtain rpc client: %v", err) @@ -105,13 +102,13 @@ func testCapacityAPI(t *testing.T, clientCount int) { } freeIdx := rand.Intn(len(clients)) + clientRpcClients := make([]*rpc.Client, len(clients)) for i, client := range clients { var err error clientRpcClients[i], err = client.Client() if err != nil { t.Fatalf("Failed to obtain rpc client: %v", err) } - t.Log("connecting client", i) if i != freeIdx { setCapacity(ctx, t, serverRpcClient, client.ID(), testCap/uint64(len(clients))) @@ -138,10 +135,13 @@ func testCapacityAPI(t *testing.T, clientCount int) { reqCount := make([]uint64, len(clientRpcClients)) + // Send light request like crazy. for i, c := range clientRpcClients { wg.Add(1) i, c := i, c go func() { + defer wg.Done() + queue := make(chan struct{}, 100) reqCount[i] = 0 for { @@ -149,10 +149,8 @@ func testCapacityAPI(t *testing.T, clientCount int) { case queue <- struct{}{}: select { case <-stop: - wg.Done() return case <-ctx.Done(): - wg.Done() return default: wg.Add(1) @@ -169,10 +167,8 @@ func testCapacityAPI(t *testing.T, clientCount int) { }() } case <-stop: - wg.Done() return case <-ctx.Done(): - wg.Done() return } } @@ -313,12 +309,10 @@ func getHead(ctx context.Context, t *testing.T, client *rpc.Client) (uint64, com } func testRequest(ctx context.Context, t *testing.T, client *rpc.Client) bool { - //res := make(map[string]interface{}) var res string var addr common.Address rand.Read(addr[:]) c, _ := context.WithTimeout(ctx, time.Second*12) - // if err := client.CallContext(ctx, &res, "eth_getProof", addr, nil, "latest"); err != nil { err := client.CallContext(c, &res, "eth_getBalance", addr, "latest") if err != nil { t.Log("request error:", err) @@ -418,7 +412,6 @@ func NewNetwork() (*simulations.Network, func(), error) { adapterTeardown() net.Shutdown() } - return net, teardown, nil } @@ -516,7 +509,6 @@ func newLesServerService(ctx *adapters.ServiceContext) (node.Service, error) { if err != nil { return nil, err } - server, err := NewLesServer(ethereum, &config) if err != nil { return nil, err diff --git a/les/benchmark.go b/les/benchmark.go index 74dfcf7c9..fbba53e4d 100644 --- a/les/benchmark.go +++ b/les/benchmark.go @@ -39,7 +39,7 @@ import ( // requestBenchmark is an interface for different randomized request generators type requestBenchmark interface { // init initializes the generator for generating the given number of randomized requests - init(pm *ProtocolManager, count int) error + init(h *serverHandler, count int) error // request initiates sending a single request to the given peer request(peer *peer, index int) error } @@ -52,10 +52,10 @@ type benchmarkBlockHeaders struct { hashes []common.Hash } -func (b *benchmarkBlockHeaders) init(pm *ProtocolManager, count int) error { +func (b *benchmarkBlockHeaders) init(h *serverHandler, count int) error { d := int64(b.amount-1) * int64(b.skip+1) b.offset = 0 - b.randMax = pm.blockchain.CurrentHeader().Number.Int64() + 1 - d + b.randMax = h.blockchain.CurrentHeader().Number.Int64() + 1 - d if b.randMax < 0 { return fmt.Errorf("chain is too short") } @@ -65,7 +65,7 @@ func (b *benchmarkBlockHeaders) init(pm *ProtocolManager, count int) error { if b.byHash { b.hashes = make([]common.Hash, count) for i := range b.hashes { - b.hashes[i] = rawdb.ReadCanonicalHash(pm.chainDb, uint64(b.offset+rand.Int63n(b.randMax))) + b.hashes[i] = rawdb.ReadCanonicalHash(h.chainDb, uint64(b.offset+rand.Int63n(b.randMax))) } } return nil @@ -85,11 +85,11 @@ type benchmarkBodiesOrReceipts struct { hashes []common.Hash } -func (b *benchmarkBodiesOrReceipts) init(pm *ProtocolManager, count int) error { - randMax := pm.blockchain.CurrentHeader().Number.Int64() + 1 +func (b *benchmarkBodiesOrReceipts) init(h *serverHandler, count int) error { + randMax := h.blockchain.CurrentHeader().Number.Int64() + 1 b.hashes = make([]common.Hash, count) for i := range b.hashes { - b.hashes[i] = rawdb.ReadCanonicalHash(pm.chainDb, uint64(rand.Int63n(randMax))) + b.hashes[i] = rawdb.ReadCanonicalHash(h.chainDb, uint64(rand.Int63n(randMax))) } return nil } @@ -108,8 +108,8 @@ type benchmarkProofsOrCode struct { headHash common.Hash } -func (b *benchmarkProofsOrCode) init(pm *ProtocolManager, count int) error { - b.headHash = pm.blockchain.CurrentHeader().Hash() +func (b *benchmarkProofsOrCode) init(h *serverHandler, count int) error { + b.headHash = h.blockchain.CurrentHeader().Hash() return nil } @@ -130,11 +130,11 @@ type benchmarkHelperTrie struct { sectionCount, headNum uint64 } -func (b *benchmarkHelperTrie) init(pm *ProtocolManager, count int) error { +func (b *benchmarkHelperTrie) init(h *serverHandler, count int) error { if b.bloom { - b.sectionCount, b.headNum, _ = pm.server.bloomTrieIndexer.Sections() + b.sectionCount, b.headNum, _ = h.server.bloomTrieIndexer.Sections() } else { - b.sectionCount, _, _ = pm.server.chtIndexer.Sections() + b.sectionCount, _, _ = h.server.chtIndexer.Sections() b.headNum = b.sectionCount*params.CHTFrequency - 1 } if b.sectionCount == 0 { @@ -170,7 +170,7 @@ type benchmarkTxSend struct { txs types.Transactions } -func (b *benchmarkTxSend) init(pm *ProtocolManager, count int) error { +func (b *benchmarkTxSend) init(h *serverHandler, count int) error { key, _ := crypto.GenerateKey() addr := crypto.PubkeyToAddress(key.PublicKey) signer := types.NewEIP155Signer(big.NewInt(18)) @@ -196,7 +196,7 @@ func (b *benchmarkTxSend) request(peer *peer, index int) error { // benchmarkTxStatus implements requestBenchmark type benchmarkTxStatus struct{} -func (b *benchmarkTxStatus) init(pm *ProtocolManager, count int) error { +func (b *benchmarkTxStatus) init(h *serverHandler, count int) error { return nil } @@ -217,7 +217,7 @@ type benchmarkSetup struct { // runBenchmark runs a benchmark cycle for all benchmark types in the specified // number of passes -func (pm *ProtocolManager) runBenchmark(benchmarks []requestBenchmark, passCount int, targetTime time.Duration) []*benchmarkSetup { +func (h *serverHandler) runBenchmark(benchmarks []requestBenchmark, passCount int, targetTime time.Duration) []*benchmarkSetup { setup := make([]*benchmarkSetup, len(benchmarks)) for i, b := range benchmarks { setup[i] = &benchmarkSetup{req: b} @@ -239,7 +239,7 @@ func (pm *ProtocolManager) runBenchmark(benchmarks []requestBenchmark, passCount if next.totalTime > 0 { count = int(uint64(next.totalCount) * uint64(targetTime) / uint64(next.totalTime)) } - if err := pm.measure(next, count); err != nil { + if err := h.measure(next, count); err != nil { next.err = err } } @@ -275,14 +275,15 @@ func (m *meteredPipe) WriteMsg(msg p2p.Msg) error { // measure runs a benchmark for a single type in a single pass, with the given // number of requests -func (pm *ProtocolManager) measure(setup *benchmarkSetup, count int) error { +func (h *serverHandler) measure(setup *benchmarkSetup, count int) error { clientPipe, serverPipe := p2p.MsgPipe() clientMeteredPipe := &meteredPipe{rw: clientPipe} serverMeteredPipe := &meteredPipe{rw: serverPipe} var id enode.ID rand.Read(id[:]) - clientPeer := pm.newPeer(lpv2, NetworkId, p2p.NewPeer(id, "client", nil), clientMeteredPipe) - serverPeer := pm.newPeer(lpv2, NetworkId, p2p.NewPeer(id, "server", nil), serverMeteredPipe) + + clientPeer := newPeer(lpv2, NetworkId, false, p2p.NewPeer(id, "client", nil), clientMeteredPipe) + serverPeer := newPeer(lpv2, NetworkId, false, p2p.NewPeer(id, "server", nil), serverMeteredPipe) serverPeer.sendQueue = newExecQueue(count) serverPeer.announceType = announceTypeNone serverPeer.fcCosts = make(requestCostTable) @@ -291,10 +292,10 @@ func (pm *ProtocolManager) measure(setup *benchmarkSetup, count int) error { serverPeer.fcCosts[code] = c } serverPeer.fcParams = flowcontrol.ServerParams{BufLimit: 1, MinRecharge: 1} - serverPeer.fcClient = flowcontrol.NewClientNode(pm.server.fcManager, serverPeer.fcParams) + serverPeer.fcClient = flowcontrol.NewClientNode(h.server.fcManager, serverPeer.fcParams) defer serverPeer.fcClient.Disconnect() - if err := setup.req.init(pm, count); err != nil { + if err := setup.req.init(h, count); err != nil { return err } @@ -311,7 +312,7 @@ func (pm *ProtocolManager) measure(setup *benchmarkSetup, count int) error { }() go func() { for i := 0; i < count; i++ { - if err := pm.handleMsg(serverPeer); err != nil { + if err := h.handleMsg(serverPeer); err != nil { errCh <- err return } @@ -336,7 +337,7 @@ func (pm *ProtocolManager) measure(setup *benchmarkSetup, count int) error { if err != nil { return err } - case <-pm.quitSync: + case <-h.closeCh: clientPipe.Close() serverPipe.Close() return fmt.Errorf("Benchmark cancelled") diff --git a/les/bloombits.go b/les/bloombits.go index aea0fcd5f..a98524ce2 100644 --- a/les/bloombits.go +++ b/les/bloombits.go @@ -46,9 +46,10 @@ const ( func (eth *LightEthereum) startBloomHandlers(sectionSize uint64) { for i := 0; i < bloomServiceThreads; i++ { go func() { + defer eth.wg.Done() for { select { - case <-eth.shutdownChan: + case <-eth.closeCh: return case request := <-eth.bloomRequests: diff --git a/les/backend.go b/les/client.go similarity index 78% rename from les/backend.go rename to les/client.go index c067afaea..461baf645 100644 --- a/les/backend.go +++ b/les/client.go @@ -19,8 +19,6 @@ package les import ( "fmt" - "sync" - "time" "github.com/ethereum/go-ethereum/accounts" "github.com/ethereum/go-ethereum/accounts/abi/bind" @@ -42,7 +40,7 @@ import ( "github.com/ethereum/go-ethereum/log" "github.com/ethereum/go-ethereum/node" "github.com/ethereum/go-ethereum/p2p" - "github.com/ethereum/go-ethereum/p2p/discv5" + "github.com/ethereum/go-ethereum/p2p/enode" "github.com/ethereum/go-ethereum/params" "github.com/ethereum/go-ethereum/rpc" ) @@ -50,33 +48,23 @@ import ( type LightEthereum struct { lesCommons - odr *LesOdr - chainConfig *params.ChainConfig - // Channel for shutting down the service - shutdownChan chan bool - - // Handlers - peers *peerSet + reqDist *requestDistributor + retriever *retrieveManager + odr *LesOdr + relay *lesTxRelay + handler *clientHandler txPool *light.TxPool blockchain *light.LightChain serverPool *serverPool - reqDist *requestDistributor - retriever *retrieveManager - relay *lesTxRelay bloomRequests chan chan *bloombits.Retrieval // Channel receiving bloom data retrieval requests - bloomIndexer *core.ChainIndexer - - ApiBackend *LesApiBackend + bloomIndexer *core.ChainIndexer // Bloom indexer operating during block imports + ApiBackend *LesApiBackend eventMux *event.TypeMux engine consensus.Engine accountManager *accounts.Manager - - networkId uint64 - netRPCService *ethapi.PublicNetAPI - - wg sync.WaitGroup + netRPCService *ethapi.PublicNetAPI } func New(ctx *node.ServiceContext, config *eth.Config) (*LightEthereum, error) { @@ -91,26 +79,24 @@ func New(ctx *node.ServiceContext, config *eth.Config) (*LightEthereum, error) { log.Info("Initialised chain configuration", "config", chainConfig) peers := newPeerSet() - quitSync := make(chan struct{}) - leth := &LightEthereum{ lesCommons: lesCommons{ - chainDb: chainDb, - config: config, - iConfig: light.DefaultClientIndexerConfig, + genesis: genesisHash, + config: config, + chainConfig: chainConfig, + iConfig: light.DefaultClientIndexerConfig, + chainDb: chainDb, + peers: peers, + closeCh: make(chan struct{}), }, - chainConfig: chainConfig, eventMux: ctx.EventMux, - peers: peers, - reqDist: newRequestDistributor(peers, quitSync, &mclock.System{}), + reqDist: newRequestDistributor(peers, &mclock.System{}), accountManager: ctx.AccountManager, engine: eth.CreateConsensusEngine(ctx, chainConfig, &config.Ethash, nil, false, chainDb), - shutdownChan: make(chan bool), - networkId: config.NetworkId, bloomRequests: make(chan chan *bloombits.Retrieval), bloomIndexer: eth.NewBloomIndexer(chainDb, params.BloomBitsBlocksClient, params.HelperTrieConfirmations), + serverPool: newServerPool(chainDb, config.UltraLightServers), } - leth.serverPool = newServerPool(chainDb, quitSync, &leth.wg, leth.config.UltraLightServers) leth.retriever = newRetrieveManager(peers, leth.reqDist, leth.serverPool) leth.relay = newLesTxRelay(peers, leth.retriever) @@ -128,11 +114,26 @@ func New(ctx *node.ServiceContext, config *eth.Config) (*LightEthereum, error) { if leth.blockchain, err = light.NewLightChain(leth.odr, leth.chainConfig, leth.engine, checkpoint); err != nil { return nil, err } + leth.chainReader = leth.blockchain + leth.txPool = light.NewTxPool(leth.chainConfig, leth.blockchain, leth.relay) + + // Set up checkpoint oracle. + oracle := config.CheckpointOracle + if oracle == nil { + oracle = params.CheckpointOracles[genesisHash] + } + leth.oracle = newCheckpointOracle(oracle, leth.localCheckpoint) + // Note: AddChildIndexer starts the update process for the child leth.bloomIndexer.AddChildIndexer(leth.bloomTrieIndexer) leth.chtIndexer.Start(leth.blockchain) leth.bloomIndexer.Start(leth.blockchain) + leth.handler = newClientHandler(config.UltraLightServers, config.UltraLightFraction, checkpoint, leth) + if leth.handler.ulc != nil { + log.Warn("Ultra light client is enabled", "trustedNodes", len(leth.handler.ulc.keys), "minTrustedFraction", leth.handler.ulc.fraction) + leth.blockchain.DisableCheckFreq() + } // Rewind the chain in case of an incompatible config upgrade. if compat, ok := genesisErr.(*params.ConfigCompatError); ok { log.Warn("Rewinding chain to upgrade configuration", "err", compat) @@ -140,41 +141,16 @@ func New(ctx *node.ServiceContext, config *eth.Config) (*LightEthereum, error) { rawdb.WriteChainConfig(chainDb, genesisHash, chainConfig) } - leth.txPool = light.NewTxPool(leth.chainConfig, leth.blockchain, leth.relay) leth.ApiBackend = &LesApiBackend{ctx.ExtRPCEnabled(), leth, nil} - gpoParams := config.GPO if gpoParams.Default == nil { gpoParams.Default = config.Miner.GasPrice } leth.ApiBackend.gpo = gasprice.NewOracle(leth.ApiBackend, gpoParams) - oracle := config.CheckpointOracle - if oracle == nil { - oracle = params.CheckpointOracles[genesisHash] - } - registrar := newCheckpointOracle(oracle, leth.getLocalCheckpoint) - if leth.protocolManager, err = NewProtocolManager(leth.chainConfig, checkpoint, light.DefaultClientIndexerConfig, config.UltraLightServers, config.UltraLightFraction, true, config.NetworkId, leth.eventMux, leth.peers, leth.blockchain, nil, chainDb, leth.odr, leth.serverPool, registrar, quitSync, &leth.wg, nil); err != nil { - return nil, err - } - if leth.protocolManager.ulc != nil { - log.Warn("Ultra light client is enabled", "servers", len(config.UltraLightServers), "fraction", config.UltraLightFraction) - leth.blockchain.DisableCheckFreq() - } return leth, nil } -func lesTopic(genesisHash common.Hash, protocolVersion uint) discv5.Topic { - var name string - switch protocolVersion { - case lpv2: - name = "LES2" - default: - panic(nil) - } - return discv5.Topic(name + "@" + common.Bytes2Hex(genesisHash.Bytes()[0:8])) -} - type LightDummyAPI struct{} // Etherbase is the address that mining rewards will be send to @@ -209,7 +185,7 @@ func (s *LightEthereum) APIs() []rpc.API { }, { Namespace: "eth", Version: "1.0", - Service: downloader.NewPublicDownloaderAPI(s.protocolManager.downloader, s.eventMux), + Service: downloader.NewPublicDownloaderAPI(s.handler.downloader, s.eventMux), Public: true, }, { Namespace: "eth", @@ -224,7 +200,7 @@ func (s *LightEthereum) APIs() []rpc.API { }, { Namespace: "les", Version: "1.0", - Service: NewPrivateLightAPI(&s.lesCommons, s.protocolManager.reg), + Service: NewPrivateLightAPI(&s.lesCommons), Public: false, }, }...) @@ -238,54 +214,63 @@ func (s *LightEthereum) BlockChain() *light.LightChain { return s.blockchai func (s *LightEthereum) TxPool() *light.TxPool { return s.txPool } func (s *LightEthereum) Engine() consensus.Engine { return s.engine } func (s *LightEthereum) LesVersion() int { return int(ClientProtocolVersions[0]) } -func (s *LightEthereum) Downloader() *downloader.Downloader { return s.protocolManager.downloader } +func (s *LightEthereum) Downloader() *downloader.Downloader { return s.handler.downloader } func (s *LightEthereum) EventMux() *event.TypeMux { return s.eventMux } // Protocols implements node.Service, returning all the currently configured // network protocols to start. func (s *LightEthereum) Protocols() []p2p.Protocol { - return s.makeProtocols(ClientProtocolVersions) + return s.makeProtocols(ClientProtocolVersions, s.handler.runPeer, func(id enode.ID) interface{} { + if p := s.peers.Peer(peerIdToString(id)); p != nil { + return p.Info() + } + return nil + }) } // Start implements node.Service, starting all internal goroutines needed by the -// Ethereum protocol implementation. +// light ethereum protocol implementation. func (s *LightEthereum) Start(srvr *p2p.Server) error { log.Warn("Light client mode is an experimental feature") + + // Start bloom request workers. + s.wg.Add(bloomServiceThreads) s.startBloomHandlers(params.BloomBitsBlocksClient) - s.netRPCService = ethapi.NewPublicNetAPI(srvr, s.networkId) + + s.netRPCService = ethapi.NewPublicNetAPI(srvr, s.config.NetworkId) + // clients are searching for the first advertised protocol in the list protocolVersion := AdvertiseProtocolVersions[0] s.serverPool.start(srvr, lesTopic(s.blockchain.Genesis().Hash(), protocolVersion)) - s.protocolManager.Start(s.config.LightPeers) return nil } // Stop implements node.Service, terminating all internal goroutines used by the // Ethereum protocol. func (s *LightEthereum) Stop() error { + close(s.closeCh) + s.peers.Close() + s.reqDist.close() s.odr.Stop() s.relay.Stop() s.bloomIndexer.Close() s.chtIndexer.Close() s.blockchain.Stop() - s.protocolManager.Stop() + s.handler.stop() s.txPool.Stop() s.engine.Close() - s.eventMux.Stop() - - time.Sleep(time.Millisecond * 200) + s.serverPool.stop() s.chainDb.Close() - close(s.shutdownChan) - + s.wg.Wait() + log.Info("Light ethereum stopped") return nil } // SetClient sets the rpc client and binds the registrar contract. func (s *LightEthereum) SetContractBackend(backend bind.ContractBackend) { - // Short circuit if registrar is nil - if s.protocolManager.reg == nil { + if s.oracle == nil { return } - s.protocolManager.reg.start(backend) + s.oracle.start(backend) } diff --git a/les/client_handler.go b/les/client_handler.go new file mode 100644 index 000000000..aff05ddbc --- /dev/null +++ b/les/client_handler.go @@ -0,0 +1,401 @@ +// Copyright 2019 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package les + +import ( + "math/big" + "sync" + "time" + + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/common/mclock" + "github.com/ethereum/go-ethereum/core/types" + "github.com/ethereum/go-ethereum/eth/downloader" + "github.com/ethereum/go-ethereum/light" + "github.com/ethereum/go-ethereum/log" + "github.com/ethereum/go-ethereum/p2p" + "github.com/ethereum/go-ethereum/params" +) + +// clientHandler is responsible for receiving and processing all incoming server +// responses. +type clientHandler struct { + ulc *ulc + checkpoint *params.TrustedCheckpoint + fetcher *lightFetcher + downloader *downloader.Downloader + backend *LightEthereum + + closeCh chan struct{} + wg sync.WaitGroup // WaitGroup used to track all connected peers. +} + +func newClientHandler(ulcServers []string, ulcFraction int, checkpoint *params.TrustedCheckpoint, backend *LightEthereum) *clientHandler { + handler := &clientHandler{ + backend: backend, + closeCh: make(chan struct{}), + } + if ulcServers != nil { + ulc, err := newULC(ulcServers, ulcFraction) + if err != nil { + log.Error("Failed to initialize ultra light client") + } + handler.ulc = ulc + log.Info("Enable ultra light client mode") + } + var height uint64 + if checkpoint != nil { + height = (checkpoint.SectionIndex+1)*params.CHTFrequency - 1 + } + handler.fetcher = newLightFetcher(handler) + handler.downloader = downloader.New(height, backend.chainDb, nil, backend.eventMux, nil, backend.blockchain, handler.removePeer) + handler.backend.peers.notify((*downloaderPeerNotify)(handler)) + return handler +} + +func (h *clientHandler) stop() { + close(h.closeCh) + h.downloader.Terminate() + h.fetcher.close() + h.wg.Wait() +} + +// runPeer is the p2p protocol run function for the given version. +func (h *clientHandler) runPeer(version uint, p *p2p.Peer, rw p2p.MsgReadWriter) error { + trusted := false + if h.ulc != nil { + trusted = h.ulc.trusted(p.ID()) + } + peer := newPeer(int(version), h.backend.config.NetworkId, trusted, p, newMeteredMsgWriter(rw, int(version))) + peer.poolEntry = h.backend.serverPool.connect(peer, peer.Node()) + if peer.poolEntry == nil { + return p2p.DiscRequested + } + h.wg.Add(1) + defer h.wg.Done() + err := h.handle(peer) + h.backend.serverPool.disconnect(peer.poolEntry) + return err +} + +func (h *clientHandler) handle(p *peer) error { + if h.backend.peers.Len() >= h.backend.config.LightPeers && !p.Peer.Info().Network.Trusted { + return p2p.DiscTooManyPeers + } + p.Log().Debug("Light Ethereum peer connected", "name", p.Name()) + + // Execute the LES handshake + var ( + head = h.backend.blockchain.CurrentHeader() + hash = head.Hash() + number = head.Number.Uint64() + td = h.backend.blockchain.GetTd(hash, number) + ) + if err := p.Handshake(td, hash, number, h.backend.blockchain.Genesis().Hash(), nil); err != nil { + p.Log().Debug("Light Ethereum handshake failed", "err", err) + return err + } + // Register the peer locally + if err := h.backend.peers.Register(p); err != nil { + p.Log().Error("Light Ethereum peer registration failed", "err", err) + return err + } + serverConnectionGauge.Update(int64(h.backend.peers.Len())) + + connectedAt := mclock.Now() + defer func() { + h.backend.peers.Unregister(p.id) + connectionTimer.Update(time.Duration(mclock.Now() - connectedAt)) + serverConnectionGauge.Update(int64(h.backend.peers.Len())) + }() + + h.fetcher.announce(p, p.headInfo) + + // pool entry can be nil during the unit test. + if p.poolEntry != nil { + h.backend.serverPool.registered(p.poolEntry) + } + // Spawn a main loop to handle all incoming messages. + for { + if err := h.handleMsg(p); err != nil { + p.Log().Debug("Light Ethereum message handling failed", "err", err) + p.fcServer.DumpLogs() + return err + } + } +} + +// handleMsg is invoked whenever an inbound message is received from a remote +// peer. The remote connection is torn down upon returning any error. +func (h *clientHandler) handleMsg(p *peer) error { + // Read the next message from the remote peer, and ensure it's fully consumed + msg, err := p.rw.ReadMsg() + if err != nil { + return err + } + p.Log().Trace("Light Ethereum message arrived", "code", msg.Code, "bytes", msg.Size) + + if msg.Size > ProtocolMaxMsgSize { + return errResp(ErrMsgTooLarge, "%v > %v", msg.Size, ProtocolMaxMsgSize) + } + defer msg.Discard() + + var deliverMsg *Msg + + // Handle the message depending on its contents + switch msg.Code { + case AnnounceMsg: + p.Log().Trace("Received announce message") + var req announceData + if err := msg.Decode(&req); err != nil { + return errResp(ErrDecode, "%v: %v", msg, err) + } + if err := req.sanityCheck(); err != nil { + return err + } + update, size := req.Update.decode() + if p.rejectUpdate(size) { + return errResp(ErrRequestRejected, "") + } + p.updateFlowControl(update) + + if req.Hash != (common.Hash{}) { + if p.announceType == announceTypeNone { + return errResp(ErrUnexpectedResponse, "") + } + if p.announceType == announceTypeSigned { + if err := req.checkSignature(p.ID(), update); err != nil { + p.Log().Trace("Invalid announcement signature", "err", err) + return err + } + p.Log().Trace("Valid announcement signature") + } + p.Log().Trace("Announce message content", "number", req.Number, "hash", req.Hash, "td", req.Td, "reorg", req.ReorgDepth) + h.fetcher.announce(p, &req) + } + case BlockHeadersMsg: + p.Log().Trace("Received block header response message") + var resp struct { + ReqID, BV uint64 + Headers []*types.Header + } + if err := msg.Decode(&resp); err != nil { + return errResp(ErrDecode, "msg %v: %v", msg, err) + } + p.fcServer.ReceivedReply(resp.ReqID, resp.BV) + if h.fetcher.requestedID(resp.ReqID) { + h.fetcher.deliverHeaders(p, resp.ReqID, resp.Headers) + } else { + if err := h.downloader.DeliverHeaders(p.id, resp.Headers); err != nil { + log.Debug("Failed to deliver headers", "err", err) + } + } + case BlockBodiesMsg: + p.Log().Trace("Received block bodies response") + var resp struct { + ReqID, BV uint64 + Data []*types.Body + } + if err := msg.Decode(&resp); err != nil { + return errResp(ErrDecode, "msg %v: %v", msg, err) + } + p.fcServer.ReceivedReply(resp.ReqID, resp.BV) + deliverMsg = &Msg{ + MsgType: MsgBlockBodies, + ReqID: resp.ReqID, + Obj: resp.Data, + } + case CodeMsg: + p.Log().Trace("Received code response") + var resp struct { + ReqID, BV uint64 + Data [][]byte + } + if err := msg.Decode(&resp); err != nil { + return errResp(ErrDecode, "msg %v: %v", msg, err) + } + p.fcServer.ReceivedReply(resp.ReqID, resp.BV) + deliverMsg = &Msg{ + MsgType: MsgCode, + ReqID: resp.ReqID, + Obj: resp.Data, + } + case ReceiptsMsg: + p.Log().Trace("Received receipts response") + var resp struct { + ReqID, BV uint64 + Receipts []types.Receipts + } + if err := msg.Decode(&resp); err != nil { + return errResp(ErrDecode, "msg %v: %v", msg, err) + } + p.fcServer.ReceivedReply(resp.ReqID, resp.BV) + deliverMsg = &Msg{ + MsgType: MsgReceipts, + ReqID: resp.ReqID, + Obj: resp.Receipts, + } + case ProofsV2Msg: + p.Log().Trace("Received les/2 proofs response") + var resp struct { + ReqID, BV uint64 + Data light.NodeList + } + if err := msg.Decode(&resp); err != nil { + return errResp(ErrDecode, "msg %v: %v", msg, err) + } + p.fcServer.ReceivedReply(resp.ReqID, resp.BV) + deliverMsg = &Msg{ + MsgType: MsgProofsV2, + ReqID: resp.ReqID, + Obj: resp.Data, + } + case HelperTrieProofsMsg: + p.Log().Trace("Received helper trie proof response") + var resp struct { + ReqID, BV uint64 + Data HelperTrieResps + } + if err := msg.Decode(&resp); err != nil { + return errResp(ErrDecode, "msg %v: %v", msg, err) + } + p.fcServer.ReceivedReply(resp.ReqID, resp.BV) + deliverMsg = &Msg{ + MsgType: MsgHelperTrieProofs, + ReqID: resp.ReqID, + Obj: resp.Data, + } + case TxStatusMsg: + p.Log().Trace("Received tx status response") + var resp struct { + ReqID, BV uint64 + Status []light.TxStatus + } + if err := msg.Decode(&resp); err != nil { + return errResp(ErrDecode, "msg %v: %v", msg, err) + } + p.fcServer.ReceivedReply(resp.ReqID, resp.BV) + deliverMsg = &Msg{ + MsgType: MsgTxStatus, + ReqID: resp.ReqID, + Obj: resp.Status, + } + case StopMsg: + p.freezeServer(true) + h.backend.retriever.frozen(p) + p.Log().Debug("Service stopped") + case ResumeMsg: + var bv uint64 + if err := msg.Decode(&bv); err != nil { + return errResp(ErrDecode, "msg %v: %v", msg, err) + } + p.fcServer.ResumeFreeze(bv) + p.freezeServer(false) + p.Log().Debug("Service resumed") + default: + p.Log().Trace("Received invalid message", "code", msg.Code) + return errResp(ErrInvalidMsgCode, "%v", msg.Code) + } + // Deliver the received response to retriever. + if deliverMsg != nil { + if err := h.backend.retriever.deliver(p, deliverMsg); err != nil { + p.responseErrors++ + if p.responseErrors > maxResponseErrors { + return err + } + } + } + return nil +} + +func (h *clientHandler) removePeer(id string) { + h.backend.peers.Unregister(id) +} + +type peerConnection struct { + handler *clientHandler + peer *peer +} + +func (pc *peerConnection) Head() (common.Hash, *big.Int) { + return pc.peer.HeadAndTd() +} + +func (pc *peerConnection) RequestHeadersByHash(origin common.Hash, amount int, skip int, reverse bool) error { + rq := &distReq{ + getCost: func(dp distPeer) uint64 { + peer := dp.(*peer) + return peer.GetRequestCost(GetBlockHeadersMsg, amount) + }, + canSend: func(dp distPeer) bool { + return dp.(*peer) == pc.peer + }, + request: func(dp distPeer) func() { + reqID := genReqID() + peer := dp.(*peer) + cost := peer.GetRequestCost(GetBlockHeadersMsg, amount) + peer.fcServer.QueuedRequest(reqID, cost) + return func() { peer.RequestHeadersByHash(reqID, cost, origin, amount, skip, reverse) } + }, + } + _, ok := <-pc.handler.backend.reqDist.queue(rq) + if !ok { + return light.ErrNoPeers + } + return nil +} + +func (pc *peerConnection) RequestHeadersByNumber(origin uint64, amount int, skip int, reverse bool) error { + rq := &distReq{ + getCost: func(dp distPeer) uint64 { + peer := dp.(*peer) + return peer.GetRequestCost(GetBlockHeadersMsg, amount) + }, + canSend: func(dp distPeer) bool { + return dp.(*peer) == pc.peer + }, + request: func(dp distPeer) func() { + reqID := genReqID() + peer := dp.(*peer) + cost := peer.GetRequestCost(GetBlockHeadersMsg, amount) + peer.fcServer.QueuedRequest(reqID, cost) + return func() { peer.RequestHeadersByNumber(reqID, cost, origin, amount, skip, reverse) } + }, + } + _, ok := <-pc.handler.backend.reqDist.queue(rq) + if !ok { + return light.ErrNoPeers + } + return nil +} + +// downloaderPeerNotify implements peerSetNotify +type downloaderPeerNotify clientHandler + +func (d *downloaderPeerNotify) registerPeer(p *peer) { + h := (*clientHandler)(d) + pc := &peerConnection{ + handler: h, + peer: p, + } + h.downloader.RegisterLightPeer(p.id, ethVersion, pc) +} + +func (d *downloaderPeerNotify) unregisterPeer(p *peer) { + h := (*clientHandler)(d) + h.downloader.UnregisterPeer(p.id) +} diff --git a/les/commons.go b/les/commons.go index ef3c470e5..ad3c5aef3 100644 --- a/les/commons.go +++ b/les/commons.go @@ -17,25 +17,56 @@ package les import ( + "fmt" "math/big" + "sync" "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/core" + "github.com/ethereum/go-ethereum/core/rawdb" + "github.com/ethereum/go-ethereum/core/types" "github.com/ethereum/go-ethereum/eth" "github.com/ethereum/go-ethereum/ethdb" "github.com/ethereum/go-ethereum/light" "github.com/ethereum/go-ethereum/p2p" + "github.com/ethereum/go-ethereum/p2p/discv5" "github.com/ethereum/go-ethereum/p2p/enode" "github.com/ethereum/go-ethereum/params" ) +func errResp(code errCode, format string, v ...interface{}) error { + return fmt.Errorf("%v - %v", code, fmt.Sprintf(format, v...)) +} + +func lesTopic(genesisHash common.Hash, protocolVersion uint) discv5.Topic { + var name string + switch protocolVersion { + case lpv2: + name = "LES2" + default: + panic(nil) + } + return discv5.Topic(name + "@" + common.Bytes2Hex(genesisHash.Bytes()[0:8])) +} + +type chainReader interface { + CurrentHeader() *types.Header +} + // lesCommons contains fields needed by both server and client. type lesCommons struct { + genesis common.Hash config *eth.Config + chainConfig *params.ChainConfig iConfig *light.IndexerConfig chainDb ethdb.Database - protocolManager *ProtocolManager + peers *peerSet + chainReader chainReader chtIndexer, bloomTrieIndexer *core.ChainIndexer + oracle *checkpointOracle + + closeCh chan struct{} + wg sync.WaitGroup } // NodeInfo represents a short summary of the Ethereum sub-protocol metadata @@ -50,7 +81,7 @@ type NodeInfo struct { } // makeProtocols creates protocol descriptors for the given LES versions. -func (c *lesCommons) makeProtocols(versions []uint) []p2p.Protocol { +func (c *lesCommons) makeProtocols(versions []uint, runPeer func(version uint, p *p2p.Peer, rw p2p.MsgReadWriter) error, peerInfo func(id enode.ID) interface{}) []p2p.Protocol { protos := make([]p2p.Protocol, len(versions)) for i, version := range versions { version := version @@ -59,15 +90,10 @@ func (c *lesCommons) makeProtocols(versions []uint) []p2p.Protocol { Version: version, Length: ProtocolLengths[version], NodeInfo: c.nodeInfo, - Run: func(p *p2p.Peer, rw p2p.MsgReadWriter) error { - return c.protocolManager.runPeer(version, p, rw) - }, - PeerInfo: func(id enode.ID) interface{} { - if p := c.protocolManager.peers.Peer(peerIdToString(id)); p != nil { - return p.Info() - } - return nil + Run: func(peer *p2p.Peer, rw p2p.MsgReadWriter) error { + return runPeer(version, peer, rw) }, + PeerInfo: peerInfo, } } return protos @@ -75,22 +101,21 @@ func (c *lesCommons) makeProtocols(versions []uint) []p2p.Protocol { // nodeInfo retrieves some protocol metadata about the running host node. func (c *lesCommons) nodeInfo() interface{} { - chain := c.protocolManager.blockchain - head := chain.CurrentHeader() + head := c.chainReader.CurrentHeader() hash := head.Hash() return &NodeInfo{ Network: c.config.NetworkId, - Difficulty: chain.GetTd(hash, head.Number.Uint64()), - Genesis: chain.Genesis().Hash(), - Config: chain.Config(), - Head: chain.CurrentHeader().Hash(), + Difficulty: rawdb.ReadTd(c.chainDb, hash, head.Number.Uint64()), + Genesis: c.genesis, + Config: c.chainConfig, + Head: hash, CHT: c.latestLocalCheckpoint(), } } -// latestLocalCheckpoint finds the common stored section index and returns a set of -// post-processed trie roots (CHT and BloomTrie) associated with -// the appropriate section index and head hash as a local checkpoint package. +// latestLocalCheckpoint finds the common stored section index and returns a set +// of post-processed trie roots (CHT and BloomTrie) associated with the appropriate +// section index and head hash as a local checkpoint package. func (c *lesCommons) latestLocalCheckpoint() params.TrustedCheckpoint { sections, _, _ := c.chtIndexer.Sections() sections2, _, _ := c.bloomTrieIndexer.Sections() @@ -102,15 +127,15 @@ func (c *lesCommons) latestLocalCheckpoint() params.TrustedCheckpoint { // No checkpoint information can be provided. return params.TrustedCheckpoint{} } - return c.getLocalCheckpoint(sections - 1) + return c.localCheckpoint(sections - 1) } -// getLocalCheckpoint returns a set of post-processed trie roots (CHT and BloomTrie) +// localCheckpoint returns a set of post-processed trie roots (CHT and BloomTrie) // associated with the appropriate head hash by specific section index. // // The returned checkpoint is only the checkpoint generated by the local indexers, // not the stable checkpoint registered in the registrar contract. -func (c *lesCommons) getLocalCheckpoint(index uint64) params.TrustedCheckpoint { +func (c *lesCommons) localCheckpoint(index uint64) params.TrustedCheckpoint { sectionHead := c.chtIndexer.SectionHead(index) return params.TrustedCheckpoint{ SectionIndex: index, diff --git a/les/costtracker.go b/les/costtracker.go index d1bb172e4..d1f5b54ca 100644 --- a/les/costtracker.go +++ b/les/costtracker.go @@ -81,7 +81,8 @@ var ( ) const ( - maxCostFactor = 2 // ratio of maximum and average cost estimates + maxCostFactor = 2 // ratio of maximum and average cost estimates + bufLimitRatio = 6000 // fixed bufLimit/MRR ratio gfUsageThreshold = 0.5 gfUsageTC = time.Second gfRaiseTC = time.Second * 200 @@ -127,6 +128,10 @@ type costTracker struct { totalRechargeCh chan uint64 stats map[uint64][]uint64 // Used for testing purpose. + + // TestHooks + testing bool // Disable real cost evaluation for testing purpose. + testCostList RequestCostList // Customized cost table for testing purpose. } // newCostTracker creates a cost tracker and loads the cost factor statistics from the database. @@ -265,8 +270,9 @@ func (ct *costTracker) gfLoop() { select { case r := <-ct.reqInfoCh: requestServedMeter.Mark(int64(r.servingTime)) - requestEstimatedMeter.Mark(int64(r.avgTimeCost / factor)) requestServedTimer.Update(time.Duration(r.servingTime)) + requestEstimatedMeter.Mark(int64(r.avgTimeCost / factor)) + requestEstimatedTimer.Update(time.Duration(r.avgTimeCost / factor)) relativeCostHistogram.Update(int64(r.avgTimeCost / factor / r.servingTime)) now := mclock.Now() @@ -323,7 +329,6 @@ func (ct *costTracker) gfLoop() { } recentServedGauge.Update(int64(recentTime)) recentEstimatedGauge.Update(int64(recentAvg)) - totalRechargeGauge.Update(int64(totalRecharge)) case <-saveTicker.C: saveCostFactor() diff --git a/les/distributor.go b/les/distributor.go index 9235adc03..62abef47d 100644 --- a/les/distributor.go +++ b/les/distributor.go @@ -28,14 +28,17 @@ import ( // suitable peers, obeying flow control rules and prioritizing them in creation // order (even when a resend is necessary). type requestDistributor struct { - clock mclock.Clock - reqQueue *list.List - lastReqOrder uint64 - peers map[distPeer]struct{} - peerLock sync.RWMutex - stopChn, loopChn chan struct{} - loopNextSent bool - lock sync.Mutex + clock mclock.Clock + reqQueue *list.List + lastReqOrder uint64 + peers map[distPeer]struct{} + peerLock sync.RWMutex + loopChn chan struct{} + loopNextSent bool + lock sync.Mutex + + closeCh chan struct{} + wg sync.WaitGroup } // distPeer is an LES server peer interface for the request distributor. @@ -66,20 +69,22 @@ type distReq struct { sentChn chan distPeer element *list.Element waitForPeers mclock.AbsTime + enterQueue mclock.AbsTime } // newRequestDistributor creates a new request distributor -func newRequestDistributor(peers *peerSet, stopChn chan struct{}, clock mclock.Clock) *requestDistributor { +func newRequestDistributor(peers *peerSet, clock mclock.Clock) *requestDistributor { d := &requestDistributor{ clock: clock, reqQueue: list.New(), loopChn: make(chan struct{}, 2), - stopChn: stopChn, + closeCh: make(chan struct{}), peers: make(map[distPeer]struct{}), } if peers != nil { peers.notify(d) } + d.wg.Add(1) go d.loop() return d } @@ -115,9 +120,10 @@ const waitForPeers = time.Second * 3 // main event loop func (d *requestDistributor) loop() { + defer d.wg.Done() for { select { - case <-d.stopChn: + case <-d.closeCh: d.lock.Lock() elem := d.reqQueue.Front() for elem != nil { @@ -140,6 +146,7 @@ func (d *requestDistributor) loop() { send := req.request(peer) if send != nil { peer.queueSend(send) + requestSendDelay.Update(time.Duration(d.clock.Now() - req.enterQueue)) } chn <- peer close(chn) @@ -249,6 +256,9 @@ func (d *requestDistributor) queue(r *distReq) chan distPeer { r.reqOrder = d.lastReqOrder r.waitForPeers = d.clock.Now() + mclock.AbsTime(waitForPeers) } + // Assign the timestamp when the request is queued no matter it's + // a new one or re-queued one. + r.enterQueue = d.clock.Now() back := d.reqQueue.Back() if back == nil || r.reqOrder > back.Value.(*distReq).reqOrder { @@ -294,3 +304,8 @@ func (d *requestDistributor) remove(r *distReq) { r.element = nil } } + +func (d *requestDistributor) close() { + close(d.closeCh) + d.wg.Wait() +} diff --git a/les/distributor_test.go b/les/distributor_test.go index d2247212c..00d43e1d6 100644 --- a/les/distributor_test.go +++ b/les/distributor_test.go @@ -121,7 +121,7 @@ func testRequestDistributor(t *testing.T, resend bool) { stop := make(chan struct{}) defer close(stop) - dist := newRequestDistributor(nil, stop, &mclock.System{}) + dist := newRequestDistributor(nil, &mclock.System{}) var peers [testDistPeerCount]*testDistPeer for i := range peers { peers[i] = &testDistPeer{} diff --git a/les/fetcher.go b/les/fetcher.go index 76e4f076a..df76c56d7 100644 --- a/les/fetcher.go +++ b/les/fetcher.go @@ -40,9 +40,8 @@ const ( // ODR system to ensure that we only request data related to a certain block from peers who have already processed // and announced that block. type lightFetcher struct { - pm *ProtocolManager - odr *LesOdr - chain lightChain + handler *clientHandler + chain *light.LightChain lock sync.Mutex // lock protects access to the fetcher's internal state variables except sent requests maxConfirmedTd *big.Int @@ -58,13 +57,9 @@ type lightFetcher struct { requestTriggered bool requestTrigger chan struct{} lastTrustedHeader *types.Header -} -// lightChain extends the BlockChain interface by locking. -type lightChain interface { - BlockChain - LockChain() - UnlockChain() + closeCh chan struct{} + wg sync.WaitGroup } // fetcherPeerInfo holds fetcher-specific information about each active peer @@ -114,32 +109,37 @@ type fetchResponse struct { } // newLightFetcher creates a new light fetcher -func newLightFetcher(pm *ProtocolManager) *lightFetcher { +func newLightFetcher(h *clientHandler) *lightFetcher { f := &lightFetcher{ - pm: pm, - chain: pm.blockchain.(*light.LightChain), - odr: pm.odr, + handler: h, + chain: h.backend.blockchain, peers: make(map[*peer]*fetcherPeerInfo), deliverChn: make(chan fetchResponse, 100), requested: make(map[uint64]fetchRequest), timeoutChn: make(chan uint64), requestTrigger: make(chan struct{}, 1), syncDone: make(chan *peer), + closeCh: make(chan struct{}), maxConfirmedTd: big.NewInt(0), } - pm.peers.notify(f) + h.backend.peers.notify(f) - f.pm.wg.Add(1) + f.wg.Add(1) go f.syncLoop() return f } +func (f *lightFetcher) close() { + close(f.closeCh) + f.wg.Wait() +} + // syncLoop is the main event loop of the light fetcher func (f *lightFetcher) syncLoop() { - defer f.pm.wg.Done() + defer f.wg.Done() for { select { - case <-f.pm.quitSync: + case <-f.closeCh: return // request loop keeps running until no further requests are necessary or possible case <-f.requestTrigger: @@ -156,7 +156,7 @@ func (f *lightFetcher) syncLoop() { f.lock.Unlock() if rq != nil { - if _, ok := <-f.pm.reqDist.queue(rq); ok { + if _, ok := <-f.handler.backend.reqDist.queue(rq); ok { if syncing { f.lock.Lock() f.syncing = true @@ -187,9 +187,9 @@ func (f *lightFetcher) syncLoop() { } f.reqMu.Unlock() if ok { - f.pm.serverPool.adjustResponseTime(req.peer.poolEntry, time.Duration(mclock.Now()-req.sent), true) + f.handler.backend.serverPool.adjustResponseTime(req.peer.poolEntry, time.Duration(mclock.Now()-req.sent), true) req.peer.Log().Debug("Fetching data timed out hard") - go f.pm.removePeer(req.peer.id) + go f.handler.removePeer(req.peer.id) } case resp := <-f.deliverChn: f.reqMu.Lock() @@ -202,12 +202,12 @@ func (f *lightFetcher) syncLoop() { } f.reqMu.Unlock() if ok { - f.pm.serverPool.adjustResponseTime(req.peer.poolEntry, time.Duration(mclock.Now()-req.sent), req.timeout) + f.handler.backend.serverPool.adjustResponseTime(req.peer.poolEntry, time.Duration(mclock.Now()-req.sent), req.timeout) } f.lock.Lock() if !ok || !(f.syncing || f.processResponse(req, resp)) { resp.peer.Log().Debug("Failed processing response") - go f.pm.removePeer(resp.peer.id) + go f.handler.removePeer(resp.peer.id) } f.lock.Unlock() case p := <-f.syncDone: @@ -264,7 +264,7 @@ func (f *lightFetcher) announce(p *peer, head *announceData) { if fp.lastAnnounced != nil && head.Td.Cmp(fp.lastAnnounced.td) <= 0 { // announced tds should be strictly monotonic p.Log().Debug("Received non-monotonic td", "current", head.Td, "previous", fp.lastAnnounced.td) - go f.pm.removePeer(p.id) + go f.handler.removePeer(p.id) return } @@ -297,7 +297,7 @@ func (f *lightFetcher) announce(p *peer, head *announceData) { // if one of root's children is canonical, keep it, delete other branches and root itself var newRoot *fetcherTreeNode for i, nn := range fp.root.children { - if rawdb.ReadCanonicalHash(f.pm.chainDb, nn.number) == nn.hash { + if rawdb.ReadCanonicalHash(f.handler.backend.chainDb, nn.number) == nn.hash { fp.root.children = append(fp.root.children[:i], fp.root.children[i+1:]...) nn.parent = nil newRoot = nn @@ -390,7 +390,7 @@ func (f *lightFetcher) peerHasBlock(p *peer, hash common.Hash, number uint64, ha // // when syncing, just check if it is part of the known chain, there is nothing better we // can do since we do not know the most recent block hash yet - return rawdb.ReadCanonicalHash(f.pm.chainDb, fp.root.number) == fp.root.hash && rawdb.ReadCanonicalHash(f.pm.chainDb, number) == hash + return rawdb.ReadCanonicalHash(f.handler.backend.chainDb, fp.root.number) == fp.root.hash && rawdb.ReadCanonicalHash(f.handler.backend.chainDb, number) == hash } // requestAmount calculates the amount of headers to be downloaded starting @@ -453,8 +453,7 @@ func (f *lightFetcher) findBestRequest() (bestHash common.Hash, bestAmount uint6 if f.checkKnownNode(p, n) || n.requested { continue } - - //if ulc mode is disabled, isTrustedHash returns true + // if ulc mode is disabled, isTrustedHash returns true amount := f.requestAmount(p, n) if (bestTd == nil || n.td.Cmp(bestTd) > 0 || amount < bestAmount) && (f.isTrustedHash(hash) || f.maxConfirmedTd.Int64() == 0) { bestHash = hash @@ -470,7 +469,7 @@ func (f *lightFetcher) findBestRequest() (bestHash common.Hash, bestAmount uint6 // isTrustedHash checks if the block can be trusted by the minimum trusted fraction. func (f *lightFetcher) isTrustedHash(hash common.Hash) bool { // If ultra light cliet mode is disabled, trust all hashes - if f.pm.ulc == nil { + if f.handler.ulc == nil { return true } // Ultra light enabled, only trust after enough confirmations @@ -480,7 +479,7 @@ func (f *lightFetcher) isTrustedHash(hash common.Hash) bool { agreed++ } } - return 100*agreed/len(f.pm.ulc.keys) >= f.pm.ulc.fraction + return 100*agreed/len(f.handler.ulc.keys) >= f.handler.ulc.fraction } func (f *lightFetcher) newFetcherDistReqForSync(bestHash common.Hash) *distReq { @@ -500,14 +499,14 @@ func (f *lightFetcher) newFetcherDistReqForSync(bestHash common.Hash) *distReq { return fp != nil && fp.nodeByHash[bestHash] != nil }, request: func(dp distPeer) func() { - if f.pm.ulc != nil { + if f.handler.ulc != nil { // Keep last trusted header before sync f.setLastTrustedHeader(f.chain.CurrentHeader()) } go func() { p := dp.(*peer) p.Log().Debug("Synchronisation started") - f.pm.synchronise(p) + f.handler.synchronise(p) f.syncDone <- p }() return nil @@ -607,7 +606,7 @@ func (f *lightFetcher) newHeaders(headers []*types.Header, tds []*big.Int) { for p, fp := range f.peers { if !f.checkAnnouncedHeaders(fp, headers, tds) { p.Log().Debug("Inconsistent announcement") - go f.pm.removePeer(p.id) + go f.handler.removePeer(p.id) } if fp.confirmedTd != nil && (maxTd == nil || maxTd.Cmp(fp.confirmedTd) > 0) { maxTd = fp.confirmedTd @@ -705,7 +704,7 @@ func (f *lightFetcher) checkSyncedHeaders(p *peer) { node = fp.lastAnnounced td *big.Int ) - if f.pm.ulc != nil { + if f.handler.ulc != nil { // Roll back untrusted blocks h, unapproved := f.lastTrustedTreeNode(p) f.chain.Rollback(unapproved) @@ -721,7 +720,7 @@ func (f *lightFetcher) checkSyncedHeaders(p *peer) { // Now node is the latest downloaded/approved header after syncing if node == nil { p.Log().Debug("Synchronisation failed") - go f.pm.removePeer(p.id) + go f.handler.removePeer(p.id) return } header := f.chain.GetHeader(node.hash, node.number) @@ -741,7 +740,7 @@ func (f *lightFetcher) lastTrustedTreeNode(p *peer) (*types.Header, []common.Has if canonical.Number.Uint64() > f.lastTrustedHeader.Number.Uint64() { canonical = f.chain.GetHeaderByNumber(f.lastTrustedHeader.Number.Uint64()) } - commonAncestor := rawdb.FindCommonAncestor(f.pm.chainDb, canonical, f.lastTrustedHeader) + commonAncestor := rawdb.FindCommonAncestor(f.handler.backend.chainDb, canonical, f.lastTrustedHeader) if commonAncestor == nil { log.Error("Common ancestor of last trusted header and canonical header is nil", "canonical hash", canonical.Hash(), "trusted hash", f.lastTrustedHeader.Hash()) return current, unapprovedHashes @@ -787,7 +786,7 @@ func (f *lightFetcher) checkKnownNode(p *peer, n *fetcherTreeNode) bool { } if !f.checkAnnouncedHeaders(fp, []*types.Header{header}, []*big.Int{td}) { p.Log().Debug("Inconsistent announcement") - go f.pm.removePeer(p.id) + go f.handler.removePeer(p.id) } if fp.confirmedTd != nil { f.updateMaxConfirmedTd(fp.confirmedTd) @@ -880,12 +879,12 @@ func (f *lightFetcher) checkUpdateStats(p *peer, newEntry *updateStatsEntry) { fp.firstUpdateStats = newEntry } for fp.firstUpdateStats != nil && fp.firstUpdateStats.time <= now-mclock.AbsTime(blockDelayTimeout) { - f.pm.serverPool.adjustBlockDelay(p.poolEntry, blockDelayTimeout) + f.handler.backend.serverPool.adjustBlockDelay(p.poolEntry, blockDelayTimeout) fp.firstUpdateStats = fp.firstUpdateStats.next } if fp.confirmedTd != nil { for fp.firstUpdateStats != nil && fp.firstUpdateStats.td.Cmp(fp.confirmedTd) <= 0 { - f.pm.serverPool.adjustBlockDelay(p.poolEntry, time.Duration(now-fp.firstUpdateStats.time)) + f.handler.backend.serverPool.adjustBlockDelay(p.poolEntry, time.Duration(now-fp.firstUpdateStats.time)) fp.firstUpdateStats = fp.firstUpdateStats.next } } diff --git a/les/fetcher_test.go b/les/fetcher_test.go deleted file mode 100644 index c6faabd66..000000000 --- a/les/fetcher_test.go +++ /dev/null @@ -1,168 +0,0 @@ -// Copyright 2019 The go-ethereum Authors -// This file is part of the go-ethereum library. -// -// The go-ethereum library is free software: you can redistribute it and/or modify -// it under the terms of the GNU Lesser General Public License as published by -// the Free Software Foundation, either version 3 of the License, or -// (at your option) any later version. -// -// The go-ethereum library is distributed in the hope that it will be useful, -// but WITHOUT ANY WARRANTY; without even the implied warranty of -// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -// GNU Lesser General Public License for more details. -// -// You should have received a copy of the GNU Lesser General Public License -// along with the go-ethereum library. If not, see . - -package les - -import ( - "math/big" - "testing" - - "net" - - "github.com/ethereum/go-ethereum/common" - "github.com/ethereum/go-ethereum/core/types" - "github.com/ethereum/go-ethereum/crypto" - "github.com/ethereum/go-ethereum/p2p" - "github.com/ethereum/go-ethereum/p2p/enode" -) - -func TestFetcherULCPeerSelector(t *testing.T) { - id1 := newNodeID(t).ID() - id2 := newNodeID(t).ID() - id3 := newNodeID(t).ID() - id4 := newNodeID(t).ID() - - ftn1 := &fetcherTreeNode{ - hash: common.HexToHash("1"), - td: big.NewInt(1), - } - ftn2 := &fetcherTreeNode{ - hash: common.HexToHash("2"), - td: big.NewInt(2), - parent: ftn1, - } - ftn3 := &fetcherTreeNode{ - hash: common.HexToHash("3"), - td: big.NewInt(3), - parent: ftn2, - } - lf := lightFetcher{ - pm: &ProtocolManager{ - ulc: &ulc{ - keys: map[string]bool{ - id1.String(): true, - id2.String(): true, - id3.String(): true, - id4.String(): true, - }, - fraction: 70, - }, - }, - maxConfirmedTd: ftn1.td, - - peers: map[*peer]*fetcherPeerInfo{ - { - id: "peer1", - Peer: p2p.NewPeer(id1, "peer1", []p2p.Cap{}), - trusted: true, - }: { - nodeByHash: map[common.Hash]*fetcherTreeNode{ - ftn1.hash: ftn1, - ftn2.hash: ftn2, - }, - }, - { - Peer: p2p.NewPeer(id2, "peer2", []p2p.Cap{}), - id: "peer2", - trusted: true, - }: { - nodeByHash: map[common.Hash]*fetcherTreeNode{ - ftn1.hash: ftn1, - ftn2.hash: ftn2, - }, - }, - { - id: "peer3", - Peer: p2p.NewPeer(id3, "peer3", []p2p.Cap{}), - trusted: true, - }: { - nodeByHash: map[common.Hash]*fetcherTreeNode{ - ftn1.hash: ftn1, - ftn2.hash: ftn2, - ftn3.hash: ftn3, - }, - }, - { - id: "peer4", - Peer: p2p.NewPeer(id4, "peer4", []p2p.Cap{}), - trusted: true, - }: { - nodeByHash: map[common.Hash]*fetcherTreeNode{ - ftn1.hash: ftn1, - }, - }, - }, - chain: &lightChainStub{ - tds: map[common.Hash]*big.Int{}, - headers: map[common.Hash]*types.Header{ - ftn1.hash: {}, - ftn2.hash: {}, - ftn3.hash: {}, - }, - }, - } - bestHash, bestAmount, bestTD, sync := lf.findBestRequest() - - if bestTD == nil { - t.Fatal("Empty result") - } - - if bestTD.Cmp(ftn2.td) != 0 { - t.Fatal("bad td", bestTD) - } - if bestHash != ftn2.hash { - t.Fatal("bad hash", bestTD) - } - - _, _ = bestAmount, sync -} - -type lightChainStub struct { - BlockChain - tds map[common.Hash]*big.Int - headers map[common.Hash]*types.Header - insertHeaderChainAssertFunc func(chain []*types.Header, checkFreq int) (int, error) -} - -func (l *lightChainStub) GetHeader(hash common.Hash, number uint64) *types.Header { - if h, ok := l.headers[hash]; ok { - return h - } - - return nil -} - -func (l *lightChainStub) LockChain() {} -func (l *lightChainStub) UnlockChain() {} - -func (l *lightChainStub) GetTd(hash common.Hash, number uint64) *big.Int { - if td, ok := l.tds[hash]; ok { - return td - } - return nil -} - -func (l *lightChainStub) InsertHeaderChain(chain []*types.Header, checkFreq int) (int, error) { - return l.insertHeaderChainAssertFunc(chain, checkFreq) -} - -func newNodeID(t *testing.T) *enode.Node { - key, err := crypto.GenerateKey() - if err != nil { - t.Fatal("generate key err:", err) - } - return enode.NewV4(&key.PublicKey, net.IP{}, 35000, 35000) -} diff --git a/les/handler.go b/les/handler.go deleted file mode 100644 index 807065e55..000000000 --- a/les/handler.go +++ /dev/null @@ -1,1293 +0,0 @@ -// Copyright 2016 The go-ethereum Authors -// This file is part of the go-ethereum library. -// -// The go-ethereum library is free software: you can redistribute it and/or modify -// it under the terms of the GNU Lesser General Public License as published by -// the Free Software Foundation, either version 3 of the License, or -// (at your option) any later version. -// -// The go-ethereum library is distributed in the hope that it will be useful, -// but WITHOUT ANY WARRANTY; without even the implied warranty of -// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -// GNU Lesser General Public License for more details. -// -// You should have received a copy of the GNU Lesser General Public License -// along with the go-ethereum library. If not, see . - -package les - -import ( - "encoding/binary" - "encoding/json" - "errors" - "fmt" - "math/big" - "sync" - "sync/atomic" - "time" - - "github.com/ethereum/go-ethereum/common" - "github.com/ethereum/go-ethereum/common/mclock" - "github.com/ethereum/go-ethereum/core" - "github.com/ethereum/go-ethereum/core/rawdb" - "github.com/ethereum/go-ethereum/core/state" - "github.com/ethereum/go-ethereum/core/types" - "github.com/ethereum/go-ethereum/eth/downloader" - "github.com/ethereum/go-ethereum/ethdb" - "github.com/ethereum/go-ethereum/event" - "github.com/ethereum/go-ethereum/light" - "github.com/ethereum/go-ethereum/log" - "github.com/ethereum/go-ethereum/p2p" - "github.com/ethereum/go-ethereum/p2p/discv5" - "github.com/ethereum/go-ethereum/params" - "github.com/ethereum/go-ethereum/rlp" - "github.com/ethereum/go-ethereum/trie" -) - -var errTooManyInvalidRequest = errors.New("too many invalid requests made") - -const ( - softResponseLimit = 2 * 1024 * 1024 // Target maximum size of returned blocks, headers or node data. - estHeaderRlpSize = 500 // Approximate size of an RLP encoded block header - - ethVersion = 63 // equivalent eth version for the downloader - - MaxHeaderFetch = 192 // Amount of block headers to be fetched per retrieval request - MaxBodyFetch = 32 // Amount of block bodies to be fetched per retrieval request - MaxReceiptFetch = 128 // Amount of transaction receipts to allow fetching per request - MaxCodeFetch = 64 // Amount of contract codes to allow fetching per request - MaxProofsFetch = 64 // Amount of merkle proofs to be fetched per retrieval request - MaxHelperTrieProofsFetch = 64 // Amount of merkle proofs to be fetched per retrieval request - MaxTxSend = 64 // Amount of transactions to be send per request - MaxTxStatus = 256 // Amount of transactions to queried per request - - disableClientRemovePeer = false -) - -func errResp(code errCode, format string, v ...interface{}) error { - return fmt.Errorf("%v - %v", code, fmt.Sprintf(format, v...)) -} - -type BlockChain interface { - Config() *params.ChainConfig - HasHeader(hash common.Hash, number uint64) bool - GetHeader(hash common.Hash, number uint64) *types.Header - GetHeaderByHash(hash common.Hash) *types.Header - CurrentHeader() *types.Header - GetTd(hash common.Hash, number uint64) *big.Int - StateCache() state.Database - InsertHeaderChain(chain []*types.Header, checkFreq int) (int, error) - Rollback(chain []common.Hash) - GetHeaderByNumber(number uint64) *types.Header - GetAncestor(hash common.Hash, number, ancestor uint64, maxNonCanonical *uint64) (common.Hash, uint64) - Genesis() *types.Block - SubscribeChainHeadEvent(ch chan<- core.ChainHeadEvent) event.Subscription -} - -type txPool interface { - AddRemotes(txs []*types.Transaction) []error - AddRemotesSync(txs []*types.Transaction) []error - Status(hashes []common.Hash) []core.TxStatus -} - -type ProtocolManager struct { - // Configs - chainConfig *params.ChainConfig - iConfig *light.IndexerConfig - - client bool // The indicator whether the node is light client - maxPeers int // The maximum number peers allowed to connect. - networkId uint64 // The identity of network. - - txpool txPool - txrelay *lesTxRelay - blockchain BlockChain - chainDb ethdb.Database - odr *LesOdr - server *LesServer - serverPool *serverPool - lesTopic discv5.Topic - reqDist *requestDistributor - retriever *retrieveManager - servingQueue *servingQueue - downloader *downloader.Downloader - fetcher *lightFetcher - ulc *ulc - peers *peerSet - checkpoint *params.TrustedCheckpoint - reg *checkpointOracle // If reg == nil, it means the checkpoint registrar is not activated - - // channels for fetcher, syncer, txsyncLoop - newPeerCh chan *peer - quitSync chan struct{} - noMorePeers chan struct{} - - wg *sync.WaitGroup - eventMux *event.TypeMux - - // Callbacks - synced func() bool - - // Testing fields - addTxsSync bool -} - -// NewProtocolManager returns a new ethereum sub protocol manager. The Ethereum sub protocol manages peers capable -// with the ethereum network. -func NewProtocolManager(chainConfig *params.ChainConfig, checkpoint *params.TrustedCheckpoint, indexerConfig *light.IndexerConfig, ulcServers []string, ulcFraction int, client bool, networkId uint64, mux *event.TypeMux, peers *peerSet, blockchain BlockChain, txpool txPool, chainDb ethdb.Database, odr *LesOdr, serverPool *serverPool, registrar *checkpointOracle, quitSync chan struct{}, wg *sync.WaitGroup, synced func() bool) (*ProtocolManager, error) { - // Create the protocol manager with the base fields - manager := &ProtocolManager{ - client: client, - eventMux: mux, - blockchain: blockchain, - chainConfig: chainConfig, - iConfig: indexerConfig, - chainDb: chainDb, - odr: odr, - networkId: networkId, - txpool: txpool, - serverPool: serverPool, - reg: registrar, - peers: peers, - newPeerCh: make(chan *peer), - quitSync: quitSync, - wg: wg, - noMorePeers: make(chan struct{}), - checkpoint: checkpoint, - synced: synced, - } - if odr != nil { - manager.retriever = odr.retriever - manager.reqDist = odr.retriever.dist - } - - if ulcServers != nil { - ulc, err := newULC(ulcServers, ulcFraction) - if err != nil { - log.Warn("Failed to initialize ultra light client", "err", err) - } else { - manager.ulc = ulc - } - } - removePeer := manager.removePeer - if disableClientRemovePeer { - removePeer = func(id string) {} - } - if client { - var checkpointNumber uint64 - if checkpoint != nil { - checkpointNumber = (checkpoint.SectionIndex+1)*params.CHTFrequency - 1 - } - manager.downloader = downloader.New(checkpointNumber, chainDb, nil, manager.eventMux, nil, blockchain, removePeer) - manager.peers.notify((*downloaderPeerNotify)(manager)) - manager.fetcher = newLightFetcher(manager) - } - return manager, nil -} - -// removePeer initiates disconnection from a peer by removing it from the peer set -func (pm *ProtocolManager) removePeer(id string) { - pm.peers.Unregister(id) -} - -func (pm *ProtocolManager) Start(maxPeers int) { - pm.maxPeers = maxPeers - if pm.client { - go pm.syncer() - } else { - go func() { - for range pm.newPeerCh { - } - }() - } -} - -func (pm *ProtocolManager) Stop() { - // Showing a log message. During download / process this could actually - // take between 5 to 10 seconds and therefor feedback is required. - log.Info("Stopping light Ethereum protocol") - - // Quit the sync loop. - // After this send has completed, no new peers will be accepted. - pm.noMorePeers <- struct{}{} - - close(pm.quitSync) // quits syncer, fetcher - - if pm.servingQueue != nil { - pm.servingQueue.stop() - } - - // Disconnect existing sessions. - // This also closes the gate for any new registrations on the peer set. - // sessions which are already established but not added to pm.peers yet - // will exit when they try to register. - pm.peers.Close() - - // Wait for any process action - pm.wg.Wait() - - log.Info("Light Ethereum protocol stopped") -} - -// runPeer is the p2p protocol run function for the given version. -func (pm *ProtocolManager) runPeer(version uint, p *p2p.Peer, rw p2p.MsgReadWriter) error { - var entry *poolEntry - peer := pm.newPeer(int(version), pm.networkId, p, rw) - if pm.serverPool != nil { - entry = pm.serverPool.connect(peer, peer.Node()) - } - peer.poolEntry = entry - select { - case pm.newPeerCh <- peer: - pm.wg.Add(1) - defer pm.wg.Done() - err := pm.handle(peer) - if entry != nil { - pm.serverPool.disconnect(entry) - } - return err - case <-pm.quitSync: - if entry != nil { - pm.serverPool.disconnect(entry) - } - return p2p.DiscQuitting - } -} - -func (pm *ProtocolManager) newPeer(pv int, nv uint64, p *p2p.Peer, rw p2p.MsgReadWriter) *peer { - var trusted bool - if pm.ulc != nil { - trusted = pm.ulc.trusted(p.ID()) - } - return newPeer(pv, nv, trusted, p, newMeteredMsgWriter(rw)) -} - -// handle is the callback invoked to manage the life cycle of a les peer. When -// this function terminates, the peer is disconnected. -func (pm *ProtocolManager) handle(p *peer) error { - // Ignore maxPeers if this is a trusted peer - // In server mode we try to check into the client pool after handshake - if pm.client && pm.peers.Len() >= pm.maxPeers && !p.Peer.Info().Network.Trusted { - clientRejectedMeter.Mark(1) - return p2p.DiscTooManyPeers - } - // Reject light clients if server is not synced. - if !pm.client && !pm.synced() { - clientRejectedMeter.Mark(1) - return p2p.DiscRequested - } - p.Log().Debug("Light Ethereum peer connected", "name", p.Name()) - - // Execute the LES handshake - var ( - genesis = pm.blockchain.Genesis() - head = pm.blockchain.CurrentHeader() - hash = head.Hash() - number = head.Number.Uint64() - td = pm.blockchain.GetTd(hash, number) - ) - if err := p.Handshake(td, hash, number, genesis.Hash(), pm.server); err != nil { - p.Log().Debug("Light Ethereum handshake failed", "err", err) - clientErrorMeter.Mark(1) - return err - } - if p.fcClient != nil { - defer p.fcClient.Disconnect() - } - - if rw, ok := p.rw.(*meteredMsgReadWriter); ok { - rw.Init(p.version) - } - - // Register the peer locally - if err := pm.peers.Register(p); err != nil { - clientErrorMeter.Mark(1) - p.Log().Error("Light Ethereum peer registration failed", "err", err) - return err - } - if !pm.client && p.balanceTracker == nil { - // add dummy balance tracker for tests - p.balanceTracker = &balanceTracker{} - p.balanceTracker.init(&mclock.System{}, 1) - } - connectedAt := time.Now() - defer func() { - p.balanceTracker = nil - pm.removePeer(p.id) - connectionTimer.UpdateSince(connectedAt) - }() - - // Register the peer in the downloader. If the downloader considers it banned, we disconnect - if pm.client { - p.lock.Lock() - head := p.headInfo - p.lock.Unlock() - if pm.fetcher != nil { - pm.fetcher.announce(p, head) - } - - if p.poolEntry != nil { - pm.serverPool.registered(p.poolEntry) - } - } - // main loop. handle incoming messages. - for { - if err := pm.handleMsg(p); err != nil { - p.Log().Debug("Light Ethereum message handling failed", "err", err) - if p.fcServer != nil { - p.fcServer.DumpLogs() - } - return err - } - } -} - -// handleMsg is invoked whenever an inbound message is received from a remote -// peer. The remote connection is torn down upon returning any error. -func (pm *ProtocolManager) handleMsg(p *peer) error { - select { - case err := <-p.errCh: - return err - default: - } - // Read the next message from the remote peer, and ensure it's fully consumed - msg, err := p.rw.ReadMsg() - if err != nil { - return err - } - p.Log().Trace("Light Ethereum message arrived", "code", msg.Code, "bytes", msg.Size) - - p.responseCount++ - responseCount := p.responseCount - var ( - maxCost uint64 - task *servingTask - ) - - accept := func(reqID, reqCnt, maxCnt uint64) bool { - inSizeCost := func() uint64 { - if pm.server.costTracker != nil { - return pm.server.costTracker.realCost(0, msg.Size, 0) - } - return 0 - } - if p.isFrozen() || reqCnt == 0 || p.fcClient == nil || reqCnt > maxCnt { - p.fcClient.OneTimeCost(inSizeCost()) - return false - } - maxCost = p.fcCosts.getMaxCost(msg.Code, reqCnt) - gf := float64(1) - if pm.server.costTracker != nil { - gf = pm.server.costTracker.globalFactor() - if gf < 0.001 { - p.Log().Error("Invalid global cost factor", "globalFactor", gf) - gf = 1 - } - } - maxTime := uint64(float64(maxCost) / gf) - - if accepted, bufShort, servingPriority := p.fcClient.AcceptRequest(reqID, responseCount, maxCost); !accepted { - p.freezeClient() - p.Log().Warn("Request came too early", "remaining", common.PrettyDuration(time.Duration(bufShort*1000000/p.fcParams.MinRecharge))) - p.fcClient.OneTimeCost(inSizeCost()) - return false - } else { - task = pm.servingQueue.newTask(p, maxTime, servingPriority) - } - if task.start() { - return true - } - p.fcClient.RequestProcessed(reqID, responseCount, maxCost, inSizeCost()) - return false - } - - if msg.Size > ProtocolMaxMsgSize { - return errResp(ErrMsgTooLarge, "%v > %v", msg.Size, ProtocolMaxMsgSize) - } - defer msg.Discard() - - var deliverMsg *Msg - balanceTracker := p.balanceTracker - - sendResponse := func(reqID, amount uint64, reply *reply, servingTime uint64) { - p.responseLock.Lock() - defer p.responseLock.Unlock() - - if p.isFrozen() { - amount = 0 - reply = nil - } - var replySize uint32 - if reply != nil { - replySize = reply.size() - } - var realCost uint64 - if pm.server.costTracker != nil { - realCost = pm.server.costTracker.realCost(servingTime, msg.Size, replySize) - if amount != 0 { - pm.server.costTracker.updateStats(msg.Code, amount, servingTime, realCost) - balanceTracker.requestCost(realCost) - } - } else { - realCost = maxCost - } - bv := p.fcClient.RequestProcessed(reqID, responseCount, maxCost, realCost) - if reply != nil { - p.queueSend(func() { - if err := reply.send(bv); err != nil { - select { - case p.errCh <- err: - default: - } - } - }) - } - } - - // Handle the message depending on its contents - switch msg.Code { - case StatusMsg: - p.Log().Trace("Received status message") - // Status messages should never arrive after the handshake - return errResp(ErrExtraStatusMsg, "uncontrolled status message") - - // Block header query, collect the requested headers and reply - case AnnounceMsg: - p.Log().Trace("Received announce message") - var req announceData - if err := msg.Decode(&req); err != nil { - return errResp(ErrDecode, "%v: %v", msg, err) - } - if err := req.sanityCheck(); err != nil { - return err - } - update, size := req.Update.decode() - if p.rejectUpdate(size) { - return errResp(ErrRequestRejected, "") - } - p.updateFlowControl(update) - - if req.Hash != (common.Hash{}) { - if p.announceType == announceTypeNone { - return errResp(ErrUnexpectedResponse, "") - } - if p.announceType == announceTypeSigned { - if err := req.checkSignature(p.ID(), update); err != nil { - p.Log().Trace("Invalid announcement signature", "err", err) - return err - } - p.Log().Trace("Valid announcement signature") - } - - p.Log().Trace("Announce message content", "number", req.Number, "hash", req.Hash, "td", req.Td, "reorg", req.ReorgDepth) - if pm.fetcher != nil { - pm.fetcher.announce(p, &req) - } - } - - case GetBlockHeadersMsg: - p.Log().Trace("Received block header request") - // Decode the complex header query - var req struct { - ReqID uint64 - Query getBlockHeadersData - } - if err := msg.Decode(&req); err != nil { - return errResp(ErrDecode, "%v: %v", msg, err) - } - - query := req.Query - if accept(req.ReqID, query.Amount, MaxHeaderFetch) { - go func() { - hashMode := query.Origin.Hash != (common.Hash{}) - first := true - maxNonCanonical := uint64(100) - - // Gather headers until the fetch or network limits is reached - var ( - bytes common.StorageSize - headers []*types.Header - unknown bool - ) - for !unknown && len(headers) < int(query.Amount) && bytes < softResponseLimit { - if !first && !task.waitOrStop() { - sendResponse(req.ReqID, 0, nil, task.servingTime) - return - } - // Retrieve the next header satisfying the query - var origin *types.Header - if hashMode { - if first { - origin = pm.blockchain.GetHeaderByHash(query.Origin.Hash) - if origin != nil { - query.Origin.Number = origin.Number.Uint64() - } - } else { - origin = pm.blockchain.GetHeader(query.Origin.Hash, query.Origin.Number) - } - } else { - origin = pm.blockchain.GetHeaderByNumber(query.Origin.Number) - } - if origin == nil { - atomic.AddUint32(&p.invalidCount, 1) - break - } - headers = append(headers, origin) - bytes += estHeaderRlpSize - - // Advance to the next header of the query - switch { - case hashMode && query.Reverse: - // Hash based traversal towards the genesis block - ancestor := query.Skip + 1 - if ancestor == 0 { - unknown = true - } else { - query.Origin.Hash, query.Origin.Number = pm.blockchain.GetAncestor(query.Origin.Hash, query.Origin.Number, ancestor, &maxNonCanonical) - unknown = (query.Origin.Hash == common.Hash{}) - } - case hashMode && !query.Reverse: - // Hash based traversal towards the leaf block - var ( - current = origin.Number.Uint64() - next = current + query.Skip + 1 - ) - if next <= current { - infos, _ := json.MarshalIndent(p.Peer.Info(), "", " ") - p.Log().Warn("GetBlockHeaders skip overflow attack", "current", current, "skip", query.Skip, "next", next, "attacker", infos) - unknown = true - } else { - if header := pm.blockchain.GetHeaderByNumber(next); header != nil { - nextHash := header.Hash() - expOldHash, _ := pm.blockchain.GetAncestor(nextHash, next, query.Skip+1, &maxNonCanonical) - if expOldHash == query.Origin.Hash { - query.Origin.Hash, query.Origin.Number = nextHash, next - } else { - unknown = true - } - } else { - unknown = true - } - } - case query.Reverse: - // Number based traversal towards the genesis block - if query.Origin.Number >= query.Skip+1 { - query.Origin.Number -= query.Skip + 1 - } else { - unknown = true - } - case !query.Reverse: - // Number based traversal towards the leaf block - query.Origin.Number += query.Skip + 1 - } - first = false - } - sendResponse(req.ReqID, query.Amount, p.ReplyBlockHeaders(req.ReqID, headers), task.done()) - }() - } - - case BlockHeadersMsg: - if pm.downloader == nil { - return errResp(ErrUnexpectedResponse, "") - } - - p.Log().Trace("Received block header response message") - // A batch of headers arrived to one of our previous requests - var resp struct { - ReqID, BV uint64 - Headers []*types.Header - } - if err := msg.Decode(&resp); err != nil { - return errResp(ErrDecode, "msg %v: %v", msg, err) - } - p.fcServer.ReceivedReply(resp.ReqID, resp.BV) - if pm.fetcher != nil && pm.fetcher.requestedID(resp.ReqID) { - pm.fetcher.deliverHeaders(p, resp.ReqID, resp.Headers) - } else { - err := pm.downloader.DeliverHeaders(p.id, resp.Headers) - if err != nil { - log.Debug(fmt.Sprint(err)) - } - } - - case GetBlockBodiesMsg: - p.Log().Trace("Received block bodies request") - // Decode the retrieval message - var req struct { - ReqID uint64 - Hashes []common.Hash - } - if err := msg.Decode(&req); err != nil { - return errResp(ErrDecode, "msg %v: %v", msg, err) - } - // Gather blocks until the fetch or network limits is reached - var ( - bytes int - bodies []rlp.RawValue - ) - reqCnt := len(req.Hashes) - if accept(req.ReqID, uint64(reqCnt), MaxBodyFetch) { - go func() { - for i, hash := range req.Hashes { - if i != 0 && !task.waitOrStop() { - sendResponse(req.ReqID, 0, nil, task.servingTime) - return - } - // Retrieve the requested block body, stopping if enough was found - if bytes >= softResponseLimit { - break - } - number := rawdb.ReadHeaderNumber(pm.chainDb, hash) - if number == nil { - atomic.AddUint32(&p.invalidCount, 1) - continue - } - if data := rawdb.ReadBodyRLP(pm.chainDb, hash, *number); len(data) != 0 { - bodies = append(bodies, data) - bytes += len(data) - } - } - sendResponse(req.ReqID, uint64(reqCnt), p.ReplyBlockBodiesRLP(req.ReqID, bodies), task.done()) - }() - } - - case BlockBodiesMsg: - if pm.odr == nil { - return errResp(ErrUnexpectedResponse, "") - } - - p.Log().Trace("Received block bodies response") - // A batch of block bodies arrived to one of our previous requests - var resp struct { - ReqID, BV uint64 - Data []*types.Body - } - if err := msg.Decode(&resp); err != nil { - return errResp(ErrDecode, "msg %v: %v", msg, err) - } - p.fcServer.ReceivedReply(resp.ReqID, resp.BV) - deliverMsg = &Msg{ - MsgType: MsgBlockBodies, - ReqID: resp.ReqID, - Obj: resp.Data, - } - - case GetCodeMsg: - p.Log().Trace("Received code request") - // Decode the retrieval message - var req struct { - ReqID uint64 - Reqs []CodeReq - } - if err := msg.Decode(&req); err != nil { - return errResp(ErrDecode, "msg %v: %v", msg, err) - } - // Gather state data until the fetch or network limits is reached - var ( - bytes int - data [][]byte - ) - reqCnt := len(req.Reqs) - if accept(req.ReqID, uint64(reqCnt), MaxCodeFetch) { - go func() { - for i, request := range req.Reqs { - if i != 0 && !task.waitOrStop() { - sendResponse(req.ReqID, 0, nil, task.servingTime) - return - } - // Look up the root hash belonging to the request - number := rawdb.ReadHeaderNumber(pm.chainDb, request.BHash) - if number == nil { - p.Log().Warn("Failed to retrieve block num for code", "hash", request.BHash) - atomic.AddUint32(&p.invalidCount, 1) - continue - } - header := rawdb.ReadHeader(pm.chainDb, request.BHash, *number) - if header == nil { - p.Log().Warn("Failed to retrieve header for code", "block", *number, "hash", request.BHash) - continue - } - // Refuse to search stale state data in the database since looking for - // a non-exist key is kind of expensive. - local := pm.blockchain.CurrentHeader().Number.Uint64() - if !pm.server.archiveMode && header.Number.Uint64()+core.TriesInMemory <= local { - p.Log().Debug("Reject stale code request", "number", header.Number.Uint64(), "head", local) - atomic.AddUint32(&p.invalidCount, 1) - continue - } - triedb := pm.blockchain.StateCache().TrieDB() - - account, err := pm.getAccount(triedb, header.Root, common.BytesToHash(request.AccKey)) - if err != nil { - p.Log().Warn("Failed to retrieve account for code", "block", header.Number, "hash", header.Hash(), "account", common.BytesToHash(request.AccKey), "err", err) - atomic.AddUint32(&p.invalidCount, 1) - continue - } - code, err := triedb.Node(common.BytesToHash(account.CodeHash)) - if err != nil { - p.Log().Warn("Failed to retrieve account code", "block", header.Number, "hash", header.Hash(), "account", common.BytesToHash(request.AccKey), "codehash", common.BytesToHash(account.CodeHash), "err", err) - continue - } - // Accumulate the code and abort if enough data was retrieved - data = append(data, code) - if bytes += len(code); bytes >= softResponseLimit { - break - } - } - sendResponse(req.ReqID, uint64(reqCnt), p.ReplyCode(req.ReqID, data), task.done()) - }() - } - - case CodeMsg: - if pm.odr == nil { - return errResp(ErrUnexpectedResponse, "") - } - - p.Log().Trace("Received code response") - // A batch of node state data arrived to one of our previous requests - var resp struct { - ReqID, BV uint64 - Data [][]byte - } - if err := msg.Decode(&resp); err != nil { - return errResp(ErrDecode, "msg %v: %v", msg, err) - } - p.fcServer.ReceivedReply(resp.ReqID, resp.BV) - deliverMsg = &Msg{ - MsgType: MsgCode, - ReqID: resp.ReqID, - Obj: resp.Data, - } - - case GetReceiptsMsg: - p.Log().Trace("Received receipts request") - // Decode the retrieval message - var req struct { - ReqID uint64 - Hashes []common.Hash - } - if err := msg.Decode(&req); err != nil { - return errResp(ErrDecode, "msg %v: %v", msg, err) - } - // Gather state data until the fetch or network limits is reached - var ( - bytes int - receipts []rlp.RawValue - ) - reqCnt := len(req.Hashes) - if accept(req.ReqID, uint64(reqCnt), MaxReceiptFetch) { - go func() { - for i, hash := range req.Hashes { - if i != 0 && !task.waitOrStop() { - sendResponse(req.ReqID, 0, nil, task.servingTime) - return - } - if bytes >= softResponseLimit { - break - } - // Retrieve the requested block's receipts, skipping if unknown to us - var results types.Receipts - number := rawdb.ReadHeaderNumber(pm.chainDb, hash) - if number == nil { - atomic.AddUint32(&p.invalidCount, 1) - continue - } - results = rawdb.ReadRawReceipts(pm.chainDb, hash, *number) - if results == nil { - if header := pm.blockchain.GetHeaderByHash(hash); header == nil || header.ReceiptHash != types.EmptyRootHash { - continue - } - } - // If known, encode and queue for response packet - if encoded, err := rlp.EncodeToBytes(results); err != nil { - log.Error("Failed to encode receipt", "err", err) - } else { - receipts = append(receipts, encoded) - bytes += len(encoded) - } - } - sendResponse(req.ReqID, uint64(reqCnt), p.ReplyReceiptsRLP(req.ReqID, receipts), task.done()) - }() - } - - case ReceiptsMsg: - if pm.odr == nil { - return errResp(ErrUnexpectedResponse, "") - } - - p.Log().Trace("Received receipts response") - // A batch of receipts arrived to one of our previous requests - var resp struct { - ReqID, BV uint64 - Receipts []types.Receipts - } - if err := msg.Decode(&resp); err != nil { - return errResp(ErrDecode, "msg %v: %v", msg, err) - } - p.fcServer.ReceivedReply(resp.ReqID, resp.BV) - deliverMsg = &Msg{ - MsgType: MsgReceipts, - ReqID: resp.ReqID, - Obj: resp.Receipts, - } - - case GetProofsV2Msg: - p.Log().Trace("Received les/2 proofs request") - // Decode the retrieval message - var req struct { - ReqID uint64 - Reqs []ProofReq - } - if err := msg.Decode(&req); err != nil { - return errResp(ErrDecode, "msg %v: %v", msg, err) - } - // Gather state data until the fetch or network limits is reached - var ( - lastBHash common.Hash - root common.Hash - ) - reqCnt := len(req.Reqs) - if accept(req.ReqID, uint64(reqCnt), MaxProofsFetch) { - go func() { - nodes := light.NewNodeSet() - - for i, request := range req.Reqs { - if i != 0 && !task.waitOrStop() { - sendResponse(req.ReqID, 0, nil, task.servingTime) - return - } - // Look up the root hash belonging to the request - var ( - number *uint64 - header *types.Header - trie state.Trie - ) - if request.BHash != lastBHash { - root, lastBHash = common.Hash{}, request.BHash - - if number = rawdb.ReadHeaderNumber(pm.chainDb, request.BHash); number == nil { - p.Log().Warn("Failed to retrieve block num for proof", "hash", request.BHash) - atomic.AddUint32(&p.invalidCount, 1) - continue - } - if header = rawdb.ReadHeader(pm.chainDb, request.BHash, *number); header == nil { - p.Log().Warn("Failed to retrieve header for proof", "block", *number, "hash", request.BHash) - continue - } - // Refuse to search stale state data in the database since looking for - // a non-exist key is kind of expensive. - local := pm.blockchain.CurrentHeader().Number.Uint64() - if !pm.server.archiveMode && header.Number.Uint64()+core.TriesInMemory <= local { - p.Log().Debug("Reject stale trie request", "number", header.Number.Uint64(), "head", local) - atomic.AddUint32(&p.invalidCount, 1) - continue - } - root = header.Root - } - // If a header lookup failed (non existent), ignore subsequent requests for the same header - if root == (common.Hash{}) { - atomic.AddUint32(&p.invalidCount, 1) - continue - } - // Open the account or storage trie for the request - statedb := pm.blockchain.StateCache() - - switch len(request.AccKey) { - case 0: - // No account key specified, open an account trie - trie, err = statedb.OpenTrie(root) - if trie == nil || err != nil { - p.Log().Warn("Failed to open storage trie for proof", "block", header.Number, "hash", header.Hash(), "root", root, "err", err) - continue - } - default: - // Account key specified, open a storage trie - account, err := pm.getAccount(statedb.TrieDB(), root, common.BytesToHash(request.AccKey)) - if err != nil { - p.Log().Warn("Failed to retrieve account for proof", "block", header.Number, "hash", header.Hash(), "account", common.BytesToHash(request.AccKey), "err", err) - atomic.AddUint32(&p.invalidCount, 1) - continue - } - trie, err = statedb.OpenStorageTrie(common.BytesToHash(request.AccKey), account.Root) - if trie == nil || err != nil { - p.Log().Warn("Failed to open storage trie for proof", "block", header.Number, "hash", header.Hash(), "account", common.BytesToHash(request.AccKey), "root", account.Root, "err", err) - continue - } - } - // Prove the user's request from the account or stroage trie - if err := trie.Prove(request.Key, request.FromLevel, nodes); err != nil { - p.Log().Warn("Failed to prove state request", "block", header.Number, "hash", header.Hash(), "err", err) - continue - } - if nodes.DataSize() >= softResponseLimit { - break - } - } - sendResponse(req.ReqID, uint64(reqCnt), p.ReplyProofsV2(req.ReqID, nodes.NodeList()), task.done()) - }() - } - - case ProofsV2Msg: - if pm.odr == nil { - return errResp(ErrUnexpectedResponse, "") - } - - p.Log().Trace("Received les/2 proofs response") - // A batch of merkle proofs arrived to one of our previous requests - var resp struct { - ReqID, BV uint64 - Data light.NodeList - } - if err := msg.Decode(&resp); err != nil { - return errResp(ErrDecode, "msg %v: %v", msg, err) - } - p.fcServer.ReceivedReply(resp.ReqID, resp.BV) - deliverMsg = &Msg{ - MsgType: MsgProofsV2, - ReqID: resp.ReqID, - Obj: resp.Data, - } - - case GetHelperTrieProofsMsg: - p.Log().Trace("Received helper trie proof request") - // Decode the retrieval message - var req struct { - ReqID uint64 - Reqs []HelperTrieReq - } - if err := msg.Decode(&req); err != nil { - return errResp(ErrDecode, "msg %v: %v", msg, err) - } - // Gather state data until the fetch or network limits is reached - var ( - auxBytes int - auxData [][]byte - ) - reqCnt := len(req.Reqs) - if accept(req.ReqID, uint64(reqCnt), MaxHelperTrieProofsFetch) { - go func() { - - var ( - lastIdx uint64 - lastType uint - root common.Hash - auxTrie *trie.Trie - ) - nodes := light.NewNodeSet() - for i, request := range req.Reqs { - if i != 0 && !task.waitOrStop() { - sendResponse(req.ReqID, 0, nil, task.servingTime) - return - } - if auxTrie == nil || request.Type != lastType || request.TrieIdx != lastIdx { - auxTrie, lastType, lastIdx = nil, request.Type, request.TrieIdx - - var prefix string - if root, prefix = pm.getHelperTrie(request.Type, request.TrieIdx); root != (common.Hash{}) { - auxTrie, _ = trie.New(root, trie.NewDatabase(rawdb.NewTable(pm.chainDb, prefix))) - } - } - if request.AuxReq == auxRoot { - var data []byte - if root != (common.Hash{}) { - data = root[:] - } - auxData = append(auxData, data) - auxBytes += len(data) - } else { - if auxTrie != nil { - auxTrie.Prove(request.Key, request.FromLevel, nodes) - } - if request.AuxReq != 0 { - data := pm.getHelperTrieAuxData(request) - auxData = append(auxData, data) - auxBytes += len(data) - } - } - if nodes.DataSize()+auxBytes >= softResponseLimit { - break - } - } - sendResponse(req.ReqID, uint64(reqCnt), p.ReplyHelperTrieProofs(req.ReqID, HelperTrieResps{Proofs: nodes.NodeList(), AuxData: auxData}), task.done()) - }() - } - - case HelperTrieProofsMsg: - if pm.odr == nil { - return errResp(ErrUnexpectedResponse, "") - } - - p.Log().Trace("Received helper trie proof response") - var resp struct { - ReqID, BV uint64 - Data HelperTrieResps - } - if err := msg.Decode(&resp); err != nil { - return errResp(ErrDecode, "msg %v: %v", msg, err) - } - - p.fcServer.ReceivedReply(resp.ReqID, resp.BV) - deliverMsg = &Msg{ - MsgType: MsgHelperTrieProofs, - ReqID: resp.ReqID, - Obj: resp.Data, - } - - case SendTxV2Msg: - if pm.txpool == nil { - return errResp(ErrRequestRejected, "") - } - // Transactions arrived, parse all of them and deliver to the pool - var req struct { - ReqID uint64 - Txs []*types.Transaction - } - if err := msg.Decode(&req); err != nil { - return errResp(ErrDecode, "msg %v: %v", msg, err) - } - reqCnt := len(req.Txs) - if accept(req.ReqID, uint64(reqCnt), MaxTxSend) { - go func() { - stats := make([]light.TxStatus, len(req.Txs)) - for i, tx := range req.Txs { - if i != 0 && !task.waitOrStop() { - sendResponse(req.ReqID, 0, nil, task.servingTime) - return - } - hash := tx.Hash() - stats[i] = pm.txStatus(hash) - if stats[i].Status == core.TxStatusUnknown { - addFn := pm.txpool.AddRemotes - // Add txs synchronously for testing purpose - if pm.addTxsSync { - addFn = pm.txpool.AddRemotesSync - } - if errs := addFn([]*types.Transaction{tx}); errs[0] != nil { - stats[i].Error = errs[0].Error() - continue - } - stats[i] = pm.txStatus(hash) - } - } - sendResponse(req.ReqID, uint64(reqCnt), p.ReplyTxStatus(req.ReqID, stats), task.done()) - }() - } - - case GetTxStatusMsg: - if pm.txpool == nil { - return errResp(ErrUnexpectedResponse, "") - } - // Transactions arrived, parse all of them and deliver to the pool - var req struct { - ReqID uint64 - Hashes []common.Hash - } - if err := msg.Decode(&req); err != nil { - return errResp(ErrDecode, "msg %v: %v", msg, err) - } - reqCnt := len(req.Hashes) - if accept(req.ReqID, uint64(reqCnt), MaxTxStatus) { - go func() { - stats := make([]light.TxStatus, len(req.Hashes)) - for i, hash := range req.Hashes { - if i != 0 && !task.waitOrStop() { - sendResponse(req.ReqID, 0, nil, task.servingTime) - return - } - stats[i] = pm.txStatus(hash) - } - sendResponse(req.ReqID, uint64(reqCnt), p.ReplyTxStatus(req.ReqID, stats), task.done()) - }() - } - - case TxStatusMsg: - if pm.odr == nil { - return errResp(ErrUnexpectedResponse, "") - } - - p.Log().Trace("Received tx status response") - var resp struct { - ReqID, BV uint64 - Status []light.TxStatus - } - if err := msg.Decode(&resp); err != nil { - return errResp(ErrDecode, "msg %v: %v", msg, err) - } - - p.fcServer.ReceivedReply(resp.ReqID, resp.BV) - - p.Log().Trace("Received helper trie proof response") - deliverMsg = &Msg{ - MsgType: MsgTxStatus, - ReqID: resp.ReqID, - Obj: resp.Status, - } - - case StopMsg: - if pm.odr == nil { - return errResp(ErrUnexpectedResponse, "") - } - p.freezeServer(true) - pm.retriever.frozen(p) - p.Log().Debug("Service stopped") - - case ResumeMsg: - if pm.odr == nil { - return errResp(ErrUnexpectedResponse, "") - } - var bv uint64 - if err := msg.Decode(&bv); err != nil { - return errResp(ErrDecode, "msg %v: %v", msg, err) - } - p.fcServer.ResumeFreeze(bv) - p.freezeServer(false) - p.Log().Debug("Service resumed") - - default: - p.Log().Trace("Received unknown message", "code", msg.Code) - return errResp(ErrInvalidMsgCode, "%v", msg.Code) - } - - if deliverMsg != nil { - err := pm.retriever.deliver(p, deliverMsg) - if err != nil { - p.responseErrors++ - if p.responseErrors > maxResponseErrors { - return err - } - } - } - // If the client has made too much invalid request(e.g. request a non-exist data), - // reject them to prevent SPAM attack. - if atomic.LoadUint32(&p.invalidCount) > maxRequestErrors { - return errTooManyInvalidRequest - } - return nil -} - -// getAccount retrieves an account from the state based at root. -func (pm *ProtocolManager) getAccount(triedb *trie.Database, root, hash common.Hash) (state.Account, error) { - trie, err := trie.New(root, triedb) - if err != nil { - return state.Account{}, err - } - blob, err := trie.TryGet(hash[:]) - if err != nil { - return state.Account{}, err - } - var account state.Account - if err = rlp.DecodeBytes(blob, &account); err != nil { - return state.Account{}, err - } - return account, nil -} - -// getHelperTrie returns the post-processed trie root for the given trie ID and section index -func (pm *ProtocolManager) getHelperTrie(id uint, idx uint64) (common.Hash, string) { - switch id { - case htCanonical: - sectionHead := rawdb.ReadCanonicalHash(pm.chainDb, (idx+1)*pm.iConfig.ChtSize-1) - return light.GetChtRoot(pm.chainDb, idx, sectionHead), light.ChtTablePrefix - case htBloomBits: - sectionHead := rawdb.ReadCanonicalHash(pm.chainDb, (idx+1)*pm.iConfig.BloomTrieSize-1) - return light.GetBloomTrieRoot(pm.chainDb, idx, sectionHead), light.BloomTrieTablePrefix - } - return common.Hash{}, "" -} - -// getHelperTrieAuxData returns requested auxiliary data for the given HelperTrie request -func (pm *ProtocolManager) getHelperTrieAuxData(req HelperTrieReq) []byte { - if req.Type == htCanonical && req.AuxReq == auxHeader && len(req.Key) == 8 { - blockNum := binary.BigEndian.Uint64(req.Key) - hash := rawdb.ReadCanonicalHash(pm.chainDb, blockNum) - return rawdb.ReadHeaderRLP(pm.chainDb, hash, blockNum) - } - return nil -} - -func (pm *ProtocolManager) txStatus(hash common.Hash) light.TxStatus { - var stat light.TxStatus - stat.Status = pm.txpool.Status([]common.Hash{hash})[0] - // If the transaction is unknown to the pool, try looking it up locally - if stat.Status == core.TxStatusUnknown { - if tx, blockHash, blockNumber, txIndex := rawdb.ReadTransaction(pm.chainDb, hash); tx != nil { - stat.Status = core.TxStatusIncluded - stat.Lookup = &rawdb.LegacyTxLookupEntry{BlockHash: blockHash, BlockIndex: blockNumber, Index: txIndex} - } - } - return stat -} - -// downloaderPeerNotify implements peerSetNotify -type downloaderPeerNotify ProtocolManager - -type peerConnection struct { - manager *ProtocolManager - peer *peer -} - -func (pc *peerConnection) Head() (common.Hash, *big.Int) { - return pc.peer.HeadAndTd() -} - -func (pc *peerConnection) RequestHeadersByHash(origin common.Hash, amount int, skip int, reverse bool) error { - reqID := genReqID() - rq := &distReq{ - getCost: func(dp distPeer) uint64 { - peer := dp.(*peer) - return peer.GetRequestCost(GetBlockHeadersMsg, amount) - }, - canSend: func(dp distPeer) bool { - return dp.(*peer) == pc.peer - }, - request: func(dp distPeer) func() { - peer := dp.(*peer) - cost := peer.GetRequestCost(GetBlockHeadersMsg, amount) - peer.fcServer.QueuedRequest(reqID, cost) - return func() { peer.RequestHeadersByHash(reqID, cost, origin, amount, skip, reverse) } - }, - } - _, ok := <-pc.manager.reqDist.queue(rq) - if !ok { - return light.ErrNoPeers - } - return nil -} - -func (pc *peerConnection) RequestHeadersByNumber(origin uint64, amount int, skip int, reverse bool) error { - reqID := genReqID() - rq := &distReq{ - getCost: func(dp distPeer) uint64 { - peer := dp.(*peer) - return peer.GetRequestCost(GetBlockHeadersMsg, amount) - }, - canSend: func(dp distPeer) bool { - return dp.(*peer) == pc.peer - }, - request: func(dp distPeer) func() { - peer := dp.(*peer) - cost := peer.GetRequestCost(GetBlockHeadersMsg, amount) - peer.fcServer.QueuedRequest(reqID, cost) - return func() { peer.RequestHeadersByNumber(reqID, cost, origin, amount, skip, reverse) } - }, - } - _, ok := <-pc.manager.reqDist.queue(rq) - if !ok { - return light.ErrNoPeers - } - return nil -} - -func (d *downloaderPeerNotify) registerPeer(p *peer) { - pm := (*ProtocolManager)(d) - pc := &peerConnection{ - manager: pm, - peer: p, - } - pm.downloader.RegisterLightPeer(p.id, ethVersion, pc) -} - -func (d *downloaderPeerNotify) unregisterPeer(p *peer) { - pm := (*ProtocolManager)(d) - pm.downloader.UnregisterPeer(p.id) -} diff --git a/les/handler_test.go b/les/handler_test.go index dae583f6d..aad8d18e4 100644 --- a/les/handler_test.go +++ b/les/handler_test.go @@ -48,11 +48,13 @@ func expectResponse(r p2p.MsgReader, msgcode, reqID, bv uint64, data interface{} // Tests that block headers can be retrieved from a remote chain based on user queries. func TestGetBlockHeadersLes2(t *testing.T) { testGetBlockHeaders(t, 2) } +func TestGetBlockHeadersLes3(t *testing.T) { testGetBlockHeaders(t, 3) } func testGetBlockHeaders(t *testing.T, protocol int) { - server, tearDown := newServerEnv(t, downloader.MaxHashFetch+15, protocol, nil) + server, tearDown := newServerEnv(t, downloader.MaxHashFetch+15, protocol, nil, false, true, 0) defer tearDown() - bc := server.pm.blockchain.(*core.BlockChain) + + bc := server.handler.blockchain // Create a "random" unknown hash for testing var unknown common.Hash @@ -114,10 +116,10 @@ func testGetBlockHeaders(t *testing.T, protocol int) { []common.Hash{bc.CurrentBlock().Hash()}, }, // Ensure protocol limits are honored - /*{ - &getBlockHeadersData{Origin: hashOrNumber{Number: bc.CurrentBlock().NumberU64() - 1}, Amount: limit + 10, Reverse: true}, - bc.GetBlockHashesFromHash(bc.CurrentBlock().Hash(), limit), - },*/ + //{ + // &getBlockHeadersData{Origin: hashOrNumber{Number: bc.CurrentBlock().NumberU64() - 1}, Amount: limit + 10, Reverse: true}, + // []common.Hash{}, + //}, // Check that requesting more than available is handled gracefully { &getBlockHeadersData{Origin: hashOrNumber{Number: bc.CurrentBlock().NumberU64() - 4}, Skip: 3, Amount: 3}, @@ -165,9 +167,10 @@ func testGetBlockHeaders(t *testing.T, protocol int) { } // Send the hash request and verify the response reqID++ - cost := server.tPeer.GetRequestCost(GetBlockHeadersMsg, int(tt.query.Amount)) - sendRequest(server.tPeer.app, GetBlockHeadersMsg, reqID, cost, tt.query) - if err := expectResponse(server.tPeer.app, BlockHeadersMsg, reqID, testBufLimit, headers); err != nil { + + cost := server.peer.peer.GetRequestCost(GetBlockHeadersMsg, int(tt.query.Amount)) + sendRequest(server.peer.app, GetBlockHeadersMsg, reqID, cost, tt.query) + if err := expectResponse(server.peer.app, BlockHeadersMsg, reqID, testBufLimit, headers); err != nil { t.Errorf("test %d: headers mismatch: %v", i, err) } } @@ -175,11 +178,13 @@ func testGetBlockHeaders(t *testing.T, protocol int) { // Tests that block contents can be retrieved from a remote chain based on their hashes. func TestGetBlockBodiesLes2(t *testing.T) { testGetBlockBodies(t, 2) } +func TestGetBlockBodiesLes3(t *testing.T) { testGetBlockBodies(t, 3) } func testGetBlockBodies(t *testing.T, protocol int) { - server, tearDown := newServerEnv(t, downloader.MaxBlockFetch+15, protocol, nil) + server, tearDown := newServerEnv(t, downloader.MaxBlockFetch+15, protocol, nil, false, true, 0) defer tearDown() - bc := server.pm.blockchain.(*core.BlockChain) + + bc := server.handler.blockchain // Create a batch of tests for various scenarios limit := MaxBodyFetch @@ -239,10 +244,11 @@ func testGetBlockBodies(t *testing.T, protocol int) { } } reqID++ + // Send the hash request and verify the response - cost := server.tPeer.GetRequestCost(GetBlockBodiesMsg, len(hashes)) - sendRequest(server.tPeer.app, GetBlockBodiesMsg, reqID, cost, hashes) - if err := expectResponse(server.tPeer.app, BlockBodiesMsg, reqID, testBufLimit, bodies); err != nil { + cost := server.peer.peer.GetRequestCost(GetBlockBodiesMsg, len(hashes)) + sendRequest(server.peer.app, GetBlockBodiesMsg, reqID, cost, hashes) + if err := expectResponse(server.peer.app, BlockBodiesMsg, reqID, testBufLimit, bodies); err != nil { t.Errorf("test %d: bodies mismatch: %v", i, err) } } @@ -250,12 +256,13 @@ func testGetBlockBodies(t *testing.T, protocol int) { // Tests that the contract codes can be retrieved based on account addresses. func TestGetCodeLes2(t *testing.T) { testGetCode(t, 2) } +func TestGetCodeLes3(t *testing.T) { testGetCode(t, 3) } func testGetCode(t *testing.T, protocol int) { // Assemble the test environment - server, tearDown := newServerEnv(t, 4, protocol, nil) + server, tearDown := newServerEnv(t, 4, protocol, nil, false, true, 0) defer tearDown() - bc := server.pm.blockchain.(*core.BlockChain) + bc := server.handler.blockchain var codereqs []*CodeReq var codes [][]byte @@ -271,9 +278,9 @@ func testGetCode(t *testing.T, protocol int) { } } - cost := server.tPeer.GetRequestCost(GetCodeMsg, len(codereqs)) - sendRequest(server.tPeer.app, GetCodeMsg, 42, cost, codereqs) - if err := expectResponse(server.tPeer.app, CodeMsg, 42, testBufLimit, codes); err != nil { + cost := server.peer.peer.GetRequestCost(GetCodeMsg, len(codereqs)) + sendRequest(server.peer.app, GetCodeMsg, 42, cost, codereqs) + if err := expectResponse(server.peer.app, CodeMsg, 42, testBufLimit, codes); err != nil { t.Errorf("codes mismatch: %v", err) } } @@ -283,18 +290,18 @@ func TestGetStaleCodeLes2(t *testing.T) { testGetStaleCode(t, 2) } func TestGetStaleCodeLes3(t *testing.T) { testGetStaleCode(t, 3) } func testGetStaleCode(t *testing.T, protocol int) { - server, tearDown := newServerEnv(t, core.TriesInMemory+4, protocol, nil) + server, tearDown := newServerEnv(t, core.TriesInMemory+4, protocol, nil, false, true, 0) defer tearDown() - bc := server.pm.blockchain.(*core.BlockChain) + bc := server.handler.blockchain check := func(number uint64, expected [][]byte) { req := &CodeReq{ BHash: bc.GetHeaderByNumber(number).Hash(), AccKey: crypto.Keccak256(testContractAddr[:]), } - cost := server.tPeer.GetRequestCost(GetCodeMsg, 1) - sendRequest(server.tPeer.app, GetCodeMsg, 42, cost, []*CodeReq{req}) - if err := expectResponse(server.tPeer.app, CodeMsg, 42, testBufLimit, expected); err != nil { + cost := server.peer.peer.GetRequestCost(GetCodeMsg, 1) + sendRequest(server.peer.app, GetCodeMsg, 42, cost, []*CodeReq{req}) + if err := expectResponse(server.peer.app, CodeMsg, 42, testBufLimit, expected); err != nil { t.Errorf("codes mismatch: %v", err) } } @@ -305,12 +312,14 @@ func testGetStaleCode(t *testing.T, protocol int) { // Tests that the transaction receipts can be retrieved based on hashes. func TestGetReceiptLes2(t *testing.T) { testGetReceipt(t, 2) } +func TestGetReceiptLes3(t *testing.T) { testGetReceipt(t, 3) } func testGetReceipt(t *testing.T, protocol int) { // Assemble the test environment - server, tearDown := newServerEnv(t, 4, protocol, nil) + server, tearDown := newServerEnv(t, 4, protocol, nil, false, true, 0) defer tearDown() - bc := server.pm.blockchain.(*core.BlockChain) + + bc := server.handler.blockchain // Collect the hashes to request, and the response to expect var receipts []types.Receipts @@ -322,26 +331,28 @@ func testGetReceipt(t *testing.T, protocol int) { receipts = append(receipts, rawdb.ReadRawReceipts(server.db, block.Hash(), block.NumberU64())) } // Send the hash request and verify the response - cost := server.tPeer.GetRequestCost(GetReceiptsMsg, len(hashes)) - sendRequest(server.tPeer.app, GetReceiptsMsg, 42, cost, hashes) - if err := expectResponse(server.tPeer.app, ReceiptsMsg, 42, testBufLimit, receipts); err != nil { + cost := server.peer.peer.GetRequestCost(GetReceiptsMsg, len(hashes)) + sendRequest(server.peer.app, GetReceiptsMsg, 42, cost, hashes) + if err := expectResponse(server.peer.app, ReceiptsMsg, 42, testBufLimit, receipts); err != nil { t.Errorf("receipts mismatch: %v", err) } } // Tests that trie merkle proofs can be retrieved func TestGetProofsLes2(t *testing.T) { testGetProofs(t, 2) } +func TestGetProofsLes3(t *testing.T) { testGetProofs(t, 3) } func testGetProofs(t *testing.T, protocol int) { // Assemble the test environment - server, tearDown := newServerEnv(t, 4, protocol, nil) + server, tearDown := newServerEnv(t, 4, protocol, nil, false, true, 0) defer tearDown() - bc := server.pm.blockchain.(*core.BlockChain) + + bc := server.handler.blockchain var proofreqs []ProofReq proofsV2 := light.NewNodeSet() - accounts := []common.Address{bankAddr, userAddr1, userAddr2, {}} + accounts := []common.Address{bankAddr, userAddr1, userAddr2, signerAddr, {}} for i := uint64(0); i <= bc.CurrentBlock().NumberU64(); i++ { header := bc.GetHeaderByNumber(i) trie, _ := trie.New(header.Root, trie.NewDatabase(server.db)) @@ -356,9 +367,9 @@ func testGetProofs(t *testing.T, protocol int) { } } // Send the proof request and verify the response - cost := server.tPeer.GetRequestCost(GetProofsV2Msg, len(proofreqs)) - sendRequest(server.tPeer.app, GetProofsV2Msg, 42, cost, proofreqs) - if err := expectResponse(server.tPeer.app, ProofsV2Msg, 42, testBufLimit, proofsV2.NodeList()); err != nil { + cost := server.peer.peer.GetRequestCost(GetProofsV2Msg, len(proofreqs)) + sendRequest(server.peer.app, GetProofsV2Msg, 42, cost, proofreqs) + if err := expectResponse(server.peer.app, ProofsV2Msg, 42, testBufLimit, proofsV2.NodeList()); err != nil { t.Errorf("proofs mismatch: %v", err) } } @@ -368,9 +379,9 @@ func TestGetStaleProofLes2(t *testing.T) { testGetStaleProof(t, 2) } func TestGetStaleProofLes3(t *testing.T) { testGetStaleProof(t, 3) } func testGetStaleProof(t *testing.T, protocol int) { - server, tearDown := newServerEnv(t, core.TriesInMemory+4, protocol, nil) + server, tearDown := newServerEnv(t, core.TriesInMemory+4, protocol, nil, false, true, 0) defer tearDown() - bc := server.pm.blockchain.(*core.BlockChain) + bc := server.handler.blockchain check := func(number uint64, wantOK bool) { var ( @@ -381,8 +392,8 @@ func testGetStaleProof(t *testing.T, protocol int) { BHash: header.Hash(), Key: account, } - cost := server.tPeer.GetRequestCost(GetProofsV2Msg, 1) - sendRequest(server.tPeer.app, GetProofsV2Msg, 42, cost, []*ProofReq{req}) + cost := server.peer.peer.GetRequestCost(GetProofsV2Msg, 1) + sendRequest(server.peer.app, GetProofsV2Msg, 42, cost, []*ProofReq{req}) var expected []rlp.RawValue if wantOK { @@ -391,7 +402,7 @@ func testGetStaleProof(t *testing.T, protocol int) { t.Prove(account, 0, proofsV2) expected = proofsV2.NodeList() } - if err := expectResponse(server.tPeer.app, ProofsV2Msg, 42, testBufLimit, expected); err != nil { + if err := expectResponse(server.peer.app, ProofsV2Msg, 42, testBufLimit, expected); err != nil { t.Errorf("codes mismatch: %v", err) } } @@ -402,6 +413,7 @@ func testGetStaleProof(t *testing.T, protocol int) { // Tests that CHT proofs can be correctly retrieved. func TestGetCHTProofsLes2(t *testing.T) { testGetCHTProofs(t, 2) } +func TestGetCHTProofsLes3(t *testing.T) { testGetCHTProofs(t, 3) } func testGetCHTProofs(t *testing.T, protocol int) { config := light.TestServerIndexerConfig @@ -415,9 +427,10 @@ func testGetCHTProofs(t *testing.T, protocol int) { time.Sleep(10 * time.Millisecond) } } - server, tearDown := newServerEnv(t, int(config.ChtSize+config.ChtConfirms), protocol, waitIndexers) + server, tearDown := newServerEnv(t, int(config.ChtSize+config.ChtConfirms), protocol, waitIndexers, false, true, 0) defer tearDown() - bc := server.pm.blockchain.(*core.BlockChain) + + bc := server.handler.blockchain // Assemble the proofs from the different protocols header := bc.GetHeaderByNumber(config.ChtSize - 1) @@ -440,15 +453,18 @@ func testGetCHTProofs(t *testing.T, protocol int) { AuxReq: auxHeader, }} // Send the proof request and verify the response - cost := server.tPeer.GetRequestCost(GetHelperTrieProofsMsg, len(requestsV2)) - sendRequest(server.tPeer.app, GetHelperTrieProofsMsg, 42, cost, requestsV2) - if err := expectResponse(server.tPeer.app, HelperTrieProofsMsg, 42, testBufLimit, proofsV2); err != nil { + cost := server.peer.peer.GetRequestCost(GetHelperTrieProofsMsg, len(requestsV2)) + sendRequest(server.peer.app, GetHelperTrieProofsMsg, 42, cost, requestsV2) + if err := expectResponse(server.peer.app, HelperTrieProofsMsg, 42, testBufLimit, proofsV2); err != nil { t.Errorf("proofs mismatch: %v", err) } } +func TestGetBloombitsProofsLes2(t *testing.T) { testGetBloombitsProofs(t, 2) } +func TestGetBloombitsProofsLes3(t *testing.T) { testGetBloombitsProofs(t, 3) } + // Tests that bloombits proofs can be correctly retrieved. -func TestGetBloombitsProofs(t *testing.T) { +func testGetBloombitsProofs(t *testing.T, protocol int) { config := light.TestServerIndexerConfig waitIndexers := func(cIndexer, bIndexer, btIndexer *core.ChainIndexer) { @@ -460,9 +476,10 @@ func TestGetBloombitsProofs(t *testing.T) { time.Sleep(10 * time.Millisecond) } } - server, tearDown := newServerEnv(t, int(config.BloomTrieSize+config.BloomTrieConfirms), 2, waitIndexers) + server, tearDown := newServerEnv(t, int(config.BloomTrieSize+config.BloomTrieConfirms), protocol, waitIndexers, false, true, 0) defer tearDown() - bc := server.pm.blockchain.(*core.BlockChain) + + bc := server.handler.blockchain // Request and verify each bit of the bloom bits proofs for bit := 0; bit < 2048; bit++ { @@ -485,43 +502,39 @@ func TestGetBloombitsProofs(t *testing.T) { trie.Prove(key, 0, &proofs.Proofs) // Send the proof request and verify the response - cost := server.tPeer.GetRequestCost(GetHelperTrieProofsMsg, len(requests)) - sendRequest(server.tPeer.app, GetHelperTrieProofsMsg, 42, cost, requests) - if err := expectResponse(server.tPeer.app, HelperTrieProofsMsg, 42, testBufLimit, proofs); err != nil { + cost := server.peer.peer.GetRequestCost(GetHelperTrieProofsMsg, len(requests)) + sendRequest(server.peer.app, GetHelperTrieProofsMsg, 42, cost, requests) + if err := expectResponse(server.peer.app, HelperTrieProofsMsg, 42, testBufLimit, proofs); err != nil { t.Errorf("bit %d: proofs mismatch: %v", bit, err) } } } -func TestTransactionStatusLes2(t *testing.T) { - server, tearDown := newServerEnv(t, 0, 2, nil) - defer tearDown() - server.pm.addTxsSync = true +func TestTransactionStatusLes2(t *testing.T) { testTransactionStatus(t, 2) } +func TestTransactionStatusLes3(t *testing.T) { testTransactionStatus(t, 3) } - chain := server.pm.blockchain.(*core.BlockChain) - config := core.DefaultTxPoolConfig - config.Journal = "" - txpool := core.NewTxPool(config, params.TestChainConfig, chain) - server.pm.txpool = txpool - peer, _ := newTestPeer(t, "peer", 2, server.pm, true, 0) - defer peer.close() +func testTransactionStatus(t *testing.T, protocol int) { + server, tearDown := newServerEnv(t, 0, protocol, nil, false, true, 0) + defer tearDown() + server.handler.addTxsSync = true + + chain := server.handler.blockchain var reqID uint64 test := func(tx *types.Transaction, send bool, expStatus light.TxStatus) { reqID++ if send { - cost := server.tPeer.GetRequestCost(SendTxV2Msg, 1) - sendRequest(server.tPeer.app, SendTxV2Msg, reqID, cost, types.Transactions{tx}) + cost := server.peer.peer.GetRequestCost(SendTxV2Msg, 1) + sendRequest(server.peer.app, SendTxV2Msg, reqID, cost, types.Transactions{tx}) } else { - cost := server.tPeer.GetRequestCost(GetTxStatusMsg, 1) - sendRequest(server.tPeer.app, GetTxStatusMsg, reqID, cost, []common.Hash{tx.Hash()}) + cost := server.peer.peer.GetRequestCost(GetTxStatusMsg, 1) + sendRequest(server.peer.app, GetTxStatusMsg, reqID, cost, []common.Hash{tx.Hash()}) } - if err := expectResponse(server.tPeer.app, TxStatusMsg, reqID, testBufLimit, []light.TxStatus{expStatus}); err != nil { + if err := expectResponse(server.peer.app, TxStatusMsg, reqID, testBufLimit, []light.TxStatus{expStatus}); err != nil { t.Errorf("transaction status mismatch") } } - signer := types.HomesteadSigner{} // test error status by sending an underpriced transaction @@ -551,18 +564,22 @@ func TestTransactionStatusLes2(t *testing.T) { } // wait until TxPool processes the inserted block for i := 0; i < 10; i++ { - if pending, _ := txpool.Stats(); pending == 1 { + if pending, _ := server.handler.txpool.Stats(); pending == 1 { break } time.Sleep(100 * time.Millisecond) } - if pending, _ := txpool.Stats(); pending != 1 { + if pending, _ := server.handler.txpool.Stats(); pending != 1 { t.Fatalf("pending count mismatch: have %d, want 1", pending) } + // Discard new block announcement + msg, _ := server.peer.app.ReadMsg() + msg.Discard() // check if their status is included now block1hash := rawdb.ReadCanonicalHash(server.db, 1) test(tx1, false, light.TxStatus{Status: core.TxStatusIncluded, Lookup: &rawdb.LegacyTxLookupEntry{BlockHash: block1hash, BlockIndex: 1, Index: 0}}) + test(tx2, false, light.TxStatus{Status: core.TxStatusIncluded, Lookup: &rawdb.LegacyTxLookupEntry{BlockHash: block1hash, BlockIndex: 1, Index: 1}}) // create a reorg that rolls them back @@ -572,46 +589,46 @@ func TestTransactionStatusLes2(t *testing.T) { } // wait until TxPool processes the reorg for i := 0; i < 10; i++ { - if pending, _ := txpool.Stats(); pending == 3 { + if pending, _ := server.handler.txpool.Stats(); pending == 3 { break } time.Sleep(100 * time.Millisecond) } - if pending, _ := txpool.Stats(); pending != 3 { + if pending, _ := server.handler.txpool.Stats(); pending != 3 { t.Fatalf("pending count mismatch: have %d, want 3", pending) } + // Discard new block announcement + msg, _ = server.peer.app.ReadMsg() + msg.Discard() + // check if their status is pending again test(tx1, false, light.TxStatus{Status: core.TxStatusPending}) test(tx2, false, light.TxStatus{Status: core.TxStatusPending}) } func TestStopResumeLes3(t *testing.T) { - db := rawdb.NewMemoryDatabase() - clock := &mclock.Simulated{} - testCost := testBufLimit / 10 - pm, _, err := newTestProtocolManager(false, 0, nil, nil, nil, db, nil, 0, testCost, clock) - if err != nil { - t.Fatalf("Failed to create protocol manager: %v", err) - } - peer, _ := newTestPeer(t, "peer", 3, pm, true, testCost) - defer peer.close() + server, tearDown := newServerEnv(t, 0, 3, nil, true, true, testBufLimit/10) + defer tearDown() - expBuf := testBufLimit - var reqID uint64 + server.handler.server.costTracker.testing = true - header := pm.blockchain.CurrentHeader() + var ( + reqID uint64 + expBuf = testBufLimit + testCost = testBufLimit / 10 + ) + header := server.handler.blockchain.CurrentHeader() req := func() { reqID++ - sendRequest(peer.app, GetBlockHeadersMsg, reqID, testCost, &getBlockHeadersData{Origin: hashOrNumber{Hash: header.Hash()}, Amount: 1}) + sendRequest(server.peer.app, GetBlockHeadersMsg, reqID, testCost, &getBlockHeadersData{Origin: hashOrNumber{Hash: header.Hash()}, Amount: 1}) } - for i := 1; i <= 5; i++ { // send requests while we still have enough buffer and expect a response for expBuf >= testCost { req() expBuf -= testCost - if err := expectResponse(peer.app, BlockHeadersMsg, reqID, expBuf, []*types.Header{header}); err != nil { - t.Fatalf("expected response and failed: %v", err) + if err := expectResponse(server.peer.app, BlockHeadersMsg, reqID, expBuf, []*types.Header{header}); err != nil { + t.Errorf("expected response and failed: %v", err) } } // send some more requests in excess and expect a single StopMsg @@ -620,15 +637,16 @@ func TestStopResumeLes3(t *testing.T) { req() c-- } - if err := p2p.ExpectMsg(peer.app, StopMsg, nil); err != nil { + if err := p2p.ExpectMsg(server.peer.app, StopMsg, nil); err != nil { t.Errorf("expected StopMsg and failed: %v", err) } // wait until the buffer is recharged by half of the limit wait := testBufLimit / testBufRecharge / 2 - clock.Run(time.Millisecond * time.Duration(wait)) + server.clock.(*mclock.Simulated).Run(time.Millisecond * time.Duration(wait)) + // expect a ResumeMsg with the partially recharged buffer value expBuf += testBufRecharge * wait - if err := p2p.ExpectMsg(peer.app, ResumeMsg, expBuf); err != nil { + if err := p2p.ExpectMsg(server.peer.app, ResumeMsg, expBuf); err != nil { t.Errorf("expected ResumeMsg and failed: %v", err) } } diff --git a/les/metrics.go b/les/metrics.go index 4fe703116..797631b8e 100644 --- a/les/metrics.go +++ b/les/metrics.go @@ -22,31 +22,73 @@ import ( ) var ( - miscInPacketsMeter = metrics.NewRegisteredMeter("les/misc/in/packets", nil) - miscInTrafficMeter = metrics.NewRegisteredMeter("les/misc/in/traffic", nil) - miscOutPacketsMeter = metrics.NewRegisteredMeter("les/misc/out/packets", nil) - miscOutTrafficMeter = metrics.NewRegisteredMeter("les/misc/out/traffic", nil) + miscInPacketsMeter = metrics.NewRegisteredMeter("les/misc/in/packets/total", nil) + miscInTrafficMeter = metrics.NewRegisteredMeter("les/misc/in/traffic/total", nil) + miscInHeaderPacketsMeter = metrics.NewRegisteredMeter("les/misc/in/packets/header", nil) + miscInHeaderTrafficMeter = metrics.NewRegisteredMeter("les/misc/in/traffic/header", nil) + miscInBodyPacketsMeter = metrics.NewRegisteredMeter("les/misc/in/packets/body", nil) + miscInBodyTrafficMeter = metrics.NewRegisteredMeter("les/misc/in/traffic/body", nil) + miscInCodePacketsMeter = metrics.NewRegisteredMeter("les/misc/in/packets/code", nil) + miscInCodeTrafficMeter = metrics.NewRegisteredMeter("les/misc/in/traffic/code", nil) + miscInReceiptPacketsMeter = metrics.NewRegisteredMeter("les/misc/in/packets/receipt", nil) + miscInReceiptTrafficMeter = metrics.NewRegisteredMeter("les/misc/in/traffic/receipt", nil) + miscInTrieProofPacketsMeter = metrics.NewRegisteredMeter("les/misc/in/packets/proof", nil) + miscInTrieProofTrafficMeter = metrics.NewRegisteredMeter("les/misc/in/traffic/proof", nil) + miscInHelperTriePacketsMeter = metrics.NewRegisteredMeter("les/misc/in/packets/helperTrie", nil) + miscInHelperTrieTrafficMeter = metrics.NewRegisteredMeter("les/misc/in/traffic/helperTrie", nil) + miscInTxsPacketsMeter = metrics.NewRegisteredMeter("les/misc/in/packets/txs", nil) + miscInTxsTrafficMeter = metrics.NewRegisteredMeter("les/misc/in/traffic/txs", nil) + miscInTxStatusPacketsMeter = metrics.NewRegisteredMeter("les/misc/in/packets/txStatus", nil) + miscInTxStatusTrafficMeter = metrics.NewRegisteredMeter("les/misc/in/traffic/txStatus", nil) - connectionTimer = metrics.NewRegisteredTimer("les/connectionTime", nil) + miscOutPacketsMeter = metrics.NewRegisteredMeter("les/misc/out/packets/total", nil) + miscOutTrafficMeter = metrics.NewRegisteredMeter("les/misc/out/traffic/total", nil) + miscOutHeaderPacketsMeter = metrics.NewRegisteredMeter("les/misc/out/packets/header", nil) + miscOutHeaderTrafficMeter = metrics.NewRegisteredMeter("les/misc/out/traffic/header", nil) + miscOutBodyPacketsMeter = metrics.NewRegisteredMeter("les/misc/out/packets/body", nil) + miscOutBodyTrafficMeter = metrics.NewRegisteredMeter("les/misc/out/traffic/body", nil) + miscOutCodePacketsMeter = metrics.NewRegisteredMeter("les/misc/out/packets/code", nil) + miscOutCodeTrafficMeter = metrics.NewRegisteredMeter("les/misc/out/traffic/code", nil) + miscOutReceiptPacketsMeter = metrics.NewRegisteredMeter("les/misc/out/packets/receipt", nil) + miscOutReceiptTrafficMeter = metrics.NewRegisteredMeter("les/misc/out/traffic/receipt", nil) + miscOutTrieProofPacketsMeter = metrics.NewRegisteredMeter("les/misc/out/packets/proof", nil) + miscOutTrieProofTrafficMeter = metrics.NewRegisteredMeter("les/misc/out/traffic/proof", nil) + miscOutHelperTriePacketsMeter = metrics.NewRegisteredMeter("les/misc/out/packets/helperTrie", nil) + miscOutHelperTrieTrafficMeter = metrics.NewRegisteredMeter("les/misc/out/traffic/helperTrie", nil) + miscOutTxsPacketsMeter = metrics.NewRegisteredMeter("les/misc/out/packets/txs", nil) + miscOutTxsTrafficMeter = metrics.NewRegisteredMeter("les/misc/out/traffic/txs", nil) + miscOutTxStatusPacketsMeter = metrics.NewRegisteredMeter("les/misc/out/packets/txStatus", nil) + miscOutTxStatusTrafficMeter = metrics.NewRegisteredMeter("les/misc/out/traffic/txStatus", nil) + + connectionTimer = metrics.NewRegisteredTimer("les/connection/duration", nil) + serverConnectionGauge = metrics.NewRegisteredGauge("les/connection/server", nil) + clientConnectionGauge = metrics.NewRegisteredGauge("les/connection/client", nil) + + totalCapacityGauge = metrics.NewRegisteredGauge("les/server/totalCapacity", nil) + totalRechargeGauge = metrics.NewRegisteredGauge("les/server/totalRecharge", nil) + totalConnectedGauge = metrics.NewRegisteredGauge("les/server/totalConnected", nil) + blockProcessingTimer = metrics.NewRegisteredTimer("les/server/blockProcessingTime", nil) + + requestServedMeter = metrics.NewRegisteredMeter("les/server/req/avgServedTime", nil) + requestServedTimer = metrics.NewRegisteredTimer("les/server/req/servedTime", nil) + requestEstimatedMeter = metrics.NewRegisteredMeter("les/server/req/avgEstimatedTime", nil) + requestEstimatedTimer = metrics.NewRegisteredTimer("les/server/req/estimatedTime", nil) + relativeCostHistogram = metrics.NewRegisteredHistogram("les/server/req/relative", nil, metrics.NewExpDecaySample(1028, 0.015)) + + recentServedGauge = metrics.NewRegisteredGauge("les/server/recentRequestServed", nil) + recentEstimatedGauge = metrics.NewRegisteredGauge("les/server/recentRequestEstimated", nil) + sqServedGauge = metrics.NewRegisteredGauge("les/server/servingQueue/served", nil) + sqQueuedGauge = metrics.NewRegisteredGauge("les/server/servingQueue/queued", nil) - totalConnectedGauge = metrics.NewRegisteredGauge("les/server/totalConnected", nil) - totalCapacityGauge = metrics.NewRegisteredGauge("les/server/totalCapacity", nil) - totalRechargeGauge = metrics.NewRegisteredGauge("les/server/totalRecharge", nil) - blockProcessingTimer = metrics.NewRegisteredTimer("les/server/blockProcessingTime", nil) - requestServedTimer = metrics.NewRegisteredTimer("les/server/requestServed", nil) - requestServedMeter = metrics.NewRegisteredMeter("les/server/totalRequestServed", nil) - requestEstimatedMeter = metrics.NewRegisteredMeter("les/server/totalRequestEstimated", nil) - relativeCostHistogram = metrics.NewRegisteredHistogram("les/server/relativeCost", nil, metrics.NewExpDecaySample(1028, 0.015)) - recentServedGauge = metrics.NewRegisteredGauge("les/server/recentRequestServed", nil) - recentEstimatedGauge = metrics.NewRegisteredGauge("les/server/recentRequestEstimated", nil) - sqServedGauge = metrics.NewRegisteredGauge("les/server/servingQueue/served", nil) - sqQueuedGauge = metrics.NewRegisteredGauge("les/server/servingQueue/queued", nil) clientConnectedMeter = metrics.NewRegisteredMeter("les/server/clientEvent/connected", nil) clientRejectedMeter = metrics.NewRegisteredMeter("les/server/clientEvent/rejected", nil) clientKickedMeter = metrics.NewRegisteredMeter("les/server/clientEvent/kicked", nil) clientDisconnectedMeter = metrics.NewRegisteredMeter("les/server/clientEvent/disconnected", nil) clientFreezeMeter = metrics.NewRegisteredMeter("les/server/clientEvent/freeze", nil) clientErrorMeter = metrics.NewRegisteredMeter("les/server/clientEvent/error", nil) + + requestRTT = metrics.NewRegisteredTimer("les/client/req/rtt", nil) + requestSendDelay = metrics.NewRegisteredTimer("les/client/req/sendDelay", nil) ) // meteredMsgReadWriter is a wrapper around a p2p.MsgReadWriter, capable of @@ -58,17 +100,11 @@ type meteredMsgReadWriter struct { // newMeteredMsgWriter wraps a p2p MsgReadWriter with metering support. If the // metrics system is disabled, this function returns the original object. -func newMeteredMsgWriter(rw p2p.MsgReadWriter) p2p.MsgReadWriter { +func newMeteredMsgWriter(rw p2p.MsgReadWriter, version int) p2p.MsgReadWriter { if !metrics.Enabled { return rw } - return &meteredMsgReadWriter{MsgReadWriter: rw} -} - -// Init sets the protocol version used by the stream to know which meters to -// increment in case of overlapping message ids between protocol versions. -func (rw *meteredMsgReadWriter) Init(version int) { - rw.version = version + return &meteredMsgReadWriter{MsgReadWriter: rw, version: version} } func (rw *meteredMsgReadWriter) ReadMsg() (p2p.Msg, error) { diff --git a/les/odr.go b/les/odr.go index a26c06680..136ecf4df 100644 --- a/les/odr.go +++ b/les/odr.go @@ -18,7 +18,9 @@ package les import ( "context" + "time" + "github.com/ethereum/go-ethereum/common/mclock" "github.com/ethereum/go-ethereum/core" "github.com/ethereum/go-ethereum/ethdb" "github.com/ethereum/go-ethereum/light" @@ -120,10 +122,11 @@ func (odr *LesOdr) Retrieve(ctx context.Context, req light.OdrRequest) (err erro return func() { lreq.Request(reqID, p) } }, } - + sent := mclock.Now() if err = odr.retriever.retrieve(ctx, reqID, rq, func(p distPeer, msg *Msg) error { return lreq.Validate(odr.db, msg) }, odr.stop); err == nil { // retrieved from network, store in db req.StoreResult(odr.db) + requestRTT.Update(time.Duration(mclock.Now() - sent)) } else { log.Debug("Failed to retrieve data from network", "err", err) } diff --git a/les/odr_test.go b/les/odr_test.go index 1e8a5f8b4..97217e948 100644 --- a/les/odr_test.go +++ b/les/odr_test.go @@ -39,6 +39,7 @@ import ( type odrTestFn func(ctx context.Context, db ethdb.Database, config *params.ChainConfig, bc *core.BlockChain, lc *light.LightChain, bhash common.Hash) []byte func TestOdrGetBlockLes2(t *testing.T) { testOdr(t, 2, 1, true, odrGetBlock) } +func TestOdrGetBlockLes3(t *testing.T) { testOdr(t, 3, 1, true, odrGetBlock) } func odrGetBlock(ctx context.Context, db ethdb.Database, config *params.ChainConfig, bc *core.BlockChain, lc *light.LightChain, bhash common.Hash) []byte { var block *types.Block @@ -55,6 +56,7 @@ func odrGetBlock(ctx context.Context, db ethdb.Database, config *params.ChainCon } func TestOdrGetReceiptsLes2(t *testing.T) { testOdr(t, 2, 1, true, odrGetReceipts) } +func TestOdrGetReceiptsLes3(t *testing.T) { testOdr(t, 3, 1, true, odrGetReceipts) } func odrGetReceipts(ctx context.Context, db ethdb.Database, config *params.ChainConfig, bc *core.BlockChain, lc *light.LightChain, bhash common.Hash) []byte { var receipts types.Receipts @@ -75,6 +77,7 @@ func odrGetReceipts(ctx context.Context, db ethdb.Database, config *params.Chain } func TestOdrAccountsLes2(t *testing.T) { testOdr(t, 2, 1, true, odrAccounts) } +func TestOdrAccountsLes3(t *testing.T) { testOdr(t, 3, 1, true, odrAccounts) } func odrAccounts(ctx context.Context, db ethdb.Database, config *params.ChainConfig, bc *core.BlockChain, lc *light.LightChain, bhash common.Hash) []byte { dummyAddr := common.HexToAddress("1234567812345678123456781234567812345678") @@ -103,6 +106,7 @@ func odrAccounts(ctx context.Context, db ethdb.Database, config *params.ChainCon } func TestOdrContractCallLes2(t *testing.T) { testOdr(t, 2, 2, true, odrContractCall) } +func TestOdrContractCallLes3(t *testing.T) { testOdr(t, 3, 2, true, odrContractCall) } type callmsg struct { types.Message @@ -152,6 +156,7 @@ func odrContractCall(ctx context.Context, db ethdb.Database, config *params.Chai } func TestOdrTxStatusLes2(t *testing.T) { testOdr(t, 2, 1, false, odrTxStatus) } +func TestOdrTxStatusLes3(t *testing.T) { testOdr(t, 3, 1, false, odrTxStatus) } func odrTxStatus(ctx context.Context, db ethdb.Database, config *params.ChainConfig, bc *core.BlockChain, lc *light.LightChain, bhash common.Hash) []byte { var txs types.Transactions @@ -178,21 +183,22 @@ func odrTxStatus(ctx context.Context, db ethdb.Database, config *params.ChainCon // testOdr tests odr requests whose validation guaranteed by block headers. func testOdr(t *testing.T, protocol int, expFail uint64, checkCached bool, fn odrTestFn) { // Assemble the test environment - server, client, tearDown := newClientServerEnv(t, 4, protocol, nil, true) + server, client, tearDown := newClientServerEnv(t, 4, protocol, nil, nil, 0, false, true) defer tearDown() - client.pm.synchronise(client.rPeer) + + client.handler.synchronise(client.peer.peer) test := func(expFail uint64) { // Mark this as a helper to put the failures at the correct lines t.Helper() - for i := uint64(0); i <= server.pm.blockchain.CurrentHeader().Number.Uint64(); i++ { + for i := uint64(0); i <= server.handler.blockchain.CurrentHeader().Number.Uint64(); i++ { bhash := rawdb.ReadCanonicalHash(server.db, i) - b1 := fn(light.NoOdr, server.db, server.pm.chainConfig, server.pm.blockchain.(*core.BlockChain), nil, bhash) + b1 := fn(light.NoOdr, server.db, server.handler.server.chainConfig, server.handler.blockchain, nil, bhash) ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond) - defer cancel() - b2 := fn(ctx, client.db, client.pm.chainConfig, nil, client.pm.blockchain.(*light.LightChain), bhash) + b2 := fn(ctx, client.db, client.handler.backend.chainConfig, nil, client.handler.backend.blockchain, bhash) + cancel() eq := bytes.Equal(b1, b2) exp := i < expFail @@ -204,22 +210,22 @@ func testOdr(t *testing.T, protocol int, expFail uint64, checkCached bool, fn od } } } - // temporarily remove peer to test odr fails + // expect retrievals to fail (except genesis block) without a les peer - client.peers.Unregister(client.rPeer.id) - time.Sleep(time.Millisecond * 10) // ensure that all peerSetNotify callbacks are executed + client.handler.backend.peers.lock.Lock() + client.peer.peer.hasBlock = func(common.Hash, uint64, bool) bool { return false } + client.handler.backend.peers.lock.Unlock() test(expFail) // expect all retrievals to pass - client.peers.Register(client.rPeer) - time.Sleep(time.Millisecond * 10) // ensure that all peerSetNotify callbacks are executed - client.peers.lock.Lock() - client.rPeer.hasBlock = func(common.Hash, uint64, bool) bool { return true } - client.peers.lock.Unlock() + client.handler.backend.peers.lock.Lock() + client.peer.peer.hasBlock = func(common.Hash, uint64, bool) bool { return true } + client.handler.backend.peers.lock.Unlock() test(5) + + // still expect all retrievals to pass, now data should be cached locally if checkCached { - // still expect all retrievals to pass, now data should be cached locally - client.peers.Unregister(client.rPeer.id) + client.handler.backend.peers.Unregister(client.peer.peer.id) time.Sleep(time.Millisecond * 10) // ensure that all peerSetNotify callbacks are executed test(5) } diff --git a/les/peer.go b/les/peer.go index bcd91cd83..edf3c7c85 100644 --- a/les/peer.go +++ b/les/peer.go @@ -111,7 +111,7 @@ type peer struct { fcServer *flowcontrol.ServerNode // nil if the peer is client only fcParams flowcontrol.ServerParams fcCosts requestCostTable - balanceTracker *balanceTracker // set by clientPool.connect, used and removed by ProtocolManager.handle + balanceTracker *balanceTracker // set by clientPool.connect, used and removed by serverHandler. trusted bool onlyAnnounce bool @@ -291,6 +291,11 @@ func (p *peer) updateCapacity(cap uint64) { p.queueSend(func() { p.SendAnnounce(announceData{Update: kvList}) }) } +func (p *peer) responseID() uint64 { + p.responseCount += 1 + return p.responseCount +} + func sendRequest(w p2p.MsgWriter, msgcode, reqID, cost uint64, data interface{}) error { type req struct { ReqID uint64 @@ -373,6 +378,7 @@ func (p *peer) HasBlock(hash common.Hash, number uint64, hasState bool) bool { } hasBlock := p.hasBlock p.lock.RUnlock() + return head >= number && number >= since && (recent == 0 || number+recent+4 > head) && hasBlock != nil && hasBlock(hash, number, hasState) } @@ -571,6 +577,8 @@ func (p *peer) Handshake(td *big.Int, head common.Hash, headNum uint64, genesis defer p.lock.Unlock() var send keyValueList + + // Add some basic handshake fields send = send.add("protocolVersion", uint64(p.version)) send = send.add("networkId", p.network) send = send.add("headTd", td) @@ -578,7 +586,8 @@ func (p *peer) Handshake(td *big.Int, head common.Hash, headNum uint64, genesis send = send.add("headNum", headNum) send = send.add("genesisHash", genesis) if server != nil { - if !server.onlyAnnounce { + // Add some information which services server can offer. + if !server.config.UltraLightOnlyAnnounce { send = send.add("serveHeaders", nil) send = send.add("serveChainSince", uint64(0)) send = send.add("serveStateSince", uint64(0)) @@ -594,25 +603,28 @@ func (p *peer) Handshake(td *big.Int, head common.Hash, headNum uint64, genesis } send = send.add("flowControl/BL", server.defParams.BufLimit) send = send.add("flowControl/MRR", server.defParams.MinRecharge) + var costList RequestCostList - if server.costTracker != nil { - costList = server.costTracker.makeCostList(server.costTracker.globalFactor()) + if server.costTracker.testCostList != nil { + costList = server.costTracker.testCostList } else { - costList = testCostList(server.testCost) + costList = server.costTracker.makeCostList(server.costTracker.globalFactor()) } send = send.add("flowControl/MRC", costList) p.fcCosts = costList.decode(ProtocolLengths[uint(p.version)]) p.fcParams = server.defParams - if server.protocolManager != nil && server.protocolManager.reg != nil && server.protocolManager.reg.isRunning() { - cp, height := server.protocolManager.reg.stableCheckpoint() + // Add advertised checkpoint and register block height which + // client can verify the checkpoint validity. + if server.oracle != nil && server.oracle.isRunning() { + cp, height := server.oracle.stableCheckpoint() if cp != nil { send = send.add("checkpoint/value", cp) send = send.add("checkpoint/registerHeight", height) } } } else { - //on client node + // Add some client-specific handshake fields p.announceType = announceTypeSimple if p.trusted { p.announceType = announceTypeSigned @@ -663,17 +675,12 @@ func (p *peer) Handshake(td *big.Int, head common.Hash, headNum uint64, genesis } if server != nil { - // until we have a proper peer connectivity API, allow LES connection to other servers - /*if recv.get("serveStateSince", nil) == nil { - return errResp(ErrUselessPeer, "wanted client, got server") - }*/ if recv.get("announceType", &p.announceType) != nil { - //set default announceType on server side + // set default announceType on server side p.announceType = announceTypeSimple } p.fcClient = flowcontrol.NewClientNode(server.fcManager, server.defParams) } else { - //mark OnlyAnnounce server if "serveHeaders", "serveChainSince", "serveStateSince" or "txRelay" fields don't exist if recv.get("serveChainSince", &p.chainSince) != nil { p.onlyAnnounce = true } @@ -730,15 +737,10 @@ func (p *peer) updateFlowControl(update keyValueMap) { if p.fcServer == nil { return } - params := p.fcParams - updateParams := false - if update.get("flowControl/BL", ¶ms.BufLimit) == nil { - updateParams = true - } - if update.get("flowControl/MRR", ¶ms.MinRecharge) == nil { - updateParams = true - } - if updateParams { + // If any of the flow control params is nil, refuse to update. + var params flowcontrol.ServerParams + if update.get("flowControl/BL", ¶ms.BufLimit) == nil && update.get("flowControl/MRR", ¶ms.MinRecharge) == nil { + // todo can light client set a minimal acceptable flow control params? p.fcParams = params p.fcServer.UpdateParams(params) } diff --git a/les/peer_test.go b/les/peer_test.go index ba8a79fe2..db74a052c 100644 --- a/les/peer_test.go +++ b/les/peer_test.go @@ -18,47 +18,54 @@ package les import ( "math/big" + "net" "testing" "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/common/mclock" + "github.com/ethereum/go-ethereum/core/rawdb" + "github.com/ethereum/go-ethereum/crypto" + "github.com/ethereum/go-ethereum/eth" "github.com/ethereum/go-ethereum/les/flowcontrol" "github.com/ethereum/go-ethereum/p2p" + "github.com/ethereum/go-ethereum/p2p/enode" "github.com/ethereum/go-ethereum/rlp" ) -const ( - test_networkid = 10 - protocol_version = lpv2 -) +const protocolVersion = lpv2 var ( - hash = common.HexToHash("some string") - genesis = common.HexToHash("genesis hash") + hash = common.HexToHash("deadbeef") + genesis = common.HexToHash("cafebabe") headNum = uint64(1234) td = big.NewInt(123) ) -//ulc connects to trusted peer and send announceType=announceTypeSigned +func newNodeID(t *testing.T) *enode.Node { + key, err := crypto.GenerateKey() + if err != nil { + t.Fatal("generate key err:", err) + } + return enode.NewV4(&key.PublicKey, net.IP{}, 35000, 35000) +} + +// ulc connects to trusted peer and send announceType=announceTypeSigned func TestPeerHandshakeSetAnnounceTypeToAnnounceTypeSignedForTrustedPeer(t *testing.T) { id := newNodeID(t).ID() - //peer to connect(on ulc side) + // peer to connect(on ulc side) p := peer{ Peer: p2p.NewPeer(id, "test peer", []p2p.Cap{}), - version: protocol_version, + version: protocolVersion, trusted: true, rw: &rwStub{ WriteHook: func(recvList keyValueList) { - //checking that ulc sends to peer allowedRequests=onlyAnnounceRequests and announceType = announceTypeSigned recv, _ := recvList.decode() var reqType uint64 - err := recv.get("announceType", &reqType) if err != nil { t.Fatal(err) } - if reqType != announceTypeSigned { t.Fatal("Expected announceTypeSigned") } @@ -71,18 +78,15 @@ func TestPeerHandshakeSetAnnounceTypeToAnnounceTypeSignedForTrustedPeer(t *testi l = l.add("flowControl/BL", uint64(0)) l = l.add("flowControl/MRR", uint64(0)) l = l.add("flowControl/MRC", testCostList(0)) - return l }, }, - network: test_networkid, + network: NetworkId, } - err := p.Handshake(td, hash, headNum, genesis, nil) if err != nil { t.Fatalf("Handshake error: %s", err) } - if p.announceType != announceTypeSigned { t.Fatal("Incorrect announceType") } @@ -92,18 +96,16 @@ func TestPeerHandshakeAnnounceTypeSignedForTrustedPeersPeerNotInTrusted(t *testi id := newNodeID(t).ID() p := peer{ Peer: p2p.NewPeer(id, "test peer", []p2p.Cap{}), - version: protocol_version, + version: protocolVersion, rw: &rwStub{ WriteHook: func(recvList keyValueList) { - //checking that ulc sends to peer allowedRequests=noRequests and announceType != announceTypeSigned + // checking that ulc sends to peer allowedRequests=noRequests and announceType != announceTypeSigned recv, _ := recvList.decode() var reqType uint64 - err := recv.get("announceType", &reqType) if err != nil { t.Fatal(err) } - if reqType == announceTypeSigned { t.Fatal("Expected not announceTypeSigned") } @@ -116,13 +118,11 @@ func TestPeerHandshakeAnnounceTypeSignedForTrustedPeersPeerNotInTrusted(t *testi l = l.add("flowControl/BL", uint64(0)) l = l.add("flowControl/MRR", uint64(0)) l = l.add("flowControl/MRC", testCostList(0)) - return l }, }, - network: test_networkid, + network: NetworkId, } - err := p.Handshake(td, hash, headNum, genesis, nil) if err != nil { t.Fatal(err) @@ -139,16 +139,15 @@ func TestPeerHandshakeDefaultAllRequests(t *testing.T) { p := peer{ Peer: p2p.NewPeer(id, "test peer", []p2p.Cap{}), - version: protocol_version, + version: protocolVersion, rw: &rwStub{ ReadHook: func(l keyValueList) keyValueList { l = l.add("announceType", uint64(announceTypeSigned)) l = l.add("allowedRequests", uint64(0)) - return l }, }, - network: test_networkid, + network: NetworkId, } err := p.Handshake(td, hash, headNum, genesis, s) @@ -165,15 +164,14 @@ func TestPeerHandshakeServerSendOnlyAnnounceRequestsHeaders(t *testing.T) { id := newNodeID(t).ID() s := generateLesServer() - s.onlyAnnounce = true + s.config.UltraLightOnlyAnnounce = true p := peer{ Peer: p2p.NewPeer(id, "test peer", []p2p.Cap{}), - version: protocol_version, + version: protocolVersion, rw: &rwStub{ ReadHook: func(l keyValueList) keyValueList { l = l.add("announceType", uint64(announceTypeSigned)) - return l }, WriteHook: func(l keyValueList) { @@ -187,7 +185,7 @@ func TestPeerHandshakeServerSendOnlyAnnounceRequestsHeaders(t *testing.T) { } }, }, - network: test_networkid, + network: NetworkId, } err := p.Handshake(td, hash, headNum, genesis, s) @@ -200,7 +198,7 @@ func TestPeerHandshakeClientReceiveOnlyAnnounceRequestsHeaders(t *testing.T) { p := peer{ Peer: p2p.NewPeer(id, "test peer", []p2p.Cap{}), - version: protocol_version, + version: protocolVersion, rw: &rwStub{ ReadHook: func(l keyValueList) keyValueList { l = l.add("flowControl/BL", uint64(0)) @@ -212,7 +210,7 @@ func TestPeerHandshakeClientReceiveOnlyAnnounceRequestsHeaders(t *testing.T) { return l }, }, - network: test_networkid, + network: NetworkId, trusted: true, } @@ -231,19 +229,17 @@ func TestPeerHandshakeClientReturnErrorOnUselessPeer(t *testing.T) { p := peer{ Peer: p2p.NewPeer(id, "test peer", []p2p.Cap{}), - version: protocol_version, + version: protocolVersion, rw: &rwStub{ ReadHook: func(l keyValueList) keyValueList { l = l.add("flowControl/BL", uint64(0)) l = l.add("flowControl/MRR", uint64(0)) l = l.add("flowControl/MRC", RequestCostList{}) - l = l.add("announceType", uint64(announceTypeSigned)) - return l }, }, - network: test_networkid, + network: NetworkId, } err := p.Handshake(td, hash, headNum, genesis, nil) @@ -254,12 +250,16 @@ func TestPeerHandshakeClientReturnErrorOnUselessPeer(t *testing.T) { func generateLesServer() *LesServer { s := &LesServer{ + lesCommons: lesCommons{ + config: ð.Config{UltraLightOnlyAnnounce: true}, + }, defParams: flowcontrol.ServerParams{ BufLimit: uint64(300000000), MinRecharge: uint64(50000), }, fcManager: flowcontrol.NewClientManager(nil, &mclock.System{}), } + s.costTracker, _ = newCostTracker(rawdb.NewMemoryDatabase(), s.config) return s } @@ -270,8 +270,8 @@ type rwStub struct { func (s *rwStub) ReadMsg() (p2p.Msg, error) { payload := keyValueList{} - payload = payload.add("protocolVersion", uint64(protocol_version)) - payload = payload.add("networkId", uint64(test_networkid)) + payload = payload.add("protocolVersion", uint64(protocolVersion)) + payload = payload.add("networkId", uint64(NetworkId)) payload = payload.add("headTd", td) payload = payload.add("headHash", hash) payload = payload.add("headNum", headNum) @@ -280,12 +280,10 @@ func (s *rwStub) ReadMsg() (p2p.Msg, error) { if s.ReadHook != nil { payload = s.ReadHook(payload) } - size, p, err := rlp.EncodeToReader(payload) if err != nil { return p2p.Msg{}, err } - return p2p.Msg{ Size: uint32(size), Payload: p, @@ -297,10 +295,8 @@ func (s *rwStub) WriteMsg(m p2p.Msg) error { if err := m.Decode(&recvList); err != nil { return err } - if s.WriteHook != nil { s.WriteHook(recvList) } - return nil } diff --git a/les/request_test.go b/les/request_test.go index 42a63c351..69b57ca31 100644 --- a/les/request_test.go +++ b/les/request_test.go @@ -37,18 +37,21 @@ func secAddr(addr common.Address) []byte { type accessTestFn func(db ethdb.Database, bhash common.Hash, number uint64) light.OdrRequest func TestBlockAccessLes2(t *testing.T) { testAccess(t, 2, tfBlockAccess) } +func TestBlockAccessLes3(t *testing.T) { testAccess(t, 3, tfBlockAccess) } func tfBlockAccess(db ethdb.Database, bhash common.Hash, number uint64) light.OdrRequest { return &light.BlockRequest{Hash: bhash, Number: number} } func TestReceiptsAccessLes2(t *testing.T) { testAccess(t, 2, tfReceiptsAccess) } +func TestReceiptsAccessLes3(t *testing.T) { testAccess(t, 3, tfReceiptsAccess) } func tfReceiptsAccess(db ethdb.Database, bhash common.Hash, number uint64) light.OdrRequest { return &light.ReceiptsRequest{Hash: bhash, Number: number} } func TestTrieEntryAccessLes2(t *testing.T) { testAccess(t, 2, tfTrieEntryAccess) } +func TestTrieEntryAccessLes3(t *testing.T) { testAccess(t, 3, tfTrieEntryAccess) } func tfTrieEntryAccess(db ethdb.Database, bhash common.Hash, number uint64) light.OdrRequest { if number := rawdb.ReadHeaderNumber(db, bhash); number != nil { @@ -58,6 +61,7 @@ func tfTrieEntryAccess(db ethdb.Database, bhash common.Hash, number uint64) ligh } func TestCodeAccessLes2(t *testing.T) { testAccess(t, 2, tfCodeAccess) } +func TestCodeAccessLes3(t *testing.T) { testAccess(t, 3, tfCodeAccess) } func tfCodeAccess(db ethdb.Database, bhash common.Hash, num uint64) light.OdrRequest { number := rawdb.ReadHeaderNumber(db, bhash) @@ -75,17 +79,18 @@ func tfCodeAccess(db ethdb.Database, bhash common.Hash, num uint64) light.OdrReq func testAccess(t *testing.T, protocol int, fn accessTestFn) { // Assemble the test environment - server, client, tearDown := newClientServerEnv(t, 4, protocol, nil, true) + server, client, tearDown := newClientServerEnv(t, 4, protocol, nil, nil, 0, false, true) defer tearDown() - client.pm.synchronise(client.rPeer) + client.handler.synchronise(client.peer.peer) test := func(expFail uint64) { - for i := uint64(0); i <= server.pm.blockchain.CurrentHeader().Number.Uint64(); i++ { + for i := uint64(0); i <= server.handler.blockchain.CurrentHeader().Number.Uint64(); i++ { bhash := rawdb.ReadCanonicalHash(server.db, i) if req := fn(client.db, bhash, i); req != nil { ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond) - defer cancel() - err := client.pm.odr.Retrieve(ctx, req) + err := client.handler.backend.odr.Retrieve(ctx, req) + cancel() + got := err == nil exp := i < expFail if exp && !got { @@ -97,18 +102,5 @@ func testAccess(t *testing.T, protocol int, fn accessTestFn) { } } } - - // temporarily remove peer to test odr fails - client.peers.Unregister(client.rPeer.id) - time.Sleep(time.Millisecond * 10) // ensure that all peerSetNotify callbacks are executed - // expect retrievals to fail (except genesis block) without a les peer - test(0) - - client.peers.Register(client.rPeer) - time.Sleep(time.Millisecond * 10) // ensure that all peerSetNotify callbacks are executed - client.rPeer.lock.Lock() - client.rPeer.hasBlock = func(common.Hash, uint64, bool) bool { return true } - client.rPeer.lock.Unlock() - // expect all retrievals to pass test(5) } diff --git a/les/server.go b/les/server.go index 97e82a42b..416cabd13 100644 --- a/les/server.go +++ b/les/server.go @@ -18,15 +18,11 @@ package les import ( "crypto/ecdsa" - "sync" "time" "github.com/ethereum/go-ethereum/accounts/abi/bind" - "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/common/mclock" "github.com/ethereum/go-ethereum/core" - "github.com/ethereum/go-ethereum/core/rawdb" - "github.com/ethereum/go-ethereum/core/types" "github.com/ethereum/go-ethereum/eth" "github.com/ethereum/go-ethereum/les/flowcontrol" "github.com/ethereum/go-ethereum/light" @@ -38,80 +34,94 @@ import ( "github.com/ethereum/go-ethereum/rpc" ) -const bufLimitRatio = 6000 // fixed bufLimit/MRR ratio - type LesServer struct { lesCommons archiveMode bool // Flag whether the ethereum node runs in archive mode. + handler *serverHandler + lesTopics []discv5.Topic + privateKey *ecdsa.PrivateKey - fcManager *flowcontrol.ClientManager // nil if our node is client only + // Flow control and capacity management + fcManager *flowcontrol.ClientManager costTracker *costTracker - testCost uint64 defParams flowcontrol.ServerParams - lesTopics []discv5.Topic - privateKey *ecdsa.PrivateKey - quitSync chan struct{} - onlyAnnounce bool + servingQueue *servingQueue + clientPool *clientPool - thcNormal, thcBlockProcessing int // serving thread count for normal operation and block processing mode - - maxPeers int - minCapacity, maxCapacity, freeClientCap uint64 - clientPool *clientPool + freeCapacity uint64 // The minimal client capacity used for free client. + threadsIdle int // Request serving threads count when system is idle. + threadsBusy int // Request serving threads count when system is busy(block insertion). } func NewLesServer(e *eth.Ethereum, config *eth.Config) (*LesServer, error) { + // Collect les protocol version information supported by local node. lesTopics := make([]discv5.Topic, len(AdvertiseProtocolVersions)) for i, pv := range AdvertiseProtocolVersions { lesTopics[i] = lesTopic(e.BlockChain().Genesis().Hash(), pv) } - quitSync := make(chan struct{}) + // Calculate the number of threads used to service the light client + // requests based on the user-specified value. + threads := config.LightServ * 4 / 100 + if threads < 4 { + threads = 4 + } srv := &LesServer{ lesCommons: lesCommons{ + genesis: e.BlockChain().Genesis().Hash(), config: config, + chainConfig: e.BlockChain().Config(), iConfig: light.DefaultServerIndexerConfig, chainDb: e.ChainDb(), + peers: newPeerSet(), + chainReader: e.BlockChain(), chtIndexer: light.NewChtIndexer(e.ChainDb(), nil, params.CHTFrequency, params.HelperTrieProcessConfirmations), bloomTrieIndexer: light.NewBloomTrieIndexer(e.ChainDb(), nil, params.BloomBitsBlocks, params.BloomTrieFrequency), + closeCh: make(chan struct{}), }, archiveMode: e.ArchiveMode(), - quitSync: quitSync, lesTopics: lesTopics, - onlyAnnounce: config.UltraLightOnlyAnnounce, + fcManager: flowcontrol.NewClientManager(nil, &mclock.System{}), + servingQueue: newServingQueue(int64(time.Millisecond*10), float64(config.LightServ)/100), + threadsBusy: config.LightServ/100 + 1, + threadsIdle: threads, } - srv.costTracker, srv.minCapacity = newCostTracker(e.ChainDb(), config) - - logger := log.New() - srv.thcNormal = config.LightServ * 4 / 100 - if srv.thcNormal < 4 { - srv.thcNormal = 4 - } - srv.thcBlockProcessing = config.LightServ/100 + 1 - srv.fcManager = flowcontrol.NewClientManager(nil, &mclock.System{}) - - checkpoint := srv.latestLocalCheckpoint() - if !checkpoint.Empty() { - logger.Info("Loaded latest checkpoint", "section", checkpoint.SectionIndex, "head", checkpoint.SectionHead, - "chtroot", checkpoint.CHTRoot, "bloomroot", checkpoint.BloomRoot) - } - - srv.chtIndexer.Start(e.BlockChain()) + srv.handler = newServerHandler(srv, e.BlockChain(), e.ChainDb(), e.TxPool(), e.Synced) + srv.costTracker, srv.freeCapacity = newCostTracker(e.ChainDb(), config) + // Set up checkpoint oracle. oracle := config.CheckpointOracle if oracle == nil { oracle = params.CheckpointOracles[e.BlockChain().Genesis().Hash()] } - registrar := newCheckpointOracle(oracle, srv.getLocalCheckpoint) - // TODO(rjl493456442) Checkpoint is useless for les server, separate handler for client and server. - pm, err := NewProtocolManager(e.BlockChain().Config(), nil, light.DefaultServerIndexerConfig, config.UltraLightServers, config.UltraLightFraction, false, config.NetworkId, e.EventMux(), newPeerSet(), e.BlockChain(), e.TxPool(), e.ChainDb(), nil, nil, registrar, quitSync, new(sync.WaitGroup), e.Synced) - if err != nil { - return nil, err - } - srv.protocolManager = pm - pm.servingQueue = newServingQueue(int64(time.Millisecond*10), float64(config.LightServ)/100) - pm.server = srv + srv.oracle = newCheckpointOracle(oracle, srv.localCheckpoint) + // Initialize server capacity management fields. + srv.defParams = flowcontrol.ServerParams{ + BufLimit: srv.freeCapacity * bufLimitRatio, + MinRecharge: srv.freeCapacity, + } + // LES flow control tries to more or less guarantee the possibility for the + // clients to send a certain amount of requests at any time and get a quick + // response. Most of the clients want this guarantee but don't actually need + // to send requests most of the time. Our goal is to serve as many clients as + // possible while the actually used server capacity does not exceed the limits + totalRecharge := srv.costTracker.totalRecharge() + maxCapacity := srv.freeCapacity * uint64(srv.config.LightPeers) + if totalRecharge > maxCapacity { + maxCapacity = totalRecharge + } + srv.fcManager.SetCapacityLimits(srv.freeCapacity, maxCapacity, srv.freeCapacity*2) + + srv.clientPool = newClientPool(srv.chainDb, srv.freeCapacity, 10000, mclock.System{}, func(id enode.ID) { go srv.peers.Unregister(peerIdToString(id)) }) + srv.peers.notify(srv.clientPool) + + checkpoint := srv.latestLocalCheckpoint() + if !checkpoint.Empty() { + log.Info("Loaded latest checkpoint", "section", checkpoint.SectionIndex, "head", checkpoint.SectionHead, + "chtroot", checkpoint.CHTRoot, "bloomroot", checkpoint.BloomRoot) + } + srv.chtIndexer.Start(e.BlockChain()) return srv, nil } @@ -120,102 +130,29 @@ func (s *LesServer) APIs() []rpc.API { { Namespace: "les", Version: "1.0", - Service: NewPrivateLightAPI(&s.lesCommons, s.protocolManager.reg), + Service: NewPrivateLightAPI(&s.lesCommons), Public: false, }, } } -// startEventLoop starts an event handler loop that updates the recharge curve of -// the client manager and adjusts the client pool's size according to the total -// capacity updates coming from the client manager -func (s *LesServer) startEventLoop() { - s.protocolManager.wg.Add(1) - - var ( - processing, procLast bool - procStarted time.Time - ) - blockProcFeed := make(chan bool, 100) - s.protocolManager.blockchain.(*core.BlockChain).SubscribeBlockProcessingEvent(blockProcFeed) - totalRechargeCh := make(chan uint64, 100) - totalRecharge := s.costTracker.subscribeTotalRecharge(totalRechargeCh) - totalCapacityCh := make(chan uint64, 100) - updateRecharge := func() { - if processing { - if !procLast { - procStarted = time.Now() - } - s.protocolManager.servingQueue.setThreads(s.thcBlockProcessing) - s.fcManager.SetRechargeCurve(flowcontrol.PieceWiseLinear{{0, 0}, {totalRecharge, totalRecharge}}) - } else { - if procLast { - blockProcessingTimer.UpdateSince(procStarted) - } - s.protocolManager.servingQueue.setThreads(s.thcNormal) - s.fcManager.SetRechargeCurve(flowcontrol.PieceWiseLinear{{0, 0}, {totalRecharge / 16, totalRecharge / 2}, {totalRecharge / 2, totalRecharge / 2}, {totalRecharge, totalRecharge}}) - } - procLast = processing - } - updateRecharge() - totalCapacity := s.fcManager.SubscribeTotalCapacity(totalCapacityCh) - s.clientPool.setLimits(s.maxPeers, totalCapacity) - - var maxFreePeers uint64 - go func() { - for { - select { - case processing = <-blockProcFeed: - updateRecharge() - case totalRecharge = <-totalRechargeCh: - updateRecharge() - case totalCapacity = <-totalCapacityCh: - totalCapacityGauge.Update(int64(totalCapacity)) - newFreePeers := totalCapacity / s.freeClientCap - if newFreePeers < maxFreePeers && newFreePeers < uint64(s.maxPeers) { - log.Warn("Reduced total capacity", "maxFreePeers", newFreePeers) - } - maxFreePeers = newFreePeers - s.clientPool.setLimits(s.maxPeers, totalCapacity) - case <-s.protocolManager.quitSync: - s.protocolManager.wg.Done() - return - } - } - }() -} - func (s *LesServer) Protocols() []p2p.Protocol { - return s.makeProtocols(ServerProtocolVersions) + return s.makeProtocols(ServerProtocolVersions, s.handler.runPeer, func(id enode.ID) interface{} { + if p := s.peers.Peer(peerIdToString(id)); p != nil { + return p.Info() + } + return nil + }) } // Start starts the LES server func (s *LesServer) Start(srvr *p2p.Server) { - s.maxPeers = s.config.LightPeers - totalRecharge := s.costTracker.totalRecharge() - if s.maxPeers > 0 { - s.freeClientCap = s.minCapacity //totalRecharge / uint64(s.maxPeers) - if s.freeClientCap < s.minCapacity { - s.freeClientCap = s.minCapacity - } - if s.freeClientCap > 0 { - s.defParams = flowcontrol.ServerParams{ - BufLimit: s.freeClientCap * bufLimitRatio, - MinRecharge: s.freeClientCap, - } - } - } + s.privateKey = srvr.PrivateKey + s.handler.start() + + s.wg.Add(1) + go s.capacityManagement() - s.maxCapacity = s.freeClientCap * uint64(s.maxPeers) - if totalRecharge > s.maxCapacity { - s.maxCapacity = totalRecharge - } - s.fcManager.SetCapacityLimits(s.freeClientCap, s.maxCapacity, s.freeClientCap*2) - s.clientPool = newClientPool(s.chainDb, s.freeClientCap, 10000, mclock.System{}, func(id enode.ID) { go s.protocolManager.removePeer(peerIdToString(id)) }) - s.clientPool.setPriceFactors(priceFactors{0, 1, 1}, priceFactors{0, 1, 1}) - s.protocolManager.peers.notify(s.clientPool) - s.startEventLoop() - s.protocolManager.Start(s.config.LightPeers) if srvr.DiscV5 != nil { for _, topic := range s.lesTopics { topic := topic @@ -224,12 +161,32 @@ func (s *LesServer) Start(srvr *p2p.Server) { logger.Info("Starting topic registration") defer logger.Info("Terminated topic registration") - srvr.DiscV5.RegisterTopic(topic, s.quitSync) + srvr.DiscV5.RegisterTopic(topic, s.closeCh) }() } } - s.privateKey = srvr.PrivateKey - s.protocolManager.blockLoop() +} + +// Stop stops the LES service +func (s *LesServer) Stop() { + close(s.closeCh) + + // Disconnect existing sessions. + // This also closes the gate for any new registrations on the peer set. + // sessions which are already established but not added to pm.peers yet + // will exit when they try to register. + s.peers.Close() + + s.fcManager.Stop() + s.clientPool.stop() + s.costTracker.stop() + s.handler.stop() + s.servingQueue.stop() + + // Note, bloom trie indexer is closed by parent bloombits indexer. + s.chtIndexer.Close() + s.wg.Wait() + log.Info("Les server stopped") } func (s *LesServer) SetBloomBitsIndexer(bloomIndexer *core.ChainIndexer) { @@ -238,78 +195,67 @@ func (s *LesServer) SetBloomBitsIndexer(bloomIndexer *core.ChainIndexer) { // SetClient sets the rpc client and starts running checkpoint contract if it is not yet watched. func (s *LesServer) SetContractBackend(backend bind.ContractBackend) { - if s.protocolManager.reg != nil { - s.protocolManager.reg.start(backend) + if s.oracle == nil { + return + } + s.oracle.start(backend) +} + +// capacityManagement starts an event handler loop that updates the recharge curve of +// the client manager and adjusts the client pool's size according to the total +// capacity updates coming from the client manager +func (s *LesServer) capacityManagement() { + defer s.wg.Done() + + processCh := make(chan bool, 100) + sub := s.handler.blockchain.SubscribeBlockProcessingEvent(processCh) + defer sub.Unsubscribe() + + totalRechargeCh := make(chan uint64, 100) + totalRecharge := s.costTracker.subscribeTotalRecharge(totalRechargeCh) + + totalCapacityCh := make(chan uint64, 100) + totalCapacity := s.fcManager.SubscribeTotalCapacity(totalCapacityCh) + s.clientPool.setLimits(s.config.LightPeers, totalCapacity) + + var ( + busy bool + freePeers uint64 + blockProcess mclock.AbsTime + ) + updateRecharge := func() { + if busy { + s.servingQueue.setThreads(s.threadsBusy) + s.fcManager.SetRechargeCurve(flowcontrol.PieceWiseLinear{{0, 0}, {totalRecharge, totalRecharge}}) + } else { + s.servingQueue.setThreads(s.threadsIdle) + s.fcManager.SetRechargeCurve(flowcontrol.PieceWiseLinear{{0, 0}, {totalRecharge / 10, totalRecharge}, {totalRecharge, totalRecharge}}) + } + } + updateRecharge() + + for { + select { + case busy = <-processCh: + if busy { + blockProcess = mclock.Now() + } else { + blockProcessingTimer.Update(time.Duration(mclock.Now() - blockProcess)) + } + updateRecharge() + case totalRecharge = <-totalRechargeCh: + totalRechargeGauge.Update(int64(totalRecharge)) + updateRecharge() + case totalCapacity = <-totalCapacityCh: + totalCapacityGauge.Update(int64(totalCapacity)) + newFreePeers := totalCapacity / s.freeCapacity + if newFreePeers < freePeers && newFreePeers < uint64(s.config.LightPeers) { + log.Warn("Reduced free peer connections", "from", freePeers, "to", newFreePeers) + } + freePeers = newFreePeers + s.clientPool.setLimits(s.config.LightPeers, totalCapacity) + case <-s.closeCh: + return + } } } - -// Stop stops the LES service -func (s *LesServer) Stop() { - s.fcManager.Stop() - s.chtIndexer.Close() - // bloom trie indexer is closed by parent bloombits indexer - go func() { - <-s.protocolManager.noMorePeers - }() - s.clientPool.stop() - s.costTracker.stop() - s.protocolManager.Stop() -} - -// todo(rjl493456442) separate client and server implementation. -func (pm *ProtocolManager) blockLoop() { - pm.wg.Add(1) - headCh := make(chan core.ChainHeadEvent, 10) - headSub := pm.blockchain.SubscribeChainHeadEvent(headCh) - go func() { - var lastHead *types.Header - lastBroadcastTd := common.Big0 - for { - select { - case ev := <-headCh: - peers := pm.peers.AllPeers() - if len(peers) > 0 { - header := ev.Block.Header() - hash := header.Hash() - number := header.Number.Uint64() - td := rawdb.ReadTd(pm.chainDb, hash, number) - if td != nil && td.Cmp(lastBroadcastTd) > 0 { - var reorg uint64 - if lastHead != nil { - reorg = lastHead.Number.Uint64() - rawdb.FindCommonAncestor(pm.chainDb, header, lastHead).Number.Uint64() - } - lastHead = header - lastBroadcastTd = td - - log.Debug("Announcing block to peers", "number", number, "hash", hash, "td", td, "reorg", reorg) - - announce := announceData{Hash: hash, Number: number, Td: td, ReorgDepth: reorg} - var ( - signed bool - signedAnnounce announceData - ) - - for _, p := range peers { - p := p - switch p.announceType { - case announceTypeSimple: - p.queueSend(func() { p.SendAnnounce(announce) }) - case announceTypeSigned: - if !signed { - signedAnnounce = announce - signedAnnounce.sign(pm.server.privateKey) - signed = true - } - p.queueSend(func() { p.SendAnnounce(signedAnnounce) }) - } - } - } - } - case <-pm.quitSync: - headSub.Unsubscribe() - pm.wg.Done() - return - } - } - }() -} diff --git a/les/server_handler.go b/les/server_handler.go new file mode 100644 index 000000000..af9c077bc --- /dev/null +++ b/les/server_handler.go @@ -0,0 +1,921 @@ +// Copyright 2019 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package les + +import ( + "encoding/binary" + "encoding/json" + "errors" + "sync" + "sync/atomic" + "time" + + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/common/mclock" + "github.com/ethereum/go-ethereum/core" + "github.com/ethereum/go-ethereum/core/rawdb" + "github.com/ethereum/go-ethereum/core/state" + "github.com/ethereum/go-ethereum/core/types" + "github.com/ethereum/go-ethereum/ethdb" + "github.com/ethereum/go-ethereum/light" + "github.com/ethereum/go-ethereum/log" + "github.com/ethereum/go-ethereum/metrics" + "github.com/ethereum/go-ethereum/p2p" + "github.com/ethereum/go-ethereum/rlp" + "github.com/ethereum/go-ethereum/trie" +) + +const ( + softResponseLimit = 2 * 1024 * 1024 // Target maximum size of returned blocks, headers or node data. + estHeaderRlpSize = 500 // Approximate size of an RLP encoded block header + ethVersion = 63 // equivalent eth version for the downloader + + MaxHeaderFetch = 192 // Amount of block headers to be fetched per retrieval request + MaxBodyFetch = 32 // Amount of block bodies to be fetched per retrieval request + MaxReceiptFetch = 128 // Amount of transaction receipts to allow fetching per request + MaxCodeFetch = 64 // Amount of contract codes to allow fetching per request + MaxProofsFetch = 64 // Amount of merkle proofs to be fetched per retrieval request + MaxHelperTrieProofsFetch = 64 // Amount of helper tries to be fetched per retrieval request + MaxTxSend = 64 // Amount of transactions to be send per request + MaxTxStatus = 256 // Amount of transactions to queried per request +) + +var errTooManyInvalidRequest = errors.New("too many invalid requests made") + +// serverHandler is responsible for serving light client and process +// all incoming light requests. +type serverHandler struct { + blockchain *core.BlockChain + chainDb ethdb.Database + txpool *core.TxPool + server *LesServer + + closeCh chan struct{} // Channel used to exit all background routines of handler. + wg sync.WaitGroup // WaitGroup used to track all background routines of handler. + synced func() bool // Callback function used to determine whether local node is synced. + + // Testing fields + addTxsSync bool +} + +func newServerHandler(server *LesServer, blockchain *core.BlockChain, chainDb ethdb.Database, txpool *core.TxPool, synced func() bool) *serverHandler { + handler := &serverHandler{ + server: server, + blockchain: blockchain, + chainDb: chainDb, + txpool: txpool, + closeCh: make(chan struct{}), + synced: synced, + } + return handler +} + +// start starts the server handler. +func (h *serverHandler) start() { + h.wg.Add(1) + go h.broadcastHeaders() +} + +// stop stops the server handler. +func (h *serverHandler) stop() { + close(h.closeCh) + h.wg.Wait() +} + +// runPeer is the p2p protocol run function for the given version. +func (h *serverHandler) runPeer(version uint, p *p2p.Peer, rw p2p.MsgReadWriter) error { + peer := newPeer(int(version), h.server.config.NetworkId, false, p, newMeteredMsgWriter(rw, int(version))) + h.wg.Add(1) + defer h.wg.Done() + return h.handle(peer) +} + +func (h *serverHandler) handle(p *peer) error { + // Reject light clients if server is not synced. + if !h.synced() { + return p2p.DiscRequested + } + p.Log().Debug("Light Ethereum peer connected", "name", p.Name()) + + // Execute the LES handshake + var ( + head = h.blockchain.CurrentHeader() + hash = head.Hash() + number = head.Number.Uint64() + td = h.blockchain.GetTd(hash, number) + ) + if err := p.Handshake(td, hash, number, h.blockchain.Genesis().Hash(), h.server); err != nil { + p.Log().Debug("Light Ethereum handshake failed", "err", err) + return err + } + defer p.fcClient.Disconnect() + + // Register the peer locally + if err := h.server.peers.Register(p); err != nil { + p.Log().Error("Light Ethereum peer registration failed", "err", err) + return err + } + clientConnectionGauge.Update(int64(h.server.peers.Len())) + + // add dummy balance tracker for tests + if p.balanceTracker == nil { + p.balanceTracker = &balanceTracker{} + p.balanceTracker.init(&mclock.System{}, 1) + } + + connectedAt := mclock.Now() + defer func() { + p.balanceTracker = nil + h.server.peers.Unregister(p.id) + clientConnectionGauge.Update(int64(h.server.peers.Len())) + connectionTimer.Update(time.Duration(mclock.Now() - connectedAt)) + }() + + // Spawn a main loop to handle all incoming messages. + for { + select { + case err := <-p.errCh: + p.Log().Debug("Failed to send light ethereum response", "err", err) + return err + default: + } + if err := h.handleMsg(p); err != nil { + p.Log().Debug("Light Ethereum message handling failed", "err", err) + return err + } + } +} + +// handleMsg is invoked whenever an inbound message is received from a remote +// peer. The remote connection is torn down upon returning any error. +func (h *serverHandler) handleMsg(p *peer) error { + // Read the next message from the remote peer, and ensure it's fully consumed + msg, err := p.rw.ReadMsg() + if err != nil { + return err + } + p.Log().Trace("Light Ethereum message arrived", "code", msg.Code, "bytes", msg.Size) + + // Discard large message which exceeds the limitation. + if msg.Size > ProtocolMaxMsgSize { + clientErrorMeter.Mark(1) + return errResp(ErrMsgTooLarge, "%v > %v", msg.Size, ProtocolMaxMsgSize) + } + defer msg.Discard() + + var ( + maxCost uint64 + task *servingTask + ) + p.responseCount++ + responseCount := p.responseCount + // accept returns an indicator whether the request can be served. + // If so, deduct the max cost from the flow control buffer. + accept := func(reqID, reqCnt, maxCnt uint64) bool { + // Short circuit if the peer is already frozen or the request is invalid. + inSizeCost := h.server.costTracker.realCost(0, msg.Size, 0) + if p.isFrozen() || reqCnt == 0 || reqCnt > maxCnt { + p.fcClient.OneTimeCost(inSizeCost) + return false + } + // Prepaid max cost units before request been serving. + maxCost = p.fcCosts.getMaxCost(msg.Code, reqCnt) + accepted, bufShort, priority := p.fcClient.AcceptRequest(reqID, responseCount, maxCost) + if !accepted { + p.freezeClient() + p.Log().Error("Request came too early", "remaining", common.PrettyDuration(time.Duration(bufShort*1000000/p.fcParams.MinRecharge))) + p.fcClient.OneTimeCost(inSizeCost) + return false + } + // Create a multi-stage task, estimate the time it takes for the task to + // execute, and cache it in the request service queue. + factor := h.server.costTracker.globalFactor() + if factor < 0.001 { + factor = 1 + p.Log().Error("Invalid global cost factor", "factor", factor) + } + maxTime := uint64(float64(maxCost) / factor) + task = h.server.servingQueue.newTask(p, maxTime, priority) + if task.start() { + return true + } + p.fcClient.RequestProcessed(reqID, responseCount, maxCost, inSizeCost) + return false + } + // sendResponse sends back the response and updates the flow control statistic. + sendResponse := func(reqID, amount uint64, reply *reply, servingTime uint64) { + p.responseLock.Lock() + defer p.responseLock.Unlock() + + // Short circuit if the client is already frozen. + if p.isFrozen() { + realCost := h.server.costTracker.realCost(servingTime, msg.Size, 0) + p.fcClient.RequestProcessed(reqID, responseCount, maxCost, realCost) + return + } + // Positive correction buffer value with real cost. + var replySize uint32 + if reply != nil { + replySize = reply.size() + } + var realCost uint64 + if h.server.costTracker.testing { + realCost = maxCost // Assign a fake cost for testing purpose + } else { + realCost = h.server.costTracker.realCost(servingTime, msg.Size, replySize) + } + bv := p.fcClient.RequestProcessed(reqID, responseCount, maxCost, realCost) + if amount != 0 { + // Feed cost tracker request serving statistic. + h.server.costTracker.updateStats(msg.Code, amount, servingTime, realCost) + // Reduce priority "balance" for the specific peer. + p.balanceTracker.requestCost(realCost) + } + if reply != nil { + p.queueSend(func() { + if err := reply.send(bv); err != nil { + select { + case p.errCh <- err: + default: + } + } + }) + } + } + switch msg.Code { + case GetBlockHeadersMsg: + p.Log().Trace("Received block header request") + if metrics.EnabledExpensive { + miscInHeaderPacketsMeter.Mark(1) + miscInHeaderTrafficMeter.Mark(int64(msg.Size)) + } + var req struct { + ReqID uint64 + Query getBlockHeadersData + } + if err := msg.Decode(&req); err != nil { + clientErrorMeter.Mark(1) + return errResp(ErrDecode, "%v: %v", msg, err) + } + query := req.Query + if accept(req.ReqID, query.Amount, MaxHeaderFetch) { + go func() { + hashMode := query.Origin.Hash != (common.Hash{}) + first := true + maxNonCanonical := uint64(100) + + // Gather headers until the fetch or network limits is reached + var ( + bytes common.StorageSize + headers []*types.Header + unknown bool + ) + for !unknown && len(headers) < int(query.Amount) && bytes < softResponseLimit { + if !first && !task.waitOrStop() { + sendResponse(req.ReqID, 0, nil, task.servingTime) + return + } + // Retrieve the next header satisfying the query + var origin *types.Header + if hashMode { + if first { + origin = h.blockchain.GetHeaderByHash(query.Origin.Hash) + if origin != nil { + query.Origin.Number = origin.Number.Uint64() + } + } else { + origin = h.blockchain.GetHeader(query.Origin.Hash, query.Origin.Number) + } + } else { + origin = h.blockchain.GetHeaderByNumber(query.Origin.Number) + } + if origin == nil { + atomic.AddUint32(&p.invalidCount, 1) + break + } + headers = append(headers, origin) + bytes += estHeaderRlpSize + + // Advance to the next header of the query + switch { + case hashMode && query.Reverse: + // Hash based traversal towards the genesis block + ancestor := query.Skip + 1 + if ancestor == 0 { + unknown = true + } else { + query.Origin.Hash, query.Origin.Number = h.blockchain.GetAncestor(query.Origin.Hash, query.Origin.Number, ancestor, &maxNonCanonical) + unknown = query.Origin.Hash == common.Hash{} + } + case hashMode && !query.Reverse: + // Hash based traversal towards the leaf block + var ( + current = origin.Number.Uint64() + next = current + query.Skip + 1 + ) + if next <= current { + infos, _ := json.MarshalIndent(p.Peer.Info(), "", " ") + p.Log().Warn("GetBlockHeaders skip overflow attack", "current", current, "skip", query.Skip, "next", next, "attacker", infos) + unknown = true + } else { + if header := h.blockchain.GetHeaderByNumber(next); header != nil { + nextHash := header.Hash() + expOldHash, _ := h.blockchain.GetAncestor(nextHash, next, query.Skip+1, &maxNonCanonical) + if expOldHash == query.Origin.Hash { + query.Origin.Hash, query.Origin.Number = nextHash, next + } else { + unknown = true + } + } else { + unknown = true + } + } + case query.Reverse: + // Number based traversal towards the genesis block + if query.Origin.Number >= query.Skip+1 { + query.Origin.Number -= query.Skip + 1 + } else { + unknown = true + } + + case !query.Reverse: + // Number based traversal towards the leaf block + query.Origin.Number += query.Skip + 1 + } + first = false + } + reply := p.ReplyBlockHeaders(req.ReqID, headers) + sendResponse(req.ReqID, query.Amount, p.ReplyBlockHeaders(req.ReqID, headers), task.done()) + if metrics.EnabledExpensive { + miscOutHeaderPacketsMeter.Mark(1) + miscOutHeaderTrafficMeter.Mark(int64(reply.size())) + } + }() + } + + case GetBlockBodiesMsg: + p.Log().Trace("Received block bodies request") + if metrics.EnabledExpensive { + miscInBodyPacketsMeter.Mark(1) + miscInBodyTrafficMeter.Mark(int64(msg.Size)) + } + var req struct { + ReqID uint64 + Hashes []common.Hash + } + if err := msg.Decode(&req); err != nil { + clientErrorMeter.Mark(1) + return errResp(ErrDecode, "msg %v: %v", msg, err) + } + var ( + bytes int + bodies []rlp.RawValue + ) + reqCnt := len(req.Hashes) + if accept(req.ReqID, uint64(reqCnt), MaxBodyFetch) { + go func() { + for i, hash := range req.Hashes { + if i != 0 && !task.waitOrStop() { + sendResponse(req.ReqID, 0, nil, task.servingTime) + return + } + if bytes >= softResponseLimit { + break + } + body := h.blockchain.GetBodyRLP(hash) + if body == nil { + atomic.AddUint32(&p.invalidCount, 1) + continue + } + bodies = append(bodies, body) + bytes += len(body) + } + reply := p.ReplyBlockBodiesRLP(req.ReqID, bodies) + sendResponse(req.ReqID, uint64(reqCnt), reply, task.done()) + if metrics.EnabledExpensive { + miscOutBodyPacketsMeter.Mark(1) + miscOutBodyTrafficMeter.Mark(int64(reply.size())) + } + }() + } + + case GetCodeMsg: + p.Log().Trace("Received code request") + if metrics.EnabledExpensive { + miscInCodePacketsMeter.Mark(1) + miscInCodeTrafficMeter.Mark(int64(msg.Size)) + } + var req struct { + ReqID uint64 + Reqs []CodeReq + } + if err := msg.Decode(&req); err != nil { + clientErrorMeter.Mark(1) + return errResp(ErrDecode, "msg %v: %v", msg, err) + } + var ( + bytes int + data [][]byte + ) + reqCnt := len(req.Reqs) + if accept(req.ReqID, uint64(reqCnt), MaxCodeFetch) { + go func() { + for i, request := range req.Reqs { + if i != 0 && !task.waitOrStop() { + sendResponse(req.ReqID, 0, nil, task.servingTime) + return + } + // Look up the root hash belonging to the request + header := h.blockchain.GetHeaderByHash(request.BHash) + if header == nil { + p.Log().Warn("Failed to retrieve associate header for code", "hash", request.BHash) + atomic.AddUint32(&p.invalidCount, 1) + continue + } + // Refuse to search stale state data in the database since looking for + // a non-exist key is kind of expensive. + local := h.blockchain.CurrentHeader().Number.Uint64() + if !h.server.archiveMode && header.Number.Uint64()+core.TriesInMemory <= local { + p.Log().Debug("Reject stale code request", "number", header.Number.Uint64(), "head", local) + atomic.AddUint32(&p.invalidCount, 1) + continue + } + triedb := h.blockchain.StateCache().TrieDB() + + account, err := h.getAccount(triedb, header.Root, common.BytesToHash(request.AccKey)) + if err != nil { + p.Log().Warn("Failed to retrieve account for code", "block", header.Number, "hash", header.Hash(), "account", common.BytesToHash(request.AccKey), "err", err) + atomic.AddUint32(&p.invalidCount, 1) + continue + } + code, err := triedb.Node(common.BytesToHash(account.CodeHash)) + if err != nil { + p.Log().Warn("Failed to retrieve account code", "block", header.Number, "hash", header.Hash(), "account", common.BytesToHash(request.AccKey), "codehash", common.BytesToHash(account.CodeHash), "err", err) + continue + } + // Accumulate the code and abort if enough data was retrieved + data = append(data, code) + if bytes += len(code); bytes >= softResponseLimit { + break + } + } + reply := p.ReplyCode(req.ReqID, data) + sendResponse(req.ReqID, uint64(reqCnt), reply, task.done()) + if metrics.EnabledExpensive { + miscOutCodePacketsMeter.Mark(1) + miscOutCodeTrafficMeter.Mark(int64(reply.size())) + } + }() + } + + case GetReceiptsMsg: + p.Log().Trace("Received receipts request") + if metrics.EnabledExpensive { + miscInReceiptPacketsMeter.Mark(1) + miscInReceiptTrafficMeter.Mark(int64(msg.Size)) + } + var req struct { + ReqID uint64 + Hashes []common.Hash + } + if err := msg.Decode(&req); err != nil { + clientErrorMeter.Mark(1) + return errResp(ErrDecode, "msg %v: %v", msg, err) + } + var ( + bytes int + receipts []rlp.RawValue + ) + reqCnt := len(req.Hashes) + if accept(req.ReqID, uint64(reqCnt), MaxReceiptFetch) { + go func() { + for i, hash := range req.Hashes { + if i != 0 && !task.waitOrStop() { + sendResponse(req.ReqID, 0, nil, task.servingTime) + return + } + if bytes >= softResponseLimit { + break + } + // Retrieve the requested block's receipts, skipping if unknown to us + results := h.blockchain.GetReceiptsByHash(hash) + if results == nil { + if header := h.blockchain.GetHeaderByHash(hash); header == nil || header.ReceiptHash != types.EmptyRootHash { + atomic.AddUint32(&p.invalidCount, 1) + continue + } + } + // If known, encode and queue for response packet + if encoded, err := rlp.EncodeToBytes(results); err != nil { + log.Error("Failed to encode receipt", "err", err) + } else { + receipts = append(receipts, encoded) + bytes += len(encoded) + } + } + reply := p.ReplyReceiptsRLP(req.ReqID, receipts) + sendResponse(req.ReqID, uint64(reqCnt), reply, task.done()) + if metrics.EnabledExpensive { + miscOutReceiptPacketsMeter.Mark(1) + miscOutReceiptTrafficMeter.Mark(int64(reply.size())) + } + }() + } + + case GetProofsV2Msg: + p.Log().Trace("Received les/2 proofs request") + if metrics.EnabledExpensive { + miscInTrieProofPacketsMeter.Mark(1) + miscInTrieProofTrafficMeter.Mark(int64(msg.Size)) + } + var req struct { + ReqID uint64 + Reqs []ProofReq + } + if err := msg.Decode(&req); err != nil { + clientErrorMeter.Mark(1) + return errResp(ErrDecode, "msg %v: %v", msg, err) + } + // Gather state data until the fetch or network limits is reached + var ( + lastBHash common.Hash + root common.Hash + ) + reqCnt := len(req.Reqs) + if accept(req.ReqID, uint64(reqCnt), MaxProofsFetch) { + go func() { + nodes := light.NewNodeSet() + + for i, request := range req.Reqs { + if i != 0 && !task.waitOrStop() { + sendResponse(req.ReqID, 0, nil, task.servingTime) + return + } + // Look up the root hash belonging to the request + var ( + number *uint64 + header *types.Header + trie state.Trie + ) + if request.BHash != lastBHash { + root, lastBHash = common.Hash{}, request.BHash + + if header = h.blockchain.GetHeaderByHash(request.BHash); header == nil { + p.Log().Warn("Failed to retrieve header for proof", "block", *number, "hash", request.BHash) + atomic.AddUint32(&p.invalidCount, 1) + continue + } + // Refuse to search stale state data in the database since looking for + // a non-exist key is kind of expensive. + local := h.blockchain.CurrentHeader().Number.Uint64() + if !h.server.archiveMode && header.Number.Uint64()+core.TriesInMemory <= local { + p.Log().Debug("Reject stale trie request", "number", header.Number.Uint64(), "head", local) + atomic.AddUint32(&p.invalidCount, 1) + continue + } + root = header.Root + } + // If a header lookup failed (non existent), ignore subsequent requests for the same header + if root == (common.Hash{}) { + atomic.AddUint32(&p.invalidCount, 1) + continue + } + // Open the account or storage trie for the request + statedb := h.blockchain.StateCache() + + switch len(request.AccKey) { + case 0: + // No account key specified, open an account trie + trie, err = statedb.OpenTrie(root) + if trie == nil || err != nil { + p.Log().Warn("Failed to open storage trie for proof", "block", header.Number, "hash", header.Hash(), "root", root, "err", err) + continue + } + default: + // Account key specified, open a storage trie + account, err := h.getAccount(statedb.TrieDB(), root, common.BytesToHash(request.AccKey)) + if err != nil { + p.Log().Warn("Failed to retrieve account for proof", "block", header.Number, "hash", header.Hash(), "account", common.BytesToHash(request.AccKey), "err", err) + atomic.AddUint32(&p.invalidCount, 1) + continue + } + trie, err = statedb.OpenStorageTrie(common.BytesToHash(request.AccKey), account.Root) + if trie == nil || err != nil { + p.Log().Warn("Failed to open storage trie for proof", "block", header.Number, "hash", header.Hash(), "account", common.BytesToHash(request.AccKey), "root", account.Root, "err", err) + continue + } + } + // Prove the user's request from the account or stroage trie + if err := trie.Prove(request.Key, request.FromLevel, nodes); err != nil { + p.Log().Warn("Failed to prove state request", "block", header.Number, "hash", header.Hash(), "err", err) + continue + } + if nodes.DataSize() >= softResponseLimit { + break + } + } + reply := p.ReplyProofsV2(req.ReqID, nodes.NodeList()) + sendResponse(req.ReqID, uint64(reqCnt), reply, task.done()) + if metrics.EnabledExpensive { + miscOutTrieProofPacketsMeter.Mark(1) + miscOutTrieProofTrafficMeter.Mark(int64(reply.size())) + } + }() + } + + case GetHelperTrieProofsMsg: + p.Log().Trace("Received helper trie proof request") + if metrics.EnabledExpensive { + miscInHelperTriePacketsMeter.Mark(1) + miscInHelperTrieTrafficMeter.Mark(int64(msg.Size)) + } + var req struct { + ReqID uint64 + Reqs []HelperTrieReq + } + if err := msg.Decode(&req); err != nil { + clientErrorMeter.Mark(1) + return errResp(ErrDecode, "msg %v: %v", msg, err) + } + // Gather state data until the fetch or network limits is reached + var ( + auxBytes int + auxData [][]byte + ) + reqCnt := len(req.Reqs) + if accept(req.ReqID, uint64(reqCnt), MaxHelperTrieProofsFetch) { + go func() { + var ( + lastIdx uint64 + lastType uint + root common.Hash + auxTrie *trie.Trie + ) + nodes := light.NewNodeSet() + for i, request := range req.Reqs { + if i != 0 && !task.waitOrStop() { + sendResponse(req.ReqID, 0, nil, task.servingTime) + return + } + if auxTrie == nil || request.Type != lastType || request.TrieIdx != lastIdx { + auxTrie, lastType, lastIdx = nil, request.Type, request.TrieIdx + + var prefix string + if root, prefix = h.getHelperTrie(request.Type, request.TrieIdx); root != (common.Hash{}) { + auxTrie, _ = trie.New(root, trie.NewDatabase(rawdb.NewTable(h.chainDb, prefix))) + } + } + if request.AuxReq == auxRoot { + var data []byte + if root != (common.Hash{}) { + data = root[:] + } + auxData = append(auxData, data) + auxBytes += len(data) + } else { + if auxTrie != nil { + auxTrie.Prove(request.Key, request.FromLevel, nodes) + } + if request.AuxReq != 0 { + data := h.getAuxiliaryHeaders(request) + auxData = append(auxData, data) + auxBytes += len(data) + } + } + if nodes.DataSize()+auxBytes >= softResponseLimit { + break + } + } + reply := p.ReplyHelperTrieProofs(req.ReqID, HelperTrieResps{Proofs: nodes.NodeList(), AuxData: auxData}) + sendResponse(req.ReqID, uint64(reqCnt), reply, task.done()) + if metrics.EnabledExpensive { + miscOutHelperTriePacketsMeter.Mark(1) + miscOutHelperTrieTrafficMeter.Mark(int64(reply.size())) + } + }() + } + + case SendTxV2Msg: + p.Log().Trace("Received new transactions") + if metrics.EnabledExpensive { + miscInTxsPacketsMeter.Mark(1) + miscInTxsTrafficMeter.Mark(int64(msg.Size)) + } + var req struct { + ReqID uint64 + Txs []*types.Transaction + } + if err := msg.Decode(&req); err != nil { + clientErrorMeter.Mark(1) + return errResp(ErrDecode, "msg %v: %v", msg, err) + } + reqCnt := len(req.Txs) + if accept(req.ReqID, uint64(reqCnt), MaxTxSend) { + go func() { + stats := make([]light.TxStatus, len(req.Txs)) + for i, tx := range req.Txs { + if i != 0 && !task.waitOrStop() { + return + } + hash := tx.Hash() + stats[i] = h.txStatus(hash) + if stats[i].Status == core.TxStatusUnknown { + addFn := h.txpool.AddRemotes + // Add txs synchronously for testing purpose + if h.addTxsSync { + addFn = h.txpool.AddRemotesSync + } + if errs := addFn([]*types.Transaction{tx}); errs[0] != nil { + stats[i].Error = errs[0].Error() + continue + } + stats[i] = h.txStatus(hash) + } + } + reply := p.ReplyTxStatus(req.ReqID, stats) + sendResponse(req.ReqID, uint64(reqCnt), reply, task.done()) + if metrics.EnabledExpensive { + miscOutTxsPacketsMeter.Mark(1) + miscOutTxsTrafficMeter.Mark(int64(reply.size())) + } + }() + } + + case GetTxStatusMsg: + p.Log().Trace("Received transaction status query request") + if metrics.EnabledExpensive { + miscInTxStatusPacketsMeter.Mark(1) + miscInTxStatusTrafficMeter.Mark(int64(msg.Size)) + } + var req struct { + ReqID uint64 + Hashes []common.Hash + } + if err := msg.Decode(&req); err != nil { + clientErrorMeter.Mark(1) + return errResp(ErrDecode, "msg %v: %v", msg, err) + } + reqCnt := len(req.Hashes) + if accept(req.ReqID, uint64(reqCnt), MaxTxStatus) { + go func() { + stats := make([]light.TxStatus, len(req.Hashes)) + for i, hash := range req.Hashes { + if i != 0 && !task.waitOrStop() { + sendResponse(req.ReqID, 0, nil, task.servingTime) + return + } + stats[i] = h.txStatus(hash) + } + reply := p.ReplyTxStatus(req.ReqID, stats) + sendResponse(req.ReqID, uint64(reqCnt), reply, task.done()) + if metrics.EnabledExpensive { + miscOutTxStatusPacketsMeter.Mark(1) + miscOutTxStatusTrafficMeter.Mark(int64(reply.size())) + } + }() + } + + default: + p.Log().Trace("Received invalid message", "code", msg.Code) + clientErrorMeter.Mark(1) + return errResp(ErrInvalidMsgCode, "%v", msg.Code) + } + // If the client has made too much invalid request(e.g. request a non-exist data), + // reject them to prevent SPAM attack. + if atomic.LoadUint32(&p.invalidCount) > maxRequestErrors { + clientErrorMeter.Mark(1) + return errTooManyInvalidRequest + } + return nil +} + +// getAccount retrieves an account from the state based on root. +func (h *serverHandler) getAccount(triedb *trie.Database, root, hash common.Hash) (state.Account, error) { + trie, err := trie.New(root, triedb) + if err != nil { + return state.Account{}, err + } + blob, err := trie.TryGet(hash[:]) + if err != nil { + return state.Account{}, err + } + var account state.Account + if err = rlp.DecodeBytes(blob, &account); err != nil { + return state.Account{}, err + } + return account, nil +} + +// getHelperTrie returns the post-processed trie root for the given trie ID and section index +func (h *serverHandler) getHelperTrie(typ uint, index uint64) (common.Hash, string) { + switch typ { + case htCanonical: + sectionHead := rawdb.ReadCanonicalHash(h.chainDb, (index+1)*h.server.iConfig.ChtSize-1) + return light.GetChtRoot(h.chainDb, index, sectionHead), light.ChtTablePrefix + case htBloomBits: + sectionHead := rawdb.ReadCanonicalHash(h.chainDb, (index+1)*h.server.iConfig.BloomTrieSize-1) + return light.GetBloomTrieRoot(h.chainDb, index, sectionHead), light.BloomTrieTablePrefix + } + return common.Hash{}, "" +} + +// getAuxiliaryHeaders returns requested auxiliary headers for the CHT request. +func (h *serverHandler) getAuxiliaryHeaders(req HelperTrieReq) []byte { + if req.Type == htCanonical && req.AuxReq == auxHeader && len(req.Key) == 8 { + blockNum := binary.BigEndian.Uint64(req.Key) + hash := rawdb.ReadCanonicalHash(h.chainDb, blockNum) + return rawdb.ReadHeaderRLP(h.chainDb, hash, blockNum) + } + return nil +} + +// txStatus returns the status of a specified transaction. +func (h *serverHandler) txStatus(hash common.Hash) light.TxStatus { + var stat light.TxStatus + // Looking the transaction in txpool first. + stat.Status = h.txpool.Status([]common.Hash{hash})[0] + + // If the transaction is unknown to the pool, try looking it up locally. + if stat.Status == core.TxStatusUnknown { + lookup := h.blockchain.GetTransactionLookup(hash) + if lookup != nil { + stat.Status = core.TxStatusIncluded + stat.Lookup = lookup + } + } + return stat +} + +// broadcastHeaders broadcasts new block information to all connected light +// clients. According to the agreement between client and server, server should +// only broadcast new announcement if the total difficulty is higher than the +// last one. Besides server will add the signature if client requires. +func (h *serverHandler) broadcastHeaders() { + defer h.wg.Done() + + headCh := make(chan core.ChainHeadEvent, 10) + headSub := h.blockchain.SubscribeChainHeadEvent(headCh) + defer headSub.Unsubscribe() + + var ( + lastHead *types.Header + lastTd = common.Big0 + ) + for { + select { + case ev := <-headCh: + peers := h.server.peers.AllPeers() + if len(peers) == 0 { + continue + } + header := ev.Block.Header() + hash, number := header.Hash(), header.Number.Uint64() + td := h.blockchain.GetTd(hash, number) + if td == nil || td.Cmp(lastTd) <= 0 { + continue + } + var reorg uint64 + if lastHead != nil { + reorg = lastHead.Number.Uint64() - rawdb.FindCommonAncestor(h.chainDb, header, lastHead).Number.Uint64() + } + lastHead, lastTd = header, td + + log.Debug("Announcing block to peers", "number", number, "hash", hash, "td", td, "reorg", reorg) + var ( + signed bool + signedAnnounce announceData + ) + announce := announceData{Hash: hash, Number: number, Td: td, ReorgDepth: reorg} + for _, p := range peers { + p := p + switch p.announceType { + case announceTypeSimple: + p.queueSend(func() { p.SendAnnounce(announce) }) + case announceTypeSigned: + if !signed { + signedAnnounce = announce + signedAnnounce.sign(h.server.privateKey) + signed = true + } + p.queueSend(func() { p.SendAnnounce(signedAnnounce) }) + } + } + case <-h.closeCh: + return + } + } +} diff --git a/les/serverpool.go b/les/serverpool.go index 3e8cdee41..37621dc63 100644 --- a/les/serverpool.go +++ b/les/serverpool.go @@ -115,8 +115,6 @@ type serverPool struct { db ethdb.Database dbKey []byte server *p2p.Server - quit chan struct{} - wg *sync.WaitGroup connWg sync.WaitGroup topic discv5.Topic @@ -137,14 +135,15 @@ type serverPool struct { connCh chan *connReq disconnCh chan *disconnReq registerCh chan *registerReq + + closeCh chan struct{} + wg sync.WaitGroup } // newServerPool creates a new serverPool instance -func newServerPool(db ethdb.Database, quit chan struct{}, wg *sync.WaitGroup, trustedNodes []string) *serverPool { +func newServerPool(db ethdb.Database, ulcServers []string) *serverPool { pool := &serverPool{ db: db, - quit: quit, - wg: wg, entries: make(map[enode.ID]*poolEntry), timeout: make(chan *poolEntry, 1), adjustStats: make(chan poolStatAdjust, 100), @@ -152,10 +151,11 @@ func newServerPool(db ethdb.Database, quit chan struct{}, wg *sync.WaitGroup, tr connCh: make(chan *connReq), disconnCh: make(chan *disconnReq), registerCh: make(chan *registerReq), + closeCh: make(chan struct{}), knownSelect: newWeightedRandomSelect(), newSelect: newWeightedRandomSelect(), fastDiscover: true, - trustedNodes: parseTrustedNodes(trustedNodes), + trustedNodes: parseTrustedNodes(ulcServers), } pool.knownQueue = newPoolEntryQueue(maxKnownEntries, pool.removeEntry) @@ -167,7 +167,6 @@ func (pool *serverPool) start(server *p2p.Server, topic discv5.Topic) { pool.server = server pool.topic = topic pool.dbKey = append([]byte("serverPool/"), []byte(topic)...) - pool.wg.Add(1) pool.loadNodes() pool.connectToTrustedNodes() @@ -178,9 +177,15 @@ func (pool *serverPool) start(server *p2p.Server, topic discv5.Topic) { go pool.discoverNodes() } pool.checkDial() + pool.wg.Add(1) go pool.eventLoop() } +func (pool *serverPool) stop() { + close(pool.closeCh) + pool.wg.Wait() +} + // discoverNodes wraps SearchTopic, converting result nodes to enode.Node. func (pool *serverPool) discoverNodes() { ch := make(chan *discv5.Node) @@ -207,7 +212,7 @@ func (pool *serverPool) connect(p *peer, node *enode.Node) *poolEntry { req := &connReq{p: p, node: node, result: make(chan *poolEntry, 1)} select { case pool.connCh <- req: - case <-pool.quit: + case <-pool.closeCh: return nil } return <-req.result @@ -219,7 +224,7 @@ func (pool *serverPool) registered(entry *poolEntry) { req := ®isterReq{entry: entry, done: make(chan struct{})} select { case pool.registerCh <- req: - case <-pool.quit: + case <-pool.closeCh: return } <-req.done @@ -231,7 +236,7 @@ func (pool *serverPool) registered(entry *poolEntry) { func (pool *serverPool) disconnect(entry *poolEntry) { stopped := false select { - case <-pool.quit: + case <-pool.closeCh: stopped = true default: } @@ -278,6 +283,7 @@ func (pool *serverPool) adjustResponseTime(entry *poolEntry, time time.Duration, // eventLoop handles pool events and mutex locking for all internal functions func (pool *serverPool) eventLoop() { + defer pool.wg.Done() lookupCnt := 0 var convTime mclock.AbsTime if pool.discSetPeriod != nil { @@ -361,7 +367,7 @@ func (pool *serverPool) eventLoop() { case req := <-pool.connCh: if pool.trustedNodes[req.p.ID()] != nil { // ignore trusted nodes - req.result <- nil + req.result <- &poolEntry{trusted: true} } else { // Handle peer connection requests. entry := pool.entries[req.p.ID()] @@ -389,6 +395,9 @@ func (pool *serverPool) eventLoop() { } case req := <-pool.registerCh: + if req.entry.trusted { + continue + } // Handle peer registration requests. entry := req.entry entry.state = psRegistered @@ -402,10 +411,13 @@ func (pool *serverPool) eventLoop() { close(req.done) case req := <-pool.disconnCh: + if req.entry.trusted { + continue + } // Handle peer disconnection requests. disconnect(req, req.stopped) - case <-pool.quit: + case <-pool.closeCh: if pool.discSetPeriod != nil { close(pool.discSetPeriod) } @@ -421,7 +433,6 @@ func (pool *serverPool) eventLoop() { disconnect(req, true) } pool.saveNodes() - pool.wg.Done() return } } @@ -549,10 +560,10 @@ func (pool *serverPool) setRetryDial(entry *poolEntry) { entry.delayedRetry = true go func() { select { - case <-pool.quit: + case <-pool.closeCh: case <-time.After(delay): select { - case <-pool.quit: + case <-pool.closeCh: case pool.enableRetry <- entry: } } @@ -618,10 +629,10 @@ func (pool *serverPool) dial(entry *poolEntry, knownSelected bool) { go func() { pool.server.AddPeer(entry.node) select { - case <-pool.quit: + case <-pool.closeCh: case <-time.After(dialTimeout): select { - case <-pool.quit: + case <-pool.closeCh: case pool.timeout <- entry: } } @@ -662,14 +673,14 @@ type poolEntry struct { lastConnected, dialed *poolEntryAddress addrSelect weightedRandomSelect - lastDiscovered mclock.AbsTime - known, knownSelected bool - connectStats, delayStats poolStats - responseStats, timeoutStats poolStats - state int - regTime mclock.AbsTime - queueIdx int - removed bool + lastDiscovered mclock.AbsTime + known, knownSelected, trusted bool + connectStats, delayStats poolStats + responseStats, timeoutStats poolStats + state int + regTime mclock.AbsTime + queueIdx int + removed bool delayedRetry bool shortRetry int diff --git a/les/sync.go b/les/sync.go index 54fd81c2c..693394464 100644 --- a/les/sync.go +++ b/les/sync.go @@ -43,35 +43,6 @@ const ( checkpointSync ) -// syncer is responsible for periodically synchronising with the network, both -// downloading hashes and blocks as well as handling the announcement handler. -func (pm *ProtocolManager) syncer() { - // Start and ensure cleanup of sync mechanisms - //pm.fetcher.Start() - //defer pm.fetcher.Stop() - defer pm.downloader.Terminate() - - // Wait for different events to fire synchronisation operations - //forceSync := time.Tick(forceSyncCycle) - for { - select { - case <-pm.newPeerCh: - /* // Make sure we have peers to select from, then sync - if pm.peers.Len() < minDesiredPeerCount { - break - } - go pm.synchronise(pm.peers.BestPeer()) - */ - /*case <-forceSync: - // Force a sync even if not enough peers are present - go pm.synchronise(pm.peers.BestPeer()) - */ - case <-pm.noMorePeers: - return - } - } -} - // validateCheckpoint verifies the advertised checkpoint by peer is valid or not. // // Each network has several hard-coded checkpoint signer addresses. Only the @@ -80,22 +51,22 @@ func (pm *ProtocolManager) syncer() { // In addition to the checkpoint registered in the registrar contract, there are // several legacy hardcoded checkpoints in our codebase. These checkpoints are // also considered as valid. -func (pm *ProtocolManager) validateCheckpoint(peer *peer) error { +func (h *clientHandler) validateCheckpoint(peer *peer) error { ctx, cancel := context.WithTimeout(context.Background(), time.Second*10) defer cancel() // Fetch the block header corresponding to the checkpoint registration. cp := peer.checkpoint - header, err := light.GetUntrustedHeaderByNumber(ctx, pm.odr, peer.checkpointNumber, peer.id) + header, err := light.GetUntrustedHeaderByNumber(ctx, h.backend.odr, peer.checkpointNumber, peer.id) if err != nil { return err } // Fetch block logs associated with the block header. - logs, err := light.GetUntrustedBlockLogs(ctx, pm.odr, header) + logs, err := light.GetUntrustedBlockLogs(ctx, h.backend.odr, header) if err != nil { return err } - events := pm.reg.contract.LookupCheckpointEvents(logs, cp.SectionIndex, cp.Hash()) + events := h.backend.oracle.contract.LookupCheckpointEvents(logs, cp.SectionIndex, cp.Hash()) if len(events) == 0 { return errInvalidCheckpoint } @@ -107,7 +78,7 @@ func (pm *ProtocolManager) validateCheckpoint(peer *peer) error { for _, event := range events { signatures = append(signatures, append(event.R[:], append(event.S[:], event.V)...)) } - valid, signers := pm.reg.verifySigners(index, hash, signatures) + valid, signers := h.backend.oracle.verifySigners(index, hash, signatures) if !valid { return errInvalidCheckpoint } @@ -116,14 +87,14 @@ func (pm *ProtocolManager) validateCheckpoint(peer *peer) error { } // synchronise tries to sync up our local chain with a remote peer. -func (pm *ProtocolManager) synchronise(peer *peer) { +func (h *clientHandler) synchronise(peer *peer) { // Short circuit if the peer is nil. if peer == nil { return } // Make sure the peer's TD is higher than our own. - latest := pm.blockchain.CurrentHeader() - currentTd := rawdb.ReadTd(pm.chainDb, latest.Hash(), latest.Number.Uint64()) + latest := h.backend.blockchain.CurrentHeader() + currentTd := rawdb.ReadTd(h.backend.chainDb, latest.Hash(), latest.Number.Uint64()) if currentTd != nil && peer.headBlockInfo().Td.Cmp(currentTd) < 0 { return } @@ -140,8 +111,8 @@ func (pm *ProtocolManager) synchronise(peer *peer) { // => Use provided checkpoint var checkpoint = &peer.checkpoint var hardcoded bool - if pm.checkpoint != nil && pm.checkpoint.SectionIndex >= peer.checkpoint.SectionIndex { - checkpoint = pm.checkpoint // Use the hardcoded one. + if h.checkpoint != nil && h.checkpoint.SectionIndex >= peer.checkpoint.SectionIndex { + checkpoint = h.checkpoint // Use the hardcoded one. hardcoded = true } // Determine whether we should run checkpoint syncing or normal light syncing. @@ -157,34 +128,34 @@ func (pm *ProtocolManager) synchronise(peer *peer) { case checkpoint.Empty(): mode = lightSync log.Debug("Disable checkpoint syncing", "reason", "empty checkpoint") - case latest.Number.Uint64() >= (checkpoint.SectionIndex+1)*pm.iConfig.ChtSize-1: + case latest.Number.Uint64() >= (checkpoint.SectionIndex+1)*h.backend.iConfig.ChtSize-1: mode = lightSync log.Debug("Disable checkpoint syncing", "reason", "local chain beyond the checkpoint") case hardcoded: mode = legacyCheckpointSync log.Debug("Disable checkpoint syncing", "reason", "checkpoint is hardcoded") - case pm.reg == nil || !pm.reg.isRunning(): + case h.backend.oracle == nil || !h.backend.oracle.isRunning(): mode = legacyCheckpointSync log.Debug("Disable checkpoint syncing", "reason", "checkpoint syncing is not activated") } // Notify testing framework if syncing has completed(for testing purpose). defer func() { - if pm.reg != nil && pm.reg.syncDoneHook != nil { - pm.reg.syncDoneHook() + if h.backend.oracle != nil && h.backend.oracle.syncDoneHook != nil { + h.backend.oracle.syncDoneHook() } }() start := time.Now() if mode == checkpointSync || mode == legacyCheckpointSync { // Validate the advertised checkpoint if mode == legacyCheckpointSync { - checkpoint = pm.checkpoint + checkpoint = h.checkpoint } else if mode == checkpointSync { - if err := pm.validateCheckpoint(peer); err != nil { + if err := h.validateCheckpoint(peer); err != nil { log.Debug("Failed to validate checkpoint", "reason", err) - pm.removePeer(peer.id) + h.removePeer(peer.id) return } - pm.blockchain.(*light.LightChain).AddTrustedCheckpoint(checkpoint) + h.backend.blockchain.AddTrustedCheckpoint(checkpoint) } log.Debug("Checkpoint syncing start", "peer", peer.id, "checkpoint", checkpoint.SectionIndex) @@ -197,14 +168,14 @@ func (pm *ProtocolManager) synchronise(peer *peer) { // of the latest epoch covered by checkpoint. ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) defer cancel() - if !checkpoint.Empty() && !pm.blockchain.(*light.LightChain).SyncCheckpoint(ctx, checkpoint) { + if !checkpoint.Empty() && !h.backend.blockchain.SyncCheckpoint(ctx, checkpoint) { log.Debug("Sync checkpoint failed") - pm.removePeer(peer.id) + h.removePeer(peer.id) return } } // Fetch the remaining block headers based on the current chain header. - if err := pm.downloader.Synchronise(peer.id, peer.Head(), peer.Td(), downloader.LightSync); err != nil { + if err := h.downloader.Synchronise(peer.id, peer.Head(), peer.Td(), downloader.LightSync); err != nil { log.Debug("Synchronise failed", "reason", err) return } diff --git a/les/sync_test.go b/les/sync_test.go index 3a75d6856..63833c1ab 100644 --- a/les/sync_test.go +++ b/les/sync_test.go @@ -57,7 +57,7 @@ func testCheckpointSyncing(t *testing.T, protocol int, syncMode int) { } } // Generate 512+4 blocks (totally 1 CHT sections) - server, client, tearDown := newClientServerEnv(t, int(config.ChtSize+config.ChtConfirms), protocol, waitIndexers, false) + server, client, tearDown := newClientServerEnv(t, int(config.ChtSize+config.ChtConfirms), protocol, waitIndexers, nil, 0, false, false) defer tearDown() expected := config.ChtSize + config.ChtConfirms @@ -74,8 +74,8 @@ func testCheckpointSyncing(t *testing.T, protocol int, syncMode int) { } if syncMode == 1 { // Register the assembled checkpoint as hardcoded one. - client.pm.checkpoint = cp - client.pm.blockchain.(*light.LightChain).AddTrustedCheckpoint(cp) + client.handler.checkpoint = cp + client.handler.backend.blockchain.AddTrustedCheckpoint(cp) } else { // Register the assembled checkpoint into oracle. header := server.backend.Blockchain().CurrentHeader() @@ -83,14 +83,14 @@ func testCheckpointSyncing(t *testing.T, protocol int, syncMode int) { data := append([]byte{0x19, 0x00}, append(registrarAddr.Bytes(), append([]byte{0, 0, 0, 0, 0, 0, 0, 0}, cp.Hash().Bytes()...)...)...) sig, _ := crypto.Sign(crypto.Keccak256(data), signerKey) sig[64] += 27 // Transform V from 0/1 to 27/28 according to the yellow paper - if _, err := server.pm.reg.contract.RegisterCheckpoint(bind.NewKeyedTransactor(signerKey), cp.SectionIndex, cp.Hash().Bytes(), new(big.Int).Sub(header.Number, big.NewInt(1)), header.ParentHash, [][]byte{sig}); err != nil { + if _, err := server.handler.server.oracle.contract.RegisterCheckpoint(bind.NewKeyedTransactor(signerKey), cp.SectionIndex, cp.Hash().Bytes(), new(big.Int).Sub(header.Number, big.NewInt(1)), header.ParentHash, [][]byte{sig}); err != nil { t.Error("register checkpoint failed", err) } server.backend.Commit() // Wait for the checkpoint registration for { - _, hash, _, err := server.pm.reg.contract.Contract().GetLatestCheckpoint(nil) + _, hash, _, err := server.handler.server.oracle.contract.Contract().GetLatestCheckpoint(nil) if err != nil || hash == [32]byte{} { time.Sleep(100 * time.Millisecond) continue @@ -102,8 +102,8 @@ func testCheckpointSyncing(t *testing.T, protocol int, syncMode int) { } done := make(chan error) - client.pm.reg.syncDoneHook = func() { - header := client.pm.blockchain.CurrentHeader() + client.handler.backend.oracle.syncDoneHook = func() { + header := client.handler.backend.blockchain.CurrentHeader() if header.Number.Uint64() == expected { done <- nil } else { @@ -112,7 +112,7 @@ func testCheckpointSyncing(t *testing.T, protocol int, syncMode int) { } // Create connected peer pair. - peer, err1, lPeer, err2 := newTestPeerPair("peer", protocol, server.pm, client.pm) + _, err1, _, err2 := newTestPeerPair("peer", protocol, server.handler, client.handler) select { case <-time.After(time.Millisecond * 100): case err := <-err1: @@ -120,7 +120,6 @@ func testCheckpointSyncing(t *testing.T, protocol int, syncMode int) { case err := <-err2: t.Fatalf("peer 2 handshake error: %v", err) } - server.rPeer, client.rPeer = peer, lPeer select { case err := <-done: diff --git a/les/helper_test.go b/les/test_helper.go similarity index 55% rename from les/helper_test.go rename to les/test_helper.go index d66cfa1a5..2efaa769f 100644 --- a/les/helper_test.go +++ b/les/test_helper.go @@ -23,7 +23,6 @@ import ( "context" "crypto/rand" "math/big" - "sync" "testing" "time" @@ -57,8 +56,8 @@ var ( userAddr1 = crypto.PubkeyToAddress(userKey1.PublicKey) userAddr2 = crypto.PubkeyToAddress(userKey2.PublicKey) - testContractCode = common.Hex2Bytes("606060405260cc8060106000396000f360606040526000357c01000000000000000000000000000000000000000000000000000000009004806360cd2685146041578063c16431b914606b57603f565b005b6055600480803590602001909190505060a9565b6040518082815260200191505060405180910390f35b60886004808035906020019091908035906020019091905050608a565b005b80600060005083606481101560025790900160005b50819055505b5050565b6000600060005082606481101560025790900160005b5054905060c7565b91905056") testContractAddr common.Address + testContractCode = common.Hex2Bytes("606060405260cc8060106000396000f360606040526000357c01000000000000000000000000000000000000000000000000000000009004806360cd2685146041578063c16431b914606b57603f565b005b6055600480803590602001909190505060a9565b6040518082815260200191505060405180910390f35b60886004808035906020019091908035906020019091905050608a565b005b80600060005083606481101560025790900160005b50819055505b5050565b6000600060005082606481101560025790900160005b5054905060c7565b91905056") testContractCodeDeployed = testContractCode[16:] testContractDeployed = uint64(2) @@ -77,8 +76,10 @@ var ( // The number of confirmations needed to generate a checkpoint(only used in test). processConfirms = big.NewInt(4) - // - testBufLimit = uint64(1000000) + // The token bucket buffer limit for testing purpose. + testBufLimit = uint64(1000000) + + // The buffer recharging speed for testing purpose. testBufRecharge = uint64(1000) ) @@ -97,8 +98,8 @@ contract test { } */ -// prepareTestchain pre-commits specified number customized blocks into chain. -func prepareTestchain(n int, backend *backends.SimulatedBackend) { +// prepare pre-commits specified number customized blocks into chain. +func prepare(n int, backend *backends.SimulatedBackend) { var ( ctx = context.Background() signer = types.HomesteadSigner{} @@ -164,51 +165,25 @@ func testIndexers(db ethdb.Database, odr light.OdrBackend, config *light.Indexer return indexers[:] } -// newTestProtocolManager creates a new protocol manager for testing purposes, -// with the given number of blocks already known, potential notification -// channels for different events and relative chain indexers array. -func newTestProtocolManager(lightSync bool, blocks int, odr *LesOdr, indexers []*core.ChainIndexer, peers *peerSet, db ethdb.Database, ulcServers []string, ulcFraction int, testCost uint64, clock mclock.Clock) (*ProtocolManager, *backends.SimulatedBackend, error) { +func newTestClientHandler(backend *backends.SimulatedBackend, odr *LesOdr, indexers []*core.ChainIndexer, db ethdb.Database, peers *peerSet, ulcServers []string, ulcFraction int) *clientHandler { var ( evmux = new(event.TypeMux) engine = ethash.NewFaker() gspec = core.Genesis{ - Config: params.AllEthashProtocolChanges, - Alloc: core.GenesisAlloc{bankAddr: {Balance: bankFunds}}, + Config: params.AllEthashProtocolChanges, + Alloc: core.GenesisAlloc{bankAddr: {Balance: bankFunds}}, + GasLimit: 100000000, } - pool txPool - chain BlockChain - exitCh = make(chan struct{}) + oracle *checkpointOracle ) - gspec.MustCommit(db) - if peers == nil { - peers = newPeerSet() - } - // create a simulation backend and pre-commit several customized block to the database. - simulation := backends.NewSimulatedBackendWithDatabase(db, gspec.Alloc, 100000000) - prepareTestchain(blocks, simulation) - - // initialize empty chain for light client or pre-committed chain for server. - if lightSync { - chain, _ = light.NewLightChain(odr, gspec.Config, engine, nil) - } else { - chain = simulation.Blockchain() - config := core.DefaultTxPoolConfig - config.Journal = "" - pool = core.NewTxPool(config, gspec.Config, simulation.Blockchain()) - } - - // Create contract registrar - indexConfig := light.TestServerIndexerConfig - if lightSync { - indexConfig = light.TestClientIndexerConfig - } - config := ¶ms.CheckpointOracleConfig{ - Address: crypto.CreateAddress(bankAddr, 0), - Signers: []common.Address{signerAddr}, - Threshold: 1, - } - var reg *checkpointOracle + genesis := gspec.MustCommit(db) + chain, _ := light.NewLightChain(odr, gspec.Config, engine, nil) if indexers != nil { + checkpointConfig := ¶ms.CheckpointOracleConfig{ + Address: crypto.CreateAddress(bankAddr, 0), + Signers: []common.Address{signerAddr}, + Threshold: 1, + } getLocal := func(index uint64) params.TrustedCheckpoint { chtIndexer := indexers[0] sectionHead := chtIndexer.SectionHead(index) @@ -219,72 +194,126 @@ func newTestProtocolManager(lightSync bool, blocks int, odr *LesOdr, indexers [] BloomRoot: light.GetBloomTrieRoot(db, index, sectionHead), } } - reg = newCheckpointOracle(config, getLocal) + oracle = newCheckpointOracle(checkpointConfig, getLocal) } - pm, err := NewProtocolManager(gspec.Config, nil, indexConfig, ulcServers, ulcFraction, lightSync, NetworkId, evmux, peers, chain, pool, db, odr, nil, reg, exitCh, new(sync.WaitGroup), func() bool { return true }) - if err != nil { - return nil, nil, err + client := &LightEthereum{ + lesCommons: lesCommons{ + genesis: genesis.Hash(), + config: ð.Config{LightPeers: 100, NetworkId: NetworkId}, + chainConfig: params.AllEthashProtocolChanges, + iConfig: light.TestClientIndexerConfig, + chainDb: db, + oracle: oracle, + chainReader: chain, + peers: peers, + closeCh: make(chan struct{}), + }, + reqDist: odr.retriever.dist, + retriever: odr.retriever, + odr: odr, + engine: engine, + blockchain: chain, + eventMux: evmux, } - // Registrar initialization could failed if checkpoint contract is not specified. - if pm.reg != nil { - pm.reg.start(simulation) - } - // Set up les server stuff. - if !lightSync { - srv := &LesServer{lesCommons: lesCommons{protocolManager: pm, chainDb: db}} - pm.server = srv - pm.servingQueue = newServingQueue(int64(time.Millisecond*10), 1) - pm.servingQueue.setThreads(4) + client.handler = newClientHandler(ulcServers, ulcFraction, nil, client) - srv.defParams = flowcontrol.ServerParams{ - BufLimit: testBufLimit, - MinRecharge: testBufRecharge, - } - srv.testCost = testCost - srv.fcManager = flowcontrol.NewClientManager(nil, clock) + if client.oracle != nil { + client.oracle.start(backend) } - pm.Start(1000) - return pm, simulation, nil + return client.handler } -// newTestProtocolManagerMust creates a new protocol manager for testing purposes, -// with the given number of blocks already known, potential notification channels -// for different events and relative chain indexers array. In case of an error, the -// constructor force-fails the test. -func newTestProtocolManagerMust(t *testing.T, lightSync bool, blocks int, odr *LesOdr, indexers []*core.ChainIndexer, peers *peerSet, db ethdb.Database, ulcServers []string, ulcFraction int) (*ProtocolManager, *backends.SimulatedBackend) { - pm, backend, err := newTestProtocolManager(lightSync, blocks, odr, indexers, peers, db, ulcServers, ulcFraction, 0, &mclock.System{}) - if err != nil { - t.Fatalf("Failed to create protocol manager: %v", err) +func newTestServerHandler(blocks int, indexers []*core.ChainIndexer, db ethdb.Database, peers *peerSet, clock mclock.Clock) (*serverHandler, *backends.SimulatedBackend) { + var ( + gspec = core.Genesis{ + Config: params.AllEthashProtocolChanges, + Alloc: core.GenesisAlloc{bankAddr: {Balance: bankFunds}}, + GasLimit: 100000000, + } + oracle *checkpointOracle + ) + genesis := gspec.MustCommit(db) + + // create a simulation backend and pre-commit several customized block to the database. + simulation := backends.NewSimulatedBackendWithDatabase(db, gspec.Alloc, 100000000) + prepare(blocks, simulation) + + txpoolConfig := core.DefaultTxPoolConfig + txpoolConfig.Journal = "" + txpool := core.NewTxPool(txpoolConfig, gspec.Config, simulation.Blockchain()) + if indexers != nil { + checkpointConfig := ¶ms.CheckpointOracleConfig{ + Address: crypto.CreateAddress(bankAddr, 0), + Signers: []common.Address{signerAddr}, + Threshold: 1, + } + getLocal := func(index uint64) params.TrustedCheckpoint { + chtIndexer := indexers[0] + sectionHead := chtIndexer.SectionHead(index) + return params.TrustedCheckpoint{ + SectionIndex: index, + SectionHead: sectionHead, + CHTRoot: light.GetChtRoot(db, index, sectionHead), + BloomRoot: light.GetBloomTrieRoot(db, index, sectionHead), + } + } + oracle = newCheckpointOracle(checkpointConfig, getLocal) } - return pm, backend + server := &LesServer{ + lesCommons: lesCommons{ + genesis: genesis.Hash(), + config: ð.Config{LightPeers: 100, NetworkId: NetworkId}, + chainConfig: params.AllEthashProtocolChanges, + iConfig: light.TestServerIndexerConfig, + chainDb: db, + chainReader: simulation.Blockchain(), + oracle: oracle, + peers: peers, + closeCh: make(chan struct{}), + }, + servingQueue: newServingQueue(int64(time.Millisecond*10), 1), + defParams: flowcontrol.ServerParams{ + BufLimit: testBufLimit, + MinRecharge: testBufRecharge, + }, + fcManager: flowcontrol.NewClientManager(nil, clock), + } + server.costTracker, server.freeCapacity = newCostTracker(db, server.config) + server.costTracker.testCostList = testCostList(0) // Disable flow control mechanism. + server.handler = newServerHandler(server, simulation.Blockchain(), db, txpool, func() bool { return true }) + if server.oracle != nil { + server.oracle.start(simulation) + } + server.servingQueue.setThreads(4) + server.handler.start() + return server.handler, simulation } // testPeer is a simulated peer to allow testing direct network calls. type testPeer struct { + peer *peer + net p2p.MsgReadWriter // Network layer reader/writer to simulate remote messaging app *p2p.MsgPipeRW // Application layer reader/writer to simulate the local side - *peer } // newTestPeer creates a new peer registered at the given protocol manager. -func newTestPeer(t *testing.T, name string, version int, pm *ProtocolManager, shake bool, testCost uint64) (*testPeer, <-chan error) { +func newTestPeer(t *testing.T, name string, version int, handler *serverHandler, shake bool, testCost uint64) (*testPeer, <-chan error) { // Create a message pipe to communicate through app, net := p2p.MsgPipe() // Generate a random id and create the peer var id enode.ID rand.Read(id[:]) - - peer := pm.newPeer(version, NetworkId, p2p.NewPeer(id, name, nil), net) + peer := newPeer(version, NetworkId, false, p2p.NewPeer(id, name, nil), net) // Start the peer on a new thread - errc := make(chan error, 1) + errCh := make(chan error, 1) go func() { select { - case pm.newPeerCh <- peer: - errc <- pm.handle(peer) - case <-pm.quitSync: - errc <- p2p.DiscQuitting + case <-handler.closeCh: + errCh <- p2p.DiscQuitting + case errCh <- handler.handle(peer): } }() tp := &testPeer{ @@ -294,17 +323,27 @@ func newTestPeer(t *testing.T, name string, version int, pm *ProtocolManager, sh } // Execute any implicitly requested handshakes and return if shake { + // Customize the cost table if required. + if testCost != 0 { + handler.server.costTracker.testCostList = testCostList(testCost) + } var ( - genesis = pm.blockchain.Genesis() - head = pm.blockchain.CurrentHeader() - td = pm.blockchain.GetTd(head.Hash(), head.Number.Uint64()) + genesis = handler.blockchain.Genesis() + head = handler.blockchain.CurrentHeader() + td = handler.blockchain.GetTd(head.Hash(), head.Number.Uint64()) ) - tp.handshake(t, td, head.Hash(), head.Number.Uint64(), genesis.Hash(), testCost) + tp.handshake(t, td, head.Hash(), head.Number.Uint64(), genesis.Hash(), testCostList(testCost)) } - return tp, errc + return tp, errCh } -func newTestPeerPair(name string, version int, pm, pm2 *ProtocolManager) (*peer, <-chan error, *peer, <-chan error) { +// close terminates the local side of the peer, notifying the remote protocol +// manager of termination. +func (p *testPeer) close() { + p.app.Close() +} + +func newTestPeerPair(name string, version int, server *serverHandler, client *clientHandler) (*testPeer, <-chan error, *testPeer, <-chan error) { // Create a message pipe to communicate through app, net := p2p.MsgPipe() @@ -312,36 +351,34 @@ func newTestPeerPair(name string, version int, pm, pm2 *ProtocolManager) (*peer, var id enode.ID rand.Read(id[:]) - peer := pm.newPeer(version, NetworkId, p2p.NewPeer(id, name, nil), net) - peer2 := pm2.newPeer(version, NetworkId, p2p.NewPeer(id, name, nil), app) + peer1 := newPeer(version, NetworkId, false, p2p.NewPeer(id, name, nil), net) + peer2 := newPeer(version, NetworkId, false, p2p.NewPeer(id, name, nil), app) // Start the peer on a new thread - errc := make(chan error, 1) + errc1 := make(chan error, 1) errc2 := make(chan error, 1) go func() { select { - case pm.newPeerCh <- peer: - errc <- pm.handle(peer) - case <-pm.quitSync: - errc <- p2p.DiscQuitting + case <-server.closeCh: + errc1 <- p2p.DiscQuitting + case errc1 <- server.handle(peer1): } }() go func() { select { - case pm2.newPeerCh <- peer2: - errc2 <- pm2.handle(peer2) - case <-pm2.quitSync: - errc2 <- p2p.DiscQuitting + case <-client.closeCh: + errc1 <- p2p.DiscQuitting + case errc1 <- client.handle(peer2): } }() - return peer, errc, peer2, errc2 + return &testPeer{peer: peer1, net: net, app: app}, errc1, &testPeer{peer: peer2, net: app, app: net}, errc2 } // handshake simulates a trivial handshake that expects the same state from the // remote side as we are simulating locally. -func (p *testPeer) handshake(t *testing.T, td *big.Int, head common.Hash, headNum uint64, genesis common.Hash, testCost uint64) { +func (p *testPeer) handshake(t *testing.T, td *big.Int, head common.Hash, headNum uint64, genesis common.Hash, costList RequestCostList) { var expList keyValueList - expList = expList.add("protocolVersion", uint64(p.version)) + expList = expList.add("protocolVersion", uint64(p.peer.version)) expList = expList.add("networkId", uint64(NetworkId)) expList = expList.add("headTd", td) expList = expList.add("headHash", head) @@ -356,7 +393,7 @@ func (p *testPeer) handshake(t *testing.T, td *big.Int, head common.Hash, headNu expList = expList.add("txRelay", nil) expList = expList.add("flowControl/BL", testBufLimit) expList = expList.add("flowControl/MRR", testBufRecharge) - expList = expList.add("flowControl/MRC", testCostList(testCost)) + expList = expList.add("flowControl/MRC", costList) if err := p2p.ExpectMsg(p.app, StatusMsg, expList); err != nil { t.Fatalf("status recv: %v", err) @@ -364,113 +401,119 @@ func (p *testPeer) handshake(t *testing.T, td *big.Int, head common.Hash, headNu if err := p2p.Send(p.app, StatusMsg, sendList); err != nil { t.Fatalf("status send: %v", err) } - - p.fcParams = flowcontrol.ServerParams{ + p.peer.fcParams = flowcontrol.ServerParams{ BufLimit: testBufLimit, MinRecharge: testBufRecharge, } } -// close terminates the local side of the peer, notifying the remote protocol -// manager of termination. -func (p *testPeer) close() { - p.app.Close() -} +type indexerCallback func(*core.ChainIndexer, *core.ChainIndexer, *core.ChainIndexer) -// TestEntity represents a network entity for testing with necessary auxiliary fields. -type TestEntity struct { +// testClient represents a client for testing with necessary auxiliary fields. +type testClient struct { + clock mclock.Clock db ethdb.Database - rPeer *peer - tPeer *testPeer - peers *peerSet - pm *ProtocolManager - backend *backends.SimulatedBackend + peer *testPeer + handler *clientHandler - // Indexers chtIndexer *core.ChainIndexer bloomIndexer *core.ChainIndexer bloomTrieIndexer *core.ChainIndexer } -// newServerEnv creates a server testing environment with a connected test peer for testing purpose. -func newServerEnv(t *testing.T, blocks int, protocol int, waitIndexers func(*core.ChainIndexer, *core.ChainIndexer, *core.ChainIndexer)) (*TestEntity, func()) { +// testServer represents a server for testing with necessary auxiliary fields. +type testServer struct { + clock mclock.Clock + backend *backends.SimulatedBackend + db ethdb.Database + peer *testPeer + handler *serverHandler + + chtIndexer *core.ChainIndexer + bloomIndexer *core.ChainIndexer + bloomTrieIndexer *core.ChainIndexer +} + +func newServerEnv(t *testing.T, blocks int, protocol int, callback indexerCallback, simClock bool, newPeer bool, testCost uint64) (*testServer, func()) { db := rawdb.NewMemoryDatabase() indexers := testIndexers(db, nil, light.TestServerIndexerConfig) - pm, b := newTestProtocolManagerMust(t, false, blocks, nil, indexers, nil, db, nil, 0) - peer, _ := newTestPeer(t, "peer", protocol, pm, true, 0) + var clock mclock.Clock = &mclock.System{} + if simClock { + clock = &mclock.Simulated{} + } + handler, b := newTestServerHandler(blocks, indexers, db, newPeerSet(), clock) + + var peer *testPeer + if newPeer { + peer, _ = newTestPeer(t, "peer", protocol, handler, true, testCost) + } cIndexer, bIndexer, btIndexer := indexers[0], indexers[1], indexers[2] - cIndexer.Start(pm.blockchain.(*core.BlockChain)) - bIndexer.Start(pm.blockchain.(*core.BlockChain)) + cIndexer.Start(handler.blockchain) + bIndexer.Start(handler.blockchain) // Wait until indexers generate enough index data. - if waitIndexers != nil { - waitIndexers(cIndexer, bIndexer, btIndexer) + if callback != nil { + callback(cIndexer, bIndexer, btIndexer) } - - return &TestEntity{ - db: db, - tPeer: peer, - pm: pm, - backend: b, - chtIndexer: cIndexer, - bloomIndexer: bIndexer, - bloomTrieIndexer: btIndexer, - }, func() { + server := &testServer{ + clock: clock, + backend: b, + db: db, + peer: peer, + handler: handler, + chtIndexer: cIndexer, + bloomIndexer: bIndexer, + bloomTrieIndexer: btIndexer, + } + teardown := func() { + if newPeer { peer.close() - // Note bloom trie indexer will be closed by it parent recursively. - cIndexer.Close() - bIndexer.Close() b.Close() } + cIndexer.Close() + bIndexer.Close() + } + return server, teardown } -// newClientServerEnv creates a client/server arch environment with a connected les server and light client pair -// for testing purpose. -func newClientServerEnv(t *testing.T, blocks int, protocol int, waitIndexers func(*core.ChainIndexer, *core.ChainIndexer, *core.ChainIndexer), newPeer bool) (*TestEntity, *TestEntity, func()) { - db, ldb := rawdb.NewMemoryDatabase(), rawdb.NewMemoryDatabase() - peers, lPeers := newPeerSet(), newPeerSet() +func newClientServerEnv(t *testing.T, blocks int, protocol int, callback indexerCallback, ulcServers []string, ulcFraction int, simClock bool, connect bool) (*testServer, *testClient, func()) { + sdb, cdb := rawdb.NewMemoryDatabase(), rawdb.NewMemoryDatabase() + speers, cPeers := newPeerSet(), newPeerSet() - dist := newRequestDistributor(lPeers, make(chan struct{}), &mclock.System{}) - rm := newRetrieveManager(lPeers, dist, nil) - odr := NewLesOdr(ldb, light.TestClientIndexerConfig, rm) - - indexers := testIndexers(db, nil, light.TestServerIndexerConfig) - lIndexers := testIndexers(ldb, odr, light.TestClientIndexerConfig) - - cIndexer, bIndexer, btIndexer := indexers[0], indexers[1], indexers[2] - lcIndexer, lbIndexer, lbtIndexer := lIndexers[0], lIndexers[1], lIndexers[2] - - odr.SetIndexers(lcIndexer, lbtIndexer, lbIndexer) - - pm, b := newTestProtocolManagerMust(t, false, blocks, nil, indexers, peers, db, nil, 0) - lpm, lb := newTestProtocolManagerMust(t, true, 0, odr, lIndexers, lPeers, ldb, nil, 0) - - startIndexers := func(clientMode bool, pm *ProtocolManager) { - if clientMode { - lcIndexer.Start(pm.blockchain.(*light.LightChain)) - lbIndexer.Start(pm.blockchain.(*light.LightChain)) - } else { - cIndexer.Start(pm.blockchain.(*core.BlockChain)) - bIndexer.Start(pm.blockchain.(*core.BlockChain)) - } + var clock mclock.Clock = &mclock.System{} + if simClock { + clock = &mclock.Simulated{} } + dist := newRequestDistributor(cPeers, clock) + rm := newRetrieveManager(cPeers, dist, nil) + odr := NewLesOdr(cdb, light.TestClientIndexerConfig, rm) - startIndexers(false, pm) - startIndexers(true, lpm) + sindexers := testIndexers(sdb, nil, light.TestServerIndexerConfig) + cIndexers := testIndexers(cdb, odr, light.TestClientIndexerConfig) - // Execute wait until function if it is specified. - if waitIndexers != nil { - waitIndexers(cIndexer, bIndexer, btIndexer) + scIndexer, sbIndexer, sbtIndexer := sindexers[0], sindexers[1], sindexers[2] + ccIndexer, cbIndexer, cbtIndexer := cIndexers[0], cIndexers[1], cIndexers[2] + odr.SetIndexers(ccIndexer, cbIndexer, cbtIndexer) + + server, b := newTestServerHandler(blocks, sindexers, sdb, speers, clock) + client := newTestClientHandler(b, odr, cIndexers, cdb, cPeers, ulcServers, ulcFraction) + + scIndexer.Start(server.blockchain) + sbIndexer.Start(server.blockchain) + ccIndexer.Start(client.backend.blockchain) + cbIndexer.Start(client.backend.blockchain) + + if callback != nil { + callback(scIndexer, sbIndexer, sbtIndexer) } - var ( - peer, lPeer *peer - err1, err2 <-chan error + speer, cpeer *testPeer + err1, err2 <-chan error ) - if newPeer { - peer, err1, lPeer, err2 = newTestPeerPair("peer", protocol, pm, lpm) + if connect { + cpeer, err1, speer, err2 = newTestPeerPair("peer", protocol, server, client) select { case <-time.After(time.Millisecond * 100): case err := <-err1: @@ -479,32 +522,35 @@ func newClientServerEnv(t *testing.T, blocks int, protocol int, waitIndexers fun t.Fatalf("peer 2 handshake error: %v", err) } } - - return &TestEntity{ - db: db, - pm: pm, - rPeer: peer, - peers: peers, - backend: b, - chtIndexer: cIndexer, - bloomIndexer: bIndexer, - bloomTrieIndexer: btIndexer, - }, &TestEntity{ - db: ldb, - pm: lpm, - rPeer: lPeer, - peers: lPeers, - backend: lb, - chtIndexer: lcIndexer, - bloomIndexer: lbIndexer, - bloomTrieIndexer: lbtIndexer, - }, func() { - // Note bloom trie indexers will be closed by their parents recursively. - cIndexer.Close() - bIndexer.Close() - lcIndexer.Close() - lbIndexer.Close() - b.Close() - lb.Close() + s := &testServer{ + clock: clock, + backend: b, + db: sdb, + peer: cpeer, + handler: server, + chtIndexer: scIndexer, + bloomIndexer: sbIndexer, + bloomTrieIndexer: sbtIndexer, + } + c := &testClient{ + clock: clock, + db: cdb, + peer: speer, + handler: client, + chtIndexer: ccIndexer, + bloomIndexer: cbIndexer, + bloomTrieIndexer: cbtIndexer, + } + teardown := func() { + if connect { + speer.close() + cpeer.close() } + ccIndexer.Close() + cbIndexer.Close() + scIndexer.Close() + sbIndexer.Close() + b.Close() + } + return s, c, teardown } diff --git a/les/ulc_test.go b/les/ulc_test.go index 7e9f0b6db..9112bf928 100644 --- a/les/ulc_test.go +++ b/les/ulc_test.go @@ -17,151 +17,100 @@ package les import ( - "crypto/ecdsa" + "crypto/rand" "fmt" - "math/big" "net" - "reflect" "testing" "time" - "github.com/ethereum/go-ethereum/common/mclock" - "github.com/ethereum/go-ethereum/core/rawdb" "github.com/ethereum/go-ethereum/crypto" - "github.com/ethereum/go-ethereum/light" "github.com/ethereum/go-ethereum/p2p" "github.com/ethereum/go-ethereum/p2p/enode" ) -func TestULCSyncWithOnePeer(t *testing.T) { - f := newFullPeerPair(t, 1, 4) - l := newLightPeer(t, []string{f.Node.String()}, 100) +func TestULCAnnounceThresholdLes2(t *testing.T) { testULCAnnounceThreshold(t, 2) } +func TestULCAnnounceThresholdLes3(t *testing.T) { testULCAnnounceThreshold(t, 3) } - if reflect.DeepEqual(f.PM.blockchain.CurrentHeader().Hash(), l.PM.blockchain.CurrentHeader().Hash()) { - t.Fatal("blocks are equal") +func testULCAnnounceThreshold(t *testing.T, protocol int) { + // todo figure out why it takes fetcher so longer to fetcher the announced header. + t.Skip("Sometimes it can failed") + var cases = []struct { + height []int + threshold int + expect uint64 + }{ + {[]int{1}, 100, 1}, + {[]int{0, 0, 0}, 100, 0}, + {[]int{1, 2, 3}, 30, 3}, + {[]int{1, 2, 3}, 60, 2}, + {[]int{3, 2, 1}, 67, 1}, + {[]int{3, 2, 1}, 100, 1}, } - _, _, err := connectPeers(f, l, 2) - if err != nil { - t.Fatal(err) - } - l.PM.fetcher.lock.Lock() - l.PM.fetcher.nextRequest() - l.PM.fetcher.lock.Unlock() + for _, testcase := range cases { + var ( + servers []*testServer + teardowns []func() + nodes []*enode.Node + ids []string + ) + for i := 0; i < len(testcase.height); i++ { + s, n, teardown := newServerPeer(t, 0, protocol) - if !reflect.DeepEqual(f.PM.blockchain.CurrentHeader().Hash(), l.PM.blockchain.CurrentHeader().Hash()) { - t.Fatal("sync doesn't work") + servers = append(servers, s) + nodes = append(nodes, n) + teardowns = append(teardowns, teardown) + ids = append(ids, n.String()) + } + c, teardown := newLightPeer(t, protocol, ids, testcase.threshold) + + // Connect all servers. + for i := 0; i < len(servers); i++ { + connect(servers[i].handler, nodes[i].ID(), c.handler, protocol) + } + for i := 0; i < len(servers); i++ { + for j := 0; j < testcase.height[i]; j++ { + servers[i].backend.Commit() + } + } + time.Sleep(1500 * time.Millisecond) // Ensure the fetcher has done its work. + head := c.handler.backend.blockchain.CurrentHeader().Number.Uint64() + if head != testcase.expect { + t.Fatalf("chain height mismatch, want %d, got %d", testcase.expect, head) + } + + // Release all servers and client resources. + teardown() + for i := 0; i < len(teardowns); i++ { + teardowns[i]() + } } } -func TestULCReceiveAnnounce(t *testing.T) { - f := newFullPeerPair(t, 1, 4) - l := newLightPeer(t, []string{f.Node.String()}, 100) - fPeer, lPeer, err := connectPeers(f, l, 2) - if err != nil { - t.Fatal(err) - } - l.PM.synchronise(fPeer) - - //check that the sync is finished correctly - if !reflect.DeepEqual(f.PM.blockchain.CurrentHeader().Hash(), l.PM.blockchain.CurrentHeader().Hash()) { - t.Fatal("sync doesn't work") - } - l.PM.peers.lock.Lock() - if len(l.PM.peers.peers) == 0 { - t.Fatal("peer list should not be empty") - } - l.PM.peers.lock.Unlock() - - time.Sleep(time.Second) - //send a signed announce message(payload doesn't matter) - td := f.PM.blockchain.GetTd(l.PM.blockchain.CurrentHeader().Hash(), l.PM.blockchain.CurrentHeader().Number.Uint64()) - announce := announceData{ - Number: l.PM.blockchain.CurrentHeader().Number.Uint64() + 1, - Td: td.Add(td, big.NewInt(1)), - } - announce.sign(f.Key) - lPeer.SendAnnounce(announce) -} - -func TestULCShouldNotSyncWithTwoPeersOneHaveEmptyChain(t *testing.T) { - f1 := newFullPeerPair(t, 1, 4) - f2 := newFullPeerPair(t, 2, 0) - l := newLightPeer(t, []string{f1.Node.String(), f2.Node.String()}, 100) - _, _, err := connectPeers(f1, l, 2) - if err != nil { - t.Fatal(err) - } - _, _, err = connectPeers(f2, l, 2) - if err != nil { - t.Fatal(err) - } - l.PM.fetcher.lock.Lock() - l.PM.fetcher.nextRequest() - l.PM.fetcher.lock.Unlock() - - if reflect.DeepEqual(f2.PM.blockchain.CurrentHeader().Hash(), l.PM.blockchain.CurrentHeader().Hash()) { - t.Fatal("Incorrect hash: second peer has empty chain") - } -} - -func TestULCShouldNotSyncWithThreePeersOneHaveEmptyChain(t *testing.T) { - f1 := newFullPeerPair(t, 1, 3) - f2 := newFullPeerPair(t, 2, 4) - f3 := newFullPeerPair(t, 3, 0) - - l := newLightPeer(t, []string{f1.Node.String(), f2.Node.String(), f3.Node.String()}, 60) - _, _, err := connectPeers(f1, l, 2) - if err != nil { - t.Fatal(err) - } - _, _, err = connectPeers(f2, l, 2) - if err != nil { - t.Fatal(err) - } - _, _, err = connectPeers(f3, l, 2) - if err != nil { - t.Fatal(err) - } - l.PM.fetcher.lock.Lock() - l.PM.fetcher.nextRequest() - l.PM.fetcher.lock.Unlock() - - if !reflect.DeepEqual(f1.PM.blockchain.CurrentHeader().Hash(), l.PM.blockchain.CurrentHeader().Hash()) { - t.Fatal("Incorrect hash") - } -} - -type pairPeer struct { - Name string - Node *enode.Node - PM *ProtocolManager - Key *ecdsa.PrivateKey -} - -func connectPeers(full, light pairPeer, version int) (*peer, *peer, error) { +func connect(server *serverHandler, serverId enode.ID, client *clientHandler, protocol int) (*peer, *peer, error) { // Create a message pipe to communicate through app, net := p2p.MsgPipe() - peerLight := full.PM.newPeer(version, NetworkId, p2p.NewPeer(light.Node.ID(), light.Name, nil), net) - peerFull := light.PM.newPeer(version, NetworkId, p2p.NewPeer(full.Node.ID(), full.Name, nil), app) + var id enode.ID + rand.Read(id[:]) + + peer1 := newPeer(protocol, NetworkId, true, p2p.NewPeer(serverId, "", nil), net) // Mark server as trusted + peer2 := newPeer(protocol, NetworkId, false, p2p.NewPeer(id, "", nil), app) // Start the peerLight on a new thread errc1 := make(chan error, 1) errc2 := make(chan error, 1) go func() { select { - case light.PM.newPeerCh <- peerFull: - errc1 <- light.PM.handle(peerFull) - case <-light.PM.quitSync: + case <-server.closeCh: errc1 <- p2p.DiscQuitting + case errc1 <- server.handle(peer2): } }() go func() { select { - case full.PM.newPeerCh <- peerLight: - errc2 <- full.PM.handle(peerLight) - case <-full.PM.quitSync: - errc2 <- p2p.DiscQuitting + case <-client.closeCh: + errc1 <- p2p.DiscQuitting + case errc1 <- client.handle(peer1): } }() @@ -172,48 +121,23 @@ func connectPeers(full, light pairPeer, version int) (*peer, *peer, error) { case err := <-errc2: return nil, nil, fmt.Errorf("peerFull handshake error: %v", err) } - - return peerFull, peerLight, nil + return peer1, peer2, nil } -// newFullPeerPair creates node with full sync mode -func newFullPeerPair(t *testing.T, index int, numberOfblocks int) pairPeer { - db := rawdb.NewMemoryDatabase() - - pmFull, _ := newTestProtocolManagerMust(t, false, numberOfblocks, nil, nil, nil, db, nil, 0) - - peerPairFull := pairPeer{ - Name: "full node", - PM: pmFull, - } +// newServerPeer creates server peer. +func newServerPeer(t *testing.T, blocks int, protocol int) (*testServer, *enode.Node, func()) { + s, teardown := newServerEnv(t, blocks, protocol, nil, false, false, 0) key, err := crypto.GenerateKey() if err != nil { t.Fatal("generate key err:", err) } - peerPairFull.Key = key - peerPairFull.Node = enode.NewV4(&key.PublicKey, net.ParseIP("127.0.0.1"), 35000, 35000) - return peerPairFull + s.handler.server.privateKey = key + n := enode.NewV4(&key.PublicKey, net.ParseIP("127.0.0.1"), 35000, 35000) + return s, n, teardown } // newLightPeer creates node with light sync mode -func newLightPeer(t *testing.T, ulcServers []string, ulcFraction int) pairPeer { - peers := newPeerSet() - dist := newRequestDistributor(peers, make(chan struct{}), &mclock.System{}) - rm := newRetrieveManager(peers, dist, nil) - ldb := rawdb.NewMemoryDatabase() - - odr := NewLesOdr(ldb, light.DefaultClientIndexerConfig, rm) - - pmLight, _ := newTestProtocolManagerMust(t, true, 0, odr, nil, peers, ldb, ulcServers, ulcFraction) - peerPairLight := pairPeer{ - Name: "ulc node", - PM: pmLight, - } - key, err := crypto.GenerateKey() - if err != nil { - t.Fatal("generate key err:", err) - } - peerPairLight.Key = key - peerPairLight.Node = enode.NewV4(&key.PublicKey, net.IP{}, 35000, 35000) - return peerPairLight +func newLightPeer(t *testing.T, protocol int, ulcServers []string, ulcFraction int) (*testClient, func()) { + _, c, teardown := newClientServerEnv(t, 0, protocol, nil, ulcServers, ulcFraction, false, false) + return c, teardown } diff --git a/light/odr_util.go b/light/odr_util.go index 82e33bb78..2c820d40c 100644 --- a/light/odr_util.go +++ b/light/odr_util.go @@ -60,7 +60,7 @@ func GetHeaderByNumber(ctx context.Context, odr OdrBackend, number uint64) (*typ } } if number >= chtCount*odr.IndexerConfig().ChtSize { - return nil, ErrNoTrustedCht + return nil, errNoTrustedCht } r := &ChtRequest{ChtRoot: GetChtRoot(db, chtCount-1, sectionHead), ChtNum: chtCount - 1, BlockNum: number, Config: odr.IndexerConfig()} if err := odr.Retrieve(ctx, r); err != nil { @@ -124,7 +124,7 @@ func GetBlock(ctx context.Context, odr OdrBackend, hash common.Hash, number uint // Retrieve the block header and body contents header := rawdb.ReadHeader(odr.Database(), hash, number) if header == nil { - return nil, ErrNoHeader + return nil, errNoHeader } body, err := GetBody(ctx, odr, hash, number) if err != nil { @@ -241,7 +241,7 @@ func GetBloomBits(ctx context.Context, odr OdrBackend, bitIdx uint, sectionIdxLi } else { // TODO(rjl493456442) Convert sectionIndex to BloomTrie relative index if sectionIdx >= bloomTrieCount { - return nil, ErrNoTrustedBloomTrie + return nil, errNoTrustedBloomTrie } reqList = append(reqList, sectionIdx) reqIdx = append(reqIdx, i) diff --git a/light/postprocess.go b/light/postprocess.go index bf632a449..083dcfceb 100644 --- a/light/postprocess.go +++ b/light/postprocess.go @@ -98,9 +98,9 @@ var ( ) var ( - ErrNoTrustedCht = errors.New("no trusted canonical hash trie") - ErrNoTrustedBloomTrie = errors.New("no trusted bloom trie") - ErrNoHeader = errors.New("header not found") + errNoTrustedCht = errors.New("no trusted canonical hash trie") + errNoTrustedBloomTrie = errors.New("no trusted bloom trie") + errNoHeader = errors.New("header not found") chtPrefix = []byte("chtRootV2-") // chtPrefix + chtNum (uint64 big endian) -> trie root hash ChtTablePrefix = "cht-" )