WIP: fix payment channel locking
This commit is contained in:
parent
bf9116c65a
commit
fdfccf0466
@ -377,7 +377,8 @@ type FullNode interface {
|
||||
// MethodGroup: Paych
|
||||
// The Paych methods are for interacting with and managing payment channels
|
||||
|
||||
PaychGet(ctx context.Context, from, to address.Address, ensureFunds types.BigInt) (*ChannelInfo, error)
|
||||
PaychGet(ctx context.Context, from, to address.Address, amt types.BigInt) (*ChannelInfo, error)
|
||||
PaychGetWaitReady(context.Context, cid.Cid) (address.Address, error)
|
||||
PaychList(context.Context) ([]address.Address, error)
|
||||
PaychStatus(context.Context, address.Address) (*PaychStatus, error)
|
||||
PaychSettle(context.Context, address.Address) (cid.Cid, error)
|
||||
|
@ -183,7 +183,8 @@ type FullNodeStruct struct {
|
||||
|
||||
MarketEnsureAvailable func(context.Context, address.Address, address.Address, types.BigInt) (cid.Cid, error) `perm:"sign"`
|
||||
|
||||
PaychGet func(ctx context.Context, from, to address.Address, ensureFunds types.BigInt) (*api.ChannelInfo, error) `perm:"sign"`
|
||||
PaychGet func(ctx context.Context, from, to address.Address, amt types.BigInt) (*api.ChannelInfo, error) `perm:"sign"`
|
||||
PaychGetWaitReady func(context.Context, cid.Cid) (address.Address, error) `perm:"sign"` // TODO: is perm:"sign" correct?
|
||||
PaychList func(context.Context) ([]address.Address, error) `perm:"read"`
|
||||
PaychStatus func(context.Context, address.Address) (*api.PaychStatus, error) `perm:"read"`
|
||||
PaychSettle func(context.Context, address.Address) (cid.Cid, error) `perm:"sign"`
|
||||
@ -803,8 +804,12 @@ func (c *FullNodeStruct) MarketEnsureAvailable(ctx context.Context, addr, wallet
|
||||
return c.Internal.MarketEnsureAvailable(ctx, addr, wallet, amt)
|
||||
}
|
||||
|
||||
func (c *FullNodeStruct) PaychGet(ctx context.Context, from, to address.Address, ensureFunds types.BigInt) (*api.ChannelInfo, error) {
|
||||
return c.Internal.PaychGet(ctx, from, to, ensureFunds)
|
||||
func (c *FullNodeStruct) PaychGet(ctx context.Context, from, to address.Address, amt types.BigInt) (*api.ChannelInfo, error) {
|
||||
return c.Internal.PaychGet(ctx, from, to, amt)
|
||||
}
|
||||
|
||||
func (c *FullNodeStruct) PaychGetWaitReady(ctx context.Context, mcid cid.Cid) (address.Address, error) {
|
||||
return c.Internal.PaychGetWaitReady(ctx, mcid)
|
||||
}
|
||||
|
||||
func (c *FullNodeStruct) PaychList(ctx context.Context) ([]address.Address, error) {
|
||||
|
@ -1,18 +1,17 @@
|
||||
package test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"github.com/filecoin-project/specs-actors/actors/builtin"
|
||||
"os"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/filecoin-project/specs-actors/actors/builtin"
|
||||
|
||||
"github.com/filecoin-project/specs-actors/actors/abi"
|
||||
"github.com/filecoin-project/specs-actors/actors/abi/big"
|
||||
initactor "github.com/filecoin-project/specs-actors/actors/builtin/init"
|
||||
"github.com/filecoin-project/specs-actors/actors/builtin/paych"
|
||||
"github.com/ipfs/go-cid"
|
||||
|
||||
@ -77,13 +76,10 @@ func TestPaymentChannels(t *testing.T, b APIBuilder, blocktime time.Duration) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
res := waitForMessage(ctx, t, paymentCreator, channelInfo.ChannelMessage, time.Second, "channel create")
|
||||
var params initactor.ExecReturn
|
||||
err = params.UnmarshalCBOR(bytes.NewReader(res.Receipt.Return))
|
||||
channel, err := paymentCreator.PaychGetWaitReady(ctx, channelInfo.ChannelMessage)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
channel := params.RobustAddress
|
||||
|
||||
// allocate three lanes
|
||||
var lanes []uint64
|
||||
@ -129,7 +125,7 @@ func TestPaymentChannels(t *testing.T, b APIBuilder, blocktime time.Duration) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
res = waitForMessage(ctx, t, paymentCreator, settleMsgCid, time.Second*10, "settle")
|
||||
res := waitForMessage(ctx, t, paymentCreator, settleMsgCid, time.Second*10, "settle")
|
||||
if res.Receipt.ExitCode != 0 {
|
||||
t.Fatal("Unable to settle payment channel")
|
||||
}
|
||||
|
@ -28,7 +28,7 @@ var paychCmd = &cli.Command{
|
||||
|
||||
var paychGetCmd = &cli.Command{
|
||||
Name: "get",
|
||||
Usage: "Create a new payment channel or get existing one",
|
||||
Usage: "Create a new payment channel or get existing one and add amount to it",
|
||||
ArgsUsage: "[fromAddress toAddress amount]",
|
||||
Action: func(cctx *cli.Context) error {
|
||||
if cctx.Args().Len() != 3 {
|
||||
|
@ -35,6 +35,7 @@ func main() {
|
||||
err = gen.WriteMapEncodersToFile("./paychmgr/cbor_gen.go", "paychmgr",
|
||||
paychmgr.VoucherInfo{},
|
||||
paychmgr.ChannelInfo{},
|
||||
paychmgr.MsgInfo{},
|
||||
)
|
||||
if err != nil {
|
||||
fmt.Println(err)
|
||||
|
@ -5,8 +5,6 @@ import (
|
||||
"errors"
|
||||
"time"
|
||||
|
||||
"github.com/filecoin-project/lotus/markets/dealfilter"
|
||||
|
||||
logging "github.com/ipfs/go-log"
|
||||
ci "github.com/libp2p/go-libp2p-core/crypto"
|
||||
"github.com/libp2p/go-libp2p-core/host"
|
||||
@ -111,6 +109,7 @@ const (
|
||||
HandleIncomingMessagesKey
|
||||
|
||||
RegisterClientValidatorKey
|
||||
HandlePaymentChannelManagerKey
|
||||
|
||||
// miner
|
||||
GetParamsKey
|
||||
@ -279,6 +278,7 @@ func Online() Option {
|
||||
Override(new(*paychmgr.Store), paychmgr.NewStore),
|
||||
Override(new(*paychmgr.Manager), paychmgr.NewManager),
|
||||
Override(new(*market.FundMgr), market.StartFundManager),
|
||||
Override(HandlePaymentChannelManagerKey, paychmgr.HandleManager),
|
||||
Override(SettlePaymentChannelsKey, settler.SettlePaymentChannels),
|
||||
),
|
||||
|
||||
|
@ -28,8 +28,8 @@ type PaychAPI struct {
|
||||
PaychMgr *paychmgr.Manager
|
||||
}
|
||||
|
||||
func (a *PaychAPI) PaychGet(ctx context.Context, from, to address.Address, ensureFunds types.BigInt) (*api.ChannelInfo, error) {
|
||||
ch, mcid, err := a.PaychMgr.GetPaych(ctx, from, to, ensureFunds)
|
||||
func (a *PaychAPI) PaychGet(ctx context.Context, from, to address.Address, amt types.BigInt) (*api.ChannelInfo, error) {
|
||||
ch, mcid, err := a.PaychMgr.GetPaych(ctx, from, to, amt)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -40,6 +40,10 @@ func (a *PaychAPI) PaychGet(ctx context.Context, from, to address.Address, ensur
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (a *PaychAPI) PaychGetWaitReady(ctx context.Context, mcid cid.Cid) (address.Address, error) {
|
||||
return a.PaychMgr.GetPaychWaitReady(ctx, mcid)
|
||||
}
|
||||
|
||||
func (a *PaychAPI) PaychAllocateLane(ctx context.Context, ch address.Address) (uint64, error) {
|
||||
return a.PaychMgr.AllocateLane(ch)
|
||||
}
|
||||
@ -66,7 +70,7 @@ func (a *PaychAPI) PaychNewPayment(ctx context.Context, from, to address.Address
|
||||
ChannelAddr: ch.Channel,
|
||||
|
||||
Amount: v.Amount,
|
||||
Lane: uint64(lane),
|
||||
Lane: lane,
|
||||
|
||||
Extra: v.Extra,
|
||||
TimeLockMin: v.TimeLockMin,
|
||||
@ -108,24 +112,7 @@ func (a *PaychAPI) PaychStatus(ctx context.Context, pch address.Address) (*api.P
|
||||
}
|
||||
|
||||
func (a *PaychAPI) PaychSettle(ctx context.Context, addr address.Address) (cid.Cid, error) {
|
||||
|
||||
ci, err := a.PaychMgr.GetChannelInfo(addr)
|
||||
if err != nil {
|
||||
return cid.Undef, err
|
||||
}
|
||||
|
||||
msg := &types.Message{
|
||||
To: addr,
|
||||
From: ci.Control,
|
||||
Value: types.NewInt(0),
|
||||
Method: builtin.MethodsPaych.Settle,
|
||||
}
|
||||
smgs, err := a.MpoolPushMessage(ctx, msg)
|
||||
|
||||
if err != nil {
|
||||
return cid.Undef, err
|
||||
}
|
||||
return smgs.Cid(), nil
|
||||
return a.PaychMgr.Settle(ctx, addr)
|
||||
}
|
||||
|
||||
func (a *PaychAPI) PaychCollect(ctx context.Context, addr address.Address) (cid.Cid, error) {
|
||||
|
67
paychmgr/accessorcache.go
Normal file
67
paychmgr/accessorcache.go
Normal file
@ -0,0 +1,67 @@
|
||||
package paychmgr
|
||||
|
||||
import "github.com/filecoin-project/go-address"
|
||||
|
||||
// accessorByFromTo gets a channel accessor for a given from / to pair.
|
||||
// The channel accessor facilitates locking a channel so that operations
|
||||
// must be performed sequentially on a channel (but can be performed at
|
||||
// the same time on different channels).
|
||||
func (pm *Manager) accessorByFromTo(from address.Address, to address.Address) (*channelAccessor, error) {
|
||||
key := pm.accessorCacheKey(from, to)
|
||||
|
||||
// First take a read lock and check the cache
|
||||
pm.lk.RLock()
|
||||
ca, ok := pm.channels[key]
|
||||
pm.lk.RUnlock()
|
||||
if ok {
|
||||
return ca, nil
|
||||
}
|
||||
|
||||
// Not in cache, so take a write lock
|
||||
pm.lk.Lock()
|
||||
defer pm.lk.Unlock()
|
||||
|
||||
// Need to check cache again in case it was updated between releasing read
|
||||
// lock and taking write lock
|
||||
ca, ok = pm.channels[key]
|
||||
if !ok {
|
||||
// Not in cache, so create a new one and store in cache
|
||||
ca = pm.addAccessorToCache(from, to)
|
||||
}
|
||||
|
||||
return ca, nil
|
||||
}
|
||||
|
||||
// accessorByAddress gets a channel accessor for a given channel address.
|
||||
// The channel accessor facilitates locking a channel so that operations
|
||||
// must be performed sequentially on a channel (but can be performed at
|
||||
// the same time on different channels).
|
||||
func (pm *Manager) accessorByAddress(ch address.Address) (*channelAccessor, error) {
|
||||
// Get the channel from / to
|
||||
pm.lk.RLock()
|
||||
channelInfo, err := pm.store.ByAddress(ch)
|
||||
pm.lk.RUnlock()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// TODO: cache by channel address so we can get by address instead of using from / to
|
||||
return pm.accessorByFromTo(channelInfo.Control, channelInfo.Target)
|
||||
}
|
||||
|
||||
// accessorCacheKey returns the cache key use to reference a channel accessor
|
||||
func (pm *Manager) accessorCacheKey(from address.Address, to address.Address) string {
|
||||
return from.String() + "->" + to.String()
|
||||
}
|
||||
|
||||
// addAccessorToCache adds a channel accessor to a cache. Note that channelInfo
|
||||
// may be nil if the channel hasn't been created yet, but we still want to
|
||||
// reference the same channel accessor for a given from/to, so that all
|
||||
// attempts to access a channel use the same lock (the lock on the accessor)
|
||||
func (pm *Manager) addAccessorToCache(from address.Address, to address.Address) *channelAccessor {
|
||||
key := pm.accessorCacheKey(from, to)
|
||||
ca := newChannelAccessor(pm)
|
||||
// TODO: Use LRU
|
||||
pm.channels[key] = ca
|
||||
return ca
|
||||
}
|
@ -6,6 +6,7 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
|
||||
"github.com/filecoin-project/go-address"
|
||||
"github.com/filecoin-project/specs-actors/actors/builtin/paych"
|
||||
cbg "github.com/whyrusleeping/cbor-gen"
|
||||
xerrors "golang.org/x/xerrors"
|
||||
@ -156,12 +157,35 @@ func (t *ChannelInfo) MarshalCBOR(w io.Writer) error {
|
||||
_, err := w.Write(cbg.CborNull)
|
||||
return err
|
||||
}
|
||||
if _, err := w.Write([]byte{166}); err != nil {
|
||||
if _, err := w.Write([]byte{172}); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
scratch := make([]byte, 9)
|
||||
|
||||
// t.ChannelID (string) (string)
|
||||
if len("ChannelID") > cbg.MaxLength {
|
||||
return xerrors.Errorf("Value in field \"ChannelID\" was too long")
|
||||
}
|
||||
|
||||
if err := cbg.WriteMajorTypeHeaderBuf(scratch, w, cbg.MajTextString, uint64(len("ChannelID"))); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := io.WriteString(w, string("ChannelID")); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if len(t.ChannelID) > cbg.MaxLength {
|
||||
return xerrors.Errorf("Value in field t.ChannelID was too long")
|
||||
}
|
||||
|
||||
if err := cbg.WriteMajorTypeHeaderBuf(scratch, w, cbg.MajTextString, uint64(len(t.ChannelID))); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := io.WriteString(w, string(t.ChannelID)); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// t.Channel (address.Address) (struct)
|
||||
if len("Channel") > cbg.MaxLength {
|
||||
return xerrors.Errorf("Value in field \"Channel\" was too long")
|
||||
@ -267,6 +291,97 @@ func (t *ChannelInfo) MarshalCBOR(w io.Writer) error {
|
||||
return err
|
||||
}
|
||||
|
||||
// t.Amount (big.Int) (struct)
|
||||
if len("Amount") > cbg.MaxLength {
|
||||
return xerrors.Errorf("Value in field \"Amount\" was too long")
|
||||
}
|
||||
|
||||
if err := cbg.WriteMajorTypeHeaderBuf(scratch, w, cbg.MajTextString, uint64(len("Amount"))); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := io.WriteString(w, string("Amount")); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := t.Amount.MarshalCBOR(w); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// t.PendingAmount (big.Int) (struct)
|
||||
if len("PendingAmount") > cbg.MaxLength {
|
||||
return xerrors.Errorf("Value in field \"PendingAmount\" was too long")
|
||||
}
|
||||
|
||||
if err := cbg.WriteMajorTypeHeaderBuf(scratch, w, cbg.MajTextString, uint64(len("PendingAmount"))); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := io.WriteString(w, string("PendingAmount")); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := t.PendingAmount.MarshalCBOR(w); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// t.CreateMsg (cid.Cid) (struct)
|
||||
if len("CreateMsg") > cbg.MaxLength {
|
||||
return xerrors.Errorf("Value in field \"CreateMsg\" was too long")
|
||||
}
|
||||
|
||||
if err := cbg.WriteMajorTypeHeaderBuf(scratch, w, cbg.MajTextString, uint64(len("CreateMsg"))); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := io.WriteString(w, string("CreateMsg")); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if t.CreateMsg == nil {
|
||||
if _, err := w.Write(cbg.CborNull); err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
if err := cbg.WriteCidBuf(scratch, w, *t.CreateMsg); err != nil {
|
||||
return xerrors.Errorf("failed to write cid field t.CreateMsg: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// t.AddFundsMsg (cid.Cid) (struct)
|
||||
if len("AddFundsMsg") > cbg.MaxLength {
|
||||
return xerrors.Errorf("Value in field \"AddFundsMsg\" was too long")
|
||||
}
|
||||
|
||||
if err := cbg.WriteMajorTypeHeaderBuf(scratch, w, cbg.MajTextString, uint64(len("AddFundsMsg"))); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := io.WriteString(w, string("AddFundsMsg")); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if t.AddFundsMsg == nil {
|
||||
if _, err := w.Write(cbg.CborNull); err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
if err := cbg.WriteCidBuf(scratch, w, *t.AddFundsMsg); err != nil {
|
||||
return xerrors.Errorf("failed to write cid field t.AddFundsMsg: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// t.Settling (bool) (bool)
|
||||
if len("Settling") > cbg.MaxLength {
|
||||
return xerrors.Errorf("Value in field \"Settling\" was too long")
|
||||
}
|
||||
|
||||
if err := cbg.WriteMajorTypeHeaderBuf(scratch, w, cbg.MajTextString, uint64(len("Settling"))); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := io.WriteString(w, string("Settling")); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := cbg.WriteBool(w, t.Settling); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -303,13 +418,36 @@ func (t *ChannelInfo) UnmarshalCBOR(r io.Reader) error {
|
||||
}
|
||||
|
||||
switch name {
|
||||
// t.Channel (address.Address) (struct)
|
||||
// t.ChannelID (string) (string)
|
||||
case "ChannelID":
|
||||
|
||||
{
|
||||
sval, err := cbg.ReadStringBuf(br, scratch)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
t.ChannelID = string(sval)
|
||||
}
|
||||
// t.Channel (address.Address) (struct)
|
||||
case "Channel":
|
||||
|
||||
{
|
||||
|
||||
if err := t.Channel.UnmarshalCBOR(br); err != nil {
|
||||
return xerrors.Errorf("unmarshaling t.Channel: %w", err)
|
||||
pb, err := br.PeekByte()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if pb == cbg.CborNull[0] {
|
||||
var nbuf [1]byte
|
||||
if _, err := br.Read(nbuf[:]); err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
t.Channel = new(address.Address)
|
||||
if err := t.Channel.UnmarshalCBOR(br); err != nil {
|
||||
return xerrors.Errorf("unmarshaling t.Channel pointer: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
@ -393,6 +531,279 @@ func (t *ChannelInfo) UnmarshalCBOR(r io.Reader) error {
|
||||
t.NextLane = uint64(extra)
|
||||
|
||||
}
|
||||
// t.Amount (big.Int) (struct)
|
||||
case "Amount":
|
||||
|
||||
{
|
||||
|
||||
if err := t.Amount.UnmarshalCBOR(br); err != nil {
|
||||
return xerrors.Errorf("unmarshaling t.Amount: %w", err)
|
||||
}
|
||||
|
||||
}
|
||||
// t.PendingAmount (big.Int) (struct)
|
||||
case "PendingAmount":
|
||||
|
||||
{
|
||||
|
||||
if err := t.PendingAmount.UnmarshalCBOR(br); err != nil {
|
||||
return xerrors.Errorf("unmarshaling t.PendingAmount: %w", err)
|
||||
}
|
||||
|
||||
}
|
||||
// t.CreateMsg (cid.Cid) (struct)
|
||||
case "CreateMsg":
|
||||
|
||||
{
|
||||
|
||||
pb, err := br.PeekByte()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if pb == cbg.CborNull[0] {
|
||||
var nbuf [1]byte
|
||||
if _, err := br.Read(nbuf[:]); err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
|
||||
c, err := cbg.ReadCid(br)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("failed to read cid field t.CreateMsg: %w", err)
|
||||
}
|
||||
|
||||
t.CreateMsg = &c
|
||||
}
|
||||
|
||||
}
|
||||
// t.AddFundsMsg (cid.Cid) (struct)
|
||||
case "AddFundsMsg":
|
||||
|
||||
{
|
||||
|
||||
pb, err := br.PeekByte()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if pb == cbg.CborNull[0] {
|
||||
var nbuf [1]byte
|
||||
if _, err := br.Read(nbuf[:]); err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
|
||||
c, err := cbg.ReadCid(br)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("failed to read cid field t.AddFundsMsg: %w", err)
|
||||
}
|
||||
|
||||
t.AddFundsMsg = &c
|
||||
}
|
||||
|
||||
}
|
||||
// t.Settling (bool) (bool)
|
||||
case "Settling":
|
||||
|
||||
maj, extra, err = cbg.CborReadHeaderBuf(br, scratch)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if maj != cbg.MajOther {
|
||||
return fmt.Errorf("booleans must be major type 7")
|
||||
}
|
||||
switch extra {
|
||||
case 20:
|
||||
t.Settling = false
|
||||
case 21:
|
||||
t.Settling = true
|
||||
default:
|
||||
return fmt.Errorf("booleans are either major type 7, value 20 or 21 (got %d)", extra)
|
||||
}
|
||||
|
||||
default:
|
||||
return fmt.Errorf("unknown struct field %d: '%s'", i, name)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
func (t *MsgInfo) MarshalCBOR(w io.Writer) error {
|
||||
if t == nil {
|
||||
_, err := w.Write(cbg.CborNull)
|
||||
return err
|
||||
}
|
||||
if _, err := w.Write([]byte{164}); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
scratch := make([]byte, 9)
|
||||
|
||||
// t.ChannelID (string) (string)
|
||||
if len("ChannelID") > cbg.MaxLength {
|
||||
return xerrors.Errorf("Value in field \"ChannelID\" was too long")
|
||||
}
|
||||
|
||||
if err := cbg.WriteMajorTypeHeaderBuf(scratch, w, cbg.MajTextString, uint64(len("ChannelID"))); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := io.WriteString(w, string("ChannelID")); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if len(t.ChannelID) > cbg.MaxLength {
|
||||
return xerrors.Errorf("Value in field t.ChannelID was too long")
|
||||
}
|
||||
|
||||
if err := cbg.WriteMajorTypeHeaderBuf(scratch, w, cbg.MajTextString, uint64(len(t.ChannelID))); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := io.WriteString(w, string(t.ChannelID)); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// t.MsgCid (cid.Cid) (struct)
|
||||
if len("MsgCid") > cbg.MaxLength {
|
||||
return xerrors.Errorf("Value in field \"MsgCid\" was too long")
|
||||
}
|
||||
|
||||
if err := cbg.WriteMajorTypeHeaderBuf(scratch, w, cbg.MajTextString, uint64(len("MsgCid"))); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := io.WriteString(w, string("MsgCid")); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := cbg.WriteCidBuf(scratch, w, t.MsgCid); err != nil {
|
||||
return xerrors.Errorf("failed to write cid field t.MsgCid: %w", err)
|
||||
}
|
||||
|
||||
// t.Received (bool) (bool)
|
||||
if len("Received") > cbg.MaxLength {
|
||||
return xerrors.Errorf("Value in field \"Received\" was too long")
|
||||
}
|
||||
|
||||
if err := cbg.WriteMajorTypeHeaderBuf(scratch, w, cbg.MajTextString, uint64(len("Received"))); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := io.WriteString(w, string("Received")); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := cbg.WriteBool(w, t.Received); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// t.Err (string) (string)
|
||||
if len("Err") > cbg.MaxLength {
|
||||
return xerrors.Errorf("Value in field \"Err\" was too long")
|
||||
}
|
||||
|
||||
if err := cbg.WriteMajorTypeHeaderBuf(scratch, w, cbg.MajTextString, uint64(len("Err"))); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := io.WriteString(w, string("Err")); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if len(t.Err) > cbg.MaxLength {
|
||||
return xerrors.Errorf("Value in field t.Err was too long")
|
||||
}
|
||||
|
||||
if err := cbg.WriteMajorTypeHeaderBuf(scratch, w, cbg.MajTextString, uint64(len(t.Err))); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := io.WriteString(w, string(t.Err)); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *MsgInfo) UnmarshalCBOR(r io.Reader) error {
|
||||
*t = MsgInfo{}
|
||||
|
||||
br := cbg.GetPeeker(r)
|
||||
scratch := make([]byte, 8)
|
||||
|
||||
maj, extra, err := cbg.CborReadHeaderBuf(br, scratch)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if maj != cbg.MajMap {
|
||||
return fmt.Errorf("cbor input should be of type map")
|
||||
}
|
||||
|
||||
if extra > cbg.MaxLength {
|
||||
return fmt.Errorf("MsgInfo: map struct too large (%d)", extra)
|
||||
}
|
||||
|
||||
var name string
|
||||
n := extra
|
||||
|
||||
for i := uint64(0); i < n; i++ {
|
||||
|
||||
{
|
||||
sval, err := cbg.ReadStringBuf(br, scratch)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
name = string(sval)
|
||||
}
|
||||
|
||||
switch name {
|
||||
// t.ChannelID (string) (string)
|
||||
case "ChannelID":
|
||||
|
||||
{
|
||||
sval, err := cbg.ReadStringBuf(br, scratch)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
t.ChannelID = string(sval)
|
||||
}
|
||||
// t.MsgCid (cid.Cid) (struct)
|
||||
case "MsgCid":
|
||||
|
||||
{
|
||||
|
||||
c, err := cbg.ReadCid(br)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("failed to read cid field t.MsgCid: %w", err)
|
||||
}
|
||||
|
||||
t.MsgCid = c
|
||||
|
||||
}
|
||||
// t.Received (bool) (bool)
|
||||
case "Received":
|
||||
|
||||
maj, extra, err = cbg.CborReadHeaderBuf(br, scratch)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if maj != cbg.MajOther {
|
||||
return fmt.Errorf("booleans must be major type 7")
|
||||
}
|
||||
switch extra {
|
||||
case 20:
|
||||
t.Received = false
|
||||
case 21:
|
||||
t.Received = true
|
||||
default:
|
||||
return fmt.Errorf("booleans are either major type 7, value 20 or 21 (got %d)", extra)
|
||||
}
|
||||
// t.Err (string) (string)
|
||||
case "Err":
|
||||
|
||||
{
|
||||
sval, err := cbg.ReadStringBuf(br, scratch)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
t.Err = string(sval)
|
||||
}
|
||||
|
||||
default:
|
||||
return fmt.Errorf("unknown struct field %d: '%s'", i, name)
|
||||
|
33
paychmgr/channellock.go
Normal file
33
paychmgr/channellock.go
Normal file
@ -0,0 +1,33 @@
|
||||
package paychmgr
|
||||
|
||||
import "sync"
|
||||
|
||||
type rwlock interface {
|
||||
RLock()
|
||||
RUnlock()
|
||||
}
|
||||
|
||||
// channelLock manages locking for a specific channel.
|
||||
// Some operations update the state of a single channel, and need to block
|
||||
// other operations only on the same channel's state.
|
||||
// Some operations update state that affects all channels, and need to block
|
||||
// any operation against any channel.
|
||||
type channelLock struct {
|
||||
globalLock rwlock
|
||||
chanLock sync.Mutex
|
||||
}
|
||||
|
||||
func (l *channelLock) Lock() {
|
||||
// Wait for other operations by this channel to finish.
|
||||
// Exclusive per-channel (no other ops by this channel allowed).
|
||||
l.chanLock.Lock()
|
||||
// Wait for operations affecting all channels to finish.
|
||||
// Allows ops by other channels in parallel, but blocks all operations
|
||||
// if global lock is taken exclusively (eg when adding a channel)
|
||||
l.globalLock.RLock()
|
||||
}
|
||||
|
||||
func (l *channelLock) Unlock() {
|
||||
l.globalLock.RUnlock()
|
||||
l.chanLock.Unlock()
|
||||
}
|
278
paychmgr/manager.go
Normal file
278
paychmgr/manager.go
Normal file
@ -0,0 +1,278 @@
|
||||
package paychmgr
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
|
||||
"github.com/ipfs/go-datastore"
|
||||
|
||||
"golang.org/x/sync/errgroup"
|
||||
|
||||
xerrors "golang.org/x/xerrors"
|
||||
|
||||
"github.com/filecoin-project/lotus/api"
|
||||
|
||||
"github.com/filecoin-project/specs-actors/actors/builtin/paych"
|
||||
|
||||
"github.com/ipfs/go-cid"
|
||||
logging "github.com/ipfs/go-log/v2"
|
||||
"go.uber.org/fx"
|
||||
|
||||
"github.com/filecoin-project/go-address"
|
||||
|
||||
"github.com/filecoin-project/lotus/chain/stmgr"
|
||||
"github.com/filecoin-project/lotus/chain/types"
|
||||
"github.com/filecoin-project/lotus/node/impl/full"
|
||||
)
|
||||
|
||||
var log = logging.Logger("paych")
|
||||
|
||||
type ManagerApi struct {
|
||||
fx.In
|
||||
|
||||
full.MpoolAPI
|
||||
full.WalletAPI
|
||||
full.StateAPI
|
||||
}
|
||||
|
||||
type StateManagerApi interface {
|
||||
LoadActorState(ctx context.Context, a address.Address, out interface{}, ts *types.TipSet) (*types.Actor, error)
|
||||
Call(ctx context.Context, msg *types.Message, ts *types.TipSet) (*api.InvocResult, error)
|
||||
}
|
||||
|
||||
type Manager struct {
|
||||
// The Manager context is used to terminate wait operations on shutdown
|
||||
ctx context.Context
|
||||
shutdown context.CancelFunc
|
||||
|
||||
store *Store
|
||||
sm StateManagerApi
|
||||
sa *stateAccessor
|
||||
pchapi paychApi
|
||||
|
||||
lk sync.RWMutex
|
||||
channels map[string]*channelAccessor
|
||||
|
||||
mpool full.MpoolAPI
|
||||
wallet full.WalletAPI
|
||||
state full.StateAPI
|
||||
}
|
||||
|
||||
type paychAPIImpl struct {
|
||||
full.MpoolAPI
|
||||
full.StateAPI
|
||||
}
|
||||
|
||||
func NewManager(sm *stmgr.StateManager, pchstore *Store, api ManagerApi) *Manager {
|
||||
return &Manager{
|
||||
store: pchstore,
|
||||
sm: sm,
|
||||
sa: &stateAccessor{sm: sm},
|
||||
channels: make(map[string]*channelAccessor),
|
||||
// TODO: Is this the correct way to do this or can I do something different
|
||||
// with dependency injection?
|
||||
pchapi: &paychAPIImpl{api.MpoolAPI, api.StateAPI},
|
||||
|
||||
mpool: api.MpoolAPI,
|
||||
wallet: api.WalletAPI,
|
||||
state: api.StateAPI,
|
||||
}
|
||||
}
|
||||
|
||||
// newManager is used by the tests to supply mocks
|
||||
func newManager(sm StateManagerApi, pchstore *Store, pchapi paychApi) (*Manager, error) {
|
||||
pm := &Manager{
|
||||
store: pchstore,
|
||||
sm: sm,
|
||||
sa: &stateAccessor{sm: sm},
|
||||
channels: make(map[string]*channelAccessor),
|
||||
pchapi: pchapi,
|
||||
}
|
||||
return pm, pm.Start(context.Background())
|
||||
}
|
||||
|
||||
// HandleManager is called by dependency injection to set up hooks
|
||||
func HandleManager(lc fx.Lifecycle, pm *Manager) {
|
||||
lc.Append(fx.Hook{
|
||||
OnStart: func(ctx context.Context) error {
|
||||
return pm.Start(ctx)
|
||||
},
|
||||
OnStop: func(context.Context) error {
|
||||
return pm.Stop()
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
// Start checks the datastore to see if there are any channels that have
|
||||
// outstanding add funds messages, and if so, waits on the messages.
|
||||
// Outstanding messages can occur if an add funds message was sent
|
||||
// and then lotus was shut down or crashed before the result was
|
||||
// received.
|
||||
func (pm *Manager) Start(ctx context.Context) error {
|
||||
pm.ctx, pm.shutdown = context.WithCancel(ctx)
|
||||
|
||||
cis, err := pm.store.WithPendingAddFunds()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
group := errgroup.Group{}
|
||||
for _, chanInfo := range cis {
|
||||
ci := chanInfo
|
||||
if ci.CreateMsg != nil {
|
||||
group.Go(func() error {
|
||||
ca, err := pm.accessorByFromTo(ci.Control, ci.Target)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("error initializing payment channel manager %s -> %s: %s", ci.Control, ci.Target, err)
|
||||
}
|
||||
go ca.waitForPaychCreateMsg(ci.Control, ci.Target, *ci.CreateMsg, nil)
|
||||
return nil
|
||||
})
|
||||
} else if ci.AddFundsMsg != nil {
|
||||
group.Go(func() error {
|
||||
ca, err := pm.accessorByAddress(*ci.Channel)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("error initializing payment channel manager %s: %s", ci.Channel, err)
|
||||
}
|
||||
go ca.waitForAddFundsMsg(ci.Control, ci.Target, *ci.AddFundsMsg, nil)
|
||||
return nil
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
return group.Wait()
|
||||
}
|
||||
|
||||
// Stop shuts down any processes used by the manager
|
||||
func (pm *Manager) Stop() error {
|
||||
pm.shutdown()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (pm *Manager) TrackOutboundChannel(ctx context.Context, ch address.Address) error {
|
||||
return pm.trackChannel(ctx, ch, DirOutbound)
|
||||
}
|
||||
|
||||
func (pm *Manager) TrackInboundChannel(ctx context.Context, ch address.Address) error {
|
||||
return pm.trackChannel(ctx, ch, DirInbound)
|
||||
}
|
||||
|
||||
func (pm *Manager) trackChannel(ctx context.Context, ch address.Address, dir uint64) error {
|
||||
pm.lk.Lock()
|
||||
defer pm.lk.Unlock()
|
||||
|
||||
ci, err := pm.sa.loadStateChannelInfo(ctx, ch, dir)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return pm.store.TrackChannel(ci)
|
||||
}
|
||||
|
||||
func (pm *Manager) GetPaych(ctx context.Context, from, to address.Address, amt types.BigInt) (address.Address, cid.Cid, error) {
|
||||
chanAccessor, err := pm.accessorByFromTo(from, to)
|
||||
if err != nil {
|
||||
return address.Undef, cid.Undef, err
|
||||
}
|
||||
|
||||
return chanAccessor.getPaych(ctx, from, to, amt)
|
||||
}
|
||||
|
||||
// GetPaychWaitReady waits until the create channel / add funds message with the
|
||||
// given message CID arrives.
|
||||
// The returned channel address can safely be used against the Manager methods.
|
||||
func (pm *Manager) GetPaychWaitReady(ctx context.Context, mcid cid.Cid) (address.Address, error) {
|
||||
// Find the channel associated with the message CID
|
||||
ci, err := pm.store.ByMessageCid(mcid)
|
||||
if err != nil {
|
||||
if err == datastore.ErrNotFound {
|
||||
return address.Undef, xerrors.Errorf("Could not find wait msg cid %s", mcid)
|
||||
}
|
||||
return address.Undef, err
|
||||
}
|
||||
|
||||
chanAccessor, err := pm.accessorByFromTo(ci.Control, ci.Target)
|
||||
if err != nil {
|
||||
return address.Undef, err
|
||||
}
|
||||
|
||||
return chanAccessor.getPaychWaitReady(ctx, mcid)
|
||||
}
|
||||
|
||||
func (pm *Manager) ListChannels() ([]address.Address, error) {
|
||||
// Need to take an exclusive lock here so that channel operations can't run
|
||||
// in parallel (see channelLock)
|
||||
pm.lk.Lock()
|
||||
defer pm.lk.Unlock()
|
||||
|
||||
return pm.store.ListChannels()
|
||||
}
|
||||
|
||||
func (pm *Manager) GetChannelInfo(addr address.Address) (*ChannelInfo, error) {
|
||||
ca, err := pm.accessorByAddress(addr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return ca.getChannelInfo(addr)
|
||||
}
|
||||
|
||||
// CheckVoucherValid checks if the given voucher is valid (is or could become spendable at some point)
|
||||
func (pm *Manager) CheckVoucherValid(ctx context.Context, ch address.Address, sv *paych.SignedVoucher) error {
|
||||
ca, err := pm.accessorByAddress(ch)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = ca.checkVoucherValid(ctx, ch, sv)
|
||||
return err
|
||||
}
|
||||
|
||||
// CheckVoucherSpendable checks if the given voucher is currently spendable
|
||||
func (pm *Manager) CheckVoucherSpendable(ctx context.Context, ch address.Address, sv *paych.SignedVoucher, secret []byte, proof []byte) (bool, error) {
|
||||
ca, err := pm.accessorByAddress(ch)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
return ca.checkVoucherSpendable(ctx, ch, sv, secret, proof)
|
||||
}
|
||||
|
||||
func (pm *Manager) AddVoucher(ctx context.Context, ch address.Address, sv *paych.SignedVoucher, proof []byte, minDelta types.BigInt) (types.BigInt, error) {
|
||||
ca, err := pm.accessorByAddress(ch)
|
||||
if err != nil {
|
||||
return types.NewInt(0), err
|
||||
}
|
||||
return ca.addVoucher(ctx, ch, sv, proof, minDelta)
|
||||
}
|
||||
|
||||
func (pm *Manager) AllocateLane(ch address.Address) (uint64, error) {
|
||||
ca, err := pm.accessorByAddress(ch)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return ca.allocateLane(ch)
|
||||
}
|
||||
|
||||
func (pm *Manager) ListVouchers(ctx context.Context, ch address.Address) ([]*VoucherInfo, error) {
|
||||
ca, err := pm.accessorByAddress(ch)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return ca.listVouchers(ctx, ch)
|
||||
}
|
||||
|
||||
func (pm *Manager) NextNonceForLane(ctx context.Context, ch address.Address, lane uint64) (uint64, error) {
|
||||
ca, err := pm.accessorByAddress(ch)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return ca.nextNonceForLane(ctx, ch, lane)
|
||||
}
|
||||
|
||||
func (pm *Manager) Settle(ctx context.Context, addr address.Address) (cid.Cid, error) {
|
||||
ca, err := pm.accessorByAddress(addr)
|
||||
if err != nil {
|
||||
return cid.Undef, err
|
||||
}
|
||||
return ca.settle(ctx, addr)
|
||||
}
|
58
paychmgr/msglistener.go
Normal file
58
paychmgr/msglistener.go
Normal file
@ -0,0 +1,58 @@
|
||||
package paychmgr
|
||||
|
||||
import (
|
||||
"sync"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/ipfs/go-cid"
|
||||
)
|
||||
|
||||
type msgListener struct {
|
||||
id string
|
||||
cb func(c cid.Cid, err error)
|
||||
}
|
||||
|
||||
type msgListeners struct {
|
||||
lk sync.Mutex
|
||||
listeners []*msgListener
|
||||
}
|
||||
|
||||
func (ml *msgListeners) onMsg(mcid cid.Cid, cb func(error)) string {
|
||||
ml.lk.Lock()
|
||||
defer ml.lk.Unlock()
|
||||
|
||||
l := &msgListener{
|
||||
id: uuid.New().String(),
|
||||
cb: func(c cid.Cid, err error) {
|
||||
if mcid.Equals(c) {
|
||||
cb(err)
|
||||
}
|
||||
},
|
||||
}
|
||||
ml.listeners = append(ml.listeners, l)
|
||||
return l.id
|
||||
}
|
||||
|
||||
func (ml *msgListeners) fireMsgComplete(mcid cid.Cid, err error) {
|
||||
ml.lk.Lock()
|
||||
defer ml.lk.Unlock()
|
||||
|
||||
for _, l := range ml.listeners {
|
||||
l.cb(mcid, err)
|
||||
}
|
||||
}
|
||||
|
||||
func (ml *msgListeners) unsubscribe(sub string) {
|
||||
for i, l := range ml.listeners {
|
||||
if l.id == sub {
|
||||
ml.removeListener(i)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (ml *msgListeners) removeListener(i int) {
|
||||
copy(ml.listeners[i:], ml.listeners[i+1:])
|
||||
ml.listeners[len(ml.listeners)-1] = nil
|
||||
ml.listeners = ml.listeners[:len(ml.listeners)-1]
|
||||
}
|
96
paychmgr/msglistener_test.go
Normal file
96
paychmgr/msglistener_test.go
Normal file
@ -0,0 +1,96 @@
|
||||
package paychmgr
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/ipfs/go-cid"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"golang.org/x/xerrors"
|
||||
)
|
||||
|
||||
func testCids() []cid.Cid {
|
||||
c1, _ := cid.Decode("QmdmGQmRgRjazArukTbsXuuxmSHsMCcRYPAZoGhd6e3MuS")
|
||||
c2, _ := cid.Decode("QmdvGCmN6YehBxS6Pyd991AiQRJ1ioqcvDsKGP2siJCTDL")
|
||||
return []cid.Cid{c1, c2}
|
||||
}
|
||||
|
||||
func TestMsgListener(t *testing.T) {
|
||||
var ml msgListeners
|
||||
|
||||
done := false
|
||||
experr := xerrors.Errorf("some err")
|
||||
cids := testCids()
|
||||
ml.onMsg(cids[0], func(err error) {
|
||||
require.Equal(t, experr, err)
|
||||
done = true
|
||||
})
|
||||
|
||||
ml.fireMsgComplete(cids[0], experr)
|
||||
|
||||
if !done {
|
||||
t.Fatal("failed to fire event")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMsgListenerNilErr(t *testing.T) {
|
||||
var ml msgListeners
|
||||
|
||||
done := false
|
||||
cids := testCids()
|
||||
ml.onMsg(cids[0], func(err error) {
|
||||
require.Nil(t, err)
|
||||
done = true
|
||||
})
|
||||
|
||||
ml.fireMsgComplete(cids[0], nil)
|
||||
|
||||
if !done {
|
||||
t.Fatal("failed to fire event")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMsgListenerUnsub(t *testing.T) {
|
||||
var ml msgListeners
|
||||
|
||||
done := false
|
||||
experr := xerrors.Errorf("some err")
|
||||
cids := testCids()
|
||||
id1 := ml.onMsg(cids[0], func(err error) {
|
||||
t.Fatal("should not call unsubscribed listener")
|
||||
})
|
||||
ml.onMsg(cids[0], func(err error) {
|
||||
require.Equal(t, experr, err)
|
||||
done = true
|
||||
})
|
||||
|
||||
ml.unsubscribe(id1)
|
||||
ml.fireMsgComplete(cids[0], experr)
|
||||
|
||||
if !done {
|
||||
t.Fatal("failed to fire event")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMsgListenerMulti(t *testing.T) {
|
||||
var ml msgListeners
|
||||
|
||||
count := 0
|
||||
cids := testCids()
|
||||
ml.onMsg(cids[0], func(err error) {
|
||||
count++
|
||||
})
|
||||
ml.onMsg(cids[0], func(err error) {
|
||||
count++
|
||||
})
|
||||
ml.onMsg(cids[1], func(err error) {
|
||||
count++
|
||||
})
|
||||
|
||||
ml.fireMsgComplete(cids[0], nil)
|
||||
require.Equal(t, 2, count)
|
||||
|
||||
ml.fireMsgComplete(cids[1], nil)
|
||||
require.Equal(t, 3, count)
|
||||
}
|
@ -5,118 +5,75 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/filecoin-project/specs-actors/actors/abi/big"
|
||||
|
||||
"github.com/filecoin-project/lotus/api"
|
||||
"github.com/ipfs/go-cid"
|
||||
|
||||
"github.com/filecoin-project/go-address"
|
||||
cborutil "github.com/filecoin-project/go-cbor-util"
|
||||
"github.com/filecoin-project/lotus/chain/actors"
|
||||
"github.com/filecoin-project/lotus/chain/types"
|
||||
"github.com/filecoin-project/lotus/lib/sigs"
|
||||
"github.com/filecoin-project/specs-actors/actors/abi/big"
|
||||
"github.com/filecoin-project/specs-actors/actors/builtin"
|
||||
"github.com/filecoin-project/specs-actors/actors/builtin/account"
|
||||
"github.com/filecoin-project/specs-actors/actors/builtin/paych"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
logging "github.com/ipfs/go-log/v2"
|
||||
"go.uber.org/fx"
|
||||
|
||||
"github.com/filecoin-project/go-address"
|
||||
|
||||
"github.com/filecoin-project/lotus/chain/actors"
|
||||
"github.com/filecoin-project/lotus/chain/stmgr"
|
||||
"github.com/filecoin-project/lotus/chain/types"
|
||||
"github.com/filecoin-project/lotus/lib/sigs"
|
||||
"github.com/filecoin-project/lotus/node/impl/full"
|
||||
xerrors "golang.org/x/xerrors"
|
||||
)
|
||||
|
||||
var log = logging.Logger("paych")
|
||||
|
||||
type ManagerApi struct {
|
||||
fx.In
|
||||
|
||||
full.MpoolAPI
|
||||
full.WalletAPI
|
||||
full.StateAPI
|
||||
// channelAccessor is used to simplify locking when accessing a channel
|
||||
type channelAccessor struct {
|
||||
// waitCtx is used by processes that wait for things to be confirmed
|
||||
// on chain
|
||||
waitCtx context.Context
|
||||
sm StateManagerApi
|
||||
sa *stateAccessor
|
||||
api paychApi
|
||||
store *Store
|
||||
lk *channelLock
|
||||
fundsReqQueue []*fundsReq
|
||||
msgListeners msgListeners
|
||||
}
|
||||
|
||||
type StateManagerApi interface {
|
||||
LoadActorState(ctx context.Context, a address.Address, out interface{}, ts *types.TipSet) (*types.Actor, error)
|
||||
Call(ctx context.Context, msg *types.Message, ts *types.TipSet) (*api.InvocResult, error)
|
||||
}
|
||||
|
||||
type Manager struct {
|
||||
store *Store
|
||||
sm StateManagerApi
|
||||
|
||||
mpool full.MpoolAPI
|
||||
wallet full.WalletAPI
|
||||
state full.StateAPI
|
||||
}
|
||||
|
||||
func NewManager(sm *stmgr.StateManager, pchstore *Store, api ManagerApi) *Manager {
|
||||
return &Manager{
|
||||
store: pchstore,
|
||||
sm: sm,
|
||||
|
||||
mpool: api.MpoolAPI,
|
||||
wallet: api.WalletAPI,
|
||||
state: api.StateAPI,
|
||||
func newChannelAccessor(pm *Manager) *channelAccessor {
|
||||
return &channelAccessor{
|
||||
lk: &channelLock{globalLock: &pm.lk},
|
||||
sm: pm.sm,
|
||||
sa: &stateAccessor{sm: pm.sm},
|
||||
api: pm.pchapi,
|
||||
store: pm.store,
|
||||
waitCtx: pm.ctx,
|
||||
}
|
||||
}
|
||||
|
||||
// Used by the tests to supply mocks
|
||||
func newManager(sm StateManagerApi, pchstore *Store) *Manager {
|
||||
return &Manager{
|
||||
store: pchstore,
|
||||
sm: sm,
|
||||
}
|
||||
func (ca *channelAccessor) getChannelInfo(addr address.Address) (*ChannelInfo, error) {
|
||||
ca.lk.Lock()
|
||||
defer ca.lk.Unlock()
|
||||
|
||||
return ca.store.ByAddress(addr)
|
||||
}
|
||||
|
||||
func (pm *Manager) TrackOutboundChannel(ctx context.Context, ch address.Address) error {
|
||||
return pm.trackChannel(ctx, ch, DirOutbound)
|
||||
func (ca *channelAccessor) checkVoucherValid(ctx context.Context, ch address.Address, sv *paych.SignedVoucher) (map[uint64]*paych.LaneState, error) {
|
||||
ca.lk.Lock()
|
||||
defer ca.lk.Unlock()
|
||||
|
||||
return ca.checkVoucherValidUnlocked(ctx, ch, sv)
|
||||
}
|
||||
|
||||
func (pm *Manager) TrackInboundChannel(ctx context.Context, ch address.Address) error {
|
||||
return pm.trackChannel(ctx, ch, DirInbound)
|
||||
}
|
||||
|
||||
func (pm *Manager) trackChannel(ctx context.Context, ch address.Address, dir uint64) error {
|
||||
ci, err := pm.loadStateChannelInfo(ctx, ch, dir)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return pm.store.TrackChannel(ci)
|
||||
}
|
||||
|
||||
func (pm *Manager) ListChannels() ([]address.Address, error) {
|
||||
return pm.store.ListChannels()
|
||||
}
|
||||
|
||||
func (pm *Manager) GetChannelInfo(addr address.Address) (*ChannelInfo, error) {
|
||||
return pm.store.getChannelInfo(addr)
|
||||
}
|
||||
|
||||
// CheckVoucherValid checks if the given voucher is valid (is or could become spendable at some point)
|
||||
func (pm *Manager) CheckVoucherValid(ctx context.Context, ch address.Address, sv *paych.SignedVoucher) error {
|
||||
_, err := pm.checkVoucherValid(ctx, ch, sv)
|
||||
return err
|
||||
}
|
||||
|
||||
func (pm *Manager) checkVoucherValid(ctx context.Context, ch address.Address, sv *paych.SignedVoucher) (map[uint64]*paych.LaneState, error) {
|
||||
func (ca *channelAccessor) checkVoucherValidUnlocked(ctx context.Context, ch address.Address, sv *paych.SignedVoucher) (map[uint64]*paych.LaneState, error) {
|
||||
if sv.ChannelAddr != ch {
|
||||
return nil, xerrors.Errorf("voucher ChannelAddr doesn't match channel address, got %s, expected %s", sv.ChannelAddr, ch)
|
||||
}
|
||||
|
||||
act, pchState, err := pm.loadPaychState(ctx, ch)
|
||||
act, pchState, err := ca.sa.loadPaychState(ctx, ch)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var account account.State
|
||||
_, err = pm.sm.LoadActorState(ctx, pchState.From, &account, nil)
|
||||
var actState account.State
|
||||
_, err = ca.sm.LoadActorState(ctx, pchState.From, &actState, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
from := account.Address
|
||||
from := actState.Address
|
||||
|
||||
// verify signature
|
||||
vb, err := sv.SigningBytes()
|
||||
@ -132,7 +89,7 @@ func (pm *Manager) checkVoucherValid(ctx context.Context, ch address.Address, sv
|
||||
}
|
||||
|
||||
// Check the voucher against the highest known voucher nonce / value
|
||||
laneStates, err := pm.laneState(pchState, ch)
|
||||
laneStates, err := ca.laneState(pchState, ch)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -164,7 +121,7 @@ func (pm *Manager) checkVoucherValid(ctx context.Context, ch address.Address, sv
|
||||
// lane 2: 2
|
||||
// -
|
||||
// total: 7
|
||||
totalRedeemed, err := pm.totalRedeemedWithVoucher(laneStates, sv)
|
||||
totalRedeemed, err := ca.totalRedeemedWithVoucher(laneStates, sv)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -183,15 +140,17 @@ func (pm *Manager) checkVoucherValid(ctx context.Context, ch address.Address, sv
|
||||
return laneStates, nil
|
||||
}
|
||||
|
||||
// CheckVoucherSpendable checks if the given voucher is currently spendable
|
||||
func (pm *Manager) CheckVoucherSpendable(ctx context.Context, ch address.Address, sv *paych.SignedVoucher, secret []byte, proof []byte) (bool, error) {
|
||||
recipient, err := pm.getPaychRecipient(ctx, ch)
|
||||
func (ca *channelAccessor) checkVoucherSpendable(ctx context.Context, ch address.Address, sv *paych.SignedVoucher, secret []byte, proof []byte) (bool, error) {
|
||||
ca.lk.Lock()
|
||||
defer ca.lk.Unlock()
|
||||
|
||||
recipient, err := ca.getPaychRecipient(ctx, ch)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
if sv.Extra != nil && proof == nil {
|
||||
known, err := pm.ListVouchers(ctx, ch)
|
||||
known, err := ca.store.VouchersForPaych(ch)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
@ -221,7 +180,7 @@ func (pm *Manager) CheckVoucherSpendable(ctx context.Context, ch address.Address
|
||||
return false, err
|
||||
}
|
||||
|
||||
ret, err := pm.sm.Call(ctx, &types.Message{
|
||||
ret, err := ca.sm.Call(ctx, &types.Message{
|
||||
From: recipient,
|
||||
To: ch,
|
||||
Method: builtin.MethodsPaych.UpdateChannelState,
|
||||
@ -238,22 +197,22 @@ func (pm *Manager) CheckVoucherSpendable(ctx context.Context, ch address.Address
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func (pm *Manager) getPaychRecipient(ctx context.Context, ch address.Address) (address.Address, error) {
|
||||
func (ca *channelAccessor) getPaychRecipient(ctx context.Context, ch address.Address) (address.Address, error) {
|
||||
var state paych.State
|
||||
if _, err := pm.sm.LoadActorState(ctx, ch, &state, nil); err != nil {
|
||||
if _, err := ca.sm.LoadActorState(ctx, ch, &state, nil); err != nil {
|
||||
return address.Address{}, err
|
||||
}
|
||||
|
||||
return state.To, nil
|
||||
}
|
||||
|
||||
func (pm *Manager) AddVoucher(ctx context.Context, ch address.Address, sv *paych.SignedVoucher, proof []byte, minDelta types.BigInt) (types.BigInt, error) {
|
||||
pm.store.lk.Lock()
|
||||
defer pm.store.lk.Unlock()
|
||||
func (ca *channelAccessor) addVoucher(ctx context.Context, ch address.Address, sv *paych.SignedVoucher, proof []byte, minDelta types.BigInt) (types.BigInt, error) {
|
||||
ca.lk.Lock()
|
||||
defer ca.lk.Unlock()
|
||||
|
||||
ci, err := pm.store.getChannelInfo(ch)
|
||||
ci, err := ca.store.ByAddress(ch)
|
||||
if err != nil {
|
||||
return types.NewInt(0), err
|
||||
return types.BigInt{}, err
|
||||
}
|
||||
|
||||
// Check if the voucher has already been added
|
||||
@ -275,7 +234,7 @@ func (pm *Manager) AddVoucher(ctx context.Context, ch address.Address, sv *paych
|
||||
Proof: proof,
|
||||
}
|
||||
|
||||
return types.NewInt(0), pm.store.putChannelInfo(ci)
|
||||
return types.NewInt(0), ca.store.putChannelInfo(ci)
|
||||
}
|
||||
|
||||
// Otherwise just ignore the duplicate voucher
|
||||
@ -284,7 +243,7 @@ func (pm *Manager) AddVoucher(ctx context.Context, ch address.Address, sv *paych
|
||||
}
|
||||
|
||||
// Check voucher validity
|
||||
laneStates, err := pm.checkVoucherValid(ctx, ch, sv)
|
||||
laneStates, err := ca.checkVoucherValidUnlocked(ctx, ch, sv)
|
||||
if err != nil {
|
||||
return types.NewInt(0), err
|
||||
}
|
||||
@ -311,35 +270,32 @@ func (pm *Manager) AddVoucher(ctx context.Context, ch address.Address, sv *paych
|
||||
ci.NextLane = sv.Lane + 1
|
||||
}
|
||||
|
||||
return delta, pm.store.putChannelInfo(ci)
|
||||
return delta, ca.store.putChannelInfo(ci)
|
||||
}
|
||||
|
||||
func (pm *Manager) AllocateLane(ch address.Address) (uint64, error) {
|
||||
func (ca *channelAccessor) allocateLane(ch address.Address) (uint64, error) {
|
||||
ca.lk.Lock()
|
||||
defer ca.lk.Unlock()
|
||||
|
||||
// TODO: should this take into account lane state?
|
||||
return pm.store.AllocateLane(ch)
|
||||
return ca.store.AllocateLane(ch)
|
||||
}
|
||||
|
||||
func (pm *Manager) ListVouchers(ctx context.Context, ch address.Address) ([]*VoucherInfo, error) {
|
||||
func (ca *channelAccessor) listVouchers(ctx context.Context, ch address.Address) ([]*VoucherInfo, error) {
|
||||
ca.lk.Lock()
|
||||
defer ca.lk.Unlock()
|
||||
|
||||
// TODO: just having a passthrough method like this feels odd. Seems like
|
||||
// there should be some filtering we're doing here
|
||||
return pm.store.VouchersForPaych(ch)
|
||||
return ca.store.VouchersForPaych(ch)
|
||||
}
|
||||
|
||||
func (pm *Manager) OutboundChanTo(from, to address.Address) (address.Address, error) {
|
||||
pm.store.lk.Lock()
|
||||
defer pm.store.lk.Unlock()
|
||||
func (ca *channelAccessor) nextNonceForLane(ctx context.Context, ch address.Address, lane uint64) (uint64, error) {
|
||||
ca.lk.Lock()
|
||||
defer ca.lk.Unlock()
|
||||
|
||||
return pm.store.findChan(func(ci *ChannelInfo) bool {
|
||||
if ci.Direction != DirOutbound {
|
||||
return false
|
||||
}
|
||||
return ci.Control == from && ci.Target == to
|
||||
})
|
||||
}
|
||||
|
||||
func (pm *Manager) NextNonceForLane(ctx context.Context, ch address.Address, lane uint64) (uint64, error) {
|
||||
// TODO: should this take into account lane state?
|
||||
vouchers, err := pm.store.VouchersForPaych(ch)
|
||||
vouchers, err := ca.store.VouchersForPaych(ch)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
@ -355,3 +311,109 @@ func (pm *Manager) NextNonceForLane(ctx context.Context, ch address.Address, lan
|
||||
|
||||
return maxnonce + 1, nil
|
||||
}
|
||||
|
||||
// laneState gets the LaneStates from chain, then applies all vouchers in
|
||||
// the data store over the chain state
|
||||
func (ca *channelAccessor) laneState(state *paych.State, ch address.Address) (map[uint64]*paych.LaneState, error) {
|
||||
// TODO: we probably want to call UpdateChannelState with all vouchers to be fully correct
|
||||
// (but technically dont't need to)
|
||||
laneStates := make(map[uint64]*paych.LaneState, len(state.LaneStates))
|
||||
|
||||
// Get the lane state from the chain
|
||||
for _, laneState := range state.LaneStates {
|
||||
laneStates[laneState.ID] = laneState
|
||||
}
|
||||
|
||||
// Apply locally stored vouchers
|
||||
vouchers, err := ca.store.VouchersForPaych(ch)
|
||||
if err != nil && err != ErrChannelNotTracked {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for _, v := range vouchers {
|
||||
for range v.Voucher.Merges {
|
||||
return nil, xerrors.Errorf("paych merges not handled yet")
|
||||
}
|
||||
|
||||
// If there's a voucher for a lane that isn't in chain state just
|
||||
// create it
|
||||
ls, ok := laneStates[v.Voucher.Lane]
|
||||
if !ok {
|
||||
ls = &paych.LaneState{
|
||||
ID: v.Voucher.Lane,
|
||||
Redeemed: types.NewInt(0),
|
||||
Nonce: 0,
|
||||
}
|
||||
laneStates[v.Voucher.Lane] = ls
|
||||
}
|
||||
|
||||
if v.Voucher.Nonce < ls.Nonce {
|
||||
continue
|
||||
}
|
||||
|
||||
ls.Nonce = v.Voucher.Nonce
|
||||
ls.Redeemed = v.Voucher.Amount
|
||||
}
|
||||
|
||||
return laneStates, nil
|
||||
}
|
||||
|
||||
// Get the total redeemed amount across all lanes, after applying the voucher
|
||||
func (ca *channelAccessor) totalRedeemedWithVoucher(laneStates map[uint64]*paych.LaneState, sv *paych.SignedVoucher) (big.Int, error) {
|
||||
// TODO: merges
|
||||
if len(sv.Merges) != 0 {
|
||||
return big.Int{}, xerrors.Errorf("dont currently support paych lane merges")
|
||||
}
|
||||
|
||||
total := big.NewInt(0)
|
||||
for _, ls := range laneStates {
|
||||
total = big.Add(total, ls.Redeemed)
|
||||
}
|
||||
|
||||
lane, ok := laneStates[sv.Lane]
|
||||
if ok {
|
||||
// If the voucher is for an existing lane, and the voucher nonce
|
||||
// and is higher than the lane nonce
|
||||
if sv.Nonce > lane.Nonce {
|
||||
// Add the delta between the redeemed amount and the voucher
|
||||
// amount to the total
|
||||
delta := big.Sub(sv.Amount, lane.Redeemed)
|
||||
total = big.Add(total, delta)
|
||||
}
|
||||
} else {
|
||||
// If the voucher is *not* for an existing lane, just add its
|
||||
// value (implicitly a new lane will be created for the voucher)
|
||||
total = big.Add(total, sv.Amount)
|
||||
}
|
||||
|
||||
return total, nil
|
||||
}
|
||||
|
||||
func (ca *channelAccessor) settle(ctx context.Context, ch address.Address) (cid.Cid, error) {
|
||||
ca.lk.Lock()
|
||||
defer ca.lk.Unlock()
|
||||
|
||||
ci, err := ca.store.ByAddress(ch)
|
||||
if err != nil {
|
||||
return cid.Undef, err
|
||||
}
|
||||
|
||||
msg := &types.Message{
|
||||
To: ch,
|
||||
From: ci.Control,
|
||||
Value: types.NewInt(0),
|
||||
Method: builtin.MethodsPaych.Settle,
|
||||
}
|
||||
smgs, err := ca.api.MpoolPushMessage(ctx, msg)
|
||||
if err != nil {
|
||||
return cid.Undef, err
|
||||
}
|
||||
|
||||
ci.Settling = true
|
||||
err = ca.store.putChannelInfo(ci)
|
||||
if err != nil {
|
||||
log.Errorf("Error marking channel as settled: %s", err)
|
||||
}
|
||||
|
||||
return smgs.Cid(), err
|
||||
}
|
||||
|
@ -104,13 +104,15 @@ func TestPaychOutbound(t *testing.T) {
|
||||
LaneStates: []*paych.LaneState{},
|
||||
})
|
||||
|
||||
mgr := newManager(sm, store)
|
||||
err := mgr.TrackOutboundChannel(ctx, ch)
|
||||
mgr, err := newManager(sm, store, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = mgr.TrackOutboundChannel(ctx, ch)
|
||||
require.NoError(t, err)
|
||||
|
||||
ci, err := mgr.GetChannelInfo(ch)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, ci.Channel, ch)
|
||||
require.Equal(t, *ci.Channel, ch)
|
||||
require.Equal(t, ci.Control, from)
|
||||
require.Equal(t, ci.Target, to)
|
||||
require.EqualValues(t, ci.Direction, DirOutbound)
|
||||
@ -140,13 +142,15 @@ func TestPaychInbound(t *testing.T) {
|
||||
LaneStates: []*paych.LaneState{},
|
||||
})
|
||||
|
||||
mgr := newManager(sm, store)
|
||||
err := mgr.TrackInboundChannel(ctx, ch)
|
||||
mgr, err := newManager(sm, store, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = mgr.TrackInboundChannel(ctx, ch)
|
||||
require.NoError(t, err)
|
||||
|
||||
ci, err := mgr.GetChannelInfo(ch)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, ci.Channel, ch)
|
||||
require.Equal(t, *ci.Channel, ch)
|
||||
require.Equal(t, ci.Control, to)
|
||||
require.Equal(t, ci.Target, from)
|
||||
require.EqualValues(t, ci.Direction, DirInbound)
|
||||
@ -321,8 +325,10 @@ func TestCheckVoucherValid(t *testing.T) {
|
||||
LaneStates: tcase.laneStates,
|
||||
})
|
||||
|
||||
mgr := newManager(sm, store)
|
||||
err := mgr.TrackInboundChannel(ctx, ch)
|
||||
mgr, err := newManager(sm, store, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = mgr.TrackInboundChannel(ctx, ch)
|
||||
require.NoError(t, err)
|
||||
|
||||
sv := testCreateVoucher(t, ch, tcase.voucherLane, tcase.voucherNonce, tcase.voucherAmount, tcase.key)
|
||||
@ -382,8 +388,10 @@ func TestCheckVoucherValidCountingAllLanes(t *testing.T) {
|
||||
LaneStates: laneStates,
|
||||
})
|
||||
|
||||
mgr := newManager(sm, store)
|
||||
err := mgr.TrackInboundChannel(ctx, ch)
|
||||
mgr, err := newManager(sm, store, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = mgr.TrackInboundChannel(ctx, ch)
|
||||
require.NoError(t, err)
|
||||
|
||||
//
|
||||
@ -690,8 +698,10 @@ func testSetupMgrWithChannel(ctx context.Context, t *testing.T) (*Manager, addre
|
||||
})
|
||||
|
||||
store := NewStore(ds_sync.MutexWrap(ds.NewMapDatastore()))
|
||||
mgr := newManager(sm, store)
|
||||
err := mgr.TrackInboundChannel(ctx, ch)
|
||||
mgr, err := newManager(sm, store, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = mgr.TrackInboundChannel(ctx, ch)
|
||||
require.NoError(t, err)
|
||||
return mgr, ch, fromKeyPrivate
|
||||
}
|
||||
|
756
paychmgr/paychget_test.go
Normal file
756
paychmgr/paychget_test.go
Normal file
@ -0,0 +1,756 @@
|
||||
package paychmgr
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
cborrpc "github.com/filecoin-project/go-cbor-util"
|
||||
|
||||
init_ "github.com/filecoin-project/specs-actors/actors/builtin/init"
|
||||
|
||||
"github.com/filecoin-project/specs-actors/actors/builtin"
|
||||
|
||||
"github.com/filecoin-project/lotus/api"
|
||||
"github.com/filecoin-project/lotus/chain/types"
|
||||
|
||||
"github.com/filecoin-project/go-address"
|
||||
|
||||
"github.com/filecoin-project/specs-actors/actors/abi/big"
|
||||
tutils "github.com/filecoin-project/specs-actors/support/testing"
|
||||
"github.com/ipfs/go-cid"
|
||||
ds "github.com/ipfs/go-datastore"
|
||||
ds_sync "github.com/ipfs/go-datastore/sync"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type waitingCall struct {
|
||||
response chan types.MessageReceipt
|
||||
}
|
||||
|
||||
type mockPaychAPI struct {
|
||||
lk sync.Mutex
|
||||
messages map[cid.Cid]*types.SignedMessage
|
||||
waitingCalls []*waitingCall
|
||||
}
|
||||
|
||||
func newMockPaychAPI() *mockPaychAPI {
|
||||
return &mockPaychAPI{
|
||||
messages: make(map[cid.Cid]*types.SignedMessage),
|
||||
}
|
||||
}
|
||||
|
||||
func (pchapi *mockPaychAPI) StateWaitMsg(ctx context.Context, msg cid.Cid, confidence uint64) (*api.MsgLookup, error) {
|
||||
response := make(chan types.MessageReceipt)
|
||||
|
||||
pchapi.lk.Lock()
|
||||
pchapi.waitingCalls = append(pchapi.waitingCalls, &waitingCall{response: response})
|
||||
pchapi.lk.Unlock()
|
||||
|
||||
receipt := <-response
|
||||
|
||||
return &api.MsgLookup{Receipt: receipt}, nil
|
||||
}
|
||||
|
||||
func (pchapi *mockPaychAPI) MpoolPushMessage(ctx context.Context, msg *types.Message) (*types.SignedMessage, error) {
|
||||
pchapi.lk.Lock()
|
||||
defer pchapi.lk.Unlock()
|
||||
|
||||
smsg := &types.SignedMessage{Message: *msg}
|
||||
pchapi.messages[smsg.Cid()] = smsg
|
||||
return smsg, nil
|
||||
}
|
||||
|
||||
func (pchapi *mockPaychAPI) pushedMessages(c cid.Cid) *types.SignedMessage {
|
||||
pchapi.lk.Lock()
|
||||
defer pchapi.lk.Unlock()
|
||||
|
||||
return pchapi.messages[c]
|
||||
}
|
||||
|
||||
func (pchapi *mockPaychAPI) pushedMessageCount() int {
|
||||
pchapi.lk.Lock()
|
||||
defer pchapi.lk.Unlock()
|
||||
|
||||
return len(pchapi.messages)
|
||||
}
|
||||
|
||||
func (pchapi *mockPaychAPI) finishWaitingCalls(receipt types.MessageReceipt) {
|
||||
pchapi.lk.Lock()
|
||||
defer pchapi.lk.Unlock()
|
||||
|
||||
for _, call := range pchapi.waitingCalls {
|
||||
call.response <- receipt
|
||||
}
|
||||
pchapi.waitingCalls = nil
|
||||
}
|
||||
|
||||
func (pchapi *mockPaychAPI) close() {
|
||||
pchapi.finishWaitingCalls(types.MessageReceipt{
|
||||
ExitCode: 0,
|
||||
Return: []byte{},
|
||||
})
|
||||
}
|
||||
|
||||
func testChannelResponse(t *testing.T, ch address.Address) types.MessageReceipt {
|
||||
createChannelRet := init_.ExecReturn{
|
||||
IDAddress: ch,
|
||||
RobustAddress: ch,
|
||||
}
|
||||
createChannelRetBytes, err := cborrpc.Dump(&createChannelRet)
|
||||
require.NoError(t, err)
|
||||
createChannelResponse := types.MessageReceipt{
|
||||
ExitCode: 0,
|
||||
Return: createChannelRetBytes,
|
||||
}
|
||||
return createChannelResponse
|
||||
}
|
||||
|
||||
// TestPaychGetCreateChannelMsg tests that GetPaych sends a message to create
|
||||
// a new channel with the correct funds
|
||||
func TestPaychGetCreateChannelMsg(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
store := NewStore(ds_sync.MutexWrap(ds.NewMapDatastore()))
|
||||
|
||||
from := tutils.NewIDAddr(t, 101)
|
||||
to := tutils.NewIDAddr(t, 102)
|
||||
|
||||
sm := newMockStateManager()
|
||||
pchapi := newMockPaychAPI()
|
||||
defer pchapi.close()
|
||||
|
||||
mgr, err := newManager(sm, store, pchapi)
|
||||
require.NoError(t, err)
|
||||
|
||||
amt := big.NewInt(10)
|
||||
ch, mcid, err := mgr.GetPaych(ctx, from, to, amt)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, address.Undef, ch)
|
||||
|
||||
pushedMsg := pchapi.pushedMessages(mcid)
|
||||
require.Equal(t, from, pushedMsg.Message.From)
|
||||
require.Equal(t, builtin.InitActorAddr, pushedMsg.Message.To)
|
||||
require.Equal(t, amt, pushedMsg.Message.Value)
|
||||
}
|
||||
|
||||
// TestPaychGetCreateChannelThenAddFunds tests creating a channel and then
|
||||
// adding funds to it
|
||||
func TestPaychGetCreateChannelThenAddFunds(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
store := NewStore(ds_sync.MutexWrap(ds.NewMapDatastore()))
|
||||
|
||||
ch := tutils.NewIDAddr(t, 100)
|
||||
from := tutils.NewIDAddr(t, 101)
|
||||
to := tutils.NewIDAddr(t, 102)
|
||||
|
||||
sm := newMockStateManager()
|
||||
pchapi := newMockPaychAPI()
|
||||
defer pchapi.close()
|
||||
|
||||
mgr, err := newManager(sm, store, pchapi)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Send create message for a channel with value 10
|
||||
amt := big.NewInt(10)
|
||||
_, createMsgCid, err := mgr.GetPaych(ctx, from, to, amt)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Should have no channels yet (message sent but channel not created)
|
||||
cis, err := mgr.ListChannels()
|
||||
require.NoError(t, err)
|
||||
require.Len(t, cis, 0)
|
||||
|
||||
// 1. Set up create channel response (sent in response to WaitForMsg())
|
||||
response := testChannelResponse(t, ch)
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer close(done)
|
||||
|
||||
// 2. Request add funds - should block until create channel has completed
|
||||
amt2 := big.NewInt(5)
|
||||
ch2, addFundsMsgCid, err := mgr.GetPaych(ctx, from, to, amt2)
|
||||
|
||||
// 4. This GetPaych should return after create channel from first
|
||||
// GetPaych completes
|
||||
require.NoError(t, err)
|
||||
|
||||
// Expect the channel to be the same
|
||||
require.Equal(t, ch, ch2)
|
||||
// Expect add funds message CID to be different to create message cid
|
||||
require.NotEqual(t, createMsgCid, addFundsMsgCid)
|
||||
|
||||
// Should have one channel, whose address is the channel that was created
|
||||
cis, err := mgr.ListChannels()
|
||||
require.NoError(t, err)
|
||||
require.Len(t, cis, 1)
|
||||
require.Equal(t, ch, cis[0])
|
||||
|
||||
// Amount should be amount sent to first GetPaych (to create
|
||||
// channel).
|
||||
// PendingAmount should be amount sent in second GetPaych
|
||||
// (second GetPaych triggered add funds, which has not yet been confirmed)
|
||||
ci, err := mgr.GetChannelInfo(ch)
|
||||
require.NoError(t, err)
|
||||
require.EqualValues(t, 10, ci.Amount.Int64())
|
||||
require.EqualValues(t, 5, ci.PendingAmount.Int64())
|
||||
require.Nil(t, ci.CreateMsg)
|
||||
|
||||
// Trigger add funds confirmation
|
||||
pchapi.finishWaitingCalls(types.MessageReceipt{ExitCode: 0})
|
||||
|
||||
time.Sleep(time.Millisecond * 10)
|
||||
|
||||
// Should still have one channel
|
||||
cis, err = mgr.ListChannels()
|
||||
require.NoError(t, err)
|
||||
require.Len(t, cis, 1)
|
||||
require.Equal(t, ch, cis[0])
|
||||
|
||||
// Channel amount should include last amount sent to GetPaych
|
||||
ci, err = mgr.GetChannelInfo(ch)
|
||||
require.NoError(t, err)
|
||||
require.EqualValues(t, 15, ci.Amount.Int64())
|
||||
require.EqualValues(t, 0, ci.PendingAmount.Int64())
|
||||
require.Nil(t, ci.AddFundsMsg)
|
||||
}()
|
||||
|
||||
// Give the go routine above a moment to run
|
||||
time.Sleep(time.Millisecond * 10)
|
||||
|
||||
// 3. Send create channel response
|
||||
pchapi.finishWaitingCalls(response)
|
||||
|
||||
<-done
|
||||
}
|
||||
|
||||
// TestPaychGetCreateChannelWithErrorThenCreateAgain tests that if an
|
||||
// operation is queued up behind a create channel operation, and the create
|
||||
// channel fails, then the waiting operation can succeed.
|
||||
func TestPaychGetCreateChannelWithErrorThenCreateAgain(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
store := NewStore(ds_sync.MutexWrap(ds.NewMapDatastore()))
|
||||
|
||||
from := tutils.NewIDAddr(t, 101)
|
||||
to := tutils.NewIDAddr(t, 102)
|
||||
|
||||
sm := newMockStateManager()
|
||||
pchapi := newMockPaychAPI()
|
||||
defer pchapi.close()
|
||||
|
||||
mgr, err := newManager(sm, store, pchapi)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Send create message for a channel
|
||||
amt := big.NewInt(10)
|
||||
_, _, err = mgr.GetPaych(ctx, from, to, amt)
|
||||
require.NoError(t, err)
|
||||
|
||||
// 1. Set up create channel response (sent in response to WaitForMsg())
|
||||
// This response indicates an error.
|
||||
errResponse := types.MessageReceipt{
|
||||
ExitCode: 1, // error
|
||||
Return: []byte{},
|
||||
}
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer close(done)
|
||||
|
||||
// 2. Should block until create channel has completed.
|
||||
// Because first channel create fails, this request
|
||||
// should be for channel create.
|
||||
amt2 := big.NewInt(5)
|
||||
ch2, _, err := mgr.GetPaych(ctx, from, to, amt2)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, address.Undef, ch2)
|
||||
|
||||
time.Sleep(time.Millisecond * 10)
|
||||
|
||||
// 4. Send a success response
|
||||
ch := tutils.NewIDAddr(t, 100)
|
||||
successResponse := testChannelResponse(t, ch)
|
||||
pchapi.finishWaitingCalls(successResponse)
|
||||
|
||||
time.Sleep(time.Millisecond * 10)
|
||||
|
||||
// Should have one channel, whose address is the channel that was created
|
||||
cis, err := mgr.ListChannels()
|
||||
require.NoError(t, err)
|
||||
require.Len(t, cis, 1)
|
||||
require.Equal(t, ch, cis[0])
|
||||
|
||||
ci, err := mgr.GetChannelInfo(ch)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, amt2, ci.Amount)
|
||||
}()
|
||||
|
||||
// Give the go routine above a moment to run
|
||||
time.Sleep(time.Millisecond * 10)
|
||||
|
||||
// 3. Send error response to first channel create
|
||||
pchapi.finishWaitingCalls(errResponse)
|
||||
|
||||
<-done
|
||||
}
|
||||
|
||||
// TestPaychGetRecoverAfterError tests that after a create channel fails, the
|
||||
// next attempt to create channel can succeed.
|
||||
func TestPaychGetRecoverAfterError(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
store := NewStore(ds_sync.MutexWrap(ds.NewMapDatastore()))
|
||||
|
||||
ch := tutils.NewIDAddr(t, 100)
|
||||
from := tutils.NewIDAddr(t, 101)
|
||||
to := tutils.NewIDAddr(t, 102)
|
||||
|
||||
sm := newMockStateManager()
|
||||
pchapi := newMockPaychAPI()
|
||||
defer pchapi.close()
|
||||
|
||||
mgr, err := newManager(sm, store, pchapi)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Send create message for a channel
|
||||
amt := big.NewInt(10)
|
||||
_, _, err = mgr.GetPaych(ctx, from, to, amt)
|
||||
require.NoError(t, err)
|
||||
|
||||
time.Sleep(time.Millisecond * 10)
|
||||
|
||||
// Send error create channel response
|
||||
pchapi.finishWaitingCalls(types.MessageReceipt{
|
||||
ExitCode: 1, // error
|
||||
Return: []byte{},
|
||||
})
|
||||
|
||||
time.Sleep(time.Millisecond * 10)
|
||||
|
||||
// Send create message for a channel again
|
||||
amt2 := big.NewInt(7)
|
||||
_, _, err = mgr.GetPaych(ctx, from, to, amt2)
|
||||
require.NoError(t, err)
|
||||
|
||||
time.Sleep(time.Millisecond * 10)
|
||||
|
||||
// Send success create channel response
|
||||
response := testChannelResponse(t, ch)
|
||||
pchapi.finishWaitingCalls(response)
|
||||
|
||||
time.Sleep(time.Millisecond * 10)
|
||||
|
||||
// Should have one channel, whose address is the channel that was created
|
||||
cis, err := mgr.ListChannels()
|
||||
require.NoError(t, err)
|
||||
require.Len(t, cis, 1)
|
||||
require.Equal(t, ch, cis[0])
|
||||
|
||||
ci, err := mgr.GetChannelInfo(ch)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, amt2, ci.Amount)
|
||||
require.EqualValues(t, 0, ci.PendingAmount.Int64())
|
||||
require.Nil(t, ci.CreateMsg)
|
||||
}
|
||||
|
||||
// TestPaychGetRecoverAfterAddFundsError tests that after an add funds fails, the
|
||||
// next attempt to add funds can succeed.
|
||||
func TestPaychGetRecoverAfterAddFundsError(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
store := NewStore(ds_sync.MutexWrap(ds.NewMapDatastore()))
|
||||
|
||||
ch := tutils.NewIDAddr(t, 100)
|
||||
from := tutils.NewIDAddr(t, 101)
|
||||
to := tutils.NewIDAddr(t, 102)
|
||||
|
||||
sm := newMockStateManager()
|
||||
pchapi := newMockPaychAPI()
|
||||
defer pchapi.close()
|
||||
|
||||
mgr, err := newManager(sm, store, pchapi)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Send create message for a channel
|
||||
amt := big.NewInt(10)
|
||||
_, _, err = mgr.GetPaych(ctx, from, to, amt)
|
||||
require.NoError(t, err)
|
||||
|
||||
time.Sleep(time.Millisecond * 10)
|
||||
|
||||
// Send success create channel response
|
||||
response := testChannelResponse(t, ch)
|
||||
pchapi.finishWaitingCalls(response)
|
||||
|
||||
time.Sleep(time.Millisecond * 10)
|
||||
|
||||
// Send add funds message for channel
|
||||
amt2 := big.NewInt(5)
|
||||
_, _, err = mgr.GetPaych(ctx, from, to, amt2)
|
||||
require.NoError(t, err)
|
||||
|
||||
time.Sleep(time.Millisecond * 10)
|
||||
|
||||
// Send error add funds response
|
||||
pchapi.finishWaitingCalls(types.MessageReceipt{
|
||||
ExitCode: 1, // error
|
||||
Return: []byte{},
|
||||
})
|
||||
|
||||
time.Sleep(time.Millisecond * 10)
|
||||
|
||||
// Should have one channel, whose address is the channel that was created
|
||||
cis, err := mgr.ListChannels()
|
||||
require.NoError(t, err)
|
||||
require.Len(t, cis, 1)
|
||||
require.Equal(t, ch, cis[0])
|
||||
|
||||
ci, err := mgr.GetChannelInfo(ch)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, amt, ci.Amount)
|
||||
require.EqualValues(t, 0, ci.PendingAmount.Int64())
|
||||
require.Nil(t, ci.CreateMsg)
|
||||
require.Nil(t, ci.AddFundsMsg)
|
||||
|
||||
// Send add funds message for channel again
|
||||
amt3 := big.NewInt(2)
|
||||
_, _, err = mgr.GetPaych(ctx, from, to, amt3)
|
||||
require.NoError(t, err)
|
||||
|
||||
time.Sleep(time.Millisecond * 10)
|
||||
|
||||
// Send success add funds response
|
||||
pchapi.finishWaitingCalls(types.MessageReceipt{
|
||||
ExitCode: 0,
|
||||
Return: []byte{},
|
||||
})
|
||||
|
||||
time.Sleep(time.Millisecond * 10)
|
||||
|
||||
// Should have one channel, whose address is the channel that was created
|
||||
cis, err = mgr.ListChannels()
|
||||
require.NoError(t, err)
|
||||
require.Len(t, cis, 1)
|
||||
require.Equal(t, ch, cis[0])
|
||||
|
||||
// Amount should include amount for successful add funds msg
|
||||
ci, err = mgr.GetChannelInfo(ch)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, amt.Int64()+amt3.Int64(), ci.Amount.Int64())
|
||||
require.EqualValues(t, 0, ci.PendingAmount.Int64())
|
||||
require.Nil(t, ci.CreateMsg)
|
||||
require.Nil(t, ci.AddFundsMsg)
|
||||
}
|
||||
|
||||
// TestPaychGetRestartAfterCreateChannelMsg tests that if the system stops
|
||||
// right after the create channel message is sent, the channel will be
|
||||
// created when the system restarts.
|
||||
func TestPaychGetRestartAfterCreateChannelMsg(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
store := NewStore(ds_sync.MutexWrap(ds.NewMapDatastore()))
|
||||
|
||||
ch := tutils.NewIDAddr(t, 100)
|
||||
from := tutils.NewIDAddr(t, 101)
|
||||
to := tutils.NewIDAddr(t, 102)
|
||||
|
||||
sm := newMockStateManager()
|
||||
pchapi := newMockPaychAPI()
|
||||
|
||||
mgr, err := newManager(sm, store, pchapi)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Send create message for a channel with value 10
|
||||
amt := big.NewInt(10)
|
||||
_, createMsgCid, err := mgr.GetPaych(ctx, from, to, amt)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Simulate shutting down lotus
|
||||
pchapi.close()
|
||||
|
||||
// Create a new manager with the same datastore
|
||||
sm2 := newMockStateManager()
|
||||
pchapi2 := newMockPaychAPI()
|
||||
defer pchapi2.close()
|
||||
|
||||
mgr2, err := newManager(sm2, store, pchapi2)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Should have no channels yet (message sent but channel not created)
|
||||
cis, err := mgr2.ListChannels()
|
||||
require.NoError(t, err)
|
||||
require.Len(t, cis, 0)
|
||||
|
||||
// 1. Set up create channel response (sent in response to WaitForMsg())
|
||||
response := testChannelResponse(t, ch)
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer close(done)
|
||||
|
||||
// 2. Request add funds - should block until create channel has completed
|
||||
amt2 := big.NewInt(5)
|
||||
ch2, addFundsMsgCid, err := mgr2.GetPaych(ctx, from, to, amt2)
|
||||
|
||||
// 4. This GetPaych should return after create channel from first
|
||||
// GetPaych completes
|
||||
require.NoError(t, err)
|
||||
|
||||
// Expect the channel to have been created
|
||||
require.Equal(t, ch, ch2)
|
||||
// Expect add funds message CID to be different to create message cid
|
||||
require.NotEqual(t, createMsgCid, addFundsMsgCid)
|
||||
|
||||
// Should have one channel, whose address is the channel that was created
|
||||
cis, err := mgr2.ListChannels()
|
||||
require.NoError(t, err)
|
||||
require.Len(t, cis, 1)
|
||||
require.Equal(t, ch, cis[0])
|
||||
|
||||
// Amount should be amount sent to first GetPaych (to create
|
||||
// channel).
|
||||
// PendingAmount should be amount sent in second GetPaych
|
||||
// (second GetPaych triggered add funds, which has not yet been confirmed)
|
||||
ci, err := mgr2.GetChannelInfo(ch)
|
||||
require.NoError(t, err)
|
||||
require.EqualValues(t, 10, ci.Amount.Int64())
|
||||
require.EqualValues(t, 5, ci.PendingAmount.Int64())
|
||||
require.Nil(t, ci.CreateMsg)
|
||||
}()
|
||||
|
||||
// Give the go routine above a moment to run
|
||||
time.Sleep(time.Millisecond * 10)
|
||||
|
||||
// 3. Send create channel response
|
||||
pchapi2.finishWaitingCalls(response)
|
||||
|
||||
<-done
|
||||
}
|
||||
|
||||
// TestPaychGetRestartAfterAddFundsMsg tests that if the system stops
|
||||
// right after the add funds message is sent, the add funds will be
|
||||
// processed when the system restarts.
|
||||
func TestPaychGetRestartAfterAddFundsMsg(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
store := NewStore(ds_sync.MutexWrap(ds.NewMapDatastore()))
|
||||
|
||||
ch := tutils.NewIDAddr(t, 100)
|
||||
from := tutils.NewIDAddr(t, 101)
|
||||
to := tutils.NewIDAddr(t, 102)
|
||||
|
||||
sm := newMockStateManager()
|
||||
pchapi := newMockPaychAPI()
|
||||
|
||||
mgr, err := newManager(sm, store, pchapi)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Send create message for a channel
|
||||
amt := big.NewInt(10)
|
||||
_, _, err = mgr.GetPaych(ctx, from, to, amt)
|
||||
require.NoError(t, err)
|
||||
|
||||
time.Sleep(time.Millisecond * 10)
|
||||
|
||||
// Send success create channel response
|
||||
response := testChannelResponse(t, ch)
|
||||
pchapi.finishWaitingCalls(response)
|
||||
|
||||
time.Sleep(time.Millisecond * 10)
|
||||
|
||||
// Send add funds message for channel
|
||||
amt2 := big.NewInt(5)
|
||||
_, _, err = mgr.GetPaych(ctx, from, to, amt2)
|
||||
require.NoError(t, err)
|
||||
|
||||
time.Sleep(time.Millisecond * 10)
|
||||
|
||||
// Simulate shutting down lotus
|
||||
pchapi.close()
|
||||
|
||||
// Create a new manager with the same datastore
|
||||
sm2 := newMockStateManager()
|
||||
pchapi2 := newMockPaychAPI()
|
||||
defer pchapi2.close()
|
||||
|
||||
time.Sleep(time.Millisecond * 10)
|
||||
|
||||
mgr2, err := newManager(sm2, store, pchapi2)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Send success add funds response
|
||||
pchapi2.finishWaitingCalls(types.MessageReceipt{
|
||||
ExitCode: 0,
|
||||
Return: []byte{},
|
||||
})
|
||||
|
||||
time.Sleep(time.Millisecond * 10)
|
||||
|
||||
// Should have one channel, whose address is the channel that was created
|
||||
cis, err := mgr2.ListChannels()
|
||||
require.NoError(t, err)
|
||||
require.Len(t, cis, 1)
|
||||
require.Equal(t, ch, cis[0])
|
||||
|
||||
// Amount should include amount for successful add funds msg
|
||||
ci, err := mgr2.GetChannelInfo(ch)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, amt.Int64()+amt2.Int64(), ci.Amount.Int64())
|
||||
require.EqualValues(t, 0, ci.PendingAmount.Int64())
|
||||
require.Nil(t, ci.CreateMsg)
|
||||
require.Nil(t, ci.AddFundsMsg)
|
||||
}
|
||||
|
||||
// TestPaychGetWait tests that GetPaychWaitReady correctly waits for the
|
||||
// channel to be created or funds to be added
|
||||
func TestPaychGetWait(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
store := NewStore(ds_sync.MutexWrap(ds.NewMapDatastore()))
|
||||
|
||||
from := tutils.NewIDAddr(t, 101)
|
||||
to := tutils.NewIDAddr(t, 102)
|
||||
|
||||
sm := newMockStateManager()
|
||||
pchapi := newMockPaychAPI()
|
||||
defer pchapi.close()
|
||||
|
||||
mgr, err := newManager(sm, store, pchapi)
|
||||
require.NoError(t, err)
|
||||
|
||||
// 1. Get
|
||||
amt := big.NewInt(10)
|
||||
_, mcid, err := mgr.GetPaych(ctx, from, to, amt)
|
||||
require.NoError(t, err)
|
||||
|
||||
done := make(chan address.Address)
|
||||
go func() {
|
||||
// 2. Wait till ready
|
||||
ch, err := mgr.GetPaychWaitReady(ctx, mcid)
|
||||
require.NoError(t, err)
|
||||
|
||||
done <- ch
|
||||
}()
|
||||
|
||||
time.Sleep(time.Millisecond * 10)
|
||||
|
||||
// 3. Send response
|
||||
expch := tutils.NewIDAddr(t, 100)
|
||||
response := testChannelResponse(t, expch)
|
||||
pchapi.finishWaitingCalls(response)
|
||||
|
||||
time.Sleep(time.Millisecond * 10)
|
||||
|
||||
ch := <-done
|
||||
require.Equal(t, expch, ch)
|
||||
|
||||
// 4. Wait again - message has already been received so should
|
||||
// return immediately
|
||||
ch, err = mgr.GetPaychWaitReady(ctx, mcid)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, expch, ch)
|
||||
|
||||
// Request add funds
|
||||
amt2 := big.NewInt(15)
|
||||
_, addFundsMsgCid, err := mgr.GetPaych(ctx, from, to, amt2)
|
||||
require.NoError(t, err)
|
||||
|
||||
time.Sleep(time.Millisecond * 10)
|
||||
|
||||
go func() {
|
||||
// 5. Wait for add funds
|
||||
ch, err := mgr.GetPaychWaitReady(ctx, addFundsMsgCid)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, expch, ch)
|
||||
|
||||
done <- ch
|
||||
}()
|
||||
|
||||
time.Sleep(time.Millisecond * 10)
|
||||
|
||||
// 6. Send add funds response
|
||||
addFundsResponse := types.MessageReceipt{
|
||||
ExitCode: 0,
|
||||
Return: []byte{},
|
||||
}
|
||||
pchapi.finishWaitingCalls(addFundsResponse)
|
||||
|
||||
<-done
|
||||
}
|
||||
|
||||
// TestPaychGetWaitErr tests that GetPaychWaitReady correctly handles errors
|
||||
func TestPaychGetWaitErr(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
store := NewStore(ds_sync.MutexWrap(ds.NewMapDatastore()))
|
||||
|
||||
from := tutils.NewIDAddr(t, 101)
|
||||
to := tutils.NewIDAddr(t, 102)
|
||||
|
||||
sm := newMockStateManager()
|
||||
pchapi := newMockPaychAPI()
|
||||
defer pchapi.close()
|
||||
|
||||
mgr, err := newManager(sm, store, pchapi)
|
||||
require.NoError(t, err)
|
||||
|
||||
// 1. Create channel
|
||||
amt := big.NewInt(10)
|
||||
_, mcid, err := mgr.GetPaych(ctx, from, to, amt)
|
||||
require.NoError(t, err)
|
||||
|
||||
done := make(chan address.Address)
|
||||
go func() {
|
||||
defer close(done)
|
||||
|
||||
// 2. Wait for channel to be created
|
||||
_, err := mgr.GetPaychWaitReady(ctx, mcid)
|
||||
|
||||
// 4. Channel creation should have failed
|
||||
require.NotNil(t, err)
|
||||
|
||||
// 5. Call wait again with the same message CID
|
||||
_, err = mgr.GetPaychWaitReady(ctx, mcid)
|
||||
|
||||
// 6. Should return immediately with the same error
|
||||
require.NotNil(t, err)
|
||||
}()
|
||||
|
||||
// Give the wait a moment to start before sending response
|
||||
time.Sleep(time.Millisecond * 10)
|
||||
|
||||
// 3. Send error response to create channel
|
||||
response := types.MessageReceipt{
|
||||
ExitCode: 1, // error
|
||||
Return: []byte{},
|
||||
}
|
||||
pchapi.finishWaitingCalls(response)
|
||||
|
||||
<-done
|
||||
}
|
||||
|
||||
// TestPaychGetWaitCtx tests that GetPaychWaitReady returns early if the context
|
||||
// is cancelled
|
||||
func TestPaychGetWaitCtx(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
store := NewStore(ds_sync.MutexWrap(ds.NewMapDatastore()))
|
||||
|
||||
from := tutils.NewIDAddr(t, 101)
|
||||
to := tutils.NewIDAddr(t, 102)
|
||||
|
||||
sm := newMockStateManager()
|
||||
pchapi := newMockPaychAPI()
|
||||
defer pchapi.close()
|
||||
|
||||
mgr, err := newManager(sm, store, pchapi)
|
||||
require.NoError(t, err)
|
||||
|
||||
amt := big.NewInt(10)
|
||||
_, mcid, err := mgr.GetPaych(ctx, from, to, amt)
|
||||
require.NoError(t, err)
|
||||
|
||||
// When the context is cancelled, should unblock wait
|
||||
go func() {
|
||||
time.Sleep(time.Millisecond * 10)
|
||||
cancel()
|
||||
}()
|
||||
|
||||
_, err = mgr.GetPaychWaitReady(ctx, mcid)
|
||||
require.Error(t, ctx.Err(), err)
|
||||
}
|
72
paychmgr/settle_test.go
Normal file
72
paychmgr/settle_test.go
Normal file
@ -0,0 +1,72 @@
|
||||
package paychmgr
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/ipfs/go-cid"
|
||||
|
||||
"github.com/filecoin-project/specs-actors/actors/abi/big"
|
||||
tutils "github.com/filecoin-project/specs-actors/support/testing"
|
||||
ds "github.com/ipfs/go-datastore"
|
||||
ds_sync "github.com/ipfs/go-datastore/sync"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestPaychSettle(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
store := NewStore(ds_sync.MutexWrap(ds.NewMapDatastore()))
|
||||
|
||||
expch := tutils.NewIDAddr(t, 100)
|
||||
expch2 := tutils.NewIDAddr(t, 101)
|
||||
from := tutils.NewIDAddr(t, 101)
|
||||
to := tutils.NewIDAddr(t, 102)
|
||||
|
||||
sm := newMockStateManager()
|
||||
pchapi := newMockPaychAPI()
|
||||
defer pchapi.close()
|
||||
|
||||
mgr, err := newManager(sm, store, pchapi)
|
||||
require.NoError(t, err)
|
||||
|
||||
amt := big.NewInt(10)
|
||||
_, mcid, err := mgr.GetPaych(ctx, from, to, amt)
|
||||
require.NoError(t, err)
|
||||
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
|
||||
// Send channel create response
|
||||
response := testChannelResponse(t, expch)
|
||||
pchapi.finishWaitingCalls(response)
|
||||
|
||||
// Get the channel address
|
||||
ch, err := mgr.GetPaychWaitReady(ctx, mcid)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, expch, ch)
|
||||
|
||||
// Settle the channel
|
||||
_, err = mgr.Settle(ctx, ch)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Send another request for funds to the same from/to
|
||||
// (should create a new channel because the previous channel
|
||||
// is settling)
|
||||
amt2 := big.NewInt(5)
|
||||
_, mcid2, err := mgr.GetPaych(ctx, from, to, amt2)
|
||||
require.NoError(t, err)
|
||||
require.NotEqual(t, cid.Undef, mcid2)
|
||||
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
|
||||
// Send new channel create response
|
||||
response2 := testChannelResponse(t, expch2)
|
||||
pchapi.finishWaitingCalls(response2)
|
||||
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
|
||||
// Make sure the new channel is different from the old channel
|
||||
ch2, err := mgr.GetPaychWaitReady(ctx, mcid2)
|
||||
require.NoError(t, err)
|
||||
require.NotEqual(t, ch, ch2)
|
||||
}
|
@ -3,6 +3,11 @@ package paychmgr
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/filecoin-project/lotus/api"
|
||||
|
||||
"github.com/filecoin-project/specs-actors/actors/abi/big"
|
||||
|
||||
"github.com/filecoin-project/specs-actors/actors/builtin"
|
||||
init_ "github.com/filecoin-project/specs-actors/actors/builtin/init"
|
||||
@ -17,7 +22,211 @@ import (
|
||||
"github.com/filecoin-project/lotus/chain/types"
|
||||
)
|
||||
|
||||
func (pm *Manager) createPaych(ctx context.Context, from, to address.Address, amt types.BigInt) (cid.Cid, error) {
|
||||
type paychApi interface {
|
||||
StateWaitMsg(ctx context.Context, msg cid.Cid, confidence uint64) (*api.MsgLookup, error)
|
||||
MpoolPushMessage(ctx context.Context, msg *types.Message) (*types.SignedMessage, error)
|
||||
}
|
||||
|
||||
// paychFundsRes is the response to a create channel or add funds request
|
||||
type paychFundsRes struct {
|
||||
channel address.Address
|
||||
mcid cid.Cid
|
||||
err error
|
||||
}
|
||||
|
||||
type onCompleteFn func(*paychFundsRes)
|
||||
|
||||
// fundsReq is a request to create a channel or add funds to a channel
|
||||
type fundsReq struct {
|
||||
ctx context.Context
|
||||
from address.Address
|
||||
to address.Address
|
||||
amt types.BigInt
|
||||
onComplete onCompleteFn
|
||||
}
|
||||
|
||||
// getPaych ensures that a channel exists between the from and to addresses,
|
||||
// and adds the given amount of funds.
|
||||
// If the channel does not exist a create channel message is sent and the
|
||||
// message CID is returned.
|
||||
// If the channel does exist an add funds message is sent and both the channel
|
||||
// address and message CID are returned.
|
||||
// If there is an in progress operation (create channel / add funds), getPaych
|
||||
// blocks until the previous operation completes, then returns both the channel
|
||||
// address and the CID of the new add funds message.
|
||||
// If an operation returns an error, subsequent waiting operations will still
|
||||
// be attempted.
|
||||
func (ca *channelAccessor) getPaych(ctx context.Context, from, to address.Address, amt types.BigInt) (address.Address, cid.Cid, error) {
|
||||
// Add the request to add funds to a queue and wait for the result
|
||||
promise := ca.enqueue(&fundsReq{ctx: ctx, from: from, to: to, amt: amt})
|
||||
select {
|
||||
case res := <-promise:
|
||||
return res.channel, res.mcid, res.err
|
||||
case <-ctx.Done():
|
||||
return address.Undef, cid.Undef, ctx.Err()
|
||||
}
|
||||
}
|
||||
|
||||
// Queue up an add funds operation
|
||||
func (ca *channelAccessor) enqueue(task *fundsReq) chan *paychFundsRes {
|
||||
promise := make(chan *paychFundsRes)
|
||||
task.onComplete = func(res *paychFundsRes) {
|
||||
select {
|
||||
case <-task.ctx.Done():
|
||||
case promise <- res:
|
||||
}
|
||||
}
|
||||
|
||||
ca.lk.Lock()
|
||||
defer ca.lk.Unlock()
|
||||
|
||||
ca.fundsReqQueue = append(ca.fundsReqQueue, task)
|
||||
go ca.processNextQueueItem()
|
||||
|
||||
return promise
|
||||
}
|
||||
|
||||
// Run the operation at the head of the queue
|
||||
func (ca *channelAccessor) processNextQueueItem() {
|
||||
ca.lk.Lock()
|
||||
defer ca.lk.Unlock()
|
||||
|
||||
if len(ca.fundsReqQueue) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
head := ca.fundsReqQueue[0]
|
||||
res := ca.processTask(head.ctx, head.from, head.to, head.amt, head.onComplete)
|
||||
|
||||
// If the task is waiting on an external event (eg something to appear on
|
||||
// chain) it will return nil
|
||||
if res == nil {
|
||||
// Stop processing the fundsReqQueue and wait. When the event occurs it will
|
||||
// call processNextQueueItem() again
|
||||
return
|
||||
}
|
||||
|
||||
// The task has finished processing so clean it up
|
||||
ca.fundsReqQueue[0] = nil // allow GC of element
|
||||
ca.fundsReqQueue = ca.fundsReqQueue[1:]
|
||||
|
||||
// Call the task callback with its results
|
||||
head.onComplete(res)
|
||||
|
||||
// Process the next task
|
||||
if len(ca.fundsReqQueue) > 0 {
|
||||
go ca.processNextQueueItem()
|
||||
}
|
||||
}
|
||||
|
||||
// msgWaitComplete is called when the message for a previous task is confirmed
|
||||
// or there is an error.
|
||||
func (ca *channelAccessor) msgWaitComplete(mcid cid.Cid, err error, cb onCompleteFn) {
|
||||
ca.lk.Lock()
|
||||
defer ca.lk.Unlock()
|
||||
|
||||
// Save the message result to the store
|
||||
dserr := ca.store.SaveMessageResult(mcid, err)
|
||||
if dserr != nil {
|
||||
log.Errorf("saving message result: %s", dserr)
|
||||
}
|
||||
|
||||
// Call the onComplete callback
|
||||
ca.callOnComplete(mcid, err, cb)
|
||||
|
||||
// Inform listeners that the message has completed
|
||||
ca.msgListeners.fireMsgComplete(mcid, err)
|
||||
|
||||
// The queue may have been waiting for msg completion to proceed, so
|
||||
// process the next queue item
|
||||
if len(ca.fundsReqQueue) > 0 {
|
||||
go ca.processNextQueueItem()
|
||||
}
|
||||
}
|
||||
|
||||
// callOnComplete calls the onComplete callback for a task
|
||||
func (ca *channelAccessor) callOnComplete(mcid cid.Cid, err error, cb onCompleteFn) {
|
||||
if cb == nil {
|
||||
return
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
go cb(&paychFundsRes{err: err})
|
||||
return
|
||||
}
|
||||
|
||||
// Get the channel address
|
||||
ci, storeErr := ca.store.ByMessageCid(mcid)
|
||||
if storeErr != nil {
|
||||
log.Errorf("getting channel by message cid: %s", err)
|
||||
go cb(&paychFundsRes{err: storeErr})
|
||||
return
|
||||
}
|
||||
|
||||
if ci.Channel == nil {
|
||||
panic("channel address is nil when calling onComplete callback")
|
||||
}
|
||||
|
||||
go cb(&paychFundsRes{channel: *ci.Channel, mcid: mcid, err: err})
|
||||
}
|
||||
|
||||
// processTask checks the state of the channel and takes appropriate action
|
||||
// (see description of getPaych).
|
||||
// Note that processTask may be called repeatedly in the same state, and should
|
||||
// return nil if there is no state change to be made (eg when waiting for a
|
||||
// message to be confirmed on chain)
|
||||
func (ca *channelAccessor) processTask(
|
||||
ctx context.Context,
|
||||
from address.Address,
|
||||
to address.Address,
|
||||
amt types.BigInt,
|
||||
onComplete onCompleteFn,
|
||||
) *paychFundsRes {
|
||||
// Get the payment channel for the from/to addresses.
|
||||
// Note: It's ok if we get ErrChannelNotTracked. It just means we need to
|
||||
// create a channel.
|
||||
channelInfo, err := ca.store.OutboundActiveByFromTo(from, to)
|
||||
if err != nil && err != ErrChannelNotTracked {
|
||||
return &paychFundsRes{err: err}
|
||||
}
|
||||
|
||||
// If a channel has not yet been created, create one.
|
||||
// Note that if the previous attempt to create the channel failed because of a VM error
|
||||
// (eg not enough gas), both channelInfo.Channel and channelInfo.CreateMsg will be nil.
|
||||
if channelInfo == nil || channelInfo.Channel == nil && channelInfo.CreateMsg == nil {
|
||||
mcid, err := ca.createPaych(ctx, from, to, amt, onComplete)
|
||||
if err != nil {
|
||||
return &paychFundsRes{err: err}
|
||||
}
|
||||
|
||||
return &paychFundsRes{mcid: mcid}
|
||||
}
|
||||
|
||||
// If the create channel message has been sent but the channel hasn't
|
||||
// been created on chain yet
|
||||
if channelInfo.CreateMsg != nil {
|
||||
// Wait for the channel to be created before trying again
|
||||
return nil
|
||||
}
|
||||
|
||||
// If an add funds message was sent to the chain but hasn't been confirmed
|
||||
// on chain yet
|
||||
if channelInfo.AddFundsMsg != nil {
|
||||
// Wait for the add funds message to be confirmed before trying again
|
||||
return nil
|
||||
}
|
||||
|
||||
// We need to add more funds, so send an add funds message to
|
||||
// cover the amount for this request
|
||||
mcid, err := ca.addFunds(ctx, from, to, amt, onComplete)
|
||||
if err != nil {
|
||||
return &paychFundsRes{err: err}
|
||||
}
|
||||
return &paychFundsRes{channel: *channelInfo.Channel, mcid: *mcid}
|
||||
}
|
||||
|
||||
// createPaych sends a message to create the channel and returns the message cid
|
||||
func (ca *channelAccessor) createPaych(ctx context.Context, from, to address.Address, amt types.BigInt, cb onCompleteFn) (cid.Cid, error) {
|
||||
params, aerr := actors.SerializeParams(&paych.ConstructorParams{From: from, To: to})
|
||||
if aerr != nil {
|
||||
return cid.Undef, aerr
|
||||
@ -41,106 +250,268 @@ func (pm *Manager) createPaych(ctx context.Context, from, to address.Address, am
|
||||
GasPrice: types.NewInt(0),
|
||||
}
|
||||
|
||||
smsg, err := pm.mpool.MpoolPushMessage(ctx, msg)
|
||||
smsg, err := ca.api.MpoolPushMessage(ctx, msg)
|
||||
if err != nil {
|
||||
return cid.Undef, xerrors.Errorf("initializing paych actor: %w", err)
|
||||
}
|
||||
mcid := smsg.Cid()
|
||||
go pm.waitForPaychCreateMsg(ctx, mcid)
|
||||
|
||||
// Create a new channel in the store
|
||||
if _, err := ca.store.createChannel(from, to, mcid, amt); err != nil {
|
||||
log.Errorf("creating channel: %s", err)
|
||||
return cid.Undef, err
|
||||
}
|
||||
|
||||
// Wait for the channel to be created on chain
|
||||
go ca.waitForPaychCreateMsg(from, to, mcid, cb)
|
||||
|
||||
return mcid, nil
|
||||
}
|
||||
|
||||
// WaitForPaychCreateMsg waits for mcid to appear on chain and returns the robust address of the
|
||||
// waitForPaychCreateMsg waits for mcid to appear on chain and stores the robust address of the
|
||||
// created payment channel
|
||||
// TODO: wait outside the store lock!
|
||||
// (tricky because we need to setup channel tracking before we know its address)
|
||||
func (pm *Manager) waitForPaychCreateMsg(ctx context.Context, mcid cid.Cid) {
|
||||
defer pm.store.lk.Unlock()
|
||||
mwait, err := pm.state.StateWaitMsg(ctx, mcid, build.MessageConfidence)
|
||||
func (ca *channelAccessor) waitForPaychCreateMsg(from address.Address, to address.Address, mcid cid.Cid, cb onCompleteFn) {
|
||||
err := ca.waitPaychCreateMsg(from, to, mcid)
|
||||
ca.msgWaitComplete(mcid, err, cb)
|
||||
}
|
||||
|
||||
func (ca *channelAccessor) waitPaychCreateMsg(from address.Address, to address.Address, mcid cid.Cid) error {
|
||||
mwait, err := ca.api.StateWaitMsg(ca.waitCtx, mcid, build.MessageConfidence)
|
||||
if err != nil {
|
||||
log.Errorf("wait msg: %w", err)
|
||||
return
|
||||
return err
|
||||
}
|
||||
|
||||
if mwait.Receipt.ExitCode != 0 {
|
||||
log.Errorf("payment channel creation failed (exit code %d)", mwait.Receipt.ExitCode)
|
||||
return
|
||||
err := xerrors.Errorf("payment channel creation failed (exit code %d)", mwait.Receipt.ExitCode)
|
||||
log.Error(err)
|
||||
|
||||
ca.lk.Lock()
|
||||
defer ca.lk.Unlock()
|
||||
|
||||
ca.mutateChannelInfo(from, to, func(channelInfo *ChannelInfo) {
|
||||
channelInfo.PendingAmount = big.NewInt(0)
|
||||
channelInfo.CreateMsg = nil
|
||||
})
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
var decodedReturn init_.ExecReturn
|
||||
err = decodedReturn.UnmarshalCBOR(bytes.NewReader(mwait.Receipt.Return))
|
||||
if err != nil {
|
||||
log.Error(err)
|
||||
return
|
||||
}
|
||||
paychaddr := decodedReturn.RobustAddress
|
||||
|
||||
ci, err := pm.loadStateChannelInfo(ctx, paychaddr, DirOutbound)
|
||||
if err != nil {
|
||||
log.Errorf("loading channel info: %w", err)
|
||||
return
|
||||
return err
|
||||
}
|
||||
|
||||
if err := pm.store.trackChannel(ci); err != nil {
|
||||
log.Errorf("tracking channel: %w", err)
|
||||
}
|
||||
ca.lk.Lock()
|
||||
defer ca.lk.Unlock()
|
||||
|
||||
// Store robust address of channel
|
||||
ca.mutateChannelInfo(from, to, func(channelInfo *ChannelInfo) {
|
||||
channelInfo.Channel = &decodedReturn.RobustAddress
|
||||
channelInfo.Amount = channelInfo.PendingAmount
|
||||
channelInfo.PendingAmount = big.NewInt(0)
|
||||
channelInfo.CreateMsg = nil
|
||||
})
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (pm *Manager) addFunds(ctx context.Context, ch address.Address, from address.Address, amt types.BigInt) (cid.Cid, error) {
|
||||
// addFunds sends a message to add funds to the channel and returns the message cid
|
||||
func (ca *channelAccessor) addFunds(ctx context.Context, from address.Address, to address.Address, amt types.BigInt, cb onCompleteFn) (*cid.Cid, error) {
|
||||
channelInfo, err := ca.store.OutboundActiveByFromTo(from, to)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
msg := &types.Message{
|
||||
To: ch,
|
||||
From: from,
|
||||
To: *channelInfo.Channel,
|
||||
From: channelInfo.Control,
|
||||
Value: amt,
|
||||
Method: 0,
|
||||
GasLimit: 0,
|
||||
GasPrice: types.NewInt(0),
|
||||
}
|
||||
|
||||
smsg, err := pm.mpool.MpoolPushMessage(ctx, msg)
|
||||
smsg, err := ca.api.MpoolPushMessage(ctx, msg)
|
||||
if err != nil {
|
||||
return cid.Undef, err
|
||||
return nil, err
|
||||
}
|
||||
mcid := smsg.Cid()
|
||||
go pm.waitForAddFundsMsg(ctx, mcid)
|
||||
return mcid, nil
|
||||
|
||||
// Store the add funds message CID on the channel
|
||||
ca.mutateChannelInfo(from, to, func(ci *ChannelInfo) {
|
||||
ci.PendingAmount = amt
|
||||
ci.AddFundsMsg = &mcid
|
||||
})
|
||||
|
||||
// Store a reference from the message CID to the channel, so that we can
|
||||
// look up the channel from the message CID
|
||||
err = ca.store.SaveNewMessage(channelInfo.ChannelID, mcid)
|
||||
if err != nil {
|
||||
log.Errorf("saving add funds message CID %s: %s", mcid, err)
|
||||
}
|
||||
|
||||
go ca.waitForAddFundsMsg(from, to, mcid, cb)
|
||||
|
||||
return &mcid, nil
|
||||
}
|
||||
|
||||
// WaitForAddFundsMsg waits for mcid to appear on chain and returns error, if any
|
||||
// TODO: wait outside the store lock!
|
||||
// (tricky because we need to setup channel tracking before we know it's address)
|
||||
func (pm *Manager) waitForAddFundsMsg(ctx context.Context, mcid cid.Cid) {
|
||||
defer pm.store.lk.Unlock()
|
||||
mwait, err := pm.state.StateWaitMsg(ctx, mcid, build.MessageConfidence)
|
||||
// waitForAddFundsMsg waits for mcid to appear on chain and returns error, if any
|
||||
func (ca *channelAccessor) waitForAddFundsMsg(from address.Address, to address.Address, mcid cid.Cid, cb onCompleteFn) {
|
||||
err := ca.waitAddFundsMsg(from, to, mcid)
|
||||
ca.msgWaitComplete(mcid, err, cb)
|
||||
}
|
||||
|
||||
func (ca *channelAccessor) waitAddFundsMsg(from address.Address, to address.Address, mcid cid.Cid) error {
|
||||
mwait, err := ca.api.StateWaitMsg(ca.waitCtx, mcid, build.MessageConfidence)
|
||||
if err != nil {
|
||||
log.Error(err)
|
||||
return err
|
||||
}
|
||||
|
||||
if mwait.Receipt.ExitCode != 0 {
|
||||
log.Errorf("voucher channel creation failed: adding funds (exit code %d)", mwait.Receipt.ExitCode)
|
||||
err := xerrors.Errorf("voucher channel creation failed: adding funds (exit code %d)", mwait.Receipt.ExitCode)
|
||||
log.Error(err)
|
||||
|
||||
ca.lk.Lock()
|
||||
defer ca.lk.Unlock()
|
||||
|
||||
ca.mutateChannelInfo(from, to, func(channelInfo *ChannelInfo) {
|
||||
channelInfo.PendingAmount = big.NewInt(0)
|
||||
channelInfo.AddFundsMsg = nil
|
||||
})
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
ca.lk.Lock()
|
||||
defer ca.lk.Unlock()
|
||||
|
||||
// Store updated amount
|
||||
ca.mutateChannelInfo(from, to, func(channelInfo *ChannelInfo) {
|
||||
channelInfo.Amount = types.BigAdd(channelInfo.Amount, channelInfo.PendingAmount)
|
||||
channelInfo.PendingAmount = big.NewInt(0)
|
||||
channelInfo.AddFundsMsg = nil
|
||||
})
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Change the state of the channel in the store
|
||||
func (ca *channelAccessor) mutateChannelInfo(from address.Address, to address.Address, mutate func(*ChannelInfo)) {
|
||||
channelInfo, err := ca.store.OutboundActiveByFromTo(from, to)
|
||||
|
||||
// If there's an error reading or writing to the store just log an error.
|
||||
// For now we're assuming it's unlikely to happen in practice.
|
||||
// Later we may want to implement a transactional approach, whereby
|
||||
// we record to the store that we're going to send a message, send
|
||||
// the message, and then record that the message was sent.
|
||||
if err != nil {
|
||||
log.Errorf("Error reading channel info from store: %s", err)
|
||||
}
|
||||
|
||||
mutate(channelInfo)
|
||||
|
||||
err = ca.store.putChannelInfo(channelInfo)
|
||||
if err != nil {
|
||||
log.Errorf("Error writing channel info to store: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (pm *Manager) GetPaych(ctx context.Context, from, to address.Address, ensureFree types.BigInt) (address.Address, cid.Cid, error) {
|
||||
pm.store.lk.Lock() // unlock only on err; wait funcs will defer unlock
|
||||
var mcid cid.Cid
|
||||
ch, err := pm.store.findChan(func(ci *ChannelInfo) bool {
|
||||
if ci.Direction != DirOutbound {
|
||||
return false
|
||||
// getPaychWaitReady waits for a the response to the message with the given cid
|
||||
func (ca *channelAccessor) getPaychWaitReady(ctx context.Context, mcid cid.Cid) (address.Address, error) {
|
||||
ca.lk.Lock()
|
||||
|
||||
// First check if the message has completed
|
||||
msgInfo, err := ca.store.GetMessage(mcid)
|
||||
if err != nil {
|
||||
ca.lk.Unlock()
|
||||
|
||||
return address.Undef, err
|
||||
}
|
||||
|
||||
// If the create channel / add funds message failed, return an error
|
||||
if len(msgInfo.Err) > 0 {
|
||||
ca.lk.Unlock()
|
||||
|
||||
return address.Undef, xerrors.New(msgInfo.Err)
|
||||
}
|
||||
|
||||
// If the message has completed successfully
|
||||
if msgInfo.Received {
|
||||
ca.lk.Unlock()
|
||||
|
||||
// Get the channel address
|
||||
ci, err := ca.store.ByMessageCid(mcid)
|
||||
if err != nil {
|
||||
return address.Undef, err
|
||||
}
|
||||
return ci.Control == from && ci.Target == to
|
||||
})
|
||||
if err != nil {
|
||||
pm.store.lk.Unlock()
|
||||
return address.Undef, cid.Undef, xerrors.Errorf("findChan: %w", err)
|
||||
|
||||
if ci.Channel == nil {
|
||||
panic(fmt.Sprintf("create / add funds message %s succeeded but channelInfo.Channel is nil", mcid))
|
||||
}
|
||||
return *ci.Channel, nil
|
||||
}
|
||||
if ch != address.Undef {
|
||||
// TODO: Track available funds
|
||||
mcid, err = pm.addFunds(ctx, ch, from, ensureFree)
|
||||
} else {
|
||||
mcid, err = pm.createPaych(ctx, from, to, ensureFree)
|
||||
|
||||
// The message hasn't completed yet so wait for it to complete
|
||||
promise := ca.msgPromise(ctx, mcid)
|
||||
|
||||
// Unlock while waiting
|
||||
ca.lk.Unlock()
|
||||
|
||||
select {
|
||||
case res := <-promise:
|
||||
return res.channel, res.err
|
||||
case <-ctx.Done():
|
||||
return address.Undef, ctx.Err()
|
||||
}
|
||||
if err != nil {
|
||||
pm.store.lk.Unlock()
|
||||
}
|
||||
return ch, mcid, err
|
||||
}
|
||||
|
||||
type onMsgRes struct {
|
||||
channel address.Address
|
||||
err error
|
||||
}
|
||||
|
||||
// msgPromise returns a channel that receives the result of the message with
|
||||
// the given CID
|
||||
func (ca *channelAccessor) msgPromise(ctx context.Context, mcid cid.Cid) chan onMsgRes {
|
||||
promise := make(chan onMsgRes)
|
||||
triggerUnsub := make(chan struct{})
|
||||
sub := ca.msgListeners.onMsg(mcid, func(err error) {
|
||||
close(triggerUnsub)
|
||||
|
||||
// Use a go-routine so as not to block the event handler loop
|
||||
go func() {
|
||||
res := onMsgRes{err: err}
|
||||
if res.err == nil {
|
||||
// Get the channel associated with the message cid
|
||||
ci, err := ca.store.ByMessageCid(mcid)
|
||||
if err != nil {
|
||||
res.err = err
|
||||
} else {
|
||||
res.channel = *ci.Channel
|
||||
}
|
||||
}
|
||||
|
||||
// Pass the result to the caller
|
||||
select {
|
||||
case promise <- res:
|
||||
case <-ctx.Done():
|
||||
}
|
||||
}()
|
||||
})
|
||||
|
||||
// Unsubscribe when the message is received or the context is done
|
||||
go func() {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
case <-triggerUnsub:
|
||||
}
|
||||
|
||||
ca.msgListeners.unsubscribe(sub)
|
||||
}()
|
||||
|
||||
return promise
|
||||
}
|
||||
|
@ -3,20 +3,21 @@ package paychmgr
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/filecoin-project/specs-actors/actors/abi/big"
|
||||
|
||||
"github.com/filecoin-project/specs-actors/actors/builtin/account"
|
||||
|
||||
"github.com/filecoin-project/go-address"
|
||||
"github.com/filecoin-project/specs-actors/actors/builtin/paych"
|
||||
xerrors "golang.org/x/xerrors"
|
||||
|
||||
"github.com/filecoin-project/lotus/chain/types"
|
||||
)
|
||||
|
||||
func (pm *Manager) loadPaychState(ctx context.Context, ch address.Address) (*types.Actor, *paych.State, error) {
|
||||
type stateAccessor struct {
|
||||
sm StateManagerApi
|
||||
}
|
||||
|
||||
func (ca *stateAccessor) loadPaychState(ctx context.Context, ch address.Address) (*types.Actor, *paych.State, error) {
|
||||
var pcast paych.State
|
||||
act, err := pm.sm.LoadActorState(ctx, ch, &pcast, nil)
|
||||
act, err := ca.sm.LoadActorState(ctx, ch, &pcast, nil)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
@ -24,26 +25,26 @@ func (pm *Manager) loadPaychState(ctx context.Context, ch address.Address) (*typ
|
||||
return act, &pcast, nil
|
||||
}
|
||||
|
||||
func (pm *Manager) loadStateChannelInfo(ctx context.Context, ch address.Address, dir uint64) (*ChannelInfo, error) {
|
||||
_, st, err := pm.loadPaychState(ctx, ch)
|
||||
func (ca *stateAccessor) loadStateChannelInfo(ctx context.Context, ch address.Address, dir uint64) (*ChannelInfo, error) {
|
||||
_, st, err := ca.loadPaychState(ctx, ch)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var account account.State
|
||||
_, err = pm.sm.LoadActorState(ctx, st.From, &account, nil)
|
||||
_, err = ca.sm.LoadActorState(ctx, st.From, &account, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
from := account.Address
|
||||
_, err = pm.sm.LoadActorState(ctx, st.To, &account, nil)
|
||||
_, err = ca.sm.LoadActorState(ctx, st.To, &account, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
to := account.Address
|
||||
|
||||
ci := &ChannelInfo{
|
||||
Channel: ch,
|
||||
Channel: &ch,
|
||||
Direction: dir,
|
||||
NextLane: nextLaneFromState(st),
|
||||
}
|
||||
@ -72,80 +73,3 @@ func nextLaneFromState(st *paych.State) uint64 {
|
||||
}
|
||||
return maxLane + 1
|
||||
}
|
||||
|
||||
// laneState gets the LaneStates from chain, then applies all vouchers in
|
||||
// the data store over the chain state
|
||||
func (pm *Manager) laneState(state *paych.State, ch address.Address) (map[uint64]*paych.LaneState, error) {
|
||||
// TODO: we probably want to call UpdateChannelState with all vouchers to be fully correct
|
||||
// (but technically dont't need to)
|
||||
laneStates := make(map[uint64]*paych.LaneState, len(state.LaneStates))
|
||||
|
||||
// Get the lane state from the chain
|
||||
for _, laneState := range state.LaneStates {
|
||||
laneStates[laneState.ID] = laneState
|
||||
}
|
||||
|
||||
// Apply locally stored vouchers
|
||||
vouchers, err := pm.store.VouchersForPaych(ch)
|
||||
if err != nil && err != ErrChannelNotTracked {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for _, v := range vouchers {
|
||||
for range v.Voucher.Merges {
|
||||
return nil, xerrors.Errorf("paych merges not handled yet")
|
||||
}
|
||||
|
||||
// If there's a voucher for a lane that isn't in chain state just
|
||||
// create it
|
||||
ls, ok := laneStates[v.Voucher.Lane]
|
||||
if !ok {
|
||||
ls = &paych.LaneState{
|
||||
ID: v.Voucher.Lane,
|
||||
Redeemed: types.NewInt(0),
|
||||
Nonce: 0,
|
||||
}
|
||||
laneStates[v.Voucher.Lane] = ls
|
||||
}
|
||||
|
||||
if v.Voucher.Nonce < ls.Nonce {
|
||||
continue
|
||||
}
|
||||
|
||||
ls.Nonce = v.Voucher.Nonce
|
||||
ls.Redeemed = v.Voucher.Amount
|
||||
}
|
||||
|
||||
return laneStates, nil
|
||||
}
|
||||
|
||||
// Get the total redeemed amount across all lanes, after applying the voucher
|
||||
func (pm *Manager) totalRedeemedWithVoucher(laneStates map[uint64]*paych.LaneState, sv *paych.SignedVoucher) (big.Int, error) {
|
||||
// TODO: merges
|
||||
if len(sv.Merges) != 0 {
|
||||
return big.Int{}, xerrors.Errorf("dont currently support paych lane merges")
|
||||
}
|
||||
|
||||
total := big.NewInt(0)
|
||||
for _, ls := range laneStates {
|
||||
total = big.Add(total, ls.Redeemed)
|
||||
}
|
||||
|
||||
lane, ok := laneStates[sv.Lane]
|
||||
if ok {
|
||||
// If the voucher is for an existing lane, and the voucher nonce
|
||||
// and is higher than the lane nonce
|
||||
if sv.Nonce > lane.Nonce {
|
||||
// Add the delta between the redeemed amount and the voucher
|
||||
// amount to the total
|
||||
delta := big.Sub(sv.Amount, lane.Redeemed)
|
||||
total = big.Add(total, delta)
|
||||
}
|
||||
} else {
|
||||
// If the voucher is *not* for an existing lane, just add its
|
||||
// value (implicitly a new lane will be created for the voucher)
|
||||
total = big.Add(total, sv.Amount)
|
||||
}
|
||||
|
||||
return total, nil
|
||||
}
|
||||
|
@ -4,14 +4,16 @@ import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/google/uuid"
|
||||
|
||||
"github.com/filecoin-project/lotus/chain/types"
|
||||
|
||||
"github.com/filecoin-project/specs-actors/actors/builtin/paych"
|
||||
"github.com/ipfs/go-cid"
|
||||
"github.com/ipfs/go-datastore"
|
||||
"github.com/ipfs/go-datastore/namespace"
|
||||
dsq "github.com/ipfs/go-datastore/query"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"github.com/filecoin-project/go-address"
|
||||
cborrpc "github.com/filecoin-project/go-cbor-util"
|
||||
@ -22,8 +24,6 @@ import (
|
||||
var ErrChannelNotTracked = errors.New("channel not tracked")
|
||||
|
||||
type Store struct {
|
||||
lk sync.Mutex // TODO: this can be split per paych
|
||||
|
||||
ds datastore.Batching
|
||||
}
|
||||
|
||||
@ -39,85 +39,107 @@ const (
|
||||
DirOutbound = 2
|
||||
)
|
||||
|
||||
const (
|
||||
dsKeyChannelInfo = "ChannelInfo"
|
||||
dsKeyMsgCid = "MsgCid"
|
||||
)
|
||||
|
||||
type VoucherInfo struct {
|
||||
Voucher *paych.SignedVoucher
|
||||
Proof []byte
|
||||
}
|
||||
|
||||
// ChannelInfo keeps track of information about a channel
|
||||
type ChannelInfo struct {
|
||||
Channel address.Address
|
||||
// ChannelID is a uuid set at channel creation
|
||||
ChannelID string
|
||||
// Channel address - may be nil if the channel hasn't been created yet
|
||||
Channel *address.Address
|
||||
// Control is the address of the account that created the channel
|
||||
Control address.Address
|
||||
Target address.Address
|
||||
|
||||
// Target is the address of the account on the other end of the channel
|
||||
Target address.Address
|
||||
// Direction indicates if the channel is inbound (this node is the Target)
|
||||
// or outbound (this node is the Control)
|
||||
Direction uint64
|
||||
Vouchers []*VoucherInfo
|
||||
NextLane uint64
|
||||
// Vouchers is a list of all vouchers sent on the channel
|
||||
Vouchers []*VoucherInfo
|
||||
// NextLane is the number of the next lane that should be used when the
|
||||
// client requests a new lane (eg to create a voucher for a new deal)
|
||||
NextLane uint64
|
||||
// Amount added to the channel.
|
||||
// Note: This amount is only used by GetPaych to keep track of how much
|
||||
// has locally been added to the channel. It should reflect the channel's
|
||||
// Balance on chain as long as all operations occur on the same datastore.
|
||||
Amount types.BigInt
|
||||
// PendingAmount is the amount that we're awaiting confirmation of
|
||||
PendingAmount types.BigInt
|
||||
// CreateMsg is the CID of a pending create message (while waiting for confirmation)
|
||||
CreateMsg *cid.Cid
|
||||
// AddFundsMsg is the CID of a pending add funds message (while waiting for confirmation)
|
||||
AddFundsMsg *cid.Cid
|
||||
// Settling indicates whether the channel has entered into the settling state
|
||||
Settling bool
|
||||
}
|
||||
|
||||
func dskeyForChannel(addr address.Address) datastore.Key {
|
||||
return datastore.NewKey(addr.String())
|
||||
}
|
||||
|
||||
func (ps *Store) putChannelInfo(ci *ChannelInfo) error {
|
||||
k := dskeyForChannel(ci.Channel)
|
||||
|
||||
b, err := cborrpc.Dump(ci)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return ps.ds.Put(k, b)
|
||||
}
|
||||
|
||||
func (ps *Store) getChannelInfo(addr address.Address) (*ChannelInfo, error) {
|
||||
k := dskeyForChannel(addr)
|
||||
|
||||
b, err := ps.ds.Get(k)
|
||||
if err == datastore.ErrNotFound {
|
||||
return nil, ErrChannelNotTracked
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var ci ChannelInfo
|
||||
if err := ci.UnmarshalCBOR(bytes.NewReader(b)); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &ci, nil
|
||||
}
|
||||
|
||||
func (ps *Store) TrackChannel(ch *ChannelInfo) error {
|
||||
ps.lk.Lock()
|
||||
defer ps.lk.Unlock()
|
||||
|
||||
return ps.trackChannel(ch)
|
||||
}
|
||||
|
||||
func (ps *Store) trackChannel(ch *ChannelInfo) error {
|
||||
_, err := ps.getChannelInfo(ch.Channel)
|
||||
// TrackChannel stores a channel, returning an error if the channel was already
|
||||
// being tracked
|
||||
func (ps *Store) TrackChannel(ci *ChannelInfo) error {
|
||||
_, err := ps.ByAddress(*ci.Channel)
|
||||
switch err {
|
||||
default:
|
||||
return err
|
||||
case nil:
|
||||
return fmt.Errorf("already tracking channel: %s", ch.Channel)
|
||||
return fmt.Errorf("already tracking channel: %s", ci.Channel)
|
||||
case ErrChannelNotTracked:
|
||||
return ps.putChannelInfo(ch)
|
||||
return ps.putChannelInfo(ci)
|
||||
}
|
||||
}
|
||||
|
||||
// ListChannels returns the addresses of all channels that have been created
|
||||
func (ps *Store) ListChannels() ([]address.Address, error) {
|
||||
ps.lk.Lock()
|
||||
defer ps.lk.Unlock()
|
||||
cis, err := ps.findChans(func(ci *ChannelInfo) bool {
|
||||
return ci.Channel != nil
|
||||
}, 0)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
res, err := ps.ds.Query(dsq.Query{KeysOnly: true})
|
||||
addrs := make([]address.Address, 0, len(cis))
|
||||
for _, ci := range cis {
|
||||
addrs = append(addrs, *ci.Channel)
|
||||
}
|
||||
|
||||
return addrs, nil
|
||||
}
|
||||
|
||||
// findChan finds a single channel using the given filter.
|
||||
// If there isn't a channel that matches the filter, returns ErrChannelNotTracked
|
||||
func (ps *Store) findChan(filter func(ci *ChannelInfo) bool) (*ChannelInfo, error) {
|
||||
cis, err := ps.findChans(filter, 1)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if len(cis) == 0 {
|
||||
return nil, ErrChannelNotTracked
|
||||
}
|
||||
|
||||
return &cis[0], err
|
||||
}
|
||||
|
||||
// findChans loops over all channels, only including those that pass the filter.
|
||||
// max is the maximum number of channels to return. Set to zero to return unlimited channels.
|
||||
func (ps *Store) findChans(filter func(*ChannelInfo) bool, max int) ([]ChannelInfo, error) {
|
||||
res, err := ps.ds.Query(dsq.Query{Prefix: dsKeyChannelInfo})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer res.Close() //nolint:errcheck
|
||||
|
||||
var out []address.Address
|
||||
var stored ChannelInfo
|
||||
var matches []ChannelInfo
|
||||
|
||||
for {
|
||||
res, ok := res.NextSync()
|
||||
if !ok {
|
||||
@ -128,60 +150,31 @@ func (ps *Store) ListChannels() ([]address.Address, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
addr, err := address.NewFromString(strings.TrimPrefix(res.Key, "/"))
|
||||
ci, err := unmarshallChannelInfo(&stored, res)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("failed reading paych key (%q) from datastore: %w", res.Key, err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
out = append(out, addr)
|
||||
}
|
||||
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (ps *Store) findChan(filter func(*ChannelInfo) bool) (address.Address, error) {
|
||||
res, err := ps.ds.Query(dsq.Query{})
|
||||
if err != nil {
|
||||
return address.Undef, err
|
||||
}
|
||||
defer res.Close() //nolint:errcheck
|
||||
|
||||
var ci ChannelInfo
|
||||
|
||||
for {
|
||||
res, ok := res.NextSync()
|
||||
if !ok {
|
||||
break
|
||||
}
|
||||
|
||||
if res.Error != nil {
|
||||
return address.Undef, err
|
||||
}
|
||||
|
||||
if err := ci.UnmarshalCBOR(bytes.NewReader(res.Value)); err != nil {
|
||||
return address.Undef, err
|
||||
}
|
||||
|
||||
if !filter(&ci) {
|
||||
if !filter(ci) {
|
||||
continue
|
||||
}
|
||||
|
||||
addr, err := address.NewFromString(strings.TrimPrefix(res.Key, "/"))
|
||||
if err != nil {
|
||||
return address.Undef, xerrors.Errorf("failed reading paych key (%q) from datastore: %w", res.Key, err)
|
||||
}
|
||||
matches = append(matches, *ci)
|
||||
|
||||
return addr, nil
|
||||
// If we've reached the maximum number of matches, return.
|
||||
// Note that if max is zero we return an unlimited number of matches
|
||||
// because len(matches) will always be at least 1.
|
||||
if len(matches) == max {
|
||||
return matches, nil
|
||||
}
|
||||
}
|
||||
|
||||
return address.Undef, nil
|
||||
return matches, nil
|
||||
}
|
||||
|
||||
// AllocateLane allocates a new lane for the given channel
|
||||
func (ps *Store) AllocateLane(ch address.Address) (uint64, error) {
|
||||
ps.lk.Lock()
|
||||
defer ps.lk.Unlock()
|
||||
|
||||
ci, err := ps.getChannelInfo(ch)
|
||||
ci, err := ps.ByAddress(ch)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
@ -192,11 +185,210 @@ func (ps *Store) AllocateLane(ch address.Address) (uint64, error) {
|
||||
return out, ps.putChannelInfo(ci)
|
||||
}
|
||||
|
||||
// VouchersForPaych gets the vouchers for the given channel
|
||||
func (ps *Store) VouchersForPaych(ch address.Address) ([]*VoucherInfo, error) {
|
||||
ci, err := ps.getChannelInfo(ch)
|
||||
ci, err := ps.ByAddress(ch)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return ci.Vouchers, nil
|
||||
}
|
||||
|
||||
// ByAddress gets the channel that matches the given address
|
||||
func (ps *Store) ByAddress(addr address.Address) (*ChannelInfo, error) {
|
||||
return ps.findChan(func(ci *ChannelInfo) bool {
|
||||
return ci.Channel != nil && *ci.Channel == addr
|
||||
})
|
||||
}
|
||||
|
||||
// MsgInfo stores information about a create channel / add funds message
|
||||
// that has been sent
|
||||
type MsgInfo struct {
|
||||
// ChannelID links the message to a channel
|
||||
ChannelID string
|
||||
// MsgCid is the CID of the message
|
||||
MsgCid cid.Cid
|
||||
// Received indicates whether a response has been received
|
||||
Received bool
|
||||
// Err is the error received in the response
|
||||
Err string
|
||||
}
|
||||
|
||||
// The datastore key used to identify the message
|
||||
func dskeyForMsg(mcid cid.Cid) datastore.Key {
|
||||
return datastore.KeyWithNamespaces([]string{dsKeyMsgCid, mcid.String()})
|
||||
}
|
||||
|
||||
// SaveNewMessage is called when a message is sent
|
||||
func (ps *Store) SaveNewMessage(channelID string, mcid cid.Cid) error {
|
||||
k := dskeyForMsg(mcid)
|
||||
|
||||
b, err := cborrpc.Dump(&MsgInfo{ChannelID: channelID, MsgCid: mcid})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return ps.ds.Put(k, b)
|
||||
}
|
||||
|
||||
// SaveMessageResult is called when the result of a message is received
|
||||
func (ps *Store) SaveMessageResult(mcid cid.Cid, msgErr error) error {
|
||||
minfo, err := ps.GetMessage(mcid)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
k := dskeyForMsg(mcid)
|
||||
minfo.Received = true
|
||||
if msgErr != nil {
|
||||
minfo.Err = msgErr.Error()
|
||||
}
|
||||
|
||||
b, err := cborrpc.Dump(minfo)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return ps.ds.Put(k, b)
|
||||
}
|
||||
|
||||
// ByMessageCid gets the channel associated with a message
|
||||
func (ps *Store) ByMessageCid(mcid cid.Cid) (*ChannelInfo, error) {
|
||||
minfo, err := ps.GetMessage(mcid)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ci, err := ps.findChan(func(ci *ChannelInfo) bool {
|
||||
return ci.ChannelID == minfo.ChannelID
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return ci, err
|
||||
}
|
||||
|
||||
// GetMessage gets the message info for a given message CID
|
||||
func (ps *Store) GetMessage(mcid cid.Cid) (*MsgInfo, error) {
|
||||
k := dskeyForMsg(mcid)
|
||||
|
||||
val, err := ps.ds.Get(k)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var minfo MsgInfo
|
||||
if err := minfo.UnmarshalCBOR(bytes.NewReader(val)); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &minfo, nil
|
||||
}
|
||||
|
||||
// OutboundActiveByFromTo looks for outbound channels that have not been
|
||||
// settled, with the given from / to addresses
|
||||
func (ps *Store) OutboundActiveByFromTo(from address.Address, to address.Address) (*ChannelInfo, error) {
|
||||
return ps.findChan(func(ci *ChannelInfo) bool {
|
||||
if ci.Direction != DirOutbound {
|
||||
return false
|
||||
}
|
||||
if ci.Settling {
|
||||
return false
|
||||
}
|
||||
return ci.Control == from && ci.Target == to
|
||||
})
|
||||
}
|
||||
|
||||
// WithPendingAddFunds is used on startup to find channels for which a
|
||||
// create channel or add funds message has been sent, but lotus shut down
|
||||
// before the response was received.
|
||||
func (ps *Store) WithPendingAddFunds() ([]ChannelInfo, error) {
|
||||
return ps.findChans(func(ci *ChannelInfo) bool {
|
||||
if ci.Direction != DirOutbound {
|
||||
return false
|
||||
}
|
||||
return ci.CreateMsg != nil || ci.AddFundsMsg != nil
|
||||
}, 0)
|
||||
}
|
||||
|
||||
// createChannel creates an outbound channel for the given from / to, ensuring
|
||||
// it has a higher sequence number than any existing channel with the same from / to
|
||||
func (ps *Store) createChannel(from address.Address, to address.Address, createMsgCid cid.Cid, amt types.BigInt) (*ChannelInfo, error) {
|
||||
ci := &ChannelInfo{
|
||||
ChannelID: uuid.New().String(),
|
||||
Direction: DirOutbound,
|
||||
NextLane: 0,
|
||||
Control: from,
|
||||
Target: to,
|
||||
CreateMsg: &createMsgCid,
|
||||
PendingAmount: amt,
|
||||
}
|
||||
|
||||
// Save the new channel
|
||||
err := ps.putChannelInfo(ci)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Save a reference to the create message
|
||||
err = ps.SaveNewMessage(ci.ChannelID, createMsgCid)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return ci, err
|
||||
}
|
||||
|
||||
// The datastore key used to identify the channel info
|
||||
func dskeyForChannel(ci *ChannelInfo) datastore.Key {
|
||||
chanKey := fmt.Sprintf("%s->%s", ci.Control.String(), ci.Target.String())
|
||||
return datastore.KeyWithNamespaces([]string{dsKeyChannelInfo, chanKey})
|
||||
}
|
||||
|
||||
// putChannelInfo stores the channel info in the datastore
|
||||
func (ps *Store) putChannelInfo(ci *ChannelInfo) error {
|
||||
k := dskeyForChannel(ci)
|
||||
|
||||
b, err := marshallChannelInfo(ci)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return ps.ds.Put(k, b)
|
||||
}
|
||||
|
||||
// TODO: This is a hack to get around not being able to CBOR marshall a nil
|
||||
// address.Address. It's been fixed in address.Address but we need to wait
|
||||
// for the change to propagate to specs-actors before we can remove this hack.
|
||||
var emptyAddr address.Address
|
||||
|
||||
func init() {
|
||||
addr, err := address.NewActorAddress([]byte("empty"))
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
emptyAddr = addr
|
||||
}
|
||||
|
||||
func marshallChannelInfo(ci *ChannelInfo) ([]byte, error) {
|
||||
// See note above about CBOR marshalling address.Address
|
||||
if ci.Channel == nil {
|
||||
ci.Channel = &emptyAddr
|
||||
}
|
||||
return cborrpc.Dump(ci)
|
||||
}
|
||||
|
||||
func unmarshallChannelInfo(stored *ChannelInfo, res dsq.Result) (*ChannelInfo, error) {
|
||||
if err := stored.UnmarshalCBOR(bytes.NewReader(res.Value)); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// See note above about CBOR marshalling address.Address
|
||||
if stored.Channel != nil && *stored.Channel == emptyAddr {
|
||||
stored.Channel = nil
|
||||
}
|
||||
|
||||
return stored, nil
|
||||
}
|
||||
|
@ -17,8 +17,9 @@ func TestStore(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
require.Len(t, addrs, 0)
|
||||
|
||||
ch := tutils.NewIDAddr(t, 100)
|
||||
ci := &ChannelInfo{
|
||||
Channel: tutils.NewIDAddr(t, 100),
|
||||
Channel: &ch,
|
||||
Control: tutils.NewIDAddr(t, 101),
|
||||
Target: tutils.NewIDAddr(t, 102),
|
||||
|
||||
@ -26,8 +27,9 @@ func TestStore(t *testing.T) {
|
||||
Vouchers: []*VoucherInfo{{Voucher: nil, Proof: []byte{}}},
|
||||
}
|
||||
|
||||
ch2 := tutils.NewIDAddr(t, 200)
|
||||
ci2 := &ChannelInfo{
|
||||
Channel: tutils.NewIDAddr(t, 200),
|
||||
Channel: &ch2,
|
||||
Control: tutils.NewIDAddr(t, 201),
|
||||
Target: tutils.NewIDAddr(t, 202),
|
||||
|
||||
@ -55,7 +57,7 @@ func TestStore(t *testing.T) {
|
||||
require.Contains(t, addrsStrings(addrs), "t0200")
|
||||
|
||||
// Request vouchers for channel
|
||||
vouchers, err := store.VouchersForPaych(ci.Channel)
|
||||
vouchers, err := store.VouchersForPaych(*ci.Channel)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, vouchers, 1)
|
||||
|
||||
@ -64,12 +66,12 @@ func TestStore(t *testing.T) {
|
||||
require.Equal(t, err, ErrChannelNotTracked)
|
||||
|
||||
// Allocate lane for channel
|
||||
lane, err := store.AllocateLane(ci.Channel)
|
||||
lane, err := store.AllocateLane(*ci.Channel)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, lane, uint64(0))
|
||||
|
||||
// Allocate next lane for channel
|
||||
lane, err = store.AllocateLane(ci.Channel)
|
||||
lane, err = store.AllocateLane(*ci.Channel)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, lane, uint64(1))
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user