Refactor to use statediff plugin #1

Merged
roysc merged 27 commits from refactor-use-plugin into v5 2023-09-29 18:43:28 +00:00
2 changed files with 156 additions and 8 deletions
Showing only changes of commit bcda12388b - Show all commits

133
pkg/prom/tracker.go Normal file
View File

@ -0,0 +1,133 @@
package prom
import (
"fmt"
"sync"
"sync/atomic"
iterutil "github.com/cerc-io/eth-iterator-utils"
"github.com/cerc-io/eth-iterator-utils/tracker"
"github.com/ethereum/go-ethereum/trie"
)
var trackedIterCount atomic.Int32
// Tracker which wraps a tracked iterators in metrics-reporting iterators
type MetricsTracker struct {
*tracker.TrackerImpl
}
type metricsIterator struct {
trie.NodeIterator
id int32
// count uint
done bool
lastPath []byte
sync.RWMutex
}
func NewTracker(file string, bufsize uint) *MetricsTracker {
return &MetricsTracker{TrackerImpl: tracker.NewImpl(file, bufsize)}
}
func (t *MetricsTracker) wrap(tracked *tracker.Iterator) *metricsIterator {
startPath, endPath := tracked.Bounds()
startDepth := max(len(startPath), len(endPath))
ret := &metricsIterator{
NodeIterator: tracked,
id: trackedIterCount.Add(1),
}
RegisterGaugeFunc(
fmt.Sprintf("tracked_iterator_%d", ret.id),
func() float64 {
ret.RLock()
if ret.done {
return 1
}
lastPath := ret.lastPath
ret.RUnlock()
if lastPath == nil {
return 0
}
// estimate remaining distance based on current position and node count
depth := max(startDepth, len(lastPath))
startPath := normalizePath(startPath, depth)
endPath := normalizePath(endPath, depth)
progressed := subtractPaths(lastPath, startPath)
total := subtractPaths(endPath, startPath)
return float64(countSteps(progressed, depth)) / float64(countSteps(total, depth))
})
return ret
}
func (t *MetricsTracker) Restore(ctor iterutil.IteratorConstructor) (
[]trie.NodeIterator, []trie.NodeIterator, error,
) {
iters, bases, err := t.TrackerImpl.Restore(ctor)
if err != nil {
return nil, nil, err
}
ret := make([]trie.NodeIterator, len(iters))
for i, tracked := range iters {
ret[i] = t.wrap(tracked)
}
return ret, bases, nil
}
func (t *MetricsTracker) Tracked(it trie.NodeIterator) trie.NodeIterator {
tracked := t.TrackerImpl.Tracked(it)
return t.wrap(tracked)
}
func (it *metricsIterator) Next(descend bool) bool {
ret := it.NodeIterator.Next(descend)
it.Lock()
defer it.Unlock()
if ret {
it.lastPath = it.Path()
} else {
it.done = true
}
return ret
}
func normalizePath(path []byte, depth int) []byte {
normalized := make([]byte, depth)
for i := 0; i < depth; i++ {
if i < len(path) {
normalized[i] = path[i]
}
}
return normalized
}
// Subtract each component, right to left, carrying over if necessary.
func subtractPaths(a, b []byte) []byte {
diff := make([]byte, len(a))
carry := false
for i := len(a) - 1; i >= 0; i-- {
diff[i] = a[i] - b[i]
if carry {
diff[i]--
}
carry = a[i] < b[i]
}
return diff
}
// count total steps in a path according to its depth (length)
func countSteps(path []byte, depth int) uint {
var steps uint
for _, b := range path {
steps *= 16
steps += uint(b)
}
return steps
}
func max(a int, b int) int {
if a > b {
return a
}
return b
}

View File

