add ability to watch specific storage slots (leaf keys) only

This commit is contained in:
Ian Norden 2020-05-13 14:28:35 -05:00
parent 347a7ba0e8
commit aa7c5b0869
3 changed files with 47 additions and 29 deletions

View File

@ -54,10 +54,10 @@ func (sdb *builder) BuildStateDiff(args Args, params Params) (StateDiff, error)
if !params.IntermediateStateNodes || len(params.WatchedAddresses) > 0 { // if we are watching only specific accounts then we are only diffing leaf nodes if !params.IntermediateStateNodes || len(params.WatchedAddresses) > 0 { // if we are watching only specific accounts then we are only diffing leaf nodes
return sdb.buildStateDiffWithoutIntermediateStateNodes(args, params) return sdb.buildStateDiffWithoutIntermediateStateNodes(args, params)
} }
return sdb.buildStateDiffWithIntermediateStateNodes(args, params.IntermediateStorageNodes) return sdb.buildStateDiffWithIntermediateStateNodes(args, params)
} }
func (sdb *builder) buildStateDiffWithIntermediateStateNodes(args Args, intermediateStorageNodes bool) (StateDiff, error) { func (sdb *builder) buildStateDiffWithIntermediateStateNodes(args Args, params Params) (StateDiff, error) {
// Generate tries for old and new states // Generate tries for old and new states
oldTrie, err := sdb.stateCache.OpenTrie(args.OldStateRoot) oldTrie, err := sdb.stateCache.OpenTrie(args.OldStateRoot)
if err != nil { if err != nil {
@ -84,11 +84,11 @@ func (sdb *builder) buildStateDiffWithIntermediateStateNodes(args Args, intermed
updatedKeys := findIntersection(createKeys, deleteKeys) updatedKeys := findIntersection(createKeys, deleteKeys)
// Build and return the statediff // Build and return the statediff
updatedAccounts, err := sdb.buildAccountUpdates(diffAccountsAtB, diffAccountsAtA, updatedKeys, intermediateStorageNodes) updatedAccounts, err := sdb.buildAccountUpdates(diffAccountsAtB, diffAccountsAtA, updatedKeys, params.WatchedStorageSlots, params.IntermediateStorageNodes)
if err != nil { if err != nil {
return StateDiff{}, fmt.Errorf("error building diff for updated accounts: %v", err) return StateDiff{}, fmt.Errorf("error building diff for updated accounts: %v", err)
} }
createdAccounts, err := sdb.buildAccountCreations(diffAccountsAtB, intermediateStorageNodes) createdAccounts, err := sdb.buildAccountCreations(diffAccountsAtB, params.WatchedStorageSlots, params.IntermediateStorageNodes)
if err != nil { if err != nil {
return StateDiff{}, fmt.Errorf("error building diff for created accounts: %v", err) return StateDiff{}, fmt.Errorf("error building diff for created accounts: %v", err)
} }
@ -131,11 +131,11 @@ func (sdb *builder) buildStateDiffWithoutIntermediateStateNodes(args Args, param
updatedKeys := findIntersection(createKeys, deleteKeys) updatedKeys := findIntersection(createKeys, deleteKeys)
// Build and return the statediff // Build and return the statediff
updatedAccounts, err := sdb.buildAccountUpdates(diffAccountsAtB, diffAccountsAtA, updatedKeys, params.IntermediateStorageNodes) updatedAccounts, err := sdb.buildAccountUpdates(diffAccountsAtB, diffAccountsAtA, updatedKeys, params.WatchedStorageSlots, params.IntermediateStorageNodes)
if err != nil { if err != nil {
return StateDiff{}, fmt.Errorf("error building diff for updated accounts: %v", err) return StateDiff{}, fmt.Errorf("error building diff for updated accounts: %v", err)
} }
createdAccounts, err := sdb.buildAccountCreations(diffAccountsAtB, params.IntermediateStorageNodes) createdAccounts, err := sdb.buildAccountCreations(diffAccountsAtB, params.WatchedStorageSlots, params.IntermediateStorageNodes)
if err != nil { if err != nil {
return StateDiff{}, fmt.Errorf("error building diff for created accounts: %v", err) return StateDiff{}, fmt.Errorf("error building diff for created accounts: %v", err)
} }
@ -151,7 +151,7 @@ func (sdb *builder) buildStateDiffWithoutIntermediateStateNodes(args Args, param
}, nil }, nil
} }
func (sdb *builder) collectDiffAccounts(a, b trie.NodeIterator, watchedAddresses []string) (AccountMap, error) { func (sdb *builder) collectDiffAccounts(a, b trie.NodeIterator, watchedAddresses []common.Address) (AccountMap, error) {
diffAcountsAtB := make(AccountMap) diffAcountsAtB := make(AccountMap)
it, _ := trie.NewDifferenceIterator(a, b) it, _ := trie.NewDifferenceIterator(a, b)
for it.Next(true) { for it.Next(true) {
@ -207,15 +207,14 @@ func (sdb *builder) collectDiffAccounts(a, b trie.NodeIterator, watchedAddresses
} }
// isWatchedAddress is used to check if a state account corresponds to one of the addresses the builder is configured to watch // isWatchedAddress is used to check if a state account corresponds to one of the addresses the builder is configured to watch
func isWatchedAddress(watchedAddresses []string, hashKey []byte) bool { func isWatchedAddress(watchedAddresses []common.Address, stateLeafKey []byte) bool {
// If we aren't watching any specific addresses, we are watching everything // If we aren't watching any specific addresses, we are watching everything
if len(watchedAddresses) == 0 { if len(watchedAddresses) == 0 {
return true return true
} }
for _, addrStr := range watchedAddresses { for _, addr := range watchedAddresses {
addr := common.HexToAddress(addrStr) addrHashKey := crypto.Keccak256(addr.Bytes())
addrHashKey := crypto.Keccak256(addr[:]) if bytes.Equal(addrHashKey, stateLeafKey) {
if bytes.Equal(addrHashKey, hashKey) {
return true return true
} }
} }
@ -353,7 +352,7 @@ func (sdb *builder) deletedOrUpdatedNodes(a, b trie.NodeIterator, diffPathsAtB m
// needs to be called before building account creations and deletions as this mutattes // needs to be called before building account creations and deletions as this mutattes
// those account maps to remove the accounts which were updated // those account maps to remove the accounts which were updated
func (sdb *builder) buildAccountUpdates(creations, deletions AccountMap, updatedKeys []string, intermediateStorageNodes bool) ([]StateNode, error) { func (sdb *builder) buildAccountUpdates(creations, deletions AccountMap, updatedKeys []string, watchedStorageKeys []common.Hash, intermediateStorageNodes bool) ([]StateNode, error) {
updatedAccounts := make([]StateNode, 0, len(updatedKeys)) updatedAccounts := make([]StateNode, 0, len(updatedKeys))
var err error var err error
for _, key := range updatedKeys { for _, key := range updatedKeys {
@ -363,7 +362,7 @@ func (sdb *builder) buildAccountUpdates(creations, deletions AccountMap, updated
if deletedAcc.Account != nil && createdAcc.Account != nil { if deletedAcc.Account != nil && createdAcc.Account != nil {
oldSR := deletedAcc.Account.Root oldSR := deletedAcc.Account.Root
newSR := createdAcc.Account.Root newSR := createdAcc.Account.Root
storageDiffs, err = sdb.buildStorageNodesIncremental(oldSR, newSR, intermediateStorageNodes) storageDiffs, err = sdb.buildStorageNodesIncremental(oldSR, newSR, watchedStorageKeys, intermediateStorageNodes)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed building incremental storage diffs for account with leafkey %s\r\nerror: %v", key, err) return nil, fmt.Errorf("failed building incremental storage diffs for account with leafkey %s\r\nerror: %v", key, err)
} }
@ -382,11 +381,11 @@ func (sdb *builder) buildAccountUpdates(creations, deletions AccountMap, updated
return updatedAccounts, nil return updatedAccounts, nil
} }
func (sdb *builder) buildAccountCreations(accounts AccountMap, intermediateStorageNodes bool) ([]StateNode, error) { func (sdb *builder) buildAccountCreations(accounts AccountMap, watchedStorageKeys []common.Hash, intermediateStorageNodes bool) ([]StateNode, error) {
accountDiffs := make([]StateNode, 0, len(accounts)) accountDiffs := make([]StateNode, 0, len(accounts))
for _, val := range accounts { for _, val := range accounts {
// For account creations, any storage node contained is a diff // For account creations, any storage node contained is a diff
storageDiffs, err := sdb.buildStorageNodesEventual(val.Account.Root, intermediateStorageNodes) storageDiffs, err := sdb.buildStorageNodesEventual(val.Account.Root, watchedStorageKeys, intermediateStorageNodes)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed building eventual storage diffs for node %x\r\nerror: %v", val.Path, err) return nil, fmt.Errorf("failed building eventual storage diffs for node %x\r\nerror: %v", val.Path, err)
} }
@ -416,7 +415,7 @@ func (sdb *builder) buildAccountDeletions(accounts AccountMap) ([]StateNode, err
return accountDiffs, nil return accountDiffs, nil
} }
func (sdb *builder) buildStorageNodesEventual(sr common.Hash, intermediateNodes bool) ([]StorageNode, error) { func (sdb *builder) buildStorageNodesEventual(sr common.Hash, watchedStorageKeys []common.Hash, intermediateNodes bool) ([]StorageNode, error) {
log.Debug("Storage Root For Eventual Diff", "root", sr.Hex()) log.Debug("Storage Root For Eventual Diff", "root", sr.Hex())
sTrie, err := sdb.stateCache.OpenTrie(sr) sTrie, err := sdb.stateCache.OpenTrie(sr)
if err != nil { if err != nil {
@ -424,10 +423,10 @@ func (sdb *builder) buildStorageNodesEventual(sr common.Hash, intermediateNodes
return nil, err return nil, err
} }
it := sTrie.NodeIterator(make([]byte, 0)) it := sTrie.NodeIterator(make([]byte, 0))
return sdb.buildStorageNodesFromTrie(it, intermediateNodes) return sdb.buildStorageNodesFromTrie(it, watchedStorageKeys, intermediateNodes)
} }
func (sdb *builder) buildStorageNodesIncremental(oldSR common.Hash, newSR common.Hash, intermediateNodes bool) ([]StorageNode, error) { func (sdb *builder) buildStorageNodesIncremental(oldSR common.Hash, newSR common.Hash, watchedStorageKeys []common.Hash, intermediateNodes bool) ([]StorageNode, error) {
log.Debug("Storage Roots for Incremental Diff", "old", oldSR.Hex(), "new", newSR.Hex()) log.Debug("Storage Roots for Incremental Diff", "old", oldSR.Hex(), "new", newSR.Hex())
oldTrie, err := sdb.stateCache.OpenTrie(oldSR) oldTrie, err := sdb.stateCache.OpenTrie(oldSR)
if err != nil { if err != nil {
@ -441,10 +440,10 @@ func (sdb *builder) buildStorageNodesIncremental(oldSR common.Hash, newSR common
oldIt := oldTrie.NodeIterator(make([]byte, 0)) oldIt := oldTrie.NodeIterator(make([]byte, 0))
newIt := newTrie.NodeIterator(make([]byte, 0)) newIt := newTrie.NodeIterator(make([]byte, 0))
it, _ := trie.NewDifferenceIterator(oldIt, newIt) it, _ := trie.NewDifferenceIterator(oldIt, newIt)
return sdb.buildStorageNodesFromTrie(it, intermediateNodes) return sdb.buildStorageNodesFromTrie(it, watchedStorageKeys, intermediateNodes)
} }
func (sdb *builder) buildStorageNodesFromTrie(it trie.NodeIterator, intermediateNodes bool) ([]StorageNode, error) { func (sdb *builder) buildStorageNodesFromTrie(it trie.NodeIterator, watchedStorageKeys []common.Hash, intermediateNodes bool) ([]StorageNode, error) {
storageDiffs := make([]StorageNode, 0) storageDiffs := make([]StorageNode, 0)
for it.Next(true) { for it.Next(true) {
// skip value nodes // skip value nodes
@ -474,12 +473,14 @@ func (sdb *builder) buildStorageNodesFromTrie(it trie.NodeIterator, intermediate
valueNodePath := append(nodePath, partialPath...) valueNodePath := append(nodePath, partialPath...)
encodedPath := trie.HexToCompact(valueNodePath) encodedPath := trie.HexToCompact(valueNodePath)
leafKey := encodedPath[1:] leafKey := encodedPath[1:]
if isWatchedStorageKey(watchedStorageKeys, leafKey) {
storageDiffs = append(storageDiffs, StorageNode{ storageDiffs = append(storageDiffs, StorageNode{
NodeType: ty, NodeType: ty,
Path: nodePath, Path: nodePath,
NodeValue: node, NodeValue: node,
LeafKey: leafKey, LeafKey: leafKey,
}) })
}
case Extension, Branch: case Extension, Branch:
if intermediateNodes { if intermediateNodes {
storageDiffs = append(storageDiffs, StorageNode{ storageDiffs = append(storageDiffs, StorageNode{
@ -494,3 +495,17 @@ func (sdb *builder) buildStorageNodesFromTrie(it trie.NodeIterator, intermediate
} }
return storageDiffs, nil return storageDiffs, nil
} }
// isWatchedStorageKey is used to check if a storage leaf corresponds to one of the storage slots the builder is configured to watch
func isWatchedStorageKey(watchedKeys []common.Hash, storageLeafKey []byte) bool {
// If we aren't watching any specific addresses, we are watching everything
if len(watchedKeys) == 0 {
return true
}
for _, hashKey := range watchedKeys {
if bytes.Equal(hashKey.Bytes(), storageLeafKey) {
return true
}
}
return false
}

View File

@ -751,7 +751,7 @@ func TestBuilderWithWatchedAddressList(t *testing.T) {
block2 = blockMap[BlockHashes[1]] block2 = blockMap[BlockHashes[1]]
block3 = blockMap[BlockHashes[0]] block3 = blockMap[BlockHashes[0]]
params := statediff.Params{ params := statediff.Params{
WatchedAddresses: []string{testhelpers.Account1Addr.Hex(), testhelpers.ContractAddr.Hex()}, WatchedAddresses: []common.Address{testhelpers.Account1Addr, testhelpers.ContractAddr},
} }
builder = statediff.NewBuilder(chain.StateCache()) builder = statediff.NewBuilder(chain.StateCache())
@ -911,6 +911,8 @@ func TestBuilderWithWatchedAddressList(t *testing.T) {
} }
} }
// Write a test that tests when accounts are deleted, or moved to a new path
/* /*
contract test { contract test {

View File

@ -41,7 +41,8 @@ type Params struct {
IncludeBlock bool IncludeBlock bool
IncludeReceipts bool IncludeReceipts bool
IncludeTD bool IncludeTD bool
WatchedAddresses []string WatchedAddresses []common.Address
WatchedStorageSlots []common.Hash
} }
// Args bundles the arguments for the state diff builder // Args bundles the arguments for the state diff builder