diff --git a/chain/messagesigner/messagesigner.go b/chain/messagesigner/messagesigner.go index ce9d01b3a..9f7b7bb5f 100644 --- a/chain/messagesigner/messagesigner.go +++ b/chain/messagesigner/messagesigner.go @@ -23,7 +23,7 @@ const dsKeyActorNonce = "ActorNextNonce" var log = logging.Logger("messagesigner") type MpoolNonceAPI interface { - GetNonce(address.Address) (uint64, error) + GetNonce(context.Context, address.Address, types.TipSetKey) (uint64, error) } // MessageSigner keeps track of nonces per address, and increments the nonce @@ -97,7 +97,7 @@ func (ms *MessageSigner) nextNonce(addr address.Address) (uint64, error) { // that have mempool nonces, so first check the mempool for a nonce for // this address. Note that the mempool returns the actor state's nonce // by default. - nonce, err := ms.mpool.GetNonce(addr) + nonce, err := ms.mpool.GetNonce(context.TODO(), addr, types.EmptyTSK) if err != nil { return 0, xerrors.Errorf("failed to get nonce from mempool: %w", err) } diff --git a/chain/messagesigner/messagesigner_test.go b/chain/messagesigner/messagesigner_test.go index 5eebd36da..7bba5b3e9 100644 --- a/chain/messagesigner/messagesigner_test.go +++ b/chain/messagesigner/messagesigner_test.go @@ -35,7 +35,7 @@ func (mp *mockMpool) setNonce(addr address.Address, nonce uint64) { mp.nonces[addr] = nonce } -func (mp *mockMpool) GetNonce(addr address.Address) (uint64, error) { +func (mp *mockMpool) GetNonce(_ context.Context, addr address.Address, _ types.TipSetKey) (uint64, error) { mp.lk.RLock() defer mp.lk.RUnlock() diff --git a/node/modules/mpoolnonceapi.go b/node/modules/mpoolnonceapi.go index efcb14037..61b38e821 100644 --- a/node/modules/mpoolnonceapi.go +++ b/node/modules/mpoolnonceapi.go @@ -2,6 +2,7 @@ package modules import ( "context" + "strings" "go.uber.org/fx" "golang.org/x/xerrors" @@ -19,41 +20,77 @@ import ( type MpoolNonceAPI struct { fx.In - StateAPI full.StateAPI + ChainModule full.ChainModuleAPI + StateModule full.StateModuleAPI } // GetNonce gets the nonce from current chain head. -func (a *MpoolNonceAPI) GetNonce(addr address.Address) (uint64, error) { - ts := a.StateAPI.Chain.GetHeaviestTipSet() +func (a *MpoolNonceAPI) GetNonce(ctx context.Context, addr address.Address, tsk types.TipSetKey) (uint64, error) { + var err error + var ts *types.TipSet + if tsk == types.EmptyTSK { + // we need consistent tsk + ts, err = a.ChainModule.ChainHead(ctx) + if err != nil { + return 0, xerrors.Errorf("getting head: %w", err) + } + tsk = ts.Key() + } else { + ts, err = a.ChainModule.ChainGetTipSet(ctx, tsk) + if err != nil { + return 0, xerrors.Errorf("getting tipset: %w", err) + } + } - // make sure we have a key address so we can compare with messages - keyAddr, err := a.StateAPI.StateManager.ResolveToKeyAddress(context.TODO(), addr, ts) - if err != nil { - return 0, err + keyAddr := addr + + if addr.Protocol() == address.ID { + // make sure we have a key address so we can compare with messages + keyAddr, err = a.StateModule.StateAccountKey(ctx, addr, tsk) + if err != nil { + return 0, xerrors.Errorf("getting account key: %w", err) + } + } else { + addr, err = a.StateModule.StateLookupID(ctx, addr, types.EmptyTSK) + if err != nil { + log.Infof("failed to look up id addr for %s: %w", addr, err) + addr = address.Undef + } } // Load the last nonce from the state, if it exists. highestNonce := uint64(0) - if baseActor, err := a.StateAPI.StateManager.LoadActorRaw(context.TODO(), addr, ts.ParentState()); err != nil { - if !xerrors.Is(err, types.ErrActorNotFound) { - return 0, err + act, err := a.StateModule.StateGetActor(ctx, keyAddr, ts.Key()) + if err != nil { + if strings.Contains(err.Error(), types.ErrActorNotFound.Error()) { + return 0, types.ErrActorNotFound + } + return 0, xerrors.Errorf("getting actor: %w", err) + } + highestNonce = act.Nonce + + apply := func(msg *types.Message) { + if msg.From != addr && msg.From != keyAddr { + return + } + if msg.Nonce == highestNonce { + highestNonce = msg.Nonce + 1 } - } else { - highestNonce = baseActor.Nonce } - // Otherwise, find the highest nonce in the tipset. - msgs, err := a.StateAPI.Chain.MessagesForTipset(ts) - if err != nil { - return 0, err - } - for _, msg := range msgs { - vmmsg := msg.VMMessage() - if vmmsg.From != keyAddr { - continue + for _, b := range ts.Blocks() { + msgs, err := a.ChainModule.ChainGetBlockMessages(ctx, b.Cid()) + if err != nil { + return 0, xerrors.Errorf("getting block messages: %w", err) } - if vmmsg.Nonce >= highestNonce { - highestNonce = vmmsg.Nonce + 1 + if keyAddr.Protocol() == address.BLS { + for _, m := range msgs.BlsMessages { + apply(m) + } + } else { + for _, sm := range msgs.SecpkMessages { + apply(&sm.Message) + } } } return highestNonce, nil