433 lines
		
	
	
		
			14 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
			
		
		
	
	
			433 lines
		
	
	
		
			14 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
| // Copyright 2019 The go-ethereum Authors
 | |
| // This file is part of the go-ethereum library.
 | |
| //
 | |
| // The go-ethereum library is free software: you can redistribute it and/or modify
 | |
| // it under the terms of the GNU Lesser General Public License as published by
 | |
| // the Free Software Foundation, either version 3 of the License, or
 | |
| // (at your option) any later version.
 | |
| //
 | |
| // The go-ethereum library is distributed in the hope that it will be useful,
 | |
| // but WITHOUT ANY WARRANTY; without even the implied warranty of
 | |
| // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
 | |
| // GNU Lesser General Public License for more details.
 | |
| //
 | |
| // You should have received a copy of the GNU Lesser General Public License
 | |
| // along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
 | |
| 
 | |
| package mocks
 | |
| 
 | |
| import (
 | |
| 	"bytes"
 | |
| 	"errors"
 | |
| 	"fmt"
 | |
| 	"sync"
 | |
| 
 | |
| 	"github.com/ethereum/go-ethereum/common"
 | |
| 	"github.com/ethereum/go-ethereum/crypto"
 | |
| 	"github.com/ethereum/go-ethereum/rlp"
 | |
| 	"github.com/thoas/go-funk"
 | |
| 
 | |
| 	"github.com/ethereum/go-ethereum/core"
 | |
| 	"github.com/ethereum/go-ethereum/core/types"
 | |
| 	"github.com/ethereum/go-ethereum/log"
 | |
| 	"github.com/ethereum/go-ethereum/p2p"
 | |
| 	"github.com/ethereum/go-ethereum/rpc"
 | |
| 	"github.com/ethereum/go-ethereum/statediff"
 | |
| 	"github.com/ethereum/go-ethereum/statediff/indexer/interfaces"
 | |
| 	sdtypes "github.com/ethereum/go-ethereum/statediff/types"
 | |
| )
 | |
| 
 | |
| var (
 | |
| 	typeAssertionFailed = "type assertion failed"
 | |
| 	unexpectedOperation = "unexpected operation"
 | |
| )
 | |
| 
 | |
| var _ statediff.IService = &MockStateDiffService{}
 | |
| 
 | |
| // MockStateDiffService is a mock state diff service
 | |
| type MockStateDiffService struct {
 | |
| 	sync.Mutex
 | |
| 	Builder           statediff.Builder
 | |
| 	BlockChain        *BlockChain
 | |
| 	ReturnProtocol    []p2p.Protocol
 | |
| 	ReturnAPIs        []rpc.API
 | |
| 	BlockChan         chan *types.Block
 | |
| 	ParentBlockChan   chan *types.Block
 | |
| 	QuitChan          chan bool
 | |
| 	Subscriptions     map[common.Hash]map[rpc.ID]statediff.Subscription
 | |
| 	SubscriptionTypes map[common.Hash]statediff.Params
 | |
| 	Indexer           interfaces.StateDiffIndexer
 | |
| 	writeLoopParams   statediff.ParamsWithMutex
 | |
| }
 | |
| 
 | |
| // Protocols mock method
 | |
| func (sds *MockStateDiffService) Protocols() []p2p.Protocol {
 | |
| 	return []p2p.Protocol{}
 | |
| }
 | |
| 
 | |
| // APIs mock method
 | |
| func (sds *MockStateDiffService) APIs() []rpc.API {
 | |
| 	return []rpc.API{
 | |
| 		{
 | |
| 			Namespace: statediff.APIName,
 | |
| 			Version:   statediff.APIVersion,
 | |
| 			Service:   statediff.NewPublicStateDiffAPI(sds),
 | |
| 			Public:    true,
 | |
| 		},
 | |
| 	}
 | |
| }
 | |
| 
 | |
| // Loop mock method
 | |
| func (sds *MockStateDiffService) Loop(chan core.ChainEvent) {
 | |
| 	//loop through chain events until no more
 | |
| 	for {
 | |
| 		select {
 | |
| 		case block := <-sds.BlockChan:
 | |
| 			currentBlock := block
 | |
| 			parentBlock := <-sds.ParentBlockChan
 | |
| 			parentHash := parentBlock.Hash()
 | |
| 			if parentBlock == nil {
 | |
| 				log.Error("Parent block is nil, skipping this block",
 | |
| 					"parent block hash", parentHash.String(),
 | |
| 					"current block number", currentBlock.Number())
 | |
| 				continue
 | |
| 			}
 | |
| 			sds.streamStateDiff(currentBlock, parentBlock.Root())
 | |
| 		case <-sds.QuitChan:
 | |
| 			log.Debug("Quitting the statediff block channel")
 | |
| 			sds.close()
 | |
| 			return
 | |
| 		}
 | |
| 	}
 | |
| }
 | |
