Use current block and mutex while changing watched addresses

This commit is contained in:
Prathamesh Musale 2022-01-17 13:50:52 +05:30
parent 5fa002c0d0
commit bbb1759886
6 changed files with 73 additions and 43 deletions

View File

@ -150,7 +150,7 @@ func (api *PublicStateDiffAPI) WriteStateDiffFor(ctx context.Context, blockHash
return api.sds.WriteStateDiffFor(blockHash, params) return api.sds.WriteStateDiffFor(blockHash, params)
} }
// WatchAddress adds the given address to a list of watched addresses to which the direct statediff process is restricted // WatchAddress changes the list of watched addresses to which the direct indexing is restricted according to given operation
func (api *PublicStateDiffAPI) WatchAddress(operation OperationType, addresses []common.Address) error { func (api *PublicStateDiffAPI) WatchAddress(operation OperationType, addresses []common.Address) error {
return api.sds.WatchAddress(operation, addresses) return api.sds.WatchAddress(operation, addresses)
} }

View File

@ -93,28 +93,31 @@ func loadWatchedAddresses(db *postgres.DB) error {
watchedAddresses = append(watchedAddresses, common.HexToAddress(addressHex)) watchedAddresses = append(watchedAddresses, common.HexToAddress(addressHex))
} }
writeLoopParams.mu.Lock()
writeLoopParams.WatchedAddresses = watchedAddresses writeLoopParams.WatchedAddresses = watchedAddresses
writeLoopParams.mu.Unlock()
return nil return nil
} }
func removeWatchedAddresses(watchedAddresses []common.Address, addressesToRemove []common.Address) []common.Address { // removeAddresses is used to remove given addresses from a list of addresses
addresses := make([]common.Address, len(addressesToRemove)) func removeAddresses(addresses []common.Address, addressesToRemove []common.Address) []common.Address {
copy(addresses, watchedAddresses) addressesCopy := make([]common.Address, len(addresses))
copy(addressesCopy, addresses)
for _, address := range addressesToRemove { for _, address := range addressesToRemove {
if idx := containsAddress(addresses, address); idx != -1 { if idx := containsAddress(addressesCopy, address); idx != -1 {
addresses = append(addresses[:idx], addresses[idx+1:]...) addressesCopy = append(addressesCopy[:idx], addressesCopy[idx+1:]...)
} }
} }
return addresses return addressesCopy
} }
// containsAddress is used to check if an address is present in the provided list of watched addresses // containsAddress is used to check if an address is present in the provided list of addresses
// return the index if found else -1 // return the index if found else -1
func containsAddress(watchedAddresses []common.Address, address common.Address) int { func containsAddress(addresses []common.Address, address common.Address) int {
for idx, addr := range watchedAddresses { for idx, addr := range addresses {
if addr == address { if addr == address {
return idx return idx
} }

View File

@ -553,6 +553,7 @@ func (sdi *StateDiffIndexer) PushCodeAndCodeHash(tx *BlockTx, codeAndCodeHash sd
return nil return nil
} }
// InsertWatchedAddresses inserts the given addresses in the database
func (sdi *StateDiffIndexer) InsertWatchedAddresses(addresses []common.Address, currentBlock *big.Int) error { func (sdi *StateDiffIndexer) InsertWatchedAddresses(addresses []common.Address, currentBlock *big.Int) error {
tx, err := sdi.dbWriter.db.Begin() tx, err := sdi.dbWriter.db.Begin()
if err != nil { if err != nil {
@ -560,7 +561,8 @@ func (sdi *StateDiffIndexer) InsertWatchedAddresses(addresses []common.Address,
} }
for _, address := range addresses { for _, address := range addresses {
_, err = tx.Exec(`INSERT INTO eth.watched_addresses (address, added_at) VALUES ($1, $2)`, address, currentBlock) _, err = tx.Exec(`INSERT INTO eth.watched_addresses (address, added_at)VALUES ($1, $2) ON CONFLICT (address) DO NOTHING`,
address.Hex(), currentBlock.Uint64())
if err != nil { if err != nil {
return fmt.Errorf("error inserting watched_addresses entry: %v", err) return fmt.Errorf("error inserting watched_addresses entry: %v", err)
} }
@ -574,6 +576,7 @@ func (sdi *StateDiffIndexer) InsertWatchedAddresses(addresses []common.Address,
return nil return nil
} }
// RemoveWatchedAddresses removes the given addresses from the database
func (sdi *StateDiffIndexer) RemoveWatchedAddresses(addresses []common.Address) error { func (sdi *StateDiffIndexer) RemoveWatchedAddresses(addresses []common.Address) error {
tx, err := sdi.dbWriter.db.Begin() tx, err := sdi.dbWriter.db.Begin()
if err != nil { if err != nil {
@ -581,7 +584,7 @@ func (sdi *StateDiffIndexer) RemoveWatchedAddresses(addresses []common.Address)
} }
for _, address := range addresses { for _, address := range addresses {
_, err = tx.Exec(`DELETE FROM eth.watched_addresses WHERE address = $1`, address) _, err = tx.Exec(`DELETE FROM eth.watched_addresses WHERE address = $1`, address.Hex())
if err != nil { if err != nil {
return fmt.Errorf("error removing watched_addresses entry: %v", err) return fmt.Errorf("error removing watched_addresses entry: %v", err)
} }
@ -595,21 +598,12 @@ func (sdi *StateDiffIndexer) RemoveWatchedAddresses(addresses []common.Address)
return nil return nil
} }
// ClearWatchedAddresses clears all the addresses from the database
func (sdi *StateDiffIndexer) ClearWatchedAddresses() error { func (sdi *StateDiffIndexer) ClearWatchedAddresses() error {
tx, err := sdi.dbWriter.db.Begin() _, err := sdi.dbWriter.db.Exec(`DELETE FROM eth.watched_addresses`)
if err != nil {
return err
}
_, err = tx.Exec(`DELETE FROM eth.watched_addresses`)
if err != nil { if err != nil {
return fmt.Errorf("error clearing watched_addresses table: %v", err) return fmt.Errorf("error clearing watched_addresses table: %v", err)
} }
err = tx.Commit()
if err != nil {
return err
}
return nil return nil
} }

View File

@ -55,13 +55,15 @@ const (
deadlockDetected = "deadlock detected" // 40P01 https://www.postgresql.org/docs/current/errcodes-appendix.html deadlockDetected = "deadlock detected" // 40P01 https://www.postgresql.org/docs/current/errcodes-appendix.html
) )
var writeLoopParams = Params{ var writeLoopParams = ParamsWithMutex{
IntermediateStateNodes: true, Params: Params{
IntermediateStorageNodes: true, IntermediateStateNodes: true,
IncludeBlock: true, IntermediateStorageNodes: true,
IncludeReceipts: true, IncludeBlock: true,
IncludeTD: true, IncludeReceipts: true,
IncludeCode: true, IncludeTD: true,
IncludeCode: true,
},
} }
var statediffMetrics = RegisterStatediffMetrics(metrics.DefaultRegistry) var statediffMetrics = RegisterStatediffMetrics(metrics.DefaultRegistry)
@ -74,6 +76,7 @@ type blockChain interface {
GetTd(hash common.Hash, number uint64) *big.Int GetTd(hash common.Hash, number uint64) *big.Int
UnlockTrie(root common.Hash) UnlockTrie(root common.Hash)
StateCache() state.Database StateCache() state.Database
CurrentBlock() *types.Block
} }
// IService is the state-diffing service interface // IService is the state-diffing service interface
@ -102,7 +105,7 @@ type IService interface {
WriteStateDiffFor(blockHash common.Hash, params Params) error WriteStateDiffFor(blockHash common.Hash, params Params) error
// Event loop for progressively processing and writing diffs directly to DB // Event loop for progressively processing and writing diffs directly to DB
WriteLoop(chainEventCh chan core.ChainEvent) WriteLoop(chainEventCh chan core.ChainEvent)
// Method to add an address to be watched to write loop params // Method to change the addresses being watched in write loop params
WatchAddress(operation OperationType, addresses []common.Address) error WatchAddress(operation OperationType, addresses []common.Address) error
// Method to get currently watched addresses from write loop params // Method to get currently watched addresses from write loop params
GetWathchedAddresses() []common.Address GetWathchedAddresses() []common.Address
@ -165,6 +168,7 @@ func New(stack *node.Node, ethServ *eth.Ethereum, cfg *ethconfig.Config, params
blockChain := ethServ.BlockChain() blockChain := ethServ.BlockChain()
var indexer ind.Indexer var indexer ind.Indexer
var db *postgres.DB var db *postgres.DB
var err error
quitCh := make(chan bool) quitCh := make(chan bool)
if params.DBParams != nil { if params.DBParams != nil {
info := nodeinfo.Info{ info := nodeinfo.Info{
@ -176,7 +180,7 @@ func New(stack *node.Node, ethServ *eth.Ethereum, cfg *ethconfig.Config, params
} }
// TODO: pass max idle, open, lifetime? // TODO: pass max idle, open, lifetime?
db, err := postgres.NewDB(params.DBParams.ConnectionURL, postgres.ConnectionConfig{}, info) db, err = postgres.NewDB(params.DBParams.ConnectionURL, postgres.ConnectionConfig{}, info)
if err != nil { if err != nil {
return err return err
} }
@ -207,7 +211,7 @@ func New(stack *node.Node, ethServ *eth.Ethereum, cfg *ethconfig.Config, params
stack.RegisterLifecycle(sds) stack.RegisterLifecycle(sds)
stack.RegisterAPIs(sds.APIs()) stack.RegisterAPIs(sds.APIs())
err := loadWatchedAddresses(db) err = loadWatchedAddresses(db)
if err != nil { if err != nil {
return err return err
} }
@ -290,7 +294,9 @@ func (sds *Service) WriteLoop(chainEventCh chan core.ChainEvent) {
func (sds *Service) writeGenesisStateDiff(currBlock *types.Block, workerId uint) { func (sds *Service) writeGenesisStateDiff(currBlock *types.Block, workerId uint) {
// For genesis block we need to return the entire state trie hence we diff it with an empty trie. // For genesis block we need to return the entire state trie hence we diff it with an empty trie.
log.Info("Writing state diff", "block height", genesisBlockNumber, "worker", workerId) log.Info("Writing state diff", "block height", genesisBlockNumber, "worker", workerId)
err := sds.writeStateDiffWithRetry(currBlock, common.Hash{}, writeLoopParams) writeLoopParams.mu.RLock()
err := sds.writeStateDiffWithRetry(currBlock, common.Hash{}, writeLoopParams.Params)
writeLoopParams.mu.RUnlock()
if err != nil { if err != nil {
log.Error("statediff.Service.WriteLoop: processing error", "block height", log.Error("statediff.Service.WriteLoop: processing error", "block height",
genesisBlockNumber, "error", err.Error(), "worker", workerId) genesisBlockNumber, "error", err.Error(), "worker", workerId)
@ -319,7 +325,9 @@ func (sds *Service) writeLoopWorker(params workerParams) {
} }
log.Info("Writing state diff", "block height", currentBlock.Number().Uint64(), "worker", params.id) log.Info("Writing state diff", "block height", currentBlock.Number().Uint64(), "worker", params.id)
err := sds.writeStateDiffWithRetry(currentBlock, parentBlock.Root(), writeLoopParams) writeLoopParams.mu.RLock()
err := sds.writeStateDiffWithRetry(currentBlock, parentBlock.Root(), writeLoopParams.Params)
writeLoopParams.mu.RUnlock()
if err != nil { if err != nil {
log.Error("statediff.Service.WriteLoop: processing error", "block height", currentBlock.Number().Uint64(), "error", err.Error(), "worker", params.id) log.Error("statediff.Service.WriteLoop: processing error", "block height", currentBlock.Number().Uint64(), "error", err.Error(), "worker", params.id)
continue continue
@ -555,7 +563,7 @@ func (sds *Service) Start() error {
go sds.Loop(chainEventCh) go sds.Loop(chainEventCh)
if sds.enableWriteLoop { if sds.enableWriteLoop {
log.Info("Starting statediff DB write loop", "params", writeLoopParams) log.Info("Starting statediff DB write loop", "params", writeLoopParams.Params)
chainEventCh := make(chan core.ChainEvent, chainEventChanSize) chainEventCh := make(chan core.ChainEvent, chainEventChanSize)
go sds.WriteLoop(chainEventCh) go sds.WriteLoop(chainEventCh)
} }
@ -728,19 +736,32 @@ func (sds *Service) writeStateDiffWithRetry(block *types.Block, parentRoot commo
// Performs Add | Remove | Set | Clear operation on the watched addresses in writeLoopParams and the db with provided addresses // Performs Add | Remove | Set | Clear operation on the watched addresses in writeLoopParams and the db with provided addresses
func (sds *Service) WatchAddress(operation OperationType, addresses []common.Address) error { func (sds *Service) WatchAddress(operation OperationType, addresses []common.Address) error {
// check operation // lock writeLoopParams for a write
writeLoopParams.mu.Lock()
defer writeLoopParams.mu.Unlock()
// get the current block number
currentBlock := sds.BlockChain.CurrentBlock()
currentBlockNumber := currentBlock.Number()
switch operation { switch operation {
case Add: case Add:
addressesToRemove := []common.Address{}
for _, address := range addresses { for _, address := range addresses {
// Check if address is already being watched // Check if address is already being watched
// Throw a warning and continue if found
if containsAddress(writeLoopParams.WatchedAddresses, address) != -1 { if containsAddress(writeLoopParams.WatchedAddresses, address) != -1 {
return fmt.Errorf("Address %s already watched", address) // log.Warn(fmt.Sprint("Address ", address.Hex(), " already being watched"))
log.Warn("Address already being watched", "address", address.Hex())
addressesToRemove = append(addressesToRemove, address)
continue
} }
} }
// TODO: Make sure WriteLoop doesn't call statediffing before the params are updated for the current block // remove already watched addresses
// TODO: Get the current block addresses = removeAddresses(addresses, addressesToRemove)
err := sds.indexer.InsertWatchedAddresses(addresses, common.Big1)
err := sds.indexer.InsertWatchedAddresses(addresses, currentBlockNumber)
if err != nil { if err != nil {
return err return err
} }
@ -752,14 +773,14 @@ func (sds *Service) WatchAddress(operation OperationType, addresses []common.Add
return err return err
} }
removeWatchedAddresses(writeLoopParams.WatchedAddresses, addresses) writeLoopParams.WatchedAddresses = removeAddresses(writeLoopParams.WatchedAddresses, addresses)
case Set: case Set:
err := sds.indexer.ClearWatchedAddresses() err := sds.indexer.ClearWatchedAddresses()
if err != nil { if err != nil {
return err return err
} }
err = sds.indexer.InsertWatchedAddresses(addresses, common.Big1) err = sds.indexer.InsertWatchedAddresses(addresses, currentBlockNumber)
if err != nil { if err != nil {
return err return err
} }

View File

@ -128,6 +128,11 @@ func (bc *BlockChain) GetTd(hash common.Hash, blockNum uint64) *big.Int {
return nil return nil
} }
// CurrentBlock mock method
func (bc *BlockChain) CurrentBlock() *types.Block {
return nil
}
func (bc *BlockChain) SetTd(hash common.Hash, blockNum uint64, td *big.Int) { func (bc *BlockChain) SetTd(hash common.Hash, blockNum uint64, td *big.Int) {
if bc.TDByHash == nil { if bc.TDByHash == nil {
bc.TDByHash = make(map[common.Hash]*big.Int) bc.TDByHash = make(map[common.Hash]*big.Int)

View File

@ -22,6 +22,7 @@ package statediff
import ( import (
"encoding/json" "encoding/json"
"math/big" "math/big"
"sync"
"github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/common"
ctypes "github.com/ethereum/go-ethereum/core/types" ctypes "github.com/ethereum/go-ethereum/core/types"
@ -53,6 +54,12 @@ type Params struct {
WatchedStorageSlots []common.Hash WatchedStorageSlots []common.Hash
} }
// ParamsWithMutex allows to lock the parameters while they are being updated | read from
type ParamsWithMutex struct {
Params
mu sync.RWMutex
}
// Args bundles the arguments for the state diff builder // Args bundles the arguments for the state diff builder
type Args struct { type Args struct {
OldStateRoot, NewStateRoot, BlockHash common.Hash OldStateRoot, NewStateRoot, BlockHash common.Hash