diff --git a/les/handler.go b/les/handler.go index 4e98e0b32..c7bd23103 100644 --- a/les/handler.go +++ b/les/handler.go @@ -19,9 +19,11 @@ package les import ( "encoding/binary" "encoding/json" + "errors" "fmt" "math/big" "sync" + "sync/atomic" "time" "github.com/ethereum/go-ethereum/common" @@ -44,6 +46,8 @@ import ( "github.com/ethereum/go-ethereum/trie" ) +var errTooManyInvalidRequest = errors.New("too many invalid requests made") + const ( softResponseLimit = 2 * 1024 * 1024 // Target maximum size of returned blocks, headers or node data. estHeaderRlpSize = 500 // Approximate size of an RLP encoded block header @@ -524,6 +528,7 @@ func (pm *ProtocolManager) handleMsg(p *peer) error { origin = pm.blockchain.GetHeaderByNumber(query.Origin.Number) } if origin == nil { + atomic.AddUint32(&p.invalidCount, 1) break } headers = append(headers, origin) @@ -570,7 +575,6 @@ func (pm *ProtocolManager) handleMsg(p *peer) error { } else { unknown = true } - case !query.Reverse: // Number based traversal towards the leaf block query.Origin.Number += query.Skip + 1 @@ -628,15 +632,18 @@ func (pm *ProtocolManager) handleMsg(p *peer) error { sendResponse(req.ReqID, 0, nil, task.servingTime) return } + // Retrieve the requested block body, stopping if enough was found if bytes >= softResponseLimit { break } - // Retrieve the requested block body, stopping if enough was found - if number := rawdb.ReadHeaderNumber(pm.chainDb, hash); number != nil { - if data := rawdb.ReadBodyRLP(pm.chainDb, hash, *number); len(data) != 0 { - bodies = append(bodies, data) - bytes += len(data) - } + number := rawdb.ReadHeaderNumber(pm.chainDb, hash) + if number == nil { + atomic.AddUint32(&p.invalidCount, 1) + continue + } + if data := rawdb.ReadBodyRLP(pm.chainDb, hash, *number); len(data) != 0 { + bodies = append(bodies, data) + bytes += len(data) } } sendResponse(req.ReqID, uint64(reqCnt), p.ReplyBlockBodiesRLP(req.ReqID, bodies), task.done()) @@ -691,6 +698,7 @@ func (pm *ProtocolManager) handleMsg(p *peer) error { number := rawdb.ReadHeaderNumber(pm.chainDb, request.BHash) if number == nil { p.Log().Warn("Failed to retrieve block num for code", "hash", request.BHash) + atomic.AddUint32(&p.invalidCount, 1) continue } header := rawdb.ReadHeader(pm.chainDb, request.BHash, *number) @@ -703,6 +711,7 @@ func (pm *ProtocolManager) handleMsg(p *peer) error { local := pm.blockchain.CurrentHeader().Number.Uint64() if !pm.server.archiveMode && header.Number.Uint64()+core.TriesInMemory <= local { p.Log().Debug("Reject stale code request", "number", header.Number.Uint64(), "head", local) + atomic.AddUint32(&p.invalidCount, 1) continue } triedb := pm.blockchain.StateCache().TrieDB() @@ -710,6 +719,7 @@ func (pm *ProtocolManager) handleMsg(p *peer) error { account, err := pm.getAccount(triedb, header.Root, common.BytesToHash(request.AccKey)) if err != nil { p.Log().Warn("Failed to retrieve account for code", "block", header.Number, "hash", header.Hash(), "account", common.BytesToHash(request.AccKey), "err", err) + atomic.AddUint32(&p.invalidCount, 1) continue } code, err := triedb.Node(common.BytesToHash(account.CodeHash)) @@ -776,9 +786,12 @@ func (pm *ProtocolManager) handleMsg(p *peer) error { } // Retrieve the requested block's receipts, skipping if unknown to us var results types.Receipts - if number := rawdb.ReadHeaderNumber(pm.chainDb, hash); number != nil { - results = rawdb.ReadRawReceipts(pm.chainDb, hash, *number) + number := rawdb.ReadHeaderNumber(pm.chainDb, hash) + if number == nil { + atomic.AddUint32(&p.invalidCount, 1) + continue } + results = rawdb.ReadRawReceipts(pm.chainDb, hash, *number) if results == nil { if header := pm.blockchain.GetHeaderByHash(hash); header == nil || header.ReceiptHash != types.EmptyRootHash { continue @@ -853,6 +866,7 @@ func (pm *ProtocolManager) handleMsg(p *peer) error { if number = rawdb.ReadHeaderNumber(pm.chainDb, request.BHash); number == nil { p.Log().Warn("Failed to retrieve block num for proof", "hash", request.BHash) + atomic.AddUint32(&p.invalidCount, 1) continue } if header = rawdb.ReadHeader(pm.chainDb, request.BHash, *number); header == nil { @@ -864,12 +878,14 @@ func (pm *ProtocolManager) handleMsg(p *peer) error { local := pm.blockchain.CurrentHeader().Number.Uint64() if !pm.server.archiveMode && header.Number.Uint64()+core.TriesInMemory <= local { p.Log().Debug("Reject stale trie request", "number", header.Number.Uint64(), "head", local) + atomic.AddUint32(&p.invalidCount, 1) continue } root = header.Root } // If a header lookup failed (non existent), ignore subsequent requests for the same header if root == (common.Hash{}) { + atomic.AddUint32(&p.invalidCount, 1) continue } // Open the account or storage trie for the request @@ -888,6 +904,7 @@ func (pm *ProtocolManager) handleMsg(p *peer) error { account, err := pm.getAccount(statedb.TrieDB(), root, common.BytesToHash(request.AccKey)) if err != nil { p.Log().Warn("Failed to retrieve account for proof", "block", header.Number, "hash", header.Hash(), "account", common.BytesToHash(request.AccKey), "err", err) + atomic.AddUint32(&p.invalidCount, 1) continue } trie, err = statedb.OpenStorageTrie(common.BytesToHash(request.AccKey), account.Root) @@ -1134,6 +1151,11 @@ func (pm *ProtocolManager) handleMsg(p *peer) error { } } } + // If the client has made too much invalid request(e.g. request a non-exist data), + // reject them to prevent SPAM attack. + if atomic.LoadUint32(&p.invalidCount) > maxRequestErrors { + return errTooManyInvalidRequest + } return nil } diff --git a/les/handler_test.go b/les/handler_test.go index 51f0a1a0e..c3608ebdd 100644 --- a/les/handler_test.go +++ b/les/handler_test.go @@ -597,9 +597,10 @@ func TestStopResumeLes3(t *testing.T) { expBuf := testBufLimit var reqID uint64 + header := pm.blockchain.CurrentHeader() req := func() { reqID++ - sendRequest(peer.app, GetBlockHeadersMsg, reqID, testCost, &getBlockHeadersData{Origin: hashOrNumber{Hash: common.Hash{1}}, Amount: 1}) + sendRequest(peer.app, GetBlockHeadersMsg, reqID, testCost, &getBlockHeadersData{Origin: hashOrNumber{Hash: header.Hash()}, Amount: 1}) } for i := 1; i <= 5; i++ { @@ -607,8 +608,8 @@ func TestStopResumeLes3(t *testing.T) { for expBuf >= testCost { req() expBuf -= testCost - if err := expectResponse(peer.app, BlockHeadersMsg, reqID, expBuf, nil); err != nil { - t.Errorf("expected response and failed: %v", err) + if err := expectResponse(peer.app, BlockHeadersMsg, reqID, expBuf, []*types.Header{header}); err != nil { + t.Fatalf("expected response and failed: %v", err) } } // send some more requests in excess and expect a single StopMsg diff --git a/les/peer.go b/les/peer.go index 56d316f50..a615c9b73 100644 --- a/les/peer.go +++ b/les/peer.go @@ -42,7 +42,10 @@ var ( errNotRegistered = errors.New("peer is not registered") ) -const maxResponseErrors = 50 // number of invalid responses tolerated (makes the protocol less brittle but still avoids spam) +const ( + maxRequestErrors = 20 // number of invalid requests tolerated (makes the protocol less brittle but still avoids spam) + maxResponseErrors = 50 // number of invalid responses tolerated (makes the protocol less brittle but still avoids spam) +) // capacity limitation for parameter updates const ( @@ -69,7 +72,6 @@ const ( type peer struct { *p2p.Peer - rw p2p.MsgReadWriter version int // Protocol version negotiated @@ -89,6 +91,7 @@ type peer struct { // RequestProcessed is called responseLock sync.Mutex responseCount uint64 + invalidCount uint32 poolEntry *poolEntry hasBlock func(common.Hash, uint64, bool) bool