eth/fetcher: make tests thread safe

This commit is contained in:
Péter Szilágyi 2015-06-17 17:08:32 +03:00
parent 37c5ff392f
commit a9ada0b5ba

View File

@ -4,6 +4,7 @@ import (
"encoding/binary" "encoding/binary"
"errors" "errors"
"math/big" "math/big"
"sync"
"sync/atomic" "sync/atomic"
"testing" "testing"
"time" "time"
@ -67,15 +68,17 @@ func createBlocksFromHashes(hashes []common.Hash) map[common.Hash]*types.Block {
type fetcherTester struct { type fetcherTester struct {
fetcher *Fetcher fetcher *Fetcher
ownHashes []common.Hash // Hash chain belonging to the tester hashes []common.Hash // Hash chain belonging to the tester
ownBlocks map[common.Hash]*types.Block // Blocks belonging to the tester blocks map[common.Hash]*types.Block // Blocks belonging to the tester
lock sync.RWMutex
} }
// newTester creates a new fetcher test mocker. // newTester creates a new fetcher test mocker.
func newTester() *fetcherTester { func newTester() *fetcherTester {
tester := &fetcherTester{ tester := &fetcherTester{
ownHashes: []common.Hash{knownHash}, hashes: []common.Hash{knownHash},
ownBlocks: map[common.Hash]*types.Block{knownHash: genesis}, blocks: map[common.Hash]*types.Block{knownHash: genesis},
} }
tester.fetcher = New(tester.hasBlock, tester.importBlock, tester.chainHeight) tester.fetcher = New(tester.hasBlock, tester.importBlock, tester.chainHeight)
tester.fetcher.Start() tester.fetcher.Start()
@ -85,29 +88,38 @@ func newTester() *fetcherTester {
// hasBlock checks if a block is pres ent in the testers canonical chain. // hasBlock checks if a block is pres ent in the testers canonical chain.
func (f *fetcherTester) hasBlock(hash common.Hash) bool { func (f *fetcherTester) hasBlock(hash common.Hash) bool {
_, ok := f.ownBlocks[hash] f.lock.RLock()
defer f.lock.RUnlock()
_, ok := f.blocks[hash]
return ok return ok
} }
// importBlock injects a new blocks into the simulated chain. // importBlock injects a new blocks into the simulated chain.
func (f *fetcherTester) importBlock(peer string, block *types.Block) error { func (f *fetcherTester) importBlock(peer string, block *types.Block) error {
f.lock.Lock()
defer f.lock.Unlock()
// Make sure the parent in known // Make sure the parent in known
if _, ok := f.ownBlocks[block.ParentHash()]; !ok { if _, ok := f.blocks[block.ParentHash()]; !ok {
return errors.New("unknown parent") return errors.New("unknown parent")
} }
// Discard any new blocks if the same height already exists // Discard any new blocks if the same height already exists
if block.NumberU64() <= f.ownBlocks[f.ownHashes[len(f.ownHashes)-1]].NumberU64() { if block.NumberU64() <= f.blocks[f.hashes[len(f.hashes)-1]].NumberU64() {
return nil return nil
} }
// Otherwise build our current chain // Otherwise build our current chain
f.ownHashes = append(f.ownHashes, block.Hash()) f.hashes = append(f.hashes, block.Hash())
f.ownBlocks[block.Hash()] = block f.blocks[block.Hash()] = block
return nil return nil
} }
// chainHeight retrieves the current height (block number) of the chain. // chainHeight retrieves the current height (block number) of the chain.
func (f *fetcherTester) chainHeight() uint64 { func (f *fetcherTester) chainHeight() uint64 {
return f.ownBlocks[f.ownHashes[len(f.ownHashes)-1]].NumberU64() f.lock.RLock()
defer f.lock.RUnlock()
return f.blocks[f.hashes[len(f.hashes)-1]].NumberU64()
} }
// peerFetcher retrieves a fetcher associated with a simulated peer. // peerFetcher retrieves a fetcher associated with a simulated peer.
@ -149,7 +161,7 @@ func TestSequentialAnnouncements(t *testing.T) {
tester.fetcher.Notify("valid", hashes[i], time.Now().Add(-arriveTimeout), fetcher) tester.fetcher.Notify("valid", hashes[i], time.Now().Add(-arriveTimeout), fetcher)
time.Sleep(50 * time.Millisecond) time.Sleep(50 * time.Millisecond)
} }
if imported := len(tester.ownBlocks); imported != targetBlocks+1 { if imported := len(tester.blocks); imported != targetBlocks+1 {
t.Fatalf("synchronised block mismatch: have %v, want %v", imported, targetBlocks+1) t.Fatalf("synchronised block mismatch: have %v, want %v", imported, targetBlocks+1)
} }
} }
@ -179,7 +191,7 @@ func TestConcurrentAnnouncements(t *testing.T) {
time.Sleep(50 * time.Millisecond) time.Sleep(50 * time.Millisecond)
} }
if imported := len(tester.ownBlocks); imported != targetBlocks+1 { if imported := len(tester.blocks); imported != targetBlocks+1 {
t.Fatalf("synchronised block mismatch: have %v, want %v", imported, targetBlocks+1) t.Fatalf("synchronised block mismatch: have %v, want %v", imported, targetBlocks+1)
} }
// Make sure no blocks were retrieved twice // Make sure no blocks were retrieved twice
@ -207,7 +219,7 @@ func TestOverlappingAnnouncements(t *testing.T) {
} }
time.Sleep(overlap * delay) time.Sleep(overlap * delay)
if imported := len(tester.ownBlocks); imported != targetBlocks+1 { if imported := len(tester.blocks); imported != targetBlocks+1 {
t.Fatalf("synchronised block mismatch: have %v, want %v", imported, targetBlocks+1) t.Fatalf("synchronised block mismatch: have %v, want %v", imported, targetBlocks+1)
} }
} }
@ -242,7 +254,7 @@ func TestPendingDeduplication(t *testing.T) {
time.Sleep(delay) time.Sleep(delay)
// Check that all blocks were imported and none fetched twice // Check that all blocks were imported and none fetched twice
if imported := len(tester.ownBlocks); imported != 2 { if imported := len(tester.blocks); imported != 2 {
t.Fatalf("synchronised block mismatch: have %v, want %v", imported, 2) t.Fatalf("synchronised block mismatch: have %v, want %v", imported, 2)
} }
if int(counter) != 1 { if int(counter) != 1 {
@ -273,7 +285,7 @@ func TestRandomArrivalImport(t *testing.T) {
tester.fetcher.Notify("valid", hashes[skip], time.Now().Add(-arriveTimeout), fetcher) tester.fetcher.Notify("valid", hashes[skip], time.Now().Add(-arriveTimeout), fetcher)
time.Sleep(50 * time.Millisecond) time.Sleep(50 * time.Millisecond)
if imported := len(tester.ownBlocks); imported != targetBlocks+1 { if imported := len(tester.blocks); imported != targetBlocks+1 {
t.Fatalf("synchronised block mismatch: have %v, want %v", imported, targetBlocks+1) t.Fatalf("synchronised block mismatch: have %v, want %v", imported, targetBlocks+1)
} }
} }
@ -301,7 +313,7 @@ func TestQueueGapFill(t *testing.T) {
tester.fetcher.Enqueue("valid", blocks[hashes[skip]]) tester.fetcher.Enqueue("valid", blocks[hashes[skip]])
time.Sleep(50 * time.Millisecond) time.Sleep(50 * time.Millisecond)
if imported := len(tester.ownBlocks); imported != targetBlocks+1 { if imported := len(tester.blocks); imported != targetBlocks+1 {
t.Fatalf("synchronised block mismatch: have %v, want %v", imported, targetBlocks+1) t.Fatalf("synchronised block mismatch: have %v, want %v", imported, targetBlocks+1)
} }
} }
@ -334,7 +346,7 @@ func TestImportDeduplication(t *testing.T) {
tester.fetcher.Enqueue("valid", blocks[hashes[1]]) tester.fetcher.Enqueue("valid", blocks[hashes[1]])
time.Sleep(50 * time.Millisecond) time.Sleep(50 * time.Millisecond)
if imported := len(tester.ownBlocks); imported != 3 { if imported := len(tester.blocks); imported != 3 {
t.Fatalf("synchronised block mismatch: have %v, want %v", imported, 3) t.Fatalf("synchronised block mismatch: have %v, want %v", imported, 3)
} }
if counter != 2 { if counter != 2 {
@ -353,8 +365,8 @@ func TestDistantDiscarding(t *testing.T) {
// Create a tester and simulate a head block being the middle of the above chain // Create a tester and simulate a head block being the middle of the above chain
tester := newTester() tester := newTester()
tester.ownHashes = []common.Hash{head} tester.hashes = []common.Hash{head}
tester.ownBlocks = map[common.Hash]*types.Block{head: blocks[head]} tester.blocks = map[common.Hash]*types.Block{head: blocks[head]}
// Ensure that a block with a lower number than the threshold is discarded // Ensure that a block with a lower number than the threshold is discarded
tester.fetcher.Enqueue("lower", blocks[hashes[0]]) tester.fetcher.Enqueue("lower", blocks[hashes[0]])
@ -413,10 +425,10 @@ func TestCompetingImports(t *testing.T) {
tester.fetcher.Enqueue("chain C", blocksC[hashesC[len(hashesC)-2]]) tester.fetcher.Enqueue("chain C", blocksC[hashesC[len(hashesC)-2]])
start := time.Now() start := time.Now()
for len(tester.ownHashes) != len(hashesA) && time.Since(start) < time.Second { for len(tester.hashes) != len(hashesA) && time.Since(start) < time.Second {
time.Sleep(50 * time.Millisecond) time.Sleep(50 * time.Millisecond)
} }
if len(tester.ownHashes) != len(hashesA) { if len(tester.hashes) != len(hashesA) {
t.Fatalf("chain length mismatch: have %v, want %v", len(tester.ownHashes), len(hashesA)) t.Fatalf("chain length mismatch: have %v, want %v", len(tester.hashes), len(hashesA))
} }
} }