package messagepool

import (
	"context"
	"sort"
	"time"

	"golang.org/x/xerrors"

	"github.com/filecoin-project/go-address"
	"github.com/filecoin-project/lotus/build"
	"github.com/filecoin-project/lotus/chain/messagepool/gasguess"
	"github.com/filecoin-project/lotus/chain/types"
	"github.com/ipfs/go-cid"
)

const repubMsgLimit = 30

var RepublishBatchDelay = 100 * time.Millisecond

func (mp *MessagePool) republishPendingMessages() error {
	mp.curTsLk.Lock()
	ts := mp.curTs

	baseFee, err := mp.api.ChainComputeBaseFee(context.TODO(), ts)
	if err != nil {
		mp.curTsLk.Unlock()
		return xerrors.Errorf("computing basefee: %w", err)
	}
	baseFeeLowerBound := getBaseFeeLowerBound(baseFee, baseFeeLowerBoundFactor)

	pending := make(map[address.Address]map[uint64]*types.SignedMessage)
	mp.lk.Lock()
	mp.republished = nil // clear this to avoid races triggering an early republish
	for actor := range mp.localAddrs {
		mset, ok := mp.pending[actor]
		if !ok {
			continue
		}
		if len(mset.msgs) == 0 {
			continue
		}
		// we need to copy this while holding the lock to avoid races with concurrent modification
		pend := make(map[uint64]*types.SignedMessage, len(mset.msgs))
		for nonce, m := range mset.msgs {
			pend[nonce] = m
		}
		pending[actor] = pend
	}
	mp.lk.Unlock()
	mp.curTsLk.Unlock()

	if len(pending) == 0 {
		return nil
	}

	var chains []*msgChain
	for actor, mset := range pending {
		// We use the baseFee lower bound for createChange so that we optimistically include
		// chains that might become profitable in the next 20 blocks.
		// We still check the lowerBound condition for individual messages so that we don't send
		// messages that will be rejected by the mpool spam protector, so this is safe to do.
		next := mp.createMessageChains(actor, mset, baseFeeLowerBound, ts)
		chains = append(chains, next...)
	}

	if len(chains) == 0 {
		return nil
	}

	sort.Slice(chains, func(i, j int) bool {
		return chains[i].Before(chains[j])
	})

	gasLimit := int64(build.BlockGasLimit)
	minGas := int64(gasguess.MinGas)
	var msgs []*types.SignedMessage
loop:
	for i := 0; i < len(chains); {
		chain := chains[i]

		// we can exceed this if we have picked (some) longer chain already
		if len(msgs) > repubMsgLimit {
			break
		}

		// there is not enough gas for any message
		if gasLimit <= minGas {
			break
		}

		// has the chain been invalidated?
		if !chain.valid {
			i++
			continue
		}

		// does it fit in a block?
		if chain.gasLimit <= gasLimit {
			// check the baseFee lower bound -- only republish messages that can be included in the chain
			// within the next 20 blocks.
			for _, m := range chain.msgs {
				if m.Message.GasFeeCap.LessThan(baseFeeLowerBound) {
					chain.Invalidate()
					continue loop
				}
				gasLimit -= m.Message.GasLimit
				msgs = append(msgs, m)
			}

			// we processed the whole chain, advance
			i++
			continue
		}

		// we can't fit the current chain but there is gas to spare
		// trim it and push it down
		chain.Trim(gasLimit, mp, baseFee)
		for j := i; j < len(chains)-1; j++ {
			if chains[j].Before(chains[j+1]) {
				break
			}
			chains[j], chains[j+1] = chains[j+1], chains[j]
		}
	}

	count := 0
	log.Infof("republishing %d messages", len(msgs))
	for _, m := range msgs {
		mb, err := m.Serialize()
		if err != nil {
			return xerrors.Errorf("cannot serialize message: %w", err)
		}

		err = mp.api.PubSubPublish(build.MessagesTopic(mp.netName), mb)
		if err != nil {
			return xerrors.Errorf("cannot publish: %w", err)
		}

		count++

		if count < len(msgs) {
			// this delay is here to encourage the pubsub subsystem to process the messages serially
			// and avoid creating nonce gaps because of concurrent validation.
			time.Sleep(RepublishBatchDelay)
		}
	}

	if len(msgs) > 0 {
		mp.journal.RecordEvent(mp.evtTypes[evtTypeMpoolRepub], func() interface{} {
			msgsEv := make([]MessagePoolEvtMessage, 0, len(msgs))
			for _, m := range msgs {
				msgsEv = append(msgsEv, MessagePoolEvtMessage{Message: m.Message, CID: m.Cid()})
			}
			return MessagePoolEvt{
				Action:   "repub",
				Messages: msgsEv,
			}
		})
	}

	// track most recently republished messages
	republished := make(map[cid.Cid]struct{})
	for _, m := range msgs[:count] {
		republished[m.Cid()] = struct{}{}
	}

	mp.lk.Lock()
	// update the republished set so that we can trigger early republish from head changes
	mp.republished = republished
	mp.lk.Unlock()

	return nil
}