forked from cerc-io/plugeth
Merge pull request #592 from fjl/disco-ping-pong
Discovery bonding protocol
This commit is contained in:
commit
fd171eff7f
@ -32,8 +32,8 @@ var (
|
||||
defaultBootNodes = []*discover.Node{
|
||||
// ETH/DEV cmd/bootnode
|
||||
discover.MustParseNode("enode://09fbeec0d047e9a37e63f60f8618aa9df0e49271f3fadb2c070dc09e2099b95827b63a8b837c6fd01d0802d457dd83e3bd48bd3e6509f8209ed90dabbc30e3d3@52.16.188.185:30303"),
|
||||
// ETH/DEV cpp-ethereum (poc-8.ethdev.com)
|
||||
discover.MustParseNode("enode://4a44599974518ea5b0f14c31c4463692ac0329cb84851f3435e6d1b18ee4eae4aa495f846a0fa1219bd58035671881d44423876e57db2abd57254d0197da0ebe@5.1.83.226:30303"),
|
||||
// ETH/DEV cpp-ethereum (poc-9.ethdev.com)
|
||||
discover.MustParseNode("enode://487611428e6c99a11a9795a6abe7b529e81315ca6aad66e2a2fc76e3adf263faba0d35466c2f8f68d561dbefa8878d4df5f1f2ddb1fbeab7f42ffb8cd328bd4a@5.1.83.226:30303"),
|
||||
}
|
||||
)
|
||||
|
||||
|
@ -13,6 +13,8 @@ import (
|
||||
"net/url"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/ethereum/go-ethereum/crypto"
|
||||
@ -30,7 +32,8 @@ type Node struct {
|
||||
DiscPort int // UDP listening port for discovery protocol
|
||||
TCPPort int // TCP listening port for RLPx
|
||||
|
||||
active time.Time
|
||||
// this must be set/read using atomic load and store.
|
||||
activeStamp int64
|
||||
}
|
||||
|
||||
func newNode(id NodeID, addr *net.UDPAddr) *Node {
|
||||
@ -39,7 +42,6 @@ func newNode(id NodeID, addr *net.UDPAddr) *Node {
|
||||
IP: addr.IP,
|
||||
DiscPort: addr.Port,
|
||||
TCPPort: addr.Port,
|
||||
active: time.Now(),
|
||||
}
|
||||
}
|
||||
|
||||
@ -48,6 +50,20 @@ func (n *Node) isValid() bool {
|
||||
return !n.IP.IsMulticast() && !n.IP.IsUnspecified() && n.TCPPort != 0 && n.DiscPort != 0
|
||||
}
|
||||
|
||||
func (n *Node) bumpActive() {
|
||||
stamp := time.Now().Unix()
|
||||
atomic.StoreInt64(&n.activeStamp, stamp)
|
||||
}
|
||||
|
||||
func (n *Node) active() time.Time {
|
||||
stamp := atomic.LoadInt64(&n.activeStamp)
|
||||
return time.Unix(stamp, 0)
|
||||
}
|
||||
|
||||
func (n *Node) addr() *net.UDPAddr {
|
||||
return &net.UDPAddr{IP: n.IP, Port: n.DiscPort}
|
||||
}
|
||||
|
||||
// The string representation of a Node is a URL.
|
||||
// Please see ParseNode for a description of the format.
|
||||
func (n *Node) String() string {
|
||||
@ -304,3 +320,26 @@ func randomID(a NodeID, n int) (b NodeID) {
|
||||
}
|
||||
return b
|
||||
}
|
||||
|
||||
// nodeDB stores all nodes we know about.
|
||||
type nodeDB struct {
|
||||
mu sync.RWMutex
|
||||
byID map[NodeID]*Node
|
||||
}
|
||||
|
||||
func (db *nodeDB) get(id NodeID) *Node {
|
||||
db.mu.RLock()
|
||||
defer db.mu.RUnlock()
|
||||
return db.byID[id]
|
||||
}
|
||||
|
||||
func (db *nodeDB) add(id NodeID, addr *net.UDPAddr, tcpPort uint16) *Node {
|
||||
db.mu.Lock()
|
||||
defer db.mu.Unlock()
|
||||
if db.byID == nil {
|
||||
db.byID = make(map[NodeID]*Node)
|
||||
}
|
||||
n := &Node{ID: id, IP: addr.IP, DiscPort: addr.Port, TCPPort: int(tcpPort)}
|
||||
db.byID[n.ID] = n
|
||||
return n
|
||||
}
|
||||
|
@ -17,6 +17,7 @@ const (
|
||||
alpha = 3 // Kademlia concurrency factor
|
||||
bucketSize = 16 // Kademlia bucket size
|
||||
nBuckets = nodeIDBits + 1 // Number of buckets
|
||||
maxBondingPingPongs = 10
|
||||
)
|
||||
|
||||
type Table struct {
|
||||
@ -24,27 +25,50 @@ type Table struct {
|
||||
buckets [nBuckets]*bucket // index of known nodes by distance
|
||||
nursery []*Node // bootstrap nodes
|
||||
|
||||
bondmu sync.Mutex
|
||||
bonding map[NodeID]*bondproc
|
||||
bondslots chan struct{} // limits total number of active bonding processes
|
||||
|
||||
net transport
|
||||
self *Node // metadata of the local node
|
||||
db *nodeDB
|
||||
}
|
||||
|
||||
type bondproc struct {
|
||||
err error
|
||||
n *Node
|
||||
done chan struct{}
|
||||
}
|
||||
|
||||
// transport is implemented by the UDP transport.
|
||||
// it is an interface so we can test without opening lots of UDP
|
||||
// sockets and without generating a private key.
|
||||
type transport interface {
|
||||
ping(*Node) error
|
||||
findnode(e *Node, target NodeID) ([]*Node, error)
|
||||
ping(NodeID, *net.UDPAddr) error
|
||||
waitping(NodeID) error
|
||||
findnode(toid NodeID, addr *net.UDPAddr, target NodeID) ([]*Node, error)
|
||||
close()
|
||||
}
|
||||
|
||||
// bucket contains nodes, ordered by their last activity.
|
||||
// the entry that was most recently active is the last element
|
||||
// in entries.
|
||||
type bucket struct {
|
||||
lastLookup time.Time
|
||||
entries []*Node
|
||||
}
|
||||
|
||||
func newTable(t transport, ourID NodeID, ourAddr *net.UDPAddr) *Table {
|
||||
tab := &Table{net: t, self: newNode(ourID, ourAddr)}
|
||||
tab := &Table{
|
||||
net: t,
|
||||
db: new(nodeDB),
|
||||
self: newNode(ourID, ourAddr),
|
||||
bonding: make(map[NodeID]*bondproc),
|
||||
bondslots: make(chan struct{}, maxBondingPingPongs),
|
||||
}
|
||||
for i := 0; i < cap(tab.bondslots); i++ {
|
||||
tab.bondslots <- struct{}{}
|
||||
}
|
||||
for i := range tab.buckets {
|
||||
tab.buckets[i] = new(bucket)
|
||||
}
|
||||
@ -107,8 +131,8 @@ func (tab *Table) Lookup(target NodeID) []*Node {
|
||||
asked[n.ID] = true
|
||||
pendingQueries++
|
||||
go func() {
|
||||
result, _ := tab.net.findnode(n, target)
|
||||
reply <- result
|
||||
r, _ := tab.net.findnode(n.ID, n.addr(), target)
|
||||
reply <- tab.bondall(r)
|
||||
}()
|
||||
}
|
||||
}
|
||||
@ -116,13 +140,11 @@ func (tab *Table) Lookup(target NodeID) []*Node {
|
||||
// we have asked all closest nodes, stop the search
|
||||
break
|
||||
}
|
||||
|
||||
// wait for the next reply
|
||||
for _, n := range <-reply {
|
||||
cn := n
|
||||
if !seen[n.ID] {
|
||||
if n != nil && !seen[n.ID] {
|
||||
seen[n.ID] = true
|
||||
result.push(cn, bucketSize)
|
||||
result.push(n, bucketSize)
|
||||
}
|
||||
}
|
||||
pendingQueries--
|
||||
@ -145,8 +167,9 @@ func (tab *Table) refresh() {
|
||||
result := tab.Lookup(randomID(tab.self.ID, ld))
|
||||
if len(result) == 0 {
|
||||
// bootstrap the table with a self lookup
|
||||
all := tab.bondall(tab.nursery)
|
||||
tab.mutex.Lock()
|
||||
tab.add(tab.nursery)
|
||||
tab.add(all)
|
||||
tab.mutex.Unlock()
|
||||
tab.Lookup(tab.self.ID)
|
||||
// TODO: the Kademlia paper says that we're supposed to perform
|
||||
@ -176,45 +199,105 @@ func (tab *Table) len() (n int) {
|
||||
return n
|
||||
}
|
||||
|
||||
// bumpOrAdd updates the activity timestamp for the given node and
|
||||
// attempts to insert the node into a bucket. The returned Node might
|
||||
// not be part of the table. The caller must hold tab.mutex.
|
||||
func (tab *Table) bumpOrAdd(node NodeID, from *net.UDPAddr) (n *Node) {
|
||||
b := tab.buckets[logdist(tab.self.ID, node)]
|
||||
if n = b.bump(node); n == nil {
|
||||
n = newNode(node, from)
|
||||
if len(b.entries) == bucketSize {
|
||||
tab.pingReplace(n, b)
|
||||
} else {
|
||||
b.entries = append(b.entries, n)
|
||||
// bondall bonds with all given nodes concurrently and returns
|
||||
// those nodes for which bonding has probably succeeded.
|
||||
func (tab *Table) bondall(nodes []*Node) (result []*Node) {
|
||||
rc := make(chan *Node, len(nodes))
|
||||
for i := range nodes {
|
||||
go func(n *Node) {
|
||||
nn, _ := tab.bond(false, n.ID, n.addr(), uint16(n.TCPPort))
|
||||
rc <- nn
|
||||
}(nodes[i])
|
||||
}
|
||||
for _ = range nodes {
|
||||
if n := <-rc; n != nil {
|
||||
result = append(result, n)
|
||||
}
|
||||
}
|
||||
return n
|
||||
return result
|
||||
}
|
||||
|
||||
func (tab *Table) pingReplace(n *Node, b *bucket) {
|
||||
old := b.entries[bucketSize-1]
|
||||
go func() {
|
||||
if err := tab.net.ping(old); err == nil {
|
||||
// it responded, we don't need to replace it.
|
||||
// bond ensures the local node has a bond with the given remote node.
|
||||
// It also attempts to insert the node into the table if bonding succeeds.
|
||||
// The caller must not hold tab.mutex.
|
||||
//
|
||||
// A bond is must be established before sending findnode requests.
|
||||
// Both sides must have completed a ping/pong exchange for a bond to
|
||||
// exist. The total number of active bonding processes is limited in
|
||||
// order to restrain network use.
|
||||
//
|
||||
// bond is meant to operate idempotently in that bonding with a remote
|
||||
// node which still remembers a previously established bond will work.
|
||||
// The remote node will simply not send a ping back, causing waitping
|
||||
// to time out.
|
||||
//
|
||||
// If pinged is true, the remote node has just pinged us and one half
|
||||
// of the process can be skipped.
|
||||
func (tab *Table) bond(pinged bool, id NodeID, addr *net.UDPAddr, tcpPort uint16) (*Node, error) {
|
||||
var n *Node
|
||||
if n = tab.db.get(id); n == nil {
|
||||
tab.bondmu.Lock()
|
||||
w := tab.bonding[id]
|
||||
if w != nil {
|
||||
// Wait for an existing bonding process to complete.
|
||||
tab.bondmu.Unlock()
|
||||
<-w.done
|
||||
} else {
|
||||
// Register a new bonding process.
|
||||
w = &bondproc{done: make(chan struct{})}
|
||||
tab.bonding[id] = w
|
||||
tab.bondmu.Unlock()
|
||||
// Do the ping/pong. The result goes into w.
|
||||
tab.pingpong(w, pinged, id, addr, tcpPort)
|
||||
// Unregister the process after it's done.
|
||||
tab.bondmu.Lock()
|
||||
delete(tab.bonding, id)
|
||||
tab.bondmu.Unlock()
|
||||
}
|
||||
n = w.n
|
||||
if w.err != nil {
|
||||
return nil, w.err
|
||||
}
|
||||
}
|
||||
tab.mutex.Lock()
|
||||
defer tab.mutex.Unlock()
|
||||
if b := tab.buckets[logdist(tab.self.ID, n.ID)]; !b.bump(n) {
|
||||
tab.pingreplace(n, b)
|
||||
}
|
||||
return n, nil
|
||||
}
|
||||
|
||||
func (tab *Table) pingpong(w *bondproc, pinged bool, id NodeID, addr *net.UDPAddr, tcpPort uint16) {
|
||||
<-tab.bondslots
|
||||
defer func() { tab.bondslots <- struct{}{} }()
|
||||
if w.err = tab.net.ping(id, addr); w.err != nil {
|
||||
close(w.done)
|
||||
return
|
||||
}
|
||||
// it didn't respond, replace the node if it is still the oldest node.
|
||||
tab.mutex.Lock()
|
||||
if len(b.entries) > 0 && b.entries[len(b.entries)-1] == old {
|
||||
// slide down other entries and put the new one in front.
|
||||
// TODO: insert in correct position to keep the order
|
||||
copy(b.entries[1:], b.entries)
|
||||
b.entries[0] = n
|
||||
if !pinged {
|
||||
// Give the remote node a chance to ping us before we start
|
||||
// sending findnode requests. If they still remember us,
|
||||
// waitping will simply time out.
|
||||
tab.net.waitping(id)
|
||||
}
|
||||
tab.mutex.Unlock()
|
||||
}()
|
||||
w.n = tab.db.add(id, addr, tcpPort)
|
||||
close(w.done)
|
||||
}
|
||||
|
||||
// bump updates the activity timestamp for the given node.
|
||||
// The caller must hold tab.mutex.
|
||||
func (tab *Table) bump(node NodeID) {
|
||||
tab.buckets[logdist(tab.self.ID, node)].bump(node)
|
||||
func (tab *Table) pingreplace(new *Node, b *bucket) {
|
||||
if len(b.entries) == bucketSize {
|
||||
oldest := b.entries[bucketSize-1]
|
||||
if err := tab.net.ping(oldest.ID, oldest.addr()); err == nil {
|
||||
// The node responded, we don't need to replace it.
|
||||
return
|
||||
}
|
||||
} else {
|
||||
// Add a slot at the end so the last entry doesn't
|
||||
// fall off when adding the new node.
|
||||
b.entries = append(b.entries, nil)
|
||||
}
|
||||
copy(b.entries[1:], b.entries)
|
||||
b.entries[0] = new
|
||||
}
|
||||
|
||||
// add puts the entries into the table if their corresponding
|
||||
@ -240,17 +323,17 @@ outer:
|
||||
}
|
||||
}
|
||||
|
||||
func (b *bucket) bump(id NodeID) *Node {
|
||||
for i, n := range b.entries {
|
||||
if n.ID == id {
|
||||
n.active = time.Now()
|
||||
func (b *bucket) bump(n *Node) bool {
|
||||
for i := range b.entries {
|
||||
if b.entries[i].ID == n.ID {
|
||||
n.bumpActive()
|
||||
// move it to the front
|
||||
copy(b.entries[1:], b.entries[:i+1])
|
||||
copy(b.entries[1:], b.entries[:i])
|
||||
b.entries[0] = n
|
||||
return n
|
||||
return true
|
||||
}
|
||||
}
|
||||
return nil
|
||||
return false
|
||||
}
|
||||
|
||||
// nodesByDistance is a list of nodes, ordered by
|
||||
|
@ -2,79 +2,110 @@ package discover
|
||||
|
||||
import (
|
||||
"crypto/ecdsa"
|
||||
"errors"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"net"
|
||||
"reflect"
|
||||
"testing"
|
||||
"testing/quick"
|
||||
"time"
|
||||
|
||||
"github.com/ethereum/go-ethereum/crypto"
|
||||
)
|
||||
|
||||
func TestTable_bumpOrAddBucketAssign(t *testing.T) {
|
||||
tab := newTable(nil, NodeID{}, &net.UDPAddr{})
|
||||
for i := 1; i < len(tab.buckets); i++ {
|
||||
tab.bumpOrAdd(randomID(tab.self.ID, i), &net.UDPAddr{})
|
||||
}
|
||||
for i, b := range tab.buckets {
|
||||
if i > 0 && len(b.entries) != 1 {
|
||||
t.Errorf("bucket %d has %d entries, want 1", i, len(b.entries))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestTable_bumpOrAddPingReplace(t *testing.T) {
|
||||
pingC := make(pingC)
|
||||
tab := newTable(pingC, NodeID{}, &net.UDPAddr{})
|
||||
func TestTable_pingReplace(t *testing.T) {
|
||||
doit := func(newNodeIsResponding, lastInBucketIsResponding bool) {
|
||||
transport := newPingRecorder()
|
||||
tab := newTable(transport, NodeID{}, &net.UDPAddr{})
|
||||
last := fillBucket(tab, 200)
|
||||
pingSender := randomID(tab.self.ID, 200)
|
||||
|
||||
// this bumpOrAdd should not replace the last node
|
||||
// because the node replies to ping.
|
||||
new := tab.bumpOrAdd(randomID(tab.self.ID, 200), &net.UDPAddr{})
|
||||
// this gotPing should replace the last node
|
||||
// if the last node is not responding.
|
||||
transport.responding[last.ID] = lastInBucketIsResponding
|
||||
transport.responding[pingSender] = newNodeIsResponding
|
||||
tab.bond(true, pingSender, &net.UDPAddr{}, 0)
|
||||
|
||||
pinged := <-pingC
|
||||
if pinged != last.ID {
|
||||
t.Fatalf("pinged wrong node: %v\nwant %v", pinged, last.ID)
|
||||
// first ping goes to sender (bonding pingback)
|
||||
if !transport.pinged[pingSender] {
|
||||
t.Error("table did not ping back sender")
|
||||
}
|
||||
if newNodeIsResponding {
|
||||
// second ping goes to oldest node in bucket
|
||||
// to see whether it is still alive.
|
||||
if !transport.pinged[last.ID] {
|
||||
t.Error("table did not ping last node in bucket")
|
||||
}
|
||||
}
|
||||
|
||||
tab.mutex.Lock()
|
||||
defer tab.mutex.Unlock()
|
||||
if l := len(tab.buckets[200].entries); l != bucketSize {
|
||||
t.Errorf("wrong bucket size after bumpOrAdd: got %d, want %d", bucketSize, l)
|
||||
t.Errorf("wrong bucket size after gotPing: got %d, want %d", bucketSize, l)
|
||||
}
|
||||
|
||||
if lastInBucketIsResponding || !newNodeIsResponding {
|
||||
if !contains(tab.buckets[200].entries, last.ID) {
|
||||
t.Error("last entry was removed")
|
||||
}
|
||||
if contains(tab.buckets[200].entries, new.ID) {
|
||||
if contains(tab.buckets[200].entries, pingSender) {
|
||||
t.Error("new entry was added")
|
||||
}
|
||||
}
|
||||
|
||||
func TestTable_bumpOrAddPingTimeout(t *testing.T) {
|
||||
tab := newTable(pingC(nil), NodeID{}, &net.UDPAddr{})
|
||||
last := fillBucket(tab, 200)
|
||||
|
||||
// this bumpOrAdd should replace the last node
|
||||
// because the node does not reply to ping.
|
||||
new := tab.bumpOrAdd(randomID(tab.self.ID, 200), &net.UDPAddr{})
|
||||
|
||||
// wait for async bucket update. damn. this needs to go away.
|
||||
time.Sleep(2 * time.Millisecond)
|
||||
|
||||
tab.mutex.Lock()
|
||||
defer tab.mutex.Unlock()
|
||||
if l := len(tab.buckets[200].entries); l != bucketSize {
|
||||
t.Errorf("wrong bucket size after bumpOrAdd: got %d, want %d", bucketSize, l)
|
||||
}
|
||||
} else {
|
||||
if contains(tab.buckets[200].entries, last.ID) {
|
||||
t.Error("last entry was not removed")
|
||||
}
|
||||
if !contains(tab.buckets[200].entries, new.ID) {
|
||||
if !contains(tab.buckets[200].entries, pingSender) {
|
||||
t.Error("new entry was not added")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
doit(true, true)
|
||||
doit(false, true)
|
||||
doit(false, true)
|
||||
doit(false, false)
|
||||
}
|
||||
|
||||
func TestBucket_bumpNoDuplicates(t *testing.T) {
|
||||
t.Parallel()
|
||||
cfg := &quick.Config{
|
||||
MaxCount: 1000,
|
||||
Rand: quickrand,
|
||||
Values: func(args []reflect.Value, rand *rand.Rand) {
|
||||
// generate a random list of nodes. this will be the content of the bucket.
|
||||
n := rand.Intn(bucketSize-1) + 1
|
||||
nodes := make([]*Node, n)
|
||||
for i := range nodes {
|
||||
nodes[i] = &Node{ID: randomID(NodeID{}, 200)}
|
||||
}
|
||||
args[0] = reflect.ValueOf(nodes)
|
||||
// generate random bump positions.
|
||||
bumps := make([]int, rand.Intn(100))
|
||||
for i := range bumps {
|
||||
bumps[i] = rand.Intn(len(nodes))
|
||||
}
|
||||
args[1] = reflect.ValueOf(bumps)
|
||||
},
|
||||
}
|
||||
|
||||
prop := func(nodes []*Node, bumps []int) (ok bool) {
|
||||
b := &bucket{entries: make([]*Node, len(nodes))}
|
||||
copy(b.entries, nodes)
|
||||
for i, pos := range bumps {
|
||||
b.bump(b.entries[pos])
|
||||
if hasDuplicates(b.entries) {
|
||||
t.Logf("bucket has duplicates after %d/%d bumps:", i+1, len(bumps))
|
||||
for _, n := range b.entries {
|
||||
t.Logf(" %p", n)
|
||||
}
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
if err := quick.Check(prop, cfg); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
}
|
||||
|
||||
func fillBucket(tab *Table, ld int) (last *Node) {
|
||||
@ -85,44 +116,27 @@ func fillBucket(tab *Table, ld int) (last *Node) {
|
||||
return b.entries[bucketSize-1]
|
||||
}
|
||||
|
||||
type pingC chan NodeID
|
||||
type pingRecorder struct{ responding, pinged map[NodeID]bool }
|
||||
|
||||
func (t pingC) findnode(n *Node, target NodeID) ([]*Node, error) {
|
||||
func newPingRecorder() *pingRecorder {
|
||||
return &pingRecorder{make(map[NodeID]bool), make(map[NodeID]bool)}
|
||||
}
|
||||
|
||||
func (t *pingRecorder) findnode(toid NodeID, toaddr *net.UDPAddr, target NodeID) ([]*Node, error) {
|
||||
panic("findnode called on pingRecorder")
|
||||
}
|
||||
func (t pingC) close() {
|
||||
func (t *pingRecorder) close() {
|
||||
panic("close called on pingRecorder")
|
||||
}
|
||||
func (t pingC) ping(n *Node) error {
|
||||
if t == nil {
|
||||
return errTimeout
|
||||
}
|
||||
t <- n.ID
|
||||
return nil
|
||||
func (t *pingRecorder) waitping(from NodeID) error {
|
||||
return nil // remote always pings
|
||||
}
|
||||
|
||||
func TestTable_bump(t *testing.T) {
|
||||
tab := newTable(nil, NodeID{}, &net.UDPAddr{})
|
||||
|
||||
// add an old entry and two recent ones
|
||||
oldactive := time.Now().Add(-2 * time.Minute)
|
||||
old := &Node{ID: randomID(tab.self.ID, 200), active: oldactive}
|
||||
others := []*Node{
|
||||
&Node{ID: randomID(tab.self.ID, 200), active: time.Now()},
|
||||
&Node{ID: randomID(tab.self.ID, 200), active: time.Now()},
|
||||
}
|
||||
tab.add(append(others, old))
|
||||
if tab.buckets[200].entries[0] == old {
|
||||
t.Fatal("old entry is at front of bucket")
|
||||
}
|
||||
|
||||
// bumping the old entry should move it to the front
|
||||
tab.bump(old.ID)
|
||||
if old.active == oldactive {
|
||||
t.Error("activity timestamp not updated")
|
||||
}
|
||||
if tab.buckets[200].entries[0] != old {
|
||||
t.Errorf("bumped entry did not move to the front of bucket")
|
||||
func (t *pingRecorder) ping(toid NodeID, toaddr *net.UDPAddr) error {
|
||||
t.pinged[toid] = true
|
||||
if t.responding[toid] {
|
||||
return nil
|
||||
} else {
|
||||
return errTimeout
|
||||
}
|
||||
}
|
||||
|
||||
@ -210,7 +224,7 @@ func TestTable_Lookup(t *testing.T) {
|
||||
t.Fatalf("lookup on empty table returned %d results: %#v", len(results), results)
|
||||
}
|
||||
// seed table with initial node (otherwise lookup will terminate immediately)
|
||||
tab.bumpOrAdd(randomID(target, 200), &net.UDPAddr{Port: 200})
|
||||
tab.add([]*Node{newNode(randomID(target, 200), &net.UDPAddr{Port: 200})})
|
||||
|
||||
results := tab.Lookup(target)
|
||||
t.Logf("results:")
|
||||
@ -238,16 +252,16 @@ type findnodeOracle struct {
|
||||
target NodeID
|
||||
}
|
||||
|
||||
func (t findnodeOracle) findnode(n *Node, target NodeID) ([]*Node, error) {
|
||||
t.t.Logf("findnode query at dist %d", n.DiscPort)
|
||||
func (t findnodeOracle) findnode(toid NodeID, toaddr *net.UDPAddr, target NodeID) ([]*Node, error) {
|
||||
t.t.Logf("findnode query at dist %d", toaddr.Port)
|
||||
// current log distance is encoded in port number
|
||||
var result []*Node
|
||||
switch n.DiscPort {
|
||||
switch toaddr.Port {
|
||||
case 0:
|
||||
panic("query to node at distance 0")
|
||||
default:
|
||||
// TODO: add more randomness to distances
|
||||
next := n.DiscPort - 1
|
||||
next := toaddr.Port - 1
|
||||
for i := 0; i < bucketSize; i++ {
|
||||
result = append(result, &Node{ID: randomID(t.target, next), DiscPort: next})
|
||||
}
|
||||
@ -256,10 +270,8 @@ func (t findnodeOracle) findnode(n *Node, target NodeID) ([]*Node, error) {
|
||||
}
|
||||
|
||||
func (t findnodeOracle) close() {}
|
||||
|
||||
func (t findnodeOracle) ping(n *Node) error {
|
||||
return errors.New("ping is not supported by this transport")
|
||||
}
|
||||
func (t findnodeOracle) waitping(from NodeID) error { return nil }
|
||||
func (t findnodeOracle) ping(toid NodeID, toaddr *net.UDPAddr) error { return nil }
|
||||
|
||||
func hasDuplicates(slice []*Node) bool {
|
||||
seen := make(map[NodeID]bool)
|
||||
|
@ -16,11 +16,16 @@ import (
|
||||
|
||||
var log = logger.NewLogger("P2P Discovery")
|
||||
|
||||
const Version = 3
|
||||
|
||||
// Errors
|
||||
var (
|
||||
errPacketTooSmall = errors.New("too small")
|
||||
errBadHash = errors.New("bad hash")
|
||||
errExpired = errors.New("expired")
|
||||
errBadVersion = errors.New("version mismatch")
|
||||
errUnsolicitedReply = errors.New("unsolicited reply")
|
||||
errUnknownNode = errors.New("unknown node")
|
||||
errTimeout = errors.New("RPC timeout")
|
||||
errClosed = errors.New("socket closed")
|
||||
)
|
||||
@ -45,6 +50,7 @@ const (
|
||||
// RPC request structures
|
||||
type (
|
||||
ping struct {
|
||||
Version uint // must match Version
|
||||
IP string // our IP
|
||||
Port uint16 // our port
|
||||
Expiration uint64
|
||||
@ -76,12 +82,25 @@ type rpcNode struct {
|
||||
ID NodeID
|
||||
}
|
||||
|
||||
type packet interface {
|
||||
handle(t *udp, from *net.UDPAddr, fromID NodeID, mac []byte) error
|
||||
}
|
||||
|
||||
type conn interface {
|
||||
ReadFromUDP(b []byte) (n int, addr *net.UDPAddr, err error)
|
||||
WriteToUDP(b []byte, addr *net.UDPAddr) (n int, err error)
|
||||
Close() error
|
||||
LocalAddr() net.Addr
|
||||
}
|
||||
|
||||
// udp implements the RPC protocol.
|
||||
type udp struct {
|
||||
conn *net.UDPConn
|
||||
conn conn
|
||||
priv *ecdsa.PrivateKey
|
||||
|
||||
addpending chan *pending
|
||||
replies chan reply
|
||||
gotreply chan reply
|
||||
|
||||
closing chan struct{}
|
||||
nat nat.Interface
|
||||
|
||||
@ -120,6 +139,9 @@ type reply struct {
|
||||
from NodeID
|
||||
ptype byte
|
||||
data interface{}
|
||||
// loop indicates whether there was
|
||||
// a matching request by sending on this channel.
|
||||
matched chan<- bool
|
||||
}
|
||||
|
||||
// ListenUDP returns a new table that listens for UDP packets on laddr.
|
||||
@ -132,15 +154,20 @@ func ListenUDP(priv *ecdsa.PrivateKey, laddr string, natm nat.Interface) (*Table
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
tab, _ := newUDP(priv, conn, natm)
|
||||
log.Infoln("Listening,", tab.self)
|
||||
return tab, nil
|
||||
}
|
||||
|
||||
func newUDP(priv *ecdsa.PrivateKey, c conn, natm nat.Interface) (*Table, *udp) {
|
||||
udp := &udp{
|
||||
conn: conn,
|
||||
conn: c,
|
||||
priv: priv,
|
||||
closing: make(chan struct{}),
|
||||
gotreply: make(chan reply),
|
||||
addpending: make(chan *pending),
|
||||
replies: make(chan reply),
|
||||
}
|
||||
|
||||
realaddr := conn.LocalAddr().(*net.UDPAddr)
|
||||
realaddr := c.LocalAddr().(*net.UDPAddr)
|
||||
if natm != nil {
|
||||
if !realaddr.IP.IsLoopback() {
|
||||
go nat.Map(natm, udp.closing, "udp", realaddr.Port, realaddr.Port, "ethereum discovery")
|
||||
@ -151,11 +178,9 @@ func ListenUDP(priv *ecdsa.PrivateKey, laddr string, natm nat.Interface) (*Table
|
||||
}
|
||||
}
|
||||
udp.Table = newTable(udp, PubkeyID(&priv.PublicKey), realaddr)
|
||||
|
||||
go udp.loop()
|
||||
go udp.readLoop()
|
||||
log.Infoln("Listening, ", udp.self)
|
||||
return udp.Table, nil
|
||||
return udp.Table, udp
|
||||
}
|
||||
|
||||
func (t *udp) close() {
|
||||
@ -165,10 +190,11 @@ func (t *udp) close() {
|
||||
}
|
||||
|
||||
// ping sends a ping message to the given node and waits for a reply.
|
||||
func (t *udp) ping(e *Node) error {
|
||||
func (t *udp) ping(toid NodeID, toaddr *net.UDPAddr) error {
|
||||
// TODO: maybe check for ReplyTo field in callback to measure RTT
|
||||
errc := t.pending(e.ID, pongPacket, func(interface{}) bool { return true })
|
||||
t.send(e, pingPacket, ping{
|
||||
errc := t.pending(toid, pongPacket, func(interface{}) bool { return true })
|
||||
t.send(toaddr, pingPacket, ping{
|
||||
Version: Version,
|
||||
IP: t.self.IP.String(),
|
||||
Port: uint16(t.self.TCPPort),
|
||||
Expiration: uint64(time.Now().Add(expiration).Unix()),
|
||||
@ -176,12 +202,16 @@ func (t *udp) ping(e *Node) error {
|
||||
return <-errc
|
||||
}
|
||||
|
||||
func (t *udp) waitping(from NodeID) error {
|
||||
return <-t.pending(from, pingPacket, func(interface{}) bool { return true })
|
||||
}
|
||||
|
||||
// findnode sends a findnode request to the given node and waits until
|
||||
// the node has sent up to k neighbors.
|
||||
func (t *udp) findnode(to *Node, target NodeID) ([]*Node, error) {
|
||||
func (t *udp) findnode(toid NodeID, toaddr *net.UDPAddr, target NodeID) ([]*Node, error) {
|
||||
nodes := make([]*Node, 0, bucketSize)
|
||||
nreceived := 0
|
||||
errc := t.pending(to.ID, neighborsPacket, func(r interface{}) bool {
|
||||
errc := t.pending(toid, neighborsPacket, func(r interface{}) bool {
|
||||
reply := r.(*neighbors)
|
||||
for _, n := range reply.Nodes {
|
||||
nreceived++
|
||||
@ -191,8 +221,7 @@ func (t *udp) findnode(to *Node, target NodeID) ([]*Node, error) {
|
||||
}
|
||||
return nreceived >= bucketSize
|
||||
})
|
||||
|
||||
t.send(to, findnodePacket, findnode{
|
||||
t.send(toaddr, findnodePacket, findnode{
|
||||
Target: target,
|
||||
Expiration: uint64(time.Now().Add(expiration).Unix()),
|
||||
})
|
||||
@ -214,6 +243,17 @@ func (t *udp) pending(id NodeID, ptype byte, callback func(interface{}) bool) <-
|
||||
return ch
|
||||
}
|
||||
|
||||
func (t *udp) handleReply(from NodeID, ptype byte, req packet) bool {
|
||||
matched := make(chan bool)
|
||||
select {
|
||||
case t.gotreply <- reply{from, ptype, req, matched}:
|
||||
// loop will handle it
|
||||
return <-matched
|
||||
case <-t.closing:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// loop runs in its own goroutin. it keeps track of
|
||||
// the refresh timer and the pending reply queue.
|
||||
func (t *udp) loop() {
|
||||
@ -244,6 +284,7 @@ func (t *udp) loop() {
|
||||
for _, p := range pending {
|
||||
p.errc <- errClosed
|
||||
}
|
||||
pending = nil
|
||||
return
|
||||
|
||||
case p := <-t.addpending:
|
||||
@ -251,18 +292,21 @@ func (t *udp) loop() {
|
||||
pending = append(pending, p)
|
||||
rearmTimeout()
|
||||
|
||||
case reply := <-t.replies:
|
||||
// run matching callbacks, remove if they return false.
|
||||
case r := <-t.gotreply:
|
||||
var matched bool
|
||||
for i := 0; i < len(pending); i++ {
|
||||
p := pending[i]
|
||||
if reply.from == p.from && reply.ptype == p.ptype && p.callback(reply.data) {
|
||||
if p := pending[i]; p.from == r.from && p.ptype == r.ptype {
|
||||
matched = true
|
||||
if p.callback(r.data) {
|
||||
// callback indicates the request is done, remove it.
|
||||
p.errc <- nil
|
||||
copy(pending[i:], pending[i+1:])
|
||||
pending = pending[:len(pending)-1]
|
||||
i--
|
||||
}
|
||||
}
|
||||
rearmTimeout()
|
||||
}
|
||||
r.matched <- matched
|
||||
|
||||
case now := <-timeout.C:
|
||||
// notify and remove callbacks whose deadline is in the past.
|
||||
@ -287,28 +331,11 @@ const (
|
||||
|
||||
var headSpace = make([]byte, headSize)
|
||||
|
||||
func (t *udp) send(to *Node, ptype byte, req interface{}) error {
|
||||
b := new(bytes.Buffer)
|
||||
b.Write(headSpace)
|
||||
b.WriteByte(ptype)
|
||||
if err := rlp.Encode(b, req); err != nil {
|
||||
log.Errorln("error encoding packet:", err)
|
||||
return err
|
||||
}
|
||||
|
||||
packet := b.Bytes()
|
||||
sig, err := crypto.Sign(crypto.Sha3(packet[headSize:]), t.priv)
|
||||
func (t *udp) send(toaddr *net.UDPAddr, ptype byte, req interface{}) error {
|
||||
packet, err := encodePacket(t.priv, ptype, req)
|
||||
if err != nil {
|
||||
log.Errorln("could not sign packet:", err)
|
||||
return err
|
||||
}
|
||||
copy(packet[macSize:], sig)
|
||||
// add the hash to the front. Note: this doesn't protect the
|
||||
// packet in any way. Our public key will be part of this hash in
|
||||
// the future.
|
||||
copy(packet, crypto.Sha3(packet[macSize:]))
|
||||
|
||||
toaddr := &net.UDPAddr{IP: to.IP, Port: to.DiscPort}
|
||||
log.DebugDetailf(">>> %v %T %v\n", toaddr, req, req)
|
||||
if _, err = t.conn.WriteToUDP(packet, toaddr); err != nil {
|
||||
log.DebugDetailln("UDP send failed:", err)
|
||||
@ -316,6 +343,28 @@ func (t *udp) send(to *Node, ptype byte, req interface{}) error {
|
||||
return err
|
||||
}
|
||||
|
||||
func encodePacket(priv *ecdsa.PrivateKey, ptype byte, req interface{}) ([]byte, error) {
|
||||
b := new(bytes.Buffer)
|
||||
b.Write(headSpace)
|
||||
b.WriteByte(ptype)
|
||||
if err := rlp.Encode(b, req); err != nil {
|
||||
log.Errorln("error encoding packet:", err)
|
||||
return nil, err
|
||||
}
|
||||
packet := b.Bytes()
|
||||
sig, err := crypto.Sign(crypto.Sha3(packet[headSize:]), priv)
|
||||
if err != nil {
|
||||
log.Errorln("could not sign packet:", err)
|
||||
return nil, err
|
||||
}
|
||||
copy(packet[macSize:], sig)
|
||||
// add the hash to the front. Note: this doesn't protect the
|
||||
// packet in any way. Our public key will be part of this hash in
|
||||
// The future.
|
||||
copy(packet, crypto.Sha3(packet[macSize:]))
|
||||
return packet, nil
|
||||
}
|
||||
|
||||
// readLoop runs in its own goroutine. it handles incoming UDP packets.
|
||||
func (t *udp) readLoop() {
|
||||
defer t.conn.Close()
|
||||
@ -325,29 +374,34 @@ func (t *udp) readLoop() {
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if err := t.packetIn(from, buf[:nbytes]); err != nil {
|
||||
packet, fromID, hash, err := decodePacket(buf[:nbytes])
|
||||
if err != nil {
|
||||
log.Debugf("Bad packet from %v: %v\n", from, err)
|
||||
continue
|
||||
}
|
||||
log.DebugDetailf("<<< %v %T %v\n", from, packet, packet)
|
||||
go func() {
|
||||
if err := packet.handle(t, from, fromID, hash); err != nil {
|
||||
log.Debugf("error handling %T from %v: %v", packet, from, err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
}
|
||||
|
||||
func (t *udp) packetIn(from *net.UDPAddr, buf []byte) error {
|
||||
func decodePacket(buf []byte) (packet, NodeID, []byte, error) {
|
||||
if len(buf) < headSize+1 {
|
||||
return errPacketTooSmall
|
||||
return nil, NodeID{}, nil, errPacketTooSmall
|
||||
}
|
||||
hash, sig, sigdata := buf[:macSize], buf[macSize:headSize], buf[headSize:]
|
||||
shouldhash := crypto.Sha3(buf[macSize:])
|
||||
if !bytes.Equal(hash, shouldhash) {
|
||||
return errBadHash
|
||||
return nil, NodeID{}, nil, errBadHash
|
||||
}
|
||||
fromID, err := recoverNodeID(crypto.Sha3(buf[headSize:]), sig)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var req interface {
|
||||
handle(t *udp, from *net.UDPAddr, fromID NodeID, mac []byte) error
|
||||
return nil, NodeID{}, hash, err
|
||||
}
|
||||
var req packet
|
||||
switch ptype := sigdata[0]; ptype {
|
||||
case pingPacket:
|
||||
req = new(ping)
|
||||
@ -358,31 +412,27 @@ func (t *udp) packetIn(from *net.UDPAddr, buf []byte) error {
|
||||
case neighborsPacket:
|
||||
req = new(neighbors)
|
||||
default:
|
||||
return fmt.Errorf("unknown type: %d", ptype)
|
||||
return nil, fromID, hash, fmt.Errorf("unknown type: %d", ptype)
|
||||
}
|
||||
if err := rlp.Decode(bytes.NewReader(sigdata[1:]), req); err != nil {
|
||||
return err
|
||||
}
|
||||
log.DebugDetailf("<<< %v %T %v\n", from, req, req)
|
||||
return req.handle(t, from, fromID, hash)
|
||||
err = rlp.Decode(bytes.NewReader(sigdata[1:]), req)
|
||||
return req, fromID, hash, err
|
||||
}
|
||||
|
||||
func (req *ping) handle(t *udp, from *net.UDPAddr, fromID NodeID, mac []byte) error {
|
||||
if expired(req.Expiration) {
|
||||
return errExpired
|
||||
}
|
||||
t.mutex.Lock()
|
||||
// Note: we're ignoring the provided IP address right now
|
||||
n := t.bumpOrAdd(fromID, from)
|
||||
if req.Port != 0 {
|
||||
n.TCPPort = int(req.Port)
|
||||
if req.Version != Version {
|
||||
return errBadVersion
|
||||
}
|
||||
t.mutex.Unlock()
|
||||
|
||||
t.send(n, pongPacket, pong{
|
||||
t.send(from, pongPacket, pong{
|
||||
ReplyTok: mac,
|
||||
Expiration: uint64(time.Now().Add(expiration).Unix()),
|
||||
})
|
||||
if !t.handleReply(fromID, pingPacket, req) {
|
||||
// Note: we're ignoring the provided IP address right now
|
||||
t.bond(true, fromID, from, req.Port)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -390,11 +440,9 @@ func (req *pong) handle(t *udp, from *net.UDPAddr, fromID NodeID, mac []byte) er
|
||||
if expired(req.Expiration) {
|
||||
return errExpired
|
||||
}
|
||||
t.mutex.Lock()
|
||||
t.bump(fromID)
|
||||
t.mutex.Unlock()
|
||||
|
||||
t.replies <- reply{fromID, pongPacket, req}
|
||||
if !t.handleReply(fromID, pongPacket, req) {
|
||||
return errUnsolicitedReply
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -402,12 +450,21 @@ func (req *findnode) handle(t *udp, from *net.UDPAddr, fromID NodeID, mac []byte
|
||||
if expired(req.Expiration) {
|
||||
return errExpired
|
||||
}
|
||||
if t.db.get(fromID) == nil {
|
||||
// No bond exists, we don't process the packet. This prevents
|
||||
// an attack vector where the discovery protocol could be used
|
||||
// to amplify traffic in a DDOS attack. A malicious actor
|
||||
// would send a findnode request with the IP address and UDP
|
||||
// port of the target as the source address. The recipient of
|
||||
// the findnode packet would then send a neighbors packet
|
||||
// (which is a much bigger packet than findnode) to the victim.
|
||||
return errUnknownNode
|
||||
}
|
||||
t.mutex.Lock()
|
||||
e := t.bumpOrAdd(fromID, from)
|
||||
closest := t.closest(req.Target, bucketSize).entries
|
||||
t.mutex.Unlock()
|
||||
|
||||
t.send(e, neighborsPacket, neighbors{
|
||||
t.send(from, neighborsPacket, neighbors{
|
||||
Nodes: closest,
|
||||
Expiration: uint64(time.Now().Add(expiration).Unix()),
|
||||
})
|
||||
@ -418,12 +475,9 @@ func (req *neighbors) handle(t *udp, from *net.UDPAddr, fromID NodeID, mac []byt
|
||||
if expired(req.Expiration) {
|
||||
return errExpired
|
||||
}
|
||||
t.mutex.Lock()
|
||||
t.bump(fromID)
|
||||
t.add(req.Nodes)
|
||||
t.mutex.Unlock()
|
||||
|
||||
t.replies <- reply{fromID, neighborsPacket, req}
|
||||
if !t.handleReply(fromID, neighborsPacket, req) {
|
||||
return errUnsolicitedReply
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -1,10 +1,18 @@
|
||||
package discover
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/ecdsa"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
logpkg "log"
|
||||
"net"
|
||||
"os"
|
||||
"path"
|
||||
"reflect"
|
||||
"runtime"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@ -15,22 +23,243 @@ func init() {
|
||||
logger.AddLogSystem(logger.NewStdLogSystem(os.Stdout, logpkg.LstdFlags, logger.ErrorLevel))
|
||||
}
|
||||
|
||||
func TestUDP_ping(t *testing.T) {
|
||||
type udpTest struct {
|
||||
t *testing.T
|
||||
pipe *dgramPipe
|
||||
table *Table
|
||||
udp *udp
|
||||
sent [][]byte
|
||||
localkey, remotekey *ecdsa.PrivateKey
|
||||
remoteaddr *net.UDPAddr
|
||||
}
|
||||
|
||||
func newUDPTest(t *testing.T) *udpTest {
|
||||
test := &udpTest{
|
||||
t: t,
|
||||
pipe: newpipe(),
|
||||
localkey: newkey(),
|
||||
remotekey: newkey(),
|
||||
remoteaddr: &net.UDPAddr{IP: net.IP{1, 2, 3, 4}, Port: 30303},
|
||||
}
|
||||
test.table, test.udp = newUDP(test.localkey, test.pipe, nil)
|
||||
return test
|
||||
}
|
||||
|
||||
// handles a packet as if it had been sent to the transport.
|
||||
func (test *udpTest) packetIn(wantError error, ptype byte, data packet) error {
|
||||
enc, err := encodePacket(test.remotekey, ptype, data)
|
||||
if err != nil {
|
||||
return test.errorf("packet (%d) encode error: %v", err)
|
||||
}
|
||||
test.sent = append(test.sent, enc)
|
||||
err = data.handle(test.udp, test.remoteaddr, PubkeyID(&test.remotekey.PublicKey), enc[:macSize])
|
||||
if err != wantError {
|
||||
return test.errorf("error mismatch: got %q, want %q", err, wantError)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// waits for a packet to be sent by the transport.
|
||||
// validate should have type func(*udpTest, X) error, where X is a packet type.
|
||||
func (test *udpTest) waitPacketOut(validate interface{}) error {
|
||||
dgram := test.pipe.waitPacketOut()
|
||||
p, _, _, err := decodePacket(dgram)
|
||||
if err != nil {
|
||||
return test.errorf("sent packet decode error: %v", err)
|
||||
}
|
||||
fn := reflect.ValueOf(validate)
|
||||
exptype := fn.Type().In(0)
|
||||
if reflect.TypeOf(p) != exptype {
|
||||
return test.errorf("sent packet type mismatch, got: %v, want: %v", reflect.TypeOf(p), exptype)
|
||||
}
|
||||
fn.Call([]reflect.Value{reflect.ValueOf(p)})
|
||||
return nil
|
||||
}
|
||||
|
||||
func (test *udpTest) errorf(format string, args ...interface{}) error {
|
||||
_, file, line, ok := runtime.Caller(2) // errorf + waitPacketOut
|
||||
if ok {
|
||||
file = path.Base(file)
|
||||
} else {
|
||||
file = "???"
|
||||
line = 1
|
||||
}
|
||||
err := fmt.Errorf(format, args...)
|
||||
fmt.Printf("\t%s:%d: %v\n", file, line, err)
|
||||
test.t.Fail()
|
||||
return err
|
||||
}
|
||||
|
||||
// shared test variables
|
||||
var (
|
||||
futureExp = uint64(time.Now().Add(10 * time.Hour).Unix())
|
||||
testTarget = MustHexID("01010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101")
|
||||
)
|
||||
|
||||
func TestUDP_packetErrors(t *testing.T) {
|
||||
test := newUDPTest(t)
|
||||
defer test.table.Close()
|
||||
|
||||
test.packetIn(errExpired, pingPacket, &ping{IP: "foo", Port: 99, Version: Version})
|
||||
test.packetIn(errBadVersion, pingPacket, &ping{IP: "foo", Port: 99, Version: 99, Expiration: futureExp})
|
||||
test.packetIn(errUnsolicitedReply, pongPacket, &pong{ReplyTok: []byte{}, Expiration: futureExp})
|
||||
test.packetIn(errUnknownNode, findnodePacket, &findnode{Expiration: futureExp})
|
||||
test.packetIn(errUnsolicitedReply, neighborsPacket, &neighbors{Expiration: futureExp})
|
||||
}
|
||||
|
||||
func TestUDP_pingTimeout(t *testing.T) {
|
||||
t.Parallel()
|
||||
test := newUDPTest(t)
|
||||
defer test.table.Close()
|
||||
|
||||
n1, _ := ListenUDP(newkey(), "127.0.0.1:0", nil)
|
||||
n2, _ := ListenUDP(newkey(), "127.0.0.1:0", nil)
|
||||
defer n1.Close()
|
||||
defer n2.Close()
|
||||
toaddr := &net.UDPAddr{IP: net.ParseIP("1.2.3.4"), Port: 2222}
|
||||
toid := NodeID{1, 2, 3, 4}
|
||||
if err := test.udp.ping(toid, toaddr); err != errTimeout {
|
||||
t.Error("expected timeout error, got", err)
|
||||
}
|
||||
}
|
||||
|
||||
if err := n1.net.ping(n2.self); err != nil {
|
||||
t.Fatalf("ping error: %v", err)
|
||||
func TestUDP_findnodeTimeout(t *testing.T) {
|
||||
t.Parallel()
|
||||
test := newUDPTest(t)
|
||||
defer test.table.Close()
|
||||
|
||||
toaddr := &net.UDPAddr{IP: net.ParseIP("1.2.3.4"), Port: 2222}
|
||||
toid := NodeID{1, 2, 3, 4}
|
||||
target := NodeID{4, 5, 6, 7}
|
||||
result, err := test.udp.findnode(toid, toaddr, target)
|
||||
if err != errTimeout {
|
||||
t.Error("expected timeout error, got", err)
|
||||
}
|
||||
if find(n2, n1.self.ID) == nil {
|
||||
t.Errorf("node 2 does not contain id of node 1")
|
||||
if len(result) > 0 {
|
||||
t.Error("expected empty result, got", result)
|
||||
}
|
||||
if e := find(n1, n2.self.ID); e != nil {
|
||||
t.Errorf("node 1 does contains id of node 2: %v", e)
|
||||
}
|
||||
|
||||
func TestUDP_findnode(t *testing.T) {
|
||||
test := newUDPTest(t)
|
||||
defer test.table.Close()
|
||||
|
||||
// put a few nodes into the table. their exact
|
||||
// distribution shouldn't matter much, altough we need to
|
||||
// take care not to overflow any bucket.
|
||||
target := testTarget
|
||||
nodes := &nodesByDistance{target: target}
|
||||
for i := 0; i < bucketSize; i++ {
|
||||
nodes.push(&Node{
|
||||
IP: net.IP{1, 2, 3, byte(i)},
|
||||
DiscPort: i + 2,
|
||||
TCPPort: i + 2,
|
||||
ID: randomID(test.table.self.ID, i+2),
|
||||
}, bucketSize)
|
||||
}
|
||||
test.table.add(nodes.entries)
|
||||
|
||||
// ensure there's a bond with the test node,
|
||||
// findnode won't be accepted otherwise.
|
||||
test.table.db.add(PubkeyID(&test.remotekey.PublicKey), test.remoteaddr, 99)
|
||||
|
||||
// check that closest neighbors are returned.
|
||||
test.packetIn(nil, findnodePacket, &findnode{Target: testTarget, Expiration: futureExp})
|
||||
test.waitPacketOut(func(p *neighbors) {
|
||||
expected := test.table.closest(testTarget, bucketSize)
|
||||
if len(p.Nodes) != bucketSize {
|
||||
t.Errorf("wrong number of results: got %d, want %d", len(p.Nodes), bucketSize)
|
||||
}
|
||||
for i := range p.Nodes {
|
||||
if p.Nodes[i].ID != expected.entries[i].ID {
|
||||
t.Errorf("result mismatch at %d:\n got: %v\n want: %v", i, p.Nodes[i], expected.entries[i])
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestUDP_findnodeMultiReply(t *testing.T) {
|
||||
test := newUDPTest(t)
|
||||
defer test.table.Close()
|
||||
|
||||
// queue a pending findnode request
|
||||
resultc, errc := make(chan []*Node), make(chan error)
|
||||
go func() {
|
||||
rid := PubkeyID(&test.remotekey.PublicKey)
|
||||
ns, err := test.udp.findnode(rid, test.remoteaddr, testTarget)
|
||||
if err != nil && len(ns) == 0 {
|
||||
errc <- err
|
||||
} else {
|
||||
resultc <- ns
|
||||
}
|
||||
}()
|
||||
|
||||
// wait for the findnode to be sent.
|
||||
// after it is sent, the transport is waiting for a reply
|
||||
test.waitPacketOut(func(p *findnode) {
|
||||
if p.Target != testTarget {
|
||||
t.Errorf("wrong target: got %v, want %v", p.Target, testTarget)
|
||||
}
|
||||
})
|
||||
|
||||
// send the reply as two packets.
|
||||
list := []*Node{
|
||||
MustParseNode("enode://ba85011c70bcc5c04d8607d3a0ed29aa6179c092cbdda10d5d32684fb33ed01bd94f588ca8f91ac48318087dcb02eaf36773a7a453f0eedd6742af668097b29c@10.0.1.16:30303"),
|
||||
MustParseNode("enode://81fa361d25f157cd421c60dcc28d8dac5ef6a89476633339c5df30287474520caca09627da18543d9079b5b288698b542d56167aa5c09111e55acdbbdf2ef799@10.0.1.16:30303"),
|
||||
MustParseNode("enode://9bffefd833d53fac8e652415f4973bee289e8b1a5c6c4cbe70abf817ce8a64cee11b823b66a987f51aaa9fba0d6a91b3e6bf0d5a5d1042de8e9eeea057b217f8@10.0.1.36:30301"),
|
||||
MustParseNode("enode://1b5b4aa662d7cb44a7221bfba67302590b643028197a7d5214790f3bac7aaa4a3241be9e83c09cf1f6c69d007c634faae3dc1b1221793e8446c0b3a09de65960@10.0.1.16:30303"),
|
||||
}
|
||||
test.packetIn(nil, neighborsPacket, &neighbors{Expiration: futureExp, Nodes: list[:2]})
|
||||
test.packetIn(nil, neighborsPacket, &neighbors{Expiration: futureExp, Nodes: list[2:]})
|
||||
|
||||
// check that the sent neighbors are all returned by findnode
|
||||
select {
|
||||
case result := <-resultc:
|
||||
if !reflect.DeepEqual(result, list) {
|
||||
t.Errorf("neighbors mismatch:\n got: %v\n want: %v", result, list)
|
||||
}
|
||||
case err := <-errc:
|
||||
t.Errorf("findnode error: %v", err)
|
||||
case <-time.After(5 * time.Second):
|
||||
t.Error("findnode did not return within 5 seconds")
|
||||
}
|
||||
}
|
||||
|
||||
func TestUDP_successfulPing(t *testing.T) {
|
||||
test := newUDPTest(t)
|
||||
defer test.table.Close()
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
test.packetIn(nil, pingPacket, &ping{IP: "foo", Port: 99, Version: Version, Expiration: futureExp})
|
||||
close(done)
|
||||
}()
|
||||
|
||||
// the ping is replied to.
|
||||
test.waitPacketOut(func(p *pong) {
|
||||
pinghash := test.sent[0][:macSize]
|
||||
if !bytes.Equal(p.ReplyTok, pinghash) {
|
||||
t.Errorf("got ReplyTok %x, want %x", p.ReplyTok, pinghash)
|
||||
}
|
||||
})
|
||||
|
||||
// remote is unknown, the table pings back.
|
||||
test.waitPacketOut(func(p *ping) error { return nil })
|
||||
test.packetIn(nil, pongPacket, &pong{Expiration: futureExp})
|
||||
|
||||
// ping should return shortly after getting the pong packet.
|
||||
<-done
|
||||
|
||||
// check that the node was added.
|
||||
rid := PubkeyID(&test.remotekey.PublicKey)
|
||||
rnode := find(test.table, rid)
|
||||
if rnode == nil {
|
||||
t.Fatalf("node %v not found in table", rid)
|
||||
}
|
||||
if !bytes.Equal(rnode.IP, test.remoteaddr.IP) {
|
||||
t.Errorf("node has wrong IP: got %v, want: %v", rnode.IP, test.remoteaddr.IP)
|
||||
}
|
||||
if rnode.DiscPort != test.remoteaddr.Port {
|
||||
t.Errorf("node has wrong Port: got %v, want: %v", rnode.DiscPort, test.remoteaddr.Port)
|
||||
}
|
||||
if rnode.TCPPort != 99 {
|
||||
t.Errorf("node has wrong Port: got %v, want: %v", rnode.TCPPort, 99)
|
||||
}
|
||||
}
|
||||
|
||||
@ -45,167 +274,66 @@ func find(tab *Table, id NodeID) *Node {
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestUDP_findnode(t *testing.T) {
|
||||
t.Parallel()
|
||||
// dgramPipe is a fake UDP socket. It queues all sent datagrams.
|
||||
type dgramPipe struct {
|
||||
mu *sync.Mutex
|
||||
cond *sync.Cond
|
||||
closing chan struct{}
|
||||
closed bool
|
||||
queue [][]byte
|
||||
}
|
||||
|
||||
n1, _ := ListenUDP(newkey(), "127.0.0.1:0", nil)
|
||||
n2, _ := ListenUDP(newkey(), "127.0.0.1:0", nil)
|
||||
defer n1.Close()
|
||||
defer n2.Close()
|
||||
|
||||
// put a few nodes into n2. the exact distribution shouldn't
|
||||
// matter much, altough we need to take care not to overflow
|
||||
// any bucket.
|
||||
target := randomID(n1.self.ID, 100)
|
||||
nodes := &nodesByDistance{target: target}
|
||||
for i := 0; i < bucketSize; i++ {
|
||||
n2.add([]*Node{&Node{
|
||||
IP: net.IP{1, 2, 3, byte(i)},
|
||||
DiscPort: i + 2,
|
||||
TCPPort: i + 2,
|
||||
ID: randomID(n2.self.ID, i+2),
|
||||
}})
|
||||
}
|
||||
n2.add(nodes.entries)
|
||||
n2.bumpOrAdd(n1.self.ID, &net.UDPAddr{IP: n1.self.IP, Port: n1.self.DiscPort})
|
||||
expected := n2.closest(target, bucketSize)
|
||||
|
||||
err := runUDP(10, func() error {
|
||||
result, _ := n1.net.findnode(n2.self, target)
|
||||
if len(result) != bucketSize {
|
||||
return fmt.Errorf("wrong number of results: got %d, want %d", len(result), bucketSize)
|
||||
}
|
||||
for i := range result {
|
||||
if result[i].ID != expected.entries[i].ID {
|
||||
return fmt.Errorf("result mismatch at %d:\n got: %v\n want: %v", i, result[i], expected.entries[i])
|
||||
}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
func newpipe() *dgramPipe {
|
||||
mu := new(sync.Mutex)
|
||||
return &dgramPipe{
|
||||
closing: make(chan struct{}),
|
||||
cond: &sync.Cond{L: mu},
|
||||
mu: mu,
|
||||
}
|
||||
}
|
||||
|
||||
func TestUDP_replytimeout(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// reserve a port so we don't talk to an existing service by accident
|
||||
addr, _ := net.ResolveUDPAddr("udp", "127.0.0.1:0")
|
||||
fd, err := net.ListenUDP("udp", addr)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer fd.Close()
|
||||
|
||||
n1, _ := ListenUDP(newkey(), "127.0.0.1:0", nil)
|
||||
defer n1.Close()
|
||||
n2 := n1.bumpOrAdd(randomID(n1.self.ID, 10), fd.LocalAddr().(*net.UDPAddr))
|
||||
|
||||
if err := n1.net.ping(n2); err != errTimeout {
|
||||
t.Error("expected timeout error, got", err)
|
||||
}
|
||||
|
||||
if result, err := n1.net.findnode(n2, n1.self.ID); err != errTimeout {
|
||||
t.Error("expected timeout error, got", err)
|
||||
} else if len(result) > 0 {
|
||||
t.Error("expected empty result, got", result)
|
||||
// WriteToUDP queues a datagram.
|
||||
func (c *dgramPipe) WriteToUDP(b []byte, to *net.UDPAddr) (n int, err error) {
|
||||
msg := make([]byte, len(b))
|
||||
copy(msg, b)
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
if c.closed {
|
||||
return 0, errors.New("closed")
|
||||
}
|
||||
c.queue = append(c.queue, msg)
|
||||
c.cond.Signal()
|
||||
return len(b), nil
|
||||
}
|
||||
|
||||
func TestUDP_findnodeMultiReply(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
n1, _ := ListenUDP(newkey(), "127.0.0.1:0", nil)
|
||||
n2, _ := ListenUDP(newkey(), "127.0.0.1:0", nil)
|
||||
udp2 := n2.net.(*udp)
|
||||
defer n1.Close()
|
||||
defer n2.Close()
|
||||
|
||||
err := runUDP(10, func() error {
|
||||
nodes := make([]*Node, bucketSize)
|
||||
for i := range nodes {
|
||||
nodes[i] = &Node{
|
||||
IP: net.IP{1, 2, 3, 4},
|
||||
DiscPort: i + 1,
|
||||
TCPPort: i + 1,
|
||||
ID: randomID(n2.self.ID, i+1),
|
||||
}
|
||||
}
|
||||
|
||||
// ask N2 for neighbors. it will send an empty reply back.
|
||||
// the request will wait for up to bucketSize replies.
|
||||
resultc := make(chan []*Node)
|
||||
errc := make(chan error)
|
||||
go func() {
|
||||
ns, err := n1.net.findnode(n2.self, n1.self.ID)
|
||||
if err != nil {
|
||||
errc <- err
|
||||
} else {
|
||||
resultc <- ns
|
||||
}
|
||||
}()
|
||||
|
||||
// send a few more neighbors packets to N1.
|
||||
// it should collect those.
|
||||
for end := 0; end < len(nodes); {
|
||||
off := end
|
||||
if end = end + 5; end > len(nodes) {
|
||||
end = len(nodes)
|
||||
}
|
||||
udp2.send(n1.self, neighborsPacket, neighbors{
|
||||
Nodes: nodes[off:end],
|
||||
Expiration: uint64(time.Now().Add(10 * time.Second).Unix()),
|
||||
})
|
||||
}
|
||||
|
||||
// check that they are all returned. we cannot just check for
|
||||
// equality because they might not be returned in the order they
|
||||
// were sent.
|
||||
var result []*Node
|
||||
select {
|
||||
case result = <-resultc:
|
||||
case err := <-errc:
|
||||
return err
|
||||
}
|
||||
if hasDuplicates(result) {
|
||||
return fmt.Errorf("result slice contains duplicates")
|
||||
}
|
||||
if len(result) != len(nodes) {
|
||||
return fmt.Errorf("wrong number of nodes returned: got %d, want %d", len(result), len(nodes))
|
||||
}
|
||||
matched := make(map[NodeID]bool)
|
||||
for _, n := range result {
|
||||
for _, expn := range nodes {
|
||||
if n.ID == expn.ID { // && bytes.Equal(n.Addr.IP, expn.Addr.IP) && n.Addr.Port == expn.Addr.Port {
|
||||
matched[n.ID] = true
|
||||
}
|
||||
}
|
||||
}
|
||||
if len(matched) != len(nodes) {
|
||||
return fmt.Errorf("wrong number of matching nodes: got %d, want %d", len(matched), len(nodes))
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
// ReadFromUDP just hangs until the pipe is closed.
|
||||
func (c *dgramPipe) ReadFromUDP(b []byte) (n int, addr *net.UDPAddr, err error) {
|
||||
<-c.closing
|
||||
return 0, nil, io.EOF
|
||||
}
|
||||
|
||||
// runUDP runs a test n times and returns an error if the test failed
|
||||
// in all n runs. This is necessary because UDP is unreliable even for
|
||||
// connections on the local machine, causing test failures.
|
||||
func runUDP(n int, test func() error) error {
|
||||
errcount := 0
|
||||
errors := ""
|
||||
for i := 0; i < n; i++ {
|
||||
if err := test(); err != nil {
|
||||
errors += fmt.Sprintf("\n#%d: %v", i, err)
|
||||
errcount++
|
||||
}
|
||||
}
|
||||
if errcount == n {
|
||||
return fmt.Errorf("failed on all %d iterations:%s", n, errors)
|
||||
func (c *dgramPipe) Close() error {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
if !c.closed {
|
||||
close(c.closing)
|
||||
c.closed = true
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *dgramPipe) LocalAddr() net.Addr {
|
||||
return &net.UDPAddr{}
|
||||
}
|
||||
|
||||
func (c *dgramPipe) waitPacketOut() []byte {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
for len(c.queue) == 0 {
|
||||
c.cond.Wait()
|
||||
}
|
||||
p := c.queue[0]
|
||||
copy(c.queue, c.queue[1:])
|
||||
c.queue = c.queue[:len(c.queue)-1]
|
||||
return p
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user