468 lines
		
	
	
		
			15 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
			
		
		
	
	
			468 lines
		
	
	
		
			15 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
| package pss
 | |
| 
 | |
| import (
 | |
| 	"context"
 | |
| 	"crypto/ecdsa"
 | |
| 	"encoding/binary"
 | |
| 	"fmt"
 | |
| 	"sync"
 | |
| 	"testing"
 | |
| 	"time"
 | |
| 
 | |
| 	"github.com/ethereum/go-ethereum/common"
 | |
| 	"github.com/ethereum/go-ethereum/common/hexutil"
 | |
| 	"github.com/ethereum/go-ethereum/log"
 | |
| 	"github.com/ethereum/go-ethereum/node"
 | |
| 	"github.com/ethereum/go-ethereum/p2p"
 | |
| 	"github.com/ethereum/go-ethereum/p2p/enode"
 | |
| 	"github.com/ethereum/go-ethereum/p2p/simulations/adapters"
 | |
| 	"github.com/ethereum/go-ethereum/rpc"
 | |
| 	"github.com/ethereum/go-ethereum/swarm/network"
 | |
| 	"github.com/ethereum/go-ethereum/swarm/network/simulation"
 | |
| 	"github.com/ethereum/go-ethereum/swarm/pot"
 | |
| 	"github.com/ethereum/go-ethereum/swarm/state"
 | |
| )
 | |
| 
 | |
| // needed to make the enode id of the receiving node available to the handler for triggers
 | |
| type handlerContextFunc func(*testData, *adapters.NodeConfig) *handler
 | |
| 
 | |
| // struct to notify reception of messages to simulation driver
 | |
| // TODO To make code cleaner:
 | |
| // - consider a separate pss unwrap to message event in sim framework (this will make eventual message propagation analysis with pss easier/possible in the future)
 | |
| // - consider also test api calls to inspect handling results of messages
 | |
| type handlerNotification struct {
 | |
| 	id     enode.ID
 | |
| 	serial uint64
 | |
| }
 | |
| 
 | |
| type testData struct {
 | |
| 	sim                *simulation.Simulation
 | |
| 	kademlias          map[enode.ID]*network.Kademlia
 | |
| 	nodeAddresses      map[enode.ID][]byte // make predictable overlay addresses from the generated random enode ids
 | |
| 	senders            map[int]enode.ID    // originating nodes of the messages (intention is to choose as far as possible from the receiving neighborhood)
 | |
| 	recipientAddresses [][]byte
 | |
| 
 | |
| 	requiredMsgCount int
 | |
| 	requiredMsgs     map[enode.ID][]uint64 // message serials we expect respective nodes to receive
 | |
| 	allowedMsgs      map[enode.ID][]uint64 // message serials we expect respective nodes to receive
 | |
| 
 | |
| 	notifications []handlerNotification // notification queue
 | |
| 	totalMsgCount int
 | |
| 	handlerDone   bool // set to true on termination of the simulation run
 | |
| 	mu            sync.Mutex
 | |
| }
 | |
| 
 | |
| var (
 | |
| 	pof   = pot.DefaultPof(256) // generate messages and index them
 | |
| 	topic = BytesToTopic([]byte{0xf3, 0x9e, 0x06, 0x82})
 | |
| )
 | |
| 
 | |
| func (td *testData) pushNotification(val handlerNotification) {
 | |
| 	td.mu.Lock()
 | |
| 	td.notifications = append(td.notifications, val)
 | |
| 	td.mu.Unlock()
 | |
| }
 | |
| 
 | |
| func (td *testData) popNotification() (first handlerNotification, exist bool) {
 | |
| 	td.mu.Lock()
 | |
| 	if len(td.notifications) > 0 {
 | |
| 		exist = true
 | |
| 		first = td.notifications[0]
 | |
| 		td.notifications = td.notifications[1:]
 | |
| 	}
 | |
| 	td.mu.Unlock()
 | |
| 	return first, exist
 | |
| }
 | |
| 
 | |
