From ebd43fb85711834cbc901900d7cf76000fe10458 Mon Sep 17 00:00:00 2001 From: prathamesh0 Date: Wed, 19 Jan 2022 15:59:35 +0530 Subject: [PATCH] Add support for changing watched storage slots --- statediff/api.go | 2 +- statediff/helpers.go | 90 ++++++++++++++++++++++---- statediff/indexer/indexer.go | 39 +++++------ statediff/service.go | 79 ++++++++++++++++++---- statediff/testhelpers/mocks/service.go | 2 +- statediff/types.go | 13 ++-- statediff/types/types.go | 22 ++++++- 7 files changed, 194 insertions(+), 53 deletions(-) diff --git a/statediff/api.go b/statediff/api.go index ed9cc3c06..3686728f2 100644 --- a/statediff/api.go +++ b/statediff/api.go @@ -150,7 +150,7 @@ func (api *PublicStateDiffAPI) WriteStateDiffFor(ctx context.Context, blockHash return api.sds.WriteStateDiffFor(blockHash, params) } -// WatchAddress changes the list of watched addresses to which the direct indexing is restricted according to given operation +// WatchAddress changes the list of watched addresses | storage slots to which the direct indexing is restricted according to given operation func (api *PublicStateDiffAPI) WatchAddress(operation OperationType, args []WatchAddressArg) error { return api.sds.WatchAddress(operation, args) } diff --git a/statediff/helpers.go b/statediff/helpers.go index 1f8523fc3..209e1d724 100644 --- a/statediff/helpers.go +++ b/statediff/helpers.go @@ -76,23 +76,36 @@ func findIntersection(a, b []string) []string { } } -// loadWatchedAddresses is used to load watched addresses to the in-memory write loop params from the db -func loadWatchedAddresses(db *postgres.DB) error { - var watchedAddressStrings []string - pgStr := "SELECT address FROM eth.watched_addresses" - err := db.Select(&watchedAddressStrings, pgStr) +// loadWatched is used to load watched addresses and storage slots to the in-memory write loop params from the db +func loadWatched(db *postgres.DB) error { + type Watched struct { + Address string `db:"address"` + Kind int `db:"kind"` + } + var watched []Watched + pgStr := "SELECT address, kind FROM eth.watched_addresses" + err := db.Select(&watched, pgStr) if err != nil { return fmt.Errorf("error loading watched addresses: %v", err) } var watchedAddresses []common.Address - for _, watchedAddressString := range watchedAddressStrings { - watchedAddresses = append(watchedAddresses, common.HexToAddress(watchedAddressString)) + var watchedStorageSlots []common.Hash + for _, entry := range watched { + switch entry.Kind { + case types.WatchedAddress.Int(): + watchedAddresses = append(watchedAddresses, common.HexToAddress(entry.Address)) + case types.WatchedStorageSlot.Int(): + watchedStorageSlots = append(watchedStorageSlots, common.HexToHash(entry.Address)) + default: + return fmt.Errorf("Unexpected kind %d", entry.Kind) + } } writeLoopParams.Lock() defer writeLoopParams.Unlock() writeLoopParams.WatchedAddresses = watchedAddresses + writeLoopParams.WatchedStorageSlots = watchedStorageSlots return nil } @@ -110,6 +123,19 @@ func removeAddresses(addresses []common.Address, addressesToRemove []common.Addr return filteredAddresses } +// removeAddresses is used to remove given storage slots from a list of storage slots +func removeStorageSlots(storageSlots []common.Hash, storageSlotsToRemove []common.Hash) []common.Hash { + filteredStorageSlots := []common.Hash{} + + for _, address := range storageSlots { + if idx := containsStorageSlot(storageSlotsToRemove, address); idx == -1 { + filteredStorageSlots = append(filteredStorageSlots, address) + } + } + + return filteredStorageSlots +} + // containsAddress is used to check if an address is present in the provided list of addresses // return the index if found else -1 func containsAddress(addresses []common.Address, address common.Address) int { @@ -121,27 +147,65 @@ func containsAddress(addresses []common.Address, address common.Address) int { return -1 } -// getArgAddresses is used to get the list of addresses from a list of WatchAddressArgs +// containsAddress is used to check if a storage slot is present in the provided list of storage slots +// return the index if found else -1 +func containsStorageSlot(storageSlots []common.Hash, storageSlot common.Hash) int { + for idx, slot := range storageSlots { + if slot == storageSlot { + return idx + } + } + return -1 +} + +// getAddresses is used to get the list of addresses from a list of WatchAddressArgs func getAddresses(args []types.WatchAddressArg) []common.Address { addresses := make([]common.Address, len(args)) for idx, arg := range args { - addresses[idx] = arg.Address + addresses[idx] = common.HexToAddress(arg.Address) } return addresses } -// filterArgs filters out the args having an address from a given list of addresses -func filterArgs(args []types.WatchAddressArg, addressesToRemove []common.Address) ([]types.WatchAddressArg, []common.Address) { +// getStorageSlots is used to get the list of storage slots from a list of WatchAddressArgs +func getStorageSlots(args []types.WatchAddressArg) []common.Hash { + storageSlots := make([]common.Hash, len(args)) + for idx, arg := range args { + storageSlots[idx] = common.HexToHash(arg.Address) + } + + return storageSlots +} + +// filterAddressArgs filters out the args having an address from a given list of addresses +func filterAddressArgs(args []types.WatchAddressArg, addressesToRemove []common.Address) ([]types.WatchAddressArg, []common.Address) { filteredArgs := []types.WatchAddressArg{} filteredAddresses := []common.Address{} for _, arg := range args { - if idx := containsAddress(addressesToRemove, arg.Address); idx == -1 { + address := common.HexToAddress(arg.Address) + if idx := containsAddress(addressesToRemove, address); idx == -1 { filteredArgs = append(filteredArgs, arg) - filteredAddresses = append(filteredAddresses, arg.Address) + filteredAddresses = append(filteredAddresses, address) } } return filteredArgs, filteredAddresses } + +// filterStorageSlotArgs filters out the args having a storage slot from a given list of storage slots +func filterStorageSlotArgs(args []types.WatchAddressArg, storageSlotsToRemove []common.Hash) ([]types.WatchAddressArg, []common.Hash) { + filteredArgs := []types.WatchAddressArg{} + filteredStorageSlots := []common.Hash{} + + for _, arg := range args { + storageSlot := common.HexToHash(arg.Address) + if idx := containsStorageSlot(storageSlotsToRemove, storageSlot); idx == -1 { + filteredArgs = append(filteredArgs, arg) + filteredStorageSlots = append(filteredStorageSlots, storageSlot) + } + } + + return filteredArgs, filteredStorageSlots +} diff --git a/statediff/indexer/indexer.go b/statediff/indexer/indexer.go index c85d00b2e..8ec34549b 100644 --- a/statediff/indexer/indexer.go +++ b/statediff/indexer/indexer.go @@ -61,10 +61,10 @@ type Indexer interface { ReportDBMetrics(delay time.Duration, quit <-chan bool) // Methods used by WatchAddress API/functionality. - InsertWatchedAddresses(addresses []sdtypes.WatchAddressArg, currentBlock *big.Int) error - RemoveWatchedAddresses(addresses []common.Address) error - SetWatchedAddresses(args []sdtypes.WatchAddressArg, currentBlockNumber *big.Int) error - ClearWatchedAddresses() error + InsertWatched(addresses []sdtypes.WatchAddressArg, currentBlock *big.Int, kind sdtypes.WatchedAddressType) error + RemoveWatched(addresses []sdtypes.WatchAddressArg, kind sdtypes.WatchedAddressType) error + SetWatched(args []sdtypes.WatchAddressArg, currentBlockNumber *big.Int, kind sdtypes.WatchedAddressType) error + ClearWatched(kind sdtypes.WatchedAddressType) error } // StateDiffIndexer satisfies the Indexer interface for ethereum statediff objects @@ -556,8 +556,8 @@ func (sdi *StateDiffIndexer) PushCodeAndCodeHash(tx *BlockTx, codeAndCodeHash sd return nil } -// InsertWatchedAddresses inserts the given addresses in the database -func (sdi *StateDiffIndexer) InsertWatchedAddresses(args []sdtypes.WatchAddressArg, currentBlockNumber *big.Int) error { +// InsertWatchedAddresses inserts the given addresses | storage slots in the database +func (sdi *StateDiffIndexer) InsertWatched(args []sdtypes.WatchAddressArg, currentBlockNumber *big.Int, kind sdtypes.WatchedAddressType) error { tx, err := sdi.dbWriter.db.Begin() if err != nil { return err @@ -565,8 +565,8 @@ func (sdi *StateDiffIndexer) InsertWatchedAddresses(args []sdtypes.WatchAddressA defer tx.Rollback() for _, arg := range args { - _, err = tx.Exec(`INSERT INTO eth.watched_addresses (address, created_at, watched_at) VALUES ($1, $2, $3) ON CONFLICT (address) DO NOTHING`, - arg.Address.Hex(), arg.CreatedAt, currentBlockNumber.Uint64()) + _, err = tx.Exec(`INSERT INTO eth.watched_addresses (address, kind, created_at, watched_at) VALUES ($1, $2, $3, $4) ON CONFLICT (address) DO NOTHING`, + arg.Address, kind.Int(), arg.CreatedAt, currentBlockNumber.Uint64()) if err != nil { return fmt.Errorf("error inserting watched_addresses entry: %v", err) } @@ -580,16 +580,16 @@ func (sdi *StateDiffIndexer) InsertWatchedAddresses(args []sdtypes.WatchAddressA return nil } -// RemoveWatchedAddresses removes the given addresses from the database -func (sdi *StateDiffIndexer) RemoveWatchedAddresses(addresses []common.Address) error { +// RemoveWatchedAddresses removes the given addresses | storage slots from the database +func (sdi *StateDiffIndexer) RemoveWatched(args []sdtypes.WatchAddressArg, kind sdtypes.WatchedAddressType) error { tx, err := sdi.dbWriter.db.Begin() if err != nil { return err } defer tx.Rollback() - for _, address := range addresses { - _, err = tx.Exec(`DELETE FROM eth.watched_addresses WHERE address = $1`, address.Hex()) + for _, arg := range args { + _, err = tx.Exec(`DELETE FROM eth.watched_addresses WHERE address = $1 AND kind = $2`, arg.Address, kind.Int()) if err != nil { return fmt.Errorf("error removing watched_addresses entry: %v", err) } @@ -603,21 +603,22 @@ func (sdi *StateDiffIndexer) RemoveWatchedAddresses(addresses []common.Address) return nil } -func (sdi *StateDiffIndexer) SetWatchedAddresses(args []sdtypes.WatchAddressArg, currentBlockNumber *big.Int) error { +// SetWatched clears and inserts the given addresses | storage slots in the database +func (sdi *StateDiffIndexer) SetWatched(args []sdtypes.WatchAddressArg, currentBlockNumber *big.Int, kind sdtypes.WatchedAddressType) error { tx, err := sdi.dbWriter.db.Begin() if err != nil { return err } defer tx.Rollback() - _, err = tx.Exec(`DELETE FROM eth.watched_addresses`) + _, err = tx.Exec(`DELETE FROM eth.watched_addresses WHERE kind = $1`, kind.Int()) if err != nil { return fmt.Errorf("error setting watched_addresses table: %v", err) } for _, arg := range args { - _, err = tx.Exec(`INSERT INTO eth.watched_addresses (address, created_at, watched_at) VALUES ($1, $2, $3) ON CONFLICT (address) DO NOTHING`, - arg.Address.Hex(), arg.CreatedAt, currentBlockNumber.Uint64()) + _, err = tx.Exec(`INSERT INTO eth.watched_addresses (address, kind, created_at, watched_at) VALUES ($1, $2, $3, $4) ON CONFLICT (address) DO NOTHING`, + arg.Address, kind.Int(), arg.CreatedAt, currentBlockNumber.Uint64()) if err != nil { return fmt.Errorf("error setting watched_addresses table: %v", err) } @@ -631,9 +632,9 @@ func (sdi *StateDiffIndexer) SetWatchedAddresses(args []sdtypes.WatchAddressArg, return nil } -// ClearWatchedAddresses clears all the addresses from the database -func (sdi *StateDiffIndexer) ClearWatchedAddresses() error { - _, err := sdi.dbWriter.db.Exec(`DELETE FROM eth.watched_addresses`) +// ClearWatchedAddresses clears all the addresses | storage slots from the database +func (sdi *StateDiffIndexer) ClearWatched(kind sdtypes.WatchedAddressType) error { + _, err := sdi.dbWriter.db.Exec(`DELETE FROM eth.watched_addresses WHERE kind = $1`, kind.Int()) if err != nil { return fmt.Errorf("error clearing watched_addresses table: %v", err) } diff --git a/statediff/service.go b/statediff/service.go index 17727f14a..75268596c 100644 --- a/statediff/service.go +++ b/statediff/service.go @@ -211,7 +211,7 @@ func New(stack *node.Node, ethServ *eth.Ethereum, cfg *ethconfig.Config, params stack.RegisterLifecycle(sds) stack.RegisterAPIs(sds.APIs()) - err = loadWatchedAddresses(db) + err = loadWatched(db) if err != nil { return err } @@ -734,7 +734,9 @@ func (sds *Service) writeStateDiffWithRetry(block *types.Block, parentRoot commo return err } -// Performs Add | Remove | Set | Clear operation on the watched addresses in writeLoopParams and the db with provided addresses +// Performs one of foll. operations on the watched addresses | storage slots in writeLoopParams and the db: +// AddAddresses | RemoveAddresses | SetAddresses | ClearAddresses +// AddStorageSlots | RemoveStorageSlots | SetStorageSlots | ClearStorageSlots func (sds *Service) WatchAddress(operation OperationType, args []WatchAddressArg) error { // lock writeLoopParams for a write writeLoopParams.Lock() @@ -744,51 +746,100 @@ func (sds *Service) WatchAddress(operation OperationType, args []WatchAddressArg currentBlockNumber := sds.BlockChain.CurrentBlock().Number() switch operation { - case Add: + case AddAddresses: addressesToRemove := []common.Address{} for _, arg := range args { // Check if address is already being watched // Throw a warning and continue if found - if containsAddress(writeLoopParams.WatchedAddresses, arg.Address) != -1 { - log.Warn("Address already being watched", "address", arg.Address.Hex()) - addressesToRemove = append(addressesToRemove, arg.Address) + address := common.HexToAddress(arg.Address) + if containsAddress(writeLoopParams.WatchedAddresses, address) != -1 { + log.Warn("Address already being watched", "address", arg.Address) + addressesToRemove = append(addressesToRemove, address) continue } } // remove already watched addresses - filteredArgs, filteredAddresses := filterArgs(args, addressesToRemove) + filteredArgs, filteredAddresses := filterAddressArgs(args, addressesToRemove) - err := sds.indexer.InsertWatchedAddresses(filteredArgs, currentBlockNumber) + err := sds.indexer.InsertWatched(filteredArgs, currentBlockNumber, WatchedAddress) if err != nil { return err } writeLoopParams.WatchedAddresses = append(writeLoopParams.WatchedAddresses, filteredAddresses...) - case Remove: + case RemoveAddresses: addresses := getAddresses(args) - err := sds.indexer.RemoveWatchedAddresses(addresses) + err := sds.indexer.RemoveWatched(args, WatchedAddress) if err != nil { return err } writeLoopParams.WatchedAddresses = removeAddresses(writeLoopParams.WatchedAddresses, addresses) - case Set: - err := sds.indexer.SetWatchedAddresses(args, currentBlockNumber) + case SetAddresses: + err := sds.indexer.SetWatched(args, currentBlockNumber, WatchedAddress) if err != nil { return err } addresses := getAddresses(args) writeLoopParams.WatchedAddresses = addresses - case Clear: - err := sds.indexer.ClearWatchedAddresses() + case ClearAddresses: + err := sds.indexer.ClearWatched(WatchedAddress) if err != nil { return err } writeLoopParams.WatchedAddresses = nil + + case AddStorageSlots: + storageSlotsToRemove := []common.Hash{} + for _, arg := range args { + // Check if address is already being watched + // Throw a warning and continue if found + storageSlot := common.HexToHash(arg.Address) + if containsStorageSlot(writeLoopParams.WatchedStorageSlots, storageSlot) != -1 { + log.Warn("StorageSlot already being watched", "storage slot", arg.Address) + storageSlotsToRemove = append(storageSlotsToRemove, storageSlot) + continue + } + } + + // remove already watched addresses + filteredArgs, filteredStorageSlots := filterStorageSlotArgs(args, storageSlotsToRemove) + + err := sds.indexer.InsertWatched(filteredArgs, currentBlockNumber, WatchedStorageSlot) + if err != nil { + return err + } + + writeLoopParams.WatchedStorageSlots = append(writeLoopParams.WatchedStorageSlots, filteredStorageSlots...) + case RemoveStorageSlots: + storageSlots := getStorageSlots(args) + + err := sds.indexer.RemoveWatched(args, WatchedStorageSlot) + if err != nil { + return err + } + + writeLoopParams.WatchedStorageSlots = removeStorageSlots(writeLoopParams.WatchedStorageSlots, storageSlots) + case SetStorageSlots: + err := sds.indexer.SetWatched(args, currentBlockNumber, WatchedStorageSlot) + if err != nil { + return err + } + + storageSlots := getStorageSlots(args) + writeLoopParams.WatchedStorageSlots = storageSlots + case ClearStorageSlots: + err := sds.indexer.ClearWatched(WatchedStorageSlot) + if err != nil { + return err + } + + writeLoopParams.WatchedStorageSlots = nil + default: return fmt.Errorf("Unexpected operation %s", operation) } diff --git a/statediff/testhelpers/mocks/service.go b/statediff/testhelpers/mocks/service.go index 686f0c2ba..f8d1b6b20 100644 --- a/statediff/testhelpers/mocks/service.go +++ b/statediff/testhelpers/mocks/service.go @@ -333,7 +333,7 @@ func sendNonBlockingQuit(id rpc.ID, sub statediff.Subscription) { } } -func (sds *MockStateDiffService) WatchAddress(operation statediff.OperationType, addresses []sdtypes.WatchAddressArg) error { +func (sds *MockStateDiffService) WatchAddress(operation statediff.OperationType, args []sdtypes.WatchAddressArg) error { return nil } diff --git a/statediff/types.go b/statediff/types.go index b179f1a00..69eaf0d0a 100644 --- a/statediff/types.go +++ b/statediff/types.go @@ -123,8 +123,13 @@ type accountWrapper struct { type OperationType string const ( - Add OperationType = "Add" - Remove OperationType = "Remove" - Set OperationType = "Set" - Clear OperationType = "Clear" + AddAddresses OperationType = "AddAddresses" + RemoveAddresses OperationType = "RemoveAddresses" + SetAddresses OperationType = "SetAddresses" + ClearAddresses OperationType = "ClearAddresses" + + AddStorageSlots OperationType = "AddStorageSlots" + RemoveStorageSlots OperationType = "RemoveStorageSlots" + SetStorageSlots OperationType = "SetStorageSlots" + ClearStorageSlots OperationType = "ClearStorageSlots" ) diff --git a/statediff/types/types.go b/statediff/types/types.go index cab4cb880..18d9d0b31 100644 --- a/statediff/types/types.go +++ b/statediff/types/types.go @@ -77,6 +77,26 @@ type CodeSink func(CodeAndCodeHash) error // WatchAddressArg is a arg type for WatchAddress API type WatchAddressArg struct { - Address common.Address + // Address represents common.Address | common.Hash + Address string CreatedAt uint64 } + +// WatchedAddressType for denoting watched: address | storage slot +type WatchedAddressType string + +const ( + WatchedAddress WatchedAddressType = "WatchedAddress" + WatchedStorageSlot WatchedAddressType = "WatchedStorageSlot" +) + +func (n WatchedAddressType) Int() int { + switch n { + case WatchedAddress: + return 0 + case WatchedStorageSlot: + return 1 + default: + return -1 + } +}