Add tests for the API to change addresses being watched

This commit is contained in:
Prathamesh Musale 2022-03-16 19:25:10 +05:30
parent af4dbed9d2
commit 7f38afe542
6 changed files with 480 additions and 5 deletions

View File

@ -889,7 +889,7 @@ func (sds *Service) WatchAddress(operation types2.OperationType, args []types2.W
}
// get addresses from the filtered args
filteredAddresses, err := mapWatchAddressArgsToAddresses(filteredArgs)
filteredAddresses, err := MapWatchAddressArgsToAddresses(filteredArgs)
if err != nil {
return fmt.Errorf("add: filtered addresses %s", err.Error())
}
@ -907,7 +907,7 @@ func (sds *Service) WatchAddress(operation types2.OperationType, args []types2.W
})
case types2.Remove:
// get addresses from args
argAddresses, err := mapWatchAddressArgsToAddresses(args)
argAddresses, err := MapWatchAddressArgsToAddresses(args)
if err != nil {
return fmt.Errorf("remove: mapped addresses %s", err.Error())
}
@ -931,7 +931,7 @@ func (sds *Service) WatchAddress(operation types2.OperationType, args []types2.W
})
case types2.Set:
// get addresses from args
argAddresses, err := mapWatchAddressArgsToAddresses(args)
argAddresses, err := MapWatchAddressArgsToAddresses(args)
if err != nil {
return fmt.Errorf("set: mapped addresses %s", err.Error())
}
@ -979,8 +979,8 @@ func loadWatchedAddresses(indexer interfaces.StateDiffIndexer) error {
return nil
}
// mapWatchAddressArgsToAddresses maps []WatchAddressArg to corresponding []common.Address
func mapWatchAddressArgsToAddresses(args []types2.WatchAddressArg) ([]common.Address, error) {
// MapWatchAddressArgsToAddresses maps []WatchAddressArg to corresponding []common.Address
func MapWatchAddressArgsToAddresses(args []types2.WatchAddressArg) ([]common.Address, error) {
addresses, ok := funk.Map(args, func(arg types2.WatchAddressArg) common.Address {
return common.HexToAddress(arg.Address)
}).([]common.Address)

View File

@ -146,6 +146,7 @@ func testErrorInChainEventLoop(t *testing.T) {
}
}
defaultParams.ComputeWatchedAddressesLeafKeys()
if !reflect.DeepEqual(builder.Params, defaultParams) {
t.Error("Test failure:", t.Name())
t.Logf("Actual params does not equal expected.\nactual:%+v\nexpected: %+v", builder.Params, defaultParams)
@ -197,6 +198,8 @@ func testErrorInBlockLoop(t *testing.T) {
}
}()
service.Loop(eventsChannel)
defaultParams.ComputeWatchedAddressesLeafKeys()
if !reflect.DeepEqual(builder.Params, defaultParams) {
t.Error("Test failure:", t.Name())
t.Logf("Actual params does not equal expected.\nactual:%+v\nexpected: %+v", builder.Params, defaultParams)
@ -270,6 +273,8 @@ func testErrorInStateDiffAt(t *testing.T) {
if err != nil {
t.Error(err)
}
defaultParams.ComputeWatchedAddressesLeafKeys()
if !reflect.DeepEqual(builder.Params, defaultParams) {
t.Error("Test failure:", t.Name())
t.Logf("Actual params does not equal expected.\nactual:%+v\nexpected: %+v", builder.Params, defaultParams)

View File

@ -39,6 +39,7 @@ type BlockChain struct {
Receipts map[common.Hash]types.Receipts
TDByHash map[common.Hash]*big.Int
TDByNum map[uint64]*big.Int
currentBlock *types.Block
}
// SetBlocksForHashes mock method
@ -128,6 +129,16 @@ func (bc *BlockChain) GetTd(hash common.Hash, blockNum uint64) *big.Int {
return nil
}
// SetCurrentBlock test method
func (bc *BlockChain) SetCurrentBlock(block *types.Block) {
bc.currentBlock = block
}
// CurrentBlock mock method
func (bc *BlockChain) CurrentBlock() *types.Block {
return bc.currentBlock
}
func (bc *BlockChain) SetTd(hash common.Hash, blockNum uint64, td *big.Int) {
if bc.TDByHash == nil {
bc.TDByHash = make(map[common.Hash]*big.Int)

View File

@ -0,0 +1,70 @@
// Copyright 2022 The go-ethereum Authors
// This file is part of the go-ethereum library.
//
// The go-ethereum library is free software: you can redistribute it and/or modify
// it under the terms of the GNU Lesser General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// The go-ethereum library is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Lesser General Public License for more details.
//
// You should have received a copy of the GNU Lesser General Public License
// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
package mocks
import (
"math/big"
"time"
"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/core/types"
"github.com/ethereum/go-ethereum/statediff/indexer/interfaces"
sdtypes "github.com/ethereum/go-ethereum/statediff/types"
)
var _ interfaces.StateDiffIndexer = &StateDiffIndexer{}
// StateDiffIndexer is a mock state diff indexer
type StateDiffIndexer struct{}
func (sdi *StateDiffIndexer) PushBlock(block *types.Block, receipts types.Receipts, totalDifficulty *big.Int) (interfaces.Batch, error) {
return nil, nil
}
func (sdi *StateDiffIndexer) PushStateNode(tx interfaces.Batch, stateNode sdtypes.StateNode, headerID string) error {
return nil
}
func (sdi *StateDiffIndexer) PushCodeAndCodeHash(tx interfaces.Batch, codeAndCodeHash sdtypes.CodeAndCodeHash) error {
return nil
}
func (sdi *StateDiffIndexer) ReportDBMetrics(delay time.Duration, quit <-chan bool) {}
func (sdi *StateDiffIndexer) LoadWatchedAddresses() ([]common.Address, error) {
return nil, nil
}
func (sdi *StateDiffIndexer) InsertWatchedAddresses(addresses []sdtypes.WatchAddressArg, currentBlock *big.Int) error {
return nil
}
func (sdi *StateDiffIndexer) RemoveWatchedAddresses(addresses []sdtypes.WatchAddressArg) error {
return nil
}
func (sdi *StateDiffIndexer) SetWatchedAddresses(args []sdtypes.WatchAddressArg, currentBlockNumber *big.Int) error {
return nil
}
func (sdi *StateDiffIndexer) ClearWatchedAddresses() error {
return nil
}
func (sdi *StateDiffIndexer) Close() error {
return nil
}

View File

@ -25,6 +25,7 @@ import (
"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/crypto"
"github.com/ethereum/go-ethereum/rlp"
"github.com/thoas/go-funk"
"github.com/ethereum/go-ethereum/core"
"github.com/ethereum/go-ethereum/core/types"
@ -32,9 +33,15 @@ import (
"github.com/ethereum/go-ethereum/p2p"
"github.com/ethereum/go-ethereum/rpc"
"github.com/ethereum/go-ethereum/statediff"
"github.com/ethereum/go-ethereum/statediff/indexer/interfaces"
sdtypes "github.com/ethereum/go-ethereum/statediff/types"
)
var (
typeAssertionFailed = "type assertion failed"
unexpectedOperation = "unexpected operation"
)
// MockStateDiffService is a mock state diff service
type MockStateDiffService struct {
sync.Mutex
@ -47,6 +54,8 @@ type MockStateDiffService struct {
QuitChan chan bool
Subscriptions map[common.Hash]map[rpc.ID]statediff.Subscription
SubscriptionTypes map[common.Hash]statediff.Params
Indexer interfaces.StateDiffIndexer
writeLoopParams statediff.ParamsWithMutex
}
// Protocols mock method
@ -332,3 +341,98 @@ func sendNonBlockingQuit(id rpc.ID, sub statediff.Subscription) {
log.Info("unable to close subscription %s; channel has no receiver", id)
}
}
// Performs one of following operations on the watched addresses in writeLoopParams and the db:
// add | remove | set | clear
func (sds *MockStateDiffService) WatchAddress(operation sdtypes.OperationType, args []sdtypes.WatchAddressArg) error {
// lock writeLoopParams for a write
sds.writeLoopParams.Lock()
defer sds.writeLoopParams.Unlock()
// get the current block number
currentBlockNumber := sds.BlockChain.CurrentBlock().Number()
switch operation {
case sdtypes.Add:
// filter out args having an already watched address with a warning
filteredArgs, ok := funk.Filter(args, func(arg sdtypes.WatchAddressArg) bool {
if funk.Contains(sds.writeLoopParams.WatchedAddresses, common.HexToAddress(arg.Address)) {
log.Warn("Address already being watched", "address", arg.Address)
return false
}
return true
}).([]sdtypes.WatchAddressArg)
if !ok {
return fmt.Errorf("add: filtered args %s", typeAssertionFailed)
}
// get addresses from the filtered args
filteredAddresses, err := statediff.MapWatchAddressArgsToAddresses(filteredArgs)
if err != nil {
return fmt.Errorf("add: filtered addresses %s", err.Error())
}
// update the db
err = sds.Indexer.InsertWatchedAddresses(filteredArgs, currentBlockNumber)
if err != nil {
return err
}
// update in-memory params
sds.writeLoopParams.WatchedAddresses = append(sds.writeLoopParams.WatchedAddresses, filteredAddresses...)
sds.writeLoopParams.ComputeWatchedAddressesLeafKeys()
case sdtypes.Remove:
// get addresses from args
argAddresses, err := statediff.MapWatchAddressArgsToAddresses(args)
if err != nil {
return fmt.Errorf("remove: mapped addresses %s", err.Error())
}
// remove the provided addresses from currently watched addresses
addresses, ok := funk.Subtract(sds.writeLoopParams.WatchedAddresses, argAddresses).([]common.Address)
if !ok {
return fmt.Errorf("remove: filtered addresses %s", typeAssertionFailed)
}
// update the db
err = sds.Indexer.RemoveWatchedAddresses(args)
if err != nil {
return err
}
// update in-memory params
sds.writeLoopParams.WatchedAddresses = addresses
sds.writeLoopParams.ComputeWatchedAddressesLeafKeys()
case sdtypes.Set:
// get addresses from args
argAddresses, err := statediff.MapWatchAddressArgsToAddresses(args)
if err != nil {
return fmt.Errorf("set: mapped addresses %s", err.Error())
}
// update the db
err = sds.Indexer.SetWatchedAddresses(args, currentBlockNumber)
if err != nil {
return err
}
// update in-memory params
sds.writeLoopParams.WatchedAddresses = argAddresses
sds.writeLoopParams.ComputeWatchedAddressesLeafKeys()
case sdtypes.Clear:
// update the db
err := sds.Indexer.ClearWatchedAddresses()
if err != nil {
return err
}
// update in-memory params
sds.writeLoopParams.WatchedAddresses = []common.Address{}
sds.writeLoopParams.ComputeWatchedAddressesLeafKeys()
default:
return fmt.Errorf("%s %s", unexpectedOperation, operation)
}
return nil
}

View File

@ -21,6 +21,7 @@ import (
"fmt"
"math/big"
"os"
"reflect"
"sort"
"sync"
"testing"
@ -88,6 +89,7 @@ func init() {
func TestAPI(t *testing.T) {
testSubscriptionAPI(t)
testHTTPAPI(t)
testWatchAddressAPI(t)
}
func testSubscriptionAPI(t *testing.T) {
@ -253,3 +255,286 @@ func testHTTPAPI(t *testing.T) {
t.Errorf("paylaod does not have the expected total difficulty\r\nactual td: %d\r\nexpected td: %d", payload.TotalDifficulty.Int64(), mockTotalDifficulty.Int64())
}
}
func testWatchAddressAPI(t *testing.T) {
blocks, chain := test_helpers.MakeChain(6, test_helpers.Genesis, test_helpers.TestChainGen)
defer chain.Stop()
block6 := blocks[5]
mockBlockChain := &BlockChain{}
mockBlockChain.SetCurrentBlock(block6)
mockIndexer := StateDiffIndexer{}
mockService := MockStateDiffService{
BlockChain: mockBlockChain,
Indexer: &mockIndexer,
}
// test data
var (
contract1Address = "0x5d663F5269090bD2A7DC2390c911dF6083D7b28F"
contract2Address = "0x6Eb7e5C66DB8af2E96159AC440cbc8CDB7fbD26B"
contract3Address = "0xcfeB164C328CA13EFd3C77E1980d94975aDfedfc"
contract4Address = "0x0Edf0c4f393a628DE4828B228C48175b3EA297fc"
contract1CreatedAt = uint64(1)
contract2CreatedAt = uint64(2)
contract3CreatedAt = uint64(3)
contract4CreatedAt = uint64(4)
args1 = []sdtypes.WatchAddressArg{
{
Address: contract1Address,
CreatedAt: contract1CreatedAt,
},
{
Address: contract2Address,
CreatedAt: contract2CreatedAt,
},
}
startingParams1 = statediff.Params{
WatchedAddresses: []common.Address{},
}
expectedParams1 = statediff.Params{
WatchedAddresses: []common.Address{
common.HexToAddress(contract1Address),
common.HexToAddress(contract2Address),
},
}
args2 = []sdtypes.WatchAddressArg{
{
Address: contract3Address,
CreatedAt: contract3CreatedAt,
},
{
Address: contract2Address,
CreatedAt: contract2CreatedAt,
},
}
startingParams2 = expectedParams1
expectedParams2 = statediff.Params{
WatchedAddresses: []common.Address{
common.HexToAddress(contract1Address),
common.HexToAddress(contract2Address),
common.HexToAddress(contract3Address),
},
}
args3 = []sdtypes.WatchAddressArg{
{
Address: contract3Address,
CreatedAt: contract3CreatedAt,
},
{
Address: contract2Address,
CreatedAt: contract2CreatedAt,
},
}
startingParams3 = expectedParams2
expectedParams3 = statediff.Params{
WatchedAddresses: []common.Address{
common.HexToAddress(contract1Address),
},
}
args4 = []sdtypes.WatchAddressArg{
{
Address: contract1Address,
CreatedAt: contract1CreatedAt,
},
{
Address: contract2Address,
CreatedAt: contract2CreatedAt,
},
}
startingParams4 = expectedParams3
expectedParams4 = statediff.Params{
WatchedAddresses: []common.Address{},
}
args5 = []sdtypes.WatchAddressArg{
{
Address: contract1Address,
CreatedAt: contract1CreatedAt,
},
{
Address: contract2Address,
CreatedAt: contract2CreatedAt,
},
{
Address: contract3Address,
CreatedAt: contract3CreatedAt,
},
}
startingParams5 = expectedParams4
expectedParams5 = statediff.Params{
WatchedAddresses: []common.Address{
common.HexToAddress(contract1Address),
common.HexToAddress(contract2Address),
common.HexToAddress(contract3Address),
},
}
args6 = []sdtypes.WatchAddressArg{
{
Address: contract4Address,
CreatedAt: contract4CreatedAt,
},
{
Address: contract2Address,
CreatedAt: contract2CreatedAt,
},
{
Address: contract3Address,
CreatedAt: contract3CreatedAt,
},
}
startingParams6 = expectedParams5
expectedParams6 = statediff.Params{
WatchedAddresses: []common.Address{
common.HexToAddress(contract4Address),
common.HexToAddress(contract2Address),
common.HexToAddress(contract3Address),
},
}
args7 = []sdtypes.WatchAddressArg{}
startingParams7 = expectedParams6
expectedParams7 = statediff.Params{
WatchedAddresses: []common.Address{},
}
args8 = []sdtypes.WatchAddressArg{}
startingParams8 = expectedParams6
expectedParams8 = statediff.Params{
WatchedAddresses: []common.Address{},
}
args9 = []sdtypes.WatchAddressArg{}
startingParams9 = expectedParams8
expectedParams9 = statediff.Params{
WatchedAddresses: []common.Address{},
}
)
tests := []struct {
name string
operation sdtypes.OperationType
args []sdtypes.WatchAddressArg
startingParams statediff.Params
expectedParams statediff.Params
expectedErr error
}{
{
"testAddAddresses",
sdtypes.Add,
args1,
startingParams1,
expectedParams1,
nil,
},
{
"testAddAddressesSomeWatched",
sdtypes.Add,
args2,
startingParams2,
expectedParams2,
nil,
},
{
"testRemoveAddresses",
sdtypes.Remove,
args3,
startingParams3,
expectedParams3,
nil,
},
{
"testRemoveAddressesSomeWatched",
sdtypes.Remove,
args4,
startingParams4,
expectedParams4,
nil,
},
{
"testSetAddresses",
sdtypes.Set,
args5,
startingParams5,
expectedParams5,
nil,
},
{
"testSetAddressesSomeWatched",
sdtypes.Set,
args6,
startingParams6,
expectedParams6,
nil,
},
{
"testSetAddressesEmtpyArgs",
sdtypes.Set,
args7,
startingParams7,
expectedParams7,
nil,
},
{
"testClearAddresses",
sdtypes.Clear,
args8,
startingParams8,
expectedParams8,
nil,
},
{
"testClearAddressesEmpty",
sdtypes.Clear,
args9,
startingParams9,
expectedParams9,
nil,
},
// invalid args
{
"testInvalidOperation",
"WrongOp",
args9,
startingParams9,
statediff.Params{},
fmt.Errorf("%s WrongOp", unexpectedOperation),
},
}
for _, test := range tests {
// set indexing params
mockService.writeLoopParams = statediff.ParamsWithMutex{
Params: test.startingParams,
}
mockService.writeLoopParams.ComputeWatchedAddressesLeafKeys()
// make the API call to change watched addresses
err := mockService.WatchAddress(test.operation, test.args)
if test.expectedErr != nil {
if err.Error() != test.expectedErr.Error() {
t.Logf("Test failed: %s", test.name)
t.Errorf("actual err: %+v\nexpected err: %+v", err, test.expectedErr)
}
continue
}
if err != nil {
t.Error(err)
}
// check updated indexing params
test.expectedParams.ComputeWatchedAddressesLeafKeys()
updatedParams := mockService.writeLoopParams.Params
if !reflect.DeepEqual(updatedParams, test.expectedParams) {
t.Logf("Test failed: %s", test.name)
t.Errorf("actual params: %+v\nexpected params: %+v", updatedParams, test.expectedParams)
}
}
}