| 
 | |
| // streamStateDiff method builds the state diff payload for each subscription according to their subscription type and sends them the result
 | |
| func (sds *MockStateDiffService) streamStateDiff(currentBlock *types.Block, parentRoot common.Hash) {
 | |
| 	sds.Lock()
 | |
| 	for ty, subs := range sds.Subscriptions {
 | |
| 		params, ok := sds.SubscriptionTypes[ty]
 | |
| 		if !ok {
 | |
| 			log.Error(fmt.Sprintf("subscriptions type %s do not have a parameter set associated with them", ty.Hex()))
 | |
| 			sds.closeType(ty)
 | |
| 			continue
 | |
| 		}
 | |
| 		// create payload for this subscription type
 | |
| 		payload, err := sds.processStateDiff(currentBlock, parentRoot, params)
 | |
| 		if err != nil {
 | |
| 			log.Error(fmt.Sprintf("statediff processing error for subscriptions with parameters: %+v", params))
 | |
| 			sds.closeType(ty)
 | |
| 			continue
 | |
| 		}
 | |
| 		for id, sub := range subs {
 | |
| 			select {
 | |
| 			case sub.PayloadChan <- *payload:
 | |
| 				log.Debug(fmt.Sprintf("sending statediff payload to subscription %s", id))
 | |
| 			default:
 | |
| 				log.Info(fmt.Sprintf("unable to send statediff payload to subscription %s; channel has no receiver", id))
 | |
| 			}
 | |
| 		}
 | |
| 	}
 | |
| 	sds.Unlock()
 | |
| }
 | |
| 
 | |
| // StateDiffAt mock method
 | |
| func (sds *MockStateDiffService) StateDiffAt(blockNumber uint64, params statediff.Params) (*statediff.Payload, error) {
 | |
| 	currentBlock := sds.BlockChain.GetBlockByNumber(blockNumber)
 | |
| 	log.Info(fmt.Sprintf("sending state diff at %d", blockNumber))
 | |
| 	if blockNumber == 0 {
 | |
| 		return sds.processStateDiff(currentBlock, common.Hash{}, params)
 | |
| 	}
 | |
| 	parentBlock := sds.BlockChain.GetBlockByHash(currentBlock.ParentHash())
 | |
| 	return sds.processStateDiff(currentBlock, parentBlock.Root(), params)
 | |
| }
 | |
| 
 | |
| // StateDiffFor mock method
 | |
| func (sds *MockStateDiffService) StateDiffFor(blockHash common.Hash, params statediff.Params) (*statediff.Payload, error) {
 | |
| 	// TODO: something useful here
 | |
| 	return nil, nil
 | |
| }
 | |
| 
 | |
| // processStateDiff method builds the state diff payload from the current block, parent state root, and provided params
 | |
| func (sds *MockStateDiffService) processStateDiff(currentBlock *types.Block, parentRoot common.Hash, params statediff.Params) (*statediff.Payload, error) {
 | |
| 	stateDiff, err := sds.Builder.BuildStateDiffObject(statediff.Args{
 | |
| 		NewStateRoot: currentBlock.Root(),
 | |
| 		OldStateRoot: parentRoot,
 | |
| 		BlockHash:    currentBlock.Hash(),
 | |
| 		BlockNumber:  currentBlock.Number(),
 | |
| 	}, params)
 | |
| 	if err != nil {
 | |
| 		return nil, err
 | |
| 	}
 | |
| 	stateDiffRlp, err := rlp.EncodeToBytes(&stateDiff)
 | |
| 	if err != nil {
 | |
| 		return nil, err
 | |
| 	}
 | |
| 	return sds.newPayload(stateDiffRlp, currentBlock, params)
 | |
| }
 | |
| 
 | |
