Fix signal handling #3

Merged
roysc merged 2 commits from fix-signal-handling into v5 2023-10-03 13:32:34 +00:00
2 changed files with 58 additions and 26 deletions
Showing only changes of commit 9e9820853a - Show all commits

View File

@ -70,26 +70,30 @@ func validateTrie() {
stateRootStr := viper.GetString("validator.stateRoot") stateRootStr := viper.GetString("validator.stateRoot")
storageRootStr := viper.GetString("validator.storageRoot") storageRootStr := viper.GetString("validator.storageRoot")
contractAddrStr := viper.GetString("validator.address") contractAddrStr := viper.GetString("validator.address")
if stateRootStr == "" {
logWithCommand.Fatal("must provide a state root for state trie validation")
}
stateRoot := common.HexToHash(stateRootStr)
traversal := strings.ToLower(viper.GetString("validator.type")) traversal := strings.ToLower(viper.GetString("validator.type"))
switch traversal { switch traversal {
case "f", "full": case "f", "full":
if stateRootStr == "" { logWithCommand.
logWithCommand.Fatal("must provide a state root for full state validation") WithField("root", stateRoot).
} Debug("Validating full state")
stateRoot := common.HexToHash(stateRootStr)
if err = v.ValidateTrie(stateRoot); err != nil { if err = v.ValidateTrie(stateRoot); err != nil {
logWithCommand.Fatalf("State for root %s is not complete\r\nerr: %v", stateRoot.String(), err) logWithCommand.Fatalf("Validation failed: %v", err)
} }
logWithCommand.Infof("State for root %s is complete", stateRoot.String()) logWithCommand.Infof("State for root %s is complete", stateRoot)
case "state": case "state":
if stateRootStr == "" { logWithCommand.
logWithCommand.Fatal("must provide a state root for state trie validation") WithField("root", stateRoot).
} Debug("Validating state trie")
stateRoot := common.HexToHash(stateRootStr)
if err = v.ValidateStateTrie(stateRoot); err != nil { if err = v.ValidateStateTrie(stateRoot); err != nil {
logWithCommand.Fatalf("State trie for root %s is not complete\r\nerr: %v", stateRoot.String(), err) logWithCommand.Fatalf("Validation failed: %s", err)
} }
logWithCommand.Infof("State trie for root %s is complete", stateRoot.String()) logWithCommand.Infof("State trie for root %s is complete", stateRoot)
case "storage": case "storage":
if storageRootStr == "" { if storageRootStr == "" {
logWithCommand.Fatal("must provide a storage root for storage trie validation") logWithCommand.Fatal("must provide a storage root for storage trie validation")
@ -97,16 +101,16 @@ func validateTrie() {
if contractAddrStr == "" { if contractAddrStr == "" {
logWithCommand.Fatal("must provide a contract address for storage trie validation") logWithCommand.Fatal("must provide a contract address for storage trie validation")
} }
if stateRootStr == "" {
logWithCommand.Fatal("must provide a state root for state trie validation")
}
storageRoot := common.HexToHash(storageRootStr) storageRoot := common.HexToHash(storageRootStr)
addr := common.HexToAddress(contractAddrStr) addr := common.HexToAddress(contractAddrStr)
stateRoot := common.HexToHash(stateRootStr) logWithCommand.
WithField("contract", addr).
WithField("storage root", storageRoot).
Debug("Validating storage trie")
if err = v.ValidateStorageTrie(stateRoot, addr, storageRoot); err != nil { if err = v.ValidateStorageTrie(stateRoot, addr, storageRoot); err != nil {
logWithCommand.Fatalf("Storage trie for contract %s and root %s not complete\r\nerr: %v", addr.String(), storageRoot.String(), err) logWithCommand.Fatalf("Validation failed", err)
} }
logWithCommand.Infof("Storage trie for contract %s and root %s is complete", addr.String(), storageRoot.String()) logWithCommand.Infof("Storage trie for contract %s and root %s is complete", addr, storageRoot)
default: default:
logWithCommand.Fatalf("Invalid traversal level: '%s'", traversal) logWithCommand.Fatalf("Invalid traversal level: '%s'", traversal)
} }

View File

@ -20,6 +20,9 @@ import (
"bytes" "bytes"
"context" "context"
"fmt" "fmt"
"os"
"os/signal"
"syscall"
"time" "time"
"github.com/spf13/viper" "github.com/spf13/viper"
@ -143,7 +146,7 @@ func (v *Validator) ValidateTrie(stateRoot common.Hash) error {
if err != nil { if err != nil {
return err return err
} }
iterate := func(it trie.NodeIterator) error { return v.iterate(it, true) } iterate := func(ctx context.Context, it trie.NodeIterator) error { return v.iterate(ctx, it, true) }
return iterateTracked(t, fmt.Sprintf(v.params.RecoveryFormat, fullTraversal), v.params.Workers, iterate) return iterateTracked(t, fmt.Sprintf(v.params.RecoveryFormat, fullTraversal), v.params.Workers, iterate)
} }
@ -155,7 +158,7 @@ func (v *Validator) ValidateStateTrie(stateRoot common.Hash) error {
if err != nil { if err != nil {
return err return err
} }
iterate := func(it trie.NodeIterator) error { return v.iterate(it, false) } iterate := func(ctx context.Context, it trie.NodeIterator) error { return v.iterate(ctx, it, false) }
return iterateTracked(t, fmt.Sprintf(v.params.RecoveryFormat, stateTraversal), v.params.Workers, iterate) return iterateTracked(t, fmt.Sprintf(v.params.RecoveryFormat, stateTraversal), v.params.Workers, iterate)
} }
@ -167,7 +170,7 @@ func (v *Validator) ValidateStorageTrie(stateRoot common.Hash, address common.Ad
if err != nil { if err != nil {
return err return err
} }
iterate := func(it trie.NodeIterator) error { return v.iterate(it, false) } iterate := func(ctx context.Context, it trie.NodeIterator) error { return v.iterate(ctx, it, false) }
return iterateTracked(t, fmt.Sprintf(v.params.RecoveryFormat, storageTraversal), v.params.Workers, iterate) return iterateTracked(t, fmt.Sprintf(v.params.RecoveryFormat, storageTraversal), v.params.Workers, iterate)
} }
@ -181,12 +184,18 @@ func (v *Validator) Close() error {
// Traverses one iterator fully // Traverses one iterator fully
// If storage = true, also traverse storage tries for each leaf. // If storage = true, also traverse storage tries for each leaf.
func (v *Validator) iterate(it trie.NodeIterator, storage bool) error { func (v *Validator) iterate(ctx context.Context, it trie.NodeIterator, storage bool) error {
// Iterate through entire state trie. it.Next() will return false when we have // Iterate through entire state trie. it.Next() will return false when we have
// either completed iteration of the entire trie or run into an error (e.g. a // either completed iteration of the entire trie or run into an error (e.g. a
// missing node). If we are able to iterate through the entire trie without error // missing node). If we are able to iterate through the entire trie without error
// then the trie is complete. // then the trie is complete.
for it.Next(true) { for it.Next(true) {
select {
case <-ctx.Done():
return ctx.Err()
default:
}
// This block adapted from geth - core/state/iterator.go // This block adapted from geth - core/state/iterator.go
// If storage is not requested, or the state trie node is an internal entry, skip // If storage is not requested, or the state trie node is an internal entry, skip
if !storage || !it.Leaf() { if !storage || !it.Leaf() {
@ -219,10 +228,15 @@ func (v *Validator) iterate(it trie.NodeIterator, storage bool) error {
// Traverses each iterator in a separate goroutine. // Traverses each iterator in a separate goroutine.
// Dumps to a recovery file on failure or interrupt. // Dumps to a recovery file on failure or interrupt.
func iterateTracked(tree state.Trie, recoveryFile string, iterCount uint, fn func(trie.NodeIterator) error) error { func iterateTracked(
ctx, _ := context.WithCancel(context.Background()) tree state.Trie,
recoveryFile string,
iterCount uint,
fn func(context.Context, trie.NodeIterator) error,
) error {
tracker := tracker.New(recoveryFile, iterCount) tracker := tracker.New(recoveryFile, iterCount)
halt := func() { halt := func() {
log.Errorf("writing recovery file: %s", recoveryFile)
if err := tracker.CloseAndSave(); err != nil { if err := tracker.CloseAndSave(); err != nil {
log.Errorf("failed to write recovery file: %v", err) log.Errorf("failed to write recovery file: %v", err)
} }
@ -242,14 +256,28 @@ func iterateTracked(tree state.Trie, recoveryFile string, iterCount uint, fn fun
for i, it := range iters { for i, it := range iters {
iters[i] = tracker.Tracked(it) iters[i] = tracker.Tracked(it)
} }
} else {
log.Debugf("restored %d iterators from: %s", len(iters), recoveryFile)
} }
ctx, cancel := context.WithCancel(context.Background())
g, ctx := errgroup.WithContext(ctx) g, ctx := errgroup.WithContext(ctx)
defer halt()
sigChan := make(chan os.Signal, 1)
signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM)
go func() {
sig := <-sigChan
log.Errorf("Signal received (%v), stopping", sig)
cancel()
}()
defer halt()
for _, it := range iters { for _, it := range iters {
func(it trie.NodeIterator) { func(it trie.NodeIterator) {
g.Go(func() error { return fn(it) }) g.Go(func() error {
return fn(ctx, it)
})
}(it) }(it)
} }
return g.Wait() return g.Wait()