Add a creation block arg in the watch address API
This commit is contained in:
parent
bbb1759886
commit
d7704d2f98
@ -151,6 +151,6 @@ func (api *PublicStateDiffAPI) WriteStateDiffFor(ctx context.Context, blockHash
|
||||
}
|
||||
|
||||
// WatchAddress changes the list of watched addresses to which the direct indexing is restricted according to given operation
|
||||
func (api *PublicStateDiffAPI) WatchAddress(operation OperationType, addresses []common.Address) error {
|
||||
return api.sds.WatchAddress(operation, addresses)
|
||||
func (api *PublicStateDiffAPI) WatchAddress(operation OperationType, args []WatchAddressArg) error {
|
||||
return api.sds.WatchAddress(operation, args)
|
||||
}
|
||||
|
@ -1,102 +0,0 @@
|
||||
// Copyright 2019 The go-ethereum Authors
|
||||
// This file is part of the go-ethereum library.
|
||||
//
|
||||
// The go-ethereum library is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU Lesser General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// The go-ethereum library is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU Lesser General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU Lesser General Public License
|
||||
// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
|
||||
|
||||
package statediff_test
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"github.com/ethereum/go-ethereum/common"
|
||||
"github.com/ethereum/go-ethereum/statediff"
|
||||
)
|
||||
|
||||
func init() {
|
||||
if os.Getenv("MODE") != "statediff" {
|
||||
fmt.Println("Skipping statediff test")
|
||||
os.Exit(0)
|
||||
}
|
||||
}
|
||||
|
||||
var (
|
||||
address1Hex = "0x1ca7c995f8eF0A2989BbcE08D5B7Efe50A584aa1"
|
||||
address2Hex = "0xe799eE0191652c864E49F3A3344CE62535B15afe"
|
||||
address1 = common.HexToAddress(address1Hex)
|
||||
address2 = common.HexToAddress(address2Hex)
|
||||
|
||||
watchedAddresses0 []common.Address
|
||||
watchedAddresses1 = []common.Address{address1}
|
||||
watchedAddresses2 = []common.Address{address1, address2}
|
||||
|
||||
expectedError = fmt.Errorf("Address %s already watched", address1Hex)
|
||||
|
||||
service = statediff.Service{}
|
||||
)
|
||||
|
||||
// TODO: Update tests for the updated API
|
||||
func TestWatchAddress(t *testing.T) {
|
||||
watchedAddresses := service.GetWathchedAddresses()
|
||||
if !reflect.DeepEqual(watchedAddresses, watchedAddresses0) {
|
||||
t.Error("Test failure:", t.Name())
|
||||
t.Logf("Actual watched addresses not equal expected watched addresses.\nactual: %+v\nexpected: %+v", watchedAddresses, watchedAddresses0)
|
||||
}
|
||||
|
||||
testWatchUnwatchedAddress(t)
|
||||
testWatchWatchedAddress(t)
|
||||
}
|
||||
|
||||
func testWatchUnwatchedAddress(t *testing.T) {
|
||||
err := service.WatchAddress(statediff.Add, []common.Address{address1})
|
||||
if err != nil {
|
||||
t.Error("Test failure:", t.Name())
|
||||
t.Logf("Unexpected error %s thrown on an attempt to watch an unwatched address.", err.Error())
|
||||
}
|
||||
watchedAddresses := service.GetWathchedAddresses()
|
||||
if !reflect.DeepEqual(watchedAddresses, watchedAddresses1) {
|
||||
t.Error("Test failure:", t.Name())
|
||||
t.Logf("Actual watched addresses not equal expected watched addresses.\nactual: %+v\nexpected: %+v", watchedAddresses, watchedAddresses1)
|
||||
}
|
||||
|
||||
err = service.WatchAddress(statediff.Add, []common.Address{address2})
|
||||
if err != nil {
|
||||
t.Error("Test failure:", t.Name())
|
||||
t.Logf("Unexpected error %s thrown on an attempt to watch an unwatched address.", err.Error())
|
||||
}
|
||||
watchedAddresses = service.GetWathchedAddresses()
|
||||
if !reflect.DeepEqual(watchedAddresses, watchedAddresses2) {
|
||||
t.Error("Test failure:", t.Name())
|
||||
t.Logf("Actual watched addresses not equal expected watched addresses.\nactual: %+v\nexpected: %+v", watchedAddresses, watchedAddresses2)
|
||||
}
|
||||
}
|
||||
|
||||
func testWatchWatchedAddress(t *testing.T) {
|
||||
err := service.WatchAddress(statediff.Add, []common.Address{address1})
|
||||
if err == nil {
|
||||
t.Error("Test failure:", t.Name())
|
||||
t.Logf("Expected error %s not thrown on an attempt to watch an already watched address.", expectedError.Error())
|
||||
}
|
||||
if err.Error() != expectedError.Error() {
|
||||
t.Error("Test failure:", t.Name())
|
||||
t.Logf("Actual thrown error not equal expected error.\nactual: %+v\nexpected: %+v", err.Error(), expectedError.Error())
|
||||
}
|
||||
watchedAddresses := service.GetWathchedAddresses()
|
||||
if !reflect.DeepEqual(watchedAddresses, watchedAddresses2) {
|
||||
t.Error("Test failure:", t.Name())
|
||||
t.Logf("Actual watched addresses not equal expected watched addresses.\nactual: %+v\nexpected: %+v", watchedAddresses, watchedAddresses2)
|
||||
}
|
||||
}
|
@ -26,6 +26,7 @@ import (
|
||||
|
||||
"github.com/ethereum/go-ethereum/common"
|
||||
"github.com/ethereum/go-ethereum/statediff/indexer/postgres"
|
||||
"github.com/ethereum/go-ethereum/statediff/types"
|
||||
)
|
||||
|
||||
func sortKeys(data AccountMap) []string {
|
||||
@ -102,16 +103,15 @@ func loadWatchedAddresses(db *postgres.DB) error {
|
||||
|
||||
// removeAddresses is used to remove given addresses from a list of addresses
|
||||
func removeAddresses(addresses []common.Address, addressesToRemove []common.Address) []common.Address {
|
||||
addressesCopy := make([]common.Address, len(addresses))
|
||||
copy(addressesCopy, addresses)
|
||||
filteredAddresses := []common.Address{}
|
||||
|
||||
for _, address := range addressesToRemove {
|
||||
if idx := containsAddress(addressesCopy, address); idx != -1 {
|
||||
addressesCopy = append(addressesCopy[:idx], addressesCopy[idx+1:]...)
|
||||
for _, address := range addresses {
|
||||
if idx := containsAddress(addressesToRemove, address); idx == -1 {
|
||||
filteredAddresses = append(filteredAddresses, address)
|
||||
}
|
||||
}
|
||||
|
||||
return addressesCopy
|
||||
return filteredAddresses
|
||||
}
|
||||
|
||||
// containsAddress is used to check if an address is present in the provided list of addresses
|
||||
@ -124,3 +124,28 @@ func containsAddress(addresses []common.Address, address common.Address) int {
|
||||
}
|
||||
return -1
|
||||
}
|
||||
|
||||
// getArgAddresses is used to get the list of addresses from a list of WatchAddressArgs
|
||||
func getArgAddresses(args []types.WatchAddressArg) []common.Address {
|
||||
addresses := make([]common.Address, len(args))
|
||||
for idx, arg := range args {
|
||||
addresses[idx] = arg.Address
|
||||
}
|
||||
|
||||
return addresses
|
||||
}
|
||||
|
||||
// filterArgs filters out the args having an address from a given list of addresses
|
||||
func filterArgs(args []types.WatchAddressArg, addressesToRemove []common.Address) ([]types.WatchAddressArg, []common.Address) {
|
||||
filteredArgs := []types.WatchAddressArg{}
|
||||
filteredAddresses := []common.Address{}
|
||||
|
||||
for _, arg := range args {
|
||||
if idx := containsAddress(addressesToRemove, arg.Address); idx == -1 {
|
||||
filteredArgs = append(filteredArgs, arg)
|
||||
filteredAddresses = append(filteredAddresses, arg.Address)
|
||||
}
|
||||
}
|
||||
|
||||
return filteredArgs, filteredAddresses
|
||||
}
|
||||
|
@ -59,7 +59,7 @@ type Indexer interface {
|
||||
PushStateNode(tx *BlockTx, stateNode sdtypes.StateNode) error
|
||||
PushCodeAndCodeHash(tx *BlockTx, codeAndCodeHash sdtypes.CodeAndCodeHash) error
|
||||
ReportDBMetrics(delay time.Duration, quit <-chan bool)
|
||||
InsertWatchedAddresses(addresses []common.Address, currentBlock *big.Int) error
|
||||
InsertWatchedAddresses(addresses []sdtypes.WatchAddressArg, currentBlock *big.Int) error
|
||||
RemoveWatchedAddresses(addresses []common.Address) error
|
||||
ClearWatchedAddresses() error
|
||||
}
|
||||
@ -554,15 +554,15 @@ func (sdi *StateDiffIndexer) PushCodeAndCodeHash(tx *BlockTx, codeAndCodeHash sd
|
||||
}
|
||||
|
||||
// InsertWatchedAddresses inserts the given addresses in the database
|
||||
func (sdi *StateDiffIndexer) InsertWatchedAddresses(addresses []common.Address, currentBlock *big.Int) error {
|
||||
func (sdi *StateDiffIndexer) InsertWatchedAddresses(args []sdtypes.WatchAddressArg, currentBlockNumber *big.Int) error {
|
||||
tx, err := sdi.dbWriter.db.Begin()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, address := range addresses {
|
||||
_, err = tx.Exec(`INSERT INTO eth.watched_addresses (address, added_at)VALUES ($1, $2) ON CONFLICT (address) DO NOTHING`,
|
||||
address.Hex(), currentBlock.Uint64())
|
||||
for _, arg := range args {
|
||||
_, err = tx.Exec(`INSERT INTO eth.watched_addresses (address, created_at, watched_at)VALUES ($1, $2, $3) ON CONFLICT (address) DO NOTHING`,
|
||||
arg.Address.Hex(), arg.CreatedAt, currentBlockNumber.Uint64())
|
||||
if err != nil {
|
||||
return fmt.Errorf("error inserting watched_addresses entry: %v", err)
|
||||
}
|
||||
|
@ -106,7 +106,7 @@ type IService interface {
|
||||
// Event loop for progressively processing and writing diffs directly to DB
|
||||
WriteLoop(chainEventCh chan core.ChainEvent)
|
||||
// Method to change the addresses being watched in write loop params
|
||||
WatchAddress(operation OperationType, addresses []common.Address) error
|
||||
WatchAddress(operation OperationType, args []WatchAddressArg) error
|
||||
// Method to get currently watched addresses from write loop params
|
||||
GetWathchedAddresses() []common.Address
|
||||
}
|
||||
@ -735,39 +735,40 @@ func (sds *Service) writeStateDiffWithRetry(block *types.Block, parentRoot commo
|
||||
}
|
||||
|
||||
// Performs Add | Remove | Set | Clear operation on the watched addresses in writeLoopParams and the db with provided addresses
|
||||
func (sds *Service) WatchAddress(operation OperationType, addresses []common.Address) error {
|
||||
func (sds *Service) WatchAddress(operation OperationType, args []WatchAddressArg) error {
|
||||
// lock writeLoopParams for a write
|
||||
writeLoopParams.mu.Lock()
|
||||
defer writeLoopParams.mu.Unlock()
|
||||
|
||||
// get the current block number
|
||||
currentBlock := sds.BlockChain.CurrentBlock()
|
||||
currentBlockNumber := currentBlock.Number()
|
||||
currentBlockNumber := sds.BlockChain.CurrentBlock().Number()
|
||||
|
||||
switch operation {
|
||||
case Add:
|
||||
addressesToRemove := []common.Address{}
|
||||
for _, address := range addresses {
|
||||
for _, arg := range args {
|
||||
// Check if address is already being watched
|
||||
// Throw a warning and continue if found
|
||||
if containsAddress(writeLoopParams.WatchedAddresses, address) != -1 {
|
||||
if containsAddress(writeLoopParams.WatchedAddresses, arg.Address) != -1 {
|
||||
// log.Warn(fmt.Sprint("Address ", address.Hex(), " already being watched"))
|
||||
log.Warn("Address already being watched", "address", address.Hex())
|
||||
addressesToRemove = append(addressesToRemove, address)
|
||||
log.Warn("Address already being watched", "address", arg.Address.Hex())
|
||||
addressesToRemove = append(addressesToRemove, arg.Address)
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
// remove already watched addresses
|
||||
addresses = removeAddresses(addresses, addressesToRemove)
|
||||
filteredArgs, filteredAddresses := filterArgs(args, addressesToRemove)
|
||||
|
||||
err := sds.indexer.InsertWatchedAddresses(addresses, currentBlockNumber)
|
||||
err := sds.indexer.InsertWatchedAddresses(filteredArgs, currentBlockNumber)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
writeLoopParams.WatchedAddresses = append(writeLoopParams.WatchedAddresses, addresses...)
|
||||
writeLoopParams.WatchedAddresses = append(writeLoopParams.WatchedAddresses, filteredAddresses...)
|
||||
case Remove:
|
||||
addresses := getArgAddresses(args)
|
||||
|
||||
err := sds.indexer.RemoveWatchedAddresses(addresses)
|
||||
if err != nil {
|
||||
return err
|
||||
@ -780,11 +781,12 @@ func (sds *Service) WatchAddress(operation OperationType, addresses []common.Add
|
||||
return err
|
||||
}
|
||||
|
||||
err = sds.indexer.InsertWatchedAddresses(addresses, currentBlockNumber)
|
||||
err = sds.indexer.InsertWatchedAddresses(args, currentBlockNumber)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
addresses := getArgAddresses(args)
|
||||
writeLoopParams.WatchedAddresses = addresses
|
||||
case Clear:
|
||||
err := sds.indexer.ClearWatchedAddresses()
|
||||
|
@ -333,7 +333,7 @@ func sendNonBlockingQuit(id rpc.ID, sub statediff.Subscription) {
|
||||
}
|
||||
}
|
||||
|
||||
func (sds *MockStateDiffService) WatchAddress(operation statediff.OperationType, addresses []common.Address) error {
|
||||
func (sds *MockStateDiffService) WatchAddress(operation statediff.OperationType, addresses []sdtypes.WatchAddressArg) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -74,3 +74,9 @@ type CodeAndCodeHash struct {
|
||||
type StateNodeSink func(StateNode) error
|
||||
type StorageNodeSink func(StorageNode) error
|
||||
type CodeSink func(CodeAndCodeHash) error
|
||||
|
||||
// WatchAddressArg is a arg type for WatchAddress API
|
||||
type WatchAddressArg struct {
|
||||
Address common.Address
|
||||
CreatedAt uint64
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user