fix signal handling

This commit is contained in:
Roy Crihfield 2023-09-30 13:46:45 +08:00
parent 32d91694d2
commit 9e9820853a
2 changed files with 58 additions and 26 deletions

View File

@ -70,26 +70,30 @@ func validateTrie() {
stateRootStr := viper.GetString("validator.stateRoot")
storageRootStr := viper.GetString("validator.storageRoot")
contractAddrStr := viper.GetString("validator.address")
traversal := strings.ToLower(viper.GetString("validator.type"))
switch traversal {
case "f", "full":
if stateRootStr == "" {
logWithCommand.Fatal("must provide a state root for full state validation")
}
stateRoot := common.HexToHash(stateRootStr)
if err = v.ValidateTrie(stateRoot); err != nil {
logWithCommand.Fatalf("State for root %s is not complete\r\nerr: %v", stateRoot.String(), err)
}
logWithCommand.Infof("State for root %s is complete", stateRoot.String())
case "state":
if stateRootStr == "" {
logWithCommand.Fatal("must provide a state root for state trie validation")
}
stateRoot := common.HexToHash(stateRootStr)
if err = v.ValidateStateTrie(stateRoot); err != nil {
logWithCommand.Fatalf("State trie for root %s is not complete\r\nerr: %v", stateRoot.String(), err)
traversal := strings.ToLower(viper.GetString("validator.type"))
switch traversal {
case "f", "full":
logWithCommand.
WithField("root", stateRoot).
Debug("Validating full state")
if err = v.ValidateTrie(stateRoot); err != nil {
logWithCommand.Fatalf("Validation failed: %v", err)
}
logWithCommand.Infof("State trie for root %s is complete", stateRoot.String())
logWithCommand.Infof("State for root %s is complete", stateRoot)
case "state":
logWithCommand.
WithField("root", stateRoot).
Debug("Validating state trie")
if err = v.ValidateStateTrie(stateRoot); err != nil {
logWithCommand.Fatalf("Validation failed: %s", err)
}
logWithCommand.Infof("State trie for root %s is complete", stateRoot)
case "storage":
if storageRootStr == "" {
logWithCommand.Fatal("must provide a storage root for storage trie validation")
@ -97,16 +101,16 @@ func validateTrie() {
if contractAddrStr == "" {
logWithCommand.Fatal("must provide a contract address for storage trie validation")
}
if stateRootStr == "" {
logWithCommand.Fatal("must provide a state root for state trie validation")
}
storageRoot := common.HexToHash(storageRootStr)
addr := common.HexToAddress(contractAddrStr)
stateRoot := common.HexToHash(stateRootStr)
logWithCommand.
WithField("contract", addr).
WithField("storage root", storageRoot).
Debug("Validating storage trie")
if err = v.ValidateStorageTrie(stateRoot, addr, storageRoot); err != nil {
logWithCommand.Fatalf("Storage trie for contract %s and root %s not complete\r\nerr: %v", addr.String(), storageRoot.String(), err)
logWithCommand.Fatalf("Validation failed", err)
}
logWithCommand.Infof("Storage trie for contract %s and root %s is complete", addr.String(), storageRoot.String())
logWithCommand.Infof("Storage trie for contract %s and root %s is complete", addr, storageRoot)
default:
logWithCommand.Fatalf("Invalid traversal level: '%s'", traversal)
}

View File

@ -20,6 +20,9 @@ import (
"bytes"
"context"
"fmt"
"os"
"os/signal"
"syscall"
"time"
"github.com/spf13/viper"
@ -143,7 +146,7 @@ func (v *Validator) ValidateTrie(stateRoot common.Hash) error {
if err != nil {
return err
}
iterate := func(it trie.NodeIterator) error { return v.iterate(it, true) }
iterate := func(ctx context.Context, it trie.NodeIterator) error { return v.iterate(ctx, it, true) }
return iterateTracked(t, fmt.Sprintf(v.params.RecoveryFormat, fullTraversal), v.params.Workers, iterate)
}
@ -155,7 +158,7 @@ func (v *Validator) ValidateStateTrie(stateRoot common.Hash) error {
if err != nil {
return err
}
iterate := func(it trie.NodeIterator) error { return v.iterate(it, false) }
iterate := func(ctx context.Context, it trie.NodeIterator) error { return v.iterate(ctx, it, false) }
return iterateTracked(t, fmt.Sprintf(v.params.RecoveryFormat, stateTraversal), v.params.Workers, iterate)
}
@ -167,7 +170,7 @@ func (v *Validator) ValidateStorageTrie(stateRoot common.Hash, address common.Ad
if err != nil {
return err
}
iterate := func(it trie.NodeIterator) error { return v.iterate(it, false) }
iterate := func(ctx context.Context, it trie.NodeIterator) error { return v.iterate(ctx, it, false) }
return iterateTracked(t, fmt.Sprintf(v.params.RecoveryFormat, storageTraversal), v.params.Workers, iterate)
}
@ -181,12 +184,18 @@ func (v *Validator) Close() error {
// Traverses one iterator fully
// If storage = true, also traverse storage tries for each leaf.
func (v *Validator) iterate(it trie.NodeIterator, storage bool) error {
func (v *Validator) iterate(ctx context.Context, it trie.NodeIterator, storage bool) error {
// Iterate through entire state trie. it.Next() will return false when we have
// either completed iteration of the entire trie or run into an error (e.g. a
// missing node). If we are able to iterate through the entire trie without error
// then the trie is complete.
for it.Next(true) {
select {
case <-ctx.Done():
return ctx.Err()
default:
}
// This block adapted from geth - core/state/iterator.go
// If storage is not requested, or the state trie node is an internal entry, skip
if !storage || !it.Leaf() {
@ -219,10 +228,15 @@ func (v *Validator) iterate(it trie.NodeIterator, storage bool) error {
// Traverses each iterator in a separate goroutine.
// Dumps to a recovery file on failure or interrupt.
func iterateTracked(tree state.Trie, recoveryFile string, iterCount uint, fn func(trie.NodeIterator) error) error {
ctx, _ := context.WithCancel(context.Background())
func iterateTracked(
tree state.Trie,
recoveryFile string,
iterCount uint,
fn func(context.Context, trie.NodeIterator) error,
) error {
tracker := tracker.New(recoveryFile, iterCount)
halt := func() {
log.Errorf("writing recovery file: %s", recoveryFile)
if err := tracker.CloseAndSave(); err != nil {
log.Errorf("failed to write recovery file: %v", err)
}
@ -242,14 +256,28 @@ func iterateTracked(tree state.Trie, recoveryFile string, iterCount uint, fn fun
for i, it := range iters {
iters[i] = tracker.Tracked(it)
}
} else {
log.Debugf("restored %d iterators from: %s", len(iters), recoveryFile)
}
ctx, cancel := context.WithCancel(context.Background())
g, ctx := errgroup.WithContext(ctx)
defer halt()
sigChan := make(chan os.Signal, 1)
signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM)
go func() {
sig := <-sigChan
log.Errorf("Signal received (%v), stopping", sig)
cancel()
}()
defer halt()
for _, it := range iters {
func(it trie.NodeIterator) {
g.Go(func() error { return fn(it) })
g.Go(func() error {
return fn(ctx, it)
})
}(it)
}
return g.Wait()