From 7f38afe542a7831ac07ff94db5c5facf7f825088 Mon Sep 17 00:00:00 2001 From: prathamesh0 Date: Wed, 16 Mar 2022 19:25:10 +0530 Subject: [PATCH] Add tests for the API to change addresses being watched --- statediff/service.go | 10 +- statediff/service_test.go | 5 + statediff/test_helpers/mocks/blockchain.go | 11 + statediff/test_helpers/mocks/indexer.go | 70 +++++ statediff/test_helpers/mocks/service.go | 104 +++++++ statediff/test_helpers/mocks/service_test.go | 285 +++++++++++++++++++ 6 files changed, 480 insertions(+), 5 deletions(-) create mode 100644 statediff/test_helpers/mocks/indexer.go diff --git a/statediff/service.go b/statediff/service.go index ce36c86d4..2a6c3d45c 100644 --- a/statediff/service.go +++ b/statediff/service.go @@ -889,7 +889,7 @@ func (sds *Service) WatchAddress(operation types2.OperationType, args []types2.W } // get addresses from the filtered args - filteredAddresses, err := mapWatchAddressArgsToAddresses(filteredArgs) + filteredAddresses, err := MapWatchAddressArgsToAddresses(filteredArgs) if err != nil { return fmt.Errorf("add: filtered addresses %s", err.Error()) } @@ -907,7 +907,7 @@ func (sds *Service) WatchAddress(operation types2.OperationType, args []types2.W }) case types2.Remove: // get addresses from args - argAddresses, err := mapWatchAddressArgsToAddresses(args) + argAddresses, err := MapWatchAddressArgsToAddresses(args) if err != nil { return fmt.Errorf("remove: mapped addresses %s", err.Error()) } @@ -931,7 +931,7 @@ func (sds *Service) WatchAddress(operation types2.OperationType, args []types2.W }) case types2.Set: // get addresses from args - argAddresses, err := mapWatchAddressArgsToAddresses(args) + argAddresses, err := MapWatchAddressArgsToAddresses(args) if err != nil { return fmt.Errorf("set: mapped addresses %s", err.Error()) } @@ -979,8 +979,8 @@ func loadWatchedAddresses(indexer interfaces.StateDiffIndexer) error { return nil } -// mapWatchAddressArgsToAddresses maps []WatchAddressArg to corresponding []common.Address -func mapWatchAddressArgsToAddresses(args []types2.WatchAddressArg) ([]common.Address, error) { +// MapWatchAddressArgsToAddresses maps []WatchAddressArg to corresponding []common.Address +func MapWatchAddressArgsToAddresses(args []types2.WatchAddressArg) ([]common.Address, error) { addresses, ok := funk.Map(args, func(arg types2.WatchAddressArg) common.Address { return common.HexToAddress(arg.Address) }).([]common.Address) diff --git a/statediff/service_test.go b/statediff/service_test.go index 96be2da1b..987e1b467 100644 --- a/statediff/service_test.go +++ b/statediff/service_test.go @@ -146,6 +146,7 @@ func testErrorInChainEventLoop(t *testing.T) { } } + defaultParams.ComputeWatchedAddressesLeafKeys() if !reflect.DeepEqual(builder.Params, defaultParams) { t.Error("Test failure:", t.Name()) t.Logf("Actual params does not equal expected.\nactual:%+v\nexpected: %+v", builder.Params, defaultParams) @@ -197,6 +198,8 @@ func testErrorInBlockLoop(t *testing.T) { } }() service.Loop(eventsChannel) + + defaultParams.ComputeWatchedAddressesLeafKeys() if !reflect.DeepEqual(builder.Params, defaultParams) { t.Error("Test failure:", t.Name()) t.Logf("Actual params does not equal expected.\nactual:%+v\nexpected: %+v", builder.Params, defaultParams) @@ -270,6 +273,8 @@ func testErrorInStateDiffAt(t *testing.T) { if err != nil { t.Error(err) } + + defaultParams.ComputeWatchedAddressesLeafKeys() if !reflect.DeepEqual(builder.Params, defaultParams) { t.Error("Test failure:", t.Name()) t.Logf("Actual params does not equal expected.\nactual:%+v\nexpected: %+v", builder.Params, defaultParams) diff --git a/statediff/test_helpers/mocks/blockchain.go b/statediff/test_helpers/mocks/blockchain.go index b4b1f3694..f2a77af64 100644 --- a/statediff/test_helpers/mocks/blockchain.go +++ b/statediff/test_helpers/mocks/blockchain.go @@ -39,6 +39,7 @@ type BlockChain struct { Receipts map[common.Hash]types.Receipts TDByHash map[common.Hash]*big.Int TDByNum map[uint64]*big.Int + currentBlock *types.Block } // SetBlocksForHashes mock method @@ -128,6 +129,16 @@ func (bc *BlockChain) GetTd(hash common.Hash, blockNum uint64) *big.Int { return nil } +// SetCurrentBlock test method +func (bc *BlockChain) SetCurrentBlock(block *types.Block) { + bc.currentBlock = block +} + +// CurrentBlock mock method +func (bc *BlockChain) CurrentBlock() *types.Block { + return bc.currentBlock +} + func (bc *BlockChain) SetTd(hash common.Hash, blockNum uint64, td *big.Int) { if bc.TDByHash == nil { bc.TDByHash = make(map[common.Hash]*big.Int) diff --git a/statediff/test_helpers/mocks/indexer.go b/statediff/test_helpers/mocks/indexer.go new file mode 100644 index 000000000..92005a8b4 --- /dev/null +++ b/statediff/test_helpers/mocks/indexer.go @@ -0,0 +1,70 @@ +// Copyright 2022 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package mocks + +import ( + "math/big" + "time" + + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/core/types" + "github.com/ethereum/go-ethereum/statediff/indexer/interfaces" + sdtypes "github.com/ethereum/go-ethereum/statediff/types" +) + +var _ interfaces.StateDiffIndexer = &StateDiffIndexer{} + +// StateDiffIndexer is a mock state diff indexer +type StateDiffIndexer struct{} + +func (sdi *StateDiffIndexer) PushBlock(block *types.Block, receipts types.Receipts, totalDifficulty *big.Int) (interfaces.Batch, error) { + return nil, nil +} + +func (sdi *StateDiffIndexer) PushStateNode(tx interfaces.Batch, stateNode sdtypes.StateNode, headerID string) error { + return nil +} + +func (sdi *StateDiffIndexer) PushCodeAndCodeHash(tx interfaces.Batch, codeAndCodeHash sdtypes.CodeAndCodeHash) error { + return nil +} + +func (sdi *StateDiffIndexer) ReportDBMetrics(delay time.Duration, quit <-chan bool) {} + +func (sdi *StateDiffIndexer) LoadWatchedAddresses() ([]common.Address, error) { + return nil, nil +} + +func (sdi *StateDiffIndexer) InsertWatchedAddresses(addresses []sdtypes.WatchAddressArg, currentBlock *big.Int) error { + return nil +} + +func (sdi *StateDiffIndexer) RemoveWatchedAddresses(addresses []sdtypes.WatchAddressArg) error { + return nil +} + +func (sdi *StateDiffIndexer) SetWatchedAddresses(args []sdtypes.WatchAddressArg, currentBlockNumber *big.Int) error { + return nil +} + +func (sdi *StateDiffIndexer) ClearWatchedAddresses() error { + return nil +} + +func (sdi *StateDiffIndexer) Close() error { + return nil +} diff --git a/statediff/test_helpers/mocks/service.go b/statediff/test_helpers/mocks/service.go index f10017df4..1ff6857dd 100644 --- a/statediff/test_helpers/mocks/service.go +++ b/statediff/test_helpers/mocks/service.go @@ -25,6 +25,7 @@ import ( "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/crypto" "github.com/ethereum/go-ethereum/rlp" + "github.com/thoas/go-funk" "github.com/ethereum/go-ethereum/core" "github.com/ethereum/go-ethereum/core/types" @@ -32,9 +33,15 @@ import ( "github.com/ethereum/go-ethereum/p2p" "github.com/ethereum/go-ethereum/rpc" "github.com/ethereum/go-ethereum/statediff" + "github.com/ethereum/go-ethereum/statediff/indexer/interfaces" sdtypes "github.com/ethereum/go-ethereum/statediff/types" ) +var ( + typeAssertionFailed = "type assertion failed" + unexpectedOperation = "unexpected operation" +) + // MockStateDiffService is a mock state diff service type MockStateDiffService struct { sync.Mutex @@ -47,6 +54,8 @@ type MockStateDiffService struct { QuitChan chan bool Subscriptions map[common.Hash]map[rpc.ID]statediff.Subscription SubscriptionTypes map[common.Hash]statediff.Params + Indexer interfaces.StateDiffIndexer + writeLoopParams statediff.ParamsWithMutex } // Protocols mock method @@ -332,3 +341,98 @@ func sendNonBlockingQuit(id rpc.ID, sub statediff.Subscription) { log.Info("unable to close subscription %s; channel has no receiver", id) } } + +// Performs one of following operations on the watched addresses in writeLoopParams and the db: +// add | remove | set | clear +func (sds *MockStateDiffService) WatchAddress(operation sdtypes.OperationType, args []sdtypes.WatchAddressArg) error { + // lock writeLoopParams for a write + sds.writeLoopParams.Lock() + defer sds.writeLoopParams.Unlock() + + // get the current block number + currentBlockNumber := sds.BlockChain.CurrentBlock().Number() + + switch operation { + case sdtypes.Add: + // filter out args having an already watched address with a warning + filteredArgs, ok := funk.Filter(args, func(arg sdtypes.WatchAddressArg) bool { + if funk.Contains(sds.writeLoopParams.WatchedAddresses, common.HexToAddress(arg.Address)) { + log.Warn("Address already being watched", "address", arg.Address) + return false + } + return true + }).([]sdtypes.WatchAddressArg) + if !ok { + return fmt.Errorf("add: filtered args %s", typeAssertionFailed) + } + + // get addresses from the filtered args + filteredAddresses, err := statediff.MapWatchAddressArgsToAddresses(filteredArgs) + if err != nil { + return fmt.Errorf("add: filtered addresses %s", err.Error()) + } + + // update the db + err = sds.Indexer.InsertWatchedAddresses(filteredArgs, currentBlockNumber) + if err != nil { + return err + } + + // update in-memory params + sds.writeLoopParams.WatchedAddresses = append(sds.writeLoopParams.WatchedAddresses, filteredAddresses...) + sds.writeLoopParams.ComputeWatchedAddressesLeafKeys() + case sdtypes.Remove: + // get addresses from args + argAddresses, err := statediff.MapWatchAddressArgsToAddresses(args) + if err != nil { + return fmt.Errorf("remove: mapped addresses %s", err.Error()) + } + + // remove the provided addresses from currently watched addresses + addresses, ok := funk.Subtract(sds.writeLoopParams.WatchedAddresses, argAddresses).([]common.Address) + if !ok { + return fmt.Errorf("remove: filtered addresses %s", typeAssertionFailed) + } + + // update the db + err = sds.Indexer.RemoveWatchedAddresses(args) + if err != nil { + return err + } + + // update in-memory params + sds.writeLoopParams.WatchedAddresses = addresses + sds.writeLoopParams.ComputeWatchedAddressesLeafKeys() + case sdtypes.Set: + // get addresses from args + argAddresses, err := statediff.MapWatchAddressArgsToAddresses(args) + if err != nil { + return fmt.Errorf("set: mapped addresses %s", err.Error()) + } + + // update the db + err = sds.Indexer.SetWatchedAddresses(args, currentBlockNumber) + if err != nil { + return err + } + + // update in-memory params + sds.writeLoopParams.WatchedAddresses = argAddresses + sds.writeLoopParams.ComputeWatchedAddressesLeafKeys() + case sdtypes.Clear: + // update the db + err := sds.Indexer.ClearWatchedAddresses() + if err != nil { + return err + } + + // update in-memory params + sds.writeLoopParams.WatchedAddresses = []common.Address{} + sds.writeLoopParams.ComputeWatchedAddressesLeafKeys() + + default: + return fmt.Errorf("%s %s", unexpectedOperation, operation) + } + + return nil +} diff --git a/statediff/test_helpers/mocks/service_test.go b/statediff/test_helpers/mocks/service_test.go index c236e1fd1..a638f1424 100644 --- a/statediff/test_helpers/mocks/service_test.go +++ b/statediff/test_helpers/mocks/service_test.go @@ -21,6 +21,7 @@ import ( "fmt" "math/big" "os" + "reflect" "sort" "sync" "testing" @@ -88,6 +89,7 @@ func init() { func TestAPI(t *testing.T) { testSubscriptionAPI(t) testHTTPAPI(t) + testWatchAddressAPI(t) } func testSubscriptionAPI(t *testing.T) { @@ -253,3 +255,286 @@ func testHTTPAPI(t *testing.T) { t.Errorf("paylaod does not have the expected total difficulty\r\nactual td: %d\r\nexpected td: %d", payload.TotalDifficulty.Int64(), mockTotalDifficulty.Int64()) } } + +func testWatchAddressAPI(t *testing.T) { + blocks, chain := test_helpers.MakeChain(6, test_helpers.Genesis, test_helpers.TestChainGen) + defer chain.Stop() + block6 := blocks[5] + + mockBlockChain := &BlockChain{} + mockBlockChain.SetCurrentBlock(block6) + mockIndexer := StateDiffIndexer{} + mockService := MockStateDiffService{ + BlockChain: mockBlockChain, + Indexer: &mockIndexer, + } + + // test data + var ( + contract1Address = "0x5d663F5269090bD2A7DC2390c911dF6083D7b28F" + contract2Address = "0x6Eb7e5C66DB8af2E96159AC440cbc8CDB7fbD26B" + contract3Address = "0xcfeB164C328CA13EFd3C77E1980d94975aDfedfc" + contract4Address = "0x0Edf0c4f393a628DE4828B228C48175b3EA297fc" + contract1CreatedAt = uint64(1) + contract2CreatedAt = uint64(2) + contract3CreatedAt = uint64(3) + contract4CreatedAt = uint64(4) + + args1 = []sdtypes.WatchAddressArg{ + { + Address: contract1Address, + CreatedAt: contract1CreatedAt, + }, + { + Address: contract2Address, + CreatedAt: contract2CreatedAt, + }, + } + startingParams1 = statediff.Params{ + WatchedAddresses: []common.Address{}, + } + expectedParams1 = statediff.Params{ + WatchedAddresses: []common.Address{ + common.HexToAddress(contract1Address), + common.HexToAddress(contract2Address), + }, + } + + args2 = []sdtypes.WatchAddressArg{ + { + Address: contract3Address, + CreatedAt: contract3CreatedAt, + }, + { + Address: contract2Address, + CreatedAt: contract2CreatedAt, + }, + } + startingParams2 = expectedParams1 + expectedParams2 = statediff.Params{ + WatchedAddresses: []common.Address{ + common.HexToAddress(contract1Address), + common.HexToAddress(contract2Address), + common.HexToAddress(contract3Address), + }, + } + + args3 = []sdtypes.WatchAddressArg{ + { + Address: contract3Address, + CreatedAt: contract3CreatedAt, + }, + { + Address: contract2Address, + CreatedAt: contract2CreatedAt, + }, + } + startingParams3 = expectedParams2 + expectedParams3 = statediff.Params{ + WatchedAddresses: []common.Address{ + common.HexToAddress(contract1Address), + }, + } + + args4 = []sdtypes.WatchAddressArg{ + { + Address: contract1Address, + CreatedAt: contract1CreatedAt, + }, + { + Address: contract2Address, + CreatedAt: contract2CreatedAt, + }, + } + startingParams4 = expectedParams3 + expectedParams4 = statediff.Params{ + WatchedAddresses: []common.Address{}, + } + + args5 = []sdtypes.WatchAddressArg{ + { + Address: contract1Address, + CreatedAt: contract1CreatedAt, + }, + { + Address: contract2Address, + CreatedAt: contract2CreatedAt, + }, + { + Address: contract3Address, + CreatedAt: contract3CreatedAt, + }, + } + startingParams5 = expectedParams4 + expectedParams5 = statediff.Params{ + WatchedAddresses: []common.Address{ + common.HexToAddress(contract1Address), + common.HexToAddress(contract2Address), + common.HexToAddress(contract3Address), + }, + } + + args6 = []sdtypes.WatchAddressArg{ + { + Address: contract4Address, + CreatedAt: contract4CreatedAt, + }, + { + Address: contract2Address, + CreatedAt: contract2CreatedAt, + }, + { + Address: contract3Address, + CreatedAt: contract3CreatedAt, + }, + } + startingParams6 = expectedParams5 + expectedParams6 = statediff.Params{ + WatchedAddresses: []common.Address{ + common.HexToAddress(contract4Address), + common.HexToAddress(contract2Address), + common.HexToAddress(contract3Address), + }, + } + + args7 = []sdtypes.WatchAddressArg{} + startingParams7 = expectedParams6 + expectedParams7 = statediff.Params{ + WatchedAddresses: []common.Address{}, + } + + args8 = []sdtypes.WatchAddressArg{} + startingParams8 = expectedParams6 + expectedParams8 = statediff.Params{ + WatchedAddresses: []common.Address{}, + } + + args9 = []sdtypes.WatchAddressArg{} + startingParams9 = expectedParams8 + expectedParams9 = statediff.Params{ + WatchedAddresses: []common.Address{}, + } + ) + + tests := []struct { + name string + operation sdtypes.OperationType + args []sdtypes.WatchAddressArg + startingParams statediff.Params + expectedParams statediff.Params + expectedErr error + }{ + { + "testAddAddresses", + sdtypes.Add, + args1, + startingParams1, + expectedParams1, + nil, + }, + { + "testAddAddressesSomeWatched", + sdtypes.Add, + args2, + startingParams2, + expectedParams2, + nil, + }, + { + "testRemoveAddresses", + sdtypes.Remove, + args3, + startingParams3, + expectedParams3, + nil, + }, + { + "testRemoveAddressesSomeWatched", + sdtypes.Remove, + args4, + startingParams4, + expectedParams4, + nil, + }, + { + "testSetAddresses", + sdtypes.Set, + args5, + startingParams5, + expectedParams5, + nil, + }, + { + "testSetAddressesSomeWatched", + sdtypes.Set, + args6, + startingParams6, + expectedParams6, + nil, + }, + { + "testSetAddressesEmtpyArgs", + sdtypes.Set, + args7, + startingParams7, + expectedParams7, + nil, + }, + { + "testClearAddresses", + sdtypes.Clear, + args8, + startingParams8, + expectedParams8, + nil, + }, + { + "testClearAddressesEmpty", + sdtypes.Clear, + args9, + startingParams9, + expectedParams9, + nil, + }, + + // invalid args + { + "testInvalidOperation", + "WrongOp", + args9, + startingParams9, + statediff.Params{}, + fmt.Errorf("%s WrongOp", unexpectedOperation), + }, + } + + for _, test := range tests { + // set indexing params + mockService.writeLoopParams = statediff.ParamsWithMutex{ + Params: test.startingParams, + } + mockService.writeLoopParams.ComputeWatchedAddressesLeafKeys() + + // make the API call to change watched addresses + err := mockService.WatchAddress(test.operation, test.args) + if test.expectedErr != nil { + if err.Error() != test.expectedErr.Error() { + t.Logf("Test failed: %s", test.name) + t.Errorf("actual err: %+v\nexpected err: %+v", err, test.expectedErr) + } + + continue + } + if err != nil { + t.Error(err) + } + + // check updated indexing params + test.expectedParams.ComputeWatchedAddressesLeafKeys() + updatedParams := mockService.writeLoopParams.Params + if !reflect.DeepEqual(updatedParams, test.expectedParams) { + t.Logf("Test failed: %s", test.name) + t.Errorf("actual params: %+v\nexpected params: %+v", updatedParams, test.expectedParams) + } + } +}