diff --git a/core/blockchain.go b/core/blockchain.go index 4ca618c5b..9fa5b09f9 100644 --- a/core/blockchain.go +++ b/core/blockchain.go @@ -1543,8 +1543,16 @@ func (bc *BlockChain) reorg(oldBlock, newBlock *types.Block) error { for _, tx := range types.TxDifference(deletedTxs, addedTxs) { rawdb.DeleteTxLookupEntry(batch, tx.Hash()) } + // Delete any canonical number assignments above the new head + number := bc.CurrentBlock().NumberU64() + for i := number + 1; ; i++ { + hash := rawdb.ReadCanonicalHash(bc.db, i) + if hash == (common.Hash{}) { + break + } + rawdb.DeleteCanonicalHash(batch, i) + } batch.Write() - // If any logs need to be fired, do it now. In theory we could avoid creating // this goroutine if there are no events to fire, but realistcally that only // ever happens if we're reorging empty blocks, which will only happen on idle diff --git a/core/blockchain_test.go b/core/blockchain_test.go index 5ee1d9f8e..80a949d90 100644 --- a/core/blockchain_test.go +++ b/core/blockchain_test.go @@ -17,6 +17,7 @@ package core import ( + "fmt" "math/big" "math/rand" "sync" @@ -1810,3 +1811,123 @@ func TestPrunedImportSide(t *testing.T) { testSideImport(t, 1, 10) testSideImport(t, 1, -10) } + +// getLongAndShortChains returns two chains, +// A is longer, B is heavier +func getLongAndShortChains() (*BlockChain, []*types.Block, []*types.Block, error) { + // Generate a canonical chain to act as the main dataset + engine := ethash.NewFaker() + db := rawdb.NewMemoryDatabase() + genesis := new(Genesis).MustCommit(db) + + // Generate and import the canonical chain, + // Offset the time, to keep the difficulty low + longChain, _ := GenerateChain(params.TestChainConfig, genesis, engine, db, 80, func(i int, b *BlockGen) { + b.SetCoinbase(common.Address{1}) + }) + diskdb := rawdb.NewMemoryDatabase() + new(Genesis).MustCommit(diskdb) + + chain, err := NewBlockChain(diskdb, nil, params.TestChainConfig, engine, vm.Config{}, nil) + if err != nil { + return nil, nil, nil, fmt.Errorf("failed to create tester chain: %v", err) + } + + // Generate fork chain, make it shorter than canon, with common ancestor pretty early + parentIndex := 3 + parent := longChain[parentIndex] + heavyChain, _ := GenerateChain(params.TestChainConfig, parent, engine, db, 75, func(i int, b *BlockGen) { + b.SetCoinbase(common.Address{2}) + b.OffsetTime(-9) + }) + // Verify that the test is sane + var ( + longerTd = new(big.Int) + shorterTd = new(big.Int) + ) + for index, b := range longChain { + longerTd.Add(longerTd, b.Difficulty()) + if index <= parentIndex { + shorterTd.Add(shorterTd, b.Difficulty()) + } + } + for _, b := range heavyChain { + shorterTd.Add(shorterTd, b.Difficulty()) + } + if shorterTd.Cmp(longerTd) <= 0 { + return nil, nil, nil, fmt.Errorf("Test is moot, heavyChain td (%v) must be larger than canon td (%v)", shorterTd, longerTd) + } + longerNum := longChain[len(longChain)-1].NumberU64() + shorterNum := heavyChain[len(heavyChain)-1].NumberU64() + if shorterNum >= longerNum { + return nil, nil, nil, fmt.Errorf("Test is moot, heavyChain num (%v) must be lower than canon num (%v)", shorterNum, longerNum) + } + return chain, longChain, heavyChain, nil +} + +// TestReorgToShorterRemovesCanonMapping tests that if we +// 1. Have a chain [0 ... N .. X] +// 2. Reorg to shorter but heavier chain [0 ... N ... Y] +// 3. Then there should be no canon mapping for the block at height X +func TestReorgToShorterRemovesCanonMapping(t *testing.T) { + chain, canonblocks, sideblocks, err := getLongAndShortChains() + if err != nil { + t.Fatal(err) + } + if n, err := chain.InsertChain(canonblocks); err != nil { + t.Fatalf("block %d: failed to insert into chain: %v", n, err) + } + canonNum := chain.CurrentBlock().NumberU64() + _, err = chain.InsertChain(sideblocks) + if err != nil { + t.Errorf("Got error, %v", err) + } + head := chain.CurrentBlock() + if got := sideblocks[len(sideblocks)-1].Hash(); got != head.Hash() { + t.Fatalf("head wrong, expected %x got %x", head.Hash(), got) + } + // We have now inserted a sidechain. + if blockByNum := chain.GetBlockByNumber(canonNum); blockByNum != nil { + t.Errorf("expected block to be gone: %v", blockByNum.NumberU64()) + } + if headerByNum := chain.GetHeaderByNumber(canonNum); headerByNum != nil { + t.Errorf("expected header to be gone: %v", headerByNum.Number.Uint64()) + } +} + +// TestReorgToShorterRemovesCanonMappingHeaderChain is the same scenario +// as TestReorgToShorterRemovesCanonMapping, but applied on headerchain +// imports -- that is, for fast sync +func TestReorgToShorterRemovesCanonMappingHeaderChain(t *testing.T) { + chain, canonblocks, sideblocks, err := getLongAndShortChains() + if err != nil { + t.Fatal(err) + } + // Convert into headers + canonHeaders := make([]*types.Header, len(canonblocks)) + for i, block := range canonblocks { + canonHeaders[i] = block.Header() + } + if n, err := chain.InsertHeaderChain(canonHeaders, 0); err != nil { + t.Fatalf("header %d: failed to insert into chain: %v", n, err) + } + canonNum := chain.CurrentHeader().Number.Uint64() + sideHeaders := make([]*types.Header, len(sideblocks)) + for i, block := range sideblocks { + sideHeaders[i] = block.Header() + } + if n, err := chain.InsertHeaderChain(sideHeaders, 0); err != nil { + t.Fatalf("header %d: failed to insert into chain: %v", n, err) + } + head := chain.CurrentHeader() + if got := sideblocks[len(sideblocks)-1].Hash(); got != head.Hash() { + t.Fatalf("head wrong, expected %x got %x", head.Hash(), got) + } + // We have now inserted a sidechain. + if blockByNum := chain.GetBlockByNumber(canonNum); blockByNum != nil { + t.Errorf("expected block to be gone: %v", blockByNum.NumberU64()) + } + if headerByNum := chain.GetHeaderByNumber(canonNum); headerByNum != nil { + t.Errorf("expected header to be gone: %v", headerByNum.Number.Uint64()) + } +}