diff --git a/Makefile b/Makefile index 975628922..c3c95eeae 100644 --- a/Makefile +++ b/Makefile @@ -52,6 +52,7 @@ ios: .PHONY: statedifftest statedifftest: | $(GOOSE) + go get github.com/stretchr/testify/assert@v1.7.0 MODE=statediff go test ./statediff/... -v test: all diff --git a/docker-compose.yml b/docker-compose.yml index f1a37ddcb..8083b4d87 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -3,7 +3,7 @@ version: '3.2' services: ipld-eth-db: restart: always - image: vulcanize/ipld-eth-db:v0.2.0 + image: vulcanize/ipld-eth-db:v2.1.1 environment: POSTGRES_USER: "vdbm" POSTGRES_DB: "vulcanize_public" diff --git a/go.mod b/go.mod index 0f94c2611..3baf72200 100644 --- a/go.mod +++ b/go.mod @@ -68,6 +68,7 @@ require ( github.com/status-im/keycard-go v0.0.0-20190316090335-8537d3370df4 github.com/stretchr/testify v1.7.0 github.com/syndtr/goleveldb v1.0.1-0.20210819022825-2ae1ddf74ef7 + github.com/thoas/go-funk v0.9.1 github.com/tklauser/go-sysconf v0.3.5 // indirect github.com/tyler-smith/go-bip39 v1.0.1-0.20181017060643-dbb3b84ba2ef golang.org/x/crypto v0.0.0-20210513164829-c07d793c2f9a diff --git a/go.sum b/go.sum index 0097b96f7..76d8234e5 100644 --- a/go.sum +++ b/go.sum @@ -470,6 +470,8 @@ github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5Cc github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/syndtr/goleveldb v1.0.1-0.20210819022825-2ae1ddf74ef7 h1:epCh84lMvA70Z7CTTCmYQn2CKbY8j86K7/FAIr141uY= github.com/syndtr/goleveldb v1.0.1-0.20210819022825-2ae1ddf74ef7/go.mod h1:q4W45IWZaF22tdD+VEXcAWRA037jwmWEB5VWYORlTpc= +github.com/thoas/go-funk v0.9.1 h1:O549iLZqPpTUQ10ykd26sZhzD+rmR5pWhuElrhbC20M= +github.com/thoas/go-funk v0.9.1/go.mod h1:+IWnUfUmFO1+WVYQWQtIJHeRRdaIyyYglZN7xzUPe4Q= github.com/tinylib/msgp v1.0.2/go.mod h1:+d+yLhGm8mzTaHzB+wgMYrodPfmZrzkirds8fDWklFE= github.com/tklauser/go-sysconf v0.3.5 h1:uu3Xl4nkLzQfXNsWn15rPc/HQCJKObbt1dKJeWp3vU4= github.com/tklauser/go-sysconf v0.3.5/go.mod h1:MkWzOF4RMCshBAMXuhXJs64Rte09mITnppBXY/rYEFI= diff --git a/statediff/README.md b/statediff/README.md index 8e673efa6..7cfb6d0f1 100644 --- a/statediff/README.md +++ b/statediff/README.md @@ -106,15 +106,13 @@ type Params struct { IncludeTD bool IncludeCode bool WatchedAddresses []common.Address - WatchedStorageSlots []common.Hash } ``` Using these params we can tell the service whether to include state and/or storage intermediate nodes; whether to include the associated block (header, uncles, and transactions); whether to include the associated receipts; whether to include the total difficulty for this block; whether to include the set of code hashes and code for -contracts deployed in this block; whether to limit the diffing process to a list of specific addresses; and/or -whether to limit the diffing process to a list of specific storage slot keys. +contracts deployed in this block; whether to limit the diffing process to a list of specific addresses. #### Subscription endpoint A websocket supporting RPC endpoint is exposed for subscribing to state diff `StateObjects` that come off the head of the chain while the geth node syncs. diff --git a/statediff/api.go b/statediff/api.go index 923a0073f..ed9cc3c06 100644 --- a/statediff/api.go +++ b/statediff/api.go @@ -149,3 +149,8 @@ func (api *PublicStateDiffAPI) WriteStateDiffAt(ctx context.Context, blockNumber func (api *PublicStateDiffAPI) WriteStateDiffFor(ctx context.Context, blockHash common.Hash, params Params) error { return api.sds.WriteStateDiffFor(blockHash, params) } + +// WatchAddress changes the list of watched addresses to which the direct indexing is restricted according to given operation +func (api *PublicStateDiffAPI) WatchAddress(operation OperationType, args []WatchAddressArg) error { + return api.sds.WatchAddress(operation, args) +} diff --git a/statediff/builder.go b/statediff/builder.go index 7befb6b3c..46546c1d5 100644 --- a/statediff/builder.go +++ b/statediff/builder.go @@ -123,7 +123,7 @@ func (sdb *builder) buildStateTrie(it trie.NodeIterator) ([]StateNode, []CodeAnd node.LeafKey = leafKey if !bytes.Equal(account.CodeHash, nullCodeHash) { var storageNodes []StorageNode - err := sdb.buildStorageNodesEventual(account.Root, nil, true, storageNodeAppender(&storageNodes)) + err := sdb.buildStorageNodesEventual(account.Root, true, storageNodeAppender(&storageNodes)) if err != nil { return nil, nil, fmt.Errorf("failed building eventual storage diffs for account %+v\r\nerror: %v", account, err) } @@ -202,7 +202,7 @@ func (sdb *builder) buildStateDiffWithIntermediateStateNodes(args StateRoots, pa // a map of their leafkey to all the accounts that were touched and exist at A diffAccountsAtA, err := sdb.deletedOrUpdatedState( oldTrie.NodeIterator([]byte{}), newTrie.NodeIterator([]byte{}), - diffPathsAtB, output) + diffPathsAtB, params.watchedAddressesLeafKeys, output) if err != nil { return fmt.Errorf("error collecting deletedOrUpdatedNodes: %v", err) } @@ -220,12 +220,12 @@ func (sdb *builder) buildStateDiffWithIntermediateStateNodes(args StateRoots, pa // build the diff nodes for the updated accounts using the mappings at both A and B as directed by the keys found as the intersection of the two err = sdb.buildAccountUpdates( diffAccountsAtB, diffAccountsAtA, updatedKeys, - params.WatchedStorageSlots, params.IntermediateStorageNodes, output) + params.IntermediateStorageNodes, output) if err != nil { return fmt.Errorf("error building diff for updated accounts: %v", err) } // build the diff nodes for created accounts - err = sdb.buildAccountCreations(diffAccountsAtB, params.WatchedStorageSlots, params.IntermediateStorageNodes, output, codeOutput) + err = sdb.buildAccountCreations(diffAccountsAtB, params.IntermediateStorageNodes, output, codeOutput) if err != nil { return fmt.Errorf("error building diff for created accounts: %v", err) } @@ -247,7 +247,7 @@ func (sdb *builder) buildStateDiffWithoutIntermediateStateNodes(args StateRoots, // and a slice of all the paths for the nodes in both of the above sets diffAccountsAtB, diffPathsAtB, err := sdb.createdAndUpdatedState( oldTrie.NodeIterator([]byte{}), newTrie.NodeIterator([]byte{}), - params.WatchedAddresses) + params.watchedAddressesLeafKeys) if err != nil { return fmt.Errorf("error collecting createdAndUpdatedNodes: %v", err) } @@ -256,7 +256,7 @@ func (sdb *builder) buildStateDiffWithoutIntermediateStateNodes(args StateRoots, // a map of their leafkey to all the accounts that were touched and exist at A diffAccountsAtA, err := sdb.deletedOrUpdatedState( oldTrie.NodeIterator([]byte{}), newTrie.NodeIterator([]byte{}), - diffPathsAtB, output) + diffPathsAtB, params.watchedAddressesLeafKeys, output) if err != nil { return fmt.Errorf("error collecting deletedOrUpdatedNodes: %v", err) } @@ -274,12 +274,12 @@ func (sdb *builder) buildStateDiffWithoutIntermediateStateNodes(args StateRoots, // build the diff nodes for the updated accounts using the mappings at both A and B as directed by the keys found as the intersection of the two err = sdb.buildAccountUpdates( diffAccountsAtB, diffAccountsAtA, updatedKeys, - params.WatchedStorageSlots, params.IntermediateStorageNodes, output) + params.IntermediateStorageNodes, output) if err != nil { return fmt.Errorf("error building diff for updated accounts: %v", err) } // build the diff nodes for created accounts - err = sdb.buildAccountCreations(diffAccountsAtB, params.WatchedStorageSlots, params.IntermediateStorageNodes, output, codeOutput) + err = sdb.buildAccountCreations(diffAccountsAtB, params.IntermediateStorageNodes, output, codeOutput) if err != nil { return fmt.Errorf("error building diff for created accounts: %v", err) } @@ -289,7 +289,7 @@ func (sdb *builder) buildStateDiffWithoutIntermediateStateNodes(args StateRoots, // createdAndUpdatedState returns // a mapping of their leafkeys to all the accounts that exist in a different state at B than A // and a slice of the paths for all of the nodes included in both -func (sdb *builder) createdAndUpdatedState(a, b trie.NodeIterator, watchedAddresses []common.Address) (AccountMap, map[string]bool, error) { +func (sdb *builder) createdAndUpdatedState(a, b trie.NodeIterator, watchedAddressesLeafKeys map[common.Hash]struct{}) (AccountMap, map[string]bool, error) { diffPathsAtB := make(map[string]bool) diffAcountsAtB := make(AccountMap) it, _ := trie.NewDifferenceIterator(a, b) @@ -313,7 +313,7 @@ func (sdb *builder) createdAndUpdatedState(a, b trie.NodeIterator, watchedAddres valueNodePath := append(node.Path, partialPath...) encodedPath := trie.HexToCompact(valueNodePath) leafKey := encodedPath[1:] - if isWatchedAddress(watchedAddresses, leafKey) { + if isWatchedAddress(watchedAddressesLeafKeys, leafKey) { diffAcountsAtB[common.Bytes2Hex(leafKey)] = accountWrapper{ NodeType: node.NodeType, Path: node.Path, @@ -386,7 +386,7 @@ func (sdb *builder) createdAndUpdatedStateWithIntermediateNodes(a, b trie.NodeIt // deletedOrUpdatedState returns a slice of all the pathes that are emptied at B // and a mapping of their leafkeys to all the accounts that exist in a different state at A than B -func (sdb *builder) deletedOrUpdatedState(a, b trie.NodeIterator, diffPathsAtB map[string]bool, output StateNodeSink) (AccountMap, error) { +func (sdb *builder) deletedOrUpdatedState(a, b trie.NodeIterator, diffPathsAtB map[string]bool, watchedAddressesLeafKeys map[common.Hash]struct{}, output StateNodeSink) (AccountMap, error) { diffAccountAtA := make(AccountMap) it, _ := trie.NewDifferenceIterator(b, a) for it.Next(true) { @@ -409,24 +409,26 @@ func (sdb *builder) deletedOrUpdatedState(a, b trie.NodeIterator, diffPathsAtB m valueNodePath := append(node.Path, partialPath...) encodedPath := trie.HexToCompact(valueNodePath) leafKey := encodedPath[1:] - diffAccountAtA[common.Bytes2Hex(leafKey)] = accountWrapper{ - NodeType: node.NodeType, - Path: node.Path, - NodeValue: node.NodeValue, - LeafKey: leafKey, - Account: &account, - } - // if this node's path did not show up in diffPathsAtB - // that means the node at this path was deleted (or moved) in B - // emit an empty "removed" diff to signify as such - if _, ok := diffPathsAtB[common.Bytes2Hex(node.Path)]; !ok { - if err := output(StateNode{ + if isWatchedAddress(watchedAddressesLeafKeys, leafKey) { + diffAccountAtA[common.Bytes2Hex(leafKey)] = accountWrapper{ + NodeType: node.NodeType, Path: node.Path, - NodeValue: []byte{}, - NodeType: Removed, + NodeValue: node.NodeValue, LeafKey: leafKey, - }); err != nil { - return nil, err + Account: &account, + } + // if this node's path did not show up in diffPathsAtB + // that means the node at this path was deleted (or moved) in B + // emit an empty "removed" diff to signify as such + if _, ok := diffPathsAtB[common.Bytes2Hex(node.Path)]; !ok { + if err := output(StateNode{ + Path: node.Path, + NodeValue: []byte{}, + NodeType: Removed, + LeafKey: leafKey, + }); err != nil { + return nil, err + } } } case Extension, Branch: @@ -454,8 +456,7 @@ func (sdb *builder) deletedOrUpdatedState(a, b trie.NodeIterator, diffPathsAtB m // to generate the statediff node objects for all of the accounts that existed at both A and B but in different states // needs to be called before building account creations and deletions as this mutates // those account maps to remove the accounts which were updated -func (sdb *builder) buildAccountUpdates(creations, deletions AccountMap, updatedKeys []string, - watchedStorageKeys []common.Hash, intermediateStorageNodes bool, output StateNodeSink) error { +func (sdb *builder) buildAccountUpdates(creations, deletions AccountMap, updatedKeys []string, intermediateStorageNodes bool, output StateNodeSink) error { var err error for _, key := range updatedKeys { createdAcc := creations[key] @@ -465,7 +466,7 @@ func (sdb *builder) buildAccountUpdates(creations, deletions AccountMap, updated oldSR := deletedAcc.Account.Root newSR := createdAcc.Account.Root err = sdb.buildStorageNodesIncremental( - oldSR, newSR, watchedStorageKeys, intermediateStorageNodes, + oldSR, newSR, intermediateStorageNodes, storageNodeAppender(&storageDiffs)) if err != nil { return fmt.Errorf("failed building incremental storage diffs for account with leafkey %s\r\nerror: %v", key, err) @@ -489,7 +490,7 @@ func (sdb *builder) buildAccountUpdates(creations, deletions AccountMap, updated // buildAccountCreations returns the statediff node objects for all the accounts that exist at B but not at A // it also returns the code and codehash for created contract accounts -func (sdb *builder) buildAccountCreations(accounts AccountMap, watchedStorageKeys []common.Hash, intermediateStorageNodes bool, output StateNodeSink, codeOutput CodeSink) error { +func (sdb *builder) buildAccountCreations(accounts AccountMap, intermediateStorageNodes bool, output StateNodeSink, codeOutput CodeSink) error { for _, val := range accounts { diff := StateNode{ NodeType: val.NodeType, @@ -500,7 +501,7 @@ func (sdb *builder) buildAccountCreations(accounts AccountMap, watchedStorageKey if !bytes.Equal(val.Account.CodeHash, nullCodeHash) { // For contract creations, any storage node contained is a diff var storageDiffs []StorageNode - err := sdb.buildStorageNodesEventual(val.Account.Root, watchedStorageKeys, intermediateStorageNodes, storageNodeAppender(&storageDiffs)) + err := sdb.buildStorageNodesEventual(val.Account.Root, intermediateStorageNodes, storageNodeAppender(&storageDiffs)) if err != nil { return fmt.Errorf("failed building eventual storage diffs for node %x\r\nerror: %v", val.Path, err) } @@ -528,7 +529,7 @@ func (sdb *builder) buildAccountCreations(accounts AccountMap, watchedStorageKey // buildStorageNodesEventual builds the storage diff node objects for a created account // i.e. it returns all the storage nodes at this state, since there is no previous state -func (sdb *builder) buildStorageNodesEventual(sr common.Hash, watchedStorageKeys []common.Hash, intermediateNodes bool, output StorageNodeSink) error { +func (sdb *builder) buildStorageNodesEventual(sr common.Hash, intermediateNodes bool, output StorageNodeSink) error { if bytes.Equal(sr.Bytes(), emptyContractRoot.Bytes()) { return nil } @@ -539,7 +540,7 @@ func (sdb *builder) buildStorageNodesEventual(sr common.Hash, watchedStorageKeys return err } it := sTrie.NodeIterator(make([]byte, 0)) - err = sdb.buildStorageNodesFromTrie(it, watchedStorageKeys, intermediateNodes, output) + err = sdb.buildStorageNodesFromTrie(it, intermediateNodes, output) if err != nil { return err } @@ -549,7 +550,7 @@ func (sdb *builder) buildStorageNodesEventual(sr common.Hash, watchedStorageKeys // buildStorageNodesFromTrie returns all the storage diff node objects in the provided node interator // if any storage keys are provided it will only return those leaf nodes // including intermediate nodes can be turned on or off -func (sdb *builder) buildStorageNodesFromTrie(it trie.NodeIterator, watchedStorageKeys []common.Hash, intermediateNodes bool, output StorageNodeSink) error { +func (sdb *builder) buildStorageNodesFromTrie(it trie.NodeIterator, intermediateNodes bool, output StorageNodeSink) error { for it.Next(true) { // skip value nodes if it.Leaf() || bytes.Equal(nullHashBytes, it.Hash().Bytes()) { @@ -565,15 +566,13 @@ func (sdb *builder) buildStorageNodesFromTrie(it trie.NodeIterator, watchedStora valueNodePath := append(node.Path, partialPath...) encodedPath := trie.HexToCompact(valueNodePath) leafKey := encodedPath[1:] - if isWatchedStorageKey(watchedStorageKeys, leafKey) { - if err := output(StorageNode{ - NodeType: node.NodeType, - Path: node.Path, - NodeValue: node.NodeValue, - LeafKey: leafKey, - }); err != nil { - return err - } + if err := output(StorageNode{ + NodeType: node.NodeType, + Path: node.Path, + NodeValue: node.NodeValue, + LeafKey: leafKey, + }); err != nil { + return err } case Extension, Branch: if intermediateNodes { @@ -593,7 +592,7 @@ func (sdb *builder) buildStorageNodesFromTrie(it trie.NodeIterator, watchedStora } // buildStorageNodesIncremental builds the storage diff node objects for all nodes that exist in a different state at B than A -func (sdb *builder) buildStorageNodesIncremental(oldSR common.Hash, newSR common.Hash, watchedStorageKeys []common.Hash, intermediateNodes bool, output StorageNodeSink) error { +func (sdb *builder) buildStorageNodesIncremental(oldSR common.Hash, newSR common.Hash, intermediateNodes bool, output StorageNodeSink) error { if bytes.Equal(newSR.Bytes(), oldSR.Bytes()) { return nil } @@ -609,19 +608,19 @@ func (sdb *builder) buildStorageNodesIncremental(oldSR common.Hash, newSR common diffPathsAtB, err := sdb.createdAndUpdatedStorage( oldTrie.NodeIterator([]byte{}), newTrie.NodeIterator([]byte{}), - watchedStorageKeys, intermediateNodes, output) + intermediateNodes, output) if err != nil { return err } err = sdb.deletedOrUpdatedStorage(oldTrie.NodeIterator([]byte{}), newTrie.NodeIterator([]byte{}), - diffPathsAtB, watchedStorageKeys, intermediateNodes, output) + diffPathsAtB, intermediateNodes, output) if err != nil { return err } return nil } -func (sdb *builder) createdAndUpdatedStorage(a, b trie.NodeIterator, watchedKeys []common.Hash, intermediateNodes bool, output StorageNodeSink) (map[string]bool, error) { +func (sdb *builder) createdAndUpdatedStorage(a, b trie.NodeIterator, intermediateNodes bool, output StorageNodeSink) (map[string]bool, error) { diffPathsAtB := make(map[string]bool) it, _ := trie.NewDifferenceIterator(a, b) for it.Next(true) { @@ -639,15 +638,13 @@ func (sdb *builder) createdAndUpdatedStorage(a, b trie.NodeIterator, watchedKeys valueNodePath := append(node.Path, partialPath...) encodedPath := trie.HexToCompact(valueNodePath) leafKey := encodedPath[1:] - if isWatchedStorageKey(watchedKeys, leafKey) { - if err := output(StorageNode{ - NodeType: node.NodeType, - Path: node.Path, - NodeValue: node.NodeValue, - LeafKey: leafKey, - }); err != nil { - return nil, err - } + if err := output(StorageNode{ + NodeType: node.NodeType, + Path: node.Path, + NodeValue: node.NodeValue, + LeafKey: leafKey, + }); err != nil { + return nil, err } case Extension, Branch: if intermediateNodes { @@ -667,7 +664,7 @@ func (sdb *builder) createdAndUpdatedStorage(a, b trie.NodeIterator, watchedKeys return diffPathsAtB, it.Error() } -func (sdb *builder) deletedOrUpdatedStorage(a, b trie.NodeIterator, diffPathsAtB map[string]bool, watchedKeys []common.Hash, intermediateNodes bool, output StorageNodeSink) error { +func (sdb *builder) deletedOrUpdatedStorage(a, b trie.NodeIterator, diffPathsAtB map[string]bool, intermediateNodes bool, output StorageNodeSink) error { it, _ := trie.NewDifferenceIterator(b, a) for it.Next(true) { // skip value nodes @@ -690,15 +687,13 @@ func (sdb *builder) deletedOrUpdatedStorage(a, b trie.NodeIterator, diffPathsAtB valueNodePath := append(node.Path, partialPath...) encodedPath := trie.HexToCompact(valueNodePath) leafKey := encodedPath[1:] - if isWatchedStorageKey(watchedKeys, leafKey) { - if err := output(StorageNode{ - NodeType: Removed, - Path: node.Path, - NodeValue: []byte{}, - LeafKey: leafKey, - }); err != nil { - return err - } + if err := output(StorageNode{ + NodeType: Removed, + Path: node.Path, + NodeValue: []byte{}, + LeafKey: leafKey, + }); err != nil { + return err } case Extension, Branch: if intermediateNodes { @@ -718,30 +713,12 @@ func (sdb *builder) deletedOrUpdatedStorage(a, b trie.NodeIterator, diffPathsAtB } // isWatchedAddress is used to check if a state account corresponds to one of the addresses the builder is configured to watch -func isWatchedAddress(watchedAddresses []common.Address, stateLeafKey []byte) bool { +func isWatchedAddress(watchedAddressesLeafKeys map[common.Hash]struct{}, stateLeafKey []byte) bool { // If we aren't watching any specific addresses, we are watching everything - if len(watchedAddresses) == 0 { + if len(watchedAddressesLeafKeys) == 0 { return true } - for _, addr := range watchedAddresses { - addrHashKey := crypto.Keccak256(addr.Bytes()) - if bytes.Equal(addrHashKey, stateLeafKey) { - return true - } - } - return false -} -// isWatchedStorageKey is used to check if a storage leaf corresponds to one of the storage slots the builder is configured to watch -func isWatchedStorageKey(watchedKeys []common.Hash, storageLeafKey []byte) bool { - // If we aren't watching any specific addresses, we are watching everything - if len(watchedKeys) == 0 { - return true - } - for _, hashKey := range watchedKeys { - if bytes.Equal(hashKey.Bytes(), storageLeafKey) { - return true - } - } - return false + _, ok := watchedAddressesLeafKeys[common.BytesToHash(stateLeafKey)] + return ok } diff --git a/statediff/builder_test.go b/statediff/builder_test.go index 6a88bbba0..e7257ef6a 100644 --- a/statediff/builder_test.go +++ b/statediff/builder_test.go @@ -987,6 +987,7 @@ func TestBuilderWithWatchedAddressList(t *testing.T) { params := statediff.Params{ WatchedAddresses: []common.Address{testhelpers.Account1Addr, testhelpers.ContractAddr}, } + params.ComputeWatchedAddressesLeafKeys() builder = statediff.NewBuilder(chain.StateCache()) var tests = []struct { @@ -1151,169 +1152,6 @@ func TestBuilderWithWatchedAddressList(t *testing.T) { } } -func TestBuilderWithWatchedAddressAndStorageKeyList(t *testing.T) { - blocks, chain := testhelpers.MakeChain(3, testhelpers.Genesis, testhelpers.TestChainGen) - contractLeafKey = testhelpers.AddressToLeafKey(testhelpers.ContractAddr) - defer chain.Stop() - block0 = testhelpers.Genesis - block1 = blocks[0] - block2 = blocks[1] - block3 = blocks[2] - params := statediff.Params{ - WatchedAddresses: []common.Address{testhelpers.Account1Addr, testhelpers.ContractAddr}, - WatchedStorageSlots: []common.Hash{slot1StorageKey}, - } - builder = statediff.NewBuilder(chain.StateCache()) - - var tests = []struct { - name string - startingArguments statediff.Args - expected *statediff.StateObject - }{ - { - "testEmptyDiff", - statediff.Args{ - OldStateRoot: block0.Root(), - NewStateRoot: block0.Root(), - BlockNumber: block0.Number(), - BlockHash: block0.Hash(), - }, - &statediff.StateObject{ - BlockNumber: block0.Number(), - BlockHash: block0.Hash(), - Nodes: emptyDiffs, - }, - }, - { - "testBlock0", - //10000 transferred from testBankAddress to account1Addr - statediff.Args{ - OldStateRoot: testhelpers.NullHash, - NewStateRoot: block0.Root(), - BlockNumber: block0.Number(), - BlockHash: block0.Hash(), - }, - &statediff.StateObject{ - BlockNumber: block0.Number(), - BlockHash: block0.Hash(), - Nodes: emptyDiffs, - }, - }, - { - "testBlock1", - //10000 transferred from testBankAddress to account1Addr - statediff.Args{ - OldStateRoot: block0.Root(), - NewStateRoot: block1.Root(), - BlockNumber: block1.Number(), - BlockHash: block1.Hash(), - }, - &statediff.StateObject{ - BlockNumber: block1.Number(), - BlockHash: block1.Hash(), - Nodes: []sdtypes.StateNode{ - { - Path: []byte{'\x0e'}, - NodeType: sdtypes.Leaf, - LeafKey: testhelpers.Account1LeafKey, - NodeValue: account1AtBlock1LeafNode, - StorageNodes: emptyStorage, - }, - }, - }, - }, - { - "testBlock2", - //1000 transferred from testBankAddress to account1Addr - //1000 transferred from account1Addr to account2Addr - statediff.Args{ - OldStateRoot: block1.Root(), - NewStateRoot: block2.Root(), - BlockNumber: block2.Number(), - BlockHash: block2.Hash(), - }, - &statediff.StateObject{ - BlockNumber: block2.Number(), - BlockHash: block2.Hash(), - Nodes: []sdtypes.StateNode{ - { - Path: []byte{'\x06'}, - NodeType: sdtypes.Leaf, - LeafKey: contractLeafKey, - NodeValue: contractAccountAtBlock2LeafNode, - StorageNodes: []sdtypes.StorageNode{ - { - Path: []byte{'\x0b'}, - NodeType: sdtypes.Leaf, - LeafKey: slot1StorageKey.Bytes(), - NodeValue: slot1StorageLeafNode, - }, - }, - }, - { - Path: []byte{'\x0e'}, - NodeType: sdtypes.Leaf, - LeafKey: testhelpers.Account1LeafKey, - NodeValue: account1AtBlock2LeafNode, - StorageNodes: emptyStorage, - }, - }, - CodeAndCodeHashes: []sdtypes.CodeAndCodeHash{ - { - Hash: testhelpers.CodeHash, - Code: testhelpers.ByteCodeAfterDeployment, - }, - }, - }, - }, - { - "testBlock3", - //the contract's storage is changed - //and the block is mined by account 2 - statediff.Args{ - OldStateRoot: block2.Root(), - NewStateRoot: block3.Root(), - BlockNumber: block3.Number(), - BlockHash: block3.Hash(), - }, - &statediff.StateObject{ - BlockNumber: block3.Number(), - BlockHash: block3.Hash(), - Nodes: []sdtypes.StateNode{ - { - Path: []byte{'\x06'}, - NodeType: sdtypes.Leaf, - LeafKey: contractLeafKey, - NodeValue: contractAccountAtBlock3LeafNode, - StorageNodes: emptyStorage, - }, - }, - }, - }, - } - - for _, test := range tests { - diff, err := builder.BuildStateDiffObject(test.startingArguments, params) - if err != nil { - t.Error(err) - } - receivedStateDiffRlp, err := rlp.EncodeToBytes(diff) - if err != nil { - t.Error(err) - } - expectedStateDiffRlp, err := rlp.EncodeToBytes(test.expected) - if err != nil { - t.Error(err) - } - sort.Slice(receivedStateDiffRlp, func(i, j int) bool { return receivedStateDiffRlp[i] < receivedStateDiffRlp[j] }) - sort.Slice(expectedStateDiffRlp, func(i, j int) bool { return expectedStateDiffRlp[i] < expectedStateDiffRlp[j] }) - if !bytes.Equal(receivedStateDiffRlp, expectedStateDiffRlp) { - t.Logf("Test failed: %s", test.name) - t.Errorf("actual state diff: %+v\nexpected state diff: %+v", diff, test.expected) - } - } -} - func TestBuilderWithRemovedAccountAndStorage(t *testing.T) { blocks, chain := testhelpers.MakeChain(6, testhelpers.Genesis, testhelpers.TestChainGen) contractLeafKey = testhelpers.AddressToLeafKey(testhelpers.ContractAddr) @@ -1718,6 +1556,286 @@ func TestBuilderWithRemovedAccountAndStorageWithoutIntermediateNodes(t *testing. } } +func TestBuilderWithRemovedNonWatchedAccount(t *testing.T) { + blocks, chain := testhelpers.MakeChain(6, testhelpers.Genesis, testhelpers.TestChainGen) + contractLeafKey = testhelpers.AddressToLeafKey(testhelpers.ContractAddr) + defer chain.Stop() + block3 = blocks[2] + block4 = blocks[3] + block5 = blocks[4] + block6 = blocks[5] + params := statediff.Params{ + WatchedAddresses: []common.Address{testhelpers.Account1Addr, testhelpers.Account2Addr}, + } + params.ComputeWatchedAddressesLeafKeys() + builder = statediff.NewBuilder(chain.StateCache()) + + var tests = []struct { + name string + startingArguments statediff.Args + expected *statediff.StateObject + }{ + { + "testBlock4", + statediff.Args{ + OldStateRoot: block3.Root(), + NewStateRoot: block4.Root(), + BlockNumber: block4.Number(), + BlockHash: block4.Hash(), + }, + &statediff.StateObject{ + BlockNumber: block4.Number(), + BlockHash: block4.Hash(), + Nodes: []sdtypes.StateNode{ + { + Path: []byte{'\x0c'}, + NodeType: sdtypes.Leaf, + LeafKey: testhelpers.Account2LeafKey, + NodeValue: account2AtBlock4LeafNode, + StorageNodes: emptyStorage, + }, + }, + }, + }, + { + "testBlock5", + statediff.Args{ + OldStateRoot: block4.Root(), + NewStateRoot: block5.Root(), + BlockNumber: block5.Number(), + BlockHash: block5.Hash(), + }, + &statediff.StateObject{ + BlockNumber: block5.Number(), + BlockHash: block5.Hash(), + Nodes: []sdtypes.StateNode{ + { + Path: []byte{'\x0e'}, + NodeType: sdtypes.Leaf, + LeafKey: testhelpers.Account1LeafKey, + NodeValue: account1AtBlock5LeafNode, + StorageNodes: emptyStorage, + }, + }, + }, + }, + { + "testBlock6", + statediff.Args{ + OldStateRoot: block5.Root(), + NewStateRoot: block6.Root(), + BlockNumber: block6.Number(), + BlockHash: block6.Hash(), + }, + &statediff.StateObject{ + BlockNumber: block6.Number(), + BlockHash: block6.Hash(), + Nodes: []sdtypes.StateNode{ + { + Path: []byte{'\x0c'}, + NodeType: sdtypes.Leaf, + LeafKey: testhelpers.Account2LeafKey, + NodeValue: account2AtBlock6LeafNode, + StorageNodes: emptyStorage, + }, + { + Path: []byte{'\x0e'}, + NodeType: sdtypes.Leaf, + LeafKey: testhelpers.Account1LeafKey, + NodeValue: account1AtBlock6LeafNode, + StorageNodes: emptyStorage, + }, + }, + }, + }, + } + + for _, test := range tests { + diff, err := builder.BuildStateDiffObject(test.startingArguments, params) + if err != nil { + t.Error(err) + } + receivedStateDiffRlp, err := rlp.EncodeToBytes(diff) + if err != nil { + t.Error(err) + } + + expectedStateDiffRlp, err := rlp.EncodeToBytes(test.expected) + if err != nil { + t.Error(err) + } + + sort.Slice(receivedStateDiffRlp, func(i, j int) bool { return receivedStateDiffRlp[i] < receivedStateDiffRlp[j] }) + sort.Slice(expectedStateDiffRlp, func(i, j int) bool { return expectedStateDiffRlp[i] < expectedStateDiffRlp[j] }) + if !bytes.Equal(receivedStateDiffRlp, expectedStateDiffRlp) { + t.Logf("Test failed: %s", test.name) + t.Errorf("actual state diff: %+v\r\n\r\n\r\nexpected state diff: %+v", diff, test.expected) + } + } +} + +func TestBuilderWithRemovedWatchedAccount(t *testing.T) { + blocks, chain := testhelpers.MakeChain(6, testhelpers.Genesis, testhelpers.TestChainGen) + contractLeafKey = testhelpers.AddressToLeafKey(testhelpers.ContractAddr) + defer chain.Stop() + block3 = blocks[2] + block4 = blocks[3] + block5 = blocks[4] + block6 = blocks[5] + params := statediff.Params{ + WatchedAddresses: []common.Address{testhelpers.Account1Addr, testhelpers.ContractAddr}, + } + params.ComputeWatchedAddressesLeafKeys() + builder = statediff.NewBuilder(chain.StateCache()) + + var tests = []struct { + name string + startingArguments statediff.Args + expected *statediff.StateObject + }{ + { + "testBlock4", + statediff.Args{ + OldStateRoot: block3.Root(), + NewStateRoot: block4.Root(), + BlockNumber: block4.Number(), + BlockHash: block4.Hash(), + }, + &statediff.StateObject{ + BlockNumber: block4.Number(), + BlockHash: block4.Hash(), + Nodes: []sdtypes.StateNode{ + { + Path: []byte{'\x06'}, + NodeType: sdtypes.Leaf, + LeafKey: contractLeafKey, + NodeValue: contractAccountAtBlock4LeafNode, + StorageNodes: []sdtypes.StorageNode{ + { + Path: []byte{'\x04'}, + NodeType: sdtypes.Leaf, + LeafKey: slot2StorageKey.Bytes(), + NodeValue: slot2StorageLeafNode, + }, + { + Path: []byte{'\x0b'}, + NodeType: sdtypes.Removed, + LeafKey: slot1StorageKey.Bytes(), + NodeValue: []byte{}, + }, + { + Path: []byte{'\x0c'}, + NodeType: sdtypes.Removed, + LeafKey: slot3StorageKey.Bytes(), + NodeValue: []byte{}, + }, + }, + }, + }, + }, + }, + { + "testBlock5", + statediff.Args{ + OldStateRoot: block4.Root(), + NewStateRoot: block5.Root(), + BlockNumber: block5.Number(), + BlockHash: block5.Hash(), + }, + &statediff.StateObject{ + BlockNumber: block5.Number(), + BlockHash: block5.Hash(), + Nodes: []sdtypes.StateNode{ + { + Path: []byte{'\x06'}, + NodeType: sdtypes.Leaf, + LeafKey: contractLeafKey, + NodeValue: contractAccountAtBlock5LeafNode, + StorageNodes: []sdtypes.StorageNode{ + { + Path: []byte{}, + NodeType: sdtypes.Leaf, + LeafKey: slot0StorageKey.Bytes(), + NodeValue: slot0StorageLeafRootNode, + }, + { + Path: []byte{'\x02'}, + NodeType: sdtypes.Removed, + LeafKey: slot0StorageKey.Bytes(), + NodeValue: []byte{}, + }, + { + Path: []byte{'\x04'}, + NodeType: sdtypes.Removed, + LeafKey: slot2StorageKey.Bytes(), + NodeValue: []byte{}, + }, + }, + }, + { + Path: []byte{'\x0e'}, + NodeType: sdtypes.Leaf, + LeafKey: testhelpers.Account1LeafKey, + NodeValue: account1AtBlock5LeafNode, + StorageNodes: emptyStorage, + }, + }, + }, + }, + { + "testBlock6", + statediff.Args{ + OldStateRoot: block5.Root(), + NewStateRoot: block6.Root(), + BlockNumber: block6.Number(), + BlockHash: block6.Hash(), + }, + &statediff.StateObject{ + BlockNumber: block6.Number(), + BlockHash: block6.Hash(), + Nodes: []sdtypes.StateNode{ + { + Path: []byte{'\x06'}, + NodeType: sdtypes.Removed, + LeafKey: contractLeafKey, + NodeValue: []byte{}, + }, + { + Path: []byte{'\x0e'}, + NodeType: sdtypes.Leaf, + LeafKey: testhelpers.Account1LeafKey, + NodeValue: account1AtBlock6LeafNode, + StorageNodes: emptyStorage, + }, + }, + }, + }, + } + + for _, test := range tests { + diff, err := builder.BuildStateDiffObject(test.startingArguments, params) + if err != nil { + t.Error(err) + } + receivedStateDiffRlp, err := rlp.EncodeToBytes(diff) + if err != nil { + t.Error(err) + } + + expectedStateDiffRlp, err := rlp.EncodeToBytes(test.expected) + if err != nil { + t.Error(err) + } + + sort.Slice(receivedStateDiffRlp, func(i, j int) bool { return receivedStateDiffRlp[i] < receivedStateDiffRlp[j] }) + sort.Slice(expectedStateDiffRlp, func(i, j int) bool { return expectedStateDiffRlp[i] < expectedStateDiffRlp[j] }) + if !bytes.Equal(receivedStateDiffRlp, expectedStateDiffRlp) { + t.Logf("Test failed: %s", test.name) + t.Errorf("actual state diff: %+v\r\n\r\n\r\nexpected state diff: %+v", diff, test.expected) + } + } +} + var ( slot00StorageValue = common.Hex2Bytes("9471562b71999873db5b286df957af199ec94617f7") // prefixed TestBankAddress diff --git a/statediff/helpers.go b/statediff/helpers.go index eb5060c51..8870855bd 100644 --- a/statediff/helpers.go +++ b/statediff/helpers.go @@ -20,8 +20,15 @@ package statediff import ( + "fmt" "sort" "strings" + + "github.com/thoas/go-funk" + + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/statediff/indexer/postgres" + . "github.com/ethereum/go-ethereum/statediff/types" ) func sortKeys(data AccountMap) []string { @@ -69,5 +76,42 @@ func findIntersection(a, b []string) []string { } } } - +} + +// loadWatchedAddresses is used to load watched addresses to in-memory write loop params from the db +func loadWatchedAddresses(db *postgres.DB) error { + type Watched struct { + Address string `db:"address"` + } + var watched []Watched + + pgStr := "SELECT address FROM eth_meta.watched_addresses" + err := db.Select(&watched, pgStr) + if err != nil { + return fmt.Errorf("error loading watched addresses: %v", err) + } + + watchedAddresses := []common.Address{} + for _, entry := range watched { + watchedAddresses = append(watchedAddresses, common.HexToAddress(entry.Address)) + } + + writeLoopParams.Lock() + defer writeLoopParams.Unlock() + writeLoopParams.WatchedAddresses = watchedAddresses + writeLoopParams.ComputeWatchedAddressesLeafKeys() + + return nil +} + +// MapWatchAddressArgsToAddresses maps []WatchAddressArg to corresponding []common.Address +func MapWatchAddressArgsToAddresses(args []WatchAddressArg) ([]common.Address, error) { + addresses, ok := funk.Map(args, func(arg WatchAddressArg) common.Address { + return common.HexToAddress(arg.Address) + }).([]common.Address) + if !ok { + return nil, fmt.Errorf(typeAssertionFailed) + } + + return addresses, nil } diff --git a/statediff/indexer/indexer.go b/statediff/indexer/indexer.go index 60d69f932..65a5e950f 100644 --- a/statediff/indexer/indexer.go +++ b/statediff/indexer/indexer.go @@ -59,6 +59,12 @@ type Indexer interface { PushStateNode(tx *BlockTx, stateNode sdtypes.StateNode) error PushCodeAndCodeHash(tx *BlockTx, codeAndCodeHash sdtypes.CodeAndCodeHash) error ReportDBMetrics(delay time.Duration, quit <-chan bool) + + // Methods used by WatchAddress API/functionality. + InsertWatchedAddresses(addresses []sdtypes.WatchAddressArg, currentBlock *big.Int) error + RemoveWatchedAddresses(addresses []sdtypes.WatchAddressArg) error + SetWatchedAddresses(args []sdtypes.WatchAddressArg, currentBlockNumber *big.Int) error + ClearWatchedAddresses() error } // StateDiffIndexer satisfies the Indexer interface for ethereum statediff objects @@ -549,3 +555,101 @@ func (sdi *StateDiffIndexer) PushCodeAndCodeHash(tx *BlockTx, codeAndCodeHash sd } return nil } + +// InsertWatchedAddresses inserts the given addresses in the database +func (sdi *StateDiffIndexer) InsertWatchedAddresses(args []sdtypes.WatchAddressArg, currentBlockNumber *big.Int) error { + tx, err := sdi.dbWriter.db.Beginx() + if err != nil { + return err + } + defer func() { + if p := recover(); p != nil { + shared.Rollback(tx) + panic(p) + } else if err != nil { + shared.Rollback(tx) + } else { + err = tx.Commit() + } + }() + + for _, arg := range args { + _, err = tx.Exec(`INSERT INTO eth_meta.watched_addresses (address, created_at, watched_at) VALUES ($1, $2, $3) ON CONFLICT (address) DO NOTHING`, + arg.Address, arg.CreatedAt, currentBlockNumber.Uint64()) + if err != nil { + return fmt.Errorf("error inserting watched_addresses entry: %v", err) + } + } + + return err +} + +// RemoveWatchedAddresses removes the given watched addresses from the database +func (sdi *StateDiffIndexer) RemoveWatchedAddresses(args []sdtypes.WatchAddressArg) error { + tx, err := sdi.dbWriter.db.Beginx() + if err != nil { + return err + } + defer func() { + if p := recover(); p != nil { + shared.Rollback(tx) + panic(p) + } else if err != nil { + shared.Rollback(tx) + } else { + err = tx.Commit() + } + }() + + for _, arg := range args { + _, err = tx.Exec(`DELETE FROM eth_meta.watched_addresses WHERE address = $1`, arg.Address) + if err != nil { + return fmt.Errorf("error removing watched_addresses entry: %v", err) + } + } + + return err +} + +// SetWatchedAddresses clears and inserts the given addresses in the database +func (sdi *StateDiffIndexer) SetWatchedAddresses(args []sdtypes.WatchAddressArg, currentBlockNumber *big.Int) error { + tx, err := sdi.dbWriter.db.Beginx() + if err != nil { + return err + } + defer func() { + if p := recover(); p != nil { + shared.Rollback(tx) + panic(p) + } else if err != nil { + shared.Rollback(tx) + } else { + err = tx.Commit() + } + }() + + _, err = tx.Exec(`DELETE FROM eth_meta.watched_addresses`) + if err != nil { + return fmt.Errorf("error setting watched_addresses table: %v", err) + } + + for _, arg := range args { + _, err = tx.Exec(`INSERT INTO eth_meta.watched_addresses (address, created_at, watched_at) VALUES ($1, $2, $3) ON CONFLICT (address) DO NOTHING`, + arg.Address, arg.CreatedAt, currentBlockNumber.Uint64()) + if err != nil { + return fmt.Errorf("error setting watched_addresses table: %v", err) + } + } + + return err +} + +// ClearWatchedAddresses clears all the watched addresses from the database +func (sdi *StateDiffIndexer) ClearWatchedAddresses() error { + _, err := sdi.dbWriter.db.Exec(`DELETE FROM eth_meta.watched_addresses`) + if err != nil { + return fmt.Errorf("error clearing watched_addresses table: %v", err) + } + + return nil +} diff --git a/statediff/indexer/indexer_test.go b/statediff/indexer/indexer_test.go index 705c33406..011738692 100644 --- a/statediff/indexer/indexer_test.go +++ b/statediff/indexer/indexer_test.go @@ -19,6 +19,7 @@ package indexer_test import ( "bytes" "fmt" + "math/big" "os" "testing" @@ -32,6 +33,7 @@ import ( "github.com/ethereum/go-ethereum/statediff/indexer/models" "github.com/ethereum/go-ethereum/statediff/indexer/postgres" "github.com/ethereum/go-ethereum/statediff/indexer/shared" + sdtypes "github.com/ethereum/go-ethereum/statediff/types" "github.com/ipfs/go-cid" blockstore "github.com/ipfs/go-ipfs-blockstore" dshelp "github.com/ipfs/go-ipfs-ds-help" @@ -45,12 +47,15 @@ var ( ind *indexer.StateDiffIndexer ipfsPgGet = `SELECT data FROM public.blocks WHERE key = $1` - tx1, tx2, tx3, tx4, tx5, rct1, rct2, rct3, rct4, rct5 []byte - mockBlock *types.Block - headerCID, trx1CID, trx2CID, trx3CID, trx4CID, trx5CID cid.Cid - rct1CID, rct2CID, rct3CID, rct4CID, rct5CID cid.Cid - rctLeaf1, rctLeaf2, rctLeaf3, rctLeaf4, rctLeaf5 []byte - state1CID, state2CID, storageCID cid.Cid + tx1, tx2, tx3, tx4, tx5, rct1, rct2, rct3, rct4, rct5 []byte + mockBlock *types.Block + headerCID, trx1CID, trx2CID, trx3CID, trx4CID, trx5CID cid.Cid + rct1CID, rct2CID, rct3CID, rct4CID, rct5CID cid.Cid + rctLeaf1, rctLeaf2, rctLeaf3, rctLeaf4, rctLeaf5 []byte + state1CID, state2CID, storageCID cid.Cid + contract1Address, contract2Address, contract3Address, contract4Address string + contract1CreatedAt, contract2CreatedAt, contract3CreatedAt, contract4CreatedAt uint64 + lastFilledAt, watchedAt1, watchedAt2, watchedAt3 uint64 ) func expectTrue(t *testing.T, value bool) { @@ -161,15 +166,33 @@ func init() { rctLeaf3 = orderedRctLeafNodes[2] rctLeaf4 = orderedRctLeafNodes[3] rctLeaf5 = orderedRctLeafNodes[4] + + contract1Address = "0x5d663F5269090bD2A7DC2390c911dF6083D7b28F" + contract2Address = "0x6Eb7e5C66DB8af2E96159AC440cbc8CDB7fbD26B" + contract3Address = "0xcfeB164C328CA13EFd3C77E1980d94975aDfedfc" + contract4Address = "0x0Edf0c4f393a628DE4828B228C48175b3EA297fc" + contract1CreatedAt = uint64(1) + contract2CreatedAt = uint64(2) + contract3CreatedAt = uint64(3) + contract4CreatedAt = uint64(4) + + lastFilledAt = uint64(0) + watchedAt1 = uint64(10) + watchedAt2 = uint64(15) + watchedAt3 = uint64(20) } -func setup(t *testing.T) { +func setupIndexer(t *testing.T) { db, err = shared.SetupDB() if err != nil { t.Fatal(err) } ind, err = indexer.NewStateDiffIndexer(mocks.TestConfig, db) require.NoError(t, err) +} + +func setup(t *testing.T) { + setupIndexer(t) var tx *indexer.BlockTx tx, err = ind.PushBlock( mockBlock, @@ -654,3 +677,297 @@ func TestPublishAndIndexer(t *testing.T) { shared.ExpectEqual(t, data, []byte{}) }) } + +func TestWatchAddressMethods(t *testing.T) { + setupIndexer(t) + defer tearDown(t) + + type res struct { + Address string `db:"address"` + CreatedAt uint64 `db:"created_at"` + WatchedAt uint64 `db:"watched_at"` + LastFilledAt uint64 `db:"last_filled_at"` + } + pgStr := "SELECT * FROM eth_meta.watched_addresses" + + t.Run("Insert watched addresses", func(t *testing.T) { + args := []sdtypes.WatchAddressArg{ + { + Address: contract1Address, + CreatedAt: contract1CreatedAt, + }, + { + Address: contract2Address, + CreatedAt: contract2CreatedAt, + }, + } + expectedData := []res{ + { + Address: contract1Address, + CreatedAt: contract1CreatedAt, + WatchedAt: watchedAt1, + LastFilledAt: lastFilledAt, + }, + { + Address: contract2Address, + CreatedAt: contract2CreatedAt, + WatchedAt: watchedAt1, + LastFilledAt: lastFilledAt, + }, + } + + ind.InsertWatchedAddresses(args, big.NewInt(int64(watchedAt1))) + + rows := []res{} + err = db.Select(&rows, pgStr) + if err != nil { + t.Fatal(err) + } + + expectTrue(t, len(rows) == len(expectedData)) + for idx, row := range rows { + shared.ExpectEqual(t, row, expectedData[idx]) + } + }) + + t.Run("Insert watched addresses (some already watched)", func(t *testing.T) { + args := []sdtypes.WatchAddressArg{ + { + Address: contract3Address, + CreatedAt: contract3CreatedAt, + }, + { + Address: contract2Address, + CreatedAt: contract2CreatedAt, + }, + } + expectedData := []res{ + { + Address: contract1Address, + CreatedAt: contract1CreatedAt, + WatchedAt: watchedAt1, + LastFilledAt: lastFilledAt, + }, + { + Address: contract2Address, + CreatedAt: contract2CreatedAt, + WatchedAt: watchedAt1, + LastFilledAt: lastFilledAt, + }, + { + Address: contract3Address, + CreatedAt: contract3CreatedAt, + WatchedAt: watchedAt2, + LastFilledAt: lastFilledAt, + }, + } + + ind.InsertWatchedAddresses(args, big.NewInt(int64(watchedAt2))) + + rows := []res{} + err = db.Select(&rows, pgStr) + if err != nil { + t.Fatal(err) + } + + expectTrue(t, len(rows) == len(expectedData)) + for idx, row := range rows { + shared.ExpectEqual(t, row, expectedData[idx]) + } + }) + + t.Run("Remove watched addresses", func(t *testing.T) { + args := []sdtypes.WatchAddressArg{ + { + Address: contract3Address, + CreatedAt: contract3CreatedAt, + }, + { + Address: contract2Address, + CreatedAt: contract2CreatedAt, + }, + } + expectedData := []res{ + { + Address: contract1Address, + CreatedAt: contract1CreatedAt, + WatchedAt: watchedAt1, + LastFilledAt: lastFilledAt, + }, + } + + ind.RemoveWatchedAddresses(args) + + rows := []res{} + err = db.Select(&rows, pgStr) + if err != nil { + t.Fatal(err) + } + + expectTrue(t, len(rows) == len(expectedData)) + for idx, row := range rows { + shared.ExpectEqual(t, row, expectedData[idx]) + } + }) + + t.Run("Remove watched addresses (some non-watched)", func(t *testing.T) { + args := []sdtypes.WatchAddressArg{ + { + Address: contract1Address, + CreatedAt: contract1CreatedAt, + }, + { + Address: contract2Address, + CreatedAt: contract2CreatedAt, + }, + } + expectedData := []res{} + + ind.RemoveWatchedAddresses(args) + + rows := []res{} + err = db.Select(&rows, pgStr) + if err != nil { + t.Fatal(err) + } + + expectTrue(t, len(rows) == len(expectedData)) + for idx, row := range rows { + shared.ExpectEqual(t, row, expectedData[idx]) + } + }) + + t.Run("Set watched addresses", func(t *testing.T) { + args := []sdtypes.WatchAddressArg{ + { + Address: contract1Address, + CreatedAt: contract1CreatedAt, + }, + { + Address: contract2Address, + CreatedAt: contract2CreatedAt, + }, + { + Address: contract3Address, + CreatedAt: contract3CreatedAt, + }, + } + expectedData := []res{ + { + Address: contract1Address, + CreatedAt: contract1CreatedAt, + WatchedAt: watchedAt2, + LastFilledAt: lastFilledAt, + }, + { + Address: contract2Address, + CreatedAt: contract2CreatedAt, + WatchedAt: watchedAt2, + LastFilledAt: lastFilledAt, + }, + { + Address: contract3Address, + CreatedAt: contract3CreatedAt, + WatchedAt: watchedAt2, + LastFilledAt: lastFilledAt, + }, + } + + ind.SetWatchedAddresses(args, big.NewInt(int64(watchedAt2))) + + rows := []res{} + err = db.Select(&rows, pgStr) + if err != nil { + t.Fatal(err) + } + + expectTrue(t, len(rows) == len(expectedData)) + for idx, row := range rows { + shared.ExpectEqual(t, row, expectedData[idx]) + } + }) + + t.Run("Set watched addresses (some already watched)", func(t *testing.T) { + args := []sdtypes.WatchAddressArg{ + { + Address: contract4Address, + CreatedAt: contract4CreatedAt, + }, + { + Address: contract2Address, + CreatedAt: contract2CreatedAt, + }, + { + Address: contract3Address, + CreatedAt: contract3CreatedAt, + }, + } + expectedData := []res{ + { + Address: contract4Address, + CreatedAt: contract4CreatedAt, + WatchedAt: watchedAt3, + LastFilledAt: lastFilledAt, + }, + { + Address: contract2Address, + CreatedAt: contract2CreatedAt, + WatchedAt: watchedAt3, + LastFilledAt: lastFilledAt, + }, + { + Address: contract3Address, + CreatedAt: contract3CreatedAt, + WatchedAt: watchedAt3, + LastFilledAt: lastFilledAt, + }, + } + + ind.SetWatchedAddresses(args, big.NewInt(int64(watchedAt3))) + + rows := []res{} + err = db.Select(&rows, pgStr) + if err != nil { + t.Fatal(err) + } + + expectTrue(t, len(rows) == len(expectedData)) + for idx, row := range rows { + shared.ExpectEqual(t, row, expectedData[idx]) + } + }) + + t.Run("Clear watched addresses", func(t *testing.T) { + expectedData := []res{} + + ind.ClearWatchedAddresses() + + rows := []res{} + err = db.Select(&rows, pgStr) + if err != nil { + t.Fatal(err) + } + + expectTrue(t, len(rows) == len(expectedData)) + for idx, row := range rows { + shared.ExpectEqual(t, row, expectedData[idx]) + } + }) + + t.Run("Clear watched addresses (empty table)", func(t *testing.T) { + expectedData := []res{} + + ind.ClearWatchedAddresses() + + rows := []res{} + err = db.Select(&rows, pgStr) + if err != nil { + t.Fatal(err) + } + + expectTrue(t, len(rows) == len(expectedData)) + for idx, row := range rows { + shared.ExpectEqual(t, row, expectedData[idx]) + } + }) +} diff --git a/statediff/indexer/test_helpers.go b/statediff/indexer/test_helpers.go index 024bb58f0..9745cfd02 100644 --- a/statediff/indexer/test_helpers.go +++ b/statediff/indexer/test_helpers.go @@ -53,6 +53,10 @@ func TearDownDB(t *testing.T, db *postgres.DB) { if err != nil { t.Fatal(err) } + _, err = tx.Exec(`DELETE FROM eth_meta.watched_addresses`) + if err != nil { + t.Fatal(err) + } err = tx.Commit() if err != nil { t.Fatal(err) diff --git a/statediff/service.go b/statediff/service.go index 6411ba68e..998a8e85a 100644 --- a/statediff/service.go +++ b/statediff/service.go @@ -18,6 +18,7 @@ package statediff import ( "bytes" + "fmt" "math/big" "strconv" "strings" @@ -40,6 +41,7 @@ import ( "github.com/ethereum/go-ethereum/rlp" "github.com/ethereum/go-ethereum/rpc" "github.com/ethereum/go-ethereum/trie" + "github.com/thoas/go-funk" ind "github.com/ethereum/go-ethereum/statediff/indexer" nodeinfo "github.com/ethereum/go-ethereum/statediff/indexer/node" @@ -48,25 +50,30 @@ import ( ) const ( - chainEventChanSize = 20000 - genesisBlockNumber = 0 - defaultRetryLimit = 3 // default retry limit once deadlock is detected. - deadlockDetected = "deadlock detected" // 40P01 https://www.postgresql.org/docs/current/errcodes-appendix.html + chainEventChanSize = 20000 + genesisBlockNumber = 0 + defaultRetryLimit = 3 // default retry limit once deadlock is detected. + deadlockDetected = "deadlock detected" // 40P01 https://www.postgresql.org/docs/current/errcodes-appendix.html + typeAssertionFailed = "type assertion failed" + unexpectedOperation = "unexpected operation" ) -var writeLoopParams = Params{ - IntermediateStateNodes: true, - IntermediateStorageNodes: true, - IncludeBlock: true, - IncludeReceipts: true, - IncludeTD: true, - IncludeCode: true, +var writeLoopParams = ParamsWithMutex{ + Params: Params{ + IntermediateStateNodes: true, + IntermediateStorageNodes: true, + IncludeBlock: true, + IncludeReceipts: true, + IncludeTD: true, + IncludeCode: true, + }, } var statediffMetrics = RegisterStatediffMetrics(metrics.DefaultRegistry) type blockChain interface { SubscribeChainEvent(ch chan<- core.ChainEvent) event.Subscription + CurrentBlock() *types.Block GetBlockByHash(hash common.Hash) *types.Block GetBlockByNumber(number uint64) *types.Block GetReceiptsByHash(hash common.Hash) types.Receipts @@ -101,6 +108,8 @@ type IService interface { WriteStateDiffFor(blockHash common.Hash, params Params) error // Event loop for progressively processing and writing diffs directly to DB WriteLoop(chainEventCh chan core.ChainEvent) + // Method to change the addresses being watched in write loop params + WatchAddress(operation OperationType, args []WatchAddressArg) error } // Wraps consructor parameters @@ -159,6 +168,8 @@ func NewBlockCache(max uint) blockCache { func New(stack *node.Node, ethServ *eth.Ethereum, cfg *ethconfig.Config, params ServiceParams) error { blockChain := ethServ.BlockChain() var indexer ind.Indexer + var db *postgres.DB + var err error quitCh := make(chan bool) if params.DBParams != nil { info := nodeinfo.Info{ @@ -170,7 +181,7 @@ func New(stack *node.Node, ethServ *eth.Ethereum, cfg *ethconfig.Config, params } // TODO: pass max idle, open, lifetime? - db, err := postgres.NewDB(params.DBParams.ConnectionURL, postgres.ConnectionConfig{}, info) + db, err = postgres.NewDB(params.DBParams.ConnectionURL, postgres.ConnectionConfig{}, info) if err != nil { return err } @@ -200,6 +211,12 @@ func New(stack *node.Node, ethServ *eth.Ethereum, cfg *ethconfig.Config, params } stack.RegisterLifecycle(sds) stack.RegisterAPIs(sds.APIs()) + + err = loadWatchedAddresses(db) + if err != nil { + return err + } + return nil } @@ -278,7 +295,9 @@ func (sds *Service) WriteLoop(chainEventCh chan core.ChainEvent) { func (sds *Service) writeGenesisStateDiff(currBlock *types.Block, workerId uint) { // For genesis block we need to return the entire state trie hence we diff it with an empty trie. log.Info("Writing state diff", "block height", genesisBlockNumber, "worker", workerId) - err := sds.writeStateDiffWithRetry(currBlock, common.Hash{}, writeLoopParams) + writeLoopParams.RLock() + err := sds.writeStateDiffWithRetry(currBlock, common.Hash{}, writeLoopParams.Params) + writeLoopParams.RUnlock() if err != nil { log.Error("statediff.Service.WriteLoop: processing error", "block height", genesisBlockNumber, "error", err.Error(), "worker", workerId) @@ -307,7 +326,9 @@ func (sds *Service) writeLoopWorker(params workerParams) { } log.Info("Writing state diff", "block height", currentBlock.Number().Uint64(), "worker", params.id) - err := sds.writeStateDiffWithRetry(currentBlock, parentBlock.Root(), writeLoopParams) + writeLoopParams.RLock() + err := sds.writeStateDiffWithRetry(currentBlock, parentBlock.Root(), writeLoopParams.Params) + writeLoopParams.RUnlock() if err != nil { log.Error("statediff.Service.WriteLoop: processing error", "block height", currentBlock.Number().Uint64(), "error", err.Error(), "worker", params.id) continue @@ -402,6 +423,9 @@ func (sds *Service) streamStateDiff(currentBlock *types.Block, parentRoot common func (sds *Service) StateDiffAt(blockNumber uint64, params Params) (*Payload, error) { currentBlock := sds.BlockChain.GetBlockByNumber(blockNumber) log.Info("sending state diff", "block height", blockNumber) + + params.ComputeWatchedAddressesLeafKeys() + if blockNumber == 0 { return sds.processStateDiff(currentBlock, common.Hash{}, params) } @@ -414,6 +438,9 @@ func (sds *Service) StateDiffAt(blockNumber uint64, params Params) (*Payload, er func (sds *Service) StateDiffFor(blockHash common.Hash, params Params) (*Payload, error) { currentBlock := sds.BlockChain.GetBlockByHash(blockHash) log.Info("sending state diff", "block hash", blockHash) + + params.ComputeWatchedAddressesLeafKeys() + if currentBlock.NumberU64() == 0 { return sds.processStateDiff(currentBlock, common.Hash{}, params) } @@ -472,6 +499,9 @@ func (sds *Service) newPayload(stateObject []byte, block *types.Block, params Pa func (sds *Service) StateTrieAt(blockNumber uint64, params Params) (*Payload, error) { currentBlock := sds.BlockChain.GetBlockByNumber(blockNumber) log.Info("sending state trie", "block height", blockNumber) + + params.ComputeWatchedAddressesLeafKeys() + return sds.processStateTrie(currentBlock, params) } @@ -494,6 +524,9 @@ func (sds *Service) Subscribe(id rpc.ID, sub chan<- Payload, quitChan chan<- boo if atomic.CompareAndSwapInt32(&sds.subscribers, 0, 1) { log.Info("State diffing subscription received; beginning statediff processing") } + + params.ComputeWatchedAddressesLeafKeys() + // Subscription type is defined as the hash of the rlp-serialized subscription params by, err := rlp.EncodeToBytes(params) if err != nil { @@ -543,7 +576,7 @@ func (sds *Service) Start() error { go sds.Loop(chainEventCh) if sds.enableWriteLoop { - log.Info("Starting statediff DB write loop", "params", writeLoopParams) + log.Info("Starting statediff DB write loop", "params", writeLoopParams.Params) chainEventCh := make(chan core.ChainEvent, chainEventChanSize) go sds.WriteLoop(chainEventCh) } @@ -640,6 +673,8 @@ func (sds *Service) StreamCodeAndCodeHash(blockNumber uint64, outChan chan<- Cod // This operation cannot be performed back past the point of db pruning; it requires an archival node // for historical data func (sds *Service) WriteStateDiffAt(blockNumber uint64, params Params) error { + params.ComputeWatchedAddressesLeafKeys() + currentBlock := sds.BlockChain.GetBlockByNumber(blockNumber) parentRoot := common.Hash{} if blockNumber != 0 { @@ -653,6 +688,8 @@ func (sds *Service) WriteStateDiffAt(blockNumber uint64, params Params) error { // This operation cannot be performed back past the point of db pruning; it requires an archival node // for historical data func (sds *Service) WriteStateDiffFor(blockHash common.Hash, params Params) error { + params.ComputeWatchedAddressesLeafKeys() + currentBlock := sds.BlockChain.GetBlockByHash(blockHash) parentRoot := common.Hash{} if currentBlock.NumberU64() != 0 { @@ -713,3 +750,102 @@ func (sds *Service) writeStateDiffWithRetry(block *types.Block, parentRoot commo } return err } + +// Performs one of following operations on the watched addresses in writeLoopParams and the db: +// add | remove | set | clear +func (sds *Service) WatchAddress(operation OperationType, args []WatchAddressArg) error { + // lock writeLoopParams for a write + writeLoopParams.Lock() + defer writeLoopParams.Unlock() + + // get the current block number + currentBlockNumber := sds.BlockChain.CurrentBlock().Number() + + switch operation { + case Add: + // filter out args having an already watched address with a warning + filteredArgs, ok := funk.Filter(args, func(arg WatchAddressArg) bool { + if funk.Contains(writeLoopParams.WatchedAddresses, common.HexToAddress(arg.Address)) { + log.Warn("Address already being watched", "address", arg.Address) + return false + } + return true + }).([]WatchAddressArg) + if !ok { + return fmt.Errorf("add: filtered args %s", typeAssertionFailed) + } + + // get addresses from the filtered args + filteredAddresses, err := MapWatchAddressArgsToAddresses(filteredArgs) + if err != nil { + return fmt.Errorf("add: filtered addresses %s", err.Error()) + } + + // update the db + err = sds.indexer.InsertWatchedAddresses(filteredArgs, currentBlockNumber) + if err != nil { + return err + } + + // update in-memory params + writeLoopParams.WatchedAddresses = append(writeLoopParams.WatchedAddresses, filteredAddresses...) + funk.ForEach(filteredAddresses, func(address common.Address) { + writeLoopParams.watchedAddressesLeafKeys[crypto.Keccak256Hash(address.Bytes())] = struct{}{} + }) + case Remove: + // get addresses from args + argAddresses, err := MapWatchAddressArgsToAddresses(args) + if err != nil { + return fmt.Errorf("remove: mapped addresses %s", err.Error()) + } + + // remove the provided addresses from currently watched addresses + addresses, ok := funk.Subtract(writeLoopParams.WatchedAddresses, argAddresses).([]common.Address) + if !ok { + return fmt.Errorf("remove: filtered addresses %s", typeAssertionFailed) + } + + // update the db + err = sds.indexer.RemoveWatchedAddresses(args) + if err != nil { + return err + } + + // update in-memory params + writeLoopParams.WatchedAddresses = addresses + funk.ForEach(argAddresses, func(address common.Address) { + delete(writeLoopParams.watchedAddressesLeafKeys, crypto.Keccak256Hash(address.Bytes())) + }) + case Set: + // get addresses from args + argAddresses, err := MapWatchAddressArgsToAddresses(args) + if err != nil { + return fmt.Errorf("set: mapped addresses %s", err.Error()) + } + + // update the db + err = sds.indexer.SetWatchedAddresses(args, currentBlockNumber) + if err != nil { + return err + } + + // update in-memory params + writeLoopParams.WatchedAddresses = argAddresses + writeLoopParams.ComputeWatchedAddressesLeafKeys() + case Clear: + // update the db + err := sds.indexer.ClearWatchedAddresses() + if err != nil { + return err + } + + // update in-memory params + writeLoopParams.WatchedAddresses = []common.Address{} + writeLoopParams.ComputeWatchedAddressesLeafKeys() + + default: + return fmt.Errorf("%s %s", unexpectedOperation, operation) + } + + return nil +} diff --git a/statediff/service_test.go b/statediff/service_test.go index ca9a483a5..a3a5ccca4 100644 --- a/statediff/service_test.go +++ b/statediff/service_test.go @@ -144,6 +144,7 @@ func testErrorInChainEventLoop(t *testing.T) { } } + defaultParams.ComputeWatchedAddressesLeafKeys() if !reflect.DeepEqual(builder.Params, defaultParams) { t.Error("Test failure:", t.Name()) t.Logf("Actual params does not equal expected.\nactual:%+v\nexpected: %+v", builder.Params, defaultParams) @@ -195,6 +196,8 @@ func testErrorInBlockLoop(t *testing.T) { } }() service.Loop(eventsChannel) + + defaultParams.ComputeWatchedAddressesLeafKeys() if !reflect.DeepEqual(builder.Params, defaultParams) { t.Error("Test failure:", t.Name()) t.Logf("Actual params does not equal expected.\nactual:%+v\nexpected: %+v", builder.Params, defaultParams) @@ -268,6 +271,8 @@ func testErrorInStateDiffAt(t *testing.T) { if err != nil { t.Error(err) } + + defaultParams.ComputeWatchedAddressesLeafKeys() if !reflect.DeepEqual(builder.Params, defaultParams) { t.Error("Test failure:", t.Name()) t.Logf("Actual params does not equal expected.\nactual:%+v\nexpected: %+v", builder.Params, defaultParams) diff --git a/statediff/testhelpers/mocks/blockchain.go b/statediff/testhelpers/mocks/blockchain.go index b4b1f3694..f2a77af64 100644 --- a/statediff/testhelpers/mocks/blockchain.go +++ b/statediff/testhelpers/mocks/blockchain.go @@ -39,6 +39,7 @@ type BlockChain struct { Receipts map[common.Hash]types.Receipts TDByHash map[common.Hash]*big.Int TDByNum map[uint64]*big.Int + currentBlock *types.Block } // SetBlocksForHashes mock method @@ -128,6 +129,16 @@ func (bc *BlockChain) GetTd(hash common.Hash, blockNum uint64) *big.Int { return nil } +// SetCurrentBlock test method +func (bc *BlockChain) SetCurrentBlock(block *types.Block) { + bc.currentBlock = block +} + +// CurrentBlock mock method +func (bc *BlockChain) CurrentBlock() *types.Block { + return bc.currentBlock +} + func (bc *BlockChain) SetTd(hash common.Hash, blockNum uint64, td *big.Int) { if bc.TDByHash == nil { bc.TDByHash = make(map[common.Hash]*big.Int) diff --git a/statediff/testhelpers/mocks/indexer.go b/statediff/testhelpers/mocks/indexer.go new file mode 100644 index 000000000..90ea40ca0 --- /dev/null +++ b/statediff/testhelpers/mocks/indexer.go @@ -0,0 +1,59 @@ +// Copyright 2022 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package mocks + +import ( + "math/big" + "time" + + "github.com/ethereum/go-ethereum/core/types" + "github.com/ethereum/go-ethereum/statediff/indexer" + sdtypes "github.com/ethereum/go-ethereum/statediff/types" +) + +// Indexer is a mock state diff indexer +type Indexer struct{} + +func (sdi *Indexer) PushBlock(block *types.Block, receipts types.Receipts, totalDifficulty *big.Int) (*indexer.BlockTx, error) { + return nil, nil +} + +func (sdi *Indexer) PushStateNode(tx *indexer.BlockTx, stateNode sdtypes.StateNode) error { + return nil +} + +func (sdi *Indexer) PushCodeAndCodeHash(tx *indexer.BlockTx, codeAndCodeHash sdtypes.CodeAndCodeHash) error { + return nil +} + +func (sdi *Indexer) ReportDBMetrics(delay time.Duration, quit <-chan bool) {} + +func (sdi *Indexer) InsertWatchedAddresses(addresses []sdtypes.WatchAddressArg, currentBlock *big.Int) error { + return nil +} + +func (sdi *Indexer) RemoveWatchedAddresses(addresses []sdtypes.WatchAddressArg) error { + return nil +} + +func (sdi *Indexer) SetWatchedAddresses(args []sdtypes.WatchAddressArg, currentBlockNumber *big.Int) error { + return nil +} + +func (sdi *Indexer) ClearWatchedAddresses() error { + return nil +} diff --git a/statediff/testhelpers/mocks/service.go b/statediff/testhelpers/mocks/service.go index f10017df4..e14565774 100644 --- a/statediff/testhelpers/mocks/service.go +++ b/statediff/testhelpers/mocks/service.go @@ -25,6 +25,7 @@ import ( "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/crypto" "github.com/ethereum/go-ethereum/rlp" + "github.com/thoas/go-funk" "github.com/ethereum/go-ethereum/core" "github.com/ethereum/go-ethereum/core/types" @@ -32,9 +33,15 @@ import ( "github.com/ethereum/go-ethereum/p2p" "github.com/ethereum/go-ethereum/rpc" "github.com/ethereum/go-ethereum/statediff" + ind "github.com/ethereum/go-ethereum/statediff/indexer" sdtypes "github.com/ethereum/go-ethereum/statediff/types" ) +var ( + typeAssertionFailed = "type assertion failed" + unexpectedOperation = "unexpected operation" +) + // MockStateDiffService is a mock state diff service type MockStateDiffService struct { sync.Mutex @@ -47,6 +54,8 @@ type MockStateDiffService struct { QuitChan chan bool Subscriptions map[common.Hash]map[rpc.ID]statediff.Subscription SubscriptionTypes map[common.Hash]statediff.Params + Indexer ind.Indexer + writeLoopParams statediff.ParamsWithMutex } // Protocols mock method @@ -332,3 +341,97 @@ func sendNonBlockingQuit(id rpc.ID, sub statediff.Subscription) { log.Info("unable to close subscription %s; channel has no receiver", id) } } + +// WatchAddress mock method +func (sds *MockStateDiffService) WatchAddress(operation statediff.OperationType, args []sdtypes.WatchAddressArg) error { + // lock writeLoopParams for a write + sds.writeLoopParams.Lock() + defer sds.writeLoopParams.Unlock() + + // get the current block number + currentBlockNumber := sds.BlockChain.CurrentBlock().Number() + + switch operation { + case statediff.Add: + // filter out args having an already watched address with a warning + filteredArgs, ok := funk.Filter(args, func(arg sdtypes.WatchAddressArg) bool { + if funk.Contains(sds.writeLoopParams.WatchedAddresses, common.HexToAddress(arg.Address)) { + log.Warn("Address already being watched", "address", arg.Address) + return false + } + return true + }).([]sdtypes.WatchAddressArg) + if !ok { + return fmt.Errorf("add: filtered args %s", typeAssertionFailed) + } + + // get addresses from the filtered args + filteredAddresses, err := statediff.MapWatchAddressArgsToAddresses(filteredArgs) + if err != nil { + return fmt.Errorf("add: filtered addresses %s", err.Error()) + } + + // update the db + err = sds.Indexer.InsertWatchedAddresses(filteredArgs, currentBlockNumber) + if err != nil { + return err + } + + // update in-memory params + sds.writeLoopParams.WatchedAddresses = append(sds.writeLoopParams.WatchedAddresses, filteredAddresses...) + sds.writeLoopParams.ComputeWatchedAddressesLeafKeys() + case statediff.Remove: + // get addresses from args + argAddresses, err := statediff.MapWatchAddressArgsToAddresses(args) + if err != nil { + return fmt.Errorf("remove: mapped addresses %s", err.Error()) + } + + // remove the provided addresses from currently watched addresses + addresses, ok := funk.Subtract(sds.writeLoopParams.WatchedAddresses, argAddresses).([]common.Address) + if !ok { + return fmt.Errorf("remove: filtered addresses %s", typeAssertionFailed) + } + + // update the db + err = sds.Indexer.RemoveWatchedAddresses(args) + if err != nil { + return err + } + + // update in-memory params + sds.writeLoopParams.WatchedAddresses = addresses + sds.writeLoopParams.ComputeWatchedAddressesLeafKeys() + case statediff.Set: + // get addresses from args + argAddresses, err := statediff.MapWatchAddressArgsToAddresses(args) + if err != nil { + return fmt.Errorf("set: mapped addresses %s", err.Error()) + } + + // update the db + err = sds.Indexer.SetWatchedAddresses(args, currentBlockNumber) + if err != nil { + return err + } + + // update in-memory params + sds.writeLoopParams.WatchedAddresses = argAddresses + sds.writeLoopParams.ComputeWatchedAddressesLeafKeys() + case statediff.Clear: + // update the db + err := sds.Indexer.ClearWatchedAddresses() + if err != nil { + return err + } + + // update in-memory params + sds.writeLoopParams.WatchedAddresses = []common.Address{} + sds.writeLoopParams.ComputeWatchedAddressesLeafKeys() + + default: + return fmt.Errorf("%s %s", unexpectedOperation, operation) + } + + return nil +} diff --git a/statediff/testhelpers/mocks/service_test.go b/statediff/testhelpers/mocks/service_test.go index 8c1fd49cf..dbd059def 100644 --- a/statediff/testhelpers/mocks/service_test.go +++ b/statediff/testhelpers/mocks/service_test.go @@ -21,6 +21,7 @@ import ( "fmt" "math/big" "os" + "reflect" "sort" "sync" "testing" @@ -87,6 +88,7 @@ func init() { func TestAPI(t *testing.T) { testSubscriptionAPI(t) testHTTPAPI(t) + testWatchAddressAPI(t) } func testSubscriptionAPI(t *testing.T) { @@ -246,3 +248,286 @@ func testHTTPAPI(t *testing.T) { t.Errorf("paylaod does not have the expected total difficulty\r\nactual td: %d\r\nexpected td: %d", payload.TotalDifficulty.Int64(), mockTotalDifficulty.Int64()) } } + +func testWatchAddressAPI(t *testing.T) { + blocks, chain := testhelpers.MakeChain(6, testhelpers.Genesis, testhelpers.TestChainGen) + defer chain.Stop() + block6 := blocks[5] + + mockBlockChain := &BlockChain{} + mockBlockChain.SetCurrentBlock(block6) + mockIndexer := Indexer{} + mockService := MockStateDiffService{ + BlockChain: mockBlockChain, + Indexer: &mockIndexer, + } + + // test data + var ( + contract1Address = "0x5d663F5269090bD2A7DC2390c911dF6083D7b28F" + contract2Address = "0x6Eb7e5C66DB8af2E96159AC440cbc8CDB7fbD26B" + contract3Address = "0xcfeB164C328CA13EFd3C77E1980d94975aDfedfc" + contract4Address = "0x0Edf0c4f393a628DE4828B228C48175b3EA297fc" + contract1CreatedAt = uint64(1) + contract2CreatedAt = uint64(2) + contract3CreatedAt = uint64(3) + contract4CreatedAt = uint64(4) + + args1 = []sdtypes.WatchAddressArg{ + { + Address: contract1Address, + CreatedAt: contract1CreatedAt, + }, + { + Address: contract2Address, + CreatedAt: contract2CreatedAt, + }, + } + startingParams1 = statediff.Params{ + WatchedAddresses: []common.Address{}, + } + expectedParams1 = statediff.Params{ + WatchedAddresses: []common.Address{ + common.HexToAddress(contract1Address), + common.HexToAddress(contract2Address), + }, + } + + args2 = []sdtypes.WatchAddressArg{ + { + Address: contract3Address, + CreatedAt: contract3CreatedAt, + }, + { + Address: contract2Address, + CreatedAt: contract2CreatedAt, + }, + } + startingParams2 = expectedParams1 + expectedParams2 = statediff.Params{ + WatchedAddresses: []common.Address{ + common.HexToAddress(contract1Address), + common.HexToAddress(contract2Address), + common.HexToAddress(contract3Address), + }, + } + + args3 = []sdtypes.WatchAddressArg{ + { + Address: contract3Address, + CreatedAt: contract3CreatedAt, + }, + { + Address: contract2Address, + CreatedAt: contract2CreatedAt, + }, + } + startingParams3 = expectedParams2 + expectedParams3 = statediff.Params{ + WatchedAddresses: []common.Address{ + common.HexToAddress(contract1Address), + }, + } + + args4 = []sdtypes.WatchAddressArg{ + { + Address: contract1Address, + CreatedAt: contract1CreatedAt, + }, + { + Address: contract2Address, + CreatedAt: contract2CreatedAt, + }, + } + startingParams4 = expectedParams3 + expectedParams4 = statediff.Params{ + WatchedAddresses: []common.Address{}, + } + + args5 = []sdtypes.WatchAddressArg{ + { + Address: contract1Address, + CreatedAt: contract1CreatedAt, + }, + { + Address: contract2Address, + CreatedAt: contract2CreatedAt, + }, + { + Address: contract3Address, + CreatedAt: contract3CreatedAt, + }, + } + startingParams5 = expectedParams4 + expectedParams5 = statediff.Params{ + WatchedAddresses: []common.Address{ + common.HexToAddress(contract1Address), + common.HexToAddress(contract2Address), + common.HexToAddress(contract3Address), + }, + } + + args6 = []sdtypes.WatchAddressArg{ + { + Address: contract4Address, + CreatedAt: contract4CreatedAt, + }, + { + Address: contract2Address, + CreatedAt: contract2CreatedAt, + }, + { + Address: contract3Address, + CreatedAt: contract3CreatedAt, + }, + } + startingParams6 = expectedParams5 + expectedParams6 = statediff.Params{ + WatchedAddresses: []common.Address{ + common.HexToAddress(contract4Address), + common.HexToAddress(contract2Address), + common.HexToAddress(contract3Address), + }, + } + + args7 = []sdtypes.WatchAddressArg{} + startingParams7 = expectedParams6 + expectedParams7 = statediff.Params{ + WatchedAddresses: []common.Address{}, + } + + args8 = []sdtypes.WatchAddressArg{} + startingParams8 = expectedParams6 + expectedParams8 = statediff.Params{ + WatchedAddresses: []common.Address{}, + } + + args9 = []sdtypes.WatchAddressArg{} + startingParams9 = expectedParams8 + expectedParams9 = statediff.Params{ + WatchedAddresses: []common.Address{}, + } + ) + + tests := []struct { + name string + operation statediff.OperationType + args []sdtypes.WatchAddressArg + startingParams statediff.Params + expectedParams statediff.Params + expectedErr error + }{ + { + "testAddAddresses", + statediff.Add, + args1, + startingParams1, + expectedParams1, + nil, + }, + { + "testAddAddressesSomeWatched", + statediff.Add, + args2, + startingParams2, + expectedParams2, + nil, + }, + { + "testRemoveAddresses", + statediff.Remove, + args3, + startingParams3, + expectedParams3, + nil, + }, + { + "testRemoveAddressesSomeWatched", + statediff.Remove, + args4, + startingParams4, + expectedParams4, + nil, + }, + { + "testSetAddresses", + statediff.Set, + args5, + startingParams5, + expectedParams5, + nil, + }, + { + "testSetAddressesSomeWatched", + statediff.Set, + args6, + startingParams6, + expectedParams6, + nil, + }, + { + "testSetAddressesEmtpyArgs", + statediff.Set, + args7, + startingParams7, + expectedParams7, + nil, + }, + { + "testClearAddresses", + statediff.Clear, + args8, + startingParams8, + expectedParams8, + nil, + }, + { + "testClearAddressesEmpty", + statediff.Clear, + args9, + startingParams9, + expectedParams9, + nil, + }, + + // invalid args + { + "testInvalidOperation", + "WrongOp", + args9, + startingParams9, + statediff.Params{}, + fmt.Errorf("%s WrongOp", unexpectedOperation), + }, + } + + for _, test := range tests { + // set indexing params + mockService.writeLoopParams = statediff.ParamsWithMutex{ + Params: test.startingParams, + } + mockService.writeLoopParams.ComputeWatchedAddressesLeafKeys() + + // make the API call to change watched addresses + err := mockService.WatchAddress(test.operation, test.args) + if test.expectedErr != nil { + if err.Error() != test.expectedErr.Error() { + t.Logf("Test failed: %s", test.name) + t.Errorf("actual err: %+v\nexpected err: %+v", err, test.expectedErr) + } + + continue + } + if err != nil { + t.Error(err) + } + + // check updated indexing params + test.expectedParams.ComputeWatchedAddressesLeafKeys() + updatedParams := mockService.writeLoopParams.Params + if !reflect.DeepEqual(updatedParams, test.expectedParams) { + t.Logf("Test failed: %s", test.name) + t.Errorf("actual params: %+v\nexpected params: %+v", updatedParams, test.expectedParams) + } + } +} diff --git a/statediff/types.go b/statediff/types.go index ef8256041..e33193637 100644 --- a/statediff/types.go +++ b/statediff/types.go @@ -22,9 +22,11 @@ package statediff import ( "encoding/json" "math/big" + "sync" "github.com/ethereum/go-ethereum/common" ctypes "github.com/ethereum/go-ethereum/core/types" + "github.com/ethereum/go-ethereum/crypto" "github.com/ethereum/go-ethereum/statediff/types" ) @@ -50,7 +52,21 @@ type Params struct { IncludeTD bool IncludeCode bool WatchedAddresses []common.Address - WatchedStorageSlots []common.Hash + watchedAddressesLeafKeys map[common.Hash]struct{} +} + +// ComputeWatchedAddressesLeafKeys populates a map with keys (Keccak256Hash) of each of the WatchedAddresses +func (p *Params) ComputeWatchedAddressesLeafKeys() { + p.watchedAddressesLeafKeys = make(map[common.Hash]struct{}, len(p.WatchedAddresses)) + for _, address := range p.WatchedAddresses { + p.watchedAddressesLeafKeys[crypto.Keccak256Hash(address.Bytes())] = struct{}{} + } +} + +// ParamsWithMutex allows to lock the parameters while they are being updated | read from +type ParamsWithMutex struct { + Params + sync.RWMutex } // Args bundles the arguments for the state diff builder @@ -111,3 +127,13 @@ type accountWrapper struct { NodeValue []byte LeafKey []byte } + +// OperationType for type of WatchAddress operation +type OperationType string + +const ( + Add OperationType = "add" + Remove OperationType = "remove" + Set OperationType = "set" + Clear OperationType = "clear" +) diff --git a/statediff/types/types.go b/statediff/types/types.go index 56babfb5b..bdda68c65 100644 --- a/statediff/types/types.go +++ b/statediff/types/types.go @@ -74,3 +74,10 @@ type CodeAndCodeHash struct { type StateNodeSink func(StateNode) error type StorageNodeSink func(StorageNode) error type CodeSink func(CodeAndCodeHash) error + +// WatchAddressArg is a arg type for WatchAddress API +type WatchAddressArg struct { + // Address represents common.Address + Address string + CreatedAt uint64 +}