| func (td *testData) getMsgCount() int {
 | |
| 	td.mu.Lock()
 | |
| 	defer td.mu.Unlock()
 | |
| 	return td.totalMsgCount
 | |
| }
 | |
| 
 | |
| func (td *testData) incrementMsgCount() int {
 | |
| 	td.mu.Lock()
 | |
| 	defer td.mu.Unlock()
 | |
| 	td.totalMsgCount++
 | |
| 	return td.totalMsgCount
 | |
| }
 | |
| 
 | |
| func (td *testData) isDone() bool {
 | |
| 	td.mu.Lock()
 | |
| 	defer td.mu.Unlock()
 | |
| 	return td.handlerDone
 | |
| }
 | |
| 
 | |
| func (td *testData) setDone() {
 | |
| 	td.mu.Lock()
 | |
| 	defer td.mu.Unlock()
 | |
| 	td.handlerDone = true
 | |
| }
 | |
| 
 | |
| func newTestData() *testData {
 | |
| 	return &testData{
 | |
| 		kademlias:     make(map[enode.ID]*network.Kademlia),
 | |
| 		nodeAddresses: make(map[enode.ID][]byte),
 | |
| 		requiredMsgs:  make(map[enode.ID][]uint64),
 | |
| 		allowedMsgs:   make(map[enode.ID][]uint64),
 | |
| 		senders:       make(map[int]enode.ID),
 | |
| 	}
 | |
| }
 | |
| 
 | |
| func (td *testData) getKademlia(nodeId *enode.ID) (*network.Kademlia, error) {
 | |
| 	kadif, ok := td.sim.NodeItem(*nodeId, simulation.BucketKeyKademlia)
 | |
| 	if !ok {
 | |
| 		return nil, fmt.Errorf("no kademlia entry for %v", nodeId)
 | |
| 	}
 | |
| 	kad, ok := kadif.(*network.Kademlia)
 | |
| 	if !ok {
 | |
| 		return nil, fmt.Errorf("invalid kademlia entry for %v", nodeId)
 | |
| 	}
 | |
| 	return kad, nil
 | |
| }
 | |
| 
 | |
| func (td *testData) init(msgCount int) error {
 | |
| 	log.Debug("TestProxNetwork start")
 | |
| 
 | |
| 	for _, nodeId := range td.sim.NodeIDs() {
 | |
| 		kad, err := td.getKademlia(&nodeId)
 | |
| 		if err != nil {
 | |
| 			return err
 | |
| 		}
 | |
| 		td.nodeAddresses[nodeId] = kad.BaseAddr()
 | |
| 	}
 | |
| 
 | |
| 	for i := 0; i < int(msgCount); i++ {
 | |
| 		msgAddr := pot.RandomAddress() // we choose message addresses randomly
 | |
| 		td.recipientAddresses = append(td.recipientAddresses, msgAddr.Bytes())
 | |
| 		smallestPo := 256
 | |
| 		var targets []enode.ID
 | |
| 		var closestPO int
 | |
| 
 | |
| 		// loop through all nodes and find the required and allowed recipients of each message
 | |
| 		// (for more information, please see the comment to the main test function)
 | |
| 		for _, nod := range td.sim.Net.GetNodes() {
 | |
| 			po, _ := pof(td.recipientAddresses[i], td.nodeAddresses[nod.ID()], 0)
 | |
| 			depth := td.kademlias[nod.ID()].NeighbourhoodDepth()
 | |
| 
 | |
| 			// only nodes with closest IDs (wrt the msg address) will be required recipients
 | |
| 			if po > closestPO {
 | |
| 				closestPO = po
 | |
| 				targets = nil
 | |
| 				targets = append(targets, nod.ID())
 | |
| 			} else if po == closestPO {
 | |
| 				targets = append(targets, nod.ID())
 | |
| 			}
 | |
| 
 | |
| 			if po >= depth {
 | |
| 				td.allowedMsgs[nod.ID()] = append(td.allowedMsgs[nod.ID()], uint64(i))
 | |
| 			}
 | |
| 
 | |
| 			// a node with the smallest PO (wrt msg) will be the sender,
 | |
| 			// in order to increase the distance the msg must travel
 | |
| 			if po < smallestPo {
 | |
| 				smallestPo = po
 | |
| 				td.senders[i] = nod.ID()
 | |
| 			}
 | |
| 		}
 | |
| 
 | |
| 		td.requiredMsgCount += len(targets)
 | |
| 		for _, id := range targets {
 | |
| 			td.requiredMsgs[id] = append(td.requiredMsgs[id], uint64(i))
 | |
| 		}
 | |
| 
 | |
| 		log.Debug("nn for msg", "targets", len(targets), "msgidx", i, "msg", common.Bytes2Hex(msgAddr[:8]), "sender", td.senders[i], "senderpo", smallestPo)
 | |
| 	}
 | |
| 	log.Debug("recipientAddresses to receive", "count", td.requiredMsgCount)
 | |
| 	return nil
 | |
| }
 | |
