Add a creation block arg in the watch address API

This commit is contained in:
Prathamesh Musale 2022-01-18 15:05:45 +05:30
parent bbb1759886
commit d7704d2f98
7 changed files with 59 additions and 128 deletions

View File

@ -151,6 +151,6 @@ func (api *PublicStateDiffAPI) WriteStateDiffFor(ctx context.Context, blockHash
} }
// WatchAddress changes the list of watched addresses to which the direct indexing is restricted according to given operation // WatchAddress changes the list of watched addresses to which the direct indexing is restricted according to given operation
func (api *PublicStateDiffAPI) WatchAddress(operation OperationType, addresses []common.Address) error { func (api *PublicStateDiffAPI) WatchAddress(operation OperationType, args []WatchAddressArg) error {
return api.sds.WatchAddress(operation, addresses) return api.sds.WatchAddress(operation, args)
} }

View File

@ -1,102 +0,0 @@
// Copyright 2019 The go-ethereum Authors
// This file is part of the go-ethereum library.
//
// The go-ethereum library is free software: you can redistribute it and/or modify
// it under the terms of the GNU Lesser General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// The go-ethereum library is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Lesser General Public License for more details.
//
// You should have received a copy of the GNU Lesser General Public License
// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
package statediff_test
import (
"fmt"
"os"
"reflect"
"testing"
"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/statediff"
)
func init() {
if os.Getenv("MODE") != "statediff" {
fmt.Println("Skipping statediff test")
os.Exit(0)
}
}
var (
address1Hex = "0x1ca7c995f8eF0A2989BbcE08D5B7Efe50A584aa1"
address2Hex = "0xe799eE0191652c864E49F3A3344CE62535B15afe"
address1 = common.HexToAddress(address1Hex)
address2 = common.HexToAddress(address2Hex)
watchedAddresses0 []common.Address
watchedAddresses1 = []common.Address{address1}
watchedAddresses2 = []common.Address{address1, address2}
expectedError = fmt.Errorf("Address %s already watched", address1Hex)
service = statediff.Service{}
)
// TODO: Update tests for the updated API
func TestWatchAddress(t *testing.T) {
watchedAddresses := service.GetWathchedAddresses()
if !reflect.DeepEqual(watchedAddresses, watchedAddresses0) {
t.Error("Test failure:", t.Name())
t.Logf("Actual watched addresses not equal expected watched addresses.\nactual: %+v\nexpected: %+v", watchedAddresses, watchedAddresses0)
}
testWatchUnwatchedAddress(t)
testWatchWatchedAddress(t)
}
func testWatchUnwatchedAddress(t *testing.T) {
err := service.WatchAddress(statediff.Add, []common.Address{address1})
if err != nil {
t.Error("Test failure:", t.Name())
t.Logf("Unexpected error %s thrown on an attempt to watch an unwatched address.", err.Error())
}
watchedAddresses := service.GetWathchedAddresses()
if !reflect.DeepEqual(watchedAddresses, watchedAddresses1) {
t.Error("Test failure:", t.Name())
t.Logf("Actual watched addresses not equal expected watched addresses.\nactual: %+v\nexpected: %+v", watchedAddresses, watchedAddresses1)
}
err = service.WatchAddress(statediff.Add, []common.Address{address2})
if err != nil {
t.Error("Test failure:", t.Name())
t.Logf("Unexpected error %s thrown on an attempt to watch an unwatched address.", err.Error())
}
watchedAddresses = service.GetWathchedAddresses()
if !reflect.DeepEqual(watchedAddresses, watchedAddresses2) {
t.Error("Test failure:", t.Name())
t.Logf("Actual watched addresses not equal expected watched addresses.\nactual: %+v\nexpected: %+v", watchedAddresses, watchedAddresses2)
}
}
func testWatchWatchedAddress(t *testing.T) {
err := service.WatchAddress(statediff.Add, []common.Address{address1})
if err == nil {
t.Error("Test failure:", t.Name())
t.Logf("Expected error %s not thrown on an attempt to watch an already watched address.", expectedError.Error())
}
if err.Error() != expectedError.Error() {
t.Error("Test failure:", t.Name())
t.Logf("Actual thrown error not equal expected error.\nactual: %+v\nexpected: %+v", err.Error(), expectedError.Error())
}
watchedAddresses := service.GetWathchedAddresses()
if !reflect.DeepEqual(watchedAddresses, watchedAddresses2) {
t.Error("Test failure:", t.Name())
t.Logf("Actual watched addresses not equal expected watched addresses.\nactual: %+v\nexpected: %+v", watchedAddresses, watchedAddresses2)
}
}

View File

