From d7704d2f98b51b28e7cb4bb04a1bf9d0a1712e75 Mon Sep 17 00:00:00 2001 From: prathamesh0 Date: Tue, 18 Jan 2022 15:05:45 +0530 Subject: [PATCH] Add a creation block arg in the watch address API --- statediff/api.go | 4 +- statediff/api_test.go | 102 ------------------------- statediff/helpers.go | 37 +++++++-- statediff/indexer/indexer.go | 10 +-- statediff/service.go | 26 ++++--- statediff/testhelpers/mocks/service.go | 2 +- statediff/types/types.go | 6 ++ 7 files changed, 59 insertions(+), 128 deletions(-) delete mode 100644 statediff/api_test.go diff --git a/statediff/api.go b/statediff/api.go index d6cdfd008..ed9cc3c06 100644 --- a/statediff/api.go +++ b/statediff/api.go @@ -151,6 +151,6 @@ func (api *PublicStateDiffAPI) WriteStateDiffFor(ctx context.Context, blockHash } // WatchAddress changes the list of watched addresses to which the direct indexing is restricted according to given operation -func (api *PublicStateDiffAPI) WatchAddress(operation OperationType, addresses []common.Address) error { - return api.sds.WatchAddress(operation, addresses) +func (api *PublicStateDiffAPI) WatchAddress(operation OperationType, args []WatchAddressArg) error { + return api.sds.WatchAddress(operation, args) } diff --git a/statediff/api_test.go b/statediff/api_test.go deleted file mode 100644 index 1e2bc2898..000000000 --- a/statediff/api_test.go +++ /dev/null @@ -1,102 +0,0 @@ -// Copyright 2019 The go-ethereum Authors -// This file is part of the go-ethereum library. -// -// The go-ethereum library is free software: you can redistribute it and/or modify -// it under the terms of the GNU Lesser General Public License as published by -// the Free Software Foundation, either version 3 of the License, or -// (at your option) any later version. -// -// The go-ethereum library is distributed in the hope that it will be useful, -// but WITHOUT ANY WARRANTY; without even the implied warranty of -// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -// GNU Lesser General Public License for more details. -// -// You should have received a copy of the GNU Lesser General Public License -// along with the go-ethereum library. If not, see . - -package statediff_test - -import ( - "fmt" - "os" - "reflect" - "testing" - - "github.com/ethereum/go-ethereum/common" - "github.com/ethereum/go-ethereum/statediff" -) - -func init() { - if os.Getenv("MODE") != "statediff" { - fmt.Println("Skipping statediff test") - os.Exit(0) - } -} - -var ( - address1Hex = "0x1ca7c995f8eF0A2989BbcE08D5B7Efe50A584aa1" - address2Hex = "0xe799eE0191652c864E49F3A3344CE62535B15afe" - address1 = common.HexToAddress(address1Hex) - address2 = common.HexToAddress(address2Hex) - - watchedAddresses0 []common.Address - watchedAddresses1 = []common.Address{address1} - watchedAddresses2 = []common.Address{address1, address2} - - expectedError = fmt.Errorf("Address %s already watched", address1Hex) - - service = statediff.Service{} -) - -// TODO: Update tests for the updated API -func TestWatchAddress(t *testing.T) { - watchedAddresses := service.GetWathchedAddresses() - if !reflect.DeepEqual(watchedAddresses, watchedAddresses0) { - t.Error("Test failure:", t.Name()) - t.Logf("Actual watched addresses not equal expected watched addresses.\nactual: %+v\nexpected: %+v", watchedAddresses, watchedAddresses0) - } - - testWatchUnwatchedAddress(t) - testWatchWatchedAddress(t) -} - -func testWatchUnwatchedAddress(t *testing.T) { - err := service.WatchAddress(statediff.Add, []common.Address{address1}) - if err != nil { - t.Error("Test failure:", t.Name()) - t.Logf("Unexpected error %s thrown on an attempt to watch an unwatched address.", err.Error()) - } - watchedAddresses := service.GetWathchedAddresses() - if !reflect.DeepEqual(watchedAddresses, watchedAddresses1) { - t.Error("Test failure:", t.Name()) - t.Logf("Actual watched addresses not equal expected watched addresses.\nactual: %+v\nexpected: %+v", watchedAddresses, watchedAddresses1) - } - - err = service.WatchAddress(statediff.Add, []common.Address{address2}) - if err != nil { - t.Error("Test failure:", t.Name()) - t.Logf("Unexpected error %s thrown on an attempt to watch an unwatched address.", err.Error()) - } - watchedAddresses = service.GetWathchedAddresses() - if !reflect.DeepEqual(watchedAddresses, watchedAddresses2) { - t.Error("Test failure:", t.Name()) - t.Logf("Actual watched addresses not equal expected watched addresses.\nactual: %+v\nexpected: %+v", watchedAddresses, watchedAddresses2) - } -} - -func testWatchWatchedAddress(t *testing.T) { - err := service.WatchAddress(statediff.Add, []common.Address{address1}) - if err == nil { - t.Error("Test failure:", t.Name()) - t.Logf("Expected error %s not thrown on an attempt to watch an already watched address.", expectedError.Error()) - } - if err.Error() != expectedError.Error() { - t.Error("Test failure:", t.Name()) - t.Logf("Actual thrown error not equal expected error.\nactual: %+v\nexpected: %+v", err.Error(), expectedError.Error()) - } - watchedAddresses := service.GetWathchedAddresses() - if !reflect.DeepEqual(watchedAddresses, watchedAddresses2) { - t.Error("Test failure:", t.Name()) - t.Logf("Actual watched addresses not equal expected watched addresses.\nactual: %+v\nexpected: %+v", watchedAddresses, watchedAddresses2) - } -} diff --git a/statediff/helpers.go b/statediff/helpers.go index 1dccbfbc2..f97f34062 100644 --- a/statediff/helpers.go +++ b/statediff/helpers.go @@ -26,6 +26,7 @@ import ( "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/statediff/indexer/postgres" + "github.com/ethereum/go-ethereum/statediff/types" ) func sortKeys(data AccountMap) []string { @@ -102,16 +103,15 @@ func loadWatchedAddresses(db *postgres.DB) error { // removeAddresses is used to remove given addresses from a list of addresses func removeAddresses(addresses []common.Address, addressesToRemove []common.Address) []common.Address { - addressesCopy := make([]common.Address, len(addresses)) - copy(addressesCopy, addresses) + filteredAddresses := []common.Address{} - for _, address := range addressesToRemove { - if idx := containsAddress(addressesCopy, address); idx != -1 { - addressesCopy = append(addressesCopy[:idx], addressesCopy[idx+1:]...) + for _, address := range addresses { + if idx := containsAddress(addressesToRemove, address); idx == -1 { + filteredAddresses = append(filteredAddresses, address) } } - return addressesCopy + return filteredAddresses } // containsAddress is used to check if an address is present in the provided list of addresses @@ -124,3 +124,28 @@ func containsAddress(addresses []common.Address, address common.Address) int { } return -1 } + +// getArgAddresses is used to get the list of addresses from a list of WatchAddressArgs +func getArgAddresses(args []types.WatchAddressArg) []common.Address { + addresses := make([]common.Address, len(args)) + for idx, arg := range args { + addresses[idx] = arg.Address + } + + return addresses +} + +// filterArgs filters out the args having an address from a given list of addresses +func filterArgs(args []types.WatchAddressArg, addressesToRemove []common.Address) ([]types.WatchAddressArg, []common.Address) { + filteredArgs := []types.WatchAddressArg{} + filteredAddresses := []common.Address{} + + for _, arg := range args { + if idx := containsAddress(addressesToRemove, arg.Address); idx == -1 { + filteredArgs = append(filteredArgs, arg) + filteredAddresses = append(filteredAddresses, arg.Address) + } + } + + return filteredArgs, filteredAddresses +} diff --git a/statediff/indexer/indexer.go b/statediff/indexer/indexer.go index fd3beec73..cda12e4bb 100644 --- a/statediff/indexer/indexer.go +++ b/statediff/indexer/indexer.go @@ -59,7 +59,7 @@ type Indexer interface { PushStateNode(tx *BlockTx, stateNode sdtypes.StateNode) error PushCodeAndCodeHash(tx *BlockTx, codeAndCodeHash sdtypes.CodeAndCodeHash) error ReportDBMetrics(delay time.Duration, quit <-chan bool) - InsertWatchedAddresses(addresses []common.Address, currentBlock *big.Int) error + InsertWatchedAddresses(addresses []sdtypes.WatchAddressArg, currentBlock *big.Int) error RemoveWatchedAddresses(addresses []common.Address) error ClearWatchedAddresses() error } @@ -554,15 +554,15 @@ func (sdi *StateDiffIndexer) PushCodeAndCodeHash(tx *BlockTx, codeAndCodeHash sd } // InsertWatchedAddresses inserts the given addresses in the database -func (sdi *StateDiffIndexer) InsertWatchedAddresses(addresses []common.Address, currentBlock *big.Int) error { +func (sdi *StateDiffIndexer) InsertWatchedAddresses(args []sdtypes.WatchAddressArg, currentBlockNumber *big.Int) error { tx, err := sdi.dbWriter.db.Begin() if err != nil { return err } - for _, address := range addresses { - _, err = tx.Exec(`INSERT INTO eth.watched_addresses (address, added_at)VALUES ($1, $2) ON CONFLICT (address) DO NOTHING`, - address.Hex(), currentBlock.Uint64()) + for _, arg := range args { + _, err = tx.Exec(`INSERT INTO eth.watched_addresses (address, created_at, watched_at)VALUES ($1, $2, $3) ON CONFLICT (address) DO NOTHING`, + arg.Address.Hex(), arg.CreatedAt, currentBlockNumber.Uint64()) if err != nil { return fmt.Errorf("error inserting watched_addresses entry: %v", err) } diff --git a/statediff/service.go b/statediff/service.go index 446388f66..cc0059f00 100644 --- a/statediff/service.go +++ b/statediff/service.go @@ -106,7 +106,7 @@ type IService interface { // Event loop for progressively processing and writing diffs directly to DB WriteLoop(chainEventCh chan core.ChainEvent) // Method to change the addresses being watched in write loop params - WatchAddress(operation OperationType, addresses []common.Address) error + WatchAddress(operation OperationType, args []WatchAddressArg) error // Method to get currently watched addresses from write loop params GetWathchedAddresses() []common.Address } @@ -735,39 +735,40 @@ func (sds *Service) writeStateDiffWithRetry(block *types.Block, parentRoot commo } // Performs Add | Remove | Set | Clear operation on the watched addresses in writeLoopParams and the db with provided addresses -func (sds *Service) WatchAddress(operation OperationType, addresses []common.Address) error { +func (sds *Service) WatchAddress(operation OperationType, args []WatchAddressArg) error { // lock writeLoopParams for a write writeLoopParams.mu.Lock() defer writeLoopParams.mu.Unlock() // get the current block number - currentBlock := sds.BlockChain.CurrentBlock() - currentBlockNumber := currentBlock.Number() + currentBlockNumber := sds.BlockChain.CurrentBlock().Number() switch operation { case Add: addressesToRemove := []common.Address{} - for _, address := range addresses { + for _, arg := range args { // Check if address is already being watched // Throw a warning and continue if found - if containsAddress(writeLoopParams.WatchedAddresses, address) != -1 { + if containsAddress(writeLoopParams.WatchedAddresses, arg.Address) != -1 { // log.Warn(fmt.Sprint("Address ", address.Hex(), " already being watched")) - log.Warn("Address already being watched", "address", address.Hex()) - addressesToRemove = append(addressesToRemove, address) + log.Warn("Address already being watched", "address", arg.Address.Hex()) + addressesToRemove = append(addressesToRemove, arg.Address) continue } } // remove already watched addresses - addresses = removeAddresses(addresses, addressesToRemove) + filteredArgs, filteredAddresses := filterArgs(args, addressesToRemove) - err := sds.indexer.InsertWatchedAddresses(addresses, currentBlockNumber) + err := sds.indexer.InsertWatchedAddresses(filteredArgs, currentBlockNumber) if err != nil { return err } - writeLoopParams.WatchedAddresses = append(writeLoopParams.WatchedAddresses, addresses...) + writeLoopParams.WatchedAddresses = append(writeLoopParams.WatchedAddresses, filteredAddresses...) case Remove: + addresses := getArgAddresses(args) + err := sds.indexer.RemoveWatchedAddresses(addresses) if err != nil { return err @@ -780,11 +781,12 @@ func (sds *Service) WatchAddress(operation OperationType, addresses []common.Add return err } - err = sds.indexer.InsertWatchedAddresses(addresses, currentBlockNumber) + err = sds.indexer.InsertWatchedAddresses(args, currentBlockNumber) if err != nil { return err } + addresses := getArgAddresses(args) writeLoopParams.WatchedAddresses = addresses case Clear: err := sds.indexer.ClearWatchedAddresses() diff --git a/statediff/testhelpers/mocks/service.go b/statediff/testhelpers/mocks/service.go index 943720094..e2f4d3cb9 100644 --- a/statediff/testhelpers/mocks/service.go +++ b/statediff/testhelpers/mocks/service.go @@ -333,7 +333,7 @@ func sendNonBlockingQuit(id rpc.ID, sub statediff.Subscription) { } } -func (sds *MockStateDiffService) WatchAddress(operation statediff.OperationType, addresses []common.Address) error { +func (sds *MockStateDiffService) WatchAddress(operation statediff.OperationType, addresses []sdtypes.WatchAddressArg) error { return nil } diff --git a/statediff/types/types.go b/statediff/types/types.go index 56babfb5b..cab4cb880 100644 --- a/statediff/types/types.go +++ b/statediff/types/types.go @@ -74,3 +74,9 @@ type CodeAndCodeHash struct { type StateNodeSink func(StateNode) error type StorageNodeSink func(StorageNode) error type CodeSink func(CodeAndCodeHash) error + +// WatchAddressArg is a arg type for WatchAddress API +type WatchAddressArg struct { + Address common.Address + CreatedAt uint64 +}