| 
 | |
| // Here we test specific functionality of the pss, setting the prox property of
 | |
| // the handler. The tests generate a number of messages with random addresses.
 | |
| // Then, for each message it calculates which nodes have the msg address
 | |
| // within its nearest neighborhood depth, and stores those nodes as possible
 | |
| // recipients. Those nodes that are the closest to the message address (nodes
 | |
| // belonging to the deepest PO wrt the msg address) are stored as required
 | |
| // recipients. The difference between allowed and required recipients results
 | |
| // from the fact that the nearest neighbours are not necessarily reciprocal.
 | |
| // Upon sending the messages, the test verifies that the respective message is
 | |
| // passed to the message handlers of these required recipients. The test fails
 | |
| // if a message is handled by recipient which is not listed among the allowed
 | |
| // recipients of this particular message. It also fails after timeout, if not
 | |
| // all the required recipients have received their respective messages.
 | |
| //
 | |
| // For example, if proximity order of certain msg address is 4, and node X
 | |
| // has PO=5 wrt the message address, and nodes Y and Z have PO=6, then:
 | |
| // nodes Y and Z will be considered required recipients of the msg,
 | |
| // whereas nodes X, Y and Z will be allowed recipients.
 | |
| func TestProxNetwork(t *testing.T) {
 | |
| 	t.Run("16_nodes,_16_messages,_16_seconds", func(t *testing.T) {
 | |
| 		testProxNetwork(t, 16, 16, 16*time.Second)
 | |
| 	})
 | |
| }
 | |
| 
 | |
| func TestProxNetworkLong(t *testing.T) {
 | |
| 	if !*longrunning {
 | |
| 		t.Skip("run with --longrunning flag to run extensive network tests")
 | |
| 	}
 | |
| 	t.Run("8_nodes,_100_messages,_30_seconds", func(t *testing.T) {
 | |
| 		testProxNetwork(t, 8, 100, 30*time.Second)
 | |
| 	})
 | |
| 	t.Run("16_nodes,_100_messages,_30_seconds", func(t *testing.T) {
 | |
| 		testProxNetwork(t, 16, 100, 30*time.Second)
 | |
| 	})
 | |
| 	t.Run("32_nodes,_100_messages,_60_seconds", func(t *testing.T) {
 | |
| 		testProxNetwork(t, 32, 100, 1*time.Minute)
 | |
| 	})
 | |
| 	t.Run("64_nodes,_100_messages,_60_seconds", func(t *testing.T) {
 | |
| 		testProxNetwork(t, 64, 100, 1*time.Minute)
 | |
| 	})
 | |
| 	t.Run("128_nodes,_100_messages,_120_seconds", func(t *testing.T) {
 | |
| 		testProxNetwork(t, 128, 100, 2*time.Minute)
 | |
| 	})
 | |
| }
 | |
| 
 | |