@ -26,6 +26,7 @@ import (
"github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/statediff/indexer/postgres" "github.com/ethereum/go-ethereum/statediff/indexer/postgres"
"github.com/ethereum/go-ethereum/statediff/types"
) )
func sortKeys(data AccountMap) []string { func sortKeys(data AccountMap) []string {
@ -102,16 +103,15 @@ func loadWatchedAddresses(db *postgres.DB) error {
// removeAddresses is used to remove given addresses from a list of addresses // removeAddresses is used to remove given addresses from a list of addresses
func removeAddresses(addresses []common.Address, addressesToRemove []common.Address) []common.Address { func removeAddresses(addresses []common.Address, addressesToRemove []common.Address) []common.Address {
addressesCopy := make([]common.Address, len(addresses)) filteredAddresses := []common.Address{}
copy(addressesCopy, addresses)
for _, address := range addressesToRemove { for _, address := range addresses {
if idx := containsAddress(addressesCopy, address); idx != -1 { if idx := containsAddress(addressesToRemove, address); idx == -1 {
addressesCopy = append(addressesCopy[:idx], addressesCopy[idx+1:]...) filteredAddresses = append(filteredAddresses, address)
} }
} }
return addressesCopy return filteredAddresses
} }
// containsAddress is used to check if an address is present in the provided list of addresses // containsAddress is used to check if an address is present in the provided list of addresses
@ -124,3 +124,28 @@ func containsAddress(addresses []common.Address, address common.Address) int {
} }
return -1 return -1
} }
// getArgAddresses is used to get the list of addresses from a list of WatchAddressArgs
func getArgAddresses(args []types.WatchAddressArg) []common.Address {
addresses := make([]common.Address, len(args))
for idx, arg := range args {
addresses[idx] = arg.Address
}
return addresses
}
// filterArgs filters out the args having an address from a given list of addresses
func filterArgs(args []types.WatchAddressArg, addressesToRemove []common.Address) ([]types.WatchAddressArg, []common.Address) {
filteredArgs := []types.WatchAddressArg{}
filteredAddresses := []common.Address{}
for _, arg := range args {
if idx := containsAddress(addressesToRemove, arg.Address); idx == -1 {
filteredArgs = append(filteredArgs, arg)
filteredAddresses = append(filteredAddresses, arg.Address)
}
}
return filteredArgs, filteredAddresses
}

View File

