diff --git a/api/api_full.go b/api/api_full.go index 56a63072e..42981746f 100644 --- a/api/api_full.go +++ b/api/api_full.go @@ -377,7 +377,8 @@ type FullNode interface { // MethodGroup: Paych // The Paych methods are for interacting with and managing payment channels - PaychGet(ctx context.Context, from, to address.Address, ensureFunds types.BigInt) (*ChannelInfo, error) + PaychGet(ctx context.Context, from, to address.Address, amt types.BigInt) (*ChannelInfo, error) + PaychGetWaitReady(context.Context, cid.Cid) (address.Address, error) PaychList(context.Context) ([]address.Address, error) PaychStatus(context.Context, address.Address) (*PaychStatus, error) PaychSettle(context.Context, address.Address) (cid.Cid, error) diff --git a/api/apistruct/struct.go b/api/apistruct/struct.go index 607a54dc6..0e3af90a7 100644 --- a/api/apistruct/struct.go +++ b/api/apistruct/struct.go @@ -183,7 +183,8 @@ type FullNodeStruct struct { MarketEnsureAvailable func(context.Context, address.Address, address.Address, types.BigInt) (cid.Cid, error) `perm:"sign"` - PaychGet func(ctx context.Context, from, to address.Address, ensureFunds types.BigInt) (*api.ChannelInfo, error) `perm:"sign"` + PaychGet func(ctx context.Context, from, to address.Address, amt types.BigInt) (*api.ChannelInfo, error) `perm:"sign"` + PaychGetWaitReady func(context.Context, cid.Cid) (address.Address, error) `perm:"sign"` PaychList func(context.Context) ([]address.Address, error) `perm:"read"` PaychStatus func(context.Context, address.Address) (*api.PaychStatus, error) `perm:"read"` PaychSettle func(context.Context, address.Address) (cid.Cid, error) `perm:"sign"` @@ -803,8 +804,12 @@ func (c *FullNodeStruct) MarketEnsureAvailable(ctx context.Context, addr, wallet return c.Internal.MarketEnsureAvailable(ctx, addr, wallet, amt) } -func (c *FullNodeStruct) PaychGet(ctx context.Context, from, to address.Address, ensureFunds types.BigInt) (*api.ChannelInfo, error) { - return c.Internal.PaychGet(ctx, from, to, ensureFunds) +func (c *FullNodeStruct) PaychGet(ctx context.Context, from, to address.Address, amt types.BigInt) (*api.ChannelInfo, error) { + return c.Internal.PaychGet(ctx, from, to, amt) +} + +func (c *FullNodeStruct) PaychGetWaitReady(ctx context.Context, mcid cid.Cid) (address.Address, error) { + return c.Internal.PaychGetWaitReady(ctx, mcid) } func (c *FullNodeStruct) PaychList(ctx context.Context) ([]address.Address, error) { diff --git a/api/test/paych.go b/api/test/paych.go index 1684413a9..81348ecfa 100644 --- a/api/test/paych.go +++ b/api/test/paych.go @@ -1,18 +1,17 @@ package test import ( - "bytes" "context" "fmt" - "github.com/filecoin-project/specs-actors/actors/builtin" "os" "sync/atomic" "testing" "time" + "github.com/filecoin-project/specs-actors/actors/builtin" + "github.com/filecoin-project/specs-actors/actors/abi" "github.com/filecoin-project/specs-actors/actors/abi/big" - initactor "github.com/filecoin-project/specs-actors/actors/builtin/init" "github.com/filecoin-project/specs-actors/actors/builtin/paych" "github.com/ipfs/go-cid" @@ -28,8 +27,6 @@ import ( ) func TestPaymentChannels(t *testing.T, b APIBuilder, blocktime time.Duration) { - t.Skip("fixme") - _ = os.Setenv("BELLMAN_NO_GPU", "1") ctx := context.Background() @@ -77,13 +74,10 @@ func TestPaymentChannels(t *testing.T, b APIBuilder, blocktime time.Duration) { t.Fatal(err) } - res := waitForMessage(ctx, t, paymentCreator, channelInfo.ChannelMessage, time.Second, "channel create") - var params initactor.ExecReturn - err = params.UnmarshalCBOR(bytes.NewReader(res.Receipt.Return)) + channel, err := paymentCreator.PaychGetWaitReady(ctx, channelInfo.ChannelMessage) if err != nil { t.Fatal(err) } - channel := params.RobustAddress // allocate three lanes var lanes []uint64 @@ -129,16 +123,11 @@ func TestPaymentChannels(t *testing.T, b APIBuilder, blocktime time.Duration) { t.Fatal(err) } - res = waitForMessage(ctx, t, paymentCreator, settleMsgCid, time.Second*10, "settle") + res := waitForMessage(ctx, t, paymentCreator, settleMsgCid, time.Second*10, "settle") if res.Receipt.ExitCode != 0 { t.Fatal("Unable to settle payment channel") } - creatorPreCollectBalance, err := paymentCreator.WalletBalance(ctx, createrAddr) - if err != nil { - t.Fatal(err) - } - // wait for the receiver to submit their vouchers ev := events.NewEvents(ctx, paymentCreator) preds := state.NewStatePredicates(paymentCreator) @@ -172,24 +161,12 @@ func TestPaymentChannels(t *testing.T, b APIBuilder, blocktime time.Duration) { t.Fatal("Timed out waiting for receiver to submit vouchers") } - atomic.StoreInt64(&bm.nulls, paych.SettleDelay) + // wait for the settlement period to pass before collecting + waitForBlocks(ctx, t, bm, paymentReceiver, receiverAddr, paych.SettleDelay) - { - // wait for a block - m, err := paymentCreator.MpoolPushMessage(ctx, &types.Message{ - To: builtin.BurntFundsActorAddr, - From: createrAddr, - Value: types.NewInt(0), - GasPrice: big.Zero(), - }) - if err != nil { - t.Fatal(err) - } - - _, err = paymentCreator.StateWaitMsg(ctx, m.Cid(), 3) - if err != nil { - t.Fatal(err) - } + creatorPreCollectBalance, err := paymentCreator.WalletBalance(ctx, createrAddr) + if err != nil { + t.Fatal(err) } // collect funds (from receiver, though either party can do it) @@ -197,7 +174,7 @@ func TestPaymentChannels(t *testing.T, b APIBuilder, blocktime time.Duration) { if err != nil { t.Fatal(err) } - res, err = paymentReceiver.StateWaitMsg(ctx, collectMsg, 1) + res, err = paymentReceiver.StateWaitMsg(ctx, collectMsg, 3) if err != nil { t.Fatal(err) } @@ -213,6 +190,7 @@ func TestPaymentChannels(t *testing.T, b APIBuilder, blocktime time.Duration) { // The highest nonce voucher that the creator sent on each lane is 2000 totalVouchers := int64(len(lanes) * 2000) + // When receiver submits the tokens to the chain, creator should get a // refund on the remaining balance, which is // channel amount - total voucher value @@ -226,6 +204,36 @@ func TestPaymentChannels(t *testing.T, b APIBuilder, blocktime time.Duration) { bm.stop() } +func waitForBlocks(ctx context.Context, t *testing.T, bm *blockMiner, paymentReceiver TestNode, receiverAddr address.Address, count int) { + // We need to add null blocks in batches, if we add too many the chain can't sync + batchSize := 60 + for i := 0; i < count; i += batchSize { + size := batchSize + if i > count { + size = count - i + } + + // Add a batch of null blocks + atomic.StoreInt64(&bm.nulls, int64(size-1)) + + // Add a real block + m, err := paymentReceiver.MpoolPushMessage(ctx, &types.Message{ + To: builtin.BurntFundsActorAddr, + From: receiverAddr, + Value: types.NewInt(0), + GasPrice: big.Zero(), + }) + if err != nil { + t.Fatal(err) + } + + _, err = paymentReceiver.StateWaitMsg(ctx, m.Cid(), 1) + if err != nil { + t.Fatal(err) + } + } +} + func waitForMessage(ctx context.Context, t *testing.T, paymentCreator TestNode, msgCid cid.Cid, duration time.Duration, desc string) *api.MsgLookup { ctx, cancel := context.WithTimeout(ctx, duration) defer cancel() diff --git a/chain/vm/runtime.go b/chain/vm/runtime.go index 9985d4a73..5ab0dd3b7 100644 --- a/chain/vm/runtime.go +++ b/chain/vm/runtime.go @@ -294,7 +294,11 @@ func (rt *Runtime) CreateActor(codeID cid.Cid, address address.Address) { _ = rt.chargeGasSafe(gasOnActorExec) } -func (rt *Runtime) DeleteActor(addr address.Address) { +// DeleteActor deletes the executing actor from the state tree, transferring +// any balance to beneficiary. +// Aborts if the beneficiary does not exist. +// May only be called by the actor itself. +func (rt *Runtime) DeleteActor(beneficiary address.Address) { rt.chargeGas(rt.Pricelist().OnDeleteActor()) act, err := rt.state.GetActor(rt.Message().Receiver()) if err != nil { @@ -304,11 +308,13 @@ func (rt *Runtime) DeleteActor(addr address.Address) { panic(aerrors.Fatalf("failed to get actor: %s", err)) } if !act.Balance.IsZero() { - if err := rt.vm.transfer(rt.Message().Receiver(), builtin.BurntFundsActorAddr, act.Balance); err != nil { - panic(aerrors.Fatalf("failed to transfer balance to burnt funds actor: %s", err)) + // Transfer the executing actor's balance to the beneficiary + if err := rt.vm.transfer(rt.Message().Receiver(), beneficiary, act.Balance); err != nil { + panic(aerrors.Fatalf("failed to transfer balance to beneficiary actor: %s", err)) } } + // Delete the executing actor if err := rt.state.DeleteActor(rt.Message().Receiver()); err != nil { panic(aerrors.Fatalf("failed to delete actor: %s", err)) } diff --git a/cli/paych.go b/cli/paych.go index 969a36df6..05dc1f319 100644 --- a/cli/paych.go +++ b/cli/paych.go @@ -28,7 +28,7 @@ var paychCmd = &cli.Command{ var paychGetCmd = &cli.Command{ Name: "get", - Usage: "Create a new payment channel or get existing one", + Usage: "Create a new payment channel or get existing one and add amount to it", ArgsUsage: "[fromAddress toAddress amount]", Action: func(cctx *cli.Context) error { if cctx.Args().Len() != 3 { diff --git a/gen/main.go b/gen/main.go index 01cd756f7..1467a8943 100644 --- a/gen/main.go +++ b/gen/main.go @@ -35,6 +35,7 @@ func main() { err = gen.WriteMapEncodersToFile("./paychmgr/cbor_gen.go", "paychmgr", paychmgr.VoucherInfo{}, paychmgr.ChannelInfo{}, + paychmgr.MsgInfo{}, ) if err != nil { fmt.Println(err) diff --git a/markets/retrievaladapter/client.go b/markets/retrievaladapter/client.go index 552a1b981..59eca595c 100644 --- a/markets/retrievaladapter/client.go +++ b/markets/retrievaladapter/client.go @@ -1,21 +1,16 @@ package retrievaladapter import ( - "bytes" "context" "github.com/filecoin-project/go-address" "github.com/filecoin-project/go-fil-markets/retrievalmarket" "github.com/filecoin-project/go-fil-markets/shared" "github.com/filecoin-project/specs-actors/actors/abi" - initactor "github.com/filecoin-project/specs-actors/actors/builtin/init" "github.com/filecoin-project/specs-actors/actors/builtin/paych" - "github.com/filecoin-project/specs-actors/actors/runtime/exitcode" "github.com/ipfs/go-cid" "github.com/multiformats/go-multiaddr" - "golang.org/x/xerrors" - "github.com/filecoin-project/lotus/build" "github.com/filecoin-project/lotus/chain/types" "github.com/filecoin-project/lotus/node/impl/full" payapi "github.com/filecoin-project/lotus/node/impl/paych" @@ -76,31 +71,12 @@ func (rcn *retrievalClientNode) GetChainHead(ctx context.Context) (shared.TipSet // WaitForPaymentChannelAddFunds waits messageCID to appear on chain. If it doesn't appear within // defaultMsgWaitTimeout it returns error func (rcn *retrievalClientNode) WaitForPaymentChannelAddFunds(messageCID cid.Cid) error { - _, mr, err := rcn.chainAPI.StateManager.WaitForMessage(context.TODO(), messageCID, build.MessageConfidence) - - if err != nil { - return err - } - if mr.ExitCode != exitcode.Ok { - return xerrors.Errorf("wait for payment channel to add funds failed. exit code: %d", mr.ExitCode) - } - return nil + _, err := rcn.payAPI.PaychMgr.GetPaychWaitReady(context.TODO(), messageCID) + return err } func (rcn *retrievalClientNode) WaitForPaymentChannelCreation(messageCID cid.Cid) (address.Address, error) { - _, mr, err := rcn.chainAPI.StateManager.WaitForMessage(context.TODO(), messageCID, build.MessageConfidence) - - if err != nil { - return address.Undef, err - } - if mr.ExitCode != exitcode.Ok { - return address.Undef, xerrors.Errorf("payment channel creation failed. exit code: %d", mr.ExitCode) - } - var retval initactor.ExecReturn - if err := retval.UnmarshalCBOR(bytes.NewReader(mr.Return)); err != nil { - return address.Undef, err - } - return retval.RobustAddress, nil + return rcn.payAPI.PaychMgr.GetPaychWaitReady(context.TODO(), messageCID) } func (rcn *retrievalClientNode) GetKnownAddresses(ctx context.Context, p retrievalmarket.RetrievalPeer, encodedTs shared.TipSetToken) ([]multiaddr.Multiaddr, error) { diff --git a/node/builder.go b/node/builder.go index 2a7cdb4e2..12d8cb888 100644 --- a/node/builder.go +++ b/node/builder.go @@ -5,8 +5,6 @@ import ( "errors" "time" - "github.com/filecoin-project/lotus/markets/dealfilter" - logging "github.com/ipfs/go-log" ci "github.com/libp2p/go-libp2p-core/crypto" "github.com/libp2p/go-libp2p-core/host" @@ -111,6 +109,7 @@ const ( HandleIncomingMessagesKey RegisterClientValidatorKey + HandlePaymentChannelManagerKey // miner GetParamsKey @@ -279,6 +278,7 @@ func Online() Option { Override(new(*paychmgr.Store), paychmgr.NewStore), Override(new(*paychmgr.Manager), paychmgr.NewManager), Override(new(*market.FundMgr), market.StartFundManager), + Override(HandlePaymentChannelManagerKey, paychmgr.HandleManager), Override(SettlePaymentChannelsKey, settler.SettlePaymentChannels), ), diff --git a/node/impl/paych/paych.go b/node/impl/paych/paych.go index c9f2f215d..8e28979f5 100644 --- a/node/impl/paych/paych.go +++ b/node/impl/paych/paych.go @@ -28,8 +28,8 @@ type PaychAPI struct { PaychMgr *paychmgr.Manager } -func (a *PaychAPI) PaychGet(ctx context.Context, from, to address.Address, ensureFunds types.BigInt) (*api.ChannelInfo, error) { - ch, mcid, err := a.PaychMgr.GetPaych(ctx, from, to, ensureFunds) +func (a *PaychAPI) PaychGet(ctx context.Context, from, to address.Address, amt types.BigInt) (*api.ChannelInfo, error) { + ch, mcid, err := a.PaychMgr.GetPaych(ctx, from, to, amt) if err != nil { return nil, err } @@ -40,6 +40,10 @@ func (a *PaychAPI) PaychGet(ctx context.Context, from, to address.Address, ensur }, nil } +func (a *PaychAPI) PaychGetWaitReady(ctx context.Context, mcid cid.Cid) (address.Address, error) { + return a.PaychMgr.GetPaychWaitReady(ctx, mcid) +} + func (a *PaychAPI) PaychAllocateLane(ctx context.Context, ch address.Address) (uint64, error) { return a.PaychMgr.AllocateLane(ch) } @@ -66,7 +70,7 @@ func (a *PaychAPI) PaychNewPayment(ctx context.Context, from, to address.Address ChannelAddr: ch.Channel, Amount: v.Amount, - Lane: uint64(lane), + Lane: lane, Extra: v.Extra, TimeLockMin: v.TimeLockMin, @@ -108,24 +112,7 @@ func (a *PaychAPI) PaychStatus(ctx context.Context, pch address.Address) (*api.P } func (a *PaychAPI) PaychSettle(ctx context.Context, addr address.Address) (cid.Cid, error) { - - ci, err := a.PaychMgr.GetChannelInfo(addr) - if err != nil { - return cid.Undef, err - } - - msg := &types.Message{ - To: addr, - From: ci.Control, - Value: types.NewInt(0), - Method: builtin.MethodsPaych.Settle, - } - smgs, err := a.MpoolPushMessage(ctx, msg) - - if err != nil { - return cid.Undef, err - } - return smgs.Cid(), nil + return a.PaychMgr.Settle(ctx, addr) } func (a *PaychAPI) PaychCollect(ctx context.Context, addr address.Address) (cid.Cid, error) { diff --git a/paychmgr/accessorcache.go b/paychmgr/accessorcache.go new file mode 100644 index 000000000..43223200d --- /dev/null +++ b/paychmgr/accessorcache.go @@ -0,0 +1,67 @@ +package paychmgr + +import "github.com/filecoin-project/go-address" + +// accessorByFromTo gets a channel accessor for a given from / to pair. +// The channel accessor facilitates locking a channel so that operations +// must be performed sequentially on a channel (but can be performed at +// the same time on different channels). +func (pm *Manager) accessorByFromTo(from address.Address, to address.Address) (*channelAccessor, error) { + key := pm.accessorCacheKey(from, to) + + // First take a read lock and check the cache + pm.lk.RLock() + ca, ok := pm.channels[key] + pm.lk.RUnlock() + if ok { + return ca, nil + } + + // Not in cache, so take a write lock + pm.lk.Lock() + defer pm.lk.Unlock() + + // Need to check cache again in case it was updated between releasing read + // lock and taking write lock + ca, ok = pm.channels[key] + if !ok { + // Not in cache, so create a new one and store in cache + ca = pm.addAccessorToCache(from, to) + } + + return ca, nil +} + +// accessorByAddress gets a channel accessor for a given channel address. +// The channel accessor facilitates locking a channel so that operations +// must be performed sequentially on a channel (but can be performed at +// the same time on different channels). +func (pm *Manager) accessorByAddress(ch address.Address) (*channelAccessor, error) { + // Get the channel from / to + pm.lk.RLock() + channelInfo, err := pm.store.ByAddress(ch) + pm.lk.RUnlock() + if err != nil { + return nil, err + } + + // TODO: cache by channel address so we can get by address instead of using from / to + return pm.accessorByFromTo(channelInfo.Control, channelInfo.Target) +} + +// accessorCacheKey returns the cache key use to reference a channel accessor +func (pm *Manager) accessorCacheKey(from address.Address, to address.Address) string { + return from.String() + "->" + to.String() +} + +// addAccessorToCache adds a channel accessor to the cache. Note that the +// channel may not have been created yet, but we still want to reference +// the same channel accessor for a given from/to, so that all attempts to +// access a channel use the same lock (the lock on the accessor) +func (pm *Manager) addAccessorToCache(from address.Address, to address.Address) *channelAccessor { + key := pm.accessorCacheKey(from, to) + ca := newChannelAccessor(pm) + // TODO: Use LRU + pm.channels[key] = ca + return ca +} diff --git a/paychmgr/cbor_gen.go b/paychmgr/cbor_gen.go index 8876f6c8a..57666fe2d 100644 --- a/paychmgr/cbor_gen.go +++ b/paychmgr/cbor_gen.go @@ -6,6 +6,7 @@ import ( "fmt" "io" + "github.com/filecoin-project/go-address" "github.com/filecoin-project/specs-actors/actors/builtin/paych" cbg "github.com/whyrusleeping/cbor-gen" xerrors "golang.org/x/xerrors" @@ -156,12 +157,35 @@ func (t *ChannelInfo) MarshalCBOR(w io.Writer) error { _, err := w.Write(cbg.CborNull) return err } - if _, err := w.Write([]byte{166}); err != nil { + if _, err := w.Write([]byte{172}); err != nil { return err } scratch := make([]byte, 9) + // t.ChannelID (string) (string) + if len("ChannelID") > cbg.MaxLength { + return xerrors.Errorf("Value in field \"ChannelID\" was too long") + } + + if err := cbg.WriteMajorTypeHeaderBuf(scratch, w, cbg.MajTextString, uint64(len("ChannelID"))); err != nil { + return err + } + if _, err := io.WriteString(w, string("ChannelID")); err != nil { + return err + } + + if len(t.ChannelID) > cbg.MaxLength { + return xerrors.Errorf("Value in field t.ChannelID was too long") + } + + if err := cbg.WriteMajorTypeHeaderBuf(scratch, w, cbg.MajTextString, uint64(len(t.ChannelID))); err != nil { + return err + } + if _, err := io.WriteString(w, string(t.ChannelID)); err != nil { + return err + } + // t.Channel (address.Address) (struct) if len("Channel") > cbg.MaxLength { return xerrors.Errorf("Value in field \"Channel\" was too long") @@ -267,6 +291,97 @@ func (t *ChannelInfo) MarshalCBOR(w io.Writer) error { return err } + // t.Amount (big.Int) (struct) + if len("Amount") > cbg.MaxLength { + return xerrors.Errorf("Value in field \"Amount\" was too long") + } + + if err := cbg.WriteMajorTypeHeaderBuf(scratch, w, cbg.MajTextString, uint64(len("Amount"))); err != nil { + return err + } + if _, err := io.WriteString(w, string("Amount")); err != nil { + return err + } + + if err := t.Amount.MarshalCBOR(w); err != nil { + return err + } + + // t.PendingAmount (big.Int) (struct) + if len("PendingAmount") > cbg.MaxLength { + return xerrors.Errorf("Value in field \"PendingAmount\" was too long") + } + + if err := cbg.WriteMajorTypeHeaderBuf(scratch, w, cbg.MajTextString, uint64(len("PendingAmount"))); err != nil { + return err + } + if _, err := io.WriteString(w, string("PendingAmount")); err != nil { + return err + } + + if err := t.PendingAmount.MarshalCBOR(w); err != nil { + return err + } + + // t.CreateMsg (cid.Cid) (struct) + if len("CreateMsg") > cbg.MaxLength { + return xerrors.Errorf("Value in field \"CreateMsg\" was too long") + } + + if err := cbg.WriteMajorTypeHeaderBuf(scratch, w, cbg.MajTextString, uint64(len("CreateMsg"))); err != nil { + return err + } + if _, err := io.WriteString(w, string("CreateMsg")); err != nil { + return err + } + + if t.CreateMsg == nil { + if _, err := w.Write(cbg.CborNull); err != nil { + return err + } + } else { + if err := cbg.WriteCidBuf(scratch, w, *t.CreateMsg); err != nil { + return xerrors.Errorf("failed to write cid field t.CreateMsg: %w", err) + } + } + + // t.AddFundsMsg (cid.Cid) (struct) + if len("AddFundsMsg") > cbg.MaxLength { + return xerrors.Errorf("Value in field \"AddFundsMsg\" was too long") + } + + if err := cbg.WriteMajorTypeHeaderBuf(scratch, w, cbg.MajTextString, uint64(len("AddFundsMsg"))); err != nil { + return err + } + if _, err := io.WriteString(w, string("AddFundsMsg")); err != nil { + return err + } + + if t.AddFundsMsg == nil { + if _, err := w.Write(cbg.CborNull); err != nil { + return err + } + } else { + if err := cbg.WriteCidBuf(scratch, w, *t.AddFundsMsg); err != nil { + return xerrors.Errorf("failed to write cid field t.AddFundsMsg: %w", err) + } + } + + // t.Settling (bool) (bool) + if len("Settling") > cbg.MaxLength { + return xerrors.Errorf("Value in field \"Settling\" was too long") + } + + if err := cbg.WriteMajorTypeHeaderBuf(scratch, w, cbg.MajTextString, uint64(len("Settling"))); err != nil { + return err + } + if _, err := io.WriteString(w, string("Settling")); err != nil { + return err + } + + if err := cbg.WriteBool(w, t.Settling); err != nil { + return err + } return nil } @@ -303,13 +418,36 @@ func (t *ChannelInfo) UnmarshalCBOR(r io.Reader) error { } switch name { - // t.Channel (address.Address) (struct) + // t.ChannelID (string) (string) + case "ChannelID": + + { + sval, err := cbg.ReadStringBuf(br, scratch) + if err != nil { + return err + } + + t.ChannelID = string(sval) + } + // t.Channel (address.Address) (struct) case "Channel": { - if err := t.Channel.UnmarshalCBOR(br); err != nil { - return xerrors.Errorf("unmarshaling t.Channel: %w", err) + pb, err := br.PeekByte() + if err != nil { + return err + } + if pb == cbg.CborNull[0] { + var nbuf [1]byte + if _, err := br.Read(nbuf[:]); err != nil { + return err + } + } else { + t.Channel = new(address.Address) + if err := t.Channel.UnmarshalCBOR(br); err != nil { + return xerrors.Errorf("unmarshaling t.Channel pointer: %w", err) + } } } @@ -393,6 +531,279 @@ func (t *ChannelInfo) UnmarshalCBOR(r io.Reader) error { t.NextLane = uint64(extra) } + // t.Amount (big.Int) (struct) + case "Amount": + + { + + if err := t.Amount.UnmarshalCBOR(br); err != nil { + return xerrors.Errorf("unmarshaling t.Amount: %w", err) + } + + } + // t.PendingAmount (big.Int) (struct) + case "PendingAmount": + + { + + if err := t.PendingAmount.UnmarshalCBOR(br); err != nil { + return xerrors.Errorf("unmarshaling t.PendingAmount: %w", err) + } + + } + // t.CreateMsg (cid.Cid) (struct) + case "CreateMsg": + + { + + pb, err := br.PeekByte() + if err != nil { + return err + } + if pb == cbg.CborNull[0] { + var nbuf [1]byte + if _, err := br.Read(nbuf[:]); err != nil { + return err + } + } else { + + c, err := cbg.ReadCid(br) + if err != nil { + return xerrors.Errorf("failed to read cid field t.CreateMsg: %w", err) + } + + t.CreateMsg = &c + } + + } + // t.AddFundsMsg (cid.Cid) (struct) + case "AddFundsMsg": + + { + + pb, err := br.PeekByte() + if err != nil { + return err + } + if pb == cbg.CborNull[0] { + var nbuf [1]byte + if _, err := br.Read(nbuf[:]); err != nil { + return err + } + } else { + + c, err := cbg.ReadCid(br) + if err != nil { + return xerrors.Errorf("failed to read cid field t.AddFundsMsg: %w", err) + } + + t.AddFundsMsg = &c + } + + } + // t.Settling (bool) (bool) + case "Settling": + + maj, extra, err = cbg.CborReadHeaderBuf(br, scratch) + if err != nil { + return err + } + if maj != cbg.MajOther { + return fmt.Errorf("booleans must be major type 7") + } + switch extra { + case 20: + t.Settling = false + case 21: + t.Settling = true + default: + return fmt.Errorf("booleans are either major type 7, value 20 or 21 (got %d)", extra) + } + + default: + return fmt.Errorf("unknown struct field %d: '%s'", i, name) + } + } + + return nil +} +func (t *MsgInfo) MarshalCBOR(w io.Writer) error { + if t == nil { + _, err := w.Write(cbg.CborNull) + return err + } + if _, err := w.Write([]byte{164}); err != nil { + return err + } + + scratch := make([]byte, 9) + + // t.ChannelID (string) (string) + if len("ChannelID") > cbg.MaxLength { + return xerrors.Errorf("Value in field \"ChannelID\" was too long") + } + + if err := cbg.WriteMajorTypeHeaderBuf(scratch, w, cbg.MajTextString, uint64(len("ChannelID"))); err != nil { + return err + } + if _, err := io.WriteString(w, string("ChannelID")); err != nil { + return err + } + + if len(t.ChannelID) > cbg.MaxLength { + return xerrors.Errorf("Value in field t.ChannelID was too long") + } + + if err := cbg.WriteMajorTypeHeaderBuf(scratch, w, cbg.MajTextString, uint64(len(t.ChannelID))); err != nil { + return err + } + if _, err := io.WriteString(w, string(t.ChannelID)); err != nil { + return err + } + + // t.MsgCid (cid.Cid) (struct) + if len("MsgCid") > cbg.MaxLength { + return xerrors.Errorf("Value in field \"MsgCid\" was too long") + } + + if err := cbg.WriteMajorTypeHeaderBuf(scratch, w, cbg.MajTextString, uint64(len("MsgCid"))); err != nil { + return err + } + if _, err := io.WriteString(w, string("MsgCid")); err != nil { + return err + } + + if err := cbg.WriteCidBuf(scratch, w, t.MsgCid); err != nil { + return xerrors.Errorf("failed to write cid field t.MsgCid: %w", err) + } + + // t.Received (bool) (bool) + if len("Received") > cbg.MaxLength { + return xerrors.Errorf("Value in field \"Received\" was too long") + } + + if err := cbg.WriteMajorTypeHeaderBuf(scratch, w, cbg.MajTextString, uint64(len("Received"))); err != nil { + return err + } + if _, err := io.WriteString(w, string("Received")); err != nil { + return err + } + + if err := cbg.WriteBool(w, t.Received); err != nil { + return err + } + + // t.Err (string) (string) + if len("Err") > cbg.MaxLength { + return xerrors.Errorf("Value in field \"Err\" was too long") + } + + if err := cbg.WriteMajorTypeHeaderBuf(scratch, w, cbg.MajTextString, uint64(len("Err"))); err != nil { + return err + } + if _, err := io.WriteString(w, string("Err")); err != nil { + return err + } + + if len(t.Err) > cbg.MaxLength { + return xerrors.Errorf("Value in field t.Err was too long") + } + + if err := cbg.WriteMajorTypeHeaderBuf(scratch, w, cbg.MajTextString, uint64(len(t.Err))); err != nil { + return err + } + if _, err := io.WriteString(w, string(t.Err)); err != nil { + return err + } + return nil +} + +func (t *MsgInfo) UnmarshalCBOR(r io.Reader) error { + *t = MsgInfo{} + + br := cbg.GetPeeker(r) + scratch := make([]byte, 8) + + maj, extra, err := cbg.CborReadHeaderBuf(br, scratch) + if err != nil { + return err + } + if maj != cbg.MajMap { + return fmt.Errorf("cbor input should be of type map") + } + + if extra > cbg.MaxLength { + return fmt.Errorf("MsgInfo: map struct too large (%d)", extra) + } + + var name string + n := extra + + for i := uint64(0); i < n; i++ { + + { + sval, err := cbg.ReadStringBuf(br, scratch) + if err != nil { + return err + } + + name = string(sval) + } + + switch name { + // t.ChannelID (string) (string) + case "ChannelID": + + { + sval, err := cbg.ReadStringBuf(br, scratch) + if err != nil { + return err + } + + t.ChannelID = string(sval) + } + // t.MsgCid (cid.Cid) (struct) + case "MsgCid": + + { + + c, err := cbg.ReadCid(br) + if err != nil { + return xerrors.Errorf("failed to read cid field t.MsgCid: %w", err) + } + + t.MsgCid = c + + } + // t.Received (bool) (bool) + case "Received": + + maj, extra, err = cbg.CborReadHeaderBuf(br, scratch) + if err != nil { + return err + } + if maj != cbg.MajOther { + return fmt.Errorf("booleans must be major type 7") + } + switch extra { + case 20: + t.Received = false + case 21: + t.Received = true + default: + return fmt.Errorf("booleans are either major type 7, value 20 or 21 (got %d)", extra) + } + // t.Err (string) (string) + case "Err": + + { + sval, err := cbg.ReadStringBuf(br, scratch) + if err != nil { + return err + } + + t.Err = string(sval) + } default: return fmt.Errorf("unknown struct field %d: '%s'", i, name) diff --git a/paychmgr/channellock.go b/paychmgr/channellock.go new file mode 100644 index 000000000..0dc785ec0 --- /dev/null +++ b/paychmgr/channellock.go @@ -0,0 +1,33 @@ +package paychmgr + +import "sync" + +type rwlock interface { + RLock() + RUnlock() +} + +// channelLock manages locking for a specific channel. +// Some operations update the state of a single channel, and need to block +// other operations only on the same channel's state. +// Some operations update state that affects all channels, and need to block +// any operation against any channel. +type channelLock struct { + globalLock rwlock + chanLock sync.Mutex +} + +func (l *channelLock) Lock() { + // Wait for other operations by this channel to finish. + // Exclusive per-channel (no other ops by this channel allowed). + l.chanLock.Lock() + // Wait for operations affecting all channels to finish. + // Allows ops by other channels in parallel, but blocks all operations + // if global lock is taken exclusively (eg when adding a channel) + l.globalLock.RLock() +} + +func (l *channelLock) Unlock() { + l.globalLock.RUnlock() + l.chanLock.Unlock() +} diff --git a/paychmgr/manager.go b/paychmgr/manager.go new file mode 100644 index 000000000..7f47640f1 --- /dev/null +++ b/paychmgr/manager.go @@ -0,0 +1,244 @@ +package paychmgr + +import ( + "context" + "sync" + + "github.com/filecoin-project/lotus/node/modules/helpers" + + "github.com/ipfs/go-datastore" + + xerrors "golang.org/x/xerrors" + + "github.com/filecoin-project/lotus/api" + + "github.com/filecoin-project/specs-actors/actors/builtin/paych" + + "github.com/ipfs/go-cid" + logging "github.com/ipfs/go-log/v2" + "go.uber.org/fx" + + "github.com/filecoin-project/go-address" + + "github.com/filecoin-project/lotus/chain/stmgr" + "github.com/filecoin-project/lotus/chain/types" + "github.com/filecoin-project/lotus/node/impl/full" +) + +var log = logging.Logger("paych") + +type ManagerApi struct { + fx.In + + full.MpoolAPI + full.WalletAPI + full.StateAPI +} + +type StateManagerApi interface { + LoadActorState(ctx context.Context, a address.Address, out interface{}, ts *types.TipSet) (*types.Actor, error) + Call(ctx context.Context, msg *types.Message, ts *types.TipSet) (*api.InvocResult, error) +} + +type Manager struct { + // The Manager context is used to terminate wait operations on shutdown + ctx context.Context + shutdown context.CancelFunc + + store *Store + sm StateManagerApi + sa *stateAccessor + pchapi paychApi + + lk sync.RWMutex + channels map[string]*channelAccessor + + mpool full.MpoolAPI + wallet full.WalletAPI + state full.StateAPI +} + +func NewManager(mctx helpers.MetricsCtx, lc fx.Lifecycle, sm *stmgr.StateManager, pchstore *Store, api ManagerApi) *Manager { + ctx := helpers.LifecycleCtx(mctx, lc) + ctx, shutdown := context.WithCancel(ctx) + + return &Manager{ + ctx: ctx, + shutdown: shutdown, + store: pchstore, + sm: sm, + sa: &stateAccessor{sm: sm}, + channels: make(map[string]*channelAccessor), + pchapi: &api, + + mpool: api.MpoolAPI, + wallet: api.WalletAPI, + state: api.StateAPI, + } +} + +// newManager is used by the tests to supply mocks +func newManager(sm StateManagerApi, pchstore *Store, pchapi paychApi) (*Manager, error) { + pm := &Manager{ + store: pchstore, + sm: sm, + sa: &stateAccessor{sm: sm}, + channels: make(map[string]*channelAccessor), + pchapi: pchapi, + } + return pm, pm.Start() +} + +// HandleManager is called by dependency injection to set up hooks +func HandleManager(lc fx.Lifecycle, pm *Manager) { + lc.Append(fx.Hook{ + OnStart: func(ctx context.Context) error { + return pm.Start() + }, + OnStop: func(context.Context) error { + return pm.Stop() + }, + }) +} + +// Start restarts tracking of any messages that were sent to chain. +func (pm *Manager) Start() error { + return pm.restartPending() +} + +// Stop shuts down any processes used by the manager +func (pm *Manager) Stop() error { + pm.shutdown() + return nil +} + +func (pm *Manager) TrackOutboundChannel(ctx context.Context, ch address.Address) error { + return pm.trackChannel(ctx, ch, DirOutbound) +} + +func (pm *Manager) TrackInboundChannel(ctx context.Context, ch address.Address) error { + return pm.trackChannel(ctx, ch, DirInbound) +} + +func (pm *Manager) trackChannel(ctx context.Context, ch address.Address, dir uint64) error { + pm.lk.Lock() + defer pm.lk.Unlock() + + ci, err := pm.sa.loadStateChannelInfo(ctx, ch, dir) + if err != nil { + return err + } + + return pm.store.TrackChannel(ci) +} + +func (pm *Manager) GetPaych(ctx context.Context, from, to address.Address, amt types.BigInt) (address.Address, cid.Cid, error) { + chanAccessor, err := pm.accessorByFromTo(from, to) + if err != nil { + return address.Undef, cid.Undef, err + } + + return chanAccessor.getPaych(ctx, from, to, amt) +} + +// GetPaychWaitReady waits until the create channel / add funds message with the +// given message CID arrives. +// The returned channel address can safely be used against the Manager methods. +func (pm *Manager) GetPaychWaitReady(ctx context.Context, mcid cid.Cid) (address.Address, error) { + // Find the channel associated with the message CID + pm.lk.Lock() + ci, err := pm.store.ByMessageCid(mcid) + pm.lk.Unlock() + + if err != nil { + if err == datastore.ErrNotFound { + return address.Undef, xerrors.Errorf("Could not find wait msg cid %s", mcid) + } + return address.Undef, err + } + + chanAccessor, err := pm.accessorByFromTo(ci.Control, ci.Target) + if err != nil { + return address.Undef, err + } + + return chanAccessor.getPaychWaitReady(ctx, mcid) +} + +func (pm *Manager) ListChannels() ([]address.Address, error) { + // Need to take an exclusive lock here so that channel operations can't run + // in parallel (see channelLock) + pm.lk.Lock() + defer pm.lk.Unlock() + + return pm.store.ListChannels() +} + +func (pm *Manager) GetChannelInfo(addr address.Address) (*ChannelInfo, error) { + ca, err := pm.accessorByAddress(addr) + if err != nil { + return nil, err + } + return ca.getChannelInfo(addr) +} + +// CheckVoucherValid checks if the given voucher is valid (is or could become spendable at some point) +func (pm *Manager) CheckVoucherValid(ctx context.Context, ch address.Address, sv *paych.SignedVoucher) error { + ca, err := pm.accessorByAddress(ch) + if err != nil { + return err + } + + _, err = ca.checkVoucherValid(ctx, ch, sv) + return err +} + +// CheckVoucherSpendable checks if the given voucher is currently spendable +func (pm *Manager) CheckVoucherSpendable(ctx context.Context, ch address.Address, sv *paych.SignedVoucher, secret []byte, proof []byte) (bool, error) { + ca, err := pm.accessorByAddress(ch) + if err != nil { + return false, err + } + + return ca.checkVoucherSpendable(ctx, ch, sv, secret, proof) +} + +func (pm *Manager) AddVoucher(ctx context.Context, ch address.Address, sv *paych.SignedVoucher, proof []byte, minDelta types.BigInt) (types.BigInt, error) { + ca, err := pm.accessorByAddress(ch) + if err != nil { + return types.NewInt(0), err + } + return ca.addVoucher(ctx, ch, sv, proof, minDelta) +} + +func (pm *Manager) AllocateLane(ch address.Address) (uint64, error) { + ca, err := pm.accessorByAddress(ch) + if err != nil { + return 0, err + } + return ca.allocateLane(ch) +} + +func (pm *Manager) ListVouchers(ctx context.Context, ch address.Address) ([]*VoucherInfo, error) { + ca, err := pm.accessorByAddress(ch) + if err != nil { + return nil, err + } + return ca.listVouchers(ctx, ch) +} + +func (pm *Manager) NextNonceForLane(ctx context.Context, ch address.Address, lane uint64) (uint64, error) { + ca, err := pm.accessorByAddress(ch) + if err != nil { + return 0, err + } + return ca.nextNonceForLane(ctx, ch, lane) +} + +func (pm *Manager) Settle(ctx context.Context, addr address.Address) (cid.Cid, error) { + ca, err := pm.accessorByAddress(addr) + if err != nil { + return cid.Undef, err + } + return ca.settle(ctx, addr) +} diff --git a/paychmgr/msglistener.go b/paychmgr/msglistener.go new file mode 100644 index 000000000..0a38cc2da --- /dev/null +++ b/paychmgr/msglistener.go @@ -0,0 +1,61 @@ +package paychmgr + +import ( + "sync" + + "github.com/google/uuid" + "github.com/ipfs/go-cid" +) + +type msgListener struct { + id string + cb func(c cid.Cid, err error) +} + +type msgListeners struct { + lk sync.Mutex + listeners []*msgListener +} + +func (ml *msgListeners) onMsg(mcid cid.Cid, cb func(error)) string { + ml.lk.Lock() + defer ml.lk.Unlock() + + l := &msgListener{ + id: uuid.New().String(), + cb: func(c cid.Cid, err error) { + if mcid.Equals(c) { + cb(err) + } + }, + } + ml.listeners = append(ml.listeners, l) + return l.id +} + +func (ml *msgListeners) fireMsgComplete(mcid cid.Cid, err error) { + ml.lk.Lock() + defer ml.lk.Unlock() + + for _, l := range ml.listeners { + l.cb(mcid, err) + } +} + +func (ml *msgListeners) unsubscribe(sub string) { + ml.lk.Lock() + defer ml.lk.Unlock() + + for i, l := range ml.listeners { + if l.id == sub { + ml.removeListener(i) + return + } + } +} + +func (ml *msgListeners) removeListener(i int) { + copy(ml.listeners[i:], ml.listeners[i+1:]) + ml.listeners[len(ml.listeners)-1] = nil + ml.listeners = ml.listeners[:len(ml.listeners)-1] +} diff --git a/paychmgr/msglistener_test.go b/paychmgr/msglistener_test.go new file mode 100644 index 000000000..fd457a518 --- /dev/null +++ b/paychmgr/msglistener_test.go @@ -0,0 +1,96 @@ +package paychmgr + +import ( + "testing" + + "github.com/ipfs/go-cid" + + "github.com/stretchr/testify/require" + + "golang.org/x/xerrors" +) + +func testCids() []cid.Cid { + c1, _ := cid.Decode("QmdmGQmRgRjazArukTbsXuuxmSHsMCcRYPAZoGhd6e3MuS") + c2, _ := cid.Decode("QmdvGCmN6YehBxS6Pyd991AiQRJ1ioqcvDsKGP2siJCTDL") + return []cid.Cid{c1, c2} +} + +func TestMsgListener(t *testing.T) { + var ml msgListeners + + done := false + experr := xerrors.Errorf("some err") + cids := testCids() + ml.onMsg(cids[0], func(err error) { + require.Equal(t, experr, err) + done = true + }) + + ml.fireMsgComplete(cids[0], experr) + + if !done { + t.Fatal("failed to fire event") + } +} + +func TestMsgListenerNilErr(t *testing.T) { + var ml msgListeners + + done := false + cids := testCids() + ml.onMsg(cids[0], func(err error) { + require.Nil(t, err) + done = true + }) + + ml.fireMsgComplete(cids[0], nil) + + if !done { + t.Fatal("failed to fire event") + } +} + +func TestMsgListenerUnsub(t *testing.T) { + var ml msgListeners + + done := false + experr := xerrors.Errorf("some err") + cids := testCids() + id1 := ml.onMsg(cids[0], func(err error) { + t.Fatal("should not call unsubscribed listener") + }) + ml.onMsg(cids[0], func(err error) { + require.Equal(t, experr, err) + done = true + }) + + ml.unsubscribe(id1) + ml.fireMsgComplete(cids[0], experr) + + if !done { + t.Fatal("failed to fire event") + } +} + +func TestMsgListenerMulti(t *testing.T) { + var ml msgListeners + + count := 0 + cids := testCids() + ml.onMsg(cids[0], func(err error) { + count++ + }) + ml.onMsg(cids[0], func(err error) { + count++ + }) + ml.onMsg(cids[1], func(err error) { + count++ + }) + + ml.fireMsgComplete(cids[0], nil) + require.Equal(t, 2, count) + + ml.fireMsgComplete(cids[1], nil) + require.Equal(t, 3, count) +} diff --git a/paychmgr/paych.go b/paychmgr/paych.go index 85db664cd..f1d5199f6 100644 --- a/paychmgr/paych.go +++ b/paychmgr/paych.go @@ -5,118 +5,75 @@ import ( "context" "fmt" - "github.com/filecoin-project/specs-actors/actors/abi/big" - - "github.com/filecoin-project/lotus/api" + "github.com/ipfs/go-cid" + "github.com/filecoin-project/go-address" cborutil "github.com/filecoin-project/go-cbor-util" + "github.com/filecoin-project/lotus/chain/actors" + "github.com/filecoin-project/lotus/chain/types" + "github.com/filecoin-project/lotus/lib/sigs" + "github.com/filecoin-project/specs-actors/actors/abi/big" "github.com/filecoin-project/specs-actors/actors/builtin" "github.com/filecoin-project/specs-actors/actors/builtin/account" "github.com/filecoin-project/specs-actors/actors/builtin/paych" - "golang.org/x/xerrors" - - logging "github.com/ipfs/go-log/v2" - "go.uber.org/fx" - - "github.com/filecoin-project/go-address" - - "github.com/filecoin-project/lotus/chain/actors" - "github.com/filecoin-project/lotus/chain/stmgr" - "github.com/filecoin-project/lotus/chain/types" - "github.com/filecoin-project/lotus/lib/sigs" - "github.com/filecoin-project/lotus/node/impl/full" + xerrors "golang.org/x/xerrors" ) -var log = logging.Logger("paych") - -type ManagerApi struct { - fx.In - - full.MpoolAPI - full.WalletAPI - full.StateAPI +// channelAccessor is used to simplify locking when accessing a channel +type channelAccessor struct { + // waitCtx is used by processes that wait for things to be confirmed + // on chain + waitCtx context.Context + sm StateManagerApi + sa *stateAccessor + api paychApi + store *Store + lk *channelLock + fundsReqQueue []*fundsReq + msgListeners msgListeners } -type StateManagerApi interface { - LoadActorState(ctx context.Context, a address.Address, out interface{}, ts *types.TipSet) (*types.Actor, error) - Call(ctx context.Context, msg *types.Message, ts *types.TipSet) (*api.InvocResult, error) -} - -type Manager struct { - store *Store - sm StateManagerApi - - mpool full.MpoolAPI - wallet full.WalletAPI - state full.StateAPI -} - -func NewManager(sm *stmgr.StateManager, pchstore *Store, api ManagerApi) *Manager { - return &Manager{ - store: pchstore, - sm: sm, - - mpool: api.MpoolAPI, - wallet: api.WalletAPI, - state: api.StateAPI, +func newChannelAccessor(pm *Manager) *channelAccessor { + return &channelAccessor{ + lk: &channelLock{globalLock: &pm.lk}, + sm: pm.sm, + sa: &stateAccessor{sm: pm.sm}, + api: pm.pchapi, + store: pm.store, + waitCtx: pm.ctx, } } -// Used by the tests to supply mocks -func newManager(sm StateManagerApi, pchstore *Store) *Manager { - return &Manager{ - store: pchstore, - sm: sm, - } +func (ca *channelAccessor) getChannelInfo(addr address.Address) (*ChannelInfo, error) { + ca.lk.Lock() + defer ca.lk.Unlock() + + return ca.store.ByAddress(addr) } -func (pm *Manager) TrackOutboundChannel(ctx context.Context, ch address.Address) error { - return pm.trackChannel(ctx, ch, DirOutbound) +func (ca *channelAccessor) checkVoucherValid(ctx context.Context, ch address.Address, sv *paych.SignedVoucher) (map[uint64]*paych.LaneState, error) { + ca.lk.Lock() + defer ca.lk.Unlock() + + return ca.checkVoucherValidUnlocked(ctx, ch, sv) } -func (pm *Manager) TrackInboundChannel(ctx context.Context, ch address.Address) error { - return pm.trackChannel(ctx, ch, DirInbound) -} - -func (pm *Manager) trackChannel(ctx context.Context, ch address.Address, dir uint64) error { - ci, err := pm.loadStateChannelInfo(ctx, ch, dir) - if err != nil { - return err - } - - return pm.store.TrackChannel(ci) -} - -func (pm *Manager) ListChannels() ([]address.Address, error) { - return pm.store.ListChannels() -} - -func (pm *Manager) GetChannelInfo(addr address.Address) (*ChannelInfo, error) { - return pm.store.getChannelInfo(addr) -} - -// CheckVoucherValid checks if the given voucher is valid (is or could become spendable at some point) -func (pm *Manager) CheckVoucherValid(ctx context.Context, ch address.Address, sv *paych.SignedVoucher) error { - _, err := pm.checkVoucherValid(ctx, ch, sv) - return err -} - -func (pm *Manager) checkVoucherValid(ctx context.Context, ch address.Address, sv *paych.SignedVoucher) (map[uint64]*paych.LaneState, error) { +func (ca *channelAccessor) checkVoucherValidUnlocked(ctx context.Context, ch address.Address, sv *paych.SignedVoucher) (map[uint64]*paych.LaneState, error) { if sv.ChannelAddr != ch { return nil, xerrors.Errorf("voucher ChannelAddr doesn't match channel address, got %s, expected %s", sv.ChannelAddr, ch) } - act, pchState, err := pm.loadPaychState(ctx, ch) + act, pchState, err := ca.sa.loadPaychState(ctx, ch) if err != nil { return nil, err } - var account account.State - _, err = pm.sm.LoadActorState(ctx, pchState.From, &account, nil) + var actState account.State + _, err = ca.sm.LoadActorState(ctx, pchState.From, &actState, nil) if err != nil { return nil, err } - from := account.Address + from := actState.Address // verify signature vb, err := sv.SigningBytes() @@ -132,7 +89,7 @@ func (pm *Manager) checkVoucherValid(ctx context.Context, ch address.Address, sv } // Check the voucher against the highest known voucher nonce / value - laneStates, err := pm.laneState(pchState, ch) + laneStates, err := ca.laneState(pchState, ch) if err != nil { return nil, err } @@ -164,7 +121,7 @@ func (pm *Manager) checkVoucherValid(ctx context.Context, ch address.Address, sv // lane 2: 2 // - // total: 7 - totalRedeemed, err := pm.totalRedeemedWithVoucher(laneStates, sv) + totalRedeemed, err := ca.totalRedeemedWithVoucher(laneStates, sv) if err != nil { return nil, err } @@ -183,15 +140,17 @@ func (pm *Manager) checkVoucherValid(ctx context.Context, ch address.Address, sv return laneStates, nil } -// CheckVoucherSpendable checks if the given voucher is currently spendable -func (pm *Manager) CheckVoucherSpendable(ctx context.Context, ch address.Address, sv *paych.SignedVoucher, secret []byte, proof []byte) (bool, error) { - recipient, err := pm.getPaychRecipient(ctx, ch) +func (ca *channelAccessor) checkVoucherSpendable(ctx context.Context, ch address.Address, sv *paych.SignedVoucher, secret []byte, proof []byte) (bool, error) { + ca.lk.Lock() + defer ca.lk.Unlock() + + recipient, err := ca.getPaychRecipient(ctx, ch) if err != nil { return false, err } if sv.Extra != nil && proof == nil { - known, err := pm.ListVouchers(ctx, ch) + known, err := ca.store.VouchersForPaych(ch) if err != nil { return false, err } @@ -221,7 +180,7 @@ func (pm *Manager) CheckVoucherSpendable(ctx context.Context, ch address.Address return false, err } - ret, err := pm.sm.Call(ctx, &types.Message{ + ret, err := ca.sm.Call(ctx, &types.Message{ From: recipient, To: ch, Method: builtin.MethodsPaych.UpdateChannelState, @@ -238,22 +197,22 @@ func (pm *Manager) CheckVoucherSpendable(ctx context.Context, ch address.Address return true, nil } -func (pm *Manager) getPaychRecipient(ctx context.Context, ch address.Address) (address.Address, error) { +func (ca *channelAccessor) getPaychRecipient(ctx context.Context, ch address.Address) (address.Address, error) { var state paych.State - if _, err := pm.sm.LoadActorState(ctx, ch, &state, nil); err != nil { + if _, err := ca.sm.LoadActorState(ctx, ch, &state, nil); err != nil { return address.Address{}, err } return state.To, nil } -func (pm *Manager) AddVoucher(ctx context.Context, ch address.Address, sv *paych.SignedVoucher, proof []byte, minDelta types.BigInt) (types.BigInt, error) { - pm.store.lk.Lock() - defer pm.store.lk.Unlock() +func (ca *channelAccessor) addVoucher(ctx context.Context, ch address.Address, sv *paych.SignedVoucher, proof []byte, minDelta types.BigInt) (types.BigInt, error) { + ca.lk.Lock() + defer ca.lk.Unlock() - ci, err := pm.store.getChannelInfo(ch) + ci, err := ca.store.ByAddress(ch) if err != nil { - return types.NewInt(0), err + return types.BigInt{}, err } // Check if the voucher has already been added @@ -275,7 +234,7 @@ func (pm *Manager) AddVoucher(ctx context.Context, ch address.Address, sv *paych Proof: proof, } - return types.NewInt(0), pm.store.putChannelInfo(ci) + return types.NewInt(0), ca.store.putChannelInfo(ci) } // Otherwise just ignore the duplicate voucher @@ -284,7 +243,7 @@ func (pm *Manager) AddVoucher(ctx context.Context, ch address.Address, sv *paych } // Check voucher validity - laneStates, err := pm.checkVoucherValid(ctx, ch, sv) + laneStates, err := ca.checkVoucherValidUnlocked(ctx, ch, sv) if err != nil { return types.NewInt(0), err } @@ -311,35 +270,32 @@ func (pm *Manager) AddVoucher(ctx context.Context, ch address.Address, sv *paych ci.NextLane = sv.Lane + 1 } - return delta, pm.store.putChannelInfo(ci) + return delta, ca.store.putChannelInfo(ci) } -func (pm *Manager) AllocateLane(ch address.Address) (uint64, error) { +func (ca *channelAccessor) allocateLane(ch address.Address) (uint64, error) { + ca.lk.Lock() + defer ca.lk.Unlock() + // TODO: should this take into account lane state? - return pm.store.AllocateLane(ch) + return ca.store.AllocateLane(ch) } -func (pm *Manager) ListVouchers(ctx context.Context, ch address.Address) ([]*VoucherInfo, error) { +func (ca *channelAccessor) listVouchers(ctx context.Context, ch address.Address) ([]*VoucherInfo, error) { + ca.lk.Lock() + defer ca.lk.Unlock() + // TODO: just having a passthrough method like this feels odd. Seems like // there should be some filtering we're doing here - return pm.store.VouchersForPaych(ch) + return ca.store.VouchersForPaych(ch) } -func (pm *Manager) OutboundChanTo(from, to address.Address) (address.Address, error) { - pm.store.lk.Lock() - defer pm.store.lk.Unlock() +func (ca *channelAccessor) nextNonceForLane(ctx context.Context, ch address.Address, lane uint64) (uint64, error) { + ca.lk.Lock() + defer ca.lk.Unlock() - return pm.store.findChan(func(ci *ChannelInfo) bool { - if ci.Direction != DirOutbound { - return false - } - return ci.Control == from && ci.Target == to - }) -} - -func (pm *Manager) NextNonceForLane(ctx context.Context, ch address.Address, lane uint64) (uint64, error) { // TODO: should this take into account lane state? - vouchers, err := pm.store.VouchersForPaych(ch) + vouchers, err := ca.store.VouchersForPaych(ch) if err != nil { return 0, err } @@ -355,3 +311,109 @@ func (pm *Manager) NextNonceForLane(ctx context.Context, ch address.Address, lan return maxnonce + 1, nil } + +// laneState gets the LaneStates from chain, then applies all vouchers in +// the data store over the chain state +func (ca *channelAccessor) laneState(state *paych.State, ch address.Address) (map[uint64]*paych.LaneState, error) { + // TODO: we probably want to call UpdateChannelState with all vouchers to be fully correct + // (but technically dont't need to) + laneStates := make(map[uint64]*paych.LaneState, len(state.LaneStates)) + + // Get the lane state from the chain + for _, laneState := range state.LaneStates { + laneStates[laneState.ID] = laneState + } + + // Apply locally stored vouchers + vouchers, err := ca.store.VouchersForPaych(ch) + if err != nil && err != ErrChannelNotTracked { + return nil, err + } + + for _, v := range vouchers { + for range v.Voucher.Merges { + return nil, xerrors.Errorf("paych merges not handled yet") + } + + // If there's a voucher for a lane that isn't in chain state just + // create it + ls, ok := laneStates[v.Voucher.Lane] + if !ok { + ls = &paych.LaneState{ + ID: v.Voucher.Lane, + Redeemed: types.NewInt(0), + Nonce: 0, + } + laneStates[v.Voucher.Lane] = ls + } + + if v.Voucher.Nonce < ls.Nonce { + continue + } + + ls.Nonce = v.Voucher.Nonce + ls.Redeemed = v.Voucher.Amount + } + + return laneStates, nil +} + +// Get the total redeemed amount across all lanes, after applying the voucher +func (ca *channelAccessor) totalRedeemedWithVoucher(laneStates map[uint64]*paych.LaneState, sv *paych.SignedVoucher) (big.Int, error) { + // TODO: merges + if len(sv.Merges) != 0 { + return big.Int{}, xerrors.Errorf("dont currently support paych lane merges") + } + + total := big.NewInt(0) + for _, ls := range laneStates { + total = big.Add(total, ls.Redeemed) + } + + lane, ok := laneStates[sv.Lane] + if ok { + // If the voucher is for an existing lane, and the voucher nonce + // and is higher than the lane nonce + if sv.Nonce > lane.Nonce { + // Add the delta between the redeemed amount and the voucher + // amount to the total + delta := big.Sub(sv.Amount, lane.Redeemed) + total = big.Add(total, delta) + } + } else { + // If the voucher is *not* for an existing lane, just add its + // value (implicitly a new lane will be created for the voucher) + total = big.Add(total, sv.Amount) + } + + return total, nil +} + +func (ca *channelAccessor) settle(ctx context.Context, ch address.Address) (cid.Cid, error) { + ca.lk.Lock() + defer ca.lk.Unlock() + + ci, err := ca.store.ByAddress(ch) + if err != nil { + return cid.Undef, err + } + + msg := &types.Message{ + To: ch, + From: ci.Control, + Value: types.NewInt(0), + Method: builtin.MethodsPaych.Settle, + } + smgs, err := ca.api.MpoolPushMessage(ctx, msg) + if err != nil { + return cid.Undef, err + } + + ci.Settling = true + err = ca.store.putChannelInfo(ci) + if err != nil { + log.Errorf("Error marking channel as settled: %s", err) + } + + return smgs.Cid(), err +} diff --git a/paychmgr/paych_test.go b/paychmgr/paych_test.go index 2cbea5cb5..9c28fdcb8 100644 --- a/paychmgr/paych_test.go +++ b/paychmgr/paych_test.go @@ -104,13 +104,15 @@ func TestPaychOutbound(t *testing.T) { LaneStates: []*paych.LaneState{}, }) - mgr := newManager(sm, store) - err := mgr.TrackOutboundChannel(ctx, ch) + mgr, err := newManager(sm, store, nil) + require.NoError(t, err) + + err = mgr.TrackOutboundChannel(ctx, ch) require.NoError(t, err) ci, err := mgr.GetChannelInfo(ch) require.NoError(t, err) - require.Equal(t, ci.Channel, ch) + require.Equal(t, *ci.Channel, ch) require.Equal(t, ci.Control, from) require.Equal(t, ci.Target, to) require.EqualValues(t, ci.Direction, DirOutbound) @@ -140,13 +142,15 @@ func TestPaychInbound(t *testing.T) { LaneStates: []*paych.LaneState{}, }) - mgr := newManager(sm, store) - err := mgr.TrackInboundChannel(ctx, ch) + mgr, err := newManager(sm, store, nil) + require.NoError(t, err) + + err = mgr.TrackInboundChannel(ctx, ch) require.NoError(t, err) ci, err := mgr.GetChannelInfo(ch) require.NoError(t, err) - require.Equal(t, ci.Channel, ch) + require.Equal(t, *ci.Channel, ch) require.Equal(t, ci.Control, to) require.Equal(t, ci.Target, from) require.EqualValues(t, ci.Direction, DirInbound) @@ -321,8 +325,10 @@ func TestCheckVoucherValid(t *testing.T) { LaneStates: tcase.laneStates, }) - mgr := newManager(sm, store) - err := mgr.TrackInboundChannel(ctx, ch) + mgr, err := newManager(sm, store, nil) + require.NoError(t, err) + + err = mgr.TrackInboundChannel(ctx, ch) require.NoError(t, err) sv := testCreateVoucher(t, ch, tcase.voucherLane, tcase.voucherNonce, tcase.voucherAmount, tcase.key) @@ -382,8 +388,10 @@ func TestCheckVoucherValidCountingAllLanes(t *testing.T) { LaneStates: laneStates, }) - mgr := newManager(sm, store) - err := mgr.TrackInboundChannel(ctx, ch) + mgr, err := newManager(sm, store, nil) + require.NoError(t, err) + + err = mgr.TrackInboundChannel(ctx, ch) require.NoError(t, err) // @@ -690,8 +698,10 @@ func testSetupMgrWithChannel(ctx context.Context, t *testing.T) (*Manager, addre }) store := NewStore(ds_sync.MutexWrap(ds.NewMapDatastore())) - mgr := newManager(sm, store) - err := mgr.TrackInboundChannel(ctx, ch) + mgr, err := newManager(sm, store, nil) + require.NoError(t, err) + + err = mgr.TrackInboundChannel(ctx, ch) require.NoError(t, err) return mgr, ch, fromKeyPrivate } diff --git a/paychmgr/paychget_test.go b/paychmgr/paychget_test.go new file mode 100644 index 000000000..c19fadb7f --- /dev/null +++ b/paychmgr/paychget_test.go @@ -0,0 +1,779 @@ +package paychmgr + +import ( + "context" + "sync" + "testing" + "time" + + cborrpc "github.com/filecoin-project/go-cbor-util" + + init_ "github.com/filecoin-project/specs-actors/actors/builtin/init" + + "github.com/filecoin-project/specs-actors/actors/builtin" + + "github.com/filecoin-project/lotus/api" + "github.com/filecoin-project/lotus/chain/types" + + "github.com/filecoin-project/go-address" + + "github.com/filecoin-project/specs-actors/actors/abi/big" + tutils "github.com/filecoin-project/specs-actors/support/testing" + "github.com/ipfs/go-cid" + ds "github.com/ipfs/go-datastore" + ds_sync "github.com/ipfs/go-datastore/sync" + + "github.com/stretchr/testify/require" +) + +type waitingCall struct { + response chan types.MessageReceipt +} + +type mockPaychAPI struct { + lk sync.Mutex + messages map[cid.Cid]*types.SignedMessage + waitingCalls map[cid.Cid]*waitingCall + responses map[cid.Cid]types.MessageReceipt +} + +func newMockPaychAPI() *mockPaychAPI { + return &mockPaychAPI{ + messages: make(map[cid.Cid]*types.SignedMessage), + waitingCalls: make(map[cid.Cid]*waitingCall), + responses: make(map[cid.Cid]types.MessageReceipt), + } +} + +func (pchapi *mockPaychAPI) StateWaitMsg(ctx context.Context, mcid cid.Cid, confidence uint64) (*api.MsgLookup, error) { + response := make(chan types.MessageReceipt) + + pchapi.lk.Lock() + + if receipt, ok := pchapi.responses[mcid]; ok { + defer pchapi.lk.Unlock() + + delete(pchapi.responses, mcid) + return &api.MsgLookup{Receipt: receipt}, nil + } + + pchapi.waitingCalls[mcid] = &waitingCall{response: response} + pchapi.lk.Unlock() + + receipt := <-response + return &api.MsgLookup{Receipt: receipt}, nil +} + +func (pchapi *mockPaychAPI) receiveMsgResponse(mcid cid.Cid, receipt types.MessageReceipt) { + pchapi.lk.Lock() + defer pchapi.lk.Unlock() + + if call, ok := pchapi.waitingCalls[mcid]; ok { + delete(pchapi.waitingCalls, mcid) + call.response <- receipt + return + } + + pchapi.responses[mcid] = receipt +} + +// Send success response for any waiting calls +func (pchapi *mockPaychAPI) close() { + pchapi.lk.Lock() + defer pchapi.lk.Unlock() + + success := types.MessageReceipt{ + ExitCode: 0, + Return: []byte{}, + } + for mcid, call := range pchapi.waitingCalls { + delete(pchapi.waitingCalls, mcid) + call.response <- success + } +} + +func (pchapi *mockPaychAPI) MpoolPushMessage(ctx context.Context, msg *types.Message) (*types.SignedMessage, error) { + pchapi.lk.Lock() + defer pchapi.lk.Unlock() + + smsg := &types.SignedMessage{Message: *msg} + pchapi.messages[smsg.Cid()] = smsg + return smsg, nil +} + +func (pchapi *mockPaychAPI) pushedMessages(c cid.Cid) *types.SignedMessage { + pchapi.lk.Lock() + defer pchapi.lk.Unlock() + + return pchapi.messages[c] +} + +func (pchapi *mockPaychAPI) pushedMessageCount() int { + pchapi.lk.Lock() + defer pchapi.lk.Unlock() + + return len(pchapi.messages) +} + +func testChannelResponse(t *testing.T, ch address.Address) types.MessageReceipt { + createChannelRet := init_.ExecReturn{ + IDAddress: ch, + RobustAddress: ch, + } + createChannelRetBytes, err := cborrpc.Dump(&createChannelRet) + require.NoError(t, err) + createChannelResponse := types.MessageReceipt{ + ExitCode: 0, + Return: createChannelRetBytes, + } + return createChannelResponse +} + +// TestPaychGetCreateChannelMsg tests that GetPaych sends a message to create +// a new channel with the correct funds +func TestPaychGetCreateChannelMsg(t *testing.T) { + ctx := context.Background() + store := NewStore(ds_sync.MutexWrap(ds.NewMapDatastore())) + + from := tutils.NewIDAddr(t, 101) + to := tutils.NewIDAddr(t, 102) + + sm := newMockStateManager() + pchapi := newMockPaychAPI() + defer pchapi.close() + + mgr, err := newManager(sm, store, pchapi) + require.NoError(t, err) + + amt := big.NewInt(10) + ch, mcid, err := mgr.GetPaych(ctx, from, to, amt) + require.NoError(t, err) + require.Equal(t, address.Undef, ch) + + pushedMsg := pchapi.pushedMessages(mcid) + require.Equal(t, from, pushedMsg.Message.From) + require.Equal(t, builtin.InitActorAddr, pushedMsg.Message.To) + require.Equal(t, amt, pushedMsg.Message.Value) +} + +// TestPaychGetCreateChannelThenAddFunds tests creating a channel and then +// adding funds to it +func TestPaychGetCreateChannelThenAddFunds(t *testing.T) { + ctx := context.Background() + store := NewStore(ds_sync.MutexWrap(ds.NewMapDatastore())) + + ch := tutils.NewIDAddr(t, 100) + from := tutils.NewIDAddr(t, 101) + to := tutils.NewIDAddr(t, 102) + + sm := newMockStateManager() + pchapi := newMockPaychAPI() + defer pchapi.close() + + mgr, err := newManager(sm, store, pchapi) + require.NoError(t, err) + + // Send create message for a channel with value 10 + amt := big.NewInt(10) + _, createMsgCid, err := mgr.GetPaych(ctx, from, to, amt) + require.NoError(t, err) + + // Should have no channels yet (message sent but channel not created) + cis, err := mgr.ListChannels() + require.NoError(t, err) + require.Len(t, cis, 0) + + // 1. Set up create channel response (sent in response to WaitForMsg()) + response := testChannelResponse(t, ch) + + done := make(chan struct{}) + go func() { + defer close(done) + + // 2. Request add funds - should block until create channel has completed + amt2 := big.NewInt(5) + ch2, addFundsMsgCid, err := mgr.GetPaych(ctx, from, to, amt2) + + // 4. This GetPaych should return after create channel from first + // GetPaych completes + require.NoError(t, err) + + // Expect the channel to be the same + require.Equal(t, ch, ch2) + // Expect add funds message CID to be different to create message cid + require.NotEqual(t, createMsgCid, addFundsMsgCid) + + // Should have one channel, whose address is the channel that was created + cis, err := mgr.ListChannels() + require.NoError(t, err) + require.Len(t, cis, 1) + require.Equal(t, ch, cis[0]) + + // Amount should be amount sent to first GetPaych (to create + // channel). + // PendingAmount should be amount sent in second GetPaych + // (second GetPaych triggered add funds, which has not yet been confirmed) + ci, err := mgr.GetChannelInfo(ch) + require.NoError(t, err) + require.EqualValues(t, 10, ci.Amount.Int64()) + require.EqualValues(t, 5, ci.PendingAmount.Int64()) + require.Nil(t, ci.CreateMsg) + + // Trigger add funds confirmation + pchapi.receiveMsgResponse(addFundsMsgCid, types.MessageReceipt{ExitCode: 0}) + + // Wait for add funds confirmation to be processed by manager + _, err = mgr.GetPaychWaitReady(ctx, addFundsMsgCid) + require.NoError(t, err) + + // Should still have one channel + cis, err = mgr.ListChannels() + require.NoError(t, err) + require.Len(t, cis, 1) + require.Equal(t, ch, cis[0]) + + // Channel amount should include last amount sent to GetPaych + ci, err = mgr.GetChannelInfo(ch) + require.NoError(t, err) + require.EqualValues(t, 15, ci.Amount.Int64()) + require.EqualValues(t, 0, ci.PendingAmount.Int64()) + require.Nil(t, ci.AddFundsMsg) + }() + + // Give the go routine above a moment to run + time.Sleep(time.Millisecond * 10) + + // 3. Send create channel response + pchapi.receiveMsgResponse(createMsgCid, response) + + <-done +} + +// TestPaychGetCreateChannelWithErrorThenCreateAgain tests that if an +// operation is queued up behind a create channel operation, and the create +// channel fails, then the waiting operation can succeed. +func TestPaychGetCreateChannelWithErrorThenCreateAgain(t *testing.T) { + ctx := context.Background() + store := NewStore(ds_sync.MutexWrap(ds.NewMapDatastore())) + + from := tutils.NewIDAddr(t, 101) + to := tutils.NewIDAddr(t, 102) + + sm := newMockStateManager() + pchapi := newMockPaychAPI() + defer pchapi.close() + + mgr, err := newManager(sm, store, pchapi) + require.NoError(t, err) + + // Send create message for a channel + amt := big.NewInt(10) + _, mcid1, err := mgr.GetPaych(ctx, from, to, amt) + require.NoError(t, err) + + // 1. Set up create channel response (sent in response to WaitForMsg()) + // This response indicates an error. + errResponse := types.MessageReceipt{ + ExitCode: 1, // error + Return: []byte{}, + } + + done := make(chan struct{}) + go func() { + defer close(done) + + // 2. Should block until create channel has completed. + // Because first channel create fails, this request + // should be for channel create again. + amt2 := big.NewInt(5) + ch2, mcid2, err := mgr.GetPaych(ctx, from, to, amt2) + require.NoError(t, err) + require.Equal(t, address.Undef, ch2) + + time.Sleep(time.Millisecond * 10) + + // 4. Send a success response + ch := tutils.NewIDAddr(t, 100) + successResponse := testChannelResponse(t, ch) + pchapi.receiveMsgResponse(mcid2, successResponse) + + time.Sleep(time.Millisecond * 10) + + // Should have one channel, whose address is the channel that was created + cis, err := mgr.ListChannels() + require.NoError(t, err) + require.Len(t, cis, 1) + require.Equal(t, ch, cis[0]) + + ci, err := mgr.GetChannelInfo(ch) + require.NoError(t, err) + require.Equal(t, amt2, ci.Amount) + }() + + // Give the go routine above a moment to run + time.Sleep(time.Millisecond * 10) + + // 3. Send error response to first channel create + pchapi.receiveMsgResponse(mcid1, errResponse) + + <-done +} + +// TestPaychGetRecoverAfterError tests that after a create channel fails, the +// next attempt to create channel can succeed. +func TestPaychGetRecoverAfterError(t *testing.T) { + ctx := context.Background() + store := NewStore(ds_sync.MutexWrap(ds.NewMapDatastore())) + + ch := tutils.NewIDAddr(t, 100) + from := tutils.NewIDAddr(t, 101) + to := tutils.NewIDAddr(t, 102) + + sm := newMockStateManager() + pchapi := newMockPaychAPI() + defer pchapi.close() + + mgr, err := newManager(sm, store, pchapi) + require.NoError(t, err) + + // Send create message for a channel + amt := big.NewInt(10) + _, mcid, err := mgr.GetPaych(ctx, from, to, amt) + require.NoError(t, err) + + time.Sleep(time.Millisecond * 10) + + // Send error create channel response + pchapi.receiveMsgResponse(mcid, types.MessageReceipt{ + ExitCode: 1, // error + Return: []byte{}, + }) + + time.Sleep(time.Millisecond * 10) + + // Send create message for a channel again + amt2 := big.NewInt(7) + _, mcid2, err := mgr.GetPaych(ctx, from, to, amt2) + require.NoError(t, err) + + time.Sleep(time.Millisecond * 10) + + // Send success create channel response + response := testChannelResponse(t, ch) + pchapi.receiveMsgResponse(mcid2, response) + + time.Sleep(time.Millisecond * 10) + + // Should have one channel, whose address is the channel that was created + cis, err := mgr.ListChannels() + require.NoError(t, err) + require.Len(t, cis, 1) + require.Equal(t, ch, cis[0]) + + ci, err := mgr.GetChannelInfo(ch) + require.NoError(t, err) + require.Equal(t, amt2, ci.Amount) + require.EqualValues(t, 0, ci.PendingAmount.Int64()) + require.Nil(t, ci.CreateMsg) +} + +// TestPaychGetRecoverAfterAddFundsError tests that after an add funds fails, the +// next attempt to add funds can succeed. +func TestPaychGetRecoverAfterAddFundsError(t *testing.T) { + ctx := context.Background() + store := NewStore(ds_sync.MutexWrap(ds.NewMapDatastore())) + + ch := tutils.NewIDAddr(t, 100) + from := tutils.NewIDAddr(t, 101) + to := tutils.NewIDAddr(t, 102) + + sm := newMockStateManager() + pchapi := newMockPaychAPI() + defer pchapi.close() + + mgr, err := newManager(sm, store, pchapi) + require.NoError(t, err) + + // Send create message for a channel + amt := big.NewInt(10) + _, mcid1, err := mgr.GetPaych(ctx, from, to, amt) + require.NoError(t, err) + + time.Sleep(time.Millisecond * 10) + + // Send success create channel response + response := testChannelResponse(t, ch) + pchapi.receiveMsgResponse(mcid1, response) + + time.Sleep(time.Millisecond * 10) + + // Send add funds message for channel + amt2 := big.NewInt(5) + _, mcid2, err := mgr.GetPaych(ctx, from, to, amt2) + require.NoError(t, err) + + time.Sleep(time.Millisecond * 10) + + // Send error add funds response + pchapi.receiveMsgResponse(mcid2, types.MessageReceipt{ + ExitCode: 1, // error + Return: []byte{}, + }) + + time.Sleep(time.Millisecond * 10) + + // Should have one channel, whose address is the channel that was created + cis, err := mgr.ListChannels() + require.NoError(t, err) + require.Len(t, cis, 1) + require.Equal(t, ch, cis[0]) + + ci, err := mgr.GetChannelInfo(ch) + require.NoError(t, err) + require.Equal(t, amt, ci.Amount) + require.EqualValues(t, 0, ci.PendingAmount.Int64()) + require.Nil(t, ci.CreateMsg) + require.Nil(t, ci.AddFundsMsg) + + // Send add funds message for channel again + amt3 := big.NewInt(2) + _, mcid3, err := mgr.GetPaych(ctx, from, to, amt3) + require.NoError(t, err) + + time.Sleep(time.Millisecond * 10) + + // Send success add funds response + pchapi.receiveMsgResponse(mcid3, types.MessageReceipt{ + ExitCode: 0, + Return: []byte{}, + }) + + time.Sleep(time.Millisecond * 10) + + // Should have one channel, whose address is the channel that was created + cis, err = mgr.ListChannels() + require.NoError(t, err) + require.Len(t, cis, 1) + require.Equal(t, ch, cis[0]) + + // Amount should include amount for successful add funds msg + ci, err = mgr.GetChannelInfo(ch) + require.NoError(t, err) + require.Equal(t, amt.Int64()+amt3.Int64(), ci.Amount.Int64()) + require.EqualValues(t, 0, ci.PendingAmount.Int64()) + require.Nil(t, ci.CreateMsg) + require.Nil(t, ci.AddFundsMsg) +} + +// TestPaychGetRestartAfterCreateChannelMsg tests that if the system stops +// right after the create channel message is sent, the channel will be +// created when the system restarts. +func TestPaychGetRestartAfterCreateChannelMsg(t *testing.T) { + ctx := context.Background() + store := NewStore(ds_sync.MutexWrap(ds.NewMapDatastore())) + + ch := tutils.NewIDAddr(t, 100) + from := tutils.NewIDAddr(t, 101) + to := tutils.NewIDAddr(t, 102) + + sm := newMockStateManager() + pchapi := newMockPaychAPI() + + mgr, err := newManager(sm, store, pchapi) + require.NoError(t, err) + + // Send create message for a channel with value 10 + amt := big.NewInt(10) + _, createMsgCid, err := mgr.GetPaych(ctx, from, to, amt) + require.NoError(t, err) + + // Simulate shutting down lotus + pchapi.close() + + // Create a new manager with the same datastore + sm2 := newMockStateManager() + pchapi2 := newMockPaychAPI() + defer pchapi2.close() + + mgr2, err := newManager(sm2, store, pchapi2) + require.NoError(t, err) + + // Should have no channels yet (message sent but channel not created) + cis, err := mgr2.ListChannels() + require.NoError(t, err) + require.Len(t, cis, 0) + + // 1. Set up create channel response (sent in response to WaitForMsg()) + response := testChannelResponse(t, ch) + + done := make(chan struct{}) + go func() { + defer close(done) + + // 2. Request add funds - should block until create channel has completed + amt2 := big.NewInt(5) + ch2, addFundsMsgCid, err := mgr2.GetPaych(ctx, from, to, amt2) + + // 4. This GetPaych should return after create channel from first + // GetPaych completes + require.NoError(t, err) + + // Expect the channel to have been created + require.Equal(t, ch, ch2) + // Expect add funds message CID to be different to create message cid + require.NotEqual(t, createMsgCid, addFundsMsgCid) + + // Should have one channel, whose address is the channel that was created + cis, err := mgr2.ListChannels() + require.NoError(t, err) + require.Len(t, cis, 1) + require.Equal(t, ch, cis[0]) + + // Amount should be amount sent to first GetPaych (to create + // channel). + // PendingAmount should be amount sent in second GetPaych + // (second GetPaych triggered add funds, which has not yet been confirmed) + ci, err := mgr2.GetChannelInfo(ch) + require.NoError(t, err) + require.EqualValues(t, 10, ci.Amount.Int64()) + require.EqualValues(t, 5, ci.PendingAmount.Int64()) + require.Nil(t, ci.CreateMsg) + }() + + // Give the go routine above a moment to run + time.Sleep(time.Millisecond * 10) + + // 3. Send create channel response + pchapi2.receiveMsgResponse(createMsgCid, response) + + <-done +} + +// TestPaychGetRestartAfterAddFundsMsg tests that if the system stops +// right after the add funds message is sent, the add funds will be +// processed when the system restarts. +func TestPaychGetRestartAfterAddFundsMsg(t *testing.T) { + ctx := context.Background() + store := NewStore(ds_sync.MutexWrap(ds.NewMapDatastore())) + + ch := tutils.NewIDAddr(t, 100) + from := tutils.NewIDAddr(t, 101) + to := tutils.NewIDAddr(t, 102) + + sm := newMockStateManager() + pchapi := newMockPaychAPI() + + mgr, err := newManager(sm, store, pchapi) + require.NoError(t, err) + + // Send create message for a channel + amt := big.NewInt(10) + _, mcid1, err := mgr.GetPaych(ctx, from, to, amt) + require.NoError(t, err) + + time.Sleep(time.Millisecond * 10) + + // Send success create channel response + response := testChannelResponse(t, ch) + pchapi.receiveMsgResponse(mcid1, response) + + time.Sleep(time.Millisecond * 10) + + // Send add funds message for channel + amt2 := big.NewInt(5) + _, mcid2, err := mgr.GetPaych(ctx, from, to, amt2) + require.NoError(t, err) + + time.Sleep(time.Millisecond * 10) + + // Simulate shutting down lotus + pchapi.close() + + // Create a new manager with the same datastore + sm2 := newMockStateManager() + pchapi2 := newMockPaychAPI() + defer pchapi2.close() + + time.Sleep(time.Millisecond * 10) + + mgr2, err := newManager(sm2, store, pchapi2) + require.NoError(t, err) + + // Send success add funds response + pchapi2.receiveMsgResponse(mcid2, types.MessageReceipt{ + ExitCode: 0, + Return: []byte{}, + }) + + time.Sleep(time.Millisecond * 10) + + // Should have one channel, whose address is the channel that was created + cis, err := mgr2.ListChannels() + require.NoError(t, err) + require.Len(t, cis, 1) + require.Equal(t, ch, cis[0]) + + // Amount should include amount for successful add funds msg + ci, err := mgr2.GetChannelInfo(ch) + require.NoError(t, err) + require.Equal(t, amt.Int64()+amt2.Int64(), ci.Amount.Int64()) + require.EqualValues(t, 0, ci.PendingAmount.Int64()) + require.Nil(t, ci.CreateMsg) + require.Nil(t, ci.AddFundsMsg) +} + +// TestPaychGetWait tests that GetPaychWaitReady correctly waits for the +// channel to be created or funds to be added +func TestPaychGetWait(t *testing.T) { + ctx := context.Background() + store := NewStore(ds_sync.MutexWrap(ds.NewMapDatastore())) + + from := tutils.NewIDAddr(t, 101) + to := tutils.NewIDAddr(t, 102) + + sm := newMockStateManager() + pchapi := newMockPaychAPI() + defer pchapi.close() + + mgr, err := newManager(sm, store, pchapi) + require.NoError(t, err) + + // 1. Get + amt := big.NewInt(10) + _, createMsgCid, err := mgr.GetPaych(ctx, from, to, amt) + require.NoError(t, err) + + done := make(chan address.Address) + go func() { + // 2. Wait till ready + ch, err := mgr.GetPaychWaitReady(ctx, createMsgCid) + require.NoError(t, err) + + done <- ch + }() + + time.Sleep(time.Millisecond * 10) + + // 3. Send response + expch := tutils.NewIDAddr(t, 100) + response := testChannelResponse(t, expch) + pchapi.receiveMsgResponse(createMsgCid, response) + + time.Sleep(time.Millisecond * 10) + + ch := <-done + require.Equal(t, expch, ch) + + // 4. Wait again - message has already been received so should + // return immediately + ch, err = mgr.GetPaychWaitReady(ctx, createMsgCid) + require.NoError(t, err) + require.Equal(t, expch, ch) + + // Request add funds + amt2 := big.NewInt(15) + _, addFundsMsgCid, err := mgr.GetPaych(ctx, from, to, amt2) + require.NoError(t, err) + + time.Sleep(time.Millisecond * 10) + + go func() { + // 5. Wait for add funds + ch, err := mgr.GetPaychWaitReady(ctx, addFundsMsgCid) + require.NoError(t, err) + require.Equal(t, expch, ch) + + done <- ch + }() + + time.Sleep(time.Millisecond * 10) + + // 6. Send add funds response + addFundsResponse := types.MessageReceipt{ + ExitCode: 0, + Return: []byte{}, + } + pchapi.receiveMsgResponse(addFundsMsgCid, addFundsResponse) + + <-done +} + +// TestPaychGetWaitErr tests that GetPaychWaitReady correctly handles errors +func TestPaychGetWaitErr(t *testing.T) { + ctx := context.Background() + store := NewStore(ds_sync.MutexWrap(ds.NewMapDatastore())) + + from := tutils.NewIDAddr(t, 101) + to := tutils.NewIDAddr(t, 102) + + sm := newMockStateManager() + pchapi := newMockPaychAPI() + defer pchapi.close() + + mgr, err := newManager(sm, store, pchapi) + require.NoError(t, err) + + // 1. Create channel + amt := big.NewInt(10) + _, mcid, err := mgr.GetPaych(ctx, from, to, amt) + require.NoError(t, err) + + done := make(chan address.Address) + go func() { + defer close(done) + + // 2. Wait for channel to be created + _, err := mgr.GetPaychWaitReady(ctx, mcid) + + // 4. Channel creation should have failed + require.NotNil(t, err) + + // 5. Call wait again with the same message CID + _, err = mgr.GetPaychWaitReady(ctx, mcid) + + // 6. Should return immediately with the same error + require.NotNil(t, err) + }() + + // Give the wait a moment to start before sending response + time.Sleep(time.Millisecond * 10) + + // 3. Send error response to create channel + response := types.MessageReceipt{ + ExitCode: 1, // error + Return: []byte{}, + } + pchapi.receiveMsgResponse(mcid, response) + + <-done +} + +// TestPaychGetWaitCtx tests that GetPaychWaitReady returns early if the context +// is cancelled +func TestPaychGetWaitCtx(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + store := NewStore(ds_sync.MutexWrap(ds.NewMapDatastore())) + + from := tutils.NewIDAddr(t, 101) + to := tutils.NewIDAddr(t, 102) + + sm := newMockStateManager() + pchapi := newMockPaychAPI() + defer pchapi.close() + + mgr, err := newManager(sm, store, pchapi) + require.NoError(t, err) + + amt := big.NewInt(10) + _, mcid, err := mgr.GetPaych(ctx, from, to, amt) + require.NoError(t, err) + + // When the context is cancelled, should unblock wait + go func() { + time.Sleep(time.Millisecond * 10) + cancel() + }() + + _, err = mgr.GetPaychWaitReady(ctx, mcid) + require.Error(t, ctx.Err(), err) +} diff --git a/paychmgr/settle_test.go b/paychmgr/settle_test.go new file mode 100644 index 000000000..292b139ea --- /dev/null +++ b/paychmgr/settle_test.go @@ -0,0 +1,77 @@ +package paychmgr + +import ( + "context" + "testing" + "time" + + "github.com/ipfs/go-cid" + + "github.com/filecoin-project/specs-actors/actors/abi/big" + tutils "github.com/filecoin-project/specs-actors/support/testing" + ds "github.com/ipfs/go-datastore" + ds_sync "github.com/ipfs/go-datastore/sync" + "github.com/stretchr/testify/require" +) + +func TestPaychSettle(t *testing.T) { + ctx := context.Background() + store := NewStore(ds_sync.MutexWrap(ds.NewMapDatastore())) + + expch := tutils.NewIDAddr(t, 100) + expch2 := tutils.NewIDAddr(t, 101) + from := tutils.NewIDAddr(t, 101) + to := tutils.NewIDAddr(t, 102) + + sm := newMockStateManager() + pchapi := newMockPaychAPI() + defer pchapi.close() + + mgr, err := newManager(sm, store, pchapi) + require.NoError(t, err) + + amt := big.NewInt(10) + _, mcid, err := mgr.GetPaych(ctx, from, to, amt) + require.NoError(t, err) + + time.Sleep(10 * time.Millisecond) + + // Send channel create response + response := testChannelResponse(t, expch) + pchapi.receiveMsgResponse(mcid, response) + + // Get the channel address + ch, err := mgr.GetPaychWaitReady(ctx, mcid) + require.NoError(t, err) + require.Equal(t, expch, ch) + + // Settle the channel + _, err = mgr.Settle(ctx, ch) + require.NoError(t, err) + + // Send another request for funds to the same from/to + // (should create a new channel because the previous channel + // is settling) + amt2 := big.NewInt(5) + _, mcid2, err := mgr.GetPaych(ctx, from, to, amt2) + require.NoError(t, err) + require.NotEqual(t, cid.Undef, mcid2) + + time.Sleep(10 * time.Millisecond) + + // Send new channel create response + response2 := testChannelResponse(t, expch2) + pchapi.receiveMsgResponse(mcid2, response2) + + time.Sleep(10 * time.Millisecond) + + // Make sure the new channel is different from the old channel + ch2, err := mgr.GetPaychWaitReady(ctx, mcid2) + require.NoError(t, err) + require.NotEqual(t, ch, ch2) + + // There should now be two channels + cis, err := mgr.ListChannels() + require.NoError(t, err) + require.Len(t, cis, 2) +} diff --git a/paychmgr/simple.go b/paychmgr/simple.go index 0d0075d62..67b5a4f41 100644 --- a/paychmgr/simple.go +++ b/paychmgr/simple.go @@ -3,6 +3,13 @@ package paychmgr import ( "bytes" "context" + "fmt" + + "golang.org/x/sync/errgroup" + + "github.com/filecoin-project/lotus/api" + + "github.com/filecoin-project/specs-actors/actors/abi/big" "github.com/filecoin-project/specs-actors/actors/builtin" init_ "github.com/filecoin-project/specs-actors/actors/builtin/init" @@ -17,7 +24,209 @@ import ( "github.com/filecoin-project/lotus/chain/types" ) -func (pm *Manager) createPaych(ctx context.Context, from, to address.Address, amt types.BigInt) (cid.Cid, error) { +type paychApi interface { + StateWaitMsg(ctx context.Context, msg cid.Cid, confidence uint64) (*api.MsgLookup, error) + MpoolPushMessage(ctx context.Context, msg *types.Message) (*types.SignedMessage, error) +} + +// paychFundsRes is the response to a create channel or add funds request +type paychFundsRes struct { + channel address.Address + mcid cid.Cid + err error +} + +type onCompleteFn func(*paychFundsRes) + +// fundsReq is a request to create a channel or add funds to a channel +type fundsReq struct { + ctx context.Context + from address.Address + to address.Address + amt types.BigInt + onComplete onCompleteFn +} + +// getPaych ensures that a channel exists between the from and to addresses, +// and adds the given amount of funds. +// If the channel does not exist a create channel message is sent and the +// message CID is returned. +// If the channel does exist an add funds message is sent and both the channel +// address and message CID are returned. +// If there is an in progress operation (create channel / add funds), getPaych +// blocks until the previous operation completes, then returns both the channel +// address and the CID of the new add funds message. +// If an operation returns an error, subsequent waiting operations will still +// be attempted. +func (ca *channelAccessor) getPaych(ctx context.Context, from, to address.Address, amt types.BigInt) (address.Address, cid.Cid, error) { + // Add the request to add funds to a queue and wait for the result + promise := ca.enqueue(&fundsReq{ctx: ctx, from: from, to: to, amt: amt}) + select { + case res := <-promise: + return res.channel, res.mcid, res.err + case <-ctx.Done(): + return address.Undef, cid.Undef, ctx.Err() + } +} + +// Queue up an add funds operation +func (ca *channelAccessor) enqueue(task *fundsReq) chan *paychFundsRes { + promise := make(chan *paychFundsRes) + task.onComplete = func(res *paychFundsRes) { + select { + case <-task.ctx.Done(): + case promise <- res: + } + } + + ca.lk.Lock() + defer ca.lk.Unlock() + + ca.fundsReqQueue = append(ca.fundsReqQueue, task) + go ca.processNextQueueItem() + + return promise +} + +// Run the operation at the head of the queue +func (ca *channelAccessor) processNextQueueItem() { + ca.lk.Lock() + defer ca.lk.Unlock() + + if len(ca.fundsReqQueue) == 0 { + return + } + + head := ca.fundsReqQueue[0] + res := ca.processTask(head.ctx, head.from, head.to, head.amt, head.onComplete) + + // If the task is waiting on an external event (eg something to appear on + // chain) it will return nil + if res == nil { + // Stop processing the fundsReqQueue and wait. When the event occurs it will + // call processNextQueueItem() again + return + } + + // The task has finished processing so clean it up + ca.fundsReqQueue[0] = nil // allow GC of element + ca.fundsReqQueue = ca.fundsReqQueue[1:] + + // Call the task callback with its results + head.onComplete(res) + + // Process the next task + if len(ca.fundsReqQueue) > 0 { + go ca.processNextQueueItem() + } +} + +// msgWaitComplete is called when the message for a previous task is confirmed +// or there is an error. +func (ca *channelAccessor) msgWaitComplete(mcid cid.Cid, err error, cb onCompleteFn) { + ca.lk.Lock() + defer ca.lk.Unlock() + + // Save the message result to the store + dserr := ca.store.SaveMessageResult(mcid, err) + if dserr != nil { + log.Errorf("saving message result: %s", dserr) + } + + // Call the onComplete callback + ca.callOnComplete(mcid, err, cb) + + // Inform listeners that the message has completed + ca.msgListeners.fireMsgComplete(mcid, err) + + // The queue may have been waiting for msg completion to proceed, so + // process the next queue item + if len(ca.fundsReqQueue) > 0 { + go ca.processNextQueueItem() + } +} + +// callOnComplete calls the onComplete callback for a task +func (ca *channelAccessor) callOnComplete(mcid cid.Cid, err error, cb onCompleteFn) { + if cb == nil { + return + } + + if err != nil { + go cb(&paychFundsRes{err: err}) + return + } + + // Get the channel address + ci, storeErr := ca.store.ByMessageCid(mcid) + if storeErr != nil { + log.Errorf("getting channel by message cid: %s", err) + go cb(&paychFundsRes{err: storeErr}) + return + } + + if ci.Channel == nil { + panic("channel address is nil when calling onComplete callback") + } + + go cb(&paychFundsRes{channel: *ci.Channel, mcid: mcid, err: err}) +} + +// processTask checks the state of the channel and takes appropriate action +// (see description of getPaych). +// Note that processTask may be called repeatedly in the same state, and should +// return nil if there is no state change to be made (eg when waiting for a +// message to be confirmed on chain) +func (ca *channelAccessor) processTask( + ctx context.Context, + from address.Address, + to address.Address, + amt types.BigInt, + onComplete onCompleteFn, +) *paychFundsRes { + // Get the payment channel for the from/to addresses. + // Note: It's ok if we get ErrChannelNotTracked. It just means we need to + // create a channel. + channelInfo, err := ca.store.OutboundActiveByFromTo(from, to) + if err != nil && err != ErrChannelNotTracked { + return &paychFundsRes{err: err} + } + + // If a channel has not yet been created, create one. + if channelInfo == nil { + mcid, err := ca.createPaych(ctx, from, to, amt, onComplete) + if err != nil { + return &paychFundsRes{err: err} + } + + return &paychFundsRes{mcid: mcid} + } + + // If the create channel message has been sent but the channel hasn't + // been created on chain yet + if channelInfo.CreateMsg != nil { + // Wait for the channel to be created before trying again + return nil + } + + // If an add funds message was sent to the chain but hasn't been confirmed + // on chain yet + if channelInfo.AddFundsMsg != nil { + // Wait for the add funds message to be confirmed before trying again + return nil + } + + // We need to add more funds, so send an add funds message to + // cover the amount for this request + mcid, err := ca.addFunds(ctx, channelInfo, amt, onComplete) + if err != nil { + return &paychFundsRes{err: err} + } + return &paychFundsRes{channel: *channelInfo.Channel, mcid: *mcid} +} + +// createPaych sends a message to create the channel and returns the message cid +func (ca *channelAccessor) createPaych(ctx context.Context, from, to address.Address, amt types.BigInt, cb onCompleteFn) (cid.Cid, error) { params, aerr := actors.SerializeParams(&paych.ConstructorParams{From: from, To: to}) if aerr != nil { return cid.Undef, aerr @@ -41,106 +250,303 @@ func (pm *Manager) createPaych(ctx context.Context, from, to address.Address, am GasPrice: types.NewInt(0), } - smsg, err := pm.mpool.MpoolPushMessage(ctx, msg) + smsg, err := ca.api.MpoolPushMessage(ctx, msg) if err != nil { return cid.Undef, xerrors.Errorf("initializing paych actor: %w", err) } mcid := smsg.Cid() - go pm.waitForPaychCreateMsg(ctx, mcid) + + // Create a new channel in the store + ci, err := ca.store.CreateChannel(from, to, mcid, amt) + if err != nil { + log.Errorf("creating channel: %s", err) + return cid.Undef, err + } + + // Wait for the channel to be created on chain + go ca.waitForPaychCreateMsg(ci.ChannelID, mcid, cb) + return mcid, nil } -// WaitForPaychCreateMsg waits for mcid to appear on chain and returns the robust address of the +// waitForPaychCreateMsg waits for mcid to appear on chain and stores the robust address of the // created payment channel -// TODO: wait outside the store lock! -// (tricky because we need to setup channel tracking before we know its address) -func (pm *Manager) waitForPaychCreateMsg(ctx context.Context, mcid cid.Cid) { - defer pm.store.lk.Unlock() - mwait, err := pm.state.StateWaitMsg(ctx, mcid, build.MessageConfidence) +func (ca *channelAccessor) waitForPaychCreateMsg(channelID string, mcid cid.Cid, cb onCompleteFn) { + err := ca.waitPaychCreateMsg(channelID, mcid) + ca.msgWaitComplete(mcid, err, cb) +} + +func (ca *channelAccessor) waitPaychCreateMsg(channelID string, mcid cid.Cid) error { + mwait, err := ca.api.StateWaitMsg(ca.waitCtx, mcid, build.MessageConfidence) if err != nil { log.Errorf("wait msg: %w", err) - return + return err } + // If channel creation failed if mwait.Receipt.ExitCode != 0 { - log.Errorf("payment channel creation failed (exit code %d)", mwait.Receipt.ExitCode) - return + ca.lk.Lock() + defer ca.lk.Unlock() + + // Channel creation failed, so remove the channel from the datastore + dserr := ca.store.RemoveChannel(channelID) + if dserr != nil { + log.Errorf("failed to remove channel %s: %s", channelID, dserr) + } + + err := xerrors.Errorf("payment channel creation failed (exit code %d)", mwait.Receipt.ExitCode) + log.Error(err) + return err } var decodedReturn init_.ExecReturn err = decodedReturn.UnmarshalCBOR(bytes.NewReader(mwait.Receipt.Return)) if err != nil { log.Error(err) - return - } - paychaddr := decodedReturn.RobustAddress - - ci, err := pm.loadStateChannelInfo(ctx, paychaddr, DirOutbound) - if err != nil { - log.Errorf("loading channel info: %w", err) - return + return err } - if err := pm.store.trackChannel(ci); err != nil { - log.Errorf("tracking channel: %w", err) - } + ca.lk.Lock() + defer ca.lk.Unlock() + + // Store robust address of channel + ca.mutateChannelInfo(channelID, func(channelInfo *ChannelInfo) { + channelInfo.Channel = &decodedReturn.RobustAddress + channelInfo.Amount = channelInfo.PendingAmount + channelInfo.PendingAmount = big.NewInt(0) + channelInfo.CreateMsg = nil + }) + + return nil } -func (pm *Manager) addFunds(ctx context.Context, ch address.Address, from address.Address, amt types.BigInt) (cid.Cid, error) { +// addFunds sends a message to add funds to the channel and returns the message cid +func (ca *channelAccessor) addFunds(ctx context.Context, channelInfo *ChannelInfo, amt types.BigInt, cb onCompleteFn) (*cid.Cid, error) { msg := &types.Message{ - To: ch, - From: from, + To: *channelInfo.Channel, + From: channelInfo.Control, Value: amt, Method: 0, GasLimit: 0, GasPrice: types.NewInt(0), } - smsg, err := pm.mpool.MpoolPushMessage(ctx, msg) + smsg, err := ca.api.MpoolPushMessage(ctx, msg) if err != nil { - return cid.Undef, err + return nil, err } mcid := smsg.Cid() - go pm.waitForAddFundsMsg(ctx, mcid) - return mcid, nil + + // Store the add funds message CID on the channel + ca.mutateChannelInfo(channelInfo.ChannelID, func(ci *ChannelInfo) { + ci.PendingAmount = amt + ci.AddFundsMsg = &mcid + }) + + // Store a reference from the message CID to the channel, so that we can + // look up the channel from the message CID + err = ca.store.SaveNewMessage(channelInfo.ChannelID, mcid) + if err != nil { + log.Errorf("saving add funds message CID %s: %s", mcid, err) + } + + go ca.waitForAddFundsMsg(channelInfo.ChannelID, mcid, cb) + + return &mcid, nil } -// WaitForAddFundsMsg waits for mcid to appear on chain and returns error, if any -// TODO: wait outside the store lock! -// (tricky because we need to setup channel tracking before we know it's address) -func (pm *Manager) waitForAddFundsMsg(ctx context.Context, mcid cid.Cid) { - defer pm.store.lk.Unlock() - mwait, err := pm.state.StateWaitMsg(ctx, mcid, build.MessageConfidence) +// waitForAddFundsMsg waits for mcid to appear on chain and returns error, if any +func (ca *channelAccessor) waitForAddFundsMsg(channelID string, mcid cid.Cid, cb onCompleteFn) { + err := ca.waitAddFundsMsg(channelID, mcid) + ca.msgWaitComplete(mcid, err, cb) +} + +func (ca *channelAccessor) waitAddFundsMsg(channelID string, mcid cid.Cid) error { + mwait, err := ca.api.StateWaitMsg(ca.waitCtx, mcid, build.MessageConfidence) if err != nil { log.Error(err) + return err } if mwait.Receipt.ExitCode != 0 { - log.Errorf("voucher channel creation failed: adding funds (exit code %d)", mwait.Receipt.ExitCode) + err := xerrors.Errorf("voucher channel creation failed: adding funds (exit code %d)", mwait.Receipt.ExitCode) + log.Error(err) + + ca.lk.Lock() + defer ca.lk.Unlock() + + ca.mutateChannelInfo(channelID, func(channelInfo *ChannelInfo) { + channelInfo.PendingAmount = big.NewInt(0) + channelInfo.AddFundsMsg = nil + }) + + return err + } + + ca.lk.Lock() + defer ca.lk.Unlock() + + // Store updated amount + ca.mutateChannelInfo(channelID, func(channelInfo *ChannelInfo) { + channelInfo.Amount = types.BigAdd(channelInfo.Amount, channelInfo.PendingAmount) + channelInfo.PendingAmount = big.NewInt(0) + channelInfo.AddFundsMsg = nil + }) + + return nil +} + +// Change the state of the channel in the store +func (ca *channelAccessor) mutateChannelInfo(channelID string, mutate func(*ChannelInfo)) { + channelInfo, err := ca.store.ByChannelID(channelID) + + // If there's an error reading or writing to the store just log an error. + // For now we're assuming it's unlikely to happen in practice. + // Later we may want to implement a transactional approach, whereby + // we record to the store that we're going to send a message, send + // the message, and then record that the message was sent. + if err != nil { + log.Errorf("Error reading channel info from store: %s", err) + } + + mutate(channelInfo) + + err = ca.store.putChannelInfo(channelInfo) + if err != nil { + log.Errorf("Error writing channel info to store: %s", err) } } -func (pm *Manager) GetPaych(ctx context.Context, from, to address.Address, ensureFree types.BigInt) (address.Address, cid.Cid, error) { - pm.store.lk.Lock() // unlock only on err; wait funcs will defer unlock - var mcid cid.Cid - ch, err := pm.store.findChan(func(ci *ChannelInfo) bool { - if ci.Direction != DirOutbound { - return false +// restartPending checks the datastore to see if there are any channels that +// have outstanding create / add funds messages, and if so, waits on the +// messages. +// Outstanding messages can occur if a create / add funds message was sent and +// then the system was shut down or crashed before the result was received. +func (pm *Manager) restartPending() error { + cis, err := pm.store.WithPendingAddFunds() + if err != nil { + return err + } + + group := errgroup.Group{} + for _, chanInfo := range cis { + ci := chanInfo + if ci.CreateMsg != nil { + group.Go(func() error { + ca, err := pm.accessorByFromTo(ci.Control, ci.Target) + if err != nil { + return xerrors.Errorf("error initializing payment channel manager %s -> %s: %s", ci.Control, ci.Target, err) + } + go ca.waitForPaychCreateMsg(ci.ChannelID, *ci.CreateMsg, nil) + return nil + }) + } else if ci.AddFundsMsg != nil { + group.Go(func() error { + ca, err := pm.accessorByAddress(*ci.Channel) + if err != nil { + return xerrors.Errorf("error initializing payment channel manager %s: %s", ci.Channel, err) + } + go ca.waitForAddFundsMsg(ci.ChannelID, *ci.AddFundsMsg, nil) + return nil + }) } - return ci.Control == from && ci.Target == to - }) - if err != nil { - pm.store.lk.Unlock() - return address.Undef, cid.Undef, xerrors.Errorf("findChan: %w", err) } - if ch != address.Undef { - // TODO: Track available funds - mcid, err = pm.addFunds(ctx, ch, from, ensureFree) - } else { - mcid, err = pm.createPaych(ctx, from, to, ensureFree) - } - if err != nil { - pm.store.lk.Unlock() - } - return ch, mcid, err + + return group.Wait() +} + +// getPaychWaitReady waits for a the response to the message with the given cid +func (ca *channelAccessor) getPaychWaitReady(ctx context.Context, mcid cid.Cid) (address.Address, error) { + ca.lk.Lock() + + // First check if the message has completed + msgInfo, err := ca.store.GetMessage(mcid) + if err != nil { + ca.lk.Unlock() + + return address.Undef, err + } + + // If the create channel / add funds message failed, return an error + if len(msgInfo.Err) > 0 { + ca.lk.Unlock() + + return address.Undef, xerrors.New(msgInfo.Err) + } + + // If the message has completed successfully + if msgInfo.Received { + ca.lk.Unlock() + + // Get the channel address + ci, err := ca.store.ByMessageCid(mcid) + if err != nil { + return address.Undef, err + } + + if ci.Channel == nil { + panic(fmt.Sprintf("create / add funds message %s succeeded but channelInfo.Channel is nil", mcid)) + } + return *ci.Channel, nil + } + + // The message hasn't completed yet so wait for it to complete + promise := ca.msgPromise(ctx, mcid) + + // Unlock while waiting + ca.lk.Unlock() + + select { + case res := <-promise: + return res.channel, res.err + case <-ctx.Done(): + return address.Undef, ctx.Err() + } +} + +type onMsgRes struct { + channel address.Address + err error +} + +// msgPromise returns a channel that receives the result of the message with +// the given CID +func (ca *channelAccessor) msgPromise(ctx context.Context, mcid cid.Cid) chan onMsgRes { + promise := make(chan onMsgRes) + triggerUnsub := make(chan struct{}) + sub := ca.msgListeners.onMsg(mcid, func(err error) { + close(triggerUnsub) + + // Use a go-routine so as not to block the event handler loop + go func() { + res := onMsgRes{err: err} + if res.err == nil { + // Get the channel associated with the message cid + ci, err := ca.store.ByMessageCid(mcid) + if err != nil { + res.err = err + } else { + res.channel = *ci.Channel + } + } + + // Pass the result to the caller + select { + case promise <- res: + case <-ctx.Done(): + } + }() + }) + + // Unsubscribe when the message is received or the context is done + go func() { + select { + case <-ctx.Done(): + case <-triggerUnsub: + } + + ca.msgListeners.unsubscribe(sub) + }() + + return promise } diff --git a/paychmgr/state.go b/paychmgr/state.go index 7d06a35a4..9ba2740e6 100644 --- a/paychmgr/state.go +++ b/paychmgr/state.go @@ -3,20 +3,21 @@ package paychmgr import ( "context" - "github.com/filecoin-project/specs-actors/actors/abi/big" - "github.com/filecoin-project/specs-actors/actors/builtin/account" "github.com/filecoin-project/go-address" "github.com/filecoin-project/specs-actors/actors/builtin/paych" - xerrors "golang.org/x/xerrors" "github.com/filecoin-project/lotus/chain/types" ) -func (pm *Manager) loadPaychState(ctx context.Context, ch address.Address) (*types.Actor, *paych.State, error) { +type stateAccessor struct { + sm StateManagerApi +} + +func (ca *stateAccessor) loadPaychState(ctx context.Context, ch address.Address) (*types.Actor, *paych.State, error) { var pcast paych.State - act, err := pm.sm.LoadActorState(ctx, ch, &pcast, nil) + act, err := ca.sm.LoadActorState(ctx, ch, &pcast, nil) if err != nil { return nil, nil, err } @@ -24,26 +25,26 @@ func (pm *Manager) loadPaychState(ctx context.Context, ch address.Address) (*typ return act, &pcast, nil } -func (pm *Manager) loadStateChannelInfo(ctx context.Context, ch address.Address, dir uint64) (*ChannelInfo, error) { - _, st, err := pm.loadPaychState(ctx, ch) +func (ca *stateAccessor) loadStateChannelInfo(ctx context.Context, ch address.Address, dir uint64) (*ChannelInfo, error) { + _, st, err := ca.loadPaychState(ctx, ch) if err != nil { return nil, err } var account account.State - _, err = pm.sm.LoadActorState(ctx, st.From, &account, nil) + _, err = ca.sm.LoadActorState(ctx, st.From, &account, nil) if err != nil { return nil, err } from := account.Address - _, err = pm.sm.LoadActorState(ctx, st.To, &account, nil) + _, err = ca.sm.LoadActorState(ctx, st.To, &account, nil) if err != nil { return nil, err } to := account.Address ci := &ChannelInfo{ - Channel: ch, + Channel: &ch, Direction: dir, NextLane: nextLaneFromState(st), } @@ -72,80 +73,3 @@ func nextLaneFromState(st *paych.State) uint64 { } return maxLane + 1 } - -// laneState gets the LaneStates from chain, then applies all vouchers in -// the data store over the chain state -func (pm *Manager) laneState(state *paych.State, ch address.Address) (map[uint64]*paych.LaneState, error) { - // TODO: we probably want to call UpdateChannelState with all vouchers to be fully correct - // (but technically dont't need to) - laneStates := make(map[uint64]*paych.LaneState, len(state.LaneStates)) - - // Get the lane state from the chain - for _, laneState := range state.LaneStates { - laneStates[laneState.ID] = laneState - } - - // Apply locally stored vouchers - vouchers, err := pm.store.VouchersForPaych(ch) - if err != nil && err != ErrChannelNotTracked { - return nil, err - } - - for _, v := range vouchers { - for range v.Voucher.Merges { - return nil, xerrors.Errorf("paych merges not handled yet") - } - - // If there's a voucher for a lane that isn't in chain state just - // create it - ls, ok := laneStates[v.Voucher.Lane] - if !ok { - ls = &paych.LaneState{ - ID: v.Voucher.Lane, - Redeemed: types.NewInt(0), - Nonce: 0, - } - laneStates[v.Voucher.Lane] = ls - } - - if v.Voucher.Nonce < ls.Nonce { - continue - } - - ls.Nonce = v.Voucher.Nonce - ls.Redeemed = v.Voucher.Amount - } - - return laneStates, nil -} - -// Get the total redeemed amount across all lanes, after applying the voucher -func (pm *Manager) totalRedeemedWithVoucher(laneStates map[uint64]*paych.LaneState, sv *paych.SignedVoucher) (big.Int, error) { - // TODO: merges - if len(sv.Merges) != 0 { - return big.Int{}, xerrors.Errorf("dont currently support paych lane merges") - } - - total := big.NewInt(0) - for _, ls := range laneStates { - total = big.Add(total, ls.Redeemed) - } - - lane, ok := laneStates[sv.Lane] - if ok { - // If the voucher is for an existing lane, and the voucher nonce - // and is higher than the lane nonce - if sv.Nonce > lane.Nonce { - // Add the delta between the redeemed amount and the voucher - // amount to the total - delta := big.Sub(sv.Amount, lane.Redeemed) - total = big.Add(total, delta) - } - } else { - // If the voucher is *not* for an existing lane, just add its - // value (implicitly a new lane will be created for the voucher) - total = big.Add(total, sv.Amount) - } - - return total, nil -} diff --git a/paychmgr/store.go b/paychmgr/store.go index 66a514feb..d7c6e82e7 100644 --- a/paychmgr/store.go +++ b/paychmgr/store.go @@ -4,14 +4,16 @@ import ( "bytes" "errors" "fmt" - "strings" - "sync" + + "github.com/google/uuid" + + "github.com/filecoin-project/lotus/chain/types" "github.com/filecoin-project/specs-actors/actors/builtin/paych" + "github.com/ipfs/go-cid" "github.com/ipfs/go-datastore" "github.com/ipfs/go-datastore/namespace" dsq "github.com/ipfs/go-datastore/query" - "golang.org/x/xerrors" "github.com/filecoin-project/go-address" cborrpc "github.com/filecoin-project/go-cbor-util" @@ -22,8 +24,6 @@ import ( var ErrChannelNotTracked = errors.New("channel not tracked") type Store struct { - lk sync.Mutex // TODO: this can be split per paych - ds datastore.Batching } @@ -39,85 +39,107 @@ const ( DirOutbound = 2 ) +const ( + dsKeyChannelInfo = "ChannelInfo" + dsKeyMsgCid = "MsgCid" +) + type VoucherInfo struct { Voucher *paych.SignedVoucher Proof []byte } +// ChannelInfo keeps track of information about a channel type ChannelInfo struct { - Channel address.Address + // ChannelID is a uuid set at channel creation + ChannelID string + // Channel address - may be nil if the channel hasn't been created yet + Channel *address.Address + // Control is the address of the account that created the channel Control address.Address - Target address.Address - + // Target is the address of the account on the other end of the channel + Target address.Address + // Direction indicates if the channel is inbound (this node is the Target) + // or outbound (this node is the Control) Direction uint64 - Vouchers []*VoucherInfo - NextLane uint64 + // Vouchers is a list of all vouchers sent on the channel + Vouchers []*VoucherInfo + // NextLane is the number of the next lane that should be used when the + // client requests a new lane (eg to create a voucher for a new deal) + NextLane uint64 + // Amount added to the channel. + // Note: This amount is only used by GetPaych to keep track of how much + // has locally been added to the channel. It should reflect the channel's + // Balance on chain as long as all operations occur on the same datastore. + Amount types.BigInt + // PendingAmount is the amount that we're awaiting confirmation of + PendingAmount types.BigInt + // CreateMsg is the CID of a pending create message (while waiting for confirmation) + CreateMsg *cid.Cid + // AddFundsMsg is the CID of a pending add funds message (while waiting for confirmation) + AddFundsMsg *cid.Cid + // Settling indicates whether the channel has entered into the settling state + Settling bool } -func dskeyForChannel(addr address.Address) datastore.Key { - return datastore.NewKey(addr.String()) -} - -func (ps *Store) putChannelInfo(ci *ChannelInfo) error { - k := dskeyForChannel(ci.Channel) - - b, err := cborrpc.Dump(ci) - if err != nil { - return err - } - - return ps.ds.Put(k, b) -} - -func (ps *Store) getChannelInfo(addr address.Address) (*ChannelInfo, error) { - k := dskeyForChannel(addr) - - b, err := ps.ds.Get(k) - if err == datastore.ErrNotFound { - return nil, ErrChannelNotTracked - } - if err != nil { - return nil, err - } - - var ci ChannelInfo - if err := ci.UnmarshalCBOR(bytes.NewReader(b)); err != nil { - return nil, err - } - - return &ci, nil -} - -func (ps *Store) TrackChannel(ch *ChannelInfo) error { - ps.lk.Lock() - defer ps.lk.Unlock() - - return ps.trackChannel(ch) -} - -func (ps *Store) trackChannel(ch *ChannelInfo) error { - _, err := ps.getChannelInfo(ch.Channel) +// TrackChannel stores a channel, returning an error if the channel was already +// being tracked +func (ps *Store) TrackChannel(ci *ChannelInfo) error { + _, err := ps.ByAddress(*ci.Channel) switch err { default: return err case nil: - return fmt.Errorf("already tracking channel: %s", ch.Channel) + return fmt.Errorf("already tracking channel: %s", ci.Channel) case ErrChannelNotTracked: - return ps.putChannelInfo(ch) + return ps.putChannelInfo(ci) } } +// ListChannels returns the addresses of all channels that have been created func (ps *Store) ListChannels() ([]address.Address, error) { - ps.lk.Lock() - defer ps.lk.Unlock() + cis, err := ps.findChans(func(ci *ChannelInfo) bool { + return ci.Channel != nil + }, 0) + if err != nil { + return nil, err + } - res, err := ps.ds.Query(dsq.Query{KeysOnly: true}) + addrs := make([]address.Address, 0, len(cis)) + for _, ci := range cis { + addrs = append(addrs, *ci.Channel) + } + + return addrs, nil +} + +// findChan finds a single channel using the given filter. +// If there isn't a channel that matches the filter, returns ErrChannelNotTracked +func (ps *Store) findChan(filter func(ci *ChannelInfo) bool) (*ChannelInfo, error) { + cis, err := ps.findChans(filter, 1) + if err != nil { + return nil, err + } + + if len(cis) == 0 { + return nil, ErrChannelNotTracked + } + + return &cis[0], err +} + +// findChans loops over all channels, only including those that pass the filter. +// max is the maximum number of channels to return. Set to zero to return unlimited channels. +func (ps *Store) findChans(filter func(*ChannelInfo) bool, max int) ([]ChannelInfo, error) { + res, err := ps.ds.Query(dsq.Query{Prefix: dsKeyChannelInfo}) if err != nil { return nil, err } defer res.Close() //nolint:errcheck - var out []address.Address + var stored ChannelInfo + var matches []ChannelInfo + for { res, ok := res.NextSync() if !ok { @@ -128,60 +150,31 @@ func (ps *Store) ListChannels() ([]address.Address, error) { return nil, err } - addr, err := address.NewFromString(strings.TrimPrefix(res.Key, "/")) + ci, err := unmarshallChannelInfo(&stored, res.Value) if err != nil { - return nil, xerrors.Errorf("failed reading paych key (%q) from datastore: %w", res.Key, err) + return nil, err } - out = append(out, addr) - } - - return out, nil -} - -func (ps *Store) findChan(filter func(*ChannelInfo) bool) (address.Address, error) { - res, err := ps.ds.Query(dsq.Query{}) - if err != nil { - return address.Undef, err - } - defer res.Close() //nolint:errcheck - - var ci ChannelInfo - - for { - res, ok := res.NextSync() - if !ok { - break - } - - if res.Error != nil { - return address.Undef, err - } - - if err := ci.UnmarshalCBOR(bytes.NewReader(res.Value)); err != nil { - return address.Undef, err - } - - if !filter(&ci) { + if !filter(ci) { continue } - addr, err := address.NewFromString(strings.TrimPrefix(res.Key, "/")) - if err != nil { - return address.Undef, xerrors.Errorf("failed reading paych key (%q) from datastore: %w", res.Key, err) - } + matches = append(matches, *ci) - return addr, nil + // If we've reached the maximum number of matches, return. + // Note that if max is zero we return an unlimited number of matches + // because len(matches) will always be at least 1. + if len(matches) == max { + return matches, nil + } } - return address.Undef, nil + return matches, nil } +// AllocateLane allocates a new lane for the given channel func (ps *Store) AllocateLane(ch address.Address) (uint64, error) { - ps.lk.Lock() - defer ps.lk.Unlock() - - ci, err := ps.getChannelInfo(ch) + ci, err := ps.ByAddress(ch) if err != nil { return 0, err } @@ -192,11 +185,231 @@ func (ps *Store) AllocateLane(ch address.Address) (uint64, error) { return out, ps.putChannelInfo(ci) } +// VouchersForPaych gets the vouchers for the given channel func (ps *Store) VouchersForPaych(ch address.Address) ([]*VoucherInfo, error) { - ci, err := ps.getChannelInfo(ch) + ci, err := ps.ByAddress(ch) if err != nil { return nil, err } return ci.Vouchers, nil } + +// ByAddress gets the channel that matches the given address +func (ps *Store) ByAddress(addr address.Address) (*ChannelInfo, error) { + return ps.findChan(func(ci *ChannelInfo) bool { + return ci.Channel != nil && *ci.Channel == addr + }) +} + +// MsgInfo stores information about a create channel / add funds message +// that has been sent +type MsgInfo struct { + // ChannelID links the message to a channel + ChannelID string + // MsgCid is the CID of the message + MsgCid cid.Cid + // Received indicates whether a response has been received + Received bool + // Err is the error received in the response + Err string +} + +// The datastore key used to identify the message +func dskeyForMsg(mcid cid.Cid) datastore.Key { + return datastore.KeyWithNamespaces([]string{dsKeyMsgCid, mcid.String()}) +} + +// SaveNewMessage is called when a message is sent +func (ps *Store) SaveNewMessage(channelID string, mcid cid.Cid) error { + k := dskeyForMsg(mcid) + + b, err := cborrpc.Dump(&MsgInfo{ChannelID: channelID, MsgCid: mcid}) + if err != nil { + return err + } + + return ps.ds.Put(k, b) +} + +// SaveMessageResult is called when the result of a message is received +func (ps *Store) SaveMessageResult(mcid cid.Cid, msgErr error) error { + minfo, err := ps.GetMessage(mcid) + if err != nil { + return err + } + + k := dskeyForMsg(mcid) + minfo.Received = true + if msgErr != nil { + minfo.Err = msgErr.Error() + } + + b, err := cborrpc.Dump(minfo) + if err != nil { + return err + } + + return ps.ds.Put(k, b) +} + +// ByMessageCid gets the channel associated with a message +func (ps *Store) ByMessageCid(mcid cid.Cid) (*ChannelInfo, error) { + minfo, err := ps.GetMessage(mcid) + if err != nil { + return nil, err + } + + ci, err := ps.findChan(func(ci *ChannelInfo) bool { + return ci.ChannelID == minfo.ChannelID + }) + if err != nil { + return nil, err + } + + return ci, err +} + +// GetMessage gets the message info for a given message CID +func (ps *Store) GetMessage(mcid cid.Cid) (*MsgInfo, error) { + k := dskeyForMsg(mcid) + + val, err := ps.ds.Get(k) + if err != nil { + return nil, err + } + + var minfo MsgInfo + if err := minfo.UnmarshalCBOR(bytes.NewReader(val)); err != nil { + return nil, err + } + + return &minfo, nil +} + +// OutboundActiveByFromTo looks for outbound channels that have not been +// settled, with the given from / to addresses +func (ps *Store) OutboundActiveByFromTo(from address.Address, to address.Address) (*ChannelInfo, error) { + return ps.findChan(func(ci *ChannelInfo) bool { + if ci.Direction != DirOutbound { + return false + } + if ci.Settling { + return false + } + return ci.Control == from && ci.Target == to + }) +} + +// WithPendingAddFunds is used on startup to find channels for which a +// create channel or add funds message has been sent, but lotus shut down +// before the response was received. +func (ps *Store) WithPendingAddFunds() ([]ChannelInfo, error) { + return ps.findChans(func(ci *ChannelInfo) bool { + if ci.Direction != DirOutbound { + return false + } + return ci.CreateMsg != nil || ci.AddFundsMsg != nil + }, 0) +} + +// ByChannelID gets channel info by channel ID +func (ps *Store) ByChannelID(channelID string) (*ChannelInfo, error) { + var stored ChannelInfo + + res, err := ps.ds.Get(dskeyForChannel(channelID)) + if err != nil { + if err == datastore.ErrNotFound { + return nil, ErrChannelNotTracked + } + return nil, err + } + + return unmarshallChannelInfo(&stored, res) +} + +// CreateChannel creates an outbound channel for the given from / to, ensuring +// it has a higher sequence number than any existing channel with the same from / to +func (ps *Store) CreateChannel(from address.Address, to address.Address, createMsgCid cid.Cid, amt types.BigInt) (*ChannelInfo, error) { + ci := &ChannelInfo{ + Direction: DirOutbound, + NextLane: 0, + Control: from, + Target: to, + CreateMsg: &createMsgCid, + PendingAmount: amt, + } + + // Save the new channel + err := ps.putChannelInfo(ci) + if err != nil { + return nil, err + } + + // Save a reference to the create message + err = ps.SaveNewMessage(ci.ChannelID, createMsgCid) + if err != nil { + return nil, err + } + + return ci, err +} + +// RemoveChannel removes the channel with the given channel ID +func (ps *Store) RemoveChannel(channelID string) error { + return ps.ds.Delete(dskeyForChannel(channelID)) +} + +// The datastore key used to identify the channel info +func dskeyForChannel(channelID string) datastore.Key { + return datastore.KeyWithNamespaces([]string{dsKeyChannelInfo, channelID}) +} + +// putChannelInfo stores the channel info in the datastore +func (ps *Store) putChannelInfo(ci *ChannelInfo) error { + if len(ci.ChannelID) == 0 { + ci.ChannelID = uuid.New().String() + } + k := dskeyForChannel(ci.ChannelID) + + b, err := marshallChannelInfo(ci) + if err != nil { + return err + } + + return ps.ds.Put(k, b) +} + +// TODO: This is a hack to get around not being able to CBOR marshall a nil +// address.Address. It's been fixed in address.Address but we need to wait +// for the change to propagate to specs-actors before we can remove this hack. +var emptyAddr address.Address + +func init() { + addr, err := address.NewActorAddress([]byte("empty")) + if err != nil { + panic(err) + } + emptyAddr = addr +} + +func marshallChannelInfo(ci *ChannelInfo) ([]byte, error) { + // See note above about CBOR marshalling address.Address + if ci.Channel == nil { + ci.Channel = &emptyAddr + } + return cborrpc.Dump(ci) +} + +func unmarshallChannelInfo(stored *ChannelInfo, value []byte) (*ChannelInfo, error) { + if err := stored.UnmarshalCBOR(bytes.NewReader(value)); err != nil { + return nil, err + } + + // See note above about CBOR marshalling address.Address + if stored.Channel != nil && *stored.Channel == emptyAddr { + stored.Channel = nil + } + + return stored, nil +} diff --git a/paychmgr/store_test.go b/paychmgr/store_test.go index 094226464..65be6f1b1 100644 --- a/paychmgr/store_test.go +++ b/paychmgr/store_test.go @@ -17,8 +17,9 @@ func TestStore(t *testing.T) { require.NoError(t, err) require.Len(t, addrs, 0) + ch := tutils.NewIDAddr(t, 100) ci := &ChannelInfo{ - Channel: tutils.NewIDAddr(t, 100), + Channel: &ch, Control: tutils.NewIDAddr(t, 101), Target: tutils.NewIDAddr(t, 102), @@ -26,8 +27,9 @@ func TestStore(t *testing.T) { Vouchers: []*VoucherInfo{{Voucher: nil, Proof: []byte{}}}, } + ch2 := tutils.NewIDAddr(t, 200) ci2 := &ChannelInfo{ - Channel: tutils.NewIDAddr(t, 200), + Channel: &ch2, Control: tutils.NewIDAddr(t, 201), Target: tutils.NewIDAddr(t, 202), @@ -55,7 +57,7 @@ func TestStore(t *testing.T) { require.Contains(t, addrsStrings(addrs), "t0200") // Request vouchers for channel - vouchers, err := store.VouchersForPaych(ci.Channel) + vouchers, err := store.VouchersForPaych(*ci.Channel) require.NoError(t, err) require.Len(t, vouchers, 1) @@ -64,12 +66,12 @@ func TestStore(t *testing.T) { require.Equal(t, err, ErrChannelNotTracked) // Allocate lane for channel - lane, err := store.AllocateLane(ci.Channel) + lane, err := store.AllocateLane(*ci.Channel) require.NoError(t, err) require.Equal(t, lane, uint64(0)) // Allocate next lane for channel - lane, err = store.AllocateLane(ci.Channel) + lane, err = store.AllocateLane(*ci.Channel) require.NoError(t, err) require.Equal(t, lane, uint64(1))