| func testProxNetwork(t *testing.T, nodeCount int, msgCount int, timeout time.Duration) {
 | |
| 	td := newTestData()
 | |
| 	handlerContextFuncs := make(map[Topic]handlerContextFunc)
 | |
| 	handlerContextFuncs[topic] = nodeMsgHandler
 | |
| 	services := newProxServices(td, true, handlerContextFuncs, td.kademlias)
 | |
| 	td.sim = simulation.New(services)
 | |
| 	defer td.sim.Close()
 | |
| 	ctx, cancel := context.WithTimeout(context.Background(), timeout)
 | |
| 	defer cancel()
 | |
| 	filename := fmt.Sprintf("testdata/snapshot_%d.json", nodeCount)
 | |
| 	err := td.sim.UploadSnapshot(ctx, filename)
 | |
| 	if err != nil {
 | |
| 		t.Fatal(err)
 | |
| 	}
 | |
| 	err = td.init(msgCount) // initialize the test data
 | |
| 	if err != nil {
 | |
| 		t.Fatal(err)
 | |
| 	}
 | |
| 	wrapper := func(c context.Context, _ *simulation.Simulation) error {
 | |
| 		return testRoutine(td, c)
 | |
| 	}
 | |
| 	result := td.sim.Run(ctx, wrapper) // call the main test function
 | |
| 	if result.Error != nil {
 | |
| 		timedOut := result.Error == context.DeadlineExceeded
 | |
| 		if !timedOut || td.getMsgCount() < td.requiredMsgCount {
 | |
| 			t.Fatal(result.Error)
 | |
| 		}
 | |
| 	}
 | |
| }
 | |
| 
 | |
| func (td *testData) sendAllMsgs() error {
 | |
| 	nodes := make(map[int]*rpc.Client)
 | |
| 	for i := range td.recipientAddresses {
 | |
| 		nodeClient, err := td.sim.Net.GetNode(td.senders[i]).Client()
 | |
| 		if err != nil {
 | |
| 			return err
 | |
| 		}
 | |
| 		nodes[i] = nodeClient
 | |
| 	}
 | |
| 
 | |
| 	for i, msg := range td.recipientAddresses {
 | |
| 		log.Debug("sending msg", "idx", i, "from", td.senders[i])
 | |
| 		nodeClient := nodes[i]
 | |
| 		var uvarByte [8]byte
 | |
| 		binary.PutUvarint(uvarByte[:], uint64(i))
 | |
| 		nodeClient.Call(nil, "pss_sendRaw", hexutil.Encode(msg), hexutil.Encode(topic[:]), hexutil.Encode(uvarByte[:]))
 | |
| 	}
 | |
| 	return nil
 | |
| }
 | |
| 
 | |
| func isMoreTimeLeft(ctx context.Context) bool {
 | |
| 	select {
 | |
| 	case <-ctx.Done():
 | |
| 		return false
 | |
| 	default:
 | |
| 		return true
 | |
| 	}
 | |
| }
 | |
| 
 | |
| // testRoutine is the main test function, called by Simulation.Run()
 | |
| func testRoutine(td *testData, ctx context.Context) error {
 | |
| 
 | |
| 	hasMoreRound := func(err error, hadMessage bool) bool {
 | |
| 		return err == nil && (hadMessage || isMoreTimeLeft(ctx))
 | |
| 	}
 | |
| 
 | |
| 	if err := td.sendAllMsgs(); err != nil {
 | |
| 		return err
 | |
| 	}
 | |
| 
 | |
| 	var err error
 | |
| 	received := 0
 | |
| 	hadMessage := false
 | |
| 
 | |
| 	for oneMoreRound := true; oneMoreRound; oneMoreRound = hasMoreRound(err, hadMessage) {
 | |
| 		message, hadMessage := td.popNotification()
 | |
| 
 | |
| 		if !isMoreTimeLeft(ctx) {
 | |
| 			// Stop handlers from sending more messages.
 | |
| 			// Note: only best effort, race is possible.
 | |
| 			td.setDone()
 | |
| 		}
 | |
| 
 | |
| 		if hadMessage {
 | |
| 			if td.isAllowedMessage(message) {
 | |
| 				received++
 | |
| 				log.Debug("msg received", "msgs_received", received, "total_expected", td.requiredMsgCount, "id", message.id, "serial", message.serial)
 | |
| 			} else {
 | |
| 				err = fmt.Errorf("message %d received by wrong recipient %v", message.serial, message.id)
 | |
| 			}
 | |
| 		} else {
 | |
| 			time.Sleep(32 * time.Millisecond)
 | |
| 		}
 | |
| 	}
 | |
| 
 | |
| 	if err != nil {
 | |
| 		return err
 | |
| 	}
 | |
| 
 | |
| 	if td.getMsgCount() < td.requiredMsgCount {
 | |
| 		return ctx.Err()
 | |
| 	}
 | |
| 	return nil
 | |
| }
 | |