@ -19,9 +19,12 @@ import (
"context" "context"
"fmt" "fmt"
"math/big" "math/big"
"os"
"os/signal"
"sync" "sync"
"syscall"
"github.com/cerc-io/eth-iterator-utils/tracker" "github.com/cerc-io/ipld-eth-state-snapshot/pkg/prom"
statediff "github.com/cerc-io/plugeth-statediff" statediff "github.com/cerc-io/plugeth-statediff"
"github.com/cerc-io/plugeth-statediff/adapt" "github.com/cerc-io/plugeth-statediff/adapt"
"github.com/cerc-io/plugeth-statediff/indexer" "github.com/cerc-io/plugeth-statediff/indexer"
@ -90,11 +93,14 @@ func (s *Service) CreateSnapshot(params SnapshotParams) error {
if header == nil { if header == nil {
return fmt.Errorf("unable to read canonical header at height %d", params.Height) return fmt.Errorf("unable to read canonical header at height %d", params.Height)
} }
log.Info("Creating snapshot", "height", params.Height, "hash", hash) log.WithField("height", params.Height).WithField("hash", hash).Info("Creating snapshot")
// Context for snapshot work // Context for snapshot work
ctx, cancelCtx := context.WithCancel(context.Background()) ctx, cancelCtx := context.WithCancel(context.Background())
defer cancelCtx() defer cancelCtx()
// Cancel context on receiving a signal. On cancellation, all tracked iterators complete
// processing of their current node before stopping.
captureSignal(cancelCtx)
var err error var err error
tx := s.indexer.BeginTx(header.Number, ctx) tx := s.indexer.BeginTx(header.Number, ctx)
@ -106,11 +112,9 @@ func (s *Service) CreateSnapshot(params SnapshotParams) error {
return err return err
} }
tracker := tracker.New(s.recoveryFile, params.Workers) tr := prom.NewTracker(s.recoveryFile, params.Workers)
tracker.CaptureSignal(cancelCtx)
defer func() { defer func() {
err := tracker.HaltAndDump() err := tr.CloseAndSave()
if err != nil { if err != nil {
log.Errorf("failed to write recovery file: %v", err) log.Errorf("failed to write recovery file: %v", err)
} }
@ -128,7 +132,7 @@ func (s *Service) CreateSnapshot(params SnapshotParams) error {
return s.indexer.PushIPLD(tx, c) return s.indexer.PushIPLD(tx, c)
} }
// Build a diff compared against the zero hash to get a full snapshot // Build a diff against the zero hash (empty trie) to get a full snapshot
sdargs := statediff.Args{ sdargs := statediff.Args{
NewStateRoot: header.Root, NewStateRoot: header.Root,
BlockHash: header.Hash(), BlockHash: header.Hash(),
@ -140,7 +144,7 @@ func (s *Service) CreateSnapshot(params SnapshotParams) error {
sdparams.ComputeWatchedAddressesLeafPaths() sdparams.ComputeWatchedAddressesLeafPaths()
builder := statediff.NewBuilder(adapt.GethStateView(s.stateDB)) builder := statediff.NewBuilder(adapt.GethStateView(s.stateDB))
builder.SetSubtrieWorkers(params.Workers) builder.SetSubtrieWorkers(params.Workers)
if err = builder.WriteStateDiffTracked(sdargs, sdparams, nodeSink, ipldSink, &tracker); err != nil { if err = builder.WriteStateDiffTracked(sdargs, sdparams, nodeSink, ipldSink, tr); err != nil {
return err return err
} }
@ -160,3 +164,14 @@ func (s *Service) CreateLatestSnapshot(workers uint, watchedAddresses []common.A
} }
return s.CreateSnapshot(SnapshotParams{Height: *height, Workers: workers, WatchedAddresses: watchedAddresses}) return s.CreateSnapshot(SnapshotParams{Height: *height, Workers: workers, WatchedAddresses: watchedAddresses})
} }
func captureSignal(cb func()) {
sigChan := make(chan os.Signal, 1)
signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM)
go func() {
sig := <-sigChan
log.Errorf("Signal received (%v), stopping", sig)
cb()
}()
}