diff --git a/chain/types/bigint.go b/chain/types/bigint.go index 6a0a875bb..11946fa2a 100644 --- a/chain/types/bigint.go +++ b/chain/types/bigint.go @@ -81,18 +81,18 @@ func BigCmp(a, b BigInt) int { return a.Int.Cmp(b.Int) } -func (bi *BigInt) Nil() bool { +func (bi BigInt) Nil() bool { return bi.Int == nil } // LessThan returns true if bi < o -func (bi *BigInt) LessThan(o BigInt) bool { - return BigCmp(*bi, o) < 0 +func (bi BigInt) LessThan(o BigInt) bool { + return BigCmp(bi, o) < 0 } // LessThan returns true if bi > o -func (bi *BigInt) GreaterThan(o BigInt) bool { - return BigCmp(*bi, o) > 0 +func (bi BigInt) GreaterThan(o BigInt) bool { + return BigCmp(bi, o) > 0 } func (bi *BigInt) MarshalJSON() ([]byte, error) { diff --git a/miner/miner.go b/miner/miner.go index ba3b33c28..df47f3f14 100644 --- a/miner/miner.go +++ b/miner/miner.go @@ -299,7 +299,7 @@ func (m *Miner) createBlock(base *MiningBase, ticket *types.Ticket, proof types. return nil, errors.Wrapf(err, "failed to get pending messages") } - msgs, err := m.selectMessages(context.TODO(), base, pending) + msgs, err := selectMessages(context.TODO(), m.api.StateGetActor, base, pending) if err != nil { return nil, xerrors.Errorf("message filtering failed: %w", err) } @@ -310,31 +310,41 @@ func (m *Miner) createBlock(base *MiningBase, ticket *types.Ticket, proof types. return m.api.MinerCreateBlock(context.TODO(), m.addresses[0], base.ts, append(base.tickets, ticket), proof, msgs, uint64(uts)) } -func (m *Miner) selectMessages(ctx context.Context, base *MiningBase, msgs []*types.SignedMessage) ([]*types.SignedMessage, error) { +type actorLookup func(context.Context, address.Address, *types.TipSet) (*types.Actor, error) + +func selectMessages(ctx context.Context, al actorLookup, base *MiningBase, msgs []*types.SignedMessage) ([]*types.SignedMessage, error) { out := make([]*types.SignedMessage, 0, len(msgs)) inclNonces := make(map[address.Address]uint64) + inclBalances := make(map[address.Address]types.BigInt) for _, msg := range msgs { from := msg.Message.From - act, err := m.api.StateGetActor(ctx, from, base.ts) + act, err := al(ctx, from, base.ts) if err != nil { return nil, xerrors.Errorf("failed to check message sender balance: %w", err) } if _, ok := inclNonces[from]; !ok { inclNonces[from] = act.Nonce + inclBalances[from] = act.Balance } - if act.Balance.LessThan(msg.Message.RequiredFunds()) { - log.Warningf("message in mempool does not have enough funds: %s", msg.Cid()) + if inclBalances[from].LessThan(msg.Message.RequiredFunds()) { + log.Warnf("message in mempool does not have enough funds: %s", msg.Cid()) continue } if msg.Message.Nonce > inclNonces[from] { - log.Warningf("message in mempool has too high of a nonce: %s", msg.Cid()) + log.Warnf("message in mempool has too high of a nonce (%d > %d) %s", msg.Message.Nonce, inclNonces[from], msg.Cid()) continue } - inclNonces[from] = msg.Message.Nonce + if msg.Message.Nonce < inclNonces[from] { + log.Warnf("message in mempool has already used nonce (%d < %d) %s", msg.Message.Nonce, inclNonces[from], msg.Cid()) + continue + } + + inclNonces[from] = msg.Message.Nonce + 1 + inclBalances[from] = types.BigSub(inclBalances[from], msg.Message.RequiredFunds()) out = append(out, msg) } diff --git a/miner/miner_test.go b/miner/miner_test.go new file mode 100644 index 000000000..dc79bf129 --- /dev/null +++ b/miner/miner_test.go @@ -0,0 +1,99 @@ +package miner + +import ( + "context" + "testing" + + "github.com/filecoin-project/go-lotus/chain/address" + "github.com/filecoin-project/go-lotus/chain/types" +) + +func mustIDAddr(i uint64) address.Address { + a, err := address.NewIDAddress(i) + if err != nil { + panic(err) + } + + return a +} + +func TestMessageFiltering(t *testing.T) { + ctx := context.TODO() + a1 := mustIDAddr(1) + a2 := mustIDAddr(2) + + actors := map[address.Address]*types.Actor{ + a1: &types.Actor{ + Nonce: 3, + Balance: types.NewInt(1200), + }, + a2: &types.Actor{ + Nonce: 1, + Balance: types.NewInt(1000), + }, + } + + af := func(ctx context.Context, addr address.Address, ts *types.TipSet) (*types.Actor, error) { + return actors[addr], nil + } + + msgs := []types.Message{ + types.Message{ + From: a1, + Nonce: 3, + Value: types.NewInt(500), + GasLimit: types.NewInt(50), + GasPrice: types.NewInt(1), + }, + types.Message{ + From: a1, + Nonce: 4, + Value: types.NewInt(500), + GasLimit: types.NewInt(50), + GasPrice: types.NewInt(1), + }, + types.Message{ + From: a2, + Nonce: 1, + Value: types.NewInt(800), + GasLimit: types.NewInt(100), + GasPrice: types.NewInt(1), + }, + types.Message{ + From: a2, + Nonce: 0, + Value: types.NewInt(800), + GasLimit: types.NewInt(100), + GasPrice: types.NewInt(1), + }, + types.Message{ + From: a2, + Nonce: 2, + Value: types.NewInt(150), + GasLimit: types.NewInt(100), + GasPrice: types.NewInt(1), + }, + } + + outmsgs, err := selectMessages(ctx, af, &MiningBase{}, wrapMsgs(msgs)) + if err != nil { + t.Fatal(err) + } + + if len(outmsgs) != 3 { + t.Fatal("filtering didnt work as expected") + } + + m1 := outmsgs[2].Message + if m1.From != msgs[2].From || m1.Nonce != msgs[2].Nonce { + t.Fatal("filtering bad") + } +} + +func wrapMsgs(msgs []types.Message) []*types.SignedMessage { + var out []*types.SignedMessage + for _, m := range msgs { + out = append(out, &types.SignedMessage{Message: m}) + } + return out +}