From 355ad83b88e534262087f27e3f53dfc98ed28f9f Mon Sep 17 00:00:00 2001 From: prathamesh0 Date: Wed, 19 Jan 2022 10:33:21 +0530 Subject: [PATCH] Rollback db transactions on an error and other fixes --- statediff/helpers.go | 20 ++++++--------- statediff/indexer/indexer.go | 35 +++++++++++++++++++++++++- statediff/service.go | 30 +++++++++------------- statediff/testhelpers/mocks/service.go | 2 +- statediff/types.go | 2 +- 5 files changed, 56 insertions(+), 33 deletions(-) diff --git a/statediff/helpers.go b/statediff/helpers.go index f97f34062..1f8523fc3 100644 --- a/statediff/helpers.go +++ b/statediff/helpers.go @@ -78,25 +78,21 @@ func findIntersection(a, b []string) []string { // loadWatchedAddresses is used to load watched addresses to the in-memory write loop params from the db func loadWatchedAddresses(db *postgres.DB) error { - rows, err := db.Query("SELECT address FROM eth.watched_addresses") + var watchedAddressStrings []string + pgStr := "SELECT address FROM eth.watched_addresses" + err := db.Select(&watchedAddressStrings, pgStr) if err != nil { return fmt.Errorf("error loading watched addresses: %v", err) } var watchedAddresses []common.Address - for rows.Next() { - var addressHex string - err := rows.Scan(&addressHex) - if err != nil { - return err - } - - watchedAddresses = append(watchedAddresses, common.HexToAddress(addressHex)) + for _, watchedAddressString := range watchedAddressStrings { + watchedAddresses = append(watchedAddresses, common.HexToAddress(watchedAddressString)) } - writeLoopParams.mu.Lock() + writeLoopParams.Lock() + defer writeLoopParams.Unlock() writeLoopParams.WatchedAddresses = watchedAddresses - writeLoopParams.mu.Unlock() return nil } @@ -126,7 +122,7 @@ func containsAddress(addresses []common.Address, address common.Address) int { } // getArgAddresses is used to get the list of addresses from a list of WatchAddressArgs -func getArgAddresses(args []types.WatchAddressArg) []common.Address { +func getAddresses(args []types.WatchAddressArg) []common.Address { addresses := make([]common.Address, len(args)) for idx, arg := range args { addresses[idx] = arg.Address diff --git a/statediff/indexer/indexer.go b/statediff/indexer/indexer.go index cda12e4bb..c85d00b2e 100644 --- a/statediff/indexer/indexer.go +++ b/statediff/indexer/indexer.go @@ -59,8 +59,11 @@ type Indexer interface { PushStateNode(tx *BlockTx, stateNode sdtypes.StateNode) error PushCodeAndCodeHash(tx *BlockTx, codeAndCodeHash sdtypes.CodeAndCodeHash) error ReportDBMetrics(delay time.Duration, quit <-chan bool) + + // Methods used by WatchAddress API/functionality. InsertWatchedAddresses(addresses []sdtypes.WatchAddressArg, currentBlock *big.Int) error RemoveWatchedAddresses(addresses []common.Address) error + SetWatchedAddresses(args []sdtypes.WatchAddressArg, currentBlockNumber *big.Int) error ClearWatchedAddresses() error } @@ -559,9 +562,10 @@ func (sdi *StateDiffIndexer) InsertWatchedAddresses(args []sdtypes.WatchAddressA if err != nil { return err } + defer tx.Rollback() for _, arg := range args { - _, err = tx.Exec(`INSERT INTO eth.watched_addresses (address, created_at, watched_at)VALUES ($1, $2, $3) ON CONFLICT (address) DO NOTHING`, + _, err = tx.Exec(`INSERT INTO eth.watched_addresses (address, created_at, watched_at) VALUES ($1, $2, $3) ON CONFLICT (address) DO NOTHING`, arg.Address.Hex(), arg.CreatedAt, currentBlockNumber.Uint64()) if err != nil { return fmt.Errorf("error inserting watched_addresses entry: %v", err) @@ -582,6 +586,7 @@ func (sdi *StateDiffIndexer) RemoveWatchedAddresses(addresses []common.Address) if err != nil { return err } + defer tx.Rollback() for _, address := range addresses { _, err = tx.Exec(`DELETE FROM eth.watched_addresses WHERE address = $1`, address.Hex()) @@ -598,6 +603,34 @@ func (sdi *StateDiffIndexer) RemoveWatchedAddresses(addresses []common.Address) return nil } +func (sdi *StateDiffIndexer) SetWatchedAddresses(args []sdtypes.WatchAddressArg, currentBlockNumber *big.Int) error { + tx, err := sdi.dbWriter.db.Begin() + if err != nil { + return err + } + defer tx.Rollback() + + _, err = tx.Exec(`DELETE FROM eth.watched_addresses`) + if err != nil { + return fmt.Errorf("error setting watched_addresses table: %v", err) + } + + for _, arg := range args { + _, err = tx.Exec(`INSERT INTO eth.watched_addresses (address, created_at, watched_at) VALUES ($1, $2, $3) ON CONFLICT (address) DO NOTHING`, + arg.Address.Hex(), arg.CreatedAt, currentBlockNumber.Uint64()) + if err != nil { + return fmt.Errorf("error setting watched_addresses table: %v", err) + } + } + + err = tx.Commit() + if err != nil { + return err + } + + return nil +} + // ClearWatchedAddresses clears all the addresses from the database func (sdi *StateDiffIndexer) ClearWatchedAddresses() error { _, err := sdi.dbWriter.db.Exec(`DELETE FROM eth.watched_addresses`) diff --git a/statediff/service.go b/statediff/service.go index cc0059f00..17727f14a 100644 --- a/statediff/service.go +++ b/statediff/service.go @@ -70,13 +70,13 @@ var statediffMetrics = RegisterStatediffMetrics(metrics.DefaultRegistry) type blockChain interface { SubscribeChainEvent(ch chan<- core.ChainEvent) event.Subscription + CurrentBlock() *types.Block GetBlockByHash(hash common.Hash) *types.Block GetBlockByNumber(number uint64) *types.Block GetReceiptsByHash(hash common.Hash) types.Receipts GetTd(hash common.Hash, number uint64) *big.Int UnlockTrie(root common.Hash) StateCache() state.Database - CurrentBlock() *types.Block } // IService is the state-diffing service interface @@ -108,7 +108,7 @@ type IService interface { // Method to change the addresses being watched in write loop params WatchAddress(operation OperationType, args []WatchAddressArg) error // Method to get currently watched addresses from write loop params - GetWathchedAddresses() []common.Address + GetWatchedAddresses() []common.Address } // Wraps consructor parameters @@ -294,9 +294,9 @@ func (sds *Service) WriteLoop(chainEventCh chan core.ChainEvent) { func (sds *Service) writeGenesisStateDiff(currBlock *types.Block, workerId uint) { // For genesis block we need to return the entire state trie hence we diff it with an empty trie. log.Info("Writing state diff", "block height", genesisBlockNumber, "worker", workerId) - writeLoopParams.mu.RLock() + writeLoopParams.RLock() err := sds.writeStateDiffWithRetry(currBlock, common.Hash{}, writeLoopParams.Params) - writeLoopParams.mu.RUnlock() + writeLoopParams.RUnlock() if err != nil { log.Error("statediff.Service.WriteLoop: processing error", "block height", genesisBlockNumber, "error", err.Error(), "worker", workerId) @@ -325,9 +325,9 @@ func (sds *Service) writeLoopWorker(params workerParams) { } log.Info("Writing state diff", "block height", currentBlock.Number().Uint64(), "worker", params.id) - writeLoopParams.mu.RLock() + writeLoopParams.RLock() err := sds.writeStateDiffWithRetry(currentBlock, parentBlock.Root(), writeLoopParams.Params) - writeLoopParams.mu.RUnlock() + writeLoopParams.RUnlock() if err != nil { log.Error("statediff.Service.WriteLoop: processing error", "block height", currentBlock.Number().Uint64(), "error", err.Error(), "worker", params.id) continue @@ -737,8 +737,8 @@ func (sds *Service) writeStateDiffWithRetry(block *types.Block, parentRoot commo // Performs Add | Remove | Set | Clear operation on the watched addresses in writeLoopParams and the db with provided addresses func (sds *Service) WatchAddress(operation OperationType, args []WatchAddressArg) error { // lock writeLoopParams for a write - writeLoopParams.mu.Lock() - defer writeLoopParams.mu.Unlock() + writeLoopParams.Lock() + defer writeLoopParams.Unlock() // get the current block number currentBlockNumber := sds.BlockChain.CurrentBlock().Number() @@ -750,7 +750,6 @@ func (sds *Service) WatchAddress(operation OperationType, args []WatchAddressArg // Check if address is already being watched // Throw a warning and continue if found if containsAddress(writeLoopParams.WatchedAddresses, arg.Address) != -1 { - // log.Warn(fmt.Sprint("Address ", address.Hex(), " already being watched")) log.Warn("Address already being watched", "address", arg.Address.Hex()) addressesToRemove = append(addressesToRemove, arg.Address) continue @@ -767,7 +766,7 @@ func (sds *Service) WatchAddress(operation OperationType, args []WatchAddressArg writeLoopParams.WatchedAddresses = append(writeLoopParams.WatchedAddresses, filteredAddresses...) case Remove: - addresses := getArgAddresses(args) + addresses := getAddresses(args) err := sds.indexer.RemoveWatchedAddresses(addresses) if err != nil { @@ -776,17 +775,12 @@ func (sds *Service) WatchAddress(operation OperationType, args []WatchAddressArg writeLoopParams.WatchedAddresses = removeAddresses(writeLoopParams.WatchedAddresses, addresses) case Set: - err := sds.indexer.ClearWatchedAddresses() + err := sds.indexer.SetWatchedAddresses(args, currentBlockNumber) if err != nil { return err } - err = sds.indexer.InsertWatchedAddresses(args, currentBlockNumber) - if err != nil { - return err - } - - addresses := getArgAddresses(args) + addresses := getAddresses(args) writeLoopParams.WatchedAddresses = addresses case Clear: err := sds.indexer.ClearWatchedAddresses() @@ -803,6 +797,6 @@ func (sds *Service) WatchAddress(operation OperationType, args []WatchAddressArg } // Gets currently watched addresses from the in-memory write loop params -func (sds *Service) GetWathchedAddresses() []common.Address { +func (sds *Service) GetWatchedAddresses() []common.Address { return writeLoopParams.WatchedAddresses } diff --git a/statediff/testhelpers/mocks/service.go b/statediff/testhelpers/mocks/service.go index e2f4d3cb9..686f0c2ba 100644 --- a/statediff/testhelpers/mocks/service.go +++ b/statediff/testhelpers/mocks/service.go @@ -337,6 +337,6 @@ func (sds *MockStateDiffService) WatchAddress(operation statediff.OperationType, return nil } -func (sds *MockStateDiffService) GetWathchedAddresses() []common.Address { +func (sds *MockStateDiffService) GetWatchedAddresses() []common.Address { return []common.Address{} } diff --git a/statediff/types.go b/statediff/types.go index d17d11067..b179f1a00 100644 --- a/statediff/types.go +++ b/statediff/types.go @@ -57,7 +57,7 @@ type Params struct { // ParamsWithMutex allows to lock the parameters while they are being updated | read from type ParamsWithMutex struct { Params - mu sync.RWMutex + sync.RWMutex } // Args bundles the arguments for the state diff builder