lotus/provider/lpmessage/sender.go
Andrew Jackson (Ajax) 2d7ba43777 fix sender
2023-12-05 17:39:15 +01:00

372 lines
11 KiB
Go

package lpmessage
import (
"bytes"
"context"
"time"
"github.com/google/uuid"
"github.com/ipfs/go-cid"
logging "github.com/ipfs/go-log/v2"
"go.uber.org/multierr"
"golang.org/x/xerrors"
"github.com/filecoin-project/go-address"
"github.com/filecoin-project/go-state-types/big"
"github.com/filecoin-project/lotus/api"
"github.com/filecoin-project/lotus/chain/types"
"github.com/filecoin-project/lotus/lib/harmony/harmonydb"
"github.com/filecoin-project/lotus/lib/harmony/harmonytask"
"github.com/filecoin-project/lotus/lib/harmony/resources"
"github.com/filecoin-project/lotus/lib/promise"
)
var log = logging.Logger("lpmessage")
var SendLockedWait = 100 * time.Millisecond
type SenderAPI interface {
StateAccountKey(ctx context.Context, addr address.Address, tsk types.TipSetKey) (address.Address, error)
GasEstimateMessageGas(ctx context.Context, msg *types.Message, spec *api.MessageSendSpec, tsk types.TipSetKey) (*types.Message, error)
WalletBalance(ctx context.Context, addr address.Address) (big.Int, error)
MpoolGetNonce(context.Context, address.Address) (uint64, error)
MpoolPush(context.Context, *types.SignedMessage) (cid.Cid, error)
}
type SignerAPI interface {
WalletSignMessage(context.Context, address.Address, *types.Message) (*types.SignedMessage, error)
}
// Sender abstracts away highly-available message sending with coordination through
// HarmonyDB. It make sure that nonces are assigned transactionally, and that
// messages are correctly broadcasted to the network. It ensures that messages
// are sent serially, and that failures to send don't cause nonce gaps.
type Sender struct {
api SenderAPI
sendTask *SendTask
db *harmonydb.DB
}
type SendTask struct {
sendTF promise.Promise[harmonytask.AddTaskFunc]
api SenderAPI
signer SignerAPI
db *harmonydb.DB
}
func (s *SendTask) Do(taskID harmonytask.TaskID, stillOwned func() bool) (done bool, err error) {
ctx := context.TODO()
// get message from db
var dbMsg struct {
FromKey string `db:"from_key"`
ToAddr string `db:"to_addr"`
UnsignedData []byte `db:"unsigned_data"`
UnsignedCid string `db:"unsigned_cid"`
// may not be null if we have somehow already signed but failed to send this message
Nonce *uint64 `db:"nonce"`
SignedData []byte `db:"signed_data"`
}
err = s.db.QueryRow(ctx, `select from_key, nonce, to_addr, unsigned_data, unsigned_cid from message_sends where send_task_id = $1`, taskID).Scan(&dbMsg)
if err != nil {
return false, xerrors.Errorf("getting message from db: %w", err)
}
// deserialize the message
var msg types.Message
err = msg.UnmarshalCBOR(bytes.NewReader(dbMsg.UnsignedData))
if err != nil {
return false, xerrors.Errorf("unmarshaling unsigned db message: %w", err)
}
// get db send lock
for {
// check if we still own the task
if !stillOwned() {
return false, xerrors.Errorf("lost ownership of task")
}
// try to acquire lock
cn, err := s.db.Exec(ctx, `INSERT INTO message_send_locks (from_key, task_id, claimed_at) VALUES ($1, $2, CURRENT_TIMESTAMP)
ON CONFLICT (from_key) DO UPDATE SET task_id = EXCLUDED.task_id, claimed_at = CURRENT_TIMESTAMP WHERE message_send_locks.task_id = $2;`, dbMsg.FromKey, taskID)
if err != nil {
return false, xerrors.Errorf("acquiring send lock: %w", err)
}
if cn == 1 {
// we got the lock
break
}
// we didn't get the lock, wait a bit and try again
log.Infow("waiting for send lock", "task_id", taskID, "from", dbMsg.FromKey)
time.Sleep(SendLockedWait)
}
// defer release db send lock
defer func() {
_, err2 := s.db.Exec(ctx, `delete from message_send_locks where from_key = $1 and task_id = $2`, dbMsg.FromKey, taskID)
if err2 != nil {
log.Errorw("releasing send lock", "task_id", taskID, "from", dbMsg.FromKey, "error", err2)
// make sure harmony retries this task so that we eventually release this lock
done = false
err = multierr.Append(err, xerrors.Errorf("releasing send lock: %w", err2))
}
}()
// assign nonce IF NOT ASSIGNED (max(api.MpoolGetNonce, db nonce+1))
var sigMsg *types.SignedMessage
if dbMsg.Nonce == nil {
msgNonce, err := s.api.MpoolGetNonce(ctx, msg.From)
if err != nil {
return false, xerrors.Errorf("getting nonce from mpool: %w", err)
}
// get nonce from db
var dbNonce *uint64
r := s.db.QueryRow(ctx, `select max(nonce) from message_sends where from_key = $1 and send_success = true`, msg.From.String())
if err := r.Scan(&dbNonce); err != nil {
return false, xerrors.Errorf("getting nonce from db: %w", err)
}
if dbNonce != nil && *dbNonce+1 > msgNonce {
msgNonce = *dbNonce + 1
}
msg.Nonce = msgNonce
// sign message
sigMsg, err = s.signer.WalletSignMessage(ctx, msg.From, &msg)
if err != nil {
return false, xerrors.Errorf("signing message: %w", err)
}
data, err := sigMsg.Serialize()
if err != nil {
return false, xerrors.Errorf("serializing message: %w", err)
}
jsonBytes, err := sigMsg.MarshalJSON()
if err != nil {
return false, xerrors.Errorf("marshaling message: %w", err)
}
// write to db
n, err := s.db.Exec(ctx, `update message_sends set nonce = $1, signed_data = $2, signed_json = $3, signed_cid = $4 where send_task_id = $5`,
msg.Nonce, data, string(jsonBytes), sigMsg.Cid().String(), taskID)
if err != nil {
return false, xerrors.Errorf("updating db record: %w", err)
}
if n != 1 {
log.Errorw("updating db record: expected 1 row to be affected, got %d", n)
return false, xerrors.Errorf("updating db record: expected 1 row to be affected, got %d", n)
}
} else {
// Note: this handles an unlikely edge-case:
// We have previously signed the message but either failed to send it or failed to update the db
// note that when that happens the likely cause is the provider process losing its db connection
// or getting killed before it can update the db. In that case the message lock will still be held
// so it will be safe to rebroadcast the signed message
// deserialize the signed message
sigMsg = new(types.SignedMessage)
err = sigMsg.UnmarshalCBOR(bytes.NewReader(dbMsg.SignedData))
if err != nil {
return false, xerrors.Errorf("unmarshaling signed db message: %w", err)
}
}
// send!
_, err = s.api.MpoolPush(ctx, sigMsg)
// persist send result
var sendSuccess = err == nil
var sendError string
if err != nil {
sendError = err.Error()
}
_, err = s.db.Exec(ctx, `update message_sends set send_success = $1, send_error = $2, send_time = CURRENT_TIMESTAMP where send_task_id = $3`, sendSuccess, sendError, taskID)
if err != nil {
return false, xerrors.Errorf("updating db record: %w", err)
}
return true, nil
}
func (s *SendTask) CanAccept(ids []harmonytask.TaskID, engine *harmonytask.TaskEngine) (*harmonytask.TaskID, error) {
if len(ids) == 0 {
// probably can't happen, but panicking is bad
return nil, nil
}
if s.signer == nil {
// can't sign messages here
return nil, nil
}
return &ids[0], nil
}
func (s *SendTask) TypeDetails() harmonytask.TaskTypeDetails {
return harmonytask.TaskTypeDetails{
Max: 1024,
Name: "SendMessage",
Cost: resources.Resources{
Cpu: 0,
Gpu: 0,
Ram: 1 << 20,
},
MaxFailures: 1000,
Follows: nil,
}
}
func (s *SendTask) Adder(taskFunc harmonytask.AddTaskFunc) {
s.sendTF.Set(taskFunc)
}
var _ harmonytask.TaskInterface = &SendTask{}
// NewSender creates a new Sender.
func NewSender(api SenderAPI, signer SignerAPI, db *harmonydb.DB) (*Sender, *SendTask) {
st := &SendTask{
api: api,
signer: signer,
db: db,
}
return &Sender{
api: api,
db: db,
sendTask: st,
}, st
}
// Send atomically assigns a nonce, signs, and pushes a message
// to mempool.
// maxFee is only used when GasFeeCap/GasPremium fields aren't specified
//
// When maxFee is set to 0, Send will guess appropriate fee
// based on current chain conditions
//
// Send behaves much like fullnodeApi.MpoolPushMessage, but it coordinates
// through HarmonyDB, making it safe to broadcast messages from multiple independent
// API nodes
//
// Send is also currently more strict about required parameters than MpoolPushMessage
func (s *Sender) Send(ctx context.Context, msg *types.Message, mss *api.MessageSendSpec, reason string) (cid.Cid, error) {
if mss == nil {
return cid.Undef, xerrors.Errorf("MessageSendSpec cannot be nil")
}
if (mss.MsgUuid != uuid.UUID{}) {
return cid.Undef, xerrors.Errorf("MessageSendSpec.MsgUuid must be zero")
}
fromA, err := s.api.StateAccountKey(ctx, msg.From, types.EmptyTSK)
if err != nil {
return cid.Undef, xerrors.Errorf("getting key address: %w", err)
}
msg.From = fromA
if msg.Nonce != 0 {
return cid.Undef, xerrors.Errorf("Send expects message nonce to be 0, was %d", msg.Nonce)
}
msg, err = s.api.GasEstimateMessageGas(ctx, msg, mss, types.EmptyTSK)
if err != nil {
return cid.Undef, xerrors.Errorf("GasEstimateMessageGas error: %w", err)
}
b, err := s.api.WalletBalance(ctx, msg.From)
if err != nil {
return cid.Undef, xerrors.Errorf("mpool push: getting origin balance: %w", err)
}
requiredFunds := big.Add(msg.Value, msg.RequiredFunds())
if b.LessThan(requiredFunds) {
return cid.Undef, xerrors.Errorf("mpool push: not enough funds: %s < %s", b, requiredFunds)
}
// push the task
taskAdder := s.sendTask.sendTF.Val(ctx)
unsBytes := new(bytes.Buffer)
err = msg.MarshalCBOR(unsBytes)
if err != nil {
return cid.Undef, xerrors.Errorf("marshaling message: %w", err)
}
taskAdder(func(id harmonytask.TaskID, tx *harmonydb.Tx) (shouldCommit bool, seriousError error) {
_, err := tx.Exec(`insert into message_sends (from_key, to_addr, send_reason, unsigned_data, unsigned_cid, send_task_id) values ($1, $2, $3, $4, $5, $6)`,
msg.From.String(), msg.To.String(), reason, unsBytes.Bytes(), msg.Cid().String(), id)
if err != nil {
return false, xerrors.Errorf("inserting message into db: %w", err)
}
return true, nil
})
// wait for exec
var (
pollInterval = 50 * time.Millisecond
pollIntervalMul = 2
maxPollInterval = 5 * time.Second
pollLoops = 0
sigCid cid.Cid
sendErr error
)
for {
var err error
var sigCidStr, sendError string
var sendSuccess *bool
err = s.db.QueryRow(ctx, `select signed_cid, send_success, send_error from message_sends where send_task_id = $1`, taskAdder).Scan(&sigCidStr, &sendSuccess, &sendError)
if err != nil {
return cid.Undef, xerrors.Errorf("getting cid for task: %w", err)
}
if sendSuccess == nil {
time.Sleep(pollInterval)
pollLoops++
pollInterval *= time.Duration(pollIntervalMul)
if pollInterval > maxPollInterval {
pollInterval = maxPollInterval
}
continue
}
if !*sendSuccess {
sendErr = xerrors.Errorf("send error: %s", sendError)
} else {
sigCid, err = cid.Parse(sigCidStr)
if err != nil {
return cid.Undef, xerrors.Errorf("parsing signed cid: %w", err)
}
}
break
}
log.Infow("sent message", "cid", sigCid, "task_id", taskAdder, "send_error", sendErr, "poll_loops", pollLoops)
return sigCid, sendErr
}