diff --git a/statediff/builder.go b/statediff/builder.go index aee8f71ff..46546c1d5 100644 --- a/statediff/builder.go +++ b/statediff/builder.go @@ -202,7 +202,7 @@ func (sdb *builder) buildStateDiffWithIntermediateStateNodes(args StateRoots, pa // a map of their leafkey to all the accounts that were touched and exist at A diffAccountsAtA, err := sdb.deletedOrUpdatedState( oldTrie.NodeIterator([]byte{}), newTrie.NodeIterator([]byte{}), - diffPathsAtB, params.WatchedAddresses, output) + diffPathsAtB, params.watchedAddressesLeafKeys, output) if err != nil { return fmt.Errorf("error collecting deletedOrUpdatedNodes: %v", err) } @@ -247,7 +247,7 @@ func (sdb *builder) buildStateDiffWithoutIntermediateStateNodes(args StateRoots, // and a slice of all the paths for the nodes in both of the above sets diffAccountsAtB, diffPathsAtB, err := sdb.createdAndUpdatedState( oldTrie.NodeIterator([]byte{}), newTrie.NodeIterator([]byte{}), - params.WatchedAddresses) + params.watchedAddressesLeafKeys) if err != nil { return fmt.Errorf("error collecting createdAndUpdatedNodes: %v", err) } @@ -256,7 +256,7 @@ func (sdb *builder) buildStateDiffWithoutIntermediateStateNodes(args StateRoots, // a map of their leafkey to all the accounts that were touched and exist at A diffAccountsAtA, err := sdb.deletedOrUpdatedState( oldTrie.NodeIterator([]byte{}), newTrie.NodeIterator([]byte{}), - diffPathsAtB, params.WatchedAddresses, output) + diffPathsAtB, params.watchedAddressesLeafKeys, output) if err != nil { return fmt.Errorf("error collecting deletedOrUpdatedNodes: %v", err) } @@ -289,7 +289,7 @@ func (sdb *builder) buildStateDiffWithoutIntermediateStateNodes(args StateRoots, // createdAndUpdatedState returns // a mapping of their leafkeys to all the accounts that exist in a different state at B than A // and a slice of the paths for all of the nodes included in both -func (sdb *builder) createdAndUpdatedState(a, b trie.NodeIterator, watchedAddresses []common.Address) (AccountMap, map[string]bool, error) { +func (sdb *builder) createdAndUpdatedState(a, b trie.NodeIterator, watchedAddressesLeafKeys map[common.Hash]struct{}) (AccountMap, map[string]bool, error) { diffPathsAtB := make(map[string]bool) diffAcountsAtB := make(AccountMap) it, _ := trie.NewDifferenceIterator(a, b) @@ -313,7 +313,7 @@ func (sdb *builder) createdAndUpdatedState(a, b trie.NodeIterator, watchedAddres valueNodePath := append(node.Path, partialPath...) encodedPath := trie.HexToCompact(valueNodePath) leafKey := encodedPath[1:] - if isWatchedAddress(watchedAddresses, leafKey) { + if isWatchedAddress(watchedAddressesLeafKeys, leafKey) { diffAcountsAtB[common.Bytes2Hex(leafKey)] = accountWrapper{ NodeType: node.NodeType, Path: node.Path, @@ -386,7 +386,7 @@ func (sdb *builder) createdAndUpdatedStateWithIntermediateNodes(a, b trie.NodeIt // deletedOrUpdatedState returns a slice of all the pathes that are emptied at B // and a mapping of their leafkeys to all the accounts that exist in a different state at A than B -func (sdb *builder) deletedOrUpdatedState(a, b trie.NodeIterator, diffPathsAtB map[string]bool, watchedAddresses []common.Address, output StateNodeSink) (AccountMap, error) { +func (sdb *builder) deletedOrUpdatedState(a, b trie.NodeIterator, diffPathsAtB map[string]bool, watchedAddressesLeafKeys map[common.Hash]struct{}, output StateNodeSink) (AccountMap, error) { diffAccountAtA := make(AccountMap) it, _ := trie.NewDifferenceIterator(b, a) for it.Next(true) { @@ -409,7 +409,7 @@ func (sdb *builder) deletedOrUpdatedState(a, b trie.NodeIterator, diffPathsAtB m valueNodePath := append(node.Path, partialPath...) encodedPath := trie.HexToCompact(valueNodePath) leafKey := encodedPath[1:] - if isWatchedAddress(watchedAddresses, leafKey) { + if isWatchedAddress(watchedAddressesLeafKeys, leafKey) { diffAccountAtA[common.Bytes2Hex(leafKey)] = accountWrapper{ NodeType: node.NodeType, Path: node.Path, @@ -713,16 +713,12 @@ func (sdb *builder) deletedOrUpdatedStorage(a, b trie.NodeIterator, diffPathsAtB } // isWatchedAddress is used to check if a state account corresponds to one of the addresses the builder is configured to watch -func isWatchedAddress(watchedAddresses []common.Address, stateLeafKey []byte) bool { +func isWatchedAddress(watchedAddressesLeafKeys map[common.Hash]struct{}, stateLeafKey []byte) bool { // If we aren't watching any specific addresses, we are watching everything - if len(watchedAddresses) == 0 { + if len(watchedAddressesLeafKeys) == 0 { return true } - for _, addr := range watchedAddresses { - addrHashKey := crypto.Keccak256(addr.Bytes()) - if bytes.Equal(addrHashKey, stateLeafKey) { - return true - } - } - return false + + _, ok := watchedAddressesLeafKeys[common.BytesToHash(stateLeafKey)] + return ok } diff --git a/statediff/builder_test.go b/statediff/builder_test.go index 741605d41..945e3799d 100644 --- a/statediff/builder_test.go +++ b/statediff/builder_test.go @@ -987,6 +987,7 @@ func TestBuilderWithWatchedAddressList(t *testing.T) { params := statediff.Params{ WatchedAddresses: []common.Address{testhelpers.Account1Addr, testhelpers.ContractAddr}, } + params.ComputeWatchedAddressesLeafKeys() builder = statediff.NewBuilder(chain.StateCache()) var tests = []struct { @@ -1566,6 +1567,7 @@ func TestBuilderWithRemovedNonWatchedAccount(t *testing.T) { params := statediff.Params{ WatchedAddresses: []common.Address{testhelpers.Account1Addr, testhelpers.Account2Addr}, } + params.ComputeWatchedAddressesLeafKeys() builder = statediff.NewBuilder(chain.StateCache()) var tests = []struct { diff --git a/statediff/helpers.go b/statediff/helpers.go index 73ac084e8..8870855bd 100644 --- a/statediff/helpers.go +++ b/statediff/helpers.go @@ -24,8 +24,11 @@ import ( "sort" "strings" + "github.com/thoas/go-funk" + "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/statediff/indexer/postgres" + . "github.com/ethereum/go-ethereum/statediff/types" ) func sortKeys(data AccountMap) []string { @@ -96,6 +99,19 @@ func loadWatchedAddresses(db *postgres.DB) error { writeLoopParams.Lock() defer writeLoopParams.Unlock() writeLoopParams.WatchedAddresses = watchedAddresses + writeLoopParams.ComputeWatchedAddressesLeafKeys() return nil } + +// MapWatchAddressArgsToAddresses maps []WatchAddressArg to corresponding []common.Address +func MapWatchAddressArgsToAddresses(args []WatchAddressArg) ([]common.Address, error) { + addresses, ok := funk.Map(args, func(arg WatchAddressArg) common.Address { + return common.HexToAddress(arg.Address) + }).([]common.Address) + if !ok { + return nil, fmt.Errorf(typeAssertionFailed) + } + + return addresses, nil +} diff --git a/statediff/service.go b/statediff/service.go index ac6739ed6..998a8e85a 100644 --- a/statediff/service.go +++ b/statediff/service.go @@ -423,6 +423,9 @@ func (sds *Service) streamStateDiff(currentBlock *types.Block, parentRoot common func (sds *Service) StateDiffAt(blockNumber uint64, params Params) (*Payload, error) { currentBlock := sds.BlockChain.GetBlockByNumber(blockNumber) log.Info("sending state diff", "block height", blockNumber) + + params.ComputeWatchedAddressesLeafKeys() + if blockNumber == 0 { return sds.processStateDiff(currentBlock, common.Hash{}, params) } @@ -435,6 +438,9 @@ func (sds *Service) StateDiffAt(blockNumber uint64, params Params) (*Payload, er func (sds *Service) StateDiffFor(blockHash common.Hash, params Params) (*Payload, error) { currentBlock := sds.BlockChain.GetBlockByHash(blockHash) log.Info("sending state diff", "block hash", blockHash) + + params.ComputeWatchedAddressesLeafKeys() + if currentBlock.NumberU64() == 0 { return sds.processStateDiff(currentBlock, common.Hash{}, params) } @@ -493,6 +499,9 @@ func (sds *Service) newPayload(stateObject []byte, block *types.Block, params Pa func (sds *Service) StateTrieAt(blockNumber uint64, params Params) (*Payload, error) { currentBlock := sds.BlockChain.GetBlockByNumber(blockNumber) log.Info("sending state trie", "block height", blockNumber) + + params.ComputeWatchedAddressesLeafKeys() + return sds.processStateTrie(currentBlock, params) } @@ -515,6 +524,9 @@ func (sds *Service) Subscribe(id rpc.ID, sub chan<- Payload, quitChan chan<- boo if atomic.CompareAndSwapInt32(&sds.subscribers, 0, 1) { log.Info("State diffing subscription received; beginning statediff processing") } + + params.ComputeWatchedAddressesLeafKeys() + // Subscription type is defined as the hash of the rlp-serialized subscription params by, err := rlp.EncodeToBytes(params) if err != nil { @@ -661,6 +673,8 @@ func (sds *Service) StreamCodeAndCodeHash(blockNumber uint64, outChan chan<- Cod // This operation cannot be performed back past the point of db pruning; it requires an archival node // for historical data func (sds *Service) WriteStateDiffAt(blockNumber uint64, params Params) error { + params.ComputeWatchedAddressesLeafKeys() + currentBlock := sds.BlockChain.GetBlockByNumber(blockNumber) parentRoot := common.Hash{} if blockNumber != 0 { @@ -674,6 +688,8 @@ func (sds *Service) WriteStateDiffAt(blockNumber uint64, params Params) error { // This operation cannot be performed back past the point of db pruning; it requires an archival node // for historical data func (sds *Service) WriteStateDiffFor(blockHash common.Hash, params Params) error { + params.ComputeWatchedAddressesLeafKeys() + currentBlock := sds.BlockChain.GetBlockByHash(blockHash) parentRoot := common.Hash{} if currentBlock.NumberU64() != 0 { @@ -736,7 +752,7 @@ func (sds *Service) writeStateDiffWithRetry(block *types.Block, parentRoot commo } // Performs one of following operations on the watched addresses in writeLoopParams and the db: -// Add | Remove | Set | Clear +// add | remove | set | clear func (sds *Service) WatchAddress(operation OperationType, args []WatchAddressArg) error { // lock writeLoopParams for a write writeLoopParams.Lock() @@ -756,65 +772,66 @@ func (sds *Service) WatchAddress(operation OperationType, args []WatchAddressArg return true }).([]WatchAddressArg) if !ok { - return fmt.Errorf("Add: filtered args %s", typeAssertionFailed) + return fmt.Errorf("add: filtered args %s", typeAssertionFailed) } // get addresses from the filtered args - filteredAddresses, ok := funk.Map(filteredArgs, func(arg WatchAddressArg) common.Address { - return common.HexToAddress(arg.Address) - }).([]common.Address) - if !ok { - return fmt.Errorf("Add: filtered addresses %s", typeAssertionFailed) + filteredAddresses, err := MapWatchAddressArgsToAddresses(filteredArgs) + if err != nil { + return fmt.Errorf("add: filtered addresses %s", err.Error()) } // update the db - err := sds.indexer.InsertWatchedAddresses(filteredArgs, currentBlockNumber) + err = sds.indexer.InsertWatchedAddresses(filteredArgs, currentBlockNumber) if err != nil { return err } // update in-memory params writeLoopParams.WatchedAddresses = append(writeLoopParams.WatchedAddresses, filteredAddresses...) + funk.ForEach(filteredAddresses, func(address common.Address) { + writeLoopParams.watchedAddressesLeafKeys[crypto.Keccak256Hash(address.Bytes())] = struct{}{} + }) case Remove: // get addresses from args - argAddresses, ok := funk.Map(args, func(arg WatchAddressArg) common.Address { - return common.HexToAddress(arg.Address) - }).([]common.Address) - if !ok { - return fmt.Errorf("Remove: mapped addresses %s", typeAssertionFailed) + argAddresses, err := MapWatchAddressArgsToAddresses(args) + if err != nil { + return fmt.Errorf("remove: mapped addresses %s", err.Error()) } // remove the provided addresses from currently watched addresses addresses, ok := funk.Subtract(writeLoopParams.WatchedAddresses, argAddresses).([]common.Address) if !ok { - return fmt.Errorf("Remove: filtered addresses %s", typeAssertionFailed) + return fmt.Errorf("remove: filtered addresses %s", typeAssertionFailed) } // update the db - err := sds.indexer.RemoveWatchedAddresses(args) + err = sds.indexer.RemoveWatchedAddresses(args) if err != nil { return err } // update in-memory params writeLoopParams.WatchedAddresses = addresses + funk.ForEach(argAddresses, func(address common.Address) { + delete(writeLoopParams.watchedAddressesLeafKeys, crypto.Keccak256Hash(address.Bytes())) + }) case Set: // get addresses from args - argAddresses, ok := funk.Map(args, func(arg WatchAddressArg) common.Address { - return common.HexToAddress(arg.Address) - }).([]common.Address) - if !ok { - return fmt.Errorf("Set: mapped addresses %s", typeAssertionFailed) + argAddresses, err := MapWatchAddressArgsToAddresses(args) + if err != nil { + return fmt.Errorf("set: mapped addresses %s", err.Error()) } // update the db - err := sds.indexer.SetWatchedAddresses(args, currentBlockNumber) + err = sds.indexer.SetWatchedAddresses(args, currentBlockNumber) if err != nil { return err } // update in-memory params writeLoopParams.WatchedAddresses = argAddresses + writeLoopParams.ComputeWatchedAddressesLeafKeys() case Clear: // update the db err := sds.indexer.ClearWatchedAddresses() @@ -824,6 +841,7 @@ func (sds *Service) WatchAddress(operation OperationType, args []WatchAddressArg // update in-memory params writeLoopParams.WatchedAddresses = []common.Address{} + writeLoopParams.ComputeWatchedAddressesLeafKeys() default: return fmt.Errorf("%s %s", unexpectedOperation, operation) diff --git a/statediff/service_test.go b/statediff/service_test.go index ca9a483a5..a3a5ccca4 100644 --- a/statediff/service_test.go +++ b/statediff/service_test.go @@ -144,6 +144,7 @@ func testErrorInChainEventLoop(t *testing.T) { } } + defaultParams.ComputeWatchedAddressesLeafKeys() if !reflect.DeepEqual(builder.Params, defaultParams) { t.Error("Test failure:", t.Name()) t.Logf("Actual params does not equal expected.\nactual:%+v\nexpected: %+v", builder.Params, defaultParams) @@ -195,6 +196,8 @@ func testErrorInBlockLoop(t *testing.T) { } }() service.Loop(eventsChannel) + + defaultParams.ComputeWatchedAddressesLeafKeys() if !reflect.DeepEqual(builder.Params, defaultParams) { t.Error("Test failure:", t.Name()) t.Logf("Actual params does not equal expected.\nactual:%+v\nexpected: %+v", builder.Params, defaultParams) @@ -268,6 +271,8 @@ func testErrorInStateDiffAt(t *testing.T) { if err != nil { t.Error(err) } + + defaultParams.ComputeWatchedAddressesLeafKeys() if !reflect.DeepEqual(builder.Params, defaultParams) { t.Error("Test failure:", t.Name()) t.Logf("Actual params does not equal expected.\nactual:%+v\nexpected: %+v", builder.Params, defaultParams) diff --git a/statediff/testhelpers/mocks/service.go b/statediff/testhelpers/mocks/service.go index 61513d75c..e14565774 100644 --- a/statediff/testhelpers/mocks/service.go +++ b/statediff/testhelpers/mocks/service.go @@ -362,65 +362,62 @@ func (sds *MockStateDiffService) WatchAddress(operation statediff.OperationType, return true }).([]sdtypes.WatchAddressArg) if !ok { - return fmt.Errorf("Add: filtered args %s", typeAssertionFailed) + return fmt.Errorf("add: filtered args %s", typeAssertionFailed) } // get addresses from the filtered args - filteredAddresses, ok := funk.Map(filteredArgs, func(arg sdtypes.WatchAddressArg) common.Address { - return common.HexToAddress(arg.Address) - }).([]common.Address) - if !ok { - return fmt.Errorf("Add: filtered addresses %s", typeAssertionFailed) + filteredAddresses, err := statediff.MapWatchAddressArgsToAddresses(filteredArgs) + if err != nil { + return fmt.Errorf("add: filtered addresses %s", err.Error()) } // update the db - err := sds.Indexer.InsertWatchedAddresses(filteredArgs, currentBlockNumber) + err = sds.Indexer.InsertWatchedAddresses(filteredArgs, currentBlockNumber) if err != nil { return err } // update in-memory params sds.writeLoopParams.WatchedAddresses = append(sds.writeLoopParams.WatchedAddresses, filteredAddresses...) + sds.writeLoopParams.ComputeWatchedAddressesLeafKeys() case statediff.Remove: // get addresses from args - argAddresses, ok := funk.Map(args, func(arg sdtypes.WatchAddressArg) common.Address { - return common.HexToAddress(arg.Address) - }).([]common.Address) - if !ok { - return fmt.Errorf("Remove: mapped addresses %s", typeAssertionFailed) + argAddresses, err := statediff.MapWatchAddressArgsToAddresses(args) + if err != nil { + return fmt.Errorf("remove: mapped addresses %s", err.Error()) } // remove the provided addresses from currently watched addresses addresses, ok := funk.Subtract(sds.writeLoopParams.WatchedAddresses, argAddresses).([]common.Address) if !ok { - return fmt.Errorf("Remove: filtered addresses %s", typeAssertionFailed) + return fmt.Errorf("remove: filtered addresses %s", typeAssertionFailed) } // update the db - err := sds.Indexer.RemoveWatchedAddresses(args) + err = sds.Indexer.RemoveWatchedAddresses(args) if err != nil { return err } // update in-memory params sds.writeLoopParams.WatchedAddresses = addresses + sds.writeLoopParams.ComputeWatchedAddressesLeafKeys() case statediff.Set: // get addresses from args - argAddresses, ok := funk.Map(args, func(arg sdtypes.WatchAddressArg) common.Address { - return common.HexToAddress(arg.Address) - }).([]common.Address) - if !ok { - return fmt.Errorf("Set: mapped addresses %s", typeAssertionFailed) + argAddresses, err := statediff.MapWatchAddressArgsToAddresses(args) + if err != nil { + return fmt.Errorf("set: mapped addresses %s", err.Error()) } // update the db - err := sds.Indexer.SetWatchedAddresses(args, currentBlockNumber) + err = sds.Indexer.SetWatchedAddresses(args, currentBlockNumber) if err != nil { return err } // update in-memory params sds.writeLoopParams.WatchedAddresses = argAddresses + sds.writeLoopParams.ComputeWatchedAddressesLeafKeys() case statediff.Clear: // update the db err := sds.Indexer.ClearWatchedAddresses() @@ -430,6 +427,7 @@ func (sds *MockStateDiffService) WatchAddress(operation statediff.OperationType, // update in-memory params sds.writeLoopParams.WatchedAddresses = []common.Address{} + sds.writeLoopParams.ComputeWatchedAddressesLeafKeys() default: return fmt.Errorf("%s %s", unexpectedOperation, operation) diff --git a/statediff/testhelpers/mocks/service_test.go b/statediff/testhelpers/mocks/service_test.go index f1be5e317..dbd059def 100644 --- a/statediff/testhelpers/mocks/service_test.go +++ b/statediff/testhelpers/mocks/service_test.go @@ -506,6 +506,7 @@ func testWatchAddressAPI(t *testing.T) { mockService.writeLoopParams = statediff.ParamsWithMutex{ Params: test.startingParams, } + mockService.writeLoopParams.ComputeWatchedAddressesLeafKeys() // make the API call to change watched addresses err := mockService.WatchAddress(test.operation, test.args) @@ -522,6 +523,7 @@ func testWatchAddressAPI(t *testing.T) { } // check updated indexing params + test.expectedParams.ComputeWatchedAddressesLeafKeys() updatedParams := mockService.writeLoopParams.Params if !reflect.DeepEqual(updatedParams, test.expectedParams) { t.Logf("Test failed: %s", test.name) diff --git a/statediff/types.go b/statediff/types.go index 9e90ecabc..e33193637 100644 --- a/statediff/types.go +++ b/statediff/types.go @@ -26,6 +26,7 @@ import ( "github.com/ethereum/go-ethereum/common" ctypes "github.com/ethereum/go-ethereum/core/types" + "github.com/ethereum/go-ethereum/crypto" "github.com/ethereum/go-ethereum/statediff/types" ) @@ -51,6 +52,15 @@ type Params struct { IncludeTD bool IncludeCode bool WatchedAddresses []common.Address + watchedAddressesLeafKeys map[common.Hash]struct{} +} + +// ComputeWatchedAddressesLeafKeys populates a map with keys (Keccak256Hash) of each of the WatchedAddresses +func (p *Params) ComputeWatchedAddressesLeafKeys() { + p.watchedAddressesLeafKeys = make(map[common.Hash]struct{}, len(p.WatchedAddresses)) + for _, address := range p.WatchedAddresses { + p.watchedAddressesLeafKeys[crypto.Keccak256Hash(address.Bytes())] = struct{}{} + } } // ParamsWithMutex allows to lock the parameters while they are being updated | read from @@ -122,8 +132,8 @@ type accountWrapper struct { type OperationType string const ( - Add OperationType = "Add" - Remove OperationType = "Remove" - Set OperationType = "Set" - Clear OperationType = "Clear" + Add OperationType = "add" + Remove OperationType = "remove" + Set OperationType = "set" + Clear OperationType = "clear" )