diff --git a/statediff/api.go b/statediff/api.go index 52c604f97..06dab7ec7 100644 --- a/statediff/api.go +++ b/statediff/api.go @@ -89,3 +89,8 @@ func (api *PublicStateDiffAPI) Stream(ctx context.Context) (*rpc.Subscription, e return rpcSub, nil } + +// StateDiffAt returns a statediff payload at the specific blockheight +func (api *PublicStateDiffAPI) StateDiffAt(ctx context.Context, blockNumber uint64) (*Payload, error) { + return api.sds.StateDiffAt(blockNumber) +} diff --git a/statediff/service.go b/statediff/service.go index d3eab1065..f7bf30cd4 100644 --- a/statediff/service.go +++ b/statediff/service.go @@ -39,6 +39,7 @@ const chainEventChanSize = 20000 type blockChain interface { SubscribeChainEvent(ch chan<- core.ChainEvent) event.Subscription GetBlockByHash(hash common.Hash) *types.Block + GetBlockByNumber(number uint64) *types.Block AddToStateDiffProcessedCollection(hash common.Hash) GetReceiptsByHash(hash common.Hash) types.Receipts } @@ -53,6 +54,8 @@ type IService interface { Subscribe(id rpc.ID, sub chan<- Payload, quitChan chan<- bool) // Method to unsubscribe from state diff processing Unsubscribe(id rpc.ID) error + // Method to get statediff at specific block + StateDiffAt(blockNumber uint64) (*Payload, error) } // Service is the underlying struct for the state diffing service @@ -132,9 +135,11 @@ func (sds *Service) Loop(chainEventCh chan core.ChainEvent) { log.Error(fmt.Sprintf("Parent block is nil, skipping this block (%d)", currentBlock.Number())) continue } - if err := sds.processStateDiff(currentBlock, parentBlock); err != nil { + payload, err := sds.processStateDiff(currentBlock, parentBlock) + if err != nil { log.Error(fmt.Sprintf("Error building statediff for block %d; error: ", currentBlock.Number()) + err.Error()) } + sds.send(*payload) case err := <-errCh: log.Warn("Error from chain event subscription, breaking loop", "error", err) sds.close() @@ -148,14 +153,14 @@ func (sds *Service) Loop(chainEventCh chan core.ChainEvent) { } // processStateDiff method builds the state diff payload from the current and parent block before sending it to listening subscriptions -func (sds *Service) processStateDiff(currentBlock, parentBlock *types.Block) error { +func (sds *Service) processStateDiff(currentBlock, parentBlock *types.Block) (*Payload, error) { stateDiff, err := sds.Builder.BuildStateDiff(parentBlock.Root(), currentBlock.Root(), currentBlock.Number(), currentBlock.Hash()) if err != nil { - return err + return nil, err } stateDiffRlp, err := rlp.EncodeToBytes(stateDiff) if err != nil { - return err + return nil, err } payload := Payload{ StateDiffRlp: stateDiffRlp, @@ -163,19 +168,17 @@ func (sds *Service) processStateDiff(currentBlock, parentBlock *types.Block) err if sds.StreamBlock { blockBuff := new(bytes.Buffer) if err = currentBlock.EncodeRLP(blockBuff); err != nil { - return err + return nil, err } payload.BlockRlp = blockBuff.Bytes() receiptBuff := new(bytes.Buffer) receipts := sds.BlockChain.GetReceiptsByHash(currentBlock.Hash()) if err = rlp.Encode(receiptBuff, receipts); err != nil { - return err + return nil, err } payload.ReceiptsRlp = receiptBuff.Bytes() } - - sds.send(payload) - return nil + return &payload, nil } // Subscribe is used by the API to subscribe to the service loop @@ -269,3 +272,12 @@ func (sds *Service) close() { } sds.Unlock() } + +// StateDiffAt returns a statediff payload at the specific blockheight +// This operation cannot be performed back past the point of db pruning; it requires an archival node +func (sds *Service) StateDiffAt(blockNumber uint64) (*Payload, error) { + currentBlock := sds.BlockChain.GetBlockByNumber(blockNumber) + parentBlock := sds.BlockChain.GetBlockByHash(currentBlock.ParentHash()) + log.Info(fmt.Sprintf("sending state diff at %d", blockNumber)) + return sds.processStateDiff(currentBlock, parentBlock) +} diff --git a/statediff/service_test.go b/statediff/service_test.go index 6119f6ecb..81d68344f 100644 --- a/statediff/service_test.go +++ b/statediff/service_test.go @@ -194,3 +194,75 @@ func testErrorInBlockLoop(t *testing.T) { t.Logf("Actual does not equal expected.\nactual:%+v\nexpected: %+v", builder.NewStateRoot, testBlock1.Root()) } } + +func TestGetStateDiffAt(t *testing.T) { + testErrorInStateDiffAt(t) +} + +func testErrorInStateDiffAt(t *testing.T) { + mockStateDiff := statediff.StateDiff{ + BlockNumber: testBlock1.Number(), + BlockHash: testBlock1.Hash(), + } + expectedStateDiffRlp, err := rlp.EncodeToBytes(mockStateDiff) + if err != nil { + t.Error(err) + } + expectedReceiptsRlp, err := rlp.EncodeToBytes(testReceipts1) + if err != nil { + t.Error(err) + } + expectedBlockRlp, err := rlp.EncodeToBytes(testBlock1) + if err != nil { + t.Error(err) + } + expectedStateDiffPayload := statediff.Payload{ + StateDiffRlp: expectedStateDiffRlp, + ReceiptsRlp: expectedReceiptsRlp, + BlockRlp: expectedBlockRlp, + } + expectedStateDiffPayloadRlp, err := rlp.EncodeToBytes(expectedStateDiffPayload) + if err != nil { + t.Error(err) + } + builder := mocks.Builder{} + builder.SetStateDiffToBuild(mockStateDiff) + blockChain := mocks.BlockChain{} + blockMapping := make(map[common.Hash]*types.Block) + blockMapping[parentBlock1.Hash()] = parentBlock1 + blockChain.SetParentBlocksToReturn(blockMapping) + blockChain.SetBlockForNumber(testBlock1, testBlock1.NumberU64()) + blockChain.SetReceiptsForHash(testBlock1.Hash(), testReceipts1) + service := statediff.Service{ + Mutex: sync.Mutex{}, + Builder: &builder, + BlockChain: &blockChain, + QuitChan: make(chan bool), + Subscriptions: make(map[rpc.ID]statediff.Subscription), + StreamBlock: true, + } + stateDiffPayload, err := service.StateDiffAt(testBlock1.NumberU64()) + if err != nil { + t.Error(err) + } + stateDiffPayloadRlp, err := rlp.EncodeToBytes(stateDiffPayload) + if err != nil { + t.Error(err) + } + if !bytes.Equal(builder.BlockHash.Bytes(), testBlock1.Hash().Bytes()) { + t.Error("Test failure:", t.Name()) + t.Logf("Actual does not equal expected.\nactual:%+v\nexpected: %+v", builder.BlockHash, testBlock1.Hash()) + } + if !bytes.Equal(builder.OldStateRoot.Bytes(), parentBlock1.Root().Bytes()) { + t.Error("Test failure:", t.Name()) + t.Logf("Actual does not equal expected.\nactual:%+v\nexpected: %+v", builder.OldStateRoot, parentBlock1.Root()) + } + if !bytes.Equal(builder.NewStateRoot.Bytes(), testBlock1.Root().Bytes()) { + t.Error("Test failure:", t.Name()) + t.Logf("Actual does not equal expected.\nactual:%+v\nexpected: %+v", builder.NewStateRoot, testBlock1.Root()) + } + if !bytes.Equal(expectedStateDiffPayloadRlp, stateDiffPayloadRlp) { + t.Error("Test failure:", t.Name()) + t.Logf("Actual does not equal expected.\nactual:%+v\nexpected: %+v", expectedStateDiffPayload, stateDiffPayload) + } +} diff --git a/statediff/testhelpers/mocks/api.go b/statediff/testhelpers/mocks/api.go index 3b43ab7dd..12ebf0744 100644 --- a/statediff/testhelpers/mocks/api.go +++ b/statediff/testhelpers/mocks/api.go @@ -185,3 +185,8 @@ func (sds *MockStateDiffService) Stop() error { close(sds.QuitChan) return nil } + +// StateDiffAt mock method +func (sds *MockStateDiffService) StateDiffAt(blockNumber uint64) (*statediff.Payload, error) { + panic("implement me") +} diff --git a/statediff/testhelpers/mocks/blockchain.go b/statediff/testhelpers/mocks/blockchain.go index 508435236..4cbf8cb11 100644 --- a/statediff/testhelpers/mocks/blockchain.go +++ b/statediff/testhelpers/mocks/blockchain.go @@ -29,11 +29,12 @@ import ( // BlockChain is a mock blockchain for testing type BlockChain struct { - ParentHashesLookedUp []common.Hash - parentBlocksToReturn map[common.Hash]*types.Block - callCount int - ChainEvents []core.ChainEvent - Receipts map[common.Hash]types.Receipts + ParentHashesLookedUp []common.Hash + parentBlocksToReturn map[common.Hash]*types.Block + parentBlocksToReturnByNumber map[uint64]*types.Block + callCount int + ChainEvents []core.ChainEvent + Receipts map[common.Hash]types.Receipts } // AddToStateDiffProcessedCollection mock method @@ -100,3 +101,16 @@ func (blockChain *BlockChain) SetReceiptsForHash(hash common.Hash, receipts type func (blockChain *BlockChain) GetReceiptsByHash(hash common.Hash) types.Receipts { return blockChain.Receipts[hash] } + +// SetBlockForNumber mock method +func (blockChain *BlockChain) SetBlockForNumber(block *types.Block, number uint64) { + if blockChain.parentBlocksToReturnByNumber == nil { + blockChain.parentBlocksToReturnByNumber = make(map[uint64]*types.Block) + } + blockChain.parentBlocksToReturnByNumber[number] = block +} + +// GetBlockByNumber mock method +func (blockChain *BlockChain) GetBlockByNumber(number uint64) *types.Block { + return blockChain.parentBlocksToReturnByNumber[number] +}