| 
 | |
| func (td *testData) isAllowedMessage(n handlerNotification) bool {
 | |
| 	// check if message serial is in expected messages for this recipient
 | |
| 	for _, s := range td.allowedMsgs[n.id] {
 | |
| 		if n.serial == s {
 | |
| 			return true
 | |
| 		}
 | |
| 	}
 | |
| 	return false
 | |
| }
 | |
| 
 | |
| func (td *testData) removeAllowedMessage(id enode.ID, index int) {
 | |
| 	last := len(td.allowedMsgs[id]) - 1
 | |
| 	td.allowedMsgs[id][index] = td.allowedMsgs[id][last]
 | |
| 	td.allowedMsgs[id] = td.allowedMsgs[id][:last]
 | |
| }
 | |
| 
 | |
| func nodeMsgHandler(td *testData, config *adapters.NodeConfig) *handler {
 | |
| 	return &handler{
 | |
| 		f: func(msg []byte, p *p2p.Peer, asymmetric bool, keyid string) error {
 | |
| 			if td.isDone() {
 | |
| 				return nil // terminate if simulation is over
 | |
| 			}
 | |
| 
 | |
| 			td.incrementMsgCount()
 | |
| 
 | |
| 			// using simple serial in message body, makes it easy to keep track of who's getting what
 | |
| 			serial, c := binary.Uvarint(msg)
 | |
| 			if c <= 0 {
 | |
| 				log.Crit(fmt.Sprintf("corrupt message received by %x (uvarint parse returned %d)", config.ID, c))
 | |
| 			}
 | |
| 
 | |
| 			td.pushNotification(handlerNotification{id: config.ID, serial: serial})
 | |
| 			return nil
 | |
| 		},
 | |
| 		caps: &handlerCaps{
 | |
| 			raw:  true, // we use raw messages for simplicity
 | |
| 			prox: true,
 | |
| 		},
 | |
| 	}
 | |
| }
 | |
| 
 | |
| // an adaptation of the same services setup as in pss_test.go
 | |
| // replaces pss_test.go when those tests are rewritten to the new swarm/network/simulation package
 | |