| func (sds *MockStateDiffService) newPayload(stateObject []byte, block *types.Block, params statediff.Params) (*statediff.Payload, error) {
 | |
| 	payload := &statediff.Payload{
 | |
| 		StateObjectRlp: stateObject,
 | |
| 	}
 | |
| 	if params.IncludeBlock {
 | |
| 		blockBuff := new(bytes.Buffer)
 | |
| 		if err := block.EncodeRLP(blockBuff); err != nil {
 | |
| 			return nil, err
 | |
| 		}
 | |
| 		payload.BlockRlp = blockBuff.Bytes()
 | |
| 	}
 | |
| 	if params.IncludeTD {
 | |
| 		payload.TotalDifficulty = sds.BlockChain.GetTd(block.Hash(), block.NumberU64())
 | |
| 	}
 | |
| 	if params.IncludeReceipts {
 | |
| 		receiptBuff := new(bytes.Buffer)
 | |
| 		receipts := sds.BlockChain.GetReceiptsByHash(block.Hash())
 | |
| 		if err := rlp.Encode(receiptBuff, receipts); err != nil {
 | |
| 			return nil, err
 | |
| 		}
 | |
| 		payload.ReceiptsRlp = receiptBuff.Bytes()
 | |
| 	}
 | |
| 	return payload, nil
 | |
| }
 | |
| 
 | |
| // WriteStateDiffAt mock method
 | |
| func (sds *MockStateDiffService) WriteStateDiffAt(blockNumber uint64, params statediff.Params) statediff.JobID {
 | |
| 	// TODO: something useful here
 | |
| 	return 0
 | |
| }
 | |
| 
 | |
| // WriteStateDiffFor mock method
 | |
| func (sds *MockStateDiffService) WriteStateDiffFor(blockHash common.Hash, params statediff.Params) error {
 | |
| 	// TODO: something useful here
 | |
| 	return nil
 | |
| }
 | |
| 
 | |
| // Loop mock method
 | |
| func (sds *MockStateDiffService) WriteLoop(chan core.ChainEvent) {
 | |
| 	//loop through chain events until no more
 | |
| 	for {
 | |
| 		select {
 | |
| 		case block := <-sds.BlockChan:
 | |
| 			currentBlock := block
 | |
| 			parentBlock := <-sds.ParentBlockChan
 | |
| 			parentHash := parentBlock.Hash()
 | |
| 			if parentBlock == nil {
 | |
| 				log.Error("Parent block is nil, skipping this block",
 | |
| 					"parent block hash", parentHash.String(),
 | |
| 					"current block number", currentBlock.Number())
 | |
| 				continue
 | |
| 			}
 | |
| 			// TODO:
 | |
| 			// sds.writeStateDiff(currentBlock, parentBlock.Root(), statediff.Params{})
 | |
| 		case <-sds.QuitChan:
 | |
| 			log.Debug("Quitting the statediff block channel")
 | |
| 			sds.close()
 | |
| 			return
 | |
| 		}
 | |
| 	}
 | |
| }
 | |
| 
 | |
| // Subscribe is used by the API to subscribe to the service loop
 | |
| func (sds *MockStateDiffService) Subscribe(id rpc.ID, sub chan<- statediff.Payload, quitChan chan<- bool, params statediff.Params) {
 | |
| 	// Subscription type is defined as the hash of the rlp-serialized subscription params
 | |
| 	by, err := rlp.EncodeToBytes(¶ms)
 | |
| 	if err != nil {
 | |
| 		return
 | |
| 	}
 | |
| 	subscriptionType := crypto.Keccak256Hash(by)
 | |
| 	// Add subscriber
 | |
| 	sds.Lock()
 | |
| 	if sds.Subscriptions[subscriptionType] == nil {
 | |
| 		sds.Subscriptions[subscriptionType] = make(map[rpc.ID]statediff.Subscription)
 | |
| 	}
 | |
| 	sds.Subscriptions[subscriptionType][id] = statediff.Subscription{
 | |
| 		PayloadChan: sub,
 | |
| 		QuitChan:    quitChan,
 | |
| 	}
 | |
| 	sds.SubscriptionTypes[subscriptionType] = params
 | |
| 	sds.Unlock()
 | |
| }
 | |
| 
 | |
| // Unsubscribe is used to unsubscribe from the service loop
 | |
| func (sds *MockStateDiffService) Unsubscribe(id rpc.ID) error {
 | |
| 	sds.Lock()
 | |
| 	for ty := range sds.Subscriptions {
 | |
| 		delete(sds.Subscriptions[ty], id)
 | |
| 		if len(sds.Subscriptions[ty]) == 0 {
 | |
| 			// If we removed the last subscription of this type, remove the subscription type outright
 | |
| 			delete(sds.Subscriptions, ty)
 | |
| 			delete(sds.SubscriptionTypes, ty)
 | |
| 		}
 | |
| 	}
 | |
| 	sds.Unlock()
 | |
| 	return nil
 | |
| }
 | |
