lotus/lib/statestore/store.go

166 lines
3.1 KiB
Go

package statestore
import (
"bytes"
"fmt"
"reflect"
"github.com/ipfs/go-datastore"
"github.com/ipfs/go-datastore/query"
cbg "github.com/whyrusleeping/cbor-gen"
"go.uber.org/multierr"
"golang.org/x/xerrors"
"github.com/filecoin-project/lotus/lib/cborutil"
)
type StateStore struct {
ds datastore.Datastore
}
func New(ds datastore.Datastore) *StateStore {
return &StateStore{ds: ds}
}
func toKey(k interface{}) datastore.Key {
switch t := k.(type) {
case uint64:
return datastore.NewKey(fmt.Sprint(t))
case fmt.Stringer:
return datastore.NewKey(t.String())
default:
panic("unexpected key type")
}
}
func (st *StateStore) Begin(i interface{}, state interface{}) error {
k := toKey(i)
has, err := st.ds.Has(k)
if err != nil {
return err
}
if has {
return xerrors.Errorf("already tracking state for %v", i)
}
b, err := cborutil.Dump(state)
if err != nil {
return err
}
return st.ds.Put(k, b)
}
func (st *StateStore) End(i interface{}) error {
k := toKey(i)
has, err := st.ds.Has(k)
if err != nil {
return err
}
if !has {
return xerrors.Errorf("No state for %s", i)
}
return st.ds.Delete(k)
}
func cborMutator(mutator interface{}) func([]byte) ([]byte, error) {
rmut := reflect.ValueOf(mutator)
return func(in []byte) ([]byte, error) {
state := reflect.New(rmut.Type().In(0).Elem())
err := cborutil.ReadCborRPC(bytes.NewReader(in), state.Interface())
if err != nil {
return nil, err
}
out := rmut.Call([]reflect.Value{state})
if err := out[0].Interface(); err != nil {
return nil, err.(error)
}
return cborutil.Dump(state.Interface())
}
}
// mutator func(*T) error
func (st *StateStore) Mutate(i interface{}, mutator interface{}) error {
return st.mutate(i, cborMutator(mutator))
}
func (st *StateStore) mutate(i interface{}, mutator func([]byte) ([]byte, error)) error {
k := toKey(i)
has, err := st.ds.Has(k)
if err != nil {
return err
}
if !has {
return xerrors.Errorf("No state for %s", i)
}
cur, err := st.ds.Get(k)
if err != nil {
return err
}
mutated, err := mutator(cur)
if err != nil {
return err
}
return st.ds.Put(k, mutated)
}
func (st *StateStore) Has(i interface{}) (bool, error) {
return st.ds.Has(toKey(i))
}
func (st *StateStore) Get(i interface{}, out cbg.CBORUnmarshaler) error {
k := toKey(i)
val, err := st.ds.Get(k)
if err != nil {
if xerrors.Is(err, datastore.ErrNotFound) {
return xerrors.Errorf("No state for %s: %w", i, err)
}
return err
}
return out.UnmarshalCBOR(bytes.NewReader(val))
}
// out: *[]T
func (st *StateStore) List(out interface{}) error {
res, err := st.ds.Query(query.Query{})
if err != nil {
return err
}
defer res.Close()
outT := reflect.TypeOf(out).Elem().Elem()
rout := reflect.ValueOf(out)
var errs error
for {
res, ok := res.NextSync()
if !ok {
break
}
if res.Error != nil {
return res.Error
}
elem := reflect.New(outT)
err := cborutil.ReadCborRPC(bytes.NewReader(res.Value), elem.Interface())
if err != nil {
errs = multierr.Append(errs, xerrors.Errorf("decoding state for key '%s': %w", res.Key, err))
continue
}
rout.Elem().Set(reflect.Append(rout.Elem(), elem.Elem()))
}
return nil
}