package statestore

import (
	"bytes"
	"fmt"
	"reflect"

	"github.com/ipfs/go-datastore"
	"github.com/ipfs/go-datastore/query"
	cbg "github.com/whyrusleeping/cbor-gen"
	"golang.org/x/xerrors"

	"github.com/filecoin-project/lotus/lib/cborutil"
)

type StateStore struct {
	ds datastore.Datastore
}

func New(ds datastore.Datastore) *StateStore {
	return &StateStore{ds: ds}
}

func toKey(k interface{}) datastore.Key {
	switch t := k.(type) {
	case uint64:
		return datastore.NewKey(fmt.Sprint(t))
	case fmt.Stringer:
		return datastore.NewKey(t.String())
	default:
		panic("unexpected key type")
	}
}

func (st *StateStore) Begin(i interface{}, state interface{}) error {
	k := toKey(i)
	has, err := st.ds.Has(k)
	if err != nil {
		return err
	}
	if has {
		return xerrors.Errorf("already tracking state for %v", i)
	}

	b, err := cborutil.Dump(state)
	if err != nil {
		return err
	}

	return st.ds.Put(k, b)
}

func (st *StateStore) End(i interface{}) error {
	k := toKey(i)
	has, err := st.ds.Has(k)
	if err != nil {
		return err
	}
	if !has {
		return xerrors.Errorf("No state for %s", i)
	}
	return st.ds.Delete(k)
}

func cborMutator(mutator interface{}) func([]byte) ([]byte, error) {
	rmut := reflect.ValueOf(mutator)

	return func(in []byte) ([]byte, error) {
		state := reflect.New(rmut.Type().In(0).Elem())

		err := cborutil.ReadCborRPC(bytes.NewReader(in), state.Interface())
		if err != nil {
			return nil, err
		}

		out := rmut.Call([]reflect.Value{state})

		if err := out[0].Interface(); err != nil {
			return nil, err.(error)
		}

		return cborutil.Dump(state.Interface())
	}
}

// mutator func(*T) error
func (st *StateStore) Mutate(i interface{}, mutator interface{}) error {
	return st.mutate(i, cborMutator(mutator))
}

func (st *StateStore) mutate(i interface{}, mutator func([]byte) ([]byte, error)) error {
	k := toKey(i)
	has, err := st.ds.Has(k)
	if err != nil {
		return err
	}
	if !has {
		return xerrors.Errorf("No state for %s", i)
	}

	cur, err := st.ds.Get(k)
	if err != nil {
		return err
	}

	mutated, err := mutator(cur)
	if err != nil {
		return err
	}

	return st.ds.Put(k, mutated)
}

func (st *StateStore) Has(i interface{}) (bool, error) {
	return st.ds.Has(toKey(i))
}

func (st *StateStore) Get(i interface{}, out cbg.CBORUnmarshaler) error {
	k := toKey(i)
	val, err := st.ds.Get(k)
	if err != nil {
		if xerrors.Is(err, datastore.ErrNotFound) {
			return xerrors.Errorf("No state for %s: %w", i, err)
		}
		return err
	}

	return out.UnmarshalCBOR(bytes.NewReader(val))
}

// out: *[]T
func (st *StateStore) List(out interface{}) error {
	res, err := st.ds.Query(query.Query{})
	if err != nil {
		return err
	}
	defer res.Close()

	outT := reflect.TypeOf(out).Elem().Elem()
	rout := reflect.ValueOf(out)

	for {
		res, ok := res.NextSync()
		if !ok {
			break
		}
		if res.Error != nil {
			return res.Error
		}

		elem := reflect.New(outT)
		err := cborutil.ReadCborRPC(bytes.NewReader(res.Value), elem.Interface())
		if err != nil {
			return err
		}

		rout.Elem().Set(reflect.Append(rout.Elem(), elem.Elem()))
	}

	return nil
}