diff --git a/paychmgr/manager.go b/paychmgr/manager.go index ea77c67ef..653684cdc 100644 --- a/paychmgr/manager.go +++ b/paychmgr/manager.go @@ -87,6 +87,7 @@ func newManager(pchstore *Store, pchapi managerAPI) (*Manager, error) { channels: make(map[string]*channelAccessor), pchapi: pchapi, } + pm.ctx, pm.shutdown = context.WithCancel(context.Background()) return pm, pm.Start() } diff --git a/paychmgr/mock_test.go b/paychmgr/mock_test.go index 2c891803b..d1325ad31 100644 --- a/paychmgr/mock_test.go +++ b/paychmgr/mock_test.go @@ -136,7 +136,7 @@ func newMockPaychAPI() *mockPaychAPI { func (pchapi *mockPaychAPI) StateWaitMsg(ctx context.Context, mcid cid.Cid, confidence uint64, limit abi.ChainEpoch, allowReplaced bool) (*api.MsgLookup, error) { pchapi.lk.Lock() - response := make(chan types.MessageReceipt) + response := make(chan types.MessageReceipt, 1) if response, ok := pchapi.waitingResponses[mcid]; ok { defer pchapi.lk.Unlock() @@ -151,8 +151,12 @@ func (pchapi *mockPaychAPI) StateWaitMsg(ctx context.Context, mcid cid.Cid, conf pchapi.waitingCalls[mcid] = &waitingCall{response: response} pchapi.lk.Unlock() - receipt := <-response - return &api.MsgLookup{Receipt: receipt}, nil + select { + case receipt := <-response: + return &api.MsgLookup{Receipt: receipt}, nil + case <-ctx.Done(): + return nil, ctx.Err() + } } func (pchapi *mockPaychAPI) receiveMsgResponse(mcid cid.Cid, receipt types.MessageReceipt) { diff --git a/paychmgr/paychget_test.go b/paychmgr/paychget_test.go index 9c5f3b47b..e18639794 100644 --- a/paychmgr/paychget_test.go +++ b/paychmgr/paychget_test.go @@ -631,7 +631,7 @@ func TestPaychGetRestartAfterAddFundsMsg(t *testing.T) { require.NoError(t, err) // Simulate shutting down system - mock.close() + require.NoError(t, mgr.Stop()) // Create a new manager with the same datastore mock2 := newMockManagerAPI() diff --git a/paychmgr/simple.go b/paychmgr/simple.go index 3d0992efe..8e8363ffc 100644 --- a/paychmgr/simple.go +++ b/paychmgr/simple.go @@ -3,6 +3,7 @@ package paychmgr import ( "bytes" "context" + "errors" "fmt" "sort" "sync" @@ -351,6 +352,11 @@ func (ca *channelAccessor) queueSize() int { // msgWaitComplete is called when the message for a previous task is confirmed // or there is an error. func (ca *channelAccessor) msgWaitComplete(ctx context.Context, mcid cid.Cid, err error) { + // if context is canceled, should Not mark message to 'bad', just return. + if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { + return + } + ca.lk.Lock() defer ca.lk.Unlock()