refactor: add wallet param to FundManager methods

This commit is contained in:
Dirk McCormick 2020-11-06 16:58:45 +01:00 committed by hannahhoward
parent 4d3cd7dcb8
commit 1182927fe5
4 changed files with 257 additions and 183 deletions

View File

@ -81,6 +81,7 @@ type DealProposals interface {
type PublishStorageDealsParams = market0.PublishStorageDealsParams type PublishStorageDealsParams = market0.PublishStorageDealsParams
type PublishStorageDealsReturn = market0.PublishStorageDealsReturn type PublishStorageDealsReturn = market0.PublishStorageDealsReturn
type VerifyDealsForActivationParams = market0.VerifyDealsForActivationParams type VerifyDealsForActivationParams = market0.VerifyDealsForActivationParams
type WithdrawBalanceParams = market0.WithdrawBalanceParams
type ClientDealProposal = market0.ClientDealProposal type ClientDealProposal = market0.ClientDealProposal

View File

@ -12,7 +12,7 @@ import (
var _ = xerrors.Errorf var _ = xerrors.Errorf
var lengthBufFundedAddressState = []byte{132} var lengthBufFundedAddressState = []byte{131}
func (t *FundedAddressState) MarshalCBOR(w io.Writer) error { func (t *FundedAddressState) MarshalCBOR(w io.Writer) error {
if t == nil { if t == nil {
@ -25,11 +25,6 @@ func (t *FundedAddressState) MarshalCBOR(w io.Writer) error {
scratch := make([]byte, 9) scratch := make([]byte, 9)
// t.Wallet (address.Address) (struct)
if err := t.Wallet.MarshalCBOR(w); err != nil {
return err
}
// t.Addr (address.Address) (struct) // t.Addr (address.Address) (struct)
if err := t.Addr.MarshalCBOR(w); err != nil { if err := t.Addr.MarshalCBOR(w); err != nil {
return err return err
@ -69,19 +64,10 @@ func (t *FundedAddressState) UnmarshalCBOR(r io.Reader) error {
return fmt.Errorf("cbor input should be of type array") return fmt.Errorf("cbor input should be of type array")
} }
if extra != 4 { if extra != 3 {
return fmt.Errorf("cbor input had wrong number of fields") return fmt.Errorf("cbor input had wrong number of fields")
} }
// t.Wallet (address.Address) (struct)
{
if err := t.Wallet.UnmarshalCBOR(br); err != nil {
return xerrors.Errorf("unmarshaling t.Wallet: %w", err)
}
}
// t.Addr (address.Address) (struct) // t.Addr (address.Address) (struct)
{ {

View File

@ -46,24 +46,22 @@ type FundManager struct {
ctx context.Context ctx context.Context
shutdown context.CancelFunc shutdown context.CancelFunc
api fundManagerAPI api fundManagerAPI
wallet address.Address
str *Store str *Store
lk sync.Mutex lk sync.Mutex
fundedAddrs map[address.Address]*fundedAddress fundedAddrs map[address.Address]*fundedAddress
} }
type waitSentinel cid.Cid type WaitSentinel cid.Cid
var waitSentinelUndef = waitSentinel(cid.Undef) var WaitSentinelUndef = WaitSentinel(cid.Undef)
func NewFundManager(api fundManagerAPI, ds datastore.Batching, wallet address.Address) *FundManager { func NewFundManager(api fundManagerAPI, ds datastore.Batching) *FundManager {
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
return &FundManager{ return &FundManager{
ctx: ctx, ctx: ctx,
shutdown: cancel, shutdown: cancel,
api: api, api: api,
wallet: wallet,
str: newStore(ds), str: newStore(ds),
fundedAddrs: make(map[address.Address]*fundedAddress), fundedAddrs: make(map[address.Address]*fundedAddress),
} }
@ -103,25 +101,25 @@ func (fm *FundManager) getFundedAddress(addr address.Address) *fundedAddress {
return fa return fa
} }
// Reserve adds amt to `reserved`. If there is not enough available funds for // Reserve adds amt to `reserved`. If there are not enough available funds for
// the address, submits a message on chain to top up available funds. // the address, submits a message on chain to top up available funds.
func (fm *FundManager) Reserve(ctx context.Context, addr address.Address, amt abi.TokenAmount) (waitSentinel, error) { func (fm *FundManager) Reserve(ctx context.Context, wallet, addr address.Address, amt abi.TokenAmount) (WaitSentinel, error) {
return fm.getFundedAddress(addr).reserve(ctx, amt) return fm.getFundedAddress(addr).reserve(ctx, wallet, amt)
} }
// Subtract from `reserved`. // Subtract from `reserved`.
func (fm *FundManager) Release(ctx context.Context, addr address.Address, amt abi.TokenAmount) error { func (fm *FundManager) Release(addr address.Address, amt abi.TokenAmount) error {
return fm.getFundedAddress(addr).release(ctx, amt) return fm.getFundedAddress(addr).release(amt)
} }
// Withdraw unreserved funds. Only succeeds if there are enough unreserved // Withdraw unreserved funds. Only succeeds if there are enough unreserved
// funds for the address. // funds for the address.
func (fm *FundManager) Withdraw(ctx context.Context, addr address.Address, amt abi.TokenAmount) (waitSentinel, error) { func (fm *FundManager) Withdraw(ctx context.Context, wallet, addr address.Address, amt abi.TokenAmount) (WaitSentinel, error) {
return fm.getFundedAddress(addr).withdraw(ctx, amt) return fm.getFundedAddress(addr).withdraw(ctx, wallet, amt)
} }
// Waits for a reserve or withdraw to complete. // Waits for a reserve or withdraw to complete.
func (fm *FundManager) Wait(ctx context.Context, sentinel waitSentinel) error { func (fm *FundManager) Wait(ctx context.Context, sentinel WaitSentinel) error {
_, err := fm.api.StateWaitMsg(ctx, cid.Cid(sentinel), build.MessageConfidence) _, err := fm.api.StateWaitMsg(ctx, cid.Cid(sentinel), build.MessageConfidence)
return err return err
} }
@ -129,8 +127,6 @@ func (fm *FundManager) Wait(ctx context.Context, sentinel waitSentinel) error {
// FundedAddressState keeps track of the state of an address with funds in the // FundedAddressState keeps track of the state of an address with funds in the
// datastore // datastore
type FundedAddressState struct { type FundedAddressState struct {
// Wallet is the wallet from which funds are added to the address
Wallet address.Address
Addr address.Address Addr address.Address
// AmtReserved is the amount that must be kept in the address (cannot be // AmtReserved is the amount that must be kept in the address (cannot be
// withdrawn) // withdrawn)
@ -164,14 +160,13 @@ func newFundedAddress(fm *FundManager, addr address.Address) *fundedAddress {
env: &fundManagerEnvironment{api: fm.api}, env: &fundManagerEnvironment{api: fm.api},
str: fm.str, str: fm.str,
state: &FundedAddressState{ state: &FundedAddressState{
Wallet: fm.wallet,
Addr: addr, Addr: addr,
AmtReserved: abi.NewTokenAmount(0), AmtReserved: abi.NewTokenAmount(0),
}, },
} }
} }
// If there is a in-progress on-chain message, don't submit any more messages // If there is an in-progress on-chain message, don't submit any more messages
// on chain until it completes // on chain until it completes
func (a *fundedAddress) start() { func (a *fundedAddress) start() {
a.lk.Lock() a.lk.Lock()
@ -183,22 +178,22 @@ func (a *fundedAddress) start() {
} }
} }
func (a *fundedAddress) reserve(ctx context.Context, amt abi.TokenAmount) (waitSentinel, error) { func (a *fundedAddress) reserve(ctx context.Context, wallet address.Address, amt abi.TokenAmount) (WaitSentinel, error) {
return a.requestAndWait(ctx, amt, &a.reservations) return a.requestAndWait(ctx, wallet, amt, &a.reservations)
} }
func (a *fundedAddress) release(ctx context.Context, amt abi.TokenAmount) error { func (a *fundedAddress) release(amt abi.TokenAmount) error {
_, err := a.requestAndWait(ctx, amt, &a.releases) _, err := a.requestAndWait(context.Background(), address.Undef, amt, &a.releases)
return err return err
} }
func (a *fundedAddress) withdraw(ctx context.Context, amt abi.TokenAmount) (waitSentinel, error) { func (a *fundedAddress) withdraw(ctx context.Context, wallet address.Address, amt abi.TokenAmount) (WaitSentinel, error) {
return a.requestAndWait(ctx, amt, &a.withdrawals) return a.requestAndWait(ctx, wallet, amt, &a.withdrawals)
} }
func (a *fundedAddress) requestAndWait(ctx context.Context, amt abi.TokenAmount, reqs *[]*fundRequest) (waitSentinel, error) { func (a *fundedAddress) requestAndWait(ctx context.Context, wallet address.Address, amt abi.TokenAmount, reqs *[]*fundRequest) (WaitSentinel, error) {
// Create a request and add it to the request queue // Create a request and add it to the request queue
req := newFundRequest(ctx, amt) req := newFundRequest(ctx, wallet, amt)
a.lk.Lock() a.lk.Lock()
*reqs = append(*reqs, req) *reqs = append(*reqs, req)
@ -210,9 +205,9 @@ func (a *fundedAddress) requestAndWait(ctx context.Context, amt abi.TokenAmount,
// Wait for the results // Wait for the results
select { select {
case <-ctx.Done(): case <-ctx.Done():
return waitSentinelUndef, ctx.Err() return WaitSentinelUndef, ctx.Err()
case r := <-req.Result: case r := <-req.Result:
return waitSentinel(r.msgCid), r.err return WaitSentinel(r.msgCid), r.err
} }
} }
@ -243,17 +238,40 @@ func (a *fundedAddress) process() {
} }
// Check if there's anything to do // Check if there's anything to do
if len(a.reservations) == 0 && len(a.releases) == 0 && len(a.withdrawals) == 0 { haveReservations := len(a.reservations) > 0 || len(a.releases) > 0
haveWithdrawals := len(a.withdrawals) > 0
if !haveReservations && !haveWithdrawals {
return return
} }
res, _ := a.processRequests() // Process reservations / releases
if haveReservations {
res, err := a.processReservations(a.reservations, a.releases)
if err == nil {
a.applyStateChange(res.msgCid, res.amtReserved)
}
a.reservations = filterOutProcessedReqs(a.reservations) a.reservations = filterOutProcessedReqs(a.reservations)
a.releases = filterOutProcessedReqs(a.releases) a.releases = filterOutProcessedReqs(a.releases)
a.withdrawals = filterOutProcessedReqs(a.withdrawals) }
a.applyStateChange(res) // If there was no message sent on chain by adding reservations, and all
// reservations have completed processing, process withdrawals
if haveWithdrawals && a.state.MsgCid == nil && len(a.reservations) == 0 {
withdrawalCid, err := a.processWithdrawals(a.withdrawals)
if err == nil && withdrawalCid != cid.Undef {
a.applyStateChange(&withdrawalCid, types.EmptyInt)
}
a.withdrawals = filterOutProcessedReqs(a.withdrawals)
}
// If a message was sent on-chain
if a.state.MsgCid != nil {
// Start waiting for results of message (async)
a.startWaitForResults(*a.state.MsgCid)
}
// Process any remaining queued requests
go a.process()
} }
// Filter out completed requests // Filter out completed requests
@ -268,9 +286,11 @@ func filterOutProcessedReqs(reqs []*fundRequest) []*fundRequest {
} }
// Apply the results of processing queues and save to the datastore // Apply the results of processing queues and save to the datastore
func (a *fundedAddress) applyStateChange(res *processResult) { func (a *fundedAddress) applyStateChange(msgCid *cid.Cid, amtReserved abi.TokenAmount) {
a.state.MsgCid = res.msgCid a.state.MsgCid = msgCid
a.state.AmtReserved = res.amtReserved if !amtReserved.Nil() {
a.state.AmtReserved = amtReserved
}
a.saveState() a.saveState()
} }
@ -289,59 +309,67 @@ func (a *fundedAddress) saveState() {
} }
} }
// The result of processing the request queues // The result of processing the reservation / release queues
type processResult struct { type processResult struct {
// Requests that completed without adding funds
cancelled []*fundRequest
// Requests that added funds
added []*fundRequest
// The new reserved amount // The new reserved amount
amtReserved abi.TokenAmount amtReserved abi.TokenAmount
// The message cid, if a message was pushed // The message cid, if a message was submitted on-chain
msgCid *cid.Cid msgCid *cid.Cid
} }
// process request queues and return the resulting changes to state // process reservations and releases, and return the resulting changes to state
func (a *fundedAddress) processRequests() (pr *processResult, prerr error) { func (a *fundedAddress) processReservations(reservations []*fundRequest, releases []*fundRequest) (pr *processResult, prerr error) {
// If there's an error, mark reserve requests as errored // When the function returns
defer func() { defer func() {
// If there's an error, mark all requests as errored
if prerr != nil { if prerr != nil {
for _, req := range a.reservations { for _, req := range append(reservations, releases...) {
req.Complete(cid.Undef, prerr) req.Complete(cid.Undef, prerr)
} }
return
}
// Complete all release requests
for _, req := range releases {
req.Complete(cid.Undef, nil)
}
// Complete all cancelled requests
for _, req := range pr.cancelled {
req.Complete(cid.Undef, nil)
}
// If a message was sent
if pr.msgCid != nil {
// Complete all add funds requests
for _, req := range pr.added {
req.Complete(*pr.msgCid, nil)
}
} }
}() }()
// Start with the reserved amount in state // Split reservations into those to cancel (because they are covered by
reserved := a.state.AmtReserved // released amounts) and those to add
toCancel, toAdd, reservedDelta := splitReservations(reservations, releases)
// Add the amount of each reserve request // Apply the reserved delta to the reserved amount
for _, req := range a.reservations { reserved := types.BigAdd(a.state.AmtReserved, reservedDelta)
amt := req.Amount()
a.debugf("reserve %d", amt)
reserved = types.BigAdd(reserved, amt)
}
// Subtract the amount of each release request
for _, req := range a.releases {
amt := req.Amount()
a.debugf("release %d", amt)
reserved = types.BigSub(reserved, amt)
// Mark release as complete
req.Complete(cid.Undef, nil)
}
// If reserved amount is negative, set it to zero
if reserved.LessThan(abi.NewTokenAmount(0)) { if reserved.LessThan(abi.NewTokenAmount(0)) {
reserved = abi.NewTokenAmount(0) reserved = abi.NewTokenAmount(0)
} }
res := &processResult{
res := &processResult{amtReserved: reserved} amtReserved: reserved,
cancelled: toCancel,
}
// Work out the amount to add to the balance // Work out the amount to add to the balance
toAdd := abi.NewTokenAmount(0) amtToAdd := abi.NewTokenAmount(0)
if reserved.GreaterThan(abi.NewTokenAmount(0)) {
// If the new reserved amount is greater than the existing amount
if reserved.GreaterThan(a.state.AmtReserved) {
a.debugf("reserved %d > state.AmtReserved %d", reserved, a.state.AmtReserved)
// Get available funds for address // Get available funds for address
avail, err := a.env.AvailableFunds(a.ctx, a.state.Addr) avail, err := a.env.AvailableFunds(a.ctx, a.state.Addr)
if err != nil { if err != nil {
@ -349,63 +377,98 @@ func (a *fundedAddress) processRequests() (pr *processResult, prerr error) {
} }
// amount to add = new reserved amount - available // amount to add = new reserved amount - available
toAdd = types.BigSub(reserved, avail) amtToAdd = types.BigSub(reserved, avail)
a.debugf("reserved %d - avail %d = %d", reserved, avail, toAdd) a.debugf("reserved %d - avail %d = to add %d", reserved, avail, amtToAdd)
} }
// If there's nothing to add to the balance // If there's nothing to add to the balance, bail out
if toAdd.LessThanEqual(abi.NewTokenAmount(0)) { if amtToAdd.LessThanEqual(abi.NewTokenAmount(0)) {
// Mark reserve requests as complete a.debugf(" queued for cancel %d", len(toAdd))
for _, req := range a.reservations { res.cancelled = append(res.cancelled, toAdd...)
req.Complete(cid.Undef, nil) return res, nil
}
// Process withdrawals
return a.processWithdrawals(reserved)
} }
// Add funds to address // Add funds to address
a.debugf("add funds %d", toAdd) a.debugf("add funds %d", amtToAdd)
addFundsCid, err := a.env.AddFunds(a.ctx, a.state.Wallet, a.state.Addr, toAdd) addFundsCid, err := a.env.AddFunds(a.ctx, toAdd[0].Wallet, a.state.Addr, amtToAdd)
if err != nil { if err != nil {
return res, err return res, err
} }
// Mark reserve requests as complete // Mark reserve requests as complete
for _, req := range a.reservations { res.added = toAdd
req.Complete(addFundsCid, nil)
}
// Start waiting for results (async)
defer a.startWaitForResults(addFundsCid)
// Save the message CID to state // Save the message CID to state
res.msgCid = &addFundsCid res.msgCid = &addFundsCid
return res, nil return res, nil
} }
// Split reservations into those that are under the total release amount and
// those that exceed it
func splitReservations(reservations []*fundRequest, releases []*fundRequest) ([]*fundRequest, []*fundRequest, abi.TokenAmount) {
toCancel := make([]*fundRequest, 0, len(reservations))
toAdd := make([]*fundRequest, 0, len(reservations))
toAddAmt := abi.NewTokenAmount(0)
// Sum release amounts
releaseAmt := abi.NewTokenAmount(0)
for _, req := range releases {
releaseAmt = types.BigAdd(releaseAmt, req.Amount())
}
// We only want to combine requests that come from the same wallet
wallet := address.Undef
for _, req := range reservations {
amt := req.Amount()
// If the amount to add to the reserve is cancelled out by a release
if amt.LessThanEqual(releaseAmt) {
// Cancel the request and update the release total
releaseAmt = types.BigSub(releaseAmt, amt)
toCancel = append(toCancel, req)
} else {
// The amount to add is greater that the release total so we want
// to send an add funds request
// The first time the wallet will be undefined
if wallet == address.Undef {
wallet = req.Wallet
}
// If this request's wallet is the same as the first request's
// wallet, the requests will be combined
if wallet == req.Wallet {
delta := types.BigSub(amt, releaseAmt)
toAddAmt = types.BigAdd(toAddAmt, delta)
releaseAmt = abi.NewTokenAmount(0)
toAdd = append(toAdd, req)
}
}
}
// The change in the reserved amount is "amount to add" - "amount to release"
reservedDelta := types.BigSub(toAddAmt, releaseAmt)
return toCancel, toAdd, reservedDelta
}
// process withdrawal queue // process withdrawal queue
func (a *fundedAddress) processWithdrawals(reserved abi.TokenAmount) (pr *processResult, prerr error) { func (a *fundedAddress) processWithdrawals(withdrawals []*fundRequest) (msgCid cid.Cid, prerr error) {
// If there's an error, mark withdrawal requests as errored // If there's an error, mark all withdrawal requests as errored
defer func() { defer func() {
if prerr != nil { if prerr != nil {
for _, req := range a.withdrawals { for _, req := range withdrawals {
req.Complete(cid.Undef, prerr) req.Complete(cid.Undef, prerr)
} }
} }
}() }()
res := &processResult{
amtReserved: reserved,
}
// Get the net available balance // Get the net available balance
avail, err := a.env.AvailableFunds(a.ctx, a.state.Addr) avail, err := a.env.AvailableFunds(a.ctx, a.state.Addr)
if err != nil { if err != nil {
return res, err return cid.Undef, err
} }
netAvail := types.BigSub(avail, reserved) netAvail := types.BigSub(avail, a.state.AmtReserved)
// Fit as many withdrawals as possible into the available balance, and fail // Fit as many withdrawals as possible into the available balance, and fail
// the rest // the rest
@ -428,18 +491,18 @@ func (a *fundedAddress) processWithdrawals(reserved abi.TokenAmount) (pr *proces
// Check if there is anything to withdraw // Check if there is anything to withdraw
if allowedAmt.Equals(abi.NewTokenAmount(0)) { if allowedAmt.Equals(abi.NewTokenAmount(0)) {
// Mark allowed requests as complete // Mark allowed requests as cancelled
for _, req := range allowed { for _, req := range allowed {
req.Complete(cid.Undef, nil) req.Complete(cid.Undef, nil)
} }
return res, nil return cid.Undef, nil
} }
// Withdraw funds // Withdraw funds
a.debugf("withdraw funds %d", allowedAmt) a.debugf("withdraw funds %d", allowedAmt)
withdrawFundsCid, err := a.env.WithdrawFunds(a.ctx, a.state.Wallet, a.state.Addr, allowedAmt) withdrawFundsCid, err := a.env.WithdrawFunds(a.ctx, allowed[0].Wallet, a.state.Addr, allowedAmt)
if err != nil { if err != nil {
return res, err return cid.Undef, err
} }
// Mark allowed requests as complete // Mark allowed requests as complete
@ -447,12 +510,8 @@ func (a *fundedAddress) processWithdrawals(reserved abi.TokenAmount) (pr *proces
req.Complete(withdrawFundsCid, nil) req.Complete(withdrawFundsCid, nil)
} }
// Start waiting for results of message (async)
defer a.startWaitForResults(withdrawFundsCid)
// Save the message CID to state // Save the message CID to state
res.msgCid = &withdrawFundsCid return withdrawFundsCid, nil
return res, nil
} }
// asynchonously wait for results of message // asynchonously wait for results of message
@ -491,13 +550,15 @@ type fundRequest struct {
ctx context.Context ctx context.Context
amt abi.TokenAmount amt abi.TokenAmount
completed chan struct{} completed chan struct{}
Wallet address.Address
Result chan reqResult Result chan reqResult
} }
func newFundRequest(ctx context.Context, amt abi.TokenAmount) *fundRequest { func newFundRequest(ctx context.Context, wallet address.Address, amt abi.TokenAmount) *fundRequest {
return &fundRequest{ return &fundRequest{
ctx: ctx, ctx: ctx,
amt: amt, amt: amt,
Wallet: wallet,
Result: make(chan reqResult), Result: make(chan reqResult),
completed: make(chan struct{}), completed: make(chan struct{}),
} }
@ -532,10 +593,6 @@ func (frp *fundRequest) Completed() bool {
} }
} }
func (frp *fundRequest) Equals(other *fundRequest) bool {
return frp == other
}
// fundManagerEnvironment simplifies some API calls // fundManagerEnvironment simplifies some API calls
type fundManagerEnvironment struct { type fundManagerEnvironment struct {
api fundManagerAPI api fundManagerAPI
@ -556,7 +613,24 @@ func (env *fundManagerEnvironment) AddFunds(
addr address.Address, addr address.Address,
amt abi.TokenAmount, amt abi.TokenAmount,
) (cid.Cid, error) { ) (cid.Cid, error) {
return env.sendFunds(ctx, wallet, addr, amt) params, err := actors.SerializeParams(&addr)
if err != nil {
return cid.Undef, err
}
smsg, aerr := env.api.MpoolPushMessage(ctx, &types.Message{
To: market.Address,
From: wallet,
Value: amt,
Method: market.Methods.AddBalance,
Params: params,
}, nil)
if aerr != nil {
return cid.Undef, aerr
}
return smsg.Cid(), nil
} }
func (env *fundManagerEnvironment) WithdrawFunds( func (env *fundManagerEnvironment) WithdrawFunds(
@ -565,25 +639,19 @@ func (env *fundManagerEnvironment) WithdrawFunds(
addr address.Address, addr address.Address,
amt abi.TokenAmount, amt abi.TokenAmount,
) (cid.Cid, error) { ) (cid.Cid, error) {
return env.sendFunds(ctx, addr, wallet, amt) params, err := actors.SerializeParams(&market.WithdrawBalanceParams{
} ProviderOrClientAddress: addr,
Amount: amt,
func (env *fundManagerEnvironment) sendFunds( })
ctx context.Context,
from address.Address,
to address.Address,
amt abi.TokenAmount,
) (cid.Cid, error) {
params, err := actors.SerializeParams(&to)
if err != nil { if err != nil {
return cid.Undef, err return cid.Undef, xerrors.Errorf("serializing params: %w", err)
} }
smsg, aerr := env.api.MpoolPushMessage(ctx, &types.Message{ smsg, aerr := env.api.MpoolPushMessage(ctx, &types.Message{
To: market.Address, To: market.Address,
From: from, From: wallet,
Value: amt, Value: types.NewInt(0),
Method: market.Methods.AddBalance, Method: market.Methods.WithdrawBalance,
Params: params, Params: params,
}, nil) }, nil)

View File

@ -35,11 +35,11 @@ func TestFundManagerBasic(t *testing.T) {
// balance: 0 -> 10 // balance: 0 -> 10
// reserved: 0 -> 10 // reserved: 0 -> 10
amt := abi.NewTokenAmount(10) amt := abi.NewTokenAmount(10)
sentinel, err := s.fm.Reserve(s.ctx, s.acctAddr, amt) sentinel, err := s.fm.Reserve(s.ctx, s.walletAddr, s.acctAddr, amt)
require.NoError(t, err) require.NoError(t, err)
msg := s.mockApi.getSentMessage(cid.Cid(sentinel)) msg := s.mockApi.getSentMessage(cid.Cid(sentinel))
checkMessageFields(t, msg, s.walletAddr, s.acctAddr, amt) checkAddMessageFields(t, msg, s.walletAddr, s.acctAddr, amt)
s.mockApi.completeMsg(cid.Cid(sentinel)) s.mockApi.completeMsg(cid.Cid(sentinel))
err = s.fm.Wait(s.ctx, sentinel) err = s.fm.Wait(s.ctx, sentinel)
@ -49,11 +49,11 @@ func TestFundManagerBasic(t *testing.T) {
// balance: 10 -> 17 // balance: 10 -> 17
// reserved: 10 -> 17 // reserved: 10 -> 17
amt = abi.NewTokenAmount(7) amt = abi.NewTokenAmount(7)
sentinel, err = s.fm.Reserve(s.ctx, s.acctAddr, amt) sentinel, err = s.fm.Reserve(s.ctx, s.walletAddr, s.acctAddr, amt)
require.NoError(t, err) require.NoError(t, err)
msg = s.mockApi.getSentMessage(cid.Cid(sentinel)) msg = s.mockApi.getSentMessage(cid.Cid(sentinel))
checkMessageFields(t, msg, s.walletAddr, s.acctAddr, amt) checkAddMessageFields(t, msg, s.walletAddr, s.acctAddr, amt)
s.mockApi.completeMsg(cid.Cid(sentinel)) s.mockApi.completeMsg(cid.Cid(sentinel))
err = s.fm.Wait(s.ctx, sentinel) err = s.fm.Wait(s.ctx, sentinel)
@ -63,18 +63,18 @@ func TestFundManagerBasic(t *testing.T) {
// balance: 17 // balance: 17
// reserved: 17 -> 12 // reserved: 17 -> 12
amt = abi.NewTokenAmount(5) amt = abi.NewTokenAmount(5)
err = s.fm.Release(s.ctx, s.acctAddr, amt) err = s.fm.Release(s.acctAddr, amt)
require.NoError(t, err) require.NoError(t, err)
// Withdraw 2 // Withdraw 2
// balance: 17 -> 15 // balance: 17 -> 15
// reserved: 12 // reserved: 12
amt = abi.NewTokenAmount(2) amt = abi.NewTokenAmount(2)
sentinel, err = s.fm.Withdraw(s.ctx, s.acctAddr, amt) sentinel, err = s.fm.Withdraw(s.ctx, s.walletAddr, s.acctAddr, amt)
require.NoError(t, err) require.NoError(t, err)
msg = s.mockApi.getSentMessage(cid.Cid(sentinel)) msg = s.mockApi.getSentMessage(cid.Cid(sentinel))
checkMessageFields(t, msg, s.acctAddr, s.walletAddr, amt) checkWithdrawMessageFields(t, msg, s.walletAddr, s.acctAddr, amt)
s.mockApi.completeMsg(cid.Cid(sentinel)) s.mockApi.completeMsg(cid.Cid(sentinel))
err = s.fm.Wait(s.ctx, sentinel) err = s.fm.Wait(s.ctx, sentinel)
@ -87,10 +87,10 @@ func TestFundManagerBasic(t *testing.T) {
// message // message
msgCount := s.mockApi.messageCount() msgCount := s.mockApi.messageCount()
amt = abi.NewTokenAmount(3) amt = abi.NewTokenAmount(3)
sentinel, err = s.fm.Reserve(s.ctx, s.acctAddr, amt) sentinel, err = s.fm.Reserve(s.ctx, s.walletAddr, s.acctAddr, amt)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, msgCount, s.mockApi.messageCount()) require.Equal(t, msgCount, s.mockApi.messageCount())
require.Equal(t, sentinel, waitSentinelUndef) require.Equal(t, sentinel, WaitSentinelUndef)
// Reserve 1 // Reserve 1
// balance: 15 -> 16 // balance: 15 -> 16
@ -99,12 +99,12 @@ func TestFundManagerBasic(t *testing.T) {
// message to top up balance // message to top up balance
amt = abi.NewTokenAmount(1) amt = abi.NewTokenAmount(1)
topUp := abi.NewTokenAmount(1) topUp := abi.NewTokenAmount(1)
sentinel, err = s.fm.Reserve(s.ctx, s.acctAddr, amt) sentinel, err = s.fm.Reserve(s.ctx, s.walletAddr, s.acctAddr, amt)
require.NoError(t, err) require.NoError(t, err)
s.mockApi.completeMsg(cid.Cid(sentinel)) s.mockApi.completeMsg(cid.Cid(sentinel))
msg = s.mockApi.getSentMessage(cid.Cid(sentinel)) msg = s.mockApi.getSentMessage(cid.Cid(sentinel))
checkMessageFields(t, msg, s.walletAddr, s.acctAddr, topUp) checkAddMessageFields(t, msg, s.walletAddr, s.acctAddr, topUp)
// Withdraw 1 // Withdraw 1
// balance: 16 // balance: 16
@ -112,7 +112,7 @@ func TestFundManagerBasic(t *testing.T) {
// Note: Expect failure because there is no available balance to withdraw: // Note: Expect failure because there is no available balance to withdraw:
// balance - reserved = 16 - 16 = 0 // balance - reserved = 16 - 16 = 0
amt = abi.NewTokenAmount(1) amt = abi.NewTokenAmount(1)
sentinel, err = s.fm.Withdraw(s.ctx, s.acctAddr, amt) sentinel, err = s.fm.Withdraw(s.ctx, s.walletAddr, s.acctAddr, amt)
require.Error(t, err) require.Error(t, err)
} }
@ -123,7 +123,7 @@ func TestFundManagerParallel(t *testing.T) {
// Reserve 10 // Reserve 10
amt := abi.NewTokenAmount(10) amt := abi.NewTokenAmount(10)
sentinelReserve10, err := s.fm.Reserve(s.ctx, s.acctAddr, amt) sentinelReserve10, err := s.fm.Reserve(s.ctx, s.walletAddr, s.acctAddr, amt)
require.NoError(t, err) require.NoError(t, err)
// Wait until all the subsequent requests are queued up // Wait until all the subsequent requests are queued up
@ -141,16 +141,16 @@ func TestFundManagerParallel(t *testing.T) {
withdrawReady := make(chan error) withdrawReady := make(chan error)
go func() { go func() {
amt = abi.NewTokenAmount(5) amt = abi.NewTokenAmount(5)
_, err := s.fm.Withdraw(s.ctx, s.acctAddr, amt) _, err := s.fm.Withdraw(s.ctx, s.walletAddr, s.acctAddr, amt)
withdrawReady <- err withdrawReady <- err
}() }()
reserveSentinels := make(chan waitSentinel) reserveSentinels := make(chan WaitSentinel)
// Reserve 3 // Reserve 3
go func() { go func() {
amt := abi.NewTokenAmount(3) amt := abi.NewTokenAmount(3)
sentinelReserve3, err := s.fm.Reserve(s.ctx, s.acctAddr, amt) sentinelReserve3, err := s.fm.Reserve(s.ctx, s.walletAddr, s.acctAddr, amt)
require.NoError(t, err) require.NoError(t, err)
reserveSentinels <- sentinelReserve3 reserveSentinels <- sentinelReserve3
}() }()
@ -158,7 +158,7 @@ func TestFundManagerParallel(t *testing.T) {
// Reserve 5 // Reserve 5
go func() { go func() {
amt := abi.NewTokenAmount(5) amt := abi.NewTokenAmount(5)
sentinelReserve5, err := s.fm.Reserve(s.ctx, s.acctAddr, amt) sentinelReserve5, err := s.fm.Reserve(s.ctx, s.walletAddr, s.acctAddr, amt)
require.NoError(t, err) require.NoError(t, err)
reserveSentinels <- sentinelReserve5 reserveSentinels <- sentinelReserve5
}() }()
@ -166,7 +166,7 @@ func TestFundManagerParallel(t *testing.T) {
// Release 2 // Release 2
go func() { go func() {
amt := abi.NewTokenAmount(2) amt := abi.NewTokenAmount(2)
err = s.fm.Release(s.ctx, s.acctAddr, amt) err = s.fm.Release(s.acctAddr, amt)
require.NoError(t, err) require.NoError(t, err)
}() }()
@ -176,7 +176,7 @@ func TestFundManagerParallel(t *testing.T) {
// Complete the "Reserve 10" message // Complete the "Reserve 10" message
s.mockApi.completeMsg(cid.Cid(sentinelReserve10)) s.mockApi.completeMsg(cid.Cid(sentinelReserve10))
msg := s.mockApi.getSentMessage(cid.Cid(sentinelReserve10)) msg := s.mockApi.getSentMessage(cid.Cid(sentinelReserve10))
checkMessageFields(t, msg, s.walletAddr, s.acctAddr, abi.NewTokenAmount(10)) checkAddMessageFields(t, msg, s.walletAddr, s.acctAddr, abi.NewTokenAmount(10))
// The other requests should now be combined and be submitted on-chain as // The other requests should now be combined and be submitted on-chain as
// a single message // a single message
@ -200,7 +200,7 @@ func TestFundManagerParallel(t *testing.T) {
// "Reserve 5" +5 // "Reserve 5" +5
// "Release 2" -2 // "Release 2" -2
// Result: 6 // Result: 6
checkMessageFields(t, msg, s.walletAddr, s.acctAddr, abi.NewTokenAmount(6)) checkAddMessageFields(t, msg, s.walletAddr, s.acctAddr, abi.NewTokenAmount(6))
// Expect withdraw to fail because not enough available funds // Expect withdraw to fail because not enough available funds
err = <-withdrawReady err = <-withdrawReady
@ -215,21 +215,21 @@ func TestFundManagerWithdrawal(t *testing.T) {
// Reserve 10 // Reserve 10
amt := abi.NewTokenAmount(10) amt := abi.NewTokenAmount(10)
sentinelReserve10, err := s.fm.Reserve(s.ctx, s.acctAddr, amt) sentinelReserve10, err := s.fm.Reserve(s.ctx, s.walletAddr, s.acctAddr, amt)
require.NoError(t, err) require.NoError(t, err)
// Complete the "Reserve 10" message // Complete the "Reserve 10" message
s.mockApi.completeMsg(cid.Cid(sentinelReserve10)) s.mockApi.completeMsg(cid.Cid(sentinelReserve10))
// Release 10 // Release 10
err = s.fm.Release(s.ctx, s.acctAddr, amt) err = s.fm.Release(s.acctAddr, amt)
require.NoError(t, err) require.NoError(t, err)
// Available 10 // Available 10
// Withdraw 6 // Withdraw 6
// Expect success // Expect success
amt = abi.NewTokenAmount(6) amt = abi.NewTokenAmount(6)
sentinelWithdraw, err := s.fm.Withdraw(s.ctx, s.acctAddr, amt) sentinelWithdraw, err := s.fm.Withdraw(s.ctx, s.walletAddr, s.acctAddr, amt)
require.NoError(t, err) require.NoError(t, err)
s.mockApi.completeMsg(cid.Cid(sentinelWithdraw)) s.mockApi.completeMsg(cid.Cid(sentinelWithdraw))
@ -240,7 +240,7 @@ func TestFundManagerWithdrawal(t *testing.T) {
// Withdraw 4 // Withdraw 4
// Expect success // Expect success
amt = abi.NewTokenAmount(4) amt = abi.NewTokenAmount(4)
sentinelWithdraw, err = s.fm.Withdraw(s.ctx, s.acctAddr, amt) sentinelWithdraw, err = s.fm.Withdraw(s.ctx, s.walletAddr, s.acctAddr, amt)
require.NoError(t, err) require.NoError(t, err)
s.mockApi.completeMsg(cid.Cid(sentinelWithdraw)) s.mockApi.completeMsg(cid.Cid(sentinelWithdraw))
@ -251,7 +251,7 @@ func TestFundManagerWithdrawal(t *testing.T) {
// Withdraw 1 // Withdraw 1
// Expect FAIL // Expect FAIL
amt = abi.NewTokenAmount(1) amt = abi.NewTokenAmount(1)
sentinelWithdraw, err = s.fm.Withdraw(s.ctx, s.acctAddr, amt) sentinelWithdraw, err = s.fm.Withdraw(s.ctx, s.walletAddr, s.acctAddr, amt)
require.Error(t, err) require.Error(t, err)
} }
@ -265,19 +265,19 @@ func TestFundManagerRestart(t *testing.T) {
// Address 1: Reserve 10 // Address 1: Reserve 10
amt := abi.NewTokenAmount(10) amt := abi.NewTokenAmount(10)
sentinelAddr1, err := s.fm.Reserve(s.ctx, s.acctAddr, amt) sentinelAddr1, err := s.fm.Reserve(s.ctx, s.walletAddr, s.acctAddr, amt)
require.NoError(t, err) require.NoError(t, err)
msg := s.mockApi.getSentMessage(cid.Cid(sentinelAddr1)) msg := s.mockApi.getSentMessage(cid.Cid(sentinelAddr1))
checkMessageFields(t, msg, s.walletAddr, s.acctAddr, amt) checkAddMessageFields(t, msg, s.walletAddr, s.acctAddr, amt)
// Address 2: Reserve 7 // Address 2: Reserve 7
amt2 := abi.NewTokenAmount(7) amt2 := abi.NewTokenAmount(7)
sentinelAddr2Res7, err := s.fm.Reserve(s.ctx, acctAddr2, amt2) sentinelAddr2Res7, err := s.fm.Reserve(s.ctx, s.walletAddr, acctAddr2, amt2)
require.NoError(t, err) require.NoError(t, err)
msg2 := s.mockApi.getSentMessage(cid.Cid(sentinelAddr2Res7)) msg2 := s.mockApi.getSentMessage(cid.Cid(sentinelAddr2Res7))
checkMessageFields(t, msg2, s.walletAddr, acctAddr2, amt2) checkAddMessageFields(t, msg2, s.walletAddr, acctAddr2, amt2)
// Complete "Address 1: Reserve 10" // Complete "Address 1: Reserve 10"
s.mockApi.completeMsg(cid.Cid(sentinelAddr1)) s.mockApi.completeMsg(cid.Cid(sentinelAddr1))
@ -289,14 +289,15 @@ func TestFundManagerRestart(t *testing.T) {
// Restart // Restart
mockApiAfter := s.mockApi mockApiAfter := s.mockApi
fmAfter := NewFundManager(mockApiAfter, s.ds, s.walletAddr) fmAfter := NewFundManager(mockApiAfter, s.ds)
fmAfter.Start() err = fmAfter.Start()
require.NoError(t, err)
amt3 := abi.NewTokenAmount(9) amt3 := abi.NewTokenAmount(9)
reserveSentinel := make(chan waitSentinel) reserveSentinel := make(chan WaitSentinel)
go func() { go func() {
// Address 2: Reserve 9 // Address 2: Reserve 9
sentinel3, err := fmAfter.Reserve(s.ctx, acctAddr2, amt3) sentinel3, err := fmAfter.Reserve(s.ctx, s.walletAddr, acctAddr2, amt3)
require.NoError(t, err) require.NoError(t, err)
reserveSentinel <- sentinel3 reserveSentinel <- sentinel3
}() }()
@ -317,7 +318,7 @@ func TestFundManagerRestart(t *testing.T) {
// Expect waiting message to now be sent // Expect waiting message to now be sent
sentinel3 := <-reserveSentinel sentinel3 := <-reserveSentinel
msg3 := mockApiAfter.getSentMessage(cid.Cid(sentinel3)) msg3 := mockApiAfter.getSentMessage(cid.Cid(sentinel3))
checkMessageFields(t, msg3, s.walletAddr, acctAddr2, amt3) checkAddMessageFields(t, msg3, s.walletAddr, acctAddr2, amt3)
} }
type scaffold struct { type scaffold struct {
@ -346,7 +347,7 @@ func setup(t *testing.T) *scaffold {
mockApi := newMockFundManagerAPI(walletAddr) mockApi := newMockFundManagerAPI(walletAddr)
ds := ds_sync.MutexWrap(ds.NewMapDatastore()) ds := ds_sync.MutexWrap(ds.NewMapDatastore())
fm := NewFundManager(mockApi, ds, walletAddr) fm := NewFundManager(mockApi, ds)
return &scaffold{ return &scaffold{
ctx: ctx, ctx: ctx,
ds: ds, ds: ds,
@ -357,7 +358,7 @@ func setup(t *testing.T) *scaffold {
} }
} }
func checkMessageFields(t *testing.T, msg *types.Message, from address.Address, to address.Address, amt abi.TokenAmount) { func checkAddMessageFields(t *testing.T, msg *types.Message, from address.Address, to address.Address, amt abi.TokenAmount) {
require.Equal(t, from, msg.From) require.Equal(t, from, msg.From)
require.Equal(t, market.Address, msg.To) require.Equal(t, market.Address, msg.To)
require.Equal(t, amt, msg.Value) require.Equal(t, amt, msg.Value)
@ -368,6 +369,18 @@ func checkMessageFields(t *testing.T, msg *types.Message, from address.Address,
require.Equal(t, to, paramsTo) require.Equal(t, to, paramsTo)
} }
func checkWithdrawMessageFields(t *testing.T, msg *types.Message, from address.Address, addr address.Address, amt abi.TokenAmount) {
require.Equal(t, from, msg.From)
require.Equal(t, market.Address, msg.To)
require.Equal(t, abi.NewTokenAmount(0), msg.Value)
var params market.WithdrawBalanceParams
err := params.UnmarshalCBOR(bytes.NewReader(msg.Params))
require.NoError(t, err)
require.Equal(t, addr, params.ProviderOrClientAddress)
require.Equal(t, amt, params.Amount)
}
type sentMsg struct { type sentMsg struct {
msg *types.SignedMessage msg *types.SignedMessage
ready chan struct{} ready chan struct{}
@ -428,7 +441,7 @@ func (mapi *mockFundManagerAPI) completeMsg(msgCid cid.Cid) {
pmsg, ok := mapi.sentMsgs[msgCid] pmsg, ok := mapi.sentMsgs[msgCid]
if ok { if ok {
if pmsg.msg.Message.From == mapi.wallet { if pmsg.msg.Message.Method == market.Methods.AddBalance {
var escrowAcct address.Address var escrowAcct address.Address
err := escrowAcct.UnmarshalCBOR(bytes.NewReader(pmsg.msg.Message.Params)) err := escrowAcct.UnmarshalCBOR(bytes.NewReader(pmsg.msg.Message.Params))
if err != nil { if err != nil {
@ -441,10 +454,16 @@ func (mapi *mockFundManagerAPI) completeMsg(msgCid cid.Cid) {
mapi.escrow[escrowAcct] = escrow mapi.escrow[escrowAcct] = escrow
log.Debugf("%s: escrow %d -> %d", escrowAcct, before, escrow) log.Debugf("%s: escrow %d -> %d", escrowAcct, before, escrow)
} else { } else {
escrowAcct := pmsg.msg.Message.From var params market.WithdrawBalanceParams
err := params.UnmarshalCBOR(bytes.NewReader(pmsg.msg.Message.Params))
if err != nil {
panic(err)
}
escrowAcct := params.ProviderOrClientAddress
escrow := mapi.getEscrow(escrowAcct) escrow := mapi.getEscrow(escrowAcct)
before := escrow before := escrow
escrow = types.BigSub(escrow, pmsg.msg.Message.Value) escrow = types.BigSub(escrow, params.Amount)
mapi.escrow[escrowAcct] = escrow mapi.escrow[escrowAcct] = escrow
log.Debugf("%s: escrow %d -> %d", escrowAcct, before, escrow) log.Debugf("%s: escrow %d -> %d", escrowAcct, before, escrow)
} }