From 381c61c51713219c45bb2b4631d6f533be35995f Mon Sep 17 00:00:00 2001 From: prathamesh0 Date: Mon, 31 Jan 2022 17:34:32 +0530 Subject: [PATCH] Add a fix in builder for removal of a non-watched address --- statediff/builder.go | 40 +++++----- statediff/builder_test.go | 149 +++++++++++++++++++++++++++++++++++++- 2 files changed, 169 insertions(+), 20 deletions(-) diff --git a/statediff/builder.go b/statediff/builder.go index 7befb6b3c..63b354a4c 100644 --- a/statediff/builder.go +++ b/statediff/builder.go @@ -202,7 +202,7 @@ func (sdb *builder) buildStateDiffWithIntermediateStateNodes(args StateRoots, pa // a map of their leafkey to all the accounts that were touched and exist at A diffAccountsAtA, err := sdb.deletedOrUpdatedState( oldTrie.NodeIterator([]byte{}), newTrie.NodeIterator([]byte{}), - diffPathsAtB, output) + diffPathsAtB, params.WatchedAddresses, output) if err != nil { return fmt.Errorf("error collecting deletedOrUpdatedNodes: %v", err) } @@ -256,7 +256,7 @@ func (sdb *builder) buildStateDiffWithoutIntermediateStateNodes(args StateRoots, // a map of their leafkey to all the accounts that were touched and exist at A diffAccountsAtA, err := sdb.deletedOrUpdatedState( oldTrie.NodeIterator([]byte{}), newTrie.NodeIterator([]byte{}), - diffPathsAtB, output) + diffPathsAtB, params.WatchedAddresses, output) if err != nil { return fmt.Errorf("error collecting deletedOrUpdatedNodes: %v", err) } @@ -386,7 +386,7 @@ func (sdb *builder) createdAndUpdatedStateWithIntermediateNodes(a, b trie.NodeIt // deletedOrUpdatedState returns a slice of all the pathes that are emptied at B // and a mapping of their leafkeys to all the accounts that exist in a different state at A than B -func (sdb *builder) deletedOrUpdatedState(a, b trie.NodeIterator, diffPathsAtB map[string]bool, output StateNodeSink) (AccountMap, error) { +func (sdb *builder) deletedOrUpdatedState(a, b trie.NodeIterator, diffPathsAtB map[string]bool, watchedAddresses []common.Address, output StateNodeSink) (AccountMap, error) { diffAccountAtA := make(AccountMap) it, _ := trie.NewDifferenceIterator(b, a) for it.Next(true) { @@ -409,24 +409,26 @@ func (sdb *builder) deletedOrUpdatedState(a, b trie.NodeIterator, diffPathsAtB m valueNodePath := append(node.Path, partialPath...) encodedPath := trie.HexToCompact(valueNodePath) leafKey := encodedPath[1:] - diffAccountAtA[common.Bytes2Hex(leafKey)] = accountWrapper{ - NodeType: node.NodeType, - Path: node.Path, - NodeValue: node.NodeValue, - LeafKey: leafKey, - Account: &account, - } - // if this node's path did not show up in diffPathsAtB - // that means the node at this path was deleted (or moved) in B - // emit an empty "removed" diff to signify as such - if _, ok := diffPathsAtB[common.Bytes2Hex(node.Path)]; !ok { - if err := output(StateNode{ + if isWatchedAddress(watchedAddresses, leafKey) { + diffAccountAtA[common.Bytes2Hex(leafKey)] = accountWrapper{ + NodeType: node.NodeType, Path: node.Path, - NodeValue: []byte{}, - NodeType: Removed, + NodeValue: node.NodeValue, LeafKey: leafKey, - }); err != nil { - return nil, err + Account: &account, + } + // if this node's path did not show up in diffPathsAtB + // that means the node at this path was deleted (or moved) in B + // emit an empty "removed" diff to signify as such + if _, ok := diffPathsAtB[common.Bytes2Hex(node.Path)]; !ok { + if err := output(StateNode{ + Path: node.Path, + NodeValue: []byte{}, + NodeType: Removed, + LeafKey: leafKey, + }); err != nil { + return nil, err + } } } case Extension, Branch: diff --git a/statediff/builder_test.go b/statediff/builder_test.go index 6a88bbba0..189295518 100644 --- a/statediff/builder_test.go +++ b/statediff/builder_test.go @@ -1152,13 +1152,14 @@ func TestBuilderWithWatchedAddressList(t *testing.T) { } func TestBuilderWithWatchedAddressAndStorageKeyList(t *testing.T) { - blocks, chain := testhelpers.MakeChain(3, testhelpers.Genesis, testhelpers.TestChainGen) + blocks, chain := testhelpers.MakeChain(4, testhelpers.Genesis, testhelpers.TestChainGen) contractLeafKey = testhelpers.AddressToLeafKey(testhelpers.ContractAddr) defer chain.Stop() block0 = testhelpers.Genesis block1 = blocks[0] block2 = blocks[1] block3 = blocks[2] + block4 = blocks[3] params := statediff.Params{ WatchedAddresses: []common.Address{testhelpers.Account1Addr, testhelpers.ContractAddr}, WatchedStorageSlots: []common.Hash{slot1StorageKey}, @@ -1290,6 +1291,35 @@ func TestBuilderWithWatchedAddressAndStorageKeyList(t *testing.T) { }, }, }, + { + "testBlock4", + statediff.Args{ + OldStateRoot: block3.Root(), + NewStateRoot: block4.Root(), + BlockNumber: block4.Number(), + BlockHash: block4.Hash(), + }, + &statediff.StateObject{ + BlockNumber: block4.Number(), + BlockHash: block4.Hash(), + Nodes: []sdtypes.StateNode{ + { + Path: []byte{'\x06'}, + NodeType: sdtypes.Leaf, + LeafKey: contractLeafKey, + NodeValue: contractAccountAtBlock4LeafNode, + StorageNodes: []sdtypes.StorageNode{ + { + Path: []byte{'\x0b'}, + NodeType: sdtypes.Removed, + LeafKey: slot1StorageKey.Bytes(), + NodeValue: []byte{}, + }, + }, + }, + }, + }, + }, } for _, test := range tests { @@ -1718,6 +1748,123 @@ func TestBuilderWithRemovedAccountAndStorageWithoutIntermediateNodes(t *testing. } } +func TestBuilderWithRemovedNonWatchedAccount(t *testing.T) { + blocks, chain := testhelpers.MakeChain(6, testhelpers.Genesis, testhelpers.TestChainGen) + contractLeafKey = testhelpers.AddressToLeafKey(testhelpers.ContractAddr) + defer chain.Stop() + block3 = blocks[2] + block4 = blocks[3] + block5 = blocks[4] + block6 = blocks[5] + params := statediff.Params{ + WatchedAddresses: []common.Address{testhelpers.Account1Addr, testhelpers.Account2Addr}, + } + builder = statediff.NewBuilder(chain.StateCache()) + + var tests = []struct { + name string + startingArguments statediff.Args + expected *statediff.StateObject + }{ + { + "testBlock4", + statediff.Args{ + OldStateRoot: block3.Root(), + NewStateRoot: block4.Root(), + BlockNumber: block4.Number(), + BlockHash: block4.Hash(), + }, + &statediff.StateObject{ + BlockNumber: block4.Number(), + BlockHash: block4.Hash(), + Nodes: []sdtypes.StateNode{ + { + Path: []byte{'\x0c'}, + NodeType: sdtypes.Leaf, + LeafKey: testhelpers.Account2LeafKey, + NodeValue: account2AtBlock4LeafNode, + StorageNodes: emptyStorage, + }, + }, + }, + }, + { + "testBlock5", + statediff.Args{ + OldStateRoot: block4.Root(), + NewStateRoot: block5.Root(), + BlockNumber: block5.Number(), + BlockHash: block5.Hash(), + }, + &statediff.StateObject{ + BlockNumber: block5.Number(), + BlockHash: block5.Hash(), + Nodes: []sdtypes.StateNode{ + { + Path: []byte{'\x0e'}, + NodeType: sdtypes.Leaf, + LeafKey: testhelpers.Account1LeafKey, + NodeValue: account1AtBlock5LeafNode, + StorageNodes: emptyStorage, + }, + }, + }, + }, + { + "testBlock6", + statediff.Args{ + OldStateRoot: block5.Root(), + NewStateRoot: block6.Root(), + BlockNumber: block6.Number(), + BlockHash: block6.Hash(), + }, + &statediff.StateObject{ + BlockNumber: block6.Number(), + BlockHash: block6.Hash(), + Nodes: []sdtypes.StateNode{ + { + Path: []byte{'\x0c'}, + NodeType: sdtypes.Leaf, + LeafKey: testhelpers.Account2LeafKey, + NodeValue: account2AtBlock6LeafNode, + StorageNodes: emptyStorage, + }, + { + Path: []byte{'\x0e'}, + NodeType: sdtypes.Leaf, + LeafKey: testhelpers.Account1LeafKey, + NodeValue: account1AtBlock6LeafNode, + StorageNodes: emptyStorage, + }, + }, + }, + }, + } + + for _, test := range tests { + diff, err := builder.BuildStateDiffObject(test.startingArguments, params) + if err != nil { + t.Error(err) + } + receivedStateDiffRlp, err := rlp.EncodeToBytes(diff) + if err != nil { + t.Error(err) + } + + expectedStateDiffRlp, err := rlp.EncodeToBytes(test.expected) + if err != nil { + t.Error(err) + } + + sort.Slice(receivedStateDiffRlp, func(i, j int) bool { return receivedStateDiffRlp[i] < receivedStateDiffRlp[j] }) + sort.Slice(expectedStateDiffRlp, func(i, j int) bool { return expectedStateDiffRlp[i] < expectedStateDiffRlp[j] }) + if !bytes.Equal(receivedStateDiffRlp, expectedStateDiffRlp) { + t.Logf("Test failed: %s", test.name) + t.Errorf("actual state diff: %+v\r\n\r\n\r\nexpected state diff: %+v", diff, test.expected) + } + } +} + var ( slot00StorageValue = common.Hex2Bytes("9471562b71999873db5b286df957af199ec94617f7") // prefixed TestBankAddress