diff --git a/pkg/snapshot/service.go b/pkg/snapshot/service.go index 2ceaf5c..e58fda5 100644 --- a/pkg/snapshot/service.go +++ b/pkg/snapshot/service.go @@ -19,9 +19,12 @@ import ( "context" "fmt" "math/big" + "os" + "os/signal" "sync" + "syscall" - "github.com/cerc-io/eth-iterator-utils/tracker" + "github.com/cerc-io/ipld-eth-state-snapshot/pkg/prom" statediff "github.com/cerc-io/plugeth-statediff" "github.com/cerc-io/plugeth-statediff/adapt" "github.com/cerc-io/plugeth-statediff/indexer" @@ -90,11 +93,14 @@ func (s *Service) CreateSnapshot(params SnapshotParams) error { if header == nil { return fmt.Errorf("unable to read canonical header at height %d", params.Height) } - log.Info("Creating snapshot", "height", params.Height, "hash", hash) + log.WithField("height", params.Height).WithField("hash", hash).Info("Creating snapshot") // Context for snapshot work ctx, cancelCtx := context.WithCancel(context.Background()) defer cancelCtx() + // Cancel context on receiving a signal. On cancellation, all tracked iterators complete + // processing of their current node before stopping. + captureSignal(cancelCtx) var err error tx := s.indexer.BeginTx(header.Number, ctx) @@ -106,11 +112,9 @@ func (s *Service) CreateSnapshot(params SnapshotParams) error { return err } - tracker := tracker.New(s.recoveryFile, params.Workers) - tracker.CaptureSignal(cancelCtx) - + tr := prom.NewTracker(s.recoveryFile, params.Workers) defer func() { - err := tracker.HaltAndDump() + err := tr.CloseAndSave() if err != nil { log.Errorf("failed to write recovery file: %v", err) } @@ -128,7 +132,7 @@ func (s *Service) CreateSnapshot(params SnapshotParams) error { return s.indexer.PushIPLD(tx, c) } - // Build a diff compared against the zero hash to get a full snapshot + // Build a diff against the zero hash (empty trie) to get a full snapshot sdargs := statediff.Args{ NewStateRoot: header.Root, BlockHash: header.Hash(), @@ -140,7 +144,7 @@ func (s *Service) CreateSnapshot(params SnapshotParams) error { sdparams.ComputeWatchedAddressesLeafPaths() builder := statediff.NewBuilder(adapt.GethStateView(s.stateDB)) builder.SetSubtrieWorkers(params.Workers) - if err = builder.WriteStateDiffTracked(sdargs, sdparams, nodeSink, ipldSink, &tracker); err != nil { + if err = builder.WriteStateDiffTracked(sdargs, sdparams, nodeSink, ipldSink, tr); err != nil { return err } @@ -160,3 +164,14 @@ func (s *Service) CreateLatestSnapshot(workers uint, watchedAddresses []common.A } return s.CreateSnapshot(SnapshotParams{Height: *height, Workers: workers, WatchedAddresses: watchedAddresses}) } + +func captureSignal(cb func()) { + sigChan := make(chan os.Signal, 1) + + signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM) + go func() { + sig := <-sigChan + log.Errorf("Signal received (%v), stopping", sig) + cb() + }() +}