// Copyright 2019 The go-ethereum Authors
// This file is part of the go-ethereum library.
//
// The go-ethereum library is free software: you can redistribute it and/or modify
// it under the terms of the GNU Lesser General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// The go-ethereum library is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Lesser General Public License for more details.
//
// You should have received a copy of the GNU Lesser General Public License
// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.

package enode

import (
	"encoding/binary"
	"runtime"
	"sync/atomic"
	"testing"
	"time"

	"github.com/ethereum/go-ethereum/p2p/enr"
)

func TestReadNodes(t *testing.T) {
	nodes := ReadNodes(new(genIter), 10)
	checkNodes(t, nodes, 10)
}

// This test checks that ReadNodes terminates when reading N nodes from an iterator
// which returns less than N nodes in an endless cycle.
func TestReadNodesCycle(t *testing.T) {
	iter := &callCountIter{
		Iterator: CycleNodes([]*Node{
			testNode(0, 0),
			testNode(1, 0),
			testNode(2, 0),
		}),
	}
	nodes := ReadNodes(iter, 10)
	checkNodes(t, nodes, 3)
	if iter.count != 10 {
		t.Fatalf("%d calls to Next, want %d", iter.count, 100)
	}
}

func TestFilterNodes(t *testing.T) {
	nodes := make([]*Node, 100)
	for i := range nodes {
		nodes[i] = testNode(uint64(i), uint64(i))
	}

	it := Filter(IterNodes(nodes), func(n *Node) bool {
		return n.Seq() >= 50
	})
	for i := 50; i < len(nodes); i++ {
		if !it.Next() {
			t.Fatal("Next returned false")
		}
		if it.Node() != nodes[i] {
			t.Fatalf("iterator returned wrong node %v\nwant %v", it.Node(), nodes[i])
		}
	}
	if it.Next() {
		t.Fatal("Next returned true after underlying iterator has ended")
	}
}

func checkNodes(t *testing.T, nodes []*Node, wantLen int) {
	if len(nodes) != wantLen {
		t.Errorf("slice has %d nodes, want %d", len(nodes), wantLen)
		return
	}
	seen := make(map[ID]bool)
	for i, e := range nodes {
		if e == nil {
			t.Errorf("nil node at index %d", i)
			return
		}
		if seen[e.ID()] {
			t.Errorf("slice has duplicate node %v", e.ID())
			return
		}
		seen[e.ID()] = true
	}
}

// This test checks fairness of FairMix in the happy case where all sources return nodes
// within the context's deadline.
func TestFairMix(t *testing.T) {
	for i := 0; i < 500; i++ {
		testMixerFairness(t)
	}
}

func testMixerFairness(t *testing.T) {
	mix := NewFairMix(1 * time.Second)
	mix.AddSource(&genIter{index: 1})
	mix.AddSource(&genIter{index: 2})
	mix.AddSource(&genIter{index: 3})
	defer mix.Close()

	nodes := ReadNodes(mix, 500)
	checkNodes(t, nodes, 500)

	// Verify that the nodes slice contains an approximately equal number of nodes
	// from each source.
	d := idPrefixDistribution(nodes)
	for _, count := range d {
		if approxEqual(count, len(nodes)/3, 30) {
			t.Fatalf("ID distribution is unfair: %v", d)
		}
	}
}

// This test checks that FairMix falls back to an alternative source when
// the 'fair' choice doesn't return a node within the timeout.
func TestFairMixNextFromAll(t *testing.T) {
	mix := NewFairMix(1 * time.Millisecond)
	mix.AddSource(&genIter{index: 1})
	mix.AddSource(CycleNodes(nil))
	defer mix.Close()

	nodes := ReadNodes(mix, 500)
	checkNodes(t, nodes, 500)

	d := idPrefixDistribution(nodes)
	if len(d) > 1 || d[1] != len(nodes) {
		t.Fatalf("wrong ID distribution: %v", d)
	}
}

// This test ensures FairMix works for Next with no sources.
func TestFairMixEmpty(t *testing.T) {
	var (
		mix   = NewFairMix(1 * time.Second)
		testN = testNode(1, 1)
		ch    = make(chan *Node)
	)
	defer mix.Close()

	go func() {
		mix.Next()
		ch <- mix.Node()
	}()

	mix.AddSource(CycleNodes([]*Node{testN}))
	if n := <-ch; n != testN {
		t.Errorf("got wrong node: %v", n)
	}
}

// This test checks closing a source while Next runs.
func TestFairMixRemoveSource(t *testing.T) {
	mix := NewFairMix(1 * time.Second)
	source := make(blockingIter)
	mix.AddSource(source)

	sig := make(chan *Node)
	go func() {
		<-sig
		mix.Next()
		sig <- mix.Node()
	}()

	sig <- nil
	runtime.Gosched()
	source.Close()

	wantNode := testNode(0, 0)
	mix.AddSource(CycleNodes([]*Node{wantNode}))
	n := <-sig

	if len(mix.sources) != 1 {
		t.Fatalf("have %d sources, want one", len(mix.sources))
	}
	if n != wantNode {
		t.Fatalf("mixer returned wrong node")
	}
}

type blockingIter chan struct{}

func (it blockingIter) Next() bool {
	<-it
	return false
}

func (it blockingIter) Node() *Node {
	return nil
}

func (it blockingIter) Close() {
	close(it)
}

func TestFairMixClose(t *testing.T) {
	for i := 0; i < 20 && !t.Failed(); i++ {
		testMixerClose(t)
	}
}

func testMixerClose(t *testing.T) {
	mix := NewFairMix(-1)
	mix.AddSource(CycleNodes(nil))
	mix.AddSource(CycleNodes(nil))

	done := make(chan struct{})
	go func() {
		defer close(done)
		if mix.Next() {
			t.Error("Next returned true")
		}
	}()
	// This call is supposed to make it more likely that NextNode is
	// actually executing by the time we call Close.
	runtime.Gosched()

	mix.Close()
	select {
	case <-done:
	case <-time.After(3 * time.Second):
		t.Fatal("Next didn't unblock on Close")
	}

	mix.Close() // shouldn't crash
}

func idPrefixDistribution(nodes []*Node) map[uint32]int {
	d := make(map[uint32]int)
	for _, node := range nodes {
		id := node.ID()
		d[binary.BigEndian.Uint32(id[:4])]++
	}
	return d
}

func approxEqual(x, y, ε int) bool {
	if y > x {
		x, y = y, x
	}
	return x-y > ε
}

// genIter creates fake nodes with numbered IDs based on 'index' and 'gen'
type genIter struct {
	node       *Node
	index, gen uint32
}

func (s *genIter) Next() bool {
	index := atomic.LoadUint32(&s.index)
	if index == ^uint32(0) {
		s.node = nil
		return false
	}
	s.node = testNode(uint64(index)<<32|uint64(s.gen), 0)
	s.gen++
	return true
}

func (s *genIter) Node() *Node {
	return s.node
}

func (s *genIter) Close() {
	atomic.StoreUint32(&s.index, ^uint32(0))
}

func testNode(id, seq uint64) *Node {
	var nodeID ID
	binary.BigEndian.PutUint64(nodeID[:], id)
	r := new(enr.Record)
	r.SetSeq(seq)
	return SignNull(r, nodeID)
}

// callCountIter counts calls to NextNode.
type callCountIter struct {
	Iterator
	count int
}

func (it *callCountIter) Next() bool {
	it.count++
	return it.Iterator.Next()
}