package store

import (
	"context"

	"github.com/filecoin-project/lotus/build"
	"github.com/filecoin-project/lotus/chain/types"
	"github.com/filecoin-project/specs-actors/actors/abi"
	"github.com/filecoin-project/specs-actors/actors/abi/big"
	"github.com/ipfs/go-cid"
	"golang.org/x/xerrors"
)

func computeNextBaseFee(baseFee types.BigInt, gasLimitUsed int64, noOfBlocks int) types.BigInt {
	// deta := 1/PackingEfficiency * gasLimitUsed/noOfBlocks - build.BlockGasTarget
	// change := baseFee * deta / BlockGasTarget / BaseFeeMaxChangeDenom
	// nextBaseFee = baseFee + change
	// nextBaseFee = max(nextBaseFee, build.MinimumBaseFee)

	delta := build.PackingEfficiencyDenom * gasLimitUsed / (int64(noOfBlocks) * build.PackingEfficiencyNum)
	delta -= build.BlockGasTarget

	// cap change at 12.5% (BaseFeeMaxChangeDenom) by capping delta
	if delta > build.BlockGasTarget {
		delta = build.BlockGasTarget
	}
	if delta < -build.BlockGasTarget {
		delta = -build.BlockGasTarget
	}

	change := big.Mul(baseFee, big.NewInt(delta))
	change = big.Div(change, big.NewInt(build.BlockGasTarget))
	change = big.Div(change, big.NewInt(build.BaseFeeMaxChangeDenom))

	nextBaseFee := big.Add(baseFee, change)
	if big.Cmp(nextBaseFee, big.NewInt(build.MinimumBaseFee)) < 0 {
		nextBaseFee = big.NewInt(build.MinimumBaseFee)
	}
	return nextBaseFee
}

func (cs *ChainStore) ComputeBaseFee(ctx context.Context, ts *types.TipSet) (abi.TokenAmount, error) {
	zero := abi.NewTokenAmount(0)

	// totalLimit is sum of GasLimits of unique messages in a tipset
	totalLimit := int64(0)

	seen := make(map[cid.Cid]struct{})

	for _, b := range ts.Blocks() {
		msg1, msg2, err := cs.MessagesForBlock(b)
		if err != nil {
			return zero, xerrors.Errorf("error getting messages for: %s: %w", b.Cid(), err)
		}
		for _, m := range msg1 {
			c := m.Cid()
			if _, ok := seen[c]; !ok {
				totalLimit += m.GasLimit
				seen[c] = struct{}{}
			}
		}
		for _, m := range msg2 {
			c := m.Cid()
			if _, ok := seen[c]; !ok {
				totalLimit += m.Message.GasLimit
				seen[c] = struct{}{}
			}
		}
	}
	parentBaseFee := ts.Blocks()[0].ParentBaseFee

	return computeNextBaseFee(parentBaseFee, totalLimit, len(ts.Blocks())), nil
}