make market diffs work across version upgrades

This commit is contained in:
Steven Allen 2020-09-17 16:08:54 -07:00
parent dc58f71604
commit 5bcfee0042
3 changed files with 121 additions and 91 deletions

View File

@ -8,6 +8,7 @@ import (
"github.com/filecoin-project/go-state-types/cbor" "github.com/filecoin-project/go-state-types/cbor"
v0builtin "github.com/filecoin-project/specs-actors/actors/builtin" v0builtin "github.com/filecoin-project/specs-actors/actors/builtin"
"github.com/ipfs/go-cid" "github.com/ipfs/go-cid"
cbg "github.com/whyrusleeping/cbor-gen"
"github.com/filecoin-project/lotus/chain/actors/adt" "github.com/filecoin-project/lotus/chain/actors/adt"
"github.com/filecoin-project/lotus/chain/types" "github.com/filecoin-project/lotus/chain/types"
@ -51,12 +52,18 @@ type BalanceTable interface {
type DealStates interface { type DealStates interface {
Get(id abi.DealID) (*DealState, bool, error) Get(id abi.DealID) (*DealState, bool, error)
Diff(DealStates) (*DealStateChanges, error) Diff(DealStates) (*DealStateChanges, error)
array() adt.Array
decode(*cbg.Deferred) (*DealState, error)
} }
type DealProposals interface { type DealProposals interface {
ForEach(cb func(id abi.DealID, dp DealProposal) error) error ForEach(cb func(id abi.DealID, dp DealProposal) error) error
Get(id abi.DealID) (*DealProposal, bool, error) Get(id abi.DealID) (*DealProposal, bool, error)
Diff(DealProposals) (*DealProposalChanges, error) Diff(DealProposals) (*DealProposalChanges, error)
array() adt.Array
decode(*cbg.Deferred) (*DealProposal, error)
} }
type DealState struct { type DealState struct {

View File

@ -0,0 +1,91 @@
package market
import (
"fmt"
"github.com/filecoin-project/go-state-types/abi"
"github.com/filecoin-project/lotus/chain/actors/adt"
cbg "github.com/whyrusleeping/cbor-gen"
)
func diffDealProposals(pre, cur DealProposals) (*DealProposalChanges, error) {
results := new(DealProposalChanges)
if err := adt.DiffAdtArray(pre.array(), cur.array(), &marketProposalsDiffer{results, pre, cur}); err != nil {
return nil, fmt.Errorf("diffing deal states: %w", err)
}
return results, nil
}
type marketProposalsDiffer struct {
Results *DealProposalChanges
pre, cur DealProposals
}
func (d *marketProposalsDiffer) Add(key uint64, val *cbg.Deferred) error {
dp, err := d.cur.decode(val)
if err != nil {
return err
}
d.Results.Added = append(d.Results.Added, ProposalIDState{abi.DealID(key), *dp})
return nil
}
func (d *marketProposalsDiffer) Modify(key uint64, from, to *cbg.Deferred) error {
// short circuit, DealProposals are static
return nil
}
func (d *marketProposalsDiffer) Remove(key uint64, val *cbg.Deferred) error {
dp, err := d.pre.decode(val)
if err != nil {
return err
}
d.Results.Removed = append(d.Results.Removed, ProposalIDState{abi.DealID(key), *dp})
return nil
}
func diffDealStates(pre, cur DealStates) (*DealStateChanges, error) {
results := new(DealStateChanges)
if err := adt.DiffAdtArray(pre.array(), cur.array(), &marketStatesDiffer{results, pre, cur}); err != nil {
return nil, fmt.Errorf("diffing deal states: %w", err)
}
return results, nil
}
type marketStatesDiffer struct {
Results *DealStateChanges
pre, cur DealStates
}
func (d *marketStatesDiffer) Add(key uint64, val *cbg.Deferred) error {
ds, err := d.cur.decode(val)
if err != nil {
return err
}
d.Results.Added = append(d.Results.Added, DealIDState{abi.DealID(key), *ds})
return nil
}
func (d *marketStatesDiffer) Modify(key uint64, from, to *cbg.Deferred) error {
dsFrom, err := d.pre.decode(from)
if err != nil {
return err
}
dsTo, err := d.cur.decode(to)
if err != nil {
return err
}
if *dsFrom != *dsTo {
d.Results.Modified = append(d.Results.Modified, DealStateChange{abi.DealID(key), dsFrom, dsTo})
}
return nil
}
func (d *marketStatesDiffer) Remove(key uint64, val *cbg.Deferred) error {
ds, err := d.pre.decode(val)
if err != nil {
return err
}
d.Results.Removed = append(d.Results.Removed, DealIDState{abi.DealID(key), *ds})
return nil
}

View File

@ -2,8 +2,6 @@ package market
import ( import (
"bytes" "bytes"
"errors"
"fmt"
"github.com/filecoin-project/go-address" "github.com/filecoin-project/go-address"
"github.com/filecoin-project/go-state-types/abi" "github.com/filecoin-project/go-state-types/abi"
@ -11,7 +9,7 @@ import (
"github.com/filecoin-project/lotus/chain/types" "github.com/filecoin-project/lotus/chain/types"
"github.com/filecoin-project/specs-actors/actors/builtin/market" "github.com/filecoin-project/specs-actors/actors/builtin/market"
v0adt "github.com/filecoin-project/specs-actors/actors/util/adt" v0adt "github.com/filecoin-project/specs-actors/actors/util/adt"
typegen "github.com/whyrusleeping/cbor-gen" cbg "github.com/whyrusleeping/cbor-gen"
) )
type v0State struct { type v0State struct {
@ -127,60 +125,20 @@ func (s *v0DealStates) Get(dealID abi.DealID) (*DealState, bool, error) {
} }
func (s *v0DealStates) Diff(other DealStates) (*DealStateChanges, error) { func (s *v0DealStates) Diff(other DealStates) (*DealStateChanges, error) {
v0other, ok := other.(*v0DealStates) return diffDealStates(s, other)
if !ok {
// TODO handle this if possible on a case by case basis but for now, just fail
return nil, errors.New("cannot compare deal states across versions")
}
results := new(DealStateChanges)
if err := adt.DiffAdtArray(s.Array, v0other.Array, &v0MarketStatesDiffer{results}); err != nil {
return nil, fmt.Errorf("diffing deal states: %w", err)
}
return results, nil
} }
type v0MarketStatesDiffer struct { func (s *v0DealStates) decode(val *cbg.Deferred) (*DealState, error) {
Results *DealStateChanges var v0ds market.DealState
if err := v0ds.UnmarshalCBOR(bytes.NewReader(val.Raw)); err != nil {
return nil, err
}
ds := fromV0DealState(v0ds)
return &ds, nil
} }
func (d *v0MarketStatesDiffer) Add(key uint64, val *typegen.Deferred) error { func (s *v0DealStates) array() adt.Array {
v0ds := new(market.DealState) return s.Array
err := v0ds.UnmarshalCBOR(bytes.NewReader(val.Raw))
if err != nil {
return err
}
d.Results.Added = append(d.Results.Added, DealIDState{abi.DealID(key), fromV0DealState(*v0ds)})
return nil
}
func (d *v0MarketStatesDiffer) Modify(key uint64, from, to *typegen.Deferred) error {
v0dsFrom := new(market.DealState)
if err := v0dsFrom.UnmarshalCBOR(bytes.NewReader(from.Raw)); err != nil {
return err
}
v0dsTo := new(market.DealState)
if err := v0dsTo.UnmarshalCBOR(bytes.NewReader(to.Raw)); err != nil {
return err
}
if *v0dsFrom != *v0dsTo {
dsFrom := fromV0DealState(*v0dsFrom)
dsTo := fromV0DealState(*v0dsTo)
d.Results.Modified = append(d.Results.Modified, DealStateChange{abi.DealID(key), &dsFrom, &dsTo})
}
return nil
}
func (d *v0MarketStatesDiffer) Remove(key uint64, val *typegen.Deferred) error {
v0ds := new(market.DealState)
err := v0ds.UnmarshalCBOR(bytes.NewReader(val.Raw))
if err != nil {
return err
}
d.Results.Removed = append(d.Results.Removed, DealIDState{abi.DealID(key), fromV0DealState(*v0ds)})
return nil
} }
func fromV0DealState(v0 market.DealState) DealState { func fromV0DealState(v0 market.DealState) DealState {
@ -192,17 +150,7 @@ type v0DealProposals struct {
} }
func (s *v0DealProposals) Diff(other DealProposals) (*DealProposalChanges, error) { func (s *v0DealProposals) Diff(other DealProposals) (*DealProposalChanges, error) {
v0other, ok := other.(*v0DealProposals) return diffDealProposals(s, other)
if !ok {
// TODO handle this if possible on a case by case basis but for now, just fail
return nil, errors.New("cannot compare deal proposals across versions")
}
results := new(DealProposalChanges)
if err := adt.DiffAdtArray(s.Array, v0other.Array, &v0MarketProposalsDiffer{results}); err != nil {
return nil, fmt.Errorf("diffing deal proposals: %w", err)
}
return results, nil
} }
func (s *v0DealProposals) Get(dealID abi.DealID) (*DealProposal, bool, error) { func (s *v0DealProposals) Get(dealID abi.DealID) (*DealProposal, bool, error) {
@ -225,35 +173,19 @@ func (s *v0DealProposals) ForEach(cb func(dealID abi.DealID, dp DealProposal) er
}) })
} }
type v0MarketProposalsDiffer struct { func (s *v0DealProposals) decode(val *cbg.Deferred) (*DealProposal, error) {
Results *DealProposalChanges var v0dp market.DealProposal
if err := v0dp.UnmarshalCBOR(bytes.NewReader(val.Raw)); err != nil {
return nil, err
}
dp := fromV0DealProposal(v0dp)
return &dp, nil
}
func (s *v0DealProposals) array() adt.Array {
return s.Array
} }
func fromV0DealProposal(v0 market.DealProposal) DealProposal { func fromV0DealProposal(v0 market.DealProposal) DealProposal {
return (DealProposal)(v0) return (DealProposal)(v0)
} }
func (d *v0MarketProposalsDiffer) Add(key uint64, val *typegen.Deferred) error {
v0dp := new(market.DealProposal)
err := v0dp.UnmarshalCBOR(bytes.NewReader(val.Raw))
if err != nil {
return err
}
d.Results.Added = append(d.Results.Added, ProposalIDState{abi.DealID(key), fromV0DealProposal(*v0dp)})
return nil
}
func (d *v0MarketProposalsDiffer) Modify(key uint64, from, to *typegen.Deferred) error {
// short circuit, DealProposals are static
return nil
}
func (d *v0MarketProposalsDiffer) Remove(key uint64, val *typegen.Deferred) error {
v0dp := new(market.DealProposal)
err := v0dp.UnmarshalCBOR(bytes.NewReader(val.Raw))
if err != nil {
return err
}
d.Results.Removed = append(d.Results.Removed, ProposalIDState{abi.DealID(key), fromV0DealProposal(*v0dp)})
return nil
}