consensus/beacon: check that only the latest pow block is valid ttd block (#25187)

* consensus/beacon: check that only the latest pow block is valid ttd block

* consensus/beacon: move verification to async function

* consensus/beacon: fix verifyTerminalPoWBlock, add test cases

* consensus/beacon: cosmetic changes

* consensus/beacon: apply karalabe's fixes
This commit is contained in:
Marius van der Wijden 2022-06-29 14:13:19 +02:00 committed by GitHub
parent c2070f8d15
commit d12b1a91cd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 182 additions and 2 deletions

View File

@ -112,10 +112,12 @@ func (beacon *Beacon) VerifyHeaders(chain consensus.ChainHeaderReader, headers [
break break
} }
} }
// All the headers have passed the transition point, use new rules. // All the headers have passed the transition point, use new rules.
if len(preHeaders) == 0 { if len(preHeaders) == 0 {
return beacon.verifyHeaders(chain, headers, nil) return beacon.verifyHeaders(chain, headers, nil)
} }
// The transition point exists in the middle, separate the headers // The transition point exists in the middle, separate the headers
// into two batches and apply different verification rules for them. // into two batches and apply different verification rules for them.
var ( var (
@ -130,6 +132,14 @@ func (beacon *Beacon) VerifyHeaders(chain consensus.ChainHeaderReader, headers [
oldDone, oldResult = beacon.ethone.VerifyHeaders(chain, preHeaders, preSeals) oldDone, oldResult = beacon.ethone.VerifyHeaders(chain, preHeaders, preSeals)
newDone, newResult = beacon.verifyHeaders(chain, postHeaders, preHeaders[len(preHeaders)-1]) newDone, newResult = beacon.verifyHeaders(chain, postHeaders, preHeaders[len(preHeaders)-1])
) )
// Verify that pre-merge headers don't overflow the TTD
if index, err := verifyTerminalPoWBlock(chain, preHeaders); err != nil {
// Mark all subsequent pow headers with the error.
for i := index; i < len(preHeaders); i++ {
errors[i], done[i] = err, true
}
}
// Collect the results
for { for {
for ; done[out]; out++ { for ; done[out]; out++ {
results <- errors[out] results <- errors[out]
@ -139,7 +149,9 @@ func (beacon *Beacon) VerifyHeaders(chain consensus.ChainHeaderReader, headers [
} }
select { select {
case err := <-oldResult: case err := <-oldResult:
if !done[old] { // skip TTD-verified failures
errors[old], done[old] = err, true errors[old], done[old] = err, true
}
old++ old++
case err := <-newResult: case err := <-newResult:
errors[new], done[new] = err, true errors[new], done[new] = err, true
@ -154,6 +166,32 @@ func (beacon *Beacon) VerifyHeaders(chain consensus.ChainHeaderReader, headers [
return abort, results return abort, results
} }
// verifyTerminalPoWBlock verifies that the preHeaders confirm to the specification
// wrt. their total difficulty.
// It expects:
// - preHeaders to be at least 1 element
// - the parent of the header element to be stored in the chain correctly
// - the preHeaders to have a set difficulty
// - the last element to be the terminal block
func verifyTerminalPoWBlock(chain consensus.ChainHeaderReader, preHeaders []*types.Header) (int, error) {
td := chain.GetTd(preHeaders[0].ParentHash, preHeaders[0].Number.Uint64()-1)
if td == nil {
return 0, consensus.ErrUnknownAncestor
}
// Check that all blocks before the last one are below the TTD
for i, head := range preHeaders {
if td.Cmp(chain.Config().TerminalTotalDifficulty) >= 0 {
return i, consensus.ErrInvalidTerminalBlock
}
td.Add(td, head.Difficulty)
}
// Check that the last block is the terminal block
if td.Cmp(chain.Config().TerminalTotalDifficulty) < 0 {
return len(preHeaders) - 1, consensus.ErrInvalidTerminalBlock
}
return 0, nil
}
// VerifyUncles verifies that the given block's uncles conform to the consensus // VerifyUncles verifies that the given block's uncles conform to the consensus
// rules of the Ethereum consensus engine. // rules of the Ethereum consensus engine.
func (beacon *Beacon) VerifyUncles(chain consensus.ChainReader, block *types.Block) error { func (beacon *Beacon) VerifyUncles(chain consensus.ChainReader, block *types.Block) error {

View File

@ -0,0 +1,137 @@
package beacon
import (
"fmt"
"math/big"
"testing"
"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/consensus"
"github.com/ethereum/go-ethereum/core/types"
"github.com/ethereum/go-ethereum/params"
)
type mockChain struct {
config *params.ChainConfig
tds map[uint64]*big.Int
}
func newMockChain() *mockChain {
return &mockChain{
config: new(params.ChainConfig),
tds: make(map[uint64]*big.Int),
}
}
func (m *mockChain) Config() *params.ChainConfig {
return m.config
}
func (m *mockChain) CurrentHeader() *types.Header { panic("not implemented") }
func (m *mockChain) GetHeader(hash common.Hash, number uint64) *types.Header {
panic("not implemented")
}
func (m *mockChain) GetHeaderByNumber(number uint64) *types.Header { panic("not implemented") }
func (m *mockChain) GetHeaderByHash(hash common.Hash) *types.Header { panic("not implemented") }
func (m *mockChain) GetTd(hash common.Hash, number uint64) *big.Int {
num, ok := m.tds[number]
if ok {
return new(big.Int).Set(num)
}
return nil
}
func TestVerifyTerminalBlock(t *testing.T) {
chain := newMockChain()
chain.tds[0] = big.NewInt(10)
chain.config.TerminalTotalDifficulty = big.NewInt(50)
tests := []struct {
preHeaders []*types.Header
ttd *big.Int
err error
index int
}{
// valid ttd
{
preHeaders: []*types.Header{
{Number: big.NewInt(1), Difficulty: big.NewInt(10)},
{Number: big.NewInt(2), Difficulty: big.NewInt(10)},
{Number: big.NewInt(3), Difficulty: big.NewInt(10)},
{Number: big.NewInt(4), Difficulty: big.NewInt(10)},
},
ttd: big.NewInt(50),
},
// last block doesn't reach ttd
{
preHeaders: []*types.Header{
{Number: big.NewInt(1), Difficulty: big.NewInt(10)},
{Number: big.NewInt(2), Difficulty: big.NewInt(10)},
{Number: big.NewInt(3), Difficulty: big.NewInt(10)},
{Number: big.NewInt(4), Difficulty: big.NewInt(9)},
},
ttd: big.NewInt(50),
err: consensus.ErrInvalidTerminalBlock,
index: 3,
},
// two blocks reach ttd
{
preHeaders: []*types.Header{
{Number: big.NewInt(1), Difficulty: big.NewInt(10)},
{Number: big.NewInt(2), Difficulty: big.NewInt(10)},
{Number: big.NewInt(3), Difficulty: big.NewInt(20)},
{Number: big.NewInt(4), Difficulty: big.NewInt(10)},
},
ttd: big.NewInt(50),
err: consensus.ErrInvalidTerminalBlock,
index: 3,
},
// three blocks reach ttd
{
preHeaders: []*types.Header{
{Number: big.NewInt(1), Difficulty: big.NewInt(10)},
{Number: big.NewInt(2), Difficulty: big.NewInt(10)},
{Number: big.NewInt(3), Difficulty: big.NewInt(20)},
{Number: big.NewInt(4), Difficulty: big.NewInt(10)},
{Number: big.NewInt(4), Difficulty: big.NewInt(10)},
},
ttd: big.NewInt(50),
err: consensus.ErrInvalidTerminalBlock,
index: 3,
},
// parent reached ttd
{
preHeaders: []*types.Header{
{Number: big.NewInt(1), Difficulty: big.NewInt(10)},
},
ttd: big.NewInt(9),
err: consensus.ErrInvalidTerminalBlock,
index: 0,
},
// unknown parent
{
preHeaders: []*types.Header{
{Number: big.NewInt(4), Difficulty: big.NewInt(10)},
},
ttd: big.NewInt(9),
err: consensus.ErrUnknownAncestor,
index: 0,
},
}
for i, test := range tests {
fmt.Printf("Test: %v\n", i)
chain.config.TerminalTotalDifficulty = test.ttd
index, err := verifyTerminalPoWBlock(chain, test.preHeaders)
if err != test.err {
t.Fatalf("Invalid error encountered, expected %v got %v", test.err, err)
}
if index != test.index {
t.Fatalf("Invalid index, expected %v got %v", test.index, index)
}
}
}

View File

@ -34,4 +34,8 @@ var (
// ErrInvalidNumber is returned if a block's number doesn't equal its parent's // ErrInvalidNumber is returned if a block's number doesn't equal its parent's
// plus one. // plus one.
ErrInvalidNumber = errors.New("invalid block number") ErrInvalidNumber = errors.New("invalid block number")
// ErrInvalidTerminalBlock is returned if a block is invalid wrt. the terminal
// total difficulty.
ErrInvalidTerminalBlock = errors.New("invalid terminal block")
) )

View File

@ -108,6 +108,7 @@ func testHeaderVerificationForMerging(t *testing.T, isClique bool) {
addr: {Balance: big.NewInt(1)}, addr: {Balance: big.NewInt(1)},
}, },
BaseFee: big.NewInt(params.InitialBaseFee), BaseFee: big.NewInt(params.InitialBaseFee),
Difficulty: new(big.Int),
} }
copy(genspec.ExtraData[32:], addr[:]) copy(genspec.ExtraData[32:], addr[:])
genesis := genspec.MustCommit(testdb) genesis := genspec.MustCommit(testdb)