refactor: simplify state management
This commit is contained in:
parent
1cdb008bd5
commit
6c70ef7c7d
@ -40,18 +40,9 @@ type StateManagerApi interface {
|
||||
Call(ctx context.Context, msg *types.Message, ts *types.TipSet) (*api.InvocResult, error)
|
||||
}
|
||||
|
||||
//type StateApi interface {
|
||||
// StateWaitMsg(ctx context.Context, msg cid.Cid, confidence uint64) (*api.MsgLookup, error)
|
||||
//}
|
||||
//
|
||||
//type MpoolApi interface {
|
||||
// MpoolPushMessage(ctx context.Context, msg *types.Message) (*types.SignedMessage, error)
|
||||
//}
|
||||
|
||||
type Manager struct {
|
||||
store *Store
|
||||
//sm *stmgr.StateManager
|
||||
sm StateManagerApi
|
||||
sm StateManagerApi
|
||||
|
||||
mpool full.MpoolAPI
|
||||
wallet full.WalletAPI
|
||||
@ -74,85 +65,19 @@ func newManager(sm StateManagerApi, pchstore *Store) *Manager {
|
||||
return &Manager{
|
||||
store: pchstore,
|
||||
sm: sm,
|
||||
|
||||
//mpool: api.MpoolAPI,
|
||||
//wallet: api.WalletAPI,
|
||||
//state: api.StateAPI,
|
||||
}
|
||||
}
|
||||
|
||||
func nextLaneFromState(st *paych.State) uint64 {
|
||||
if len(st.LaneStates) == 0 {
|
||||
return 0
|
||||
}
|
||||
|
||||
maxLane := st.LaneStates[0].ID
|
||||
for _, state := range st.LaneStates {
|
||||
if state.ID > maxLane {
|
||||
maxLane = state.ID
|
||||
}
|
||||
}
|
||||
return maxLane + 1
|
||||
}
|
||||
|
||||
func (pm *Manager) TrackInboundChannel(ctx context.Context, ch address.Address) error {
|
||||
_, st, err := pm.loadPaychState(ctx, ch)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var account account.State
|
||||
_, err = pm.sm.LoadActorState(ctx, st.From, &account, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
from := account.Address
|
||||
_, err = pm.sm.LoadActorState(ctx, st.To, &account, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
to := account.Address
|
||||
|
||||
return pm.store.TrackChannel(&ChannelInfo{
|
||||
Channel: ch,
|
||||
Control: to,
|
||||
Target: from,
|
||||
|
||||
Direction: DirInbound,
|
||||
NextLane: nextLaneFromState(st),
|
||||
})
|
||||
}
|
||||
|
||||
func (pm *Manager) loadOutboundChannelInfo(ctx context.Context, ch address.Address) (*ChannelInfo, error) {
|
||||
_, st, err := pm.loadPaychState(ctx, ch)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var account account.State
|
||||
_, err = pm.sm.LoadActorState(ctx, st.From, &account, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
from := account.Address
|
||||
_, err = pm.sm.LoadActorState(ctx, st.To, &account, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
to := account.Address
|
||||
|
||||
return &ChannelInfo{
|
||||
Channel: ch,
|
||||
Control: from,
|
||||
Target: to,
|
||||
|
||||
Direction: DirOutbound,
|
||||
NextLane: nextLaneFromState(st),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (pm *Manager) TrackOutboundChannel(ctx context.Context, ch address.Address) error {
|
||||
ci, err := pm.loadOutboundChannelInfo(ctx, ch)
|
||||
return pm.trackChannel(ctx, ch, DirOutbound)
|
||||
}
|
||||
|
||||
func (pm *Manager) TrackInboundChannel(ctx context.Context, ch address.Address) error {
|
||||
return pm.trackChannel(ctx, ch, DirInbound)
|
||||
}
|
||||
|
||||
func (pm *Manager) trackChannel(ctx context.Context, ch address.Address, dir uint64) error {
|
||||
ci, err := pm.loadStateChannelInfo(ctx, ch, dir)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -170,58 +95,68 @@ func (pm *Manager) GetChannelInfo(addr address.Address) (*ChannelInfo, error) {
|
||||
|
||||
// checks if the given voucher is valid (is or could become spendable at some point)
|
||||
func (pm *Manager) CheckVoucherValid(ctx context.Context, ch address.Address, sv *paych.SignedVoucher) error {
|
||||
_, err := pm.checkVoucherValid(ctx, ch, sv)
|
||||
return err
|
||||
}
|
||||
|
||||
func (pm *Manager) checkVoucherValid(ctx context.Context, ch address.Address, sv *paych.SignedVoucher) (*paych.State, error) {
|
||||
act, pca, err := pm.loadPaychState(ctx, ch)
|
||||
if err != nil {
|
||||
return err
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var account account.State
|
||||
_, err = pm.sm.LoadActorState(ctx, pca.From, &account, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
return nil, err
|
||||
}
|
||||
from := account.Address
|
||||
|
||||
// verify signature
|
||||
vb, err := sv.SigningBytes()
|
||||
if err != nil {
|
||||
return err
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// TODO: technically, either party may create and sign a voucher.
|
||||
// However, for now, we only accept them from the channel creator.
|
||||
// More complex handling logic can be added later
|
||||
if err := sigs.Verify(sv.Signature, from, vb); err != nil {
|
||||
return err
|
||||
return nil, err
|
||||
}
|
||||
|
||||
sendAmount := sv.Amount
|
||||
|
||||
// now check the lane state
|
||||
// TODO: should check against vouchers in our local store too
|
||||
// there might be something conflicting
|
||||
ls := findLane(pca.LaneStates, uint64(sv.Lane))
|
||||
if ls == nil {
|
||||
} else {
|
||||
if (ls.Nonce) >= sv.Nonce {
|
||||
return fmt.Errorf("nonce too low")
|
||||
}
|
||||
|
||||
// TODO: return error if ls.Redeemed > vs.Amount
|
||||
sendAmount = types.BigSub(sv.Amount, ls.Redeemed)
|
||||
// Check the voucher against the highest known voucher nonce / value
|
||||
ls, err := pm.laneState(pca, ch, sv.Lane)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// If there has been at least once voucher redeemed, and the voucher
|
||||
// nonce value is less than the highest known nonce
|
||||
if ls.Redeemed.Int64() > 0 && sv.Nonce <= ls.Nonce {
|
||||
return nil, fmt.Errorf("nonce too low")
|
||||
}
|
||||
// If the voucher amount is less than the highest known voucher amount
|
||||
if sv.Amount.LessThanEqual(ls.Redeemed) {
|
||||
return nil, fmt.Errorf("voucher amount is lower than amount for voucher with lower nonce")
|
||||
}
|
||||
|
||||
// Only send the difference between the voucher amount and what has already
|
||||
// been redeemed
|
||||
sendAmount = types.BigSub(sv.Amount, ls.Redeemed)
|
||||
|
||||
// TODO: also account for vouchers on other lanes we've received
|
||||
newTotal := types.BigAdd(sendAmount, pca.ToSend)
|
||||
if act.Balance.LessThan(newTotal) {
|
||||
return fmt.Errorf("not enough funds in channel to cover voucher")
|
||||
return nil, fmt.Errorf("not enough funds in channel to cover voucher")
|
||||
}
|
||||
|
||||
if len(sv.Merges) != 0 {
|
||||
return fmt.Errorf("dont currently support paych lane merges")
|
||||
return nil, fmt.Errorf("dont currently support paych lane merges")
|
||||
}
|
||||
|
||||
return nil
|
||||
return pca, nil
|
||||
}
|
||||
|
||||
// checks if the given voucher is currently spendable
|
||||
@ -289,10 +224,6 @@ func (pm *Manager) getPaychOwner(ctx context.Context, ch address.Address) (addre
|
||||
}
|
||||
|
||||
func (pm *Manager) AddVoucher(ctx context.Context, ch address.Address, sv *paych.SignedVoucher, proof []byte, minDelta types.BigInt) (types.BigInt, error) {
|
||||
if err := pm.CheckVoucherValid(ctx, ch, sv); err != nil {
|
||||
return types.NewInt(0), err
|
||||
}
|
||||
|
||||
pm.store.lk.Lock()
|
||||
defer pm.store.lk.Unlock()
|
||||
|
||||
@ -301,25 +232,7 @@ func (pm *Manager) AddVoucher(ctx context.Context, ch address.Address, sv *paych
|
||||
return types.NewInt(0), err
|
||||
}
|
||||
|
||||
laneState, err := pm.laneState(ctx, ch, uint64(sv.Lane))
|
||||
if err != nil {
|
||||
return types.NewInt(0), err
|
||||
}
|
||||
|
||||
// TODO: I believe this check is redundant because
|
||||
// CheckVoucherValid() already returns an error if laneState.Nonce >= sv.Nonce
|
||||
if minDelta.GreaterThan(types.NewInt(0)) && laneState.Nonce > sv.Nonce {
|
||||
return types.NewInt(0), xerrors.Errorf("already storing voucher with higher nonce; %d > %d", laneState.Nonce, sv.Nonce)
|
||||
}
|
||||
|
||||
// TODO:
|
||||
// It's possible to repeatedly add a voucher with the same proof:
|
||||
// 1. add a voucher with proof P1
|
||||
// 2. add a voucher with proof P2
|
||||
// 3. add a voucher with proof P2 (again)
|
||||
// Voucher with proof P2 has been added twice
|
||||
//
|
||||
// look for duplicates
|
||||
// Check if the voucher has already been added
|
||||
for i, v := range ci.Vouchers {
|
||||
eq, err := cborutil.Equals(sv, v.Voucher)
|
||||
if err != nil {
|
||||
@ -328,24 +241,35 @@ func (pm *Manager) AddVoucher(ctx context.Context, ch address.Address, sv *paych
|
||||
if !eq {
|
||||
continue
|
||||
}
|
||||
// TODO: CBOR encoding / decoding changes nil into []byte{}, so instead of
|
||||
// checking v.Proof against nil we should check len(v.Proof) == 0
|
||||
if v.Proof != nil {
|
||||
if !bytes.Equal(v.Proof, proof) {
|
||||
log.Warnf("AddVoucher: multiple proofs for single voucher, storing both")
|
||||
break
|
||||
|
||||
// This is a duplicate voucher.
|
||||
// Update the proof on the existing voucher
|
||||
if len(proof) > 0 && !bytes.Equal(v.Proof, proof) {
|
||||
log.Warnf("AddVoucher: adding proof to stored voucher")
|
||||
ci.Vouchers[i] = &VoucherInfo{
|
||||
Voucher: v.Voucher,
|
||||
Proof: proof,
|
||||
}
|
||||
log.Warnf("AddVoucher: voucher re-added with matching proof")
|
||||
return types.NewInt(0), nil
|
||||
|
||||
return types.NewInt(0), pm.store.putChannelInfo(ci)
|
||||
}
|
||||
|
||||
log.Warnf("AddVoucher: adding proof to stored voucher")
|
||||
ci.Vouchers[i] = &VoucherInfo{
|
||||
Voucher: v.Voucher,
|
||||
Proof: proof,
|
||||
}
|
||||
// Otherwise just ignore the duplicate voucher
|
||||
log.Warnf("AddVoucher: voucher re-added with matching proof")
|
||||
return types.NewInt(0), nil
|
||||
}
|
||||
|
||||
return types.NewInt(0), pm.store.putChannelInfo(ci)
|
||||
// Check voucher validity
|
||||
pchState, err := pm.checkVoucherValid(ctx, ch, sv)
|
||||
if err != nil {
|
||||
return types.NewInt(0), err
|
||||
}
|
||||
|
||||
// The change in value is the delta between the voucher amount and
|
||||
// the highest previous voucher amount
|
||||
laneState, err := pm.laneState(pchState, ch, sv.Lane)
|
||||
if err != nil {
|
||||
return types.NewInt(0), err
|
||||
}
|
||||
|
||||
delta := types.BigSub(sv.Amount, laneState.Redeemed)
|
||||
|
@ -435,7 +435,7 @@ func TestAddVoucherProof(t *testing.T) {
|
||||
|
||||
// Add same voucher with proof
|
||||
proof = []byte{1}
|
||||
_, err = mgr.AddVoucher(ctx, ch, sv, nil, minDelta)
|
||||
_, err = mgr.AddVoucher(ctx, ch, sv, proof, minDelta)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Should add proof to existing voucher
|
||||
@ -478,31 +478,32 @@ func TestNextNonceForLane(t *testing.T) {
|
||||
voucherAmount = big.NewInt(2)
|
||||
|
||||
// Add vouchers such that we have
|
||||
// lane 1: nonce 3
|
||||
// lane 1: nonce 2
|
||||
// lane 2: nonce 5
|
||||
// lane 1: nonce 4
|
||||
// lane 2: nonce 7
|
||||
voucherLane := uint64(1)
|
||||
for _, nonce := range []uint64{3, 2} {
|
||||
for _, nonce := range []uint64{2, 4} {
|
||||
voucherAmount = big.Add(voucherAmount, big.NewInt(1))
|
||||
sv := testCreateVoucher(t, voucherLane, nonce, voucherAmount, key)
|
||||
_, err := mgr.AddVoucher(ctx, ch, sv, nil, minDelta)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
voucherLane = uint64(2)
|
||||
nonce := uint64(5)
|
||||
nonce := uint64(7)
|
||||
sv := testCreateVoucher(t, voucherLane, nonce, voucherAmount, key)
|
||||
_, err = mgr.AddVoucher(ctx, ch, sv, nil, minDelta)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Expect next nonce for lane 1 to be 4
|
||||
// Expect next nonce for lane 1 to be 5
|
||||
next, err = mgr.NextNonceForLane(ctx, ch, 1)
|
||||
require.NoError(t, err)
|
||||
require.EqualValues(t, next, 4)
|
||||
require.EqualValues(t, next, 5)
|
||||
|
||||
// Expect next nonce for lane 2 to be 6
|
||||
// Expect next nonce for lane 2 to be 8
|
||||
next, err = mgr.NextNonceForLane(ctx, ch, 2)
|
||||
require.NoError(t, err)
|
||||
require.EqualValues(t, next, 6)
|
||||
require.EqualValues(t, next, 8)
|
||||
}
|
||||
|
||||
func testSetupMgrWithChannel(t *testing.T, ctx context.Context) (*Manager, address.Address, []byte) {
|
||||
|
@ -75,7 +75,7 @@ func (pm *Manager) waitForPaychCreateMsg(ctx context.Context, mcid cid.Cid) {
|
||||
}
|
||||
paychaddr := decodedReturn.RobustAddress
|
||||
|
||||
ci, err := pm.loadOutboundChannelInfo(ctx, paychaddr)
|
||||
ci, err := pm.loadStateChannelInfo(ctx, paychaddr, DirOutbound)
|
||||
if err != nil {
|
||||
log.Errorf("loading channel info: %w", err)
|
||||
return
|
||||
|
@ -3,6 +3,8 @@ package paychmgr
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/filecoin-project/specs-actors/actors/builtin/account"
|
||||
|
||||
"github.com/filecoin-project/go-address"
|
||||
"github.com/filecoin-project/specs-actors/actors/builtin/paych"
|
||||
xerrors "golang.org/x/xerrors"
|
||||
@ -20,6 +22,55 @@ func (pm *Manager) loadPaychState(ctx context.Context, ch address.Address) (*typ
|
||||
return act, &pcast, nil
|
||||
}
|
||||
|
||||
func (pm *Manager) loadStateChannelInfo(ctx context.Context, ch address.Address, dir uint64) (*ChannelInfo, error) {
|
||||
_, st, err := pm.loadPaychState(ctx, ch)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var account account.State
|
||||
_, err = pm.sm.LoadActorState(ctx, st.From, &account, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
from := account.Address
|
||||
_, err = pm.sm.LoadActorState(ctx, st.To, &account, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
to := account.Address
|
||||
|
||||
ci := &ChannelInfo{
|
||||
Channel: ch,
|
||||
Direction: dir,
|
||||
NextLane: nextLaneFromState(st),
|
||||
}
|
||||
|
||||
if dir == DirOutbound {
|
||||
ci.Control = from
|
||||
ci.Target = to
|
||||
} else {
|
||||
ci.Control = to
|
||||
ci.Target = from
|
||||
}
|
||||
|
||||
return ci, nil
|
||||
}
|
||||
|
||||
func nextLaneFromState(st *paych.State) uint64 {
|
||||
if len(st.LaneStates) == 0 {
|
||||
return 0
|
||||
}
|
||||
|
||||
maxLane := st.LaneStates[0].ID
|
||||
for _, state := range st.LaneStates {
|
||||
if state.ID > maxLane {
|
||||
maxLane = state.ID
|
||||
}
|
||||
}
|
||||
return maxLane + 1
|
||||
}
|
||||
|
||||
func findLane(states []*paych.LaneState, lane uint64) *paych.LaneState {
|
||||
var ls *paych.LaneState
|
||||
for _, laneState := range states {
|
||||
@ -31,16 +82,12 @@ func findLane(states []*paych.LaneState, lane uint64) *paych.LaneState {
|
||||
return ls
|
||||
}
|
||||
|
||||
func (pm *Manager) laneState(ctx context.Context, ch address.Address, lane uint64) (paych.LaneState, error) {
|
||||
_, state, err := pm.loadPaychState(ctx, ch)
|
||||
if err != nil {
|
||||
return paych.LaneState{}, err
|
||||
}
|
||||
|
||||
func (pm *Manager) laneState(state *paych.State, ch address.Address, lane uint64) (paych.LaneState, error) {
|
||||
// TODO: we probably want to call UpdateChannelState with all vouchers to be fully correct
|
||||
// (but technically dont't need to)
|
||||
// TODO: make sure this is correct
|
||||
|
||||
// Get the lane state from the chain
|
||||
ls := findLane(state.LaneStates, lane)
|
||||
if ls == nil {
|
||||
ls = &paych.LaneState{
|
||||
@ -50,6 +97,7 @@ func (pm *Manager) laneState(ctx context.Context, ch address.Address, lane uint6
|
||||
}
|
||||
}
|
||||
|
||||
// Apply locally stored vouchers
|
||||
vouchers, err := pm.store.VouchersForPaych(ch)
|
||||
if err != nil {
|
||||
if err == ErrChannelNotTracked {
|
||||
|
Loading…
Reference in New Issue
Block a user