From 98c52a02a846076703dd71289528d4eac2658396 Mon Sep 17 00:00:00 2001 From: prathamesh0 Date: Wed, 12 Jan 2022 15:09:14 +0530 Subject: [PATCH 01/17] Statediff API to add an address to be watched --- statediff/api.go | 5 ++ statediff/api_test.go | 101 ++++++++++++++++++++++++ statediff/service.go | 105 +++++++++++++++++++++++++ statediff/testhelpers/mocks/service.go | 8 ++ 4 files changed, 219 insertions(+) create mode 100644 statediff/api_test.go diff --git a/statediff/api.go b/statediff/api.go index 923a0073f..2271ae6d0 100644 --- a/statediff/api.go +++ b/statediff/api.go @@ -149,3 +149,8 @@ func (api *PublicStateDiffAPI) WriteStateDiffAt(ctx context.Context, blockNumber func (api *PublicStateDiffAPI) WriteStateDiffFor(ctx context.Context, blockHash common.Hash, params Params) error { return api.sds.WriteStateDiffFor(blockHash, params) } + +// WatchAddress adds the given address to a list of watched addresses to which the direct statediff process is restricted +func (api *PublicStateDiffAPI) WatchAddress(address common.Address) error { + return api.sds.WatchAddress(address) +} diff --git a/statediff/api_test.go b/statediff/api_test.go new file mode 100644 index 000000000..9ea8a9ede --- /dev/null +++ b/statediff/api_test.go @@ -0,0 +1,101 @@ +// Copyright 2019 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package statediff_test + +import ( + "fmt" + "os" + "reflect" + "testing" + + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/statediff" +) + +func init() { + if os.Getenv("MODE") != "statediff" { + fmt.Println("Skipping statediff test") + os.Exit(0) + } +} + +var ( + address1Hex = "0x1ca7c995f8eF0A2989BbcE08D5B7Efe50A584aa1" + address2Hex = "0xe799eE0191652c864E49F3A3344CE62535B15afe" + address1 = common.HexToAddress(address1Hex) + address2 = common.HexToAddress(address2Hex) + + watchedAddresses0 []common.Address + watchedAddresses1 = []common.Address{address1} + watchedAddresses2 = []common.Address{address1, address2} + + expectedError = fmt.Errorf("Address %s already watched", address1Hex) + + service = statediff.Service{} +) + +func TestWatchAddress(t *testing.T) { + watchedAddresses := service.GetWathchedAddresses() + if !reflect.DeepEqual(watchedAddresses, watchedAddresses0) { + t.Error("Test failure:", t.Name()) + t.Logf("Actual watched addresses not equal expected watched addresses.\nactual: %+v\nexpected: %+v", watchedAddresses, watchedAddresses0) + } + + testWatchUnwatchedAddress(t) + testWatchWatchedAddress(t) +} + +func testWatchUnwatchedAddress(t *testing.T) { + err := service.WatchAddress(address1) + if err != nil { + t.Error("Test failure:", t.Name()) + t.Logf("Unexpected error %s thrown on an attempt to watch an unwatched address.", err.Error()) + } + watchedAddresses := service.GetWathchedAddresses() + if !reflect.DeepEqual(watchedAddresses, watchedAddresses1) { + t.Error("Test failure:", t.Name()) + t.Logf("Actual watched addresses not equal expected watched addresses.\nactual: %+v\nexpected: %+v", watchedAddresses, watchedAddresses1) + } + + err = service.WatchAddress(address2) + if err != nil { + t.Error("Test failure:", t.Name()) + t.Logf("Unexpected error %s thrown on an attempt to watch an unwatched address.", err.Error()) + } + watchedAddresses = service.GetWathchedAddresses() + if !reflect.DeepEqual(watchedAddresses, watchedAddresses2) { + t.Error("Test failure:", t.Name()) + t.Logf("Actual watched addresses not equal expected watched addresses.\nactual: %+v\nexpected: %+v", watchedAddresses, watchedAddresses2) + } +} + +func testWatchWatchedAddress(t *testing.T) { + err := service.WatchAddress(address1) + if err == nil { + t.Error("Test failure:", t.Name()) + t.Logf("Expected error %s not thrown on an attempt to watch an already watched address.", expectedError.Error()) + } + if err.Error() != expectedError.Error() { + t.Error("Test failure:", t.Name()) + t.Logf("Actual thrown error not equal expected error.\nactual: %+v\nexpected: %+v", err.Error(), expectedError.Error()) + } + watchedAddresses := service.GetWathchedAddresses() + if !reflect.DeepEqual(watchedAddresses, watchedAddresses2) { + t.Error("Test failure:", t.Name()) + t.Logf("Actual watched addresses not equal expected watched addresses.\nactual: %+v\nexpected: %+v", watchedAddresses, watchedAddresses2) + } +} diff --git a/statediff/service.go b/statediff/service.go index 6411ba68e..9d8b6ecf0 100644 --- a/statediff/service.go +++ b/statediff/service.go @@ -18,7 +18,11 @@ package statediff import ( "bytes" + "encoding/json" + "fmt" + "io/ioutil" "math/big" + "os" "strconv" "strings" "sync" @@ -54,6 +58,9 @@ const ( deadlockDetected = "deadlock detected" // 40P01 https://www.postgresql.org/docs/current/errcodes-appendix.html ) +// TODO: Take the watched addresses file path as a CLI arg. +const watchedAddressesFile = "./watched-addresses.json" + var writeLoopParams = Params{ IntermediateStateNodes: true, IntermediateStorageNodes: true, @@ -101,6 +108,10 @@ type IService interface { WriteStateDiffFor(blockHash common.Hash, params Params) error // Event loop for progressively processing and writing diffs directly to DB WriteLoop(chainEventCh chan core.ChainEvent) + // Method to add an address to be watched to write loop params + WatchAddress(address common.Address) error + // Method to get currently watched addresses from write loop params + GetWathchedAddresses() []common.Address } // Wraps consructor parameters @@ -200,6 +211,12 @@ func New(stack *node.Node, ethServ *eth.Ethereum, cfg *ethconfig.Config, params } stack.RegisterLifecycle(sds) stack.RegisterAPIs(sds.APIs()) + + err := loadWatchedAddresses() + if err != nil { + return err + } + return nil } @@ -713,3 +730,91 @@ func (sds *Service) writeStateDiffWithRetry(block *types.Block, parentRoot commo } return err } + +// Adds the provided address to the list of watched addresses in write loop params and to the watched addresses file +func (sds *Service) WatchAddress(address common.Address) error { + // Check if address is already being watched + if containsAddress(writeLoopParams.WatchedAddresses, address) { + return fmt.Errorf("Address %s already watched", address) + } + + // Check if the watched addresses file exists + fileExists, err := doesFileExist(watchedAddressesFile) + if err != nil { + return err + } + + // Create the watched addresses file if doesn't exist + if !fileExists { + _, err := os.Create(watchedAddressesFile) + if err != nil { + return err + } + } + + watchedAddresses := append(writeLoopParams.WatchedAddresses, address) + + // Write the updated list of watched address to a json file + content, err := json.Marshal(watchedAddresses) + err = ioutil.WriteFile(watchedAddressesFile, content, 0644) + if err != nil { + return err + } + + // Update the in-memory params as well + writeLoopParams.WatchedAddresses = watchedAddresses + + return nil +} + +// Gets currently watched addresses from the in-memory write loop params +func (sds *Service) GetWathchedAddresses() []common.Address { + return writeLoopParams.WatchedAddresses +} + +// loadWatchedAddresses is used to load watched addresses to the in-memory write loop params from a json file if it exists +func loadWatchedAddresses() error { + // Check if the watched addresses file exists + fileExists, err := doesFileExist(watchedAddressesFile) + if err != nil { + return err + } + + if fileExists { + content, err := ioutil.ReadFile(watchedAddressesFile) + if err != nil { + return err + } + + var watchedAddresses []common.Address + err = json.Unmarshal(content, &watchedAddresses) + if err != nil { + return err + } + + writeLoopParams.WatchedAddresses = watchedAddresses + } + + return nil +} + +// containsAddress is used to check if an address is present in the provided list of watched addresses +func containsAddress(watchedAddresses []common.Address, address common.Address) bool { + for _, addr := range watchedAddresses { + if addr == address { + return true + } + } + return false +} + +// doesFileExist is used to check if file at a given path exists +func doesFileExist(path string) (bool, error) { + _, err := os.Stat(path) + if err == nil { + return true, nil + } else if os.IsNotExist(err) { + return false, nil + } + return false, err +} diff --git a/statediff/testhelpers/mocks/service.go b/statediff/testhelpers/mocks/service.go index f10017df4..3c47cbea6 100644 --- a/statediff/testhelpers/mocks/service.go +++ b/statediff/testhelpers/mocks/service.go @@ -332,3 +332,11 @@ func sendNonBlockingQuit(id rpc.ID, sub statediff.Subscription) { log.Info("unable to close subscription %s; channel has no receiver", id) } } + +func (sds *MockStateDiffService) WatchAddress(address common.Address) error { + return nil +} + +func (sds *MockStateDiffService) GetWathchedAddresses() []common.Address { + return []common.Address{} +} -- 2.45.2 From 5fa002c0d023f1e9f23d84b55765430779ecc7ce Mon Sep 17 00:00:00 2001 From: prathamesh0 Date: Thu, 13 Jan 2022 18:30:34 +0530 Subject: [PATCH 02/17] Statediff API to change addresses being watched in direct indexing --- statediff/api.go | 4 +- statediff/api_test.go | 7 +- statediff/helpers.go | 52 +++++++++- statediff/indexer/indexer.go | 64 ++++++++++++ statediff/service.go | 129 +++++++++---------------- statediff/testhelpers/mocks/service.go | 2 +- statediff/types.go | 10 ++ 7 files changed, 179 insertions(+), 89 deletions(-) diff --git a/statediff/api.go b/statediff/api.go index 2271ae6d0..3dfed6c26 100644 --- a/statediff/api.go +++ b/statediff/api.go @@ -151,6 +151,6 @@ func (api *PublicStateDiffAPI) WriteStateDiffFor(ctx context.Context, blockHash } // WatchAddress adds the given address to a list of watched addresses to which the direct statediff process is restricted -func (api *PublicStateDiffAPI) WatchAddress(address common.Address) error { - return api.sds.WatchAddress(address) +func (api *PublicStateDiffAPI) WatchAddress(operation OperationType, addresses []common.Address) error { + return api.sds.WatchAddress(operation, addresses) } diff --git a/statediff/api_test.go b/statediff/api_test.go index 9ea8a9ede..1e2bc2898 100644 --- a/statediff/api_test.go +++ b/statediff/api_test.go @@ -48,6 +48,7 @@ var ( service = statediff.Service{} ) +// TODO: Update tests for the updated API func TestWatchAddress(t *testing.T) { watchedAddresses := service.GetWathchedAddresses() if !reflect.DeepEqual(watchedAddresses, watchedAddresses0) { @@ -60,7 +61,7 @@ func TestWatchAddress(t *testing.T) { } func testWatchUnwatchedAddress(t *testing.T) { - err := service.WatchAddress(address1) + err := service.WatchAddress(statediff.Add, []common.Address{address1}) if err != nil { t.Error("Test failure:", t.Name()) t.Logf("Unexpected error %s thrown on an attempt to watch an unwatched address.", err.Error()) @@ -71,7 +72,7 @@ func testWatchUnwatchedAddress(t *testing.T) { t.Logf("Actual watched addresses not equal expected watched addresses.\nactual: %+v\nexpected: %+v", watchedAddresses, watchedAddresses1) } - err = service.WatchAddress(address2) + err = service.WatchAddress(statediff.Add, []common.Address{address2}) if err != nil { t.Error("Test failure:", t.Name()) t.Logf("Unexpected error %s thrown on an attempt to watch an unwatched address.", err.Error()) @@ -84,7 +85,7 @@ func testWatchUnwatchedAddress(t *testing.T) { } func testWatchWatchedAddress(t *testing.T) { - err := service.WatchAddress(address1) + err := service.WatchAddress(statediff.Add, []common.Address{address1}) if err == nil { t.Error("Test failure:", t.Name()) t.Logf("Expected error %s not thrown on an attempt to watch an already watched address.", expectedError.Error()) diff --git a/statediff/helpers.go b/statediff/helpers.go index eb5060c51..1a766aae3 100644 --- a/statediff/helpers.go +++ b/statediff/helpers.go @@ -20,8 +20,12 @@ package statediff import ( + "fmt" "sort" "strings" + + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/statediff/indexer/postgres" ) func sortKeys(data AccountMap) []string { @@ -69,5 +73,51 @@ func findIntersection(a, b []string) []string { } } } - +} + +// loadWatchedAddresses is used to load watched addresses to the in-memory write loop params from the db +func loadWatchedAddresses(db *postgres.DB) error { + rows, err := db.Query("SELECT address FROM eth.watched_addresses") + if err != nil { + return fmt.Errorf("error loading watched addresses: %v", err) + } + + var watchedAddresses []common.Address + for rows.Next() { + var addressHex string + err := rows.Scan(&addressHex) + if err != nil { + return err + } + + watchedAddresses = append(watchedAddresses, common.HexToAddress(addressHex)) + } + + writeLoopParams.WatchedAddresses = watchedAddresses + + return nil +} + +func removeWatchedAddresses(watchedAddresses []common.Address, addressesToRemove []common.Address) []common.Address { + addresses := make([]common.Address, len(addressesToRemove)) + copy(addresses, watchedAddresses) + + for _, address := range addressesToRemove { + if idx := containsAddress(addresses, address); idx != -1 { + addresses = append(addresses[:idx], addresses[idx+1:]...) + } + } + + return addresses +} + +// containsAddress is used to check if an address is present in the provided list of watched addresses +// return the index if found else -1 +func containsAddress(watchedAddresses []common.Address, address common.Address) int { + for idx, addr := range watchedAddresses { + if addr == address { + return idx + } + } + return -1 } diff --git a/statediff/indexer/indexer.go b/statediff/indexer/indexer.go index 60d69f932..8d628d31e 100644 --- a/statediff/indexer/indexer.go +++ b/statediff/indexer/indexer.go @@ -59,6 +59,9 @@ type Indexer interface { PushStateNode(tx *BlockTx, stateNode sdtypes.StateNode) error PushCodeAndCodeHash(tx *BlockTx, codeAndCodeHash sdtypes.CodeAndCodeHash) error ReportDBMetrics(delay time.Duration, quit <-chan bool) + InsertWatchedAddresses(addresses []common.Address, currentBlock *big.Int) error + RemoveWatchedAddresses(addresses []common.Address) error + ClearWatchedAddresses() error } // StateDiffIndexer satisfies the Indexer interface for ethereum statediff objects @@ -549,3 +552,64 @@ func (sdi *StateDiffIndexer) PushCodeAndCodeHash(tx *BlockTx, codeAndCodeHash sd } return nil } + +func (sdi *StateDiffIndexer) InsertWatchedAddresses(addresses []common.Address, currentBlock *big.Int) error { + tx, err := sdi.dbWriter.db.Begin() + if err != nil { + return err + } + + for _, address := range addresses { + _, err = tx.Exec(`INSERT INTO eth.watched_addresses (address, added_at) VALUES ($1, $2)`, address, currentBlock) + if err != nil { + return fmt.Errorf("error inserting watched_addresses entry: %v", err) + } + } + + err = tx.Commit() + if err != nil { + return err + } + + return nil +} + +func (sdi *StateDiffIndexer) RemoveWatchedAddresses(addresses []common.Address) error { + tx, err := sdi.dbWriter.db.Begin() + if err != nil { + return err + } + + for _, address := range addresses { + _, err = tx.Exec(`DELETE FROM eth.watched_addresses WHERE address = $1`, address) + if err != nil { + return fmt.Errorf("error removing watched_addresses entry: %v", err) + } + } + + err = tx.Commit() + if err != nil { + return err + } + + return nil +} + +func (sdi *StateDiffIndexer) ClearWatchedAddresses() error { + tx, err := sdi.dbWriter.db.Begin() + if err != nil { + return err + } + + _, err = tx.Exec(`DELETE FROM eth.watched_addresses`) + if err != nil { + return fmt.Errorf("error clearing watched_addresses table: %v", err) + } + + err = tx.Commit() + if err != nil { + return err + } + + return nil +} diff --git a/statediff/service.go b/statediff/service.go index 9d8b6ecf0..ba599b05f 100644 --- a/statediff/service.go +++ b/statediff/service.go @@ -18,11 +18,8 @@ package statediff import ( "bytes" - "encoding/json" "fmt" - "io/ioutil" "math/big" - "os" "strconv" "strings" "sync" @@ -58,9 +55,6 @@ const ( deadlockDetected = "deadlock detected" // 40P01 https://www.postgresql.org/docs/current/errcodes-appendix.html ) -// TODO: Take the watched addresses file path as a CLI arg. -const watchedAddressesFile = "./watched-addresses.json" - var writeLoopParams = Params{ IntermediateStateNodes: true, IntermediateStorageNodes: true, @@ -109,7 +103,7 @@ type IService interface { // Event loop for progressively processing and writing diffs directly to DB WriteLoop(chainEventCh chan core.ChainEvent) // Method to add an address to be watched to write loop params - WatchAddress(address common.Address) error + WatchAddress(operation OperationType, addresses []common.Address) error // Method to get currently watched addresses from write loop params GetWathchedAddresses() []common.Address } @@ -170,6 +164,7 @@ func NewBlockCache(max uint) blockCache { func New(stack *node.Node, ethServ *eth.Ethereum, cfg *ethconfig.Config, params ServiceParams) error { blockChain := ethServ.BlockChain() var indexer ind.Indexer + var db *postgres.DB quitCh := make(chan bool) if params.DBParams != nil { info := nodeinfo.Info{ @@ -212,7 +207,7 @@ func New(stack *node.Node, ethServ *eth.Ethereum, cfg *ethconfig.Config, params stack.RegisterLifecycle(sds) stack.RegisterAPIs(sds.APIs()) - err := loadWatchedAddresses() + err := loadWatchedAddresses(db) if err != nil { return err } @@ -731,39 +726,56 @@ func (sds *Service) writeStateDiffWithRetry(block *types.Block, parentRoot commo return err } -// Adds the provided address to the list of watched addresses in write loop params and to the watched addresses file -func (sds *Service) WatchAddress(address common.Address) error { - // Check if address is already being watched - if containsAddress(writeLoopParams.WatchedAddresses, address) { - return fmt.Errorf("Address %s already watched", address) - } +// Performs Add | Remove | Set | Clear operation on the watched addresses in writeLoopParams and the db with provided addresses +func (sds *Service) WatchAddress(operation OperationType, addresses []common.Address) error { + // check operation + switch operation { + case Add: + for _, address := range addresses { + // Check if address is already being watched + if containsAddress(writeLoopParams.WatchedAddresses, address) != -1 { + return fmt.Errorf("Address %s already watched", address) + } + } - // Check if the watched addresses file exists - fileExists, err := doesFileExist(watchedAddressesFile) - if err != nil { - return err - } - - // Create the watched addresses file if doesn't exist - if !fileExists { - _, err := os.Create(watchedAddressesFile) + // TODO: Make sure WriteLoop doesn't call statediffing before the params are updated for the current block + // TODO: Get the current block + err := sds.indexer.InsertWatchedAddresses(addresses, common.Big1) if err != nil { return err } + + writeLoopParams.WatchedAddresses = append(writeLoopParams.WatchedAddresses, addresses...) + case Remove: + err := sds.indexer.RemoveWatchedAddresses(addresses) + if err != nil { + return err + } + + removeWatchedAddresses(writeLoopParams.WatchedAddresses, addresses) + case Set: + err := sds.indexer.ClearWatchedAddresses() + if err != nil { + return err + } + + err = sds.indexer.InsertWatchedAddresses(addresses, common.Big1) + if err != nil { + return err + } + + writeLoopParams.WatchedAddresses = addresses + case Clear: + err := sds.indexer.ClearWatchedAddresses() + if err != nil { + return err + } + + writeLoopParams.WatchedAddresses = nil + default: + return fmt.Errorf("Unexpected operation %s", operation) } - watchedAddresses := append(writeLoopParams.WatchedAddresses, address) - - // Write the updated list of watched address to a json file - content, err := json.Marshal(watchedAddresses) - err = ioutil.WriteFile(watchedAddressesFile, content, 0644) - if err != nil { - return err - } - - // Update the in-memory params as well - writeLoopParams.WatchedAddresses = watchedAddresses - return nil } @@ -771,50 +783,3 @@ func (sds *Service) WatchAddress(address common.Address) error { func (sds *Service) GetWathchedAddresses() []common.Address { return writeLoopParams.WatchedAddresses } - -// loadWatchedAddresses is used to load watched addresses to the in-memory write loop params from a json file if it exists -func loadWatchedAddresses() error { - // Check if the watched addresses file exists - fileExists, err := doesFileExist(watchedAddressesFile) - if err != nil { - return err - } - - if fileExists { - content, err := ioutil.ReadFile(watchedAddressesFile) - if err != nil { - return err - } - - var watchedAddresses []common.Address - err = json.Unmarshal(content, &watchedAddresses) - if err != nil { - return err - } - - writeLoopParams.WatchedAddresses = watchedAddresses - } - - return nil -} - -// containsAddress is used to check if an address is present in the provided list of watched addresses -func containsAddress(watchedAddresses []common.Address, address common.Address) bool { - for _, addr := range watchedAddresses { - if addr == address { - return true - } - } - return false -} - -// doesFileExist is used to check if file at a given path exists -func doesFileExist(path string) (bool, error) { - _, err := os.Stat(path) - if err == nil { - return true, nil - } else if os.IsNotExist(err) { - return false, nil - } - return false, err -} diff --git a/statediff/testhelpers/mocks/service.go b/statediff/testhelpers/mocks/service.go index 3c47cbea6..943720094 100644 --- a/statediff/testhelpers/mocks/service.go +++ b/statediff/testhelpers/mocks/service.go @@ -333,7 +333,7 @@ func sendNonBlockingQuit(id rpc.ID, sub statediff.Subscription) { } } -func (sds *MockStateDiffService) WatchAddress(address common.Address) error { +func (sds *MockStateDiffService) WatchAddress(operation statediff.OperationType, addresses []common.Address) error { return nil } diff --git a/statediff/types.go b/statediff/types.go index ef8256041..fd0877b02 100644 --- a/statediff/types.go +++ b/statediff/types.go @@ -111,3 +111,13 @@ type accountWrapper struct { NodeValue []byte LeafKey []byte } + +// OperationType for type of WatchAddress operation +type OperationType string + +const ( + Add OperationType = "Add" + Remove OperationType = "Remove" + Set OperationType = "Set" + Clear OperationType = "Clear" +) -- 2.45.2 From bbb1759886dbeafc2e4e191150419eafb1ffa113 Mon Sep 17 00:00:00 2001 From: prathamesh0 Date: Mon, 17 Jan 2022 13:50:52 +0530 Subject: [PATCH 03/17] Use current block and mutex while changing watched addresses --- statediff/api.go | 2 +- statediff/helpers.go | 21 ++++---- statediff/indexer/indexer.go | 20 +++----- statediff/service.go | 61 +++++++++++++++-------- statediff/testhelpers/mocks/blockchain.go | 5 ++ statediff/types.go | 7 +++ 6 files changed, 73 insertions(+), 43 deletions(-) diff --git a/statediff/api.go b/statediff/api.go index 3dfed6c26..d6cdfd008 100644 --- a/statediff/api.go +++ b/statediff/api.go @@ -150,7 +150,7 @@ func (api *PublicStateDiffAPI) WriteStateDiffFor(ctx context.Context, blockHash return api.sds.WriteStateDiffFor(blockHash, params) } -// WatchAddress adds the given address to a list of watched addresses to which the direct statediff process is restricted +// WatchAddress changes the list of watched addresses to which the direct indexing is restricted according to given operation func (api *PublicStateDiffAPI) WatchAddress(operation OperationType, addresses []common.Address) error { return api.sds.WatchAddress(operation, addresses) } diff --git a/statediff/helpers.go b/statediff/helpers.go index 1a766aae3..1dccbfbc2 100644 --- a/statediff/helpers.go +++ b/statediff/helpers.go @@ -93,28 +93,31 @@ func loadWatchedAddresses(db *postgres.DB) error { watchedAddresses = append(watchedAddresses, common.HexToAddress(addressHex)) } + writeLoopParams.mu.Lock() writeLoopParams.WatchedAddresses = watchedAddresses + writeLoopParams.mu.Unlock() return nil } -func removeWatchedAddresses(watchedAddresses []common.Address, addressesToRemove []common.Address) []common.Address { - addresses := make([]common.Address, len(addressesToRemove)) - copy(addresses, watchedAddresses) +// removeAddresses is used to remove given addresses from a list of addresses +func removeAddresses(addresses []common.Address, addressesToRemove []common.Address) []common.Address { + addressesCopy := make([]common.Address, len(addresses)) + copy(addressesCopy, addresses) for _, address := range addressesToRemove { - if idx := containsAddress(addresses, address); idx != -1 { - addresses = append(addresses[:idx], addresses[idx+1:]...) + if idx := containsAddress(addressesCopy, address); idx != -1 { + addressesCopy = append(addressesCopy[:idx], addressesCopy[idx+1:]...) } } - return addresses + return addressesCopy } -// containsAddress is used to check if an address is present in the provided list of watched addresses +// containsAddress is used to check if an address is present in the provided list of addresses // return the index if found else -1 -func containsAddress(watchedAddresses []common.Address, address common.Address) int { - for idx, addr := range watchedAddresses { +func containsAddress(addresses []common.Address, address common.Address) int { + for idx, addr := range addresses { if addr == address { return idx } diff --git a/statediff/indexer/indexer.go b/statediff/indexer/indexer.go index 8d628d31e..fd3beec73 100644 --- a/statediff/indexer/indexer.go +++ b/statediff/indexer/indexer.go @@ -553,6 +553,7 @@ func (sdi *StateDiffIndexer) PushCodeAndCodeHash(tx *BlockTx, codeAndCodeHash sd return nil } +// InsertWatchedAddresses inserts the given addresses in the database func (sdi *StateDiffIndexer) InsertWatchedAddresses(addresses []common.Address, currentBlock *big.Int) error { tx, err := sdi.dbWriter.db.Begin() if err != nil { @@ -560,7 +561,8 @@ func (sdi *StateDiffIndexer) InsertWatchedAddresses(addresses []common.Address, } for _, address := range addresses { - _, err = tx.Exec(`INSERT INTO eth.watched_addresses (address, added_at) VALUES ($1, $2)`, address, currentBlock) + _, err = tx.Exec(`INSERT INTO eth.watched_addresses (address, added_at)VALUES ($1, $2) ON CONFLICT (address) DO NOTHING`, + address.Hex(), currentBlock.Uint64()) if err != nil { return fmt.Errorf("error inserting watched_addresses entry: %v", err) } @@ -574,6 +576,7 @@ func (sdi *StateDiffIndexer) InsertWatchedAddresses(addresses []common.Address, return nil } +// RemoveWatchedAddresses removes the given addresses from the database func (sdi *StateDiffIndexer) RemoveWatchedAddresses(addresses []common.Address) error { tx, err := sdi.dbWriter.db.Begin() if err != nil { @@ -581,7 +584,7 @@ func (sdi *StateDiffIndexer) RemoveWatchedAddresses(addresses []common.Address) } for _, address := range addresses { - _, err = tx.Exec(`DELETE FROM eth.watched_addresses WHERE address = $1`, address) + _, err = tx.Exec(`DELETE FROM eth.watched_addresses WHERE address = $1`, address.Hex()) if err != nil { return fmt.Errorf("error removing watched_addresses entry: %v", err) } @@ -595,21 +598,12 @@ func (sdi *StateDiffIndexer) RemoveWatchedAddresses(addresses []common.Address) return nil } +// ClearWatchedAddresses clears all the addresses from the database func (sdi *StateDiffIndexer) ClearWatchedAddresses() error { - tx, err := sdi.dbWriter.db.Begin() - if err != nil { - return err - } - - _, err = tx.Exec(`DELETE FROM eth.watched_addresses`) + _, err := sdi.dbWriter.db.Exec(`DELETE FROM eth.watched_addresses`) if err != nil { return fmt.Errorf("error clearing watched_addresses table: %v", err) } - err = tx.Commit() - if err != nil { - return err - } - return nil } diff --git a/statediff/service.go b/statediff/service.go index ba599b05f..446388f66 100644 --- a/statediff/service.go +++ b/statediff/service.go @@ -55,13 +55,15 @@ const ( deadlockDetected = "deadlock detected" // 40P01 https://www.postgresql.org/docs/current/errcodes-appendix.html ) -var writeLoopParams = Params{ - IntermediateStateNodes: true, - IntermediateStorageNodes: true, - IncludeBlock: true, - IncludeReceipts: true, - IncludeTD: true, - IncludeCode: true, +var writeLoopParams = ParamsWithMutex{ + Params: Params{ + IntermediateStateNodes: true, + IntermediateStorageNodes: true, + IncludeBlock: true, + IncludeReceipts: true, + IncludeTD: true, + IncludeCode: true, + }, } var statediffMetrics = RegisterStatediffMetrics(metrics.DefaultRegistry) @@ -74,6 +76,7 @@ type blockChain interface { GetTd(hash common.Hash, number uint64) *big.Int UnlockTrie(root common.Hash) StateCache() state.Database + CurrentBlock() *types.Block } // IService is the state-diffing service interface @@ -102,7 +105,7 @@ type IService interface { WriteStateDiffFor(blockHash common.Hash, params Params) error // Event loop for progressively processing and writing diffs directly to DB WriteLoop(chainEventCh chan core.ChainEvent) - // Method to add an address to be watched to write loop params + // Method to change the addresses being watched in write loop params WatchAddress(operation OperationType, addresses []common.Address) error // Method to get currently watched addresses from write loop params GetWathchedAddresses() []common.Address @@ -165,6 +168,7 @@ func New(stack *node.Node, ethServ *eth.Ethereum, cfg *ethconfig.Config, params blockChain := ethServ.BlockChain() var indexer ind.Indexer var db *postgres.DB + var err error quitCh := make(chan bool) if params.DBParams != nil { info := nodeinfo.Info{ @@ -176,7 +180,7 @@ func New(stack *node.Node, ethServ *eth.Ethereum, cfg *ethconfig.Config, params } // TODO: pass max idle, open, lifetime? - db, err := postgres.NewDB(params.DBParams.ConnectionURL, postgres.ConnectionConfig{}, info) + db, err = postgres.NewDB(params.DBParams.ConnectionURL, postgres.ConnectionConfig{}, info) if err != nil { return err } @@ -207,7 +211,7 @@ func New(stack *node.Node, ethServ *eth.Ethereum, cfg *ethconfig.Config, params stack.RegisterLifecycle(sds) stack.RegisterAPIs(sds.APIs()) - err := loadWatchedAddresses(db) + err = loadWatchedAddresses(db) if err != nil { return err } @@ -290,7 +294,9 @@ func (sds *Service) WriteLoop(chainEventCh chan core.ChainEvent) { func (sds *Service) writeGenesisStateDiff(currBlock *types.Block, workerId uint) { // For genesis block we need to return the entire state trie hence we diff it with an empty trie. log.Info("Writing state diff", "block height", genesisBlockNumber, "worker", workerId) - err := sds.writeStateDiffWithRetry(currBlock, common.Hash{}, writeLoopParams) + writeLoopParams.mu.RLock() + err := sds.writeStateDiffWithRetry(currBlock, common.Hash{}, writeLoopParams.Params) + writeLoopParams.mu.RUnlock() if err != nil { log.Error("statediff.Service.WriteLoop: processing error", "block height", genesisBlockNumber, "error", err.Error(), "worker", workerId) @@ -319,7 +325,9 @@ func (sds *Service) writeLoopWorker(params workerParams) { } log.Info("Writing state diff", "block height", currentBlock.Number().Uint64(), "worker", params.id) - err := sds.writeStateDiffWithRetry(currentBlock, parentBlock.Root(), writeLoopParams) + writeLoopParams.mu.RLock() + err := sds.writeStateDiffWithRetry(currentBlock, parentBlock.Root(), writeLoopParams.Params) + writeLoopParams.mu.RUnlock() if err != nil { log.Error("statediff.Service.WriteLoop: processing error", "block height", currentBlock.Number().Uint64(), "error", err.Error(), "worker", params.id) continue @@ -555,7 +563,7 @@ func (sds *Service) Start() error { go sds.Loop(chainEventCh) if sds.enableWriteLoop { - log.Info("Starting statediff DB write loop", "params", writeLoopParams) + log.Info("Starting statediff DB write loop", "params", writeLoopParams.Params) chainEventCh := make(chan core.ChainEvent, chainEventChanSize) go sds.WriteLoop(chainEventCh) } @@ -728,19 +736,32 @@ func (sds *Service) writeStateDiffWithRetry(block *types.Block, parentRoot commo // Performs Add | Remove | Set | Clear operation on the watched addresses in writeLoopParams and the db with provided addresses func (sds *Service) WatchAddress(operation OperationType, addresses []common.Address) error { - // check operation + // lock writeLoopParams for a write + writeLoopParams.mu.Lock() + defer writeLoopParams.mu.Unlock() + + // get the current block number + currentBlock := sds.BlockChain.CurrentBlock() + currentBlockNumber := currentBlock.Number() + switch operation { case Add: + addressesToRemove := []common.Address{} for _, address := range addresses { // Check if address is already being watched + // Throw a warning and continue if found if containsAddress(writeLoopParams.WatchedAddresses, address) != -1 { - return fmt.Errorf("Address %s already watched", address) + // log.Warn(fmt.Sprint("Address ", address.Hex(), " already being watched")) + log.Warn("Address already being watched", "address", address.Hex()) + addressesToRemove = append(addressesToRemove, address) + continue } } - // TODO: Make sure WriteLoop doesn't call statediffing before the params are updated for the current block - // TODO: Get the current block - err := sds.indexer.InsertWatchedAddresses(addresses, common.Big1) + // remove already watched addresses + addresses = removeAddresses(addresses, addressesToRemove) + + err := sds.indexer.InsertWatchedAddresses(addresses, currentBlockNumber) if err != nil { return err } @@ -752,14 +773,14 @@ func (sds *Service) WatchAddress(operation OperationType, addresses []common.Add return err } - removeWatchedAddresses(writeLoopParams.WatchedAddresses, addresses) + writeLoopParams.WatchedAddresses = removeAddresses(writeLoopParams.WatchedAddresses, addresses) case Set: err := sds.indexer.ClearWatchedAddresses() if err != nil { return err } - err = sds.indexer.InsertWatchedAddresses(addresses, common.Big1) + err = sds.indexer.InsertWatchedAddresses(addresses, currentBlockNumber) if err != nil { return err } diff --git a/statediff/testhelpers/mocks/blockchain.go b/statediff/testhelpers/mocks/blockchain.go index b4b1f3694..f2834a4a8 100644 --- a/statediff/testhelpers/mocks/blockchain.go +++ b/statediff/testhelpers/mocks/blockchain.go @@ -128,6 +128,11 @@ func (bc *BlockChain) GetTd(hash common.Hash, blockNum uint64) *big.Int { return nil } +// CurrentBlock mock method +func (bc *BlockChain) CurrentBlock() *types.Block { + return nil +} + func (bc *BlockChain) SetTd(hash common.Hash, blockNum uint64, td *big.Int) { if bc.TDByHash == nil { bc.TDByHash = make(map[common.Hash]*big.Int) diff --git a/statediff/types.go b/statediff/types.go index fd0877b02..d17d11067 100644 --- a/statediff/types.go +++ b/statediff/types.go @@ -22,6 +22,7 @@ package statediff import ( "encoding/json" "math/big" + "sync" "github.com/ethereum/go-ethereum/common" ctypes "github.com/ethereum/go-ethereum/core/types" @@ -53,6 +54,12 @@ type Params struct { WatchedStorageSlots []common.Hash } +// ParamsWithMutex allows to lock the parameters while they are being updated | read from +type ParamsWithMutex struct { + Params + mu sync.RWMutex +} + // Args bundles the arguments for the state diff builder type Args struct { OldStateRoot, NewStateRoot, BlockHash common.Hash -- 2.45.2 From d7704d2f98b51b28e7cb4bb04a1bf9d0a1712e75 Mon Sep 17 00:00:00 2001 From: prathamesh0 Date: Tue, 18 Jan 2022 15:05:45 +0530 Subject: [PATCH 04/17] Add a creation block arg in the watch address API --- statediff/api.go | 4 +- statediff/api_test.go | 102 ------------------------- statediff/helpers.go | 37 +++++++-- statediff/indexer/indexer.go | 10 +-- statediff/service.go | 26 ++++--- statediff/testhelpers/mocks/service.go | 2 +- statediff/types/types.go | 6 ++ 7 files changed, 59 insertions(+), 128 deletions(-) delete mode 100644 statediff/api_test.go diff --git a/statediff/api.go b/statediff/api.go index d6cdfd008..ed9cc3c06 100644 --- a/statediff/api.go +++ b/statediff/api.go @@ -151,6 +151,6 @@ func (api *PublicStateDiffAPI) WriteStateDiffFor(ctx context.Context, blockHash } // WatchAddress changes the list of watched addresses to which the direct indexing is restricted according to given operation -func (api *PublicStateDiffAPI) WatchAddress(operation OperationType, addresses []common.Address) error { - return api.sds.WatchAddress(operation, addresses) +func (api *PublicStateDiffAPI) WatchAddress(operation OperationType, args []WatchAddressArg) error { + return api.sds.WatchAddress(operation, args) } diff --git a/statediff/api_test.go b/statediff/api_test.go deleted file mode 100644 index 1e2bc2898..000000000 --- a/statediff/api_test.go +++ /dev/null @@ -1,102 +0,0 @@ -// Copyright 2019 The go-ethereum Authors -// This file is part of the go-ethereum library. -// -// The go-ethereum library is free software: you can redistribute it and/or modify -// it under the terms of the GNU Lesser General Public License as published by -// the Free Software Foundation, either version 3 of the License, or -// (at your option) any later version. -// -// The go-ethereum library is distributed in the hope that it will be useful, -// but WITHOUT ANY WARRANTY; without even the implied warranty of -// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -// GNU Lesser General Public License for more details. -// -// You should have received a copy of the GNU Lesser General Public License -// along with the go-ethereum library. If not, see . - -package statediff_test - -import ( - "fmt" - "os" - "reflect" - "testing" - - "github.com/ethereum/go-ethereum/common" - "github.com/ethereum/go-ethereum/statediff" -) - -func init() { - if os.Getenv("MODE") != "statediff" { - fmt.Println("Skipping statediff test") - os.Exit(0) - } -} - -var ( - address1Hex = "0x1ca7c995f8eF0A2989BbcE08D5B7Efe50A584aa1" - address2Hex = "0xe799eE0191652c864E49F3A3344CE62535B15afe" - address1 = common.HexToAddress(address1Hex) - address2 = common.HexToAddress(address2Hex) - - watchedAddresses0 []common.Address - watchedAddresses1 = []common.Address{address1} - watchedAddresses2 = []common.Address{address1, address2} - - expectedError = fmt.Errorf("Address %s already watched", address1Hex) - - service = statediff.Service{} -) - -// TODO: Update tests for the updated API -func TestWatchAddress(t *testing.T) { - watchedAddresses := service.GetWathchedAddresses() - if !reflect.DeepEqual(watchedAddresses, watchedAddresses0) { - t.Error("Test failure:", t.Name()) - t.Logf("Actual watched addresses not equal expected watched addresses.\nactual: %+v\nexpected: %+v", watchedAddresses, watchedAddresses0) - } - - testWatchUnwatchedAddress(t) - testWatchWatchedAddress(t) -} - -func testWatchUnwatchedAddress(t *testing.T) { - err := service.WatchAddress(statediff.Add, []common.Address{address1}) - if err != nil { - t.Error("Test failure:", t.Name()) - t.Logf("Unexpected error %s thrown on an attempt to watch an unwatched address.", err.Error()) - } - watchedAddresses := service.GetWathchedAddresses() - if !reflect.DeepEqual(watchedAddresses, watchedAddresses1) { - t.Error("Test failure:", t.Name()) - t.Logf("Actual watched addresses not equal expected watched addresses.\nactual: %+v\nexpected: %+v", watchedAddresses, watchedAddresses1) - } - - err = service.WatchAddress(statediff.Add, []common.Address{address2}) - if err != nil { - t.Error("Test failure:", t.Name()) - t.Logf("Unexpected error %s thrown on an attempt to watch an unwatched address.", err.Error()) - } - watchedAddresses = service.GetWathchedAddresses() - if !reflect.DeepEqual(watchedAddresses, watchedAddresses2) { - t.Error("Test failure:", t.Name()) - t.Logf("Actual watched addresses not equal expected watched addresses.\nactual: %+v\nexpected: %+v", watchedAddresses, watchedAddresses2) - } -} - -func testWatchWatchedAddress(t *testing.T) { - err := service.WatchAddress(statediff.Add, []common.Address{address1}) - if err == nil { - t.Error("Test failure:", t.Name()) - t.Logf("Expected error %s not thrown on an attempt to watch an already watched address.", expectedError.Error()) - } - if err.Error() != expectedError.Error() { - t.Error("Test failure:", t.Name()) - t.Logf("Actual thrown error not equal expected error.\nactual: %+v\nexpected: %+v", err.Error(), expectedError.Error()) - } - watchedAddresses := service.GetWathchedAddresses() - if !reflect.DeepEqual(watchedAddresses, watchedAddresses2) { - t.Error("Test failure:", t.Name()) - t.Logf("Actual watched addresses not equal expected watched addresses.\nactual: %+v\nexpected: %+v", watchedAddresses, watchedAddresses2) - } -} diff --git a/statediff/helpers.go b/statediff/helpers.go index 1dccbfbc2..f97f34062 100644 --- a/statediff/helpers.go +++ b/statediff/helpers.go @@ -26,6 +26,7 @@ import ( "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/statediff/indexer/postgres" + "github.com/ethereum/go-ethereum/statediff/types" ) func sortKeys(data AccountMap) []string { @@ -102,16 +103,15 @@ func loadWatchedAddresses(db *postgres.DB) error { // removeAddresses is used to remove given addresses from a list of addresses func removeAddresses(addresses []common.Address, addressesToRemove []common.Address) []common.Address { - addressesCopy := make([]common.Address, len(addresses)) - copy(addressesCopy, addresses) + filteredAddresses := []common.Address{} - for _, address := range addressesToRemove { - if idx := containsAddress(addressesCopy, address); idx != -1 { - addressesCopy = append(addressesCopy[:idx], addressesCopy[idx+1:]...) + for _, address := range addresses { + if idx := containsAddress(addressesToRemove, address); idx == -1 { + filteredAddresses = append(filteredAddresses, address) } } - return addressesCopy + return filteredAddresses } // containsAddress is used to check if an address is present in the provided list of addresses @@ -124,3 +124,28 @@ func containsAddress(addresses []common.Address, address common.Address) int { } return -1 } + +// getArgAddresses is used to get the list of addresses from a list of WatchAddressArgs +func getArgAddresses(args []types.WatchAddressArg) []common.Address { + addresses := make([]common.Address, len(args)) + for idx, arg := range args { + addresses[idx] = arg.Address + } + + return addresses +} + +// filterArgs filters out the args having an address from a given list of addresses +func filterArgs(args []types.WatchAddressArg, addressesToRemove []common.Address) ([]types.WatchAddressArg, []common.Address) { + filteredArgs := []types.WatchAddressArg{} + filteredAddresses := []common.Address{} + + for _, arg := range args { + if idx := containsAddress(addressesToRemove, arg.Address); idx == -1 { + filteredArgs = append(filteredArgs, arg) + filteredAddresses = append(filteredAddresses, arg.Address) + } + } + + return filteredArgs, filteredAddresses +} diff --git a/statediff/indexer/indexer.go b/statediff/indexer/indexer.go index fd3beec73..cda12e4bb 100644 --- a/statediff/indexer/indexer.go +++ b/statediff/indexer/indexer.go @@ -59,7 +59,7 @@ type Indexer interface { PushStateNode(tx *BlockTx, stateNode sdtypes.StateNode) error PushCodeAndCodeHash(tx *BlockTx, codeAndCodeHash sdtypes.CodeAndCodeHash) error ReportDBMetrics(delay time.Duration, quit <-chan bool) - InsertWatchedAddresses(addresses []common.Address, currentBlock *big.Int) error + InsertWatchedAddresses(addresses []sdtypes.WatchAddressArg, currentBlock *big.Int) error RemoveWatchedAddresses(addresses []common.Address) error ClearWatchedAddresses() error } @@ -554,15 +554,15 @@ func (sdi *StateDiffIndexer) PushCodeAndCodeHash(tx *BlockTx, codeAndCodeHash sd } // InsertWatchedAddresses inserts the given addresses in the database -func (sdi *StateDiffIndexer) InsertWatchedAddresses(addresses []common.Address, currentBlock *big.Int) error { +func (sdi *StateDiffIndexer) InsertWatchedAddresses(args []sdtypes.WatchAddressArg, currentBlockNumber *big.Int) error { tx, err := sdi.dbWriter.db.Begin() if err != nil { return err } - for _, address := range addresses { - _, err = tx.Exec(`INSERT INTO eth.watched_addresses (address, added_at)VALUES ($1, $2) ON CONFLICT (address) DO NOTHING`, - address.Hex(), currentBlock.Uint64()) + for _, arg := range args { + _, err = tx.Exec(`INSERT INTO eth.watched_addresses (address, created_at, watched_at)VALUES ($1, $2, $3) ON CONFLICT (address) DO NOTHING`, + arg.Address.Hex(), arg.CreatedAt, currentBlockNumber.Uint64()) if err != nil { return fmt.Errorf("error inserting watched_addresses entry: %v", err) } diff --git a/statediff/service.go b/statediff/service.go index 446388f66..cc0059f00 100644 --- a/statediff/service.go +++ b/statediff/service.go @@ -106,7 +106,7 @@ type IService interface { // Event loop for progressively processing and writing diffs directly to DB WriteLoop(chainEventCh chan core.ChainEvent) // Method to change the addresses being watched in write loop params - WatchAddress(operation OperationType, addresses []common.Address) error + WatchAddress(operation OperationType, args []WatchAddressArg) error // Method to get currently watched addresses from write loop params GetWathchedAddresses() []common.Address } @@ -735,39 +735,40 @@ func (sds *Service) writeStateDiffWithRetry(block *types.Block, parentRoot commo } // Performs Add | Remove | Set | Clear operation on the watched addresses in writeLoopParams and the db with provided addresses -func (sds *Service) WatchAddress(operation OperationType, addresses []common.Address) error { +func (sds *Service) WatchAddress(operation OperationType, args []WatchAddressArg) error { // lock writeLoopParams for a write writeLoopParams.mu.Lock() defer writeLoopParams.mu.Unlock() // get the current block number - currentBlock := sds.BlockChain.CurrentBlock() - currentBlockNumber := currentBlock.Number() + currentBlockNumber := sds.BlockChain.CurrentBlock().Number() switch operation { case Add: addressesToRemove := []common.Address{} - for _, address := range addresses { + for _, arg := range args { // Check if address is already being watched // Throw a warning and continue if found - if containsAddress(writeLoopParams.WatchedAddresses, address) != -1 { + if containsAddress(writeLoopParams.WatchedAddresses, arg.Address) != -1 { // log.Warn(fmt.Sprint("Address ", address.Hex(), " already being watched")) - log.Warn("Address already being watched", "address", address.Hex()) - addressesToRemove = append(addressesToRemove, address) + log.Warn("Address already being watched", "address", arg.Address.Hex()) + addressesToRemove = append(addressesToRemove, arg.Address) continue } } // remove already watched addresses - addresses = removeAddresses(addresses, addressesToRemove) + filteredArgs, filteredAddresses := filterArgs(args, addressesToRemove) - err := sds.indexer.InsertWatchedAddresses(addresses, currentBlockNumber) + err := sds.indexer.InsertWatchedAddresses(filteredArgs, currentBlockNumber) if err != nil { return err } - writeLoopParams.WatchedAddresses = append(writeLoopParams.WatchedAddresses, addresses...) + writeLoopParams.WatchedAddresses = append(writeLoopParams.WatchedAddresses, filteredAddresses...) case Remove: + addresses := getArgAddresses(args) + err := sds.indexer.RemoveWatchedAddresses(addresses) if err != nil { return err @@ -780,11 +781,12 @@ func (sds *Service) WatchAddress(operation OperationType, addresses []common.Add return err } - err = sds.indexer.InsertWatchedAddresses(addresses, currentBlockNumber) + err = sds.indexer.InsertWatchedAddresses(args, currentBlockNumber) if err != nil { return err } + addresses := getArgAddresses(args) writeLoopParams.WatchedAddresses = addresses case Clear: err := sds.indexer.ClearWatchedAddresses() diff --git a/statediff/testhelpers/mocks/service.go b/statediff/testhelpers/mocks/service.go index 943720094..e2f4d3cb9 100644 --- a/statediff/testhelpers/mocks/service.go +++ b/statediff/testhelpers/mocks/service.go @@ -333,7 +333,7 @@ func sendNonBlockingQuit(id rpc.ID, sub statediff.Subscription) { } } -func (sds *MockStateDiffService) WatchAddress(operation statediff.OperationType, addresses []common.Address) error { +func (sds *MockStateDiffService) WatchAddress(operation statediff.OperationType, addresses []sdtypes.WatchAddressArg) error { return nil } diff --git a/statediff/types/types.go b/statediff/types/types.go index 56babfb5b..cab4cb880 100644 --- a/statediff/types/types.go +++ b/statediff/types/types.go @@ -74,3 +74,9 @@ type CodeAndCodeHash struct { type StateNodeSink func(StateNode) error type StorageNodeSink func(StorageNode) error type CodeSink func(CodeAndCodeHash) error + +// WatchAddressArg is a arg type for WatchAddress API +type WatchAddressArg struct { + Address common.Address + CreatedAt uint64 +} -- 2.45.2 From 355ad83b88e534262087f27e3f53dfc98ed28f9f Mon Sep 17 00:00:00 2001 From: prathamesh0 Date: Wed, 19 Jan 2022 10:33:21 +0530 Subject: [PATCH 05/17] Rollback db transactions on an error and other fixes --- statediff/helpers.go | 20 ++++++--------- statediff/indexer/indexer.go | 35 +++++++++++++++++++++++++- statediff/service.go | 30 +++++++++------------- statediff/testhelpers/mocks/service.go | 2 +- statediff/types.go | 2 +- 5 files changed, 56 insertions(+), 33 deletions(-) diff --git a/statediff/helpers.go b/statediff/helpers.go index f97f34062..1f8523fc3 100644 --- a/statediff/helpers.go +++ b/statediff/helpers.go @@ -78,25 +78,21 @@ func findIntersection(a, b []string) []string { // loadWatchedAddresses is used to load watched addresses to the in-memory write loop params from the db func loadWatchedAddresses(db *postgres.DB) error { - rows, err := db.Query("SELECT address FROM eth.watched_addresses") + var watchedAddressStrings []string + pgStr := "SELECT address FROM eth.watched_addresses" + err := db.Select(&watchedAddressStrings, pgStr) if err != nil { return fmt.Errorf("error loading watched addresses: %v", err) } var watchedAddresses []common.Address - for rows.Next() { - var addressHex string - err := rows.Scan(&addressHex) - if err != nil { - return err - } - - watchedAddresses = append(watchedAddresses, common.HexToAddress(addressHex)) + for _, watchedAddressString := range watchedAddressStrings { + watchedAddresses = append(watchedAddresses, common.HexToAddress(watchedAddressString)) } - writeLoopParams.mu.Lock() + writeLoopParams.Lock() + defer writeLoopParams.Unlock() writeLoopParams.WatchedAddresses = watchedAddresses - writeLoopParams.mu.Unlock() return nil } @@ -126,7 +122,7 @@ func containsAddress(addresses []common.Address, address common.Address) int { } // getArgAddresses is used to get the list of addresses from a list of WatchAddressArgs -func getArgAddresses(args []types.WatchAddressArg) []common.Address { +func getAddresses(args []types.WatchAddressArg) []common.Address { addresses := make([]common.Address, len(args)) for idx, arg := range args { addresses[idx] = arg.Address diff --git a/statediff/indexer/indexer.go b/statediff/indexer/indexer.go index cda12e4bb..c85d00b2e 100644 --- a/statediff/indexer/indexer.go +++ b/statediff/indexer/indexer.go @@ -59,8 +59,11 @@ type Indexer interface { PushStateNode(tx *BlockTx, stateNode sdtypes.StateNode) error PushCodeAndCodeHash(tx *BlockTx, codeAndCodeHash sdtypes.CodeAndCodeHash) error ReportDBMetrics(delay time.Duration, quit <-chan bool) + + // Methods used by WatchAddress API/functionality. InsertWatchedAddresses(addresses []sdtypes.WatchAddressArg, currentBlock *big.Int) error RemoveWatchedAddresses(addresses []common.Address) error + SetWatchedAddresses(args []sdtypes.WatchAddressArg, currentBlockNumber *big.Int) error ClearWatchedAddresses() error } @@ -559,9 +562,10 @@ func (sdi *StateDiffIndexer) InsertWatchedAddresses(args []sdtypes.WatchAddressA if err != nil { return err } + defer tx.Rollback() for _, arg := range args { - _, err = tx.Exec(`INSERT INTO eth.watched_addresses (address, created_at, watched_at)VALUES ($1, $2, $3) ON CONFLICT (address) DO NOTHING`, + _, err = tx.Exec(`INSERT INTO eth.watched_addresses (address, created_at, watched_at) VALUES ($1, $2, $3) ON CONFLICT (address) DO NOTHING`, arg.Address.Hex(), arg.CreatedAt, currentBlockNumber.Uint64()) if err != nil { return fmt.Errorf("error inserting watched_addresses entry: %v", err) @@ -582,6 +586,7 @@ func (sdi *StateDiffIndexer) RemoveWatchedAddresses(addresses []common.Address) if err != nil { return err } + defer tx.Rollback() for _, address := range addresses { _, err = tx.Exec(`DELETE FROM eth.watched_addresses WHERE address = $1`, address.Hex()) @@ -598,6 +603,34 @@ func (sdi *StateDiffIndexer) RemoveWatchedAddresses(addresses []common.Address) return nil } +func (sdi *StateDiffIndexer) SetWatchedAddresses(args []sdtypes.WatchAddressArg, currentBlockNumber *big.Int) error { + tx, err := sdi.dbWriter.db.Begin() + if err != nil { + return err + } + defer tx.Rollback() + + _, err = tx.Exec(`DELETE FROM eth.watched_addresses`) + if err != nil { + return fmt.Errorf("error setting watched_addresses table: %v", err) + } + + for _, arg := range args { + _, err = tx.Exec(`INSERT INTO eth.watched_addresses (address, created_at, watched_at) VALUES ($1, $2, $3) ON CONFLICT (address) DO NOTHING`, + arg.Address.Hex(), arg.CreatedAt, currentBlockNumber.Uint64()) + if err != nil { + return fmt.Errorf("error setting watched_addresses table: %v", err) + } + } + + err = tx.Commit() + if err != nil { + return err + } + + return nil +} + // ClearWatchedAddresses clears all the addresses from the database func (sdi *StateDiffIndexer) ClearWatchedAddresses() error { _, err := sdi.dbWriter.db.Exec(`DELETE FROM eth.watched_addresses`) diff --git a/statediff/service.go b/statediff/service.go index cc0059f00..17727f14a 100644 --- a/statediff/service.go +++ b/statediff/service.go @@ -70,13 +70,13 @@ var statediffMetrics = RegisterStatediffMetrics(metrics.DefaultRegistry) type blockChain interface { SubscribeChainEvent(ch chan<- core.ChainEvent) event.Subscription + CurrentBlock() *types.Block GetBlockByHash(hash common.Hash) *types.Block GetBlockByNumber(number uint64) *types.Block GetReceiptsByHash(hash common.Hash) types.Receipts GetTd(hash common.Hash, number uint64) *big.Int UnlockTrie(root common.Hash) StateCache() state.Database - CurrentBlock() *types.Block } // IService is the state-diffing service interface @@ -108,7 +108,7 @@ type IService interface { // Method to change the addresses being watched in write loop params WatchAddress(operation OperationType, args []WatchAddressArg) error // Method to get currently watched addresses from write loop params - GetWathchedAddresses() []common.Address + GetWatchedAddresses() []common.Address } // Wraps consructor parameters @@ -294,9 +294,9 @@ func (sds *Service) WriteLoop(chainEventCh chan core.ChainEvent) { func (sds *Service) writeGenesisStateDiff(currBlock *types.Block, workerId uint) { // For genesis block we need to return the entire state trie hence we diff it with an empty trie. log.Info("Writing state diff", "block height", genesisBlockNumber, "worker", workerId) - writeLoopParams.mu.RLock() + writeLoopParams.RLock() err := sds.writeStateDiffWithRetry(currBlock, common.Hash{}, writeLoopParams.Params) - writeLoopParams.mu.RUnlock() + writeLoopParams.RUnlock() if err != nil { log.Error("statediff.Service.WriteLoop: processing error", "block height", genesisBlockNumber, "error", err.Error(), "worker", workerId) @@ -325,9 +325,9 @@ func (sds *Service) writeLoopWorker(params workerParams) { } log.Info("Writing state diff", "block height", currentBlock.Number().Uint64(), "worker", params.id) - writeLoopParams.mu.RLock() + writeLoopParams.RLock() err := sds.writeStateDiffWithRetry(currentBlock, parentBlock.Root(), writeLoopParams.Params) - writeLoopParams.mu.RUnlock() + writeLoopParams.RUnlock() if err != nil { log.Error("statediff.Service.WriteLoop: processing error", "block height", currentBlock.Number().Uint64(), "error", err.Error(), "worker", params.id) continue @@ -737,8 +737,8 @@ func (sds *Service) writeStateDiffWithRetry(block *types.Block, parentRoot commo // Performs Add | Remove | Set | Clear operation on the watched addresses in writeLoopParams and the db with provided addresses func (sds *Service) WatchAddress(operation OperationType, args []WatchAddressArg) error { // lock writeLoopParams for a write - writeLoopParams.mu.Lock() - defer writeLoopParams.mu.Unlock() + writeLoopParams.Lock() + defer writeLoopParams.Unlock() // get the current block number currentBlockNumber := sds.BlockChain.CurrentBlock().Number() @@ -750,7 +750,6 @@ func (sds *Service) WatchAddress(operation OperationType, args []WatchAddressArg // Check if address is already being watched // Throw a warning and continue if found if containsAddress(writeLoopParams.WatchedAddresses, arg.Address) != -1 { - // log.Warn(fmt.Sprint("Address ", address.Hex(), " already being watched")) log.Warn("Address already being watched", "address", arg.Address.Hex()) addressesToRemove = append(addressesToRemove, arg.Address) continue @@ -767,7 +766,7 @@ func (sds *Service) WatchAddress(operation OperationType, args []WatchAddressArg writeLoopParams.WatchedAddresses = append(writeLoopParams.WatchedAddresses, filteredAddresses...) case Remove: - addresses := getArgAddresses(args) + addresses := getAddresses(args) err := sds.indexer.RemoveWatchedAddresses(addresses) if err != nil { @@ -776,17 +775,12 @@ func (sds *Service) WatchAddress(operation OperationType, args []WatchAddressArg writeLoopParams.WatchedAddresses = removeAddresses(writeLoopParams.WatchedAddresses, addresses) case Set: - err := sds.indexer.ClearWatchedAddresses() + err := sds.indexer.SetWatchedAddresses(args, currentBlockNumber) if err != nil { return err } - err = sds.indexer.InsertWatchedAddresses(args, currentBlockNumber) - if err != nil { - return err - } - - addresses := getArgAddresses(args) + addresses := getAddresses(args) writeLoopParams.WatchedAddresses = addresses case Clear: err := sds.indexer.ClearWatchedAddresses() @@ -803,6 +797,6 @@ func (sds *Service) WatchAddress(operation OperationType, args []WatchAddressArg } // Gets currently watched addresses from the in-memory write loop params -func (sds *Service) GetWathchedAddresses() []common.Address { +func (sds *Service) GetWatchedAddresses() []common.Address { return writeLoopParams.WatchedAddresses } diff --git a/statediff/testhelpers/mocks/service.go b/statediff/testhelpers/mocks/service.go index e2f4d3cb9..686f0c2ba 100644 --- a/statediff/testhelpers/mocks/service.go +++ b/statediff/testhelpers/mocks/service.go @@ -337,6 +337,6 @@ func (sds *MockStateDiffService) WatchAddress(operation statediff.OperationType, return nil } -func (sds *MockStateDiffService) GetWathchedAddresses() []common.Address { +func (sds *MockStateDiffService) GetWatchedAddresses() []common.Address { return []common.Address{} } diff --git a/statediff/types.go b/statediff/types.go index d17d11067..b179f1a00 100644 --- a/statediff/types.go +++ b/statediff/types.go @@ -57,7 +57,7 @@ type Params struct { // ParamsWithMutex allows to lock the parameters while they are being updated | read from type ParamsWithMutex struct { Params - mu sync.RWMutex + sync.RWMutex } // Args bundles the arguments for the state diff builder -- 2.45.2 From ebd43fb85711834cbc901900d7cf76000fe10458 Mon Sep 17 00:00:00 2001 From: prathamesh0 Date: Wed, 19 Jan 2022 15:59:35 +0530 Subject: [PATCH 06/17] Add support for changing watched storage slots --- statediff/api.go | 2 +- statediff/helpers.go | 90 ++++++++++++++++++++++---- statediff/indexer/indexer.go | 39 +++++------ statediff/service.go | 79 ++++++++++++++++++---- statediff/testhelpers/mocks/service.go | 2 +- statediff/types.go | 13 ++-- statediff/types/types.go | 22 ++++++- 7 files changed, 194 insertions(+), 53 deletions(-) diff --git a/statediff/api.go b/statediff/api.go index ed9cc3c06..3686728f2 100644 --- a/statediff/api.go +++ b/statediff/api.go @@ -150,7 +150,7 @@ func (api *PublicStateDiffAPI) WriteStateDiffFor(ctx context.Context, blockHash return api.sds.WriteStateDiffFor(blockHash, params) } -// WatchAddress changes the list of watched addresses to which the direct indexing is restricted according to given operation +// WatchAddress changes the list of watched addresses | storage slots to which the direct indexing is restricted according to given operation func (api *PublicStateDiffAPI) WatchAddress(operation OperationType, args []WatchAddressArg) error { return api.sds.WatchAddress(operation, args) } diff --git a/statediff/helpers.go b/statediff/helpers.go index 1f8523fc3..209e1d724 100644 --- a/statediff/helpers.go +++ b/statediff/helpers.go @@ -76,23 +76,36 @@ func findIntersection(a, b []string) []string { } } -// loadWatchedAddresses is used to load watched addresses to the in-memory write loop params from the db -func loadWatchedAddresses(db *postgres.DB) error { - var watchedAddressStrings []string - pgStr := "SELECT address FROM eth.watched_addresses" - err := db.Select(&watchedAddressStrings, pgStr) +// loadWatched is used to load watched addresses and storage slots to the in-memory write loop params from the db +func loadWatched(db *postgres.DB) error { + type Watched struct { + Address string `db:"address"` + Kind int `db:"kind"` + } + var watched []Watched + pgStr := "SELECT address, kind FROM eth.watched_addresses" + err := db.Select(&watched, pgStr) if err != nil { return fmt.Errorf("error loading watched addresses: %v", err) } var watchedAddresses []common.Address - for _, watchedAddressString := range watchedAddressStrings { - watchedAddresses = append(watchedAddresses, common.HexToAddress(watchedAddressString)) + var watchedStorageSlots []common.Hash + for _, entry := range watched { + switch entry.Kind { + case types.WatchedAddress.Int(): + watchedAddresses = append(watchedAddresses, common.HexToAddress(entry.Address)) + case types.WatchedStorageSlot.Int(): + watchedStorageSlots = append(watchedStorageSlots, common.HexToHash(entry.Address)) + default: + return fmt.Errorf("Unexpected kind %d", entry.Kind) + } } writeLoopParams.Lock() defer writeLoopParams.Unlock() writeLoopParams.WatchedAddresses = watchedAddresses + writeLoopParams.WatchedStorageSlots = watchedStorageSlots return nil } @@ -110,6 +123,19 @@ func removeAddresses(addresses []common.Address, addressesToRemove []common.Addr return filteredAddresses } +// removeAddresses is used to remove given storage slots from a list of storage slots +func removeStorageSlots(storageSlots []common.Hash, storageSlotsToRemove []common.Hash) []common.Hash { + filteredStorageSlots := []common.Hash{} + + for _, address := range storageSlots { + if idx := containsStorageSlot(storageSlotsToRemove, address); idx == -1 { + filteredStorageSlots = append(filteredStorageSlots, address) + } + } + + return filteredStorageSlots +} + // containsAddress is used to check if an address is present in the provided list of addresses // return the index if found else -1 func containsAddress(addresses []common.Address, address common.Address) int { @@ -121,27 +147,65 @@ func containsAddress(addresses []common.Address, address common.Address) int { return -1 } -// getArgAddresses is used to get the list of addresses from a list of WatchAddressArgs +// containsAddress is used to check if a storage slot is present in the provided list of storage slots +// return the index if found else -1 +func containsStorageSlot(storageSlots []common.Hash, storageSlot common.Hash) int { + for idx, slot := range storageSlots { + if slot == storageSlot { + return idx + } + } + return -1 +} + +// getAddresses is used to get the list of addresses from a list of WatchAddressArgs func getAddresses(args []types.WatchAddressArg) []common.Address { addresses := make([]common.Address, len(args)) for idx, arg := range args { - addresses[idx] = arg.Address + addresses[idx] = common.HexToAddress(arg.Address) } return addresses } -// filterArgs filters out the args having an address from a given list of addresses -func filterArgs(args []types.WatchAddressArg, addressesToRemove []common.Address) ([]types.WatchAddressArg, []common.Address) { +// getStorageSlots is used to get the list of storage slots from a list of WatchAddressArgs +func getStorageSlots(args []types.WatchAddressArg) []common.Hash { + storageSlots := make([]common.Hash, len(args)) + for idx, arg := range args { + storageSlots[idx] = common.HexToHash(arg.Address) + } + + return storageSlots +} + +// filterAddressArgs filters out the args having an address from a given list of addresses +func filterAddressArgs(args []types.WatchAddressArg, addressesToRemove []common.Address) ([]types.WatchAddressArg, []common.Address) { filteredArgs := []types.WatchAddressArg{} filteredAddresses := []common.Address{} for _, arg := range args { - if idx := containsAddress(addressesToRemove, arg.Address); idx == -1 { + address := common.HexToAddress(arg.Address) + if idx := containsAddress(addressesToRemove, address); idx == -1 { filteredArgs = append(filteredArgs, arg) - filteredAddresses = append(filteredAddresses, arg.Address) + filteredAddresses = append(filteredAddresses, address) } } return filteredArgs, filteredAddresses } + +// filterStorageSlotArgs filters out the args having a storage slot from a given list of storage slots +func filterStorageSlotArgs(args []types.WatchAddressArg, storageSlotsToRemove []common.Hash) ([]types.WatchAddressArg, []common.Hash) { + filteredArgs := []types.WatchAddressArg{} + filteredStorageSlots := []common.Hash{} + + for _, arg := range args { + storageSlot := common.HexToHash(arg.Address) + if idx := containsStorageSlot(storageSlotsToRemove, storageSlot); idx == -1 { + filteredArgs = append(filteredArgs, arg) + filteredStorageSlots = append(filteredStorageSlots, storageSlot) + } + } + + return filteredArgs, filteredStorageSlots +} diff --git a/statediff/indexer/indexer.go b/statediff/indexer/indexer.go index c85d00b2e..8ec34549b 100644 --- a/statediff/indexer/indexer.go +++ b/statediff/indexer/indexer.go @@ -61,10 +61,10 @@ type Indexer interface { ReportDBMetrics(delay time.Duration, quit <-chan bool) // Methods used by WatchAddress API/functionality. - InsertWatchedAddresses(addresses []sdtypes.WatchAddressArg, currentBlock *big.Int) error - RemoveWatchedAddresses(addresses []common.Address) error - SetWatchedAddresses(args []sdtypes.WatchAddressArg, currentBlockNumber *big.Int) error - ClearWatchedAddresses() error + InsertWatched(addresses []sdtypes.WatchAddressArg, currentBlock *big.Int, kind sdtypes.WatchedAddressType) error + RemoveWatched(addresses []sdtypes.WatchAddressArg, kind sdtypes.WatchedAddressType) error + SetWatched(args []sdtypes.WatchAddressArg, currentBlockNumber *big.Int, kind sdtypes.WatchedAddressType) error + ClearWatched(kind sdtypes.WatchedAddressType) error } // StateDiffIndexer satisfies the Indexer interface for ethereum statediff objects @@ -556,8 +556,8 @@ func (sdi *StateDiffIndexer) PushCodeAndCodeHash(tx *BlockTx, codeAndCodeHash sd return nil } -// InsertWatchedAddresses inserts the given addresses in the database -func (sdi *StateDiffIndexer) InsertWatchedAddresses(args []sdtypes.WatchAddressArg, currentBlockNumber *big.Int) error { +// InsertWatchedAddresses inserts the given addresses | storage slots in the database +func (sdi *StateDiffIndexer) InsertWatched(args []sdtypes.WatchAddressArg, currentBlockNumber *big.Int, kind sdtypes.WatchedAddressType) error { tx, err := sdi.dbWriter.db.Begin() if err != nil { return err @@ -565,8 +565,8 @@ func (sdi *StateDiffIndexer) InsertWatchedAddresses(args []sdtypes.WatchAddressA defer tx.Rollback() for _, arg := range args { - _, err = tx.Exec(`INSERT INTO eth.watched_addresses (address, created_at, watched_at) VALUES ($1, $2, $3) ON CONFLICT (address) DO NOTHING`, - arg.Address.Hex(), arg.CreatedAt, currentBlockNumber.Uint64()) + _, err = tx.Exec(`INSERT INTO eth.watched_addresses (address, kind, created_at, watched_at) VALUES ($1, $2, $3, $4) ON CONFLICT (address) DO NOTHING`, + arg.Address, kind.Int(), arg.CreatedAt, currentBlockNumber.Uint64()) if err != nil { return fmt.Errorf("error inserting watched_addresses entry: %v", err) } @@ -580,16 +580,16 @@ func (sdi *StateDiffIndexer) InsertWatchedAddresses(args []sdtypes.WatchAddressA return nil } -// RemoveWatchedAddresses removes the given addresses from the database -func (sdi *StateDiffIndexer) RemoveWatchedAddresses(addresses []common.Address) error { +// RemoveWatchedAddresses removes the given addresses | storage slots from the database +func (sdi *StateDiffIndexer) RemoveWatched(args []sdtypes.WatchAddressArg, kind sdtypes.WatchedAddressType) error { tx, err := sdi.dbWriter.db.Begin() if err != nil { return err } defer tx.Rollback() - for _, address := range addresses { - _, err = tx.Exec(`DELETE FROM eth.watched_addresses WHERE address = $1`, address.Hex()) + for _, arg := range args { + _, err = tx.Exec(`DELETE FROM eth.watched_addresses WHERE address = $1 AND kind = $2`, arg.Address, kind.Int()) if err != nil { return fmt.Errorf("error removing watched_addresses entry: %v", err) } @@ -603,21 +603,22 @@ func (sdi *StateDiffIndexer) RemoveWatchedAddresses(addresses []common.Address) return nil } -func (sdi *StateDiffIndexer) SetWatchedAddresses(args []sdtypes.WatchAddressArg, currentBlockNumber *big.Int) error { +// SetWatched clears and inserts the given addresses | storage slots in the database +func (sdi *StateDiffIndexer) SetWatched(args []sdtypes.WatchAddressArg, currentBlockNumber *big.Int, kind sdtypes.WatchedAddressType) error { tx, err := sdi.dbWriter.db.Begin() if err != nil { return err } defer tx.Rollback() - _, err = tx.Exec(`DELETE FROM eth.watched_addresses`) + _, err = tx.Exec(`DELETE FROM eth.watched_addresses WHERE kind = $1`, kind.Int()) if err != nil { return fmt.Errorf("error setting watched_addresses table: %v", err) } for _, arg := range args { - _, err = tx.Exec(`INSERT INTO eth.watched_addresses (address, created_at, watched_at) VALUES ($1, $2, $3) ON CONFLICT (address) DO NOTHING`, - arg.Address.Hex(), arg.CreatedAt, currentBlockNumber.Uint64()) + _, err = tx.Exec(`INSERT INTO eth.watched_addresses (address, kind, created_at, watched_at) VALUES ($1, $2, $3, $4) ON CONFLICT (address) DO NOTHING`, + arg.Address, kind.Int(), arg.CreatedAt, currentBlockNumber.Uint64()) if err != nil { return fmt.Errorf("error setting watched_addresses table: %v", err) } @@ -631,9 +632,9 @@ func (sdi *StateDiffIndexer) SetWatchedAddresses(args []sdtypes.WatchAddressArg, return nil } -// ClearWatchedAddresses clears all the addresses from the database -func (sdi *StateDiffIndexer) ClearWatchedAddresses() error { - _, err := sdi.dbWriter.db.Exec(`DELETE FROM eth.watched_addresses`) +// ClearWatchedAddresses clears all the addresses | storage slots from the database +func (sdi *StateDiffIndexer) ClearWatched(kind sdtypes.WatchedAddressType) error { + _, err := sdi.dbWriter.db.Exec(`DELETE FROM eth.watched_addresses WHERE kind = $1`, kind.Int()) if err != nil { return fmt.Errorf("error clearing watched_addresses table: %v", err) } diff --git a/statediff/service.go b/statediff/service.go index 17727f14a..75268596c 100644 --- a/statediff/service.go +++ b/statediff/service.go @@ -211,7 +211,7 @@ func New(stack *node.Node, ethServ *eth.Ethereum, cfg *ethconfig.Config, params stack.RegisterLifecycle(sds) stack.RegisterAPIs(sds.APIs()) - err = loadWatchedAddresses(db) + err = loadWatched(db) if err != nil { return err } @@ -734,7 +734,9 @@ func (sds *Service) writeStateDiffWithRetry(block *types.Block, parentRoot commo return err } -// Performs Add | Remove | Set | Clear operation on the watched addresses in writeLoopParams and the db with provided addresses +// Performs one of foll. operations on the watched addresses | storage slots in writeLoopParams and the db: +// AddAddresses | RemoveAddresses | SetAddresses | ClearAddresses +// AddStorageSlots | RemoveStorageSlots | SetStorageSlots | ClearStorageSlots func (sds *Service) WatchAddress(operation OperationType, args []WatchAddressArg) error { // lock writeLoopParams for a write writeLoopParams.Lock() @@ -744,51 +746,100 @@ func (sds *Service) WatchAddress(operation OperationType, args []WatchAddressArg currentBlockNumber := sds.BlockChain.CurrentBlock().Number() switch operation { - case Add: + case AddAddresses: addressesToRemove := []common.Address{} for _, arg := range args { // Check if address is already being watched // Throw a warning and continue if found - if containsAddress(writeLoopParams.WatchedAddresses, arg.Address) != -1 { - log.Warn("Address already being watched", "address", arg.Address.Hex()) - addressesToRemove = append(addressesToRemove, arg.Address) + address := common.HexToAddress(arg.Address) + if containsAddress(writeLoopParams.WatchedAddresses, address) != -1 { + log.Warn("Address already being watched", "address", arg.Address) + addressesToRemove = append(addressesToRemove, address) continue } } // remove already watched addresses - filteredArgs, filteredAddresses := filterArgs(args, addressesToRemove) + filteredArgs, filteredAddresses := filterAddressArgs(args, addressesToRemove) - err := sds.indexer.InsertWatchedAddresses(filteredArgs, currentBlockNumber) + err := sds.indexer.InsertWatched(filteredArgs, currentBlockNumber, WatchedAddress) if err != nil { return err } writeLoopParams.WatchedAddresses = append(writeLoopParams.WatchedAddresses, filteredAddresses...) - case Remove: + case RemoveAddresses: addresses := getAddresses(args) - err := sds.indexer.RemoveWatchedAddresses(addresses) + err := sds.indexer.RemoveWatched(args, WatchedAddress) if err != nil { return err } writeLoopParams.WatchedAddresses = removeAddresses(writeLoopParams.WatchedAddresses, addresses) - case Set: - err := sds.indexer.SetWatchedAddresses(args, currentBlockNumber) + case SetAddresses: + err := sds.indexer.SetWatched(args, currentBlockNumber, WatchedAddress) if err != nil { return err } addresses := getAddresses(args) writeLoopParams.WatchedAddresses = addresses - case Clear: - err := sds.indexer.ClearWatchedAddresses() + case ClearAddresses: + err := sds.indexer.ClearWatched(WatchedAddress) if err != nil { return err } writeLoopParams.WatchedAddresses = nil + + case AddStorageSlots: + storageSlotsToRemove := []common.Hash{} + for _, arg := range args { + // Check if address is already being watched + // Throw a warning and continue if found + storageSlot := common.HexToHash(arg.Address) + if containsStorageSlot(writeLoopParams.WatchedStorageSlots, storageSlot) != -1 { + log.Warn("StorageSlot already being watched", "storage slot", arg.Address) + storageSlotsToRemove = append(storageSlotsToRemove, storageSlot) + continue + } + } + + // remove already watched addresses + filteredArgs, filteredStorageSlots := filterStorageSlotArgs(args, storageSlotsToRemove) + + err := sds.indexer.InsertWatched(filteredArgs, currentBlockNumber, WatchedStorageSlot) + if err != nil { + return err + } + + writeLoopParams.WatchedStorageSlots = append(writeLoopParams.WatchedStorageSlots, filteredStorageSlots...) + case RemoveStorageSlots: + storageSlots := getStorageSlots(args) + + err := sds.indexer.RemoveWatched(args, WatchedStorageSlot) + if err != nil { + return err + } + + writeLoopParams.WatchedStorageSlots = removeStorageSlots(writeLoopParams.WatchedStorageSlots, storageSlots) + case SetStorageSlots: + err := sds.indexer.SetWatched(args, currentBlockNumber, WatchedStorageSlot) + if err != nil { + return err + } + + storageSlots := getStorageSlots(args) + writeLoopParams.WatchedStorageSlots = storageSlots + case ClearStorageSlots: + err := sds.indexer.ClearWatched(WatchedStorageSlot) + if err != nil { + return err + } + + writeLoopParams.WatchedStorageSlots = nil + default: return fmt.Errorf("Unexpected operation %s", operation) } diff --git a/statediff/testhelpers/mocks/service.go b/statediff/testhelpers/mocks/service.go index 686f0c2ba..f8d1b6b20 100644 --- a/statediff/testhelpers/mocks/service.go +++ b/statediff/testhelpers/mocks/service.go @@ -333,7 +333,7 @@ func sendNonBlockingQuit(id rpc.ID, sub statediff.Subscription) { } } -func (sds *MockStateDiffService) WatchAddress(operation statediff.OperationType, addresses []sdtypes.WatchAddressArg) error { +func (sds *MockStateDiffService) WatchAddress(operation statediff.OperationType, args []sdtypes.WatchAddressArg) error { return nil } diff --git a/statediff/types.go b/statediff/types.go index b179f1a00..69eaf0d0a 100644 --- a/statediff/types.go +++ b/statediff/types.go @@ -123,8 +123,13 @@ type accountWrapper struct { type OperationType string const ( - Add OperationType = "Add" - Remove OperationType = "Remove" - Set OperationType = "Set" - Clear OperationType = "Clear" + AddAddresses OperationType = "AddAddresses" + RemoveAddresses OperationType = "RemoveAddresses" + SetAddresses OperationType = "SetAddresses" + ClearAddresses OperationType = "ClearAddresses" + + AddStorageSlots OperationType = "AddStorageSlots" + RemoveStorageSlots OperationType = "RemoveStorageSlots" + SetStorageSlots OperationType = "SetStorageSlots" + ClearStorageSlots OperationType = "ClearStorageSlots" ) diff --git a/statediff/types/types.go b/statediff/types/types.go index cab4cb880..18d9d0b31 100644 --- a/statediff/types/types.go +++ b/statediff/types/types.go @@ -77,6 +77,26 @@ type CodeSink func(CodeAndCodeHash) error // WatchAddressArg is a arg type for WatchAddress API type WatchAddressArg struct { - Address common.Address + // Address represents common.Address | common.Hash + Address string CreatedAt uint64 } + +// WatchedAddressType for denoting watched: address | storage slot +type WatchedAddressType string + +const ( + WatchedAddress WatchedAddressType = "WatchedAddress" + WatchedStorageSlot WatchedAddressType = "WatchedStorageSlot" +) + +func (n WatchedAddressType) Int() int { + switch n { + case WatchedAddress: + return 0 + case WatchedStorageSlot: + return 1 + default: + return -1 + } +} -- 2.45.2 From 234974f4114b5670884908c5e90e3fb3147e09bc Mon Sep 17 00:00:00 2001 From: prathamesh0 Date: Thu, 20 Jan 2022 11:20:39 +0530 Subject: [PATCH 07/17] Use an utility library for common operations --- go.mod | 1 + go.sum | 2 + statediff/helpers.go | 105 +--------------------- statediff/indexer/indexer.go | 18 ++-- statediff/service.go | 168 ++++++++++++++++++++++++----------- 5 files changed, 131 insertions(+), 163 deletions(-) diff --git a/go.mod b/go.mod index 0f94c2611..3baf72200 100644 --- a/go.mod +++ b/go.mod @@ -68,6 +68,7 @@ require ( github.com/status-im/keycard-go v0.0.0-20190316090335-8537d3370df4 github.com/stretchr/testify v1.7.0 github.com/syndtr/goleveldb v1.0.1-0.20210819022825-2ae1ddf74ef7 + github.com/thoas/go-funk v0.9.1 github.com/tklauser/go-sysconf v0.3.5 // indirect github.com/tyler-smith/go-bip39 v1.0.1-0.20181017060643-dbb3b84ba2ef golang.org/x/crypto v0.0.0-20210513164829-c07d793c2f9a diff --git a/go.sum b/go.sum index 0097b96f7..76d8234e5 100644 --- a/go.sum +++ b/go.sum @@ -470,6 +470,8 @@ github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5Cc github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/syndtr/goleveldb v1.0.1-0.20210819022825-2ae1ddf74ef7 h1:epCh84lMvA70Z7CTTCmYQn2CKbY8j86K7/FAIr141uY= github.com/syndtr/goleveldb v1.0.1-0.20210819022825-2ae1ddf74ef7/go.mod h1:q4W45IWZaF22tdD+VEXcAWRA037jwmWEB5VWYORlTpc= +github.com/thoas/go-funk v0.9.1 h1:O549iLZqPpTUQ10ykd26sZhzD+rmR5pWhuElrhbC20M= +github.com/thoas/go-funk v0.9.1/go.mod h1:+IWnUfUmFO1+WVYQWQtIJHeRRdaIyyYglZN7xzUPe4Q= github.com/tinylib/msgp v1.0.2/go.mod h1:+d+yLhGm8mzTaHzB+wgMYrodPfmZrzkirds8fDWklFE= github.com/tklauser/go-sysconf v0.3.5 h1:uu3Xl4nkLzQfXNsWn15rPc/HQCJKObbt1dKJeWp3vU4= github.com/tklauser/go-sysconf v0.3.5/go.mod h1:MkWzOF4RMCshBAMXuhXJs64Rte09mITnppBXY/rYEFI= diff --git a/statediff/helpers.go b/statediff/helpers.go index 209e1d724..4d9b886e4 100644 --- a/statediff/helpers.go +++ b/statediff/helpers.go @@ -76,13 +76,14 @@ func findIntersection(a, b []string) []string { } } -// loadWatched is used to load watched addresses and storage slots to the in-memory write loop params from the db -func loadWatched(db *postgres.DB) error { +// loadWatchedAddresses is used to load watched addresses and storage slots to the in-memory write loop params from the db +func loadWatchedAddresses(db *postgres.DB) error { type Watched struct { Address string `db:"address"` Kind int `db:"kind"` } var watched []Watched + pgStr := "SELECT address, kind FROM eth.watched_addresses" err := db.Select(&watched, pgStr) if err != nil { @@ -109,103 +110,3 @@ func loadWatched(db *postgres.DB) error { return nil } - -// removeAddresses is used to remove given addresses from a list of addresses -func removeAddresses(addresses []common.Address, addressesToRemove []common.Address) []common.Address { - filteredAddresses := []common.Address{} - - for _, address := range addresses { - if idx := containsAddress(addressesToRemove, address); idx == -1 { - filteredAddresses = append(filteredAddresses, address) - } - } - - return filteredAddresses -} - -// removeAddresses is used to remove given storage slots from a list of storage slots -func removeStorageSlots(storageSlots []common.Hash, storageSlotsToRemove []common.Hash) []common.Hash { - filteredStorageSlots := []common.Hash{} - - for _, address := range storageSlots { - if idx := containsStorageSlot(storageSlotsToRemove, address); idx == -1 { - filteredStorageSlots = append(filteredStorageSlots, address) - } - } - - return filteredStorageSlots -} - -// containsAddress is used to check if an address is present in the provided list of addresses -// return the index if found else -1 -func containsAddress(addresses []common.Address, address common.Address) int { - for idx, addr := range addresses { - if addr == address { - return idx - } - } - return -1 -} - -// containsAddress is used to check if a storage slot is present in the provided list of storage slots -// return the index if found else -1 -func containsStorageSlot(storageSlots []common.Hash, storageSlot common.Hash) int { - for idx, slot := range storageSlots { - if slot == storageSlot { - return idx - } - } - return -1 -} - -// getAddresses is used to get the list of addresses from a list of WatchAddressArgs -func getAddresses(args []types.WatchAddressArg) []common.Address { - addresses := make([]common.Address, len(args)) - for idx, arg := range args { - addresses[idx] = common.HexToAddress(arg.Address) - } - - return addresses -} - -// getStorageSlots is used to get the list of storage slots from a list of WatchAddressArgs -func getStorageSlots(args []types.WatchAddressArg) []common.Hash { - storageSlots := make([]common.Hash, len(args)) - for idx, arg := range args { - storageSlots[idx] = common.HexToHash(arg.Address) - } - - return storageSlots -} - -// filterAddressArgs filters out the args having an address from a given list of addresses -func filterAddressArgs(args []types.WatchAddressArg, addressesToRemove []common.Address) ([]types.WatchAddressArg, []common.Address) { - filteredArgs := []types.WatchAddressArg{} - filteredAddresses := []common.Address{} - - for _, arg := range args { - address := common.HexToAddress(arg.Address) - if idx := containsAddress(addressesToRemove, address); idx == -1 { - filteredArgs = append(filteredArgs, arg) - filteredAddresses = append(filteredAddresses, address) - } - } - - return filteredArgs, filteredAddresses -} - -// filterStorageSlotArgs filters out the args having a storage slot from a given list of storage slots -func filterStorageSlotArgs(args []types.WatchAddressArg, storageSlotsToRemove []common.Hash) ([]types.WatchAddressArg, []common.Hash) { - filteredArgs := []types.WatchAddressArg{} - filteredStorageSlots := []common.Hash{} - - for _, arg := range args { - storageSlot := common.HexToHash(arg.Address) - if idx := containsStorageSlot(storageSlotsToRemove, storageSlot); idx == -1 { - filteredArgs = append(filteredArgs, arg) - filteredStorageSlots = append(filteredStorageSlots, storageSlot) - } - } - - return filteredArgs, filteredStorageSlots -} diff --git a/statediff/indexer/indexer.go b/statediff/indexer/indexer.go index 8ec34549b..5132f4942 100644 --- a/statediff/indexer/indexer.go +++ b/statediff/indexer/indexer.go @@ -61,10 +61,10 @@ type Indexer interface { ReportDBMetrics(delay time.Duration, quit <-chan bool) // Methods used by WatchAddress API/functionality. - InsertWatched(addresses []sdtypes.WatchAddressArg, currentBlock *big.Int, kind sdtypes.WatchedAddressType) error - RemoveWatched(addresses []sdtypes.WatchAddressArg, kind sdtypes.WatchedAddressType) error - SetWatched(args []sdtypes.WatchAddressArg, currentBlockNumber *big.Int, kind sdtypes.WatchedAddressType) error - ClearWatched(kind sdtypes.WatchedAddressType) error + InsertWatchedAddresses(addresses []sdtypes.WatchAddressArg, currentBlock *big.Int, kind sdtypes.WatchedAddressType) error + RemoveWatchedAddresses(addresses []sdtypes.WatchAddressArg, kind sdtypes.WatchedAddressType) error + SetWatchedAddresses(args []sdtypes.WatchAddressArg, currentBlockNumber *big.Int, kind sdtypes.WatchedAddressType) error + ClearWatchedAddresses(kind sdtypes.WatchedAddressType) error } // StateDiffIndexer satisfies the Indexer interface for ethereum statediff objects @@ -557,7 +557,7 @@ func (sdi *StateDiffIndexer) PushCodeAndCodeHash(tx *BlockTx, codeAndCodeHash sd } // InsertWatchedAddresses inserts the given addresses | storage slots in the database -func (sdi *StateDiffIndexer) InsertWatched(args []sdtypes.WatchAddressArg, currentBlockNumber *big.Int, kind sdtypes.WatchedAddressType) error { +func (sdi *StateDiffIndexer) InsertWatchedAddresses(args []sdtypes.WatchAddressArg, currentBlockNumber *big.Int, kind sdtypes.WatchedAddressType) error { tx, err := sdi.dbWriter.db.Begin() if err != nil { return err @@ -581,7 +581,7 @@ func (sdi *StateDiffIndexer) InsertWatched(args []sdtypes.WatchAddressArg, curre } // RemoveWatchedAddresses removes the given addresses | storage slots from the database -func (sdi *StateDiffIndexer) RemoveWatched(args []sdtypes.WatchAddressArg, kind sdtypes.WatchedAddressType) error { +func (sdi *StateDiffIndexer) RemoveWatchedAddresses(args []sdtypes.WatchAddressArg, kind sdtypes.WatchedAddressType) error { tx, err := sdi.dbWriter.db.Begin() if err != nil { return err @@ -603,8 +603,8 @@ func (sdi *StateDiffIndexer) RemoveWatched(args []sdtypes.WatchAddressArg, kind return nil } -// SetWatched clears and inserts the given addresses | storage slots in the database -func (sdi *StateDiffIndexer) SetWatched(args []sdtypes.WatchAddressArg, currentBlockNumber *big.Int, kind sdtypes.WatchedAddressType) error { +// SetWatchedAddresses clears and inserts the given addresses | storage slots in the database +func (sdi *StateDiffIndexer) SetWatchedAddresses(args []sdtypes.WatchAddressArg, currentBlockNumber *big.Int, kind sdtypes.WatchedAddressType) error { tx, err := sdi.dbWriter.db.Begin() if err != nil { return err @@ -633,7 +633,7 @@ func (sdi *StateDiffIndexer) SetWatched(args []sdtypes.WatchAddressArg, currentB } // ClearWatchedAddresses clears all the addresses | storage slots from the database -func (sdi *StateDiffIndexer) ClearWatched(kind sdtypes.WatchedAddressType) error { +func (sdi *StateDiffIndexer) ClearWatchedAddresses(kind sdtypes.WatchedAddressType) error { _, err := sdi.dbWriter.db.Exec(`DELETE FROM eth.watched_addresses WHERE kind = $1`, kind.Int()) if err != nil { return fmt.Errorf("error clearing watched_addresses table: %v", err) diff --git a/statediff/service.go b/statediff/service.go index 75268596c..39e874d98 100644 --- a/statediff/service.go +++ b/statediff/service.go @@ -41,6 +41,7 @@ import ( "github.com/ethereum/go-ethereum/rlp" "github.com/ethereum/go-ethereum/rpc" "github.com/ethereum/go-ethereum/trie" + "github.com/thoas/go-funk" ind "github.com/ethereum/go-ethereum/statediff/indexer" nodeinfo "github.com/ethereum/go-ethereum/statediff/indexer/node" @@ -49,10 +50,11 @@ import ( ) const ( - chainEventChanSize = 20000 - genesisBlockNumber = 0 - defaultRetryLimit = 3 // default retry limit once deadlock is detected. - deadlockDetected = "deadlock detected" // 40P01 https://www.postgresql.org/docs/current/errcodes-appendix.html + chainEventChanSize = 20000 + genesisBlockNumber = 0 + defaultRetryLimit = 3 // default retry limit once deadlock is detected. + deadlockDetected = "deadlock detected" // 40P01 https://www.postgresql.org/docs/current/errcodes-appendix.html + typeAssertionFailed = "type assertion failed" ) var writeLoopParams = ParamsWithMutex{ @@ -211,7 +213,7 @@ func New(stack *node.Node, ethServ *eth.Ethereum, cfg *ethconfig.Config, params stack.RegisterLifecycle(sds) stack.RegisterAPIs(sds.APIs()) - err = loadWatched(db) + err = loadWatchedAddresses(db) if err != nil { return err } @@ -734,7 +736,7 @@ func (sds *Service) writeStateDiffWithRetry(block *types.Block, parentRoot commo return err } -// Performs one of foll. operations on the watched addresses | storage slots in writeLoopParams and the db: +// Performs one of following operations on the watched addresses | storage slots in writeLoopParams and the db: // AddAddresses | RemoveAddresses | SetAddresses | ClearAddresses // AddStorageSlots | RemoveStorageSlots | SetStorageSlots | ClearStorageSlots func (sds *Service) WatchAddress(operation OperationType, args []WatchAddressArg) error { @@ -747,93 +749,155 @@ func (sds *Service) WatchAddress(operation OperationType, args []WatchAddressArg switch operation { case AddAddresses: - addressesToRemove := []common.Address{} - for _, arg := range args { - // Check if address is already being watched - // Throw a warning and continue if found - address := common.HexToAddress(arg.Address) - if containsAddress(writeLoopParams.WatchedAddresses, address) != -1 { + // filter out args having an already watched address with a warning + filteredArgs, ok := funk.Filter(args, func(arg WatchAddressArg) bool { + if funk.Contains(writeLoopParams.WatchedAddresses, common.HexToAddress(arg.Address)) { log.Warn("Address already being watched", "address", arg.Address) - addressesToRemove = append(addressesToRemove, address) - continue + return false } + return true + }).([]WatchAddressArg) + if !ok { + return fmt.Errorf("AddAddresses: filtered args %s", typeAssertionFailed) } - // remove already watched addresses - filteredArgs, filteredAddresses := filterAddressArgs(args, addressesToRemove) + // get addresses from the filtered args + filteredAddresses, ok := funk.Map(filteredArgs, func(arg WatchAddressArg) common.Address { + return common.HexToAddress(arg.Address) + }).([]common.Address) + if !ok { + return fmt.Errorf("AddAddresses: filtered addresses %s", typeAssertionFailed) + } - err := sds.indexer.InsertWatched(filteredArgs, currentBlockNumber, WatchedAddress) + // update the db + err := sds.indexer.InsertWatchedAddresses(filteredArgs, currentBlockNumber, WatchedAddress) if err != nil { return err } + // update in-memory params writeLoopParams.WatchedAddresses = append(writeLoopParams.WatchedAddresses, filteredAddresses...) case RemoveAddresses: - addresses := getAddresses(args) + // get addresses from args + argAddresses, ok := funk.Map(args, func(arg WatchAddressArg) common.Address { + return common.HexToAddress(arg.Address) + }).([]common.Address) + if !ok { + return fmt.Errorf("RemoveAddresses: mapped addresses %s", typeAssertionFailed) + } - err := sds.indexer.RemoveWatched(args, WatchedAddress) + // remove the provided addresses from currently watched addresses + addresses, ok := funk.Subtract(writeLoopParams.WatchedAddresses, argAddresses).([]common.Address) + if !ok { + return fmt.Errorf("RemoveAddresses: filtered addresses %s", typeAssertionFailed) + } + + // update the db + err := sds.indexer.RemoveWatchedAddresses(args, WatchedAddress) if err != nil { return err } - writeLoopParams.WatchedAddresses = removeAddresses(writeLoopParams.WatchedAddresses, addresses) - case SetAddresses: - err := sds.indexer.SetWatched(args, currentBlockNumber, WatchedAddress) - if err != nil { - return err - } - - addresses := getAddresses(args) + // update in-memory params writeLoopParams.WatchedAddresses = addresses - case ClearAddresses: - err := sds.indexer.ClearWatched(WatchedAddress) + case SetAddresses: + // get addresses from args + argAddresses, ok := funk.Map(args, func(arg WatchAddressArg) common.Address { + return common.HexToAddress(arg.Address) + }).([]common.Address) + if !ok { + return fmt.Errorf("SetAddresses: mapped addresses %s", typeAssertionFailed) + } + + // update the db + err := sds.indexer.SetWatchedAddresses(args, currentBlockNumber, WatchedAddress) if err != nil { return err } + // update in-memory params + writeLoopParams.WatchedAddresses = argAddresses + case ClearAddresses: + // update the db + err := sds.indexer.ClearWatchedAddresses(WatchedAddress) + if err != nil { + return err + } + + // update in-memory params writeLoopParams.WatchedAddresses = nil case AddStorageSlots: - storageSlotsToRemove := []common.Hash{} - for _, arg := range args { - // Check if address is already being watched - // Throw a warning and continue if found - storageSlot := common.HexToHash(arg.Address) - if containsStorageSlot(writeLoopParams.WatchedStorageSlots, storageSlot) != -1 { - log.Warn("StorageSlot already being watched", "storage slot", arg.Address) - storageSlotsToRemove = append(storageSlotsToRemove, storageSlot) - continue + // filter out args having an already watched storage slot with a warning + filteredArgs, ok := funk.Filter(args, func(arg WatchAddressArg) bool { + if funk.Contains(writeLoopParams.WatchedStorageSlots, common.HexToHash(arg.Address)) { + log.Warn("StorageSlot already being watched", "address", arg.Address) + return false } + return true + }).([]WatchAddressArg) + if !ok { + return fmt.Errorf("AddStorageSlots: filtered args %s", typeAssertionFailed) } - // remove already watched addresses - filteredArgs, filteredStorageSlots := filterStorageSlotArgs(args, storageSlotsToRemove) + // get storage slots from the filtered args + filteredStorageSlots, ok := funk.Map(filteredArgs, func(arg WatchAddressArg) common.Hash { + return common.HexToHash(arg.Address) + }).([]common.Hash) + if !ok { + return fmt.Errorf("AddStorageSlots: filtered storage slots %s", typeAssertionFailed) + } - err := sds.indexer.InsertWatched(filteredArgs, currentBlockNumber, WatchedStorageSlot) + // update the db + err := sds.indexer.InsertWatchedAddresses(filteredArgs, currentBlockNumber, WatchedStorageSlot) if err != nil { return err } + // update in-memory params writeLoopParams.WatchedStorageSlots = append(writeLoopParams.WatchedStorageSlots, filteredStorageSlots...) case RemoveStorageSlots: - storageSlots := getStorageSlots(args) + // get storage slots from args + argStorageSlots, ok := funk.Map(args, func(arg WatchAddressArg) common.Hash { + return common.HexToHash(arg.Address) + }).([]common.Hash) + if !ok { + return fmt.Errorf("RemoveStorageSlots: mapped storage slots %s", typeAssertionFailed) + } - err := sds.indexer.RemoveWatched(args, WatchedStorageSlot) + // remove the provided storage slots from currently watched storage slots + storageSlots, ok := funk.Subtract(writeLoopParams.WatchedStorageSlots, argStorageSlots).([]common.Hash) + if !ok { + return fmt.Errorf("RemoveStorageSlots: filtered storage slots %s", typeAssertionFailed) + } + + // update the db + err := sds.indexer.RemoveWatchedAddresses(args, WatchedStorageSlot) if err != nil { return err } - writeLoopParams.WatchedStorageSlots = removeStorageSlots(writeLoopParams.WatchedStorageSlots, storageSlots) - case SetStorageSlots: - err := sds.indexer.SetWatched(args, currentBlockNumber, WatchedStorageSlot) - if err != nil { - return err - } - - storageSlots := getStorageSlots(args) + // update in-memory params writeLoopParams.WatchedStorageSlots = storageSlots + case SetStorageSlots: + // get storage slots from args + argStorageSlots, ok := funk.Map(args, func(arg WatchAddressArg) common.Hash { + return common.HexToHash(arg.Address) + }).([]common.Hash) + if !ok { + return fmt.Errorf("SetStorageSlots: mapped storage slots %s", typeAssertionFailed) + } + + // update the db + err := sds.indexer.SetWatchedAddresses(args, currentBlockNumber, WatchedStorageSlot) + if err != nil { + return err + } + + // update in-memory params + writeLoopParams.WatchedStorageSlots = argStorageSlots case ClearStorageSlots: - err := sds.indexer.ClearWatched(WatchedStorageSlot) + err := sds.indexer.ClearWatchedAddresses(WatchedStorageSlot) if err != nil { return err } -- 2.45.2 From 381c61c51713219c45bb2b4631d6f533be35995f Mon Sep 17 00:00:00 2001 From: prathamesh0 Date: Mon, 31 Jan 2022 17:34:32 +0530 Subject: [PATCH 08/17] Add a fix in builder for removal of a non-watched address --- statediff/builder.go | 40 +++++----- statediff/builder_test.go | 149 +++++++++++++++++++++++++++++++++++++- 2 files changed, 169 insertions(+), 20 deletions(-) diff --git a/statediff/builder.go b/statediff/builder.go index 7befb6b3c..63b354a4c 100644 --- a/statediff/builder.go +++ b/statediff/builder.go @@ -202,7 +202,7 @@ func (sdb *builder) buildStateDiffWithIntermediateStateNodes(args StateRoots, pa // a map of their leafkey to all the accounts that were touched and exist at A diffAccountsAtA, err := sdb.deletedOrUpdatedState( oldTrie.NodeIterator([]byte{}), newTrie.NodeIterator([]byte{}), - diffPathsAtB, output) + diffPathsAtB, params.WatchedAddresses, output) if err != nil { return fmt.Errorf("error collecting deletedOrUpdatedNodes: %v", err) } @@ -256,7 +256,7 @@ func (sdb *builder) buildStateDiffWithoutIntermediateStateNodes(args StateRoots, // a map of their leafkey to all the accounts that were touched and exist at A diffAccountsAtA, err := sdb.deletedOrUpdatedState( oldTrie.NodeIterator([]byte{}), newTrie.NodeIterator([]byte{}), - diffPathsAtB, output) + diffPathsAtB, params.WatchedAddresses, output) if err != nil { return fmt.Errorf("error collecting deletedOrUpdatedNodes: %v", err) } @@ -386,7 +386,7 @@ func (sdb *builder) createdAndUpdatedStateWithIntermediateNodes(a, b trie.NodeIt // deletedOrUpdatedState returns a slice of all the pathes that are emptied at B // and a mapping of their leafkeys to all the accounts that exist in a different state at A than B -func (sdb *builder) deletedOrUpdatedState(a, b trie.NodeIterator, diffPathsAtB map[string]bool, output StateNodeSink) (AccountMap, error) { +func (sdb *builder) deletedOrUpdatedState(a, b trie.NodeIterator, diffPathsAtB map[string]bool, watchedAddresses []common.Address, output StateNodeSink) (AccountMap, error) { diffAccountAtA := make(AccountMap) it, _ := trie.NewDifferenceIterator(b, a) for it.Next(true) { @@ -409,24 +409,26 @@ func (sdb *builder) deletedOrUpdatedState(a, b trie.NodeIterator, diffPathsAtB m valueNodePath := append(node.Path, partialPath...) encodedPath := trie.HexToCompact(valueNodePath) leafKey := encodedPath[1:] - diffAccountAtA[common.Bytes2Hex(leafKey)] = accountWrapper{ - NodeType: node.NodeType, - Path: node.Path, - NodeValue: node.NodeValue, - LeafKey: leafKey, - Account: &account, - } - // if this node's path did not show up in diffPathsAtB - // that means the node at this path was deleted (or moved) in B - // emit an empty "removed" diff to signify as such - if _, ok := diffPathsAtB[common.Bytes2Hex(node.Path)]; !ok { - if err := output(StateNode{ + if isWatchedAddress(watchedAddresses, leafKey) { + diffAccountAtA[common.Bytes2Hex(leafKey)] = accountWrapper{ + NodeType: node.NodeType, Path: node.Path, - NodeValue: []byte{}, - NodeType: Removed, + NodeValue: node.NodeValue, LeafKey: leafKey, - }); err != nil { - return nil, err + Account: &account, + } + // if this node's path did not show up in diffPathsAtB + // that means the node at this path was deleted (or moved) in B + // emit an empty "removed" diff to signify as such + if _, ok := diffPathsAtB[common.Bytes2Hex(node.Path)]; !ok { + if err := output(StateNode{ + Path: node.Path, + NodeValue: []byte{}, + NodeType: Removed, + LeafKey: leafKey, + }); err != nil { + return nil, err + } } } case Extension, Branch: diff --git a/statediff/builder_test.go b/statediff/builder_test.go index 6a88bbba0..189295518 100644 --- a/statediff/builder_test.go +++ b/statediff/builder_test.go @@ -1152,13 +1152,14 @@ func TestBuilderWithWatchedAddressList(t *testing.T) { } func TestBuilderWithWatchedAddressAndStorageKeyList(t *testing.T) { - blocks, chain := testhelpers.MakeChain(3, testhelpers.Genesis, testhelpers.TestChainGen) + blocks, chain := testhelpers.MakeChain(4, testhelpers.Genesis, testhelpers.TestChainGen) contractLeafKey = testhelpers.AddressToLeafKey(testhelpers.ContractAddr) defer chain.Stop() block0 = testhelpers.Genesis block1 = blocks[0] block2 = blocks[1] block3 = blocks[2] + block4 = blocks[3] params := statediff.Params{ WatchedAddresses: []common.Address{testhelpers.Account1Addr, testhelpers.ContractAddr}, WatchedStorageSlots: []common.Hash{slot1StorageKey}, @@ -1290,6 +1291,35 @@ func TestBuilderWithWatchedAddressAndStorageKeyList(t *testing.T) { }, }, }, + { + "testBlock4", + statediff.Args{ + OldStateRoot: block3.Root(), + NewStateRoot: block4.Root(), + BlockNumber: block4.Number(), + BlockHash: block4.Hash(), + }, + &statediff.StateObject{ + BlockNumber: block4.Number(), + BlockHash: block4.Hash(), + Nodes: []sdtypes.StateNode{ + { + Path: []byte{'\x06'}, + NodeType: sdtypes.Leaf, + LeafKey: contractLeafKey, + NodeValue: contractAccountAtBlock4LeafNode, + StorageNodes: []sdtypes.StorageNode{ + { + Path: []byte{'\x0b'}, + NodeType: sdtypes.Removed, + LeafKey: slot1StorageKey.Bytes(), + NodeValue: []byte{}, + }, + }, + }, + }, + }, + }, } for _, test := range tests { @@ -1718,6 +1748,123 @@ func TestBuilderWithRemovedAccountAndStorageWithoutIntermediateNodes(t *testing. } } +func TestBuilderWithRemovedNonWatchedAccount(t *testing.T) { + blocks, chain := testhelpers.MakeChain(6, testhelpers.Genesis, testhelpers.TestChainGen) + contractLeafKey = testhelpers.AddressToLeafKey(testhelpers.ContractAddr) + defer chain.Stop() + block3 = blocks[2] + block4 = blocks[3] + block5 = blocks[4] + block6 = blocks[5] + params := statediff.Params{ + WatchedAddresses: []common.Address{testhelpers.Account1Addr, testhelpers.Account2Addr}, + } + builder = statediff.NewBuilder(chain.StateCache()) + + var tests = []struct { + name string + startingArguments statediff.Args + expected *statediff.StateObject + }{ + { + "testBlock4", + statediff.Args{ + OldStateRoot: block3.Root(), + NewStateRoot: block4.Root(), + BlockNumber: block4.Number(), + BlockHash: block4.Hash(), + }, + &statediff.StateObject{ + BlockNumber: block4.Number(), + BlockHash: block4.Hash(), + Nodes: []sdtypes.StateNode{ + { + Path: []byte{'\x0c'}, + NodeType: sdtypes.Leaf, + LeafKey: testhelpers.Account2LeafKey, + NodeValue: account2AtBlock4LeafNode, + StorageNodes: emptyStorage, + }, + }, + }, + }, + { + "testBlock5", + statediff.Args{ + OldStateRoot: block4.Root(), + NewStateRoot: block5.Root(), + BlockNumber: block5.Number(), + BlockHash: block5.Hash(), + }, + &statediff.StateObject{ + BlockNumber: block5.Number(), + BlockHash: block5.Hash(), + Nodes: []sdtypes.StateNode{ + { + Path: []byte{'\x0e'}, + NodeType: sdtypes.Leaf, + LeafKey: testhelpers.Account1LeafKey, + NodeValue: account1AtBlock5LeafNode, + StorageNodes: emptyStorage, + }, + }, + }, + }, + { + "testBlock6", + statediff.Args{ + OldStateRoot: block5.Root(), + NewStateRoot: block6.Root(), + BlockNumber: block6.Number(), + BlockHash: block6.Hash(), + }, + &statediff.StateObject{ + BlockNumber: block6.Number(), + BlockHash: block6.Hash(), + Nodes: []sdtypes.StateNode{ + { + Path: []byte{'\x0c'}, + NodeType: sdtypes.Leaf, + LeafKey: testhelpers.Account2LeafKey, + NodeValue: account2AtBlock6LeafNode, + StorageNodes: emptyStorage, + }, + { + Path: []byte{'\x0e'}, + NodeType: sdtypes.Leaf, + LeafKey: testhelpers.Account1LeafKey, + NodeValue: account1AtBlock6LeafNode, + StorageNodes: emptyStorage, + }, + }, + }, + }, + } + + for _, test := range tests { + diff, err := builder.BuildStateDiffObject(test.startingArguments, params) + if err != nil { + t.Error(err) + } + receivedStateDiffRlp, err := rlp.EncodeToBytes(diff) + if err != nil { + t.Error(err) + } + + expectedStateDiffRlp, err := rlp.EncodeToBytes(test.expected) + if err != nil { + t.Error(err) + } + + sort.Slice(receivedStateDiffRlp, func(i, j int) bool { return receivedStateDiffRlp[i] < receivedStateDiffRlp[j] }) + sort.Slice(expectedStateDiffRlp, func(i, j int) bool { return expectedStateDiffRlp[i] < expectedStateDiffRlp[j] }) + if !bytes.Equal(receivedStateDiffRlp, expectedStateDiffRlp) { + t.Logf("Test failed: %s", test.name) + t.Errorf("actual state diff: %+v\r\n\r\n\r\nexpected state diff: %+v", diff, test.expected) + } + } +} + var ( slot00StorageValue = common.Hex2Bytes("9471562b71999873db5b286df957af199ec94617f7") // prefixed TestBankAddress -- 2.45.2 From 19aff30d614a53b0e6a8da66fb835e1a852bb5ca Mon Sep 17 00:00:00 2001 From: prathamesh0 Date: Tue, 1 Feb 2022 15:32:39 +0530 Subject: [PATCH 09/17] Add tests for the API to change addresses being watched in direct indexing --- statediff/helpers.go | 6 +- statediff/service.go | 11 +- statediff/testhelpers/mocks/blockchain.go | 8 +- statediff/testhelpers/mocks/indexer.go | 59 +++ statediff/testhelpers/mocks/service.go | 179 ++++++- statediff/testhelpers/mocks/service_test.go | 507 ++++++++++++++++++++ 6 files changed, 754 insertions(+), 16 deletions(-) create mode 100644 statediff/testhelpers/mocks/indexer.go diff --git a/statediff/helpers.go b/statediff/helpers.go index 4d9b886e4..ba809fc06 100644 --- a/statediff/helpers.go +++ b/statediff/helpers.go @@ -90,8 +90,10 @@ func loadWatchedAddresses(db *postgres.DB) error { return fmt.Errorf("error loading watched addresses: %v", err) } - var watchedAddresses []common.Address - var watchedStorageSlots []common.Hash + var ( + watchedAddresses = []common.Address{} + watchedStorageSlots = []common.Hash{} + ) for _, entry := range watched { switch entry.Kind { case types.WatchedAddress.Int(): diff --git a/statediff/service.go b/statediff/service.go index 39e874d98..edff82d67 100644 --- a/statediff/service.go +++ b/statediff/service.go @@ -109,8 +109,6 @@ type IService interface { WriteLoop(chainEventCh chan core.ChainEvent) // Method to change the addresses being watched in write loop params WatchAddress(operation OperationType, args []WatchAddressArg) error - // Method to get currently watched addresses from write loop params - GetWatchedAddresses() []common.Address } // Wraps consructor parameters @@ -825,7 +823,7 @@ func (sds *Service) WatchAddress(operation OperationType, args []WatchAddressArg } // update in-memory params - writeLoopParams.WatchedAddresses = nil + writeLoopParams.WatchedAddresses = []common.Address{} case AddStorageSlots: // filter out args having an already watched storage slot with a warning @@ -902,7 +900,7 @@ func (sds *Service) WatchAddress(operation OperationType, args []WatchAddressArg return err } - writeLoopParams.WatchedStorageSlots = nil + writeLoopParams.WatchedStorageSlots = []common.Hash{} default: return fmt.Errorf("Unexpected operation %s", operation) @@ -910,8 +908,3 @@ func (sds *Service) WatchAddress(operation OperationType, args []WatchAddressArg return nil } - -// Gets currently watched addresses from the in-memory write loop params -func (sds *Service) GetWatchedAddresses() []common.Address { - return writeLoopParams.WatchedAddresses -} diff --git a/statediff/testhelpers/mocks/blockchain.go b/statediff/testhelpers/mocks/blockchain.go index f2834a4a8..f2a77af64 100644 --- a/statediff/testhelpers/mocks/blockchain.go +++ b/statediff/testhelpers/mocks/blockchain.go @@ -39,6 +39,7 @@ type BlockChain struct { Receipts map[common.Hash]types.Receipts TDByHash map[common.Hash]*big.Int TDByNum map[uint64]*big.Int + currentBlock *types.Block } // SetBlocksForHashes mock method @@ -128,9 +129,14 @@ func (bc *BlockChain) GetTd(hash common.Hash, blockNum uint64) *big.Int { return nil } +// SetCurrentBlock test method +func (bc *BlockChain) SetCurrentBlock(block *types.Block) { + bc.currentBlock = block +} + // CurrentBlock mock method func (bc *BlockChain) CurrentBlock() *types.Block { - return nil + return bc.currentBlock } func (bc *BlockChain) SetTd(hash common.Hash, blockNum uint64, td *big.Int) { diff --git a/statediff/testhelpers/mocks/indexer.go b/statediff/testhelpers/mocks/indexer.go new file mode 100644 index 000000000..89459b558 --- /dev/null +++ b/statediff/testhelpers/mocks/indexer.go @@ -0,0 +1,59 @@ +// Copyright 2022 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package mocks + +import ( + "math/big" + "time" + + "github.com/ethereum/go-ethereum/core/types" + "github.com/ethereum/go-ethereum/statediff/indexer" + sdtypes "github.com/ethereum/go-ethereum/statediff/types" +) + +// Indexer is a mock state diff indexer +type Indexer struct{} + +func (sdi *Indexer) PushBlock(block *types.Block, receipts types.Receipts, totalDifficulty *big.Int) (*indexer.BlockTx, error) { + return nil, nil +} + +func (sdi *Indexer) PushStateNode(tx *indexer.BlockTx, stateNode sdtypes.StateNode) error { + return nil +} + +func (sdi *Indexer) PushCodeAndCodeHash(tx *indexer.BlockTx, codeAndCodeHash sdtypes.CodeAndCodeHash) error { + return nil +} + +func (sdi *Indexer) ReportDBMetrics(delay time.Duration, quit <-chan bool) {} + +func (sdi *Indexer) InsertWatchedAddresses(addresses []sdtypes.WatchAddressArg, currentBlock *big.Int, kind sdtypes.WatchedAddressType) error { + return nil +} + +func (sdi *Indexer) RemoveWatchedAddresses(addresses []sdtypes.WatchAddressArg, kind sdtypes.WatchedAddressType) error { + return nil +} + +func (sdi *Indexer) SetWatchedAddresses(args []sdtypes.WatchAddressArg, currentBlockNumber *big.Int, kind sdtypes.WatchedAddressType) error { + return nil +} + +func (sdi *Indexer) ClearWatchedAddresses(kind sdtypes.WatchedAddressType) error { + return nil +} diff --git a/statediff/testhelpers/mocks/service.go b/statediff/testhelpers/mocks/service.go index f8d1b6b20..1e5d3c1ba 100644 --- a/statediff/testhelpers/mocks/service.go +++ b/statediff/testhelpers/mocks/service.go @@ -25,6 +25,7 @@ import ( "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/crypto" "github.com/ethereum/go-ethereum/rlp" + "github.com/thoas/go-funk" "github.com/ethereum/go-ethereum/core" "github.com/ethereum/go-ethereum/core/types" @@ -32,9 +33,12 @@ import ( "github.com/ethereum/go-ethereum/p2p" "github.com/ethereum/go-ethereum/rpc" "github.com/ethereum/go-ethereum/statediff" + ind "github.com/ethereum/go-ethereum/statediff/indexer" sdtypes "github.com/ethereum/go-ethereum/statediff/types" ) +var typeAssertionFailed = "type assertion failed" + // MockStateDiffService is a mock state diff service type MockStateDiffService struct { sync.Mutex @@ -47,6 +51,8 @@ type MockStateDiffService struct { QuitChan chan bool Subscriptions map[common.Hash]map[rpc.ID]statediff.Subscription SubscriptionTypes map[common.Hash]statediff.Params + Indexer ind.Indexer + writeLoopParams statediff.ParamsWithMutex } // Protocols mock method @@ -333,10 +339,175 @@ func sendNonBlockingQuit(id rpc.ID, sub statediff.Subscription) { } } +// WatchAddress mock method func (sds *MockStateDiffService) WatchAddress(operation statediff.OperationType, args []sdtypes.WatchAddressArg) error { + // lock writeLoopParams for a write + sds.writeLoopParams.Lock() + defer sds.writeLoopParams.Unlock() + + // get the current block number + currentBlockNumber := sds.BlockChain.CurrentBlock().Number() + + switch operation { + case statediff.AddAddresses: + // filter out args having an already watched address with a warning + filteredArgs, ok := funk.Filter(args, func(arg sdtypes.WatchAddressArg) bool { + if funk.Contains(sds.writeLoopParams.WatchedAddresses, common.HexToAddress(arg.Address)) { + log.Warn("Address already being watched", "address", arg.Address) + return false + } + return true + }).([]sdtypes.WatchAddressArg) + if !ok { + return fmt.Errorf("AddAddresses: filtered args %s", typeAssertionFailed) + } + + // get addresses from the filtered args + filteredAddresses, ok := funk.Map(filteredArgs, func(arg sdtypes.WatchAddressArg) common.Address { + return common.HexToAddress(arg.Address) + }).([]common.Address) + if !ok { + return fmt.Errorf("AddAddresses: filtered addresses %s", typeAssertionFailed) + } + + // update the db + err := sds.Indexer.InsertWatchedAddresses(filteredArgs, currentBlockNumber, sdtypes.WatchedAddress) + if err != nil { + return err + } + + // update in-memory params + sds.writeLoopParams.WatchedAddresses = append(sds.writeLoopParams.WatchedAddresses, filteredAddresses...) + case statediff.RemoveAddresses: + // get addresses from args + argAddresses, ok := funk.Map(args, func(arg sdtypes.WatchAddressArg) common.Address { + return common.HexToAddress(arg.Address) + }).([]common.Address) + if !ok { + return fmt.Errorf("RemoveAddresses: mapped addresses %s", typeAssertionFailed) + } + + // remove the provided addresses from currently watched addresses + addresses, ok := funk.Subtract(sds.writeLoopParams.WatchedAddresses, argAddresses).([]common.Address) + if !ok { + return fmt.Errorf("RemoveAddresses: filtered addresses %s", typeAssertionFailed) + } + + // update the db + err := sds.Indexer.RemoveWatchedAddresses(args, sdtypes.WatchedAddress) + if err != nil { + return err + } + + // update in-memory params + sds.writeLoopParams.WatchedAddresses = addresses + case statediff.SetAddresses: + // get addresses from args + argAddresses, ok := funk.Map(args, func(arg sdtypes.WatchAddressArg) common.Address { + return common.HexToAddress(arg.Address) + }).([]common.Address) + if !ok { + return fmt.Errorf("SetAddresses: mapped addresses %s", typeAssertionFailed) + } + + // update the db + err := sds.Indexer.SetWatchedAddresses(args, currentBlockNumber, sdtypes.WatchedAddress) + if err != nil { + return err + } + + // update in-memory params + sds.writeLoopParams.WatchedAddresses = argAddresses + case statediff.ClearAddresses: + // update the db + err := sds.Indexer.ClearWatchedAddresses(sdtypes.WatchedAddress) + if err != nil { + return err + } + + // update in-memory params + sds.writeLoopParams.WatchedAddresses = []common.Address{} + + case statediff.AddStorageSlots: + // filter out args having an already watched storage slot with a warning + filteredArgs, ok := funk.Filter(args, func(arg sdtypes.WatchAddressArg) bool { + if funk.Contains(sds.writeLoopParams.WatchedStorageSlots, common.HexToHash(arg.Address)) { + log.Warn("StorageSlot already being watched", "address", arg.Address) + return false + } + return true + }).([]sdtypes.WatchAddressArg) + if !ok { + return fmt.Errorf("AddStorageSlots: filtered args %s", typeAssertionFailed) + } + + // get storage slots from the filtered args + filteredStorageSlots, ok := funk.Map(filteredArgs, func(arg sdtypes.WatchAddressArg) common.Hash { + return common.HexToHash(arg.Address) + }).([]common.Hash) + if !ok { + return fmt.Errorf("AddStorageSlots: filtered storage slots %s", typeAssertionFailed) + } + + // update the db + err := sds.Indexer.InsertWatchedAddresses(filteredArgs, currentBlockNumber, sdtypes.WatchedStorageSlot) + if err != nil { + return err + } + + // update in-memory params + sds.writeLoopParams.WatchedStorageSlots = append(sds.writeLoopParams.WatchedStorageSlots, filteredStorageSlots...) + case statediff.RemoveStorageSlots: + // get storage slots from args + argStorageSlots, ok := funk.Map(args, func(arg sdtypes.WatchAddressArg) common.Hash { + return common.HexToHash(arg.Address) + }).([]common.Hash) + if !ok { + return fmt.Errorf("RemoveStorageSlots: mapped storage slots %s", typeAssertionFailed) + } + + // remove the provided storage slots from currently watched storage slots + storageSlots, ok := funk.Subtract(sds.writeLoopParams.WatchedStorageSlots, argStorageSlots).([]common.Hash) + if !ok { + return fmt.Errorf("RemoveStorageSlots: filtered storage slots %s", typeAssertionFailed) + } + + // update the db + err := sds.Indexer.RemoveWatchedAddresses(args, sdtypes.WatchedStorageSlot) + if err != nil { + return err + } + + // update in-memory params + sds.writeLoopParams.WatchedStorageSlots = storageSlots + case statediff.SetStorageSlots: + // get storage slots from args + argStorageSlots, ok := funk.Map(args, func(arg sdtypes.WatchAddressArg) common.Hash { + return common.HexToHash(arg.Address) + }).([]common.Hash) + if !ok { + return fmt.Errorf("SetStorageSlots: mapped storage slots %s", typeAssertionFailed) + } + + // update the db + err := sds.Indexer.SetWatchedAddresses(args, currentBlockNumber, sdtypes.WatchedStorageSlot) + if err != nil { + return err + } + + // update in-memory params + sds.writeLoopParams.WatchedStorageSlots = argStorageSlots + case statediff.ClearStorageSlots: + err := sds.Indexer.ClearWatchedAddresses(sdtypes.WatchedStorageSlot) + if err != nil { + return err + } + + sds.writeLoopParams.WatchedStorageSlots = []common.Hash{} + + default: + return fmt.Errorf("Unexpected operation %s", operation) + } + return nil } - -func (sds *MockStateDiffService) GetWatchedAddresses() []common.Address { - return []common.Address{} -} diff --git a/statediff/testhelpers/mocks/service_test.go b/statediff/testhelpers/mocks/service_test.go index 8c1fd49cf..bc64e2a01 100644 --- a/statediff/testhelpers/mocks/service_test.go +++ b/statediff/testhelpers/mocks/service_test.go @@ -21,12 +21,14 @@ import ( "fmt" "math/big" "os" + "reflect" "sort" "sync" "testing" "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/core/types" + "github.com/ethereum/go-ethereum/crypto" "github.com/ethereum/go-ethereum/rlp" "github.com/ethereum/go-ethereum/rpc" "github.com/ethereum/go-ethereum/statediff" @@ -87,6 +89,7 @@ func init() { func TestAPI(t *testing.T) { testSubscriptionAPI(t) testHTTPAPI(t) + testWatchAddressAPI(t) } func testSubscriptionAPI(t *testing.T) { @@ -246,3 +249,507 @@ func testHTTPAPI(t *testing.T) { t.Errorf("paylaod does not have the expected total difficulty\r\nactual td: %d\r\nexpected td: %d", payload.TotalDifficulty.Int64(), mockTotalDifficulty.Int64()) } } + +func testWatchAddressAPI(t *testing.T) { + blocks, chain := testhelpers.MakeChain(6, testhelpers.Genesis, testhelpers.TestChainGen) + defer chain.Stop() + block6 := blocks[5] + + mockBlockChain := &BlockChain{} + mockBlockChain.SetCurrentBlock(block6) + mockIndexer := Indexer{} + mockService := MockStateDiffService{ + BlockChain: mockBlockChain, + Indexer: &mockIndexer, + } + + var ( + contract1Address = "0x5d663F5269090bD2A7DC2390c911dF6083D7b28F" + contract2Address = "0x6Eb7e5C66DB8af2E96159AC440cbc8CDB7fbD26B" + contract3Address = "0xcfeB164C328CA13EFd3C77E1980d94975aDfedfc" + contract4Address = "0x0Edf0c4f393a628DE4828B228C48175b3EA297fc" + contract1CreatedAt = uint64(1) + contract2CreatedAt = uint64(2) + contract3CreatedAt = uint64(3) + contract4CreatedAt = uint64(4) + + slot1 = common.HexToHash("1") + slot2 = common.HexToHash("2") + slot3 = common.HexToHash("3") + slot4 = common.HexToHash("4") + slot1StorageKey = crypto.Keccak256Hash(slot1.Bytes()) + slot2StorageKey = crypto.Keccak256Hash(slot2.Bytes()) + slot3StorageKey = crypto.Keccak256Hash(slot3.Bytes()) + slot4StorageKey = crypto.Keccak256Hash(slot4.Bytes()) + slot1StorageKeyHex = crypto.Keccak256Hash(slot1.Bytes()).Hex() + slot2StorageKeyHex = crypto.Keccak256Hash(slot2.Bytes()).Hex() + slot3StorageKeyHex = crypto.Keccak256Hash(slot3.Bytes()).Hex() + slot4StorageKeyHex = crypto.Keccak256Hash(slot4.Bytes()).Hex() + slot1CreatedAt = uint64(1) + slot2CreatedAt = uint64(2) + slot3CreatedAt = uint64(3) + slot4CreatedAt = uint64(4) + + args1 = []sdtypes.WatchAddressArg{ + { + Address: contract1Address, + CreatedAt: contract1CreatedAt, + }, + { + Address: contract2Address, + CreatedAt: contract2CreatedAt, + }, + } + startingParams1 = statediff.Params{ + WatchedAddresses: []common.Address{}, + } + expectedParams1 = statediff.Params{ + WatchedAddresses: []common.Address{ + common.HexToAddress(contract1Address), + common.HexToAddress(contract2Address), + }, + } + + args2 = []sdtypes.WatchAddressArg{ + { + Address: contract3Address, + CreatedAt: contract3CreatedAt, + }, + { + Address: contract2Address, + CreatedAt: contract2CreatedAt, + }, + } + startingParams2 = expectedParams1 + expectedParams2 = statediff.Params{ + WatchedAddresses: []common.Address{ + common.HexToAddress(contract1Address), + common.HexToAddress(contract2Address), + common.HexToAddress(contract3Address), + }, + } + + args3 = []sdtypes.WatchAddressArg{ + { + Address: contract3Address, + CreatedAt: contract3CreatedAt, + }, + { + Address: contract2Address, + CreatedAt: contract2CreatedAt, + }, + } + startingParams3 = expectedParams2 + expectedParams3 = statediff.Params{ + WatchedAddresses: []common.Address{ + common.HexToAddress(contract1Address), + }, + } + + args4 = []sdtypes.WatchAddressArg{ + { + Address: contract1Address, + CreatedAt: contract1CreatedAt, + }, + { + Address: contract2Address, + CreatedAt: contract2CreatedAt, + }, + } + startingParams4 = expectedParams3 + expectedParams4 = statediff.Params{ + WatchedAddresses: []common.Address{}, + } + + args5 = []sdtypes.WatchAddressArg{ + { + Address: contract1Address, + CreatedAt: contract1CreatedAt, + }, + { + Address: contract2Address, + CreatedAt: contract2CreatedAt, + }, + { + Address: contract3Address, + CreatedAt: contract3CreatedAt, + }, + } + startingParams5 = expectedParams4 + expectedParams5 = statediff.Params{ + WatchedAddresses: []common.Address{ + common.HexToAddress(contract1Address), + common.HexToAddress(contract2Address), + common.HexToAddress(contract3Address), + }, + } + + args6 = []sdtypes.WatchAddressArg{ + { + Address: contract4Address, + CreatedAt: contract4CreatedAt, + }, + { + Address: contract2Address, + CreatedAt: contract2CreatedAt, + }, + { + Address: contract3Address, + CreatedAt: contract3CreatedAt, + }, + } + startingParams6 = expectedParams5 + expectedParams6 = statediff.Params{ + WatchedAddresses: []common.Address{ + common.HexToAddress(contract4Address), + common.HexToAddress(contract2Address), + common.HexToAddress(contract3Address), + }, + } + + args7 = []sdtypes.WatchAddressArg{} + startingParams7 = expectedParams6 + expectedParams7 = statediff.Params{ + WatchedAddresses: []common.Address{}, + } + + args8 = []sdtypes.WatchAddressArg{} + startingParams8 = expectedParams6 + expectedParams8 = statediff.Params{ + WatchedAddresses: []common.Address{}, + } + + args9 = []sdtypes.WatchAddressArg{} + startingParams9 = expectedParams8 + expectedParams9 = statediff.Params{ + WatchedAddresses: []common.Address{}, + } + + args10 = []sdtypes.WatchAddressArg{ + { + Address: slot1StorageKeyHex, + CreatedAt: slot1CreatedAt, + }, + { + Address: slot2StorageKeyHex, + CreatedAt: slot2CreatedAt, + }, + } + startingParams10 = statediff.Params{ + WatchedStorageSlots: []common.Hash{}, + } + expectedParams10 = statediff.Params{ + WatchedStorageSlots: []common.Hash{ + slot1StorageKey, + slot2StorageKey, + }, + } + + args11 = []sdtypes.WatchAddressArg{ + { + Address: slot3StorageKeyHex, + CreatedAt: slot3CreatedAt, + }, + { + Address: slot2StorageKeyHex, + CreatedAt: slot2CreatedAt, + }, + } + startingParams11 = expectedParams10 + expectedParams11 = statediff.Params{ + WatchedStorageSlots: []common.Hash{ + slot1StorageKey, + slot2StorageKey, + slot3StorageKey, + }, + } + + args12 = []sdtypes.WatchAddressArg{ + { + Address: slot3StorageKeyHex, + CreatedAt: slot3CreatedAt, + }, + { + Address: slot2StorageKeyHex, + CreatedAt: slot2CreatedAt, + }, + } + startingParams12 = expectedParams11 + expectedParams12 = statediff.Params{ + WatchedStorageSlots: []common.Hash{ + slot1StorageKey, + }, + } + + args13 = []sdtypes.WatchAddressArg{ + { + Address: slot1StorageKeyHex, + CreatedAt: slot1CreatedAt, + }, + { + Address: slot2StorageKeyHex, + CreatedAt: slot2CreatedAt, + }, + } + startingParams13 = expectedParams12 + expectedParams13 = statediff.Params{ + WatchedStorageSlots: []common.Hash{}, + } + + args14 = []sdtypes.WatchAddressArg{ + { + Address: slot1StorageKeyHex, + CreatedAt: slot1CreatedAt, + }, + { + Address: slot2StorageKeyHex, + CreatedAt: slot2CreatedAt, + }, + { + Address: slot3StorageKeyHex, + CreatedAt: slot3CreatedAt, + }, + } + startingParams14 = expectedParams13 + expectedParams14 = statediff.Params{ + WatchedStorageSlots: []common.Hash{ + slot1StorageKey, + slot2StorageKey, + slot3StorageKey, + }, + } + + args15 = []sdtypes.WatchAddressArg{ + { + Address: slot4StorageKeyHex, + CreatedAt: slot4CreatedAt, + }, + { + Address: slot2StorageKeyHex, + CreatedAt: slot2CreatedAt, + }, + { + Address: slot3StorageKeyHex, + CreatedAt: slot3CreatedAt, + }, + } + startingParams15 = expectedParams14 + expectedParams15 = statediff.Params{ + WatchedStorageSlots: []common.Hash{ + slot4StorageKey, + slot2StorageKey, + slot3StorageKey, + }, + } + + args16 = []sdtypes.WatchAddressArg{} + startingParams16 = expectedParams15 + expectedParams16 = statediff.Params{ + WatchedStorageSlots: []common.Hash{}, + } + + args17 = []sdtypes.WatchAddressArg{} + startingParams17 = expectedParams15 + expectedParams17 = statediff.Params{ + WatchedStorageSlots: []common.Hash{}, + } + + args18 = []sdtypes.WatchAddressArg{} + startingParams18 = expectedParams17 + expectedParams18 = statediff.Params{ + WatchedStorageSlots: []common.Hash{}, + } + ) + + tests := []struct { + name string + operation statediff.OperationType + args []sdtypes.WatchAddressArg + startingParams statediff.Params + expectedParams statediff.Params + expectedErr error + }{ + // addresses tests + { + "testAddAddresses", + statediff.AddAddresses, + args1, + startingParams1, + expectedParams1, + nil, + }, + { + "testAddAddressesSomeWatched", + statediff.AddAddresses, + args2, + startingParams2, + expectedParams2, + nil, + }, + { + "testRemoveAddresses", + statediff.RemoveAddresses, + args3, + startingParams3, + expectedParams3, + nil, + }, + { + "testRemoveAddressesSomeWatched", + statediff.RemoveAddresses, + args4, + startingParams4, + expectedParams4, + nil, + }, + { + "testSetAddresses", + statediff.SetAddresses, + args5, + startingParams5, + expectedParams5, + nil, + }, + { + "testSetAddressesSomeWatched", + statediff.SetAddresses, + args6, + startingParams6, + expectedParams6, + nil, + }, + { + "testSetAddressesEmtpyArgs", + statediff.SetAddresses, + args7, + startingParams7, + expectedParams7, + nil, + }, + { + "testClearAddresses", + statediff.ClearAddresses, + args8, + startingParams8, + expectedParams8, + nil, + }, + { + "testClearAddressesEmpty", + statediff.ClearAddresses, + args9, + startingParams9, + expectedParams9, + nil, + }, + + // storage slots tests + { + "testAddStorageSlots", + statediff.AddStorageSlots, + args10, + startingParams10, + expectedParams10, + nil, + }, + { + "testAddStorageSlotsSomeWatched", + statediff.AddStorageSlots, + args11, + startingParams11, + expectedParams11, + nil, + }, + { + "testRemoveStorageSlots", + statediff.RemoveStorageSlots, + args12, + startingParams12, + expectedParams12, + nil, + }, + { + "testRemoveStorageSlotsSomeWatched", + statediff.RemoveStorageSlots, + args13, + startingParams13, + expectedParams13, + nil, + }, + { + "testSetStorageSlots", + statediff.SetStorageSlots, + args14, + startingParams14, + expectedParams14, + nil, + }, + { + "testSetStorageSlotsSomeWatched", + statediff.SetStorageSlots, + args15, + startingParams15, + expectedParams15, + nil, + }, + { + "testSetStorageSlotsEmtpyArgs", + statediff.SetStorageSlots, + args16, + startingParams16, + expectedParams16, + nil, + }, + { + "testClearStorageSlots", + statediff.ClearStorageSlots, + args17, + startingParams17, + expectedParams17, + nil, + }, + { + "testClearStorageSlotsEmpty", + statediff.ClearStorageSlots, + args18, + startingParams18, + expectedParams18, + nil, + }, + + // invalid args + { + "testInvalidOperation", + "WrongOp", + args18, + startingParams18, + statediff.Params{}, + fmt.Errorf("Unexpected operation WrongOp"), + }, + } + + for _, test := range tests { + mockService.writeLoopParams = statediff.ParamsWithMutex{ + Params: test.startingParams, + } + + err := mockService.WatchAddress(test.operation, test.args) + if test.expectedErr != nil { + if err.Error() != test.expectedErr.Error() { + t.Logf("Test failed: %s", test.name) + t.Errorf("actual err: %+v\nexpected err: %+v", err, test.expectedErr) + } + + continue + } + if err != nil { + t.Error(err) + } + + updatedParams := mockService.writeLoopParams.Params + if !reflect.DeepEqual(updatedParams, test.expectedParams) { + t.Logf("Test failed: %s", test.name) + t.Errorf("actual params: %+v\nexpected params: %+v", updatedParams, test.expectedParams) + } + } +} -- 2.45.2 From 4e96d4f4447347af9b3513e6fd670a51fc1ce7c4 Mon Sep 17 00:00:00 2001 From: prathamesh0 Date: Tue, 1 Feb 2022 18:11:13 +0530 Subject: [PATCH 10/17] Add tests for indexer methods used to change addresses being watched --- statediff/indexer/indexer_test.go | 654 +++++++++++++++++++++++++++++- statediff/indexer/test_helpers.go | 4 + 2 files changed, 651 insertions(+), 7 deletions(-) diff --git a/statediff/indexer/indexer_test.go b/statediff/indexer/indexer_test.go index 705c33406..aa6ec9f64 100644 --- a/statediff/indexer/indexer_test.go +++ b/statediff/indexer/indexer_test.go @@ -19,11 +19,13 @@ package indexer_test import ( "bytes" "fmt" + "math/big" "os" "testing" "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/core/types" + "github.com/ethereum/go-ethereum/crypto" "github.com/ethereum/go-ethereum/rlp" "github.com/ethereum/go-ethereum/statediff/indexer" "github.com/ethereum/go-ethereum/statediff/indexer/ipfs" @@ -32,6 +34,7 @@ import ( "github.com/ethereum/go-ethereum/statediff/indexer/models" "github.com/ethereum/go-ethereum/statediff/indexer/postgres" "github.com/ethereum/go-ethereum/statediff/indexer/shared" + sdtypes "github.com/ethereum/go-ethereum/statediff/types" "github.com/ipfs/go-cid" blockstore "github.com/ipfs/go-ipfs-blockstore" dshelp "github.com/ipfs/go-ipfs-ds-help" @@ -45,12 +48,17 @@ var ( ind *indexer.StateDiffIndexer ipfsPgGet = `SELECT data FROM public.blocks WHERE key = $1` - tx1, tx2, tx3, tx4, tx5, rct1, rct2, rct3, rct4, rct5 []byte - mockBlock *types.Block - headerCID, trx1CID, trx2CID, trx3CID, trx4CID, trx5CID cid.Cid - rct1CID, rct2CID, rct3CID, rct4CID, rct5CID cid.Cid - rctLeaf1, rctLeaf2, rctLeaf3, rctLeaf4, rctLeaf5 []byte - state1CID, state2CID, storageCID cid.Cid + tx1, tx2, tx3, tx4, tx5, rct1, rct2, rct3, rct4, rct5 []byte + mockBlock *types.Block + headerCID, trx1CID, trx2CID, trx3CID, trx4CID, trx5CID cid.Cid + rct1CID, rct2CID, rct3CID, rct4CID, rct5CID cid.Cid + rctLeaf1, rctLeaf2, rctLeaf3, rctLeaf4, rctLeaf5 []byte + state1CID, state2CID, storageCID cid.Cid + contract1Address, contract2Address, contract3Address, contract4Address string + contract1CreatedAt, contract2CreatedAt, contract3CreatedAt, contract4CreatedAt uint64 + slot1StorageKeyHex, slot2StorageKeyHex, slot3StorageKeyHex, slot4StorageKeyHex string + slot1CreatedAt, slot2CreatedAt, slot3CreatedAt, slot4CreatedAt uint64 + lastFilledAt, watchedAt1, watchedAt2, watchedAt3 uint64 ) func expectTrue(t *testing.T, value bool) { @@ -161,15 +169,42 @@ func init() { rctLeaf3 = orderedRctLeafNodes[2] rctLeaf4 = orderedRctLeafNodes[3] rctLeaf5 = orderedRctLeafNodes[4] + + contract1Address = "0x5d663F5269090bD2A7DC2390c911dF6083D7b28F" + contract2Address = "0x6Eb7e5C66DB8af2E96159AC440cbc8CDB7fbD26B" + contract3Address = "0xcfeB164C328CA13EFd3C77E1980d94975aDfedfc" + contract4Address = "0x0Edf0c4f393a628DE4828B228C48175b3EA297fc" + contract1CreatedAt = uint64(1) + contract2CreatedAt = uint64(2) + contract3CreatedAt = uint64(3) + contract4CreatedAt = uint64(4) + + slot1StorageKeyHex = crypto.Keccak256Hash(common.HexToHash("1").Bytes()).Hex() + slot2StorageKeyHex = crypto.Keccak256Hash(common.HexToHash("2").Bytes()).Hex() + slot3StorageKeyHex = crypto.Keccak256Hash(common.HexToHash("3").Bytes()).Hex() + slot4StorageKeyHex = crypto.Keccak256Hash(common.HexToHash("4").Bytes()).Hex() + slot1CreatedAt = uint64(1) + slot2CreatedAt = uint64(2) + slot3CreatedAt = uint64(3) + slot4CreatedAt = uint64(4) + + lastFilledAt = uint64(0) + watchedAt1 = uint64(10) + watchedAt2 = uint64(15) + watchedAt3 = uint64(20) } -func setup(t *testing.T) { +func setupIndexer(t *testing.T) { db, err = shared.SetupDB() if err != nil { t.Fatal(err) } ind, err = indexer.NewStateDiffIndexer(mocks.TestConfig, db) require.NoError(t, err) +} + +func setup(t *testing.T) { + setupIndexer(t) var tx *indexer.BlockTx tx, err = ind.PushBlock( mockBlock, @@ -654,3 +689,608 @@ func TestPublishAndIndexer(t *testing.T) { shared.ExpectEqual(t, data, []byte{}) }) } + +func TestWatchAddressMethods(t *testing.T) { + setupIndexer(t) + defer tearDown(t) + + type res struct { + Address string `db:"address"` + Kind int `db:"kind"` + CreatedAt uint64 `db:"created_at"` + WatchedAt uint64 `db:"watched_at"` + LastFilledAt uint64 `db:"last_filled_at"` + } + pgStr := "SELECT * FROM eth.watched_addresses" + + // Watched addresses + t.Run("Insert watched addresses", func(t *testing.T) { + args := []sdtypes.WatchAddressArg{ + { + Address: contract1Address, + CreatedAt: contract1CreatedAt, + }, + { + Address: contract2Address, + CreatedAt: contract2CreatedAt, + }, + } + expectedData := []res{ + { + Address: contract1Address, + Kind: sdtypes.WatchedAddress.Int(), + CreatedAt: contract1CreatedAt, + WatchedAt: watchedAt1, + LastFilledAt: lastFilledAt, + }, + { + Address: contract2Address, + Kind: sdtypes.WatchedAddress.Int(), + CreatedAt: contract2CreatedAt, + WatchedAt: watchedAt1, + LastFilledAt: lastFilledAt, + }, + } + + ind.InsertWatchedAddresses(args, big.NewInt(int64(watchedAt1)), sdtypes.WatchedAddress) + + rows := []res{} + err = db.Select(&rows, pgStr) + if err != nil { + t.Fatal(err) + } + + expectTrue(t, len(rows) == len(expectedData)) + for idx, row := range rows { + shared.ExpectEqual(t, row, expectedData[idx]) + } + }) + + t.Run("Insert watched addresses (some already watched)", func(t *testing.T) { + args := []sdtypes.WatchAddressArg{ + { + Address: contract3Address, + CreatedAt: contract3CreatedAt, + }, + { + Address: contract2Address, + CreatedAt: contract2CreatedAt, + }, + } + expectedData := []res{ + { + Address: contract1Address, + Kind: sdtypes.WatchedAddress.Int(), + CreatedAt: contract1CreatedAt, + WatchedAt: watchedAt1, + LastFilledAt: lastFilledAt, + }, + { + Address: contract2Address, + Kind: sdtypes.WatchedAddress.Int(), + CreatedAt: contract2CreatedAt, + WatchedAt: watchedAt1, + LastFilledAt: lastFilledAt, + }, + { + Address: contract3Address, + Kind: sdtypes.WatchedAddress.Int(), + CreatedAt: contract3CreatedAt, + WatchedAt: watchedAt2, + LastFilledAt: lastFilledAt, + }, + } + + ind.InsertWatchedAddresses(args, big.NewInt(int64(watchedAt2)), sdtypes.WatchedAddress) + + rows := []res{} + err = db.Select(&rows, pgStr) + if err != nil { + t.Fatal(err) + } + + expectTrue(t, len(rows) == len(expectedData)) + for idx, row := range rows { + shared.ExpectEqual(t, row, expectedData[idx]) + } + }) + + t.Run("Remove watched addresses", func(t *testing.T) { + args := []sdtypes.WatchAddressArg{ + { + Address: contract3Address, + CreatedAt: contract3CreatedAt, + }, + { + Address: contract2Address, + CreatedAt: contract2CreatedAt, + }, + } + expectedData := []res{ + { + Address: contract1Address, + Kind: sdtypes.WatchedAddress.Int(), + CreatedAt: contract1CreatedAt, + WatchedAt: watchedAt1, + LastFilledAt: lastFilledAt, + }, + } + + ind.RemoveWatchedAddresses(args, sdtypes.WatchedAddress) + + rows := []res{} + err = db.Select(&rows, pgStr) + if err != nil { + t.Fatal(err) + } + + expectTrue(t, len(rows) == len(expectedData)) + for idx, row := range rows { + shared.ExpectEqual(t, row, expectedData[idx]) + } + }) + + t.Run("Remove watched addresses (some non-watched)", func(t *testing.T) { + args := []sdtypes.WatchAddressArg{ + { + Address: contract1Address, + CreatedAt: contract1CreatedAt, + }, + { + Address: contract2Address, + CreatedAt: contract2CreatedAt, + }, + } + expectedData := []res{} + + ind.RemoveWatchedAddresses(args, sdtypes.WatchedAddress) + + rows := []res{} + err = db.Select(&rows, pgStr) + if err != nil { + t.Fatal(err) + } + + expectTrue(t, len(rows) == len(expectedData)) + for idx, row := range rows { + shared.ExpectEqual(t, row, expectedData[idx]) + } + }) + + t.Run("Set watched addresses", func(t *testing.T) { + args := []sdtypes.WatchAddressArg{ + { + Address: contract1Address, + CreatedAt: contract1CreatedAt, + }, + { + Address: contract2Address, + CreatedAt: contract2CreatedAt, + }, + { + Address: contract3Address, + CreatedAt: contract3CreatedAt, + }, + } + expectedData := []res{ + { + Address: contract1Address, + Kind: sdtypes.WatchedAddress.Int(), + CreatedAt: contract1CreatedAt, + WatchedAt: watchedAt2, + LastFilledAt: lastFilledAt, + }, + { + Address: contract2Address, + Kind: sdtypes.WatchedAddress.Int(), + CreatedAt: contract2CreatedAt, + WatchedAt: watchedAt2, + LastFilledAt: lastFilledAt, + }, + { + Address: contract3Address, + Kind: sdtypes.WatchedAddress.Int(), + CreatedAt: contract3CreatedAt, + WatchedAt: watchedAt2, + LastFilledAt: lastFilledAt, + }, + } + + ind.SetWatchedAddresses(args, big.NewInt(int64(watchedAt2)), sdtypes.WatchedAddress) + + rows := []res{} + err = db.Select(&rows, pgStr) + if err != nil { + t.Fatal(err) + } + + expectTrue(t, len(rows) == len(expectedData)) + for idx, row := range rows { + shared.ExpectEqual(t, row, expectedData[idx]) + } + }) + + t.Run("Set watched addresses (some already watched)", func(t *testing.T) { + args := []sdtypes.WatchAddressArg{ + { + Address: contract4Address, + CreatedAt: contract4CreatedAt, + }, + { + Address: contract2Address, + CreatedAt: contract2CreatedAt, + }, + { + Address: contract3Address, + CreatedAt: contract3CreatedAt, + }, + } + expectedData := []res{ + { + Address: contract4Address, + Kind: sdtypes.WatchedAddress.Int(), + CreatedAt: contract4CreatedAt, + WatchedAt: watchedAt3, + LastFilledAt: lastFilledAt, + }, + { + Address: contract2Address, + Kind: sdtypes.WatchedAddress.Int(), + CreatedAt: contract2CreatedAt, + WatchedAt: watchedAt3, + LastFilledAt: lastFilledAt, + }, + { + Address: contract3Address, + Kind: sdtypes.WatchedAddress.Int(), + CreatedAt: contract3CreatedAt, + WatchedAt: watchedAt3, + LastFilledAt: lastFilledAt, + }, + } + + ind.SetWatchedAddresses(args, big.NewInt(int64(watchedAt3)), sdtypes.WatchedAddress) + + rows := []res{} + err = db.Select(&rows, pgStr) + if err != nil { + t.Fatal(err) + } + + expectTrue(t, len(rows) == len(expectedData)) + for idx, row := range rows { + shared.ExpectEqual(t, row, expectedData[idx]) + } + }) + + t.Run("Clear watched addresses", func(t *testing.T) { + expectedData := []res{} + + ind.ClearWatchedAddresses(sdtypes.WatchedAddress) + + rows := []res{} + err = db.Select(&rows, pgStr) + if err != nil { + t.Fatal(err) + } + + expectTrue(t, len(rows) == len(expectedData)) + for idx, row := range rows { + shared.ExpectEqual(t, row, expectedData[idx]) + } + }) + + t.Run("Clear watched addresses (empty table)", func(t *testing.T) { + expectedData := []res{} + + ind.ClearWatchedAddresses(sdtypes.WatchedAddress) + + rows := []res{} + err = db.Select(&rows, pgStr) + if err != nil { + t.Fatal(err) + } + + expectTrue(t, len(rows) == len(expectedData)) + for idx, row := range rows { + shared.ExpectEqual(t, row, expectedData[idx]) + } + }) + + // Watched storage slots + // Reset the db. + tearDown(t) + + t.Run("Insert watched storage slots", func(t *testing.T) { + args := []sdtypes.WatchAddressArg{ + { + Address: slot1StorageKeyHex, + CreatedAt: slot1CreatedAt, + }, + { + Address: slot2StorageKeyHex, + CreatedAt: slot2CreatedAt, + }, + } + expectedData := []res{ + { + Address: slot1StorageKeyHex, + Kind: sdtypes.WatchedStorageSlot.Int(), + CreatedAt: slot1CreatedAt, + WatchedAt: watchedAt1, + LastFilledAt: lastFilledAt, + }, + { + Address: slot2StorageKeyHex, + Kind: sdtypes.WatchedStorageSlot.Int(), + CreatedAt: slot2CreatedAt, + WatchedAt: watchedAt1, + LastFilledAt: lastFilledAt, + }, + } + + ind.InsertWatchedAddresses(args, big.NewInt(int64(watchedAt1)), sdtypes.WatchedStorageSlot) + + rows := []res{} + err = db.Select(&rows, pgStr) + if err != nil { + t.Fatal(err) + } + + expectTrue(t, len(rows) == len(expectedData)) + for idx, row := range rows { + shared.ExpectEqual(t, row, expectedData[idx]) + } + }) + + t.Run("Insert watched storage slots (some already watched)", func(t *testing.T) { + args := []sdtypes.WatchAddressArg{ + { + Address: slot3StorageKeyHex, + CreatedAt: slot3CreatedAt, + }, + { + Address: slot2StorageKeyHex, + CreatedAt: slot2CreatedAt, + }, + } + expectedData := []res{ + { + Address: slot1StorageKeyHex, + Kind: sdtypes.WatchedStorageSlot.Int(), + CreatedAt: slot1CreatedAt, + WatchedAt: watchedAt1, + LastFilledAt: lastFilledAt, + }, + { + Address: slot2StorageKeyHex, + Kind: sdtypes.WatchedStorageSlot.Int(), + CreatedAt: slot2CreatedAt, + WatchedAt: watchedAt1, + LastFilledAt: lastFilledAt, + }, + { + Address: slot3StorageKeyHex, + Kind: sdtypes.WatchedStorageSlot.Int(), + CreatedAt: slot3CreatedAt, + WatchedAt: watchedAt2, + LastFilledAt: lastFilledAt, + }, + } + + ind.InsertWatchedAddresses(args, big.NewInt(int64(watchedAt2)), sdtypes.WatchedStorageSlot) + + rows := []res{} + err = db.Select(&rows, pgStr) + if err != nil { + t.Fatal(err) + } + + expectTrue(t, len(rows) == len(expectedData)) + for idx, row := range rows { + shared.ExpectEqual(t, row, expectedData[idx]) + } + }) + + t.Run("Remove watched storage slots", func(t *testing.T) { + args := []sdtypes.WatchAddressArg{ + { + Address: slot3StorageKeyHex, + CreatedAt: slot3CreatedAt, + }, + { + Address: slot2StorageKeyHex, + CreatedAt: slot2CreatedAt, + }, + } + expectedData := []res{ + { + Address: slot1StorageKeyHex, + Kind: sdtypes.WatchedStorageSlot.Int(), + CreatedAt: slot1CreatedAt, + WatchedAt: watchedAt1, + LastFilledAt: lastFilledAt, + }, + } + + ind.RemoveWatchedAddresses(args, sdtypes.WatchedStorageSlot) + + rows := []res{} + err = db.Select(&rows, pgStr) + if err != nil { + t.Fatal(err) + } + + expectTrue(t, len(rows) == len(expectedData)) + for idx, row := range rows { + shared.ExpectEqual(t, row, expectedData[idx]) + } + }) + + t.Run("Remove watched storage slots (some non-watched)", func(t *testing.T) { + args := []sdtypes.WatchAddressArg{ + { + Address: slot1StorageKeyHex, + CreatedAt: slot1CreatedAt, + }, + { + Address: slot2StorageKeyHex, + CreatedAt: slot2CreatedAt, + }, + } + expectedData := []res{} + + ind.RemoveWatchedAddresses(args, sdtypes.WatchedStorageSlot) + + rows := []res{} + err = db.Select(&rows, pgStr) + if err != nil { + t.Fatal(err) + } + + expectTrue(t, len(rows) == len(expectedData)) + for idx, row := range rows { + shared.ExpectEqual(t, row, expectedData[idx]) + } + }) + + t.Run("Set watched storage slots", func(t *testing.T) { + args := []sdtypes.WatchAddressArg{ + { + Address: slot1StorageKeyHex, + CreatedAt: slot1CreatedAt, + }, + { + Address: slot2StorageKeyHex, + CreatedAt: slot2CreatedAt, + }, + { + Address: slot3StorageKeyHex, + CreatedAt: slot3CreatedAt, + }, + } + expectedData := []res{ + { + Address: slot1StorageKeyHex, + Kind: sdtypes.WatchedStorageSlot.Int(), + CreatedAt: slot1CreatedAt, + WatchedAt: watchedAt2, + LastFilledAt: lastFilledAt, + }, + { + Address: slot2StorageKeyHex, + Kind: sdtypes.WatchedStorageSlot.Int(), + CreatedAt: slot2CreatedAt, + WatchedAt: watchedAt2, + LastFilledAt: lastFilledAt, + }, + { + Address: slot3StorageKeyHex, + Kind: sdtypes.WatchedStorageSlot.Int(), + CreatedAt: slot3CreatedAt, + WatchedAt: watchedAt2, + LastFilledAt: lastFilledAt, + }, + } + + ind.SetWatchedAddresses(args, big.NewInt(int64(watchedAt2)), sdtypes.WatchedStorageSlot) + + rows := []res{} + err = db.Select(&rows, pgStr) + if err != nil { + t.Fatal(err) + } + + expectTrue(t, len(rows) == len(expectedData)) + for idx, row := range rows { + shared.ExpectEqual(t, row, expectedData[idx]) + } + }) + + t.Run("Set watched storage slots (some already watched)", func(t *testing.T) { + args := []sdtypes.WatchAddressArg{ + { + Address: slot4StorageKeyHex, + CreatedAt: slot4CreatedAt, + }, + { + Address: slot2StorageKeyHex, + CreatedAt: slot2CreatedAt, + }, + { + Address: slot3StorageKeyHex, + CreatedAt: slot3CreatedAt, + }, + } + expectedData := []res{ + { + Address: slot4StorageKeyHex, + Kind: sdtypes.WatchedStorageSlot.Int(), + CreatedAt: slot4CreatedAt, + WatchedAt: watchedAt3, + LastFilledAt: lastFilledAt, + }, + { + Address: slot2StorageKeyHex, + Kind: sdtypes.WatchedStorageSlot.Int(), + CreatedAt: slot2CreatedAt, + WatchedAt: watchedAt3, + LastFilledAt: lastFilledAt, + }, + { + Address: slot3StorageKeyHex, + Kind: sdtypes.WatchedStorageSlot.Int(), + CreatedAt: slot3CreatedAt, + WatchedAt: watchedAt3, + LastFilledAt: lastFilledAt, + }, + } + + ind.SetWatchedAddresses(args, big.NewInt(int64(watchedAt3)), sdtypes.WatchedStorageSlot) + + rows := []res{} + err = db.Select(&rows, pgStr) + if err != nil { + t.Fatal(err) + } + + expectTrue(t, len(rows) == len(expectedData)) + for idx, row := range rows { + shared.ExpectEqual(t, row, expectedData[idx]) + } + }) + + t.Run("Clear watched storage slots", func(t *testing.T) { + expectedData := []res{} + + ind.ClearWatchedAddresses(sdtypes.WatchedStorageSlot) + + rows := []res{} + err = db.Select(&rows, pgStr) + if err != nil { + t.Fatal(err) + } + + expectTrue(t, len(rows) == len(expectedData)) + for idx, row := range rows { + shared.ExpectEqual(t, row, expectedData[idx]) + } + }) + + t.Run("Clear watched storage slots (empty table)", func(t *testing.T) { + expectedData := []res{} + + ind.ClearWatchedAddresses(sdtypes.WatchedStorageSlot) + + rows := []res{} + err = db.Select(&rows, pgStr) + if err != nil { + t.Fatal(err) + } + + expectTrue(t, len(rows) == len(expectedData)) + for idx, row := range rows { + shared.ExpectEqual(t, row, expectedData[idx]) + } + }) +} diff --git a/statediff/indexer/test_helpers.go b/statediff/indexer/test_helpers.go index 024bb58f0..43bb3023d 100644 --- a/statediff/indexer/test_helpers.go +++ b/statediff/indexer/test_helpers.go @@ -53,6 +53,10 @@ func TearDownDB(t *testing.T, db *postgres.DB) { if err != nil { t.Fatal(err) } + _, err = tx.Exec(`DELETE FROM eth.watched_addresses`) + if err != nil { + t.Fatal(err) + } err = tx.Commit() if err != nil { t.Fatal(err) -- 2.45.2 From 8088141a4051d833fab15090a8c0f12073e1b76e Mon Sep 17 00:00:00 2001 From: prathamesh0 Date: Thu, 3 Feb 2022 11:12:51 +0530 Subject: [PATCH 11/17] Remove support for watched storage slots --- statediff/README.md | 4 +- statediff/api.go | 2 +- statediff/builder.go | 97 +++--- statediff/builder_test.go | 193 ----------- statediff/helpers.go | 21 +- statediff/indexer/indexer.go | 38 +-- statediff/indexer/indexer_test.go | 338 +------------------- statediff/service.go | 111 +------ statediff/testhelpers/mocks/indexer.go | 8 +- statediff/testhelpers/mocks/service.go | 110 ++----- statediff/testhelpers/mocks/service_test.go | 251 +-------------- statediff/types.go | 14 +- statediff/types/types.go | 21 +- 13 files changed, 127 insertions(+), 1081 deletions(-) diff --git a/statediff/README.md b/statediff/README.md index 8e673efa6..7cfb6d0f1 100644 --- a/statediff/README.md +++ b/statediff/README.md @@ -106,15 +106,13 @@ type Params struct { IncludeTD bool IncludeCode bool WatchedAddresses []common.Address - WatchedStorageSlots []common.Hash } ``` Using these params we can tell the service whether to include state and/or storage intermediate nodes; whether to include the associated block (header, uncles, and transactions); whether to include the associated receipts; whether to include the total difficulty for this block; whether to include the set of code hashes and code for -contracts deployed in this block; whether to limit the diffing process to a list of specific addresses; and/or -whether to limit the diffing process to a list of specific storage slot keys. +contracts deployed in this block; whether to limit the diffing process to a list of specific addresses. #### Subscription endpoint A websocket supporting RPC endpoint is exposed for subscribing to state diff `StateObjects` that come off the head of the chain while the geth node syncs. diff --git a/statediff/api.go b/statediff/api.go index 3686728f2..ed9cc3c06 100644 --- a/statediff/api.go +++ b/statediff/api.go @@ -150,7 +150,7 @@ func (api *PublicStateDiffAPI) WriteStateDiffFor(ctx context.Context, blockHash return api.sds.WriteStateDiffFor(blockHash, params) } -// WatchAddress changes the list of watched addresses | storage slots to which the direct indexing is restricted according to given operation +// WatchAddress changes the list of watched addresses to which the direct indexing is restricted according to given operation func (api *PublicStateDiffAPI) WatchAddress(operation OperationType, args []WatchAddressArg) error { return api.sds.WatchAddress(operation, args) } diff --git a/statediff/builder.go b/statediff/builder.go index 63b354a4c..aee8f71ff 100644 --- a/statediff/builder.go +++ b/statediff/builder.go @@ -123,7 +123,7 @@ func (sdb *builder) buildStateTrie(it trie.NodeIterator) ([]StateNode, []CodeAnd node.LeafKey = leafKey if !bytes.Equal(account.CodeHash, nullCodeHash) { var storageNodes []StorageNode - err := sdb.buildStorageNodesEventual(account.Root, nil, true, storageNodeAppender(&storageNodes)) + err := sdb.buildStorageNodesEventual(account.Root, true, storageNodeAppender(&storageNodes)) if err != nil { return nil, nil, fmt.Errorf("failed building eventual storage diffs for account %+v\r\nerror: %v", account, err) } @@ -220,12 +220,12 @@ func (sdb *builder) buildStateDiffWithIntermediateStateNodes(args StateRoots, pa // build the diff nodes for the updated accounts using the mappings at both A and B as directed by the keys found as the intersection of the two err = sdb.buildAccountUpdates( diffAccountsAtB, diffAccountsAtA, updatedKeys, - params.WatchedStorageSlots, params.IntermediateStorageNodes, output) + params.IntermediateStorageNodes, output) if err != nil { return fmt.Errorf("error building diff for updated accounts: %v", err) } // build the diff nodes for created accounts - err = sdb.buildAccountCreations(diffAccountsAtB, params.WatchedStorageSlots, params.IntermediateStorageNodes, output, codeOutput) + err = sdb.buildAccountCreations(diffAccountsAtB, params.IntermediateStorageNodes, output, codeOutput) if err != nil { return fmt.Errorf("error building diff for created accounts: %v", err) } @@ -274,12 +274,12 @@ func (sdb *builder) buildStateDiffWithoutIntermediateStateNodes(args StateRoots, // build the diff nodes for the updated accounts using the mappings at both A and B as directed by the keys found as the intersection of the two err = sdb.buildAccountUpdates( diffAccountsAtB, diffAccountsAtA, updatedKeys, - params.WatchedStorageSlots, params.IntermediateStorageNodes, output) + params.IntermediateStorageNodes, output) if err != nil { return fmt.Errorf("error building diff for updated accounts: %v", err) } // build the diff nodes for created accounts - err = sdb.buildAccountCreations(diffAccountsAtB, params.WatchedStorageSlots, params.IntermediateStorageNodes, output, codeOutput) + err = sdb.buildAccountCreations(diffAccountsAtB, params.IntermediateStorageNodes, output, codeOutput) if err != nil { return fmt.Errorf("error building diff for created accounts: %v", err) } @@ -456,8 +456,7 @@ func (sdb *builder) deletedOrUpdatedState(a, b trie.NodeIterator, diffPathsAtB m // to generate the statediff node objects for all of the accounts that existed at both A and B but in different states // needs to be called before building account creations and deletions as this mutates // those account maps to remove the accounts which were updated -func (sdb *builder) buildAccountUpdates(creations, deletions AccountMap, updatedKeys []string, - watchedStorageKeys []common.Hash, intermediateStorageNodes bool, output StateNodeSink) error { +func (sdb *builder) buildAccountUpdates(creations, deletions AccountMap, updatedKeys []string, intermediateStorageNodes bool, output StateNodeSink) error { var err error for _, key := range updatedKeys { createdAcc := creations[key] @@ -467,7 +466,7 @@ func (sdb *builder) buildAccountUpdates(creations, deletions AccountMap, updated oldSR := deletedAcc.Account.Root newSR := createdAcc.Account.Root err = sdb.buildStorageNodesIncremental( - oldSR, newSR, watchedStorageKeys, intermediateStorageNodes, + oldSR, newSR, intermediateStorageNodes, storageNodeAppender(&storageDiffs)) if err != nil { return fmt.Errorf("failed building incremental storage diffs for account with leafkey %s\r\nerror: %v", key, err) @@ -491,7 +490,7 @@ func (sdb *builder) buildAccountUpdates(creations, deletions AccountMap, updated // buildAccountCreations returns the statediff node objects for all the accounts that exist at B but not at A // it also returns the code and codehash for created contract accounts -func (sdb *builder) buildAccountCreations(accounts AccountMap, watchedStorageKeys []common.Hash, intermediateStorageNodes bool, output StateNodeSink, codeOutput CodeSink) error { +func (sdb *builder) buildAccountCreations(accounts AccountMap, intermediateStorageNodes bool, output StateNodeSink, codeOutput CodeSink) error { for _, val := range accounts { diff := StateNode{ NodeType: val.NodeType, @@ -502,7 +501,7 @@ func (sdb *builder) buildAccountCreations(accounts AccountMap, watchedStorageKey if !bytes.Equal(val.Account.CodeHash, nullCodeHash) { // For contract creations, any storage node contained is a diff var storageDiffs []StorageNode - err := sdb.buildStorageNodesEventual(val.Account.Root, watchedStorageKeys, intermediateStorageNodes, storageNodeAppender(&storageDiffs)) + err := sdb.buildStorageNodesEventual(val.Account.Root, intermediateStorageNodes, storageNodeAppender(&storageDiffs)) if err != nil { return fmt.Errorf("failed building eventual storage diffs for node %x\r\nerror: %v", val.Path, err) } @@ -530,7 +529,7 @@ func (sdb *builder) buildAccountCreations(accounts AccountMap, watchedStorageKey // buildStorageNodesEventual builds the storage diff node objects for a created account // i.e. it returns all the storage nodes at this state, since there is no previous state -func (sdb *builder) buildStorageNodesEventual(sr common.Hash, watchedStorageKeys []common.Hash, intermediateNodes bool, output StorageNodeSink) error { +func (sdb *builder) buildStorageNodesEventual(sr common.Hash, intermediateNodes bool, output StorageNodeSink) error { if bytes.Equal(sr.Bytes(), emptyContractRoot.Bytes()) { return nil } @@ -541,7 +540,7 @@ func (sdb *builder) buildStorageNodesEventual(sr common.Hash, watchedStorageKeys return err } it := sTrie.NodeIterator(make([]byte, 0)) - err = sdb.buildStorageNodesFromTrie(it, watchedStorageKeys, intermediateNodes, output) + err = sdb.buildStorageNodesFromTrie(it, intermediateNodes, output) if err != nil { return err } @@ -551,7 +550,7 @@ func (sdb *builder) buildStorageNodesEventual(sr common.Hash, watchedStorageKeys // buildStorageNodesFromTrie returns all the storage diff node objects in the provided node interator // if any storage keys are provided it will only return those leaf nodes // including intermediate nodes can be turned on or off -func (sdb *builder) buildStorageNodesFromTrie(it trie.NodeIterator, watchedStorageKeys []common.Hash, intermediateNodes bool, output StorageNodeSink) error { +func (sdb *builder) buildStorageNodesFromTrie(it trie.NodeIterator, intermediateNodes bool, output StorageNodeSink) error { for it.Next(true) { // skip value nodes if it.Leaf() || bytes.Equal(nullHashBytes, it.Hash().Bytes()) { @@ -567,15 +566,13 @@ func (sdb *builder) buildStorageNodesFromTrie(it trie.NodeIterator, watchedStora valueNodePath := append(node.Path, partialPath...) encodedPath := trie.HexToCompact(valueNodePath) leafKey := encodedPath[1:] - if isWatchedStorageKey(watchedStorageKeys, leafKey) { - if err := output(StorageNode{ - NodeType: node.NodeType, - Path: node.Path, - NodeValue: node.NodeValue, - LeafKey: leafKey, - }); err != nil { - return err - } + if err := output(StorageNode{ + NodeType: node.NodeType, + Path: node.Path, + NodeValue: node.NodeValue, + LeafKey: leafKey, + }); err != nil { + return err } case Extension, Branch: if intermediateNodes { @@ -595,7 +592,7 @@ func (sdb *builder) buildStorageNodesFromTrie(it trie.NodeIterator, watchedStora } // buildStorageNodesIncremental builds the storage diff node objects for all nodes that exist in a different state at B than A -func (sdb *builder) buildStorageNodesIncremental(oldSR common.Hash, newSR common.Hash, watchedStorageKeys []common.Hash, intermediateNodes bool, output StorageNodeSink) error { +func (sdb *builder) buildStorageNodesIncremental(oldSR common.Hash, newSR common.Hash, intermediateNodes bool, output StorageNodeSink) error { if bytes.Equal(newSR.Bytes(), oldSR.Bytes()) { return nil } @@ -611,19 +608,19 @@ func (sdb *builder) buildStorageNodesIncremental(oldSR common.Hash, newSR common diffPathsAtB, err := sdb.createdAndUpdatedStorage( oldTrie.NodeIterator([]byte{}), newTrie.NodeIterator([]byte{}), - watchedStorageKeys, intermediateNodes, output) + intermediateNodes, output) if err != nil { return err } err = sdb.deletedOrUpdatedStorage(oldTrie.NodeIterator([]byte{}), newTrie.NodeIterator([]byte{}), - diffPathsAtB, watchedStorageKeys, intermediateNodes, output) + diffPathsAtB, intermediateNodes, output) if err != nil { return err } return nil } -func (sdb *builder) createdAndUpdatedStorage(a, b trie.NodeIterator, watchedKeys []common.Hash, intermediateNodes bool, output StorageNodeSink) (map[string]bool, error) { +func (sdb *builder) createdAndUpdatedStorage(a, b trie.NodeIterator, intermediateNodes bool, output StorageNodeSink) (map[string]bool, error) { diffPathsAtB := make(map[string]bool) it, _ := trie.NewDifferenceIterator(a, b) for it.Next(true) { @@ -641,15 +638,13 @@ func (sdb *builder) createdAndUpdatedStorage(a, b trie.NodeIterator, watchedKeys valueNodePath := append(node.Path, partialPath...) encodedPath := trie.HexToCompact(valueNodePath) leafKey := encodedPath[1:] - if isWatchedStorageKey(watchedKeys, leafKey) { - if err := output(StorageNode{ - NodeType: node.NodeType, - Path: node.Path, - NodeValue: node.NodeValue, - LeafKey: leafKey, - }); err != nil { - return nil, err - } + if err := output(StorageNode{ + NodeType: node.NodeType, + Path: node.Path, + NodeValue: node.NodeValue, + LeafKey: leafKey, + }); err != nil { + return nil, err } case Extension, Branch: if intermediateNodes { @@ -669,7 +664,7 @@ func (sdb *builder) createdAndUpdatedStorage(a, b trie.NodeIterator, watchedKeys return diffPathsAtB, it.Error() } -func (sdb *builder) deletedOrUpdatedStorage(a, b trie.NodeIterator, diffPathsAtB map[string]bool, watchedKeys []common.Hash, intermediateNodes bool, output StorageNodeSink) error { +func (sdb *builder) deletedOrUpdatedStorage(a, b trie.NodeIterator, diffPathsAtB map[string]bool, intermediateNodes bool, output StorageNodeSink) error { it, _ := trie.NewDifferenceIterator(b, a) for it.Next(true) { // skip value nodes @@ -692,15 +687,13 @@ func (sdb *builder) deletedOrUpdatedStorage(a, b trie.NodeIterator, diffPathsAtB valueNodePath := append(node.Path, partialPath...) encodedPath := trie.HexToCompact(valueNodePath) leafKey := encodedPath[1:] - if isWatchedStorageKey(watchedKeys, leafKey) { - if err := output(StorageNode{ - NodeType: Removed, - Path: node.Path, - NodeValue: []byte{}, - LeafKey: leafKey, - }); err != nil { - return err - } + if err := output(StorageNode{ + NodeType: Removed, + Path: node.Path, + NodeValue: []byte{}, + LeafKey: leafKey, + }); err != nil { + return err } case Extension, Branch: if intermediateNodes { @@ -733,17 +726,3 @@ func isWatchedAddress(watchedAddresses []common.Address, stateLeafKey []byte) bo } return false } - -// isWatchedStorageKey is used to check if a storage leaf corresponds to one of the storage slots the builder is configured to watch -func isWatchedStorageKey(watchedKeys []common.Hash, storageLeafKey []byte) bool { - // If we aren't watching any specific addresses, we are watching everything - if len(watchedKeys) == 0 { - return true - } - for _, hashKey := range watchedKeys { - if bytes.Equal(hashKey.Bytes(), storageLeafKey) { - return true - } - } - return false -} diff --git a/statediff/builder_test.go b/statediff/builder_test.go index 189295518..741605d41 100644 --- a/statediff/builder_test.go +++ b/statediff/builder_test.go @@ -1151,199 +1151,6 @@ func TestBuilderWithWatchedAddressList(t *testing.T) { } } -func TestBuilderWithWatchedAddressAndStorageKeyList(t *testing.T) { - blocks, chain := testhelpers.MakeChain(4, testhelpers.Genesis, testhelpers.TestChainGen) - contractLeafKey = testhelpers.AddressToLeafKey(testhelpers.ContractAddr) - defer chain.Stop() - block0 = testhelpers.Genesis - block1 = blocks[0] - block2 = blocks[1] - block3 = blocks[2] - block4 = blocks[3] - params := statediff.Params{ - WatchedAddresses: []common.Address{testhelpers.Account1Addr, testhelpers.ContractAddr}, - WatchedStorageSlots: []common.Hash{slot1StorageKey}, - } - builder = statediff.NewBuilder(chain.StateCache()) - - var tests = []struct { - name string - startingArguments statediff.Args - expected *statediff.StateObject - }{ - { - "testEmptyDiff", - statediff.Args{ - OldStateRoot: block0.Root(), - NewStateRoot: block0.Root(), - BlockNumber: block0.Number(), - BlockHash: block0.Hash(), - }, - &statediff.StateObject{ - BlockNumber: block0.Number(), - BlockHash: block0.Hash(), - Nodes: emptyDiffs, - }, - }, - { - "testBlock0", - //10000 transferred from testBankAddress to account1Addr - statediff.Args{ - OldStateRoot: testhelpers.NullHash, - NewStateRoot: block0.Root(), - BlockNumber: block0.Number(), - BlockHash: block0.Hash(), - }, - &statediff.StateObject{ - BlockNumber: block0.Number(), - BlockHash: block0.Hash(), - Nodes: emptyDiffs, - }, - }, - { - "testBlock1", - //10000 transferred from testBankAddress to account1Addr - statediff.Args{ - OldStateRoot: block0.Root(), - NewStateRoot: block1.Root(), - BlockNumber: block1.Number(), - BlockHash: block1.Hash(), - }, - &statediff.StateObject{ - BlockNumber: block1.Number(), - BlockHash: block1.Hash(), - Nodes: []sdtypes.StateNode{ - { - Path: []byte{'\x0e'}, - NodeType: sdtypes.Leaf, - LeafKey: testhelpers.Account1LeafKey, - NodeValue: account1AtBlock1LeafNode, - StorageNodes: emptyStorage, - }, - }, - }, - }, - { - "testBlock2", - //1000 transferred from testBankAddress to account1Addr - //1000 transferred from account1Addr to account2Addr - statediff.Args{ - OldStateRoot: block1.Root(), - NewStateRoot: block2.Root(), - BlockNumber: block2.Number(), - BlockHash: block2.Hash(), - }, - &statediff.StateObject{ - BlockNumber: block2.Number(), - BlockHash: block2.Hash(), - Nodes: []sdtypes.StateNode{ - { - Path: []byte{'\x06'}, - NodeType: sdtypes.Leaf, - LeafKey: contractLeafKey, - NodeValue: contractAccountAtBlock2LeafNode, - StorageNodes: []sdtypes.StorageNode{ - { - Path: []byte{'\x0b'}, - NodeType: sdtypes.Leaf, - LeafKey: slot1StorageKey.Bytes(), - NodeValue: slot1StorageLeafNode, - }, - }, - }, - { - Path: []byte{'\x0e'}, - NodeType: sdtypes.Leaf, - LeafKey: testhelpers.Account1LeafKey, - NodeValue: account1AtBlock2LeafNode, - StorageNodes: emptyStorage, - }, - }, - CodeAndCodeHashes: []sdtypes.CodeAndCodeHash{ - { - Hash: testhelpers.CodeHash, - Code: testhelpers.ByteCodeAfterDeployment, - }, - }, - }, - }, - { - "testBlock3", - //the contract's storage is changed - //and the block is mined by account 2 - statediff.Args{ - OldStateRoot: block2.Root(), - NewStateRoot: block3.Root(), - BlockNumber: block3.Number(), - BlockHash: block3.Hash(), - }, - &statediff.StateObject{ - BlockNumber: block3.Number(), - BlockHash: block3.Hash(), - Nodes: []sdtypes.StateNode{ - { - Path: []byte{'\x06'}, - NodeType: sdtypes.Leaf, - LeafKey: contractLeafKey, - NodeValue: contractAccountAtBlock3LeafNode, - StorageNodes: emptyStorage, - }, - }, - }, - }, - { - "testBlock4", - statediff.Args{ - OldStateRoot: block3.Root(), - NewStateRoot: block4.Root(), - BlockNumber: block4.Number(), - BlockHash: block4.Hash(), - }, - &statediff.StateObject{ - BlockNumber: block4.Number(), - BlockHash: block4.Hash(), - Nodes: []sdtypes.StateNode{ - { - Path: []byte{'\x06'}, - NodeType: sdtypes.Leaf, - LeafKey: contractLeafKey, - NodeValue: contractAccountAtBlock4LeafNode, - StorageNodes: []sdtypes.StorageNode{ - { - Path: []byte{'\x0b'}, - NodeType: sdtypes.Removed, - LeafKey: slot1StorageKey.Bytes(), - NodeValue: []byte{}, - }, - }, - }, - }, - }, - }, - } - - for _, test := range tests { - diff, err := builder.BuildStateDiffObject(test.startingArguments, params) - if err != nil { - t.Error(err) - } - receivedStateDiffRlp, err := rlp.EncodeToBytes(diff) - if err != nil { - t.Error(err) - } - expectedStateDiffRlp, err := rlp.EncodeToBytes(test.expected) - if err != nil { - t.Error(err) - } - sort.Slice(receivedStateDiffRlp, func(i, j int) bool { return receivedStateDiffRlp[i] < receivedStateDiffRlp[j] }) - sort.Slice(expectedStateDiffRlp, func(i, j int) bool { return expectedStateDiffRlp[i] < expectedStateDiffRlp[j] }) - if !bytes.Equal(receivedStateDiffRlp, expectedStateDiffRlp) { - t.Logf("Test failed: %s", test.name) - t.Errorf("actual state diff: %+v\nexpected state diff: %+v", diff, test.expected) - } - } -} - func TestBuilderWithRemovedAccountAndStorage(t *testing.T) { blocks, chain := testhelpers.MakeChain(6, testhelpers.Genesis, testhelpers.TestChainGen) contractLeafKey = testhelpers.AddressToLeafKey(testhelpers.ContractAddr) diff --git a/statediff/helpers.go b/statediff/helpers.go index ba809fc06..e2cf7365a 100644 --- a/statediff/helpers.go +++ b/statediff/helpers.go @@ -26,7 +26,6 @@ import ( "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/statediff/indexer/postgres" - "github.com/ethereum/go-ethereum/statediff/types" ) func sortKeys(data AccountMap) []string { @@ -76,39 +75,27 @@ func findIntersection(a, b []string) []string { } } -// loadWatchedAddresses is used to load watched addresses and storage slots to the in-memory write loop params from the db +// loadWatchedAddresses is used to load watched addresses to in-memory write loop params from the db func loadWatchedAddresses(db *postgres.DB) error { type Watched struct { Address string `db:"address"` - Kind int `db:"kind"` } var watched []Watched - pgStr := "SELECT address, kind FROM eth.watched_addresses" + pgStr := "SELECT address FROM eth.watched_addresses" err := db.Select(&watched, pgStr) if err != nil { return fmt.Errorf("error loading watched addresses: %v", err) } - var ( - watchedAddresses = []common.Address{} - watchedStorageSlots = []common.Hash{} - ) + watchedAddresses := []common.Address{} for _, entry := range watched { - switch entry.Kind { - case types.WatchedAddress.Int(): - watchedAddresses = append(watchedAddresses, common.HexToAddress(entry.Address)) - case types.WatchedStorageSlot.Int(): - watchedStorageSlots = append(watchedStorageSlots, common.HexToHash(entry.Address)) - default: - return fmt.Errorf("Unexpected kind %d", entry.Kind) - } + watchedAddresses = append(watchedAddresses, common.HexToAddress(entry.Address)) } writeLoopParams.Lock() defer writeLoopParams.Unlock() writeLoopParams.WatchedAddresses = watchedAddresses - writeLoopParams.WatchedStorageSlots = watchedStorageSlots return nil } diff --git a/statediff/indexer/indexer.go b/statediff/indexer/indexer.go index 5132f4942..70ab1339a 100644 --- a/statediff/indexer/indexer.go +++ b/statediff/indexer/indexer.go @@ -61,10 +61,10 @@ type Indexer interface { ReportDBMetrics(delay time.Duration, quit <-chan bool) // Methods used by WatchAddress API/functionality. - InsertWatchedAddresses(addresses []sdtypes.WatchAddressArg, currentBlock *big.Int, kind sdtypes.WatchedAddressType) error - RemoveWatchedAddresses(addresses []sdtypes.WatchAddressArg, kind sdtypes.WatchedAddressType) error - SetWatchedAddresses(args []sdtypes.WatchAddressArg, currentBlockNumber *big.Int, kind sdtypes.WatchedAddressType) error - ClearWatchedAddresses(kind sdtypes.WatchedAddressType) error + InsertWatchedAddresses(addresses []sdtypes.WatchAddressArg, currentBlock *big.Int) error + RemoveWatchedAddresses(addresses []sdtypes.WatchAddressArg) error + SetWatchedAddresses(args []sdtypes.WatchAddressArg, currentBlockNumber *big.Int) error + ClearWatchedAddresses() error } // StateDiffIndexer satisfies the Indexer interface for ethereum statediff objects @@ -556,8 +556,8 @@ func (sdi *StateDiffIndexer) PushCodeAndCodeHash(tx *BlockTx, codeAndCodeHash sd return nil } -// InsertWatchedAddresses inserts the given addresses | storage slots in the database -func (sdi *StateDiffIndexer) InsertWatchedAddresses(args []sdtypes.WatchAddressArg, currentBlockNumber *big.Int, kind sdtypes.WatchedAddressType) error { +// InsertWatchedAddresses inserts the given addresses in the database +func (sdi *StateDiffIndexer) InsertWatchedAddresses(args []sdtypes.WatchAddressArg, currentBlockNumber *big.Int) error { tx, err := sdi.dbWriter.db.Begin() if err != nil { return err @@ -565,8 +565,8 @@ func (sdi *StateDiffIndexer) InsertWatchedAddresses(args []sdtypes.WatchAddressA defer tx.Rollback() for _, arg := range args { - _, err = tx.Exec(`INSERT INTO eth.watched_addresses (address, kind, created_at, watched_at) VALUES ($1, $2, $3, $4) ON CONFLICT (address) DO NOTHING`, - arg.Address, kind.Int(), arg.CreatedAt, currentBlockNumber.Uint64()) + _, err = tx.Exec(`INSERT INTO eth.watched_addresses (address, created_at, watched_at) VALUES ($1, $2, $3) ON CONFLICT (address) DO NOTHING`, + arg.Address, arg.CreatedAt, currentBlockNumber.Uint64()) if err != nil { return fmt.Errorf("error inserting watched_addresses entry: %v", err) } @@ -580,8 +580,8 @@ func (sdi *StateDiffIndexer) InsertWatchedAddresses(args []sdtypes.WatchAddressA return nil } -// RemoveWatchedAddresses removes the given addresses | storage slots from the database -func (sdi *StateDiffIndexer) RemoveWatchedAddresses(args []sdtypes.WatchAddressArg, kind sdtypes.WatchedAddressType) error { +// RemoveWatchedAddresses removes the given watched addresses from the database +func (sdi *StateDiffIndexer) RemoveWatchedAddresses(args []sdtypes.WatchAddressArg) error { tx, err := sdi.dbWriter.db.Begin() if err != nil { return err @@ -589,7 +589,7 @@ func (sdi *StateDiffIndexer) RemoveWatchedAddresses(args []sdtypes.WatchAddressA defer tx.Rollback() for _, arg := range args { - _, err = tx.Exec(`DELETE FROM eth.watched_addresses WHERE address = $1 AND kind = $2`, arg.Address, kind.Int()) + _, err = tx.Exec(`DELETE FROM eth.watched_addresses WHERE address = $1`, arg.Address) if err != nil { return fmt.Errorf("error removing watched_addresses entry: %v", err) } @@ -603,22 +603,22 @@ func (sdi *StateDiffIndexer) RemoveWatchedAddresses(args []sdtypes.WatchAddressA return nil } -// SetWatchedAddresses clears and inserts the given addresses | storage slots in the database -func (sdi *StateDiffIndexer) SetWatchedAddresses(args []sdtypes.WatchAddressArg, currentBlockNumber *big.Int, kind sdtypes.WatchedAddressType) error { +// SetWatchedAddresses clears and inserts the given addresses in the database +func (sdi *StateDiffIndexer) SetWatchedAddresses(args []sdtypes.WatchAddressArg, currentBlockNumber *big.Int) error { tx, err := sdi.dbWriter.db.Begin() if err != nil { return err } defer tx.Rollback() - _, err = tx.Exec(`DELETE FROM eth.watched_addresses WHERE kind = $1`, kind.Int()) + _, err = tx.Exec(`DELETE FROM eth.watched_addresses`) if err != nil { return fmt.Errorf("error setting watched_addresses table: %v", err) } for _, arg := range args { - _, err = tx.Exec(`INSERT INTO eth.watched_addresses (address, kind, created_at, watched_at) VALUES ($1, $2, $3, $4) ON CONFLICT (address) DO NOTHING`, - arg.Address, kind.Int(), arg.CreatedAt, currentBlockNumber.Uint64()) + _, err = tx.Exec(`INSERT INTO eth.watched_addresses (address, created_at, watched_at) VALUES ($1, $2, $3) ON CONFLICT (address) DO NOTHING`, + arg.Address, arg.CreatedAt, currentBlockNumber.Uint64()) if err != nil { return fmt.Errorf("error setting watched_addresses table: %v", err) } @@ -632,9 +632,9 @@ func (sdi *StateDiffIndexer) SetWatchedAddresses(args []sdtypes.WatchAddressArg, return nil } -// ClearWatchedAddresses clears all the addresses | storage slots from the database -func (sdi *StateDiffIndexer) ClearWatchedAddresses(kind sdtypes.WatchedAddressType) error { - _, err := sdi.dbWriter.db.Exec(`DELETE FROM eth.watched_addresses WHERE kind = $1`, kind.Int()) +// ClearWatchedAddresses clears all the watched addresses from the database +func (sdi *StateDiffIndexer) ClearWatchedAddresses() error { + _, err := sdi.dbWriter.db.Exec(`DELETE FROM eth.watched_addresses`) if err != nil { return fmt.Errorf("error clearing watched_addresses table: %v", err) } diff --git a/statediff/indexer/indexer_test.go b/statediff/indexer/indexer_test.go index aa6ec9f64..f71e2ead2 100644 --- a/statediff/indexer/indexer_test.go +++ b/statediff/indexer/indexer_test.go @@ -25,7 +25,6 @@ import ( "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/core/types" - "github.com/ethereum/go-ethereum/crypto" "github.com/ethereum/go-ethereum/rlp" "github.com/ethereum/go-ethereum/statediff/indexer" "github.com/ethereum/go-ethereum/statediff/indexer/ipfs" @@ -56,8 +55,6 @@ var ( state1CID, state2CID, storageCID cid.Cid contract1Address, contract2Address, contract3Address, contract4Address string contract1CreatedAt, contract2CreatedAt, contract3CreatedAt, contract4CreatedAt uint64 - slot1StorageKeyHex, slot2StorageKeyHex, slot3StorageKeyHex, slot4StorageKeyHex string - slot1CreatedAt, slot2CreatedAt, slot3CreatedAt, slot4CreatedAt uint64 lastFilledAt, watchedAt1, watchedAt2, watchedAt3 uint64 ) @@ -179,15 +176,6 @@ func init() { contract3CreatedAt = uint64(3) contract4CreatedAt = uint64(4) - slot1StorageKeyHex = crypto.Keccak256Hash(common.HexToHash("1").Bytes()).Hex() - slot2StorageKeyHex = crypto.Keccak256Hash(common.HexToHash("2").Bytes()).Hex() - slot3StorageKeyHex = crypto.Keccak256Hash(common.HexToHash("3").Bytes()).Hex() - slot4StorageKeyHex = crypto.Keccak256Hash(common.HexToHash("4").Bytes()).Hex() - slot1CreatedAt = uint64(1) - slot2CreatedAt = uint64(2) - slot3CreatedAt = uint64(3) - slot4CreatedAt = uint64(4) - lastFilledAt = uint64(0) watchedAt1 = uint64(10) watchedAt2 = uint64(15) @@ -696,7 +684,6 @@ func TestWatchAddressMethods(t *testing.T) { type res struct { Address string `db:"address"` - Kind int `db:"kind"` CreatedAt uint64 `db:"created_at"` WatchedAt uint64 `db:"watched_at"` LastFilledAt uint64 `db:"last_filled_at"` @@ -718,21 +705,19 @@ func TestWatchAddressMethods(t *testing.T) { expectedData := []res{ { Address: contract1Address, - Kind: sdtypes.WatchedAddress.Int(), CreatedAt: contract1CreatedAt, WatchedAt: watchedAt1, LastFilledAt: lastFilledAt, }, { Address: contract2Address, - Kind: sdtypes.WatchedAddress.Int(), CreatedAt: contract2CreatedAt, WatchedAt: watchedAt1, LastFilledAt: lastFilledAt, }, } - ind.InsertWatchedAddresses(args, big.NewInt(int64(watchedAt1)), sdtypes.WatchedAddress) + ind.InsertWatchedAddresses(args, big.NewInt(int64(watchedAt1))) rows := []res{} err = db.Select(&rows, pgStr) @@ -760,28 +745,25 @@ func TestWatchAddressMethods(t *testing.T) { expectedData := []res{ { Address: contract1Address, - Kind: sdtypes.WatchedAddress.Int(), CreatedAt: contract1CreatedAt, WatchedAt: watchedAt1, LastFilledAt: lastFilledAt, }, { Address: contract2Address, - Kind: sdtypes.WatchedAddress.Int(), CreatedAt: contract2CreatedAt, WatchedAt: watchedAt1, LastFilledAt: lastFilledAt, }, { Address: contract3Address, - Kind: sdtypes.WatchedAddress.Int(), CreatedAt: contract3CreatedAt, WatchedAt: watchedAt2, LastFilledAt: lastFilledAt, }, } - ind.InsertWatchedAddresses(args, big.NewInt(int64(watchedAt2)), sdtypes.WatchedAddress) + ind.InsertWatchedAddresses(args, big.NewInt(int64(watchedAt2))) rows := []res{} err = db.Select(&rows, pgStr) @@ -809,14 +791,13 @@ func TestWatchAddressMethods(t *testing.T) { expectedData := []res{ { Address: contract1Address, - Kind: sdtypes.WatchedAddress.Int(), CreatedAt: contract1CreatedAt, WatchedAt: watchedAt1, LastFilledAt: lastFilledAt, }, } - ind.RemoveWatchedAddresses(args, sdtypes.WatchedAddress) + ind.RemoveWatchedAddresses(args) rows := []res{} err = db.Select(&rows, pgStr) @@ -843,7 +824,7 @@ func TestWatchAddressMethods(t *testing.T) { } expectedData := []res{} - ind.RemoveWatchedAddresses(args, sdtypes.WatchedAddress) + ind.RemoveWatchedAddresses(args) rows := []res{} err = db.Select(&rows, pgStr) @@ -875,28 +856,25 @@ func TestWatchAddressMethods(t *testing.T) { expectedData := []res{ { Address: contract1Address, - Kind: sdtypes.WatchedAddress.Int(), CreatedAt: contract1CreatedAt, WatchedAt: watchedAt2, LastFilledAt: lastFilledAt, }, { Address: contract2Address, - Kind: sdtypes.WatchedAddress.Int(), CreatedAt: contract2CreatedAt, WatchedAt: watchedAt2, LastFilledAt: lastFilledAt, }, { Address: contract3Address, - Kind: sdtypes.WatchedAddress.Int(), CreatedAt: contract3CreatedAt, WatchedAt: watchedAt2, LastFilledAt: lastFilledAt, }, } - ind.SetWatchedAddresses(args, big.NewInt(int64(watchedAt2)), sdtypes.WatchedAddress) + ind.SetWatchedAddresses(args, big.NewInt(int64(watchedAt2))) rows := []res{} err = db.Select(&rows, pgStr) @@ -928,28 +906,25 @@ func TestWatchAddressMethods(t *testing.T) { expectedData := []res{ { Address: contract4Address, - Kind: sdtypes.WatchedAddress.Int(), CreatedAt: contract4CreatedAt, WatchedAt: watchedAt3, LastFilledAt: lastFilledAt, }, { Address: contract2Address, - Kind: sdtypes.WatchedAddress.Int(), CreatedAt: contract2CreatedAt, WatchedAt: watchedAt3, LastFilledAt: lastFilledAt, }, { Address: contract3Address, - Kind: sdtypes.WatchedAddress.Int(), CreatedAt: contract3CreatedAt, WatchedAt: watchedAt3, LastFilledAt: lastFilledAt, }, } - ind.SetWatchedAddresses(args, big.NewInt(int64(watchedAt3)), sdtypes.WatchedAddress) + ind.SetWatchedAddresses(args, big.NewInt(int64(watchedAt3))) rows := []res{} err = db.Select(&rows, pgStr) @@ -966,7 +941,7 @@ func TestWatchAddressMethods(t *testing.T) { t.Run("Clear watched addresses", func(t *testing.T) { expectedData := []res{} - ind.ClearWatchedAddresses(sdtypes.WatchedAddress) + ind.ClearWatchedAddresses() rows := []res{} err = db.Select(&rows, pgStr) @@ -983,304 +958,7 @@ func TestWatchAddressMethods(t *testing.T) { t.Run("Clear watched addresses (empty table)", func(t *testing.T) { expectedData := []res{} - ind.ClearWatchedAddresses(sdtypes.WatchedAddress) - - rows := []res{} - err = db.Select(&rows, pgStr) - if err != nil { - t.Fatal(err) - } - - expectTrue(t, len(rows) == len(expectedData)) - for idx, row := range rows { - shared.ExpectEqual(t, row, expectedData[idx]) - } - }) - - // Watched storage slots - // Reset the db. - tearDown(t) - - t.Run("Insert watched storage slots", func(t *testing.T) { - args := []sdtypes.WatchAddressArg{ - { - Address: slot1StorageKeyHex, - CreatedAt: slot1CreatedAt, - }, - { - Address: slot2StorageKeyHex, - CreatedAt: slot2CreatedAt, - }, - } - expectedData := []res{ - { - Address: slot1StorageKeyHex, - Kind: sdtypes.WatchedStorageSlot.Int(), - CreatedAt: slot1CreatedAt, - WatchedAt: watchedAt1, - LastFilledAt: lastFilledAt, - }, - { - Address: slot2StorageKeyHex, - Kind: sdtypes.WatchedStorageSlot.Int(), - CreatedAt: slot2CreatedAt, - WatchedAt: watchedAt1, - LastFilledAt: lastFilledAt, - }, - } - - ind.InsertWatchedAddresses(args, big.NewInt(int64(watchedAt1)), sdtypes.WatchedStorageSlot) - - rows := []res{} - err = db.Select(&rows, pgStr) - if err != nil { - t.Fatal(err) - } - - expectTrue(t, len(rows) == len(expectedData)) - for idx, row := range rows { - shared.ExpectEqual(t, row, expectedData[idx]) - } - }) - - t.Run("Insert watched storage slots (some already watched)", func(t *testing.T) { - args := []sdtypes.WatchAddressArg{ - { - Address: slot3StorageKeyHex, - CreatedAt: slot3CreatedAt, - }, - { - Address: slot2StorageKeyHex, - CreatedAt: slot2CreatedAt, - }, - } - expectedData := []res{ - { - Address: slot1StorageKeyHex, - Kind: sdtypes.WatchedStorageSlot.Int(), - CreatedAt: slot1CreatedAt, - WatchedAt: watchedAt1, - LastFilledAt: lastFilledAt, - }, - { - Address: slot2StorageKeyHex, - Kind: sdtypes.WatchedStorageSlot.Int(), - CreatedAt: slot2CreatedAt, - WatchedAt: watchedAt1, - LastFilledAt: lastFilledAt, - }, - { - Address: slot3StorageKeyHex, - Kind: sdtypes.WatchedStorageSlot.Int(), - CreatedAt: slot3CreatedAt, - WatchedAt: watchedAt2, - LastFilledAt: lastFilledAt, - }, - } - - ind.InsertWatchedAddresses(args, big.NewInt(int64(watchedAt2)), sdtypes.WatchedStorageSlot) - - rows := []res{} - err = db.Select(&rows, pgStr) - if err != nil { - t.Fatal(err) - } - - expectTrue(t, len(rows) == len(expectedData)) - for idx, row := range rows { - shared.ExpectEqual(t, row, expectedData[idx]) - } - }) - - t.Run("Remove watched storage slots", func(t *testing.T) { - args := []sdtypes.WatchAddressArg{ - { - Address: slot3StorageKeyHex, - CreatedAt: slot3CreatedAt, - }, - { - Address: slot2StorageKeyHex, - CreatedAt: slot2CreatedAt, - }, - } - expectedData := []res{ - { - Address: slot1StorageKeyHex, - Kind: sdtypes.WatchedStorageSlot.Int(), - CreatedAt: slot1CreatedAt, - WatchedAt: watchedAt1, - LastFilledAt: lastFilledAt, - }, - } - - ind.RemoveWatchedAddresses(args, sdtypes.WatchedStorageSlot) - - rows := []res{} - err = db.Select(&rows, pgStr) - if err != nil { - t.Fatal(err) - } - - expectTrue(t, len(rows) == len(expectedData)) - for idx, row := range rows { - shared.ExpectEqual(t, row, expectedData[idx]) - } - }) - - t.Run("Remove watched storage slots (some non-watched)", func(t *testing.T) { - args := []sdtypes.WatchAddressArg{ - { - Address: slot1StorageKeyHex, - CreatedAt: slot1CreatedAt, - }, - { - Address: slot2StorageKeyHex, - CreatedAt: slot2CreatedAt, - }, - } - expectedData := []res{} - - ind.RemoveWatchedAddresses(args, sdtypes.WatchedStorageSlot) - - rows := []res{} - err = db.Select(&rows, pgStr) - if err != nil { - t.Fatal(err) - } - - expectTrue(t, len(rows) == len(expectedData)) - for idx, row := range rows { - shared.ExpectEqual(t, row, expectedData[idx]) - } - }) - - t.Run("Set watched storage slots", func(t *testing.T) { - args := []sdtypes.WatchAddressArg{ - { - Address: slot1StorageKeyHex, - CreatedAt: slot1CreatedAt, - }, - { - Address: slot2StorageKeyHex, - CreatedAt: slot2CreatedAt, - }, - { - Address: slot3StorageKeyHex, - CreatedAt: slot3CreatedAt, - }, - } - expectedData := []res{ - { - Address: slot1StorageKeyHex, - Kind: sdtypes.WatchedStorageSlot.Int(), - CreatedAt: slot1CreatedAt, - WatchedAt: watchedAt2, - LastFilledAt: lastFilledAt, - }, - { - Address: slot2StorageKeyHex, - Kind: sdtypes.WatchedStorageSlot.Int(), - CreatedAt: slot2CreatedAt, - WatchedAt: watchedAt2, - LastFilledAt: lastFilledAt, - }, - { - Address: slot3StorageKeyHex, - Kind: sdtypes.WatchedStorageSlot.Int(), - CreatedAt: slot3CreatedAt, - WatchedAt: watchedAt2, - LastFilledAt: lastFilledAt, - }, - } - - ind.SetWatchedAddresses(args, big.NewInt(int64(watchedAt2)), sdtypes.WatchedStorageSlot) - - rows := []res{} - err = db.Select(&rows, pgStr) - if err != nil { - t.Fatal(err) - } - - expectTrue(t, len(rows) == len(expectedData)) - for idx, row := range rows { - shared.ExpectEqual(t, row, expectedData[idx]) - } - }) - - t.Run("Set watched storage slots (some already watched)", func(t *testing.T) { - args := []sdtypes.WatchAddressArg{ - { - Address: slot4StorageKeyHex, - CreatedAt: slot4CreatedAt, - }, - { - Address: slot2StorageKeyHex, - CreatedAt: slot2CreatedAt, - }, - { - Address: slot3StorageKeyHex, - CreatedAt: slot3CreatedAt, - }, - } - expectedData := []res{ - { - Address: slot4StorageKeyHex, - Kind: sdtypes.WatchedStorageSlot.Int(), - CreatedAt: slot4CreatedAt, - WatchedAt: watchedAt3, - LastFilledAt: lastFilledAt, - }, - { - Address: slot2StorageKeyHex, - Kind: sdtypes.WatchedStorageSlot.Int(), - CreatedAt: slot2CreatedAt, - WatchedAt: watchedAt3, - LastFilledAt: lastFilledAt, - }, - { - Address: slot3StorageKeyHex, - Kind: sdtypes.WatchedStorageSlot.Int(), - CreatedAt: slot3CreatedAt, - WatchedAt: watchedAt3, - LastFilledAt: lastFilledAt, - }, - } - - ind.SetWatchedAddresses(args, big.NewInt(int64(watchedAt3)), sdtypes.WatchedStorageSlot) - - rows := []res{} - err = db.Select(&rows, pgStr) - if err != nil { - t.Fatal(err) - } - - expectTrue(t, len(rows) == len(expectedData)) - for idx, row := range rows { - shared.ExpectEqual(t, row, expectedData[idx]) - } - }) - - t.Run("Clear watched storage slots", func(t *testing.T) { - expectedData := []res{} - - ind.ClearWatchedAddresses(sdtypes.WatchedStorageSlot) - - rows := []res{} - err = db.Select(&rows, pgStr) - if err != nil { - t.Fatal(err) - } - - expectTrue(t, len(rows) == len(expectedData)) - for idx, row := range rows { - shared.ExpectEqual(t, row, expectedData[idx]) - } - }) - - t.Run("Clear watched storage slots (empty table)", func(t *testing.T) { - expectedData := []res{} - - ind.ClearWatchedAddresses(sdtypes.WatchedStorageSlot) + ind.ClearWatchedAddresses() rows := []res{} err = db.Select(&rows, pgStr) diff --git a/statediff/service.go b/statediff/service.go index edff82d67..ac6739ed6 100644 --- a/statediff/service.go +++ b/statediff/service.go @@ -55,6 +55,7 @@ const ( defaultRetryLimit = 3 // default retry limit once deadlock is detected. deadlockDetected = "deadlock detected" // 40P01 https://www.postgresql.org/docs/current/errcodes-appendix.html typeAssertionFailed = "type assertion failed" + unexpectedOperation = "unexpected operation" ) var writeLoopParams = ParamsWithMutex{ @@ -734,9 +735,8 @@ func (sds *Service) writeStateDiffWithRetry(block *types.Block, parentRoot commo return err } -// Performs one of following operations on the watched addresses | storage slots in writeLoopParams and the db: -// AddAddresses | RemoveAddresses | SetAddresses | ClearAddresses -// AddStorageSlots | RemoveStorageSlots | SetStorageSlots | ClearStorageSlots +// Performs one of following operations on the watched addresses in writeLoopParams and the db: +// Add | Remove | Set | Clear func (sds *Service) WatchAddress(operation OperationType, args []WatchAddressArg) error { // lock writeLoopParams for a write writeLoopParams.Lock() @@ -746,7 +746,7 @@ func (sds *Service) WatchAddress(operation OperationType, args []WatchAddressArg currentBlockNumber := sds.BlockChain.CurrentBlock().Number() switch operation { - case AddAddresses: + case Add: // filter out args having an already watched address with a warning filteredArgs, ok := funk.Filter(args, func(arg WatchAddressArg) bool { if funk.Contains(writeLoopParams.WatchedAddresses, common.HexToAddress(arg.Address)) { @@ -756,7 +756,7 @@ func (sds *Service) WatchAddress(operation OperationType, args []WatchAddressArg return true }).([]WatchAddressArg) if !ok { - return fmt.Errorf("AddAddresses: filtered args %s", typeAssertionFailed) + return fmt.Errorf("Add: filtered args %s", typeAssertionFailed) } // get addresses from the filtered args @@ -764,60 +764,60 @@ func (sds *Service) WatchAddress(operation OperationType, args []WatchAddressArg return common.HexToAddress(arg.Address) }).([]common.Address) if !ok { - return fmt.Errorf("AddAddresses: filtered addresses %s", typeAssertionFailed) + return fmt.Errorf("Add: filtered addresses %s", typeAssertionFailed) } // update the db - err := sds.indexer.InsertWatchedAddresses(filteredArgs, currentBlockNumber, WatchedAddress) + err := sds.indexer.InsertWatchedAddresses(filteredArgs, currentBlockNumber) if err != nil { return err } // update in-memory params writeLoopParams.WatchedAddresses = append(writeLoopParams.WatchedAddresses, filteredAddresses...) - case RemoveAddresses: + case Remove: // get addresses from args argAddresses, ok := funk.Map(args, func(arg WatchAddressArg) common.Address { return common.HexToAddress(arg.Address) }).([]common.Address) if !ok { - return fmt.Errorf("RemoveAddresses: mapped addresses %s", typeAssertionFailed) + return fmt.Errorf("Remove: mapped addresses %s", typeAssertionFailed) } // remove the provided addresses from currently watched addresses addresses, ok := funk.Subtract(writeLoopParams.WatchedAddresses, argAddresses).([]common.Address) if !ok { - return fmt.Errorf("RemoveAddresses: filtered addresses %s", typeAssertionFailed) + return fmt.Errorf("Remove: filtered addresses %s", typeAssertionFailed) } // update the db - err := sds.indexer.RemoveWatchedAddresses(args, WatchedAddress) + err := sds.indexer.RemoveWatchedAddresses(args) if err != nil { return err } // update in-memory params writeLoopParams.WatchedAddresses = addresses - case SetAddresses: + case Set: // get addresses from args argAddresses, ok := funk.Map(args, func(arg WatchAddressArg) common.Address { return common.HexToAddress(arg.Address) }).([]common.Address) if !ok { - return fmt.Errorf("SetAddresses: mapped addresses %s", typeAssertionFailed) + return fmt.Errorf("Set: mapped addresses %s", typeAssertionFailed) } // update the db - err := sds.indexer.SetWatchedAddresses(args, currentBlockNumber, WatchedAddress) + err := sds.indexer.SetWatchedAddresses(args, currentBlockNumber) if err != nil { return err } // update in-memory params writeLoopParams.WatchedAddresses = argAddresses - case ClearAddresses: + case Clear: // update the db - err := sds.indexer.ClearWatchedAddresses(WatchedAddress) + err := sds.indexer.ClearWatchedAddresses() if err != nil { return err } @@ -825,85 +825,8 @@ func (sds *Service) WatchAddress(operation OperationType, args []WatchAddressArg // update in-memory params writeLoopParams.WatchedAddresses = []common.Address{} - case AddStorageSlots: - // filter out args having an already watched storage slot with a warning - filteredArgs, ok := funk.Filter(args, func(arg WatchAddressArg) bool { - if funk.Contains(writeLoopParams.WatchedStorageSlots, common.HexToHash(arg.Address)) { - log.Warn("StorageSlot already being watched", "address", arg.Address) - return false - } - return true - }).([]WatchAddressArg) - if !ok { - return fmt.Errorf("AddStorageSlots: filtered args %s", typeAssertionFailed) - } - - // get storage slots from the filtered args - filteredStorageSlots, ok := funk.Map(filteredArgs, func(arg WatchAddressArg) common.Hash { - return common.HexToHash(arg.Address) - }).([]common.Hash) - if !ok { - return fmt.Errorf("AddStorageSlots: filtered storage slots %s", typeAssertionFailed) - } - - // update the db - err := sds.indexer.InsertWatchedAddresses(filteredArgs, currentBlockNumber, WatchedStorageSlot) - if err != nil { - return err - } - - // update in-memory params - writeLoopParams.WatchedStorageSlots = append(writeLoopParams.WatchedStorageSlots, filteredStorageSlots...) - case RemoveStorageSlots: - // get storage slots from args - argStorageSlots, ok := funk.Map(args, func(arg WatchAddressArg) common.Hash { - return common.HexToHash(arg.Address) - }).([]common.Hash) - if !ok { - return fmt.Errorf("RemoveStorageSlots: mapped storage slots %s", typeAssertionFailed) - } - - // remove the provided storage slots from currently watched storage slots - storageSlots, ok := funk.Subtract(writeLoopParams.WatchedStorageSlots, argStorageSlots).([]common.Hash) - if !ok { - return fmt.Errorf("RemoveStorageSlots: filtered storage slots %s", typeAssertionFailed) - } - - // update the db - err := sds.indexer.RemoveWatchedAddresses(args, WatchedStorageSlot) - if err != nil { - return err - } - - // update in-memory params - writeLoopParams.WatchedStorageSlots = storageSlots - case SetStorageSlots: - // get storage slots from args - argStorageSlots, ok := funk.Map(args, func(arg WatchAddressArg) common.Hash { - return common.HexToHash(arg.Address) - }).([]common.Hash) - if !ok { - return fmt.Errorf("SetStorageSlots: mapped storage slots %s", typeAssertionFailed) - } - - // update the db - err := sds.indexer.SetWatchedAddresses(args, currentBlockNumber, WatchedStorageSlot) - if err != nil { - return err - } - - // update in-memory params - writeLoopParams.WatchedStorageSlots = argStorageSlots - case ClearStorageSlots: - err := sds.indexer.ClearWatchedAddresses(WatchedStorageSlot) - if err != nil { - return err - } - - writeLoopParams.WatchedStorageSlots = []common.Hash{} - default: - return fmt.Errorf("Unexpected operation %s", operation) + return fmt.Errorf("%s %s", unexpectedOperation, operation) } return nil diff --git a/statediff/testhelpers/mocks/indexer.go b/statediff/testhelpers/mocks/indexer.go index 89459b558..90ea40ca0 100644 --- a/statediff/testhelpers/mocks/indexer.go +++ b/statediff/testhelpers/mocks/indexer.go @@ -42,18 +42,18 @@ func (sdi *Indexer) PushCodeAndCodeHash(tx *indexer.BlockTx, codeAndCodeHash sdt func (sdi *Indexer) ReportDBMetrics(delay time.Duration, quit <-chan bool) {} -func (sdi *Indexer) InsertWatchedAddresses(addresses []sdtypes.WatchAddressArg, currentBlock *big.Int, kind sdtypes.WatchedAddressType) error { +func (sdi *Indexer) InsertWatchedAddresses(addresses []sdtypes.WatchAddressArg, currentBlock *big.Int) error { return nil } -func (sdi *Indexer) RemoveWatchedAddresses(addresses []sdtypes.WatchAddressArg, kind sdtypes.WatchedAddressType) error { +func (sdi *Indexer) RemoveWatchedAddresses(addresses []sdtypes.WatchAddressArg) error { return nil } -func (sdi *Indexer) SetWatchedAddresses(args []sdtypes.WatchAddressArg, currentBlockNumber *big.Int, kind sdtypes.WatchedAddressType) error { +func (sdi *Indexer) SetWatchedAddresses(args []sdtypes.WatchAddressArg, currentBlockNumber *big.Int) error { return nil } -func (sdi *Indexer) ClearWatchedAddresses(kind sdtypes.WatchedAddressType) error { +func (sdi *Indexer) ClearWatchedAddresses() error { return nil } diff --git a/statediff/testhelpers/mocks/service.go b/statediff/testhelpers/mocks/service.go index 1e5d3c1ba..61513d75c 100644 --- a/statediff/testhelpers/mocks/service.go +++ b/statediff/testhelpers/mocks/service.go @@ -37,7 +37,10 @@ import ( sdtypes "github.com/ethereum/go-ethereum/statediff/types" ) -var typeAssertionFailed = "type assertion failed" +var ( + typeAssertionFailed = "type assertion failed" + unexpectedOperation = "unexpected operation" +) // MockStateDiffService is a mock state diff service type MockStateDiffService struct { @@ -349,7 +352,7 @@ func (sds *MockStateDiffService) WatchAddress(operation statediff.OperationType, currentBlockNumber := sds.BlockChain.CurrentBlock().Number() switch operation { - case statediff.AddAddresses: + case statediff.Add: // filter out args having an already watched address with a warning filteredArgs, ok := funk.Filter(args, func(arg sdtypes.WatchAddressArg) bool { if funk.Contains(sds.writeLoopParams.WatchedAddresses, common.HexToAddress(arg.Address)) { @@ -359,7 +362,7 @@ func (sds *MockStateDiffService) WatchAddress(operation statediff.OperationType, return true }).([]sdtypes.WatchAddressArg) if !ok { - return fmt.Errorf("AddAddresses: filtered args %s", typeAssertionFailed) + return fmt.Errorf("Add: filtered args %s", typeAssertionFailed) } // get addresses from the filtered args @@ -367,60 +370,60 @@ func (sds *MockStateDiffService) WatchAddress(operation statediff.OperationType, return common.HexToAddress(arg.Address) }).([]common.Address) if !ok { - return fmt.Errorf("AddAddresses: filtered addresses %s", typeAssertionFailed) + return fmt.Errorf("Add: filtered addresses %s", typeAssertionFailed) } // update the db - err := sds.Indexer.InsertWatchedAddresses(filteredArgs, currentBlockNumber, sdtypes.WatchedAddress) + err := sds.Indexer.InsertWatchedAddresses(filteredArgs, currentBlockNumber) if err != nil { return err } // update in-memory params sds.writeLoopParams.WatchedAddresses = append(sds.writeLoopParams.WatchedAddresses, filteredAddresses...) - case statediff.RemoveAddresses: + case statediff.Remove: // get addresses from args argAddresses, ok := funk.Map(args, func(arg sdtypes.WatchAddressArg) common.Address { return common.HexToAddress(arg.Address) }).([]common.Address) if !ok { - return fmt.Errorf("RemoveAddresses: mapped addresses %s", typeAssertionFailed) + return fmt.Errorf("Remove: mapped addresses %s", typeAssertionFailed) } // remove the provided addresses from currently watched addresses addresses, ok := funk.Subtract(sds.writeLoopParams.WatchedAddresses, argAddresses).([]common.Address) if !ok { - return fmt.Errorf("RemoveAddresses: filtered addresses %s", typeAssertionFailed) + return fmt.Errorf("Remove: filtered addresses %s", typeAssertionFailed) } // update the db - err := sds.Indexer.RemoveWatchedAddresses(args, sdtypes.WatchedAddress) + err := sds.Indexer.RemoveWatchedAddresses(args) if err != nil { return err } // update in-memory params sds.writeLoopParams.WatchedAddresses = addresses - case statediff.SetAddresses: + case statediff.Set: // get addresses from args argAddresses, ok := funk.Map(args, func(arg sdtypes.WatchAddressArg) common.Address { return common.HexToAddress(arg.Address) }).([]common.Address) if !ok { - return fmt.Errorf("SetAddresses: mapped addresses %s", typeAssertionFailed) + return fmt.Errorf("Set: mapped addresses %s", typeAssertionFailed) } // update the db - err := sds.Indexer.SetWatchedAddresses(args, currentBlockNumber, sdtypes.WatchedAddress) + err := sds.Indexer.SetWatchedAddresses(args, currentBlockNumber) if err != nil { return err } // update in-memory params sds.writeLoopParams.WatchedAddresses = argAddresses - case statediff.ClearAddresses: + case statediff.Clear: // update the db - err := sds.Indexer.ClearWatchedAddresses(sdtypes.WatchedAddress) + err := sds.Indexer.ClearWatchedAddresses() if err != nil { return err } @@ -428,85 +431,8 @@ func (sds *MockStateDiffService) WatchAddress(operation statediff.OperationType, // update in-memory params sds.writeLoopParams.WatchedAddresses = []common.Address{} - case statediff.AddStorageSlots: - // filter out args having an already watched storage slot with a warning - filteredArgs, ok := funk.Filter(args, func(arg sdtypes.WatchAddressArg) bool { - if funk.Contains(sds.writeLoopParams.WatchedStorageSlots, common.HexToHash(arg.Address)) { - log.Warn("StorageSlot already being watched", "address", arg.Address) - return false - } - return true - }).([]sdtypes.WatchAddressArg) - if !ok { - return fmt.Errorf("AddStorageSlots: filtered args %s", typeAssertionFailed) - } - - // get storage slots from the filtered args - filteredStorageSlots, ok := funk.Map(filteredArgs, func(arg sdtypes.WatchAddressArg) common.Hash { - return common.HexToHash(arg.Address) - }).([]common.Hash) - if !ok { - return fmt.Errorf("AddStorageSlots: filtered storage slots %s", typeAssertionFailed) - } - - // update the db - err := sds.Indexer.InsertWatchedAddresses(filteredArgs, currentBlockNumber, sdtypes.WatchedStorageSlot) - if err != nil { - return err - } - - // update in-memory params - sds.writeLoopParams.WatchedStorageSlots = append(sds.writeLoopParams.WatchedStorageSlots, filteredStorageSlots...) - case statediff.RemoveStorageSlots: - // get storage slots from args - argStorageSlots, ok := funk.Map(args, func(arg sdtypes.WatchAddressArg) common.Hash { - return common.HexToHash(arg.Address) - }).([]common.Hash) - if !ok { - return fmt.Errorf("RemoveStorageSlots: mapped storage slots %s", typeAssertionFailed) - } - - // remove the provided storage slots from currently watched storage slots - storageSlots, ok := funk.Subtract(sds.writeLoopParams.WatchedStorageSlots, argStorageSlots).([]common.Hash) - if !ok { - return fmt.Errorf("RemoveStorageSlots: filtered storage slots %s", typeAssertionFailed) - } - - // update the db - err := sds.Indexer.RemoveWatchedAddresses(args, sdtypes.WatchedStorageSlot) - if err != nil { - return err - } - - // update in-memory params - sds.writeLoopParams.WatchedStorageSlots = storageSlots - case statediff.SetStorageSlots: - // get storage slots from args - argStorageSlots, ok := funk.Map(args, func(arg sdtypes.WatchAddressArg) common.Hash { - return common.HexToHash(arg.Address) - }).([]common.Hash) - if !ok { - return fmt.Errorf("SetStorageSlots: mapped storage slots %s", typeAssertionFailed) - } - - // update the db - err := sds.Indexer.SetWatchedAddresses(args, currentBlockNumber, sdtypes.WatchedStorageSlot) - if err != nil { - return err - } - - // update in-memory params - sds.writeLoopParams.WatchedStorageSlots = argStorageSlots - case statediff.ClearStorageSlots: - err := sds.Indexer.ClearWatchedAddresses(sdtypes.WatchedStorageSlot) - if err != nil { - return err - } - - sds.writeLoopParams.WatchedStorageSlots = []common.Hash{} - default: - return fmt.Errorf("Unexpected operation %s", operation) + return fmt.Errorf("%s %s", unexpectedOperation, operation) } return nil diff --git a/statediff/testhelpers/mocks/service_test.go b/statediff/testhelpers/mocks/service_test.go index bc64e2a01..3af1a4b11 100644 --- a/statediff/testhelpers/mocks/service_test.go +++ b/statediff/testhelpers/mocks/service_test.go @@ -28,7 +28,6 @@ import ( "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/core/types" - "github.com/ethereum/go-ethereum/crypto" "github.com/ethereum/go-ethereum/rlp" "github.com/ethereum/go-ethereum/rpc" "github.com/ethereum/go-ethereum/statediff" @@ -273,23 +272,6 @@ func testWatchAddressAPI(t *testing.T) { contract3CreatedAt = uint64(3) contract4CreatedAt = uint64(4) - slot1 = common.HexToHash("1") - slot2 = common.HexToHash("2") - slot3 = common.HexToHash("3") - slot4 = common.HexToHash("4") - slot1StorageKey = crypto.Keccak256Hash(slot1.Bytes()) - slot2StorageKey = crypto.Keccak256Hash(slot2.Bytes()) - slot3StorageKey = crypto.Keccak256Hash(slot3.Bytes()) - slot4StorageKey = crypto.Keccak256Hash(slot4.Bytes()) - slot1StorageKeyHex = crypto.Keccak256Hash(slot1.Bytes()).Hex() - slot2StorageKeyHex = crypto.Keccak256Hash(slot2.Bytes()).Hex() - slot3StorageKeyHex = crypto.Keccak256Hash(slot3.Bytes()).Hex() - slot4StorageKeyHex = crypto.Keccak256Hash(slot4.Bytes()).Hex() - slot1CreatedAt = uint64(1) - slot2CreatedAt = uint64(2) - slot3CreatedAt = uint64(3) - slot4CreatedAt = uint64(4) - args1 = []sdtypes.WatchAddressArg{ { Address: contract1Address, @@ -424,141 +406,6 @@ func testWatchAddressAPI(t *testing.T) { expectedParams9 = statediff.Params{ WatchedAddresses: []common.Address{}, } - - args10 = []sdtypes.WatchAddressArg{ - { - Address: slot1StorageKeyHex, - CreatedAt: slot1CreatedAt, - }, - { - Address: slot2StorageKeyHex, - CreatedAt: slot2CreatedAt, - }, - } - startingParams10 = statediff.Params{ - WatchedStorageSlots: []common.Hash{}, - } - expectedParams10 = statediff.Params{ - WatchedStorageSlots: []common.Hash{ - slot1StorageKey, - slot2StorageKey, - }, - } - - args11 = []sdtypes.WatchAddressArg{ - { - Address: slot3StorageKeyHex, - CreatedAt: slot3CreatedAt, - }, - { - Address: slot2StorageKeyHex, - CreatedAt: slot2CreatedAt, - }, - } - startingParams11 = expectedParams10 - expectedParams11 = statediff.Params{ - WatchedStorageSlots: []common.Hash{ - slot1StorageKey, - slot2StorageKey, - slot3StorageKey, - }, - } - - args12 = []sdtypes.WatchAddressArg{ - { - Address: slot3StorageKeyHex, - CreatedAt: slot3CreatedAt, - }, - { - Address: slot2StorageKeyHex, - CreatedAt: slot2CreatedAt, - }, - } - startingParams12 = expectedParams11 - expectedParams12 = statediff.Params{ - WatchedStorageSlots: []common.Hash{ - slot1StorageKey, - }, - } - - args13 = []sdtypes.WatchAddressArg{ - { - Address: slot1StorageKeyHex, - CreatedAt: slot1CreatedAt, - }, - { - Address: slot2StorageKeyHex, - CreatedAt: slot2CreatedAt, - }, - } - startingParams13 = expectedParams12 - expectedParams13 = statediff.Params{ - WatchedStorageSlots: []common.Hash{}, - } - - args14 = []sdtypes.WatchAddressArg{ - { - Address: slot1StorageKeyHex, - CreatedAt: slot1CreatedAt, - }, - { - Address: slot2StorageKeyHex, - CreatedAt: slot2CreatedAt, - }, - { - Address: slot3StorageKeyHex, - CreatedAt: slot3CreatedAt, - }, - } - startingParams14 = expectedParams13 - expectedParams14 = statediff.Params{ - WatchedStorageSlots: []common.Hash{ - slot1StorageKey, - slot2StorageKey, - slot3StorageKey, - }, - } - - args15 = []sdtypes.WatchAddressArg{ - { - Address: slot4StorageKeyHex, - CreatedAt: slot4CreatedAt, - }, - { - Address: slot2StorageKeyHex, - CreatedAt: slot2CreatedAt, - }, - { - Address: slot3StorageKeyHex, - CreatedAt: slot3CreatedAt, - }, - } - startingParams15 = expectedParams14 - expectedParams15 = statediff.Params{ - WatchedStorageSlots: []common.Hash{ - slot4StorageKey, - slot2StorageKey, - slot3StorageKey, - }, - } - - args16 = []sdtypes.WatchAddressArg{} - startingParams16 = expectedParams15 - expectedParams16 = statediff.Params{ - WatchedStorageSlots: []common.Hash{}, - } - - args17 = []sdtypes.WatchAddressArg{} - startingParams17 = expectedParams15 - expectedParams17 = statediff.Params{ - WatchedStorageSlots: []common.Hash{}, - } - - args18 = []sdtypes.WatchAddressArg{} - startingParams18 = expectedParams17 - expectedParams18 = statediff.Params{ - WatchedStorageSlots: []common.Hash{}, - } ) tests := []struct { @@ -572,7 +419,7 @@ func testWatchAddressAPI(t *testing.T) { // addresses tests { "testAddAddresses", - statediff.AddAddresses, + statediff.Add, args1, startingParams1, expectedParams1, @@ -580,7 +427,7 @@ func testWatchAddressAPI(t *testing.T) { }, { "testAddAddressesSomeWatched", - statediff.AddAddresses, + statediff.Add, args2, startingParams2, expectedParams2, @@ -588,7 +435,7 @@ func testWatchAddressAPI(t *testing.T) { }, { "testRemoveAddresses", - statediff.RemoveAddresses, + statediff.Remove, args3, startingParams3, expectedParams3, @@ -596,7 +443,7 @@ func testWatchAddressAPI(t *testing.T) { }, { "testRemoveAddressesSomeWatched", - statediff.RemoveAddresses, + statediff.Remove, args4, startingParams4, expectedParams4, @@ -604,7 +451,7 @@ func testWatchAddressAPI(t *testing.T) { }, { "testSetAddresses", - statediff.SetAddresses, + statediff.Set, args5, startingParams5, expectedParams5, @@ -612,7 +459,7 @@ func testWatchAddressAPI(t *testing.T) { }, { "testSetAddressesSomeWatched", - statediff.SetAddresses, + statediff.Set, args6, startingParams6, expectedParams6, @@ -620,7 +467,7 @@ func testWatchAddressAPI(t *testing.T) { }, { "testSetAddressesEmtpyArgs", - statediff.SetAddresses, + statediff.Set, args7, startingParams7, expectedParams7, @@ -628,7 +475,7 @@ func testWatchAddressAPI(t *testing.T) { }, { "testClearAddresses", - statediff.ClearAddresses, + statediff.Clear, args8, startingParams8, expectedParams8, @@ -636,95 +483,21 @@ func testWatchAddressAPI(t *testing.T) { }, { "testClearAddressesEmpty", - statediff.ClearAddresses, + statediff.Clear, args9, startingParams9, expectedParams9, nil, }, - // storage slots tests - { - "testAddStorageSlots", - statediff.AddStorageSlots, - args10, - startingParams10, - expectedParams10, - nil, - }, - { - "testAddStorageSlotsSomeWatched", - statediff.AddStorageSlots, - args11, - startingParams11, - expectedParams11, - nil, - }, - { - "testRemoveStorageSlots", - statediff.RemoveStorageSlots, - args12, - startingParams12, - expectedParams12, - nil, - }, - { - "testRemoveStorageSlotsSomeWatched", - statediff.RemoveStorageSlots, - args13, - startingParams13, - expectedParams13, - nil, - }, - { - "testSetStorageSlots", - statediff.SetStorageSlots, - args14, - startingParams14, - expectedParams14, - nil, - }, - { - "testSetStorageSlotsSomeWatched", - statediff.SetStorageSlots, - args15, - startingParams15, - expectedParams15, - nil, - }, - { - "testSetStorageSlotsEmtpyArgs", - statediff.SetStorageSlots, - args16, - startingParams16, - expectedParams16, - nil, - }, - { - "testClearStorageSlots", - statediff.ClearStorageSlots, - args17, - startingParams17, - expectedParams17, - nil, - }, - { - "testClearStorageSlotsEmpty", - statediff.ClearStorageSlots, - args18, - startingParams18, - expectedParams18, - nil, - }, - // invalid args { "testInvalidOperation", "WrongOp", - args18, - startingParams18, + args9, + startingParams9, statediff.Params{}, - fmt.Errorf("Unexpected operation WrongOp"), + fmt.Errorf("%s WrongOp", unexpectedOperation), }, } diff --git a/statediff/types.go b/statediff/types.go index 69eaf0d0a..9e90ecabc 100644 --- a/statediff/types.go +++ b/statediff/types.go @@ -51,7 +51,6 @@ type Params struct { IncludeTD bool IncludeCode bool WatchedAddresses []common.Address - WatchedStorageSlots []common.Hash } // ParamsWithMutex allows to lock the parameters while they are being updated | read from @@ -123,13 +122,8 @@ type accountWrapper struct { type OperationType string const ( - AddAddresses OperationType = "AddAddresses" - RemoveAddresses OperationType = "RemoveAddresses" - SetAddresses OperationType = "SetAddresses" - ClearAddresses OperationType = "ClearAddresses" - - AddStorageSlots OperationType = "AddStorageSlots" - RemoveStorageSlots OperationType = "RemoveStorageSlots" - SetStorageSlots OperationType = "SetStorageSlots" - ClearStorageSlots OperationType = "ClearStorageSlots" + Add OperationType = "Add" + Remove OperationType = "Remove" + Set OperationType = "Set" + Clear OperationType = "Clear" ) diff --git a/statediff/types/types.go b/statediff/types/types.go index 18d9d0b31..bdda68c65 100644 --- a/statediff/types/types.go +++ b/statediff/types/types.go @@ -77,26 +77,7 @@ type CodeSink func(CodeAndCodeHash) error // WatchAddressArg is a arg type for WatchAddress API type WatchAddressArg struct { - // Address represents common.Address | common.Hash + // Address represents common.Address Address string CreatedAt uint64 } - -// WatchedAddressType for denoting watched: address | storage slot -type WatchedAddressType string - -const ( - WatchedAddress WatchedAddressType = "WatchedAddress" - WatchedStorageSlot WatchedAddressType = "WatchedStorageSlot" -) - -func (n WatchedAddressType) Int() int { - switch n { - case WatchedAddress: - return 0 - case WatchedStorageSlot: - return 1 - default: - return -1 - } -} -- 2.45.2 From 6cb4731d8d3ac14d12e448015fa4a676aba87137 Mon Sep 17 00:00:00 2001 From: prathamesh0 Date: Fri, 4 Feb 2022 15:00:31 +0530 Subject: [PATCH 12/17] Changes for updated database schema --- statediff/helpers.go | 2 +- statediff/indexer/indexer.go | 10 +++++----- statediff/indexer/indexer_test.go | 3 +-- statediff/indexer/test_helpers.go | 2 +- statediff/testhelpers/mocks/service_test.go | 5 ++++- 5 files changed, 12 insertions(+), 10 deletions(-) diff --git a/statediff/helpers.go b/statediff/helpers.go index e2cf7365a..73ac084e8 100644 --- a/statediff/helpers.go +++ b/statediff/helpers.go @@ -82,7 +82,7 @@ func loadWatchedAddresses(db *postgres.DB) error { } var watched []Watched - pgStr := "SELECT address FROM eth.watched_addresses" + pgStr := "SELECT address FROM eth_meta.watched_addresses" err := db.Select(&watched, pgStr) if err != nil { return fmt.Errorf("error loading watched addresses: %v", err) diff --git a/statediff/indexer/indexer.go b/statediff/indexer/indexer.go index 70ab1339a..58e0bafc5 100644 --- a/statediff/indexer/indexer.go +++ b/statediff/indexer/indexer.go @@ -565,7 +565,7 @@ func (sdi *StateDiffIndexer) InsertWatchedAddresses(args []sdtypes.WatchAddressA defer tx.Rollback() for _, arg := range args { - _, err = tx.Exec(`INSERT INTO eth.watched_addresses (address, created_at, watched_at) VALUES ($1, $2, $3) ON CONFLICT (address) DO NOTHING`, + _, err = tx.Exec(`INSERT INTO eth_meta.watched_addresses (address, created_at, watched_at) VALUES ($1, $2, $3) ON CONFLICT (address) DO NOTHING`, arg.Address, arg.CreatedAt, currentBlockNumber.Uint64()) if err != nil { return fmt.Errorf("error inserting watched_addresses entry: %v", err) @@ -589,7 +589,7 @@ func (sdi *StateDiffIndexer) RemoveWatchedAddresses(args []sdtypes.WatchAddressA defer tx.Rollback() for _, arg := range args { - _, err = tx.Exec(`DELETE FROM eth.watched_addresses WHERE address = $1`, arg.Address) + _, err = tx.Exec(`DELETE FROM eth_meta.watched_addresses WHERE address = $1`, arg.Address) if err != nil { return fmt.Errorf("error removing watched_addresses entry: %v", err) } @@ -611,13 +611,13 @@ func (sdi *StateDiffIndexer) SetWatchedAddresses(args []sdtypes.WatchAddressArg, } defer tx.Rollback() - _, err = tx.Exec(`DELETE FROM eth.watched_addresses`) + _, err = tx.Exec(`DELETE FROM eth_meta.watched_addresses`) if err != nil { return fmt.Errorf("error setting watched_addresses table: %v", err) } for _, arg := range args { - _, err = tx.Exec(`INSERT INTO eth.watched_addresses (address, created_at, watched_at) VALUES ($1, $2, $3) ON CONFLICT (address) DO NOTHING`, + _, err = tx.Exec(`INSERT INTO eth_meta.watched_addresses (address, created_at, watched_at) VALUES ($1, $2, $3) ON CONFLICT (address) DO NOTHING`, arg.Address, arg.CreatedAt, currentBlockNumber.Uint64()) if err != nil { return fmt.Errorf("error setting watched_addresses table: %v", err) @@ -634,7 +634,7 @@ func (sdi *StateDiffIndexer) SetWatchedAddresses(args []sdtypes.WatchAddressArg, // ClearWatchedAddresses clears all the watched addresses from the database func (sdi *StateDiffIndexer) ClearWatchedAddresses() error { - _, err := sdi.dbWriter.db.Exec(`DELETE FROM eth.watched_addresses`) + _, err := sdi.dbWriter.db.Exec(`DELETE FROM eth_meta.watched_addresses`) if err != nil { return fmt.Errorf("error clearing watched_addresses table: %v", err) } diff --git a/statediff/indexer/indexer_test.go b/statediff/indexer/indexer_test.go index f71e2ead2..011738692 100644 --- a/statediff/indexer/indexer_test.go +++ b/statediff/indexer/indexer_test.go @@ -688,9 +688,8 @@ func TestWatchAddressMethods(t *testing.T) { WatchedAt uint64 `db:"watched_at"` LastFilledAt uint64 `db:"last_filled_at"` } - pgStr := "SELECT * FROM eth.watched_addresses" + pgStr := "SELECT * FROM eth_meta.watched_addresses" - // Watched addresses t.Run("Insert watched addresses", func(t *testing.T) { args := []sdtypes.WatchAddressArg{ { diff --git a/statediff/indexer/test_helpers.go b/statediff/indexer/test_helpers.go index 43bb3023d..9745cfd02 100644 --- a/statediff/indexer/test_helpers.go +++ b/statediff/indexer/test_helpers.go @@ -53,7 +53,7 @@ func TearDownDB(t *testing.T, db *postgres.DB) { if err != nil { t.Fatal(err) } - _, err = tx.Exec(`DELETE FROM eth.watched_addresses`) + _, err = tx.Exec(`DELETE FROM eth_meta.watched_addresses`) if err != nil { t.Fatal(err) } diff --git a/statediff/testhelpers/mocks/service_test.go b/statediff/testhelpers/mocks/service_test.go index 3af1a4b11..f1be5e317 100644 --- a/statediff/testhelpers/mocks/service_test.go +++ b/statediff/testhelpers/mocks/service_test.go @@ -262,6 +262,7 @@ func testWatchAddressAPI(t *testing.T) { Indexer: &mockIndexer, } + // test data var ( contract1Address = "0x5d663F5269090bD2A7DC2390c911dF6083D7b28F" contract2Address = "0x6Eb7e5C66DB8af2E96159AC440cbc8CDB7fbD26B" @@ -416,7 +417,6 @@ func testWatchAddressAPI(t *testing.T) { expectedParams statediff.Params expectedErr error }{ - // addresses tests { "testAddAddresses", statediff.Add, @@ -502,10 +502,12 @@ func testWatchAddressAPI(t *testing.T) { } for _, test := range tests { + // set indexing params mockService.writeLoopParams = statediff.ParamsWithMutex{ Params: test.startingParams, } + // make the API call to change watched addresses err := mockService.WatchAddress(test.operation, test.args) if test.expectedErr != nil { if err.Error() != test.expectedErr.Error() { @@ -519,6 +521,7 @@ func testWatchAddressAPI(t *testing.T) { t.Error(err) } + // check updated indexing params updatedParams := mockService.writeLoopParams.Params if !reflect.DeepEqual(updatedParams, test.expectedParams) { t.Logf("Test failed: %s", test.name) -- 2.45.2 From 3c6aa6a9ccfdef44b753ade5f76944bc7af71491 Mon Sep 17 00:00:00 2001 From: prathamesh0 Date: Mon, 7 Feb 2022 12:26:33 +0530 Subject: [PATCH 13/17] Store pre-computed leaf keys for watched addresses in a map --- statediff/builder.go | 28 +++++----- statediff/builder_test.go | 2 + statediff/helpers.go | 16 ++++++ statediff/service.go | 60 +++++++++++++-------- statediff/service_test.go | 5 ++ statediff/testhelpers/mocks/service.go | 38 +++++++------ statediff/testhelpers/mocks/service_test.go | 2 + statediff/types.go | 18 +++++-- 8 files changed, 108 insertions(+), 61 deletions(-) diff --git a/statediff/builder.go b/statediff/builder.go index aee8f71ff..46546c1d5 100644 --- a/statediff/builder.go +++ b/statediff/builder.go @@ -202,7 +202,7 @@ func (sdb *builder) buildStateDiffWithIntermediateStateNodes(args StateRoots, pa // a map of their leafkey to all the accounts that were touched and exist at A diffAccountsAtA, err := sdb.deletedOrUpdatedState( oldTrie.NodeIterator([]byte{}), newTrie.NodeIterator([]byte{}), - diffPathsAtB, params.WatchedAddresses, output) + diffPathsAtB, params.watchedAddressesLeafKeys, output) if err != nil { return fmt.Errorf("error collecting deletedOrUpdatedNodes: %v", err) } @@ -247,7 +247,7 @@ func (sdb *builder) buildStateDiffWithoutIntermediateStateNodes(args StateRoots, // and a slice of all the paths for the nodes in both of the above sets diffAccountsAtB, diffPathsAtB, err := sdb.createdAndUpdatedState( oldTrie.NodeIterator([]byte{}), newTrie.NodeIterator([]byte{}), - params.WatchedAddresses) + params.watchedAddressesLeafKeys) if err != nil { return fmt.Errorf("error collecting createdAndUpdatedNodes: %v", err) } @@ -256,7 +256,7 @@ func (sdb *builder) buildStateDiffWithoutIntermediateStateNodes(args StateRoots, // a map of their leafkey to all the accounts that were touched and exist at A diffAccountsAtA, err := sdb.deletedOrUpdatedState( oldTrie.NodeIterator([]byte{}), newTrie.NodeIterator([]byte{}), - diffPathsAtB, params.WatchedAddresses, output) + diffPathsAtB, params.watchedAddressesLeafKeys, output) if err != nil { return fmt.Errorf("error collecting deletedOrUpdatedNodes: %v", err) } @@ -289,7 +289,7 @@ func (sdb *builder) buildStateDiffWithoutIntermediateStateNodes(args StateRoots, // createdAndUpdatedState returns // a mapping of their leafkeys to all the accounts that exist in a different state at B than A // and a slice of the paths for all of the nodes included in both -func (sdb *builder) createdAndUpdatedState(a, b trie.NodeIterator, watchedAddresses []common.Address) (AccountMap, map[string]bool, error) { +func (sdb *builder) createdAndUpdatedState(a, b trie.NodeIterator, watchedAddressesLeafKeys map[common.Hash]struct{}) (AccountMap, map[string]bool, error) { diffPathsAtB := make(map[string]bool) diffAcountsAtB := make(AccountMap) it, _ := trie.NewDifferenceIterator(a, b) @@ -313,7 +313,7 @@ func (sdb *builder) createdAndUpdatedState(a, b trie.NodeIterator, watchedAddres valueNodePath := append(node.Path, partialPath...) encodedPath := trie.HexToCompact(valueNodePath) leafKey := encodedPath[1:] - if isWatchedAddress(watchedAddresses, leafKey) { + if isWatchedAddress(watchedAddressesLeafKeys, leafKey) { diffAcountsAtB[common.Bytes2Hex(leafKey)] = accountWrapper{ NodeType: node.NodeType, Path: node.Path, @@ -386,7 +386,7 @@ func (sdb *builder) createdAndUpdatedStateWithIntermediateNodes(a, b trie.NodeIt // deletedOrUpdatedState returns a slice of all the pathes that are emptied at B // and a mapping of their leafkeys to all the accounts that exist in a different state at A than B -func (sdb *builder) deletedOrUpdatedState(a, b trie.NodeIterator, diffPathsAtB map[string]bool, watchedAddresses []common.Address, output StateNodeSink) (AccountMap, error) { +func (sdb *builder) deletedOrUpdatedState(a, b trie.NodeIterator, diffPathsAtB map[string]bool, watchedAddressesLeafKeys map[common.Hash]struct{}, output StateNodeSink) (AccountMap, error) { diffAccountAtA := make(AccountMap) it, _ := trie.NewDifferenceIterator(b, a) for it.Next(true) { @@ -409,7 +409,7 @@ func (sdb *builder) deletedOrUpdatedState(a, b trie.NodeIterator, diffPathsAtB m valueNodePath := append(node.Path, partialPath...) encodedPath := trie.HexToCompact(valueNodePath) leafKey := encodedPath[1:] - if isWatchedAddress(watchedAddresses, leafKey) { + if isWatchedAddress(watchedAddressesLeafKeys, leafKey) { diffAccountAtA[common.Bytes2Hex(leafKey)] = accountWrapper{ NodeType: node.NodeType, Path: node.Path, @@ -713,16 +713,12 @@ func (sdb *builder) deletedOrUpdatedStorage(a, b trie.NodeIterator, diffPathsAtB } // isWatchedAddress is used to check if a state account corresponds to one of the addresses the builder is configured to watch -func isWatchedAddress(watchedAddresses []common.Address, stateLeafKey []byte) bool { +func isWatchedAddress(watchedAddressesLeafKeys map[common.Hash]struct{}, stateLeafKey []byte) bool { // If we aren't watching any specific addresses, we are watching everything - if len(watchedAddresses) == 0 { + if len(watchedAddressesLeafKeys) == 0 { return true } - for _, addr := range watchedAddresses { - addrHashKey := crypto.Keccak256(addr.Bytes()) - if bytes.Equal(addrHashKey, stateLeafKey) { - return true - } - } - return false + + _, ok := watchedAddressesLeafKeys[common.BytesToHash(stateLeafKey)] + return ok } diff --git a/statediff/builder_test.go b/statediff/builder_test.go index 741605d41..945e3799d 100644 --- a/statediff/builder_test.go +++ b/statediff/builder_test.go @@ -987,6 +987,7 @@ func TestBuilderWithWatchedAddressList(t *testing.T) { params := statediff.Params{ WatchedAddresses: []common.Address{testhelpers.Account1Addr, testhelpers.ContractAddr}, } + params.ComputeWatchedAddressesLeafKeys() builder = statediff.NewBuilder(chain.StateCache()) var tests = []struct { @@ -1566,6 +1567,7 @@ func TestBuilderWithRemovedNonWatchedAccount(t *testing.T) { params := statediff.Params{ WatchedAddresses: []common.Address{testhelpers.Account1Addr, testhelpers.Account2Addr}, } + params.ComputeWatchedAddressesLeafKeys() builder = statediff.NewBuilder(chain.StateCache()) var tests = []struct { diff --git a/statediff/helpers.go b/statediff/helpers.go index 73ac084e8..8870855bd 100644 --- a/statediff/helpers.go +++ b/statediff/helpers.go @@ -24,8 +24,11 @@ import ( "sort" "strings" + "github.com/thoas/go-funk" + "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/statediff/indexer/postgres" + . "github.com/ethereum/go-ethereum/statediff/types" ) func sortKeys(data AccountMap) []string { @@ -96,6 +99,19 @@ func loadWatchedAddresses(db *postgres.DB) error { writeLoopParams.Lock() defer writeLoopParams.Unlock() writeLoopParams.WatchedAddresses = watchedAddresses + writeLoopParams.ComputeWatchedAddressesLeafKeys() return nil } + +// MapWatchAddressArgsToAddresses maps []WatchAddressArg to corresponding []common.Address +func MapWatchAddressArgsToAddresses(args []WatchAddressArg) ([]common.Address, error) { + addresses, ok := funk.Map(args, func(arg WatchAddressArg) common.Address { + return common.HexToAddress(arg.Address) + }).([]common.Address) + if !ok { + return nil, fmt.Errorf(typeAssertionFailed) + } + + return addresses, nil +} diff --git a/statediff/service.go b/statediff/service.go index ac6739ed6..998a8e85a 100644 --- a/statediff/service.go +++ b/statediff/service.go @@ -423,6 +423,9 @@ func (sds *Service) streamStateDiff(currentBlock *types.Block, parentRoot common func (sds *Service) StateDiffAt(blockNumber uint64, params Params) (*Payload, error) { currentBlock := sds.BlockChain.GetBlockByNumber(blockNumber) log.Info("sending state diff", "block height", blockNumber) + + params.ComputeWatchedAddressesLeafKeys() + if blockNumber == 0 { return sds.processStateDiff(currentBlock, common.Hash{}, params) } @@ -435,6 +438,9 @@ func (sds *Service) StateDiffAt(blockNumber uint64, params Params) (*Payload, er func (sds *Service) StateDiffFor(blockHash common.Hash, params Params) (*Payload, error) { currentBlock := sds.BlockChain.GetBlockByHash(blockHash) log.Info("sending state diff", "block hash", blockHash) + + params.ComputeWatchedAddressesLeafKeys() + if currentBlock.NumberU64() == 0 { return sds.processStateDiff(currentBlock, common.Hash{}, params) } @@ -493,6 +499,9 @@ func (sds *Service) newPayload(stateObject []byte, block *types.Block, params Pa func (sds *Service) StateTrieAt(blockNumber uint64, params Params) (*Payload, error) { currentBlock := sds.BlockChain.GetBlockByNumber(blockNumber) log.Info("sending state trie", "block height", blockNumber) + + params.ComputeWatchedAddressesLeafKeys() + return sds.processStateTrie(currentBlock, params) } @@ -515,6 +524,9 @@ func (sds *Service) Subscribe(id rpc.ID, sub chan<- Payload, quitChan chan<- boo if atomic.CompareAndSwapInt32(&sds.subscribers, 0, 1) { log.Info("State diffing subscription received; beginning statediff processing") } + + params.ComputeWatchedAddressesLeafKeys() + // Subscription type is defined as the hash of the rlp-serialized subscription params by, err := rlp.EncodeToBytes(params) if err != nil { @@ -661,6 +673,8 @@ func (sds *Service) StreamCodeAndCodeHash(blockNumber uint64, outChan chan<- Cod // This operation cannot be performed back past the point of db pruning; it requires an archival node // for historical data func (sds *Service) WriteStateDiffAt(blockNumber uint64, params Params) error { + params.ComputeWatchedAddressesLeafKeys() + currentBlock := sds.BlockChain.GetBlockByNumber(blockNumber) parentRoot := common.Hash{} if blockNumber != 0 { @@ -674,6 +688,8 @@ func (sds *Service) WriteStateDiffAt(blockNumber uint64, params Params) error { // This operation cannot be performed back past the point of db pruning; it requires an archival node // for historical data func (sds *Service) WriteStateDiffFor(blockHash common.Hash, params Params) error { + params.ComputeWatchedAddressesLeafKeys() + currentBlock := sds.BlockChain.GetBlockByHash(blockHash) parentRoot := common.Hash{} if currentBlock.NumberU64() != 0 { @@ -736,7 +752,7 @@ func (sds *Service) writeStateDiffWithRetry(block *types.Block, parentRoot commo } // Performs one of following operations on the watched addresses in writeLoopParams and the db: -// Add | Remove | Set | Clear +// add | remove | set | clear func (sds *Service) WatchAddress(operation OperationType, args []WatchAddressArg) error { // lock writeLoopParams for a write writeLoopParams.Lock() @@ -756,65 +772,66 @@ func (sds *Service) WatchAddress(operation OperationType, args []WatchAddressArg return true }).([]WatchAddressArg) if !ok { - return fmt.Errorf("Add: filtered args %s", typeAssertionFailed) + return fmt.Errorf("add: filtered args %s", typeAssertionFailed) } // get addresses from the filtered args - filteredAddresses, ok := funk.Map(filteredArgs, func(arg WatchAddressArg) common.Address { - return common.HexToAddress(arg.Address) - }).([]common.Address) - if !ok { - return fmt.Errorf("Add: filtered addresses %s", typeAssertionFailed) + filteredAddresses, err := MapWatchAddressArgsToAddresses(filteredArgs) + if err != nil { + return fmt.Errorf("add: filtered addresses %s", err.Error()) } // update the db - err := sds.indexer.InsertWatchedAddresses(filteredArgs, currentBlockNumber) + err = sds.indexer.InsertWatchedAddresses(filteredArgs, currentBlockNumber) if err != nil { return err } // update in-memory params writeLoopParams.WatchedAddresses = append(writeLoopParams.WatchedAddresses, filteredAddresses...) + funk.ForEach(filteredAddresses, func(address common.Address) { + writeLoopParams.watchedAddressesLeafKeys[crypto.Keccak256Hash(address.Bytes())] = struct{}{} + }) case Remove: // get addresses from args - argAddresses, ok := funk.Map(args, func(arg WatchAddressArg) common.Address { - return common.HexToAddress(arg.Address) - }).([]common.Address) - if !ok { - return fmt.Errorf("Remove: mapped addresses %s", typeAssertionFailed) + argAddresses, err := MapWatchAddressArgsToAddresses(args) + if err != nil { + return fmt.Errorf("remove: mapped addresses %s", err.Error()) } // remove the provided addresses from currently watched addresses addresses, ok := funk.Subtract(writeLoopParams.WatchedAddresses, argAddresses).([]common.Address) if !ok { - return fmt.Errorf("Remove: filtered addresses %s", typeAssertionFailed) + return fmt.Errorf("remove: filtered addresses %s", typeAssertionFailed) } // update the db - err := sds.indexer.RemoveWatchedAddresses(args) + err = sds.indexer.RemoveWatchedAddresses(args) if err != nil { return err } // update in-memory params writeLoopParams.WatchedAddresses = addresses + funk.ForEach(argAddresses, func(address common.Address) { + delete(writeLoopParams.watchedAddressesLeafKeys, crypto.Keccak256Hash(address.Bytes())) + }) case Set: // get addresses from args - argAddresses, ok := funk.Map(args, func(arg WatchAddressArg) common.Address { - return common.HexToAddress(arg.Address) - }).([]common.Address) - if !ok { - return fmt.Errorf("Set: mapped addresses %s", typeAssertionFailed) + argAddresses, err := MapWatchAddressArgsToAddresses(args) + if err != nil { + return fmt.Errorf("set: mapped addresses %s", err.Error()) } // update the db - err := sds.indexer.SetWatchedAddresses(args, currentBlockNumber) + err = sds.indexer.SetWatchedAddresses(args, currentBlockNumber) if err != nil { return err } // update in-memory params writeLoopParams.WatchedAddresses = argAddresses + writeLoopParams.ComputeWatchedAddressesLeafKeys() case Clear: // update the db err := sds.indexer.ClearWatchedAddresses() @@ -824,6 +841,7 @@ func (sds *Service) WatchAddress(operation OperationType, args []WatchAddressArg // update in-memory params writeLoopParams.WatchedAddresses = []common.Address{} + writeLoopParams.ComputeWatchedAddressesLeafKeys() default: return fmt.Errorf("%s %s", unexpectedOperation, operation) diff --git a/statediff/service_test.go b/statediff/service_test.go index ca9a483a5..a3a5ccca4 100644 --- a/statediff/service_test.go +++ b/statediff/service_test.go @@ -144,6 +144,7 @@ func testErrorInChainEventLoop(t *testing.T) { } } + defaultParams.ComputeWatchedAddressesLeafKeys() if !reflect.DeepEqual(builder.Params, defaultParams) { t.Error("Test failure:", t.Name()) t.Logf("Actual params does not equal expected.\nactual:%+v\nexpected: %+v", builder.Params, defaultParams) @@ -195,6 +196,8 @@ func testErrorInBlockLoop(t *testing.T) { } }() service.Loop(eventsChannel) + + defaultParams.ComputeWatchedAddressesLeafKeys() if !reflect.DeepEqual(builder.Params, defaultParams) { t.Error("Test failure:", t.Name()) t.Logf("Actual params does not equal expected.\nactual:%+v\nexpected: %+v", builder.Params, defaultParams) @@ -268,6 +271,8 @@ func testErrorInStateDiffAt(t *testing.T) { if err != nil { t.Error(err) } + + defaultParams.ComputeWatchedAddressesLeafKeys() if !reflect.DeepEqual(builder.Params, defaultParams) { t.Error("Test failure:", t.Name()) t.Logf("Actual params does not equal expected.\nactual:%+v\nexpected: %+v", builder.Params, defaultParams) diff --git a/statediff/testhelpers/mocks/service.go b/statediff/testhelpers/mocks/service.go index 61513d75c..e14565774 100644 --- a/statediff/testhelpers/mocks/service.go +++ b/statediff/testhelpers/mocks/service.go @@ -362,65 +362,62 @@ func (sds *MockStateDiffService) WatchAddress(operation statediff.OperationType, return true }).([]sdtypes.WatchAddressArg) if !ok { - return fmt.Errorf("Add: filtered args %s", typeAssertionFailed) + return fmt.Errorf("add: filtered args %s", typeAssertionFailed) } // get addresses from the filtered args - filteredAddresses, ok := funk.Map(filteredArgs, func(arg sdtypes.WatchAddressArg) common.Address { - return common.HexToAddress(arg.Address) - }).([]common.Address) - if !ok { - return fmt.Errorf("Add: filtered addresses %s", typeAssertionFailed) + filteredAddresses, err := statediff.MapWatchAddressArgsToAddresses(filteredArgs) + if err != nil { + return fmt.Errorf("add: filtered addresses %s", err.Error()) } // update the db - err := sds.Indexer.InsertWatchedAddresses(filteredArgs, currentBlockNumber) + err = sds.Indexer.InsertWatchedAddresses(filteredArgs, currentBlockNumber) if err != nil { return err } // update in-memory params sds.writeLoopParams.WatchedAddresses = append(sds.writeLoopParams.WatchedAddresses, filteredAddresses...) + sds.writeLoopParams.ComputeWatchedAddressesLeafKeys() case statediff.Remove: // get addresses from args - argAddresses, ok := funk.Map(args, func(arg sdtypes.WatchAddressArg) common.Address { - return common.HexToAddress(arg.Address) - }).([]common.Address) - if !ok { - return fmt.Errorf("Remove: mapped addresses %s", typeAssertionFailed) + argAddresses, err := statediff.MapWatchAddressArgsToAddresses(args) + if err != nil { + return fmt.Errorf("remove: mapped addresses %s", err.Error()) } // remove the provided addresses from currently watched addresses addresses, ok := funk.Subtract(sds.writeLoopParams.WatchedAddresses, argAddresses).([]common.Address) if !ok { - return fmt.Errorf("Remove: filtered addresses %s", typeAssertionFailed) + return fmt.Errorf("remove: filtered addresses %s", typeAssertionFailed) } // update the db - err := sds.Indexer.RemoveWatchedAddresses(args) + err = sds.Indexer.RemoveWatchedAddresses(args) if err != nil { return err } // update in-memory params sds.writeLoopParams.WatchedAddresses = addresses + sds.writeLoopParams.ComputeWatchedAddressesLeafKeys() case statediff.Set: // get addresses from args - argAddresses, ok := funk.Map(args, func(arg sdtypes.WatchAddressArg) common.Address { - return common.HexToAddress(arg.Address) - }).([]common.Address) - if !ok { - return fmt.Errorf("Set: mapped addresses %s", typeAssertionFailed) + argAddresses, err := statediff.MapWatchAddressArgsToAddresses(args) + if err != nil { + return fmt.Errorf("set: mapped addresses %s", err.Error()) } // update the db - err := sds.Indexer.SetWatchedAddresses(args, currentBlockNumber) + err = sds.Indexer.SetWatchedAddresses(args, currentBlockNumber) if err != nil { return err } // update in-memory params sds.writeLoopParams.WatchedAddresses = argAddresses + sds.writeLoopParams.ComputeWatchedAddressesLeafKeys() case statediff.Clear: // update the db err := sds.Indexer.ClearWatchedAddresses() @@ -430,6 +427,7 @@ func (sds *MockStateDiffService) WatchAddress(operation statediff.OperationType, // update in-memory params sds.writeLoopParams.WatchedAddresses = []common.Address{} + sds.writeLoopParams.ComputeWatchedAddressesLeafKeys() default: return fmt.Errorf("%s %s", unexpectedOperation, operation) diff --git a/statediff/testhelpers/mocks/service_test.go b/statediff/testhelpers/mocks/service_test.go index f1be5e317..dbd059def 100644 --- a/statediff/testhelpers/mocks/service_test.go +++ b/statediff/testhelpers/mocks/service_test.go @@ -506,6 +506,7 @@ func testWatchAddressAPI(t *testing.T) { mockService.writeLoopParams = statediff.ParamsWithMutex{ Params: test.startingParams, } + mockService.writeLoopParams.ComputeWatchedAddressesLeafKeys() // make the API call to change watched addresses err := mockService.WatchAddress(test.operation, test.args) @@ -522,6 +523,7 @@ func testWatchAddressAPI(t *testing.T) { } // check updated indexing params + test.expectedParams.ComputeWatchedAddressesLeafKeys() updatedParams := mockService.writeLoopParams.Params if !reflect.DeepEqual(updatedParams, test.expectedParams) { t.Logf("Test failed: %s", test.name) diff --git a/statediff/types.go b/statediff/types.go index 9e90ecabc..e33193637 100644 --- a/statediff/types.go +++ b/statediff/types.go @@ -26,6 +26,7 @@ import ( "github.com/ethereum/go-ethereum/common" ctypes "github.com/ethereum/go-ethereum/core/types" + "github.com/ethereum/go-ethereum/crypto" "github.com/ethereum/go-ethereum/statediff/types" ) @@ -51,6 +52,15 @@ type Params struct { IncludeTD bool IncludeCode bool WatchedAddresses []common.Address + watchedAddressesLeafKeys map[common.Hash]struct{} +} + +// ComputeWatchedAddressesLeafKeys populates a map with keys (Keccak256Hash) of each of the WatchedAddresses +func (p *Params) ComputeWatchedAddressesLeafKeys() { + p.watchedAddressesLeafKeys = make(map[common.Hash]struct{}, len(p.WatchedAddresses)) + for _, address := range p.WatchedAddresses { + p.watchedAddressesLeafKeys[crypto.Keccak256Hash(address.Bytes())] = struct{}{} + } } // ParamsWithMutex allows to lock the parameters while they are being updated | read from @@ -122,8 +132,8 @@ type accountWrapper struct { type OperationType string const ( - Add OperationType = "Add" - Remove OperationType = "Remove" - Set OperationType = "Set" - Clear OperationType = "Clear" + Add OperationType = "add" + Remove OperationType = "remove" + Set OperationType = "set" + Clear OperationType = "clear" ) -- 2.45.2 From 53913e0cb9a550c523e9ba979740b917f322c587 Mon Sep 17 00:00:00 2001 From: nabarun Date: Wed, 9 Mar 2022 18:07:51 +0530 Subject: [PATCH 14/17] Add test for removal of watched address --- statediff/builder_test.go | 162 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 162 insertions(+) diff --git a/statediff/builder_test.go b/statediff/builder_test.go index 945e3799d..e7257ef6a 100644 --- a/statediff/builder_test.go +++ b/statediff/builder_test.go @@ -1674,6 +1674,168 @@ func TestBuilderWithRemovedNonWatchedAccount(t *testing.T) { } } +func TestBuilderWithRemovedWatchedAccount(t *testing.T) { + blocks, chain := testhelpers.MakeChain(6, testhelpers.Genesis, testhelpers.TestChainGen) + contractLeafKey = testhelpers.AddressToLeafKey(testhelpers.ContractAddr) + defer chain.Stop() + block3 = blocks[2] + block4 = blocks[3] + block5 = blocks[4] + block6 = blocks[5] + params := statediff.Params{ + WatchedAddresses: []common.Address{testhelpers.Account1Addr, testhelpers.ContractAddr}, + } + params.ComputeWatchedAddressesLeafKeys() + builder = statediff.NewBuilder(chain.StateCache()) + + var tests = []struct { + name string + startingArguments statediff.Args + expected *statediff.StateObject + }{ + { + "testBlock4", + statediff.Args{ + OldStateRoot: block3.Root(), + NewStateRoot: block4.Root(), + BlockNumber: block4.Number(), + BlockHash: block4.Hash(), + }, + &statediff.StateObject{ + BlockNumber: block4.Number(), + BlockHash: block4.Hash(), + Nodes: []sdtypes.StateNode{ + { + Path: []byte{'\x06'}, + NodeType: sdtypes.Leaf, + LeafKey: contractLeafKey, + NodeValue: contractAccountAtBlock4LeafNode, + StorageNodes: []sdtypes.StorageNode{ + { + Path: []byte{'\x04'}, + NodeType: sdtypes.Leaf, + LeafKey: slot2StorageKey.Bytes(), + NodeValue: slot2StorageLeafNode, + }, + { + Path: []byte{'\x0b'}, + NodeType: sdtypes.Removed, + LeafKey: slot1StorageKey.Bytes(), + NodeValue: []byte{}, + }, + { + Path: []byte{'\x0c'}, + NodeType: sdtypes.Removed, + LeafKey: slot3StorageKey.Bytes(), + NodeValue: []byte{}, + }, + }, + }, + }, + }, + }, + { + "testBlock5", + statediff.Args{ + OldStateRoot: block4.Root(), + NewStateRoot: block5.Root(), + BlockNumber: block5.Number(), + BlockHash: block5.Hash(), + }, + &statediff.StateObject{ + BlockNumber: block5.Number(), + BlockHash: block5.Hash(), + Nodes: []sdtypes.StateNode{ + { + Path: []byte{'\x06'}, + NodeType: sdtypes.Leaf, + LeafKey: contractLeafKey, + NodeValue: contractAccountAtBlock5LeafNode, + StorageNodes: []sdtypes.StorageNode{ + { + Path: []byte{}, + NodeType: sdtypes.Leaf, + LeafKey: slot0StorageKey.Bytes(), + NodeValue: slot0StorageLeafRootNode, + }, + { + Path: []byte{'\x02'}, + NodeType: sdtypes.Removed, + LeafKey: slot0StorageKey.Bytes(), + NodeValue: []byte{}, + }, + { + Path: []byte{'\x04'}, + NodeType: sdtypes.Removed, + LeafKey: slot2StorageKey.Bytes(), + NodeValue: []byte{}, + }, + }, + }, + { + Path: []byte{'\x0e'}, + NodeType: sdtypes.Leaf, + LeafKey: testhelpers.Account1LeafKey, + NodeValue: account1AtBlock5LeafNode, + StorageNodes: emptyStorage, + }, + }, + }, + }, + { + "testBlock6", + statediff.Args{ + OldStateRoot: block5.Root(), + NewStateRoot: block6.Root(), + BlockNumber: block6.Number(), + BlockHash: block6.Hash(), + }, + &statediff.StateObject{ + BlockNumber: block6.Number(), + BlockHash: block6.Hash(), + Nodes: []sdtypes.StateNode{ + { + Path: []byte{'\x06'}, + NodeType: sdtypes.Removed, + LeafKey: contractLeafKey, + NodeValue: []byte{}, + }, + { + Path: []byte{'\x0e'}, + NodeType: sdtypes.Leaf, + LeafKey: testhelpers.Account1LeafKey, + NodeValue: account1AtBlock6LeafNode, + StorageNodes: emptyStorage, + }, + }, + }, + }, + } + + for _, test := range tests { + diff, err := builder.BuildStateDiffObject(test.startingArguments, params) + if err != nil { + t.Error(err) + } + receivedStateDiffRlp, err := rlp.EncodeToBytes(diff) + if err != nil { + t.Error(err) + } + + expectedStateDiffRlp, err := rlp.EncodeToBytes(test.expected) + if err != nil { + t.Error(err) + } + + sort.Slice(receivedStateDiffRlp, func(i, j int) bool { return receivedStateDiffRlp[i] < receivedStateDiffRlp[j] }) + sort.Slice(expectedStateDiffRlp, func(i, j int) bool { return expectedStateDiffRlp[i] < expectedStateDiffRlp[j] }) + if !bytes.Equal(receivedStateDiffRlp, expectedStateDiffRlp) { + t.Logf("Test failed: %s", test.name) + t.Errorf("actual state diff: %+v\r\n\r\n\r\nexpected state diff: %+v", diff, test.expected) + } + } +} + var ( slot00StorageValue = common.Hex2Bytes("9471562b71999873db5b286df957af199ec94617f7") // prefixed TestBankAddress -- 2.45.2 From 2d790cea0345ffbbbae68fb5845e8d852127b54d Mon Sep 17 00:00:00 2001 From: nabarun Date: Thu, 10 Mar 2022 18:37:10 +0530 Subject: [PATCH 15/17] Refactor database transaction defer to match pattern elsewhere --- statediff/indexer/indexer.go | 60 +++++++++++++++++++++--------------- 1 file changed, 36 insertions(+), 24 deletions(-) diff --git a/statediff/indexer/indexer.go b/statediff/indexer/indexer.go index 58e0bafc5..65a5e950f 100644 --- a/statediff/indexer/indexer.go +++ b/statediff/indexer/indexer.go @@ -558,11 +558,20 @@ func (sdi *StateDiffIndexer) PushCodeAndCodeHash(tx *BlockTx, codeAndCodeHash sd // InsertWatchedAddresses inserts the given addresses in the database func (sdi *StateDiffIndexer) InsertWatchedAddresses(args []sdtypes.WatchAddressArg, currentBlockNumber *big.Int) error { - tx, err := sdi.dbWriter.db.Begin() + tx, err := sdi.dbWriter.db.Beginx() if err != nil { return err } - defer tx.Rollback() + defer func() { + if p := recover(); p != nil { + shared.Rollback(tx) + panic(p) + } else if err != nil { + shared.Rollback(tx) + } else { + err = tx.Commit() + } + }() for _, arg := range args { _, err = tx.Exec(`INSERT INTO eth_meta.watched_addresses (address, created_at, watched_at) VALUES ($1, $2, $3) ON CONFLICT (address) DO NOTHING`, @@ -572,21 +581,25 @@ func (sdi *StateDiffIndexer) InsertWatchedAddresses(args []sdtypes.WatchAddressA } } - err = tx.Commit() - if err != nil { - return err - } - - return nil + return err } // RemoveWatchedAddresses removes the given watched addresses from the database func (sdi *StateDiffIndexer) RemoveWatchedAddresses(args []sdtypes.WatchAddressArg) error { - tx, err := sdi.dbWriter.db.Begin() + tx, err := sdi.dbWriter.db.Beginx() if err != nil { return err } - defer tx.Rollback() + defer func() { + if p := recover(); p != nil { + shared.Rollback(tx) + panic(p) + } else if err != nil { + shared.Rollback(tx) + } else { + err = tx.Commit() + } + }() for _, arg := range args { _, err = tx.Exec(`DELETE FROM eth_meta.watched_addresses WHERE address = $1`, arg.Address) @@ -595,21 +608,25 @@ func (sdi *StateDiffIndexer) RemoveWatchedAddresses(args []sdtypes.WatchAddressA } } - err = tx.Commit() - if err != nil { - return err - } - - return nil + return err } // SetWatchedAddresses clears and inserts the given addresses in the database func (sdi *StateDiffIndexer) SetWatchedAddresses(args []sdtypes.WatchAddressArg, currentBlockNumber *big.Int) error { - tx, err := sdi.dbWriter.db.Begin() + tx, err := sdi.dbWriter.db.Beginx() if err != nil { return err } - defer tx.Rollback() + defer func() { + if p := recover(); p != nil { + shared.Rollback(tx) + panic(p) + } else if err != nil { + shared.Rollback(tx) + } else { + err = tx.Commit() + } + }() _, err = tx.Exec(`DELETE FROM eth_meta.watched_addresses`) if err != nil { @@ -624,12 +641,7 @@ func (sdi *StateDiffIndexer) SetWatchedAddresses(args []sdtypes.WatchAddressArg, } } - err = tx.Commit() - if err != nil { - return err - } - - return nil + return err } // ClearWatchedAddresses clears all the watched addresses from the database -- 2.45.2 From 436bc78c098e18642974bf9a7cc3fc39ac7b8737 Mon Sep 17 00:00:00 2001 From: prathamesh0 Date: Fri, 1 Apr 2022 13:31:04 +0530 Subject: [PATCH 16/17] Update ipld-eth-db image source in docker-compose --- docker-compose.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docker-compose.yml b/docker-compose.yml index f1a37ddcb..8083b4d87 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -3,7 +3,7 @@ version: '3.2' services: ipld-eth-db: restart: always - image: vulcanize/ipld-eth-db:v0.2.0 + image: vulcanize/ipld-eth-db:v2.1.1 environment: POSTGRES_USER: "vdbm" POSTGRES_DB: "vulcanize_public" -- 2.45.2 From 2ec883cdcc30a8335f2e5db5fd357e253033a6c2 Mon Sep 17 00:00:00 2001 From: prathamesh0 Date: Fri, 1 Apr 2022 14:02:09 +0530 Subject: [PATCH 17/17] Update unit tests command to get assert package --- Makefile | 1 + 1 file changed, 1 insertion(+) diff --git a/Makefile b/Makefile index 975628922..c3c95eeae 100644 --- a/Makefile +++ b/Makefile @@ -52,6 +52,7 @@ ios: .PHONY: statedifftest statedifftest: | $(GOOSE) + go get github.com/stretchr/testify/assert@v1.7.0 MODE=statediff go test ./statediff/... -v test: all -- 2.45.2