diff --git a/statediff/api.go b/statediff/api.go index b2614d5e7..1c19e312c 100644 --- a/statediff/api.go +++ b/statediff/api.go @@ -178,10 +178,8 @@ func (api *PublicStateDiffAPI) StreamWrites(ctx context.Context) (*rpc.Subscript var err error defer func() { - if err != nil { - if err = api.sds.UnsubscribeWriteStatus(rpcSub.ID); err != nil { - log.Error("Failed to unsubscribe from job status stream: " + err.Error()) - } + if err = api.sds.UnsubscribeWriteStatus(rpcSub.ID); err != nil { + log.Error("Failed to unsubscribe from job status stream: " + err.Error()) } }() // loop and await payloads and relay them to the subscriber with the notifier diff --git a/statediff/service.go b/statediff/service.go index 5a171ef74..738723391 100644 --- a/statediff/service.go +++ b/statediff/service.go @@ -890,9 +890,6 @@ func (sds *Service) UnsubscribeWriteStatus(id rpc.ID) error { sds.Lock() close(sds.jobStatusSubs[id].quitChan) delete(sds.jobStatusSubs, id) - if len(sds.jobStatusSubs) == 0 { - sds.jobStatusSubs = nil - } sds.Unlock() return nil } diff --git a/statediff/service_test.go b/statediff/service_test.go index ca5c1116e..ceea79ece 100644 --- a/statediff/service_test.go +++ b/statediff/service_test.go @@ -302,29 +302,26 @@ func TestGetStateDiffAt(t *testing.T) { type writeSub struct { sub *rpc.ClientSubscription statusChan <-chan statediff.JobStatus - client *rpc.Client } -func (ws writeSub) close() { - ws.sub.Unsubscribe() - ws.client.Close() -} - -// awaitStatus awaits status update for writeStateDiffAt job -func subscribeWrites(ctx context.Context, svc *statediff.Service) (writeSub, error) { +func makeClient(svc *statediff.Service) *rpc.Client { server := rpc.NewServer() api := statediff.NewPublicStateDiffAPI(svc) err := server.RegisterName("statediff", api) if err != nil { - return writeSub{}, err + panic(err) } - client := rpc.DialInProc(server) - statusChan := make(chan statediff.JobStatus) - sub, err := client.Subscribe(ctx, "statediff", statusChan, "streamWrites") - return writeSub{sub, statusChan, client}, err + return rpc.DialInProc(server) } -func awaitJob(ws writeSub, job statediff.JobID, timeout time.Duration) (bool, error) { +// awaitStatus awaits status update for writeStateDiffAt job +func subscribeWrites(client *rpc.Client) (writeSub, error) { + statusChan := make(chan statediff.JobStatus) + sub, err := client.Subscribe(context.Background(), "statediff", statusChan, "streamWrites") + return writeSub{sub, statusChan}, err +} + +func (ws writeSub) await(job statediff.JobID, timeout time.Duration) (bool, error) { for { select { case err := <-ws.sub.Err(): @@ -358,13 +355,15 @@ func TestWriteStateDiffAt(t *testing.T) { // delay to avoid subscription request being sent after statediff is written, // and timeout to prevent hanging just in case it still happens writeDelay := 100 * time.Millisecond - jobTimeout := time.Second - ws, err := subscribeWrites(context.Background(), service) + jobTimeout := 200 * time.Millisecond + client := makeClient(service) + defer client.Close() + + ws, err := subscribeWrites(client) require.NoError(t, err) - defer ws.close() time.Sleep(writeDelay) job := service.WriteStateDiffAt(testBlock1.NumberU64(), defaultParams) - ok, err := awaitJob(ws, job, jobTimeout) + ok, err := ws.await(job, jobTimeout) require.NoError(t, err) require.True(t, ok) @@ -372,6 +371,27 @@ func TestWriteStateDiffAt(t *testing.T) { require.Equal(t, testBlock1.Hash(), builder.Args.BlockHash) require.Equal(t, parentBlock1.Root(), builder.Args.OldStateRoot) require.Equal(t, testBlock1.Root(), builder.Args.NewStateRoot) + + // unsubscribe and verify we get nothing + // TODO - StreamWrites receives EOF error after unsubscribing. Doesn't seem to impact + // anything but would be good to know why. + ws.sub.Unsubscribe() + time.Sleep(writeDelay) + job = service.WriteStateDiffAt(testBlock1.NumberU64(), defaultParams) + ok, _ = ws.await(job, jobTimeout) + require.False(t, ok) + + client.Close() + client = makeClient(service) + + // re-subscribe and test again + ws, err = subscribeWrites(client) + require.NoError(t, err) + time.Sleep(writeDelay) + job = service.WriteStateDiffAt(testBlock1.NumberU64(), defaultParams) + ok, err = ws.await(job, jobTimeout) + require.NoError(t, err) + require.True(t, ok) } func TestWaitForSync(t *testing.T) {