refactor: simplify chain event Called API

This commit is contained in:
Dirk McCormick 2020-10-30 14:00:32 +01:00
parent c3d00b0ac6
commit 79a8ff04fd
6 changed files with 45 additions and 54 deletions

View File

@ -459,7 +459,7 @@ type messageEvents struct {
hcAPI headChangeAPI
lk sync.RWMutex
matchers map[triggerID][]MsgMatchFunc
matchers map[triggerID]MsgMatchFunc
}
func newMessageEvents(ctx context.Context, hcAPI headChangeAPI, cs eventAPI) messageEvents {
@ -467,7 +467,7 @@ func newMessageEvents(ctx context.Context, hcAPI headChangeAPI, cs eventAPI) mes
ctx: ctx,
cs: cs,
hcAPI: hcAPI,
matchers: map[triggerID][]MsgMatchFunc{},
matchers: make(map[triggerID]MsgMatchFunc),
}
}
@ -482,32 +482,23 @@ func (me *messageEvents) checkNewCalls(ts *types.TipSet) (map[triggerID]eventDat
me.lk.RLock()
defer me.lk.RUnlock()
// For each message in the tipset
res := make(map[triggerID]eventData)
me.messagesForTs(pts, func(msg *types.Message) {
// TODO: provide receipts
for tid, matchFns := range me.matchers {
var matched bool
var once bool
for _, matchFn := range matchFns {
matchOne, ok, err := matchFn(msg)
if err != nil {
log.Errorf("event matcher failed: %s", err)
continue
}
matched = ok
once = matchOne
if matched {
break
}
// Run each trigger's matcher against the message
for tid, matchFn := range me.matchers {
matched, err := matchFn(msg)
if err != nil {
log.Errorf("event matcher failed: %s", err)
continue
}
// If there was a match, include the message in the results for the
// trigger
if matched {
res[tid] = msg
if once {
break
}
}
}
})
@ -555,7 +546,7 @@ func (me *messageEvents) messagesForTs(ts *types.TipSet, consume func(*types.Mes
// `curH`-`ts.Height` = `confidence`
type MsgHandler func(msg *types.Message, rec *types.MessageReceipt, ts *types.TipSet, curH abi.ChainEpoch) (more bool, err error)
type MsgMatchFunc func(msg *types.Message) (matchOnce bool, matched bool, err error)
type MsgMatchFunc func(msg *types.Message) (matched bool, err error)
// Called registers a callback which is triggered when a specified method is
// called on an actor, or a timeout is reached.
@ -607,7 +598,7 @@ func (me *messageEvents) Called(check CheckFunc, msgHnd MsgHandler, rev RevertHa
me.lk.Lock()
defer me.lk.Unlock()
me.matchers[id] = append(me.matchers[id], mf)
me.matchers[id] = mf
return nil
}

View File

@ -572,9 +572,9 @@ func TestAtChainedConfidenceNull(t *testing.T) {
require.Equal(t, false, reverted)
}
func matchAddrMethod(to address.Address, m abi.MethodNum) func(msg *types.Message) (matchOnce bool, matched bool, err error) {
return func(msg *types.Message) (matchOnce bool, matched bool, err error) {
return true, to == msg.To && m == msg.Method, nil
func matchAddrMethod(to address.Address, m abi.MethodNum) func(msg *types.Message) (matched bool, err error) {
return func(msg *types.Message) (matched bool, err error) {
return to == msg.To && m == msg.Method, nil
}
}

View File

@ -34,11 +34,11 @@ func (me *messageEvents) CheckMsg(ctx context.Context, smsg types.ChainMsg, hnd
}
func (me *messageEvents) MatchMsg(inmsg *types.Message) MsgMatchFunc {
return func(msg *types.Message) (matchOnce bool, matched bool, err error) {
return func(msg *types.Message) (matched bool, err error) {
if msg.From == inmsg.From && msg.Nonce == inmsg.Nonce && !inmsg.Equals(msg) {
return true, false, xerrors.Errorf("matching msg %s from %s, nonce %d: got duplicate origin/nonce msg %d", inmsg.Cid(), inmsg.From, inmsg.Nonce, msg.Nonce)
return false, xerrors.Errorf("matching msg %s from %s, nonce %d: got duplicate origin/nonce msg %d", inmsg.Cid(), inmsg.From, inmsg.Nonce, msg.Nonce)
}
return true, inmsg.Equals(msg), nil
return inmsg.Equals(msg), nil
}
}

View File

@ -263,44 +263,44 @@ func (c *ClientNodeAdapter) OnDealSectorCommitted(ctx context.Context, provider
var sectorNumber abi.SectorNumber
var sectorFound bool
matchEvent := func(msg *types.Message) (matchOnce bool, matched bool, err error) {
matchEvent := func(msg *types.Message) (matched bool, err error) {
if msg.To != provider {
return true, false, nil
return false, nil
}
switch msg.Method {
case miner2.MethodsMiner.PreCommitSector:
var params miner.SectorPreCommitInfo
if err := params.UnmarshalCBOR(bytes.NewReader(msg.Params)); err != nil {
return true, false, xerrors.Errorf("unmarshal pre commit: %w", err)
return false, xerrors.Errorf("unmarshal pre commit: %w", err)
}
for _, did := range params.DealIDs {
if did == dealId {
sectorNumber = params.SectorNumber
sectorFound = true
return true, false, nil
return false, nil
}
}
return true, false, nil
return false, nil
case miner2.MethodsMiner.ProveCommitSector:
var params miner.ProveCommitSectorParams
if err := params.UnmarshalCBOR(bytes.NewReader(msg.Params)); err != nil {
return true, false, xerrors.Errorf("failed to unmarshal prove commit sector params: %w", err)
return false, xerrors.Errorf("failed to unmarshal prove commit sector params: %w", err)
}
if !sectorFound {
return true, false, nil
return false, nil
}
if params.SectorNumber != sectorNumber {
return true, false, nil
return false, nil
}
return false, true, nil
return true, nil
default:
return true, false, nil
return false, nil
}
}

View File

@ -307,44 +307,44 @@ func (n *ProviderNodeAdapter) OnDealSectorCommitted(ctx context.Context, provide
var sectorNumber abi.SectorNumber
var sectorFound bool
matchEvent := func(msg *types.Message) (matchOnce bool, matched bool, err error) {
matchEvent := func(msg *types.Message) (matched bool, err error) {
if msg.To != provider {
return true, false, nil
return false, nil
}
switch msg.Method {
case miner.Methods.PreCommitSector:
var params miner.SectorPreCommitInfo
if err := params.UnmarshalCBOR(bytes.NewReader(msg.Params)); err != nil {
return true, false, xerrors.Errorf("unmarshal pre commit: %w", err)
return false, xerrors.Errorf("unmarshal pre commit: %w", err)
}
for _, did := range params.DealIDs {
if did == dealID {
sectorNumber = params.SectorNumber
sectorFound = true
return true, false, nil
return false, nil
}
}
return true, false, nil
return false, nil
case miner.Methods.ProveCommitSector:
var params miner.ProveCommitSectorParams
if err := params.UnmarshalCBOR(bytes.NewReader(msg.Params)); err != nil {
return true, false, xerrors.Errorf("failed to unmarshal prove commit sector params: %w", err)
return false, xerrors.Errorf("failed to unmarshal prove commit sector params: %w", err)
}
if !sectorFound {
return true, false, nil
return false, nil
}
if params.SectorNumber != sectorNumber {
return true, false, nil
return false, nil
}
return false, true, nil
return true, nil
default:
return true, false, nil
return false, nil
}
}

View File

@ -103,27 +103,27 @@ func (pcs *paymentChannelSettler) revertHandler(ctx context.Context, ts *types.T
return nil
}
func (pcs *paymentChannelSettler) matcher(msg *types.Message) (matchOnce bool, matched bool, err error) {
func (pcs *paymentChannelSettler) matcher(msg *types.Message) (matched bool, err error) {
// Check if this is a settle payment channel message
if msg.Method != paych.Methods.Settle {
return false, false, nil
return false, nil
}
// Check if this payment channel is of concern to this node (i.e. tracked in payment channel store),
// and its inbound (i.e. we're getting vouchers that we may need to redeem)
trackedAddresses, err := pcs.api.PaychList(pcs.ctx)
if err != nil {
return false, false, err
return false, err
}
for _, addr := range trackedAddresses {
if msg.To == addr {
status, err := pcs.api.PaychStatus(pcs.ctx, addr)
if err != nil {
return false, false, err
return false, err
}
if status.Direction == api.PCHInbound {
return false, true, nil
return true, nil
}
}
}
return false, false, nil
return false, nil
}