diff --git a/statediff/api.go b/statediff/api.go index 06dab7ec7..275867e50 100644 --- a/statediff/api.go +++ b/statediff/api.go @@ -44,7 +44,7 @@ func NewPublicStateDiffAPI(sds IService) *PublicStateDiffAPI { } // Stream is the public method to setup a subscription that fires off statediff service payloads as they are created -func (api *PublicStateDiffAPI) Stream(ctx context.Context) (*rpc.Subscription, error) { +func (api *PublicStateDiffAPI) Stream(ctx context.Context, params Params) (*rpc.Subscription, error) { // ensure that the RPC connection supports subscriptions notifier, supported := rpc.NotifierFromContext(ctx) if !supported { @@ -58,7 +58,7 @@ func (api *PublicStateDiffAPI) Stream(ctx context.Context) (*rpc.Subscription, e // subscribe to events from the statediff service payloadChannel := make(chan Payload, chainEventChanSize) quitChan := make(chan bool, 1) - api.sds.Subscribe(rpcSub.ID, payloadChannel, quitChan) + api.sds.Subscribe(rpcSub.ID, payloadChannel, quitChan, params) // loop and await payloads and relay them to the subscriber with the notifier for { select { @@ -91,6 +91,6 @@ func (api *PublicStateDiffAPI) Stream(ctx context.Context) (*rpc.Subscription, e } // StateDiffAt returns a statediff payload at the specific blockheight -func (api *PublicStateDiffAPI) StateDiffAt(ctx context.Context, blockNumber uint64) (*Payload, error) { - return api.sds.StateDiffAt(blockNumber) +func (api *PublicStateDiffAPI) StateDiffAt(ctx context.Context, blockNumber uint64, params Params) (*Payload, error) { + return api.sds.StateDiffAt(blockNumber, params) } diff --git a/statediff/service.go b/statediff/service.go index 8de64f996..b8f234008 100644 --- a/statediff/service.go +++ b/statediff/service.go @@ -23,6 +23,8 @@ import ( "sync" "sync/atomic" + "github.com/ethereum/go-ethereum/crypto" + "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/core" "github.com/ethereum/go-ethereum/core/types" @@ -40,7 +42,6 @@ type blockChain interface { SubscribeChainEvent(ch chan<- core.ChainEvent) event.Subscription GetBlockByHash(hash common.Hash) *types.Block GetBlockByNumber(number uint64) *types.Block - AddToStateDiffProcessedCollection(hash common.Hash) GetReceiptsByHash(hash common.Hash) types.Receipts GetTdByHash(hash common.Hash) *big.Int } @@ -52,11 +53,11 @@ type IService interface { // Main event loop for processing state diffs Loop(chainEventCh chan core.ChainEvent) // Method to subscribe to receive state diff processing output - Subscribe(id rpc.ID, sub chan<- Payload, quitChan chan<- bool) + Subscribe(id rpc.ID, sub chan<- Payload, quitChan chan<- bool, params Params) // Method to unsubscribe from state diff processing Unsubscribe(id rpc.ID) error // Method to get statediff at specific block - StateDiffAt(blockNumber uint64) (*Payload, error) + StateDiffAt(blockNumber uint64, params Params) (*Payload, error) } // Service is the underlying struct for the state diffing service @@ -69,25 +70,25 @@ type Service struct { BlockChain blockChain // Used to signal shutdown of the service QuitChan chan bool - // A mapping of rpc.IDs to their subscription channels - Subscriptions map[rpc.ID]Subscription + // A mapping of rpc.IDs to their subscription channels, mapped to their subscription type (hash of the Params rlp) + Subscriptions map[common.Hash]map[rpc.ID]Subscription + // A mapping of subscription params rlp hash to the corresponding subscription params + SubscriptionTypes map[common.Hash]Params // Cache the last block so that we can avoid having to lookup the next block's parent lastBlock *types.Block - // Whether or not the block data is streamed alongside the state diff data in the subscription payload - StreamBlock bool // Whether or not we have any subscribers; only if we do, do we processes state diffs subscribers int32 } // NewStateDiffService creates a new statediff.Service -func NewStateDiffService(blockChain *core.BlockChain, config Config) (*Service, error) { +func NewStateDiffService(blockChain *core.BlockChain) (*Service, error) { return &Service{ - Mutex: sync.Mutex{}, - BlockChain: blockChain, - Builder: NewBuilder(blockChain, config), - QuitChan: make(chan bool), - Subscriptions: make(map[rpc.ID]Subscription), - StreamBlock: config.StreamBlock, + Mutex: sync.Mutex{}, + BlockChain: blockChain, + Builder: NewBuilder(blockChain.StateCache()), + QuitChan: make(chan bool), + Subscriptions: make(map[common.Hash]map[rpc.ID]Subscription), + SubscriptionTypes: make(map[common.Hash]Params), }, nil } @@ -136,16 +137,9 @@ func (sds *Service) Loop(chainEventCh chan core.ChainEvent) { log.Error(fmt.Sprintf("Parent block is nil, skipping this block (%d)", currentBlock.Number())) continue } - payload, err := sds.processStateDiff(currentBlock, parentBlock.Root()) - if err != nil { - log.Error(fmt.Sprintf("Error building statediff for block %d; error: ", currentBlock.Number()) + err.Error()) - continue - } - sds.send(*payload) + sds.streamStateDiff(currentBlock, parentBlock.Root()) case err := <-errCh: - log.Warn("Error from chain event subscription, breaking loop", "error", err) - sds.close() - return + log.Warn("Error from chain event subscription", "error", err) case <-sds.QuitChan: log.Info("Quitting the statediffing process") sds.close() @@ -154,9 +148,43 @@ func (sds *Service) Loop(chainEventCh chan core.ChainEvent) { } } -// processStateDiff method builds the state diff payload from the current and parent block before sending it to listening subscriptions -func (sds *Service) processStateDiff(currentBlock *types.Block, parentRoot common.Hash) (*Payload, error) { - stateDiff, err := sds.Builder.BuildStateDiff(parentRoot, currentBlock.Root(), currentBlock.Number(), currentBlock.Hash()) +// streamStateDiff method builds the state diff payload for each subscription according to their subscription type and sends them the result +func (sds *Service) streamStateDiff(currentBlock *types.Block, parentRoot common.Hash) { + sds.Lock() + for ty, subs := range sds.Subscriptions { + params, ok := sds.SubscriptionTypes[ty] + if !ok { + log.Error(fmt.Sprintf("subscriptions type %s do not have a parameter set associated with them", ty.Hex())) + sds.closeType(ty) + continue + } + // create payload for this subscription type + payload, err := sds.processStateDiff(currentBlock, parentRoot, params) + if err != nil { + log.Error(fmt.Sprintf("statediff processing error for subscriptions with parameters: %+v", params)) + sds.closeType(ty) + continue + } + for id, sub := range subs { + select { + case sub.PayloadChan <- *payload: + log.Debug(fmt.Sprintf("sending statediff payload to subscription %s", id)) + default: + log.Info(fmt.Sprintf("unable to send statediff payload to subscription %s; channel has no receiver", id)) + } + } + } + sds.Unlock() +} + +// processStateDiff method builds the state diff payload from the current block, parent state root, and provided params +func (sds *Service) processStateDiff(currentBlock *types.Block, parentRoot common.Hash, params Params) (*Payload, error) { + stateDiff, err := sds.Builder.BuildStateDiff(Args{ + NewStateRoot: currentBlock.Root(), + OldStateRoot: parentRoot, + BlockHash: currentBlock.Hash(), + BlockNumber: currentBlock.Number(), + }, params) if err != nil { return nil, err } @@ -167,13 +195,17 @@ func (sds *Service) processStateDiff(currentBlock *types.Block, parentRoot commo payload := Payload{ StateDiffRlp: stateDiffRlp, } - if sds.StreamBlock { + if params.IncludeBlock { blockBuff := new(bytes.Buffer) if err = currentBlock.EncodeRLP(blockBuff); err != nil { return nil, err } payload.BlockRlp = blockBuff.Bytes() + } + if params.IncludeTD { payload.TotalDifficulty = sds.BlockChain.GetTdByHash(currentBlock.Hash()) + } + if params.IncludeReceipts { receiptBuff := new(bytes.Buffer) receipts := sds.BlockChain.GetReceiptsByHash(currentBlock.Hash()) if err = rlp.Encode(receiptBuff, receipts); err != nil { @@ -185,28 +217,43 @@ func (sds *Service) processStateDiff(currentBlock *types.Block, parentRoot commo } // Subscribe is used by the API to subscribe to the service loop -func (sds *Service) Subscribe(id rpc.ID, sub chan<- Payload, quitChan chan<- bool) { +func (sds *Service) Subscribe(id rpc.ID, sub chan<- Payload, quitChan chan<- bool, params Params) { log.Info("Subscribing to the statediff service") if atomic.CompareAndSwapInt32(&sds.subscribers, 0, 1) { log.Info("State diffing subscription received; beginning statediff processing") } + // Subscription type is defined as the hash of the rlp-serialized subscription params + by, err := rlp.EncodeToBytes(params) + if err != nil { + log.Error("State diffing params need to be rlp-serializable") + return + } + subscriptionType := crypto.Keccak256Hash(by) + // Add subscriber sds.Lock() - sds.Subscriptions[id] = Subscription{ + if sds.Subscriptions[subscriptionType] == nil { + sds.Subscriptions[subscriptionType] = make(map[rpc.ID]Subscription) + } + sds.Subscriptions[subscriptionType][id] = Subscription{ PayloadChan: sub, QuitChan: quitChan, } + sds.SubscriptionTypes[subscriptionType] = params sds.Unlock() } // Unsubscribe is used to unsubscribe from the service loop func (sds *Service) Unsubscribe(id rpc.ID) error { - log.Info("Unsubscribing from the statediff service") + log.Info(fmt.Sprintf("Unsubscribing subscription %s from the statediff service", id)) sds.Lock() - _, ok := sds.Subscriptions[id] - if !ok { - return fmt.Errorf("cannot unsubscribe; subscription for id %s does not exist", id) + for ty := range sds.Subscriptions { + delete(sds.Subscriptions[ty], id) + if len(sds.Subscriptions[ty]) == 0 { + // If we removed the last subscription of this type, remove the subscription type outright + delete(sds.Subscriptions, ty) + delete(sds.SubscriptionTypes, ty) + } } - delete(sds.Subscriptions, id) if len(sds.Subscriptions) == 0 { if atomic.CompareAndSwapInt32(&sds.subscribers, 1, 0) { log.Info("No more subscriptions; halting statediff processing") @@ -226,6 +273,18 @@ func (sds *Service) Start(*p2p.Server) error { return nil } +// StateDiffAt returns a statediff payload at the specific blockheight +// This operation cannot be performed back past the point of db pruning; it requires an archival node +func (sds *Service) StateDiffAt(blockNumber uint64, params Params) (*Payload, error) { + currentBlock := sds.BlockChain.GetBlockByNumber(blockNumber) + log.Info(fmt.Sprintf("sending state diff at %d", blockNumber)) + if blockNumber == 0 { + return sds.processStateDiff(currentBlock, common.Hash{}, params) + } + parentBlock := sds.BlockChain.GetBlockByHash(currentBlock.ParentHash()) + return sds.processStateDiff(currentBlock, parentBlock.Root(), params) +} + // Stop is used to close down the service func (sds *Service) Stop() error { log.Info("Stopping statediff service") @@ -233,57 +292,41 @@ func (sds *Service) Stop() error { return nil } -// send is used to fan out and serve the payloads to all subscriptions -func (sds *Service) send(payload Payload) { +// close is used to close all listening subscriptions +func (sds *Service) close() { sds.Lock() - for id, sub := range sds.Subscriptions { - select { - case sub.PayloadChan <- payload: - log.Info(fmt.Sprintf("sending state diff payload to subscription %s", id)) - default: - log.Info(fmt.Sprintf("unable to send payload to subscription %s; channel has no receiver", id)) - // in this case, try to close the bad subscription and remove it + for ty, subs := range sds.Subscriptions { + for id, sub := range subs { select { case sub.QuitChan <- true: log.Info(fmt.Sprintf("closing subscription %s", id)) default: log.Info(fmt.Sprintf("unable to close subscription %s; channel has no receiver", id)) } - delete(sds.Subscriptions, id) - } - } - // If after removing all bad subscriptions we have none left, halt processing - if len(sds.Subscriptions) == 0 { - if atomic.CompareAndSwapInt32(&sds.subscribers, 1, 0) { - log.Info("No more subscriptions; halting statediff processing") + delete(sds.Subscriptions[ty], id) } + delete(sds.Subscriptions, ty) + delete(sds.SubscriptionTypes, ty) } sds.Unlock() } -// close is used to close all listening subscriptions -func (sds *Service) close() { - sds.Lock() - for id, sub := range sds.Subscriptions { - select { - case sub.QuitChan <- true: - log.Info(fmt.Sprintf("closing subscription %s", id)) - default: - log.Info(fmt.Sprintf("unable to close subscription %s; channel has no receiver", id)) - } - delete(sds.Subscriptions, id) +// closeType is used to close all subscriptions of given type +// closeType needs to be called with subscription access locked +func (sds *Service) closeType(subType common.Hash) { + subs := sds.Subscriptions[subType] + for id, sub := range subs { + sendNonBlockingQuit(id, sub) } - sds.Unlock() + delete(sds.Subscriptions, subType) + delete(sds.SubscriptionTypes, subType) } -// StateDiffAt returns a statediff payload at the specific blockheight -// This operation cannot be performed back past the point of db pruning; it requires an archival node -func (sds *Service) StateDiffAt(blockNumber uint64) (*Payload, error) { - currentBlock := sds.BlockChain.GetBlockByNumber(blockNumber) - log.Info(fmt.Sprintf("sending state diff at %d", blockNumber)) - if blockNumber == 0 { - return sds.processStateDiff(currentBlock, common.Hash{}) +func sendNonBlockingQuit(id rpc.ID, sub Subscription) { + select { + case sub.QuitChan <- true: + log.Info(fmt.Sprintf("closing subscription %s", id)) + default: + log.Info("unable to close subscription %s; channel has no receiver", id) } - parentBlock := sds.BlockChain.GetBlockByHash(currentBlock.ParentHash()) - return sds.processStateDiff(currentBlock, parentBlock.Root()) }