WIP: fix payment channel locking

This commit is contained in:
Dirk McCormick 2020-07-28 19:16:47 -04:00
parent bf9116c65a
commit fdfccf0466
21 changed files with 2744 additions and 422 deletions

View File

@ -377,7 +377,8 @@ type FullNode interface {
// MethodGroup: Paych
// The Paych methods are for interacting with and managing payment channels
PaychGet(ctx context.Context, from, to address.Address, ensureFunds types.BigInt) (*ChannelInfo, error)
PaychGet(ctx context.Context, from, to address.Address, amt types.BigInt) (*ChannelInfo, error)
PaychGetWaitReady(context.Context, cid.Cid) (address.Address, error)
PaychList(context.Context) ([]address.Address, error)
PaychStatus(context.Context, address.Address) (*PaychStatus, error)
PaychSettle(context.Context, address.Address) (cid.Cid, error)

View File

@ -183,7 +183,8 @@ type FullNodeStruct struct {
MarketEnsureAvailable func(context.Context, address.Address, address.Address, types.BigInt) (cid.Cid, error) `perm:"sign"`
PaychGet func(ctx context.Context, from, to address.Address, ensureFunds types.BigInt) (*api.ChannelInfo, error) `perm:"sign"`
PaychGet func(ctx context.Context, from, to address.Address, amt types.BigInt) (*api.ChannelInfo, error) `perm:"sign"`
PaychGetWaitReady func(context.Context, cid.Cid) (address.Address, error) `perm:"sign"` // TODO: is perm:"sign" correct?
PaychList func(context.Context) ([]address.Address, error) `perm:"read"`
PaychStatus func(context.Context, address.Address) (*api.PaychStatus, error) `perm:"read"`
PaychSettle func(context.Context, address.Address) (cid.Cid, error) `perm:"sign"`
@ -803,8 +804,12 @@ func (c *FullNodeStruct) MarketEnsureAvailable(ctx context.Context, addr, wallet
return c.Internal.MarketEnsureAvailable(ctx, addr, wallet, amt)
}
func (c *FullNodeStruct) PaychGet(ctx context.Context, from, to address.Address, ensureFunds types.BigInt) (*api.ChannelInfo, error) {
return c.Internal.PaychGet(ctx, from, to, ensureFunds)
func (c *FullNodeStruct) PaychGet(ctx context.Context, from, to address.Address, amt types.BigInt) (*api.ChannelInfo, error) {
return c.Internal.PaychGet(ctx, from, to, amt)
}
func (c *FullNodeStruct) PaychGetWaitReady(ctx context.Context, mcid cid.Cid) (address.Address, error) {
return c.Internal.PaychGetWaitReady(ctx, mcid)
}
func (c *FullNodeStruct) PaychList(ctx context.Context) ([]address.Address, error) {

View File

@ -1,18 +1,17 @@
package test
import (
"bytes"
"context"
"fmt"
"github.com/filecoin-project/specs-actors/actors/builtin"
"os"
"sync/atomic"
"testing"
"time"
"github.com/filecoin-project/specs-actors/actors/builtin"
"github.com/filecoin-project/specs-actors/actors/abi"
"github.com/filecoin-project/specs-actors/actors/abi/big"
initactor "github.com/filecoin-project/specs-actors/actors/builtin/init"
"github.com/filecoin-project/specs-actors/actors/builtin/paych"
"github.com/ipfs/go-cid"
@ -77,13 +76,10 @@ func TestPaymentChannels(t *testing.T, b APIBuilder, blocktime time.Duration) {
t.Fatal(err)
}
res := waitForMessage(ctx, t, paymentCreator, channelInfo.ChannelMessage, time.Second, "channel create")
var params initactor.ExecReturn
err = params.UnmarshalCBOR(bytes.NewReader(res.Receipt.Return))
channel, err := paymentCreator.PaychGetWaitReady(ctx, channelInfo.ChannelMessage)
if err != nil {
t.Fatal(err)
}
channel := params.RobustAddress
// allocate three lanes
var lanes []uint64
@ -129,7 +125,7 @@ func TestPaymentChannels(t *testing.T, b APIBuilder, blocktime time.Duration) {
t.Fatal(err)
}
res = waitForMessage(ctx, t, paymentCreator, settleMsgCid, time.Second*10, "settle")
res := waitForMessage(ctx, t, paymentCreator, settleMsgCid, time.Second*10, "settle")
if res.Receipt.ExitCode != 0 {
t.Fatal("Unable to settle payment channel")
}

View File

@ -28,7 +28,7 @@ var paychCmd = &cli.Command{
var paychGetCmd = &cli.Command{
Name: "get",
Usage: "Create a new payment channel or get existing one",
Usage: "Create a new payment channel or get existing one and add amount to it",
ArgsUsage: "[fromAddress toAddress amount]",
Action: func(cctx *cli.Context) error {
if cctx.Args().Len() != 3 {

View File

@ -35,6 +35,7 @@ func main() {
err = gen.WriteMapEncodersToFile("./paychmgr/cbor_gen.go", "paychmgr",
paychmgr.VoucherInfo{},
paychmgr.ChannelInfo{},
paychmgr.MsgInfo{},
)
if err != nil {
fmt.Println(err)

View File

@ -5,8 +5,6 @@ import (
"errors"
"time"
"github.com/filecoin-project/lotus/markets/dealfilter"
logging "github.com/ipfs/go-log"
ci "github.com/libp2p/go-libp2p-core/crypto"
"github.com/libp2p/go-libp2p-core/host"
@ -111,6 +109,7 @@ const (
HandleIncomingMessagesKey
RegisterClientValidatorKey
HandlePaymentChannelManagerKey
// miner
GetParamsKey
@ -279,6 +278,7 @@ func Online() Option {
Override(new(*paychmgr.Store), paychmgr.NewStore),
Override(new(*paychmgr.Manager), paychmgr.NewManager),
Override(new(*market.FundMgr), market.StartFundManager),
Override(HandlePaymentChannelManagerKey, paychmgr.HandleManager),
Override(SettlePaymentChannelsKey, settler.SettlePaymentChannels),
),

View File

@ -28,8 +28,8 @@ type PaychAPI struct {
PaychMgr *paychmgr.Manager
}
func (a *PaychAPI) PaychGet(ctx context.Context, from, to address.Address, ensureFunds types.BigInt) (*api.ChannelInfo, error) {
ch, mcid, err := a.PaychMgr.GetPaych(ctx, from, to, ensureFunds)
func (a *PaychAPI) PaychGet(ctx context.Context, from, to address.Address, amt types.BigInt) (*api.ChannelInfo, error) {
ch, mcid, err := a.PaychMgr.GetPaych(ctx, from, to, amt)
if err != nil {
return nil, err
}
@ -40,6 +40,10 @@ func (a *PaychAPI) PaychGet(ctx context.Context, from, to address.Address, ensur
}, nil
}
func (a *PaychAPI) PaychGetWaitReady(ctx context.Context, mcid cid.Cid) (address.Address, error) {
return a.PaychMgr.GetPaychWaitReady(ctx, mcid)
}
func (a *PaychAPI) PaychAllocateLane(ctx context.Context, ch address.Address) (uint64, error) {
return a.PaychMgr.AllocateLane(ch)
}
@ -66,7 +70,7 @@ func (a *PaychAPI) PaychNewPayment(ctx context.Context, from, to address.Address
ChannelAddr: ch.Channel,
Amount: v.Amount,
Lane: uint64(lane),
Lane: lane,
Extra: v.Extra,
TimeLockMin: v.TimeLockMin,
@ -108,24 +112,7 @@ func (a *PaychAPI) PaychStatus(ctx context.Context, pch address.Address) (*api.P
}
func (a *PaychAPI) PaychSettle(ctx context.Context, addr address.Address) (cid.Cid, error) {
ci, err := a.PaychMgr.GetChannelInfo(addr)
if err != nil {
return cid.Undef, err
}
msg := &types.Message{
To: addr,
From: ci.Control,
Value: types.NewInt(0),
Method: builtin.MethodsPaych.Settle,
}
smgs, err := a.MpoolPushMessage(ctx, msg)
if err != nil {
return cid.Undef, err
}
return smgs.Cid(), nil
return a.PaychMgr.Settle(ctx, addr)
}
func (a *PaychAPI) PaychCollect(ctx context.Context, addr address.Address) (cid.Cid, error) {

67
paychmgr/accessorcache.go Normal file
View File

@ -0,0 +1,67 @@
package paychmgr
import "github.com/filecoin-project/go-address"
// accessorByFromTo gets a channel accessor for a given from / to pair.
// The channel accessor facilitates locking a channel so that operations
// must be performed sequentially on a channel (but can be performed at
// the same time on different channels).
func (pm *Manager) accessorByFromTo(from address.Address, to address.Address) (*channelAccessor, error) {
key := pm.accessorCacheKey(from, to)
// First take a read lock and check the cache
pm.lk.RLock()
ca, ok := pm.channels[key]
pm.lk.RUnlock()
if ok {
return ca, nil
}
// Not in cache, so take a write lock
pm.lk.Lock()
defer pm.lk.Unlock()
// Need to check cache again in case it was updated between releasing read
// lock and taking write lock
ca, ok = pm.channels[key]
if !ok {
// Not in cache, so create a new one and store in cache
ca = pm.addAccessorToCache(from, to)
}
return ca, nil
}
// accessorByAddress gets a channel accessor for a given channel address.
// The channel accessor facilitates locking a channel so that operations
// must be performed sequentially on a channel (but can be performed at
// the same time on different channels).
func (pm *Manager) accessorByAddress(ch address.Address) (*channelAccessor, error) {
// Get the channel from / to
pm.lk.RLock()
channelInfo, err := pm.store.ByAddress(ch)
pm.lk.RUnlock()
if err != nil {
return nil, err
}
// TODO: cache by channel address so we can get by address instead of using from / to
return pm.accessorByFromTo(channelInfo.Control, channelInfo.Target)
}
// accessorCacheKey returns the cache key use to reference a channel accessor
func (pm *Manager) accessorCacheKey(from address.Address, to address.Address) string {
return from.String() + "->" + to.String()
}
// addAccessorToCache adds a channel accessor to a cache. Note that channelInfo
// may be nil if the channel hasn't been created yet, but we still want to
// reference the same channel accessor for a given from/to, so that all
// attempts to access a channel use the same lock (the lock on the accessor)
func (pm *Manager) addAccessorToCache(from address.Address, to address.Address) *channelAccessor {
key := pm.accessorCacheKey(from, to)
ca := newChannelAccessor(pm)
// TODO: Use LRU
pm.channels[key] = ca
return ca
}

View File

@ -6,6 +6,7 @@ import (
"fmt"
"io"
"github.com/filecoin-project/go-address"
"github.com/filecoin-project/specs-actors/actors/builtin/paych"
cbg "github.com/whyrusleeping/cbor-gen"
xerrors "golang.org/x/xerrors"
@ -156,12 +157,35 @@ func (t *ChannelInfo) MarshalCBOR(w io.Writer) error {
_, err := w.Write(cbg.CborNull)
return err
}
if _, err := w.Write([]byte{166}); err != nil {
if _, err := w.Write([]byte{172}); err != nil {
return err
}
scratch := make([]byte, 9)
// t.ChannelID (string) (string)
if len("ChannelID") > cbg.MaxLength {
return xerrors.Errorf("Value in field \"ChannelID\" was too long")
}
if err := cbg.WriteMajorTypeHeaderBuf(scratch, w, cbg.MajTextString, uint64(len("ChannelID"))); err != nil {
return err
}
if _, err := io.WriteString(w, string("ChannelID")); err != nil {
return err
}
if len(t.ChannelID) > cbg.MaxLength {
return xerrors.Errorf("Value in field t.ChannelID was too long")
}
if err := cbg.WriteMajorTypeHeaderBuf(scratch, w, cbg.MajTextString, uint64(len(t.ChannelID))); err != nil {
return err
}
if _, err := io.WriteString(w, string(t.ChannelID)); err != nil {
return err
}
// t.Channel (address.Address) (struct)
if len("Channel") > cbg.MaxLength {
return xerrors.Errorf("Value in field \"Channel\" was too long")
@ -267,6 +291,97 @@ func (t *ChannelInfo) MarshalCBOR(w io.Writer) error {
return err
}
// t.Amount (big.Int) (struct)
if len("Amount") > cbg.MaxLength {
return xerrors.Errorf("Value in field \"Amount\" was too long")
}
if err := cbg.WriteMajorTypeHeaderBuf(scratch, w, cbg.MajTextString, uint64(len("Amount"))); err != nil {
return err
}
if _, err := io.WriteString(w, string("Amount")); err != nil {
return err
}
if err := t.Amount.MarshalCBOR(w); err != nil {
return err
}
// t.PendingAmount (big.Int) (struct)
if len("PendingAmount") > cbg.MaxLength {
return xerrors.Errorf("Value in field \"PendingAmount\" was too long")
}
if err := cbg.WriteMajorTypeHeaderBuf(scratch, w, cbg.MajTextString, uint64(len("PendingAmount"))); err != nil {
return err
}
if _, err := io.WriteString(w, string("PendingAmount")); err != nil {
return err
}
if err := t.PendingAmount.MarshalCBOR(w); err != nil {
return err
}
// t.CreateMsg (cid.Cid) (struct)
if len("CreateMsg") > cbg.MaxLength {
return xerrors.Errorf("Value in field \"CreateMsg\" was too long")
}
if err := cbg.WriteMajorTypeHeaderBuf(scratch, w, cbg.MajTextString, uint64(len("CreateMsg"))); err != nil {
return err
}
if _, err := io.WriteString(w, string("CreateMsg")); err != nil {
return err
}
if t.CreateMsg == nil {
if _, err := w.Write(cbg.CborNull); err != nil {
return err
}
} else {
if err := cbg.WriteCidBuf(scratch, w, *t.CreateMsg); err != nil {
return xerrors.Errorf("failed to write cid field t.CreateMsg: %w", err)
}
}
// t.AddFundsMsg (cid.Cid) (struct)
if len("AddFundsMsg") > cbg.MaxLength {
return xerrors.Errorf("Value in field \"AddFundsMsg\" was too long")
}
if err := cbg.WriteMajorTypeHeaderBuf(scratch, w, cbg.MajTextString, uint64(len("AddFundsMsg"))); err != nil {
return err
}
if _, err := io.WriteString(w, string("AddFundsMsg")); err != nil {
return err
}
if t.AddFundsMsg == nil {
if _, err := w.Write(cbg.CborNull); err != nil {
return err
}
} else {
if err := cbg.WriteCidBuf(scratch, w, *t.AddFundsMsg); err != nil {
return xerrors.Errorf("failed to write cid field t.AddFundsMsg: %w", err)
}
}
// t.Settling (bool) (bool)
if len("Settling") > cbg.MaxLength {
return xerrors.Errorf("Value in field \"Settling\" was too long")
}
if err := cbg.WriteMajorTypeHeaderBuf(scratch, w, cbg.MajTextString, uint64(len("Settling"))); err != nil {
return err
}
if _, err := io.WriteString(w, string("Settling")); err != nil {
return err
}
if err := cbg.WriteBool(w, t.Settling); err != nil {
return err
}
return nil
}
@ -303,13 +418,36 @@ func (t *ChannelInfo) UnmarshalCBOR(r io.Reader) error {
}
switch name {
// t.Channel (address.Address) (struct)
// t.ChannelID (string) (string)
case "ChannelID":
{
sval, err := cbg.ReadStringBuf(br, scratch)
if err != nil {
return err
}
t.ChannelID = string(sval)
}
// t.Channel (address.Address) (struct)
case "Channel":
{
if err := t.Channel.UnmarshalCBOR(br); err != nil {
return xerrors.Errorf("unmarshaling t.Channel: %w", err)
pb, err := br.PeekByte()
if err != nil {
return err
}
if pb == cbg.CborNull[0] {
var nbuf [1]byte
if _, err := br.Read(nbuf[:]); err != nil {
return err
}
} else {
t.Channel = new(address.Address)
if err := t.Channel.UnmarshalCBOR(br); err != nil {
return xerrors.Errorf("unmarshaling t.Channel pointer: %w", err)
}
}
}
@ -393,6 +531,279 @@ func (t *ChannelInfo) UnmarshalCBOR(r io.Reader) error {
t.NextLane = uint64(extra)
}
// t.Amount (big.Int) (struct)
case "Amount":
{
if err := t.Amount.UnmarshalCBOR(br); err != nil {
return xerrors.Errorf("unmarshaling t.Amount: %w", err)
}
}
// t.PendingAmount (big.Int) (struct)
case "PendingAmount":
{
if err := t.PendingAmount.UnmarshalCBOR(br); err != nil {
return xerrors.Errorf("unmarshaling t.PendingAmount: %w", err)
}
}
// t.CreateMsg (cid.Cid) (struct)
case "CreateMsg":
{
pb, err := br.PeekByte()
if err != nil {
return err
}
if pb == cbg.CborNull[0] {
var nbuf [1]byte
if _, err := br.Read(nbuf[:]); err != nil {
return err
}
} else {
c, err := cbg.ReadCid(br)
if err != nil {
return xerrors.Errorf("failed to read cid field t.CreateMsg: %w", err)
}
t.CreateMsg = &c
}
}
// t.AddFundsMsg (cid.Cid) (struct)
case "AddFundsMsg":
{
pb, err := br.PeekByte()
if err != nil {
return err
}
if pb == cbg.CborNull[0] {
var nbuf [1]byte
if _, err := br.Read(nbuf[:]); err != nil {
return err
}
} else {
c, err := cbg.ReadCid(br)
if err != nil {
return xerrors.Errorf("failed to read cid field t.AddFundsMsg: %w", err)
}
t.AddFundsMsg = &c
}
}
// t.Settling (bool) (bool)
case "Settling":
maj, extra, err = cbg.CborReadHeaderBuf(br, scratch)
if err != nil {
return err
}
if maj != cbg.MajOther {
return fmt.Errorf("booleans must be major type 7")
}
switch extra {
case 20:
t.Settling = false
case 21:
t.Settling = true
default:
return fmt.Errorf("booleans are either major type 7, value 20 or 21 (got %d)", extra)
}
default:
return fmt.Errorf("unknown struct field %d: '%s'", i, name)
}
}
return nil
}
func (t *MsgInfo) MarshalCBOR(w io.Writer) error {
if t == nil {
_, err := w.Write(cbg.CborNull)
return err
}
if _, err := w.Write([]byte{164}); err != nil {
return err
}
scratch := make([]byte, 9)
// t.ChannelID (string) (string)
if len("ChannelID") > cbg.MaxLength {
return xerrors.Errorf("Value in field \"ChannelID\" was too long")
}
if err := cbg.WriteMajorTypeHeaderBuf(scratch, w, cbg.MajTextString, uint64(len("ChannelID"))); err != nil {
return err
}
if _, err := io.WriteString(w, string("ChannelID")); err != nil {
return err
}
if len(t.ChannelID) > cbg.MaxLength {
return xerrors.Errorf("Value in field t.ChannelID was too long")
}
if err := cbg.WriteMajorTypeHeaderBuf(scratch, w, cbg.MajTextString, uint64(len(t.ChannelID))); err != nil {
return err
}
if _, err := io.WriteString(w, string(t.ChannelID)); err != nil {
return err
}
// t.MsgCid (cid.Cid) (struct)
if len("MsgCid") > cbg.MaxLength {
return xerrors.Errorf("Value in field \"MsgCid\" was too long")
}
if err := cbg.WriteMajorTypeHeaderBuf(scratch, w, cbg.MajTextString, uint64(len("MsgCid"))); err != nil {
return err
}
if _, err := io.WriteString(w, string("MsgCid")); err != nil {
return err
}
if err := cbg.WriteCidBuf(scratch, w, t.MsgCid); err != nil {
return xerrors.Errorf("failed to write cid field t.MsgCid: %w", err)
}
// t.Received (bool) (bool)
if len("Received") > cbg.MaxLength {
return xerrors.Errorf("Value in field \"Received\" was too long")
}
if err := cbg.WriteMajorTypeHeaderBuf(scratch, w, cbg.MajTextString, uint64(len("Received"))); err != nil {
return err
}
if _, err := io.WriteString(w, string("Received")); err != nil {
return err
}
if err := cbg.WriteBool(w, t.Received); err != nil {
return err
}
// t.Err (string) (string)
if len("Err") > cbg.MaxLength {
return xerrors.Errorf("Value in field \"Err\" was too long")
}
if err := cbg.WriteMajorTypeHeaderBuf(scratch, w, cbg.MajTextString, uint64(len("Err"))); err != nil {
return err
}
if _, err := io.WriteString(w, string("Err")); err != nil {
return err
}
if len(t.Err) > cbg.MaxLength {
return xerrors.Errorf("Value in field t.Err was too long")
}
if err := cbg.WriteMajorTypeHeaderBuf(scratch, w, cbg.MajTextString, uint64(len(t.Err))); err != nil {
return err
}
if _, err := io.WriteString(w, string(t.Err)); err != nil {
return err
}
return nil
}
func (t *MsgInfo) UnmarshalCBOR(r io.Reader) error {
*t = MsgInfo{}
br := cbg.GetPeeker(r)
scratch := make([]byte, 8)
maj, extra, err := cbg.CborReadHeaderBuf(br, scratch)
if err != nil {
return err
}
if maj != cbg.MajMap {
return fmt.Errorf("cbor input should be of type map")
}
if extra > cbg.MaxLength {
return fmt.Errorf("MsgInfo: map struct too large (%d)", extra)
}
var name string
n := extra
for i := uint64(0); i < n; i++ {
{
sval, err := cbg.ReadStringBuf(br, scratch)
if err != nil {
return err
}
name = string(sval)
}
switch name {
// t.ChannelID (string) (string)
case "ChannelID":
{
sval, err := cbg.ReadStringBuf(br, scratch)
if err != nil {
return err
}
t.ChannelID = string(sval)
}
// t.MsgCid (cid.Cid) (struct)
case "MsgCid":
{
c, err := cbg.ReadCid(br)
if err != nil {
return xerrors.Errorf("failed to read cid field t.MsgCid: %w", err)
}
t.MsgCid = c
}
// t.Received (bool) (bool)
case "Received":
maj, extra, err = cbg.CborReadHeaderBuf(br, scratch)
if err != nil {
return err
}
if maj != cbg.MajOther {
return fmt.Errorf("booleans must be major type 7")
}
switch extra {
case 20:
t.Received = false
case 21:
t.Received = true
default:
return fmt.Errorf("booleans are either major type 7, value 20 or 21 (got %d)", extra)
}
// t.Err (string) (string)
case "Err":
{
sval, err := cbg.ReadStringBuf(br, scratch)
if err != nil {
return err
}
t.Err = string(sval)
}
default:
return fmt.Errorf("unknown struct field %d: '%s'", i, name)

33
paychmgr/channellock.go Normal file
View File

@ -0,0 +1,33 @@
package paychmgr
import "sync"
type rwlock interface {
RLock()
RUnlock()
}
// channelLock manages locking for a specific channel.
// Some operations update the state of a single channel, and need to block
// other operations only on the same channel's state.
// Some operations update state that affects all channels, and need to block
// any operation against any channel.
type channelLock struct {
globalLock rwlock
chanLock sync.Mutex
}
func (l *channelLock) Lock() {
// Wait for other operations by this channel to finish.
// Exclusive per-channel (no other ops by this channel allowed).
l.chanLock.Lock()
// Wait for operations affecting all channels to finish.
// Allows ops by other channels in parallel, but blocks all operations
// if global lock is taken exclusively (eg when adding a channel)
l.globalLock.RLock()
}
func (l *channelLock) Unlock() {
l.globalLock.RUnlock()
l.chanLock.Unlock()
}

278
paychmgr/manager.go Normal file
View File

@ -0,0 +1,278 @@
package paychmgr
import (
"context"
"sync"
"github.com/ipfs/go-datastore"
"golang.org/x/sync/errgroup"
xerrors "golang.org/x/xerrors"
"github.com/filecoin-project/lotus/api"
"github.com/filecoin-project/specs-actors/actors/builtin/paych"
"github.com/ipfs/go-cid"
logging "github.com/ipfs/go-log/v2"
"go.uber.org/fx"
"github.com/filecoin-project/go-address"
"github.com/filecoin-project/lotus/chain/stmgr"
"github.com/filecoin-project/lotus/chain/types"
"github.com/filecoin-project/lotus/node/impl/full"
)
var log = logging.Logger("paych")
type ManagerApi struct {
fx.In
full.MpoolAPI
full.WalletAPI
full.StateAPI
}
type StateManagerApi interface {
LoadActorState(ctx context.Context, a address.Address, out interface{}, ts *types.TipSet) (*types.Actor, error)
Call(ctx context.Context, msg *types.Message, ts *types.TipSet) (*api.InvocResult, error)
}
type Manager struct {
// The Manager context is used to terminate wait operations on shutdown
ctx context.Context
shutdown context.CancelFunc
store *Store
sm StateManagerApi
sa *stateAccessor
pchapi paychApi
lk sync.RWMutex
channels map[string]*channelAccessor
mpool full.MpoolAPI
wallet full.WalletAPI
state full.StateAPI
}
type paychAPIImpl struct {
full.MpoolAPI
full.StateAPI
}
func NewManager(sm *stmgr.StateManager, pchstore *Store, api ManagerApi) *Manager {
return &Manager{
store: pchstore,
sm: sm,
sa: &stateAccessor{sm: sm},
channels: make(map[string]*channelAccessor),
// TODO: Is this the correct way to do this or can I do something different
// with dependency injection?
pchapi: &paychAPIImpl{api.MpoolAPI, api.StateAPI},
mpool: api.MpoolAPI,
wallet: api.WalletAPI,
state: api.StateAPI,
}
}
// newManager is used by the tests to supply mocks
func newManager(sm StateManagerApi, pchstore *Store, pchapi paychApi) (*Manager, error) {
pm := &Manager{
store: pchstore,
sm: sm,
sa: &stateAccessor{sm: sm},
channels: make(map[string]*channelAccessor),
pchapi: pchapi,
}
return pm, pm.Start(context.Background())
}
// HandleManager is called by dependency injection to set up hooks
func HandleManager(lc fx.Lifecycle, pm *Manager) {
lc.Append(fx.Hook{
OnStart: func(ctx context.Context) error {
return pm.Start(ctx)
},
OnStop: func(context.Context) error {
return pm.Stop()
},
})
}
// Start checks the datastore to see if there are any channels that have
// outstanding add funds messages, and if so, waits on the messages.
// Outstanding messages can occur if an add funds message was sent
// and then lotus was shut down or crashed before the result was
// received.
func (pm *Manager) Start(ctx context.Context) error {
pm.ctx, pm.shutdown = context.WithCancel(ctx)
cis, err := pm.store.WithPendingAddFunds()
if err != nil {
return err
}
group := errgroup.Group{}
for _, chanInfo := range cis {
ci := chanInfo
if ci.CreateMsg != nil {
group.Go(func() error {
ca, err := pm.accessorByFromTo(ci.Control, ci.Target)
if err != nil {
return xerrors.Errorf("error initializing payment channel manager %s -> %s: %s", ci.Control, ci.Target, err)
}
go ca.waitForPaychCreateMsg(ci.Control, ci.Target, *ci.CreateMsg, nil)
return nil
})
} else if ci.AddFundsMsg != nil {
group.Go(func() error {
ca, err := pm.accessorByAddress(*ci.Channel)
if err != nil {
return xerrors.Errorf("error initializing payment channel manager %s: %s", ci.Channel, err)
}
go ca.waitForAddFundsMsg(ci.Control, ci.Target, *ci.AddFundsMsg, nil)
return nil
})
}
}
return group.Wait()
}
// Stop shuts down any processes used by the manager
func (pm *Manager) Stop() error {
pm.shutdown()
return nil
}
func (pm *Manager) TrackOutboundChannel(ctx context.Context, ch address.Address) error {
return pm.trackChannel(ctx, ch, DirOutbound)
}
func (pm *Manager) TrackInboundChannel(ctx context.Context, ch address.Address) error {
return pm.trackChannel(ctx, ch, DirInbound)
}
func (pm *Manager) trackChannel(ctx context.Context, ch address.Address, dir uint64) error {
pm.lk.Lock()
defer pm.lk.Unlock()
ci, err := pm.sa.loadStateChannelInfo(ctx, ch, dir)
if err != nil {
return err
}
return pm.store.TrackChannel(ci)
}
func (pm *Manager) GetPaych(ctx context.Context, from, to address.Address, amt types.BigInt) (address.Address, cid.Cid, error) {
chanAccessor, err := pm.accessorByFromTo(from, to)
if err != nil {
return address.Undef, cid.Undef, err
}
return chanAccessor.getPaych(ctx, from, to, amt)
}
// GetPaychWaitReady waits until the create channel / add funds message with the
// given message CID arrives.
// The returned channel address can safely be used against the Manager methods.
func (pm *Manager) GetPaychWaitReady(ctx context.Context, mcid cid.Cid) (address.Address, error) {
// Find the channel associated with the message CID
ci, err := pm.store.ByMessageCid(mcid)
if err != nil {
if err == datastore.ErrNotFound {
return address.Undef, xerrors.Errorf("Could not find wait msg cid %s", mcid)
}
return address.Undef, err
}
chanAccessor, err := pm.accessorByFromTo(ci.Control, ci.Target)
if err != nil {
return address.Undef, err
}
return chanAccessor.getPaychWaitReady(ctx, mcid)
}
func (pm *Manager) ListChannels() ([]address.Address, error) {
// Need to take an exclusive lock here so that channel operations can't run
// in parallel (see channelLock)
pm.lk.Lock()
defer pm.lk.Unlock()
return pm.store.ListChannels()
}
func (pm *Manager) GetChannelInfo(addr address.Address) (*ChannelInfo, error) {
ca, err := pm.accessorByAddress(addr)
if err != nil {
return nil, err
}
return ca.getChannelInfo(addr)
}
// CheckVoucherValid checks if the given voucher is valid (is or could become spendable at some point)
func (pm *Manager) CheckVoucherValid(ctx context.Context, ch address.Address, sv *paych.SignedVoucher) error {
ca, err := pm.accessorByAddress(ch)
if err != nil {
return err
}
_, err = ca.checkVoucherValid(ctx, ch, sv)
return err
}
// CheckVoucherSpendable checks if the given voucher is currently spendable
func (pm *Manager) CheckVoucherSpendable(ctx context.Context, ch address.Address, sv *paych.SignedVoucher, secret []byte, proof []byte) (bool, error) {
ca, err := pm.accessorByAddress(ch)
if err != nil {
return false, err
}
return ca.checkVoucherSpendable(ctx, ch, sv, secret, proof)
}
func (pm *Manager) AddVoucher(ctx context.Context, ch address.Address, sv *paych.SignedVoucher, proof []byte, minDelta types.BigInt) (types.BigInt, error) {
ca, err := pm.accessorByAddress(ch)
if err != nil {
return types.NewInt(0), err
}
return ca.addVoucher(ctx, ch, sv, proof, minDelta)
}
func (pm *Manager) AllocateLane(ch address.Address) (uint64, error) {
ca, err := pm.accessorByAddress(ch)
if err != nil {
return 0, err
}
return ca.allocateLane(ch)
}
func (pm *Manager) ListVouchers(ctx context.Context, ch address.Address) ([]*VoucherInfo, error) {
ca, err := pm.accessorByAddress(ch)
if err != nil {
return nil, err
}
return ca.listVouchers(ctx, ch)
}
func (pm *Manager) NextNonceForLane(ctx context.Context, ch address.Address, lane uint64) (uint64, error) {
ca, err := pm.accessorByAddress(ch)
if err != nil {
return 0, err
}
return ca.nextNonceForLane(ctx, ch, lane)
}
func (pm *Manager) Settle(ctx context.Context, addr address.Address) (cid.Cid, error) {
ca, err := pm.accessorByAddress(addr)
if err != nil {
return cid.Undef, err
}
return ca.settle(ctx, addr)
}

58
paychmgr/msglistener.go Normal file
View File

@ -0,0 +1,58 @@
package paychmgr
import (
"sync"
"github.com/google/uuid"
"github.com/ipfs/go-cid"
)
type msgListener struct {
id string
cb func(c cid.Cid, err error)
}
type msgListeners struct {
lk sync.Mutex
listeners []*msgListener
}
func (ml *msgListeners) onMsg(mcid cid.Cid, cb func(error)) string {
ml.lk.Lock()
defer ml.lk.Unlock()
l := &msgListener{
id: uuid.New().String(),
cb: func(c cid.Cid, err error) {
if mcid.Equals(c) {
cb(err)
}
},
}
ml.listeners = append(ml.listeners, l)
return l.id
}
func (ml *msgListeners) fireMsgComplete(mcid cid.Cid, err error) {
ml.lk.Lock()
defer ml.lk.Unlock()
for _, l := range ml.listeners {
l.cb(mcid, err)
}
}
func (ml *msgListeners) unsubscribe(sub string) {
for i, l := range ml.listeners {
if l.id == sub {
ml.removeListener(i)
return
}
}
}
func (ml *msgListeners) removeListener(i int) {
copy(ml.listeners[i:], ml.listeners[i+1:])
ml.listeners[len(ml.listeners)-1] = nil
ml.listeners = ml.listeners[:len(ml.listeners)-1]
}

View File

@ -0,0 +1,96 @@
package paychmgr
import (
"testing"
"github.com/ipfs/go-cid"
"github.com/stretchr/testify/require"
"golang.org/x/xerrors"
)
func testCids() []cid.Cid {
c1, _ := cid.Decode("QmdmGQmRgRjazArukTbsXuuxmSHsMCcRYPAZoGhd6e3MuS")
c2, _ := cid.Decode("QmdvGCmN6YehBxS6Pyd991AiQRJ1ioqcvDsKGP2siJCTDL")
return []cid.Cid{c1, c2}
}
func TestMsgListener(t *testing.T) {
var ml msgListeners
done := false
experr := xerrors.Errorf("some err")
cids := testCids()
ml.onMsg(cids[0], func(err error) {
require.Equal(t, experr, err)
done = true
})
ml.fireMsgComplete(cids[0], experr)
if !done {
t.Fatal("failed to fire event")
}
}
func TestMsgListenerNilErr(t *testing.T) {
var ml msgListeners
done := false
cids := testCids()
ml.onMsg(cids[0], func(err error) {
require.Nil(t, err)
done = true
})
ml.fireMsgComplete(cids[0], nil)
if !done {
t.Fatal("failed to fire event")
}
}
func TestMsgListenerUnsub(t *testing.T) {
var ml msgListeners
done := false
experr := xerrors.Errorf("some err")
cids := testCids()
id1 := ml.onMsg(cids[0], func(err error) {
t.Fatal("should not call unsubscribed listener")
})
ml.onMsg(cids[0], func(err error) {
require.Equal(t, experr, err)
done = true
})
ml.unsubscribe(id1)
ml.fireMsgComplete(cids[0], experr)
if !done {
t.Fatal("failed to fire event")
}
}
func TestMsgListenerMulti(t *testing.T) {
var ml msgListeners
count := 0
cids := testCids()
ml.onMsg(cids[0], func(err error) {
count++
})
ml.onMsg(cids[0], func(err error) {
count++
})
ml.onMsg(cids[1], func(err error) {
count++
})
ml.fireMsgComplete(cids[0], nil)
require.Equal(t, 2, count)
ml.fireMsgComplete(cids[1], nil)
require.Equal(t, 3, count)
}

View File

@ -5,118 +5,75 @@ import (
"context"
"fmt"
"github.com/filecoin-project/specs-actors/actors/abi/big"
"github.com/filecoin-project/lotus/api"
"github.com/ipfs/go-cid"
"github.com/filecoin-project/go-address"
cborutil "github.com/filecoin-project/go-cbor-util"
"github.com/filecoin-project/lotus/chain/actors"
"github.com/filecoin-project/lotus/chain/types"
"github.com/filecoin-project/lotus/lib/sigs"
"github.com/filecoin-project/specs-actors/actors/abi/big"
"github.com/filecoin-project/specs-actors/actors/builtin"
"github.com/filecoin-project/specs-actors/actors/builtin/account"
"github.com/filecoin-project/specs-actors/actors/builtin/paych"
"golang.org/x/xerrors"
logging "github.com/ipfs/go-log/v2"
"go.uber.org/fx"
"github.com/filecoin-project/go-address"
"github.com/filecoin-project/lotus/chain/actors"
"github.com/filecoin-project/lotus/chain/stmgr"
"github.com/filecoin-project/lotus/chain/types"
"github.com/filecoin-project/lotus/lib/sigs"
"github.com/filecoin-project/lotus/node/impl/full"
xerrors "golang.org/x/xerrors"
)
var log = logging.Logger("paych")
type ManagerApi struct {
fx.In
full.MpoolAPI
full.WalletAPI
full.StateAPI
// channelAccessor is used to simplify locking when accessing a channel
type channelAccessor struct {
// waitCtx is used by processes that wait for things to be confirmed
// on chain
waitCtx context.Context
sm StateManagerApi
sa *stateAccessor
api paychApi
store *Store
lk *channelLock
fundsReqQueue []*fundsReq
msgListeners msgListeners
}
type StateManagerApi interface {
LoadActorState(ctx context.Context, a address.Address, out interface{}, ts *types.TipSet) (*types.Actor, error)
Call(ctx context.Context, msg *types.Message, ts *types.TipSet) (*api.InvocResult, error)
}
type Manager struct {
store *Store
sm StateManagerApi
mpool full.MpoolAPI
wallet full.WalletAPI
state full.StateAPI
}
func NewManager(sm *stmgr.StateManager, pchstore *Store, api ManagerApi) *Manager {
return &Manager{
store: pchstore,
sm: sm,
mpool: api.MpoolAPI,
wallet: api.WalletAPI,
state: api.StateAPI,
func newChannelAccessor(pm *Manager) *channelAccessor {
return &channelAccessor{
lk: &channelLock{globalLock: &pm.lk},
sm: pm.sm,
sa: &stateAccessor{sm: pm.sm},
api: pm.pchapi,
store: pm.store,
waitCtx: pm.ctx,
}
}
// Used by the tests to supply mocks
func newManager(sm StateManagerApi, pchstore *Store) *Manager {
return &Manager{
store: pchstore,
sm: sm,
}
func (ca *channelAccessor) getChannelInfo(addr address.Address) (*ChannelInfo, error) {
ca.lk.Lock()
defer ca.lk.Unlock()
return ca.store.ByAddress(addr)
}
func (pm *Manager) TrackOutboundChannel(ctx context.Context, ch address.Address) error {
return pm.trackChannel(ctx, ch, DirOutbound)
func (ca *channelAccessor) checkVoucherValid(ctx context.Context, ch address.Address, sv *paych.SignedVoucher) (map[uint64]*paych.LaneState, error) {
ca.lk.Lock()
defer ca.lk.Unlock()
return ca.checkVoucherValidUnlocked(ctx, ch, sv)
}
func (pm *Manager) TrackInboundChannel(ctx context.Context, ch address.Address) error {
return pm.trackChannel(ctx, ch, DirInbound)
}
func (pm *Manager) trackChannel(ctx context.Context, ch address.Address, dir uint64) error {
ci, err := pm.loadStateChannelInfo(ctx, ch, dir)
if err != nil {
return err
}
return pm.store.TrackChannel(ci)
}
func (pm *Manager) ListChannels() ([]address.Address, error) {
return pm.store.ListChannels()
}
func (pm *Manager) GetChannelInfo(addr address.Address) (*ChannelInfo, error) {
return pm.store.getChannelInfo(addr)
}
// CheckVoucherValid checks if the given voucher is valid (is or could become spendable at some point)
func (pm *Manager) CheckVoucherValid(ctx context.Context, ch address.Address, sv *paych.SignedVoucher) error {
_, err := pm.checkVoucherValid(ctx, ch, sv)
return err
}
func (pm *Manager) checkVoucherValid(ctx context.Context, ch address.Address, sv *paych.SignedVoucher) (map[uint64]*paych.LaneState, error) {
func (ca *channelAccessor) checkVoucherValidUnlocked(ctx context.Context, ch address.Address, sv *paych.SignedVoucher) (map[uint64]*paych.LaneState, error) {
if sv.ChannelAddr != ch {
return nil, xerrors.Errorf("voucher ChannelAddr doesn't match channel address, got %s, expected %s", sv.ChannelAddr, ch)
}
act, pchState, err := pm.loadPaychState(ctx, ch)
act, pchState, err := ca.sa.loadPaychState(ctx, ch)
if err != nil {
return nil, err
}
var account account.State
_, err = pm.sm.LoadActorState(ctx, pchState.From, &account, nil)
var actState account.State
_, err = ca.sm.LoadActorState(ctx, pchState.From, &actState, nil)
if err != nil {
return nil, err
}
from := account.Address
from := actState.Address
// verify signature
vb, err := sv.SigningBytes()
@ -132,7 +89,7 @@ func (pm *Manager) checkVoucherValid(ctx context.Context, ch address.Address, sv
}
// Check the voucher against the highest known voucher nonce / value
laneStates, err := pm.laneState(pchState, ch)
laneStates, err := ca.laneState(pchState, ch)
if err != nil {
return nil, err
}
@ -164,7 +121,7 @@ func (pm *Manager) checkVoucherValid(ctx context.Context, ch address.Address, sv
// lane 2: 2
// -
// total: 7
totalRedeemed, err := pm.totalRedeemedWithVoucher(laneStates, sv)
totalRedeemed, err := ca.totalRedeemedWithVoucher(laneStates, sv)
if err != nil {
return nil, err
}
@ -183,15 +140,17 @@ func (pm *Manager) checkVoucherValid(ctx context.Context, ch address.Address, sv
return laneStates, nil
}
// CheckVoucherSpendable checks if the given voucher is currently spendable
func (pm *Manager) CheckVoucherSpendable(ctx context.Context, ch address.Address, sv *paych.SignedVoucher, secret []byte, proof []byte) (bool, error) {
recipient, err := pm.getPaychRecipient(ctx, ch)
func (ca *channelAccessor) checkVoucherSpendable(ctx context.Context, ch address.Address, sv *paych.SignedVoucher, secret []byte, proof []byte) (bool, error) {
ca.lk.Lock()
defer ca.lk.Unlock()
recipient, err := ca.getPaychRecipient(ctx, ch)
if err != nil {
return false, err
}
if sv.Extra != nil && proof == nil {
known, err := pm.ListVouchers(ctx, ch)
known, err := ca.store.VouchersForPaych(ch)
if err != nil {
return false, err
}
@ -221,7 +180,7 @@ func (pm *Manager) CheckVoucherSpendable(ctx context.Context, ch address.Address
return false, err
}
ret, err := pm.sm.Call(ctx, &types.Message{
ret, err := ca.sm.Call(ctx, &types.Message{
From: recipient,
To: ch,
Method: builtin.MethodsPaych.UpdateChannelState,
@ -238,22 +197,22 @@ func (pm *Manager) CheckVoucherSpendable(ctx context.Context, ch address.Address
return true, nil
}
func (pm *Manager) getPaychRecipient(ctx context.Context, ch address.Address) (address.Address, error) {
func (ca *channelAccessor) getPaychRecipient(ctx context.Context, ch address.Address) (address.Address, error) {
var state paych.State
if _, err := pm.sm.LoadActorState(ctx, ch, &state, nil); err != nil {
if _, err := ca.sm.LoadActorState(ctx, ch, &state, nil); err != nil {
return address.Address{}, err
}
return state.To, nil
}
func (pm *Manager) AddVoucher(ctx context.Context, ch address.Address, sv *paych.SignedVoucher, proof []byte, minDelta types.BigInt) (types.BigInt, error) {
pm.store.lk.Lock()
defer pm.store.lk.Unlock()
func (ca *channelAccessor) addVoucher(ctx context.Context, ch address.Address, sv *paych.SignedVoucher, proof []byte, minDelta types.BigInt) (types.BigInt, error) {
ca.lk.Lock()
defer ca.lk.Unlock()
ci, err := pm.store.getChannelInfo(ch)
ci, err := ca.store.ByAddress(ch)
if err != nil {
return types.NewInt(0), err
return types.BigInt{}, err
}
// Check if the voucher has already been added
@ -275,7 +234,7 @@ func (pm *Manager) AddVoucher(ctx context.Context, ch address.Address, sv *paych
Proof: proof,
}
return types.NewInt(0), pm.store.putChannelInfo(ci)
return types.NewInt(0), ca.store.putChannelInfo(ci)
}
// Otherwise just ignore the duplicate voucher
@ -284,7 +243,7 @@ func (pm *Manager) AddVoucher(ctx context.Context, ch address.Address, sv *paych
}
// Check voucher validity
laneStates, err := pm.checkVoucherValid(ctx, ch, sv)
laneStates, err := ca.checkVoucherValidUnlocked(ctx, ch, sv)
if err != nil {
return types.NewInt(0), err
}
@ -311,35 +270,32 @@ func (pm *Manager) AddVoucher(ctx context.Context, ch address.Address, sv *paych
ci.NextLane = sv.Lane + 1
}
return delta, pm.store.putChannelInfo(ci)
return delta, ca.store.putChannelInfo(ci)
}
func (pm *Manager) AllocateLane(ch address.Address) (uint64, error) {
func (ca *channelAccessor) allocateLane(ch address.Address) (uint64, error) {
ca.lk.Lock()
defer ca.lk.Unlock()
// TODO: should this take into account lane state?
return pm.store.AllocateLane(ch)
return ca.store.AllocateLane(ch)
}
func (pm *Manager) ListVouchers(ctx context.Context, ch address.Address) ([]*VoucherInfo, error) {
func (ca *channelAccessor) listVouchers(ctx context.Context, ch address.Address) ([]*VoucherInfo, error) {
ca.lk.Lock()
defer ca.lk.Unlock()
// TODO: just having a passthrough method like this feels odd. Seems like
// there should be some filtering we're doing here
return pm.store.VouchersForPaych(ch)
return ca.store.VouchersForPaych(ch)
}
func (pm *Manager) OutboundChanTo(from, to address.Address) (address.Address, error) {
pm.store.lk.Lock()
defer pm.store.lk.Unlock()
func (ca *channelAccessor) nextNonceForLane(ctx context.Context, ch address.Address, lane uint64) (uint64, error) {
ca.lk.Lock()
defer ca.lk.Unlock()
return pm.store.findChan(func(ci *ChannelInfo) bool {
if ci.Direction != DirOutbound {
return false
}
return ci.Control == from && ci.Target == to
})
}
func (pm *Manager) NextNonceForLane(ctx context.Context, ch address.Address, lane uint64) (uint64, error) {
// TODO: should this take into account lane state?
vouchers, err := pm.store.VouchersForPaych(ch)
vouchers, err := ca.store.VouchersForPaych(ch)
if err != nil {
return 0, err
}
@ -355,3 +311,109 @@ func (pm *Manager) NextNonceForLane(ctx context.Context, ch address.Address, lan
return maxnonce + 1, nil
}
// laneState gets the LaneStates from chain, then applies all vouchers in
// the data store over the chain state
func (ca *channelAccessor) laneState(state *paych.State, ch address.Address) (map[uint64]*paych.LaneState, error) {
// TODO: we probably want to call UpdateChannelState with all vouchers to be fully correct
// (but technically dont't need to)
laneStates := make(map[uint64]*paych.LaneState, len(state.LaneStates))
// Get the lane state from the chain
for _, laneState := range state.LaneStates {
laneStates[laneState.ID] = laneState
}
// Apply locally stored vouchers
vouchers, err := ca.store.VouchersForPaych(ch)
if err != nil && err != ErrChannelNotTracked {
return nil, err
}
for _, v := range vouchers {
for range v.Voucher.Merges {
return nil, xerrors.Errorf("paych merges not handled yet")
}
// If there's a voucher for a lane that isn't in chain state just
// create it
ls, ok := laneStates[v.Voucher.Lane]
if !ok {
ls = &paych.LaneState{
ID: v.Voucher.Lane,
Redeemed: types.NewInt(0),
Nonce: 0,
}
laneStates[v.Voucher.Lane] = ls
}
if v.Voucher.Nonce < ls.Nonce {
continue
}
ls.Nonce = v.Voucher.Nonce
ls.Redeemed = v.Voucher.Amount
}
return laneStates, nil
}
// Get the total redeemed amount across all lanes, after applying the voucher
func (ca *channelAccessor) totalRedeemedWithVoucher(laneStates map[uint64]*paych.LaneState, sv *paych.SignedVoucher) (big.Int, error) {
// TODO: merges
if len(sv.Merges) != 0 {
return big.Int{}, xerrors.Errorf("dont currently support paych lane merges")
}
total := big.NewInt(0)
for _, ls := range laneStates {
total = big.Add(total, ls.Redeemed)
}
lane, ok := laneStates[sv.Lane]
if ok {
// If the voucher is for an existing lane, and the voucher nonce
// and is higher than the lane nonce
if sv.Nonce > lane.Nonce {
// Add the delta between the redeemed amount and the voucher
// amount to the total
delta := big.Sub(sv.Amount, lane.Redeemed)
total = big.Add(total, delta)
}
} else {
// If the voucher is *not* for an existing lane, just add its
// value (implicitly a new lane will be created for the voucher)
total = big.Add(total, sv.Amount)
}
return total, nil
}
func (ca *channelAccessor) settle(ctx context.Context, ch address.Address) (cid.Cid, error) {
ca.lk.Lock()
defer ca.lk.Unlock()
ci, err := ca.store.ByAddress(ch)
if err != nil {
return cid.Undef, err
}
msg := &types.Message{
To: ch,
From: ci.Control,
Value: types.NewInt(0),
Method: builtin.MethodsPaych.Settle,
}
smgs, err := ca.api.MpoolPushMessage(ctx, msg)
if err != nil {
return cid.Undef, err
}
ci.Settling = true
err = ca.store.putChannelInfo(ci)
if err != nil {
log.Errorf("Error marking channel as settled: %s", err)
}
return smgs.Cid(), err
}

View File

@ -104,13 +104,15 @@ func TestPaychOutbound(t *testing.T) {
LaneStates: []*paych.LaneState{},
})
mgr := newManager(sm, store)
err := mgr.TrackOutboundChannel(ctx, ch)
mgr, err := newManager(sm, store, nil)
require.NoError(t, err)
err = mgr.TrackOutboundChannel(ctx, ch)
require.NoError(t, err)
ci, err := mgr.GetChannelInfo(ch)
require.NoError(t, err)
require.Equal(t, ci.Channel, ch)
require.Equal(t, *ci.Channel, ch)
require.Equal(t, ci.Control, from)
require.Equal(t, ci.Target, to)
require.EqualValues(t, ci.Direction, DirOutbound)
@ -140,13 +142,15 @@ func TestPaychInbound(t *testing.T) {
LaneStates: []*paych.LaneState{},
})
mgr := newManager(sm, store)
err := mgr.TrackInboundChannel(ctx, ch)
mgr, err := newManager(sm, store, nil)
require.NoError(t, err)
err = mgr.TrackInboundChannel(ctx, ch)
require.NoError(t, err)
ci, err := mgr.GetChannelInfo(ch)
require.NoError(t, err)
require.Equal(t, ci.Channel, ch)
require.Equal(t, *ci.Channel, ch)
require.Equal(t, ci.Control, to)
require.Equal(t, ci.Target, from)
require.EqualValues(t, ci.Direction, DirInbound)
@ -321,8 +325,10 @@ func TestCheckVoucherValid(t *testing.T) {
LaneStates: tcase.laneStates,
})
mgr := newManager(sm, store)
err := mgr.TrackInboundChannel(ctx, ch)
mgr, err := newManager(sm, store, nil)
require.NoError(t, err)
err = mgr.TrackInboundChannel(ctx, ch)
require.NoError(t, err)
sv := testCreateVoucher(t, ch, tcase.voucherLane, tcase.voucherNonce, tcase.voucherAmount, tcase.key)
@ -382,8 +388,10 @@ func TestCheckVoucherValidCountingAllLanes(t *testing.T) {
LaneStates: laneStates,
})
mgr := newManager(sm, store)
err := mgr.TrackInboundChannel(ctx, ch)
mgr, err := newManager(sm, store, nil)
require.NoError(t, err)
err = mgr.TrackInboundChannel(ctx, ch)
require.NoError(t, err)
//
@ -690,8 +698,10 @@ func testSetupMgrWithChannel(ctx context.Context, t *testing.T) (*Manager, addre
})
store := NewStore(ds_sync.MutexWrap(ds.NewMapDatastore()))
mgr := newManager(sm, store)
err := mgr.TrackInboundChannel(ctx, ch)
mgr, err := newManager(sm, store, nil)
require.NoError(t, err)
err = mgr.TrackInboundChannel(ctx, ch)
require.NoError(t, err)
return mgr, ch, fromKeyPrivate
}

756
paychmgr/paychget_test.go Normal file
View File

@ -0,0 +1,756 @@
package paychmgr
import (
"context"
"sync"
"testing"
"time"
cborrpc "github.com/filecoin-project/go-cbor-util"
init_ "github.com/filecoin-project/specs-actors/actors/builtin/init"
"github.com/filecoin-project/specs-actors/actors/builtin"
"github.com/filecoin-project/lotus/api"
"github.com/filecoin-project/lotus/chain/types"
"github.com/filecoin-project/go-address"
"github.com/filecoin-project/specs-actors/actors/abi/big"
tutils "github.com/filecoin-project/specs-actors/support/testing"
"github.com/ipfs/go-cid"
ds "github.com/ipfs/go-datastore"
ds_sync "github.com/ipfs/go-datastore/sync"
"github.com/stretchr/testify/require"
)
type waitingCall struct {
response chan types.MessageReceipt
}
type mockPaychAPI struct {
lk sync.Mutex
messages map[cid.Cid]*types.SignedMessage
waitingCalls []*waitingCall
}
func newMockPaychAPI() *mockPaychAPI {
return &mockPaychAPI{
messages: make(map[cid.Cid]*types.SignedMessage),
}
}
func (pchapi *mockPaychAPI) StateWaitMsg(ctx context.Context, msg cid.Cid, confidence uint64) (*api.MsgLookup, error) {
response := make(chan types.MessageReceipt)
pchapi.lk.Lock()
pchapi.waitingCalls = append(pchapi.waitingCalls, &waitingCall{response: response})
pchapi.lk.Unlock()
receipt := <-response
return &api.MsgLookup{Receipt: receipt}, nil
}
func (pchapi *mockPaychAPI) MpoolPushMessage(ctx context.Context, msg *types.Message) (*types.SignedMessage, error) {
pchapi.lk.Lock()
defer pchapi.lk.Unlock()
smsg := &types.SignedMessage{Message: *msg}
pchapi.messages[smsg.Cid()] = smsg
return smsg, nil
}
func (pchapi *mockPaychAPI) pushedMessages(c cid.Cid) *types.SignedMessage {
pchapi.lk.Lock()
defer pchapi.lk.Unlock()
return pchapi.messages[c]
}
func (pchapi *mockPaychAPI) pushedMessageCount() int {
pchapi.lk.Lock()
defer pchapi.lk.Unlock()
return len(pchapi.messages)
}
func (pchapi *mockPaychAPI) finishWaitingCalls(receipt types.MessageReceipt) {
pchapi.lk.Lock()
defer pchapi.lk.Unlock()
for _, call := range pchapi.waitingCalls {
call.response <- receipt
}
pchapi.waitingCalls = nil
}
func (pchapi *mockPaychAPI) close() {
pchapi.finishWaitingCalls(types.MessageReceipt{
ExitCode: 0,
Return: []byte{},
})
}
func testChannelResponse(t *testing.T, ch address.Address) types.MessageReceipt {
createChannelRet := init_.ExecReturn{
IDAddress: ch,
RobustAddress: ch,
}
createChannelRetBytes, err := cborrpc.Dump(&createChannelRet)
require.NoError(t, err)
createChannelResponse := types.MessageReceipt{
ExitCode: 0,
Return: createChannelRetBytes,
}
return createChannelResponse
}
// TestPaychGetCreateChannelMsg tests that GetPaych sends a message to create
// a new channel with the correct funds
func TestPaychGetCreateChannelMsg(t *testing.T) {
ctx := context.Background()
store := NewStore(ds_sync.MutexWrap(ds.NewMapDatastore()))
from := tutils.NewIDAddr(t, 101)
to := tutils.NewIDAddr(t, 102)
sm := newMockStateManager()
pchapi := newMockPaychAPI()
defer pchapi.close()
mgr, err := newManager(sm, store, pchapi)
require.NoError(t, err)
amt := big.NewInt(10)
ch, mcid, err := mgr.GetPaych(ctx, from, to, amt)
require.NoError(t, err)
require.Equal(t, address.Undef, ch)
pushedMsg := pchapi.pushedMessages(mcid)
require.Equal(t, from, pushedMsg.Message.From)
require.Equal(t, builtin.InitActorAddr, pushedMsg.Message.To)
require.Equal(t, amt, pushedMsg.Message.Value)
}
// TestPaychGetCreateChannelThenAddFunds tests creating a channel and then
// adding funds to it
func TestPaychGetCreateChannelThenAddFunds(t *testing.T) {
ctx := context.Background()
store := NewStore(ds_sync.MutexWrap(ds.NewMapDatastore()))
ch := tutils.NewIDAddr(t, 100)
from := tutils.NewIDAddr(t, 101)
to := tutils.NewIDAddr(t, 102)
sm := newMockStateManager()
pchapi := newMockPaychAPI()
defer pchapi.close()
mgr, err := newManager(sm, store, pchapi)
require.NoError(t, err)
// Send create message for a channel with value 10
amt := big.NewInt(10)
_, createMsgCid, err := mgr.GetPaych(ctx, from, to, amt)
require.NoError(t, err)
// Should have no channels yet (message sent but channel not created)
cis, err := mgr.ListChannels()
require.NoError(t, err)
require.Len(t, cis, 0)
// 1. Set up create channel response (sent in response to WaitForMsg())
response := testChannelResponse(t, ch)
done := make(chan struct{})
go func() {
defer close(done)
// 2. Request add funds - should block until create channel has completed
amt2 := big.NewInt(5)
ch2, addFundsMsgCid, err := mgr.GetPaych(ctx, from, to, amt2)
// 4. This GetPaych should return after create channel from first
// GetPaych completes
require.NoError(t, err)
// Expect the channel to be the same
require.Equal(t, ch, ch2)
// Expect add funds message CID to be different to create message cid
require.NotEqual(t, createMsgCid, addFundsMsgCid)
// Should have one channel, whose address is the channel that was created
cis, err := mgr.ListChannels()
require.NoError(t, err)
require.Len(t, cis, 1)
require.Equal(t, ch, cis[0])
// Amount should be amount sent to first GetPaych (to create
// channel).
// PendingAmount should be amount sent in second GetPaych
// (second GetPaych triggered add funds, which has not yet been confirmed)
ci, err := mgr.GetChannelInfo(ch)
require.NoError(t, err)
require.EqualValues(t, 10, ci.Amount.Int64())
require.EqualValues(t, 5, ci.PendingAmount.Int64())
require.Nil(t, ci.CreateMsg)
// Trigger add funds confirmation
pchapi.finishWaitingCalls(types.MessageReceipt{ExitCode: 0})
time.Sleep(time.Millisecond * 10)
// Should still have one channel
cis, err = mgr.ListChannels()
require.NoError(t, err)
require.Len(t, cis, 1)
require.Equal(t, ch, cis[0])
// Channel amount should include last amount sent to GetPaych
ci, err = mgr.GetChannelInfo(ch)
require.NoError(t, err)
require.EqualValues(t, 15, ci.Amount.Int64())
require.EqualValues(t, 0, ci.PendingAmount.Int64())
require.Nil(t, ci.AddFundsMsg)
}()
// Give the go routine above a moment to run
time.Sleep(time.Millisecond * 10)
// 3. Send create channel response
pchapi.finishWaitingCalls(response)
<-done
}
// TestPaychGetCreateChannelWithErrorThenCreateAgain tests that if an
// operation is queued up behind a create channel operation, and the create
// channel fails, then the waiting operation can succeed.
func TestPaychGetCreateChannelWithErrorThenCreateAgain(t *testing.T) {
ctx := context.Background()
store := NewStore(ds_sync.MutexWrap(ds.NewMapDatastore()))
from := tutils.NewIDAddr(t, 101)
to := tutils.NewIDAddr(t, 102)
sm := newMockStateManager()
pchapi := newMockPaychAPI()
defer pchapi.close()
mgr, err := newManager(sm, store, pchapi)
require.NoError(t, err)
// Send create message for a channel
amt := big.NewInt(10)
_, _, err = mgr.GetPaych(ctx, from, to, amt)
require.NoError(t, err)
// 1. Set up create channel response (sent in response to WaitForMsg())
// This response indicates an error.
errResponse := types.MessageReceipt{
ExitCode: 1, // error
Return: []byte{},
}
done := make(chan struct{})
go func() {
defer close(done)
// 2. Should block until create channel has completed.
// Because first channel create fails, this request
// should be for channel create.
amt2 := big.NewInt(5)
ch2, _, err := mgr.GetPaych(ctx, from, to, amt2)
require.NoError(t, err)
require.Equal(t, address.Undef, ch2)
time.Sleep(time.Millisecond * 10)
// 4. Send a success response
ch := tutils.NewIDAddr(t, 100)
successResponse := testChannelResponse(t, ch)
pchapi.finishWaitingCalls(successResponse)
time.Sleep(time.Millisecond * 10)
// Should have one channel, whose address is the channel that was created
cis, err := mgr.ListChannels()
require.NoError(t, err)
require.Len(t, cis, 1)
require.Equal(t, ch, cis[0])
ci, err := mgr.GetChannelInfo(ch)
require.NoError(t, err)
require.Equal(t, amt2, ci.Amount)
}()
// Give the go routine above a moment to run
time.Sleep(time.Millisecond * 10)
// 3. Send error response to first channel create
pchapi.finishWaitingCalls(errResponse)
<-done
}
// TestPaychGetRecoverAfterError tests that after a create channel fails, the
// next attempt to create channel can succeed.
func TestPaychGetRecoverAfterError(t *testing.T) {
ctx := context.Background()
store := NewStore(ds_sync.MutexWrap(ds.NewMapDatastore()))
ch := tutils.NewIDAddr(t, 100)
from := tutils.NewIDAddr(t, 101)
to := tutils.NewIDAddr(t, 102)
sm := newMockStateManager()
pchapi := newMockPaychAPI()
defer pchapi.close()
mgr, err := newManager(sm, store, pchapi)
require.NoError(t, err)
// Send create message for a channel
amt := big.NewInt(10)
_, _, err = mgr.GetPaych(ctx, from, to, amt)
require.NoError(t, err)
time.Sleep(time.Millisecond * 10)
// Send error create channel response
pchapi.finishWaitingCalls(types.MessageReceipt{
ExitCode: 1, // error
Return: []byte{},
})
time.Sleep(time.Millisecond * 10)
// Send create message for a channel again
amt2 := big.NewInt(7)
_, _, err = mgr.GetPaych(ctx, from, to, amt2)
require.NoError(t, err)
time.Sleep(time.Millisecond * 10)
// Send success create channel response
response := testChannelResponse(t, ch)
pchapi.finishWaitingCalls(response)
time.Sleep(time.Millisecond * 10)
// Should have one channel, whose address is the channel that was created
cis, err := mgr.ListChannels()
require.NoError(t, err)
require.Len(t, cis, 1)
require.Equal(t, ch, cis[0])
ci, err := mgr.GetChannelInfo(ch)
require.NoError(t, err)
require.Equal(t, amt2, ci.Amount)
require.EqualValues(t, 0, ci.PendingAmount.Int64())
require.Nil(t, ci.CreateMsg)
}
// TestPaychGetRecoverAfterAddFundsError tests that after an add funds fails, the
// next attempt to add funds can succeed.
func TestPaychGetRecoverAfterAddFundsError(t *testing.T) {
ctx := context.Background()
store := NewStore(ds_sync.MutexWrap(ds.NewMapDatastore()))
ch := tutils.NewIDAddr(t, 100)
from := tutils.NewIDAddr(t, 101)
to := tutils.NewIDAddr(t, 102)
sm := newMockStateManager()
pchapi := newMockPaychAPI()
defer pchapi.close()
mgr, err := newManager(sm, store, pchapi)
require.NoError(t, err)
// Send create message for a channel
amt := big.NewInt(10)
_, _, err = mgr.GetPaych(ctx, from, to, amt)
require.NoError(t, err)
time.Sleep(time.Millisecond * 10)
// Send success create channel response
response := testChannelResponse(t, ch)
pchapi.finishWaitingCalls(response)
time.Sleep(time.Millisecond * 10)
// Send add funds message for channel
amt2 := big.NewInt(5)
_, _, err = mgr.GetPaych(ctx, from, to, amt2)
require.NoError(t, err)
time.Sleep(time.Millisecond * 10)
// Send error add funds response
pchapi.finishWaitingCalls(types.MessageReceipt{
ExitCode: 1, // error
Return: []byte{},
})
time.Sleep(time.Millisecond * 10)
// Should have one channel, whose address is the channel that was created
cis, err := mgr.ListChannels()
require.NoError(t, err)
require.Len(t, cis, 1)
require.Equal(t, ch, cis[0])
ci, err := mgr.GetChannelInfo(ch)
require.NoError(t, err)
require.Equal(t, amt, ci.Amount)
require.EqualValues(t, 0, ci.PendingAmount.Int64())
require.Nil(t, ci.CreateMsg)
require.Nil(t, ci.AddFundsMsg)
// Send add funds message for channel again
amt3 := big.NewInt(2)
_, _, err = mgr.GetPaych(ctx, from, to, amt3)
require.NoError(t, err)
time.Sleep(time.Millisecond * 10)
// Send success add funds response
pchapi.finishWaitingCalls(types.MessageReceipt{
ExitCode: 0,
Return: []byte{},
})
time.Sleep(time.Millisecond * 10)
// Should have one channel, whose address is the channel that was created
cis, err = mgr.ListChannels()
require.NoError(t, err)
require.Len(t, cis, 1)
require.Equal(t, ch, cis[0])
// Amount should include amount for successful add funds msg
ci, err = mgr.GetChannelInfo(ch)
require.NoError(t, err)
require.Equal(t, amt.Int64()+amt3.Int64(), ci.Amount.Int64())
require.EqualValues(t, 0, ci.PendingAmount.Int64())
require.Nil(t, ci.CreateMsg)
require.Nil(t, ci.AddFundsMsg)
}
// TestPaychGetRestartAfterCreateChannelMsg tests that if the system stops
// right after the create channel message is sent, the channel will be
// created when the system restarts.
func TestPaychGetRestartAfterCreateChannelMsg(t *testing.T) {
ctx := context.Background()
store := NewStore(ds_sync.MutexWrap(ds.NewMapDatastore()))
ch := tutils.NewIDAddr(t, 100)
from := tutils.NewIDAddr(t, 101)
to := tutils.NewIDAddr(t, 102)
sm := newMockStateManager()
pchapi := newMockPaychAPI()
mgr, err := newManager(sm, store, pchapi)
require.NoError(t, err)
// Send create message for a channel with value 10
amt := big.NewInt(10)
_, createMsgCid, err := mgr.GetPaych(ctx, from, to, amt)
require.NoError(t, err)
// Simulate shutting down lotus
pchapi.close()
// Create a new manager with the same datastore
sm2 := newMockStateManager()
pchapi2 := newMockPaychAPI()
defer pchapi2.close()
mgr2, err := newManager(sm2, store, pchapi2)
require.NoError(t, err)
// Should have no channels yet (message sent but channel not created)
cis, err := mgr2.ListChannels()
require.NoError(t, err)
require.Len(t, cis, 0)
// 1. Set up create channel response (sent in response to WaitForMsg())
response := testChannelResponse(t, ch)
done := make(chan struct{})
go func() {
defer close(done)
// 2. Request add funds - should block until create channel has completed
amt2 := big.NewInt(5)
ch2, addFundsMsgCid, err := mgr2.GetPaych(ctx, from, to, amt2)
// 4. This GetPaych should return after create channel from first
// GetPaych completes
require.NoError(t, err)
// Expect the channel to have been created
require.Equal(t, ch, ch2)
// Expect add funds message CID to be different to create message cid
require.NotEqual(t, createMsgCid, addFundsMsgCid)
// Should have one channel, whose address is the channel that was created
cis, err := mgr2.ListChannels()
require.NoError(t, err)
require.Len(t, cis, 1)
require.Equal(t, ch, cis[0])
// Amount should be amount sent to first GetPaych (to create
// channel).
// PendingAmount should be amount sent in second GetPaych
// (second GetPaych triggered add funds, which has not yet been confirmed)
ci, err := mgr2.GetChannelInfo(ch)
require.NoError(t, err)
require.EqualValues(t, 10, ci.Amount.Int64())
require.EqualValues(t, 5, ci.PendingAmount.Int64())
require.Nil(t, ci.CreateMsg)
}()
// Give the go routine above a moment to run
time.Sleep(time.Millisecond * 10)
// 3. Send create channel response
pchapi2.finishWaitingCalls(response)
<-done
}
// TestPaychGetRestartAfterAddFundsMsg tests that if the system stops
// right after the add funds message is sent, the add funds will be
// processed when the system restarts.
func TestPaychGetRestartAfterAddFundsMsg(t *testing.T) {
ctx := context.Background()
store := NewStore(ds_sync.MutexWrap(ds.NewMapDatastore()))
ch := tutils.NewIDAddr(t, 100)
from := tutils.NewIDAddr(t, 101)
to := tutils.NewIDAddr(t, 102)
sm := newMockStateManager()
pchapi := newMockPaychAPI()
mgr, err := newManager(sm, store, pchapi)
require.NoError(t, err)
// Send create message for a channel
amt := big.NewInt(10)
_, _, err = mgr.GetPaych(ctx, from, to, amt)
require.NoError(t, err)
time.Sleep(time.Millisecond * 10)
// Send success create channel response
response := testChannelResponse(t, ch)
pchapi.finishWaitingCalls(response)
time.Sleep(time.Millisecond * 10)
// Send add funds message for channel
amt2 := big.NewInt(5)
_, _, err = mgr.GetPaych(ctx, from, to, amt2)
require.NoError(t, err)
time.Sleep(time.Millisecond * 10)
// Simulate shutting down lotus
pchapi.close()
// Create a new manager with the same datastore
sm2 := newMockStateManager()
pchapi2 := newMockPaychAPI()
defer pchapi2.close()
time.Sleep(time.Millisecond * 10)
mgr2, err := newManager(sm2, store, pchapi2)
require.NoError(t, err)
// Send success add funds response
pchapi2.finishWaitingCalls(types.MessageReceipt{
ExitCode: 0,
Return: []byte{},
})
time.Sleep(time.Millisecond * 10)
// Should have one channel, whose address is the channel that was created
cis, err := mgr2.ListChannels()
require.NoError(t, err)
require.Len(t, cis, 1)
require.Equal(t, ch, cis[0])
// Amount should include amount for successful add funds msg
ci, err := mgr2.GetChannelInfo(ch)
require.NoError(t, err)
require.Equal(t, amt.Int64()+amt2.Int64(), ci.Amount.Int64())
require.EqualValues(t, 0, ci.PendingAmount.Int64())
require.Nil(t, ci.CreateMsg)
require.Nil(t, ci.AddFundsMsg)
}
// TestPaychGetWait tests that GetPaychWaitReady correctly waits for the
// channel to be created or funds to be added
func TestPaychGetWait(t *testing.T) {
ctx := context.Background()
store := NewStore(ds_sync.MutexWrap(ds.NewMapDatastore()))
from := tutils.NewIDAddr(t, 101)
to := tutils.NewIDAddr(t, 102)
sm := newMockStateManager()
pchapi := newMockPaychAPI()
defer pchapi.close()
mgr, err := newManager(sm, store, pchapi)
require.NoError(t, err)
// 1. Get
amt := big.NewInt(10)
_, mcid, err := mgr.GetPaych(ctx, from, to, amt)
require.NoError(t, err)
done := make(chan address.Address)
go func() {
// 2. Wait till ready
ch, err := mgr.GetPaychWaitReady(ctx, mcid)
require.NoError(t, err)
done <- ch
}()
time.Sleep(time.Millisecond * 10)
// 3. Send response
expch := tutils.NewIDAddr(t, 100)
response := testChannelResponse(t, expch)
pchapi.finishWaitingCalls(response)
time.Sleep(time.Millisecond * 10)
ch := <-done
require.Equal(t, expch, ch)
// 4. Wait again - message has already been received so should
// return immediately
ch, err = mgr.GetPaychWaitReady(ctx, mcid)
require.NoError(t, err)
require.Equal(t, expch, ch)
// Request add funds
amt2 := big.NewInt(15)
_, addFundsMsgCid, err := mgr.GetPaych(ctx, from, to, amt2)
require.NoError(t, err)
time.Sleep(time.Millisecond * 10)
go func() {
// 5. Wait for add funds
ch, err := mgr.GetPaychWaitReady(ctx, addFundsMsgCid)
require.NoError(t, err)
require.Equal(t, expch, ch)
done <- ch
}()
time.Sleep(time.Millisecond * 10)
// 6. Send add funds response
addFundsResponse := types.MessageReceipt{
ExitCode: 0,
Return: []byte{},
}
pchapi.finishWaitingCalls(addFundsResponse)
<-done
}
// TestPaychGetWaitErr tests that GetPaychWaitReady correctly handles errors
func TestPaychGetWaitErr(t *testing.T) {
ctx := context.Background()
store := NewStore(ds_sync.MutexWrap(ds.NewMapDatastore()))
from := tutils.NewIDAddr(t, 101)
to := tutils.NewIDAddr(t, 102)
sm := newMockStateManager()
pchapi := newMockPaychAPI()
defer pchapi.close()
mgr, err := newManager(sm, store, pchapi)
require.NoError(t, err)
// 1. Create channel
amt := big.NewInt(10)
_, mcid, err := mgr.GetPaych(ctx, from, to, amt)
require.NoError(t, err)
done := make(chan address.Address)
go func() {
defer close(done)
// 2. Wait for channel to be created
_, err := mgr.GetPaychWaitReady(ctx, mcid)
// 4. Channel creation should have failed
require.NotNil(t, err)
// 5. Call wait again with the same message CID
_, err = mgr.GetPaychWaitReady(ctx, mcid)
// 6. Should return immediately with the same error
require.NotNil(t, err)
}()
// Give the wait a moment to start before sending response
time.Sleep(time.Millisecond * 10)
// 3. Send error response to create channel
response := types.MessageReceipt{
ExitCode: 1, // error
Return: []byte{},
}
pchapi.finishWaitingCalls(response)
<-done
}
// TestPaychGetWaitCtx tests that GetPaychWaitReady returns early if the context
// is cancelled
func TestPaychGetWaitCtx(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
store := NewStore(ds_sync.MutexWrap(ds.NewMapDatastore()))
from := tutils.NewIDAddr(t, 101)
to := tutils.NewIDAddr(t, 102)
sm := newMockStateManager()
pchapi := newMockPaychAPI()
defer pchapi.close()
mgr, err := newManager(sm, store, pchapi)
require.NoError(t, err)
amt := big.NewInt(10)
_, mcid, err := mgr.GetPaych(ctx, from, to, amt)
require.NoError(t, err)
// When the context is cancelled, should unblock wait
go func() {
time.Sleep(time.Millisecond * 10)
cancel()
}()
_, err = mgr.GetPaychWaitReady(ctx, mcid)
require.Error(t, ctx.Err(), err)
}

72
paychmgr/settle_test.go Normal file
View File

@ -0,0 +1,72 @@
package paychmgr
import (
"context"
"testing"
"time"
"github.com/ipfs/go-cid"
"github.com/filecoin-project/specs-actors/actors/abi/big"
tutils "github.com/filecoin-project/specs-actors/support/testing"
ds "github.com/ipfs/go-datastore"
ds_sync "github.com/ipfs/go-datastore/sync"
"github.com/stretchr/testify/require"
)
func TestPaychSettle(t *testing.T) {
ctx := context.Background()
store := NewStore(ds_sync.MutexWrap(ds.NewMapDatastore()))
expch := tutils.NewIDAddr(t, 100)
expch2 := tutils.NewIDAddr(t, 101)
from := tutils.NewIDAddr(t, 101)
to := tutils.NewIDAddr(t, 102)
sm := newMockStateManager()
pchapi := newMockPaychAPI()
defer pchapi.close()
mgr, err := newManager(sm, store, pchapi)
require.NoError(t, err)
amt := big.NewInt(10)
_, mcid, err := mgr.GetPaych(ctx, from, to, amt)
require.NoError(t, err)
time.Sleep(10 * time.Millisecond)
// Send channel create response
response := testChannelResponse(t, expch)
pchapi.finishWaitingCalls(response)
// Get the channel address
ch, err := mgr.GetPaychWaitReady(ctx, mcid)
require.NoError(t, err)
require.Equal(t, expch, ch)
// Settle the channel
_, err = mgr.Settle(ctx, ch)
require.NoError(t, err)
// Send another request for funds to the same from/to
// (should create a new channel because the previous channel
// is settling)
amt2 := big.NewInt(5)
_, mcid2, err := mgr.GetPaych(ctx, from, to, amt2)
require.NoError(t, err)
require.NotEqual(t, cid.Undef, mcid2)
time.Sleep(10 * time.Millisecond)
// Send new channel create response
response2 := testChannelResponse(t, expch2)
pchapi.finishWaitingCalls(response2)
time.Sleep(10 * time.Millisecond)
// Make sure the new channel is different from the old channel
ch2, err := mgr.GetPaychWaitReady(ctx, mcid2)
require.NoError(t, err)
require.NotEqual(t, ch, ch2)
}

View File

@ -3,6 +3,11 @@ package paychmgr
import (
"bytes"
"context"
"fmt"
"github.com/filecoin-project/lotus/api"
"github.com/filecoin-project/specs-actors/actors/abi/big"
"github.com/filecoin-project/specs-actors/actors/builtin"
init_ "github.com/filecoin-project/specs-actors/actors/builtin/init"
@ -17,7 +22,211 @@ import (
"github.com/filecoin-project/lotus/chain/types"
)
func (pm *Manager) createPaych(ctx context.Context, from, to address.Address, amt types.BigInt) (cid.Cid, error) {
type paychApi interface {
StateWaitMsg(ctx context.Context, msg cid.Cid, confidence uint64) (*api.MsgLookup, error)
MpoolPushMessage(ctx context.Context, msg *types.Message) (*types.SignedMessage, error)
}
// paychFundsRes is the response to a create channel or add funds request
type paychFundsRes struct {
channel address.Address
mcid cid.Cid
err error
}
type onCompleteFn func(*paychFundsRes)
// fundsReq is a request to create a channel or add funds to a channel
type fundsReq struct {
ctx context.Context
from address.Address
to address.Address
amt types.BigInt
onComplete onCompleteFn
}
// getPaych ensures that a channel exists between the from and to addresses,
// and adds the given amount of funds.
// If the channel does not exist a create channel message is sent and the
// message CID is returned.
// If the channel does exist an add funds message is sent and both the channel
// address and message CID are returned.
// If there is an in progress operation (create channel / add funds), getPaych
// blocks until the previous operation completes, then returns both the channel
// address and the CID of the new add funds message.
// If an operation returns an error, subsequent waiting operations will still
// be attempted.
func (ca *channelAccessor) getPaych(ctx context.Context, from, to address.Address, amt types.BigInt) (address.Address, cid.Cid, error) {
// Add the request to add funds to a queue and wait for the result
promise := ca.enqueue(&fundsReq{ctx: ctx, from: from, to: to, amt: amt})
select {
case res := <-promise:
return res.channel, res.mcid, res.err
case <-ctx.Done():
return address.Undef, cid.Undef, ctx.Err()
}
}
// Queue up an add funds operation
func (ca *channelAccessor) enqueue(task *fundsReq) chan *paychFundsRes {
promise := make(chan *paychFundsRes)
task.onComplete = func(res *paychFundsRes) {
select {
case <-task.ctx.Done():
case promise <- res:
}
}
ca.lk.Lock()
defer ca.lk.Unlock()
ca.fundsReqQueue = append(ca.fundsReqQueue, task)
go ca.processNextQueueItem()
return promise
}
// Run the operation at the head of the queue
func (ca *channelAccessor) processNextQueueItem() {
ca.lk.Lock()
defer ca.lk.Unlock()
if len(ca.fundsReqQueue) == 0 {
return
}
head := ca.fundsReqQueue[0]
res := ca.processTask(head.ctx, head.from, head.to, head.amt, head.onComplete)
// If the task is waiting on an external event (eg something to appear on
// chain) it will return nil
if res == nil {
// Stop processing the fundsReqQueue and wait. When the event occurs it will
// call processNextQueueItem() again
return
}
// The task has finished processing so clean it up
ca.fundsReqQueue[0] = nil // allow GC of element
ca.fundsReqQueue = ca.fundsReqQueue[1:]
// Call the task callback with its results
head.onComplete(res)
// Process the next task
if len(ca.fundsReqQueue) > 0 {
go ca.processNextQueueItem()
}
}
// msgWaitComplete is called when the message for a previous task is confirmed
// or there is an error.
func (ca *channelAccessor) msgWaitComplete(mcid cid.Cid, err error, cb onCompleteFn) {
ca.lk.Lock()
defer ca.lk.Unlock()
// Save the message result to the store
dserr := ca.store.SaveMessageResult(mcid, err)
if dserr != nil {
log.Errorf("saving message result: %s", dserr)
}
// Call the onComplete callback
ca.callOnComplete(mcid, err, cb)
// Inform listeners that the message has completed
ca.msgListeners.fireMsgComplete(mcid, err)
// The queue may have been waiting for msg completion to proceed, so
// process the next queue item
if len(ca.fundsReqQueue) > 0 {
go ca.processNextQueueItem()
}
}
// callOnComplete calls the onComplete callback for a task
func (ca *channelAccessor) callOnComplete(mcid cid.Cid, err error, cb onCompleteFn) {
if cb == nil {
return
}
if err != nil {
go cb(&paychFundsRes{err: err})
return
}
// Get the channel address
ci, storeErr := ca.store.ByMessageCid(mcid)
if storeErr != nil {
log.Errorf("getting channel by message cid: %s", err)
go cb(&paychFundsRes{err: storeErr})
return
}
if ci.Channel == nil {
panic("channel address is nil when calling onComplete callback")
}
go cb(&paychFundsRes{channel: *ci.Channel, mcid: mcid, err: err})
}
// processTask checks the state of the channel and takes appropriate action
// (see description of getPaych).
// Note that processTask may be called repeatedly in the same state, and should
// return nil if there is no state change to be made (eg when waiting for a
// message to be confirmed on chain)
func (ca *channelAccessor) processTask(
ctx context.Context,
from address.Address,
to address.Address,
amt types.BigInt,
onComplete onCompleteFn,
) *paychFundsRes {
// Get the payment channel for the from/to addresses.
// Note: It's ok if we get ErrChannelNotTracked. It just means we need to
// create a channel.
channelInfo, err := ca.store.OutboundActiveByFromTo(from, to)
if err != nil && err != ErrChannelNotTracked {
return &paychFundsRes{err: err}
}
// If a channel has not yet been created, create one.
// Note that if the previous attempt to create the channel failed because of a VM error
// (eg not enough gas), both channelInfo.Channel and channelInfo.CreateMsg will be nil.
if channelInfo == nil || channelInfo.Channel == nil && channelInfo.CreateMsg == nil {
mcid, err := ca.createPaych(ctx, from, to, amt, onComplete)
if err != nil {
return &paychFundsRes{err: err}
}
return &paychFundsRes{mcid: mcid}
}
// If the create channel message has been sent but the channel hasn't
// been created on chain yet
if channelInfo.CreateMsg != nil {
// Wait for the channel to be created before trying again
return nil
}
// If an add funds message was sent to the chain but hasn't been confirmed
// on chain yet
if channelInfo.AddFundsMsg != nil {
// Wait for the add funds message to be confirmed before trying again
return nil
}
// We need to add more funds, so send an add funds message to
// cover the amount for this request
mcid, err := ca.addFunds(ctx, from, to, amt, onComplete)
if err != nil {
return &paychFundsRes{err: err}
}
return &paychFundsRes{channel: *channelInfo.Channel, mcid: *mcid}
}
// createPaych sends a message to create the channel and returns the message cid
func (ca *channelAccessor) createPaych(ctx context.Context, from, to address.Address, amt types.BigInt, cb onCompleteFn) (cid.Cid, error) {
params, aerr := actors.SerializeParams(&paych.ConstructorParams{From: from, To: to})
if aerr != nil {
return cid.Undef, aerr
@ -41,106 +250,268 @@ func (pm *Manager) createPaych(ctx context.Context, from, to address.Address, am
GasPrice: types.NewInt(0),
}
smsg, err := pm.mpool.MpoolPushMessage(ctx, msg)
smsg, err := ca.api.MpoolPushMessage(ctx, msg)
if err != nil {
return cid.Undef, xerrors.Errorf("initializing paych actor: %w", err)
}
mcid := smsg.Cid()
go pm.waitForPaychCreateMsg(ctx, mcid)
// Create a new channel in the store
if _, err := ca.store.createChannel(from, to, mcid, amt); err != nil {
log.Errorf("creating channel: %s", err)
return cid.Undef, err
}
// Wait for the channel to be created on chain
go ca.waitForPaychCreateMsg(from, to, mcid, cb)
return mcid, nil
}
// WaitForPaychCreateMsg waits for mcid to appear on chain and returns the robust address of the
// waitForPaychCreateMsg waits for mcid to appear on chain and stores the robust address of the
// created payment channel
// TODO: wait outside the store lock!
// (tricky because we need to setup channel tracking before we know its address)
func (pm *Manager) waitForPaychCreateMsg(ctx context.Context, mcid cid.Cid) {
defer pm.store.lk.Unlock()
mwait, err := pm.state.StateWaitMsg(ctx, mcid, build.MessageConfidence)
func (ca *channelAccessor) waitForPaychCreateMsg(from address.Address, to address.Address, mcid cid.Cid, cb onCompleteFn) {
err := ca.waitPaychCreateMsg(from, to, mcid)
ca.msgWaitComplete(mcid, err, cb)
}
func (ca *channelAccessor) waitPaychCreateMsg(from address.Address, to address.Address, mcid cid.Cid) error {
mwait, err := ca.api.StateWaitMsg(ca.waitCtx, mcid, build.MessageConfidence)
if err != nil {
log.Errorf("wait msg: %w", err)
return
return err
}
if mwait.Receipt.ExitCode != 0 {
log.Errorf("payment channel creation failed (exit code %d)", mwait.Receipt.ExitCode)
return
err := xerrors.Errorf("payment channel creation failed (exit code %d)", mwait.Receipt.ExitCode)
log.Error(err)
ca.lk.Lock()
defer ca.lk.Unlock()
ca.mutateChannelInfo(from, to, func(channelInfo *ChannelInfo) {
channelInfo.PendingAmount = big.NewInt(0)
channelInfo.CreateMsg = nil
})
return err
}
var decodedReturn init_.ExecReturn
err = decodedReturn.UnmarshalCBOR(bytes.NewReader(mwait.Receipt.Return))
if err != nil {
log.Error(err)
return
}
paychaddr := decodedReturn.RobustAddress
ci, err := pm.loadStateChannelInfo(ctx, paychaddr, DirOutbound)
if err != nil {
log.Errorf("loading channel info: %w", err)
return
return err
}
if err := pm.store.trackChannel(ci); err != nil {
log.Errorf("tracking channel: %w", err)
}
ca.lk.Lock()
defer ca.lk.Unlock()
// Store robust address of channel
ca.mutateChannelInfo(from, to, func(channelInfo *ChannelInfo) {
channelInfo.Channel = &decodedReturn.RobustAddress
channelInfo.Amount = channelInfo.PendingAmount
channelInfo.PendingAmount = big.NewInt(0)
channelInfo.CreateMsg = nil
})
return nil
}
func (pm *Manager) addFunds(ctx context.Context, ch address.Address, from address.Address, amt types.BigInt) (cid.Cid, error) {
// addFunds sends a message to add funds to the channel and returns the message cid
func (ca *channelAccessor) addFunds(ctx context.Context, from address.Address, to address.Address, amt types.BigInt, cb onCompleteFn) (*cid.Cid, error) {
channelInfo, err := ca.store.OutboundActiveByFromTo(from, to)
if err != nil {
return nil, err
}
msg := &types.Message{
To: ch,
From: from,
To: *channelInfo.Channel,
From: channelInfo.Control,
Value: amt,
Method: 0,
GasLimit: 0,
GasPrice: types.NewInt(0),
}
smsg, err := pm.mpool.MpoolPushMessage(ctx, msg)
smsg, err := ca.api.MpoolPushMessage(ctx, msg)
if err != nil {
return cid.Undef, err
return nil, err
}
mcid := smsg.Cid()
go pm.waitForAddFundsMsg(ctx, mcid)
return mcid, nil
// Store the add funds message CID on the channel
ca.mutateChannelInfo(from, to, func(ci *ChannelInfo) {
ci.PendingAmount = amt
ci.AddFundsMsg = &mcid
})
// Store a reference from the message CID to the channel, so that we can
// look up the channel from the message CID
err = ca.store.SaveNewMessage(channelInfo.ChannelID, mcid)
if err != nil {
log.Errorf("saving add funds message CID %s: %s", mcid, err)
}
go ca.waitForAddFundsMsg(from, to, mcid, cb)
return &mcid, nil
}
// WaitForAddFundsMsg waits for mcid to appear on chain and returns error, if any
// TODO: wait outside the store lock!
// (tricky because we need to setup channel tracking before we know it's address)
func (pm *Manager) waitForAddFundsMsg(ctx context.Context, mcid cid.Cid) {
defer pm.store.lk.Unlock()
mwait, err := pm.state.StateWaitMsg(ctx, mcid, build.MessageConfidence)
// waitForAddFundsMsg waits for mcid to appear on chain and returns error, if any
func (ca *channelAccessor) waitForAddFundsMsg(from address.Address, to address.Address, mcid cid.Cid, cb onCompleteFn) {
err := ca.waitAddFundsMsg(from, to, mcid)
ca.msgWaitComplete(mcid, err, cb)
}
func (ca *channelAccessor) waitAddFundsMsg(from address.Address, to address.Address, mcid cid.Cid) error {
mwait, err := ca.api.StateWaitMsg(ca.waitCtx, mcid, build.MessageConfidence)
if err != nil {
log.Error(err)
return err
}
if mwait.Receipt.ExitCode != 0 {
log.Errorf("voucher channel creation failed: adding funds (exit code %d)", mwait.Receipt.ExitCode)
err := xerrors.Errorf("voucher channel creation failed: adding funds (exit code %d)", mwait.Receipt.ExitCode)
log.Error(err)
ca.lk.Lock()
defer ca.lk.Unlock()
ca.mutateChannelInfo(from, to, func(channelInfo *ChannelInfo) {
channelInfo.PendingAmount = big.NewInt(0)
channelInfo.AddFundsMsg = nil
})
return err
}
ca.lk.Lock()
defer ca.lk.Unlock()
// Store updated amount
ca.mutateChannelInfo(from, to, func(channelInfo *ChannelInfo) {
channelInfo.Amount = types.BigAdd(channelInfo.Amount, channelInfo.PendingAmount)
channelInfo.PendingAmount = big.NewInt(0)
channelInfo.AddFundsMsg = nil
})
return nil
}
// Change the state of the channel in the store
func (ca *channelAccessor) mutateChannelInfo(from address.Address, to address.Address, mutate func(*ChannelInfo)) {
channelInfo, err := ca.store.OutboundActiveByFromTo(from, to)
// If there's an error reading or writing to the store just log an error.
// For now we're assuming it's unlikely to happen in practice.
// Later we may want to implement a transactional approach, whereby
// we record to the store that we're going to send a message, send
// the message, and then record that the message was sent.
if err != nil {
log.Errorf("Error reading channel info from store: %s", err)
}
mutate(channelInfo)
err = ca.store.putChannelInfo(channelInfo)
if err != nil {
log.Errorf("Error writing channel info to store: %s", err)
}
}
func (pm *Manager) GetPaych(ctx context.Context, from, to address.Address, ensureFree types.BigInt) (address.Address, cid.Cid, error) {
pm.store.lk.Lock() // unlock only on err; wait funcs will defer unlock
var mcid cid.Cid
ch, err := pm.store.findChan(func(ci *ChannelInfo) bool {
if ci.Direction != DirOutbound {
return false
// getPaychWaitReady waits for a the response to the message with the given cid
func (ca *channelAccessor) getPaychWaitReady(ctx context.Context, mcid cid.Cid) (address.Address, error) {
ca.lk.Lock()
// First check if the message has completed
msgInfo, err := ca.store.GetMessage(mcid)
if err != nil {
ca.lk.Unlock()
return address.Undef, err
}
// If the create channel / add funds message failed, return an error
if len(msgInfo.Err) > 0 {
ca.lk.Unlock()
return address.Undef, xerrors.New(msgInfo.Err)
}
// If the message has completed successfully
if msgInfo.Received {
ca.lk.Unlock()
// Get the channel address
ci, err := ca.store.ByMessageCid(mcid)
if err != nil {
return address.Undef, err
}
return ci.Control == from && ci.Target == to
})
if err != nil {
pm.store.lk.Unlock()
return address.Undef, cid.Undef, xerrors.Errorf("findChan: %w", err)
if ci.Channel == nil {
panic(fmt.Sprintf("create / add funds message %s succeeded but channelInfo.Channel is nil", mcid))
}
return *ci.Channel, nil
}
if ch != address.Undef {
// TODO: Track available funds
mcid, err = pm.addFunds(ctx, ch, from, ensureFree)
} else {
mcid, err = pm.createPaych(ctx, from, to, ensureFree)
// The message hasn't completed yet so wait for it to complete
promise := ca.msgPromise(ctx, mcid)
// Unlock while waiting
ca.lk.Unlock()
select {
case res := <-promise:
return res.channel, res.err
case <-ctx.Done():
return address.Undef, ctx.Err()
}
if err != nil {
pm.store.lk.Unlock()
}
return ch, mcid, err
}
type onMsgRes struct {
channel address.Address
err error
}
// msgPromise returns a channel that receives the result of the message with
// the given CID
func (ca *channelAccessor) msgPromise(ctx context.Context, mcid cid.Cid) chan onMsgRes {
promise := make(chan onMsgRes)
triggerUnsub := make(chan struct{})
sub := ca.msgListeners.onMsg(mcid, func(err error) {
close(triggerUnsub)
// Use a go-routine so as not to block the event handler loop
go func() {
res := onMsgRes{err: err}
if res.err == nil {
// Get the channel associated with the message cid
ci, err := ca.store.ByMessageCid(mcid)
if err != nil {
res.err = err
} else {
res.channel = *ci.Channel
}
}
// Pass the result to the caller
select {
case promise <- res:
case <-ctx.Done():
}
}()
})
// Unsubscribe when the message is received or the context is done
go func() {
select {
case <-ctx.Done():
case <-triggerUnsub:
}
ca.msgListeners.unsubscribe(sub)
}()
return promise
}

View File

@ -3,20 +3,21 @@ package paychmgr
import (
"context"
"github.com/filecoin-project/specs-actors/actors/abi/big"
"github.com/filecoin-project/specs-actors/actors/builtin/account"
"github.com/filecoin-project/go-address"
"github.com/filecoin-project/specs-actors/actors/builtin/paych"
xerrors "golang.org/x/xerrors"
"github.com/filecoin-project/lotus/chain/types"
)
func (pm *Manager) loadPaychState(ctx context.Context, ch address.Address) (*types.Actor, *paych.State, error) {
type stateAccessor struct {
sm StateManagerApi
}
func (ca *stateAccessor) loadPaychState(ctx context.Context, ch address.Address) (*types.Actor, *paych.State, error) {
var pcast paych.State
act, err := pm.sm.LoadActorState(ctx, ch, &pcast, nil)
act, err := ca.sm.LoadActorState(ctx, ch, &pcast, nil)
if err != nil {
return nil, nil, err
}
@ -24,26 +25,26 @@ func (pm *Manager) loadPaychState(ctx context.Context, ch address.Address) (*typ
return act, &pcast, nil
}
func (pm *Manager) loadStateChannelInfo(ctx context.Context, ch address.Address, dir uint64) (*ChannelInfo, error) {
_, st, err := pm.loadPaychState(ctx, ch)
func (ca *stateAccessor) loadStateChannelInfo(ctx context.Context, ch address.Address, dir uint64) (*ChannelInfo, error) {
_, st, err := ca.loadPaychState(ctx, ch)
if err != nil {
return nil, err
}
var account account.State
_, err = pm.sm.LoadActorState(ctx, st.From, &account, nil)
_, err = ca.sm.LoadActorState(ctx, st.From, &account, nil)
if err != nil {
return nil, err
}
from := account.Address
_, err = pm.sm.LoadActorState(ctx, st.To, &account, nil)
_, err = ca.sm.LoadActorState(ctx, st.To, &account, nil)
if err != nil {
return nil, err
}
to := account.Address
ci := &ChannelInfo{
Channel: ch,
Channel: &ch,
Direction: dir,
NextLane: nextLaneFromState(st),
}
@ -72,80 +73,3 @@ func nextLaneFromState(st *paych.State) uint64 {
}
return maxLane + 1
}
// laneState gets the LaneStates from chain, then applies all vouchers in
// the data store over the chain state
func (pm *Manager) laneState(state *paych.State, ch address.Address) (map[uint64]*paych.LaneState, error) {
// TODO: we probably want to call UpdateChannelState with all vouchers to be fully correct
// (but technically dont't need to)
laneStates := make(map[uint64]*paych.LaneState, len(state.LaneStates))
// Get the lane state from the chain
for _, laneState := range state.LaneStates {
laneStates[laneState.ID] = laneState
}
// Apply locally stored vouchers
vouchers, err := pm.store.VouchersForPaych(ch)
if err != nil && err != ErrChannelNotTracked {
return nil, err
}
for _, v := range vouchers {
for range v.Voucher.Merges {
return nil, xerrors.Errorf("paych merges not handled yet")
}
// If there's a voucher for a lane that isn't in chain state just
// create it
ls, ok := laneStates[v.Voucher.Lane]
if !ok {
ls = &paych.LaneState{
ID: v.Voucher.Lane,
Redeemed: types.NewInt(0),
Nonce: 0,
}
laneStates[v.Voucher.Lane] = ls
}
if v.Voucher.Nonce < ls.Nonce {
continue
}
ls.Nonce = v.Voucher.Nonce
ls.Redeemed = v.Voucher.Amount
}
return laneStates, nil
}
// Get the total redeemed amount across all lanes, after applying the voucher
func (pm *Manager) totalRedeemedWithVoucher(laneStates map[uint64]*paych.LaneState, sv *paych.SignedVoucher) (big.Int, error) {
// TODO: merges
if len(sv.Merges) != 0 {
return big.Int{}, xerrors.Errorf("dont currently support paych lane merges")
}
total := big.NewInt(0)
for _, ls := range laneStates {
total = big.Add(total, ls.Redeemed)
}
lane, ok := laneStates[sv.Lane]
if ok {
// If the voucher is for an existing lane, and the voucher nonce
// and is higher than the lane nonce
if sv.Nonce > lane.Nonce {
// Add the delta between the redeemed amount and the voucher
// amount to the total
delta := big.Sub(sv.Amount, lane.Redeemed)
total = big.Add(total, delta)
}
} else {
// If the voucher is *not* for an existing lane, just add its
// value (implicitly a new lane will be created for the voucher)
total = big.Add(total, sv.Amount)
}
return total, nil
}

View File

@ -4,14 +4,16 @@ import (
"bytes"
"errors"
"fmt"
"strings"
"sync"
"github.com/google/uuid"
"github.com/filecoin-project/lotus/chain/types"
"github.com/filecoin-project/specs-actors/actors/builtin/paych"
"github.com/ipfs/go-cid"
"github.com/ipfs/go-datastore"
"github.com/ipfs/go-datastore/namespace"
dsq "github.com/ipfs/go-datastore/query"
"golang.org/x/xerrors"
"github.com/filecoin-project/go-address"
cborrpc "github.com/filecoin-project/go-cbor-util"
@ -22,8 +24,6 @@ import (
var ErrChannelNotTracked = errors.New("channel not tracked")
type Store struct {
lk sync.Mutex // TODO: this can be split per paych
ds datastore.Batching
}
@ -39,85 +39,107 @@ const (
DirOutbound = 2
)
const (
dsKeyChannelInfo = "ChannelInfo"
dsKeyMsgCid = "MsgCid"
)
type VoucherInfo struct {
Voucher *paych.SignedVoucher
Proof []byte
}
// ChannelInfo keeps track of information about a channel
type ChannelInfo struct {
Channel address.Address
// ChannelID is a uuid set at channel creation
ChannelID string
// Channel address - may be nil if the channel hasn't been created yet
Channel *address.Address
// Control is the address of the account that created the channel
Control address.Address
Target address.Address
// Target is the address of the account on the other end of the channel
Target address.Address
// Direction indicates if the channel is inbound (this node is the Target)
// or outbound (this node is the Control)
Direction uint64
Vouchers []*VoucherInfo
NextLane uint64
// Vouchers is a list of all vouchers sent on the channel
Vouchers []*VoucherInfo
// NextLane is the number of the next lane that should be used when the
// client requests a new lane (eg to create a voucher for a new deal)
NextLane uint64
// Amount added to the channel.
// Note: This amount is only used by GetPaych to keep track of how much
// has locally been added to the channel. It should reflect the channel's
// Balance on chain as long as all operations occur on the same datastore.
Amount types.BigInt
// PendingAmount is the amount that we're awaiting confirmation of
PendingAmount types.BigInt
// CreateMsg is the CID of a pending create message (while waiting for confirmation)
CreateMsg *cid.Cid
// AddFundsMsg is the CID of a pending add funds message (while waiting for confirmation)
AddFundsMsg *cid.Cid
// Settling indicates whether the channel has entered into the settling state
Settling bool
}
func dskeyForChannel(addr address.Address) datastore.Key {
return datastore.NewKey(addr.String())
}
func (ps *Store) putChannelInfo(ci *ChannelInfo) error {
k := dskeyForChannel(ci.Channel)
b, err := cborrpc.Dump(ci)
if err != nil {
return err
}
return ps.ds.Put(k, b)
}
func (ps *Store) getChannelInfo(addr address.Address) (*ChannelInfo, error) {
k := dskeyForChannel(addr)
b, err := ps.ds.Get(k)
if err == datastore.ErrNotFound {
return nil, ErrChannelNotTracked
}
if err != nil {
return nil, err
}
var ci ChannelInfo
if err := ci.UnmarshalCBOR(bytes.NewReader(b)); err != nil {
return nil, err
}
return &ci, nil
}
func (ps *Store) TrackChannel(ch *ChannelInfo) error {
ps.lk.Lock()
defer ps.lk.Unlock()
return ps.trackChannel(ch)
}
func (ps *Store) trackChannel(ch *ChannelInfo) error {
_, err := ps.getChannelInfo(ch.Channel)
// TrackChannel stores a channel, returning an error if the channel was already
// being tracked
func (ps *Store) TrackChannel(ci *ChannelInfo) error {
_, err := ps.ByAddress(*ci.Channel)
switch err {
default:
return err
case nil:
return fmt.Errorf("already tracking channel: %s", ch.Channel)
return fmt.Errorf("already tracking channel: %s", ci.Channel)
case ErrChannelNotTracked:
return ps.putChannelInfo(ch)
return ps.putChannelInfo(ci)
}
}
// ListChannels returns the addresses of all channels that have been created
func (ps *Store) ListChannels() ([]address.Address, error) {
ps.lk.Lock()
defer ps.lk.Unlock()
cis, err := ps.findChans(func(ci *ChannelInfo) bool {
return ci.Channel != nil
}, 0)
if err != nil {
return nil, err
}
res, err := ps.ds.Query(dsq.Query{KeysOnly: true})
addrs := make([]address.Address, 0, len(cis))
for _, ci := range cis {
addrs = append(addrs, *ci.Channel)
}
return addrs, nil
}
// findChan finds a single channel using the given filter.
// If there isn't a channel that matches the filter, returns ErrChannelNotTracked
func (ps *Store) findChan(filter func(ci *ChannelInfo) bool) (*ChannelInfo, error) {
cis, err := ps.findChans(filter, 1)
if err != nil {
return nil, err
}
if len(cis) == 0 {
return nil, ErrChannelNotTracked
}
return &cis[0], err
}
// findChans loops over all channels, only including those that pass the filter.
// max is the maximum number of channels to return. Set to zero to return unlimited channels.
func (ps *Store) findChans(filter func(*ChannelInfo) bool, max int) ([]ChannelInfo, error) {
res, err := ps.ds.Query(dsq.Query{Prefix: dsKeyChannelInfo})
if err != nil {
return nil, err
}
defer res.Close() //nolint:errcheck
var out []address.Address
var stored ChannelInfo
var matches []ChannelInfo
for {
res, ok := res.NextSync()
if !ok {
@ -128,60 +150,31 @@ func (ps *Store) ListChannels() ([]address.Address, error) {
return nil, err
}
addr, err := address.NewFromString(strings.TrimPrefix(res.Key, "/"))
ci, err := unmarshallChannelInfo(&stored, res)
if err != nil {
return nil, xerrors.Errorf("failed reading paych key (%q) from datastore: %w", res.Key, err)
return nil, err
}
out = append(out, addr)
}
return out, nil
}
func (ps *Store) findChan(filter func(*ChannelInfo) bool) (address.Address, error) {
res, err := ps.ds.Query(dsq.Query{})
if err != nil {
return address.Undef, err
}
defer res.Close() //nolint:errcheck
var ci ChannelInfo
for {
res, ok := res.NextSync()
if !ok {
break
}
if res.Error != nil {
return address.Undef, err
}
if err := ci.UnmarshalCBOR(bytes.NewReader(res.Value)); err != nil {
return address.Undef, err
}
if !filter(&ci) {
if !filter(ci) {
continue
}
addr, err := address.NewFromString(strings.TrimPrefix(res.Key, "/"))
if err != nil {
return address.Undef, xerrors.Errorf("failed reading paych key (%q) from datastore: %w", res.Key, err)
}
matches = append(matches, *ci)
return addr, nil
// If we've reached the maximum number of matches, return.
// Note that if max is zero we return an unlimited number of matches
// because len(matches) will always be at least 1.
if len(matches) == max {
return matches, nil
}
}
return address.Undef, nil
return matches, nil
}
// AllocateLane allocates a new lane for the given channel
func (ps *Store) AllocateLane(ch address.Address) (uint64, error) {
ps.lk.Lock()
defer ps.lk.Unlock()
ci, err := ps.getChannelInfo(ch)
ci, err := ps.ByAddress(ch)
if err != nil {
return 0, err
}
@ -192,11 +185,210 @@ func (ps *Store) AllocateLane(ch address.Address) (uint64, error) {
return out, ps.putChannelInfo(ci)
}
// VouchersForPaych gets the vouchers for the given channel
func (ps *Store) VouchersForPaych(ch address.Address) ([]*VoucherInfo, error) {
ci, err := ps.getChannelInfo(ch)
ci, err := ps.ByAddress(ch)
if err != nil {
return nil, err
}
return ci.Vouchers, nil
}
// ByAddress gets the channel that matches the given address
func (ps *Store) ByAddress(addr address.Address) (*ChannelInfo, error) {
return ps.findChan(func(ci *ChannelInfo) bool {
return ci.Channel != nil && *ci.Channel == addr
})
}
// MsgInfo stores information about a create channel / add funds message
// that has been sent
type MsgInfo struct {
// ChannelID links the message to a channel
ChannelID string
// MsgCid is the CID of the message
MsgCid cid.Cid
// Received indicates whether a response has been received
Received bool
// Err is the error received in the response
Err string
}
// The datastore key used to identify the message
func dskeyForMsg(mcid cid.Cid) datastore.Key {
return datastore.KeyWithNamespaces([]string{dsKeyMsgCid, mcid.String()})
}
// SaveNewMessage is called when a message is sent
func (ps *Store) SaveNewMessage(channelID string, mcid cid.Cid) error {
k := dskeyForMsg(mcid)
b, err := cborrpc.Dump(&MsgInfo{ChannelID: channelID, MsgCid: mcid})
if err != nil {
return err
}
return ps.ds.Put(k, b)
}
// SaveMessageResult is called when the result of a message is received
func (ps *Store) SaveMessageResult(mcid cid.Cid, msgErr error) error {
minfo, err := ps.GetMessage(mcid)
if err != nil {
return err
}
k := dskeyForMsg(mcid)
minfo.Received = true
if msgErr != nil {
minfo.Err = msgErr.Error()
}
b, err := cborrpc.Dump(minfo)
if err != nil {
return err
}
return ps.ds.Put(k, b)
}
// ByMessageCid gets the channel associated with a message
func (ps *Store) ByMessageCid(mcid cid.Cid) (*ChannelInfo, error) {
minfo, err := ps.GetMessage(mcid)
if err != nil {
return nil, err
}
ci, err := ps.findChan(func(ci *ChannelInfo) bool {
return ci.ChannelID == minfo.ChannelID
})
if err != nil {
return nil, err
}
return ci, err
}
// GetMessage gets the message info for a given message CID
func (ps *Store) GetMessage(mcid cid.Cid) (*MsgInfo, error) {
k := dskeyForMsg(mcid)
val, err := ps.ds.Get(k)
if err != nil {
return nil, err
}
var minfo MsgInfo
if err := minfo.UnmarshalCBOR(bytes.NewReader(val)); err != nil {
return nil, err
}
return &minfo, nil
}
// OutboundActiveByFromTo looks for outbound channels that have not been
// settled, with the given from / to addresses
func (ps *Store) OutboundActiveByFromTo(from address.Address, to address.Address) (*ChannelInfo, error) {
return ps.findChan(func(ci *ChannelInfo) bool {
if ci.Direction != DirOutbound {
return false
}
if ci.Settling {
return false
}
return ci.Control == from && ci.Target == to
})
}
// WithPendingAddFunds is used on startup to find channels for which a
// create channel or add funds message has been sent, but lotus shut down
// before the response was received.
func (ps *Store) WithPendingAddFunds() ([]ChannelInfo, error) {
return ps.findChans(func(ci *ChannelInfo) bool {
if ci.Direction != DirOutbound {
return false
}
return ci.CreateMsg != nil || ci.AddFundsMsg != nil
}, 0)
}
// createChannel creates an outbound channel for the given from / to, ensuring
// it has a higher sequence number than any existing channel with the same from / to
func (ps *Store) createChannel(from address.Address, to address.Address, createMsgCid cid.Cid, amt types.BigInt) (*ChannelInfo, error) {
ci := &ChannelInfo{
ChannelID: uuid.New().String(),
Direction: DirOutbound,
NextLane: 0,
Control: from,
Target: to,
CreateMsg: &createMsgCid,
PendingAmount: amt,
}
// Save the new channel
err := ps.putChannelInfo(ci)
if err != nil {
return nil, err
}
// Save a reference to the create message
err = ps.SaveNewMessage(ci.ChannelID, createMsgCid)
if err != nil {
return nil, err
}
return ci, err
}
// The datastore key used to identify the channel info
func dskeyForChannel(ci *ChannelInfo) datastore.Key {
chanKey := fmt.Sprintf("%s->%s", ci.Control.String(), ci.Target.String())
return datastore.KeyWithNamespaces([]string{dsKeyChannelInfo, chanKey})
}
// putChannelInfo stores the channel info in the datastore
func (ps *Store) putChannelInfo(ci *ChannelInfo) error {
k := dskeyForChannel(ci)
b, err := marshallChannelInfo(ci)
if err != nil {
return err
}
return ps.ds.Put(k, b)
}
// TODO: This is a hack to get around not being able to CBOR marshall a nil
// address.Address. It's been fixed in address.Address but we need to wait
// for the change to propagate to specs-actors before we can remove this hack.
var emptyAddr address.Address
func init() {
addr, err := address.NewActorAddress([]byte("empty"))
if err != nil {
panic(err)
}
emptyAddr = addr
}
func marshallChannelInfo(ci *ChannelInfo) ([]byte, error) {
// See note above about CBOR marshalling address.Address
if ci.Channel == nil {
ci.Channel = &emptyAddr
}
return cborrpc.Dump(ci)
}
func unmarshallChannelInfo(stored *ChannelInfo, res dsq.Result) (*ChannelInfo, error) {
if err := stored.UnmarshalCBOR(bytes.NewReader(res.Value)); err != nil {
return nil, err
}
// See note above about CBOR marshalling address.Address
if stored.Channel != nil && *stored.Channel == emptyAddr {
stored.Channel = nil
}
return stored, nil
}

View File

@ -17,8 +17,9 @@ func TestStore(t *testing.T) {
require.NoError(t, err)
require.Len(t, addrs, 0)
ch := tutils.NewIDAddr(t, 100)
ci := &ChannelInfo{
Channel: tutils.NewIDAddr(t, 100),
Channel: &ch,
Control: tutils.NewIDAddr(t, 101),
Target: tutils.NewIDAddr(t, 102),
@ -26,8 +27,9 @@ func TestStore(t *testing.T) {
Vouchers: []*VoucherInfo{{Voucher: nil, Proof: []byte{}}},
}
ch2 := tutils.NewIDAddr(t, 200)
ci2 := &ChannelInfo{
Channel: tutils.NewIDAddr(t, 200),
Channel: &ch2,
Control: tutils.NewIDAddr(t, 201),
Target: tutils.NewIDAddr(t, 202),
@ -55,7 +57,7 @@ func TestStore(t *testing.T) {
require.Contains(t, addrsStrings(addrs), "t0200")
// Request vouchers for channel
vouchers, err := store.VouchersForPaych(ci.Channel)
vouchers, err := store.VouchersForPaych(*ci.Channel)
require.NoError(t, err)
require.Len(t, vouchers, 1)
@ -64,12 +66,12 @@ func TestStore(t *testing.T) {
require.Equal(t, err, ErrChannelNotTracked)
// Allocate lane for channel
lane, err := store.AllocateLane(ci.Channel)
lane, err := store.AllocateLane(*ci.Channel)
require.NoError(t, err)
require.Equal(t, lane, uint64(0))
// Allocate next lane for channel
lane, err = store.AllocateLane(ci.Channel)
lane, err = store.AllocateLane(*ci.Channel)
require.NoError(t, err)
require.Equal(t, lane, uint64(1))