add ability to watch specific storage slots (leaf keys) only

This commit is contained in:
Ian Norden 2020-05-13 14:28:35 -05:00
parent 347a7ba0e8
commit aa7c5b0869
3 changed files with 47 additions and 29 deletions

View File

@ -54,10 +54,10 @@ func (sdb *builder) BuildStateDiff(args Args, params Params) (StateDiff, error)
if !params.IntermediateStateNodes || len(params.WatchedAddresses) > 0 { // if we are watching only specific accounts then we are only diffing leaf nodes
return sdb.buildStateDiffWithoutIntermediateStateNodes(args, params)
}
return sdb.buildStateDiffWithIntermediateStateNodes(args, params.IntermediateStorageNodes)
return sdb.buildStateDiffWithIntermediateStateNodes(args, params)
}
func (sdb *builder) buildStateDiffWithIntermediateStateNodes(args Args, intermediateStorageNodes bool) (StateDiff, error) {
func (sdb *builder) buildStateDiffWithIntermediateStateNodes(args Args, params Params) (StateDiff, error) {
// Generate tries for old and new states
oldTrie, err := sdb.stateCache.OpenTrie(args.OldStateRoot)
if err != nil {
@ -84,11 +84,11 @@ func (sdb *builder) buildStateDiffWithIntermediateStateNodes(args Args, intermed
updatedKeys := findIntersection(createKeys, deleteKeys)
// Build and return the statediff
updatedAccounts, err := sdb.buildAccountUpdates(diffAccountsAtB, diffAccountsAtA, updatedKeys, intermediateStorageNodes)
updatedAccounts, err := sdb.buildAccountUpdates(diffAccountsAtB, diffAccountsAtA, updatedKeys, params.WatchedStorageSlots, params.IntermediateStorageNodes)
if err != nil {
return StateDiff{}, fmt.Errorf("error building diff for updated accounts: %v", err)
}
createdAccounts, err := sdb.buildAccountCreations(diffAccountsAtB, intermediateStorageNodes)
createdAccounts, err := sdb.buildAccountCreations(diffAccountsAtB, params.WatchedStorageSlots, params.IntermediateStorageNodes)
if err != nil {
return StateDiff{}, fmt.Errorf("error building diff for created accounts: %v", err)
}
@ -131,11 +131,11 @@ func (sdb *builder) buildStateDiffWithoutIntermediateStateNodes(args Args, param
updatedKeys := findIntersection(createKeys, deleteKeys)
// Build and return the statediff
updatedAccounts, err := sdb.buildAccountUpdates(diffAccountsAtB, diffAccountsAtA, updatedKeys, params.IntermediateStorageNodes)
updatedAccounts, err := sdb.buildAccountUpdates(diffAccountsAtB, diffAccountsAtA, updatedKeys, params.WatchedStorageSlots, params.IntermediateStorageNodes)
if err != nil {
return StateDiff{}, fmt.Errorf("error building diff for updated accounts: %v", err)
}
createdAccounts, err := sdb.buildAccountCreations(diffAccountsAtB, params.IntermediateStorageNodes)
createdAccounts, err := sdb.buildAccountCreations(diffAccountsAtB, params.WatchedStorageSlots, params.IntermediateStorageNodes)
if err != nil {
return StateDiff{}, fmt.Errorf("error building diff for created accounts: %v", err)
}
@ -151,7 +151,7 @@ func (sdb *builder) buildStateDiffWithoutIntermediateStateNodes(args Args, param
}, nil
}
func (sdb *builder) collectDiffAccounts(a, b trie.NodeIterator, watchedAddresses []string) (AccountMap, error) {
func (sdb *builder) collectDiffAccounts(a, b trie.NodeIterator, watchedAddresses []common.Address) (AccountMap, error) {
diffAcountsAtB := make(AccountMap)
it, _ := trie.NewDifferenceIterator(a, b)
for it.Next(true) {
@ -207,15 +207,14 @@ func (sdb *builder) collectDiffAccounts(a, b trie.NodeIterator, watchedAddresses
}
// isWatchedAddress is used to check if a state account corresponds to one of the addresses the builder is configured to watch
func isWatchedAddress(watchedAddresses []string, hashKey []byte) bool {
func isWatchedAddress(watchedAddresses []common.Address, stateLeafKey []byte) bool {
// If we aren't watching any specific addresses, we are watching everything
if len(watchedAddresses) == 0 {
return true
}
for _, addrStr := range watchedAddresses {
addr := common.HexToAddress(addrStr)
addrHashKey := crypto.Keccak256(addr[:])
if bytes.Equal(addrHashKey, hashKey) {
for _, addr := range watchedAddresses {
addrHashKey := crypto.Keccak256(addr.Bytes())
if bytes.Equal(addrHashKey, stateLeafKey) {
return true
}
}
@ -353,7 +352,7 @@ func (sdb *builder) deletedOrUpdatedNodes(a, b trie.NodeIterator, diffPathsAtB m
// needs to be called before building account creations and deletions as this mutattes
// those account maps to remove the accounts which were updated
func (sdb *builder) buildAccountUpdates(creations, deletions AccountMap, updatedKeys []string, intermediateStorageNodes bool) ([]StateNode, error) {
func (sdb *builder) buildAccountUpdates(creations, deletions AccountMap, updatedKeys []string, watchedStorageKeys []common.Hash, intermediateStorageNodes bool) ([]StateNode, error) {
updatedAccounts := make([]StateNode, 0, len(updatedKeys))
var err error
for _, key := range updatedKeys {
@ -363,7 +362,7 @@ func (sdb *builder) buildAccountUpdates(creations, deletions AccountMap, updated
if deletedAcc.Account != nil && createdAcc.Account != nil {
oldSR := deletedAcc.Account.Root
newSR := createdAcc.Account.Root
storageDiffs, err = sdb.buildStorageNodesIncremental(oldSR, newSR, intermediateStorageNodes)
storageDiffs, err = sdb.buildStorageNodesIncremental(oldSR, newSR, watchedStorageKeys, intermediateStorageNodes)
if err != nil {
return nil, fmt.Errorf("failed building incremental storage diffs for account with leafkey %s\r\nerror: %v", key, err)
}
@ -382,11 +381,11 @@ func (sdb *builder) buildAccountUpdates(creations, deletions AccountMap, updated
return updatedAccounts, nil
}
func (sdb *builder) buildAccountCreations(accounts AccountMap, intermediateStorageNodes bool) ([]StateNode, error) {
func (sdb *builder) buildAccountCreations(accounts AccountMap, watchedStorageKeys []common.Hash, intermediateStorageNodes bool) ([]StateNode, error) {
accountDiffs := make([]StateNode, 0, len(accounts))
for _, val := range accounts {
// For account creations, any storage node contained is a diff
storageDiffs, err := sdb.buildStorageNodesEventual(val.Account.Root, intermediateStorageNodes)
storageDiffs, err := sdb.buildStorageNodesEventual(val.Account.Root, watchedStorageKeys, intermediateStorageNodes)
if err != nil {
return nil, fmt.Errorf("failed building eventual storage diffs for node %x\r\nerror: %v", val.Path, err)
}
@ -416,7 +415,7 @@ func (sdb *builder) buildAccountDeletions(accounts AccountMap) ([]StateNode, err
return accountDiffs, nil
}
func (sdb *builder) buildStorageNodesEventual(sr common.Hash, intermediateNodes bool) ([]StorageNode, error) {
func (sdb *builder) buildStorageNodesEventual(sr common.Hash, watchedStorageKeys []common.Hash, intermediateNodes bool) ([]StorageNode, error) {
log.Debug("Storage Root For Eventual Diff", "root", sr.Hex())
sTrie, err := sdb.stateCache.OpenTrie(sr)
if err != nil {
@ -424,10 +423,10 @@ func (sdb *builder) buildStorageNodesEventual(sr common.Hash, intermediateNodes
return nil, err
}
it := sTrie.NodeIterator(make([]byte, 0))
return sdb.buildStorageNodesFromTrie(it, intermediateNodes)
return sdb.buildStorageNodesFromTrie(it, watchedStorageKeys, intermediateNodes)
}
func (sdb *builder) buildStorageNodesIncremental(oldSR common.Hash, newSR common.Hash, intermediateNodes bool) ([]StorageNode, error) {
func (sdb *builder) buildStorageNodesIncremental(oldSR common.Hash, newSR common.Hash, watchedStorageKeys []common.Hash, intermediateNodes bool) ([]StorageNode, error) {
log.Debug("Storage Roots for Incremental Diff", "old", oldSR.Hex(), "new", newSR.Hex())
oldTrie, err := sdb.stateCache.OpenTrie(oldSR)
if err != nil {
@ -441,10 +440,10 @@ func (sdb *builder) buildStorageNodesIncremental(oldSR common.Hash, newSR common
oldIt := oldTrie.NodeIterator(make([]byte, 0))
newIt := newTrie.NodeIterator(make([]byte, 0))
it, _ := trie.NewDifferenceIterator(oldIt, newIt)
return sdb.buildStorageNodesFromTrie(it, intermediateNodes)
return sdb.buildStorageNodesFromTrie(it, watchedStorageKeys, intermediateNodes)
}
func (sdb *builder) buildStorageNodesFromTrie(it trie.NodeIterator, intermediateNodes bool) ([]StorageNode, error) {
func (sdb *builder) buildStorageNodesFromTrie(it trie.NodeIterator, watchedStorageKeys []common.Hash, intermediateNodes bool) ([]StorageNode, error) {
storageDiffs := make([]StorageNode, 0)
for it.Next(true) {
// skip value nodes
@ -474,12 +473,14 @@ func (sdb *builder) buildStorageNodesFromTrie(it trie.NodeIterator, intermediate
valueNodePath := append(nodePath, partialPath...)
encodedPath := trie.HexToCompact(valueNodePath)
leafKey := encodedPath[1:]
storageDiffs = append(storageDiffs, StorageNode{
NodeType: ty,
Path: nodePath,
NodeValue: node,
LeafKey: leafKey,
})
if isWatchedStorageKey(watchedStorageKeys, leafKey) {
storageDiffs = append(storageDiffs, StorageNode{
NodeType: ty,
Path: nodePath,
NodeValue: node,
LeafKey: leafKey,
})
}
case Extension, Branch:
if intermediateNodes {
storageDiffs = append(storageDiffs, StorageNode{
@ -494,3 +495,17 @@ func (sdb *builder) buildStorageNodesFromTrie(it trie.NodeIterator, intermediate
}
return storageDiffs, nil
}
// isWatchedStorageKey is used to check if a storage leaf corresponds to one of the storage slots the builder is configured to watch
func isWatchedStorageKey(watchedKeys []common.Hash, storageLeafKey []byte) bool {
// If we aren't watching any specific addresses, we are watching everything
if len(watchedKeys) == 0 {
return true
}
for _, hashKey := range watchedKeys {
if bytes.Equal(hashKey.Bytes(), storageLeafKey) {
return true
}
}
return false
}

View File

@ -751,7 +751,7 @@ func TestBuilderWithWatchedAddressList(t *testing.T) {
block2 = blockMap[BlockHashes[1]]
block3 = blockMap[BlockHashes[0]]
params := statediff.Params{
WatchedAddresses: []string{testhelpers.Account1Addr.Hex(), testhelpers.ContractAddr.Hex()},
WatchedAddresses: []common.Address{testhelpers.Account1Addr, testhelpers.ContractAddr},
}
builder = statediff.NewBuilder(chain.StateCache())
@ -911,6 +911,8 @@ func TestBuilderWithWatchedAddressList(t *testing.T) {
}
}
// Write a test that tests when accounts are deleted, or moved to a new path
/*
contract test {

View File

@ -41,7 +41,8 @@ type Params struct {
IncludeBlock bool
IncludeReceipts bool
IncludeTD bool
WatchedAddresses []string
WatchedAddresses []common.Address
WatchedStorageSlots []common.Hash
}
// Args bundles the arguments for the state diff builder