diff --git a/api/cbor_gen.go b/api/cbor_gen.go index 8d309a6cc..a65f97aaf 100644 --- a/api/cbor_gen.go +++ b/api/cbor_gen.go @@ -23,6 +23,10 @@ func (t *PaymentInfo) MarshalCBOR(w io.Writer) error { } // t.Channel (address.Address) (struct) + if len("Channel") > cbg.MaxLength { + return xerrors.Errorf("Value in field \"Channel\" was too long") + } + if _, err := w.Write(cbg.CborEncodeMajorType(cbg.MajTextString, uint64(len("Channel")))); err != nil { return err } @@ -35,6 +39,10 @@ func (t *PaymentInfo) MarshalCBOR(w io.Writer) error { } // t.ChannelMessage (cid.Cid) (struct) + if len("ChannelMessage") > cbg.MaxLength { + return xerrors.Errorf("Value in field \"ChannelMessage\" was too long") + } + if _, err := w.Write(cbg.CborEncodeMajorType(cbg.MajTextString, uint64(len("ChannelMessage")))); err != nil { return err } @@ -53,6 +61,10 @@ func (t *PaymentInfo) MarshalCBOR(w io.Writer) error { } // t.Vouchers ([]*types.SignedVoucher) (slice) + if len("Vouchers") > cbg.MaxLength { + return xerrors.Errorf("Value in field \"Vouchers\" was too long") + } + if _, err := w.Write(cbg.CborEncodeMajorType(cbg.MajTextString, uint64(len("Vouchers")))); err != nil { return err } @@ -60,6 +72,10 @@ func (t *PaymentInfo) MarshalCBOR(w io.Writer) error { return err } + if len(t.Vouchers) > cbg.MaxLength { + return xerrors.Errorf("Slice value in field t.Vouchers was too long") + } + if _, err := w.Write(cbg.CborEncodeMajorType(cbg.MajArray, uint64(len(t.Vouchers)))); err != nil { return err } @@ -200,6 +216,10 @@ func (t *SealedRef) MarshalCBOR(w io.Writer) error { } // t.SectorID (uint64) (uint64) + if len("SectorID") > cbg.MaxLength { + return xerrors.Errorf("Value in field \"SectorID\" was too long") + } + if _, err := w.Write(cbg.CborEncodeMajorType(cbg.MajTextString, uint64(len("SectorID")))); err != nil { return err } @@ -212,6 +232,10 @@ func (t *SealedRef) MarshalCBOR(w io.Writer) error { } // t.Offset (uint64) (uint64) + if len("Offset") > cbg.MaxLength { + return xerrors.Errorf("Value in field \"Offset\" was too long") + } + if _, err := w.Write(cbg.CborEncodeMajorType(cbg.MajTextString, uint64(len("Offset")))); err != nil { return err } @@ -224,6 +248,10 @@ func (t *SealedRef) MarshalCBOR(w io.Writer) error { } // t.Size (uint64) (uint64) + if len("Size") > cbg.MaxLength { + return xerrors.Errorf("Value in field \"Size\" was too long") + } + if _, err := w.Write(cbg.CborEncodeMajorType(cbg.MajTextString, uint64(len("Size")))); err != nil { return err } @@ -336,6 +364,10 @@ func (t *SealedRefs) MarshalCBOR(w io.Writer) error { } // t.Refs ([]api.SealedRef) (slice) + if len("Refs") > cbg.MaxLength { + return xerrors.Errorf("Value in field \"Refs\" was too long") + } + if _, err := w.Write(cbg.CborEncodeMajorType(cbg.MajTextString, uint64(len("Refs")))); err != nil { return err } @@ -343,6 +375,10 @@ func (t *SealedRefs) MarshalCBOR(w io.Writer) error { return err } + if len(t.Refs) > cbg.MaxLength { + return xerrors.Errorf("Slice value in field t.Refs was too long") + } + if _, err := w.Write(cbg.CborEncodeMajorType(cbg.MajArray, uint64(len(t.Refs)))); err != nil { return err } diff --git a/chain/deals/client.go b/chain/deals/client.go index afbb317f0..754a2d021 100644 --- a/chain/deals/client.go +++ b/chain/deals/client.go @@ -149,7 +149,7 @@ func (c *Client) onIncoming(deal *ClientDeal) { func (c *Client) onUpdated(ctx context.Context, update clientDealUpdate) { log.Infof("Client deal %s updated state to %s", update.id, api.DealStates[update.newState]) var deal ClientDeal - err := c.deals.Mutate(update.id, func(d *ClientDeal) error { + err := c.deals.Get(update.id).Mutate(func(d *ClientDeal) error { d.State = update.newState if update.mut != nil { update.mut(d) @@ -299,7 +299,7 @@ func (c *Client) List() ([]ClientDeal, error) { func (c *Client) GetDeal(d cid.Cid) (*ClientDeal, error) { var out ClientDeal - if err := c.deals.Get(d, &out); err != nil { + if err := c.deals.Get(d).Get(&out); err != nil { return nil, err } return &out, nil diff --git a/chain/deals/client_utils.go b/chain/deals/client_utils.go index 256738366..d2fbce4c1 100644 --- a/chain/deals/client_utils.go +++ b/chain/deals/client_utils.go @@ -151,7 +151,7 @@ func (c *ClientRequestValidator) ValidatePull( } var deal ClientDeal - err := c.deals.Get(dealVoucher.Proposal, &deal) + err := c.deals.Get(dealVoucher.Proposal).Get(&deal) if err != nil { return xerrors.Errorf("Proposal CID %s: %w", dealVoucher.Proposal.String(), ErrNoDeal) } diff --git a/chain/deals/provider.go b/chain/deals/provider.go index a12f9fa07..b75af8429 100644 --- a/chain/deals/provider.go +++ b/chain/deals/provider.go @@ -184,7 +184,7 @@ func (p *Provider) onUpdated(ctx context.Context, update minerDealUpdate) { return } var deal MinerDeal - err := p.deals.Mutate(update.id, func(d *MinerDeal) error { + err := p.deals.Get(update.id).Mutate(func(d *MinerDeal) error { d.State = update.newState if update.mut != nil { update.mut(d) diff --git a/chain/deals/provider_utils.go b/chain/deals/provider_utils.go index c3e77b760..4607d609a 100644 --- a/chain/deals/provider_utils.go +++ b/chain/deals/provider_utils.go @@ -23,7 +23,7 @@ import ( ) func (p *Provider) failDeal(id cid.Cid, cerr error) { - if err := p.deals.End(id); err != nil { + if err := p.deals.Get(id).End(); err != nil { log.Warnf("deals.End: %s", err) } @@ -167,7 +167,7 @@ func (m *ProviderRequestValidator) ValidatePush( } var deal MinerDeal - err := m.deals.Get(dealVoucher.Proposal, &deal) + err := m.deals.Get(dealVoucher.Proposal).Get(&deal) if err != nil { return xerrors.Errorf("Proposal CID %s: %w", dealVoucher.Proposal.String(), ErrNoDeal) } diff --git a/gen/main.go b/gen/main.go index 52fcf4665..202b22d5f 100644 --- a/gen/main.go +++ b/gen/main.go @@ -11,6 +11,7 @@ import ( "github.com/filecoin-project/lotus/chain/blocksync" "github.com/filecoin-project/lotus/chain/deals" "github.com/filecoin-project/lotus/chain/types" + "github.com/filecoin-project/lotus/lib/evtsm" "github.com/filecoin-project/lotus/paych" "github.com/filecoin-project/lotus/retrieval" "github.com/filecoin-project/lotus/storage" @@ -164,4 +165,13 @@ func main() { fmt.Println(err) os.Exit(1) } + + err = gen.WriteMapEncodersToFile("./lib/evtsm/cbor_gen.go", "evtsm", + evtsm.TestState{}, + evtsm.TestEvent{}, + ) + if err != nil { + fmt.Printf("%+v\n", err) + os.Exit(1) + } } diff --git a/go.mod b/go.mod index c0bf799d6..f7e7281aa 100644 --- a/go.mod +++ b/go.mod @@ -103,6 +103,7 @@ require ( golang.org/x/crypto v0.0.0-20191206172530-e9b2fee46413 // indirect golang.org/x/sys v0.0.0-20191210023423-ac6580df4449 // indirect golang.org/x/time v0.0.0-20190308202827-9d24e82272b4 + golang.org/x/tools v0.0.0-20191029190741-b9c20aec41a5 golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 gopkg.in/cheggaaa/pb.v1 v1.0.28 gopkg.in/urfave/cli.v2 v2.0.0-20180128182452-d3ae77c26ac8 @@ -113,3 +114,7 @@ require ( replace github.com/golangci/golangci-lint => github.com/golangci/golangci-lint v1.18.0 replace github.com/filecoin-project/filecoin-ffi => ./extern/filecoin-ffi + +replace github.com/whyrusleeping/cbor-gen => /home/magik6k/gohack/github.com/whyrusleeping/cbor-gen + +replace github.com/filecoin-project/go-cbor-util => /home/magik6k/gohack/github.com/filecoin-project/go-cbor-util diff --git a/go.sum b/go.sum index 3f09e3afb..7ec4fd5d3 100644 --- a/go.sum +++ b/go.sum @@ -85,12 +85,10 @@ github.com/filecoin-project/go-address v0.0.0-20191219011437-af739c490b4f h1:L2j github.com/filecoin-project/go-address v0.0.0-20191219011437-af739c490b4f/go.mod h1:rCbpXPva2NKF9/J4X6sr7hbKBgQCxyFtRj7KOZqoIms= github.com/filecoin-project/go-amt-ipld v0.0.0-20191205011053-79efc22d6cdc h1:cODZD2YzpTUtrOSxbEnWFcQHidNRZiRdvLxySjGvG/M= github.com/filecoin-project/go-amt-ipld v0.0.0-20191205011053-79efc22d6cdc/go.mod h1:KsFPWjF+UUYl6n9A+qbg4bjFgAOneicFZtDH/LQEX2U= -github.com/filecoin-project/go-paramfetch v0.0.0-20200102181131-b20d579f2878 h1:YicJT9xhPzZ1SBGiJFNUCkfwqK/G9vFyY1ytKBSjNJA= -github.com/filecoin-project/go-paramfetch v0.0.0-20200102181131-b20d579f2878/go.mod h1:40kI2Gv16mwcRsHptI3OAV4nlOEU7wVDc4RgMylNFjU= github.com/filecoin-project/go-crypto v0.0.0-20191218222705-effae4ea9f03 h1:2pMXdBnCiXjfCYx/hLqFxccPoqsSveQFxVLvNxy9bus= github.com/filecoin-project/go-crypto v0.0.0-20191218222705-effae4ea9f03/go.mod h1:+viYnvGtUTgJRdy6oaeF4MTFKAfatX071MPDPBL11EQ= -github.com/filecoin-project/go-cbor-util v0.0.0-20191219014500-08c40a1e63a2 h1:av5fw6wmm58FYMgJeoB/lK9XXrgdugYiTqkdxjTy9k8= -github.com/filecoin-project/go-cbor-util v0.0.0-20191219014500-08c40a1e63a2/go.mod h1:pqTiPHobNkOVM5thSRsHYjyQfq7O5QSCMhvuu9JoDlg= +github.com/filecoin-project/go-paramfetch v0.0.0-20200102181131-b20d579f2878 h1:YicJT9xhPzZ1SBGiJFNUCkfwqK/G9vFyY1ytKBSjNJA= +github.com/filecoin-project/go-paramfetch v0.0.0-20200102181131-b20d579f2878/go.mod h1:40kI2Gv16mwcRsHptI3OAV4nlOEU7wVDc4RgMylNFjU= github.com/fsnotify/fsnotify v1.4.7 h1:IXs+QLmnXW2CcXuY+8Mzv/fWEsPGWxqefPtCP5CnV9I= github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo= github.com/gbrlsnchs/jwt/v3 v3.0.0-beta.1 h1:EzDjxMg43q1tA2c0MV3tNbaontnHLplHyFF6M5KiVP0= @@ -629,11 +627,6 @@ github.com/whyrusleeping/base32 v0.0.0-20170828182744-c30ac30633cc h1:BCPnHtcboa github.com/whyrusleeping/base32 v0.0.0-20170828182744-c30ac30633cc/go.mod h1:r45hJU7yEoA81k6MWNhpMj/kms0n14dkzkxYHoB96UM= github.com/whyrusleeping/bencher v0.0.0-20190829221104-bb6607aa8bba h1:X4n8JG2e2biEZZXdBKt9HX7DN3bYGFUqljqqy0DqgnY= github.com/whyrusleeping/bencher v0.0.0-20190829221104-bb6607aa8bba/go.mod h1:CHQnYnQUEPydYCwuy8lmTHfGmdw9TKrhWV0xLx8l0oM= -github.com/whyrusleeping/cbor-gen v0.0.0-20190910031516-c1cbffdb01bb/go.mod h1:xdlJQaiqipF0HW+Mzpg7XRM3fWbGvfgFlcppuvlkIvY= -github.com/whyrusleeping/cbor-gen v0.0.0-20190917003517-d78d67427694/go.mod h1:xdlJQaiqipF0HW+Mzpg7XRM3fWbGvfgFlcppuvlkIvY= -github.com/whyrusleeping/cbor-gen v0.0.0-20191116002219-891f55cd449d/go.mod h1:xdlJQaiqipF0HW+Mzpg7XRM3fWbGvfgFlcppuvlkIvY= -github.com/whyrusleeping/cbor-gen v0.0.0-20191216205031-b047b6acb3c0 h1:efb/4CnrubzNGqQOeHErxyQ6rIsJb7GcgeSDF7fqWeI= -github.com/whyrusleeping/cbor-gen v0.0.0-20191216205031-b047b6acb3c0/go.mod h1:xdlJQaiqipF0HW+Mzpg7XRM3fWbGvfgFlcppuvlkIvY= github.com/whyrusleeping/chunker v0.0.0-20181014151217-fe64bd25879f h1:jQa4QT2UP9WYv2nzyawpKMOCl+Z/jW7djv2/J50lj9E= github.com/whyrusleeping/chunker v0.0.0-20181014151217-fe64bd25879f/go.mod h1:p9UJB6dDgdPgMJZs7UjUOdulKyRr9fqkS+6JKAInPy8= github.com/whyrusleeping/go-keyspace v0.0.0-20160322163242-5b898ac5add1 h1:EKhdznlJHPMoKr0XTrX+IlJs1LH3lyx2nfr1dOlZ79k= diff --git a/lib/evtsm/cbor_gen.go b/lib/evtsm/cbor_gen.go new file mode 100644 index 000000000..72a54252d --- /dev/null +++ b/lib/evtsm/cbor_gen.go @@ -0,0 +1,238 @@ +package evtsm + +import ( + "fmt" + "io" + + cbg "github.com/whyrusleeping/cbor-gen" + xerrors "golang.org/x/xerrors" +) + +// Code generated by github.com/whyrusleeping/cbor-gen. DO NOT EDIT. + +var _ = xerrors.Errorf + +func (t *TestState) MarshalCBOR(w io.Writer) error { + if t == nil { + _, err := w.Write(cbg.CborNull) + return err + } + if _, err := w.Write([]byte{162}); err != nil { + return err + } + + // t.A (uint64) (uint64) + if len("A") > cbg.MaxLength { + return xerrors.Errorf("Value in field \"A\" was too long") + } + + if _, err := w.Write(cbg.CborEncodeMajorType(cbg.MajTextString, uint64(len("A")))); err != nil { + return err + } + if _, err := w.Write([]byte("A")); err != nil { + return err + } + + if _, err := w.Write(cbg.CborEncodeMajorType(cbg.MajUnsignedInt, uint64(t.A))); err != nil { + return err + } + + // t.B (uint64) (uint64) + if len("B") > cbg.MaxLength { + return xerrors.Errorf("Value in field \"B\" was too long") + } + + if _, err := w.Write(cbg.CborEncodeMajorType(cbg.MajTextString, uint64(len("B")))); err != nil { + return err + } + if _, err := w.Write([]byte("B")); err != nil { + return err + } + + if _, err := w.Write(cbg.CborEncodeMajorType(cbg.MajUnsignedInt, uint64(t.B))); err != nil { + return err + } + return nil +} + +func (t *TestState) UnmarshalCBOR(r io.Reader) error { + br := cbg.GetPeeker(r) + + maj, extra, err := cbg.CborReadHeader(br) + if err != nil { + return err + } + if maj != cbg.MajMap { + return fmt.Errorf("cbor input should be of type map") + } + + if extra != 2 { + return fmt.Errorf("cbor input had wrong number of fields") + } + + var name string + + // t.A (uint64) (uint64) + + { + sval, err := cbg.ReadString(br) + if err != nil { + return err + } + + name = string(sval) + } + + if name != "A" { + return fmt.Errorf("expected struct map entry %s to be A", name) + } + + maj, extra, err = cbg.CborReadHeader(br) + if err != nil { + return err + } + if maj != cbg.MajUnsignedInt { + return fmt.Errorf("wrong type for uint64 field") + } + t.A = uint64(extra) + // t.B (uint64) (uint64) + + { + sval, err := cbg.ReadString(br) + if err != nil { + return err + } + + name = string(sval) + } + + if name != "B" { + return fmt.Errorf("expected struct map entry %s to be B", name) + } + + maj, extra, err = cbg.CborReadHeader(br) + if err != nil { + return err + } + if maj != cbg.MajUnsignedInt { + return fmt.Errorf("wrong type for uint64 field") + } + t.B = uint64(extra) + return nil +} + +func (t *TestEvent) MarshalCBOR(w io.Writer) error { + if t == nil { + _, err := w.Write(cbg.CborNull) + return err + } + if _, err := w.Write([]byte{162}); err != nil { + return err + } + + // t.A (string) (string) + if len("A") > cbg.MaxLength { + return xerrors.Errorf("Value in field \"A\" was too long") + } + + if _, err := w.Write(cbg.CborEncodeMajorType(cbg.MajTextString, uint64(len("A")))); err != nil { + return err + } + if _, err := w.Write([]byte("A")); err != nil { + return err + } + + if len(t.A) > cbg.MaxLength { + return xerrors.Errorf("Value in field t.A was too long") + } + + if _, err := w.Write(cbg.CborEncodeMajorType(cbg.MajTextString, uint64(len(t.A)))); err != nil { + return err + } + if _, err := w.Write([]byte(t.A)); err != nil { + return err + } + + // t.Val (uint64) (uint64) + if len("Val") > cbg.MaxLength { + return xerrors.Errorf("Value in field \"Val\" was too long") + } + + if _, err := w.Write(cbg.CborEncodeMajorType(cbg.MajTextString, uint64(len("Val")))); err != nil { + return err + } + if _, err := w.Write([]byte("Val")); err != nil { + return err + } + + if _, err := w.Write(cbg.CborEncodeMajorType(cbg.MajUnsignedInt, uint64(t.Val))); err != nil { + return err + } + return nil +} + +func (t *TestEvent) UnmarshalCBOR(r io.Reader) error { + br := cbg.GetPeeker(r) + + maj, extra, err := cbg.CborReadHeader(br) + if err != nil { + return err + } + if maj != cbg.MajMap { + return fmt.Errorf("cbor input should be of type map") + } + + if extra != 2 { + return fmt.Errorf("cbor input had wrong number of fields") + } + + var name string + + // t.A (string) (string) + + { + sval, err := cbg.ReadString(br) + if err != nil { + return err + } + + name = string(sval) + } + + if name != "A" { + return fmt.Errorf("expected struct map entry %s to be A", name) + } + + { + sval, err := cbg.ReadString(br) + if err != nil { + return err + } + + t.A = string(sval) + } + // t.Val (uint64) (uint64) + + { + sval, err := cbg.ReadString(br) + if err != nil { + return err + } + + name = string(sval) + } + + if name != "Val" { + return fmt.Errorf("expected struct map entry %s to be Val", name) + } + + maj, extra, err = cbg.CborReadHeader(br) + if err != nil { + return err + } + if maj != cbg.MajUnsignedInt { + return fmt.Errorf("wrong type for uint64 field") + } + t.Val = uint64(extra) + return nil +} diff --git a/lib/evtsm/ctx.go b/lib/evtsm/ctx.go new file mode 100644 index 000000000..a3e1f0521 --- /dev/null +++ b/lib/evtsm/ctx.go @@ -0,0 +1,16 @@ +package evtsm + +import "context" + +type Context struct { + ctx context.Context + send func(evt interface{}) error +} + +func (ctx *Context) Context() context.Context { + return ctx.ctx +} + +func (ctx *Context) Send(evt interface{}) error { + return ctx.send(evt) +} diff --git a/lib/evtsm/event.go b/lib/evtsm/event.go new file mode 100644 index 000000000..bc54b4a01 --- /dev/null +++ b/lib/evtsm/event.go @@ -0,0 +1,5 @@ +package evtsm + +type Event struct { + User interface{} +} diff --git a/lib/evtsm/evtsm.go b/lib/evtsm/evtsm.go new file mode 100644 index 000000000..f0467acde --- /dev/null +++ b/lib/evtsm/evtsm.go @@ -0,0 +1,120 @@ +package evtsm + +import ( + "context" + "reflect" + "sync/atomic" + + "github.com/filecoin-project/lotus/lib/statestore" + logging "github.com/ipfs/go-log" +) + +var log = logging.Logger("evtsm") + +// returns func(ctx Context, st ) (func(*), error), where is the typeOf(User) param +type Planner func(events []Event, user interface{}) (interface{}, error) + +type ESm struct { + planner Planner + eventsIn chan Event + + name interface{} + st *statestore.StoredState + stateType reflect.Type + + stageDone chan struct{} + closing chan struct{} + closed chan struct{} + + busy int32 +} + +func (fsm *ESm) run() { + defer close(fsm.closed) + + var pendingEvents []Event + + for { + // NOTE: This requires at least one event to be sent to trigger a stage + // This means that after restarting the state machine users of this + // code must send a 'restart' event + select { + case evt := <-fsm.eventsIn: + pendingEvents = append(pendingEvents, evt) + case <-fsm.stageDone: + if len(pendingEvents) == 0 { + continue + } + case <-fsm.closing: + return + } + + if atomic.CompareAndSwapInt32(&fsm.busy, 0, 1) { + var nextStep interface{} + var ustate interface{} + + err := fsm.mutateUser(func(user interface{}) (err error) { + nextStep, err = fsm.planner(pendingEvents, user) + ustate = user + return err + }) + if err != nil { + log.Errorf("Executing event planner failed: %+v", err) + return + } + + pendingEvents = nil + + if nextStep == nil { + continue + } + + ctx := Context{ + ctx: context.TODO(), + send: func(evt interface{}) error { + return fsm.send(Event{User: evt}) + }, + } + + go func() { + res := reflect.ValueOf(nextStep).Call([]reflect.Value{reflect.ValueOf(ctx), reflect.ValueOf(ustate).Elem()}) + + if res[0].Interface() != nil { + log.Errorf("executing step: %+v", res[0].Interface().(error)) // TODO: propagate top level + return + } + + atomic.StoreInt32(&fsm.busy, 0) + fsm.stageDone <- struct{}{} + }() + + } + } +} + +func (fsm *ESm) mutateUser(cb func(user interface{}) error) error { + mutt := reflect.FuncOf([]reflect.Type{reflect.PtrTo(fsm.stateType)}, []reflect.Type{reflect.TypeOf(new(error)).Elem()}, false) + + mutf := reflect.MakeFunc(mutt, func(args []reflect.Value) (results []reflect.Value) { + err := cb(args[0].Interface()) + return []reflect.Value{reflect.ValueOf(&err).Elem()} + }) + + return fsm.st.Mutate(mutf.Interface()) +} + +func (fsm *ESm) send(evt Event) error { + fsm.eventsIn <- evt // TODO: ctx, at least + return nil +} + +func (fsm *ESm) stop(ctx context.Context) error { + close(fsm.closing) + + select { + case <-fsm.closed: + return nil + case <-ctx.Done(): + return ctx.Err() + } +} diff --git a/lib/evtsm/evtsm_test.go b/lib/evtsm/evtsm_test.go new file mode 100644 index 000000000..bcf4aab69 --- /dev/null +++ b/lib/evtsm/evtsm_test.go @@ -0,0 +1,106 @@ +package evtsm + +import ( + "context" + "reflect" + "testing" + + "github.com/ipfs/go-datastore" + logging "github.com/ipfs/go-log" + "gotest.tools/assert" +) + +func init() { + logging.SetLogLevel("*", "INFO") +} + +type testHandler struct { + t *testing.T + proceed chan struct{} + done chan struct{} +} + +func (t *testHandler) Plan(events []Event, state interface{}) (interface{}, error) { + return t.plan(events, state.(*TestState)) +} + +func (t *testHandler) plan(events []Event, state *TestState) (interface{}, error) { + for _, event := range events { + e := event.User.(*TestEvent) + switch e.A { + case "restart": + case "start": + state.A = 1 + case "b": + state.A = 2 + state.B = e.Val + } + } + + switch state.A { + case 1: + return t.step0, nil + case 2: + return t.step1, nil + default: + t.t.Fatal(state.A) + } + panic("how?") +} + +func (t *testHandler) step0(ctx Context, st TestState) error { + ctx.Send(&TestEvent{A: "b", Val: 55}) + <-t.proceed + return nil +} + +func (t *testHandler) step1(ctx Context, st TestState) error { + assert.Equal(t.t, uint64(2), st.A) + + close(t.done) + return nil +} + +func TestBasic(t *testing.T) { + for i := 0; i < 1000; i++ { // run a few times to expose any races + ds := datastore.NewMapDatastore() + + th := &testHandler{t: t, done: make(chan struct{}), proceed: make(chan struct{})} + close(th.proceed) + smm := New(ds, th, reflect.TypeOf(TestState{})) + + if err := smm.Send(uint64(2), &TestEvent{A: "start"}); err != nil { + t.Fatalf("%+v", err) + } + + <-th.done + } +} + +func TestPersist(t *testing.T) { + for i := 0; i < 1000; i++ { // run a few times to expose any races + ds := datastore.NewMapDatastore() + + th := &testHandler{t: t, done: make(chan struct{}), proceed: make(chan struct{})} + smm := New(ds, th, reflect.TypeOf(TestState{})) + + if err := smm.Send(uint64(2), &TestEvent{A: "start"}); err != nil { + t.Fatalf("%+v", err) + } + + if err := smm.Stop(context.Background()); err != nil { + t.Fatal(err) + return + } + + smm = New(ds, th, reflect.TypeOf(TestState{})) + if err := smm.Send(uint64(2), &TestEvent{A: "restart"}); err != nil { + t.Fatalf("%+v", err) + } + close(th.proceed) + + <-th.done + } +} + +var _ StateHandler = &testHandler{} diff --git a/lib/evtsm/sched.go b/lib/evtsm/sched.go new file mode 100644 index 000000000..39dd3027e --- /dev/null +++ b/lib/evtsm/sched.go @@ -0,0 +1,99 @@ +package evtsm + +import ( + "context" + "reflect" + "sync" + + "github.com/ipfs/go-datastore" + "golang.org/x/xerrors" + + "github.com/filecoin-project/lotus/lib/statestore" +) + +type StateHandler interface { + // returns + Plan(events []Event, user interface{}) (interface{}, error) +} + +type Sched struct { + sts *statestore.StateStore + hnd StateHandler + stateType reflect.Type + + lk sync.Mutex + sms map[datastore.Key]*ESm +} + +// stateType: T - (reflect.TypeOf(MyStateStruct{})) +func New(ds datastore.Datastore, hnd StateHandler, stateType reflect.Type) *Sched { + return &Sched{ + sts: statestore.New(ds), + hnd: hnd, + stateType: stateType, + + sms: map[datastore.Key]*ESm{}, + } +} + +func (s *Sched) Send(to interface{}, evt interface{}) (err error) { + s.lk.Lock() + defer s.lk.Unlock() + + sm, exist := s.sms[statestore.ToKey(to)] + if !exist { + sm, err = s.loadOrCreate(to) + if err != nil { + return xerrors.Errorf("loadOrCreate state: %w", err) + } + s.sms[statestore.ToKey(to)] = sm + } + + return sm.send(Event{User: evt}) +} + +func (s *Sched) loadOrCreate(name interface{}) (*ESm, error) { + exists, err := s.sts.Has(name) + if err != nil { + return nil, xerrors.Errorf("failed to check if state for %v exists: %w", name, err) + } + + if !exists { + userState := reflect.New(s.stateType).Interface() + + err = s.sts.Begin(name, userState) + if err != nil { + return nil, xerrors.Errorf("saving initial state: %w", err) + } + } + + res := &ESm{ + planner: s.hnd.Plan, + eventsIn: make(chan Event), + + name: name, + st: s.sts.Get(name), + stateType: s.stateType, + + stageDone: make(chan struct{}), + closing: make(chan struct{}), + closed: make(chan struct{}), + } + + go res.run() + + return res, nil +} + +func (s *Sched) Stop(ctx context.Context) error { + s.lk.Lock() + defer s.lk.Unlock() + + for _, sm := range s.sms { + if err := sm.stop(ctx); err != nil { + return err + } + } + + return nil +} diff --git a/lib/evtsm/testing.go b/lib/evtsm/testing.go new file mode 100644 index 000000000..8bf0e1733 --- /dev/null +++ b/lib/evtsm/testing.go @@ -0,0 +1,11 @@ +package evtsm + +type TestState struct { + A uint64 + B uint64 +} + +type TestEvent struct { + A string + Val uint64 +} diff --git a/lib/statestore/state.go b/lib/statestore/state.go new file mode 100644 index 000000000..8c2604e2e --- /dev/null +++ b/lib/statestore/state.go @@ -0,0 +1,93 @@ +package statestore + +import ( + "bytes" + "reflect" + + cborutil "github.com/filecoin-project/go-cbor-util" + "github.com/ipfs/go-datastore" + cbg "github.com/whyrusleeping/cbor-gen" + "golang.org/x/xerrors" +) + +type StoredState struct { + ds datastore.Datastore + name datastore.Key +} + +func (st *StoredState) End() error { + has, err := st.ds.Has(st.name) + if err != nil { + return err + } + if !has { + return xerrors.Errorf("No state for %s", st.name) + } + if err := st.ds.Delete(st.name); err != nil { + return xerrors.Errorf("removing state from datastore: %w", err) + } + st.name = datastore.Key{} + st.ds = nil + + return nil +} + +func (st *StoredState) Get(out cbg.CBORUnmarshaler) error { + val, err := st.ds.Get(st.name) + if err != nil { + if xerrors.Is(err, datastore.ErrNotFound) { + return xerrors.Errorf("No state for %s: %w", st.name, err) + } + return err + } + + return out.UnmarshalCBOR(bytes.NewReader(val)) +} + +// mutator func(*T) error +func (st *StoredState) Mutate(mutator interface{}) error { + return st.mutate(cborMutator(mutator)) +} + +func (st *StoredState) mutate(mutator func([]byte) ([]byte, error)) error { + has, err := st.ds.Has(st.name) + if err != nil { + return err + } + if !has { + return xerrors.Errorf("No state for %s", st.name) + } + + cur, err := st.ds.Get(st.name) + if err != nil { + return err + } + + mutated, err := mutator(cur) + if err != nil { + return err + } + + return st.ds.Put(st.name, mutated) +} + +func cborMutator(mutator interface{}) func([]byte) ([]byte, error) { + rmut := reflect.ValueOf(mutator) + + return func(in []byte) ([]byte, error) { + state := reflect.New(rmut.Type().In(0).Elem()) + + err := cborutil.ReadCborRPC(bytes.NewReader(in), state.Interface()) + if err != nil { + return nil, err + } + + out := rmut.Call([]reflect.Value{state}) + + if err := out[0].Interface(); err != nil { + return nil, err.(error) + } + + return cborutil.Dump(state.Interface()) + } +} diff --git a/lib/statestore/store.go b/lib/statestore/store.go index 38ce17b39..1761b899f 100644 --- a/lib/statestore/store.go +++ b/lib/statestore/store.go @@ -5,13 +5,11 @@ import ( "fmt" "reflect" + "github.com/filecoin-project/go-cbor-util" "github.com/ipfs/go-datastore" "github.com/ipfs/go-datastore/query" - cbg "github.com/whyrusleeping/cbor-gen" "go.uber.org/multierr" "golang.org/x/xerrors" - - "github.com/filecoin-project/go-cbor-util" ) type StateStore struct { @@ -22,7 +20,7 @@ func New(ds datastore.Datastore) *StateStore { return &StateStore{ds: ds} } -func toKey(k interface{}) datastore.Key { +func ToKey(k interface{}) datastore.Key { switch t := k.(type) { case uint64: return datastore.NewKey(fmt.Sprint(t)) @@ -34,7 +32,7 @@ func toKey(k interface{}) datastore.Key { } func (st *StateStore) Begin(i interface{}, state interface{}) error { - k := toKey(i) + k := ToKey(i) has, err := st.ds.Has(k) if err != nil { return err @@ -51,82 +49,15 @@ func (st *StateStore) Begin(i interface{}, state interface{}) error { return st.ds.Put(k, b) } -func (st *StateStore) End(i interface{}) error { - k := toKey(i) - has, err := st.ds.Has(k) - if err != nil { - return err +func (st *StateStore) Get(i interface{}) *StoredState { + return &StoredState{ + ds: st.ds, + name: ToKey(i), } - if !has { - return xerrors.Errorf("No state for %s", i) - } - return st.ds.Delete(k) -} - -func cborMutator(mutator interface{}) func([]byte) ([]byte, error) { - rmut := reflect.ValueOf(mutator) - - return func(in []byte) ([]byte, error) { - state := reflect.New(rmut.Type().In(0).Elem()) - - err := cborutil.ReadCborRPC(bytes.NewReader(in), state.Interface()) - if err != nil { - return nil, err - } - - out := rmut.Call([]reflect.Value{state}) - - if err := out[0].Interface(); err != nil { - return nil, err.(error) - } - - return cborutil.Dump(state.Interface()) - } -} - -// mutator func(*T) error -func (st *StateStore) Mutate(i interface{}, mutator interface{}) error { - return st.mutate(i, cborMutator(mutator)) -} - -func (st *StateStore) mutate(i interface{}, mutator func([]byte) ([]byte, error)) error { - k := toKey(i) - has, err := st.ds.Has(k) - if err != nil { - return err - } - if !has { - return xerrors.Errorf("No state for %s", i) - } - - cur, err := st.ds.Get(k) - if err != nil { - return err - } - - mutated, err := mutator(cur) - if err != nil { - return err - } - - return st.ds.Put(k, mutated) } func (st *StateStore) Has(i interface{}) (bool, error) { - return st.ds.Has(toKey(i)) -} - -func (st *StateStore) Get(i interface{}, out cbg.CBORUnmarshaler) error { - k := toKey(i) - val, err := st.ds.Get(k) - if err != nil { - if xerrors.Is(err, datastore.ErrNotFound) { - return xerrors.Errorf("No state for %s: %w", i, err) - } - return err - } - - return out.UnmarshalCBOR(bytes.NewReader(val)) + return st.ds.Has(ToKey(i)) } // out: *[]T diff --git a/storage/cbor_gen.go b/storage/cbor_gen.go index 43262ad2b..d2a6b9ee3 100644 --- a/storage/cbor_gen.go +++ b/storage/cbor_gen.go @@ -22,6 +22,10 @@ func (t *SealTicket) MarshalCBOR(w io.Writer) error { } // t.BlockHeight (uint64) (uint64) + if len("BlockHeight") > cbg.MaxLength { + return xerrors.Errorf("Value in field \"BlockHeight\" was too long") + } + if _, err := w.Write(cbg.CborEncodeMajorType(cbg.MajTextString, uint64(len("BlockHeight")))); err != nil { return err } @@ -34,6 +38,10 @@ func (t *SealTicket) MarshalCBOR(w io.Writer) error { } // t.TicketBytes ([]uint8) (slice) + if len("TicketBytes") > cbg.MaxLength { + return xerrors.Errorf("Value in field \"TicketBytes\" was too long") + } + if _, err := w.Write(cbg.CborEncodeMajorType(cbg.MajTextString, uint64(len("TicketBytes")))); err != nil { return err } @@ -41,6 +49,10 @@ func (t *SealTicket) MarshalCBOR(w io.Writer) error { return err } + if len(t.TicketBytes) > cbg.ByteArrayMaxLen { + return xerrors.Errorf("Byte array in field t.TicketBytes was too long") + } + if _, err := w.Write(cbg.CborEncodeMajorType(cbg.MajByteString, uint64(len(t.TicketBytes)))); err != nil { return err } @@ -133,6 +145,10 @@ func (t *SealSeed) MarshalCBOR(w io.Writer) error { } // t.BlockHeight (uint64) (uint64) + if len("BlockHeight") > cbg.MaxLength { + return xerrors.Errorf("Value in field \"BlockHeight\" was too long") + } + if _, err := w.Write(cbg.CborEncodeMajorType(cbg.MajTextString, uint64(len("BlockHeight")))); err != nil { return err } @@ -145,6 +161,10 @@ func (t *SealSeed) MarshalCBOR(w io.Writer) error { } // t.TicketBytes ([]uint8) (slice) + if len("TicketBytes") > cbg.MaxLength { + return xerrors.Errorf("Value in field \"TicketBytes\" was too long") + } + if _, err := w.Write(cbg.CborEncodeMajorType(cbg.MajTextString, uint64(len("TicketBytes")))); err != nil { return err } @@ -152,6 +172,10 @@ func (t *SealSeed) MarshalCBOR(w io.Writer) error { return err } + if len(t.TicketBytes) > cbg.ByteArrayMaxLen { + return xerrors.Errorf("Byte array in field t.TicketBytes was too long") + } + if _, err := w.Write(cbg.CborEncodeMajorType(cbg.MajByteString, uint64(len(t.TicketBytes)))); err != nil { return err } @@ -244,6 +268,10 @@ func (t *Piece) MarshalCBOR(w io.Writer) error { } // t.DealID (uint64) (uint64) + if len("DealID") > cbg.MaxLength { + return xerrors.Errorf("Value in field \"DealID\" was too long") + } + if _, err := w.Write(cbg.CborEncodeMajorType(cbg.MajTextString, uint64(len("DealID")))); err != nil { return err } @@ -256,6 +284,10 @@ func (t *Piece) MarshalCBOR(w io.Writer) error { } // t.Size (uint64) (uint64) + if len("Size") > cbg.MaxLength { + return xerrors.Errorf("Value in field \"Size\" was too long") + } + if _, err := w.Write(cbg.CborEncodeMajorType(cbg.MajTextString, uint64(len("Size")))); err != nil { return err } @@ -268,6 +300,10 @@ func (t *Piece) MarshalCBOR(w io.Writer) error { } // t.CommP ([]uint8) (slice) + if len("CommP") > cbg.MaxLength { + return xerrors.Errorf("Value in field \"CommP\" was too long") + } + if _, err := w.Write(cbg.CborEncodeMajorType(cbg.MajTextString, uint64(len("CommP")))); err != nil { return err } @@ -275,6 +311,10 @@ func (t *Piece) MarshalCBOR(w io.Writer) error { return err } + if len(t.CommP) > cbg.ByteArrayMaxLen { + return xerrors.Errorf("Byte array in field t.CommP was too long") + } + if _, err := w.Write(cbg.CborEncodeMajorType(cbg.MajByteString, uint64(len(t.CommP)))); err != nil { return err } @@ -390,6 +430,10 @@ func (t *SectorInfo) MarshalCBOR(w io.Writer) error { } // t.State (uint64) (uint64) + if len("State") > cbg.MaxLength { + return xerrors.Errorf("Value in field \"State\" was too long") + } + if _, err := w.Write(cbg.CborEncodeMajorType(cbg.MajTextString, uint64(len("State")))); err != nil { return err } @@ -402,6 +446,10 @@ func (t *SectorInfo) MarshalCBOR(w io.Writer) error { } // t.SectorID (uint64) (uint64) + if len("SectorID") > cbg.MaxLength { + return xerrors.Errorf("Value in field \"SectorID\" was too long") + } + if _, err := w.Write(cbg.CborEncodeMajorType(cbg.MajTextString, uint64(len("SectorID")))); err != nil { return err } @@ -414,6 +462,10 @@ func (t *SectorInfo) MarshalCBOR(w io.Writer) error { } // t.Nonce (uint64) (uint64) + if len("Nonce") > cbg.MaxLength { + return xerrors.Errorf("Value in field \"Nonce\" was too long") + } + if _, err := w.Write(cbg.CborEncodeMajorType(cbg.MajTextString, uint64(len("Nonce")))); err != nil { return err } @@ -426,6 +478,10 @@ func (t *SectorInfo) MarshalCBOR(w io.Writer) error { } // t.Pieces ([]storage.Piece) (slice) + if len("Pieces") > cbg.MaxLength { + return xerrors.Errorf("Value in field \"Pieces\" was too long") + } + if _, err := w.Write(cbg.CborEncodeMajorType(cbg.MajTextString, uint64(len("Pieces")))); err != nil { return err } @@ -433,6 +489,10 @@ func (t *SectorInfo) MarshalCBOR(w io.Writer) error { return err } + if len(t.Pieces) > cbg.MaxLength { + return xerrors.Errorf("Slice value in field t.Pieces was too long") + } + if _, err := w.Write(cbg.CborEncodeMajorType(cbg.MajArray, uint64(len(t.Pieces)))); err != nil { return err } @@ -443,6 +503,10 @@ func (t *SectorInfo) MarshalCBOR(w io.Writer) error { } // t.CommD ([]uint8) (slice) + if len("CommD") > cbg.MaxLength { + return xerrors.Errorf("Value in field \"CommD\" was too long") + } + if _, err := w.Write(cbg.CborEncodeMajorType(cbg.MajTextString, uint64(len("CommD")))); err != nil { return err } @@ -450,6 +514,10 @@ func (t *SectorInfo) MarshalCBOR(w io.Writer) error { return err } + if len(t.CommD) > cbg.ByteArrayMaxLen { + return xerrors.Errorf("Byte array in field t.CommD was too long") + } + if _, err := w.Write(cbg.CborEncodeMajorType(cbg.MajByteString, uint64(len(t.CommD)))); err != nil { return err } @@ -458,6 +526,10 @@ func (t *SectorInfo) MarshalCBOR(w io.Writer) error { } // t.CommR ([]uint8) (slice) + if len("CommR") > cbg.MaxLength { + return xerrors.Errorf("Value in field \"CommR\" was too long") + } + if _, err := w.Write(cbg.CborEncodeMajorType(cbg.MajTextString, uint64(len("CommR")))); err != nil { return err } @@ -465,6 +537,10 @@ func (t *SectorInfo) MarshalCBOR(w io.Writer) error { return err } + if len(t.CommR) > cbg.ByteArrayMaxLen { + return xerrors.Errorf("Byte array in field t.CommR was too long") + } + if _, err := w.Write(cbg.CborEncodeMajorType(cbg.MajByteString, uint64(len(t.CommR)))); err != nil { return err } @@ -473,6 +549,10 @@ func (t *SectorInfo) MarshalCBOR(w io.Writer) error { } // t.Proof ([]uint8) (slice) + if len("Proof") > cbg.MaxLength { + return xerrors.Errorf("Value in field \"Proof\" was too long") + } + if _, err := w.Write(cbg.CborEncodeMajorType(cbg.MajTextString, uint64(len("Proof")))); err != nil { return err } @@ -480,6 +560,10 @@ func (t *SectorInfo) MarshalCBOR(w io.Writer) error { return err } + if len(t.Proof) > cbg.ByteArrayMaxLen { + return xerrors.Errorf("Byte array in field t.Proof was too long") + } + if _, err := w.Write(cbg.CborEncodeMajorType(cbg.MajByteString, uint64(len(t.Proof)))); err != nil { return err } @@ -488,6 +572,10 @@ func (t *SectorInfo) MarshalCBOR(w io.Writer) error { } // t.Ticket (storage.SealTicket) (struct) + if len("Ticket") > cbg.MaxLength { + return xerrors.Errorf("Value in field \"Ticket\" was too long") + } + if _, err := w.Write(cbg.CborEncodeMajorType(cbg.MajTextString, uint64(len("Ticket")))); err != nil { return err } @@ -500,6 +588,10 @@ func (t *SectorInfo) MarshalCBOR(w io.Writer) error { } // t.PreCommitMessage (cid.Cid) (struct) + if len("PreCommitMessage") > cbg.MaxLength { + return xerrors.Errorf("Value in field \"PreCommitMessage\" was too long") + } + if _, err := w.Write(cbg.CborEncodeMajorType(cbg.MajTextString, uint64(len("PreCommitMessage")))); err != nil { return err } @@ -518,6 +610,10 @@ func (t *SectorInfo) MarshalCBOR(w io.Writer) error { } // t.Seed (storage.SealSeed) (struct) + if len("Seed") > cbg.MaxLength { + return xerrors.Errorf("Value in field \"Seed\" was too long") + } + if _, err := w.Write(cbg.CborEncodeMajorType(cbg.MajTextString, uint64(len("Seed")))); err != nil { return err } @@ -530,6 +626,10 @@ func (t *SectorInfo) MarshalCBOR(w io.Writer) error { } // t.CommitMessage (cid.Cid) (struct) + if len("CommitMessage") > cbg.MaxLength { + return xerrors.Errorf("Value in field \"CommitMessage\" was too long") + } + if _, err := w.Write(cbg.CborEncodeMajorType(cbg.MajTextString, uint64(len("CommitMessage")))); err != nil { return err } @@ -548,6 +648,10 @@ func (t *SectorInfo) MarshalCBOR(w io.Writer) error { } // t.FaultReportMsg (cid.Cid) (struct) + if len("FaultReportMsg") > cbg.MaxLength { + return xerrors.Errorf("Value in field \"FaultReportMsg\" was too long") + } + if _, err := w.Write(cbg.CborEncodeMajorType(cbg.MajTextString, uint64(len("FaultReportMsg")))); err != nil { return err } @@ -566,6 +670,10 @@ func (t *SectorInfo) MarshalCBOR(w io.Writer) error { } // t.LastErr (string) (string) + if len("LastErr") > cbg.MaxLength { + return xerrors.Errorf("Value in field \"LastErr\" was too long") + } + if _, err := w.Write(cbg.CborEncodeMajorType(cbg.MajTextString, uint64(len("LastErr")))); err != nil { return err } @@ -573,6 +681,10 @@ func (t *SectorInfo) MarshalCBOR(w io.Writer) error { return err } + if len(t.LastErr) > cbg.MaxLength { + return xerrors.Errorf("Value in field t.LastErr was too long") + } + if _, err := w.Write(cbg.CborEncodeMajorType(cbg.MajTextString, uint64(len(t.LastErr)))); err != nil { return err } diff --git a/storage/sector_utils.go b/storage/sector_utils.go index d8d5225bb..cc52f073a 100644 --- a/storage/sector_utils.go +++ b/storage/sector_utils.go @@ -52,6 +52,6 @@ func (m *Miner) ListSectors() ([]SectorInfo, error) { func (m *Miner) GetSectorInfo(sid uint64) (SectorInfo, error) { var out SectorInfo - err := m.sectors.Get(sid, &out) + err := m.sectors.Get(sid).Get(&out) return out, err } diff --git a/storage/sectors.go b/storage/sectors.go index d52dc70b6..0a322164d 100644 --- a/storage/sectors.go +++ b/storage/sectors.go @@ -161,7 +161,7 @@ func (m *Miner) onSectorIncoming(sector *SectorInfo) { func (m *Miner) onSectorUpdated(ctx context.Context, update sectorUpdate) { log.Infof("Sector %d updated state to %s", update.id, api.SectorStates[update.newState]) var sector SectorInfo - err := m.sectors.Mutate(update.id, func(s *SectorInfo) error { + err := m.sectors.Get(update.id).Mutate(func(s *SectorInfo) error { if update.nonce < s.Nonce { return xerrors.Errorf("update nonce too low, ignoring (%d < %d)", update.nonce, s.Nonce) }