package beacon

import (
	"fmt"
	"math/big"
	"testing"

	"github.com/ethereum/go-ethereum/common"
	"github.com/ethereum/go-ethereum/consensus"
	"github.com/ethereum/go-ethereum/core/types"
	"github.com/ethereum/go-ethereum/params"
)

type mockChain struct {
	config *params.ChainConfig
	tds    map[uint64]*big.Int
}

func newMockChain() *mockChain {
	return &mockChain{
		config: new(params.ChainConfig),
		tds:    make(map[uint64]*big.Int),
	}
}

func (m *mockChain) Config() *params.ChainConfig {
	return m.config
}

func (m *mockChain) CurrentHeader() *types.Header { panic("not implemented") }

func (m *mockChain) GetHeader(hash common.Hash, number uint64) *types.Header {
	panic("not implemented")
}

func (m *mockChain) GetHeaderByNumber(number uint64) *types.Header { panic("not implemented") }

func (m *mockChain) GetHeaderByHash(hash common.Hash) *types.Header { panic("not implemented") }

func (m *mockChain) GetTd(hash common.Hash, number uint64) *big.Int {
	num, ok := m.tds[number]
	if ok {
		return new(big.Int).Set(num)
	}
	return nil
}

func TestVerifyTerminalBlock(t *testing.T) {
	chain := newMockChain()
	chain.tds[0] = big.NewInt(10)
	chain.config.TerminalTotalDifficulty = big.NewInt(50)

	tests := []struct {
		preHeaders []*types.Header
		ttd        *big.Int
		err        error
		index      int
	}{
		// valid ttd
		{
			preHeaders: []*types.Header{
				{Number: big.NewInt(1), Difficulty: big.NewInt(10)},
				{Number: big.NewInt(2), Difficulty: big.NewInt(10)},
				{Number: big.NewInt(3), Difficulty: big.NewInt(10)},
				{Number: big.NewInt(4), Difficulty: big.NewInt(10)},
			},
			ttd: big.NewInt(50),
		},
		// last block doesn't reach ttd
		{
			preHeaders: []*types.Header{
				{Number: big.NewInt(1), Difficulty: big.NewInt(10)},
				{Number: big.NewInt(2), Difficulty: big.NewInt(10)},
				{Number: big.NewInt(3), Difficulty: big.NewInt(10)},
				{Number: big.NewInt(4), Difficulty: big.NewInt(9)},
			},
			ttd:   big.NewInt(50),
			err:   consensus.ErrInvalidTerminalBlock,
			index: 3,
		},
		// two blocks reach ttd
		{
			preHeaders: []*types.Header{
				{Number: big.NewInt(1), Difficulty: big.NewInt(10)},
				{Number: big.NewInt(2), Difficulty: big.NewInt(10)},
				{Number: big.NewInt(3), Difficulty: big.NewInt(20)},
				{Number: big.NewInt(4), Difficulty: big.NewInt(10)},
			},
			ttd:   big.NewInt(50),
			err:   consensus.ErrInvalidTerminalBlock,
			index: 3,
		},
		// three blocks reach ttd
		{
			preHeaders: []*types.Header{
				{Number: big.NewInt(1), Difficulty: big.NewInt(10)},
				{Number: big.NewInt(2), Difficulty: big.NewInt(10)},
				{Number: big.NewInt(3), Difficulty: big.NewInt(20)},
				{Number: big.NewInt(4), Difficulty: big.NewInt(10)},
				{Number: big.NewInt(4), Difficulty: big.NewInt(10)},
			},
			ttd:   big.NewInt(50),
			err:   consensus.ErrInvalidTerminalBlock,
			index: 3,
		},
		// parent reached ttd
		{
			preHeaders: []*types.Header{
				{Number: big.NewInt(1), Difficulty: big.NewInt(10)},
			},
			ttd:   big.NewInt(9),
			err:   consensus.ErrInvalidTerminalBlock,
			index: 0,
		},
		// unknown parent
		{
			preHeaders: []*types.Header{
				{Number: big.NewInt(4), Difficulty: big.NewInt(10)},
			},
			ttd:   big.NewInt(9),
			err:   consensus.ErrUnknownAncestor,
			index: 0,
		},
	}

	for i, test := range tests {
		fmt.Printf("Test: %v\n", i)
		chain.config.TerminalTotalDifficulty = test.ttd
		index, err := verifyTerminalPoWBlock(chain, test.preHeaders)
		if err != test.err {
			t.Fatalf("Invalid error encountered, expected %v got %v", test.err, err)
		}
		if index != test.index {
			t.Fatalf("Invalid index, expected %v got %v", test.index, index)
		}
	}
}