| 
 | |
| // close is used to close all listening subscriptions
 | |
| func (sds *MockStateDiffService) close() {
 | |
| 	sds.Lock()
 | |
| 	for ty, subs := range sds.Subscriptions {
 | |
| 		for id, sub := range subs {
 | |
| 			select {
 | |
| 			case sub.QuitChan <- true:
 | |
| 				log.Info(fmt.Sprintf("closing subscription %s", id))
 | |
| 			default:
 | |
| 				log.Info(fmt.Sprintf("unable to close subscription %s; channel has no receiver", id))
 | |
| 			}
 | |
| 			delete(sds.Subscriptions[ty], id)
 | |
| 		}
 | |
| 		delete(sds.Subscriptions, ty)
 | |
| 		delete(sds.SubscriptionTypes, ty)
 | |
| 	}
 | |
| 	sds.Unlock()
 | |
| }
 | |
| 
 | |
| // Start mock method
 | |
| func (sds *MockStateDiffService) Start() error {
 | |
| 	log.Info("Starting mock statediff service")
 | |
| 	if sds.ParentBlockChan == nil || sds.BlockChan == nil {
 | |
| 		return errors.New("MockStateDiffingService needs to be configured with a MockParentBlockChan and MockBlockChan")
 | |
| 	}
 | |
| 	chainEventCh := make(chan core.ChainEvent, 10)
 | |
| 	go sds.Loop(chainEventCh)
 | |
| 
 | |
| 	return nil
 | |
| }
 | |
| 
 | |
| // Stop mock method
 | |
| func (sds *MockStateDiffService) Stop() error {
 | |
| 	log.Info("Stopping mock statediff service")
 | |
| 	close(sds.QuitChan)
 | |
| 	return nil
 | |
| }
 | |
| 
 | |
| // closeType is used to close all subscriptions of given type
 | |
| // closeType needs to be called with subscription access locked
 | |
| func (sds *MockStateDiffService) closeType(subType common.Hash) {
 | |
| 	subs := sds.Subscriptions[subType]
 | |
| 	for id, sub := range subs {
 | |
| 		sendNonBlockingQuit(id, sub)
 | |
| 	}
 | |
| 	delete(sds.Subscriptions, subType)
 | |
| 	delete(sds.SubscriptionTypes, subType)
 | |
| }
 | |
| 
 | |
| func (sds *MockStateDiffService) StreamCodeAndCodeHash(blockNumber uint64, outChan chan<- sdtypes.CodeAndCodeHash, quitChan chan<- bool) {
 | |
| 	panic("implement me")
 | |
| }
 | |
| 
 | |
| func sendNonBlockingQuit(id rpc.ID, sub statediff.Subscription) {
 | |
| 	select {
 | |
| 	case sub.QuitChan <- true:
 | |
| 		log.Info(fmt.Sprintf("closing subscription %s", id))
 | |
| 	default:
 | |
| 		log.Info("unable to close subscription %s; channel has no receiver", id)
 | |
| 	}
 | |
| }
 | |
| 
 | |
| // Performs one of following operations on the watched addresses in writeLoopParams and the db:
 | |
| // add | remove | set | clear
 | |
