Use current block and mutex while changing watched addresses

This commit is contained in:
Prathamesh Musale 2022-01-17 13:50:52 +05:30
parent 5fa002c0d0
commit bbb1759886
6 changed files with 73 additions and 43 deletions

View File

@ -150,7 +150,7 @@ func (api *PublicStateDiffAPI) WriteStateDiffFor(ctx context.Context, blockHash
return api.sds.WriteStateDiffFor(blockHash, params)
}
// WatchAddress adds the given address to a list of watched addresses to which the direct statediff process is restricted
// WatchAddress changes the list of watched addresses to which the direct indexing is restricted according to given operation
func (api *PublicStateDiffAPI) WatchAddress(operation OperationType, addresses []common.Address) error {
return api.sds.WatchAddress(operation, addresses)
}

View File

@ -93,28 +93,31 @@ func loadWatchedAddresses(db *postgres.DB) error {
watchedAddresses = append(watchedAddresses, common.HexToAddress(addressHex))
}
writeLoopParams.mu.Lock()
writeLoopParams.WatchedAddresses = watchedAddresses
writeLoopParams.mu.Unlock()
return nil
}
func removeWatchedAddresses(watchedAddresses []common.Address, addressesToRemove []common.Address) []common.Address {
addresses := make([]common.Address, len(addressesToRemove))
copy(addresses, watchedAddresses)
// removeAddresses is used to remove given addresses from a list of addresses
func removeAddresses(addresses []common.Address, addressesToRemove []common.Address) []common.Address {
addressesCopy := make([]common.Address, len(addresses))
copy(addressesCopy, addresses)
for _, address := range addressesToRemove {
if idx := containsAddress(addresses, address); idx != -1 {
addresses = append(addresses[:idx], addresses[idx+1:]...)
if idx := containsAddress(addressesCopy, address); idx != -1 {
addressesCopy = append(addressesCopy[:idx], addressesCopy[idx+1:]...)
}
}
return addresses
return addressesCopy
}
// containsAddress is used to check if an address is present in the provided list of watched addresses
// containsAddress is used to check if an address is present in the provided list of addresses
// return the index if found else -1
func containsAddress(watchedAddresses []common.Address, address common.Address) int {
for idx, addr := range watchedAddresses {
func containsAddress(addresses []common.Address, address common.Address) int {
for idx, addr := range addresses {
if addr == address {
return idx
}

View File

@ -553,6 +553,7 @@ func (sdi *StateDiffIndexer) PushCodeAndCodeHash(tx *BlockTx, codeAndCodeHash sd
return nil
}
// InsertWatchedAddresses inserts the given addresses in the database
func (sdi *StateDiffIndexer) InsertWatchedAddresses(addresses []common.Address, currentBlock *big.Int) error {
tx, err := sdi.dbWriter.db.Begin()
if err != nil {
@ -560,7 +561,8 @@ func (sdi *StateDiffIndexer) InsertWatchedAddresses(addresses []common.Address,
}
for _, address := range addresses {
_, err = tx.Exec(`INSERT INTO eth.watched_addresses (address, added_at) VALUES ($1, $2)`, address, currentBlock)
_, err = tx.Exec(`INSERT INTO eth.watched_addresses (address, added_at)VALUES ($1, $2) ON CONFLICT (address) DO NOTHING`,
address.Hex(), currentBlock.Uint64())
if err != nil {
return fmt.Errorf("error inserting watched_addresses entry: %v", err)
}
@ -574,6 +576,7 @@ func (sdi *StateDiffIndexer) InsertWatchedAddresses(addresses []common.Address,
return nil
}
// RemoveWatchedAddresses removes the given addresses from the database
func (sdi *StateDiffIndexer) RemoveWatchedAddresses(addresses []common.Address) error {
tx, err := sdi.dbWriter.db.Begin()
if err != nil {
@ -581,7 +584,7 @@ func (sdi *StateDiffIndexer) RemoveWatchedAddresses(addresses []common.Address)
}
for _, address := range addresses {
_, err = tx.Exec(`DELETE FROM eth.watched_addresses WHERE address = $1`, address)
_, err = tx.Exec(`DELETE FROM eth.watched_addresses WHERE address = $1`, address.Hex())
if err != nil {
return fmt.Errorf("error removing watched_addresses entry: %v", err)
}
@ -595,21 +598,12 @@ func (sdi *StateDiffIndexer) RemoveWatchedAddresses(addresses []common.Address)
return nil
}
// ClearWatchedAddresses clears all the addresses from the database
func (sdi *StateDiffIndexer) ClearWatchedAddresses() error {
tx, err := sdi.dbWriter.db.Begin()
if err != nil {
return err
}
_, err = tx.Exec(`DELETE FROM eth.watched_addresses`)
_, err := sdi.dbWriter.db.Exec(`DELETE FROM eth.watched_addresses`)
if err != nil {
return fmt.Errorf("error clearing watched_addresses table: %v", err)
}
err = tx.Commit()
if err != nil {
return err
}
return nil
}

View File

@ -55,13 +55,15 @@ const (
deadlockDetected = "deadlock detected" // 40P01 https://www.postgresql.org/docs/current/errcodes-appendix.html
)
var writeLoopParams = Params{
var writeLoopParams = ParamsWithMutex{
Params: Params{
IntermediateStateNodes: true,
IntermediateStorageNodes: true,
IncludeBlock: true,
IncludeReceipts: true,
IncludeTD: true,
IncludeCode: true,
},
}
var statediffMetrics = RegisterStatediffMetrics(metrics.DefaultRegistry)
@ -74,6 +76,7 @@ type blockChain interface {
GetTd(hash common.Hash, number uint64) *big.Int
UnlockTrie(root common.Hash)
StateCache() state.Database
CurrentBlock() *types.Block
}
// IService is the state-diffing service interface
@ -102,7 +105,7 @@ type IService interface {
WriteStateDiffFor(blockHash common.Hash, params Params) error
// Event loop for progressively processing and writing diffs directly to DB
WriteLoop(chainEventCh chan core.ChainEvent)
// Method to add an address to be watched to write loop params
// Method to change the addresses being watched in write loop params
WatchAddress(operation OperationType, addresses []common.Address) error
// Method to get currently watched addresses from write loop params
GetWathchedAddresses() []common.Address
@ -165,6 +168,7 @@ func New(stack *node.Node, ethServ *eth.Ethereum, cfg *ethconfig.Config, params
blockChain := ethServ.BlockChain()
var indexer ind.Indexer
var db *postgres.DB
var err error
quitCh := make(chan bool)
if params.DBParams != nil {
info := nodeinfo.Info{
@ -176,7 +180,7 @@ func New(stack *node.Node, ethServ *eth.Ethereum, cfg *ethconfig.Config, params
}
// TODO: pass max idle, open, lifetime?
db, err := postgres.NewDB(params.DBParams.ConnectionURL, postgres.ConnectionConfig{}, info)
db, err = postgres.NewDB(params.DBParams.ConnectionURL, postgres.ConnectionConfig{}, info)
if err != nil {
return err
}
@ -207,7 +211,7 @@ func New(stack *node.Node, ethServ *eth.Ethereum, cfg *ethconfig.Config, params
stack.RegisterLifecycle(sds)
stack.RegisterAPIs(sds.APIs())
err := loadWatchedAddresses(db)
err = loadWatchedAddresses(db)
if err != nil {
return err
}
@ -290,7 +294,9 @@ func (sds *Service) WriteLoop(chainEventCh chan core.ChainEvent) {
func (sds *Service) writeGenesisStateDiff(currBlock *types.Block, workerId uint) {
// For genesis block we need to return the entire state trie hence we diff it with an empty trie.
log.Info("Writing state diff", "block height", genesisBlockNumber, "worker", workerId)
err := sds.writeStateDiffWithRetry(currBlock, common.Hash{}, writeLoopParams)
writeLoopParams.mu.RLock()
err := sds.writeStateDiffWithRetry(currBlock, common.Hash{}, writeLoopParams.Params)
writeLoopParams.mu.RUnlock()
if err != nil {
log.Error("statediff.Service.WriteLoop: processing error", "block height",
genesisBlockNumber, "error", err.Error(), "worker", workerId)
@ -319,7 +325,9 @@ func (sds *Service) writeLoopWorker(params workerParams) {
}
log.Info("Writing state diff", "block height", currentBlock.Number().Uint64(), "worker", params.id)
err := sds.writeStateDiffWithRetry(currentBlock, parentBlock.Root(), writeLoopParams)
writeLoopParams.mu.RLock()
err := sds.writeStateDiffWithRetry(currentBlock, parentBlock.Root(), writeLoopParams.Params)
writeLoopParams.mu.RUnlock()
if err != nil {
log.Error("statediff.Service.WriteLoop: processing error", "block height", currentBlock.Number().Uint64(), "error", err.Error(), "worker", params.id)
continue
@ -555,7 +563,7 @@ func (sds *Service) Start() error {
go sds.Loop(chainEventCh)
if sds.enableWriteLoop {
log.Info("Starting statediff DB write loop", "params", writeLoopParams)
log.Info("Starting statediff DB write loop", "params", writeLoopParams.Params)
chainEventCh := make(chan core.ChainEvent, chainEventChanSize)
go sds.WriteLoop(chainEventCh)
}
@ -728,19 +736,32 @@ func (sds *Service) writeStateDiffWithRetry(block *types.Block, parentRoot commo
// Performs Add | Remove | Set | Clear operation on the watched addresses in writeLoopParams and the db with provided addresses
func (sds *Service) WatchAddress(operation OperationType, addresses []common.Address) error {
// check operation
// lock writeLoopParams for a write
writeLoopParams.mu.Lock()
defer writeLoopParams.mu.Unlock()
// get the current block number
currentBlock := sds.BlockChain.CurrentBlock()
currentBlockNumber := currentBlock.Number()
switch operation {
case Add:
addressesToRemove := []common.Address{}
for _, address := range addresses {
// Check if address is already being watched
// Throw a warning and continue if found
if containsAddress(writeLoopParams.WatchedAddresses, address) != -1 {
return fmt.Errorf("Address %s already watched", address)
// log.Warn(fmt.Sprint("Address ", address.Hex(), " already being watched"))
log.Warn("Address already being watched", "address", address.Hex())
addressesToRemove = append(addressesToRemove, address)
continue
}
}
// TODO: Make sure WriteLoop doesn't call statediffing before the params are updated for the current block
// TODO: Get the current block
err := sds.indexer.InsertWatchedAddresses(addresses, common.Big1)
// remove already watched addresses
addresses = removeAddresses(addresses, addressesToRemove)
err := sds.indexer.InsertWatchedAddresses(addresses, currentBlockNumber)
if err != nil {
return err
}
@ -752,14 +773,14 @@ func (sds *Service) WatchAddress(operation OperationType, addresses []common.Add
return err
}
removeWatchedAddresses(writeLoopParams.WatchedAddresses, addresses)
writeLoopParams.WatchedAddresses = removeAddresses(writeLoopParams.WatchedAddresses, addresses)
case Set:
err := sds.indexer.ClearWatchedAddresses()
if err != nil {
return err
}
err = sds.indexer.InsertWatchedAddresses(addresses, common.Big1)
err = sds.indexer.InsertWatchedAddresses(addresses, currentBlockNumber)
if err != nil {
return err
}

View File

@ -128,6 +128,11 @@ func (bc *BlockChain) GetTd(hash common.Hash, blockNum uint64) *big.Int {
return nil
}
// CurrentBlock mock method
func (bc *BlockChain) CurrentBlock() *types.Block {
return nil
}
func (bc *BlockChain) SetTd(hash common.Hash, blockNum uint64, td *big.Int) {
if bc.TDByHash == nil {
bc.TDByHash = make(map[common.Hash]*big.Int)

View File

@ -22,6 +22,7 @@ package statediff
import (
"encoding/json"
"math/big"
"sync"
"github.com/ethereum/go-ethereum/common"
ctypes "github.com/ethereum/go-ethereum/core/types"
@ -53,6 +54,12 @@ type Params struct {
WatchedStorageSlots []common.Hash
}
// ParamsWithMutex allows to lock the parameters while they are being updated | read from
type ParamsWithMutex struct {
Params
mu sync.RWMutex
}
// Args bundles the arguments for the state diff builder
type Args struct {
OldStateRoot, NewStateRoot, BlockHash common.Hash