diff --git a/pkg/prom/tracker.go b/pkg/prom/tracker.go new file mode 100644 index 0000000..74bf6c6 --- /dev/null +++ b/pkg/prom/tracker.go @@ -0,0 +1,133 @@ +package prom + +import ( + "fmt" + "sync" + "sync/atomic" + + iterutil "github.com/cerc-io/eth-iterator-utils" + "github.com/cerc-io/eth-iterator-utils/tracker" + "github.com/ethereum/go-ethereum/trie" +) + +var trackedIterCount atomic.Int32 + +// Tracker which wraps a tracked iterators in metrics-reporting iterators +type MetricsTracker struct { + *tracker.TrackerImpl +} + +type metricsIterator struct { + trie.NodeIterator + id int32 + // count uint + done bool + lastPath []byte + sync.RWMutex +} + +func NewTracker(file string, bufsize uint) *MetricsTracker { + return &MetricsTracker{TrackerImpl: tracker.NewImpl(file, bufsize)} +} + +func (t *MetricsTracker) wrap(tracked *tracker.Iterator) *metricsIterator { + startPath, endPath := tracked.Bounds() + startDepth := max(len(startPath), len(endPath)) + ret := &metricsIterator{ + NodeIterator: tracked, + id: trackedIterCount.Add(1), + } + RegisterGaugeFunc( + fmt.Sprintf("tracked_iterator_%d", ret.id), + func() float64 { + ret.RLock() + if ret.done { + return 1 + } + lastPath := ret.lastPath + ret.RUnlock() + if lastPath == nil { + return 0 + } + // estimate remaining distance based on current position and node count + depth := max(startDepth, len(lastPath)) + startPath := normalizePath(startPath, depth) + endPath := normalizePath(endPath, depth) + progressed := subtractPaths(lastPath, startPath) + total := subtractPaths(endPath, startPath) + return float64(countSteps(progressed, depth)) / float64(countSteps(total, depth)) + }) + return ret +} + +func (t *MetricsTracker) Restore(ctor iterutil.IteratorConstructor) ( + []trie.NodeIterator, []trie.NodeIterator, error, +) { + iters, bases, err := t.TrackerImpl.Restore(ctor) + if err != nil { + return nil, nil, err + } + ret := make([]trie.NodeIterator, len(iters)) + for i, tracked := range iters { + ret[i] = t.wrap(tracked) + } + return ret, bases, nil +} + +func (t *MetricsTracker) Tracked(it trie.NodeIterator) trie.NodeIterator { + tracked := t.TrackerImpl.Tracked(it) + return t.wrap(tracked) +} + +func (it *metricsIterator) Next(descend bool) bool { + ret := it.NodeIterator.Next(descend) + it.Lock() + defer it.Unlock() + if ret { + it.lastPath = it.Path() + } else { + it.done = true + } + return ret +} + +func normalizePath(path []byte, depth int) []byte { + normalized := make([]byte, depth) + for i := 0; i < depth; i++ { + if i < len(path) { + normalized[i] = path[i] + } + } + return normalized +} + +// Subtract each component, right to left, carrying over if necessary. +func subtractPaths(a, b []byte) []byte { + diff := make([]byte, len(a)) + carry := false + for i := len(a) - 1; i >= 0; i-- { + diff[i] = a[i] - b[i] + if carry { + diff[i]-- + } + carry = a[i] < b[i] + } + return diff +} + +// count total steps in a path according to its depth (length) +func countSteps(path []byte, depth int) uint { + var steps uint + for _, b := range path { + steps *= 16 + steps += uint(b) + } + return steps +} + +func max(a int, b int) int { + if a > b { + return a + } + return b +} diff --git a/pkg/snapshot/service.go b/pkg/snapshot/service.go index 2ceaf5c..e58fda5 100644 --- a/pkg/snapshot/service.go +++ b/pkg/snapshot/service.go @@ -19,9 +19,12 @@ import ( "context" "fmt" "math/big" + "os" + "os/signal" "sync" + "syscall" - "github.com/cerc-io/eth-iterator-utils/tracker" + "github.com/cerc-io/ipld-eth-state-snapshot/pkg/prom" statediff "github.com/cerc-io/plugeth-statediff" "github.com/cerc-io/plugeth-statediff/adapt" "github.com/cerc-io/plugeth-statediff/indexer" @@ -90,11 +93,14 @@ func (s *Service) CreateSnapshot(params SnapshotParams) error { if header == nil { return fmt.Errorf("unable to read canonical header at height %d", params.Height) } - log.Info("Creating snapshot", "height", params.Height, "hash", hash) + log.WithField("height", params.Height).WithField("hash", hash).Info("Creating snapshot") // Context for snapshot work ctx, cancelCtx := context.WithCancel(context.Background()) defer cancelCtx() + // Cancel context on receiving a signal. On cancellation, all tracked iterators complete + // processing of their current node before stopping. + captureSignal(cancelCtx) var err error tx := s.indexer.BeginTx(header.Number, ctx) @@ -106,11 +112,9 @@ func (s *Service) CreateSnapshot(params SnapshotParams) error { return err } - tracker := tracker.New(s.recoveryFile, params.Workers) - tracker.CaptureSignal(cancelCtx) - + tr := prom.NewTracker(s.recoveryFile, params.Workers) defer func() { - err := tracker.HaltAndDump() + err := tr.CloseAndSave() if err != nil { log.Errorf("failed to write recovery file: %v", err) } @@ -128,7 +132,7 @@ func (s *Service) CreateSnapshot(params SnapshotParams) error { return s.indexer.PushIPLD(tx, c) } - // Build a diff compared against the zero hash to get a full snapshot + // Build a diff against the zero hash (empty trie) to get a full snapshot sdargs := statediff.Args{ NewStateRoot: header.Root, BlockHash: header.Hash(), @@ -140,7 +144,7 @@ func (s *Service) CreateSnapshot(params SnapshotParams) error { sdparams.ComputeWatchedAddressesLeafPaths() builder := statediff.NewBuilder(adapt.GethStateView(s.stateDB)) builder.SetSubtrieWorkers(params.Workers) - if err = builder.WriteStateDiffTracked(sdargs, sdparams, nodeSink, ipldSink, &tracker); err != nil { + if err = builder.WriteStateDiffTracked(sdargs, sdparams, nodeSink, ipldSink, tr); err != nil { return err } @@ -160,3 +164,14 @@ func (s *Service) CreateLatestSnapshot(workers uint, watchedAddresses []common.A } return s.CreateSnapshot(SnapshotParams{Height: *height, Workers: workers, WatchedAddresses: watchedAddresses}) } + +func captureSignal(cb func()) { + sigChan := make(chan os.Signal, 1) + + signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM) + go func() { + sig := <-sigChan + log.Errorf("Signal received (%v), stopping", sig) + cb() + }() +}