| func (sds *MockStateDiffService) WatchAddress(operation sdtypes.OperationType, args []sdtypes.WatchAddressArg) error {
 | |
| 	// lock writeLoopParams for a write
 | |
| 	sds.writeLoopParams.Lock()
 | |
| 	defer sds.writeLoopParams.Unlock()
 | |
| 
 | |
| 	// get the current block number
 | |
| 	currentBlockNumber := sds.BlockChain.CurrentBlock().Number
 | |
| 
 | |
| 	switch operation {
 | |
| 	case sdtypes.Add:
 | |
| 		// filter out args having an already watched address with a warning
 | |
| 		filteredArgs, ok := funk.Filter(args, func(arg sdtypes.WatchAddressArg) bool {
 | |
| 			if funk.Contains(sds.writeLoopParams.WatchedAddresses, common.HexToAddress(arg.Address)) {
 | |
| 				log.Warn("Address already being watched", "address", arg.Address)
 | |
| 				return false
 | |
| 			}
 | |
| 			return true
 | |
| 		}).([]sdtypes.WatchAddressArg)
 | |
| 		if !ok {
 | |
| 			return fmt.Errorf("add: filtered args %s", typeAssertionFailed)
 | |
| 		}
 | |
| 
 | |
| 		// get addresses from the filtered args
 | |
| 		filteredAddresses, err := statediff.MapWatchAddressArgsToAddresses(filteredArgs)
 | |
| 		if err != nil {
 | |
| 			return fmt.Errorf("add: filtered addresses %s", err.Error())
 | |
| 		}
 | |
| 
 | |
| 		// update the db
 | |
| 		err = sds.Indexer.InsertWatchedAddresses(filteredArgs, currentBlockNumber)
 | |
| 		if err != nil {
 | |
| 			return err
 | |
| 		}
 | |
| 
 | |
| 		// update in-memory params
 | |
| 		sds.writeLoopParams.WatchedAddresses = append(sds.writeLoopParams.WatchedAddresses, filteredAddresses...)
 | |
| 		sds.writeLoopParams.ComputeWatchedAddressesLeafPaths()
 | |
| 	case sdtypes.Remove:
 | |
| 		// get addresses from args
 | |
| 		argAddresses, err := statediff.MapWatchAddressArgsToAddresses(args)
 | |
| 		if err != nil {
 | |
| 			return fmt.Errorf("remove: mapped addresses %s", err.Error())
 | |
| 		}
 | |
| 
 | |
| 		// remove the provided addresses from currently watched addresses
 | |
| 		addresses, ok := funk.Subtract(sds.writeLoopParams.WatchedAddresses, argAddresses).([]common.Address)
 | |
| 		if !ok {
 | |
| 			return fmt.Errorf("remove: filtered addresses %s", typeAssertionFailed)
 | |
| 		}
 | |
| 
 | |
| 		// update the db
 | |
| 		err = sds.Indexer.RemoveWatchedAddresses(args)
 | |
| 		if err != nil {
 | |
| 			return err
 | |
| 		}
 | |
| 
 | |
| 		// update in-memory params
 | |
| 		sds.writeLoopParams.WatchedAddresses = addresses
 | |
| 		sds.writeLoopParams.ComputeWatchedAddressesLeafPaths()
 | |
| 	case sdtypes.Set:
 | |
| 		// get addresses from args
 | |
| 		argAddresses, err := statediff.MapWatchAddressArgsToAddresses(args)
 | |
| 		if err != nil {
 | |
| 			return fmt.Errorf("set: mapped addresses %s", err.Error())
 | |
| 		}
 | |
| 
 | |
| 		// update the db
 | |
| 		err = sds.Indexer.SetWatchedAddresses(args, currentBlockNumber)
 | |
| 		if err != nil {
 | |
| 			return err
 | |
| 		}
 | |
| 
 | |
| 		// update in-memory params
 | |
| 		sds.writeLoopParams.WatchedAddresses = argAddresses
 | |
| 		sds.writeLoopParams.ComputeWatchedAddressesLeafPaths()
 | |
| 	case sdtypes.Clear:
 | |
| 		// update the db
 | |
| 		err := sds.Indexer.ClearWatchedAddresses()
 | |
| 		if err != nil {
 | |
| 			return err
 | |
| 		}
 | |
| 
 | |
| 		// update in-memory params
 | |
| 		sds.writeLoopParams.WatchedAddresses = []common.Address{}
 | |
| 		sds.writeLoopParams.ComputeWatchedAddressesLeafPaths()
 | |
| 
 | |
| 	default:
 | |
| 		return fmt.Errorf("%s %s", unexpectedOperation, operation)
 | |
| 	}
 | |
| 
 | |
| 	return nil
 | |
| }
 | |
| 
 | |
| // SubscribeWriteStatus is used by the API to subscribe to the job status updates
 | |
| func (sds *MockStateDiffService) SubscribeWriteStatus(id rpc.ID, sub chan<- statediff.JobStatus, quitChan chan<- bool) {
 | |
| 	// TODO when WriteStateDiff methods are implemented
 | |
| }
 | |
| 
 | |
| // UnsubscribeWriteStatus is used to unsubscribe from job status updates
 | |
| func (sds *MockStateDiffService) UnsubscribeWriteStatus(id rpc.ID) error {
 | |
| 	// TODO when WriteStateDiff methods are implemented
 | |
| 	return nil
 | |
| }
 |