| func newProxServices(td *testData, allowRaw bool, handlerContextFuncs map[Topic]handlerContextFunc, kademlias map[enode.ID]*network.Kademlia) map[string]simulation.ServiceFunc {
 | |
| 	stateStore := state.NewInmemoryStore()
 | |
| 	kademlia := func(id enode.ID, bzzkey []byte) *network.Kademlia {
 | |
| 		if k, ok := kademlias[id]; ok {
 | |
| 			return k
 | |
| 		}
 | |
| 		params := network.NewKadParams()
 | |
| 		params.MaxBinSize = 3
 | |
| 		params.MinBinSize = 1
 | |
| 		params.MaxRetries = 1000
 | |
| 		params.RetryExponent = 2
 | |
| 		params.RetryInterval = 1000000
 | |
| 		kademlias[id] = network.NewKademlia(bzzkey, params)
 | |
| 		return kademlias[id]
 | |
| 	}
 | |
| 	return map[string]simulation.ServiceFunc{
 | |
| 		"bzz": func(ctx *adapters.ServiceContext, b *sync.Map) (node.Service, func(), error) {
 | |
| 			var err error
 | |
| 			var bzzPrivateKey *ecdsa.PrivateKey
 | |
| 			// normally translation of enode id to swarm address is concealed by the network package
 | |
| 			// however, we need to keep track of it in the test driver as well.
 | |
| 			// if the translation in the network package changes, that can cause these tests to unpredictably fail
 | |
| 			// therefore we keep a local copy of the translation here
 | |
| 			addr := network.NewAddr(ctx.Config.Node())
 | |
| 			bzzPrivateKey, err = simulation.BzzPrivateKeyFromConfig(ctx.Config)
 | |
| 			if err != nil {
 | |
| 				return nil, nil, err
 | |
| 			}
 | |
| 			addr.OAddr = network.PrivateKeyToBzzKey(bzzPrivateKey)
 | |
| 			b.Store(simulation.BucketKeyBzzPrivateKey, bzzPrivateKey)
 | |
| 			hp := network.NewHiveParams()
 | |
| 			hp.Discovery = false
 | |
| 			config := &network.BzzConfig{
 | |
| 				OverlayAddr:  addr.Over(),
 | |
| 				UnderlayAddr: addr.Under(),
 | |
| 				HiveParams:   hp,
 | |
| 			}
 | |
| 			bzzKey := network.PrivateKeyToBzzKey(bzzPrivateKey)
 | |
| 			pskad := kademlia(ctx.Config.ID, bzzKey)
 | |
| 			b.Store(simulation.BucketKeyKademlia, pskad)
 | |
| 			return network.NewBzz(config, kademlia(ctx.Config.ID, addr.OAddr), stateStore, nil, nil), nil, nil
 | |
| 		},
 | |
| 		"pss": func(ctx *adapters.ServiceContext, b *sync.Map) (node.Service, func(), error) {
 | |
| 			// execadapter does not exec init()
 | |
| 			initTest()
 | |
| 
 | |
| 			// create keys in whisper and set up the pss object
 | |
| 			ctxlocal, cancel := context.WithTimeout(context.Background(), time.Second*3)
 | |
| 			defer cancel()
 | |
| 			keys, err := wapi.NewKeyPair(ctxlocal)
 | |
| 			privkey, err := w.GetPrivateKey(keys)
 | |
| 			pssp := NewPssParams().WithPrivateKey(privkey)
 | |
| 			pssp.AllowRaw = allowRaw
 | |
| 			bzzPrivateKey, err := simulation.BzzPrivateKeyFromConfig(ctx.Config)
 | |
| 			if err != nil {
 | |
| 				return nil, nil, err
 | |
| 			}
 | |
| 			bzzKey := network.PrivateKeyToBzzKey(bzzPrivateKey)
 | |
| 			pskad := kademlia(ctx.Config.ID, bzzKey)
 | |
| 			b.Store(simulation.BucketKeyKademlia, pskad)
 | |
| 			ps, err := NewPss(pskad, pssp)
 | |
| 			if err != nil {
 | |
| 				return nil, nil, err
 | |
| 			}
 | |
| 
 | |
| 			// register the handlers we've been passed
 | |
| 			var deregisters []func()
 | |
| 			for tpc, hndlrFunc := range handlerContextFuncs {
 | |
| 				deregisters = append(deregisters, ps.Register(&tpc, hndlrFunc(td, ctx.Config)))
 | |
| 			}
 | |
| 
 | |
| 			// if handshake mode is set, add the controller
 | |
| 			// TODO: This should be hooked to the handshake test file
 | |
| 			if useHandshake {
 | |
| 				SetHandshakeController(ps, NewHandshakeParams())
 | |
| 			}
 | |
| 
 | |
| 			// we expose some api calls for cheating
 | |
| 			ps.addAPI(rpc.API{
 | |
| 				Namespace: "psstest",
 | |
| 				Version:   "0.3",
 | |
| 				Service:   NewAPITest(ps),
 | |
| 				Public:    false,
 | |
| 			})
 | |
| 
 | |
| 			// return Pss and cleanups
 | |
| 			return ps, func() {
 | |
| 				// run the handler deregister functions in reverse order
 | |
| 				for i := len(deregisters); i > 0; i-- {
 | |
| 					deregisters[i-1]()
 | |
| 				}
 | |
| 			}, nil
 | |
| 		},
 | |
| 	}
 | |
| }
 |