diff --git a/statediff/api.go b/statediff/api.go index 2271ae6d0..3dfed6c26 100644 --- a/statediff/api.go +++ b/statediff/api.go @@ -151,6 +151,6 @@ func (api *PublicStateDiffAPI) WriteStateDiffFor(ctx context.Context, blockHash } // WatchAddress adds the given address to a list of watched addresses to which the direct statediff process is restricted -func (api *PublicStateDiffAPI) WatchAddress(address common.Address) error { - return api.sds.WatchAddress(address) +func (api *PublicStateDiffAPI) WatchAddress(operation OperationType, addresses []common.Address) error { + return api.sds.WatchAddress(operation, addresses) } diff --git a/statediff/api_test.go b/statediff/api_test.go index 9ea8a9ede..1e2bc2898 100644 --- a/statediff/api_test.go +++ b/statediff/api_test.go @@ -48,6 +48,7 @@ var ( service = statediff.Service{} ) +// TODO: Update tests for the updated API func TestWatchAddress(t *testing.T) { watchedAddresses := service.GetWathchedAddresses() if !reflect.DeepEqual(watchedAddresses, watchedAddresses0) { @@ -60,7 +61,7 @@ func TestWatchAddress(t *testing.T) { } func testWatchUnwatchedAddress(t *testing.T) { - err := service.WatchAddress(address1) + err := service.WatchAddress(statediff.Add, []common.Address{address1}) if err != nil { t.Error("Test failure:", t.Name()) t.Logf("Unexpected error %s thrown on an attempt to watch an unwatched address.", err.Error()) @@ -71,7 +72,7 @@ func testWatchUnwatchedAddress(t *testing.T) { t.Logf("Actual watched addresses not equal expected watched addresses.\nactual: %+v\nexpected: %+v", watchedAddresses, watchedAddresses1) } - err = service.WatchAddress(address2) + err = service.WatchAddress(statediff.Add, []common.Address{address2}) if err != nil { t.Error("Test failure:", t.Name()) t.Logf("Unexpected error %s thrown on an attempt to watch an unwatched address.", err.Error()) @@ -84,7 +85,7 @@ func testWatchUnwatchedAddress(t *testing.T) { } func testWatchWatchedAddress(t *testing.T) { - err := service.WatchAddress(address1) + err := service.WatchAddress(statediff.Add, []common.Address{address1}) if err == nil { t.Error("Test failure:", t.Name()) t.Logf("Expected error %s not thrown on an attempt to watch an already watched address.", expectedError.Error()) diff --git a/statediff/helpers.go b/statediff/helpers.go index eb5060c51..1a766aae3 100644 --- a/statediff/helpers.go +++ b/statediff/helpers.go @@ -20,8 +20,12 @@ package statediff import ( + "fmt" "sort" "strings" + + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/statediff/indexer/postgres" ) func sortKeys(data AccountMap) []string { @@ -69,5 +73,51 @@ func findIntersection(a, b []string) []string { } } } - +} + +// loadWatchedAddresses is used to load watched addresses to the in-memory write loop params from the db +func loadWatchedAddresses(db *postgres.DB) error { + rows, err := db.Query("SELECT address FROM eth.watched_addresses") + if err != nil { + return fmt.Errorf("error loading watched addresses: %v", err) + } + + var watchedAddresses []common.Address + for rows.Next() { + var addressHex string + err := rows.Scan(&addressHex) + if err != nil { + return err + } + + watchedAddresses = append(watchedAddresses, common.HexToAddress(addressHex)) + } + + writeLoopParams.WatchedAddresses = watchedAddresses + + return nil +} + +func removeWatchedAddresses(watchedAddresses []common.Address, addressesToRemove []common.Address) []common.Address { + addresses := make([]common.Address, len(addressesToRemove)) + copy(addresses, watchedAddresses) + + for _, address := range addressesToRemove { + if idx := containsAddress(addresses, address); idx != -1 { + addresses = append(addresses[:idx], addresses[idx+1:]...) + } + } + + return addresses +} + +// containsAddress is used to check if an address is present in the provided list of watched addresses +// return the index if found else -1 +func containsAddress(watchedAddresses []common.Address, address common.Address) int { + for idx, addr := range watchedAddresses { + if addr == address { + return idx + } + } + return -1 } diff --git a/statediff/indexer/indexer.go b/statediff/indexer/indexer.go index 60d69f932..8d628d31e 100644 --- a/statediff/indexer/indexer.go +++ b/statediff/indexer/indexer.go @@ -59,6 +59,9 @@ type Indexer interface { PushStateNode(tx *BlockTx, stateNode sdtypes.StateNode) error PushCodeAndCodeHash(tx *BlockTx, codeAndCodeHash sdtypes.CodeAndCodeHash) error ReportDBMetrics(delay time.Duration, quit <-chan bool) + InsertWatchedAddresses(addresses []common.Address, currentBlock *big.Int) error + RemoveWatchedAddresses(addresses []common.Address) error + ClearWatchedAddresses() error } // StateDiffIndexer satisfies the Indexer interface for ethereum statediff objects @@ -549,3 +552,64 @@ func (sdi *StateDiffIndexer) PushCodeAndCodeHash(tx *BlockTx, codeAndCodeHash sd } return nil } + +func (sdi *StateDiffIndexer) InsertWatchedAddresses(addresses []common.Address, currentBlock *big.Int) error { + tx, err := sdi.dbWriter.db.Begin() + if err != nil { + return err + } + + for _, address := range addresses { + _, err = tx.Exec(`INSERT INTO eth.watched_addresses (address, added_at) VALUES ($1, $2)`, address, currentBlock) + if err != nil { + return fmt.Errorf("error inserting watched_addresses entry: %v", err) + } + } + + err = tx.Commit() + if err != nil { + return err + } + + return nil +} + +func (sdi *StateDiffIndexer) RemoveWatchedAddresses(addresses []common.Address) error { + tx, err := sdi.dbWriter.db.Begin() + if err != nil { + return err + } + + for _, address := range addresses { + _, err = tx.Exec(`DELETE FROM eth.watched_addresses WHERE address = $1`, address) + if err != nil { + return fmt.Errorf("error removing watched_addresses entry: %v", err) + } + } + + err = tx.Commit() + if err != nil { + return err + } + + return nil +} + +func (sdi *StateDiffIndexer) ClearWatchedAddresses() error { + tx, err := sdi.dbWriter.db.Begin() + if err != nil { + return err + } + + _, err = tx.Exec(`DELETE FROM eth.watched_addresses`) + if err != nil { + return fmt.Errorf("error clearing watched_addresses table: %v", err) + } + + err = tx.Commit() + if err != nil { + return err + } + + return nil +} diff --git a/statediff/service.go b/statediff/service.go index 9d8b6ecf0..ba599b05f 100644 --- a/statediff/service.go +++ b/statediff/service.go @@ -18,11 +18,8 @@ package statediff import ( "bytes" - "encoding/json" "fmt" - "io/ioutil" "math/big" - "os" "strconv" "strings" "sync" @@ -58,9 +55,6 @@ const ( deadlockDetected = "deadlock detected" // 40P01 https://www.postgresql.org/docs/current/errcodes-appendix.html ) -// TODO: Take the watched addresses file path as a CLI arg. -const watchedAddressesFile = "./watched-addresses.json" - var writeLoopParams = Params{ IntermediateStateNodes: true, IntermediateStorageNodes: true, @@ -109,7 +103,7 @@ type IService interface { // Event loop for progressively processing and writing diffs directly to DB WriteLoop(chainEventCh chan core.ChainEvent) // Method to add an address to be watched to write loop params - WatchAddress(address common.Address) error + WatchAddress(operation OperationType, addresses []common.Address) error // Method to get currently watched addresses from write loop params GetWathchedAddresses() []common.Address } @@ -170,6 +164,7 @@ func NewBlockCache(max uint) blockCache { func New(stack *node.Node, ethServ *eth.Ethereum, cfg *ethconfig.Config, params ServiceParams) error { blockChain := ethServ.BlockChain() var indexer ind.Indexer + var db *postgres.DB quitCh := make(chan bool) if params.DBParams != nil { info := nodeinfo.Info{ @@ -212,7 +207,7 @@ func New(stack *node.Node, ethServ *eth.Ethereum, cfg *ethconfig.Config, params stack.RegisterLifecycle(sds) stack.RegisterAPIs(sds.APIs()) - err := loadWatchedAddresses() + err := loadWatchedAddresses(db) if err != nil { return err } @@ -731,39 +726,56 @@ func (sds *Service) writeStateDiffWithRetry(block *types.Block, parentRoot commo return err } -// Adds the provided address to the list of watched addresses in write loop params and to the watched addresses file -func (sds *Service) WatchAddress(address common.Address) error { - // Check if address is already being watched - if containsAddress(writeLoopParams.WatchedAddresses, address) { - return fmt.Errorf("Address %s already watched", address) - } +// Performs Add | Remove | Set | Clear operation on the watched addresses in writeLoopParams and the db with provided addresses +func (sds *Service) WatchAddress(operation OperationType, addresses []common.Address) error { + // check operation + switch operation { + case Add: + for _, address := range addresses { + // Check if address is already being watched + if containsAddress(writeLoopParams.WatchedAddresses, address) != -1 { + return fmt.Errorf("Address %s already watched", address) + } + } - // Check if the watched addresses file exists - fileExists, err := doesFileExist(watchedAddressesFile) - if err != nil { - return err - } - - // Create the watched addresses file if doesn't exist - if !fileExists { - _, err := os.Create(watchedAddressesFile) + // TODO: Make sure WriteLoop doesn't call statediffing before the params are updated for the current block + // TODO: Get the current block + err := sds.indexer.InsertWatchedAddresses(addresses, common.Big1) if err != nil { return err } + + writeLoopParams.WatchedAddresses = append(writeLoopParams.WatchedAddresses, addresses...) + case Remove: + err := sds.indexer.RemoveWatchedAddresses(addresses) + if err != nil { + return err + } + + removeWatchedAddresses(writeLoopParams.WatchedAddresses, addresses) + case Set: + err := sds.indexer.ClearWatchedAddresses() + if err != nil { + return err + } + + err = sds.indexer.InsertWatchedAddresses(addresses, common.Big1) + if err != nil { + return err + } + + writeLoopParams.WatchedAddresses = addresses + case Clear: + err := sds.indexer.ClearWatchedAddresses() + if err != nil { + return err + } + + writeLoopParams.WatchedAddresses = nil + default: + return fmt.Errorf("Unexpected operation %s", operation) } - watchedAddresses := append(writeLoopParams.WatchedAddresses, address) - - // Write the updated list of watched address to a json file - content, err := json.Marshal(watchedAddresses) - err = ioutil.WriteFile(watchedAddressesFile, content, 0644) - if err != nil { - return err - } - - // Update the in-memory params as well - writeLoopParams.WatchedAddresses = watchedAddresses - return nil } @@ -771,50 +783,3 @@ func (sds *Service) WatchAddress(address common.Address) error { func (sds *Service) GetWathchedAddresses() []common.Address { return writeLoopParams.WatchedAddresses } - -// loadWatchedAddresses is used to load watched addresses to the in-memory write loop params from a json file if it exists -func loadWatchedAddresses() error { - // Check if the watched addresses file exists - fileExists, err := doesFileExist(watchedAddressesFile) - if err != nil { - return err - } - - if fileExists { - content, err := ioutil.ReadFile(watchedAddressesFile) - if err != nil { - return err - } - - var watchedAddresses []common.Address - err = json.Unmarshal(content, &watchedAddresses) - if err != nil { - return err - } - - writeLoopParams.WatchedAddresses = watchedAddresses - } - - return nil -} - -// containsAddress is used to check if an address is present in the provided list of watched addresses -func containsAddress(watchedAddresses []common.Address, address common.Address) bool { - for _, addr := range watchedAddresses { - if addr == address { - return true - } - } - return false -} - -// doesFileExist is used to check if file at a given path exists -func doesFileExist(path string) (bool, error) { - _, err := os.Stat(path) - if err == nil { - return true, nil - } else if os.IsNotExist(err) { - return false, nil - } - return false, err -} diff --git a/statediff/testhelpers/mocks/service.go b/statediff/testhelpers/mocks/service.go index 3c47cbea6..943720094 100644 --- a/statediff/testhelpers/mocks/service.go +++ b/statediff/testhelpers/mocks/service.go @@ -333,7 +333,7 @@ func sendNonBlockingQuit(id rpc.ID, sub statediff.Subscription) { } } -func (sds *MockStateDiffService) WatchAddress(address common.Address) error { +func (sds *MockStateDiffService) WatchAddress(operation statediff.OperationType, addresses []common.Address) error { return nil } diff --git a/statediff/types.go b/statediff/types.go index ef8256041..fd0877b02 100644 --- a/statediff/types.go +++ b/statediff/types.go @@ -111,3 +111,13 @@ type accountWrapper struct { NodeValue []byte LeafKey []byte } + +// OperationType for type of WatchAddress operation +type OperationType string + +const ( + Add OperationType = "Add" + Remove OperationType = "Remove" + Set OperationType = "Set" + Clear OperationType = "Clear" +)