From b6dc2d509e14eb17e5a68743c144807264555bdb Mon Sep 17 00:00:00 2001 From: Elizabeth Engelman Date: Mon, 28 Jan 2019 12:47:52 -0600 Subject: [PATCH] Gracefully exit geth command --- statediff/service/service.go | 67 +++++++++++++++++------ statediff/service/service_test.go | 11 ++-- statediff/testhelpers/mocks/blockchain.go | 28 +++++++++- 3 files changed, 81 insertions(+), 25 deletions(-) diff --git a/statediff/service/service.go b/statediff/service/service.go index 2b93a1dd1..065b0ea18 100644 --- a/statediff/service/service.go +++ b/statediff/service/service.go @@ -49,34 +49,67 @@ func (StateDiffService) APIs() []rpc.API { return []rpc.API{} } -func (sds *StateDiffService) Loop(events chan core.ChainEvent) { - for elem := range events { - currentBlock := elem.Block - parentHash := currentBlock.ParentHash() - parentBlock := sds.BlockChain.GetBlockByHash(parentHash) +func (sds *StateDiffService) Loop(chainEventCh chan core.ChainEvent) { + chainEventSub := sds.BlockChain.SubscribeChainEvent(chainEventCh) + defer chainEventSub.Unsubscribe() - stateDiffLocation, err := sds.Extractor.ExtractStateDiff(*parentBlock, *currentBlock) - if err != nil { - log.Error("Error extracting statediff", "block number", currentBlock.Number(), "error", err) - } else { - log.Info("Statediff extracted", "block number", currentBlock.Number(), "location", stateDiffLocation) + blocksCh := make(chan *types.Block) + errCh := chainEventSub.Err() + quitCh := make(chan struct{}) + + go func() { + HandleLoop: + for { + select { + //Notify chain event channel of events + case chainEvent := <-chainEventCh: + log.Debug("Event received from chainEventCh", "event", chainEvent) + blocksCh <- chainEvent.Block + //if node stopped + case err := <-errCh: + log.Debug("Error from chain event subscription, breaking loop.", "error", err) + break HandleLoop + } + } + close(quitCh) + }() + + //loop through chain events until no more + for { + select { + case block := <-blocksCh: + currentBlock := block + parentHash := currentBlock.ParentHash() + parentBlock := sds.BlockChain.GetBlockByHash(parentHash) + if parentBlock == nil { + log.Warn("Parent block is nil, skipping this block", + "parent block hash", parentHash.String(), + "current block number", currentBlock.Number()) + break + } + + stateDiffLocation, err := sds.Extractor.ExtractStateDiff(*parentBlock, *currentBlock) + if err != nil { + log.Error("Error extracting statediff", "block number", currentBlock.Number(), "error", err) + } else { + log.Info("Statediff extracted", "block number", currentBlock.Number(), "location", stateDiffLocation) + } + case <-quitCh: + return } } } -var eventsChannel chan core.ChainEvent - func (sds *StateDiffService) Start(server *p2p.Server) error { log.Info("Starting statediff service") - eventsChannel := make(chan core.ChainEvent, 10) - sds.BlockChain.SubscribeChainEvent(eventsChannel) - go sds.Loop(eventsChannel) + + chainEventCh := make(chan core.ChainEvent, 10) + go sds.Loop(chainEventCh) + return nil } func (StateDiffService) Stop() error { log.Info("Stopping statediff service") - close(eventsChannel) - return nil } diff --git a/statediff/service/service_test.go b/statediff/service/service_test.go index 15b4bc1b2..c4cbc186a 100644 --- a/statediff/service/service_test.go +++ b/statediff/service/service_test.go @@ -9,7 +9,7 @@ import ( "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/core" "github.com/ethereum/go-ethereum/core/types" - service2 "github.com/ethereum/go-ethereum/statediff/service" + s "github.com/ethereum/go-ethereum/statediff/service" "github.com/ethereum/go-ethereum/statediff/testhelpers/mocks" ) @@ -18,7 +18,7 @@ func TestServiceLoop(t *testing.T) { } var ( - eventsChannel = make(chan core.ChainEvent, 10) + eventsChannel = make(chan core.ChainEvent) parentHeader1 = types.Header{Number: big.NewInt(rand.Int63())} parentHeader2 = types.Header{Number: big.NewInt(rand.Int63())} @@ -40,20 +40,19 @@ var ( ) func testServiceLoop(t *testing.T) { - eventsChannel <- event1 - eventsChannel <- event2 extractor := mocks.Extractor{} - close(eventsChannel) + //close(eventsChannel) blockChain := mocks.BlockChain{} - service := service2.StateDiffService{ + service := s.StateDiffService{ Builder: nil, Extractor: &extractor, BlockChain: &blockChain, } blockChain.SetParentBlockToReturn([]*types.Block{parentBlock1, parentBlock2}) + blockChain.SetChainEvents([]core.ChainEvent{event1, event2}) service.Loop(eventsChannel) //parent and current blocks are passed to the extractor diff --git a/statediff/testhelpers/mocks/blockchain.go b/statediff/testhelpers/mocks/blockchain.go index baa7b3cec..0ea92fa9e 100644 --- a/statediff/testhelpers/mocks/blockchain.go +++ b/statediff/testhelpers/mocks/blockchain.go @@ -5,12 +5,14 @@ import ( "github.com/ethereum/go-ethereum/core" "github.com/ethereum/go-ethereum/core/types" "github.com/ethereum/go-ethereum/event" + "errors" ) type BlockChain struct { ParentHashesLookedUp []common.Hash parentBlocksToReturn []*types.Block callCount int + ChainEvents []core.ChainEvent } func (mc *BlockChain) SetParentBlockToReturn(blocks []*types.Block) { @@ -29,6 +31,28 @@ func (mc *BlockChain) GetBlockByHash(hash common.Hash) *types.Block { return &parentBlock } -func (BlockChain) SubscribeChainEvent(ch chan<- core.ChainEvent) event.Subscription { - panic("implement me") +func (bc *BlockChain) SetChainEvents(chainEvents []core.ChainEvent) { + bc.ChainEvents = chainEvents +} + +func (bc *BlockChain) SubscribeChainEvent(ch chan<- core.ChainEvent) event.Subscription { + subErr := errors.New("Subscription Error") + + var eventCounter int + subscription := event.NewSubscription(func(quit <-chan struct{}) error { + for _, chainEvent := range bc.ChainEvents { + if eventCounter > 1 { + return subErr + } + select { + case ch <- chainEvent: + case <-quit: + return nil + } + eventCounter++ + } + return nil + }) + + return subscription }