diff --git a/eth/downloader/downloader.go b/eth/downloader/downloader.go index 2807a4bcf..f9bd5a635 100644 --- a/eth/downloader/downloader.go +++ b/eth/downloader/downloader.go @@ -90,7 +90,6 @@ func New(mux *event.TypeMux, hasBlock hashCheckFn, getBlock getBlockFn) *Downloa mux: mux, queue: newQueue(), peers: newPeerSet(), - checks: make(map[common.Hash]time.Time), hasBlock: hasBlock, getBlock: getBlock, newPeerCh: make(chan *peer, 1), @@ -160,6 +159,7 @@ func (d *Downloader) Synchronise(id string, hash common.Hash) error { // Reset the queue and peer set to clean any internal leftover state d.queue.Reset() d.peers.Reset() + d.checks = make(map[common.Hash]time.Time) // Retrieve the origin peer and initiate the downloading process p := d.peers.Peer(id) @@ -184,7 +184,7 @@ func (d *Downloader) syncWithPeer(p *peer, hash common.Hash) (err error) { defer func() { // reset on error if err != nil { - d.queue.Reset() + d.Cancel() d.mux.Post(FailedEvent{err}) } else { d.mux.Post(DoneEvent{}) @@ -259,8 +259,6 @@ func (d *Downloader) fetchHashes(p *peer, h common.Hash) error { // Make sure the peer actually gave something valid if len(hashPack.hashes) == 0 { glog.V(logger.Debug).Infof("Peer (%s) responded with empty hash set\n", active.id) - d.queue.Reset() - return errEmptyHashSet } // Determine if we're done fetching hashes (queue up all pending), and continue if not done @@ -277,8 +275,6 @@ func (d *Downloader) fetchHashes(p *peer, h common.Hash) error { inserts := d.queue.Insert(hashPack.hashes) if len(inserts) == 0 && !done { glog.V(logger.Debug).Infof("Peer (%s) responded with stale hashes\n", active.id) - d.queue.Reset() - return ErrBadPeer } if !done { @@ -306,8 +302,13 @@ func (d *Downloader) fetchHashes(p *peer, h common.Hash) error { if blockPack.peerId != active.id || len(blockPack.blocks) != 1 { continue } - hash := blockPack.blocks[0].Hash() - delete(d.checks, hash) + block := blockPack.blocks[0] + if _, ok := d.checks[block.Hash()]; ok { + if !d.queue.Has(block.ParentHash()) { + return ErrCrossCheckFailed + } + delete(d.checks, block.Hash()) + } case <-crossTicker.C: // Iterate over all the cross checks and fail the hash chain if they're not verified @@ -334,7 +335,6 @@ func (d *Downloader) fetchHashes(p *peer, h common.Hash) error { // if all peers have been tried, abort the process entirely or if the hash is // the zero hash. if p == nil || (head == common.Hash{}) { - d.queue.Reset() return ErrTimeout } // set p to the active peer. this will invalidate any hashes that may be returned @@ -380,7 +380,6 @@ out: if err := d.queue.Deliver(blockPack.peerId, blockPack.blocks); err != nil { if err == ErrInvalidChain { // The hash chain is invalid (blocks are not ordered properly), abort - d.queue.Reset() return err } // Peer did deliver, but some blocks were off, penalize @@ -414,7 +413,6 @@ out: } // After removing bad peers make sure we actually have sufficient peer left to keep downloading if d.peers.Len() == 0 { - d.queue.Reset() return errNoPeers } // If there are unrequested hashes left start fetching @@ -448,7 +446,6 @@ out: // Make sure that we have peers available for fetching. If all peers have been tried // and all failed throw an error if d.queue.InFlight() == 0 { - d.queue.Reset() return errPeersUnavailable } diff --git a/eth/downloader/downloader_test.go b/eth/downloader/downloader_test.go index 60dcc06cd..d55664314 100644 --- a/eth/downloader/downloader_test.go +++ b/eth/downloader/downloader_test.go @@ -23,25 +23,26 @@ func createHashes(start, amount int) (hashes []common.Hash) { for i := range hashes[:len(hashes)-1] { binary.BigEndian.PutUint64(hashes[i][:8], uint64(i+2)) } - return } -func createBlock(i int, prevHash, hash common.Hash) *types.Block { +func createBlock(i int, parent, hash common.Hash) *types.Block { header := &types.Header{Number: big.NewInt(int64(i))} block := types.NewBlockWithHeader(header) block.HeaderHash = hash - block.ParentHeaderHash = prevHash + block.ParentHeaderHash = parent return block } func createBlocksFromHashes(hashes []common.Hash) map[common.Hash]*types.Block { blocks := make(map[common.Hash]*types.Block) - - for i, hash := range hashes { - blocks[hash] = createBlock(len(hashes)-i, knownHash, hash) + for i := 0; i < len(hashes); i++ { + parent := knownHash + if i < len(hashes)-1 { + parent = hashes[i+1] + } + blocks[hashes[i]] = createBlock(len(hashes)-i, parent, hashes[i]) } - return blocks } @@ -136,6 +137,7 @@ func (dl *downloadTester) getHashes(head common.Hash) error { hashes := make([]common.Hash, 0, maxHashFetch) for i, hash := range dl.hashes { if hash == head { + i++ for len(hashes) < cap(hashes) && i < len(dl.hashes) { hashes = append(hashes, dl.hashes[i]) i++ @@ -144,9 +146,11 @@ func (dl *downloadTester) getHashes(head common.Hash) error { } } // Delay delivery a bit to allow attacks to unfold - time.Sleep(time.Millisecond) - - dl.downloader.DeliverHashes(dl.activePeerId, hashes) + id := dl.activePeerId + go func() { + time.Sleep(time.Millisecond) + dl.downloader.DeliverHashes(id, hashes) + }() return nil } @@ -424,12 +428,15 @@ func TestInvalidHashOrderAttack(t *testing.T) { hashes := createHashes(0, 4*blockCacheLimit) blocks := createBlocksFromHashes(hashes) + chunk1 := make([]common.Hash, blockCacheLimit) + chunk2 := make([]common.Hash, blockCacheLimit) + copy(chunk1, hashes[blockCacheLimit:2*blockCacheLimit]) + copy(chunk2, hashes[2*blockCacheLimit:3*blockCacheLimit]) + reverse := make([]common.Hash, len(hashes)) copy(reverse, hashes) - - for i := len(hashes) / 4; i < 2*len(hashes)/4; i++ { - reverse[i], reverse[len(hashes)-i-1] = reverse[len(hashes)-i-1], reverse[i] - } + copy(reverse[2*blockCacheLimit:], chunk1) + copy(reverse[blockCacheLimit:], chunk2) // Try and sync with the malicious node and check that it fails tester := newTester(t, reverse, blocks) @@ -453,7 +460,6 @@ func TestMadeupHashChainAttack(t *testing.T) { // Create a long chain of hashes without backing blocks hashes := createHashes(0, 1024*blockCacheLimit) - hashes = hashes[:len(hashes)-1] // Try and sync with the malicious node and check that it fails tester := newTester(t, hashes, nil) @@ -462,3 +468,31 @@ func TestMadeupHashChainAttack(t *testing.T) { t.Fatalf("synchronisation error mismatch: have %v, want %v", err, ErrCrossCheckFailed) } } + +// Tests that if a malicious peer makes up a random block chain, and tried to +// push indefinitely, it actually gets caught with it. +func TestMadeupBlockChainAttack(t *testing.T) { + blockTTL = 100 * time.Millisecond + crossCheckCycle = 25 * time.Millisecond + + // Create a long chain of blocks and simulate an invalid chain by dropping every second + hashes := createHashes(0, 32*blockCacheLimit) + blocks := createBlocksFromHashes(hashes) + + gapped := make([]common.Hash, len(hashes)/2) + for i := 0; i < len(gapped); i++ { + gapped[i] = hashes[2*i] + } + // Try and sync with the malicious node and check that it fails + tester := newTester(t, gapped, blocks) + tester.newPeer("attack", big.NewInt(10000), gapped[0]) + if _, err := tester.syncTake("attack", gapped[0]); err != ErrCrossCheckFailed { + t.Fatalf("synchronisation error mismatch: have %v, want %v", err, ErrCrossCheckFailed) + } + // Ensure that a valid chain can still pass sync + tester.hashes = hashes + tester.newPeer("valid", big.NewInt(20000), hashes[0]) + if _, err := tester.syncTake("valid", hashes[0]); err != nil { + t.Fatalf("failed to synchronise blocks: %v", err) + } +}