Rollback db transactions on an error and other fixes

This commit is contained in:
Prathamesh Musale 2022-01-19 10:33:21 +05:30
parent d7704d2f98
commit 355ad83b88
5 changed files with 56 additions and 33 deletions

View File

@ -78,25 +78,21 @@ func findIntersection(a, b []string) []string {
// loadWatchedAddresses is used to load watched addresses to the in-memory write loop params from the db // loadWatchedAddresses is used to load watched addresses to the in-memory write loop params from the db
func loadWatchedAddresses(db *postgres.DB) error { func loadWatchedAddresses(db *postgres.DB) error {
rows, err := db.Query("SELECT address FROM eth.watched_addresses") var watchedAddressStrings []string
pgStr := "SELECT address FROM eth.watched_addresses"
err := db.Select(&watchedAddressStrings, pgStr)
if err != nil { if err != nil {
return fmt.Errorf("error loading watched addresses: %v", err) return fmt.Errorf("error loading watched addresses: %v", err)
} }
var watchedAddresses []common.Address var watchedAddresses []common.Address
for rows.Next() { for _, watchedAddressString := range watchedAddressStrings {
var addressHex string watchedAddresses = append(watchedAddresses, common.HexToAddress(watchedAddressString))
err := rows.Scan(&addressHex)
if err != nil {
return err
} }
watchedAddresses = append(watchedAddresses, common.HexToAddress(addressHex)) writeLoopParams.Lock()
} defer writeLoopParams.Unlock()
writeLoopParams.mu.Lock()
writeLoopParams.WatchedAddresses = watchedAddresses writeLoopParams.WatchedAddresses = watchedAddresses
writeLoopParams.mu.Unlock()
return nil return nil
} }
@ -126,7 +122,7 @@ func containsAddress(addresses []common.Address, address common.Address) int {
} }
// getArgAddresses is used to get the list of addresses from a list of WatchAddressArgs // getArgAddresses is used to get the list of addresses from a list of WatchAddressArgs
func getArgAddresses(args []types.WatchAddressArg) []common.Address { func getAddresses(args []types.WatchAddressArg) []common.Address {
addresses := make([]common.Address, len(args)) addresses := make([]common.Address, len(args))
for idx, arg := range args { for idx, arg := range args {
addresses[idx] = arg.Address addresses[idx] = arg.Address

View File

@ -59,8 +59,11 @@ type Indexer interface {
PushStateNode(tx *BlockTx, stateNode sdtypes.StateNode) error PushStateNode(tx *BlockTx, stateNode sdtypes.StateNode) error
PushCodeAndCodeHash(tx *BlockTx, codeAndCodeHash sdtypes.CodeAndCodeHash) error PushCodeAndCodeHash(tx *BlockTx, codeAndCodeHash sdtypes.CodeAndCodeHash) error
ReportDBMetrics(delay time.Duration, quit <-chan bool) ReportDBMetrics(delay time.Duration, quit <-chan bool)
// Methods used by WatchAddress API/functionality.
InsertWatchedAddresses(addresses []sdtypes.WatchAddressArg, currentBlock *big.Int) error InsertWatchedAddresses(addresses []sdtypes.WatchAddressArg, currentBlock *big.Int) error
RemoveWatchedAddresses(addresses []common.Address) error RemoveWatchedAddresses(addresses []common.Address) error
SetWatchedAddresses(args []sdtypes.WatchAddressArg, currentBlockNumber *big.Int) error
ClearWatchedAddresses() error ClearWatchedAddresses() error
} }
@ -559,6 +562,7 @@ func (sdi *StateDiffIndexer) InsertWatchedAddresses(args []sdtypes.WatchAddressA
if err != nil { if err != nil {
return err return err
} }
defer tx.Rollback()
for _, arg := range args { for _, arg := range args {
_, err = tx.Exec(`INSERT INTO eth.watched_addresses (address, created_at, watched_at) VALUES ($1, $2, $3) ON CONFLICT (address) DO NOTHING`, _, err = tx.Exec(`INSERT INTO eth.watched_addresses (address, created_at, watched_at) VALUES ($1, $2, $3) ON CONFLICT (address) DO NOTHING`,
@ -582,6 +586,7 @@ func (sdi *StateDiffIndexer) RemoveWatchedAddresses(addresses []common.Address)
if err != nil { if err != nil {
return err return err
} }
defer tx.Rollback()
for _, address := range addresses { for _, address := range addresses {
_, err = tx.Exec(`DELETE FROM eth.watched_addresses WHERE address = $1`, address.Hex()) _, err = tx.Exec(`DELETE FROM eth.watched_addresses WHERE address = $1`, address.Hex())
@ -598,6 +603,34 @@ func (sdi *StateDiffIndexer) RemoveWatchedAddresses(addresses []common.Address)
return nil return nil
} }
func (sdi *StateDiffIndexer) SetWatchedAddresses(args []sdtypes.WatchAddressArg, currentBlockNumber *big.Int) error {
tx, err := sdi.dbWriter.db.Begin()
if err != nil {
return err
}
defer tx.Rollback()
_, err = tx.Exec(`DELETE FROM eth.watched_addresses`)
if err != nil {
return fmt.Errorf("error setting watched_addresses table: %v", err)
}
for _, arg := range args {
_, err = tx.Exec(`INSERT INTO eth.watched_addresses (address, created_at, watched_at) VALUES ($1, $2, $3) ON CONFLICT (address) DO NOTHING`,
arg.Address.Hex(), arg.CreatedAt, currentBlockNumber.Uint64())
if err != nil {
return fmt.Errorf("error setting watched_addresses table: %v", err)
}
}
err = tx.Commit()
if err != nil {
return err
}
return nil
}
// ClearWatchedAddresses clears all the addresses from the database // ClearWatchedAddresses clears all the addresses from the database
func (sdi *StateDiffIndexer) ClearWatchedAddresses() error { func (sdi *StateDiffIndexer) ClearWatchedAddresses() error {
_, err := sdi.dbWriter.db.Exec(`DELETE FROM eth.watched_addresses`) _, err := sdi.dbWriter.db.Exec(`DELETE FROM eth.watched_addresses`)

View File

@ -70,13 +70,13 @@ var statediffMetrics = RegisterStatediffMetrics(metrics.DefaultRegistry)
type blockChain interface { type blockChain interface {
SubscribeChainEvent(ch chan<- core.ChainEvent) event.Subscription SubscribeChainEvent(ch chan<- core.ChainEvent) event.Subscription
CurrentBlock() *types.Block
GetBlockByHash(hash common.Hash) *types.Block GetBlockByHash(hash common.Hash) *types.Block
GetBlockByNumber(number uint64) *types.Block GetBlockByNumber(number uint64) *types.Block
GetReceiptsByHash(hash common.Hash) types.Receipts GetReceiptsByHash(hash common.Hash) types.Receipts
GetTd(hash common.Hash, number uint64) *big.Int GetTd(hash common.Hash, number uint64) *big.Int
UnlockTrie(root common.Hash) UnlockTrie(root common.Hash)
StateCache() state.Database StateCache() state.Database
CurrentBlock() *types.Block
} }
// IService is the state-diffing service interface // IService is the state-diffing service interface
@ -108,7 +108,7 @@ type IService interface {
// Method to change the addresses being watched in write loop params // Method to change the addresses being watched in write loop params
WatchAddress(operation OperationType, args []WatchAddressArg) error WatchAddress(operation OperationType, args []WatchAddressArg) error
// Method to get currently watched addresses from write loop params // Method to get currently watched addresses from write loop params
GetWathchedAddresses() []common.Address GetWatchedAddresses() []common.Address
} }
// Wraps consructor parameters // Wraps consructor parameters
@ -294,9 +294,9 @@ func (sds *Service) WriteLoop(chainEventCh chan core.ChainEvent) {
func (sds *Service) writeGenesisStateDiff(currBlock *types.Block, workerId uint) { func (sds *Service) writeGenesisStateDiff(currBlock *types.Block, workerId uint) {
// For genesis block we need to return the entire state trie hence we diff it with an empty trie. // For genesis block we need to return the entire state trie hence we diff it with an empty trie.
log.Info("Writing state diff", "block height", genesisBlockNumber, "worker", workerId) log.Info("Writing state diff", "block height", genesisBlockNumber, "worker", workerId)
writeLoopParams.mu.RLock() writeLoopParams.RLock()
err := sds.writeStateDiffWithRetry(currBlock, common.Hash{}, writeLoopParams.Params) err := sds.writeStateDiffWithRetry(currBlock, common.Hash{}, writeLoopParams.Params)
writeLoopParams.mu.RUnlock() writeLoopParams.RUnlock()
if err != nil { if err != nil {
log.Error("statediff.Service.WriteLoop: processing error", "block height", log.Error("statediff.Service.WriteLoop: processing error", "block height",
genesisBlockNumber, "error", err.Error(), "worker", workerId) genesisBlockNumber, "error", err.Error(), "worker", workerId)
@ -325,9 +325,9 @@ func (sds *Service) writeLoopWorker(params workerParams) {
} }
log.Info("Writing state diff", "block height", currentBlock.Number().Uint64(), "worker", params.id) log.Info("Writing state diff", "block height", currentBlock.Number().Uint64(), "worker", params.id)
writeLoopParams.mu.RLock() writeLoopParams.RLock()
err := sds.writeStateDiffWithRetry(currentBlock, parentBlock.Root(), writeLoopParams.Params) err := sds.writeStateDiffWithRetry(currentBlock, parentBlock.Root(), writeLoopParams.Params)
writeLoopParams.mu.RUnlock() writeLoopParams.RUnlock()
if err != nil { if err != nil {
log.Error("statediff.Service.WriteLoop: processing error", "block height", currentBlock.Number().Uint64(), "error", err.Error(), "worker", params.id) log.Error("statediff.Service.WriteLoop: processing error", "block height", currentBlock.Number().Uint64(), "error", err.Error(), "worker", params.id)
continue continue
@ -737,8 +737,8 @@ func (sds *Service) writeStateDiffWithRetry(block *types.Block, parentRoot commo
// Performs Add | Remove | Set | Clear operation on the watched addresses in writeLoopParams and the db with provided addresses // Performs Add | Remove | Set | Clear operation on the watched addresses in writeLoopParams and the db with provided addresses
func (sds *Service) WatchAddress(operation OperationType, args []WatchAddressArg) error { func (sds *Service) WatchAddress(operation OperationType, args []WatchAddressArg) error {
// lock writeLoopParams for a write // lock writeLoopParams for a write
writeLoopParams.mu.Lock() writeLoopParams.Lock()
defer writeLoopParams.mu.Unlock() defer writeLoopParams.Unlock()
// get the current block number // get the current block number
currentBlockNumber := sds.BlockChain.CurrentBlock().Number() currentBlockNumber := sds.BlockChain.CurrentBlock().Number()
@ -750,7 +750,6 @@ func (sds *Service) WatchAddress(operation OperationType, args []WatchAddressArg
// Check if address is already being watched // Check if address is already being watched
// Throw a warning and continue if found // Throw a warning and continue if found
if containsAddress(writeLoopParams.WatchedAddresses, arg.Address) != -1 { if containsAddress(writeLoopParams.WatchedAddresses, arg.Address) != -1 {
// log.Warn(fmt.Sprint("Address ", address.Hex(), " already being watched"))
log.Warn("Address already being watched", "address", arg.Address.Hex()) log.Warn("Address already being watched", "address", arg.Address.Hex())
addressesToRemove = append(addressesToRemove, arg.Address) addressesToRemove = append(addressesToRemove, arg.Address)
continue continue
@ -767,7 +766,7 @@ func (sds *Service) WatchAddress(operation OperationType, args []WatchAddressArg
writeLoopParams.WatchedAddresses = append(writeLoopParams.WatchedAddresses, filteredAddresses...) writeLoopParams.WatchedAddresses = append(writeLoopParams.WatchedAddresses, filteredAddresses...)
case Remove: case Remove:
addresses := getArgAddresses(args) addresses := getAddresses(args)
err := sds.indexer.RemoveWatchedAddresses(addresses) err := sds.indexer.RemoveWatchedAddresses(addresses)
if err != nil { if err != nil {
@ -776,17 +775,12 @@ func (sds *Service) WatchAddress(operation OperationType, args []WatchAddressArg
writeLoopParams.WatchedAddresses = removeAddresses(writeLoopParams.WatchedAddresses, addresses) writeLoopParams.WatchedAddresses = removeAddresses(writeLoopParams.WatchedAddresses, addresses)
case Set: case Set:
err := sds.indexer.ClearWatchedAddresses() err := sds.indexer.SetWatchedAddresses(args, currentBlockNumber)
if err != nil { if err != nil {
return err return err
} }
err = sds.indexer.InsertWatchedAddresses(args, currentBlockNumber) addresses := getAddresses(args)
if err != nil {
return err
}
addresses := getArgAddresses(args)
writeLoopParams.WatchedAddresses = addresses writeLoopParams.WatchedAddresses = addresses
case Clear: case Clear:
err := sds.indexer.ClearWatchedAddresses() err := sds.indexer.ClearWatchedAddresses()
@ -803,6 +797,6 @@ func (sds *Service) WatchAddress(operation OperationType, args []WatchAddressArg
} }
// Gets currently watched addresses from the in-memory write loop params // Gets currently watched addresses from the in-memory write loop params
func (sds *Service) GetWathchedAddresses() []common.Address { func (sds *Service) GetWatchedAddresses() []common.Address {
return writeLoopParams.WatchedAddresses return writeLoopParams.WatchedAddresses
} }

View File

@ -337,6 +337,6 @@ func (sds *MockStateDiffService) WatchAddress(operation statediff.OperationType,
return nil return nil
} }
func (sds *MockStateDiffService) GetWathchedAddresses() []common.Address { func (sds *MockStateDiffService) GetWatchedAddresses() []common.Address {
return []common.Address{} return []common.Address{}
} }

View File

@ -57,7 +57,7 @@ type Params struct {
// ParamsWithMutex allows to lock the parameters while they are being updated | read from // ParamsWithMutex allows to lock the parameters while they are being updated | read from
type ParamsWithMutex struct { type ParamsWithMutex struct {
Params Params
mu sync.RWMutex sync.RWMutex
} }
// Args bundles the arguments for the state diff builder // Args bundles the arguments for the state diff builder