core/forkid: skip genesis forks by time (#28034)

* core/forkid: skip genesis forks by time

* core/forkid: add comment about skipping non-zero fork times

* core/forkid: skip all time based forks in genesis using loop

* core/forkid: simplify logic for dropping time-based forks
This commit is contained in:
lightclient 2023-09-04 07:32:14 -06:00 committed by GitHub
parent f260a9edb9
commit eff7c3bda0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 15 additions and 18 deletions

View File

@ -25,6 +25,7 @@ import (
"strings" "strings"
"time" "time"
"github.com/ethereum/go-ethereum/core"
"github.com/ethereum/go-ethereum/core/forkid" "github.com/ethereum/go-ethereum/core/forkid"
"github.com/ethereum/go-ethereum/p2p/enr" "github.com/ethereum/go-ethereum/p2p/enr"
"github.com/ethereum/go-ethereum/params" "github.com/ethereum/go-ethereum/params"
@ -228,13 +229,13 @@ func ethFilter(args []string) (nodeFilter, error) {
var filter forkid.Filter var filter forkid.Filter
switch args[0] { switch args[0] {
case "mainnet": case "mainnet":
filter = forkid.NewStaticFilter(params.MainnetChainConfig, params.MainnetGenesisHash) filter = forkid.NewStaticFilter(params.MainnetChainConfig, core.DefaultGenesisBlock().ToBlock())
case "goerli": case "goerli":
filter = forkid.NewStaticFilter(params.GoerliChainConfig, params.GoerliGenesisHash) filter = forkid.NewStaticFilter(params.GoerliChainConfig, core.DefaultGoerliGenesisBlock().ToBlock())
case "sepolia": case "sepolia":
filter = forkid.NewStaticFilter(params.SepoliaChainConfig, params.SepoliaGenesisHash) filter = forkid.NewStaticFilter(params.SepoliaChainConfig, core.DefaultSepoliaGenesisBlock().ToBlock())
case "holesky": case "holesky":
filter = forkid.NewStaticFilter(params.HoleskyChainConfig, params.HoleskyGenesisHash) filter = forkid.NewStaticFilter(params.HoleskyChainConfig, core.DefaultHoleskyGenesisBlock().ToBlock())
default: default:
return nil, fmt.Errorf("unknown network %q", args[0]) return nil, fmt.Errorf("unknown network %q", args[0])
} }

View File

@ -26,7 +26,6 @@ import (
"reflect" "reflect"
"strings" "strings"
"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/core/types" "github.com/ethereum/go-ethereum/core/types"
"github.com/ethereum/go-ethereum/log" "github.com/ethereum/go-ethereum/log"
"github.com/ethereum/go-ethereum/params" "github.com/ethereum/go-ethereum/params"
@ -78,7 +77,7 @@ func NewID(config *params.ChainConfig, genesis *types.Block, head, time uint64)
hash := crc32.ChecksumIEEE(genesis.Hash().Bytes()) hash := crc32.ChecksumIEEE(genesis.Hash().Bytes())
// Calculate the current fork checksum and the next fork block // Calculate the current fork checksum and the next fork block
forksByBlock, forksByTime := gatherForks(config) forksByBlock, forksByTime := gatherForks(config, genesis.Time())
for _, fork := range forksByBlock { for _, fork := range forksByBlock {
if fork <= head { if fork <= head {
// Fork already passed, checksum the previous hash and the fork number // Fork already passed, checksum the previous hash and the fork number
@ -88,10 +87,6 @@ func NewID(config *params.ChainConfig, genesis *types.Block, head, time uint64)
return ID{Hash: checksumToBytes(hash), Next: fork} return ID{Hash: checksumToBytes(hash), Next: fork}
} }
for _, fork := range forksByTime { for _, fork := range forksByTime {
if fork <= genesis.Time() {
// Fork active in genesis, skip in forkid calculation
continue
}
if fork <= time { if fork <= time {
// Fork already passed, checksum the previous hash and fork timestamp // Fork already passed, checksum the previous hash and fork timestamp
hash = checksumUpdate(hash, fork) hash = checksumUpdate(hash, fork)
@ -119,7 +114,7 @@ func NewIDWithChain(chain Blockchain) ID {
func NewFilter(chain Blockchain) Filter { func NewFilter(chain Blockchain) Filter {
return newFilter( return newFilter(
chain.Config(), chain.Config(),
chain.Genesis().Hash(), chain.Genesis(),
func() (uint64, uint64) { func() (uint64, uint64) {
head := chain.CurrentHeader() head := chain.CurrentHeader()
return head.Number.Uint64(), head.Time return head.Number.Uint64(), head.Time
@ -128,7 +123,7 @@ func NewFilter(chain Blockchain) Filter {
} }
// NewStaticFilter creates a filter at block zero. // NewStaticFilter creates a filter at block zero.
func NewStaticFilter(config *params.ChainConfig, genesis common.Hash) Filter { func NewStaticFilter(config *params.ChainConfig, genesis *types.Block) Filter {
head := func() (uint64, uint64) { return 0, 0 } head := func() (uint64, uint64) { return 0, 0 }
return newFilter(config, genesis, head) return newFilter(config, genesis, head)
} }
@ -136,14 +131,14 @@ func NewStaticFilter(config *params.ChainConfig, genesis common.Hash) Filter {
// newFilter is the internal version of NewFilter, taking closures as its arguments // newFilter is the internal version of NewFilter, taking closures as its arguments
// instead of a chain. The reason is to allow testing it without having to simulate // instead of a chain. The reason is to allow testing it without having to simulate
// an entire blockchain. // an entire blockchain.
func newFilter(config *params.ChainConfig, genesis common.Hash, headfn func() (uint64, uint64)) Filter { func newFilter(config *params.ChainConfig, genesis *types.Block, headfn func() (uint64, uint64)) Filter {
// Calculate the all the valid fork hash and fork next combos // Calculate the all the valid fork hash and fork next combos
var ( var (
forksByBlock, forksByTime = gatherForks(config) forksByBlock, forksByTime = gatherForks(config, genesis.Time())
forks = append(append([]uint64{}, forksByBlock...), forksByTime...) forks = append(append([]uint64{}, forksByBlock...), forksByTime...)
sums = make([][4]byte, len(forks)+1) // 0th is the genesis sums = make([][4]byte, len(forks)+1) // 0th is the genesis
) )
hash := crc32.ChecksumIEEE(genesis[:]) hash := crc32.ChecksumIEEE(genesis.Hash().Bytes())
sums[0] = checksumToBytes(hash) sums[0] = checksumToBytes(hash)
for i, fork := range forks { for i, fork := range forks {
hash = checksumUpdate(hash, fork) hash = checksumUpdate(hash, fork)
@ -244,7 +239,7 @@ func checksumToBytes(hash uint32) [4]byte {
// gatherForks gathers all the known forks and creates two sorted lists out of // gatherForks gathers all the known forks and creates two sorted lists out of
// them, one for the block number based forks and the second for the timestamps. // them, one for the block number based forks and the second for the timestamps.
func gatherForks(config *params.ChainConfig) ([]uint64, []uint64) { func gatherForks(config *params.ChainConfig, genesis uint64) ([]uint64, []uint64) {
// Gather all the fork block numbers via reflection // Gather all the fork block numbers via reflection
kind := reflect.TypeOf(params.ChainConfig{}) kind := reflect.TypeOf(params.ChainConfig{})
conf := reflect.ValueOf(config).Elem() conf := reflect.ValueOf(config).Elem()
@ -294,7 +289,8 @@ func gatherForks(config *params.ChainConfig) ([]uint64, []uint64) {
if len(forksByBlock) > 0 && forksByBlock[0] == 0 { if len(forksByBlock) > 0 && forksByBlock[0] == 0 {
forksByBlock = forksByBlock[1:] forksByBlock = forksByBlock[1:]
} }
if len(forksByTime) > 0 && forksByTime[0] == 0 { // Skip any forks before genesis.
for len(forksByTime) > 0 && forksByTime[0] <= genesis {
forksByTime = forksByTime[1:] forksByTime = forksByTime[1:]
} }
return forksByBlock, forksByTime return forksByBlock, forksByTime

View File

@ -357,7 +357,7 @@ func TestValidation(t *testing.T) {
//{params.MainnetChainConfig, 20999999, 1677999999, ID{Hash: checksumToBytes(0x71147644), Next: 1678000000}, ErrLocalIncompatibleOrStale}, //{params.MainnetChainConfig, 20999999, 1677999999, ID{Hash: checksumToBytes(0x71147644), Next: 1678000000}, ErrLocalIncompatibleOrStale},
} }
for i, tt := range tests { for i, tt := range tests {
filter := newFilter(tt.config, params.MainnetGenesisHash, func() (uint64, uint64) { return tt.head, tt.time }) filter := newFilter(tt.config, core.DefaultGenesisBlock().ToBlock(), func() (uint64, uint64) { return tt.head, tt.time })
if err := filter(tt.id); err != tt.err { if err := filter(tt.id); err != tt.err {
t.Errorf("test %d: validation error mismatch: have %v, want %v", i, err, tt.err) t.Errorf("test %d: validation error mismatch: have %v, want %v", i, err, tt.err)
} }