package test

import (
	"context"
	"fmt"
	"sync/atomic"
	"testing"
	"time"

	"github.com/stretchr/testify/require"

	"github.com/filecoin-project/go-state-types/abi"

	"github.com/filecoin-project/lotus/chain/types"
	"github.com/filecoin-project/lotus/node/impl"
)

func TestCCUpgrade(t *testing.T, b APIBuilder, blocktime time.Duration) {
	for _, height := range []abi.ChainEpoch{
		-1,   // before
		162,  // while sealing
		530,  // after upgrade deal
		5000, // after
	} {
		height := height // make linters happy by copying
		t.Run(fmt.Sprintf("upgrade-%d", height), func(t *testing.T) {
			testCCUpgrade(t, b, blocktime, height)
		})
	}
}

func testCCUpgrade(t *testing.T, b APIBuilder, blocktime time.Duration, upgradeHeight abi.ChainEpoch) {
	ctx := context.Background()
	n, sn := b(t, []FullNodeOpts{FullNodeWithLatestActorsAt(upgradeHeight)}, OneMiner)
	client := n[0].FullNode.(*impl.FullNodeAPI)
	miner := sn[0]

	addrinfo, err := client.NetAddrsListen(ctx)
	if err != nil {
		t.Fatal(err)
	}

	if err := miner.NetConnect(ctx, addrinfo); err != nil {
		t.Fatal(err)
	}
	time.Sleep(time.Second)

	mine := int64(1)
	done := make(chan struct{})
	go func() {
		defer close(done)
		for atomic.LoadInt64(&mine) == 1 {
			time.Sleep(blocktime)
			if err := sn[0].MineOne(ctx, MineNext); err != nil {
				t.Error(err)
			}
		}
	}()

	maddr, err := miner.ActorAddress(ctx)
	if err != nil {
		t.Fatal(err)
	}

	CC := abi.SectorNumber(GenesisPreseals + 1)
	Upgraded := CC + 1

	pledgeSectors(t, ctx, miner, 1, 0, nil)

	sl, err := miner.SectorsList(ctx)
	if err != nil {
		t.Fatal(err)
	}
	if len(sl) != 1 {
		t.Fatal("expected 1 sector")
	}

	if sl[0] != CC {
		t.Fatal("bad")
	}

	{
		si, err := client.StateSectorGetInfo(ctx, maddr, CC, types.EmptyTSK)
		require.NoError(t, err)
		require.Less(t, 50000, int(si.Expiration))
	}

	if err := miner.SectorMarkForUpgrade(ctx, sl[0]); err != nil {
		t.Fatal(err)
	}

	MakeDeal(t, ctx, 6, client, miner, false, false, 0)

	// Validate upgrade

	{
		exp, err := client.StateSectorExpiration(ctx, maddr, CC, types.EmptyTSK)
		require.NoError(t, err)
		require.NotNil(t, exp)
		require.Greater(t, 50000, int(exp.OnTime))
	}
	{
		exp, err := client.StateSectorExpiration(ctx, maddr, Upgraded, types.EmptyTSK)
		require.NoError(t, err)
		require.Less(t, 50000, int(exp.OnTime))
	}

	dlInfo, err := client.StateMinerProvingDeadline(ctx, maddr, types.EmptyTSK)
	require.NoError(t, err)

	// Sector should expire.
	for {
		// Wait for the sector to expire.
		status, err := miner.SectorsStatus(ctx, CC, true)
		require.NoError(t, err)
		if status.OnTime == 0 && status.Early == 0 {
			break
		}
		t.Log("waiting for sector to expire")
		// wait one deadline per loop.
		time.Sleep(time.Duration(dlInfo.WPoStChallengeWindow) * blocktime)
	}

	fmt.Println("shutting down mining")
	atomic.AddInt64(&mine, -1)
	<-done
}