package main

import (
	"context"
	"net/http"
	"time"

	manet "github.com/multiformats/go-multiaddr-net"

	"golang.org/x/xerrors"

	"github.com/filecoin-project/lotus/api"
	"github.com/filecoin-project/lotus/api/client"
	"github.com/filecoin-project/lotus/build"
	"github.com/filecoin-project/lotus/chain"
	"github.com/filecoin-project/lotus/chain/store"
	"github.com/filecoin-project/lotus/chain/types"
	"github.com/filecoin-project/lotus/lib/jsonrpc"
	"github.com/filecoin-project/lotus/node/repo"
)

func getAPI(path string) (string, http.Header, error) {
	r, err := repo.NewFS(path)
	if err != nil {
		return "", nil, err
	}

	ma, err := r.APIEndpoint()
	if err != nil {
		return "", nil, xerrors.Errorf("failed to get api endpoint: %w", err)
	}
	_, addr, err := manet.DialArgs(ma)
	if err != nil {
		return "", nil, err
	}
	var headers http.Header
	token, err := r.APIToken()
	if err != nil {
		log.Warnw("Couldn't load CLI token, capabilities may be limited", "error", err)
	} else {
		headers = http.Header{}
		headers.Add("Authorization", "Bearer "+string(token))
	}

	return "ws://" + addr + "/rpc/v0", headers, nil
}

func WaitForSyncComplete(ctx context.Context, napi api.FullNode) error {
sync_complete:
	for {
		select {
		case <-ctx.Done():
			return ctx.Err()
		case <-time.After(5 * time.Second):
			state, err := napi.SyncState(ctx)
			if err != nil {
				return err
			}

			for i, w := range state.ActiveSyncs {
				if w.Target == nil {
					continue
				}

				if w.Stage == api.StageSyncErrored {
					log.Errorw(
						"Syncing",
						"worker", i,
						"base", w.Base.Key(),
						"target", w.Target.Key(),
						"target_height", w.Target.Height(),
						"height", w.Height,
						"error", w.Message,
						"stage", chain.SyncStageString(w.Stage),
					)
				} else {
					log.Infow(
						"Syncing",
						"worker", i,
						"base", w.Base.Key(),
						"target", w.Target.Key(),
						"target_height", w.Target.Height(),
						"height", w.Height,
						"stage", chain.SyncStageString(w.Stage),
					)
				}

				if w.Stage == api.StageSyncComplete {
					break sync_complete
				}
			}
		}
	}

	for {
		select {
		case <-ctx.Done():
			return ctx.Err()
		case <-time.After(5 * time.Second):
			head, err := napi.ChainHead(ctx)
			if err != nil {
				return err
			}

			timestampDelta := time.Now().Unix() - int64(head.MinTimestamp())

			log.Infow(
				"Waiting for reasonable head height",
				"height", head.Height(),
				"timestamp_delta", timestampDelta,
			)

			// If we get within 20 blocks of the current exected block height we
			// consider sync complete. Block propagation is not always great but we still
			// want to be recording stats as soon as we can
			if timestampDelta < build.BlockDelay*20 {
				return nil
			}
		}
	}
}

func GetTips(ctx context.Context, api api.FullNode, lastHeight uint64, headlag int) (<-chan *types.TipSet, error) {
	chmain := make(chan *types.TipSet)

	hb := NewHeadBuffer(headlag)

	notif, err := api.ChainNotify(ctx)
	if err != nil {
		return nil, err
	}

	go func() {
		defer close(chmain)

		ping := time.Tick(30 * time.Second)

		for {
			select {
			case changes := <-notif:
				for _, change := range changes {
					log.Infow("Head event", "height", change.Val.Height(), "type", change.Type)

					switch change.Type {
					case store.HCCurrent:
						tipsets, err := loadTipsets(ctx, api, change.Val, lastHeight)
						if err != nil {
							log.Info(err)
							return
						}

						for _, tipset := range tipsets {
							chmain <- tipset
						}
					case store.HCApply:
						if out := hb.Push(change); out != nil {
							chmain <- out.Val
						}
					case store.HCRevert:
						hb.Pop()
					}
				}
			case <-ping:
				log.Info("Running health check")

				cctx, cancel := context.WithTimeout(ctx, 5*time.Second)
				if _, err := api.ID(cctx); err != nil {
					log.Error("Health check failed")
					return
				}

				cancel()

				log.Info("Node online")
			case <-ctx.Done():
				return
			}
		}
	}()

	return chmain, nil
}

func loadTipsets(ctx context.Context, api api.FullNode, curr *types.TipSet, lowestHeight uint64) ([]*types.TipSet, error) {
	tipsets := []*types.TipSet{}
	for {
		if curr.Height() == 0 {
			break
		}

		if curr.Height() <= lowestHeight {
			break
		}

		log.Infow("Walking back", "height", curr.Height())
		tipsets = append(tipsets, curr)

		tsk := curr.Parents()
		prev, err := api.ChainGetTipSet(ctx, tsk)
		if err != nil {
			return tipsets, err
		}

		curr = prev
	}

	for i, j := 0, len(tipsets)-1; i < j; i, j = i+1, j-1 {
		tipsets[i], tipsets[j] = tipsets[j], tipsets[i]
	}

	return tipsets, nil
}

func GetFullNodeAPI(repo string) (api.FullNode, jsonrpc.ClientCloser, error) {
	addr, headers, err := getAPI(repo)
	if err != nil {
		return nil, nil, err
	}

	return client.NewFullNodeRPC(addr, headers)
}