From 9b234ef5b4d943dfb82ddb94761d8ed9acee62d3 Mon Sep 17 00:00:00 2001 From: prathamesh0 Date: Wed, 16 Mar 2022 15:35:33 +0530 Subject: [PATCH] Statediff API to change addresses being watched in direct indexing --- go.mod | 1 + go.sum | 2 + statediff/README.md | 4 +- statediff/api.go | 5 + statediff/builder.go | 115 +++++------- statediff/config.go | 18 +- statediff/indexer/database/sql/indexer.go | 116 ++++++++++++ statediff/indexer/interfaces/interfaces.go | 9 + statediff/service.go | 197 +++++++++++++++++++-- statediff/types/types.go | 17 ++ 10 files changed, 396 insertions(+), 88 deletions(-) diff --git a/go.mod b/go.mod index 669d85d16..30b3c97ed 100644 --- a/go.mod +++ b/go.mod @@ -70,6 +70,7 @@ require ( github.com/status-im/keycard-go v0.0.0-20190316090335-8537d3370df4 github.com/stretchr/testify v1.7.0 github.com/syndtr/goleveldb v1.0.1-0.20210819022825-2ae1ddf74ef7 + github.com/thoas/go-funk v0.9.2 github.com/tklauser/go-sysconf v0.3.5 // indirect github.com/tyler-smith/go-bip39 v1.0.1-0.20181017060643-dbb3b84ba2ef golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97 diff --git a/go.sum b/go.sum index 136d31806..dc748a005 100644 --- a/go.sum +++ b/go.sum @@ -585,6 +585,8 @@ github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5Cc github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/syndtr/goleveldb v1.0.1-0.20210819022825-2ae1ddf74ef7 h1:epCh84lMvA70Z7CTTCmYQn2CKbY8j86K7/FAIr141uY= github.com/syndtr/goleveldb v1.0.1-0.20210819022825-2ae1ddf74ef7/go.mod h1:q4W45IWZaF22tdD+VEXcAWRA037jwmWEB5VWYORlTpc= +github.com/thoas/go-funk v0.9.2 h1:oKlNYv0AY5nyf9g+/GhMgS/UO2ces0QRdPKwkhY3VCk= +github.com/thoas/go-funk v0.9.2/go.mod h1:+IWnUfUmFO1+WVYQWQtIJHeRRdaIyyYglZN7xzUPe4Q= github.com/tinylib/msgp v1.0.2/go.mod h1:+d+yLhGm8mzTaHzB+wgMYrodPfmZrzkirds8fDWklFE= github.com/tklauser/go-sysconf v0.3.5 h1:uu3Xl4nkLzQfXNsWn15rPc/HQCJKObbt1dKJeWp3vU4= github.com/tklauser/go-sysconf v0.3.5/go.mod h1:MkWzOF4RMCshBAMXuhXJs64Rte09mITnppBXY/rYEFI= diff --git a/statediff/README.md b/statediff/README.md index ef9509a7d..39ba24775 100644 --- a/statediff/README.md +++ b/statediff/README.md @@ -148,15 +148,13 @@ type Params struct { IncludeTD bool IncludeCode bool WatchedAddresses []common.Address - WatchedStorageSlots []common.Hash } ``` Using these params we can tell the service whether to include state and/or storage intermediate nodes; whether to include the associated block (header, uncles, and transactions); whether to include the associated receipts; whether to include the total difficulty for this block; whether to include the set of code hashes and code for -contracts deployed in this block; whether to limit the diffing process to a list of specific addresses; and/or -whether to limit the diffing process to a list of specific storage slot keys. +contracts deployed in this block; whether to limit the diffing process to a list of specific addresses. #### Subscription endpoint diff --git a/statediff/api.go b/statediff/api.go index 5c534cddb..0a7c5bba8 100644 --- a/statediff/api.go +++ b/statediff/api.go @@ -149,3 +149,8 @@ func (api *PublicStateDiffAPI) WriteStateDiffAt(ctx context.Context, blockNumber func (api *PublicStateDiffAPI) WriteStateDiffFor(ctx context.Context, blockHash common.Hash, params Params) error { return api.sds.WriteStateDiffFor(blockHash, params) } + +// WatchAddress changes the list of watched addresses to which the direct indexing is restricted according to given operation +func (api *PublicStateDiffAPI) WatchAddress(operation types.OperationType, args []types.WatchAddressArg) error { + return api.sds.WatchAddress(operation, args) +} diff --git a/statediff/builder.go b/statediff/builder.go index 7811c3e82..b56db2f5a 100644 --- a/statediff/builder.go +++ b/statediff/builder.go @@ -123,7 +123,7 @@ func (sdb *builder) buildStateTrie(it trie.NodeIterator) ([]types2.StateNode, [] node.LeafKey = leafKey if !bytes.Equal(account.CodeHash, nullCodeHash) { var storageNodes []types2.StorageNode - err := sdb.buildStorageNodesEventual(account.Root, nil, true, storageNodeAppender(&storageNodes)) + err := sdb.buildStorageNodesEventual(account.Root, true, storageNodeAppender(&storageNodes)) if err != nil { return nil, nil, fmt.Errorf("failed building eventual storage diffs for account %+v\r\nerror: %v", account, err) } @@ -220,12 +220,12 @@ func (sdb *builder) buildStateDiffWithIntermediateStateNodes(args types2.StateRo // build the diff nodes for the updated accounts using the mappings at both A and B as directed by the keys found as the intersection of the two err = sdb.buildAccountUpdates( diffAccountsAtB, diffAccountsAtA, updatedKeys, - params.WatchedStorageSlots, params.IntermediateStorageNodes, output) + params.IntermediateStorageNodes, output) if err != nil { return fmt.Errorf("error building diff for updated accounts: %v", err) } // build the diff nodes for created accounts - err = sdb.buildAccountCreations(diffAccountsAtB, params.WatchedStorageSlots, params.IntermediateStorageNodes, output, codeOutput) + err = sdb.buildAccountCreations(diffAccountsAtB, params.IntermediateStorageNodes, output, codeOutput) if err != nil { return fmt.Errorf("error building diff for created accounts: %v", err) } @@ -247,7 +247,7 @@ func (sdb *builder) buildStateDiffWithoutIntermediateStateNodes(args types2.Stat // and a slice of all the paths for the nodes in both of the above sets diffAccountsAtB, diffPathsAtB, err := sdb.createdAndUpdatedState( oldTrie.NodeIterator([]byte{}), newTrie.NodeIterator([]byte{}), - params.WatchedAddresses) + params.watchedAddressesLeafKeys) if err != nil { return fmt.Errorf("error collecting createdAndUpdatedNodes: %v", err) } @@ -274,12 +274,12 @@ func (sdb *builder) buildStateDiffWithoutIntermediateStateNodes(args types2.Stat // build the diff nodes for the updated accounts using the mappings at both A and B as directed by the keys found as the intersection of the two err = sdb.buildAccountUpdates( diffAccountsAtB, diffAccountsAtA, updatedKeys, - params.WatchedStorageSlots, params.IntermediateStorageNodes, output) + params.IntermediateStorageNodes, output) if err != nil { return fmt.Errorf("error building diff for updated accounts: %v", err) } // build the diff nodes for created accounts - err = sdb.buildAccountCreations(diffAccountsAtB, params.WatchedStorageSlots, params.IntermediateStorageNodes, output, codeOutput) + err = sdb.buildAccountCreations(diffAccountsAtB, params.IntermediateStorageNodes, output, codeOutput) if err != nil { return fmt.Errorf("error building diff for created accounts: %v", err) } @@ -289,7 +289,7 @@ func (sdb *builder) buildStateDiffWithoutIntermediateStateNodes(args types2.Stat // createdAndUpdatedState returns // a mapping of their leafkeys to all the accounts that exist in a different state at B than A // and a slice of the paths for all of the nodes included in both -func (sdb *builder) createdAndUpdatedState(a, b trie.NodeIterator, watchedAddresses []common.Address) (types2.AccountMap, map[string]bool, error) { +func (sdb *builder) createdAndUpdatedState(a, b trie.NodeIterator, watchedAddressesLeafKeys map[common.Hash]struct{}) (types2.AccountMap, map[string]bool, error) { diffPathsAtB := make(map[string]bool) diffAcountsAtB := make(types2.AccountMap) it, _ := trie.NewDifferenceIterator(a, b) @@ -313,7 +313,7 @@ func (sdb *builder) createdAndUpdatedState(a, b trie.NodeIterator, watchedAddres valueNodePath := append(node.Path, partialPath...) encodedPath := trie.HexToCompact(valueNodePath) leafKey := encodedPath[1:] - if isWatchedAddress(watchedAddresses, leafKey) { + if isWatchedAddress(watchedAddressesLeafKeys, leafKey) { diffAcountsAtB[common.Bytes2Hex(leafKey)] = types2.AccountWrapper{ NodeType: node.NodeType, Path: node.Path, @@ -454,8 +454,7 @@ func (sdb *builder) deletedOrUpdatedState(a, b trie.NodeIterator, diffPathsAtB m // to generate the statediff node objects for all of the accounts that existed at both A and B but in different states // needs to be called before building account creations and deletions as this mutates // those account maps to remove the accounts which were updated -func (sdb *builder) buildAccountUpdates(creations, deletions types2.AccountMap, updatedKeys []string, - watchedStorageKeys []common.Hash, intermediateStorageNodes bool, output types2.StateNodeSink) error { +func (sdb *builder) buildAccountUpdates(creations, deletions types2.AccountMap, updatedKeys []string, intermediateStorageNodes bool, output types2.StateNodeSink) error { var err error for _, key := range updatedKeys { createdAcc := creations[key] @@ -465,7 +464,7 @@ func (sdb *builder) buildAccountUpdates(creations, deletions types2.AccountMap, oldSR := deletedAcc.Account.Root newSR := createdAcc.Account.Root err = sdb.buildStorageNodesIncremental( - oldSR, newSR, watchedStorageKeys, intermediateStorageNodes, + oldSR, newSR, intermediateStorageNodes, storageNodeAppender(&storageDiffs)) if err != nil { return fmt.Errorf("failed building incremental storage diffs for account with leafkey %s\r\nerror: %v", key, err) @@ -489,7 +488,7 @@ func (sdb *builder) buildAccountUpdates(creations, deletions types2.AccountMap, // buildAccountCreations returns the statediff node objects for all the accounts that exist at B but not at A // it also returns the code and codehash for created contract accounts -func (sdb *builder) buildAccountCreations(accounts types2.AccountMap, watchedStorageKeys []common.Hash, intermediateStorageNodes bool, output types2.StateNodeSink, codeOutput types2.CodeSink) error { +func (sdb *builder) buildAccountCreations(accounts types2.AccountMap, intermediateStorageNodes bool, output types2.StateNodeSink, codeOutput types2.CodeSink) error { for _, val := range accounts { diff := types2.StateNode{ NodeType: val.NodeType, @@ -500,7 +499,7 @@ func (sdb *builder) buildAccountCreations(accounts types2.AccountMap, watchedSto if !bytes.Equal(val.Account.CodeHash, nullCodeHash) { // For contract creations, any storage node contained is a diff var storageDiffs []types2.StorageNode - err := sdb.buildStorageNodesEventual(val.Account.Root, watchedStorageKeys, intermediateStorageNodes, storageNodeAppender(&storageDiffs)) + err := sdb.buildStorageNodesEventual(val.Account.Root, intermediateStorageNodes, storageNodeAppender(&storageDiffs)) if err != nil { return fmt.Errorf("failed building eventual storage diffs for node %x\r\nerror: %v", val.Path, err) } @@ -528,7 +527,7 @@ func (sdb *builder) buildAccountCreations(accounts types2.AccountMap, watchedSto // buildStorageNodesEventual builds the storage diff node objects for a created account // i.e. it returns all the storage nodes at this state, since there is no previous state -func (sdb *builder) buildStorageNodesEventual(sr common.Hash, watchedStorageKeys []common.Hash, intermediateNodes bool, output types2.StorageNodeSink) error { +func (sdb *builder) buildStorageNodesEventual(sr common.Hash, intermediateNodes bool, output types2.StorageNodeSink) error { if bytes.Equal(sr.Bytes(), emptyContractRoot.Bytes()) { return nil } @@ -539,7 +538,7 @@ func (sdb *builder) buildStorageNodesEventual(sr common.Hash, watchedStorageKeys return err } it := sTrie.NodeIterator(make([]byte, 0)) - err = sdb.buildStorageNodesFromTrie(it, watchedStorageKeys, intermediateNodes, output) + err = sdb.buildStorageNodesFromTrie(it, intermediateNodes, output) if err != nil { return err } @@ -549,7 +548,7 @@ func (sdb *builder) buildStorageNodesEventual(sr common.Hash, watchedStorageKeys // buildStorageNodesFromTrie returns all the storage diff node objects in the provided node interator // if any storage keys are provided it will only return those leaf nodes // including intermediate nodes can be turned on or off -func (sdb *builder) buildStorageNodesFromTrie(it trie.NodeIterator, watchedStorageKeys []common.Hash, intermediateNodes bool, output types2.StorageNodeSink) error { +func (sdb *builder) buildStorageNodesFromTrie(it trie.NodeIterator, intermediateNodes bool, output types2.StorageNodeSink) error { for it.Next(true) { // skip value nodes if it.Leaf() || bytes.Equal(nullHashBytes, it.Hash().Bytes()) { @@ -565,15 +564,13 @@ func (sdb *builder) buildStorageNodesFromTrie(it trie.NodeIterator, watchedStora valueNodePath := append(node.Path, partialPath...) encodedPath := trie.HexToCompact(valueNodePath) leafKey := encodedPath[1:] - if isWatchedStorageKey(watchedStorageKeys, leafKey) { - if err := output(types2.StorageNode{ - NodeType: node.NodeType, - Path: node.Path, - NodeValue: node.NodeValue, - LeafKey: leafKey, - }); err != nil { - return err - } + if err := output(types2.StorageNode{ + NodeType: node.NodeType, + Path: node.Path, + NodeValue: node.NodeValue, + LeafKey: leafKey, + }); err != nil { + return err } case types2.Extension, types2.Branch: if intermediateNodes { @@ -593,7 +590,7 @@ func (sdb *builder) buildStorageNodesFromTrie(it trie.NodeIterator, watchedStora } // buildStorageNodesIncremental builds the storage diff node objects for all nodes that exist in a different state at B than A -func (sdb *builder) buildStorageNodesIncremental(oldSR common.Hash, newSR common.Hash, watchedStorageKeys []common.Hash, intermediateNodes bool, output types2.StorageNodeSink) error { +func (sdb *builder) buildStorageNodesIncremental(oldSR common.Hash, newSR common.Hash, intermediateNodes bool, output types2.StorageNodeSink) error { if bytes.Equal(newSR.Bytes(), oldSR.Bytes()) { return nil } @@ -609,19 +606,19 @@ func (sdb *builder) buildStorageNodesIncremental(oldSR common.Hash, newSR common diffPathsAtB, err := sdb.createdAndUpdatedStorage( oldTrie.NodeIterator([]byte{}), newTrie.NodeIterator([]byte{}), - watchedStorageKeys, intermediateNodes, output) + intermediateNodes, output) if err != nil { return err } err = sdb.deletedOrUpdatedStorage(oldTrie.NodeIterator([]byte{}), newTrie.NodeIterator([]byte{}), - diffPathsAtB, watchedStorageKeys, intermediateNodes, output) + diffPathsAtB, intermediateNodes, output) if err != nil { return err } return nil } -func (sdb *builder) createdAndUpdatedStorage(a, b trie.NodeIterator, watchedKeys []common.Hash, intermediateNodes bool, output types2.StorageNodeSink) (map[string]bool, error) { +func (sdb *builder) createdAndUpdatedStorage(a, b trie.NodeIterator, intermediateNodes bool, output types2.StorageNodeSink) (map[string]bool, error) { diffPathsAtB := make(map[string]bool) it, _ := trie.NewDifferenceIterator(a, b) for it.Next(true) { @@ -639,15 +636,13 @@ func (sdb *builder) createdAndUpdatedStorage(a, b trie.NodeIterator, watchedKeys valueNodePath := append(node.Path, partialPath...) encodedPath := trie.HexToCompact(valueNodePath) leafKey := encodedPath[1:] - if isWatchedStorageKey(watchedKeys, leafKey) { - if err := output(types2.StorageNode{ - NodeType: node.NodeType, - Path: node.Path, - NodeValue: node.NodeValue, - LeafKey: leafKey, - }); err != nil { - return nil, err - } + if err := output(types2.StorageNode{ + NodeType: node.NodeType, + Path: node.Path, + NodeValue: node.NodeValue, + LeafKey: leafKey, + }); err != nil { + return nil, err } case types2.Extension, types2.Branch: if intermediateNodes { @@ -667,7 +662,7 @@ func (sdb *builder) createdAndUpdatedStorage(a, b trie.NodeIterator, watchedKeys return diffPathsAtB, it.Error() } -func (sdb *builder) deletedOrUpdatedStorage(a, b trie.NodeIterator, diffPathsAtB map[string]bool, watchedKeys []common.Hash, intermediateNodes bool, output types2.StorageNodeSink) error { +func (sdb *builder) deletedOrUpdatedStorage(a, b trie.NodeIterator, diffPathsAtB map[string]bool, intermediateNodes bool, output types2.StorageNodeSink) error { it, _ := trie.NewDifferenceIterator(b, a) for it.Next(true) { // skip value nodes @@ -690,15 +685,13 @@ func (sdb *builder) deletedOrUpdatedStorage(a, b trie.NodeIterator, diffPathsAtB valueNodePath := append(node.Path, partialPath...) encodedPath := trie.HexToCompact(valueNodePath) leafKey := encodedPath[1:] - if isWatchedStorageKey(watchedKeys, leafKey) { - if err := output(types2.StorageNode{ - NodeType: types2.Removed, - Path: node.Path, - NodeValue: []byte{}, - LeafKey: leafKey, - }); err != nil { - return err - } + if err := output(types2.StorageNode{ + NodeType: types2.Removed, + Path: node.Path, + NodeValue: []byte{}, + LeafKey: leafKey, + }); err != nil { + return err } case types2.Extension, types2.Branch: if intermediateNodes { @@ -718,30 +711,12 @@ func (sdb *builder) deletedOrUpdatedStorage(a, b trie.NodeIterator, diffPathsAtB } // isWatchedAddress is used to check if a state account corresponds to one of the addresses the builder is configured to watch -func isWatchedAddress(watchedAddresses []common.Address, stateLeafKey []byte) bool { +func isWatchedAddress(watchedAddressesLeafKeys map[common.Hash]struct{}, stateLeafKey []byte) bool { // If we aren't watching any specific addresses, we are watching everything - if len(watchedAddresses) == 0 { + if len(watchedAddressesLeafKeys) == 0 { return true } - for _, addr := range watchedAddresses { - addrHashKey := crypto.Keccak256(addr.Bytes()) - if bytes.Equal(addrHashKey, stateLeafKey) { - return true - } - } - return false -} -// isWatchedStorageKey is used to check if a storage leaf corresponds to one of the storage slots the builder is configured to watch -func isWatchedStorageKey(watchedKeys []common.Hash, storageLeafKey []byte) bool { - // If we aren't watching any specific addresses, we are watching everything - if len(watchedKeys) == 0 { - return true - } - for _, hashKey := range watchedKeys { - if bytes.Equal(hashKey.Bytes(), storageLeafKey) { - return true - } - } - return false + _, ok := watchedAddressesLeafKeys[common.BytesToHash(stateLeafKey)] + return ok } diff --git a/statediff/config.go b/statediff/config.go index f20f3267e..b4905ab5a 100644 --- a/statediff/config.go +++ b/statediff/config.go @@ -19,8 +19,10 @@ package statediff import ( "context" "math/big" + "sync" "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/crypto" "github.com/ethereum/go-ethereum/statediff/indexer/interfaces" ) @@ -53,7 +55,21 @@ type Params struct { IncludeTD bool IncludeCode bool WatchedAddresses []common.Address - WatchedStorageSlots []common.Hash + watchedAddressesLeafKeys map[common.Hash]struct{} +} + +// ComputeWatchedAddressesLeafKeys populates a map with keys (Keccak256Hash) of each of the WatchedAddresses +func (p *Params) ComputeWatchedAddressesLeafKeys() { + p.watchedAddressesLeafKeys = make(map[common.Hash]struct{}, len(p.WatchedAddresses)) + for _, address := range p.WatchedAddresses { + p.watchedAddressesLeafKeys[crypto.Keccak256Hash(address.Bytes())] = struct{}{} + } +} + +// ParamsWithMutex allows to lock the parameters while they are being updated | read from +type ParamsWithMutex struct { + Params + sync.RWMutex } // Args bundles the arguments for the state diff builder diff --git a/statediff/indexer/database/sql/indexer.go b/statediff/indexer/database/sql/indexer.go index 4f2b52434..8f39f1f34 100644 --- a/statediff/indexer/database/sql/indexer.go +++ b/statediff/indexer/database/sql/indexer.go @@ -555,3 +555,119 @@ func (sdi *StateDiffIndexer) Close() error { } // Update the known gaps table with the gap information. + +// LoadWatchedAddresses reads watched addresses from the database +func (sdi *StateDiffIndexer) LoadWatchedAddresses() ([]common.Address, error) { + + addressStrings := make([]string, 0) + pgStr := "SELECT address FROM eth_meta.watched_addresses" + err := sdi.dbWriter.db.Select(sdi.dbWriter.db.Context(), &addressStrings, pgStr) + if err != nil { + return nil, fmt.Errorf("error loading watched addresses: %v", err) + } + + watchedAddresses := []common.Address{} + for _, addressString := range addressStrings { + watchedAddresses = append(watchedAddresses, common.HexToAddress(addressString)) + } + + return watchedAddresses, nil +} + +// InsertWatchedAddresses inserts the given addresses in the database +func (sdi *StateDiffIndexer) InsertWatchedAddresses(args []sdtypes.WatchAddressArg, currentBlockNumber *big.Int) error { + tx, err := sdi.dbWriter.db.Begin(sdi.ctx) + if err != nil { + return err + } + defer func() { + if p := recover(); p != nil { + rollback(sdi.ctx, tx) + panic(p) + } else if err != nil { + rollback(sdi.ctx, tx) + } else { + err = tx.Commit(sdi.ctx) + } + }() + + for _, arg := range args { + _, err = tx.Exec(sdi.dbWriter.db.Context(), `INSERT INTO eth_meta.watched_addresses (address, created_at, watched_at) VALUES ($1, $2, $3) ON CONFLICT (address) DO NOTHING`, + arg.Address, arg.CreatedAt, currentBlockNumber.Uint64()) + if err != nil { + return fmt.Errorf("error inserting watched_addresses entry: %v", err) + } + } + + return err +} + +// RemoveWatchedAddresses removes the given watched addresses from the database +func (sdi *StateDiffIndexer) RemoveWatchedAddresses(args []sdtypes.WatchAddressArg) error { + tx, err := sdi.dbWriter.db.Begin(sdi.ctx) + if err != nil { + return err + } + defer func() { + if p := recover(); p != nil { + rollback(sdi.ctx, tx) + panic(p) + } else if err != nil { + rollback(sdi.ctx, tx) + } else { + err = tx.Commit(sdi.ctx) + } + }() + + for _, arg := range args { + _, err = tx.Exec(sdi.dbWriter.db.Context(), `DELETE FROM eth_meta.watched_addresses WHERE address = $1`, arg.Address) + if err != nil { + return fmt.Errorf("error removing watched_addresses entry: %v", err) + } + } + + return err +} + +// SetWatchedAddresses clears and inserts the given addresses in the database +func (sdi *StateDiffIndexer) SetWatchedAddresses(args []sdtypes.WatchAddressArg, currentBlockNumber *big.Int) error { + tx, err := sdi.dbWriter.db.Begin(sdi.ctx) + if err != nil { + return err + } + defer func() { + if p := recover(); p != nil { + rollback(sdi.ctx, tx) + panic(p) + } else if err != nil { + rollback(sdi.ctx, tx) + } else { + err = tx.Commit(sdi.ctx) + } + }() + + _, err = tx.Exec(sdi.dbWriter.db.Context(), `DELETE FROM eth_meta.watched_addresses`) + if err != nil { + return fmt.Errorf("error setting watched_addresses table: %v", err) + } + + for _, arg := range args { + _, err = tx.Exec(sdi.dbWriter.db.Context(), `INSERT INTO eth_meta.watched_addresses (address, created_at, watched_at) VALUES ($1, $2, $3) ON CONFLICT (address) DO NOTHING`, + arg.Address, arg.CreatedAt, currentBlockNumber.Uint64()) + if err != nil { + return fmt.Errorf("error setting watched_addresses table: %v", err) + } + } + + return err +} + +// ClearWatchedAddresses clears all the watched addresses from the database +func (sdi *StateDiffIndexer) ClearWatchedAddresses() error { + _, err := sdi.dbWriter.db.Exec(sdi.dbWriter.db.Context(), `DELETE FROM eth_meta.watched_addresses`) + if err != nil { + return fmt.Errorf("error clearing watched_addresses table: %v", err) + } + + return nil +} diff --git a/statediff/indexer/interfaces/interfaces.go b/statediff/indexer/interfaces/interfaces.go index 8f951230d..6910e3f49 100644 --- a/statediff/indexer/interfaces/interfaces.go +++ b/statediff/indexer/interfaces/interfaces.go @@ -21,6 +21,7 @@ import ( "math/big" "time" + "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/core/types" "github.com/ethereum/go-ethereum/statediff/indexer/shared" sdtypes "github.com/ethereum/go-ethereum/statediff/types" @@ -32,6 +33,14 @@ type StateDiffIndexer interface { PushStateNode(tx Batch, stateNode sdtypes.StateNode, headerID string) error PushCodeAndCodeHash(tx Batch, codeAndCodeHash sdtypes.CodeAndCodeHash) error ReportDBMetrics(delay time.Duration, quit <-chan bool) + + // Methods used by WatchAddress API/functionality + LoadWatchedAddresses() ([]common.Address, error) + InsertWatchedAddresses(addresses []sdtypes.WatchAddressArg, currentBlock *big.Int) error + RemoveWatchedAddresses(addresses []sdtypes.WatchAddressArg) error + SetWatchedAddresses(args []sdtypes.WatchAddressArg, currentBlockNumber *big.Int) error + ClearWatchedAddresses() error + io.Closer } diff --git a/statediff/service.go b/statediff/service.go index 960f776f8..ce36c86d4 100644 --- a/statediff/service.go +++ b/statediff/service.go @@ -18,6 +18,7 @@ package statediff import ( "bytes" + "fmt" "math/big" "strconv" "strings" @@ -47,28 +48,34 @@ import ( "github.com/ethereum/go-ethereum/statediff/indexer/shared" types2 "github.com/ethereum/go-ethereum/statediff/types" "github.com/ethereum/go-ethereum/trie" + "github.com/thoas/go-funk" ) const ( - chainEventChanSize = 20000 - genesisBlockNumber = 0 - defaultRetryLimit = 3 // default retry limit once deadlock is detected. - deadlockDetected = "deadlock detected" // 40P01 https://www.postgresql.org/docs/current/errcodes-appendix.html + chainEventChanSize = 20000 + genesisBlockNumber = 0 + defaultRetryLimit = 3 // default retry limit once deadlock is detected. + deadlockDetected = "deadlock detected" // 40P01 https://www.postgresql.org/docs/current/errcodes-appendix.html + typeAssertionFailed = "type assertion failed" + unexpectedOperation = "unexpected operation" ) -var writeLoopParams = Params{ - IntermediateStateNodes: true, - IntermediateStorageNodes: true, - IncludeBlock: true, - IncludeReceipts: true, - IncludeTD: true, - IncludeCode: true, +var writeLoopParams = ParamsWithMutex{ + Params: Params{ + IntermediateStateNodes: true, + IntermediateStorageNodes: true, + IncludeBlock: true, + IncludeReceipts: true, + IncludeTD: true, + IncludeCode: true, + }, } var statediffMetrics = RegisterStatediffMetrics(metrics.DefaultRegistry) type blockChain interface { SubscribeChainEvent(ch chan<- core.ChainEvent) event.Subscription + CurrentBlock() *types.Block GetBlockByHash(hash common.Hash) *types.Block GetBlockByNumber(number uint64) *types.Block GetReceiptsByHash(hash common.Hash) types.Receipts @@ -103,6 +110,8 @@ type IService interface { WriteStateDiffFor(blockHash common.Hash, params Params) error // WriteLoop event loop for progressively processing and writing diffs directly to DB WriteLoop(chainEventCh chan core.ChainEvent) + // Method to change the addresses being watched in write loop params + WatchAddress(operation types2.OperationType, args []types2.WatchAddressArg) error } // Service is the underlying struct for the state diffing service @@ -159,6 +168,7 @@ func New(stack *node.Node, ethServ *eth.Ethereum, cfg *ethconfig.Config, params blockChain := ethServ.BlockChain() var indexer interfaces.StateDiffIndexer var db sql.Database + var err error quitCh := make(chan bool) if params.IndexerConfig != nil { info := nodeinfo.Info{ @@ -215,6 +225,12 @@ func New(stack *node.Node, ethServ *eth.Ethereum, cfg *ethconfig.Config, params } stack.RegisterLifecycle(sds) stack.RegisterAPIs(sds.APIs()) + + err = loadWatchedAddresses(indexer) + if err != nil { + return err + } + return nil } @@ -304,7 +320,9 @@ func (sds *Service) WriteLoop(chainEventCh chan core.ChainEvent) { func (sds *Service) writeGenesisStateDiff(currBlock *types.Block, workerId uint) { // For genesis block we need to return the entire state trie hence we diff it with an empty trie. log.Info("Writing state diff", "block height", genesisBlockNumber, "worker", workerId) - err := sds.writeStateDiffWithRetry(currBlock, common.Hash{}, writeLoopParams) + writeLoopParams.RLock() + err := sds.writeStateDiffWithRetry(currBlock, common.Hash{}, writeLoopParams.Params) + writeLoopParams.RUnlock() if err != nil { log.Error("statediff.Service.WriteLoop: processing error", "block height", genesisBlockNumber, "error", err.Error(), "worker", workerId) @@ -341,7 +359,9 @@ func (sds *Service) writeLoopWorker(params workerParams) { } log.Info("Writing state diff", "block height", currentBlock.Number().Uint64(), "worker", params.id) - err := sds.writeStateDiffWithRetry(currentBlock, parentBlock.Root(), writeLoopParams) + writeLoopParams.RLock() + err := sds.writeStateDiffWithRetry(currentBlock, parentBlock.Root(), writeLoopParams.Params) + writeLoopParams.RUnlock() if err != nil { log.Error("statediff.Service.WriteLoop: processing error", "block height", currentBlock.Number().Uint64(), "error", err.Error(), "worker", params.id) sds.KnownGaps.errorState = true @@ -456,6 +476,10 @@ func (sds *Service) streamStateDiff(currentBlock *types.Block, parentRoot common func (sds *Service) StateDiffAt(blockNumber uint64, params Params) (*Payload, error) { currentBlock := sds.BlockChain.GetBlockByNumber(blockNumber) log.Info("sending state diff", "block height", blockNumber) + + // compute leaf keys of watched addresses in the params + params.ComputeWatchedAddressesLeafKeys() + if blockNumber == 0 { return sds.processStateDiff(currentBlock, common.Hash{}, params) } @@ -468,6 +492,10 @@ func (sds *Service) StateDiffAt(blockNumber uint64, params Params) (*Payload, er func (sds *Service) StateDiffFor(blockHash common.Hash, params Params) (*Payload, error) { currentBlock := sds.BlockChain.GetBlockByHash(blockHash) log.Info("sending state diff", "block hash", blockHash) + + // compute leaf keys of watched addresses in the params + params.ComputeWatchedAddressesLeafKeys() + if currentBlock.NumberU64() == 0 { return sds.processStateDiff(currentBlock, common.Hash{}, params) } @@ -526,6 +554,10 @@ func (sds *Service) newPayload(stateObject []byte, block *types.Block, params Pa func (sds *Service) StateTrieAt(blockNumber uint64, params Params) (*Payload, error) { currentBlock := sds.BlockChain.GetBlockByNumber(blockNumber) log.Info("sending state trie", "block height", blockNumber) + + // compute leaf keys of watched addresses in the params + params.ComputeWatchedAddressesLeafKeys() + return sds.processStateTrie(currentBlock, params) } @@ -548,6 +580,10 @@ func (sds *Service) Subscribe(id rpc.ID, sub chan<- Payload, quitChan chan<- boo if atomic.CompareAndSwapInt32(&sds.subscribers, 0, 1) { log.Info("State diffing subscription received; beginning statediff processing") } + + // compute leaf keys of watched addresses in the params + params.ComputeWatchedAddressesLeafKeys() + // Subscription type is defined as the hash of the rlp-serialized subscription params by, err := rlp.EncodeToBytes(params) if err != nil { @@ -644,7 +680,7 @@ func (sds *Service) Start() error { go sds.Loop(chainEventCh) if sds.enableWriteLoop { - log.Info("Starting statediff DB write loop", "params", writeLoopParams) + log.Info("Starting statediff DB write loop", "params", writeLoopParams.Params) chainEventCh := make(chan core.ChainEvent, chainEventChanSize) go sds.WriteLoop(chainEventCh) } @@ -741,6 +777,9 @@ func (sds *Service) StreamCodeAndCodeHash(blockNumber uint64, outChan chan<- typ // This operation cannot be performed back past the point of db pruning; it requires an archival node // for historical data func (sds *Service) WriteStateDiffAt(blockNumber uint64, params Params) error { + // compute leaf keys of watched addresses in the params + params.ComputeWatchedAddressesLeafKeys() + currentBlock := sds.BlockChain.GetBlockByNumber(blockNumber) parentRoot := common.Hash{} if blockNumber != 0 { @@ -754,6 +793,9 @@ func (sds *Service) WriteStateDiffAt(blockNumber uint64, params Params) error { // This operation cannot be performed back past the point of db pruning; it requires an archival node // for historical data func (sds *Service) WriteStateDiffFor(blockHash common.Hash, params Params) error { + // compute leaf keys of watched addresses in the params + params.ComputeWatchedAddressesLeafKeys() + currentBlock := sds.BlockChain.GetBlockByHash(blockHash) parentRoot := common.Hash{} if currentBlock.NumberU64() != 0 { @@ -821,3 +863,130 @@ func (sds *Service) writeStateDiffWithRetry(block *types.Block, parentRoot commo } return err } + +// Performs one of following operations on the watched addresses in writeLoopParams and the db: +// add | remove | set | clear +func (sds *Service) WatchAddress(operation types2.OperationType, args []types2.WatchAddressArg) error { + // lock writeLoopParams for a write + writeLoopParams.Lock() + defer writeLoopParams.Unlock() + + // get the current block number + currentBlockNumber := sds.BlockChain.CurrentBlock().Number() + + switch operation { + case types2.Add: + // filter out args having an already watched address with a warning + filteredArgs, ok := funk.Filter(args, func(arg types2.WatchAddressArg) bool { + if funk.Contains(writeLoopParams.WatchedAddresses, common.HexToAddress(arg.Address)) { + log.Warn("Address already being watched", "address", arg.Address) + return false + } + return true + }).([]types2.WatchAddressArg) + if !ok { + return fmt.Errorf("add: filtered args %s", typeAssertionFailed) + } + + // get addresses from the filtered args + filteredAddresses, err := mapWatchAddressArgsToAddresses(filteredArgs) + if err != nil { + return fmt.Errorf("add: filtered addresses %s", err.Error()) + } + + // update the db + err = sds.indexer.InsertWatchedAddresses(filteredArgs, currentBlockNumber) + if err != nil { + return err + } + + // update in-memory params + writeLoopParams.WatchedAddresses = append(writeLoopParams.WatchedAddresses, filteredAddresses...) + funk.ForEach(filteredAddresses, func(address common.Address) { + writeLoopParams.watchedAddressesLeafKeys[crypto.Keccak256Hash(address.Bytes())] = struct{}{} + }) + case types2.Remove: + // get addresses from args + argAddresses, err := mapWatchAddressArgsToAddresses(args) + if err != nil { + return fmt.Errorf("remove: mapped addresses %s", err.Error()) + } + + // remove the provided addresses from currently watched addresses + addresses, ok := funk.Subtract(writeLoopParams.WatchedAddresses, argAddresses).([]common.Address) + if !ok { + return fmt.Errorf("remove: filtered addresses %s", typeAssertionFailed) + } + + // update the db + err = sds.indexer.RemoveWatchedAddresses(args) + if err != nil { + return err + } + + // update in-memory params + writeLoopParams.WatchedAddresses = addresses + funk.ForEach(argAddresses, func(address common.Address) { + delete(writeLoopParams.watchedAddressesLeafKeys, crypto.Keccak256Hash(address.Bytes())) + }) + case types2.Set: + // get addresses from args + argAddresses, err := mapWatchAddressArgsToAddresses(args) + if err != nil { + return fmt.Errorf("set: mapped addresses %s", err.Error()) + } + + // update the db + err = sds.indexer.SetWatchedAddresses(args, currentBlockNumber) + if err != nil { + return err + } + + // update in-memory params + writeLoopParams.WatchedAddresses = argAddresses + writeLoopParams.ComputeWatchedAddressesLeafKeys() + case types2.Clear: + // update the db + err := sds.indexer.ClearWatchedAddresses() + if err != nil { + return err + } + + // update in-memory params + writeLoopParams.WatchedAddresses = []common.Address{} + writeLoopParams.ComputeWatchedAddressesLeafKeys() + + default: + return fmt.Errorf("%s %s", unexpectedOperation, operation) + } + + return nil +} + +// loadWatchedAddresses loads watched addresses to in-memory write loop params +func loadWatchedAddresses(indexer interfaces.StateDiffIndexer) error { + watchedAddresses, err := indexer.LoadWatchedAddresses() + if err != nil { + return err + } + + writeLoopParams.Lock() + defer writeLoopParams.Unlock() + + writeLoopParams.WatchedAddresses = watchedAddresses + writeLoopParams.ComputeWatchedAddressesLeafKeys() + + return nil +} + +// mapWatchAddressArgsToAddresses maps []WatchAddressArg to corresponding []common.Address +func mapWatchAddressArgsToAddresses(args []types2.WatchAddressArg) ([]common.Address, error) { + addresses, ok := funk.Map(args, func(arg types2.WatchAddressArg) common.Address { + return common.HexToAddress(arg.Address) + }).([]common.Address) + if !ok { + return nil, fmt.Errorf(typeAssertionFailed) + } + + return addresses, nil +} diff --git a/statediff/types/types.go b/statediff/types/types.go index 36008a784..0a29adaf8 100644 --- a/statediff/types/types.go +++ b/statediff/types/types.go @@ -101,3 +101,20 @@ type CodeAndCodeHash struct { type StateNodeSink func(StateNode) error type StorageNodeSink func(StorageNode) error type CodeSink func(CodeAndCodeHash) error + +// OperationType for type of WatchAddress operation +type OperationType string + +const ( + Add OperationType = "add" + Remove OperationType = "remove" + Set OperationType = "set" + Clear OperationType = "clear" +) + +// WatchAddressArg is a arg type for WatchAddress API +type WatchAddressArg struct { + // Address represents common.Address + Address string + CreatedAt uint64 +}