Merge pull request #212 from deep-stack/pm-watched-addresses-v3

Statediff API (v3) to change addresses watched in direct indexing mode
This commit is contained in:
Ashwin Phatak 2022-04-05 09:33:51 +05:30 committed by GitHub
commit 2aaf6bcda3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
30 changed files with 2519 additions and 308 deletions

View File

@ -195,6 +195,7 @@ func makeFullNode(ctx *cli.Context) (*node.Node, ethapi.Backend) {
case shared.FILE: case shared.FILE:
indexerConfig = file.Config{ indexerConfig = file.Config{
FilePath: ctx.GlobalString(utils.StateDiffFilePath.Name), FilePath: ctx.GlobalString(utils.StateDiffFilePath.Name),
WatchedAddressesFilePath: ctx.GlobalString(utils.StateDiffWatchedAddressesFilePath.Name),
} }
case shared.POSTGRES: case shared.POSTGRES:
driverTypeStr := ctx.GlobalString(utils.StateDiffDBDriverTypeFlag.Name) driverTypeStr := ctx.GlobalString(utils.StateDiffDBDriverTypeFlag.Name)

View File

@ -178,6 +178,7 @@ var (
utils.StateDiffFilePath, utils.StateDiffFilePath,
utils.StateDiffKnownGapsFilePath, utils.StateDiffKnownGapsFilePath,
utils.StateDiffWaitForSync, utils.StateDiffWaitForSync,
utils.StateDiffWatchedAddressesFilePath,
configFileFlag, configFileFlag,
} }

View File

