diff --git a/statediff/service/service.go b/statediff/service/service.go index 2b93a1dd1..dbb898915 100644 --- a/statediff/service/service.go +++ b/statediff/service/service.go @@ -49,34 +49,69 @@ func (StateDiffService) APIs() []rpc.API { return []rpc.API{} } -func (sds *StateDiffService) Loop(events chan core.ChainEvent) { - for elem := range events { - currentBlock := elem.Block - parentHash := currentBlock.ParentHash() - parentBlock := sds.BlockChain.GetBlockByHash(parentHash) +func (sds *StateDiffService) Loop(chainEventCh chan core.ChainEvent) { + chainEventSub := sds.BlockChain.SubscribeChainEvent(chainEventCh) + defer chainEventSub.Unsubscribe() - stateDiffLocation, err := sds.Extractor.ExtractStateDiff(*parentBlock, *currentBlock) - if err != nil { - log.Error("Error extracting statediff", "block number", currentBlock.Number(), "error", err) - } else { - log.Info("Statediff extracted", "block number", currentBlock.Number(), "location", stateDiffLocation) + blocksCh := make(chan *types.Block, 10) + errCh := chainEventSub.Err() + quitCh := make(chan struct{}) + + go func() { + HandleChainEventChLoop: + for { + select { + //Notify chain event channel of events + case chainEvent := <-chainEventCh: + log.Debug("Event received from chainEventCh", "event", chainEvent) + blocksCh <- chainEvent.Block + //if node stopped + case err := <-errCh: + log.Warn("Error from chain event subscription, breaking loop.", "error", err) + break HandleChainEventChLoop + } + } + close(quitCh) + }() + + //loop through chain events until no more +HandleBlockChLoop: + for { + select { + case block := <-blocksCh: + currentBlock := block + parentHash := currentBlock.ParentHash() + parentBlock := sds.BlockChain.GetBlockByHash(parentHash) + if parentBlock == nil { + log.Error("Parent block is nil, skipping this block", + "parent block hash", parentHash.String(), + "current block number", currentBlock.Number()) + break HandleBlockChLoop + } + + stateDiffLocation, err := sds.Extractor.ExtractStateDiff(*parentBlock, *currentBlock) + if err != nil { + log.Error("Error extracting statediff", "block number", currentBlock.Number(), "error", err) + } else { + log.Info("Statediff extracted", "block number", currentBlock.Number(), "location", stateDiffLocation) + } + case <-quitCh: + log.Debug("Quitting the statediff block channel") + return } } } -var eventsChannel chan core.ChainEvent - func (sds *StateDiffService) Start(server *p2p.Server) error { log.Info("Starting statediff service") - eventsChannel := make(chan core.ChainEvent, 10) - sds.BlockChain.SubscribeChainEvent(eventsChannel) - go sds.Loop(eventsChannel) + + chainEventCh := make(chan core.ChainEvent, 10) + go sds.Loop(chainEventCh) + return nil } func (StateDiffService) Stop() error { log.Info("Stopping statediff service") - close(eventsChannel) - return nil } diff --git a/statediff/service/service_test.go b/statediff/service/service_test.go index 15b4bc1b2..daf3445c3 100644 --- a/statediff/service/service_test.go +++ b/statediff/service/service_test.go @@ -9,16 +9,17 @@ import ( "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/core" "github.com/ethereum/go-ethereum/core/types" - service2 "github.com/ethereum/go-ethereum/statediff/service" + s "github.com/ethereum/go-ethereum/statediff/service" "github.com/ethereum/go-ethereum/statediff/testhelpers/mocks" ) func TestServiceLoop(t *testing.T) { - testServiceLoop(t) + testErrorInChainEventLoop(t) + testErrorInBlockLoop(t) } var ( - eventsChannel = make(chan core.ChainEvent, 10) + eventsChannel = make(chan core.ChainEvent, 1) parentHeader1 = types.Header{Number: big.NewInt(rand.Int63())} parentHeader2 = types.Header{Number: big.NewInt(rand.Int63())} @@ -31,29 +32,30 @@ var ( header1 = types.Header{ParentHash: parentHash1} header2 = types.Header{ParentHash: parentHash2} + header3 = types.Header{ParentHash: common.HexToHash("parent hash")} block1 = types.NewBlock(&header1, nil, nil, nil) block2 = types.NewBlock(&header2, nil, nil, nil) + block3 = types.NewBlock(&header3, nil, nil, nil) event1 = core.ChainEvent{Block: block1} event2 = core.ChainEvent{Block: block2} + event3 = core.ChainEvent{Block: block3} ) -func testServiceLoop(t *testing.T) { - eventsChannel <- event1 - eventsChannel <- event2 - +func testErrorInChainEventLoop(t *testing.T) { + //the first chain event causes and error (in blockchain mock) extractor := mocks.Extractor{} - close(eventsChannel) blockChain := mocks.BlockChain{} - service := service2.StateDiffService{ + service := s.StateDiffService{ Builder: nil, Extractor: &extractor, BlockChain: &blockChain, } - blockChain.SetParentBlockToReturn([]*types.Block{parentBlock1, parentBlock2}) + blockChain.SetParentBlocksToReturn([]*types.Block{parentBlock1, parentBlock2}) + blockChain.SetChainEvents([]core.ChainEvent{event1, event2, event3}) service.Loop(eventsChannel) //parent and current blocks are passed to the extractor @@ -75,3 +77,31 @@ func testServiceLoop(t *testing.T) { t.Logf("Actual does not equal expected.\nactual:%+v\nexpected: %+v", blockChain.ParentHashesLookedUp, expectedHashes) } } + +func testErrorInBlockLoop(t *testing.T) { + //second block's parent block can't be found + extractor := mocks.Extractor{} + + blockChain := mocks.BlockChain{} + service := s.StateDiffService{ + Builder: nil, + Extractor: &extractor, + BlockChain: &blockChain, + } + + blockChain.SetParentBlocksToReturn([]*types.Block{parentBlock1, nil}) + blockChain.SetChainEvents([]core.ChainEvent{event1, event2}) + service.Loop(eventsChannel) + + //only the first current block (and it's parent) are passed to the extractor + expectedCurrentBlocks := []types.Block{*block1} + if !reflect.DeepEqual(extractor.CurrentBlocks, expectedCurrentBlocks) { + t.Error("Test failure:", t.Name()) + t.Logf("Actual does not equal expected.\nactual:%+v\nexpected: %+v", extractor.CurrentBlocks, expectedCurrentBlocks) + } + expectedParentBlocks := []types.Block{*parentBlock1} + if !reflect.DeepEqual(extractor.ParentBlocks, expectedParentBlocks) { + t.Error("Test failure:", t.Name()) + t.Logf("Actual does not equal expected.\nactual:%+v\nexpected: %+v", extractor.CurrentBlocks, expectedParentBlocks) + } +} diff --git a/statediff/testhelpers/mocks/blockchain.go b/statediff/testhelpers/mocks/blockchain.go index baa7b3cec..f2d77ea34 100644 --- a/statediff/testhelpers/mocks/blockchain.go +++ b/statediff/testhelpers/mocks/blockchain.go @@ -1,6 +1,8 @@ package mocks import ( + "errors" + "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/core" "github.com/ethereum/go-ethereum/core/types" @@ -11,24 +13,47 @@ type BlockChain struct { ParentHashesLookedUp []common.Hash parentBlocksToReturn []*types.Block callCount int + ChainEvents []core.ChainEvent } -func (mc *BlockChain) SetParentBlockToReturn(blocks []*types.Block) { +func (mc *BlockChain) SetParentBlocksToReturn(blocks []*types.Block) { mc.parentBlocksToReturn = blocks } func (mc *BlockChain) GetBlockByHash(hash common.Hash) *types.Block { mc.ParentHashesLookedUp = append(mc.ParentHashesLookedUp, hash) - var parentBlock types.Block + var parentBlock *types.Block if len(mc.parentBlocksToReturn) > 0 { - parentBlock = *mc.parentBlocksToReturn[mc.callCount] + parentBlock = mc.parentBlocksToReturn[mc.callCount] } mc.callCount++ - return &parentBlock + return parentBlock } -func (BlockChain) SubscribeChainEvent(ch chan<- core.ChainEvent) event.Subscription { - panic("implement me") +func (bc *BlockChain) SetChainEvents(chainEvents []core.ChainEvent) { + bc.ChainEvents = chainEvents +} + +func (bc *BlockChain) SubscribeChainEvent(ch chan<- core.ChainEvent) event.Subscription { + subErr := errors.New("Subscription Error") + + var eventCounter int + subscription := event.NewSubscription(func(quit <-chan struct{}) error { + for _, chainEvent := range bc.ChainEvents { + if eventCounter > 1 { + return subErr + } + select { + case ch <- chainEvent: + case <-quit: + return nil + } + eventCounter++ + } + return nil + }) + + return subscription }