diff --git a/swarm/network/discovery.go b/swarm/network/discovery.go index 4c503047a..54ecf257c 100644 --- a/swarm/network/discovery.go +++ b/swarm/network/discovery.go @@ -26,6 +26,8 @@ import ( // discovery bzz extension for requesting and relaying node address records +var sortPeers = noSortPeers + // Peer wraps BzzPeer and embeds Kademlia overlay connectivity driver type Peer struct { *BzzPeer @@ -156,28 +158,39 @@ func (msg subPeersMsg) String() string { return fmt.Sprintf("%T: request peers > PO%02d. ", msg, msg.Depth) } +// handleSubPeersMsg handles incoming subPeersMsg +// this message represents the saturation depth of the remote peer +// saturation depth is the radius within which the peer subscribes to peers +// the first time this is received we send peer info on all +// our connected peers that fall within peers saturation depth +// otherwise this depth is just recorded on the peer, so that +// subsequent new connections are sent iff they fall within the radius func (d *Peer) handleSubPeersMsg(msg *subPeersMsg) error { + d.setDepth(msg.Depth) + // only send peers after the initial subPeersMsg if !d.sentPeers { - d.setDepth(msg.Depth) var peers []*BzzAddr + // iterate connection in ascending order of disctance from the remote address d.kad.EachConn(d.Over(), 255, func(p *Peer, po int) bool { - if pob, _ := Pof(d, d.kad.BaseAddr(), 0); pob > po { + // terminate if we are beyond the radius + if uint8(po) < msg.Depth { return false } - if !d.seen(p.BzzAddr) { + if !d.seen(p.BzzAddr) { // here just records the peer sent peers = append(peers, p.BzzAddr) } return true }) + // if useful peers are found, send them over if len(peers) > 0 { - go d.Send(context.TODO(), &peersMsg{Peers: peers}) + go d.Send(context.TODO(), &peersMsg{Peers: sortPeers(peers)}) } } d.sentPeers = true return nil } -// seen takes an peer address and checks if it was sent to a peer already +// seen takes a peer address and checks if it was sent to a peer already // if not, marks the peer as sent func (d *Peer) seen(p *BzzAddr) bool { d.mtx.Lock() @@ -201,3 +214,7 @@ func (d *Peer) setDepth(depth uint8) { defer d.mtx.Unlock() d.depth = depth } + +func noSortPeers(peers []*BzzAddr) []*BzzAddr { + return peers +} diff --git a/swarm/network/discovery_test.go b/swarm/network/discovery_test.go index ea0d776e6..04e1b36fe 100644 --- a/swarm/network/discovery_test.go +++ b/swarm/network/discovery_test.go @@ -17,9 +17,22 @@ package network import ( + "crypto/ecdsa" + crand "crypto/rand" + "encoding/binary" + "fmt" + "math/rand" + "net" + "sort" "testing" + "time" + "github.com/ethereum/go-ethereum/crypto" + "github.com/ethereum/go-ethereum/p2p" + "github.com/ethereum/go-ethereum/p2p/enode" + "github.com/ethereum/go-ethereum/p2p/protocols" p2ptest "github.com/ethereum/go-ethereum/p2p/testing" + "github.com/ethereum/go-ethereum/swarm/pot" ) /*** @@ -27,9 +40,9 @@ import ( * - after connect, that outgoing subpeersmsg is sent * */ -func TestDiscovery(t *testing.T) { +func TestSubPeersMsg(t *testing.T) { params := NewHiveParams() - s, pp, err := newHiveTester(t, params, 1, nil) + s, pp, err := newHiveTester(params, 1, nil) if err != nil { t.Fatal(err) } @@ -58,3 +71,192 @@ func TestDiscovery(t *testing.T) { t.Fatal(err) } } + +const ( + maxPO = 8 // PO of pivot and control; chosen to test enough cases but not run too long + maxPeerPO = 6 // pivot has no peers closer than this to the control peer + maxPeersPerPO = 3 +) + +// TestInitialPeersMsg tests if peersMsg response to incoming subPeersMsg is correct +func TestInitialPeersMsg(t *testing.T) { + for po := 0; po < maxPO; po++ { + for depth := 0; depth < maxPO; depth++ { + t.Run(fmt.Sprintf("PO=%d,advertised depth=%d", po, depth), func(t *testing.T) { + testInitialPeersMsg(t, po, depth) + }) + } + } +} + +// testInitialPeersMsg tests that the correct set of peer info is sent +// to another peer after receiving their subPeersMsg request +func testInitialPeersMsg(t *testing.T, peerPO, peerDepth int) { + // generate random pivot address + prvkey, err := crypto.GenerateKey() + if err != nil { + t.Fatal(err) + } + + defer func(orig func([]*BzzAddr) []*BzzAddr) { + sortPeers = orig + }(sortPeers) + sortPeers = testSortPeers + pivotAddr := pot.NewAddressFromBytes(PrivateKeyToBzzKey(prvkey)) + // generate control peers address at peerPO wrt pivot + peerAddr := pot.RandomAddressAt(pivotAddr, peerPO) + // construct kademlia and hive + to := NewKademlia(pivotAddr[:], NewKadParams()) + hive := NewHive(NewHiveParams(), to, nil) + + // expected addrs in peersMsg response + var expBzzAddrs []*BzzAddr + connect := func(a pot.Address, po int) (addrs []*BzzAddr) { + n := rand.Intn(maxPeersPerPO) + for i := 0; i < n; i++ { + peer, err := newDiscPeer(pot.RandomAddressAt(a, po)) + if err != nil { + t.Fatal(err) + } + hive.On(peer) + addrs = append(addrs, peer.BzzAddr) + } + return addrs + } + register := func(a pot.Address, po int) { + addr := pot.RandomAddressAt(a, po) + hive.Register(&BzzAddr{OAddr: addr[:]}) + } + + // generate connected and just registered peers + for po := maxPeerPO; po >= 0; po-- { + // create a fake connected peer at po from peerAddr + ons := connect(peerAddr, po) + // create a fake registered address at po from peerAddr + register(peerAddr, po) + // we collect expected peer addresses only up till peerPO + if po < peerDepth { + continue + } + expBzzAddrs = append(expBzzAddrs, ons...) + } + + // add extra connections closer to pivot than control + for po := peerPO + 1; po < maxPO; po++ { + ons := connect(pivotAddr, po) + if peerDepth <= peerPO { + expBzzAddrs = append(expBzzAddrs, ons...) + } + } + + // create a special bzzBaseTester in which we can associate `enode.ID` to the `bzzAddr` we created above + s, _, err := newBzzBaseTesterWithAddrs(prvkey, [][]byte{peerAddr[:]}, DiscoverySpec, hive.Run) + if err != nil { + t.Fatal(err) + } + + // peerID to use in the protocol tester testExchange expect/trigger + peerID := s.Nodes[0].ID() + // block until control peer is found among hive peers + found := false + for attempts := 0; attempts < 20; attempts++ { + if _, found = hive.peers[peerID]; found { + break + } + time.Sleep(1 * time.Millisecond) + } + + if !found { + t.Fatal("timeout waiting for peer connection to start") + } + + // pivotDepth is the advertised depth of the pivot node we expect in the outgoing subPeersMsg + pivotDepth := hive.saturation() + // the test exchange is as follows: + // 1. pivot sends to the control peer a `subPeersMsg` advertising its depth (ignored) + // 2. peer sends to pivot a `subPeersMsg` advertising its own depth (arbitrarily chosen) + // 3. pivot responds with `peersMsg` with the set of expected peers + err = s.TestExchanges( + p2ptest.Exchange{ + Label: "outgoing subPeersMsg", + Expects: []p2ptest.Expect{ + { + Code: 1, + Msg: &subPeersMsg{Depth: uint8(pivotDepth)}, + Peer: peerID, + }, + }, + }, + p2ptest.Exchange{ + Label: "trigger subPeersMsg and expect peersMsg", + Triggers: []p2ptest.Trigger{ + { + Code: 1, + Msg: &subPeersMsg{Depth: uint8(peerDepth)}, + Peer: peerID, + }, + }, + Expects: []p2ptest.Expect{ + { + Code: 0, + Msg: &peersMsg{Peers: testSortPeers(expBzzAddrs)}, + Peer: peerID, + Timeout: 100 * time.Millisecond, + }, + }, + }) + + // for values MaxPeerPO < peerPO < MaxPO the pivot has no peers to offer to the control peer + // in this case, no peersMsg will be sent out, and we would run into a time out + if len(expBzzAddrs) == 0 { + if err != nil { + if err.Error() != "exchange #1 \"trigger subPeersMsg and expect peersMsg\": timed out" { + t.Fatalf("expected timeout, got %v", err) + } + return + } + t.Fatalf("expected timeout, got no error") + } + + if err != nil { + t.Fatal(err) + } +} + +func testSortPeers(peers []*BzzAddr) []*BzzAddr { + comp := func(i, j int) bool { + vi := binary.BigEndian.Uint64(peers[i].OAddr) + vj := binary.BigEndian.Uint64(peers[j].OAddr) + return vi < vj + } + sort.Slice(peers, comp) + return peers +} + +// as we are not creating a real node via the protocol, +// we need to create the discovery peer objects for the additional kademlia +// nodes manually +func newDiscPeer(addr pot.Address) (*Peer, error) { + pKey, err := ecdsa.GenerateKey(crypto.S256(), crand.Reader) + if err != nil { + return nil, err + } + pubKey := pKey.PublicKey + nod := enode.NewV4(&pubKey, net.IPv4(127, 0, 0, 1), 0, 0) + bzzAddr := &BzzAddr{OAddr: addr[:], UAddr: []byte(nod.String())} + id := nod.ID() + p2pPeer := p2p.NewPeer(id, id.String(), nil) + return NewPeer(&BzzPeer{ + Peer: protocols.NewPeer(p2pPeer, &dummyMsgRW{}, DiscoverySpec), + BzzAddr: bzzAddr, + }, nil), nil +} + +type dummyMsgRW struct{} + +func (d *dummyMsgRW) ReadMsg() (p2p.Msg, error) { + return p2p.Msg{}, nil +} +func (d *dummyMsgRW) WriteMsg(msg p2p.Msg) error { + return nil +} diff --git a/swarm/network/hive_test.go b/swarm/network/hive_test.go index ddae95a45..d03db42bc 100644 --- a/swarm/network/hive_test.go +++ b/swarm/network/hive_test.go @@ -23,11 +23,12 @@ import ( "time" "github.com/ethereum/go-ethereum/crypto" + "github.com/ethereum/go-ethereum/p2p" p2ptest "github.com/ethereum/go-ethereum/p2p/testing" "github.com/ethereum/go-ethereum/swarm/state" ) -func newHiveTester(t *testing.T, params *HiveParams, n int, store state.Store) (*bzzTester, *Hive, error) { +func newHiveTester(params *HiveParams, n int, store state.Store) (*bzzTester, *Hive, error) { // setup prvkey, err := crypto.GenerateKey() if err != nil { @@ -37,7 +38,7 @@ func newHiveTester(t *testing.T, params *HiveParams, n int, store state.Store) ( to := NewKademlia(addr, NewKadParams()) pp := NewHive(params, to, store) // hive - bt, err := newBzzBaseTester(t, n, prvkey, DiscoverySpec, pp.Run) + bt, err := newBzzBaseTester(n, prvkey, DiscoverySpec, pp.Run) if err != nil { return nil, nil, err } @@ -48,7 +49,7 @@ func newHiveTester(t *testing.T, params *HiveParams, n int, store state.Store) ( // and that the peer connection exists afterwards func TestRegisterAndConnect(t *testing.T) { params := NewHiveParams() - s, pp, err := newHiveTester(t, params, 1, nil) + s, pp, err := newHiveTester(params, 1, nil) if err != nil { t.Fatal(err) } @@ -108,65 +109,66 @@ func TestRegisterAndConnect(t *testing.T) { // Actual connectivity is not in scope for this test, as the peers loaded from state are not known to // the simulation; the test only verifies that the peers are known to the node func TestHiveStatePersistance(t *testing.T) { - dir, err := ioutil.TempDir("", "hive_test_store") if err != nil { - panic(err) + t.Fatal(err) } defer os.RemoveAll(dir) - store, err := state.NewDBStore(dir) //start the hive with an empty dbstore - if err != nil { - t.Fatal(err) + const peersCount = 5 + + startHive := func(t *testing.T, dir string) (h *Hive) { + store, err := state.NewDBStore(dir) + if err != nil { + t.Fatal(err) + } + + params := NewHiveParams() + params.Discovery = false + + prvkey, err := crypto.GenerateKey() + if err != nil { + t.Fatal(err) + } + + h = NewHive(params, NewKademlia(PrivateKeyToBzzKey(prvkey), NewKadParams()), store) + s := p2ptest.NewProtocolTester(prvkey, 0, func(p *p2p.Peer, rw p2p.MsgReadWriter) error { return nil }) + + if err := h.Start(s.Server); err != nil { + t.Fatal(err) + } + return h } - params := NewHiveParams() - s, pp, err := newHiveTester(t, params, 5, store) - if err != nil { - t.Fatal(err) - } + h1 := startHive(t, dir) peers := make(map[string]bool) - for _, node := range s.Nodes { - raddr := NewAddr(node) - pp.Register(raddr) + for i := 0; i < peersCount; i++ { + raddr := RandomAddr() + h1.Register(raddr) peers[raddr.String()] = true } - - // start and stop the hive - // the known peers should be saved upon stopping - err = pp.Start(s.Server) - if err != nil { - t.Fatal(err) - } - pp.Stop() - store.Close() - - // start the hive with an empty dbstore - persistedStore, err := state.NewDBStore(dir) - if err != nil { - t.Fatal(err) - } - - s1, pp, err := newHiveTester(t, params, 0, persistedStore) - if err != nil { + if err = h1.Stop(); err != nil { t.Fatal(err) } // start the hive and check that we know of all expected peers - pp.Start(s1.Server) + h2 := startHive(t, dir) + defer func() { + if err = h2.Stop(); err != nil { + t.Fatal(err) + } + }() + i := 0 - pp.Kademlia.EachAddr(nil, 256, func(addr *BzzAddr, po int) bool { + h2.Kademlia.EachAddr(nil, 256, func(addr *BzzAddr, po int) bool { delete(peers, addr.String()) i++ return true }) - // TODO remove this line when verified that test passes - time.Sleep(time.Second) - if i != 5 { - t.Fatalf("invalid number of entries: got %v, want %v", i, 5) + if i != peersCount { + t.Fatalf("invalid number of entries: got %v, want %v", i, peersCount) } if len(peers) != 0 { t.Fatalf("%d peers left over: %v", len(peers), peers) } - } diff --git a/swarm/network/protocol.go b/swarm/network/protocol.go index fcceb5c31..ad3f8df8f 100644 --- a/swarm/network/protocol.go +++ b/swarm/network/protocol.go @@ -20,6 +20,7 @@ import ( "context" "errors" "fmt" + "math/rand" "sync" "time" @@ -37,6 +38,8 @@ const ( bzzHandshakeTimeout = 3000 * time.Millisecond ) +var DefaultTestNetworkID = rand.Uint64() + // BzzSpec is the spec of the generic swarm handshake var BzzSpec = &protocols.Spec{ Name: "bzz", diff --git a/swarm/network/protocol_test.go b/swarm/network/protocol_test.go index 1e7bb04aa..b562a4253 100644 --- a/swarm/network/protocol_test.go +++ b/swarm/network/protocol_test.go @@ -21,6 +21,7 @@ import ( "flag" "fmt" "os" + "sync" "testing" "time" @@ -31,13 +32,15 @@ import ( "github.com/ethereum/go-ethereum/p2p/enr" "github.com/ethereum/go-ethereum/p2p/protocols" p2ptest "github.com/ethereum/go-ethereum/p2p/testing" + "github.com/ethereum/go-ethereum/swarm/pot" ) const ( - TestProtocolVersion = 8 - TestProtocolNetworkID = 3 + TestProtocolVersion = 8 ) +var TestProtocolNetworkID = DefaultTestNetworkID + var ( loglevel = flag.Int("loglevel", 2, "verbosity of logs") ) @@ -70,20 +73,37 @@ func HandshakeMsgExchange(lhs, rhs *HandshakeMsg, id enode.ID) []p2ptest.Exchang } } -func newBzzBaseTester(t *testing.T, n int, prvkey *ecdsa.PrivateKey, spec *protocols.Spec, run func(*BzzPeer) error) (*bzzTester, error) { - cs := make(map[string]chan bool) +func newBzzBaseTester(n int, prvkey *ecdsa.PrivateKey, spec *protocols.Spec, run func(*BzzPeer) error) (*bzzTester, error) { + var addrs [][]byte + for i := 0; i < n; i++ { + addr := pot.RandomAddress() + addrs = append(addrs, addr[:]) + } + pt, _, err := newBzzBaseTesterWithAddrs(prvkey, addrs, spec, run) + return pt, err +} + +func newBzzBaseTesterWithAddrs(prvkey *ecdsa.PrivateKey, addrs [][]byte, spec *protocols.Spec, run func(*BzzPeer) error) (*bzzTester, [][]byte, error) { + n := len(addrs) + cs := make(map[enode.ID]chan bool) srv := func(p *BzzPeer) error { defer func() { - if cs[p.ID().String()] != nil { - close(cs[p.ID().String()]) + if cs[p.ID()] != nil { + close(cs[p.ID()]) } }() return run(p) } - + mu := &sync.Mutex{} + nodeToAddr := make(map[enode.ID][]byte) protocol := func(p *p2p.Peer, rw p2p.MsgReadWriter) error { - return srv(&BzzPeer{Peer: protocols.NewPeer(p, rw, spec), BzzAddr: NewAddr(p.Node())}) + mu.Lock() + defer mu.Unlock() + nodeToAddr[p.ID()] = addrs[0] + bzzAddr := &BzzAddr{addrs[0], []byte(p.Node().String())} + addrs = addrs[1:] + return srv(&BzzPeer{Peer: protocols.NewPeer(p, rw, spec), BzzAddr: bzzAddr}) } s := p2ptest.NewProtocolTester(prvkey, n, protocol) @@ -92,30 +112,36 @@ func newBzzBaseTester(t *testing.T, n int, prvkey *ecdsa.PrivateKey, spec *proto record.Set(NewENRAddrEntry(bzzKey)) err := enode.SignV4(&record, prvkey) if err != nil { - return nil, fmt.Errorf("unable to generate ENR: %v", err) + return nil, nil, fmt.Errorf("unable to generate ENR: %v", err) } nod, err := enode.New(enode.V4ID{}, &record) if err != nil { - return nil, fmt.Errorf("unable to create enode: %v", err) + return nil, nil, fmt.Errorf("unable to create enode: %v", err) } addr := getENRBzzAddr(nod) for _, node := range s.Nodes { log.Warn("node", "node", node) - cs[node.ID().String()] = make(chan bool) + cs[node.ID()] = make(chan bool) } - return &bzzTester{ + var nodeAddrs [][]byte + pt := &bzzTester{ addr: addr, ProtocolTester: s, cs: cs, - }, nil + } + for _, n := range pt.Nodes { + nodeAddrs = append(nodeAddrs, nodeToAddr[n.ID()]) + } + + return pt, nodeAddrs, nil } type bzzTester struct { *p2ptest.ProtocolTester addr *BzzAddr - cs map[string]chan bool + cs map[enode.ID]chan bool bzz *Bzz } @@ -124,7 +150,7 @@ func newBzz(addr *BzzAddr, lightNode bool) *Bzz { OverlayAddr: addr.Over(), UnderlayAddr: addr.Under(), HiveParams: NewHiveParams(), - NetworkID: DefaultNetworkID, + NetworkID: DefaultTestNetworkID, LightNode: lightNode, } kad := NewKademlia(addr.OAddr, NewKadParams()) @@ -207,7 +233,7 @@ func TestBzzHandshakeNetworkIDMismatch(t *testing.T) { err = s.testHandshake( correctBzzHandshake(s.addr, lightNode), &HandshakeMsg{Version: TestProtocolVersion, NetworkID: 321, Addr: NewAddr(node)}, - &p2ptest.Disconnect{Peer: node.ID(), Error: fmt.Errorf("Handshake error: Message handler error: (msg code 0): network id mismatch 321 (!= 3)")}, + &p2ptest.Disconnect{Peer: node.ID(), Error: fmt.Errorf("Handshake error: Message handler error: (msg code 0): network id mismatch 321 (!= %v)", TestProtocolNetworkID)}, ) if err != nil {