Statediff API to change addresses being watched in direct indexing
This commit is contained in:
parent
98c52a02a8
commit
5fa002c0d0
@ -151,6 +151,6 @@ func (api *PublicStateDiffAPI) WriteStateDiffFor(ctx context.Context, blockHash
|
|||||||
}
|
}
|
||||||
|
|
||||||
// WatchAddress adds the given address to a list of watched addresses to which the direct statediff process is restricted
|
// WatchAddress adds the given address to a list of watched addresses to which the direct statediff process is restricted
|
||||||
func (api *PublicStateDiffAPI) WatchAddress(address common.Address) error {
|
func (api *PublicStateDiffAPI) WatchAddress(operation OperationType, addresses []common.Address) error {
|
||||||
return api.sds.WatchAddress(address)
|
return api.sds.WatchAddress(operation, addresses)
|
||||||
}
|
}
|
||||||
|
@ -48,6 +48,7 @@ var (
|
|||||||
service = statediff.Service{}
|
service = statediff.Service{}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// TODO: Update tests for the updated API
|
||||||
func TestWatchAddress(t *testing.T) {
|
func TestWatchAddress(t *testing.T) {
|
||||||
watchedAddresses := service.GetWathchedAddresses()
|
watchedAddresses := service.GetWathchedAddresses()
|
||||||
if !reflect.DeepEqual(watchedAddresses, watchedAddresses0) {
|
if !reflect.DeepEqual(watchedAddresses, watchedAddresses0) {
|
||||||
@ -60,7 +61,7 @@ func TestWatchAddress(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func testWatchUnwatchedAddress(t *testing.T) {
|
func testWatchUnwatchedAddress(t *testing.T) {
|
||||||
err := service.WatchAddress(address1)
|
err := service.WatchAddress(statediff.Add, []common.Address{address1})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Error("Test failure:", t.Name())
|
t.Error("Test failure:", t.Name())
|
||||||
t.Logf("Unexpected error %s thrown on an attempt to watch an unwatched address.", err.Error())
|
t.Logf("Unexpected error %s thrown on an attempt to watch an unwatched address.", err.Error())
|
||||||
@ -71,7 +72,7 @@ func testWatchUnwatchedAddress(t *testing.T) {
|
|||||||
t.Logf("Actual watched addresses not equal expected watched addresses.\nactual: %+v\nexpected: %+v", watchedAddresses, watchedAddresses1)
|
t.Logf("Actual watched addresses not equal expected watched addresses.\nactual: %+v\nexpected: %+v", watchedAddresses, watchedAddresses1)
|
||||||
}
|
}
|
||||||
|
|
||||||
err = service.WatchAddress(address2)
|
err = service.WatchAddress(statediff.Add, []common.Address{address2})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Error("Test failure:", t.Name())
|
t.Error("Test failure:", t.Name())
|
||||||
t.Logf("Unexpected error %s thrown on an attempt to watch an unwatched address.", err.Error())
|
t.Logf("Unexpected error %s thrown on an attempt to watch an unwatched address.", err.Error())
|
||||||
@ -84,7 +85,7 @@ func testWatchUnwatchedAddress(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func testWatchWatchedAddress(t *testing.T) {
|
func testWatchWatchedAddress(t *testing.T) {
|
||||||
err := service.WatchAddress(address1)
|
err := service.WatchAddress(statediff.Add, []common.Address{address1})
|
||||||
if err == nil {
|
if err == nil {
|
||||||
t.Error("Test failure:", t.Name())
|
t.Error("Test failure:", t.Name())
|
||||||
t.Logf("Expected error %s not thrown on an attempt to watch an already watched address.", expectedError.Error())
|
t.Logf("Expected error %s not thrown on an attempt to watch an already watched address.", expectedError.Error())
|
||||||
|
@ -20,8 +20,12 @@
|
|||||||
package statediff
|
package statediff
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"fmt"
|
||||||
"sort"
|
"sort"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
|
"github.com/ethereum/go-ethereum/common"
|
||||||
|
"github.com/ethereum/go-ethereum/statediff/indexer/postgres"
|
||||||
)
|
)
|
||||||
|
|
||||||
func sortKeys(data AccountMap) []string {
|
func sortKeys(data AccountMap) []string {
|
||||||
@ -69,5 +73,51 @@ func findIntersection(a, b []string) []string {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// loadWatchedAddresses is used to load watched addresses to the in-memory write loop params from the db
|
||||||
|
func loadWatchedAddresses(db *postgres.DB) error {
|
||||||
|
rows, err := db.Query("SELECT address FROM eth.watched_addresses")
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("error loading watched addresses: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var watchedAddresses []common.Address
|
||||||
|
for rows.Next() {
|
||||||
|
var addressHex string
|
||||||
|
err := rows.Scan(&addressHex)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
watchedAddresses = append(watchedAddresses, common.HexToAddress(addressHex))
|
||||||
|
}
|
||||||
|
|
||||||
|
writeLoopParams.WatchedAddresses = watchedAddresses
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func removeWatchedAddresses(watchedAddresses []common.Address, addressesToRemove []common.Address) []common.Address {
|
||||||
|
addresses := make([]common.Address, len(addressesToRemove))
|
||||||
|
copy(addresses, watchedAddresses)
|
||||||
|
|
||||||
|
for _, address := range addressesToRemove {
|
||||||
|
if idx := containsAddress(addresses, address); idx != -1 {
|
||||||
|
addresses = append(addresses[:idx], addresses[idx+1:]...)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return addresses
|
||||||
|
}
|
||||||
|
|
||||||
|
// containsAddress is used to check if an address is present in the provided list of watched addresses
|
||||||
|
// return the index if found else -1
|
||||||
|
func containsAddress(watchedAddresses []common.Address, address common.Address) int {
|
||||||
|
for idx, addr := range watchedAddresses {
|
||||||
|
if addr == address {
|
||||||
|
return idx
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return -1
|
||||||
}
|
}
|
||||||
|
@ -59,6 +59,9 @@ type Indexer interface {
|
|||||||
PushStateNode(tx *BlockTx, stateNode sdtypes.StateNode) error
|
PushStateNode(tx *BlockTx, stateNode sdtypes.StateNode) error
|
||||||
PushCodeAndCodeHash(tx *BlockTx, codeAndCodeHash sdtypes.CodeAndCodeHash) error
|
PushCodeAndCodeHash(tx *BlockTx, codeAndCodeHash sdtypes.CodeAndCodeHash) error
|
||||||
ReportDBMetrics(delay time.Duration, quit <-chan bool)
|
ReportDBMetrics(delay time.Duration, quit <-chan bool)
|
||||||
|
InsertWatchedAddresses(addresses []common.Address, currentBlock *big.Int) error
|
||||||
|
RemoveWatchedAddresses(addresses []common.Address) error
|
||||||
|
ClearWatchedAddresses() error
|
||||||
}
|
}
|
||||||
|
|
||||||
// StateDiffIndexer satisfies the Indexer interface for ethereum statediff objects
|
// StateDiffIndexer satisfies the Indexer interface for ethereum statediff objects
|
||||||
@ -549,3 +552,64 @@ func (sdi *StateDiffIndexer) PushCodeAndCodeHash(tx *BlockTx, codeAndCodeHash sd
|
|||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (sdi *StateDiffIndexer) InsertWatchedAddresses(addresses []common.Address, currentBlock *big.Int) error {
|
||||||
|
tx, err := sdi.dbWriter.db.Begin()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, address := range addresses {
|
||||||
|
_, err = tx.Exec(`INSERT INTO eth.watched_addresses (address, added_at) VALUES ($1, $2)`, address, currentBlock)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("error inserting watched_addresses entry: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
err = tx.Commit()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (sdi *StateDiffIndexer) RemoveWatchedAddresses(addresses []common.Address) error {
|
||||||
|
tx, err := sdi.dbWriter.db.Begin()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, address := range addresses {
|
||||||
|
_, err = tx.Exec(`DELETE FROM eth.watched_addresses WHERE address = $1`, address)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("error removing watched_addresses entry: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
err = tx.Commit()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (sdi *StateDiffIndexer) ClearWatchedAddresses() error {
|
||||||
|
tx, err := sdi.dbWriter.db.Begin()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = tx.Exec(`DELETE FROM eth.watched_addresses`)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("error clearing watched_addresses table: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = tx.Commit()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
@ -18,11 +18,8 @@ package statediff
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"encoding/json"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"io/ioutil"
|
|
||||||
"math/big"
|
"math/big"
|
||||||
"os"
|
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
@ -58,9 +55,6 @@ const (
|
|||||||
deadlockDetected = "deadlock detected" // 40P01 https://www.postgresql.org/docs/current/errcodes-appendix.html
|
deadlockDetected = "deadlock detected" // 40P01 https://www.postgresql.org/docs/current/errcodes-appendix.html
|
||||||
)
|
)
|
||||||
|
|
||||||
// TODO: Take the watched addresses file path as a CLI arg.
|
|
||||||
const watchedAddressesFile = "./watched-addresses.json"
|
|
||||||
|
|
||||||
var writeLoopParams = Params{
|
var writeLoopParams = Params{
|
||||||
IntermediateStateNodes: true,
|
IntermediateStateNodes: true,
|
||||||
IntermediateStorageNodes: true,
|
IntermediateStorageNodes: true,
|
||||||
@ -109,7 +103,7 @@ type IService interface {
|
|||||||
// Event loop for progressively processing and writing diffs directly to DB
|
// Event loop for progressively processing and writing diffs directly to DB
|
||||||
WriteLoop(chainEventCh chan core.ChainEvent)
|
WriteLoop(chainEventCh chan core.ChainEvent)
|
||||||
// Method to add an address to be watched to write loop params
|
// Method to add an address to be watched to write loop params
|
||||||
WatchAddress(address common.Address) error
|
WatchAddress(operation OperationType, addresses []common.Address) error
|
||||||
// Method to get currently watched addresses from write loop params
|
// Method to get currently watched addresses from write loop params
|
||||||
GetWathchedAddresses() []common.Address
|
GetWathchedAddresses() []common.Address
|
||||||
}
|
}
|
||||||
@ -170,6 +164,7 @@ func NewBlockCache(max uint) blockCache {
|
|||||||
func New(stack *node.Node, ethServ *eth.Ethereum, cfg *ethconfig.Config, params ServiceParams) error {
|
func New(stack *node.Node, ethServ *eth.Ethereum, cfg *ethconfig.Config, params ServiceParams) error {
|
||||||
blockChain := ethServ.BlockChain()
|
blockChain := ethServ.BlockChain()
|
||||||
var indexer ind.Indexer
|
var indexer ind.Indexer
|
||||||
|
var db *postgres.DB
|
||||||
quitCh := make(chan bool)
|
quitCh := make(chan bool)
|
||||||
if params.DBParams != nil {
|
if params.DBParams != nil {
|
||||||
info := nodeinfo.Info{
|
info := nodeinfo.Info{
|
||||||
@ -212,7 +207,7 @@ func New(stack *node.Node, ethServ *eth.Ethereum, cfg *ethconfig.Config, params
|
|||||||
stack.RegisterLifecycle(sds)
|
stack.RegisterLifecycle(sds)
|
||||||
stack.RegisterAPIs(sds.APIs())
|
stack.RegisterAPIs(sds.APIs())
|
||||||
|
|
||||||
err := loadWatchedAddresses()
|
err := loadWatchedAddresses(db)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@ -731,39 +726,56 @@ func (sds *Service) writeStateDiffWithRetry(block *types.Block, parentRoot commo
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Adds the provided address to the list of watched addresses in write loop params and to the watched addresses file
|
// Performs Add | Remove | Set | Clear operation on the watched addresses in writeLoopParams and the db with provided addresses
|
||||||
func (sds *Service) WatchAddress(address common.Address) error {
|
func (sds *Service) WatchAddress(operation OperationType, addresses []common.Address) error {
|
||||||
// Check if address is already being watched
|
// check operation
|
||||||
if containsAddress(writeLoopParams.WatchedAddresses, address) {
|
switch operation {
|
||||||
return fmt.Errorf("Address %s already watched", address)
|
case Add:
|
||||||
}
|
for _, address := range addresses {
|
||||||
|
// Check if address is already being watched
|
||||||
|
if containsAddress(writeLoopParams.WatchedAddresses, address) != -1 {
|
||||||
|
return fmt.Errorf("Address %s already watched", address)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Check if the watched addresses file exists
|
// TODO: Make sure WriteLoop doesn't call statediffing before the params are updated for the current block
|
||||||
fileExists, err := doesFileExist(watchedAddressesFile)
|
// TODO: Get the current block
|
||||||
if err != nil {
|
err := sds.indexer.InsertWatchedAddresses(addresses, common.Big1)
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Create the watched addresses file if doesn't exist
|
|
||||||
if !fileExists {
|
|
||||||
_, err := os.Create(watchedAddressesFile)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
writeLoopParams.WatchedAddresses = append(writeLoopParams.WatchedAddresses, addresses...)
|
||||||
|
case Remove:
|
||||||
|
err := sds.indexer.RemoveWatchedAddresses(addresses)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
removeWatchedAddresses(writeLoopParams.WatchedAddresses, addresses)
|
||||||
|
case Set:
|
||||||
|
err := sds.indexer.ClearWatchedAddresses()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
err = sds.indexer.InsertWatchedAddresses(addresses, common.Big1)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
writeLoopParams.WatchedAddresses = addresses
|
||||||
|
case Clear:
|
||||||
|
err := sds.indexer.ClearWatchedAddresses()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
writeLoopParams.WatchedAddresses = nil
|
||||||
|
default:
|
||||||
|
return fmt.Errorf("Unexpected operation %s", operation)
|
||||||
}
|
}
|
||||||
|
|
||||||
watchedAddresses := append(writeLoopParams.WatchedAddresses, address)
|
|
||||||
|
|
||||||
// Write the updated list of watched address to a json file
|
|
||||||
content, err := json.Marshal(watchedAddresses)
|
|
||||||
err = ioutil.WriteFile(watchedAddressesFile, content, 0644)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Update the in-memory params as well
|
|
||||||
writeLoopParams.WatchedAddresses = watchedAddresses
|
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -771,50 +783,3 @@ func (sds *Service) WatchAddress(address common.Address) error {
|
|||||||
func (sds *Service) GetWathchedAddresses() []common.Address {
|
func (sds *Service) GetWathchedAddresses() []common.Address {
|
||||||
return writeLoopParams.WatchedAddresses
|
return writeLoopParams.WatchedAddresses
|
||||||
}
|
}
|
||||||
|
|
||||||
// loadWatchedAddresses is used to load watched addresses to the in-memory write loop params from a json file if it exists
|
|
||||||
func loadWatchedAddresses() error {
|
|
||||||
// Check if the watched addresses file exists
|
|
||||||
fileExists, err := doesFileExist(watchedAddressesFile)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
if fileExists {
|
|
||||||
content, err := ioutil.ReadFile(watchedAddressesFile)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
var watchedAddresses []common.Address
|
|
||||||
err = json.Unmarshal(content, &watchedAddresses)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
writeLoopParams.WatchedAddresses = watchedAddresses
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// containsAddress is used to check if an address is present in the provided list of watched addresses
|
|
||||||
func containsAddress(watchedAddresses []common.Address, address common.Address) bool {
|
|
||||||
for _, addr := range watchedAddresses {
|
|
||||||
if addr == address {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
// doesFileExist is used to check if file at a given path exists
|
|
||||||
func doesFileExist(path string) (bool, error) {
|
|
||||||
_, err := os.Stat(path)
|
|
||||||
if err == nil {
|
|
||||||
return true, nil
|
|
||||||
} else if os.IsNotExist(err) {
|
|
||||||
return false, nil
|
|
||||||
}
|
|
||||||
return false, err
|
|
||||||
}
|
|
||||||
|
@ -333,7 +333,7 @@ func sendNonBlockingQuit(id rpc.ID, sub statediff.Subscription) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (sds *MockStateDiffService) WatchAddress(address common.Address) error {
|
func (sds *MockStateDiffService) WatchAddress(operation statediff.OperationType, addresses []common.Address) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -111,3 +111,13 @@ type accountWrapper struct {
|
|||||||
NodeValue []byte
|
NodeValue []byte
|
||||||
LeafKey []byte
|
LeafKey []byte
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// OperationType for type of WatchAddress operation
|
||||||
|
type OperationType string
|
||||||
|
|
||||||
|
const (
|
||||||
|
Add OperationType = "Add"
|
||||||
|
Remove OperationType = "Remove"
|
||||||
|
Set OperationType = "Set"
|
||||||
|
Clear OperationType = "Clear"
|
||||||
|
)
|
||||||
|
Loading…
Reference in New Issue
Block a user