package sealer

import (
	"fmt"
	"io"

	cbg "github.com/whyrusleeping/cbor-gen"
	"golang.org/x/xerrors"

	"github.com/filecoin-project/go-statestore"

	"github.com/filecoin-project/lotus/storage/sealer/storiface"
)

type workerCallTracker struct {
	st *statestore.StateStore // by CallID
}

type CallState uint64

const (
	CallStarted CallState = iota
	CallDone
	// returned -> remove
)

type Call struct {
	ID      storiface.CallID
	RetType ReturnType

	State CallState

	Result *ManyBytes // json bytes
}

func (wt *workerCallTracker) onStart(ci storiface.CallID, rt ReturnType) error {
	return wt.st.Begin(ci, &Call{
		ID:      ci,
		RetType: rt,
		State:   CallStarted,
	})
}

func (wt *workerCallTracker) onDone(ci storiface.CallID, ret []byte) error {
	st := wt.st.Get(ci)
	return st.Mutate(func(cs *Call) error {
		cs.State = CallDone
		cs.Result = &ManyBytes{ret}
		return nil
	})
}

func (wt *workerCallTracker) onReturned(ci storiface.CallID) error {
	st := wt.st.Get(ci)
	return st.End()
}

func (wt *workerCallTracker) unfinished() ([]Call, error) {
	var out []Call
	return out, wt.st.List(&out)
}

// Ideally this would be a tag on the struct field telling cbor-gen to enforce higher max-len
type ManyBytes struct {
	b []byte
}

const many = 100 << 20

func (t *ManyBytes) MarshalCBOR(w io.Writer) error {
	if t == nil {
		t = &ManyBytes{}
	}

	if len(t.b) > many {
		return xerrors.Errorf("byte array in field t.Result was too long")
	}

	scratch := make([]byte, 9)

	if err := cbg.WriteMajorTypeHeaderBuf(scratch, w, cbg.MajByteString, uint64(len(t.b))); err != nil {
		return err
	}

	if _, err := w.Write(t.b[:]); err != nil {
		return err
	}
	return nil
}

func (t *ManyBytes) UnmarshalCBOR(r io.Reader) error {
	*t = ManyBytes{}

	br := cbg.GetPeeker(r)
	scratch := make([]byte, 9)

	maj, extra, err := cbg.CborReadHeaderBuf(br, scratch)
	if err != nil {
		return err
	}

	if extra > many {
		return fmt.Errorf("byte array too large (%d)", extra)
	}
	if maj != cbg.MajByteString {
		return fmt.Errorf("expected byte array")
	}

	if extra > 0 {
		t.b = make([]uint8, extra)
	}

	if _, err := io.ReadFull(br, t.b[:]); err != nil {
		return err
	}

	return nil
}