diff --git a/paychmgr/paychget_test.go b/paychmgr/paychget_test.go index 3dd061259..ab4c65c77 100644 --- a/paychmgr/paychget_test.go +++ b/paychmgr/paychget_test.go @@ -4,6 +4,7 @@ import ( "context" "sync" "testing" + "time" cborrpc "github.com/filecoin-project/go-cbor-util" @@ -741,3 +742,292 @@ func TestPaychGetWaitCtx(t *testing.T) { _, err = mgr.GetPaychWaitReady(ctx, mcid) require.Error(t, ctx.Err(), err) } + +// TestPaychGetMergeAddFunds tests that if a create channel is in +// progress and two add funds are queued up behind it, the two add funds +// will be merged +func TestPaychGetMergeAddFunds(t *testing.T) { + ctx := context.Background() + store := NewStore(ds_sync.MutexWrap(ds.NewMapDatastore())) + + ch := tutils.NewIDAddr(t, 100) + from := tutils.NewIDAddr(t, 101) + to := tutils.NewIDAddr(t, 102) + + sm := newMockStateManager() + pchapi := newMockPaychAPI() + defer pchapi.close() + + mgr, err := newManager(sm, store, pchapi) + require.NoError(t, err) + + // Send create message for a channel with value 10 + createAmt := big.NewInt(10) + _, createMsgCid, err := mgr.GetPaych(ctx, from, to, createAmt) + require.NoError(t, err) + + // Queue up two add funds requests behind create channel + //var addFundsQueuedUp sync.WaitGroup + //addFundsQueuedUp.Add(2) + var addFundsSent sync.WaitGroup + addFundsSent.Add(2) + + addFundsAmt1 := big.NewInt(5) + addFundsAmt2 := big.NewInt(3) + var addFundsCh1 address.Address + var addFundsCh2 address.Address + var addFundsMcid1 cid.Cid + var addFundsMcid2 cid.Cid + go func() { + //go addFundsQueuedUp.Done() + defer addFundsSent.Done() + + // Request add funds - should block until create channel has completed + addFundsCh1, addFundsMcid1, err = mgr.GetPaych(ctx, from, to, addFundsAmt1) + require.NoError(t, err) + }() + + go func() { + //go addFundsQueuedUp.Done() + defer addFundsSent.Done() + + // Request add funds again - should merge with waiting add funds request + addFundsCh2, addFundsMcid2, err = mgr.GetPaych(ctx, from, to, addFundsAmt2) + require.NoError(t, err) + }() + // Wait for add funds requests to be queued up + waitForQueueSize(t, mgr, from, to, 2) + + // Send create channel response + response := testChannelResponse(t, ch) + pchapi.receiveMsgResponse(createMsgCid, response) + + // Wait for create channel response + chres, err := mgr.GetPaychWaitReady(ctx, createMsgCid) + require.NoError(t, err) + require.Equal(t, ch, chres) + + // Wait for add funds requests to be sent + addFundsSent.Wait() + + // Expect add funds requests to have same channel as create channel and + // same message cid as each other (because they should have been merged) + require.Equal(t, ch, addFundsCh1) + require.Equal(t, ch, addFundsCh2) + require.Equal(t, addFundsMcid1, addFundsMcid2) + + // Send success add funds response + pchapi.receiveMsgResponse(addFundsMcid1, types.MessageReceipt{ + ExitCode: 0, + Return: []byte{}, + }) + + // Wait for add funds response + addFundsCh, err := mgr.GetPaychWaitReady(ctx, addFundsMcid1) + require.NoError(t, err) + require.Equal(t, ch, addFundsCh) + + // Make sure that one create channel message and one add funds message was + // sent + require.Equal(t, 2, pchapi.pushedMessageCount()) + + // Check create message amount is correct + createMsg := pchapi.pushedMessages(createMsgCid) + require.Equal(t, from, createMsg.Message.From) + require.Equal(t, builtin.InitActorAddr, createMsg.Message.To) + require.Equal(t, createAmt, createMsg.Message.Value) + + // Check merged add funds amount is the sum of the individual + // amounts + addFundsMsg := pchapi.pushedMessages(addFundsMcid1) + require.Equal(t, from, addFundsMsg.Message.From) + require.Equal(t, ch, addFundsMsg.Message.To) + require.Equal(t, types.BigAdd(addFundsAmt1, addFundsAmt2), addFundsMsg.Message.Value) +} + +// TestPaychGetMergeAddFundsCtxCancelOne tests that when a queued add funds +// request is cancelled, its amount is removed from the total merged add funds +func TestPaychGetMergeAddFundsCtxCancelOne(t *testing.T) { + ctx := context.Background() + store := NewStore(ds_sync.MutexWrap(ds.NewMapDatastore())) + + ch := tutils.NewIDAddr(t, 100) + from := tutils.NewIDAddr(t, 101) + to := tutils.NewIDAddr(t, 102) + + sm := newMockStateManager() + pchapi := newMockPaychAPI() + defer pchapi.close() + + mgr, err := newManager(sm, store, pchapi) + require.NoError(t, err) + + // Send create message for a channel with value 10 + createAmt := big.NewInt(10) + _, createMsgCid, err := mgr.GetPaych(ctx, from, to, createAmt) + require.NoError(t, err) + + // Queue up two add funds requests behind create channel + var addFundsSent sync.WaitGroup + addFundsSent.Add(2) + + addFundsAmt1 := big.NewInt(5) + addFundsAmt2 := big.NewInt(3) + var addFundsCh2 address.Address + var addFundsMcid2 cid.Cid + var addFundsErr1 error + addFundsCtx1, cancelAddFundsCtx1 := context.WithCancel(ctx) + go func() { + defer addFundsSent.Done() + + // Request add funds - should block until create channel has completed + _, _, addFundsErr1 = mgr.GetPaych(addFundsCtx1, from, to, addFundsAmt1) + }() + + go func() { + defer addFundsSent.Done() + + // Request add funds again - should merge with waiting add funds request + addFundsCh2, addFundsMcid2, err = mgr.GetPaych(ctx, from, to, addFundsAmt2) + require.NoError(t, err) + }() + // Wait for add funds requests to be queued up + waitForQueueSize(t, mgr, from, to, 2) + + // Cancel the first add funds request + cancelAddFundsCtx1() + + // Send create channel response + response := testChannelResponse(t, ch) + pchapi.receiveMsgResponse(createMsgCid, response) + + // Wait for create channel response + chres, err := mgr.GetPaychWaitReady(ctx, createMsgCid) + require.NoError(t, err) + require.Equal(t, ch, chres) + + // Wait for add funds requests to be sent + addFundsSent.Wait() + + // Expect first add funds request to have been cancelled + require.NotNil(t, addFundsErr1) + require.Equal(t, ch, addFundsCh2) + + // Send success add funds response + pchapi.receiveMsgResponse(addFundsMcid2, types.MessageReceipt{ + ExitCode: 0, + Return: []byte{}, + }) + + // Wait for add funds response + addFundsCh, err := mgr.GetPaychWaitReady(ctx, addFundsMcid2) + require.NoError(t, err) + require.Equal(t, ch, addFundsCh) + + // Make sure that one create channel message and one add funds message was + // sent + require.Equal(t, 2, pchapi.pushedMessageCount()) + + // Check create message amount is correct + createMsg := pchapi.pushedMessages(createMsgCid) + require.Equal(t, from, createMsg.Message.From) + require.Equal(t, builtin.InitActorAddr, createMsg.Message.To) + require.Equal(t, createAmt, createMsg.Message.Value) + + // Check merged add funds amount only includes the second add funds amount + // (because first was cancelled) + addFundsMsg := pchapi.pushedMessages(addFundsMcid2) + require.Equal(t, from, addFundsMsg.Message.From) + require.Equal(t, ch, addFundsMsg.Message.To) + require.Equal(t, addFundsAmt2, addFundsMsg.Message.Value) +} + +// TestPaychGetMergeAddFundsCtxCancelAll tests that when all queued add funds +// requests are cancelled, no add funds message is sent +func TestPaychGetMergeAddFundsCtxCancelAll(t *testing.T) { + ctx := context.Background() + store := NewStore(ds_sync.MutexWrap(ds.NewMapDatastore())) + + ch := tutils.NewIDAddr(t, 100) + from := tutils.NewIDAddr(t, 101) + to := tutils.NewIDAddr(t, 102) + + sm := newMockStateManager() + pchapi := newMockPaychAPI() + defer pchapi.close() + + mgr, err := newManager(sm, store, pchapi) + require.NoError(t, err) + + // Send create message for a channel with value 10 + createAmt := big.NewInt(10) + _, createMsgCid, err := mgr.GetPaych(ctx, from, to, createAmt) + require.NoError(t, err) + + // Queue up two add funds requests behind create channel + var addFundsSent sync.WaitGroup + addFundsSent.Add(2) + + var addFundsErr1 error + var addFundsErr2 error + addFundsCtx1, cancelAddFundsCtx1 := context.WithCancel(ctx) + addFundsCtx2, cancelAddFundsCtx2 := context.WithCancel(ctx) + go func() { + defer addFundsSent.Done() + + // Request add funds - should block until create channel has completed + _, _, addFundsErr1 = mgr.GetPaych(addFundsCtx1, from, to, big.NewInt(5)) + }() + + go func() { + defer addFundsSent.Done() + + // Request add funds again - should merge with waiting add funds request + _, _, addFundsErr2 = mgr.GetPaych(addFundsCtx2, from, to, big.NewInt(3)) + require.NoError(t, err) + }() + // Wait for add funds requests to be queued up + waitForQueueSize(t, mgr, from, to, 2) + + // Cancel all add funds requests + cancelAddFundsCtx1() + cancelAddFundsCtx2() + + // Send create channel response + response := testChannelResponse(t, ch) + pchapi.receiveMsgResponse(createMsgCid, response) + + // Wait for create channel response + chres, err := mgr.GetPaychWaitReady(ctx, createMsgCid) + require.NoError(t, err) + require.Equal(t, ch, chres) + + // Wait for add funds requests to error out + addFundsSent.Wait() + + require.NotNil(t, addFundsErr1) + require.NotNil(t, addFundsErr2) + + // Make sure that just the create channel message was sent + require.Equal(t, 1, pchapi.pushedMessageCount()) + + // Check create message amount is correct + createMsg := pchapi.pushedMessages(createMsgCid) + require.Equal(t, from, createMsg.Message.From) + require.Equal(t, builtin.InitActorAddr, createMsg.Message.To) + require.Equal(t, createAmt, createMsg.Message.Value) +} + +// waitForQueueSize waits for the funds request queue to be of the given size +func waitForQueueSize(t *testing.T, mgr *Manager, from address.Address, to address.Address, size int) { + ca, err := mgr.accessorByFromTo(from, to) + require.NoError(t, err) + + for { + if ca.queueSize() == size { + return + } + + time.Sleep(time.Millisecond) + } +} diff --git a/paychmgr/simple.go b/paychmgr/simple.go index 61c0a8fb1..12ff40d82 100644 --- a/paychmgr/simple.go +++ b/paychmgr/simple.go @@ -4,6 +4,7 @@ import ( "bytes" "context" "fmt" + "sync" "golang.org/x/sync/errgroup" @@ -36,15 +37,141 @@ type paychFundsRes struct { err error } -type onCompleteFn func(*paychFundsRes) - // fundsReq is a request to create a channel or add funds to a channel type fundsReq struct { - ctx context.Context - from address.Address - to address.Address - amt types.BigInt - onComplete onCompleteFn + ctx context.Context + promise chan *paychFundsRes + from address.Address + to address.Address + amt types.BigInt + + lk sync.Mutex + // merge parent, if this req is part of a merge + merge *mergedFundsReq + // whether the req's context has been cancelled + active bool +} + +func newFundsReq(ctx context.Context, from address.Address, to address.Address, amt types.BigInt) *fundsReq { + promise := make(chan *paychFundsRes) + return &fundsReq{ + ctx: ctx, + promise: promise, + from: from, + to: to, + amt: amt, + active: true, + } +} + +// onComplete is called when the funds request has been executed +func (r *fundsReq) onComplete(res *paychFundsRes) { + select { + case <-r.ctx.Done(): + case r.promise <- res: + } +} + +// cancel is called when the req's context is cancelled +func (r *fundsReq) cancel() { + r.lk.Lock() + + r.active = false + m := r.merge + + r.lk.Unlock() + + // If there's a merge parent, tell the merge parent to check if it has any + // active reqs left + if m != nil { + m.checkActive() + } +} + +// isActive indicates whether the req's context has been cancelled +func (r *fundsReq) isActive() bool { + r.lk.Lock() + defer r.lk.Unlock() + + return r.active +} + +// setMergeParent sets the merge that this req is part of +func (r *fundsReq) setMergeParent(m *mergedFundsReq) { + r.lk.Lock() + defer r.lk.Unlock() + + r.merge = m +} + +// mergedFundsReq merges together multiple add funds requests that are queued +// up, so that only one message is sent for all the requests (instead of one +// message for each request) +type mergedFundsReq struct { + ctx context.Context + cancel context.CancelFunc + reqs []*fundsReq +} + +func newMergedFundsReq(reqs []*fundsReq) *mergedFundsReq { + ctx, cancel := context.WithCancel(context.Background()) + m := &mergedFundsReq{ + ctx: ctx, + cancel: cancel, + reqs: reqs, + } + + for _, r := range m.reqs { + r.setMergeParent(m) + } + + // If the requests were all cancelled while being added, cancel the context + // immediately + m.checkActive() + + return m +} + +// Called when a fundsReq is cancelled +func (m *mergedFundsReq) checkActive() { + // Check if there are any active fundsReqs + for _, r := range m.reqs { + if r.isActive() { + return + } + } + + // If all fundsReqs have been cancelled, cancel the context + m.cancel() +} + +// onComplete is called when the queue has executed the mergeFundsReq. +// Calls onComplete on each fundsReq in the mergeFundsReq. +func (m *mergedFundsReq) onComplete(res *paychFundsRes) { + for _, r := range m.reqs { + if r.isActive() { + r.onComplete(res) + } + } +} + +func (m *mergedFundsReq) from() address.Address { + return m.reqs[0].from +} + +func (m *mergedFundsReq) to() address.Address { + return m.reqs[0].to +} + +// sum is the sum of the amounts in all requests in the merge +func (m *mergedFundsReq) sum() types.BigInt { + sum := types.NewInt(0) + for _, r := range m.reqs { + if r.isActive() { + sum = types.BigAdd(sum, r.amt) + } + } + return sum } // getPaych ensures that a channel exists between the from and to addresses, @@ -60,65 +187,97 @@ type fundsReq struct { // be attempted. func (ca *channelAccessor) getPaych(ctx context.Context, from, to address.Address, amt types.BigInt) (address.Address, cid.Cid, error) { // Add the request to add funds to a queue and wait for the result - promise := ca.enqueue(&fundsReq{ctx: ctx, from: from, to: to, amt: amt}) + freq := newFundsReq(ctx, from, to, amt) + ca.enqueue(freq) select { - case res := <-promise: + case res := <-freq.promise: return res.channel, res.mcid, res.err case <-ctx.Done(): + freq.cancel() return address.Undef, cid.Undef, ctx.Err() } } -// Queue up an add funds operation -func (ca *channelAccessor) enqueue(task *fundsReq) chan *paychFundsRes { - promise := make(chan *paychFundsRes) - task.onComplete = func(res *paychFundsRes) { - select { - case <-task.ctx.Done(): - case promise <- res: - } - } - +// Queue up an add funds operations +func (ca *channelAccessor) enqueue(task *fundsReq) { ca.lk.Lock() defer ca.lk.Unlock() ca.fundsReqQueue = append(ca.fundsReqQueue, task) - go ca.processNextQueueItem() - - return promise + go ca.processQueue() } -// Run the operation at the head of the queue -func (ca *channelAccessor) processNextQueueItem() { +// Run the operations in the queue +func (ca *channelAccessor) processQueue() { ca.lk.Lock() defer ca.lk.Unlock() + // Remove cancelled requests + ca.filterQueue() + + // If there's nothing in the queue, bail out if len(ca.fundsReqQueue) == 0 { return } - head := ca.fundsReqQueue[0] - res := ca.processTask(head.ctx, head.from, head.to, head.amt) + // Merge all pending requests into one. + // For example if there are pending requests for 3, 2, 4 then + // amt = 3 + 2 + 4 = 9 + merged := newMergedFundsReq(ca.fundsReqQueue[:]) + amt := merged.sum() + if amt.IsZero() { + // Note: The amount can be zero if requests are cancelled as we're + // building the mergedFundsReq + return + } + + res := ca.processTask(merged.ctx, merged.from(), merged.to(), amt) // If the task is waiting on an external event (eg something to appear on // chain) it will return nil if res == nil { // Stop processing the fundsReqQueue and wait. When the event occurs it will - // call processNextQueueItem() again + // call processQueue() again return } - // The task has finished processing so clean it up - ca.fundsReqQueue[0] = nil // allow GC of element - ca.fundsReqQueue = ca.fundsReqQueue[1:] + // Finished processing so clear the queue + ca.fundsReqQueue = nil // Call the task callback with its results - head.onComplete(res) + merged.onComplete(res) +} - // Process the next task - if len(ca.fundsReqQueue) > 0 { - go ca.processNextQueueItem() +// filterQueue filters cancelled requests out of the queue +func (ca *channelAccessor) filterQueue() { + if len(ca.fundsReqQueue) == 0 { + return } + + // Remove cancelled requests + i := 0 + for _, r := range ca.fundsReqQueue { + if r.isActive() { + ca.fundsReqQueue[i] = r + i++ + } + } + + // Allow GC of remaining slice elements + for rem := i; rem < len(ca.fundsReqQueue); rem++ { + ca.fundsReqQueue[i] = nil + } + + // Resize slice + ca.fundsReqQueue = ca.fundsReqQueue[:i] +} + +// queueSize is the size of the funds request queue (used by tests) +func (ca *channelAccessor) queueSize() int { + ca.lk.Lock() + defer ca.lk.Unlock() + + return len(ca.fundsReqQueue) } // msgWaitComplete is called when the message for a previous task is confirmed @@ -139,7 +298,7 @@ func (ca *channelAccessor) msgWaitComplete(mcid cid.Cid, err error) { // The queue may have been waiting for msg completion to proceed, so // process the next queue item if len(ca.fundsReqQueue) > 0 { - go ca.processNextQueueItem() + go ca.processQueue() } } diff --git a/paychmgr/store.go b/paychmgr/store.go index d7c6e82e7..62c4cf9b2 100644 --- a/paychmgr/store.go +++ b/paychmgr/store.go @@ -328,8 +328,7 @@ func (ps *Store) ByChannelID(channelID string) (*ChannelInfo, error) { return unmarshallChannelInfo(&stored, res) } -// CreateChannel creates an outbound channel for the given from / to, ensuring -// it has a higher sequence number than any existing channel with the same from / to +// CreateChannel creates an outbound channel for the given from / to func (ps *Store) CreateChannel(from address.Address, to address.Address, createMsgCid cid.Cid, amt types.BigInt) (*ChannelInfo, error) { ci := &ChannelInfo{ Direction: DirOutbound,