diff --git a/statediff/indexer/database/sql/batch_tx.go b/statediff/indexer/database/sql/batch_tx.go index 06bb49c9e..f5cbdeadc 100644 --- a/statediff/indexer/database/sql/batch_tx.go +++ b/statediff/indexer/database/sql/batch_tx.go @@ -18,6 +18,7 @@ package sql import ( "context" + "sync" "sync/atomic" blockstore "github.com/ipfs/go-ipfs-blockstore" @@ -42,6 +43,8 @@ type BatchTx struct { iplds chan models.IPLDModel ipldCache models.IPLDBatch removedCacheFlag *uint32 + // Tracks expected cache size and ensures cache is caught up before flush + cacheWg sync.WaitGroup submit func(blockTx *BatchTx, err error) error } @@ -52,6 +55,7 @@ func (tx *BatchTx) Submit(err error) error { } func (tx *BatchTx) flush() error { + tx.cacheWg.Wait() _, err := tx.dbtx.Exec(tx.ctx, tx.stm, pq.Array(tx.ipldCache.BlockNumbers), pq.Array(tx.ipldCache.Keys), pq.Array(tx.ipldCache.Values)) if err != nil { @@ -69,6 +73,7 @@ func (tx *BatchTx) cache() { tx.ipldCache.BlockNumbers = append(tx.ipldCache.BlockNumbers, i.BlockNumber) tx.ipldCache.Keys = append(tx.ipldCache.Keys, i.Key) tx.ipldCache.Values = append(tx.ipldCache.Values, i.Data) + tx.cacheWg.Done() case <-tx.quit: tx.ipldCache = models.IPLDBatch{} return @@ -77,6 +82,7 @@ func (tx *BatchTx) cache() { } func (tx *BatchTx) cacheDirect(key string, value []byte) { + tx.cacheWg.Add(1) tx.iplds <- models.IPLDModel{ BlockNumber: tx.BlockNumber, Key: key, @@ -85,6 +91,7 @@ func (tx *BatchTx) cacheDirect(key string, value []byte) { } func (tx *BatchTx) cacheIPLD(i node.Node) { + tx.cacheWg.Add(1) tx.iplds <- models.IPLDModel{ BlockNumber: tx.BlockNumber, Key: blockstore.BlockPrefix.String() + dshelp.MultihashToDsKey(i.Cid().Hash()).String(), @@ -98,6 +105,7 @@ func (tx *BatchTx) cacheRaw(codec, mh uint64, raw []byte) (string, string, error return "", "", err } prefixedKey := blockstore.BlockPrefix.String() + dshelp.MultihashToDsKey(c.Hash()).String() + tx.cacheWg.Add(1) tx.iplds <- models.IPLDModel{ BlockNumber: tx.BlockNumber, Key: prefixedKey, @@ -109,6 +117,7 @@ func (tx *BatchTx) cacheRaw(codec, mh uint64, raw []byte) (string, string, error func (tx *BatchTx) cacheRemoved(key string, value []byte) { if atomic.LoadUint32(tx.removedCacheFlag) == 0 { atomic.StoreUint32(tx.removedCacheFlag, 1) + tx.cacheWg.Add(1) tx.iplds <- models.IPLDModel{ BlockNumber: tx.BlockNumber, Key: key, diff --git a/statediff/indexer/database/sql/indexer.go b/statediff/indexer/database/sql/indexer.go index 762107ee5..9e23405a0 100644 --- a/statediff/indexer/database/sql/indexer.go +++ b/statediff/indexer/database/sql/indexer.go @@ -122,11 +122,8 @@ func (sdi *StateDiffIndexer) PushBlock(block *types.Block, receipts types.Receip } t = time.Now() - // Begin new db tx for everything - tx, err := sdi.dbWriter.db.Begin(sdi.ctx) - if err != nil { - return nil, err - } + // Begin new DB tx for everything + tx := NewDelayedTx(sdi.dbWriter.db) defer func() { if p := recover(); p != nil { rollback(sdi.ctx, tx) @@ -589,11 +586,8 @@ func (sdi *StateDiffIndexer) LoadWatchedAddresses() ([]common.Address, error) { } // InsertWatchedAddresses inserts the given addresses in the database -func (sdi *StateDiffIndexer) InsertWatchedAddresses(args []sdtypes.WatchAddressArg, currentBlockNumber *big.Int) error { - tx, err := sdi.dbWriter.db.Begin(sdi.ctx) - if err != nil { - return err - } +func (sdi *StateDiffIndexer) InsertWatchedAddresses(args []sdtypes.WatchAddressArg, currentBlockNumber *big.Int) (err error) { + tx := NewDelayedTx(sdi.dbWriter.db) defer func() { if p := recover(); p != nil { rollback(sdi.ctx, tx) @@ -617,11 +611,8 @@ func (sdi *StateDiffIndexer) InsertWatchedAddresses(args []sdtypes.WatchAddressA } // RemoveWatchedAddresses removes the given watched addresses from the database -func (sdi *StateDiffIndexer) RemoveWatchedAddresses(args []sdtypes.WatchAddressArg) error { - tx, err := sdi.dbWriter.db.Begin(sdi.ctx) - if err != nil { - return err - } +func (sdi *StateDiffIndexer) RemoveWatchedAddresses(args []sdtypes.WatchAddressArg) (err error) { + tx := NewDelayedTx(sdi.dbWriter.db) defer func() { if p := recover(); p != nil { rollback(sdi.ctx, tx) @@ -644,11 +635,8 @@ func (sdi *StateDiffIndexer) RemoveWatchedAddresses(args []sdtypes.WatchAddressA } // SetWatchedAddresses clears and inserts the given addresses in the database -func (sdi *StateDiffIndexer) SetWatchedAddresses(args []sdtypes.WatchAddressArg, currentBlockNumber *big.Int) error { - tx, err := sdi.dbWriter.db.Begin(sdi.ctx) - if err != nil { - return err - } +func (sdi *StateDiffIndexer) SetWatchedAddresses(args []sdtypes.WatchAddressArg, currentBlockNumber *big.Int) (err error) { + tx := NewDelayedTx(sdi.dbWriter.db) defer func() { if p := recover(); p != nil { rollback(sdi.ctx, tx) diff --git a/statediff/indexer/database/sql/lazy_tx.go b/statediff/indexer/database/sql/lazy_tx.go new file mode 100644 index 000000000..922bf84a0 --- /dev/null +++ b/statediff/indexer/database/sql/lazy_tx.go @@ -0,0 +1,55 @@ +package sql + +import ( + "context" +) + +type DelayedTx struct { + cache []cachedStmt + db Database +} +type cachedStmt struct { + sql string + args []interface{} +} + +func NewDelayedTx(db Database) *DelayedTx { + return &DelayedTx{db: db} +} + +func (tx *DelayedTx) QueryRow(ctx context.Context, sql string, args ...interface{}) ScannableRow { + return tx.db.QueryRow(ctx, sql, args...) +} + +func (tx *DelayedTx) Exec(ctx context.Context, sql string, args ...interface{}) (Result, error) { + tx.cache = append(tx.cache, cachedStmt{sql, args}) + return nil, nil +} + +func (tx *DelayedTx) Commit(ctx context.Context) error { + base, err := tx.db.Begin(ctx) + if err != nil { + return err + } + defer func() { + if p := recover(); p != nil { + rollback(ctx, base) + panic(p) + } else if err != nil { + rollback(ctx, base) + } + }() + for _, stmt := range tx.cache { + _, err := base.Exec(ctx, stmt.sql, stmt.args...) + if err != nil { + return err + } + } + tx.cache = nil + return base.Commit(ctx) +} + +func (tx *DelayedTx) Rollback(ctx context.Context) error { + tx.cache = nil + return nil +}