Store pre-computed leaf keys for watched addresses in a map

This commit is contained in:
Prathamesh Musale 2022-02-07 12:26:33 +05:30
parent 6cb4731d8d
commit 3c6aa6a9cc
8 changed files with 108 additions and 61 deletions

View File

@ -202,7 +202,7 @@ func (sdb *builder) buildStateDiffWithIntermediateStateNodes(args StateRoots, pa
// a map of their leafkey to all the accounts that were touched and exist at A
diffAccountsAtA, err := sdb.deletedOrUpdatedState(
oldTrie.NodeIterator([]byte{}), newTrie.NodeIterator([]byte{}),
diffPathsAtB, params.WatchedAddresses, output)
diffPathsAtB, params.watchedAddressesLeafKeys, output)
if err != nil {
return fmt.Errorf("error collecting deletedOrUpdatedNodes: %v", err)
}
@ -247,7 +247,7 @@ func (sdb *builder) buildStateDiffWithoutIntermediateStateNodes(args StateRoots,
// and a slice of all the paths for the nodes in both of the above sets
diffAccountsAtB, diffPathsAtB, err := sdb.createdAndUpdatedState(
oldTrie.NodeIterator([]byte{}), newTrie.NodeIterator([]byte{}),
params.WatchedAddresses)
params.watchedAddressesLeafKeys)
if err != nil {
return fmt.Errorf("error collecting createdAndUpdatedNodes: %v", err)
}
@ -256,7 +256,7 @@ func (sdb *builder) buildStateDiffWithoutIntermediateStateNodes(args StateRoots,
// a map of their leafkey to all the accounts that were touched and exist at A
diffAccountsAtA, err := sdb.deletedOrUpdatedState(
oldTrie.NodeIterator([]byte{}), newTrie.NodeIterator([]byte{}),
diffPathsAtB, params.WatchedAddresses, output)
diffPathsAtB, params.watchedAddressesLeafKeys, output)
if err != nil {
return fmt.Errorf("error collecting deletedOrUpdatedNodes: %v", err)
}
@ -289,7 +289,7 @@ func (sdb *builder) buildStateDiffWithoutIntermediateStateNodes(args StateRoots,
// createdAndUpdatedState returns
// a mapping of their leafkeys to all the accounts that exist in a different state at B than A
// and a slice of the paths for all of the nodes included in both
func (sdb *builder) createdAndUpdatedState(a, b trie.NodeIterator, watchedAddresses []common.Address) (AccountMap, map[string]bool, error) {
func (sdb *builder) createdAndUpdatedState(a, b trie.NodeIterator, watchedAddressesLeafKeys map[common.Hash]struct{}) (AccountMap, map[string]bool, error) {
diffPathsAtB := make(map[string]bool)
diffAcountsAtB := make(AccountMap)
it, _ := trie.NewDifferenceIterator(a, b)
@ -313,7 +313,7 @@ func (sdb *builder) createdAndUpdatedState(a, b trie.NodeIterator, watchedAddres
valueNodePath := append(node.Path, partialPath...)
encodedPath := trie.HexToCompact(valueNodePath)
leafKey := encodedPath[1:]
if isWatchedAddress(watchedAddresses, leafKey) {
if isWatchedAddress(watchedAddressesLeafKeys, leafKey) {
diffAcountsAtB[common.Bytes2Hex(leafKey)] = accountWrapper{
NodeType: node.NodeType,
Path: node.Path,
@ -386,7 +386,7 @@ func (sdb *builder) createdAndUpdatedStateWithIntermediateNodes(a, b trie.NodeIt
// deletedOrUpdatedState returns a slice of all the pathes that are emptied at B
// and a mapping of their leafkeys to all the accounts that exist in a different state at A than B
func (sdb *builder) deletedOrUpdatedState(a, b trie.NodeIterator, diffPathsAtB map[string]bool, watchedAddresses []common.Address, output StateNodeSink) (AccountMap, error) {
func (sdb *builder) deletedOrUpdatedState(a, b trie.NodeIterator, diffPathsAtB map[string]bool, watchedAddressesLeafKeys map[common.Hash]struct{}, output StateNodeSink) (AccountMap, error) {
diffAccountAtA := make(AccountMap)
it, _ := trie.NewDifferenceIterator(b, a)
for it.Next(true) {
@ -409,7 +409,7 @@ func (sdb *builder) deletedOrUpdatedState(a, b trie.NodeIterator, diffPathsAtB m
valueNodePath := append(node.Path, partialPath...)
encodedPath := trie.HexToCompact(valueNodePath)
leafKey := encodedPath[1:]
if isWatchedAddress(watchedAddresses, leafKey) {
if isWatchedAddress(watchedAddressesLeafKeys, leafKey) {
diffAccountAtA[common.Bytes2Hex(leafKey)] = accountWrapper{
NodeType: node.NodeType,
Path: node.Path,
@ -713,16 +713,12 @@ func (sdb *builder) deletedOrUpdatedStorage(a, b trie.NodeIterator, diffPathsAtB
}
// isWatchedAddress is used to check if a state account corresponds to one of the addresses the builder is configured to watch
func isWatchedAddress(watchedAddresses []common.Address, stateLeafKey []byte) bool {
func isWatchedAddress(watchedAddressesLeafKeys map[common.Hash]struct{}, stateLeafKey []byte) bool {
// If we aren't watching any specific addresses, we are watching everything
if len(watchedAddresses) == 0 {
if len(watchedAddressesLeafKeys) == 0 {
return true
}
for _, addr := range watchedAddresses {
addrHashKey := crypto.Keccak256(addr.Bytes())
if bytes.Equal(addrHashKey, stateLeafKey) {
return true
}
}
return false
_, ok := watchedAddressesLeafKeys[common.BytesToHash(stateLeafKey)]
return ok
}

View File

@ -987,6 +987,7 @@ func TestBuilderWithWatchedAddressList(t *testing.T) {
params := statediff.Params{
WatchedAddresses: []common.Address{testhelpers.Account1Addr, testhelpers.ContractAddr},
}
params.ComputeWatchedAddressesLeafKeys()
builder = statediff.NewBuilder(chain.StateCache())
var tests = []struct {
@ -1566,6 +1567,7 @@ func TestBuilderWithRemovedNonWatchedAccount(t *testing.T) {
params := statediff.Params{
WatchedAddresses: []common.Address{testhelpers.Account1Addr, testhelpers.Account2Addr},
}
params.ComputeWatchedAddressesLeafKeys()
builder = statediff.NewBuilder(chain.StateCache())
var tests = []struct {

View File

@ -24,8 +24,11 @@ import (
"sort"
"strings"
"github.com/thoas/go-funk"
"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/statediff/indexer/postgres"
. "github.com/ethereum/go-ethereum/statediff/types"
)
func sortKeys(data AccountMap) []string {
@ -96,6 +99,19 @@ func loadWatchedAddresses(db *postgres.DB) error {
writeLoopParams.Lock()
defer writeLoopParams.Unlock()
writeLoopParams.WatchedAddresses = watchedAddresses
writeLoopParams.ComputeWatchedAddressesLeafKeys()
return nil
}
// MapWatchAddressArgsToAddresses maps []WatchAddressArg to corresponding []common.Address
func MapWatchAddressArgsToAddresses(args []WatchAddressArg) ([]common.Address, error) {
addresses, ok := funk.Map(args, func(arg WatchAddressArg) common.Address {
return common.HexToAddress(arg.Address)
}).([]common.Address)
if !ok {
return nil, fmt.Errorf(typeAssertionFailed)
}
return addresses, nil
}

View File

@ -423,6 +423,9 @@ func (sds *Service) streamStateDiff(currentBlock *types.Block, parentRoot common
func (sds *Service) StateDiffAt(blockNumber uint64, params Params) (*Payload, error) {
currentBlock := sds.BlockChain.GetBlockByNumber(blockNumber)
log.Info("sending state diff", "block height", blockNumber)
params.ComputeWatchedAddressesLeafKeys()
if blockNumber == 0 {
return sds.processStateDiff(currentBlock, common.Hash{}, params)
}
@ -435,6 +438,9 @@ func (sds *Service) StateDiffAt(blockNumber uint64, params Params) (*Payload, er
func (sds *Service) StateDiffFor(blockHash common.Hash, params Params) (*Payload, error) {
currentBlock := sds.BlockChain.GetBlockByHash(blockHash)
log.Info("sending state diff", "block hash", blockHash)
params.ComputeWatchedAddressesLeafKeys()
if currentBlock.NumberU64() == 0 {
return sds.processStateDiff(currentBlock, common.Hash{}, params)
}
@ -493,6 +499,9 @@ func (sds *Service) newPayload(stateObject []byte, block *types.Block, params Pa
func (sds *Service) StateTrieAt(blockNumber uint64, params Params) (*Payload, error) {
currentBlock := sds.BlockChain.GetBlockByNumber(blockNumber)
log.Info("sending state trie", "block height", blockNumber)
params.ComputeWatchedAddressesLeafKeys()
return sds.processStateTrie(currentBlock, params)
}
@ -515,6 +524,9 @@ func (sds *Service) Subscribe(id rpc.ID, sub chan<- Payload, quitChan chan<- boo
if atomic.CompareAndSwapInt32(&sds.subscribers, 0, 1) {
log.Info("State diffing subscription received; beginning statediff processing")
}
params.ComputeWatchedAddressesLeafKeys()
// Subscription type is defined as the hash of the rlp-serialized subscription params
by, err := rlp.EncodeToBytes(params)
if err != nil {
@ -661,6 +673,8 @@ func (sds *Service) StreamCodeAndCodeHash(blockNumber uint64, outChan chan<- Cod
// This operation cannot be performed back past the point of db pruning; it requires an archival node
// for historical data
func (sds *Service) WriteStateDiffAt(blockNumber uint64, params Params) error {
params.ComputeWatchedAddressesLeafKeys()
currentBlock := sds.BlockChain.GetBlockByNumber(blockNumber)
parentRoot := common.Hash{}
if blockNumber != 0 {
@ -674,6 +688,8 @@ func (sds *Service) WriteStateDiffAt(blockNumber uint64, params Params) error {
// This operation cannot be performed back past the point of db pruning; it requires an archival node
// for historical data
func (sds *Service) WriteStateDiffFor(blockHash common.Hash, params Params) error {
params.ComputeWatchedAddressesLeafKeys()
currentBlock := sds.BlockChain.GetBlockByHash(blockHash)
parentRoot := common.Hash{}
if currentBlock.NumberU64() != 0 {
@ -736,7 +752,7 @@ func (sds *Service) writeStateDiffWithRetry(block *types.Block, parentRoot commo
}
// Performs one of following operations on the watched addresses in writeLoopParams and the db:
// Add | Remove | Set | Clear
// add | remove | set | clear
func (sds *Service) WatchAddress(operation OperationType, args []WatchAddressArg) error {
// lock writeLoopParams for a write
writeLoopParams.Lock()
@ -756,65 +772,66 @@ func (sds *Service) WatchAddress(operation OperationType, args []WatchAddressArg
return true
}).([]WatchAddressArg)
if !ok {
return fmt.Errorf("Add: filtered args %s", typeAssertionFailed)
return fmt.Errorf("add: filtered args %s", typeAssertionFailed)
}
// get addresses from the filtered args
filteredAddresses, ok := funk.Map(filteredArgs, func(arg WatchAddressArg) common.Address {
return common.HexToAddress(arg.Address)
}).([]common.Address)
if !ok {
return fmt.Errorf("Add: filtered addresses %s", typeAssertionFailed)
filteredAddresses, err := MapWatchAddressArgsToAddresses(filteredArgs)
if err != nil {
return fmt.Errorf("add: filtered addresses %s", err.Error())
}
// update the db
err := sds.indexer.InsertWatchedAddresses(filteredArgs, currentBlockNumber)
err = sds.indexer.InsertWatchedAddresses(filteredArgs, currentBlockNumber)
if err != nil {
return err
}
// update in-memory params
writeLoopParams.WatchedAddresses = append(writeLoopParams.WatchedAddresses, filteredAddresses...)
funk.ForEach(filteredAddresses, func(address common.Address) {
writeLoopParams.watchedAddressesLeafKeys[crypto.Keccak256Hash(address.Bytes())] = struct{}{}
})
case Remove:
// get addresses from args
argAddresses, ok := funk.Map(args, func(arg WatchAddressArg) common.Address {
return common.HexToAddress(arg.Address)
}).([]common.Address)
if !ok {
return fmt.Errorf("Remove: mapped addresses %s", typeAssertionFailed)
argAddresses, err := MapWatchAddressArgsToAddresses(args)
if err != nil {
return fmt.Errorf("remove: mapped addresses %s", err.Error())
}
// remove the provided addresses from currently watched addresses
addresses, ok := funk.Subtract(writeLoopParams.WatchedAddresses, argAddresses).([]common.Address)
if !ok {
return fmt.Errorf("Remove: filtered addresses %s", typeAssertionFailed)
return fmt.Errorf("remove: filtered addresses %s", typeAssertionFailed)
}
// update the db
err := sds.indexer.RemoveWatchedAddresses(args)
err = sds.indexer.RemoveWatchedAddresses(args)
if err != nil {
return err
}
// update in-memory params
writeLoopParams.WatchedAddresses = addresses
funk.ForEach(argAddresses, func(address common.Address) {
delete(writeLoopParams.watchedAddressesLeafKeys, crypto.Keccak256Hash(address.Bytes()))
})
case Set:
// get addresses from args
argAddresses, ok := funk.Map(args, func(arg WatchAddressArg) common.Address {
return common.HexToAddress(arg.Address)
}).([]common.Address)
if !ok {
return fmt.Errorf("Set: mapped addresses %s", typeAssertionFailed)
argAddresses, err := MapWatchAddressArgsToAddresses(args)
if err != nil {
return fmt.Errorf("set: mapped addresses %s", err.Error())
}
// update the db
err := sds.indexer.SetWatchedAddresses(args, currentBlockNumber)
err = sds.indexer.SetWatchedAddresses(args, currentBlockNumber)
if err != nil {
return err
}
// update in-memory params
writeLoopParams.WatchedAddresses = argAddresses
writeLoopParams.ComputeWatchedAddressesLeafKeys()
case Clear:
// update the db
err := sds.indexer.ClearWatchedAddresses()
@ -824,6 +841,7 @@ func (sds *Service) WatchAddress(operation OperationType, args []WatchAddressArg
// update in-memory params
writeLoopParams.WatchedAddresses = []common.Address{}
writeLoopParams.ComputeWatchedAddressesLeafKeys()
default:
return fmt.Errorf("%s %s", unexpectedOperation, operation)

View File

@ -144,6 +144,7 @@ func testErrorInChainEventLoop(t *testing.T) {
}
}
defaultParams.ComputeWatchedAddressesLeafKeys()
if !reflect.DeepEqual(builder.Params, defaultParams) {
t.Error("Test failure:", t.Name())
t.Logf("Actual params does not equal expected.\nactual:%+v\nexpected: %+v", builder.Params, defaultParams)
@ -195,6 +196,8 @@ func testErrorInBlockLoop(t *testing.T) {
}
}()
service.Loop(eventsChannel)
defaultParams.ComputeWatchedAddressesLeafKeys()
if !reflect.DeepEqual(builder.Params, defaultParams) {
t.Error("Test failure:", t.Name())
t.Logf("Actual params does not equal expected.\nactual:%+v\nexpected: %+v", builder.Params, defaultParams)
@ -268,6 +271,8 @@ func testErrorInStateDiffAt(t *testing.T) {
if err != nil {
t.Error(err)
}
defaultParams.ComputeWatchedAddressesLeafKeys()
if !reflect.DeepEqual(builder.Params, defaultParams) {
t.Error("Test failure:", t.Name())
t.Logf("Actual params does not equal expected.\nactual:%+v\nexpected: %+v", builder.Params, defaultParams)

View File

@ -362,65 +362,62 @@ func (sds *MockStateDiffService) WatchAddress(operation statediff.OperationType,
return true
}).([]sdtypes.WatchAddressArg)
if !ok {
return fmt.Errorf("Add: filtered args %s", typeAssertionFailed)
return fmt.Errorf("add: filtered args %s", typeAssertionFailed)
}
// get addresses from the filtered args
filteredAddresses, ok := funk.Map(filteredArgs, func(arg sdtypes.WatchAddressArg) common.Address {
return common.HexToAddress(arg.Address)
}).([]common.Address)
if !ok {
return fmt.Errorf("Add: filtered addresses %s", typeAssertionFailed)
filteredAddresses, err := statediff.MapWatchAddressArgsToAddresses(filteredArgs)
if err != nil {
return fmt.Errorf("add: filtered addresses %s", err.Error())
}
// update the db
err := sds.Indexer.InsertWatchedAddresses(filteredArgs, currentBlockNumber)
err = sds.Indexer.InsertWatchedAddresses(filteredArgs, currentBlockNumber)
if err != nil {
return err
}
// update in-memory params
sds.writeLoopParams.WatchedAddresses = append(sds.writeLoopParams.WatchedAddresses, filteredAddresses...)
sds.writeLoopParams.ComputeWatchedAddressesLeafKeys()
case statediff.Remove:
// get addresses from args
argAddresses, ok := funk.Map(args, func(arg sdtypes.WatchAddressArg) common.Address {
return common.HexToAddress(arg.Address)
}).([]common.Address)
if !ok {
return fmt.Errorf("Remove: mapped addresses %s", typeAssertionFailed)
argAddresses, err := statediff.MapWatchAddressArgsToAddresses(args)
if err != nil {
return fmt.Errorf("remove: mapped addresses %s", err.Error())
}
// remove the provided addresses from currently watched addresses
addresses, ok := funk.Subtract(sds.writeLoopParams.WatchedAddresses, argAddresses).([]common.Address)
if !ok {
return fmt.Errorf("Remove: filtered addresses %s", typeAssertionFailed)
return fmt.Errorf("remove: filtered addresses %s", typeAssertionFailed)
}
// update the db
err := sds.Indexer.RemoveWatchedAddresses(args)
err = sds.Indexer.RemoveWatchedAddresses(args)
if err != nil {
return err
}
// update in-memory params
sds.writeLoopParams.WatchedAddresses = addresses
sds.writeLoopParams.ComputeWatchedAddressesLeafKeys()
case statediff.Set:
// get addresses from args
argAddresses, ok := funk.Map(args, func(arg sdtypes.WatchAddressArg) common.Address {
return common.HexToAddress(arg.Address)
}).([]common.Address)
if !ok {
return fmt.Errorf("Set: mapped addresses %s", typeAssertionFailed)
argAddresses, err := statediff.MapWatchAddressArgsToAddresses(args)
if err != nil {
return fmt.Errorf("set: mapped addresses %s", err.Error())
}
// update the db
err := sds.Indexer.SetWatchedAddresses(args, currentBlockNumber)
err = sds.Indexer.SetWatchedAddresses(args, currentBlockNumber)
if err != nil {
return err
}
// update in-memory params
sds.writeLoopParams.WatchedAddresses = argAddresses
sds.writeLoopParams.ComputeWatchedAddressesLeafKeys()
case statediff.Clear:
// update the db
err := sds.Indexer.ClearWatchedAddresses()
@ -430,6 +427,7 @@ func (sds *MockStateDiffService) WatchAddress(operation statediff.OperationType,
// update in-memory params
sds.writeLoopParams.WatchedAddresses = []common.Address{}
sds.writeLoopParams.ComputeWatchedAddressesLeafKeys()
default:
return fmt.Errorf("%s %s", unexpectedOperation, operation)

View File

@ -506,6 +506,7 @@ func testWatchAddressAPI(t *testing.T) {
mockService.writeLoopParams = statediff.ParamsWithMutex{
Params: test.startingParams,
}
mockService.writeLoopParams.ComputeWatchedAddressesLeafKeys()
// make the API call to change watched addresses
err := mockService.WatchAddress(test.operation, test.args)
@ -522,6 +523,7 @@ func testWatchAddressAPI(t *testing.T) {
}
// check updated indexing params
test.expectedParams.ComputeWatchedAddressesLeafKeys()
updatedParams := mockService.writeLoopParams.Params
if !reflect.DeepEqual(updatedParams, test.expectedParams) {
t.Logf("Test failed: %s", test.name)

View File

@ -26,6 +26,7 @@ import (
"github.com/ethereum/go-ethereum/common"
ctypes "github.com/ethereum/go-ethereum/core/types"
"github.com/ethereum/go-ethereum/crypto"
"github.com/ethereum/go-ethereum/statediff/types"
)
@ -51,6 +52,15 @@ type Params struct {
IncludeTD bool
IncludeCode bool
WatchedAddresses []common.Address
watchedAddressesLeafKeys map[common.Hash]struct{}
}
// ComputeWatchedAddressesLeafKeys populates a map with keys (Keccak256Hash) of each of the WatchedAddresses
func (p *Params) ComputeWatchedAddressesLeafKeys() {
p.watchedAddressesLeafKeys = make(map[common.Hash]struct{}, len(p.WatchedAddresses))
for _, address := range p.WatchedAddresses {
p.watchedAddressesLeafKeys[crypto.Keccak256Hash(address.Bytes())] = struct{}{}
}
}
// ParamsWithMutex allows to lock the parameters while they are being updated | read from
@ -122,8 +132,8 @@ type accountWrapper struct {
type OperationType string
const (
Add OperationType = "Add"
Remove OperationType = "Remove"
Set OperationType = "Set"
Clear OperationType = "Clear"
Add OperationType = "add"
Remove OperationType = "remove"
Set OperationType = "set"
Clear OperationType = "clear"
)