add ability to watch specific storage slots (leaf keys) only
This commit is contained in:
parent
347a7ba0e8
commit
aa7c5b0869
@ -54,10 +54,10 @@ func (sdb *builder) BuildStateDiff(args Args, params Params) (StateDiff, error)
|
||||
if !params.IntermediateStateNodes || len(params.WatchedAddresses) > 0 { // if we are watching only specific accounts then we are only diffing leaf nodes
|
||||
return sdb.buildStateDiffWithoutIntermediateStateNodes(args, params)
|
||||
}
|
||||
return sdb.buildStateDiffWithIntermediateStateNodes(args, params.IntermediateStorageNodes)
|
||||
return sdb.buildStateDiffWithIntermediateStateNodes(args, params)
|
||||
}
|
||||
|
||||
func (sdb *builder) buildStateDiffWithIntermediateStateNodes(args Args, intermediateStorageNodes bool) (StateDiff, error) {
|
||||
func (sdb *builder) buildStateDiffWithIntermediateStateNodes(args Args, params Params) (StateDiff, error) {
|
||||
// Generate tries for old and new states
|
||||
oldTrie, err := sdb.stateCache.OpenTrie(args.OldStateRoot)
|
||||
if err != nil {
|
||||
@ -84,11 +84,11 @@ func (sdb *builder) buildStateDiffWithIntermediateStateNodes(args Args, intermed
|
||||
updatedKeys := findIntersection(createKeys, deleteKeys)
|
||||
|
||||
// Build and return the statediff
|
||||
updatedAccounts, err := sdb.buildAccountUpdates(diffAccountsAtB, diffAccountsAtA, updatedKeys, intermediateStorageNodes)
|
||||
updatedAccounts, err := sdb.buildAccountUpdates(diffAccountsAtB, diffAccountsAtA, updatedKeys, params.WatchedStorageSlots, params.IntermediateStorageNodes)
|
||||
if err != nil {
|
||||
return StateDiff{}, fmt.Errorf("error building diff for updated accounts: %v", err)
|
||||
}
|
||||
createdAccounts, err := sdb.buildAccountCreations(diffAccountsAtB, intermediateStorageNodes)
|
||||
createdAccounts, err := sdb.buildAccountCreations(diffAccountsAtB, params.WatchedStorageSlots, params.IntermediateStorageNodes)
|
||||
if err != nil {
|
||||
return StateDiff{}, fmt.Errorf("error building diff for created accounts: %v", err)
|
||||
}
|
||||
@ -131,11 +131,11 @@ func (sdb *builder) buildStateDiffWithoutIntermediateStateNodes(args Args, param
|
||||
updatedKeys := findIntersection(createKeys, deleteKeys)
|
||||
|
||||
// Build and return the statediff
|
||||
updatedAccounts, err := sdb.buildAccountUpdates(diffAccountsAtB, diffAccountsAtA, updatedKeys, params.IntermediateStorageNodes)
|
||||
updatedAccounts, err := sdb.buildAccountUpdates(diffAccountsAtB, diffAccountsAtA, updatedKeys, params.WatchedStorageSlots, params.IntermediateStorageNodes)
|
||||
if err != nil {
|
||||
return StateDiff{}, fmt.Errorf("error building diff for updated accounts: %v", err)
|
||||
}
|
||||
createdAccounts, err := sdb.buildAccountCreations(diffAccountsAtB, params.IntermediateStorageNodes)
|
||||
createdAccounts, err := sdb.buildAccountCreations(diffAccountsAtB, params.WatchedStorageSlots, params.IntermediateStorageNodes)
|
||||
if err != nil {
|
||||
return StateDiff{}, fmt.Errorf("error building diff for created accounts: %v", err)
|
||||
}
|
||||
@ -151,7 +151,7 @@ func (sdb *builder) buildStateDiffWithoutIntermediateStateNodes(args Args, param
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (sdb *builder) collectDiffAccounts(a, b trie.NodeIterator, watchedAddresses []string) (AccountMap, error) {
|
||||
func (sdb *builder) collectDiffAccounts(a, b trie.NodeIterator, watchedAddresses []common.Address) (AccountMap, error) {
|
||||
diffAcountsAtB := make(AccountMap)
|
||||
it, _ := trie.NewDifferenceIterator(a, b)
|
||||
for it.Next(true) {
|
||||
@ -207,15 +207,14 @@ func (sdb *builder) collectDiffAccounts(a, b trie.NodeIterator, watchedAddresses
|
||||
}
|
||||
|
||||
// isWatchedAddress is used to check if a state account corresponds to one of the addresses the builder is configured to watch
|
||||
func isWatchedAddress(watchedAddresses []string, hashKey []byte) bool {
|
||||
func isWatchedAddress(watchedAddresses []common.Address, stateLeafKey []byte) bool {
|
||||
// If we aren't watching any specific addresses, we are watching everything
|
||||
if len(watchedAddresses) == 0 {
|
||||
return true
|
||||
}
|
||||
for _, addrStr := range watchedAddresses {
|
||||
addr := common.HexToAddress(addrStr)
|
||||
addrHashKey := crypto.Keccak256(addr[:])
|
||||
if bytes.Equal(addrHashKey, hashKey) {
|
||||
for _, addr := range watchedAddresses {
|
||||
addrHashKey := crypto.Keccak256(addr.Bytes())
|
||||
if bytes.Equal(addrHashKey, stateLeafKey) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
@ -353,7 +352,7 @@ func (sdb *builder) deletedOrUpdatedNodes(a, b trie.NodeIterator, diffPathsAtB m
|
||||
|
||||
// needs to be called before building account creations and deletions as this mutattes
|
||||
// those account maps to remove the accounts which were updated
|
||||
func (sdb *builder) buildAccountUpdates(creations, deletions AccountMap, updatedKeys []string, intermediateStorageNodes bool) ([]StateNode, error) {
|
||||
func (sdb *builder) buildAccountUpdates(creations, deletions AccountMap, updatedKeys []string, watchedStorageKeys []common.Hash, intermediateStorageNodes bool) ([]StateNode, error) {
|
||||
updatedAccounts := make([]StateNode, 0, len(updatedKeys))
|
||||
var err error
|
||||
for _, key := range updatedKeys {
|
||||
@ -363,7 +362,7 @@ func (sdb *builder) buildAccountUpdates(creations, deletions AccountMap, updated
|
||||
if deletedAcc.Account != nil && createdAcc.Account != nil {
|
||||
oldSR := deletedAcc.Account.Root
|
||||
newSR := createdAcc.Account.Root
|
||||
storageDiffs, err = sdb.buildStorageNodesIncremental(oldSR, newSR, intermediateStorageNodes)
|
||||
storageDiffs, err = sdb.buildStorageNodesIncremental(oldSR, newSR, watchedStorageKeys, intermediateStorageNodes)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed building incremental storage diffs for account with leafkey %s\r\nerror: %v", key, err)
|
||||
}
|
||||
@ -382,11 +381,11 @@ func (sdb *builder) buildAccountUpdates(creations, deletions AccountMap, updated
|
||||
return updatedAccounts, nil
|
||||
}
|
||||
|
||||
func (sdb *builder) buildAccountCreations(accounts AccountMap, intermediateStorageNodes bool) ([]StateNode, error) {
|
||||
func (sdb *builder) buildAccountCreations(accounts AccountMap, watchedStorageKeys []common.Hash, intermediateStorageNodes bool) ([]StateNode, error) {
|
||||
accountDiffs := make([]StateNode, 0, len(accounts))
|
||||
for _, val := range accounts {
|
||||
// For account creations, any storage node contained is a diff
|
||||
storageDiffs, err := sdb.buildStorageNodesEventual(val.Account.Root, intermediateStorageNodes)
|
||||
storageDiffs, err := sdb.buildStorageNodesEventual(val.Account.Root, watchedStorageKeys, intermediateStorageNodes)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed building eventual storage diffs for node %x\r\nerror: %v", val.Path, err)
|
||||
}
|
||||
@ -416,7 +415,7 @@ func (sdb *builder) buildAccountDeletions(accounts AccountMap) ([]StateNode, err
|
||||
return accountDiffs, nil
|
||||
}
|
||||
|
||||
func (sdb *builder) buildStorageNodesEventual(sr common.Hash, intermediateNodes bool) ([]StorageNode, error) {
|
||||
func (sdb *builder) buildStorageNodesEventual(sr common.Hash, watchedStorageKeys []common.Hash, intermediateNodes bool) ([]StorageNode, error) {
|
||||
log.Debug("Storage Root For Eventual Diff", "root", sr.Hex())
|
||||
sTrie, err := sdb.stateCache.OpenTrie(sr)
|
||||
if err != nil {
|
||||
@ -424,10 +423,10 @@ func (sdb *builder) buildStorageNodesEventual(sr common.Hash, intermediateNodes
|
||||
return nil, err
|
||||
}
|
||||
it := sTrie.NodeIterator(make([]byte, 0))
|
||||
return sdb.buildStorageNodesFromTrie(it, intermediateNodes)
|
||||
return sdb.buildStorageNodesFromTrie(it, watchedStorageKeys, intermediateNodes)
|
||||
}
|
||||
|
||||
func (sdb *builder) buildStorageNodesIncremental(oldSR common.Hash, newSR common.Hash, intermediateNodes bool) ([]StorageNode, error) {
|
||||
func (sdb *builder) buildStorageNodesIncremental(oldSR common.Hash, newSR common.Hash, watchedStorageKeys []common.Hash, intermediateNodes bool) ([]StorageNode, error) {
|
||||
log.Debug("Storage Roots for Incremental Diff", "old", oldSR.Hex(), "new", newSR.Hex())
|
||||
oldTrie, err := sdb.stateCache.OpenTrie(oldSR)
|
||||
if err != nil {
|
||||
@ -441,10 +440,10 @@ func (sdb *builder) buildStorageNodesIncremental(oldSR common.Hash, newSR common
|
||||
oldIt := oldTrie.NodeIterator(make([]byte, 0))
|
||||
newIt := newTrie.NodeIterator(make([]byte, 0))
|
||||
it, _ := trie.NewDifferenceIterator(oldIt, newIt)
|
||||
return sdb.buildStorageNodesFromTrie(it, intermediateNodes)
|
||||
return sdb.buildStorageNodesFromTrie(it, watchedStorageKeys, intermediateNodes)
|
||||
}
|
||||
|
||||
func (sdb *builder) buildStorageNodesFromTrie(it trie.NodeIterator, intermediateNodes bool) ([]StorageNode, error) {
|
||||
func (sdb *builder) buildStorageNodesFromTrie(it trie.NodeIterator, watchedStorageKeys []common.Hash, intermediateNodes bool) ([]StorageNode, error) {
|
||||
storageDiffs := make([]StorageNode, 0)
|
||||
for it.Next(true) {
|
||||
// skip value nodes
|
||||
@ -474,12 +473,14 @@ func (sdb *builder) buildStorageNodesFromTrie(it trie.NodeIterator, intermediate
|
||||
valueNodePath := append(nodePath, partialPath...)
|
||||
encodedPath := trie.HexToCompact(valueNodePath)
|
||||
leafKey := encodedPath[1:]
|
||||
storageDiffs = append(storageDiffs, StorageNode{
|
||||
NodeType: ty,
|
||||
Path: nodePath,
|
||||
NodeValue: node,
|
||||
LeafKey: leafKey,
|
||||
})
|
||||
if isWatchedStorageKey(watchedStorageKeys, leafKey) {
|
||||
storageDiffs = append(storageDiffs, StorageNode{
|
||||
NodeType: ty,
|
||||
Path: nodePath,
|
||||
NodeValue: node,
|
||||
LeafKey: leafKey,
|
||||
})
|
||||
}
|
||||
case Extension, Branch:
|
||||
if intermediateNodes {
|
||||
storageDiffs = append(storageDiffs, StorageNode{
|
||||
@ -494,3 +495,17 @@ func (sdb *builder) buildStorageNodesFromTrie(it trie.NodeIterator, intermediate
|
||||
}
|
||||
return storageDiffs, nil
|
||||
}
|
||||
|
||||
// isWatchedStorageKey is used to check if a storage leaf corresponds to one of the storage slots the builder is configured to watch
|
||||
func isWatchedStorageKey(watchedKeys []common.Hash, storageLeafKey []byte) bool {
|
||||
// If we aren't watching any specific addresses, we are watching everything
|
||||
if len(watchedKeys) == 0 {
|
||||
return true
|
||||
}
|
||||
for _, hashKey := range watchedKeys {
|
||||
if bytes.Equal(hashKey.Bytes(), storageLeafKey) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
@ -751,7 +751,7 @@ func TestBuilderWithWatchedAddressList(t *testing.T) {
|
||||
block2 = blockMap[BlockHashes[1]]
|
||||
block3 = blockMap[BlockHashes[0]]
|
||||
params := statediff.Params{
|
||||
WatchedAddresses: []string{testhelpers.Account1Addr.Hex(), testhelpers.ContractAddr.Hex()},
|
||||
WatchedAddresses: []common.Address{testhelpers.Account1Addr, testhelpers.ContractAddr},
|
||||
}
|
||||
builder = statediff.NewBuilder(chain.StateCache())
|
||||
|
||||
@ -911,6 +911,8 @@ func TestBuilderWithWatchedAddressList(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// Write a test that tests when accounts are deleted, or moved to a new path
|
||||
|
||||
/*
|
||||
contract test {
|
||||
|
||||
|
@ -41,7 +41,8 @@ type Params struct {
|
||||
IncludeBlock bool
|
||||
IncludeReceipts bool
|
||||
IncludeTD bool
|
||||
WatchedAddresses []string
|
||||
WatchedAddresses []common.Address
|
||||
WatchedStorageSlots []common.Hash
|
||||
}
|
||||
|
||||
// Args bundles the arguments for the state diff builder
|
||||
|
Loading…
Reference in New Issue
Block a user