diff --git a/eth/fetcher/fetcher_test.go b/eth/fetcher/fetcher_test.go index 00fad3498..4c6a1bf6a 100644 --- a/eth/fetcher/fetcher_test.go +++ b/eth/fetcher/fetcher_test.go @@ -4,6 +4,7 @@ import ( "encoding/binary" "errors" "math/big" + "sync" "sync/atomic" "testing" "time" @@ -67,15 +68,17 @@ func createBlocksFromHashes(hashes []common.Hash) map[common.Hash]*types.Block { type fetcherTester struct { fetcher *Fetcher - ownHashes []common.Hash // Hash chain belonging to the tester - ownBlocks map[common.Hash]*types.Block // Blocks belonging to the tester + hashes []common.Hash // Hash chain belonging to the tester + blocks map[common.Hash]*types.Block // Blocks belonging to the tester + + lock sync.RWMutex } // newTester creates a new fetcher test mocker. func newTester() *fetcherTester { tester := &fetcherTester{ - ownHashes: []common.Hash{knownHash}, - ownBlocks: map[common.Hash]*types.Block{knownHash: genesis}, + hashes: []common.Hash{knownHash}, + blocks: map[common.Hash]*types.Block{knownHash: genesis}, } tester.fetcher = New(tester.hasBlock, tester.importBlock, tester.chainHeight) tester.fetcher.Start() @@ -85,29 +88,38 @@ func newTester() *fetcherTester { // hasBlock checks if a block is pres ent in the testers canonical chain. func (f *fetcherTester) hasBlock(hash common.Hash) bool { - _, ok := f.ownBlocks[hash] + f.lock.RLock() + defer f.lock.RUnlock() + + _, ok := f.blocks[hash] return ok } // importBlock injects a new blocks into the simulated chain. func (f *fetcherTester) importBlock(peer string, block *types.Block) error { + f.lock.Lock() + defer f.lock.Unlock() + // Make sure the parent in known - if _, ok := f.ownBlocks[block.ParentHash()]; !ok { + if _, ok := f.blocks[block.ParentHash()]; !ok { return errors.New("unknown parent") } // Discard any new blocks if the same height already exists - if block.NumberU64() <= f.ownBlocks[f.ownHashes[len(f.ownHashes)-1]].NumberU64() { + if block.NumberU64() <= f.blocks[f.hashes[len(f.hashes)-1]].NumberU64() { return nil } // Otherwise build our current chain - f.ownHashes = append(f.ownHashes, block.Hash()) - f.ownBlocks[block.Hash()] = block + f.hashes = append(f.hashes, block.Hash()) + f.blocks[block.Hash()] = block return nil } // chainHeight retrieves the current height (block number) of the chain. func (f *fetcherTester) chainHeight() uint64 { - return f.ownBlocks[f.ownHashes[len(f.ownHashes)-1]].NumberU64() + f.lock.RLock() + defer f.lock.RUnlock() + + return f.blocks[f.hashes[len(f.hashes)-1]].NumberU64() } // peerFetcher retrieves a fetcher associated with a simulated peer. @@ -149,7 +161,7 @@ func TestSequentialAnnouncements(t *testing.T) { tester.fetcher.Notify("valid", hashes[i], time.Now().Add(-arriveTimeout), fetcher) time.Sleep(50 * time.Millisecond) } - if imported := len(tester.ownBlocks); imported != targetBlocks+1 { + if imported := len(tester.blocks); imported != targetBlocks+1 { t.Fatalf("synchronised block mismatch: have %v, want %v", imported, targetBlocks+1) } } @@ -179,7 +191,7 @@ func TestConcurrentAnnouncements(t *testing.T) { time.Sleep(50 * time.Millisecond) } - if imported := len(tester.ownBlocks); imported != targetBlocks+1 { + if imported := len(tester.blocks); imported != targetBlocks+1 { t.Fatalf("synchronised block mismatch: have %v, want %v", imported, targetBlocks+1) } // Make sure no blocks were retrieved twice @@ -207,7 +219,7 @@ func TestOverlappingAnnouncements(t *testing.T) { } time.Sleep(overlap * delay) - if imported := len(tester.ownBlocks); imported != targetBlocks+1 { + if imported := len(tester.blocks); imported != targetBlocks+1 { t.Fatalf("synchronised block mismatch: have %v, want %v", imported, targetBlocks+1) } } @@ -242,7 +254,7 @@ func TestPendingDeduplication(t *testing.T) { time.Sleep(delay) // Check that all blocks were imported and none fetched twice - if imported := len(tester.ownBlocks); imported != 2 { + if imported := len(tester.blocks); imported != 2 { t.Fatalf("synchronised block mismatch: have %v, want %v", imported, 2) } if int(counter) != 1 { @@ -273,7 +285,7 @@ func TestRandomArrivalImport(t *testing.T) { tester.fetcher.Notify("valid", hashes[skip], time.Now().Add(-arriveTimeout), fetcher) time.Sleep(50 * time.Millisecond) - if imported := len(tester.ownBlocks); imported != targetBlocks+1 { + if imported := len(tester.blocks); imported != targetBlocks+1 { t.Fatalf("synchronised block mismatch: have %v, want %v", imported, targetBlocks+1) } } @@ -301,7 +313,7 @@ func TestQueueGapFill(t *testing.T) { tester.fetcher.Enqueue("valid", blocks[hashes[skip]]) time.Sleep(50 * time.Millisecond) - if imported := len(tester.ownBlocks); imported != targetBlocks+1 { + if imported := len(tester.blocks); imported != targetBlocks+1 { t.Fatalf("synchronised block mismatch: have %v, want %v", imported, targetBlocks+1) } } @@ -334,7 +346,7 @@ func TestImportDeduplication(t *testing.T) { tester.fetcher.Enqueue("valid", blocks[hashes[1]]) time.Sleep(50 * time.Millisecond) - if imported := len(tester.ownBlocks); imported != 3 { + if imported := len(tester.blocks); imported != 3 { t.Fatalf("synchronised block mismatch: have %v, want %v", imported, 3) } if counter != 2 { @@ -353,8 +365,8 @@ func TestDistantDiscarding(t *testing.T) { // Create a tester and simulate a head block being the middle of the above chain tester := newTester() - tester.ownHashes = []common.Hash{head} - tester.ownBlocks = map[common.Hash]*types.Block{head: blocks[head]} + tester.hashes = []common.Hash{head} + tester.blocks = map[common.Hash]*types.Block{head: blocks[head]} // Ensure that a block with a lower number than the threshold is discarded tester.fetcher.Enqueue("lower", blocks[hashes[0]]) @@ -413,10 +425,10 @@ func TestCompetingImports(t *testing.T) { tester.fetcher.Enqueue("chain C", blocksC[hashesC[len(hashesC)-2]]) start := time.Now() - for len(tester.ownHashes) != len(hashesA) && time.Since(start) < time.Second { + for len(tester.hashes) != len(hashesA) && time.Since(start) < time.Second { time.Sleep(50 * time.Millisecond) } - if len(tester.ownHashes) != len(hashesA) { - t.Fatalf("chain length mismatch: have %v, want %v", len(tester.ownHashes), len(hashesA)) + if len(tester.hashes) != len(hashesA) { + t.Fatalf("chain length mismatch: have %v, want %v", len(tester.hashes), len(hashesA)) } }