refactor: simplify state management

This commit is contained in:
Dirk McCormick 2020-07-09 18:27:39 -04:00
parent 1cdb008bd5
commit 6c70ef7c7d
4 changed files with 130 additions and 157 deletions

View File

@ -40,18 +40,9 @@ type StateManagerApi interface {
Call(ctx context.Context, msg *types.Message, ts *types.TipSet) (*api.InvocResult, error) Call(ctx context.Context, msg *types.Message, ts *types.TipSet) (*api.InvocResult, error)
} }
//type StateApi interface {
// StateWaitMsg(ctx context.Context, msg cid.Cid, confidence uint64) (*api.MsgLookup, error)
//}
//
//type MpoolApi interface {
// MpoolPushMessage(ctx context.Context, msg *types.Message) (*types.SignedMessage, error)
//}
type Manager struct { type Manager struct {
store *Store store *Store
//sm *stmgr.StateManager sm StateManagerApi
sm StateManagerApi
mpool full.MpoolAPI mpool full.MpoolAPI
wallet full.WalletAPI wallet full.WalletAPI
@ -74,85 +65,19 @@ func newManager(sm StateManagerApi, pchstore *Store) *Manager {
return &Manager{ return &Manager{
store: pchstore, store: pchstore,
sm: sm, sm: sm,
//mpool: api.MpoolAPI,
//wallet: api.WalletAPI,
//state: api.StateAPI,
} }
} }
func nextLaneFromState(st *paych.State) uint64 {
if len(st.LaneStates) == 0 {
return 0
}
maxLane := st.LaneStates[0].ID
for _, state := range st.LaneStates {
if state.ID > maxLane {
maxLane = state.ID
}
}
return maxLane + 1
}
func (pm *Manager) TrackInboundChannel(ctx context.Context, ch address.Address) error {
_, st, err := pm.loadPaychState(ctx, ch)
if err != nil {
return err
}
var account account.State
_, err = pm.sm.LoadActorState(ctx, st.From, &account, nil)
if err != nil {
return err
}
from := account.Address
_, err = pm.sm.LoadActorState(ctx, st.To, &account, nil)
if err != nil {
return err
}
to := account.Address
return pm.store.TrackChannel(&ChannelInfo{
Channel: ch,
Control: to,
Target: from,
Direction: DirInbound,
NextLane: nextLaneFromState(st),
})
}
func (pm *Manager) loadOutboundChannelInfo(ctx context.Context, ch address.Address) (*ChannelInfo, error) {
_, st, err := pm.loadPaychState(ctx, ch)
if err != nil {
return nil, err
}
var account account.State
_, err = pm.sm.LoadActorState(ctx, st.From, &account, nil)
if err != nil {
return nil, err
}
from := account.Address
_, err = pm.sm.LoadActorState(ctx, st.To, &account, nil)
if err != nil {
return nil, err
}
to := account.Address
return &ChannelInfo{
Channel: ch,
Control: from,
Target: to,
Direction: DirOutbound,
NextLane: nextLaneFromState(st),
}, nil
}
func (pm *Manager) TrackOutboundChannel(ctx context.Context, ch address.Address) error { func (pm *Manager) TrackOutboundChannel(ctx context.Context, ch address.Address) error {
ci, err := pm.loadOutboundChannelInfo(ctx, ch) return pm.trackChannel(ctx, ch, DirOutbound)
}
func (pm *Manager) TrackInboundChannel(ctx context.Context, ch address.Address) error {
return pm.trackChannel(ctx, ch, DirInbound)
}
func (pm *Manager) trackChannel(ctx context.Context, ch address.Address, dir uint64) error {
ci, err := pm.loadStateChannelInfo(ctx, ch, dir)
if err != nil { if err != nil {
return err return err
} }
@ -170,58 +95,68 @@ func (pm *Manager) GetChannelInfo(addr address.Address) (*ChannelInfo, error) {
// checks if the given voucher is valid (is or could become spendable at some point) // checks if the given voucher is valid (is or could become spendable at some point)
func (pm *Manager) CheckVoucherValid(ctx context.Context, ch address.Address, sv *paych.SignedVoucher) error { func (pm *Manager) CheckVoucherValid(ctx context.Context, ch address.Address, sv *paych.SignedVoucher) error {
_, err := pm.checkVoucherValid(ctx, ch, sv)
return err
}
func (pm *Manager) checkVoucherValid(ctx context.Context, ch address.Address, sv *paych.SignedVoucher) (*paych.State, error) {
act, pca, err := pm.loadPaychState(ctx, ch) act, pca, err := pm.loadPaychState(ctx, ch)
if err != nil { if err != nil {
return err return nil, err
} }
var account account.State var account account.State
_, err = pm.sm.LoadActorState(ctx, pca.From, &account, nil) _, err = pm.sm.LoadActorState(ctx, pca.From, &account, nil)
if err != nil { if err != nil {
return err return nil, err
} }
from := account.Address from := account.Address
// verify signature // verify signature
vb, err := sv.SigningBytes() vb, err := sv.SigningBytes()
if err != nil { if err != nil {
return err return nil, err
} }
// TODO: technically, either party may create and sign a voucher. // TODO: technically, either party may create and sign a voucher.
// However, for now, we only accept them from the channel creator. // However, for now, we only accept them from the channel creator.
// More complex handling logic can be added later // More complex handling logic can be added later
if err := sigs.Verify(sv.Signature, from, vb); err != nil { if err := sigs.Verify(sv.Signature, from, vb); err != nil {
return err return nil, err
} }
sendAmount := sv.Amount sendAmount := sv.Amount
// now check the lane state // Check the voucher against the highest known voucher nonce / value
// TODO: should check against vouchers in our local store too ls, err := pm.laneState(pca, ch, sv.Lane)
// there might be something conflicting if err != nil {
ls := findLane(pca.LaneStates, uint64(sv.Lane)) return nil, err
if ls == nil {
} else {
if (ls.Nonce) >= sv.Nonce {
return fmt.Errorf("nonce too low")
}
// TODO: return error if ls.Redeemed > vs.Amount
sendAmount = types.BigSub(sv.Amount, ls.Redeemed)
} }
// If there has been at least once voucher redeemed, and the voucher
// nonce value is less than the highest known nonce
if ls.Redeemed.Int64() > 0 && sv.Nonce <= ls.Nonce {
return nil, fmt.Errorf("nonce too low")
}
// If the voucher amount is less than the highest known voucher amount
if sv.Amount.LessThanEqual(ls.Redeemed) {
return nil, fmt.Errorf("voucher amount is lower than amount for voucher with lower nonce")
}
// Only send the difference between the voucher amount and what has already
// been redeemed
sendAmount = types.BigSub(sv.Amount, ls.Redeemed)
// TODO: also account for vouchers on other lanes we've received // TODO: also account for vouchers on other lanes we've received
newTotal := types.BigAdd(sendAmount, pca.ToSend) newTotal := types.BigAdd(sendAmount, pca.ToSend)
if act.Balance.LessThan(newTotal) { if act.Balance.LessThan(newTotal) {
return fmt.Errorf("not enough funds in channel to cover voucher") return nil, fmt.Errorf("not enough funds in channel to cover voucher")
} }
if len(sv.Merges) != 0 { if len(sv.Merges) != 0 {
return fmt.Errorf("dont currently support paych lane merges") return nil, fmt.Errorf("dont currently support paych lane merges")
} }
return nil return pca, nil
} }
// checks if the given voucher is currently spendable // checks if the given voucher is currently spendable
@ -289,10 +224,6 @@ func (pm *Manager) getPaychOwner(ctx context.Context, ch address.Address) (addre
} }
func (pm *Manager) AddVoucher(ctx context.Context, ch address.Address, sv *paych.SignedVoucher, proof []byte, minDelta types.BigInt) (types.BigInt, error) { func (pm *Manager) AddVoucher(ctx context.Context, ch address.Address, sv *paych.SignedVoucher, proof []byte, minDelta types.BigInt) (types.BigInt, error) {
if err := pm.CheckVoucherValid(ctx, ch, sv); err != nil {
return types.NewInt(0), err
}
pm.store.lk.Lock() pm.store.lk.Lock()
defer pm.store.lk.Unlock() defer pm.store.lk.Unlock()
@ -301,25 +232,7 @@ func (pm *Manager) AddVoucher(ctx context.Context, ch address.Address, sv *paych
return types.NewInt(0), err return types.NewInt(0), err
} }
laneState, err := pm.laneState(ctx, ch, uint64(sv.Lane)) // Check if the voucher has already been added
if err != nil {
return types.NewInt(0), err
}
// TODO: I believe this check is redundant because
// CheckVoucherValid() already returns an error if laneState.Nonce >= sv.Nonce
if minDelta.GreaterThan(types.NewInt(0)) && laneState.Nonce > sv.Nonce {
return types.NewInt(0), xerrors.Errorf("already storing voucher with higher nonce; %d > %d", laneState.Nonce, sv.Nonce)
}
// TODO:
// It's possible to repeatedly add a voucher with the same proof:
// 1. add a voucher with proof P1
// 2. add a voucher with proof P2
// 3. add a voucher with proof P2 (again)
// Voucher with proof P2 has been added twice
//
// look for duplicates
for i, v := range ci.Vouchers { for i, v := range ci.Vouchers {
eq, err := cborutil.Equals(sv, v.Voucher) eq, err := cborutil.Equals(sv, v.Voucher)
if err != nil { if err != nil {
@ -328,24 +241,35 @@ func (pm *Manager) AddVoucher(ctx context.Context, ch address.Address, sv *paych
if !eq { if !eq {
continue continue
} }
// TODO: CBOR encoding / decoding changes nil into []byte{}, so instead of
// checking v.Proof against nil we should check len(v.Proof) == 0 // This is a duplicate voucher.
if v.Proof != nil { // Update the proof on the existing voucher
if !bytes.Equal(v.Proof, proof) { if len(proof) > 0 && !bytes.Equal(v.Proof, proof) {
log.Warnf("AddVoucher: multiple proofs for single voucher, storing both") log.Warnf("AddVoucher: adding proof to stored voucher")
break ci.Vouchers[i] = &VoucherInfo{
Voucher: v.Voucher,
Proof: proof,
} }
log.Warnf("AddVoucher: voucher re-added with matching proof")
return types.NewInt(0), nil return types.NewInt(0), pm.store.putChannelInfo(ci)
} }
log.Warnf("AddVoucher: adding proof to stored voucher") // Otherwise just ignore the duplicate voucher
ci.Vouchers[i] = &VoucherInfo{ log.Warnf("AddVoucher: voucher re-added with matching proof")
Voucher: v.Voucher, return types.NewInt(0), nil
Proof: proof, }
}
return types.NewInt(0), pm.store.putChannelInfo(ci) // Check voucher validity
pchState, err := pm.checkVoucherValid(ctx, ch, sv)
if err != nil {
return types.NewInt(0), err
}
// The change in value is the delta between the voucher amount and
// the highest previous voucher amount
laneState, err := pm.laneState(pchState, ch, sv.Lane)
if err != nil {
return types.NewInt(0), err
} }
delta := types.BigSub(sv.Amount, laneState.Redeemed) delta := types.BigSub(sv.Amount, laneState.Redeemed)

View File

@ -435,7 +435,7 @@ func TestAddVoucherProof(t *testing.T) {
// Add same voucher with proof // Add same voucher with proof
proof = []byte{1} proof = []byte{1}
_, err = mgr.AddVoucher(ctx, ch, sv, nil, minDelta) _, err = mgr.AddVoucher(ctx, ch, sv, proof, minDelta)
require.NoError(t, err) require.NoError(t, err)
// Should add proof to existing voucher // Should add proof to existing voucher
@ -478,31 +478,32 @@ func TestNextNonceForLane(t *testing.T) {
voucherAmount = big.NewInt(2) voucherAmount = big.NewInt(2)
// Add vouchers such that we have // Add vouchers such that we have
// lane 1: nonce 3
// lane 1: nonce 2 // lane 1: nonce 2
// lane 2: nonce 5 // lane 1: nonce 4
// lane 2: nonce 7
voucherLane := uint64(1) voucherLane := uint64(1)
for _, nonce := range []uint64{3, 2} { for _, nonce := range []uint64{2, 4} {
voucherAmount = big.Add(voucherAmount, big.NewInt(1))
sv := testCreateVoucher(t, voucherLane, nonce, voucherAmount, key) sv := testCreateVoucher(t, voucherLane, nonce, voucherAmount, key)
_, err := mgr.AddVoucher(ctx, ch, sv, nil, minDelta) _, err := mgr.AddVoucher(ctx, ch, sv, nil, minDelta)
require.NoError(t, err) require.NoError(t, err)
} }
voucherLane = uint64(2) voucherLane = uint64(2)
nonce := uint64(5) nonce := uint64(7)
sv := testCreateVoucher(t, voucherLane, nonce, voucherAmount, key) sv := testCreateVoucher(t, voucherLane, nonce, voucherAmount, key)
_, err = mgr.AddVoucher(ctx, ch, sv, nil, minDelta) _, err = mgr.AddVoucher(ctx, ch, sv, nil, minDelta)
require.NoError(t, err) require.NoError(t, err)
// Expect next nonce for lane 1 to be 4 // Expect next nonce for lane 1 to be 5
next, err = mgr.NextNonceForLane(ctx, ch, 1) next, err = mgr.NextNonceForLane(ctx, ch, 1)
require.NoError(t, err) require.NoError(t, err)
require.EqualValues(t, next, 4) require.EqualValues(t, next, 5)
// Expect next nonce for lane 2 to be 6 // Expect next nonce for lane 2 to be 8
next, err = mgr.NextNonceForLane(ctx, ch, 2) next, err = mgr.NextNonceForLane(ctx, ch, 2)
require.NoError(t, err) require.NoError(t, err)
require.EqualValues(t, next, 6) require.EqualValues(t, next, 8)
} }
func testSetupMgrWithChannel(t *testing.T, ctx context.Context) (*Manager, address.Address, []byte) { func testSetupMgrWithChannel(t *testing.T, ctx context.Context) (*Manager, address.Address, []byte) {

View File

@ -75,7 +75,7 @@ func (pm *Manager) waitForPaychCreateMsg(ctx context.Context, mcid cid.Cid) {
} }
paychaddr := decodedReturn.RobustAddress paychaddr := decodedReturn.RobustAddress
ci, err := pm.loadOutboundChannelInfo(ctx, paychaddr) ci, err := pm.loadStateChannelInfo(ctx, paychaddr, DirOutbound)
if err != nil { if err != nil {
log.Errorf("loading channel info: %w", err) log.Errorf("loading channel info: %w", err)
return return

View File

@ -3,6 +3,8 @@ package paychmgr
import ( import (
"context" "context"
"github.com/filecoin-project/specs-actors/actors/builtin/account"
"github.com/filecoin-project/go-address" "github.com/filecoin-project/go-address"
"github.com/filecoin-project/specs-actors/actors/builtin/paych" "github.com/filecoin-project/specs-actors/actors/builtin/paych"
xerrors "golang.org/x/xerrors" xerrors "golang.org/x/xerrors"
@ -20,6 +22,55 @@ func (pm *Manager) loadPaychState(ctx context.Context, ch address.Address) (*typ
return act, &pcast, nil return act, &pcast, nil
} }
func (pm *Manager) loadStateChannelInfo(ctx context.Context, ch address.Address, dir uint64) (*ChannelInfo, error) {
_, st, err := pm.loadPaychState(ctx, ch)
if err != nil {
return nil, err
}
var account account.State
_, err = pm.sm.LoadActorState(ctx, st.From, &account, nil)
if err != nil {
return nil, err
}
from := account.Address
_, err = pm.sm.LoadActorState(ctx, st.To, &account, nil)
if err != nil {
return nil, err
}
to := account.Address
ci := &ChannelInfo{
Channel: ch,
Direction: dir,
NextLane: nextLaneFromState(st),
}
if dir == DirOutbound {
ci.Control = from
ci.Target = to
} else {
ci.Control = to
ci.Target = from
}
return ci, nil
}
func nextLaneFromState(st *paych.State) uint64 {
if len(st.LaneStates) == 0 {
return 0
}
maxLane := st.LaneStates[0].ID
for _, state := range st.LaneStates {
if state.ID > maxLane {
maxLane = state.ID
}
}
return maxLane + 1
}
func findLane(states []*paych.LaneState, lane uint64) *paych.LaneState { func findLane(states []*paych.LaneState, lane uint64) *paych.LaneState {
var ls *paych.LaneState var ls *paych.LaneState
for _, laneState := range states { for _, laneState := range states {
@ -31,16 +82,12 @@ func findLane(states []*paych.LaneState, lane uint64) *paych.LaneState {
return ls return ls
} }
func (pm *Manager) laneState(ctx context.Context, ch address.Address, lane uint64) (paych.LaneState, error) { func (pm *Manager) laneState(state *paych.State, ch address.Address, lane uint64) (paych.LaneState, error) {
_, state, err := pm.loadPaychState(ctx, ch)
if err != nil {
return paych.LaneState{}, err
}
// TODO: we probably want to call UpdateChannelState with all vouchers to be fully correct // TODO: we probably want to call UpdateChannelState with all vouchers to be fully correct
// (but technically dont't need to) // (but technically dont't need to)
// TODO: make sure this is correct // TODO: make sure this is correct
// Get the lane state from the chain
ls := findLane(state.LaneStates, lane) ls := findLane(state.LaneStates, lane)
if ls == nil { if ls == nil {
ls = &paych.LaneState{ ls = &paych.LaneState{
@ -50,6 +97,7 @@ func (pm *Manager) laneState(ctx context.Context, ch address.Address, lane uint6
} }
} }
// Apply locally stored vouchers
vouchers, err := pm.store.VouchersForPaych(ch) vouchers, err := pm.store.VouchersForPaych(ch)
if err != nil { if err != nil {
if err == ErrChannelNotTracked { if err == ErrChannelNotTracked {