package sectorstorage

import (
	"context"
	"io"
	"sync"
	"time"

	"github.com/ipfs/go-cid"
	"go.opencensus.io/stats"
	"go.opencensus.io/tag"

	"github.com/filecoin-project/go-state-types/abi"
	"github.com/filecoin-project/specs-storage/storage"

	"github.com/filecoin-project/lotus/extern/sector-storage/sealtasks"
	"github.com/filecoin-project/lotus/extern/sector-storage/storiface"
	"github.com/filecoin-project/lotus/metrics"
)

type trackedWork struct {
	job            storiface.WorkerJob
	worker         WorkerID
	workerHostname string
}

type workTracker struct {
	lk sync.Mutex

	done    map[storiface.CallID]struct{}
	running map[storiface.CallID]trackedWork

	// TODO: done, aggregate stats, queue stats, scheduler feedback
}

func (wt *workTracker) onDone(ctx context.Context, callID storiface.CallID) {
	wt.lk.Lock()
	defer wt.lk.Unlock()

	t, ok := wt.running[callID]
	if !ok {
		wt.done[callID] = struct{}{}

		stats.Record(ctx, metrics.WorkerUntrackedCallsReturned.M(1))
		return
	}

	took := metrics.SinceInMilliseconds(t.job.Start)

	ctx, _ = tag.New(
		ctx,
		tag.Upsert(metrics.TaskType, string(t.job.Task)),
		tag.Upsert(metrics.WorkerHostname, t.workerHostname),
	)
	stats.Record(ctx, metrics.WorkerCallsReturnedCount.M(1), metrics.WorkerCallsReturnedDuration.M(took))

	delete(wt.running, callID)
}

func (wt *workTracker) track(ctx context.Context, wid WorkerID, wi storiface.WorkerInfo, sid storage.SectorRef, task sealtasks.TaskType) func(storiface.CallID, error) (storiface.CallID, error) {
	return func(callID storiface.CallID, err error) (storiface.CallID, error) {
		if err != nil {
			return callID, err
		}

		wt.lk.Lock()
		defer wt.lk.Unlock()

		_, done := wt.done[callID]
		if done {
			delete(wt.done, callID)
			return callID, err
		}

		wt.running[callID] = trackedWork{
			job: storiface.WorkerJob{
				ID:     callID,
				Sector: sid.ID,
				Task:   task,
				Start:  time.Now(),
			},
			worker:         wid,
			workerHostname: wi.Hostname,
		}

		ctx, _ = tag.New(
			ctx,
			tag.Upsert(metrics.TaskType, string(task)),
			tag.Upsert(metrics.WorkerHostname, wi.Hostname),
		)
		stats.Record(ctx, metrics.WorkerCallsStarted.M(1))

		return callID, err
	}
}

func (wt *workTracker) worker(wid WorkerID, wi storiface.WorkerInfo, w Worker) Worker {
	return &trackedWorker{
		Worker:     w,
		wid:        wid,
		workerInfo: wi,

		tracker: wt,
	}
}

func (wt *workTracker) Running() []trackedWork {
	wt.lk.Lock()
	defer wt.lk.Unlock()

	out := make([]trackedWork, 0, len(wt.running))
	for _, job := range wt.running {
		out = append(out, job)
	}

	return out
}

type trackedWorker struct {
	Worker
	wid        WorkerID
	workerInfo storiface.WorkerInfo

	tracker *workTracker
}

func (t *trackedWorker) SealPreCommit1(ctx context.Context, sector storage.SectorRef, ticket abi.SealRandomness, pieces []abi.PieceInfo) (storiface.CallID, error) {
	return t.tracker.track(ctx, t.wid, t.workerInfo, sector, sealtasks.TTPreCommit1)(t.Worker.SealPreCommit1(ctx, sector, ticket, pieces))
}

func (t *trackedWorker) SealPreCommit2(ctx context.Context, sector storage.SectorRef, pc1o storage.PreCommit1Out) (storiface.CallID, error) {
	return t.tracker.track(ctx, t.wid, t.workerInfo, sector, sealtasks.TTPreCommit2)(t.Worker.SealPreCommit2(ctx, sector, pc1o))
}

func (t *trackedWorker) SealCommit1(ctx context.Context, sector storage.SectorRef, ticket abi.SealRandomness, seed abi.InteractiveSealRandomness, pieces []abi.PieceInfo, cids storage.SectorCids) (storiface.CallID, error) {
	return t.tracker.track(ctx, t.wid, t.workerInfo, sector, sealtasks.TTCommit1)(t.Worker.SealCommit1(ctx, sector, ticket, seed, pieces, cids))
}

func (t *trackedWorker) SealCommit2(ctx context.Context, sector storage.SectorRef, c1o storage.Commit1Out) (storiface.CallID, error) {
	return t.tracker.track(ctx, t.wid, t.workerInfo, sector, sealtasks.TTCommit2)(t.Worker.SealCommit2(ctx, sector, c1o))
}

func (t *trackedWorker) FinalizeSector(ctx context.Context, sector storage.SectorRef, keepUnsealed []storage.Range) (storiface.CallID, error) {
	return t.tracker.track(ctx, t.wid, t.workerInfo, sector, sealtasks.TTFinalize)(t.Worker.FinalizeSector(ctx, sector, keepUnsealed))
}

func (t *trackedWorker) AddPiece(ctx context.Context, sector storage.SectorRef, pieceSizes []abi.UnpaddedPieceSize, newPieceSize abi.UnpaddedPieceSize, pieceData storage.Data) (storiface.CallID, error) {
	return t.tracker.track(ctx, t.wid, t.workerInfo, sector, sealtasks.TTAddPiece)(t.Worker.AddPiece(ctx, sector, pieceSizes, newPieceSize, pieceData))
}

func (t *trackedWorker) Fetch(ctx context.Context, s storage.SectorRef, ft storiface.SectorFileType, ptype storiface.PathType, am storiface.AcquireMode) (storiface.CallID, error) {
	return t.tracker.track(ctx, t.wid, t.workerInfo, s, sealtasks.TTFetch)(t.Worker.Fetch(ctx, s, ft, ptype, am))
}

func (t *trackedWorker) UnsealPiece(ctx context.Context, id storage.SectorRef, index storiface.UnpaddedByteIndex, size abi.UnpaddedPieceSize, randomness abi.SealRandomness, cid cid.Cid) (storiface.CallID, error) {
	return t.tracker.track(ctx, t.wid, t.workerInfo, id, sealtasks.TTUnseal)(t.Worker.UnsealPiece(ctx, id, index, size, randomness, cid))
}

func (t *trackedWorker) ReadPiece(ctx context.Context, writer io.Writer, id storage.SectorRef, index storiface.UnpaddedByteIndex, size abi.UnpaddedPieceSize) (storiface.CallID, error) {
	return t.tracker.track(ctx, t.wid, t.workerInfo, id, sealtasks.TTReadUnsealed)(t.Worker.ReadPiece(ctx, writer, id, index, size))
}

var _ Worker = &trackedWorker{}