package splitstore

import (
	"io/ioutil"
	"testing"

	cid "github.com/ipfs/go-cid"
	"github.com/multiformats/go-multihash"

	"github.com/filecoin-project/go-state-types/abi"
)

func TestBoltTrackingStore(t *testing.T) {
	testTrackingStore(t, "bolt")
}

func testTrackingStore(t *testing.T, tsType string) {
	t.Helper()

	makeCid := func(key string) cid.Cid {
		h, err := multihash.Sum([]byte(key), multihash.SHA2_256, -1)
		if err != nil {
			t.Fatal(err)
		}

		return cid.NewCidV1(cid.Raw, h)
	}

	mustHave := func(s TrackingStore, cid cid.Cid, epoch abi.ChainEpoch) {
		val, err := s.Get(cid)
		if err != nil {
			t.Fatal(err)
		}

		if val != epoch {
			t.Fatal("epoch mismatch")
		}
	}

	mustNotHave := func(s TrackingStore, cid cid.Cid) {
		_, err := s.Get(cid)
		if err == nil {
			t.Fatal("expected error")
		}
	}

	path, err := ioutil.TempDir("", "snoop-test.*")
	if err != nil {
		t.Fatal(err)
	}

	s, err := OpenTrackingStore(path, tsType)
	if err != nil {
		t.Fatal(err)
	}

	k1 := makeCid("a")
	k2 := makeCid("b")
	k3 := makeCid("c")
	k4 := makeCid("d")

	s.Put(k1, 1) //nolint
	s.Put(k2, 2) //nolint
	s.Put(k3, 3) //nolint
	s.Put(k4, 4) //nolint

	mustHave(s, k1, 1)
	mustHave(s, k2, 2)
	mustHave(s, k3, 3)
	mustHave(s, k4, 4)

	s.Delete(k1) // nolint
	s.Delete(k2) // nolint

	mustNotHave(s, k1)
	mustNotHave(s, k2)
	mustHave(s, k3, 3)
	mustHave(s, k4, 4)

	s.PutBatch([]cid.Cid{k1}, 1) //nolint
	s.PutBatch([]cid.Cid{k2}, 2) //nolint

	mustHave(s, k1, 1)
	mustHave(s, k2, 2)
	mustHave(s, k3, 3)
	mustHave(s, k4, 4)

	allKeys := map[string]struct{}{
		k1.String(): {},
		k2.String(): {},
		k3.String(): {},
		k4.String(): {},
	}

	err = s.ForEach(func(k cid.Cid, _ abi.ChainEpoch) error {
		_, ok := allKeys[k.String()]
		if !ok {
			t.Fatal("unexpected key")
		}

		delete(allKeys, k.String())
		return nil
	})

	if err != nil {
		t.Fatal(err)
	}

	if len(allKeys) != 0 {
		t.Fatal("not all keys were returned")
	}

	// no close and reopen and ensure the keys still exist
	err = s.Close()
	if err != nil {
		t.Fatal(err)
	}

	s, err = OpenTrackingStore(path, tsType)
	if err != nil {
		t.Fatal(err)
	}

	mustHave(s, k1, 1)
	mustHave(s, k2, 2)
	mustHave(s, k3, 3)
	mustHave(s, k4, 4)

	s.Close() //nolint:errcheck
}