Fix rollback call on error.

This commit is contained in:
Thomas E Lackey 2023-05-31 13:47:34 -05:00
parent 71f764bafa
commit 49abb1a2fa
4 changed files with 25 additions and 24 deletions

View File

@ -116,14 +116,14 @@ func (sdi *StateDiffIndexer) PushBlock(block *types.Block, receipts types.Receip
// Begin new DB tx for everything // Begin new DB tx for everything
tx := NewDelayedTx(sdi.dbWriter.db) tx := NewDelayedTx(sdi.dbWriter.db)
defer func() { defer func(e *error) {
if p := recover(); p != nil { if p := recover(); p != nil {
rollback(sdi.ctx, tx) rollback(sdi.ctx, tx)
panic(p) panic(p)
} else if err != nil { } else if e != nil && *e != nil {
rollback(sdi.ctx, tx) rollback(sdi.ctx, tx)
} }
}() }(&err)
blockTx := &BatchTx{ blockTx := &BatchTx{
removedCacheFlag: new(uint32), removedCacheFlag: new(uint32),
ctx: sdi.ctx, ctx: sdi.ctx,
@ -496,16 +496,16 @@ func (sdi *StateDiffIndexer) LoadWatchedAddresses() ([]common.Address, error) {
// InsertWatchedAddresses inserts the given addresses in the database // InsertWatchedAddresses inserts the given addresses in the database
func (sdi *StateDiffIndexer) InsertWatchedAddresses(args []sdtypes.WatchAddressArg, currentBlockNumber *big.Int) (err error) { func (sdi *StateDiffIndexer) InsertWatchedAddresses(args []sdtypes.WatchAddressArg, currentBlockNumber *big.Int) (err error) {
tx := NewDelayedTx(sdi.dbWriter.db) tx := NewDelayedTx(sdi.dbWriter.db)
defer func() { defer func(e *error) {
if p := recover(); p != nil { if p := recover(); p != nil {
rollback(sdi.ctx, tx) rollback(sdi.ctx, tx)
panic(p) panic(p)
} else if err != nil { } else if e != nil && *e != nil {
rollback(sdi.ctx, tx) rollback(sdi.ctx, tx)
} else { } else {
err = tx.Commit(sdi.ctx) err = tx.Commit(sdi.ctx)
} }
}() }(&err)
for _, arg := range args { for _, arg := range args {
_, err = tx.Exec(sdi.ctx, `INSERT INTO eth_meta.watched_addresses (address, created_at, watched_at) VALUES ($1, $2, $3) ON CONFLICT (address) DO NOTHING`, _, err = tx.Exec(sdi.ctx, `INSERT INTO eth_meta.watched_addresses (address, created_at, watched_at) VALUES ($1, $2, $3) ON CONFLICT (address) DO NOTHING`,
@ -521,16 +521,16 @@ func (sdi *StateDiffIndexer) InsertWatchedAddresses(args []sdtypes.WatchAddressA
// RemoveWatchedAddresses removes the given watched addresses from the database // RemoveWatchedAddresses removes the given watched addresses from the database
func (sdi *StateDiffIndexer) RemoveWatchedAddresses(args []sdtypes.WatchAddressArg) (err error) { func (sdi *StateDiffIndexer) RemoveWatchedAddresses(args []sdtypes.WatchAddressArg) (err error) {
tx := NewDelayedTx(sdi.dbWriter.db) tx := NewDelayedTx(sdi.dbWriter.db)
defer func() { defer func(e *error) {
if p := recover(); p != nil { if p := recover(); p != nil {
rollback(sdi.ctx, tx) rollback(sdi.ctx, tx)
panic(p) panic(p)
} else if err != nil { } else if e != nil && *e != nil {
rollback(sdi.ctx, tx) rollback(sdi.ctx, tx)
} else { } else {
err = tx.Commit(sdi.ctx) err = tx.Commit(sdi.ctx)
} }
}() }(&err)
for _, arg := range args { for _, arg := range args {
_, err = tx.Exec(sdi.ctx, `DELETE FROM eth_meta.watched_addresses WHERE address = $1`, arg.Address) _, err = tx.Exec(sdi.ctx, `DELETE FROM eth_meta.watched_addresses WHERE address = $1`, arg.Address)
@ -545,16 +545,16 @@ func (sdi *StateDiffIndexer) RemoveWatchedAddresses(args []sdtypes.WatchAddressA
// SetWatchedAddresses clears and inserts the given addresses in the database // SetWatchedAddresses clears and inserts the given addresses in the database
func (sdi *StateDiffIndexer) SetWatchedAddresses(args []sdtypes.WatchAddressArg, currentBlockNumber *big.Int) (err error) { func (sdi *StateDiffIndexer) SetWatchedAddresses(args []sdtypes.WatchAddressArg, currentBlockNumber *big.Int) (err error) {
tx := NewDelayedTx(sdi.dbWriter.db) tx := NewDelayedTx(sdi.dbWriter.db)
defer func() { defer func(e *error) {
if p := recover(); p != nil { if p := recover(); p != nil {
rollback(sdi.ctx, tx) rollback(sdi.ctx, tx)
panic(p) panic(p)
} else if err != nil { } else if e != nil && *e != nil {
rollback(sdi.ctx, tx) rollback(sdi.ctx, tx)
} else { } else {
err = tx.Commit(sdi.ctx) err = tx.Commit(sdi.ctx)
} }
}() }(&err)
_, err = tx.Exec(sdi.ctx, `DELETE FROM eth_meta.watched_addresses`) _, err = tx.Exec(sdi.ctx, `DELETE FROM eth_meta.watched_addresses`)
if err != nil { if err != nil {

View File

@ -73,24 +73,24 @@ func (tx *DelayedTx) Commit(ctx context.Context) error {
if err != nil { if err != nil {
return err return err
} }
defer func() { defer func(e *error) {
if p := recover(); p != nil { if p := recover(); p != nil {
rollback(ctx, base) rollback(ctx, base)
panic(p) panic(p)
} else if err != nil { } else if e != nil && *e != nil {
rollback(ctx, base) rollback(ctx, base)
} }
}() }(&err)
for _, item := range tx.cache { for _, item := range tx.cache {
switch item := item.(type) { switch item := item.(type) {
case *copyFrom: case *copyFrom:
_, err := base.CopyFrom(ctx, item.tableName, item.columnNames, item.rows) _, err = base.CopyFrom(ctx, item.tableName, item.columnNames, item.rows)
if err != nil { if err != nil {
log.Error("COPY error", "table", item.tableName, "err", err) log.Error("COPY error", "table", item.tableName, "err", err)
return err return err
} }
case cachedStmt: case cachedStmt:
_, err := base.Exec(ctx, item.sql, item.args...) _, err = base.Exec(ctx, item.sql, item.args...)
if err != nil { if err != nil {
return err return err
} }

View File

@ -39,18 +39,19 @@ func countStateDiffBegin(block *types.Block) (time.Time, log.Logger) {
return start, logger return start, logger
} }
func countStateDiffEnd(start time.Time, logger log.Logger, err error) time.Duration { func countStateDiffEnd(start time.Time, logger log.Logger, err *error) time.Duration {
duration := time.Since(start) duration := time.Since(start)
defaultStatediffMetrics.underway.Dec(1) defaultStatediffMetrics.underway.Dec(1)
if nil == err { failed := nil != err && nil != *err
defaultStatediffMetrics.succeeded.Inc(1) if failed {
} else {
defaultStatediffMetrics.failed.Inc(1) defaultStatediffMetrics.failed.Inc(1)
} else {
defaultStatediffMetrics.succeeded.Inc(1)
} }
defaultStatediffMetrics.totalProcessingTime.Inc(duration.Milliseconds()) defaultStatediffMetrics.totalProcessingTime.Inc(duration.Milliseconds())
logger.Debug(fmt.Sprintf("writeStateDiff END (duration=%dms, err=%t) [underway=%d, succeeded=%d, failed=%d, total_time=%dms]", logger.Debug(fmt.Sprintf("writeStateDiff END (duration=%dms, err=%t) [underway=%d, succeeded=%d, failed=%d, total_time=%dms]",
duration.Milliseconds(), nil != err, duration.Milliseconds(), failed,
defaultStatediffMetrics.underway.Count(), defaultStatediffMetrics.underway.Count(),
defaultStatediffMetrics.succeeded.Count(), defaultStatediffMetrics.succeeded.Count(),
defaultStatediffMetrics.failed.Count(), defaultStatediffMetrics.failed.Count(),

View File

@ -815,7 +815,7 @@ func (sds *Service) writeStateDiff(block *types.Block, parentRoot common.Hash, p
var err error var err error
var tx interfaces.Batch var tx interfaces.Batch
start, logger := countStateDiffBegin(block) start, logger := countStateDiffBegin(block)
defer countStateDiffEnd(start, logger, err) defer countStateDiffEnd(start, logger, &err)
if params.IncludeTD { if params.IncludeTD {
totalDifficulty = sds.BlockChain.GetTd(block.Hash(), block.NumberU64()) totalDifficulty = sds.BlockChain.GetTd(block.Hash(), block.NumberU64())
} }
@ -847,7 +847,7 @@ func (sds *Service) writeStateDiff(block *types.Block, parentRoot common.Hash, p
BlockNumber: block.Number(), BlockNumber: block.Number(),
}, params, output, ipldOutput) }, params, output, ipldOutput)
// TODO this anti-pattern needs to be sorted out eventually // TODO this anti-pattern needs to be sorted out eventually
if err := tx.Submit(err); err != nil { if err = tx.Submit(err); err != nil {
return fmt.Errorf("batch transaction submission failed: %w", err) return fmt.Errorf("batch transaction submission failed: %w", err)
} }