diff --git a/cmd/validateTrie.go b/cmd/validateTrie.go index 8c1e6da..ce378bf 100644 --- a/cmd/validateTrie.go +++ b/cmd/validateTrie.go @@ -70,26 +70,30 @@ func validateTrie() { stateRootStr := viper.GetString("validator.stateRoot") storageRootStr := viper.GetString("validator.storageRoot") contractAddrStr := viper.GetString("validator.address") + + if stateRootStr == "" { + logWithCommand.Fatal("must provide a state root for state trie validation") + } + stateRoot := common.HexToHash(stateRootStr) + traversal := strings.ToLower(viper.GetString("validator.type")) switch traversal { case "f", "full": - if stateRootStr == "" { - logWithCommand.Fatal("must provide a state root for full state validation") - } - stateRoot := common.HexToHash(stateRootStr) + logWithCommand. + WithField("root", stateRoot). + Debug("Validating full state") if err = v.ValidateTrie(stateRoot); err != nil { - logWithCommand.Fatalf("State for root %s is not complete\r\nerr: %v", stateRoot.String(), err) + logWithCommand.Fatalf("Validation failed: %v", err) } - logWithCommand.Infof("State for root %s is complete", stateRoot.String()) + logWithCommand.Infof("State for root %s is complete", stateRoot) case "state": - if stateRootStr == "" { - logWithCommand.Fatal("must provide a state root for state trie validation") - } - stateRoot := common.HexToHash(stateRootStr) + logWithCommand. + WithField("root", stateRoot). + Debug("Validating state trie") if err = v.ValidateStateTrie(stateRoot); err != nil { - logWithCommand.Fatalf("State trie for root %s is not complete\r\nerr: %v", stateRoot.String(), err) + logWithCommand.Fatalf("Validation failed: %s", err) } - logWithCommand.Infof("State trie for root %s is complete", stateRoot.String()) + logWithCommand.Infof("State trie for root %s is complete", stateRoot) case "storage": if storageRootStr == "" { logWithCommand.Fatal("must provide a storage root for storage trie validation") @@ -97,16 +101,16 @@ func validateTrie() { if contractAddrStr == "" { logWithCommand.Fatal("must provide a contract address for storage trie validation") } - if stateRootStr == "" { - logWithCommand.Fatal("must provide a state root for state trie validation") - } storageRoot := common.HexToHash(storageRootStr) addr := common.HexToAddress(contractAddrStr) - stateRoot := common.HexToHash(stateRootStr) + logWithCommand. + WithField("contract", addr). + WithField("storage root", storageRoot). + Debug("Validating storage trie") if err = v.ValidateStorageTrie(stateRoot, addr, storageRoot); err != nil { - logWithCommand.Fatalf("Storage trie for contract %s and root %s not complete\r\nerr: %v", addr.String(), storageRoot.String(), err) + logWithCommand.Fatalf("Validation failed", err) } - logWithCommand.Infof("Storage trie for contract %s and root %s is complete", addr.String(), storageRoot.String()) + logWithCommand.Infof("Storage trie for contract %s and root %s is complete", addr, storageRoot) default: logWithCommand.Fatalf("Invalid traversal level: '%s'", traversal) } diff --git a/pkg/validator.go b/pkg/validator.go index 52d94a4..4fcfaa0 100644 --- a/pkg/validator.go +++ b/pkg/validator.go @@ -20,6 +20,9 @@ import ( "bytes" "context" "fmt" + "os" + "os/signal" + "syscall" "time" "github.com/spf13/viper" @@ -143,7 +146,7 @@ func (v *Validator) ValidateTrie(stateRoot common.Hash) error { if err != nil { return err } - iterate := func(it trie.NodeIterator) error { return v.iterate(it, true) } + iterate := func(ctx context.Context, it trie.NodeIterator) error { return v.iterate(ctx, it, true) } return iterateTracked(t, fmt.Sprintf(v.params.RecoveryFormat, fullTraversal), v.params.Workers, iterate) } @@ -155,7 +158,7 @@ func (v *Validator) ValidateStateTrie(stateRoot common.Hash) error { if err != nil { return err } - iterate := func(it trie.NodeIterator) error { return v.iterate(it, false) } + iterate := func(ctx context.Context, it trie.NodeIterator) error { return v.iterate(ctx, it, false) } return iterateTracked(t, fmt.Sprintf(v.params.RecoveryFormat, stateTraversal), v.params.Workers, iterate) } @@ -167,7 +170,7 @@ func (v *Validator) ValidateStorageTrie(stateRoot common.Hash, address common.Ad if err != nil { return err } - iterate := func(it trie.NodeIterator) error { return v.iterate(it, false) } + iterate := func(ctx context.Context, it trie.NodeIterator) error { return v.iterate(ctx, it, false) } return iterateTracked(t, fmt.Sprintf(v.params.RecoveryFormat, storageTraversal), v.params.Workers, iterate) } @@ -181,12 +184,18 @@ func (v *Validator) Close() error { // Traverses one iterator fully // If storage = true, also traverse storage tries for each leaf. -func (v *Validator) iterate(it trie.NodeIterator, storage bool) error { +func (v *Validator) iterate(ctx context.Context, it trie.NodeIterator, storage bool) error { // Iterate through entire state trie. it.Next() will return false when we have // either completed iteration of the entire trie or run into an error (e.g. a // missing node). If we are able to iterate through the entire trie without error // then the trie is complete. for it.Next(true) { + select { + case <-ctx.Done(): + return ctx.Err() + default: + } + // This block adapted from geth - core/state/iterator.go // If storage is not requested, or the state trie node is an internal entry, skip if !storage || !it.Leaf() { @@ -219,10 +228,15 @@ func (v *Validator) iterate(it trie.NodeIterator, storage bool) error { // Traverses each iterator in a separate goroutine. // Dumps to a recovery file on failure or interrupt. -func iterateTracked(tree state.Trie, recoveryFile string, iterCount uint, fn func(trie.NodeIterator) error) error { - ctx, _ := context.WithCancel(context.Background()) +func iterateTracked( + tree state.Trie, + recoveryFile string, + iterCount uint, + fn func(context.Context, trie.NodeIterator) error, +) error { tracker := tracker.New(recoveryFile, iterCount) halt := func() { + log.Errorf("writing recovery file: %s", recoveryFile) if err := tracker.CloseAndSave(); err != nil { log.Errorf("failed to write recovery file: %v", err) } @@ -242,14 +256,28 @@ func iterateTracked(tree state.Trie, recoveryFile string, iterCount uint, fn fun for i, it := range iters { iters[i] = tracker.Tracked(it) } + } else { + log.Debugf("restored %d iterators from: %s", len(iters), recoveryFile) } + ctx, cancel := context.WithCancel(context.Background()) g, ctx := errgroup.WithContext(ctx) - defer halt() + sigChan := make(chan os.Signal, 1) + signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM) + go func() { + sig := <-sigChan + log.Errorf("Signal received (%v), stopping", sig) + cancel() + }() + + defer halt() for _, it := range iters { func(it trie.NodeIterator) { - g.Go(func() error { return fn(it) }) + g.Go(func() error { + return fn(ctx, it) + }) + }(it) } return g.Wait() diff --git a/test/compose.yml b/test/compose.yml index de8e3dc..730728d 100644 --- a/test/compose.yml +++ b/test/compose.yml @@ -5,7 +5,7 @@ services: restart: on-failure depends_on: - ipld-eth-db - image: git.vdb.to/cerc-io/ipld-eth-db/ipld-eth-db:v5.0.2-alpha + image: git.vdb.to/cerc-io/ipld-eth-db/ipld-eth-db:v5.0.5-alpha environment: DATABASE_USER: "vdbm" DATABASE_NAME: "cerc_testing"