diff --git a/go.mod b/go.mod index 51ca85830..89460674e 100644 --- a/go.mod +++ b/go.mod @@ -44,6 +44,7 @@ require ( github.com/google/uuid v1.1.1 github.com/gorilla/mux v1.7.4 github.com/gorilla/websocket v1.4.2 + github.com/hannahhoward/go-pubsub v0.0.0-20200423002714-8d62886cc36e github.com/hashicorp/go-multierror v1.1.0 github.com/hashicorp/golang-lru v0.5.4 github.com/influxdata/influxdb1-client v0.0.0-20191209144304-8bf82d3c094d diff --git a/paychmgr/msglistener.go b/paychmgr/msglistener.go index 0a38cc2da..d1204e486 100644 --- a/paychmgr/msglistener.go +++ b/paychmgr/msglistener.go @@ -1,61 +1,56 @@ package paychmgr import ( - "sync" + "golang.org/x/xerrors" + + "github.com/hannahhoward/go-pubsub" - "github.com/google/uuid" "github.com/ipfs/go-cid" ) -type msgListener struct { - id string - cb func(c cid.Cid, err error) -} - type msgListeners struct { - lk sync.Mutex - listeners []*msgListener + ps *pubsub.PubSub } -func (ml *msgListeners) onMsg(mcid cid.Cid, cb func(error)) string { - ml.lk.Lock() - defer ml.lk.Unlock() - - l := &msgListener{ - id: uuid.New().String(), - cb: func(c cid.Cid, err error) { - if mcid.Equals(c) { - cb(err) - } - }, - } - ml.listeners = append(ml.listeners, l) - return l.id +type msgCompleteEvt struct { + mcid cid.Cid + err error } -func (ml *msgListeners) fireMsgComplete(mcid cid.Cid, err error) { - ml.lk.Lock() - defer ml.lk.Unlock() +type subscriberFn func(msgCompleteEvt) - for _, l := range ml.listeners { - l.cb(mcid, err) - } +func newMsgListeners() msgListeners { + ps := pubsub.New(func(event pubsub.Event, subFn pubsub.SubscriberFn) error { + evt, ok := event.(msgCompleteEvt) + if !ok { + return xerrors.Errorf("wrong type of event") + } + sub, ok := subFn.(subscriberFn) + if !ok { + return xerrors.Errorf("wrong type of subscriber") + } + sub(evt) + return nil + }) + return msgListeners{ps: ps} } -func (ml *msgListeners) unsubscribe(sub string) { - ml.lk.Lock() - defer ml.lk.Unlock() - - for i, l := range ml.listeners { - if l.id == sub { - ml.removeListener(i) - return +// onMsgComplete registers a callback for when the message with the given cid +// completes +func (ml *msgListeners) onMsgComplete(mcid cid.Cid, cb func(error)) pubsub.Unsubscribe { + var fn subscriberFn = func(evt msgCompleteEvt) { + if mcid.Equals(evt.mcid) { + cb(evt.err) } } + return ml.ps.Subscribe(fn) } -func (ml *msgListeners) removeListener(i int) { - copy(ml.listeners[i:], ml.listeners[i+1:]) - ml.listeners[len(ml.listeners)-1] = nil - ml.listeners = ml.listeners[:len(ml.listeners)-1] +// fireMsgComplete is called when a message completes +func (ml *msgListeners) fireMsgComplete(mcid cid.Cid, err error) { + e := ml.ps.Publish(msgCompleteEvt{mcid: mcid, err: err}) + if e != nil { + // In theory we shouldn't ever get an error here + log.Errorf("unexpected error publishing message complete: %s", e) + } } diff --git a/paychmgr/msglistener_test.go b/paychmgr/msglistener_test.go index fd457a518..2c3ae16e4 100644 --- a/paychmgr/msglistener_test.go +++ b/paychmgr/msglistener_test.go @@ -17,12 +17,12 @@ func testCids() []cid.Cid { } func TestMsgListener(t *testing.T) { - var ml msgListeners + ml := newMsgListeners() done := false experr := xerrors.Errorf("some err") cids := testCids() - ml.onMsg(cids[0], func(err error) { + ml.onMsgComplete(cids[0], func(err error) { require.Equal(t, experr, err) done = true }) @@ -35,11 +35,11 @@ func TestMsgListener(t *testing.T) { } func TestMsgListenerNilErr(t *testing.T) { - var ml msgListeners + ml := newMsgListeners() done := false cids := testCids() - ml.onMsg(cids[0], func(err error) { + ml.onMsgComplete(cids[0], func(err error) { require.Nil(t, err) done = true }) @@ -52,20 +52,20 @@ func TestMsgListenerNilErr(t *testing.T) { } func TestMsgListenerUnsub(t *testing.T) { - var ml msgListeners + ml := newMsgListeners() done := false experr := xerrors.Errorf("some err") cids := testCids() - id1 := ml.onMsg(cids[0], func(err error) { + unsub := ml.onMsgComplete(cids[0], func(err error) { t.Fatal("should not call unsubscribed listener") }) - ml.onMsg(cids[0], func(err error) { + ml.onMsgComplete(cids[0], func(err error) { require.Equal(t, experr, err) done = true }) - ml.unsubscribe(id1) + unsub() ml.fireMsgComplete(cids[0], experr) if !done { @@ -74,17 +74,17 @@ func TestMsgListenerUnsub(t *testing.T) { } func TestMsgListenerMulti(t *testing.T) { - var ml msgListeners + ml := newMsgListeners() count := 0 cids := testCids() - ml.onMsg(cids[0], func(err error) { + ml.onMsgComplete(cids[0], func(err error) { count++ }) - ml.onMsg(cids[0], func(err error) { + ml.onMsgComplete(cids[0], func(err error) { count++ }) - ml.onMsg(cids[1], func(err error) { + ml.onMsgComplete(cids[1], func(err error) { count++ }) diff --git a/paychmgr/paych.go b/paychmgr/paych.go index ba2018e96..f0d347e0b 100644 --- a/paychmgr/paych.go +++ b/paychmgr/paych.go @@ -35,12 +35,13 @@ type channelAccessor struct { func newChannelAccessor(pm *Manager) *channelAccessor { return &channelAccessor{ - lk: &channelLock{globalLock: &pm.lk}, - sm: pm.sm, - sa: &stateAccessor{sm: pm.sm}, - api: pm.pchapi, - store: pm.store, - waitCtx: pm.ctx, + lk: &channelLock{globalLock: &pm.lk}, + sm: pm.sm, + sa: &stateAccessor{sm: pm.sm}, + api: pm.pchapi, + store: pm.store, + msgListeners: newMsgListeners(), + waitCtx: pm.ctx, } } diff --git a/paychmgr/simple.go b/paychmgr/simple.go index 12ff40d82..ea65b48b2 100644 --- a/paychmgr/simple.go +++ b/paychmgr/simple.go @@ -640,7 +640,7 @@ type onMsgRes struct { func (ca *channelAccessor) msgPromise(ctx context.Context, mcid cid.Cid) chan onMsgRes { promise := make(chan onMsgRes) triggerUnsub := make(chan struct{}) - sub := ca.msgListeners.onMsg(mcid, func(err error) { + unsub := ca.msgListeners.onMsgComplete(mcid, func(err error) { close(triggerUnsub) // Use a go-routine so as not to block the event handler loop @@ -671,7 +671,7 @@ func (ca *channelAccessor) msgPromise(ctx context.Context, mcid cid.Cid) chan on case <-triggerUnsub: } - ca.msgListeners.unsubscribe(sub) + unsub() }() return promise