@ -59,7 +59,7 @@ type Indexer interface {
PushStateNode(tx *BlockTx, stateNode sdtypes.StateNode) error PushStateNode(tx *BlockTx, stateNode sdtypes.StateNode) error
PushCodeAndCodeHash(tx *BlockTx, codeAndCodeHash sdtypes.CodeAndCodeHash) error PushCodeAndCodeHash(tx *BlockTx, codeAndCodeHash sdtypes.CodeAndCodeHash) error
ReportDBMetrics(delay time.Duration, quit <-chan bool) ReportDBMetrics(delay time.Duration, quit <-chan bool)
InsertWatchedAddresses(addresses []common.Address, currentBlock *big.Int) error InsertWatchedAddresses(addresses []sdtypes.WatchAddressArg, currentBlock *big.Int) error
RemoveWatchedAddresses(addresses []common.Address) error RemoveWatchedAddresses(addresses []common.Address) error
ClearWatchedAddresses() error ClearWatchedAddresses() error
} }
@ -554,15 +554,15 @@ func (sdi *StateDiffIndexer) PushCodeAndCodeHash(tx *BlockTx, codeAndCodeHash sd
} }
// InsertWatchedAddresses inserts the given addresses in the database // InsertWatchedAddresses inserts the given addresses in the database
func (sdi *StateDiffIndexer) InsertWatchedAddresses(addresses []common.Address, currentBlock *big.Int) error { func (sdi *StateDiffIndexer) InsertWatchedAddresses(args []sdtypes.WatchAddressArg, currentBlockNumber *big.Int) error {
tx, err := sdi.dbWriter.db.Begin() tx, err := sdi.dbWriter.db.Begin()
if err != nil { if err != nil {
return err return err
} }
for _, address := range addresses { for _, arg := range args {
_, err = tx.Exec(`INSERT INTO eth.watched_addresses (address, added_at)VALUES ($1, $2) ON CONFLICT (address) DO NOTHING`, _, err = tx.Exec(`INSERT INTO eth.watched_addresses (address, created_at, watched_at)VALUES ($1, $2, $3) ON CONFLICT (address) DO NOTHING`,
address.Hex(), currentBlock.Uint64()) arg.Address.Hex(), arg.CreatedAt, currentBlockNumber.Uint64())
if err != nil { if err != nil {
return fmt.Errorf("error inserting watched_addresses entry: %v", err) return fmt.Errorf("error inserting watched_addresses entry: %v", err)
} }

View File

@ -106,7 +106,7 @@ type IService interface {
// Event loop for progressively processing and writing diffs directly to DB // Event loop for progressively processing and writing diffs directly to DB
WriteLoop(chainEventCh chan core.ChainEvent) WriteLoop(chainEventCh chan core.ChainEvent)
// Method to change the addresses being watched in write loop params // Method to change the addresses being watched in write loop params
WatchAddress(operation OperationType, addresses []common.Address) error WatchAddress(operation OperationType, args []WatchAddressArg) error
// Method to get currently watched addresses from write loop params // Method to get currently watched addresses from write loop params
GetWathchedAddresses() []common.Address GetWathchedAddresses() []common.Address
} }
@ -735,39 +735,40 @@ func (sds *Service) writeStateDiffWithRetry(block *types.Block, parentRoot commo
} }
// Performs Add | Remove | Set | Clear operation on the watched addresses in writeLoopParams and the db with provided addresses // Performs Add | Remove | Set | Clear operation on the watched addresses in writeLoopParams and the db with provided addresses
func (sds *Service) WatchAddress(operation OperationType, addresses []common.Address) error { func (sds *Service) WatchAddress(operation OperationType, args []WatchAddressArg) error {
// lock writeLoopParams for a write // lock writeLoopParams for a write
writeLoopParams.mu.Lock() writeLoopParams.mu.Lock()
defer writeLoopParams.mu.Unlock() defer writeLoopParams.mu.Unlock()
// get the current block number // get the current block number
currentBlock := sds.BlockChain.CurrentBlock() currentBlockNumber := sds.BlockChain.CurrentBlock().Number()
currentBlockNumber := currentBlock.Number()
switch operation { switch operation {
case Add: case Add:
addressesToRemove := []common.Address{} addressesToRemove := []common.Address{}
for _, address := range addresses { for _, arg := range args {
// Check if address is already being watched // Check if address is already being watched
// Throw a warning and continue if found // Throw a warning and continue if found
if containsAddress(writeLoopParams.WatchedAddresses, address) != -1 { if containsAddress(writeLoopParams.WatchedAddresses, arg.Address) != -1 {
// log.Warn(fmt.Sprint("Address ", address.Hex(), " already being watched")) // log.Warn(fmt.Sprint("Address ", address.Hex(), " already being watched"))
log.Warn("Address already being watched", "address", address.Hex()) log.Warn("Address already being watched", "address", arg.Address.Hex())
addressesToRemove = append(addressesToRemove, address) addressesToRemove = append(addressesToRemove, arg.Address)
continue continue
} }
} }
// remove already watched addresses // remove already watched addresses
addresses = removeAddresses(addresses, addressesToRemove) filteredArgs, filteredAddresses := filterArgs(args, addressesToRemove)
err := sds.indexer.InsertWatchedAddresses(addresses, currentBlockNumber) err := sds.indexer.InsertWatchedAddresses(filteredArgs, currentBlockNumber)
if err != nil { if err != nil {
return err return err
} }
writeLoopParams.WatchedAddresses = append(writeLoopParams.WatchedAddresses, addresses...) writeLoopParams.WatchedAddresses = append(writeLoopParams.WatchedAddresses, filteredAddresses...)
case Remove: case Remove:
addresses := getArgAddresses(args)
err := sds.indexer.RemoveWatchedAddresses(addresses) err := sds.indexer.RemoveWatchedAddresses(addresses)
if err != nil { if err != nil {
return err return err
@ -780,11 +781,12 @@ func (sds *Service) WatchAddress(operation OperationType, addresses []common.Add
return err return err
} }
err = sds.indexer.InsertWatchedAddresses(addresses, currentBlockNumber) err = sds.indexer.InsertWatchedAddresses(args, currentBlockNumber)
if err != nil { if err != nil {
return err return err
} }
addresses := getArgAddresses(args)
writeLoopParams.WatchedAddresses = addresses writeLoopParams.WatchedAddresses = addresses
case Clear: case Clear:
err := sds.indexer.ClearWatchedAddresses() err := sds.indexer.ClearWatchedAddresses()

View File

@ -333,7 +333,7 @@ func sendNonBlockingQuit(id rpc.ID, sub statediff.Subscription) {
} }
} }
func (sds *MockStateDiffService) WatchAddress(operation statediff.OperationType, addresses []common.Address) error { func (sds *MockStateDiffService) WatchAddress(operation statediff.OperationType, addresses []sdtypes.WatchAddressArg) error {
return nil return nil
} }

View File

@ -74,3 +74,9 @@ type CodeAndCodeHash struct {
type StateNodeSink func(StateNode) error type StateNodeSink func(StateNode) error
type StorageNodeSink func(StorageNode) error type StorageNodeSink func(StorageNode) error
type CodeSink func(CodeAndCodeHash) error type CodeSink func(CodeAndCodeHash) error
// WatchAddressArg is a arg type for WatchAddress API
type WatchAddressArg struct {
Address common.Address
CreatedAt uint64
}