refactor: simplify state management

This commit is contained in:
Dirk McCormick 2020-07-09 18:27:39 -04:00
parent 1cdb008bd5
commit 6c70ef7c7d
4 changed files with 130 additions and 157 deletions

View File

@ -40,17 +40,8 @@ type StateManagerApi interface {
Call(ctx context.Context, msg *types.Message, ts *types.TipSet) (*api.InvocResult, error)
}
//type StateApi interface {
// StateWaitMsg(ctx context.Context, msg cid.Cid, confidence uint64) (*api.MsgLookup, error)
//}
//
//type MpoolApi interface {
// MpoolPushMessage(ctx context.Context, msg *types.Message) (*types.SignedMessage, error)
//}
type Manager struct {
store *Store
//sm *stmgr.StateManager
sm StateManagerApi
mpool full.MpoolAPI
@ -74,85 +65,19 @@ func newManager(sm StateManagerApi, pchstore *Store) *Manager {
return &Manager{
store: pchstore,
sm: sm,
//mpool: api.MpoolAPI,
//wallet: api.WalletAPI,
//state: api.StateAPI,
}
}
func nextLaneFromState(st *paych.State) uint64 {
if len(st.LaneStates) == 0 {
return 0
}
maxLane := st.LaneStates[0].ID
for _, state := range st.LaneStates {
if state.ID > maxLane {
maxLane = state.ID
}
}
return maxLane + 1
}
func (pm *Manager) TrackInboundChannel(ctx context.Context, ch address.Address) error {
_, st, err := pm.loadPaychState(ctx, ch)
if err != nil {
return err
}
var account account.State
_, err = pm.sm.LoadActorState(ctx, st.From, &account, nil)
if err != nil {
return err
}
from := account.Address
_, err = pm.sm.LoadActorState(ctx, st.To, &account, nil)
if err != nil {
return err
}
to := account.Address
return pm.store.TrackChannel(&ChannelInfo{
Channel: ch,
Control: to,
Target: from,
Direction: DirInbound,
NextLane: nextLaneFromState(st),
})
}
func (pm *Manager) loadOutboundChannelInfo(ctx context.Context, ch address.Address) (*ChannelInfo, error) {
_, st, err := pm.loadPaychState(ctx, ch)
if err != nil {
return nil, err
}
var account account.State
_, err = pm.sm.LoadActorState(ctx, st.From, &account, nil)
if err != nil {
return nil, err
}
from := account.Address
_, err = pm.sm.LoadActorState(ctx, st.To, &account, nil)
if err != nil {
return nil, err
}
to := account.Address
return &ChannelInfo{
Channel: ch,
Control: from,
Target: to,
Direction: DirOutbound,
NextLane: nextLaneFromState(st),
}, nil
}
func (pm *Manager) TrackOutboundChannel(ctx context.Context, ch address.Address) error {
ci, err := pm.loadOutboundChannelInfo(ctx, ch)
return pm.trackChannel(ctx, ch, DirOutbound)
}
func (pm *Manager) TrackInboundChannel(ctx context.Context, ch address.Address) error {
return pm.trackChannel(ctx, ch, DirInbound)
}
func (pm *Manager) trackChannel(ctx context.Context, ch address.Address, dir uint64) error {
ci, err := pm.loadStateChannelInfo(ctx, ch, dir)
if err != nil {
return err
}
@ -170,58 +95,68 @@ func (pm *Manager) GetChannelInfo(addr address.Address) (*ChannelInfo, error) {
// checks if the given voucher is valid (is or could become spendable at some point)
func (pm *Manager) CheckVoucherValid(ctx context.Context, ch address.Address, sv *paych.SignedVoucher) error {
_, err := pm.checkVoucherValid(ctx, ch, sv)
return err
}
func (pm *Manager) checkVoucherValid(ctx context.Context, ch address.Address, sv *paych.SignedVoucher) (*paych.State, error) {
act, pca, err := pm.loadPaychState(ctx, ch)
if err != nil {
return err
return nil, err
}
var account account.State
_, err = pm.sm.LoadActorState(ctx, pca.From, &account, nil)
if err != nil {
return err
return nil, err
}
from := account.Address
// verify signature
vb, err := sv.SigningBytes()
if err != nil {
return err
return nil, err
}
// TODO: technically, either party may create and sign a voucher.
// However, for now, we only accept them from the channel creator.
// More complex handling logic can be added later
if err := sigs.Verify(sv.Signature, from, vb); err != nil {
return err
return nil, err
}
sendAmount := sv.Amount
// now check the lane state
// TODO: should check against vouchers in our local store too
// there might be something conflicting
ls := findLane(pca.LaneStates, uint64(sv.Lane))
if ls == nil {
} else {
if (ls.Nonce) >= sv.Nonce {
return fmt.Errorf("nonce too low")
// Check the voucher against the highest known voucher nonce / value
ls, err := pm.laneState(pca, ch, sv.Lane)
if err != nil {
return nil, err
}
// If there has been at least once voucher redeemed, and the voucher
// nonce value is less than the highest known nonce
if ls.Redeemed.Int64() > 0 && sv.Nonce <= ls.Nonce {
return nil, fmt.Errorf("nonce too low")
}
// If the voucher amount is less than the highest known voucher amount
if sv.Amount.LessThanEqual(ls.Redeemed) {
return nil, fmt.Errorf("voucher amount is lower than amount for voucher with lower nonce")
}
// TODO: return error if ls.Redeemed > vs.Amount
// Only send the difference between the voucher amount and what has already
// been redeemed
sendAmount = types.BigSub(sv.Amount, ls.Redeemed)
}
// TODO: also account for vouchers on other lanes we've received
newTotal := types.BigAdd(sendAmount, pca.ToSend)
if act.Balance.LessThan(newTotal) {
return fmt.Errorf("not enough funds in channel to cover voucher")
return nil, fmt.Errorf("not enough funds in channel to cover voucher")
}
if len(sv.Merges) != 0 {
return fmt.Errorf("dont currently support paych lane merges")
return nil, fmt.Errorf("dont currently support paych lane merges")
}
return nil
return pca, nil
}
// checks if the given voucher is currently spendable
@ -289,10 +224,6 @@ func (pm *Manager) getPaychOwner(ctx context.Context, ch address.Address) (addre
}
func (pm *Manager) AddVoucher(ctx context.Context, ch address.Address, sv *paych.SignedVoucher, proof []byte, minDelta types.BigInt) (types.BigInt, error) {
if err := pm.CheckVoucherValid(ctx, ch, sv); err != nil {
return types.NewInt(0), err
}
pm.store.lk.Lock()
defer pm.store.lk.Unlock()
@ -301,25 +232,7 @@ func (pm *Manager) AddVoucher(ctx context.Context, ch address.Address, sv *paych
return types.NewInt(0), err
}
laneState, err := pm.laneState(ctx, ch, uint64(sv.Lane))
if err != nil {
return types.NewInt(0), err
}
// TODO: I believe this check is redundant because
// CheckVoucherValid() already returns an error if laneState.Nonce >= sv.Nonce
if minDelta.GreaterThan(types.NewInt(0)) && laneState.Nonce > sv.Nonce {
return types.NewInt(0), xerrors.Errorf("already storing voucher with higher nonce; %d > %d", laneState.Nonce, sv.Nonce)
}
// TODO:
// It's possible to repeatedly add a voucher with the same proof:
// 1. add a voucher with proof P1
// 2. add a voucher with proof P2
// 3. add a voucher with proof P2 (again)
// Voucher with proof P2 has been added twice
//
// look for duplicates
// Check if the voucher has already been added
for i, v := range ci.Vouchers {
eq, err := cborutil.Equals(sv, v.Voucher)
if err != nil {
@ -328,17 +241,10 @@ func (pm *Manager) AddVoucher(ctx context.Context, ch address.Address, sv *paych
if !eq {
continue
}
// TODO: CBOR encoding / decoding changes nil into []byte{}, so instead of
// checking v.Proof against nil we should check len(v.Proof) == 0
if v.Proof != nil {
if !bytes.Equal(v.Proof, proof) {
log.Warnf("AddVoucher: multiple proofs for single voucher, storing both")
break
}
log.Warnf("AddVoucher: voucher re-added with matching proof")
return types.NewInt(0), nil
}
// This is a duplicate voucher.
// Update the proof on the existing voucher
if len(proof) > 0 && !bytes.Equal(v.Proof, proof) {
log.Warnf("AddVoucher: adding proof to stored voucher")
ci.Vouchers[i] = &VoucherInfo{
Voucher: v.Voucher,
@ -348,6 +254,24 @@ func (pm *Manager) AddVoucher(ctx context.Context, ch address.Address, sv *paych
return types.NewInt(0), pm.store.putChannelInfo(ci)
}
// Otherwise just ignore the duplicate voucher
log.Warnf("AddVoucher: voucher re-added with matching proof")
return types.NewInt(0), nil
}
// Check voucher validity
pchState, err := pm.checkVoucherValid(ctx, ch, sv)
if err != nil {
return types.NewInt(0), err
}
// The change in value is the delta between the voucher amount and
// the highest previous voucher amount
laneState, err := pm.laneState(pchState, ch, sv.Lane)
if err != nil {
return types.NewInt(0), err
}
delta := types.BigSub(sv.Amount, laneState.Redeemed)
if minDelta.GreaterThan(delta) {
return delta, xerrors.Errorf("addVoucher: supplied token amount too low; minD=%s, D=%s; laneAmt=%s; v.Amt=%s", minDelta, delta, laneState.Redeemed, sv.Amount)

View File

@ -435,7 +435,7 @@ func TestAddVoucherProof(t *testing.T) {
// Add same voucher with proof
proof = []byte{1}
_, err = mgr.AddVoucher(ctx, ch, sv, nil, minDelta)
_, err = mgr.AddVoucher(ctx, ch, sv, proof, minDelta)
require.NoError(t, err)
// Should add proof to existing voucher
@ -478,31 +478,32 @@ func TestNextNonceForLane(t *testing.T) {
voucherAmount = big.NewInt(2)
// Add vouchers such that we have
// lane 1: nonce 3
// lane 1: nonce 2
// lane 2: nonce 5
// lane 1: nonce 4
// lane 2: nonce 7
voucherLane := uint64(1)
for _, nonce := range []uint64{3, 2} {
for _, nonce := range []uint64{2, 4} {
voucherAmount = big.Add(voucherAmount, big.NewInt(1))
sv := testCreateVoucher(t, voucherLane, nonce, voucherAmount, key)
_, err := mgr.AddVoucher(ctx, ch, sv, nil, minDelta)
require.NoError(t, err)
}
voucherLane = uint64(2)
nonce := uint64(5)
nonce := uint64(7)
sv := testCreateVoucher(t, voucherLane, nonce, voucherAmount, key)
_, err = mgr.AddVoucher(ctx, ch, sv, nil, minDelta)
require.NoError(t, err)
// Expect next nonce for lane 1 to be 4
// Expect next nonce for lane 1 to be 5
next, err = mgr.NextNonceForLane(ctx, ch, 1)
require.NoError(t, err)
require.EqualValues(t, next, 4)
require.EqualValues(t, next, 5)
// Expect next nonce for lane 2 to be 6
// Expect next nonce for lane 2 to be 8
next, err = mgr.NextNonceForLane(ctx, ch, 2)
require.NoError(t, err)
require.EqualValues(t, next, 6)
require.EqualValues(t, next, 8)
}
func testSetupMgrWithChannel(t *testing.T, ctx context.Context) (*Manager, address.Address, []byte) {

View File

@ -75,7 +75,7 @@ func (pm *Manager) waitForPaychCreateMsg(ctx context.Context, mcid cid.Cid) {
}
paychaddr := decodedReturn.RobustAddress
ci, err := pm.loadOutboundChannelInfo(ctx, paychaddr)
ci, err := pm.loadStateChannelInfo(ctx, paychaddr, DirOutbound)
if err != nil {
log.Errorf("loading channel info: %w", err)
return

View File

@ -3,6 +3,8 @@ package paychmgr
import (
"context"
"github.com/filecoin-project/specs-actors/actors/builtin/account"
"github.com/filecoin-project/go-address"
"github.com/filecoin-project/specs-actors/actors/builtin/paych"
xerrors "golang.org/x/xerrors"
@ -20,6 +22,55 @@ func (pm *Manager) loadPaychState(ctx context.Context, ch address.Address) (*typ
return act, &pcast, nil
}
func (pm *Manager) loadStateChannelInfo(ctx context.Context, ch address.Address, dir uint64) (*ChannelInfo, error) {
_, st, err := pm.loadPaychState(ctx, ch)
if err != nil {
return nil, err
}
var account account.State
_, err = pm.sm.LoadActorState(ctx, st.From, &account, nil)
if err != nil {
return nil, err
}
from := account.Address
_, err = pm.sm.LoadActorState(ctx, st.To, &account, nil)
if err != nil {
return nil, err
}
to := account.Address
ci := &ChannelInfo{
Channel: ch,
Direction: dir,
NextLane: nextLaneFromState(st),
}
if dir == DirOutbound {
ci.Control = from
ci.Target = to
} else {
ci.Control = to
ci.Target = from
}
return ci, nil
}
func nextLaneFromState(st *paych.State) uint64 {
if len(st.LaneStates) == 0 {
return 0
}
maxLane := st.LaneStates[0].ID
for _, state := range st.LaneStates {
if state.ID > maxLane {
maxLane = state.ID
}
}
return maxLane + 1
}
func findLane(states []*paych.LaneState, lane uint64) *paych.LaneState {
var ls *paych.LaneState
for _, laneState := range states {
@ -31,16 +82,12 @@ func findLane(states []*paych.LaneState, lane uint64) *paych.LaneState {
return ls
}
func (pm *Manager) laneState(ctx context.Context, ch address.Address, lane uint64) (paych.LaneState, error) {
_, state, err := pm.loadPaychState(ctx, ch)
if err != nil {
return paych.LaneState{}, err
}
func (pm *Manager) laneState(state *paych.State, ch address.Address, lane uint64) (paych.LaneState, error) {
// TODO: we probably want to call UpdateChannelState with all vouchers to be fully correct
// (but technically dont't need to)
// TODO: make sure this is correct
// Get the lane state from the chain
ls := findLane(state.LaneStates, lane)
if ls == nil {
ls = &paych.LaneState{
@ -50,6 +97,7 @@ func (pm *Manager) laneState(ctx context.Context, ch address.Address, lane uint6
}
}
// Apply locally stored vouchers
vouchers, err := pm.store.VouchersForPaych(ch)
if err != nil {
if err == ErrChannelNotTracked {