@ -248,6 +248,7 @@ var AppHelpFlagGroups = []flags.FlagGroup{
utils.StateDiffFilePath, utils.StateDiffFilePath,
utils.StateDiffKnownGapsFilePath, utils.StateDiffKnownGapsFilePath,
utils.StateDiffWaitForSync, utils.StateDiffWaitForSync,
utils.StateDiffWatchedAddressesFilePath,
}, },
}, },
{ {

View File

@ -872,6 +872,10 @@ var (
Usage: "Full path (including filename) to write knownGaps statements when the DB is unavailable.", Usage: "Full path (including filename) to write knownGaps statements when the DB is unavailable.",
Value: "./known_gaps.sql", Value: "./known_gaps.sql",
} }
StateDiffWatchedAddressesFilePath = cli.StringFlag{
Name: "statediff.file.wapath",
Usage: "Full path (including filename) to write statediff watched addresses out to when operating in file mode",
}
StateDiffDBClientNameFlag = cli.StringFlag{ StateDiffDBClientNameFlag = cli.StringFlag{
Name: "statediff.db.clientname", Name: "statediff.db.clientname",
Usage: "Client name to use when writing state diffs to database", Usage: "Client name to use when writing state diffs to database",

2
go.mod
View File

@ -63,6 +63,7 @@ require (
github.com/naoina/toml v0.1.2-0.20170918210437-9fafd6967416 github.com/naoina/toml v0.1.2-0.20170918210437-9fafd6967416
github.com/olekukonko/tablewriter v0.0.5 github.com/olekukonko/tablewriter v0.0.5
github.com/peterh/liner v1.1.1-0.20190123174540-a2c9a5303de7 github.com/peterh/liner v1.1.1-0.20190123174540-a2c9a5303de7
github.com/pganalyze/pg_query_go/v2 v2.1.0
github.com/prometheus/tsdb v0.7.1 github.com/prometheus/tsdb v0.7.1
github.com/rjeczalik/notify v0.9.1 github.com/rjeczalik/notify v0.9.1
github.com/rs/cors v1.7.0 github.com/rs/cors v1.7.0
@ -70,6 +71,7 @@ require (
github.com/status-im/keycard-go v0.0.0-20190316090335-8537d3370df4 github.com/status-im/keycard-go v0.0.0-20190316090335-8537d3370df4
github.com/stretchr/testify v1.7.0 github.com/stretchr/testify v1.7.0
github.com/syndtr/goleveldb v1.0.1-0.20210819022825-2ae1ddf74ef7 github.com/syndtr/goleveldb v1.0.1-0.20210819022825-2ae1ddf74ef7
github.com/thoas/go-funk v0.9.2
github.com/tklauser/go-sysconf v0.3.5 // indirect github.com/tklauser/go-sysconf v0.3.5 // indirect
github.com/tyler-smith/go-bip39 v1.0.1-0.20181017060643-dbb3b84ba2ef github.com/tyler-smith/go-bip39 v1.0.1-0.20181017060643-dbb3b84ba2ef
golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97 golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97

5
go.sum
View File

@ -222,6 +222,7 @@ github.com/google/go-cmp v0.3.0/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMyw
github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU=
github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/go-cmp v0.4.1/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.4.1/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/go-cmp v0.5.1/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/go-cmp v0.5.4 h1:L8R9j+yAqZuZjsqh/z+F1NCffTKKLShY6zXTItVIZ8M= github.com/google/go-cmp v0.5.4 h1:L8R9j+yAqZuZjsqh/z+F1NCffTKKLShY6zXTItVIZ8M=
github.com/google/go-cmp v0.5.4/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.4/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/gofuzz v1.1.1-0.20200604201612-c04b05f3adfa h1:Q75Upo5UN4JbPFURXZ8nLKYUvF85dyFRop/vQ0Rv+64= github.com/google/gofuzz v1.1.1-0.20200604201612-c04b05f3adfa h1:Q75Upo5UN4JbPFURXZ8nLKYUvF85dyFRop/vQ0Rv+64=
@ -516,6 +517,8 @@ github.com/paulbellamy/ratecounter v0.2.0/go.mod h1:Hfx1hDpSGoqxkVVpBi/IlYD7kChl
github.com/peterh/liner v1.0.1-0.20180619022028-8c1271fcf47f/go.mod h1:xIteQHvHuaLYG9IFj6mSxM0fCKrs34IrEQUhOYuGPHc= github.com/peterh/liner v1.0.1-0.20180619022028-8c1271fcf47f/go.mod h1:xIteQHvHuaLYG9IFj6mSxM0fCKrs34IrEQUhOYuGPHc=
github.com/peterh/liner v1.1.1-0.20190123174540-a2c9a5303de7 h1:oYW+YCJ1pachXTQmzR3rNLYGGz4g/UgFcjb28p/viDM= github.com/peterh/liner v1.1.1-0.20190123174540-a2c9a5303de7 h1:oYW+YCJ1pachXTQmzR3rNLYGGz4g/UgFcjb28p/viDM=
github.com/peterh/liner v1.1.1-0.20190123174540-a2c9a5303de7/go.mod h1:CRroGNssyjTd/qIG2FyxByd2S8JEAZXBl4qUrZf8GS0= github.com/peterh/liner v1.1.1-0.20190123174540-a2c9a5303de7/go.mod h1:CRroGNssyjTd/qIG2FyxByd2S8JEAZXBl4qUrZf8GS0=
github.com/pganalyze/pg_query_go/v2 v2.1.0 h1:donwPZ4G/X+kMs7j5eYtKjdziqyOLVp3pkUrzb9lDl8=
github.com/pganalyze/pg_query_go/v2 v2.1.0/go.mod h1:XAxmVqz1tEGqizcQ3YSdN90vCOHBWjJi8URL1er5+cA=
github.com/philhofer/fwd v1.0.0/go.mod h1:gk3iGcWd9+svBvR0sR+KPcfE+RNWozjowpeBVG3ZVNU= github.com/philhofer/fwd v1.0.0/go.mod h1:gk3iGcWd9+svBvR0sR+KPcfE+RNWozjowpeBVG3ZVNU=
github.com/pierrec/lz4 v2.0.5+incompatible/go.mod h1:pdkljMzZIN41W+lC3N2tnIh5sFi+IEE17M5jbnwPHcY= github.com/pierrec/lz4 v2.0.5+incompatible/go.mod h1:pdkljMzZIN41W+lC3N2tnIh5sFi+IEE17M5jbnwPHcY=
github.com/pkg/errors v0.8.0/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pkg/errors v0.8.0/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
@ -585,6 +588,8 @@ github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5Cc
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/syndtr/goleveldb v1.0.1-0.20210819022825-2ae1ddf74ef7 h1:epCh84lMvA70Z7CTTCmYQn2CKbY8j86K7/FAIr141uY= github.com/syndtr/goleveldb v1.0.1-0.20210819022825-2ae1ddf74ef7 h1:epCh84lMvA70Z7CTTCmYQn2CKbY8j86K7/FAIr141uY=
github.com/syndtr/goleveldb v1.0.1-0.20210819022825-2ae1ddf74ef7/go.mod h1:q4W45IWZaF22tdD+VEXcAWRA037jwmWEB5VWYORlTpc= github.com/syndtr/goleveldb v1.0.1-0.20210819022825-2ae1ddf74ef7/go.mod h1:q4W45IWZaF22tdD+VEXcAWRA037jwmWEB5VWYORlTpc=
github.com/thoas/go-funk v0.9.2 h1:oKlNYv0AY5nyf9g+/GhMgS/UO2ces0QRdPKwkhY3VCk=
github.com/thoas/go-funk v0.9.2/go.mod h1:+IWnUfUmFO1+WVYQWQtIJHeRRdaIyyYglZN7xzUPe4Q=
github.com/tinylib/msgp v1.0.2/go.mod h1:+d+yLhGm8mzTaHzB+wgMYrodPfmZrzkirds8fDWklFE= github.com/tinylib/msgp v1.0.2/go.mod h1:+d+yLhGm8mzTaHzB+wgMYrodPfmZrzkirds8fDWklFE=
github.com/tklauser/go-sysconf v0.3.5 h1:uu3Xl4nkLzQfXNsWn15rPc/HQCJKObbt1dKJeWp3vU4= github.com/tklauser/go-sysconf v0.3.5 h1:uu3Xl4nkLzQfXNsWn15rPc/HQCJKObbt1dKJeWp3vU4=
github.com/tklauser/go-sysconf v0.3.5/go.mod h1:MkWzOF4RMCshBAMXuhXJs64Rte09mITnppBXY/rYEFI= github.com/tklauser/go-sysconf v0.3.5/go.mod h1:MkWzOF4RMCshBAMXuhXJs64Rte09mITnppBXY/rYEFI=

View File

@ -120,6 +120,8 @@ This service introduces a CLI flag namespace `statediff`
`--statediff.file.path` full path (including filename) to write statediff data out to when operating in file mode `--statediff.file.path` full path (including filename) to write statediff data out to when operating in file mode
`--statediff.file.wapath` full path (including filename) to write statediff watched addresses out to when operating in file mode
The service can only operate in full sync mode (`--syncmode=full`), but only the historical RPC endpoints require an archive node (`--gcmode=archive`) The service can only operate in full sync mode (`--syncmode=full`), but only the historical RPC endpoints require an archive node (`--gcmode=archive`)
e.g. e.g.
@ -148,15 +150,13 @@ type Params struct {
IncludeTD bool IncludeTD bool
IncludeCode bool IncludeCode bool
WatchedAddresses []common.Address WatchedAddresses []common.Address
WatchedStorageSlots []common.Hash
} }
``` ```
Using these params we can tell the service whether to include state and/or storage intermediate nodes; whether Using these params we can tell the service whether to include state and/or storage intermediate nodes; whether
to include the associated block (header, uncles, and transactions); whether to include the associated receipts; to include the associated block (header, uncles, and transactions); whether to include the associated receipts;
whether to include the total difficulty for this block; whether to include the set of code hashes and code for whether to include the total difficulty for this block; whether to include the set of code hashes and code for
contracts deployed in this block; whether to limit the diffing process to a list of specific addresses; and/or contracts deployed in this block; whether to limit the diffing process to a list of specific addresses.
whether to limit the diffing process to a list of specific storage slot keys.
#### Subscription endpoint #### Subscription endpoint

View File

@ -149,3 +149,8 @@ func (api *PublicStateDiffAPI) WriteStateDiffAt(ctx context.Context, blockNumber
func (api *PublicStateDiffAPI) WriteStateDiffFor(ctx context.Context, blockHash common.Hash, params Params) error { func (api *PublicStateDiffAPI) WriteStateDiffFor(ctx context.Context, blockHash common.Hash, params Params) error {
return api.sds.WriteStateDiffFor(blockHash, params) return api.sds.WriteStateDiffFor(blockHash, params)
} }
// WatchAddress changes the list of watched addresses to which the direct indexing is restricted according to given operation
func (api *PublicStateDiffAPI) WatchAddress(operation types.OperationType, args []types.WatchAddressArg) error {
return api.sds.WatchAddress(operation, args)
}

View File

@ -123,7 +123,7 @@ func (sdb *builder) buildStateTrie(it trie.NodeIterator) ([]types2.StateNode, []
node.LeafKey = leafKey node.LeafKey = leafKey
if !bytes.Equal(account.CodeHash, nullCodeHash) { if !bytes.Equal(account.CodeHash, nullCodeHash) {
var storageNodes []types2.StorageNode var storageNodes []types2.StorageNode
err := sdb.buildStorageNodesEventual(account.Root, nil, true, storageNodeAppender(&storageNodes)) err := sdb.buildStorageNodesEventual(account.Root, true, storageNodeAppender(&storageNodes))
if err != nil { if err != nil {
return nil, nil, fmt.Errorf("failed building eventual storage diffs for account %+v\r\nerror: %v", account, err) return nil, nil, fmt.Errorf("failed building eventual storage diffs for account %+v\r\nerror: %v", account, err)
} }
@ -202,7 +202,7 @@ func (sdb *builder) buildStateDiffWithIntermediateStateNodes(args types2.StateRo
// a map of their leafkey to all the accounts that were touched and exist at A // a map of their leafkey to all the accounts that were touched and exist at A
diffAccountsAtA, err := sdb.deletedOrUpdatedState( diffAccountsAtA, err := sdb.deletedOrUpdatedState(
oldTrie.NodeIterator([]byte{}), newTrie.NodeIterator([]byte{}), oldTrie.NodeIterator([]byte{}), newTrie.NodeIterator([]byte{}),
diffPathsAtB, output) diffPathsAtB, params.watchedAddressesLeafKeys, output)
if err != nil { if err != nil {
return fmt.Errorf("error collecting deletedOrUpdatedNodes: %v", err) return fmt.Errorf("error collecting deletedOrUpdatedNodes: %v", err)
} }
@ -220,12 +220,12 @@ func (sdb *builder) buildStateDiffWithIntermediateStateNodes(args types2.StateRo
// build the diff nodes for the updated accounts using the mappings at both A and B as directed by the keys found as the intersection of the two // build the diff nodes for the updated accounts using the mappings at both A and B as directed by the keys found as the intersection of the two
err = sdb.buildAccountUpdates( err = sdb.buildAccountUpdates(
diffAccountsAtB, diffAccountsAtA, updatedKeys, diffAccountsAtB, diffAccountsAtA, updatedKeys,
params.WatchedStorageSlots, params.IntermediateStorageNodes, output) params.IntermediateStorageNodes, output)
if err != nil { if err != nil {
return fmt.Errorf("error building diff for updated accounts: %v", err) return fmt.Errorf("error building diff for updated accounts: %v", err)
} }
// build the diff nodes for created accounts // build the diff nodes for created accounts
err = sdb.buildAccountCreations(diffAccountsAtB, params.WatchedStorageSlots, params.IntermediateStorageNodes, output, codeOutput) err = sdb.buildAccountCreations(diffAccountsAtB, params.IntermediateStorageNodes, output, codeOutput)
if err != nil { if err != nil {
return fmt.Errorf("error building diff for created accounts: %v", err) return fmt.Errorf("error building diff for created accounts: %v", err)
} }
@ -247,7 +247,7 @@ func (sdb *builder) buildStateDiffWithoutIntermediateStateNodes(args types2.Stat
// and a slice of all the paths for the nodes in both of the above sets // and a slice of all the paths for the nodes in both of the above sets
diffAccountsAtB, diffPathsAtB, err := sdb.createdAndUpdatedState( diffAccountsAtB, diffPathsAtB, err := sdb.createdAndUpdatedState(
oldTrie.NodeIterator([]byte{}), newTrie.NodeIterator([]byte{}), oldTrie.NodeIterator([]byte{}), newTrie.NodeIterator([]byte{}),
params.WatchedAddresses) params.watchedAddressesLeafKeys)
if err != nil { if err != nil {
return fmt.Errorf("error collecting createdAndUpdatedNodes: %v", err) return fmt.Errorf("error collecting createdAndUpdatedNodes: %v", err)
} }
@ -256,7 +256,7 @@ func (sdb *builder) buildStateDiffWithoutIntermediateStateNodes(args types2.Stat
// a map of their leafkey to all the accounts that were touched and exist at A // a map of their leafkey to all the accounts that were touched and exist at A
diffAccountsAtA, err := sdb.deletedOrUpdatedState( diffAccountsAtA, err := sdb.deletedOrUpdatedState(
oldTrie.NodeIterator([]byte{}), newTrie.NodeIterator([]byte{}), oldTrie.NodeIterator([]byte{}), newTrie.NodeIterator([]byte{}),
diffPathsAtB, output) diffPathsAtB, params.watchedAddressesLeafKeys, output)
if err != nil { if err != nil {
return fmt.Errorf("error collecting deletedOrUpdatedNodes: %v", err) return fmt.Errorf("error collecting deletedOrUpdatedNodes: %v", err)
} }
@ -274,12 +274,12 @@ func (sdb *builder) buildStateDiffWithoutIntermediateStateNodes(args types2.Stat
// build the diff nodes for the updated accounts using the mappings at both A and B as directed by the keys found as the intersection of the two // build the diff nodes for the updated accounts using the mappings at both A and B as directed by the keys found as the intersection of the two
err = sdb.buildAccountUpdates( err = sdb.buildAccountUpdates(
diffAccountsAtB, diffAccountsAtA, updatedKeys, diffAccountsAtB, diffAccountsAtA, updatedKeys,
params.WatchedStorageSlots, params.IntermediateStorageNodes, output) params.IntermediateStorageNodes, output)
if err != nil { if err != nil {
return fmt.Errorf("error building diff for updated accounts: %v", err) return fmt.Errorf("error building diff for updated accounts: %v", err)
} }
// build the diff nodes for created accounts // build the diff nodes for created accounts
err = sdb.buildAccountCreations(diffAccountsAtB, params.WatchedStorageSlots, params.IntermediateStorageNodes, output, codeOutput) err = sdb.buildAccountCreations(diffAccountsAtB, params.IntermediateStorageNodes, output, codeOutput)
if err != nil { if err != nil {
return fmt.Errorf("error building diff for created accounts: %v", err) return fmt.Errorf("error building diff for created accounts: %v", err)
} }
@ -289,7 +289,7 @@ func (sdb *builder) buildStateDiffWithoutIntermediateStateNodes(args types2.Stat
// createdAndUpdatedState returns // createdAndUpdatedState returns
// a mapping of their leafkeys to all the accounts that exist in a different state at B than A // a mapping of their leafkeys to all the accounts that exist in a different state at B than A
// and a slice of the paths for all of the nodes included in both // and a slice of the paths for all of the nodes included in both
func (sdb *builder) createdAndUpdatedState(a, b trie.NodeIterator, watchedAddresses []common.Address) (types2.AccountMap, map[string]bool, error) { func (sdb *builder) createdAndUpdatedState(a, b trie.NodeIterator, watchedAddressesLeafKeys map[common.Hash]struct{}) (types2.AccountMap, map[string]bool, error) {
diffPathsAtB := make(map[string]bool) diffPathsAtB := make(map[string]bool)
diffAcountsAtB := make(types2.AccountMap) diffAcountsAtB := make(types2.AccountMap)
it, _ := trie.NewDifferenceIterator(a, b) it, _ := trie.NewDifferenceIterator(a, b)
@ -313,7 +313,7 @@ func (sdb *builder) createdAndUpdatedState(a, b trie.NodeIterator, watchedAddres
valueNodePath := append(node.Path, partialPath...) valueNodePath := append(node.Path, partialPath...)
encodedPath := trie.HexToCompact(valueNodePath) encodedPath := trie.HexToCompact(valueNodePath)
leafKey := encodedPath[1:] leafKey := encodedPath[1:]
if isWatchedAddress(watchedAddresses, leafKey) { if isWatchedAddress(watchedAddressesLeafKeys, leafKey) {
diffAcountsAtB[common.Bytes2Hex(leafKey)] = types2.AccountWrapper{ diffAcountsAtB[common.Bytes2Hex(leafKey)] = types2.AccountWrapper{
NodeType: node.NodeType, NodeType: node.NodeType,
Path: node.Path, Path: node.Path,
@ -386,7 +386,7 @@ func (sdb *builder) createdAndUpdatedStateWithIntermediateNodes(a, b trie.NodeIt
// deletedOrUpdatedState returns a slice of all the pathes that are emptied at B // deletedOrUpdatedState returns a slice of all the pathes that are emptied at B
// and a mapping of their leafkeys to all the accounts that exist in a different state at A than B // and a mapping of their leafkeys to all the accounts that exist in a different state at A than B
func (sdb *builder) deletedOrUpdatedState(a, b trie.NodeIterator, diffPathsAtB map[string]bool, output types2.StateNodeSink) (types2.AccountMap, error) { func (sdb *builder) deletedOrUpdatedState(a, b trie.NodeIterator, diffPathsAtB map[string]bool, watchedAddressesLeafKeys map[common.Hash]struct{}, output types2.StateNodeSink) (types2.AccountMap, error) {
diffAccountAtA := make(types2.AccountMap) diffAccountAtA := make(types2.AccountMap)
it, _ := trie.NewDifferenceIterator(b, a) it, _ := trie.NewDifferenceIterator(b, a)
for it.Next(true) { for it.Next(true) {
@ -409,6 +409,7 @@ func (sdb *builder) deletedOrUpdatedState(a, b trie.NodeIterator, diffPathsAtB m
valueNodePath := append(node.Path, partialPath...) valueNodePath := append(node.Path, partialPath...)
encodedPath := trie.HexToCompact(valueNodePath) encodedPath := trie.HexToCompact(valueNodePath)
leafKey := encodedPath[1:] leafKey := encodedPath[1:]
if isWatchedAddress(watchedAddressesLeafKeys, leafKey) {
diffAccountAtA[common.Bytes2Hex(leafKey)] = types2.AccountWrapper{ diffAccountAtA[common.Bytes2Hex(leafKey)] = types2.AccountWrapper{
NodeType: node.NodeType, NodeType: node.NodeType,
Path: node.Path, Path: node.Path,
@ -429,6 +430,7 @@ func (sdb *builder) deletedOrUpdatedState(a, b trie.NodeIterator, diffPathsAtB m
return nil, err return nil, err
} }
} }
}
case types2.Extension, types2.Branch: case types2.Extension, types2.Branch:
// if this node's path did not show up in diffPathsAtB // if this node's path did not show up in diffPathsAtB
// that means the node at this path was deleted (or moved) in B // that means the node at this path was deleted (or moved) in B
@ -454,8 +456,7 @@ func (sdb *builder) deletedOrUpdatedState(a, b trie.NodeIterator, diffPathsAtB m
// to generate the statediff node objects for all of the accounts that existed at both A and B but in different states // to generate the statediff node objects for all of the accounts that existed at both A and B but in different states
// needs to be called before building account creations and deletions as this mutates // needs to be called before building account creations and deletions as this mutates
// those account maps to remove the accounts which were updated // those account maps to remove the accounts which were updated
func (sdb *builder) buildAccountUpdates(creations, deletions types2.AccountMap, updatedKeys []string, func (sdb *builder) buildAccountUpdates(creations, deletions types2.AccountMap, updatedKeys []string, intermediateStorageNodes bool, output types2.StateNodeSink) error {
watchedStorageKeys []common.Hash, intermediateStorageNodes bool, output types2.StateNodeSink) error {
var err error var err error
for _, key := range updatedKeys { for _, key := range updatedKeys {
createdAcc := creations[key] createdAcc := creations[key]
@ -465,7 +466,7 @@ func (sdb *builder) buildAccountUpdates(creations, deletions types2.AccountMap,
oldSR := deletedAcc.Account.Root oldSR := deletedAcc.Account.Root
newSR := createdAcc.Account.Root newSR := createdAcc.Account.Root
err = sdb.buildStorageNodesIncremental( err = sdb.buildStorageNodesIncremental(
oldSR, newSR, watchedStorageKeys, intermediateStorageNodes, oldSR, newSR, intermediateStorageNodes,
storageNodeAppender(&storageDiffs)) storageNodeAppender(&storageDiffs))
if err != nil { if err != nil {
return fmt.Errorf("failed building incremental storage diffs for account with leafkey %s\r\nerror: %v", key, err) return fmt.Errorf("failed building incremental storage diffs for account with leafkey %s\r\nerror: %v", key, err)
@ -489,7 +490,7 @@ func (sdb *builder) buildAccountUpdates(creations, deletions types2.AccountMap,
// buildAccountCreations returns the statediff node objects for all the accounts that exist at B but not at A // buildAccountCreations returns the statediff node objects for all the accounts that exist at B but not at A
// it also returns the code and codehash for created contract accounts // it also returns the code and codehash for created contract accounts
func (sdb *builder) buildAccountCreations(accounts types2.AccountMap, watchedStorageKeys []common.Hash, intermediateStorageNodes bool, output types2.StateNodeSink, codeOutput types2.CodeSink) error { func (sdb *builder) buildAccountCreations(accounts types2.AccountMap, intermediateStorageNodes bool, output types2.StateNodeSink, codeOutput types2.CodeSink) error {
for _, val := range accounts { for _, val := range accounts {
diff := types2.StateNode{ diff := types2.StateNode{
NodeType: val.NodeType, NodeType: val.NodeType,
@ -500,7 +501,7 @@ func (sdb *builder) buildAccountCreations(accounts types2.AccountMap, watchedSto
if !bytes.Equal(val.Account.CodeHash, nullCodeHash) { if !bytes.Equal(val.Account.CodeHash, nullCodeHash) {
// For contract creations, any storage node contained is a diff // For contract creations, any storage node contained is a diff
var storageDiffs []types2.StorageNode var storageDiffs []types2.StorageNode
err := sdb.buildStorageNodesEventual(val.Account.Root, watchedStorageKeys, intermediateStorageNodes, storageNodeAppender(&storageDiffs)) err := sdb.buildStorageNodesEventual(val.Account.Root, intermediateStorageNodes, storageNodeAppender(&storageDiffs))
if err != nil { if err != nil {
return fmt.Errorf("failed building eventual storage diffs for node %x\r\nerror: %v", val.Path, err) return fmt.Errorf("failed building eventual storage diffs for node %x\r\nerror: %v", val.Path, err)
} }
@ -528,7 +529,7 @@ func (sdb *builder) buildAccountCreations(accounts types2.AccountMap, watchedSto
// buildStorageNodesEventual builds the storage diff node objects for a created account // buildStorageNodesEventual builds the storage diff node objects for a created account
// i.e. it returns all the storage nodes at this state, since there is no previous state // i.e. it returns all the storage nodes at this state, since there is no previous state
func (sdb *builder) buildStorageNodesEventual(sr common.Hash, watchedStorageKeys []common.Hash, intermediateNodes bool, output types2.StorageNodeSink) error { func (sdb *builder) buildStorageNodesEventual(sr common.Hash, intermediateNodes bool, output types2.StorageNodeSink) error {
if bytes.Equal(sr.Bytes(), emptyContractRoot.Bytes()) { if bytes.Equal(sr.Bytes(), emptyContractRoot.Bytes()) {
return nil return nil
} }
@ -539,7 +540,7 @@ func (sdb *builder) buildStorageNodesEventual(sr common.Hash, watchedStorageKeys
return err return err
} }
it := sTrie.NodeIterator(make([]byte, 0)) it := sTrie.NodeIterator(make([]byte, 0))
err = sdb.buildStorageNodesFromTrie(it, watchedStorageKeys, intermediateNodes, output) err = sdb.buildStorageNodesFromTrie(it, intermediateNodes, output)
if err != nil { if err != nil {
return err return err
} }
@ -549,7 +550,7 @@ func (sdb *builder) buildStorageNodesEventual(sr common.Hash, watchedStorageKeys
// buildStorageNodesFromTrie returns all the storage diff node objects in the provided node interator // buildStorageNodesFromTrie returns all the storage diff node objects in the provided node interator
// if any storage keys are provided it will only return those leaf nodes // if any storage keys are provided it will only return those leaf nodes
// including intermediate nodes can be turned on or off // including intermediate nodes can be turned on or off
func (sdb *builder) buildStorageNodesFromTrie(it trie.NodeIterator, watchedStorageKeys []common.Hash, intermediateNodes bool, output types2.StorageNodeSink) error { func (sdb *builder) buildStorageNodesFromTrie(it trie.NodeIterator, intermediateNodes bool, output types2.StorageNodeSink) error {
for it.Next(true) { for it.Next(true) {
// skip value nodes // skip value nodes
if it.Leaf() || bytes.Equal(nullHashBytes, it.Hash().Bytes()) { if it.Leaf() || bytes.Equal(nullHashBytes, it.Hash().Bytes()) {
@ -565,7 +566,6 @@ func (sdb *builder) buildStorageNodesFromTrie(it trie.NodeIterator, watchedStora
valueNodePath := append(node.Path, partialPath...) valueNodePath := append(node.Path, partialPath...)
encodedPath := trie.HexToCompact(valueNodePath) encodedPath := trie.HexToCompact(valueNodePath)
leafKey := encodedPath[1:] leafKey := encodedPath[1:]
if isWatchedStorageKey(watchedStorageKeys, leafKey) {
if err := output(types2.StorageNode{ if err := output(types2.StorageNode{
NodeType: node.NodeType, NodeType: node.NodeType,
Path: node.Path, Path: node.Path,
@ -574,7 +574,6 @@ func (sdb *builder) buildStorageNodesFromTrie(it trie.NodeIterator, watchedStora
}); err != nil { }); err != nil {
return err return err
} }
}
case types2.Extension, types2.Branch: case types2.Extension, types2.Branch:
if intermediateNodes { if intermediateNodes {
if err := output(types2.StorageNode{ if err := output(types2.StorageNode{
@ -593,7 +592,7 @@ func (sdb *builder) buildStorageNodesFromTrie(it trie.NodeIterator, watchedStora
} }
// buildStorageNodesIncremental builds the storage diff node objects for all nodes that exist in a different state at B than A // buildStorageNodesIncremental builds the storage diff node objects for all nodes that exist in a different state at B than A
func (sdb *builder) buildStorageNodesIncremental(oldSR common.Hash, newSR common.Hash, watchedStorageKeys []common.Hash, intermediateNodes bool, output types2.StorageNodeSink) error { func (sdb *builder) buildStorageNodesIncremental(oldSR common.Hash, newSR common.Hash, intermediateNodes bool, output types2.StorageNodeSink) error {
if bytes.Equal(newSR.Bytes(), oldSR.Bytes()) { if bytes.Equal(newSR.Bytes(), oldSR.Bytes()) {
return nil return nil
} }
@ -609,19 +608,19 @@ func (sdb *builder) buildStorageNodesIncremental(oldSR common.Hash, newSR common
diffPathsAtB, err := sdb.createdAndUpdatedStorage( diffPathsAtB, err := sdb.createdAndUpdatedStorage(
oldTrie.NodeIterator([]byte{}), newTrie.NodeIterator([]byte{}), oldTrie.NodeIterator([]byte{}), newTrie.NodeIterator([]byte{}),
watchedStorageKeys, intermediateNodes, output) intermediateNodes, output)
if err != nil { if err != nil {
return err return err
} }
err = sdb.deletedOrUpdatedStorage(oldTrie.NodeIterator([]byte{}), newTrie.NodeIterator([]byte{}), err = sdb.deletedOrUpdatedStorage(oldTrie.NodeIterator([]byte{}), newTrie.NodeIterator([]byte{}),
diffPathsAtB, watchedStorageKeys, intermediateNodes, output) diffPathsAtB, intermediateNodes, output)
if err != nil { if err != nil {
return err return err
} }
return nil return nil
} }
func (sdb *builder) createdAndUpdatedStorage(a, b trie.NodeIterator, watchedKeys []common.Hash, intermediateNodes bool, output types2.StorageNodeSink) (map[string]bool, error) { func (sdb *builder) createdAndUpdatedStorage(a, b trie.NodeIterator, intermediateNodes bool, output types2.StorageNodeSink) (map[string]bool, error) {
diffPathsAtB := make(map[string]bool) diffPathsAtB := make(map[string]bool)
it, _ := trie.NewDifferenceIterator(a, b) it, _ := trie.NewDifferenceIterator(a, b)
for it.Next(true) { for it.Next(true) {
@ -639,7 +638,6 @@ func (sdb *builder) createdAndUpdatedStorage(a, b trie.NodeIterator, watchedKeys
valueNodePath := append(node.Path, partialPath...) valueNodePath := append(node.Path, partialPath...)
encodedPath := trie.HexToCompact(valueNodePath) encodedPath := trie.HexToCompact(valueNodePath)
leafKey := encodedPath[1:] leafKey := encodedPath[1:]
if isWatchedStorageKey(watchedKeys, leafKey) {
if err := output(types2.StorageNode{ if err := output(types2.StorageNode{
NodeType: node.NodeType, NodeType: node.NodeType,
Path: node.Path, Path: node.Path,
@ -648,7 +646,6 @@ func (sdb *builder) createdAndUpdatedStorage(a, b trie.NodeIterator, watchedKeys
}); err != nil { }); err != nil {
return nil, err return nil, err
} }
}
case types2.Extension, types2.Branch: case types2.Extension, types2.Branch:
if intermediateNodes { if intermediateNodes {
if err := output(types2.StorageNode{ if err := output(types2.StorageNode{
@ -667,7 +664,7 @@ func (sdb *builder) createdAndUpdatedStorage(a, b trie.NodeIterator, watchedKeys
return diffPathsAtB, it.Error() return diffPathsAtB, it.Error()
} }
func (sdb *builder) deletedOrUpdatedStorage(a, b trie.NodeIterator, diffPathsAtB map[string]bool, watchedKeys []common.Hash, intermediateNodes bool, output types2.StorageNodeSink) error { func (sdb *builder) deletedOrUpdatedStorage(a, b trie.NodeIterator, diffPathsAtB map[string]bool, intermediateNodes bool, output types2.StorageNodeSink) error {
it, _ := trie.NewDifferenceIterator(b, a) it, _ := trie.NewDifferenceIterator(b, a)
for it.Next(true) { for it.Next(true) {
// skip value nodes // skip value nodes
@ -690,7 +687,6 @@ func (sdb *builder) deletedOrUpdatedStorage(a, b trie.NodeIterator, diffPathsAtB
valueNodePath := append(node.Path, partialPath...) valueNodePath := append(node.Path, partialPath...)
encodedPath := trie.HexToCompact(valueNodePath) encodedPath := trie.HexToCompact(valueNodePath)
leafKey := encodedPath[1:] leafKey := encodedPath[1:]
if isWatchedStorageKey(watchedKeys, leafKey) {
if err := output(types2.StorageNode{ if err := output(types2.StorageNode{
NodeType: types2.Removed, NodeType: types2.Removed,
Path: node.Path, Path: node.Path,
@ -699,7 +695,6 @@ func (sdb *builder) deletedOrUpdatedStorage(a, b trie.NodeIterator, diffPathsAtB
}); err != nil { }); err != nil {
return err return err
} }
}
case types2.Extension, types2.Branch: case types2.Extension, types2.Branch:
if intermediateNodes { if intermediateNodes {
if err := output(types2.StorageNode{ if err := output(types2.StorageNode{
@ -718,30 +713,12 @@ func (sdb *builder) deletedOrUpdatedStorage(a, b trie.NodeIterator, diffPathsAtB
} }
// isWatchedAddress is used to check if a state account corresponds to one of the addresses the builder is configured to watch // isWatchedAddress is used to check if a state account corresponds to one of the addresses the builder is configured to watch
func isWatchedAddress(watchedAddresses []common.Address, stateLeafKey []byte) bool { func isWatchedAddress(watchedAddressesLeafKeys map[common.Hash]struct{}, stateLeafKey []byte) bool {
// If we aren't watching any specific addresses, we are watching everything // If we aren't watching any specific addresses, we are watching everything
if len(watchedAddresses) == 0 { if len(watchedAddressesLeafKeys) == 0 {
return true return true
} }
for _, addr := range watchedAddresses {
addrHashKey := crypto.Keccak256(addr.Bytes())
if bytes.Equal(addrHashKey, stateLeafKey) {
return true
}
}
return false
}
// isWatchedStorageKey is used to check if a storage leaf corresponds to one of the storage slots the builder is configured to watch _, ok := watchedAddressesLeafKeys[common.BytesToHash(stateLeafKey)]
func isWatchedStorageKey(watchedKeys []common.Hash, storageLeafKey []byte) bool { return ok
// If we aren't watching any specific addresses, we are watching everything
if len(watchedKeys) == 0 {
return true
}
for _, hashKey := range watchedKeys {
if bytes.Equal(hashKey.Bytes(), storageLeafKey) {
return true
}
}
return false
} }

View File

@ -988,6 +988,7 @@ func TestBuilderWithWatchedAddressList(t *testing.T) {
params := statediff.Params{ params := statediff.Params{
WatchedAddresses: []common.Address{test_helpers.Account1Addr, test_helpers.ContractAddr}, WatchedAddresses: []common.Address{test_helpers.Account1Addr, test_helpers.ContractAddr},
} }
params.ComputeWatchedAddressesLeafKeys()
builder = statediff.NewBuilder(chain.StateCache()) builder = statediff.NewBuilder(chain.StateCache())
var tests = []struct { var tests = []struct {
@ -1152,169 +1153,6 @@ func TestBuilderWithWatchedAddressList(t *testing.T) {
} }
} }
func TestBuilderWithWatchedAddressAndStorageKeyList(t *testing.T) {
blocks, chain := test_helpers.MakeChain(3, test_helpers.Genesis, test_helpers.TestChainGen)
contractLeafKey = test_helpers.AddressToLeafKey(test_helpers.ContractAddr)
defer chain.Stop()
block0 = test_helpers.Genesis
block1 = blocks[0]
block2 = blocks[1]
block3 = blocks[2]
params := statediff.Params{
WatchedAddresses: []common.Address{test_helpers.Account1Addr, test_helpers.ContractAddr},
WatchedStorageSlots: []common.Hash{slot1StorageKey},
}
builder = statediff.NewBuilder(chain.StateCache())
var tests = []struct {
name string
startingArguments statediff.Args
expected *types2.StateObject
}{
{
"testEmptyDiff",
statediff.Args{
OldStateRoot: block0.Root(),
NewStateRoot: block0.Root(),
BlockNumber: block0.Number(),
BlockHash: block0.Hash(),
},
&types2.StateObject{
BlockNumber: block0.Number(),
BlockHash: block0.Hash(),
Nodes: emptyDiffs,
},
},
{
"testBlock0",
//10000 transferred from testBankAddress to account1Addr
statediff.Args{
OldStateRoot: test_helpers.NullHash,
NewStateRoot: block0.Root(),
BlockNumber: block0.Number(),
BlockHash: block0.Hash(),
},
&types2.StateObject{
BlockNumber: block0.Number(),
BlockHash: block0.Hash(),
Nodes: emptyDiffs,
},
},
{
"testBlock1",
//10000 transferred from testBankAddress to account1Addr
statediff.Args{
OldStateRoot: block0.Root(),
NewStateRoot: block1.Root(),
BlockNumber: block1.Number(),
BlockHash: block1.Hash(),
},
&types2.StateObject{
BlockNumber: block1.Number(),
BlockHash: block1.Hash(),
Nodes: []types2.StateNode{
{
Path: []byte{'\x0e'},
NodeType: types2.Leaf,
LeafKey: test_helpers.Account1LeafKey,
NodeValue: account1AtBlock1LeafNode,
StorageNodes: emptyStorage,
},
},
},
},
{
"testBlock2",
//1000 transferred from testBankAddress to account1Addr
//1000 transferred from account1Addr to account2Addr
statediff.Args{
OldStateRoot: block1.Root(),
NewStateRoot: block2.Root(),
BlockNumber: block2.Number(),
BlockHash: block2.Hash(),
},
&types2.StateObject{
BlockNumber: block2.Number(),
BlockHash: block2.Hash(),
Nodes: []types2.StateNode{
{
Path: []byte{'\x06'},
NodeType: types2.Leaf,
LeafKey: contractLeafKey,
NodeValue: contractAccountAtBlock2LeafNode,
StorageNodes: []types2.StorageNode{
{
Path: []byte{'\x0b'},
NodeType: types2.Leaf,
LeafKey: slot1StorageKey.Bytes(),
NodeValue: slot1StorageLeafNode,
},
},
},
{
Path: []byte{'\x0e'},
NodeType: types2.Leaf,
LeafKey: test_helpers.Account1LeafKey,
NodeValue: account1AtBlock2LeafNode,
StorageNodes: emptyStorage,
},
},
CodeAndCodeHashes: []types2.CodeAndCodeHash{
{
Hash: test_helpers.CodeHash,
Code: test_helpers.ByteCodeAfterDeployment,
},
},
},
},
{
"testBlock3",
//the contract's storage is changed
//and the block is mined by account 2
statediff.Args{
OldStateRoot: block2.Root(),
NewStateRoot: block3.Root(),
BlockNumber: block3.Number(),
BlockHash: block3.Hash(),
},
&types2.StateObject{
BlockNumber: block3.Number(),
BlockHash: block3.Hash(),
Nodes: []types2.StateNode{
{
Path: []byte{'\x06'},
NodeType: types2.Leaf,
LeafKey: contractLeafKey,
NodeValue: contractAccountAtBlock3LeafNode,
StorageNodes: emptyStorage,
},
},
},
},
}
for _, test := range tests {
diff, err := builder.BuildStateDiffObject(test.startingArguments, params)
if err != nil {
t.Error(err)
}
receivedStateDiffRlp, err := rlp.EncodeToBytes(diff)
if err != nil {
t.Error(err)
}
expectedStateDiffRlp, err := rlp.EncodeToBytes(test.expected)
if err != nil {
t.Error(err)
}
sort.Slice(receivedStateDiffRlp, func(i, j int) bool { return receivedStateDiffRlp[i] < receivedStateDiffRlp[j] })
sort.Slice(expectedStateDiffRlp, func(i, j int) bool { return expectedStateDiffRlp[i] < expectedStateDiffRlp[j] })
if !bytes.Equal(receivedStateDiffRlp, expectedStateDiffRlp) {
t.Logf("Test failed: %s", test.name)
t.Errorf("actual state diff: %+v\nexpected state diff: %+v", diff, test.expected)
}
}
}
func TestBuilderWithRemovedAccountAndStorage(t *testing.T) { func TestBuilderWithRemovedAccountAndStorage(t *testing.T) {
blocks, chain := test_helpers.MakeChain(6, test_helpers.Genesis, test_helpers.TestChainGen) blocks, chain := test_helpers.MakeChain(6, test_helpers.Genesis, test_helpers.TestChainGen)
contractLeafKey = test_helpers.AddressToLeafKey(test_helpers.ContractAddr) contractLeafKey = test_helpers.AddressToLeafKey(test_helpers.ContractAddr)
@ -1719,6 +1557,286 @@ func TestBuilderWithRemovedAccountAndStorageWithoutIntermediateNodes(t *testing.
} }
} }
func TestBuilderWithRemovedNonWatchedAccount(t *testing.T) {
blocks, chain := test_helpers.MakeChain(6, test_helpers.Genesis, test_helpers.TestChainGen)
contractLeafKey = test_helpers.AddressToLeafKey(test_helpers.ContractAddr)
defer chain.Stop()
block3 = blocks[2]
block4 = blocks[3]
block5 = blocks[4]
block6 = blocks[5]
params := statediff.Params{
WatchedAddresses: []common.Address{test_helpers.Account1Addr, test_helpers.Account2Addr},
}
params.ComputeWatchedAddressesLeafKeys()
builder = statediff.NewBuilder(chain.StateCache())
var tests = []struct {
name string
startingArguments statediff.Args
expected *types2.StateObject
}{
{
"testBlock4",
statediff.Args{
OldStateRoot: block3.Root(),
NewStateRoot: block4.Root(),
BlockNumber: block4.Number(),
BlockHash: block4.Hash(),
},
&types2.StateObject{
BlockNumber: block4.Number(),
BlockHash: block4.Hash(),
Nodes: []types2.StateNode{
{
Path: []byte{'\x0c'},
NodeType: types2.Leaf,
LeafKey: test_helpers.Account2LeafKey,
NodeValue: account2AtBlock4LeafNode,
StorageNodes: emptyStorage,
},
},
},
},
{
"testBlock5",
statediff.Args{
OldStateRoot: block4.Root(),
NewStateRoot: block5.Root(),
BlockNumber: block5.Number(),
BlockHash: block5.Hash(),
},
&types2.StateObject{
BlockNumber: block5.Number(),
BlockHash: block5.Hash(),
Nodes: []types2.StateNode{
{
Path: []byte{'\x0e'},
NodeType: types2.Leaf,
LeafKey: test_helpers.Account1LeafKey,
NodeValue: account1AtBlock5LeafNode,
StorageNodes: emptyStorage,
},
},
},
},
{
"testBlock6",
statediff.Args{
OldStateRoot: block5.Root(),
NewStateRoot: block6.Root(),
BlockNumber: block6.Number(),
BlockHash: block6.Hash(),
},
&types2.StateObject{
BlockNumber: block6.Number(),
BlockHash: block6.Hash(),
Nodes: []types2.StateNode{
{
Path: []byte{'\x0c'},
NodeType: types2.Leaf,
LeafKey: test_helpers.Account2LeafKey,
NodeValue: account2AtBlock6LeafNode,
StorageNodes: emptyStorage,
},
{
Path: []byte{'\x0e'},
NodeType: types2.Leaf,
LeafKey: test_helpers.Account1LeafKey,
NodeValue: account1AtBlock6LeafNode,
StorageNodes: emptyStorage,
},
},
},
},
}
for _, test := range tests {
diff, err := builder.BuildStateDiffObject(test.startingArguments, params)
if err != nil {
t.Error(err)
}
receivedStateDiffRlp, err := rlp.EncodeToBytes(diff)
if err != nil {
t.Error(err)
}
expectedStateDiffRlp, err := rlp.EncodeToBytes(test.expected)
if err != nil {
t.Error(err)
}
sort.Slice(receivedStateDiffRlp, func(i, j int) bool { return receivedStateDiffRlp[i] < receivedStateDiffRlp[j] })
sort.Slice(expectedStateDiffRlp, func(i, j int) bool { return expectedStateDiffRlp[i] < expectedStateDiffRlp[j] })
if !bytes.Equal(receivedStateDiffRlp, expectedStateDiffRlp) {
t.Logf("Test failed: %s", test.name)
t.Errorf("actual state diff: %+v\r\n\r\n\r\nexpected state diff: %+v", diff, test.expected)
}
}
}
func TestBuilderWithRemovedWatchedAccount(t *testing.T) {
blocks, chain := test_helpers.MakeChain(6, test_helpers.Genesis, test_helpers.TestChainGen)
contractLeafKey = test_helpers.AddressToLeafKey(test_helpers.ContractAddr)
defer chain.Stop()
block3 = blocks[2]
block4 = blocks[3]
block5 = blocks[4]
block6 = blocks[5]
params := statediff.Params{
WatchedAddresses: []common.Address{test_helpers.Account1Addr, test_helpers.ContractAddr},
}
params.ComputeWatchedAddressesLeafKeys()
builder = statediff.NewBuilder(chain.StateCache())
var tests = []struct {
name string
startingArguments statediff.Args
expected *types2.StateObject
}{
{
"testBlock4",
statediff.Args{
OldStateRoot: block3.Root(),
NewStateRoot: block4.Root(),
BlockNumber: block4.Number(),
BlockHash: block4.Hash(),
},
&types2.StateObject{
BlockNumber: block4.Number(),
BlockHash: block4.Hash(),
Nodes: []types2.StateNode{
{
Path: []byte{'\x06'},
NodeType: types2.Leaf,
LeafKey: contractLeafKey,
NodeValue: contractAccountAtBlock4LeafNode,
StorageNodes: []types2.StorageNode{
{
Path: []byte{'\x04'},
NodeType: types2.Leaf,
LeafKey: slot2StorageKey.Bytes(),
NodeValue: slot2StorageLeafNode,
},
{
Path: []byte{'\x0b'},
NodeType: types2.Removed,
LeafKey: slot1StorageKey.Bytes(),
NodeValue: []byte{},
},
{
Path: []byte{'\x0c'},
NodeType: types2.Removed,
LeafKey: slot3StorageKey.Bytes(),
NodeValue: []byte{},
},
},
},
},
},
},
{
"testBlock5",
statediff.Args{
OldStateRoot: block4.Root(),
NewStateRoot: block5.Root(),
BlockNumber: block5.Number(),
BlockHash: block5.Hash(),
},
&types2.StateObject{
BlockNumber: block5.Number(),
BlockHash: block5.Hash(),
Nodes: []types2.StateNode{
{
Path: []byte{'\x06'},
NodeType: types2.Leaf,
LeafKey: contractLeafKey,
NodeValue: contractAccountAtBlock5LeafNode,
StorageNodes: []types2.StorageNode{
{
Path: []byte{},
NodeType: types2.Leaf,
LeafKey: slot0StorageKey.Bytes(),
NodeValue: slot0StorageLeafRootNode,
},
{
Path: []byte{'\x02'},
NodeType: types2.Removed,
LeafKey: slot0StorageKey.Bytes(),
NodeValue: []byte{},
},
{
Path: []byte{'\x04'},
NodeType: types2.Removed,
LeafKey: slot2StorageKey.Bytes(),
NodeValue: []byte{},
},
},
},
{
Path: []byte{'\x0e'},
NodeType: types2.Leaf,
LeafKey: test_helpers.Account1LeafKey,
NodeValue: account1AtBlock5LeafNode,
StorageNodes: emptyStorage,
},
},
},
},
{
"testBlock6",
statediff.Args{
OldStateRoot: block5.Root(),
NewStateRoot: block6.Root(),
BlockNumber: block6.Number(),
BlockHash: block6.Hash(),
},
&types2.StateObject{
BlockNumber: block6.Number(),
BlockHash: block6.Hash(),
Nodes: []types2.StateNode{
{
Path: []byte{'\x06'},
NodeType: types2.Removed,
LeafKey: contractLeafKey,
NodeValue: []byte{},
},
{
Path: []byte{'\x0e'},
NodeType: types2.Leaf,
LeafKey: test_helpers.Account1LeafKey,
NodeValue: account1AtBlock6LeafNode,
StorageNodes: emptyStorage,
},
},
},
},
}
for _, test := range tests {
diff, err := builder.BuildStateDiffObject(test.startingArguments, params)
if err != nil {
t.Error(err)
}
receivedStateDiffRlp, err := rlp.EncodeToBytes(diff)
if err != nil {
t.Error(err)
}
expectedStateDiffRlp, err := rlp.EncodeToBytes(test.expected)
if err != nil {
t.Error(err)
}
sort.Slice(receivedStateDiffRlp, func(i, j int) bool { return receivedStateDiffRlp[i] < receivedStateDiffRlp[j] })
sort.Slice(expectedStateDiffRlp, func(i, j int) bool { return expectedStateDiffRlp[i] < expectedStateDiffRlp[j] })
if !bytes.Equal(receivedStateDiffRlp, expectedStateDiffRlp) {
t.Logf("Test failed: %s", test.name)
t.Errorf("actual state diff: %+v\r\n\r\n\r\nexpected state diff: %+v", diff, test.expected)
}
}
}
var ( var (
slot00StorageValue = common.Hex2Bytes("9471562b71999873db5b286df957af199ec94617f7") // prefixed TestBankAddress slot00StorageValue = common.Hex2Bytes("9471562b71999873db5b286df957af199ec94617f7") // prefixed TestBankAddress

View File

@ -19,8 +19,10 @@ package statediff
import ( import (
"context" "context"
"math/big" "math/big"
"sync"
"github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/crypto"
"github.com/ethereum/go-ethereum/statediff/indexer/interfaces" "github.com/ethereum/go-ethereum/statediff/indexer/interfaces"
) )
@ -53,7 +55,21 @@ type Params struct {
IncludeTD bool IncludeTD bool
IncludeCode bool IncludeCode bool
WatchedAddresses []common.Address WatchedAddresses []common.Address
WatchedStorageSlots []common.Hash watchedAddressesLeafKeys map[common.Hash]struct{}
}
// ComputeWatchedAddressesLeafKeys populates a map with keys (Keccak256Hash) of each of the WatchedAddresses
func (p *Params) ComputeWatchedAddressesLeafKeys() {
p.watchedAddressesLeafKeys = make(map[common.Hash]struct{}, len(p.WatchedAddresses))
for _, address := range p.WatchedAddresses {
p.watchedAddressesLeafKeys[crypto.Keccak256Hash(address.Bytes())] = struct{}{}
}
}
// ParamsWithMutex allows to lock the parameters while they are being updated | read from
type ParamsWithMutex struct {
Params
sync.RWMutex
} }
// Args bundles the arguments for the state diff builder // Args bundles the arguments for the state diff builder

View File

@ -496,3 +496,28 @@ func (sdi *StateDiffIndexer) PushCodeAndCodeHash(batch interfaces.Batch, codeAnd
func (sdi *StateDiffIndexer) Close() error { func (sdi *StateDiffIndexer) Close() error {
return sdi.dump.Close() return sdi.dump.Close()
} }
// LoadWatchedAddresses satisfies the interfaces.StateDiffIndexer interface
func (sdi *StateDiffIndexer) LoadWatchedAddresses() ([]common.Address, error) {
return nil, nil
}
// InsertWatchedAddresses satisfies the interfaces.StateDiffIndexer interface
func (sdi *StateDiffIndexer) InsertWatchedAddresses(args []sdtypes.WatchAddressArg, currentBlockNumber *big.Int) error {
return nil
}
// RemoveWatchedAddresses satisfies the interfaces.StateDiffIndexer interface
func (sdi *StateDiffIndexer) RemoveWatchedAddresses(args []sdtypes.WatchAddressArg) error {
return nil
}
// SetWatchedAddresses satisfies the interfaces.StateDiffIndexer interface
func (sdi *StateDiffIndexer) SetWatchedAddresses(args []sdtypes.WatchAddressArg, currentBlockNumber *big.Int) error {
return nil
}
// ClearWatchedAddresses satisfies the interfaces.StateDiffIndexer interface
func (sdi *StateDiffIndexer) ClearWatchedAddresses() error {
return nil
}

View File

@ -24,6 +24,7 @@ import (
// Config holds params for writing sql statements out to a file // Config holds params for writing sql statements out to a file
type Config struct { type Config struct {
FilePath string FilePath string
WatchedAddressesFilePath string
NodeInfo node.Info NodeInfo node.Info
} }
@ -35,6 +36,7 @@ func (c Config) Type() shared.DBType {
// TestConfig config for unit tests // TestConfig config for unit tests
var TestConfig = Config{ var TestConfig = Config{
FilePath: "./statediffing_test_file.sql", FilePath: "./statediffing_test_file.sql",
WatchedAddressesFilePath: "./statediffing_watched_addresses_test_file.sql",
NodeInfo: node.Info{ NodeInfo: node.Info{
GenesisBlock: "0xd4e56740f876aef8c010b86a40d5f56745a118d0906a34e69aec8c0db1cb8fa3", GenesisBlock: "0xd4e56740f876aef8c010b86a40d5f56745a118d0906a34e69aec8c0db1cb8fa3",
NetworkID: "1", NetworkID: "1",

View File

@ -17,6 +17,7 @@
package file package file
import ( import (
"bufio"
"context" "context"
"errors" "errors"
"fmt" "fmt"
@ -28,6 +29,8 @@ import (
"github.com/ipfs/go-cid" "github.com/ipfs/go-cid"
node "github.com/ipfs/go-ipld-format" node "github.com/ipfs/go-ipld-format"
"github.com/multiformats/go-multihash" "github.com/multiformats/go-multihash"
pg_query "github.com/pganalyze/pg_query_go/v2"
"github.com/thoas/go-funk"
"github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/core/types" "github.com/ethereum/go-ethereum/core/types"
@ -44,6 +47,9 @@ import (
) )
const defaultFilePath = "./statediff.sql" const defaultFilePath = "./statediff.sql"
const defaultWatchedAddressesFilePath = "./statediff-watched-addresses.sql"
const watchedAddressesInsert = "INSERT INTO eth_meta.watched_addresses (address, created_at, watched_at) VALUES ('%s', '%d', '%d') ON CONFLICT (address) DO NOTHING;"
var _ interfaces.StateDiffIndexer = &StateDiffIndexer{} var _ interfaces.StateDiffIndexer = &StateDiffIndexer{}
@ -57,6 +63,8 @@ type StateDiffIndexer struct {
chainConfig *params.ChainConfig chainConfig *params.ChainConfig
nodeID string nodeID string
wg *sync.WaitGroup wg *sync.WaitGroup
watchedAddressesFilePath string
} }
// NewStateDiffIndexer creates a void implementation of interfaces.StateDiffIndexer // NewStateDiffIndexer creates a void implementation of interfaces.StateDiffIndexer
@ -73,6 +81,13 @@ func NewStateDiffIndexer(ctx context.Context, chainConfig *params.ChainConfig, c
return nil, fmt.Errorf("unable to create file (%s), err: %v", filePath, err) return nil, fmt.Errorf("unable to create file (%s), err: %v", filePath, err)
} }
log.Info("Writing statediff SQL statements to file", "file", filePath) log.Info("Writing statediff SQL statements to file", "file", filePath)
watchedAddressesFilePath := config.WatchedAddressesFilePath
if watchedAddressesFilePath == "" {
watchedAddressesFilePath = defaultWatchedAddressesFilePath
}
log.Info("Writing watched addresses SQL statements to file", "file", watchedAddressesFilePath)
w := NewSQLWriter(file) w := NewSQLWriter(file)
wg := new(sync.WaitGroup) wg := new(sync.WaitGroup)
w.Loop() w.Loop()
@ -83,6 +98,7 @@ func NewStateDiffIndexer(ctx context.Context, chainConfig *params.ChainConfig, c
chainConfig: chainConfig, chainConfig: chainConfig,
nodeID: config.NodeInfo.ID, nodeID: config.NodeInfo.ID,
wg: wg, wg: wg,
watchedAddressesFilePath: watchedAddressesFilePath,
}, nil }, nil
} }
@ -478,3 +494,165 @@ func (sdi *StateDiffIndexer) PushCodeAndCodeHash(batch interfaces.Batch, codeAnd
func (sdi *StateDiffIndexer) Close() error { func (sdi *StateDiffIndexer) Close() error {
return sdi.fileWriter.Close() return sdi.fileWriter.Close()
} }
// LoadWatchedAddresses loads watched addresses from a file
func (sdi *StateDiffIndexer) LoadWatchedAddresses() ([]common.Address, error) {
// load sql statements from watched addresses file
stmts, err := loadWatchedAddressesStatements(sdi.watchedAddressesFilePath)
if err != nil {
return nil, err
}
// extract addresses from the sql statements
watchedAddresses := []common.Address{}
for _, stmt := range stmts {
addressString, err := parseWatchedAddressStatement(stmt)
if err != nil {
return nil, err
}
watchedAddresses = append(watchedAddresses, common.HexToAddress(addressString))
}
return watchedAddresses, nil
}
// InsertWatchedAddresses inserts the given addresses in a file
func (sdi *StateDiffIndexer) InsertWatchedAddresses(args []sdtypes.WatchAddressArg, currentBlockNumber *big.Int) error {
// load sql statements from watched addresses file
stmts, err := loadWatchedAddressesStatements(sdi.watchedAddressesFilePath)
if err != nil {
return err
}
// get already watched addresses
var watchedAddresses []string
for _, stmt := range stmts {
addressString, err := parseWatchedAddressStatement(stmt)
if err != nil {
return err
}
watchedAddresses = append(watchedAddresses, addressString)
}
// append statements for new addresses to existing statements
for _, arg := range args {
// ignore if already watched
if funk.Contains(watchedAddresses, arg.Address) {
continue
}
stmt := fmt.Sprintf(watchedAddressesInsert, arg.Address, arg.CreatedAt, currentBlockNumber.Uint64())
stmts = append(stmts, stmt)
}
return dumpWatchedAddressesStatements(sdi.watchedAddressesFilePath, stmts)
}
// RemoveWatchedAddresses removes the given watched addresses from a file
func (sdi *StateDiffIndexer) RemoveWatchedAddresses(args []sdtypes.WatchAddressArg) error {
// load sql statements from watched addresses file
stmts, err := loadWatchedAddressesStatements(sdi.watchedAddressesFilePath)
if err != nil {
return err
}
// get rid of statements having addresses to be removed
var filteredStmts []string
for _, stmt := range stmts {
addressString, err := parseWatchedAddressStatement(stmt)
if err != nil {
return err
}
toRemove := funk.Contains(args, func(arg sdtypes.WatchAddressArg) bool {
return arg.Address == addressString
})
if !toRemove {
filteredStmts = append(filteredStmts, stmt)
}
}
return dumpWatchedAddressesStatements(sdi.watchedAddressesFilePath, filteredStmts)
}
// SetWatchedAddresses clears and inserts the given addresses in a file
func (sdi *StateDiffIndexer) SetWatchedAddresses(args []sdtypes.WatchAddressArg, currentBlockNumber *big.Int) error {
var stmts []string
for _, arg := range args {
stmt := fmt.Sprintf(watchedAddressesInsert, arg.Address, arg.CreatedAt, currentBlockNumber.Uint64())
stmts = append(stmts, stmt)
}
return dumpWatchedAddressesStatements(sdi.watchedAddressesFilePath, stmts)
}
// ClearWatchedAddresses clears all the watched addresses from a file
func (sdi *StateDiffIndexer) ClearWatchedAddresses() error {
return sdi.SetWatchedAddresses([]sdtypes.WatchAddressArg{}, big.NewInt(0))
}
// loadWatchedAddressesStatements loads sql statements from the given file in a string slice
func loadWatchedAddressesStatements(filePath string) ([]string, error) {
file, err := os.Open(filePath)
if err != nil {
if errors.Is(err, os.ErrNotExist) {
return []string{}, nil
}
return nil, fmt.Errorf("error opening watched addresses file: %v", err)
}
defer file.Close()
stmts := []string{}
scanner := bufio.NewScanner(file)
for scanner.Scan() {
stmts = append(stmts, scanner.Text())
}
if err := scanner.Err(); err != nil {
return nil, fmt.Errorf("error loading watched addresses: %v", err)
}
return stmts, nil
}
// dumpWatchedAddressesStatements dumps sql statements to the given file
func dumpWatchedAddressesStatements(filePath string, stmts []string) error {
file, err := os.Create(filePath)
if err != nil {
return fmt.Errorf("error creating watched addresses file: %v", err)
}
defer file.Close()
for _, stmt := range stmts {
_, err := file.Write([]byte(stmt + "\n"))
if err != nil {
return fmt.Errorf("error inserting watched_addresses entry: %v", err)
}
}
return nil
}
// parseWatchedAddressStatement parses given sql insert statement to extract the address argument
func parseWatchedAddressStatement(stmt string) (string, error) {
parseResult, err := pg_query.Parse(stmt)
if err != nil {
return "", fmt.Errorf("error parsing sql stmt: %v", err)
}
// extract address argument from parse output for a SQL statement of form
// "INSERT INTO eth_meta.watched_addresses (address, created_at, watched_at)
// VALUES ('0xabc', '123', '130') ON CONFLICT (address) DO NOTHING;"
addressString := parseResult.Stmts[0].Stmt.GetInsertStmt().
SelectStmt.GetSelectStmt().
ValuesLists[0].GetList().
Items[0].GetAConst().
GetVal().
GetString_().
Str
return addressString, nil
}

View File

@ -81,7 +81,7 @@ func setupLegacy(t *testing.T) {
} }
} }
func dumpData(t *testing.T) { func dumpFileData(t *testing.T) {
sqlFileBytes, err := os.ReadFile(file.TestConfig.FilePath) sqlFileBytes, err := os.ReadFile(file.TestConfig.FilePath)
require.NoError(t, err) require.NoError(t, err)
@ -89,10 +89,36 @@ func dumpData(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
} }
func resetAndDumpWatchedAddressesFileData(t *testing.T) {
resetDB(t)
sqlFileBytes, err := os.ReadFile(file.TestConfig.WatchedAddressesFilePath)
require.NoError(t, err)
_, err = sqlxdb.Exec(string(sqlFileBytes))
require.NoError(t, err)
}
func resetDB(t *testing.T) {
file.TearDownDB(t, sqlxdb)
connStr := postgres.DefaultConfig.DbConnectionString()
sqlxdb, err = sqlx.Connect("postgres", connStr)
if err != nil {
t.Fatalf("failed to connect to db with connection string: %s err: %v", connStr, err)
}
}
func tearDown(t *testing.T) { func tearDown(t *testing.T) {
file.TearDownDB(t, sqlxdb) file.TearDownDB(t, sqlxdb)
err := os.Remove(file.TestConfig.FilePath) err := os.Remove(file.TestConfig.FilePath)
require.NoError(t, err) require.NoError(t, err)
if err := os.Remove(file.TestConfig.WatchedAddressesFilePath); !errors.Is(err, os.ErrNotExist) {
require.NoError(t, err)
}
err = sqlxdb.Close() err = sqlxdb.Close()
require.NoError(t, err) require.NoError(t, err)
} }
@ -106,7 +132,7 @@ func expectTrue(t *testing.T, value bool) {
func TestFileIndexerLegacy(t *testing.T) { func TestFileIndexerLegacy(t *testing.T) {
t.Run("Publish and index header IPLDs", func(t *testing.T) { t.Run("Publish and index header IPLDs", func(t *testing.T) {
setupLegacy(t) setupLegacy(t)
dumpData(t) dumpFileData(t)
defer tearDown(t) defer tearDown(t)
pgStr := `SELECT cid, td, reward, block_hash, coinbase pgStr := `SELECT cid, td, reward, block_hash, coinbase
FROM eth.header_cids FROM eth.header_cids

View File

@ -21,6 +21,7 @@ import (
"context" "context"
"errors" "errors"
"fmt" "fmt"
"math/big"
"os" "os"
"testing" "testing"
@ -28,6 +29,7 @@ import (
"github.com/ethereum/go-ethereum/rlp" "github.com/ethereum/go-ethereum/rlp"
"github.com/ethereum/go-ethereum/statediff/indexer/models" "github.com/ethereum/go-ethereum/statediff/indexer/models"
"github.com/ethereum/go-ethereum/statediff/indexer/shared" "github.com/ethereum/go-ethereum/statediff/indexer/shared"
sdtypes "github.com/ethereum/go-ethereum/statediff/types"
"github.com/ipfs/go-cid" "github.com/ipfs/go-cid"
blockstore "github.com/ipfs/go-ipfs-blockstore" blockstore "github.com/ipfs/go-ipfs-blockstore"
@ -57,6 +59,9 @@ var (
rct1CID, rct2CID, rct3CID, rct4CID, rct5CID cid.Cid rct1CID, rct2CID, rct3CID, rct4CID, rct5CID cid.Cid
rctLeaf1, rctLeaf2, rctLeaf3, rctLeaf4, rctLeaf5 []byte rctLeaf1, rctLeaf2, rctLeaf3, rctLeaf4, rctLeaf5 []byte
state1CID, state2CID, storageCID cid.Cid state1CID, state2CID, storageCID cid.Cid
contract1Address, contract2Address, contract3Address, contract4Address string
contract1CreatedAt, contract2CreatedAt, contract3CreatedAt, contract4CreatedAt uint64
lastFilledAt, watchedAt1, watchedAt2, watchedAt3 uint64
) )
func init() { func init() {
@ -161,15 +166,45 @@ func init() {
rctLeaf3 = orderedRctLeafNodes[2] rctLeaf3 = orderedRctLeafNodes[2]
rctLeaf4 = orderedRctLeafNodes[3] rctLeaf4 = orderedRctLeafNodes[3]
rctLeaf5 = orderedRctLeafNodes[4] rctLeaf5 = orderedRctLeafNodes[4]
contract1Address = "0x5d663F5269090bD2A7DC2390c911dF6083D7b28F"
contract2Address = "0x6Eb7e5C66DB8af2E96159AC440cbc8CDB7fbD26B"
contract3Address = "0xcfeB164C328CA13EFd3C77E1980d94975aDfedfc"
contract4Address = "0x0Edf0c4f393a628DE4828B228C48175b3EA297fc"
contract1CreatedAt = uint64(1)
contract2CreatedAt = uint64(2)
contract3CreatedAt = uint64(3)
contract4CreatedAt = uint64(4)
lastFilledAt = uint64(0)
watchedAt1 = uint64(10)
watchedAt2 = uint64(15)
watchedAt3 = uint64(20)
} }
func setup(t *testing.T) { func setupIndexer(t *testing.T) {
if _, err := os.Stat(file.TestConfig.FilePath); !errors.Is(err, os.ErrNotExist) { if _, err := os.Stat(file.TestConfig.FilePath); !errors.Is(err, os.ErrNotExist) {
err := os.Remove(file.TestConfig.FilePath) err := os.Remove(file.TestConfig.FilePath)
require.NoError(t, err) require.NoError(t, err)
} }
if _, err := os.Stat(file.TestConfig.WatchedAddressesFilePath); !errors.Is(err, os.ErrNotExist) {
err := os.Remove(file.TestConfig.WatchedAddressesFilePath)
require.NoError(t, err)
}
ind, err = file.NewStateDiffIndexer(context.Background(), mocks.TestConfig, file.TestConfig) ind, err = file.NewStateDiffIndexer(context.Background(), mocks.TestConfig, file.TestConfig)
require.NoError(t, err) require.NoError(t, err)
connStr := postgres.DefaultConfig.DbConnectionString()
sqlxdb, err = sqlx.Connect("postgres", connStr)
if err != nil {
t.Fatalf("failed to connect to db with connection string: %s err: %v", connStr, err)
}
}
func setup(t *testing.T) {
setupIndexer(t)
var tx interfaces.Batch var tx interfaces.Batch
tx, err = ind.PushBlock( tx, err = ind.PushBlock(
mockBlock, mockBlock,
@ -192,19 +227,12 @@ func setup(t *testing.T) {
} }
test_helpers.ExpectEqual(t, tx.(*file.BatchTx).BlockNumber, mocks.BlockNumber.Uint64()) test_helpers.ExpectEqual(t, tx.(*file.BatchTx).BlockNumber, mocks.BlockNumber.Uint64())
connStr := postgres.DefaultConfig.DbConnectionString()
sqlxdb, err = sqlx.Connect("postgres", connStr)
if err != nil {
t.Fatalf("failed to connect to db with connection string: %s err: %v", connStr, err)
}
} }
func TestFileIndexer(t *testing.T) { func TestFileIndexer(t *testing.T) {
t.Run("Publish and index header IPLDs in a single tx", func(t *testing.T) { t.Run("Publish and index header IPLDs in a single tx", func(t *testing.T) {
setup(t) setup(t)
dumpData(t) dumpFileData(t)
defer tearDown(t) defer tearDown(t)
pgStr := `SELECT cid, td, reward, block_hash, coinbase pgStr := `SELECT cid, td, reward, block_hash, coinbase
FROM eth.header_cids FROM eth.header_cids
@ -242,7 +270,7 @@ func TestFileIndexer(t *testing.T) {
}) })
t.Run("Publish and index transaction IPLDs in a single tx", func(t *testing.T) { t.Run("Publish and index transaction IPLDs in a single tx", func(t *testing.T) {
setup(t) setup(t)
dumpData(t) dumpFileData(t)
defer tearDown(t) defer tearDown(t)
// check that txs were properly indexed and published // check that txs were properly indexed and published
@ -370,7 +398,7 @@ func TestFileIndexer(t *testing.T) {
t.Run("Publish and index log IPLDs for multiple receipt of a specific block", func(t *testing.T) { t.Run("Publish and index log IPLDs for multiple receipt of a specific block", func(t *testing.T) {
setup(t) setup(t)
dumpData(t) dumpFileData(t)
defer tearDown(t) defer tearDown(t)
rcts := make([]string, 0) rcts := make([]string, 0)
@ -426,7 +454,7 @@ func TestFileIndexer(t *testing.T) {
t.Run("Publish and index receipt IPLDs in a single tx", func(t *testing.T) { t.Run("Publish and index receipt IPLDs in a single tx", func(t *testing.T) {
setup(t) setup(t)
dumpData(t) dumpFileData(t)
defer tearDown(t) defer tearDown(t)
// check receipts were properly indexed and published // check receipts were properly indexed and published
@ -527,7 +555,7 @@ func TestFileIndexer(t *testing.T) {
t.Run("Publish and index state IPLDs in a single tx", func(t *testing.T) { t.Run("Publish and index state IPLDs in a single tx", func(t *testing.T) {
setup(t) setup(t)
dumpData(t) dumpFileData(t)
defer tearDown(t) defer tearDown(t)
// check that state nodes were properly indexed and published // check that state nodes were properly indexed and published
@ -618,7 +646,7 @@ func TestFileIndexer(t *testing.T) {
t.Run("Publish and index storage IPLDs in a single tx", func(t *testing.T) { t.Run("Publish and index storage IPLDs in a single tx", func(t *testing.T) {
setup(t) setup(t)
dumpData(t) dumpFileData(t)
defer tearDown(t) defer tearDown(t)
// check that storage nodes were properly indexed // check that storage nodes were properly indexed
@ -688,3 +716,341 @@ func TestFileIndexer(t *testing.T) {
test_helpers.ExpectEqual(t, data, []byte{}) test_helpers.ExpectEqual(t, data, []byte{})
}) })
} }
func TestFileWatchAddressMethods(t *testing.T) {
setupIndexer(t)
defer tearDown(t)
type res struct {
Address string `db:"address"`
CreatedAt uint64 `db:"created_at"`
WatchedAt uint64 `db:"watched_at"`
LastFilledAt uint64 `db:"last_filled_at"`
}
pgStr := "SELECT * FROM eth_meta.watched_addresses"
t.Run("Load watched addresses (empty table)", func(t *testing.T) {
expectedData := []common.Address{}
rows, err := ind.LoadWatchedAddresses()
require.NoError(t, err)
expectTrue(t, len(rows) == len(expectedData))
for idx, row := range rows {
test_helpers.ExpectEqual(t, row, expectedData[idx])
}
})
t.Run("Insert watched addresses", func(t *testing.T) {
args := []sdtypes.WatchAddressArg{
{
Address: contract1Address,
CreatedAt: contract1CreatedAt,
},
{
Address: contract2Address,
CreatedAt: contract2CreatedAt,
},
}
expectedData := []res{
{
Address: contract1Address,
CreatedAt: contract1CreatedAt,
WatchedAt: watchedAt1,
LastFilledAt: lastFilledAt,
},
{
Address: contract2Address,
CreatedAt: contract2CreatedAt,
WatchedAt: watchedAt1,
LastFilledAt: lastFilledAt,
},
}
err = ind.InsertWatchedAddresses(args, big.NewInt(int64(watchedAt1)))
require.NoError(t, err)
resetAndDumpWatchedAddressesFileData(t)
rows := []res{}
err = sqlxdb.Select(&rows, pgStr)
if err != nil {
t.Fatal(err)
}
expectTrue(t, len(rows) == len(expectedData))
for idx, row := range rows {
test_helpers.ExpectEqual(t, row, expectedData[idx])
}
})
t.Run("Insert watched addresses (some already watched)", func(t *testing.T) {
args := []sdtypes.WatchAddressArg{
{
Address: contract3Address,
CreatedAt: contract3CreatedAt,
},
{
Address: contract2Address,
CreatedAt: contract2CreatedAt,
},
}
expectedData := []res{
{
Address: contract1Address,
CreatedAt: contract1CreatedAt,
WatchedAt: watchedAt1,
LastFilledAt: lastFilledAt,
},
{
Address: contract2Address,
CreatedAt: contract2CreatedAt,
WatchedAt: watchedAt1,
LastFilledAt: lastFilledAt,
},
{
Address: contract3Address,
CreatedAt: contract3CreatedAt,
WatchedAt: watchedAt2,
LastFilledAt: lastFilledAt,
},
}
err = ind.InsertWatchedAddresses(args, big.NewInt(int64(watchedAt2)))
require.NoError(t, err)
resetAndDumpWatchedAddressesFileData(t)
rows := []res{}
err = sqlxdb.Select(&rows, pgStr)
if err != nil {
t.Fatal(err)
}
expectTrue(t, len(rows) == len(expectedData))
for idx, row := range rows {
test_helpers.ExpectEqual(t, row, expectedData[idx])
}
})
t.Run("Remove watched addresses", func(t *testing.T) {
args := []sdtypes.WatchAddressArg{
{
Address: contract3Address,
CreatedAt: contract3CreatedAt,
},
{
Address: contract2Address,
CreatedAt: contract2CreatedAt,
},
}
expectedData := []res{
{
Address: contract1Address,
CreatedAt: contract1CreatedAt,
WatchedAt: watchedAt1,
LastFilledAt: lastFilledAt,
},
}
err = ind.RemoveWatchedAddresses(args)
require.NoError(t, err)
resetAndDumpWatchedAddressesFileData(t)
rows := []res{}
err = sqlxdb.Select(&rows, pgStr)
if err != nil {
t.Fatal(err)
}
expectTrue(t, len(rows) == len(expectedData))
for idx, row := range rows {
test_helpers.ExpectEqual(t, row, expectedData[idx])
}
})
t.Run("Remove watched addresses (some non-watched)", func(t *testing.T) {
args := []sdtypes.WatchAddressArg{
{
Address: contract1Address,
CreatedAt: contract1CreatedAt,
},
{
Address: contract2Address,
CreatedAt: contract2CreatedAt,
},
}
expectedData := []res{}
err = ind.RemoveWatchedAddresses(args)
require.NoError(t, err)
resetAndDumpWatchedAddressesFileData(t)
rows := []res{}
err = sqlxdb.Select(&rows, pgStr)
if err != nil {
t.Fatal(err)
}
expectTrue(t, len(rows) == len(expectedData))
for idx, row := range rows {
test_helpers.ExpectEqual(t, row, expectedData[idx])
}
})
t.Run("Set watched addresses", func(t *testing.T) {
args := []sdtypes.WatchAddressArg{
{
Address: contract1Address,
CreatedAt: contract1CreatedAt,
},
{
Address: contract2Address,
CreatedAt: contract2CreatedAt,
},
{
Address: contract3Address,
CreatedAt: contract3CreatedAt,
},
}
expectedData := []res{
{
Address: contract1Address,
CreatedAt: contract1CreatedAt,
WatchedAt: watchedAt2,
LastFilledAt: lastFilledAt,
},
{
Address: contract2Address,
CreatedAt: contract2CreatedAt,
WatchedAt: watchedAt2,
LastFilledAt: lastFilledAt,
},
{
Address: contract3Address,
CreatedAt: contract3CreatedAt,
WatchedAt: watchedAt2,
LastFilledAt: lastFilledAt,
},
}
err = ind.SetWatchedAddresses(args, big.NewInt(int64(watchedAt2)))
require.NoError(t, err)
resetAndDumpWatchedAddressesFileData(t)
rows := []res{}
err = sqlxdb.Select(&rows, pgStr)
if err != nil {
t.Fatal(err)
}
expectTrue(t, len(rows) == len(expectedData))
for idx, row := range rows {
test_helpers.ExpectEqual(t, row, expectedData[idx])
}
})
t.Run("Set watched addresses (some already watched)", func(t *testing.T) {
args := []sdtypes.WatchAddressArg{
{
Address: contract4Address,
CreatedAt: contract4CreatedAt,
},
{
Address: contract2Address,
CreatedAt: contract2CreatedAt,
},
{
Address: contract3Address,
CreatedAt: contract3CreatedAt,
},
}
expectedData := []res{
{
Address: contract4Address,
CreatedAt: contract4CreatedAt,
WatchedAt: watchedAt3,
LastFilledAt: lastFilledAt,
},
{
Address: contract2Address,
CreatedAt: contract2CreatedAt,
WatchedAt: watchedAt3,
LastFilledAt: lastFilledAt,
},
{
Address: contract3Address,
CreatedAt: contract3CreatedAt,
WatchedAt: watchedAt3,
LastFilledAt: lastFilledAt,
},
}
err = ind.SetWatchedAddresses(args, big.NewInt(int64(watchedAt3)))
require.NoError(t, err)
resetAndDumpWatchedAddressesFileData(t)
rows := []res{}
err = sqlxdb.Select(&rows, pgStr)
if err != nil {
t.Fatal(err)
}
expectTrue(t, len(rows) == len(expectedData))
for idx, row := range rows {
test_helpers.ExpectEqual(t, row, expectedData[idx])
}
})
t.Run("Load watched addresses", func(t *testing.T) {
expectedData := []common.Address{
common.HexToAddress(contract4Address),
common.HexToAddress(contract2Address),
common.HexToAddress(contract3Address),
}
rows, err := ind.LoadWatchedAddresses()
require.NoError(t, err)
expectTrue(t, len(rows) == len(expectedData))
for idx, row := range rows {
test_helpers.ExpectEqual(t, row, expectedData[idx])
}
})
t.Run("Clear watched addresses", func(t *testing.T) {
expectedData := []res{}
err = ind.ClearWatchedAddresses()
require.NoError(t, err)
resetAndDumpWatchedAddressesFileData(t)
rows := []res{}
err = sqlxdb.Select(&rows, pgStr)
if err != nil {
t.Fatal(err)
}
expectTrue(t, len(rows) == len(expectedData))
for idx, row := range rows {
test_helpers.ExpectEqual(t, row, expectedData[idx])
}
})
t.Run("Clear watched addresses (empty table)", func(t *testing.T) {
expectedData := []res{}
err = ind.ClearWatchedAddresses()
require.NoError(t, err)
resetAndDumpWatchedAddressesFileData(t)
rows := []res{}
err = sqlxdb.Select(&rows, pgStr)
if err != nil {
t.Fatal(err)
}
expectTrue(t, len(rows) == len(expectedData))
for idx, row := range rows {
test_helpers.ExpectEqual(t, row, expectedData[idx])
}
})
}

View File

@ -57,6 +57,10 @@ func TearDownDB(t *testing.T, db *sqlx.DB) {
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
_, err = tx.Exec(`DELETE FROM eth_meta.watched_addresses`)
if err != nil {
t.Fatal(err)
}
err = tx.Commit() err = tx.Commit()
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)

View File

@ -555,3 +555,118 @@ func (sdi *StateDiffIndexer) Close() error {
} }
// Update the known gaps table with the gap information. // Update the known gaps table with the gap information.
// LoadWatchedAddresses reads watched addresses from the database
func (sdi *StateDiffIndexer) LoadWatchedAddresses() ([]common.Address, error) {
addressStrings := make([]string, 0)
pgStr := "SELECT address FROM eth_meta.watched_addresses"
err := sdi.dbWriter.db.Select(sdi.ctx, &addressStrings, pgStr)
if err != nil {
return nil, fmt.Errorf("error loading watched addresses: %v", err)
}
watchedAddresses := []common.Address{}
for _, addressString := range addressStrings {
watchedAddresses = append(watchedAddresses, common.HexToAddress(addressString))
}
return watchedAddresses, nil
}
// InsertWatchedAddresses inserts the given addresses in the database
func (sdi *StateDiffIndexer) InsertWatchedAddresses(args []sdtypes.WatchAddressArg, currentBlockNumber *big.Int) error {
tx, err := sdi.dbWriter.db.Begin(sdi.ctx)
if err != nil {
return err
}
defer func() {
if p := recover(); p != nil {
rollback(sdi.ctx, tx)
panic(p)
} else if err != nil {
rollback(sdi.ctx, tx)
} else {
err = tx.Commit(sdi.ctx)
}
}()
for _, arg := range args {
_, err = tx.Exec(sdi.ctx, `INSERT INTO eth_meta.watched_addresses (address, created_at, watched_at) VALUES ($1, $2, $3) ON CONFLICT (address) DO NOTHING`,
arg.Address, arg.CreatedAt, currentBlockNumber.Uint64())
if err != nil {
return fmt.Errorf("error inserting watched_addresses entry: %v", err)
}
}
return err
}
// RemoveWatchedAddresses removes the given watched addresses from the database
func (sdi *StateDiffIndexer) RemoveWatchedAddresses(args []sdtypes.WatchAddressArg) error {
tx, err := sdi.dbWriter.db.Begin(sdi.ctx)
if err != nil {
return err
}
defer func() {
if p := recover(); p != nil {
rollback(sdi.ctx, tx)
panic(p)
} else if err != nil {
rollback(sdi.ctx, tx)
} else {
err = tx.Commit(sdi.ctx)
}
}()
for _, arg := range args {
_, err = tx.Exec(sdi.ctx, `DELETE FROM eth_meta.watched_addresses WHERE address = $1`, arg.Address)
if err != nil {
return fmt.Errorf("error removing watched_addresses entry: %v", err)
}
}
return err
}
// SetWatchedAddresses clears and inserts the given addresses in the database
func (sdi *StateDiffIndexer) SetWatchedAddresses(args []sdtypes.WatchAddressArg, currentBlockNumber *big.Int) error {
tx, err := sdi.dbWriter.db.Begin(sdi.ctx)
if err != nil {
return err
}
defer func() {
if p := recover(); p != nil {
rollback(sdi.ctx, tx)
panic(p)
} else if err != nil {
rollback(sdi.ctx, tx)
} else {
err = tx.Commit(sdi.ctx)
}
}()
_, err = tx.Exec(sdi.ctx, `DELETE FROM eth_meta.watched_addresses`)
if err != nil {
return fmt.Errorf("error setting watched_addresses table: %v", err)
}
for _, arg := range args {
_, err = tx.Exec(sdi.ctx, `INSERT INTO eth_meta.watched_addresses (address, created_at, watched_at) VALUES ($1, $2, $3) ON CONFLICT (address) DO NOTHING`,
arg.Address, arg.CreatedAt, currentBlockNumber.Uint64())
if err != nil {
return fmt.Errorf("error setting watched_addresses table: %v", err)
}
}
return err
}
// ClearWatchedAddresses clears all the watched addresses from the database
func (sdi *StateDiffIndexer) ClearWatchedAddresses() error {
_, err := sdi.dbWriter.db.Exec(sdi.ctx, `DELETE FROM eth_meta.watched_addresses`)
if err != nil {
return fmt.Errorf("error clearing watched_addresses table: %v", err)
}
return nil
}

View File

@ -30,6 +30,9 @@ var (
rct1CID, rct2CID, rct3CID, rct4CID, rct5CID cid.Cid rct1CID, rct2CID, rct3CID, rct4CID, rct5CID cid.Cid
rctLeaf1, rctLeaf2, rctLeaf3, rctLeaf4, rctLeaf5 []byte rctLeaf1, rctLeaf2, rctLeaf3, rctLeaf4, rctLeaf5 []byte
state1CID, state2CID, storageCID cid.Cid state1CID, state2CID, storageCID cid.Cid
contract1Address, contract2Address, contract3Address, contract4Address string
contract1CreatedAt, contract2CreatedAt, contract3CreatedAt, contract4CreatedAt uint64
lastFilledAt, watchedAt1, watchedAt2, watchedAt3 uint64
) )
func init() { func init() {
@ -134,6 +137,20 @@ func init() {
rctLeaf3 = orderedRctLeafNodes[2] rctLeaf3 = orderedRctLeafNodes[2]
rctLeaf4 = orderedRctLeafNodes[3] rctLeaf4 = orderedRctLeafNodes[3]
rctLeaf5 = orderedRctLeafNodes[4] rctLeaf5 = orderedRctLeafNodes[4]
contract1Address = "0x5d663F5269090bD2A7DC2390c911dF6083D7b28F"
contract2Address = "0x6Eb7e5C66DB8af2E96159AC440cbc8CDB7fbD26B"
contract3Address = "0xcfeB164C328CA13EFd3C77E1980d94975aDfedfc"
contract4Address = "0x0Edf0c4f393a628DE4828B228C48175b3EA297fc"
contract1CreatedAt = uint64(1)
contract2CreatedAt = uint64(2)
contract3CreatedAt = uint64(3)
contract4CreatedAt = uint64(4)
lastFilledAt = uint64(0)
watchedAt1 = uint64(10)
watchedAt2 = uint64(15)
watchedAt3 = uint64(20)
} }
func expectTrue(t *testing.T, value bool) { func expectTrue(t *testing.T, value bool) {

View File

@ -18,6 +18,7 @@ package sql_test
import ( import (
"context" "context"
"math/big"
"testing" "testing"
"github.com/ipfs/go-cid" "github.com/ipfs/go-cid"
@ -35,15 +36,20 @@ import (
"github.com/ethereum/go-ethereum/statediff/indexer/models" "github.com/ethereum/go-ethereum/statediff/indexer/models"
"github.com/ethereum/go-ethereum/statediff/indexer/shared" "github.com/ethereum/go-ethereum/statediff/indexer/shared"
"github.com/ethereum/go-ethereum/statediff/indexer/test_helpers" "github.com/ethereum/go-ethereum/statediff/indexer/test_helpers"
sdtypes "github.com/ethereum/go-ethereum/statediff/types"
) )
func setupPGX(t *testing.T) { func setupPGXIndexer(t *testing.T) {
db, err = postgres.SetupPGXDB() db, err = postgres.SetupPGXDB()
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
ind, err = sql.NewStateDiffIndexer(context.Background(), mocks.TestConfig, db) ind, err = sql.NewStateDiffIndexer(context.Background(), mocks.TestConfig, db)
require.NoError(t, err) require.NoError(t, err)
}
func setupPGX(t *testing.T) {
setupPGXIndexer(t)
var tx interfaces.Batch var tx interfaces.Batch
tx, err = ind.PushBlock( tx, err = ind.PushBlock(
mockBlock, mockBlock,
@ -557,3 +563,334 @@ func TestPGXIndexer(t *testing.T) {
test_helpers.ExpectEqual(t, data, []byte{}) test_helpers.ExpectEqual(t, data, []byte{})
}) })
} }
func TestPGXWatchAddressMethods(t *testing.T) {
setupPGXIndexer(t)
defer tearDown(t)
defer checkTxClosure(t, 1, 0, 1)
type res struct {
Address string `db:"address"`
CreatedAt uint64 `db:"created_at"`
WatchedAt uint64 `db:"watched_at"`
LastFilledAt uint64 `db:"last_filled_at"`
}
pgStr := "SELECT * FROM eth_meta.watched_addresses"
t.Run("Load watched addresses (empty table)", func(t *testing.T) {
expectedData := []common.Address{}
rows, err := ind.LoadWatchedAddresses()
require.NoError(t, err)
expectTrue(t, len(rows) == len(expectedData))
for idx, row := range rows {
test_helpers.ExpectEqual(t, row, expectedData[idx])
}
})
t.Run("Insert watched addresses", func(t *testing.T) {
args := []sdtypes.WatchAddressArg{
{
Address: contract1Address,
CreatedAt: contract1CreatedAt,
},
{
Address: contract2Address,
CreatedAt: contract2CreatedAt,
},
}
expectedData := []res{
{
Address: contract1Address,
CreatedAt: contract1CreatedAt,
WatchedAt: watchedAt1,
LastFilledAt: lastFilledAt,
},
{
Address: contract2Address,
CreatedAt: contract2CreatedAt,
WatchedAt: watchedAt1,
LastFilledAt: lastFilledAt,
},
}
err = ind.InsertWatchedAddresses(args, big.NewInt(int64(watchedAt1)))
require.NoError(t, err)
rows := []res{}
err = db.Select(context.Background(), &rows, pgStr)
if err != nil {
t.Fatal(err)
}
expectTrue(t, len(rows) == len(expectedData))
for idx, row := range rows {
test_helpers.ExpectEqual(t, row, expectedData[idx])
}
})
t.Run("Insert watched addresses (some already watched)", func(t *testing.T) {
args := []sdtypes.WatchAddressArg{
{
Address: contract3Address,
CreatedAt: contract3CreatedAt,
},
{
Address: contract2Address,
CreatedAt: contract2CreatedAt,
},
}
expectedData := []res{
{
Address: contract1Address,
CreatedAt: contract1CreatedAt,
WatchedAt: watchedAt1,
LastFilledAt: lastFilledAt,
},
{
Address: contract2Address,
CreatedAt: contract2CreatedAt,
WatchedAt: watchedAt1,
LastFilledAt: lastFilledAt,
},
{
Address: contract3Address,
CreatedAt: contract3CreatedAt,
WatchedAt: watchedAt2,
LastFilledAt: lastFilledAt,
},
}
err = ind.InsertWatchedAddresses(args, big.NewInt(int64(watchedAt2)))
require.NoError(t, err)
rows := []res{}
err = db.Select(context.Background(), &rows, pgStr)
if err != nil {
t.Fatal(err)
}
expectTrue(t, len(rows) == len(expectedData))
for idx, row := range rows {
test_helpers.ExpectEqual(t, row, expectedData[idx])
}
})
t.Run("Remove watched addresses", func(t *testing.T) {
args := []sdtypes.WatchAddressArg{
{
Address: contract3Address,
CreatedAt: contract3CreatedAt,
},
{
Address: contract2Address,
CreatedAt: contract2CreatedAt,
},
}
expectedData := []res{
{
Address: contract1Address,
CreatedAt: contract1CreatedAt,
WatchedAt: watchedAt1,
LastFilledAt: lastFilledAt,
},
}
err = ind.RemoveWatchedAddresses(args)
require.NoError(t, err)
rows := []res{}
err = db.Select(context.Background(), &rows, pgStr)
if err != nil {
t.Fatal(err)
}
expectTrue(t, len(rows) == len(expectedData))
for idx, row := range rows {
test_helpers.ExpectEqual(t, row, expectedData[idx])
}
})
t.Run("Remove watched addresses (some non-watched)", func(t *testing.T) {
args := []sdtypes.WatchAddressArg{
{
Address: contract1Address,
CreatedAt: contract1CreatedAt,
},
{
Address: contract2Address,
CreatedAt: contract2CreatedAt,
},
}
expectedData := []res{}
err = ind.RemoveWatchedAddresses(args)
require.NoError(t, err)
rows := []res{}
err = db.Select(context.Background(), &rows, pgStr)
if err != nil {
t.Fatal(err)
}
expectTrue(t, len(rows) == len(expectedData))
for idx, row := range rows {
test_helpers.ExpectEqual(t, row, expectedData[idx])
}
})
t.Run("Set watched addresses", func(t *testing.T) {
args := []sdtypes.WatchAddressArg{
{
Address: contract1Address,
CreatedAt: contract1CreatedAt,
},
{
Address: contract2Address,
CreatedAt: contract2CreatedAt,
},
{
Address: contract3Address,
CreatedAt: contract3CreatedAt,
},
}
expectedData := []res{
{
Address: contract1Address,
CreatedAt: contract1CreatedAt,
WatchedAt: watchedAt2,
LastFilledAt: lastFilledAt,
},
{
Address: contract2Address,
CreatedAt: contract2CreatedAt,
WatchedAt: watchedAt2,
LastFilledAt: lastFilledAt,
},
{
Address: contract3Address,
CreatedAt: contract3CreatedAt,
WatchedAt: watchedAt2,
LastFilledAt: lastFilledAt,
},
}
err = ind.SetWatchedAddresses(args, big.NewInt(int64(watchedAt2)))
require.NoError(t, err)
rows := []res{}
err = db.Select(context.Background(), &rows, pgStr)
if err != nil {
t.Fatal(err)
}
expectTrue(t, len(rows) == len(expectedData))
for idx, row := range rows {
test_helpers.ExpectEqual(t, row, expectedData[idx])
}
})
t.Run("Set watched addresses (some already watched)", func(t *testing.T) {
args := []sdtypes.WatchAddressArg{
{
Address: contract4Address,
CreatedAt: contract4CreatedAt,
},
{
Address: contract2Address,
CreatedAt: contract2CreatedAt,
},
{
Address: contract3Address,
CreatedAt: contract3CreatedAt,
},
}
expectedData := []res{
{
Address: contract4Address,
CreatedAt: contract4CreatedAt,
WatchedAt: watchedAt3,
LastFilledAt: lastFilledAt,
},
{
Address: contract2Address,
CreatedAt: contract2CreatedAt,
WatchedAt: watchedAt3,
LastFilledAt: lastFilledAt,
},
{
Address: contract3Address,
CreatedAt: contract3CreatedAt,
WatchedAt: watchedAt3,
LastFilledAt: lastFilledAt,
},
}
err = ind.SetWatchedAddresses(args, big.NewInt(int64(watchedAt3)))
require.NoError(t, err)
rows := []res{}
err = db.Select(context.Background(), &rows, pgStr)
if err != nil {
t.Fatal(err)
}
expectTrue(t, len(rows) == len(expectedData))
for idx, row := range rows {
test_helpers.ExpectEqual(t, row, expectedData[idx])
}
})
t.Run("Load watched addresses", func(t *testing.T) {
expectedData := []common.Address{
common.HexToAddress(contract4Address),
common.HexToAddress(contract2Address),
common.HexToAddress(contract3Address),
}
rows, err := ind.LoadWatchedAddresses()
require.NoError(t, err)
expectTrue(t, len(rows) == len(expectedData))
for idx, row := range rows {
test_helpers.ExpectEqual(t, row, expectedData[idx])
}
})
t.Run("Clear watched addresses", func(t *testing.T) {
expectedData := []res{}
err = ind.ClearWatchedAddresses()
require.NoError(t, err)
rows := []res{}
err = db.Select(context.Background(), &rows, pgStr)
if err != nil {
t.Fatal(err)
}
expectTrue(t, len(rows) == len(expectedData))
for idx, row := range rows {
test_helpers.ExpectEqual(t, row, expectedData[idx])
}
})
t.Run("Clear watched addresses (empty table)", func(t *testing.T) {
expectedData := []res{}
err = ind.ClearWatchedAddresses()
require.NoError(t, err)
rows := []res{}
err = db.Select(context.Background(), &rows, pgStr)
if err != nil {
t.Fatal(err)
}
expectTrue(t, len(rows) == len(expectedData))
for idx, row := range rows {
test_helpers.ExpectEqual(t, row, expectedData[idx])
}
})
}

View File

@ -18,6 +18,7 @@ package sql_test
import ( import (
"context" "context"
"math/big"
"testing" "testing"
"github.com/ipfs/go-cid" "github.com/ipfs/go-cid"
@ -36,15 +37,20 @@ import (
"github.com/ethereum/go-ethereum/statediff/indexer/models" "github.com/ethereum/go-ethereum/statediff/indexer/models"
"github.com/ethereum/go-ethereum/statediff/indexer/shared" "github.com/ethereum/go-ethereum/statediff/indexer/shared"
"github.com/ethereum/go-ethereum/statediff/indexer/test_helpers" "github.com/ethereum/go-ethereum/statediff/indexer/test_helpers"
sdtypes "github.com/ethereum/go-ethereum/statediff/types"
) )
func setupSQLX(t *testing.T) { func setupSQLXIndexer(t *testing.T) {
db, err = postgres.SetupSQLXDB() db, err = postgres.SetupSQLXDB()
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
ind, err = sql.NewStateDiffIndexer(context.Background(), mocks.TestConfig, db) ind, err = sql.NewStateDiffIndexer(context.Background(), mocks.TestConfig, db)
require.NoError(t, err) require.NoError(t, err)
}
func setupSQLX(t *testing.T) {
setupSQLXIndexer(t)
var tx interfaces.Batch var tx interfaces.Batch
tx, err = ind.PushBlock( tx, err = ind.PushBlock(
mockBlock, mockBlock,
@ -550,3 +556,334 @@ func TestSQLXIndexer(t *testing.T) {
test_helpers.ExpectEqual(t, data, []byte{}) test_helpers.ExpectEqual(t, data, []byte{})
}) })
} }
func TestSQLXWatchAddressMethods(t *testing.T) {
setupSQLXIndexer(t)
defer tearDown(t)
defer checkTxClosure(t, 0, 0, 0)
type res struct {
Address string `db:"address"`
CreatedAt uint64 `db:"created_at"`
WatchedAt uint64 `db:"watched_at"`
LastFilledAt uint64 `db:"last_filled_at"`
}
pgStr := "SELECT * FROM eth_meta.watched_addresses"
t.Run("Load watched addresses (empty table)", func(t *testing.T) {
expectedData := []common.Address{}
rows, err := ind.LoadWatchedAddresses()
require.NoError(t, err)
expectTrue(t, len(rows) == len(expectedData))
for idx, row := range rows {
test_helpers.ExpectEqual(t, row, expectedData[idx])
}
})
t.Run("Insert watched addresses", func(t *testing.T) {
args := []sdtypes.WatchAddressArg{
{
Address: contract1Address,
CreatedAt: contract1CreatedAt,
},
{
Address: contract2Address,
CreatedAt: contract2CreatedAt,
},
}
expectedData := []res{
{
Address: contract1Address,
CreatedAt: contract1CreatedAt,
WatchedAt: watchedAt1,
LastFilledAt: lastFilledAt,
},
{
Address: contract2Address,
CreatedAt: contract2CreatedAt,
WatchedAt: watchedAt1,
LastFilledAt: lastFilledAt,
},
}
err = ind.InsertWatchedAddresses(args, big.NewInt(int64(watchedAt1)))
require.NoError(t, err)
rows := []res{}
err = db.Select(context.Background(), &rows, pgStr)
if err != nil {
t.Fatal(err)
}
expectTrue(t, len(rows) == len(expectedData))
for idx, row := range rows {
test_helpers.ExpectEqual(t, row, expectedData[idx])
}
})
t.Run("Insert watched addresses (some already watched)", func(t *testing.T) {
args := []sdtypes.WatchAddressArg{
{
Address: contract3Address,
CreatedAt: contract3CreatedAt,
},
{
Address: contract2Address,
CreatedAt: contract2CreatedAt,
},
}
expectedData := []res{
{
Address: contract1Address,
CreatedAt: contract1CreatedAt,
WatchedAt: watchedAt1,
LastFilledAt: lastFilledAt,
},
{
Address: contract2Address,
CreatedAt: contract2CreatedAt,
WatchedAt: watchedAt1,
LastFilledAt: lastFilledAt,
},
{
Address: contract3Address,
CreatedAt: contract3CreatedAt,
WatchedAt: watchedAt2,
LastFilledAt: lastFilledAt,
},
}
err = ind.InsertWatchedAddresses(args, big.NewInt(int64(watchedAt2)))
require.NoError(t, err)
rows := []res{}
err = db.Select(context.Background(), &rows, pgStr)
if err != nil {
t.Fatal(err)
}
expectTrue(t, len(rows) == len(expectedData))
for idx, row := range rows {
test_helpers.ExpectEqual(t, row, expectedData[idx])
}
})
t.Run("Remove watched addresses", func(t *testing.T) {
args := []sdtypes.WatchAddressArg{
{
Address: contract3Address,
CreatedAt: contract3CreatedAt,
},
{
Address: contract2Address,
CreatedAt: contract2CreatedAt,
},
}
expectedData := []res{
{
Address: contract1Address,
CreatedAt: contract1CreatedAt,
WatchedAt: watchedAt1,
LastFilledAt: lastFilledAt,
},
}
err = ind.RemoveWatchedAddresses(args)
require.NoError(t, err)
rows := []res{}
err = db.Select(context.Background(), &rows, pgStr)
if err != nil {
t.Fatal(err)
}
expectTrue(t, len(rows) == len(expectedData))
for idx, row := range rows {
test_helpers.ExpectEqual(t, row, expectedData[idx])
}
})
t.Run("Remove watched addresses (some non-watched)", func(t *testing.T) {
args := []sdtypes.WatchAddressArg{
{
Address: contract1Address,
CreatedAt: contract1CreatedAt,
},
{
Address: contract2Address,
CreatedAt: contract2CreatedAt,
},
}
expectedData := []res{}
err = ind.RemoveWatchedAddresses(args)
require.NoError(t, err)
rows := []res{}
err = db.Select(context.Background(), &rows, pgStr)
if err != nil {
t.Fatal(err)
}
expectTrue(t, len(rows) == len(expectedData))
for idx, row := range rows {
test_helpers.ExpectEqual(t, row, expectedData[idx])
}
})
t.Run("Set watched addresses", func(t *testing.T) {
args := []sdtypes.WatchAddressArg{
{
Address: contract1Address,
CreatedAt: contract1CreatedAt,
},
{
Address: contract2Address,
CreatedAt: contract2CreatedAt,
},
{
Address: contract3Address,
CreatedAt: contract3CreatedAt,
},
}
expectedData := []res{
{
Address: contract1Address,
CreatedAt: contract1CreatedAt,
WatchedAt: watchedAt2,
LastFilledAt: lastFilledAt,
},
{
Address: contract2Address,
CreatedAt: contract2CreatedAt,
WatchedAt: watchedAt2,
LastFilledAt: lastFilledAt,
},
{
Address: contract3Address,
CreatedAt: contract3CreatedAt,
WatchedAt: watchedAt2,
LastFilledAt: lastFilledAt,
},
}
err = ind.SetWatchedAddresses(args, big.NewInt(int64(watchedAt2)))
require.NoError(t, err)
rows := []res{}
err = db.Select(context.Background(), &rows, pgStr)
if err != nil {
t.Fatal(err)
}
expectTrue(t, len(rows) == len(expectedData))
for idx, row := range rows {
test_helpers.ExpectEqual(t, row, expectedData[idx])
}
})
t.Run("Set watched addresses (some already watched)", func(t *testing.T) {
args := []sdtypes.WatchAddressArg{
{
Address: contract4Address,
CreatedAt: contract4CreatedAt,
},
{
Address: contract2Address,
CreatedAt: contract2CreatedAt,
},
{
Address: contract3Address,
CreatedAt: contract3CreatedAt,
},
}
expectedData := []res{
{
Address: contract4Address,
CreatedAt: contract4CreatedAt,
WatchedAt: watchedAt3,
LastFilledAt: lastFilledAt,
},
{
Address: contract2Address,
CreatedAt: contract2CreatedAt,
WatchedAt: watchedAt3,
LastFilledAt: lastFilledAt,
},
{
Address: contract3Address,
CreatedAt: contract3CreatedAt,
WatchedAt: watchedAt3,
LastFilledAt: lastFilledAt,
},
}
err = ind.SetWatchedAddresses(args, big.NewInt(int64(watchedAt3)))
require.NoError(t, err)
rows := []res{}
err = db.Select(context.Background(), &rows, pgStr)
if err != nil {
t.Fatal(err)
}
expectTrue(t, len(rows) == len(expectedData))
for idx, row := range rows {
test_helpers.ExpectEqual(t, row, expectedData[idx])
}
})
t.Run("Load watched addresses", func(t *testing.T) {
expectedData := []common.Address{
common.HexToAddress(contract4Address),
common.HexToAddress(contract2Address),
common.HexToAddress(contract3Address),
}
rows, err := ind.LoadWatchedAddresses()
require.NoError(t, err)
expectTrue(t, len(rows) == len(expectedData))
for idx, row := range rows {
test_helpers.ExpectEqual(t, row, expectedData[idx])
}
})
t.Run("Clear watched addresses", func(t *testing.T) {
expectedData := []res{}
err = ind.ClearWatchedAddresses()
require.NoError(t, err)
rows := []res{}
err = db.Select(context.Background(), &rows, pgStr)
if err != nil {
t.Fatal(err)
}
expectTrue(t, len(rows) == len(expectedData))
for idx, row := range rows {
test_helpers.ExpectEqual(t, row, expectedData[idx])
}
})
t.Run("Clear watched addresses (empty table)", func(t *testing.T) {
expectedData := []res{}
err = ind.ClearWatchedAddresses()
require.NoError(t, err)
rows := []res{}
err = db.Select(context.Background(), &rows, pgStr)
if err != nil {
t.Fatal(err)
}
expectTrue(t, len(rows) == len(expectedData))
for idx, row := range rows {
test_helpers.ExpectEqual(t, row, expectedData[idx])
}
})
}

View File

@ -73,6 +73,10 @@ func TearDownDB(t *testing.T, db Database) {
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
_, err = tx.Exec(ctx, `DELETE FROM eth_meta.watched_addresses`)
if err != nil {
t.Fatal(err)
}
err = tx.Commit(ctx) err = tx.Commit(ctx)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)

View File

@ -21,6 +21,7 @@ import (
"math/big" "math/big"
"time" "time"
"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/core/types" "github.com/ethereum/go-ethereum/core/types"
"github.com/ethereum/go-ethereum/statediff/indexer/shared" "github.com/ethereum/go-ethereum/statediff/indexer/shared"
sdtypes "github.com/ethereum/go-ethereum/statediff/types" sdtypes "github.com/ethereum/go-ethereum/statediff/types"
@ -32,6 +33,14 @@ type StateDiffIndexer interface {
PushStateNode(tx Batch, stateNode sdtypes.StateNode, headerID string) error PushStateNode(tx Batch, stateNode sdtypes.StateNode, headerID string) error
PushCodeAndCodeHash(tx Batch, codeAndCodeHash sdtypes.CodeAndCodeHash) error PushCodeAndCodeHash(tx Batch, codeAndCodeHash sdtypes.CodeAndCodeHash) error
ReportDBMetrics(delay time.Duration, quit <-chan bool) ReportDBMetrics(delay time.Duration, quit <-chan bool)
// Methods used by WatchAddress API/functionality
LoadWatchedAddresses() ([]common.Address, error)
InsertWatchedAddresses(addresses []sdtypes.WatchAddressArg, currentBlock *big.Int) error
RemoveWatchedAddresses(addresses []sdtypes.WatchAddressArg) error
SetWatchedAddresses(args []sdtypes.WatchAddressArg, currentBlockNumber *big.Int) error
ClearWatchedAddresses() error
io.Closer io.Closer
} }

View File

@ -18,6 +18,7 @@ package statediff
import ( import (
"bytes" "bytes"
"fmt"
"math/big" "math/big"
"strconv" "strconv"
"strings" "strings"
@ -47,6 +48,7 @@ import (
"github.com/ethereum/go-ethereum/statediff/indexer/shared" "github.com/ethereum/go-ethereum/statediff/indexer/shared"
types2 "github.com/ethereum/go-ethereum/statediff/types" types2 "github.com/ethereum/go-ethereum/statediff/types"
"github.com/ethereum/go-ethereum/trie" "github.com/ethereum/go-ethereum/trie"
"github.com/thoas/go-funk"
) )
const ( const (
@ -54,21 +56,26 @@ const (
genesisBlockNumber = 0 genesisBlockNumber = 0
defaultRetryLimit = 3 // default retry limit once deadlock is detected. defaultRetryLimit = 3 // default retry limit once deadlock is detected.
deadlockDetected = "deadlock detected" // 40P01 https://www.postgresql.org/docs/current/errcodes-appendix.html deadlockDetected = "deadlock detected" // 40P01 https://www.postgresql.org/docs/current/errcodes-appendix.html
typeAssertionFailed = "type assertion failed"
unexpectedOperation = "unexpected operation"
) )
var writeLoopParams = Params{ var writeLoopParams = ParamsWithMutex{
Params: Params{
IntermediateStateNodes: true, IntermediateStateNodes: true,
IntermediateStorageNodes: true, IntermediateStorageNodes: true,
IncludeBlock: true, IncludeBlock: true,
IncludeReceipts: true, IncludeReceipts: true,
IncludeTD: true, IncludeTD: true,
IncludeCode: true, IncludeCode: true,
},
} }
var statediffMetrics = RegisterStatediffMetrics(metrics.DefaultRegistry) var statediffMetrics = RegisterStatediffMetrics(metrics.DefaultRegistry)
type blockChain interface { type blockChain interface {
SubscribeChainEvent(ch chan<- core.ChainEvent) event.Subscription SubscribeChainEvent(ch chan<- core.ChainEvent) event.Subscription
CurrentBlock() *types.Block
GetBlockByHash(hash common.Hash) *types.Block GetBlockByHash(hash common.Hash) *types.Block
GetBlockByNumber(number uint64) *types.Block GetBlockByNumber(number uint64) *types.Block
GetReceiptsByHash(hash common.Hash) types.Receipts GetReceiptsByHash(hash common.Hash) types.Receipts
@ -103,6 +110,8 @@ type IService interface {
WriteStateDiffFor(blockHash common.Hash, params Params) error WriteStateDiffFor(blockHash common.Hash, params Params) error
// WriteLoop event loop for progressively processing and writing diffs directly to DB // WriteLoop event loop for progressively processing and writing diffs directly to DB
WriteLoop(chainEventCh chan core.ChainEvent) WriteLoop(chainEventCh chan core.ChainEvent)
// Method to change the addresses being watched in write loop params
WatchAddress(operation types2.OperationType, args []types2.WatchAddressArg) error
} }
// Service is the underlying struct for the state diffing service // Service is the underlying struct for the state diffing service
@ -159,6 +168,7 @@ func New(stack *node.Node, ethServ *eth.Ethereum, cfg *ethconfig.Config, params
blockChain := ethServ.BlockChain() blockChain := ethServ.BlockChain()
var indexer interfaces.StateDiffIndexer var indexer interfaces.StateDiffIndexer
var db sql.Database var db sql.Database
var err error
quitCh := make(chan bool) quitCh := make(chan bool)
if params.IndexerConfig != nil { if params.IndexerConfig != nil {
info := nodeinfo.Info{ info := nodeinfo.Info{
@ -215,6 +225,12 @@ func New(stack *node.Node, ethServ *eth.Ethereum, cfg *ethconfig.Config, params
} }
stack.RegisterLifecycle(sds) stack.RegisterLifecycle(sds)
stack.RegisterAPIs(sds.APIs()) stack.RegisterAPIs(sds.APIs())
err = loadWatchedAddresses(indexer)
if err != nil {
return err
}
return nil return nil
} }
@ -304,7 +320,9 @@ func (sds *Service) WriteLoop(chainEventCh chan core.ChainEvent) {
func (sds *Service) writeGenesisStateDiff(currBlock *types.Block, workerId uint) { func (sds *Service) writeGenesisStateDiff(currBlock *types.Block, workerId uint) {
// For genesis block we need to return the entire state trie hence we diff it with an empty trie. // For genesis block we need to return the entire state trie hence we diff it with an empty trie.
log.Info("Writing state diff", "block height", genesisBlockNumber, "worker", workerId) log.Info("Writing state diff", "block height", genesisBlockNumber, "worker", workerId)
err := sds.writeStateDiffWithRetry(currBlock, common.Hash{}, writeLoopParams) writeLoopParams.RLock()
err := sds.writeStateDiffWithRetry(currBlock, common.Hash{}, writeLoopParams.Params)
writeLoopParams.RUnlock()
if err != nil { if err != nil {
log.Error("statediff.Service.WriteLoop: processing error", "block height", log.Error("statediff.Service.WriteLoop: processing error", "block height",
genesisBlockNumber, "error", err.Error(), "worker", workerId) genesisBlockNumber, "error", err.Error(), "worker", workerId)
@ -341,7 +359,9 @@ func (sds *Service) writeLoopWorker(params workerParams) {
} }
log.Info("Writing state diff", "block height", currentBlock.Number().Uint64(), "worker", params.id) log.Info("Writing state diff", "block height", currentBlock.Number().Uint64(), "worker", params.id)
err := sds.writeStateDiffWithRetry(currentBlock, parentBlock.Root(), writeLoopParams) writeLoopParams.RLock()
err := sds.writeStateDiffWithRetry(currentBlock, parentBlock.Root(), writeLoopParams.Params)
writeLoopParams.RUnlock()
if err != nil { if err != nil {
log.Error("statediff.Service.WriteLoop: processing error", "block height", currentBlock.Number().Uint64(), "error", err.Error(), "worker", params.id) log.Error("statediff.Service.WriteLoop: processing error", "block height", currentBlock.Number().Uint64(), "error", err.Error(), "worker", params.id)
sds.KnownGaps.errorState = true sds.KnownGaps.errorState = true
@ -456,6 +476,10 @@ func (sds *Service) streamStateDiff(currentBlock *types.Block, parentRoot common
func (sds *Service) StateDiffAt(blockNumber uint64, params Params) (*Payload, error) { func (sds *Service) StateDiffAt(blockNumber uint64, params Params) (*Payload, error) {
currentBlock := sds.BlockChain.GetBlockByNumber(blockNumber) currentBlock := sds.BlockChain.GetBlockByNumber(blockNumber)
log.Info("sending state diff", "block height", blockNumber) log.Info("sending state diff", "block height", blockNumber)
// compute leaf keys of watched addresses in the params
params.ComputeWatchedAddressesLeafKeys()
if blockNumber == 0 { if blockNumber == 0 {
return sds.processStateDiff(currentBlock, common.Hash{}, params) return sds.processStateDiff(currentBlock, common.Hash{}, params)
} }
@ -468,6 +492,10 @@ func (sds *Service) StateDiffAt(blockNumber uint64, params Params) (*Payload, er
func (sds *Service) StateDiffFor(blockHash common.Hash, params Params) (*Payload, error) { func (sds *Service) StateDiffFor(blockHash common.Hash, params Params) (*Payload, error) {
currentBlock := sds.BlockChain.GetBlockByHash(blockHash) currentBlock := sds.BlockChain.GetBlockByHash(blockHash)
log.Info("sending state diff", "block hash", blockHash) log.Info("sending state diff", "block hash", blockHash)
// compute leaf keys of watched addresses in the params
params.ComputeWatchedAddressesLeafKeys()
if currentBlock.NumberU64() == 0 { if currentBlock.NumberU64() == 0 {
return sds.processStateDiff(currentBlock, common.Hash{}, params) return sds.processStateDiff(currentBlock, common.Hash{}, params)
} }
@ -526,6 +554,10 @@ func (sds *Service) newPayload(stateObject []byte, block *types.Block, params Pa
func (sds *Service) StateTrieAt(blockNumber uint64, params Params) (*Payload, error) { func (sds *Service) StateTrieAt(blockNumber uint64, params Params) (*Payload, error) {
currentBlock := sds.BlockChain.GetBlockByNumber(blockNumber) currentBlock := sds.BlockChain.GetBlockByNumber(blockNumber)
log.Info("sending state trie", "block height", blockNumber) log.Info("sending state trie", "block height", blockNumber)
// compute leaf keys of watched addresses in the params
params.ComputeWatchedAddressesLeafKeys()
return sds.processStateTrie(currentBlock, params) return sds.processStateTrie(currentBlock, params)
} }
@ -548,6 +580,10 @@ func (sds *Service) Subscribe(id rpc.ID, sub chan<- Payload, quitChan chan<- boo
if atomic.CompareAndSwapInt32(&sds.subscribers, 0, 1) { if atomic.CompareAndSwapInt32(&sds.subscribers, 0, 1) {
log.Info("State diffing subscription received; beginning statediff processing") log.Info("State diffing subscription received; beginning statediff processing")
} }
// compute leaf keys of watched addresses in the params
params.ComputeWatchedAddressesLeafKeys()
// Subscription type is defined as the hash of the rlp-serialized subscription params // Subscription type is defined as the hash of the rlp-serialized subscription params
by, err := rlp.EncodeToBytes(params) by, err := rlp.EncodeToBytes(params)
if err != nil { if err != nil {
@ -644,7 +680,7 @@ func (sds *Service) Start() error {
go sds.Loop(chainEventCh) go sds.Loop(chainEventCh)
if sds.enableWriteLoop { if sds.enableWriteLoop {
log.Info("Starting statediff DB write loop", "params", writeLoopParams) log.Info("Starting statediff DB write loop", "params", writeLoopParams.Params)
chainEventCh := make(chan core.ChainEvent, chainEventChanSize) chainEventCh := make(chan core.ChainEvent, chainEventChanSize)
go sds.WriteLoop(chainEventCh) go sds.WriteLoop(chainEventCh)
} }
@ -741,6 +777,9 @@ func (sds *Service) StreamCodeAndCodeHash(blockNumber uint64, outChan chan<- typ
// This operation cannot be performed back past the point of db pruning; it requires an archival node // This operation cannot be performed back past the point of db pruning; it requires an archival node
// for historical data // for historical data
func (sds *Service) WriteStateDiffAt(blockNumber uint64, params Params) error { func (sds *Service) WriteStateDiffAt(blockNumber uint64, params Params) error {
// compute leaf keys of watched addresses in the params
params.ComputeWatchedAddressesLeafKeys()
currentBlock := sds.BlockChain.GetBlockByNumber(blockNumber) currentBlock := sds.BlockChain.GetBlockByNumber(blockNumber)
parentRoot := common.Hash{} parentRoot := common.Hash{}
if blockNumber != 0 { if blockNumber != 0 {
@ -754,6 +793,9 @@ func (sds *Service) WriteStateDiffAt(blockNumber uint64, params Params) error {
// This operation cannot be performed back past the point of db pruning; it requires an archival node // This operation cannot be performed back past the point of db pruning; it requires an archival node
// for historical data // for historical data
func (sds *Service) WriteStateDiffFor(blockHash common.Hash, params Params) error { func (sds *Service) WriteStateDiffFor(blockHash common.Hash, params Params) error {
// compute leaf keys of watched addresses in the params
params.ComputeWatchedAddressesLeafKeys()
currentBlock := sds.BlockChain.GetBlockByHash(blockHash) currentBlock := sds.BlockChain.GetBlockByHash(blockHash)
parentRoot := common.Hash{} parentRoot := common.Hash{}
if currentBlock.NumberU64() != 0 { if currentBlock.NumberU64() != 0 {
@ -821,3 +863,130 @@ func (sds *Service) writeStateDiffWithRetry(block *types.Block, parentRoot commo
} }
return err return err
} }
// Performs one of following operations on the watched addresses in writeLoopParams and the db:
// add | remove | set | clear
func (sds *Service) WatchAddress(operation types2.OperationType, args []types2.WatchAddressArg) error {
// lock writeLoopParams for a write
writeLoopParams.Lock()
defer writeLoopParams.Unlock()
// get the current block number
currentBlockNumber := sds.BlockChain.CurrentBlock().Number()
switch operation {
case types2.Add:
// filter out args having an already watched address with a warning
filteredArgs, ok := funk.Filter(args, func(arg types2.WatchAddressArg) bool {
if funk.Contains(writeLoopParams.WatchedAddresses, common.HexToAddress(arg.Address)) {
log.Warn("Address already being watched", "address", arg.Address)
return false
}
return true
}).([]types2.WatchAddressArg)
if !ok {
return fmt.Errorf("add: filtered args %s", typeAssertionFailed)
}
// get addresses from the filtered args
filteredAddresses, err := MapWatchAddressArgsToAddresses(filteredArgs)
if err != nil {
return fmt.Errorf("add: filtered addresses %s", err.Error())
}
// update the db
err = sds.indexer.InsertWatchedAddresses(filteredArgs, currentBlockNumber)
if err != nil {
return err
}
// update in-memory params
writeLoopParams.WatchedAddresses = append(writeLoopParams.WatchedAddresses, filteredAddresses...)
funk.ForEach(filteredAddresses, func(address common.Address) {
writeLoopParams.watchedAddressesLeafKeys[crypto.Keccak256Hash(address.Bytes())] = struct{}{}
})
case types2.Remove:
// get addresses from args
argAddresses, err := MapWatchAddressArgsToAddresses(args)
if err != nil {
return fmt.Errorf("remove: mapped addresses %s", err.Error())
}
// remove the provided addresses from currently watched addresses
addresses, ok := funk.Subtract(writeLoopParams.WatchedAddresses, argAddresses).([]common.Address)
if !ok {
return fmt.Errorf("remove: filtered addresses %s", typeAssertionFailed)
}
// update the db
err = sds.indexer.RemoveWatchedAddresses(args)
if err != nil {
return err
}
// update in-memory params
writeLoopParams.WatchedAddresses = addresses
funk.ForEach(argAddresses, func(address common.Address) {
delete(writeLoopParams.watchedAddressesLeafKeys, crypto.Keccak256Hash(address.Bytes()))
})
case types2.Set:
// get addresses from args
argAddresses, err := MapWatchAddressArgsToAddresses(args)
if err != nil {
return fmt.Errorf("set: mapped addresses %s", err.Error())
}
// update the db
err = sds.indexer.SetWatchedAddresses(args, currentBlockNumber)
if err != nil {
return err
}
// update in-memory params
writeLoopParams.WatchedAddresses = argAddresses
writeLoopParams.ComputeWatchedAddressesLeafKeys()
case types2.Clear:
// update the db
err := sds.indexer.ClearWatchedAddresses()
if err != nil {
return err
}
// update in-memory params
writeLoopParams.WatchedAddresses = []common.Address{}
writeLoopParams.ComputeWatchedAddressesLeafKeys()
default:
return fmt.Errorf("%s %s", unexpectedOperation, operation)
}
return nil
}
// loadWatchedAddresses loads watched addresses to in-memory write loop params
func loadWatchedAddresses(indexer interfaces.StateDiffIndexer) error {
watchedAddresses, err := indexer.LoadWatchedAddresses()
if err != nil {
return err
}
writeLoopParams.Lock()
defer writeLoopParams.Unlock()
writeLoopParams.WatchedAddresses = watchedAddresses
writeLoopParams.ComputeWatchedAddressesLeafKeys()
return nil
}
// MapWatchAddressArgsToAddresses maps []WatchAddressArg to corresponding []common.Address
func MapWatchAddressArgsToAddresses(args []types2.WatchAddressArg) ([]common.Address, error) {
addresses, ok := funk.Map(args, func(arg types2.WatchAddressArg) common.Address {
return common.HexToAddress(arg.Address)
}).([]common.Address)
if !ok {
return nil, fmt.Errorf(typeAssertionFailed)
}
return addresses, nil
}

View File

@ -146,6 +146,7 @@ func testErrorInChainEventLoop(t *testing.T) {
} }
} }
defaultParams.ComputeWatchedAddressesLeafKeys()
if !reflect.DeepEqual(builder.Params, defaultParams) { if !reflect.DeepEqual(builder.Params, defaultParams) {
t.Error("Test failure:", t.Name()) t.Error("Test failure:", t.Name())
t.Logf("Actual params does not equal expected.\nactual:%+v\nexpected: %+v", builder.Params, defaultParams) t.Logf("Actual params does not equal expected.\nactual:%+v\nexpected: %+v", builder.Params, defaultParams)
@ -197,6 +198,8 @@ func testErrorInBlockLoop(t *testing.T) {
} }
}() }()
service.Loop(eventsChannel) service.Loop(eventsChannel)
defaultParams.ComputeWatchedAddressesLeafKeys()
if !reflect.DeepEqual(builder.Params, defaultParams) { if !reflect.DeepEqual(builder.Params, defaultParams) {
t.Error("Test failure:", t.Name()) t.Error("Test failure:", t.Name())
t.Logf("Actual params does not equal expected.\nactual:%+v\nexpected: %+v", builder.Params, defaultParams) t.Logf("Actual params does not equal expected.\nactual:%+v\nexpected: %+v", builder.Params, defaultParams)
@ -270,6 +273,8 @@ func testErrorInStateDiffAt(t *testing.T) {
if err != nil { if err != nil {
t.Error(err) t.Error(err)
} }
defaultParams.ComputeWatchedAddressesLeafKeys()
if !reflect.DeepEqual(builder.Params, defaultParams) { if !reflect.DeepEqual(builder.Params, defaultParams) {
t.Error("Test failure:", t.Name()) t.Error("Test failure:", t.Name())
t.Logf("Actual params does not equal expected.\nactual:%+v\nexpected: %+v", builder.Params, defaultParams) t.Logf("Actual params does not equal expected.\nactual:%+v\nexpected: %+v", builder.Params, defaultParams)

View File

@ -39,6 +39,7 @@ type BlockChain struct {
Receipts map[common.Hash]types.Receipts Receipts map[common.Hash]types.Receipts
TDByHash map[common.Hash]*big.Int TDByHash map[common.Hash]*big.Int
TDByNum map[uint64]*big.Int TDByNum map[uint64]*big.Int
currentBlock *types.Block
} }
// SetBlocksForHashes mock method // SetBlocksForHashes mock method
@ -128,6 +129,16 @@ func (bc *BlockChain) GetTd(hash common.Hash, blockNum uint64) *big.Int {
return nil return nil
} }
// SetCurrentBlock test method
func (bc *BlockChain) SetCurrentBlock(block *types.Block) {
bc.currentBlock = block
}
// CurrentBlock mock method
func (bc *BlockChain) CurrentBlock() *types.Block {
return bc.currentBlock
}
func (bc *BlockChain) SetTd(hash common.Hash, blockNum uint64, td *big.Int) { func (bc *BlockChain) SetTd(hash common.Hash, blockNum uint64, td *big.Int) {
if bc.TDByHash == nil { if bc.TDByHash == nil {
bc.TDByHash = make(map[common.Hash]*big.Int) bc.TDByHash = make(map[common.Hash]*big.Int)

View File

@ -0,0 +1,70 @@
// Copyright 2022 The go-ethereum Authors
// This file is part of the go-ethereum library.
//
// The go-ethereum library is free software: you can redistribute it and/or modify
// it under the terms of the GNU Lesser General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// The go-ethereum library is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Lesser General Public License for more details.
//
// You should have received a copy of the GNU Lesser General Public License
// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
package mocks
import (
"math/big"
"time"
"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/core/types"
"github.com/ethereum/go-ethereum/statediff/indexer/interfaces"
sdtypes "github.com/ethereum/go-ethereum/statediff/types"
)
var _ interfaces.StateDiffIndexer = &StateDiffIndexer{}
// StateDiffIndexer is a mock state diff indexer
type StateDiffIndexer struct{}
func (sdi *StateDiffIndexer) PushBlock(block *types.Block, receipts types.Receipts, totalDifficulty *big.Int) (interfaces.Batch, error) {
return nil, nil
}
func (sdi *StateDiffIndexer) PushStateNode(tx interfaces.Batch, stateNode sdtypes.StateNode, headerID string) error {
return nil
}
func (sdi *StateDiffIndexer) PushCodeAndCodeHash(tx interfaces.Batch, codeAndCodeHash sdtypes.CodeAndCodeHash) error {
return nil
}
func (sdi *StateDiffIndexer) ReportDBMetrics(delay time.Duration, quit <-chan bool) {}
func (sdi *StateDiffIndexer) LoadWatchedAddresses() ([]common.Address, error) {
return nil, nil
}
func (sdi *StateDiffIndexer) InsertWatchedAddresses(addresses []sdtypes.WatchAddressArg, currentBlock *big.Int) error {
return nil
}
func (sdi *StateDiffIndexer) RemoveWatchedAddresses(addresses []sdtypes.WatchAddressArg) error {
return nil
}
func (sdi *StateDiffIndexer) SetWatchedAddresses(args []sdtypes.WatchAddressArg, currentBlockNumber *big.Int) error {
return nil
}
func (sdi *StateDiffIndexer) ClearWatchedAddresses() error {
return nil
}
func (sdi *StateDiffIndexer) Close() error {
return nil
}

View File

@ -25,6 +25,7 @@ import (
"github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/crypto" "github.com/ethereum/go-ethereum/crypto"
"github.com/ethereum/go-ethereum/rlp" "github.com/ethereum/go-ethereum/rlp"
"github.com/thoas/go-funk"
"github.com/ethereum/go-ethereum/core" "github.com/ethereum/go-ethereum/core"
"github.com/ethereum/go-ethereum/core/types" "github.com/ethereum/go-ethereum/core/types"
@ -32,9 +33,15 @@ import (
"github.com/ethereum/go-ethereum/p2p" "github.com/ethereum/go-ethereum/p2p"
"github.com/ethereum/go-ethereum/rpc" "github.com/ethereum/go-ethereum/rpc"
"github.com/ethereum/go-ethereum/statediff" "github.com/ethereum/go-ethereum/statediff"
"github.com/ethereum/go-ethereum/statediff/indexer/interfaces"
sdtypes "github.com/ethereum/go-ethereum/statediff/types" sdtypes "github.com/ethereum/go-ethereum/statediff/types"
) )
var (
typeAssertionFailed = "type assertion failed"
unexpectedOperation = "unexpected operation"
)
// MockStateDiffService is a mock state diff service // MockStateDiffService is a mock state diff service
type MockStateDiffService struct { type MockStateDiffService struct {
sync.Mutex sync.Mutex
@ -47,6 +54,8 @@ type MockStateDiffService struct {
QuitChan chan bool QuitChan chan bool
Subscriptions map[common.Hash]map[rpc.ID]statediff.Subscription Subscriptions map[common.Hash]map[rpc.ID]statediff.Subscription
SubscriptionTypes map[common.Hash]statediff.Params SubscriptionTypes map[common.Hash]statediff.Params
Indexer interfaces.StateDiffIndexer
writeLoopParams statediff.ParamsWithMutex
} }
// Protocols mock method // Protocols mock method
@ -332,3 +341,98 @@ func sendNonBlockingQuit(id rpc.ID, sub statediff.Subscription) {
log.Info("unable to close subscription %s; channel has no receiver", id) log.Info("unable to close subscription %s; channel has no receiver", id)
} }
} }
// Performs one of following operations on the watched addresses in writeLoopParams and the db:
// add | remove | set | clear
func (sds *MockStateDiffService) WatchAddress(operation sdtypes.OperationType, args []sdtypes.WatchAddressArg) error {
// lock writeLoopParams for a write
sds.writeLoopParams.Lock()
defer sds.writeLoopParams.Unlock()
// get the current block number
currentBlockNumber := sds.BlockChain.CurrentBlock().Number()
switch operation {
case sdtypes.Add:
// filter out args having an already watched address with a warning
filteredArgs, ok := funk.Filter(args, func(arg sdtypes.WatchAddressArg) bool {
if funk.Contains(sds.writeLoopParams.WatchedAddresses, common.HexToAddress(arg.Address)) {
log.Warn("Address already being watched", "address", arg.Address)
return false
}
return true
}).([]sdtypes.WatchAddressArg)
if !ok {
return fmt.Errorf("add: filtered args %s", typeAssertionFailed)
}
// get addresses from the filtered args
filteredAddresses, err := statediff.MapWatchAddressArgsToAddresses(filteredArgs)
if err != nil {
return fmt.Errorf("add: filtered addresses %s", err.Error())
}
// update the db
err = sds.Indexer.InsertWatchedAddresses(filteredArgs, currentBlockNumber)
if err != nil {
return err
}
// update in-memory params
sds.writeLoopParams.WatchedAddresses = append(sds.writeLoopParams.WatchedAddresses, filteredAddresses...)
sds.writeLoopParams.ComputeWatchedAddressesLeafKeys()
case sdtypes.Remove:
// get addresses from args
argAddresses, err := statediff.MapWatchAddressArgsToAddresses(args)
if err != nil {
return fmt.Errorf("remove: mapped addresses %s", err.Error())
}
// remove the provided addresses from currently watched addresses
addresses, ok := funk.Subtract(sds.writeLoopParams.WatchedAddresses, argAddresses).([]common.Address)
if !ok {
return fmt.Errorf("remove: filtered addresses %s", typeAssertionFailed)
}
// update the db
err = sds.Indexer.RemoveWatchedAddresses(args)
if err != nil {
return err
}
// update in-memory params
sds.writeLoopParams.WatchedAddresses = addresses
sds.writeLoopParams.ComputeWatchedAddressesLeafKeys()
case sdtypes.Set:
// get addresses from args
argAddresses, err := statediff.MapWatchAddressArgsToAddresses(args)
if err != nil {
return fmt.Errorf("set: mapped addresses %s", err.Error())
}
// update the db
err = sds.Indexer.SetWatchedAddresses(args, currentBlockNumber)
if err != nil {
return err
}
// update in-memory params
sds.writeLoopParams.WatchedAddresses = argAddresses
sds.writeLoopParams.ComputeWatchedAddressesLeafKeys()
case sdtypes.Clear:
// update the db
err := sds.Indexer.ClearWatchedAddresses()
if err != nil {
return err
}
// update in-memory params
sds.writeLoopParams.WatchedAddresses = []common.Address{}
sds.writeLoopParams.ComputeWatchedAddressesLeafKeys()
default:
return fmt.Errorf("%s %s", unexpectedOperation, operation)
}
return nil
}

View File

@ -21,6 +21,7 @@ import (
"fmt" "fmt"
"math/big" "math/big"
"os" "os"
"reflect"
"sort" "sort"
"sync" "sync"
"testing" "testing"
@ -88,6 +89,7 @@ func init() {
func TestAPI(t *testing.T) { func TestAPI(t *testing.T) {
testSubscriptionAPI(t) testSubscriptionAPI(t)
testHTTPAPI(t) testHTTPAPI(t)
testWatchAddressAPI(t)
} }
func testSubscriptionAPI(t *testing.T) { func testSubscriptionAPI(t *testing.T) {
@ -253,3 +255,286 @@ func testHTTPAPI(t *testing.T) {
t.Errorf("paylaod does not have the expected total difficulty\r\nactual td: %d\r\nexpected td: %d", payload.TotalDifficulty.Int64(), mockTotalDifficulty.Int64()) t.Errorf("paylaod does not have the expected total difficulty\r\nactual td: %d\r\nexpected td: %d", payload.TotalDifficulty.Int64(), mockTotalDifficulty.Int64())
} }
} }
func testWatchAddressAPI(t *testing.T) {
blocks, chain := test_helpers.MakeChain(6, test_helpers.Genesis, test_helpers.TestChainGen)
defer chain.Stop()
block6 := blocks[5]
mockBlockChain := &BlockChain{}
mockBlockChain.SetCurrentBlock(block6)
mockIndexer := StateDiffIndexer{}
mockService := MockStateDiffService{
BlockChain: mockBlockChain,
Indexer: &mockIndexer,
}
// test data
var (
contract1Address = "0x5d663F5269090bD2A7DC2390c911dF6083D7b28F"
contract2Address = "0x6Eb7e5C66DB8af2E96159AC440cbc8CDB7fbD26B"
contract3Address = "0xcfeB164C328CA13EFd3C77E1980d94975aDfedfc"
contract4Address = "0x0Edf0c4f393a628DE4828B228C48175b3EA297fc"
contract1CreatedAt = uint64(1)
contract2CreatedAt = uint64(2)
contract3CreatedAt = uint64(3)
contract4CreatedAt = uint64(4)
args1 = []sdtypes.WatchAddressArg{
{
Address: contract1Address,
CreatedAt: contract1CreatedAt,
},
{
Address: contract2Address,
CreatedAt: contract2CreatedAt,
},
}
startingParams1 = statediff.Params{
WatchedAddresses: []common.Address{},
}
expectedParams1 = statediff.Params{
WatchedAddresses: []common.Address{
common.HexToAddress(contract1Address),
common.HexToAddress(contract2Address),
},
}
args2 = []sdtypes.WatchAddressArg{
{
Address: contract3Address,
CreatedAt: contract3CreatedAt,
},
{
Address: contract2Address,
CreatedAt: contract2CreatedAt,
},
}
startingParams2 = expectedParams1
expectedParams2 = statediff.Params{
WatchedAddresses: []common.Address{
common.HexToAddress(contract1Address),
common.HexToAddress(contract2Address),
common.HexToAddress(contract3Address),
},
}
args3 = []sdtypes.WatchAddressArg{
{
Address: contract3Address,
CreatedAt: contract3CreatedAt,
},
{
Address: contract2Address,
CreatedAt: contract2CreatedAt,
},
}
startingParams3 = expectedParams2
expectedParams3 = statediff.Params{
WatchedAddresses: []common.Address{
common.HexToAddress(contract1Address),
},
}
args4 = []sdtypes.WatchAddressArg{
{
Address: contract1Address,
CreatedAt: contract1CreatedAt,
},
{
Address: contract2Address,
CreatedAt: contract2CreatedAt,
},
}
startingParams4 = expectedParams3
expectedParams4 = statediff.Params{
WatchedAddresses: []common.Address{},
}
args5 = []sdtypes.WatchAddressArg{
{
Address: contract1Address,
CreatedAt: contract1CreatedAt,
},
{
Address: contract2Address,
CreatedAt: contract2CreatedAt,
},
{
Address: contract3Address,
CreatedAt: contract3CreatedAt,
},
}
startingParams5 = expectedParams4
expectedParams5 = statediff.Params{
WatchedAddresses: []common.Address{
common.HexToAddress(contract1Address),
common.HexToAddress(contract2Address),
common.HexToAddress(contract3Address),
},
}
args6 = []sdtypes.WatchAddressArg{
{
Address: contract4Address,
CreatedAt: contract4CreatedAt,
},
{
Address: contract2Address,
CreatedAt: contract2CreatedAt,
},
{
Address: contract3Address,
CreatedAt: contract3CreatedAt,
},
}
startingParams6 = expectedParams5
expectedParams6 = statediff.Params{
WatchedAddresses: []common.Address{
common.HexToAddress(contract4Address),
common.HexToAddress(contract2Address),
common.HexToAddress(contract3Address),
},
}
args7 = []sdtypes.WatchAddressArg{}
startingParams7 = expectedParams6
expectedParams7 = statediff.Params{
WatchedAddresses: []common.Address{},
}
args8 = []sdtypes.WatchAddressArg{}
startingParams8 = expectedParams6
expectedParams8 = statediff.Params{
WatchedAddresses: []common.Address{},
}
args9 = []sdtypes.WatchAddressArg{}
startingParams9 = expectedParams8
expectedParams9 = statediff.Params{
WatchedAddresses: []common.Address{},
}
)
tests := []struct {
name string
operation sdtypes.OperationType
args []sdtypes.WatchAddressArg
startingParams statediff.Params
expectedParams statediff.Params
expectedErr error
}{
{
"testAddAddresses",
sdtypes.Add,
args1,
startingParams1,
expectedParams1,
nil,
},
{
"testAddAddressesSomeWatched",
sdtypes.Add,
args2,
startingParams2,
expectedParams2,
nil,
},
{
"testRemoveAddresses",
sdtypes.Remove,
args3,
startingParams3,
expectedParams3,
nil,
},
{
"testRemoveAddressesSomeWatched",
sdtypes.Remove,
args4,
startingParams4,
expectedParams4,
nil,
},
{
"testSetAddresses",
sdtypes.Set,
args5,
startingParams5,
expectedParams5,
nil,
},
{
"testSetAddressesSomeWatched",
sdtypes.Set,
args6,
startingParams6,
expectedParams6,
nil,
},
{
"testSetAddressesEmtpyArgs",
sdtypes.Set,
args7,
startingParams7,
expectedParams7,
nil,
},
{
"testClearAddresses",
sdtypes.Clear,
args8,
startingParams8,
expectedParams8,
nil,
},
{
"testClearAddressesEmpty",
sdtypes.Clear,
args9,
startingParams9,
expectedParams9,
nil,
},
// invalid args
{
"testInvalidOperation",
"WrongOp",
args9,
startingParams9,
statediff.Params{},
fmt.Errorf("%s WrongOp", unexpectedOperation),
},
}
for _, test := range tests {
// set indexing params
mockService.writeLoopParams = statediff.ParamsWithMutex{
Params: test.startingParams,
}
mockService.writeLoopParams.ComputeWatchedAddressesLeafKeys()
// make the API call to change watched addresses
err := mockService.WatchAddress(test.operation, test.args)
if test.expectedErr != nil {
if err.Error() != test.expectedErr.Error() {
t.Logf("Test failed: %s", test.name)
t.Errorf("actual err: %+v\nexpected err: %+v", err, test.expectedErr)
}
continue
}
if err != nil {
t.Error(err)
}
// check updated indexing params
test.expectedParams.ComputeWatchedAddressesLeafKeys()
updatedParams := mockService.writeLoopParams.Params
if !reflect.DeepEqual(updatedParams, test.expectedParams) {
t.Logf("Test failed: %s", test.name)
t.Errorf("actual params: %+v\nexpected params: %+v", updatedParams, test.expectedParams)
}
}
}

View File

@ -101,3 +101,20 @@ type CodeAndCodeHash struct {
type StateNodeSink func(StateNode) error type StateNodeSink func(StateNode) error
type StorageNodeSink func(StorageNode) error type StorageNodeSink func(StorageNode) error
type CodeSink func(CodeAndCodeHash) error type CodeSink func(CodeAndCodeHash) error
// OperationType for type of WatchAddress operation
type OperationType string
const (
Add OperationType = "add"
Remove OperationType = "remove"
Set OperationType = "set"
Clear OperationType = "clear"
)
// WatchAddressArg is a arg type for WatchAddress API
type WatchAddressArg struct {
// Address represents common.Address
Address string
CreatedAt uint64
}