117 lines
2.6 KiB
Go
117 lines
2.6 KiB
Go
package statemachine
|
|
|
|
import (
|
|
"context"
|
|
"reflect"
|
|
"sync"
|
|
|
|
"github.com/filecoin-project/go-statestore"
|
|
"github.com/ipfs/go-datastore"
|
|
"golang.org/x/xerrors"
|
|
)
|
|
|
|
type StateHandler interface {
|
|
// returns
|
|
Plan(events []Event, user interface{}) (interface{}, error)
|
|
}
|
|
|
|
// StateGroup manages a group of state machines sharing the same logic
|
|
type StateGroup struct {
|
|
sts *statestore.StateStore
|
|
hnd StateHandler
|
|
stateType reflect.Type
|
|
|
|
lk sync.Mutex
|
|
sms map[datastore.Key]*StateMachine
|
|
}
|
|
|
|
// stateType: T - (reflect.TypeOf(MyStateStruct{}))
|
|
func New(ds datastore.Datastore, hnd StateHandler, stateType reflect.Type) *StateGroup {
|
|
return &StateGroup{
|
|
sts: statestore.New(ds),
|
|
hnd: hnd,
|
|
stateType: stateType,
|
|
|
|
sms: map[datastore.Key]*StateMachine{},
|
|
}
|
|
}
|
|
|
|
// Send sends an event to machine identified by `id`.
|
|
// `evt` is going to be passed into StateHandler.Planner, in the events[].User param
|
|
//
|
|
// If a state machine with the specified id doesn't exits, it's created, and it's
|
|
// state is set to zero-value of stateType provided in group constructor
|
|
func (s *StateGroup) Send(id interface{}, evt interface{}) (err error) {
|
|
s.lk.Lock()
|
|
defer s.lk.Unlock()
|
|
|
|
sm, exist := s.sms[statestore.ToKey(id)]
|
|
if !exist {
|
|
sm, err = s.loadOrCreate(id)
|
|
if err != nil {
|
|
return xerrors.Errorf("loadOrCreate state: %w", err)
|
|
}
|
|
s.sms[statestore.ToKey(id)] = sm
|
|
}
|
|
|
|
return sm.send(Event{User: evt})
|
|
}
|
|
|
|
func (s *StateGroup) loadOrCreate(name interface{}) (*StateMachine, error) {
|
|
exists, err := s.sts.Has(name)
|
|
if err != nil {
|
|
return nil, xerrors.Errorf("failed to check if state for %v exists: %w", name, err)
|
|
}
|
|
|
|
if !exists {
|
|
userState := reflect.New(s.stateType).Interface()
|
|
|
|
err = s.sts.Begin(name, userState)
|
|
if err != nil {
|
|
return nil, xerrors.Errorf("saving initial state: %w", err)
|
|
}
|
|
}
|
|
|
|
res := &StateMachine{
|
|
planner: s.hnd.Plan,
|
|
eventsIn: make(chan Event),
|
|
|
|
name: name,
|
|
st: s.sts.Get(name),
|
|
stateType: s.stateType,
|
|
|
|
stageDone: make(chan struct{}),
|
|
closing: make(chan struct{}),
|
|
closed: make(chan struct{}),
|
|
}
|
|
|
|
go res.run()
|
|
|
|
return res, nil
|
|
}
|
|
|
|
// Stop stops all state machines in this group
|
|
func (s *StateGroup) Stop(ctx context.Context) error {
|
|
s.lk.Lock()
|
|
defer s.lk.Unlock()
|
|
|
|
for _, sm := range s.sms {
|
|
if err := sm.stop(ctx); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// List outputs states of all state machines in this group
|
|
// out: *[]StateT
|
|
func (s *StateGroup) List(out interface{}) error {
|
|
return s.sts.List(out)
|
|
}
|
|
|
|
// Get gets state for a single state machine
|
|
func (s *StateGroup) Get(id interface{}) *statestore.StoredState {
|
|
return s.sts.Get(id)
|
|
}
|