Add support for changing watched storage slots

This commit is contained in:
Prathamesh Musale 2022-01-19 15:59:35 +05:30
parent 355ad83b88
commit ebd43fb857
7 changed files with 194 additions and 53 deletions

View File

@ -150,7 +150,7 @@ func (api *PublicStateDiffAPI) WriteStateDiffFor(ctx context.Context, blockHash
return api.sds.WriteStateDiffFor(blockHash, params) return api.sds.WriteStateDiffFor(blockHash, params)
} }
// WatchAddress changes the list of watched addresses to which the direct indexing is restricted according to given operation // WatchAddress changes the list of watched addresses | storage slots to which the direct indexing is restricted according to given operation
func (api *PublicStateDiffAPI) WatchAddress(operation OperationType, args []WatchAddressArg) error { func (api *PublicStateDiffAPI) WatchAddress(operation OperationType, args []WatchAddressArg) error {
return api.sds.WatchAddress(operation, args) return api.sds.WatchAddress(operation, args)
} }

View File

@ -76,23 +76,36 @@ func findIntersection(a, b []string) []string {
} }
} }
// loadWatchedAddresses is used to load watched addresses to the in-memory write loop params from the db // loadWatched is used to load watched addresses and storage slots to the in-memory write loop params from the db
func loadWatchedAddresses(db *postgres.DB) error { func loadWatched(db *postgres.DB) error {
var watchedAddressStrings []string type Watched struct {
pgStr := "SELECT address FROM eth.watched_addresses" Address string `db:"address"`
err := db.Select(&watchedAddressStrings, pgStr) Kind int `db:"kind"`
}
var watched []Watched
pgStr := "SELECT address, kind FROM eth.watched_addresses"
err := db.Select(&watched, pgStr)
if err != nil { if err != nil {
return fmt.Errorf("error loading watched addresses: %v", err) return fmt.Errorf("error loading watched addresses: %v", err)
} }
var watchedAddresses []common.Address var watchedAddresses []common.Address
for _, watchedAddressString := range watchedAddressStrings { var watchedStorageSlots []common.Hash
watchedAddresses = append(watchedAddresses, common.HexToAddress(watchedAddressString)) for _, entry := range watched {
switch entry.Kind {
case types.WatchedAddress.Int():
watchedAddresses = append(watchedAddresses, common.HexToAddress(entry.Address))
case types.WatchedStorageSlot.Int():
watchedStorageSlots = append(watchedStorageSlots, common.HexToHash(entry.Address))
default:
return fmt.Errorf("Unexpected kind %d", entry.Kind)
}
} }
writeLoopParams.Lock() writeLoopParams.Lock()
defer writeLoopParams.Unlock() defer writeLoopParams.Unlock()
writeLoopParams.WatchedAddresses = watchedAddresses writeLoopParams.WatchedAddresses = watchedAddresses
writeLoopParams.WatchedStorageSlots = watchedStorageSlots
return nil return nil
} }
@ -110,6 +123,19 @@ func removeAddresses(addresses []common.Address, addressesToRemove []common.Addr
return filteredAddresses return filteredAddresses
} }
// removeAddresses is used to remove given storage slots from a list of storage slots
func removeStorageSlots(storageSlots []common.Hash, storageSlotsToRemove []common.Hash) []common.Hash {
filteredStorageSlots := []common.Hash{}
for _, address := range storageSlots {
if idx := containsStorageSlot(storageSlotsToRemove, address); idx == -1 {
filteredStorageSlots = append(filteredStorageSlots, address)
}
}
return filteredStorageSlots
}
// containsAddress is used to check if an address is present in the provided list of addresses // containsAddress is used to check if an address is present in the provided list of addresses
// return the index if found else -1 // return the index if found else -1
func containsAddress(addresses []common.Address, address common.Address) int { func containsAddress(addresses []common.Address, address common.Address) int {
@ -121,27 +147,65 @@ func containsAddress(addresses []common.Address, address common.Address) int {
return -1 return -1
} }
// getArgAddresses is used to get the list of addresses from a list of WatchAddressArgs // containsAddress is used to check if a storage slot is present in the provided list of storage slots
// return the index if found else -1
func containsStorageSlot(storageSlots []common.Hash, storageSlot common.Hash) int {
for idx, slot := range storageSlots {
if slot == storageSlot {
return idx
}
}
return -1
}
// getAddresses is used to get the list of addresses from a list of WatchAddressArgs
func getAddresses(args []types.WatchAddressArg) []common.Address { func getAddresses(args []types.WatchAddressArg) []common.Address {
addresses := make([]common.Address, len(args)) addresses := make([]common.Address, len(args))
for idx, arg := range args { for idx, arg := range args {
addresses[idx] = arg.Address addresses[idx] = common.HexToAddress(arg.Address)
} }
return addresses return addresses
} }
// filterArgs filters out the args having an address from a given list of addresses // getStorageSlots is used to get the list of storage slots from a list of WatchAddressArgs
func filterArgs(args []types.WatchAddressArg, addressesToRemove []common.Address) ([]types.WatchAddressArg, []common.Address) { func getStorageSlots(args []types.WatchAddressArg) []common.Hash {
storageSlots := make([]common.Hash, len(args))
for idx, arg := range args {
storageSlots[idx] = common.HexToHash(arg.Address)
}
return storageSlots
}
// filterAddressArgs filters out the args having an address from a given list of addresses
func filterAddressArgs(args []types.WatchAddressArg, addressesToRemove []common.Address) ([]types.WatchAddressArg, []common.Address) {
filteredArgs := []types.WatchAddressArg{} filteredArgs := []types.WatchAddressArg{}
filteredAddresses := []common.Address{} filteredAddresses := []common.Address{}
for _, arg := range args { for _, arg := range args {
if idx := containsAddress(addressesToRemove, arg.Address); idx == -1 { address := common.HexToAddress(arg.Address)
if idx := containsAddress(addressesToRemove, address); idx == -1 {
filteredArgs = append(filteredArgs, arg) filteredArgs = append(filteredArgs, arg)
filteredAddresses = append(filteredAddresses, arg.Address) filteredAddresses = append(filteredAddresses, address)
} }
} }
return filteredArgs, filteredAddresses return filteredArgs, filteredAddresses
} }
// filterStorageSlotArgs filters out the args having a storage slot from a given list of storage slots
func filterStorageSlotArgs(args []types.WatchAddressArg, storageSlotsToRemove []common.Hash) ([]types.WatchAddressArg, []common.Hash) {
filteredArgs := []types.WatchAddressArg{}
filteredStorageSlots := []common.Hash{}
for _, arg := range args {
storageSlot := common.HexToHash(arg.Address)
if idx := containsStorageSlot(storageSlotsToRemove, storageSlot); idx == -1 {
filteredArgs = append(filteredArgs, arg)
filteredStorageSlots = append(filteredStorageSlots, storageSlot)
}
}
return filteredArgs, filteredStorageSlots
}

View File

@ -61,10 +61,10 @@ type Indexer interface {
ReportDBMetrics(delay time.Duration, quit <-chan bool) ReportDBMetrics(delay time.Duration, quit <-chan bool)
// Methods used by WatchAddress API/functionality. // Methods used by WatchAddress API/functionality.
InsertWatchedAddresses(addresses []sdtypes.WatchAddressArg, currentBlock *big.Int) error InsertWatched(addresses []sdtypes.WatchAddressArg, currentBlock *big.Int, kind sdtypes.WatchedAddressType) error
RemoveWatchedAddresses(addresses []common.Address) error RemoveWatched(addresses []sdtypes.WatchAddressArg, kind sdtypes.WatchedAddressType) error
SetWatchedAddresses(args []sdtypes.WatchAddressArg, currentBlockNumber *big.Int) error SetWatched(args []sdtypes.WatchAddressArg, currentBlockNumber *big.Int, kind sdtypes.WatchedAddressType) error
ClearWatchedAddresses() error ClearWatched(kind sdtypes.WatchedAddressType) error
} }
// StateDiffIndexer satisfies the Indexer interface for ethereum statediff objects // StateDiffIndexer satisfies the Indexer interface for ethereum statediff objects
@ -556,8 +556,8 @@ func (sdi *StateDiffIndexer) PushCodeAndCodeHash(tx *BlockTx, codeAndCodeHash sd
return nil return nil
} }
// InsertWatchedAddresses inserts the given addresses in the database // InsertWatchedAddresses inserts the given addresses | storage slots in the database
func (sdi *StateDiffIndexer) InsertWatchedAddresses(args []sdtypes.WatchAddressArg, currentBlockNumber *big.Int) error { func (sdi *StateDiffIndexer) InsertWatched(args []sdtypes.WatchAddressArg, currentBlockNumber *big.Int, kind sdtypes.WatchedAddressType) error {
tx, err := sdi.dbWriter.db.Begin() tx, err := sdi.dbWriter.db.Begin()
if err != nil { if err != nil {
return err return err
@ -565,8 +565,8 @@ func (sdi *StateDiffIndexer) InsertWatchedAddresses(args []sdtypes.WatchAddressA
defer tx.Rollback() defer tx.Rollback()
for _, arg := range args { for _, arg := range args {
_, err = tx.Exec(`INSERT INTO eth.watched_addresses (address, created_at, watched_at) VALUES ($1, $2, $3) ON CONFLICT (address) DO NOTHING`, _, err = tx.Exec(`INSERT INTO eth.watched_addresses (address, kind, created_at, watched_at) VALUES ($1, $2, $3, $4) ON CONFLICT (address) DO NOTHING`,
arg.Address.Hex(), arg.CreatedAt, currentBlockNumber.Uint64()) arg.Address, kind.Int(), arg.CreatedAt, currentBlockNumber.Uint64())
if err != nil { if err != nil {
return fmt.Errorf("error inserting watched_addresses entry: %v", err) return fmt.Errorf("error inserting watched_addresses entry: %v", err)
} }
@ -580,16 +580,16 @@ func (sdi *StateDiffIndexer) InsertWatchedAddresses(args []sdtypes.WatchAddressA
return nil return nil
} }
// RemoveWatchedAddresses removes the given addresses from the database // RemoveWatchedAddresses removes the given addresses | storage slots from the database
func (sdi *StateDiffIndexer) RemoveWatchedAddresses(addresses []common.Address) error { func (sdi *StateDiffIndexer) RemoveWatched(args []sdtypes.WatchAddressArg, kind sdtypes.WatchedAddressType) error {
tx, err := sdi.dbWriter.db.Begin() tx, err := sdi.dbWriter.db.Begin()
if err != nil { if err != nil {
return err return err
} }
defer tx.Rollback() defer tx.Rollback()
for _, address := range addresses { for _, arg := range args {
_, err = tx.Exec(`DELETE FROM eth.watched_addresses WHERE address = $1`, address.Hex()) _, err = tx.Exec(`DELETE FROM eth.watched_addresses WHERE address = $1 AND kind = $2`, arg.Address, kind.Int())
if err != nil { if err != nil {
return fmt.Errorf("error removing watched_addresses entry: %v", err) return fmt.Errorf("error removing watched_addresses entry: %v", err)
} }
@ -603,21 +603,22 @@ func (sdi *StateDiffIndexer) RemoveWatchedAddresses(addresses []common.Address)
return nil return nil
} }
func (sdi *StateDiffIndexer) SetWatchedAddresses(args []sdtypes.WatchAddressArg, currentBlockNumber *big.Int) error { // SetWatched clears and inserts the given addresses | storage slots in the database
func (sdi *StateDiffIndexer) SetWatched(args []sdtypes.WatchAddressArg, currentBlockNumber *big.Int, kind sdtypes.WatchedAddressType) error {
tx, err := sdi.dbWriter.db.Begin() tx, err := sdi.dbWriter.db.Begin()
if err != nil { if err != nil {
return err return err
} }
defer tx.Rollback() defer tx.Rollback()
_, err = tx.Exec(`DELETE FROM eth.watched_addresses`) _, err = tx.Exec(`DELETE FROM eth.watched_addresses WHERE kind = $1`, kind.Int())
if err != nil { if err != nil {
return fmt.Errorf("error setting watched_addresses table: %v", err) return fmt.Errorf("error setting watched_addresses table: %v", err)
} }
for _, arg := range args { for _, arg := range args {
_, err = tx.Exec(`INSERT INTO eth.watched_addresses (address, created_at, watched_at) VALUES ($1, $2, $3) ON CONFLICT (address) DO NOTHING`, _, err = tx.Exec(`INSERT INTO eth.watched_addresses (address, kind, created_at, watched_at) VALUES ($1, $2, $3, $4) ON CONFLICT (address) DO NOTHING`,
arg.Address.Hex(), arg.CreatedAt, currentBlockNumber.Uint64()) arg.Address, kind.Int(), arg.CreatedAt, currentBlockNumber.Uint64())
if err != nil { if err != nil {
return fmt.Errorf("error setting watched_addresses table: %v", err) return fmt.Errorf("error setting watched_addresses table: %v", err)
} }
@ -631,9 +632,9 @@ func (sdi *StateDiffIndexer) SetWatchedAddresses(args []sdtypes.WatchAddressArg,
return nil return nil
} }
// ClearWatchedAddresses clears all the addresses from the database // ClearWatchedAddresses clears all the addresses | storage slots from the database
func (sdi *StateDiffIndexer) ClearWatchedAddresses() error { func (sdi *StateDiffIndexer) ClearWatched(kind sdtypes.WatchedAddressType) error {
_, err := sdi.dbWriter.db.Exec(`DELETE FROM eth.watched_addresses`) _, err := sdi.dbWriter.db.Exec(`DELETE FROM eth.watched_addresses WHERE kind = $1`, kind.Int())
if err != nil { if err != nil {
return fmt.Errorf("error clearing watched_addresses table: %v", err) return fmt.Errorf("error clearing watched_addresses table: %v", err)
} }

View File

@ -211,7 +211,7 @@ func New(stack *node.Node, ethServ *eth.Ethereum, cfg *ethconfig.Config, params
stack.RegisterLifecycle(sds) stack.RegisterLifecycle(sds)
stack.RegisterAPIs(sds.APIs()) stack.RegisterAPIs(sds.APIs())
err = loadWatchedAddresses(db) err = loadWatched(db)
if err != nil { if err != nil {
return err return err
} }
@ -734,7 +734,9 @@ func (sds *Service) writeStateDiffWithRetry(block *types.Block, parentRoot commo
return err return err
} }
// Performs Add | Remove | Set | Clear operation on the watched addresses in writeLoopParams and the db with provided addresses // Performs one of foll. operations on the watched addresses | storage slots in writeLoopParams and the db:
// AddAddresses | RemoveAddresses | SetAddresses | ClearAddresses
// AddStorageSlots | RemoveStorageSlots | SetStorageSlots | ClearStorageSlots
func (sds *Service) WatchAddress(operation OperationType, args []WatchAddressArg) error { func (sds *Service) WatchAddress(operation OperationType, args []WatchAddressArg) error {
// lock writeLoopParams for a write // lock writeLoopParams for a write
writeLoopParams.Lock() writeLoopParams.Lock()
@ -744,51 +746,100 @@ func (sds *Service) WatchAddress(operation OperationType, args []WatchAddressArg
currentBlockNumber := sds.BlockChain.CurrentBlock().Number() currentBlockNumber := sds.BlockChain.CurrentBlock().Number()
switch operation { switch operation {
case Add: case AddAddresses:
addressesToRemove := []common.Address{} addressesToRemove := []common.Address{}
for _, arg := range args { for _, arg := range args {
// Check if address is already being watched // Check if address is already being watched
// Throw a warning and continue if found // Throw a warning and continue if found
if containsAddress(writeLoopParams.WatchedAddresses, arg.Address) != -1 { address := common.HexToAddress(arg.Address)
log.Warn("Address already being watched", "address", arg.Address.Hex()) if containsAddress(writeLoopParams.WatchedAddresses, address) != -1 {
addressesToRemove = append(addressesToRemove, arg.Address) log.Warn("Address already being watched", "address", arg.Address)
addressesToRemove = append(addressesToRemove, address)
continue continue
} }
} }
// remove already watched addresses // remove already watched addresses
filteredArgs, filteredAddresses := filterArgs(args, addressesToRemove) filteredArgs, filteredAddresses := filterAddressArgs(args, addressesToRemove)
err := sds.indexer.InsertWatchedAddresses(filteredArgs, currentBlockNumber) err := sds.indexer.InsertWatched(filteredArgs, currentBlockNumber, WatchedAddress)
if err != nil { if err != nil {
return err return err
} }
writeLoopParams.WatchedAddresses = append(writeLoopParams.WatchedAddresses, filteredAddresses...) writeLoopParams.WatchedAddresses = append(writeLoopParams.WatchedAddresses, filteredAddresses...)
case Remove: case RemoveAddresses:
addresses := getAddresses(args) addresses := getAddresses(args)
err := sds.indexer.RemoveWatchedAddresses(addresses) err := sds.indexer.RemoveWatched(args, WatchedAddress)
if err != nil { if err != nil {
return err return err
} }
writeLoopParams.WatchedAddresses = removeAddresses(writeLoopParams.WatchedAddresses, addresses) writeLoopParams.WatchedAddresses = removeAddresses(writeLoopParams.WatchedAddresses, addresses)
case Set: case SetAddresses:
err := sds.indexer.SetWatchedAddresses(args, currentBlockNumber) err := sds.indexer.SetWatched(args, currentBlockNumber, WatchedAddress)
if err != nil { if err != nil {
return err return err
} }
addresses := getAddresses(args) addresses := getAddresses(args)
writeLoopParams.WatchedAddresses = addresses writeLoopParams.WatchedAddresses = addresses
case Clear: case ClearAddresses:
err := sds.indexer.ClearWatchedAddresses() err := sds.indexer.ClearWatched(WatchedAddress)
if err != nil { if err != nil {
return err return err
} }
writeLoopParams.WatchedAddresses = nil writeLoopParams.WatchedAddresses = nil
case AddStorageSlots:
storageSlotsToRemove := []common.Hash{}
for _, arg := range args {
// Check if address is already being watched
// Throw a warning and continue if found
storageSlot := common.HexToHash(arg.Address)
if containsStorageSlot(writeLoopParams.WatchedStorageSlots, storageSlot) != -1 {
log.Warn("StorageSlot already being watched", "storage slot", arg.Address)
storageSlotsToRemove = append(storageSlotsToRemove, storageSlot)
continue
}
}
// remove already watched addresses
filteredArgs, filteredStorageSlots := filterStorageSlotArgs(args, storageSlotsToRemove)
err := sds.indexer.InsertWatched(filteredArgs, currentBlockNumber, WatchedStorageSlot)
if err != nil {
return err
}
writeLoopParams.WatchedStorageSlots = append(writeLoopParams.WatchedStorageSlots, filteredStorageSlots...)
case RemoveStorageSlots:
storageSlots := getStorageSlots(args)
err := sds.indexer.RemoveWatched(args, WatchedStorageSlot)
if err != nil {
return err
}
writeLoopParams.WatchedStorageSlots = removeStorageSlots(writeLoopParams.WatchedStorageSlots, storageSlots)
case SetStorageSlots:
err := sds.indexer.SetWatched(args, currentBlockNumber, WatchedStorageSlot)
if err != nil {
return err
}
storageSlots := getStorageSlots(args)
writeLoopParams.WatchedStorageSlots = storageSlots
case ClearStorageSlots:
err := sds.indexer.ClearWatched(WatchedStorageSlot)
if err != nil {
return err
}
writeLoopParams.WatchedStorageSlots = nil
default: default:
return fmt.Errorf("Unexpected operation %s", operation) return fmt.Errorf("Unexpected operation %s", operation)
} }

View File

@ -333,7 +333,7 @@ func sendNonBlockingQuit(id rpc.ID, sub statediff.Subscription) {
} }
} }
func (sds *MockStateDiffService) WatchAddress(operation statediff.OperationType, addresses []sdtypes.WatchAddressArg) error { func (sds *MockStateDiffService) WatchAddress(operation statediff.OperationType, args []sdtypes.WatchAddressArg) error {
return nil return nil
} }

View File

@ -123,8 +123,13 @@ type accountWrapper struct {
type OperationType string type OperationType string
const ( const (
Add OperationType = "Add" AddAddresses OperationType = "AddAddresses"
Remove OperationType = "Remove" RemoveAddresses OperationType = "RemoveAddresses"
Set OperationType = "Set" SetAddresses OperationType = "SetAddresses"
Clear OperationType = "Clear" ClearAddresses OperationType = "ClearAddresses"
AddStorageSlots OperationType = "AddStorageSlots"
RemoveStorageSlots OperationType = "RemoveStorageSlots"
SetStorageSlots OperationType = "SetStorageSlots"
ClearStorageSlots OperationType = "ClearStorageSlots"
) )

View File

@ -77,6 +77,26 @@ type CodeSink func(CodeAndCodeHash) error
// WatchAddressArg is a arg type for WatchAddress API // WatchAddressArg is a arg type for WatchAddress API
type WatchAddressArg struct { type WatchAddressArg struct {
Address common.Address // Address represents common.Address | common.Hash
Address string
CreatedAt uint64 CreatedAt uint64
} }
// WatchedAddressType for denoting watched: address | storage slot
type WatchedAddressType string
const (
WatchedAddress WatchedAddressType = "WatchedAddress"
WatchedStorageSlot WatchedAddressType = "WatchedStorageSlot"
)
func (n WatchedAddressType) Int() int {
switch n {
case WatchedAddress:
return 0
case WatchedStorageSlot:
return 1
default:
return -1
}
}