forked from cerc-io/plugeth
Merge pull request #592 from fjl/disco-ping-pong
Discovery bonding protocol
This commit is contained in:
commit
fd171eff7f
@ -32,8 +32,8 @@ var (
|
|||||||
defaultBootNodes = []*discover.Node{
|
defaultBootNodes = []*discover.Node{
|
||||||
// ETH/DEV cmd/bootnode
|
// ETH/DEV cmd/bootnode
|
||||||
discover.MustParseNode("enode://09fbeec0d047e9a37e63f60f8618aa9df0e49271f3fadb2c070dc09e2099b95827b63a8b837c6fd01d0802d457dd83e3bd48bd3e6509f8209ed90dabbc30e3d3@52.16.188.185:30303"),
|
discover.MustParseNode("enode://09fbeec0d047e9a37e63f60f8618aa9df0e49271f3fadb2c070dc09e2099b95827b63a8b837c6fd01d0802d457dd83e3bd48bd3e6509f8209ed90dabbc30e3d3@52.16.188.185:30303"),
|
||||||
// ETH/DEV cpp-ethereum (poc-8.ethdev.com)
|
// ETH/DEV cpp-ethereum (poc-9.ethdev.com)
|
||||||
discover.MustParseNode("enode://4a44599974518ea5b0f14c31c4463692ac0329cb84851f3435e6d1b18ee4eae4aa495f846a0fa1219bd58035671881d44423876e57db2abd57254d0197da0ebe@5.1.83.226:30303"),
|
discover.MustParseNode("enode://487611428e6c99a11a9795a6abe7b529e81315ca6aad66e2a2fc76e3adf263faba0d35466c2f8f68d561dbefa8878d4df5f1f2ddb1fbeab7f42ffb8cd328bd4a@5.1.83.226:30303"),
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -13,6 +13,8 @@ import (
|
|||||||
"net/url"
|
"net/url"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/ethereum/go-ethereum/crypto"
|
"github.com/ethereum/go-ethereum/crypto"
|
||||||
@ -30,7 +32,8 @@ type Node struct {
|
|||||||
DiscPort int // UDP listening port for discovery protocol
|
DiscPort int // UDP listening port for discovery protocol
|
||||||
TCPPort int // TCP listening port for RLPx
|
TCPPort int // TCP listening port for RLPx
|
||||||
|
|
||||||
active time.Time
|
// this must be set/read using atomic load and store.
|
||||||
|
activeStamp int64
|
||||||
}
|
}
|
||||||
|
|
||||||
func newNode(id NodeID, addr *net.UDPAddr) *Node {
|
func newNode(id NodeID, addr *net.UDPAddr) *Node {
|
||||||
@ -39,7 +42,6 @@ func newNode(id NodeID, addr *net.UDPAddr) *Node {
|
|||||||
IP: addr.IP,
|
IP: addr.IP,
|
||||||
DiscPort: addr.Port,
|
DiscPort: addr.Port,
|
||||||
TCPPort: addr.Port,
|
TCPPort: addr.Port,
|
||||||
active: time.Now(),
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -48,6 +50,20 @@ func (n *Node) isValid() bool {
|
|||||||
return !n.IP.IsMulticast() && !n.IP.IsUnspecified() && n.TCPPort != 0 && n.DiscPort != 0
|
return !n.IP.IsMulticast() && !n.IP.IsUnspecified() && n.TCPPort != 0 && n.DiscPort != 0
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (n *Node) bumpActive() {
|
||||||
|
stamp := time.Now().Unix()
|
||||||
|
atomic.StoreInt64(&n.activeStamp, stamp)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (n *Node) active() time.Time {
|
||||||
|
stamp := atomic.LoadInt64(&n.activeStamp)
|
||||||
|
return time.Unix(stamp, 0)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (n *Node) addr() *net.UDPAddr {
|
||||||
|
return &net.UDPAddr{IP: n.IP, Port: n.DiscPort}
|
||||||
|
}
|
||||||
|
|
||||||
// The string representation of a Node is a URL.
|
// The string representation of a Node is a URL.
|
||||||
// Please see ParseNode for a description of the format.
|
// Please see ParseNode for a description of the format.
|
||||||
func (n *Node) String() string {
|
func (n *Node) String() string {
|
||||||
@ -304,3 +320,26 @@ func randomID(a NodeID, n int) (b NodeID) {
|
|||||||
}
|
}
|
||||||
return b
|
return b
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// nodeDB stores all nodes we know about.
|
||||||
|
type nodeDB struct {
|
||||||
|
mu sync.RWMutex
|
||||||
|
byID map[NodeID]*Node
|
||||||
|
}
|
||||||
|
|
||||||
|
func (db *nodeDB) get(id NodeID) *Node {
|
||||||
|
db.mu.RLock()
|
||||||
|
defer db.mu.RUnlock()
|
||||||
|
return db.byID[id]
|
||||||
|
}
|
||||||
|
|
||||||
|
func (db *nodeDB) add(id NodeID, addr *net.UDPAddr, tcpPort uint16) *Node {
|
||||||
|
db.mu.Lock()
|
||||||
|
defer db.mu.Unlock()
|
||||||
|
if db.byID == nil {
|
||||||
|
db.byID = make(map[NodeID]*Node)
|
||||||
|
}
|
||||||
|
n := &Node{ID: id, IP: addr.IP, DiscPort: addr.Port, TCPPort: int(tcpPort)}
|
||||||
|
db.byID[n.ID] = n
|
||||||
|
return n
|
||||||
|
}
|
||||||
|
@ -14,9 +14,10 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
alpha = 3 // Kademlia concurrency factor
|
alpha = 3 // Kademlia concurrency factor
|
||||||
bucketSize = 16 // Kademlia bucket size
|
bucketSize = 16 // Kademlia bucket size
|
||||||
nBuckets = nodeIDBits + 1 // Number of buckets
|
nBuckets = nodeIDBits + 1 // Number of buckets
|
||||||
|
maxBondingPingPongs = 10
|
||||||
)
|
)
|
||||||
|
|
||||||
type Table struct {
|
type Table struct {
|
||||||
@ -24,27 +25,50 @@ type Table struct {
|
|||||||
buckets [nBuckets]*bucket // index of known nodes by distance
|
buckets [nBuckets]*bucket // index of known nodes by distance
|
||||||
nursery []*Node // bootstrap nodes
|
nursery []*Node // bootstrap nodes
|
||||||
|
|
||||||
|
bondmu sync.Mutex
|
||||||
|
bonding map[NodeID]*bondproc
|
||||||
|
bondslots chan struct{} // limits total number of active bonding processes
|
||||||
|
|
||||||
net transport
|
net transport
|
||||||
self *Node // metadata of the local node
|
self *Node // metadata of the local node
|
||||||
|
db *nodeDB
|
||||||
|
}
|
||||||
|
|
||||||
|
type bondproc struct {
|
||||||
|
err error
|
||||||
|
n *Node
|
||||||
|
done chan struct{}
|
||||||
}
|
}
|
||||||
|
|
||||||
// transport is implemented by the UDP transport.
|
// transport is implemented by the UDP transport.
|
||||||
// it is an interface so we can test without opening lots of UDP
|
// it is an interface so we can test without opening lots of UDP
|
||||||
// sockets and without generating a private key.
|
// sockets and without generating a private key.
|
||||||
type transport interface {
|
type transport interface {
|
||||||
ping(*Node) error
|
ping(NodeID, *net.UDPAddr) error
|
||||||
findnode(e *Node, target NodeID) ([]*Node, error)
|
waitping(NodeID) error
|
||||||
|
findnode(toid NodeID, addr *net.UDPAddr, target NodeID) ([]*Node, error)
|
||||||
close()
|
close()
|
||||||
}
|
}
|
||||||
|
|
||||||
// bucket contains nodes, ordered by their last activity.
|
// bucket contains nodes, ordered by their last activity.
|
||||||
|
// the entry that was most recently active is the last element
|
||||||
|
// in entries.
|
||||||
type bucket struct {
|
type bucket struct {
|
||||||
lastLookup time.Time
|
lastLookup time.Time
|
||||||
entries []*Node
|
entries []*Node
|
||||||
}
|
}
|
||||||
|
|
||||||
func newTable(t transport, ourID NodeID, ourAddr *net.UDPAddr) *Table {
|
func newTable(t transport, ourID NodeID, ourAddr *net.UDPAddr) *Table {
|
||||||
tab := &Table{net: t, self: newNode(ourID, ourAddr)}
|
tab := &Table{
|
||||||
|
net: t,
|
||||||
|
db: new(nodeDB),
|
||||||
|
self: newNode(ourID, ourAddr),
|
||||||
|
bonding: make(map[NodeID]*bondproc),
|
||||||
|
bondslots: make(chan struct{}, maxBondingPingPongs),
|
||||||
|
}
|
||||||
|
for i := 0; i < cap(tab.bondslots); i++ {
|
||||||
|
tab.bondslots <- struct{}{}
|
||||||
|
}
|
||||||
for i := range tab.buckets {
|
for i := range tab.buckets {
|
||||||
tab.buckets[i] = new(bucket)
|
tab.buckets[i] = new(bucket)
|
||||||
}
|
}
|
||||||
@ -107,8 +131,8 @@ func (tab *Table) Lookup(target NodeID) []*Node {
|
|||||||
asked[n.ID] = true
|
asked[n.ID] = true
|
||||||
pendingQueries++
|
pendingQueries++
|
||||||
go func() {
|
go func() {
|
||||||
result, _ := tab.net.findnode(n, target)
|
r, _ := tab.net.findnode(n.ID, n.addr(), target)
|
||||||
reply <- result
|
reply <- tab.bondall(r)
|
||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -116,13 +140,11 @@ func (tab *Table) Lookup(target NodeID) []*Node {
|
|||||||
// we have asked all closest nodes, stop the search
|
// we have asked all closest nodes, stop the search
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
|
|
||||||
// wait for the next reply
|
// wait for the next reply
|
||||||
for _, n := range <-reply {
|
for _, n := range <-reply {
|
||||||
cn := n
|
if n != nil && !seen[n.ID] {
|
||||||
if !seen[n.ID] {
|
|
||||||
seen[n.ID] = true
|
seen[n.ID] = true
|
||||||
result.push(cn, bucketSize)
|
result.push(n, bucketSize)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
pendingQueries--
|
pendingQueries--
|
||||||
@ -145,8 +167,9 @@ func (tab *Table) refresh() {
|
|||||||
result := tab.Lookup(randomID(tab.self.ID, ld))
|
result := tab.Lookup(randomID(tab.self.ID, ld))
|
||||||
if len(result) == 0 {
|
if len(result) == 0 {
|
||||||
// bootstrap the table with a self lookup
|
// bootstrap the table with a self lookup
|
||||||
|
all := tab.bondall(tab.nursery)
|
||||||
tab.mutex.Lock()
|
tab.mutex.Lock()
|
||||||
tab.add(tab.nursery)
|
tab.add(all)
|
||||||
tab.mutex.Unlock()
|
tab.mutex.Unlock()
|
||||||
tab.Lookup(tab.self.ID)
|
tab.Lookup(tab.self.ID)
|
||||||
// TODO: the Kademlia paper says that we're supposed to perform
|
// TODO: the Kademlia paper says that we're supposed to perform
|
||||||
@ -176,45 +199,105 @@ func (tab *Table) len() (n int) {
|
|||||||
return n
|
return n
|
||||||
}
|
}
|
||||||
|
|
||||||
// bumpOrAdd updates the activity timestamp for the given node and
|
// bondall bonds with all given nodes concurrently and returns
|
||||||
// attempts to insert the node into a bucket. The returned Node might
|
// those nodes for which bonding has probably succeeded.
|
||||||
// not be part of the table. The caller must hold tab.mutex.
|
func (tab *Table) bondall(nodes []*Node) (result []*Node) {
|
||||||
func (tab *Table) bumpOrAdd(node NodeID, from *net.UDPAddr) (n *Node) {
|
rc := make(chan *Node, len(nodes))
|
||||||
b := tab.buckets[logdist(tab.self.ID, node)]
|
for i := range nodes {
|
||||||
if n = b.bump(node); n == nil {
|
go func(n *Node) {
|
||||||
n = newNode(node, from)
|
nn, _ := tab.bond(false, n.ID, n.addr(), uint16(n.TCPPort))
|
||||||
if len(b.entries) == bucketSize {
|
rc <- nn
|
||||||
tab.pingReplace(n, b)
|
}(nodes[i])
|
||||||
} else {
|
}
|
||||||
b.entries = append(b.entries, n)
|
for _ = range nodes {
|
||||||
|
if n := <-rc; n != nil {
|
||||||
|
result = append(result, n)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return n
|
return result
|
||||||
}
|
}
|
||||||
|
|
||||||
func (tab *Table) pingReplace(n *Node, b *bucket) {
|
// bond ensures the local node has a bond with the given remote node.
|
||||||
old := b.entries[bucketSize-1]
|
// It also attempts to insert the node into the table if bonding succeeds.
|
||||||
go func() {
|
// The caller must not hold tab.mutex.
|
||||||
if err := tab.net.ping(old); err == nil {
|
//
|
||||||
// it responded, we don't need to replace it.
|
// A bond is must be established before sending findnode requests.
|
||||||
|
// Both sides must have completed a ping/pong exchange for a bond to
|
||||||
|
// exist. The total number of active bonding processes is limited in
|
||||||
|
// order to restrain network use.
|
||||||
|
//
|
||||||
|
// bond is meant to operate idempotently in that bonding with a remote
|
||||||
|
// node which still remembers a previously established bond will work.
|
||||||
|
// The remote node will simply not send a ping back, causing waitping
|
||||||
|
// to time out.
|
||||||
|
//
|
||||||
|
// If pinged is true, the remote node has just pinged us and one half
|
||||||
|
// of the process can be skipped.
|
||||||
|
func (tab *Table) bond(pinged bool, id NodeID, addr *net.UDPAddr, tcpPort uint16) (*Node, error) {
|
||||||
|
var n *Node
|
||||||
|
if n = tab.db.get(id); n == nil {
|
||||||
|
tab.bondmu.Lock()
|
||||||
|
w := tab.bonding[id]
|
||||||
|
if w != nil {
|
||||||
|
// Wait for an existing bonding process to complete.
|
||||||
|
tab.bondmu.Unlock()
|
||||||
|
<-w.done
|
||||||
|
} else {
|
||||||
|
// Register a new bonding process.
|
||||||
|
w = &bondproc{done: make(chan struct{})}
|
||||||
|
tab.bonding[id] = w
|
||||||
|
tab.bondmu.Unlock()
|
||||||
|
// Do the ping/pong. The result goes into w.
|
||||||
|
tab.pingpong(w, pinged, id, addr, tcpPort)
|
||||||
|
// Unregister the process after it's done.
|
||||||
|
tab.bondmu.Lock()
|
||||||
|
delete(tab.bonding, id)
|
||||||
|
tab.bondmu.Unlock()
|
||||||
|
}
|
||||||
|
n = w.n
|
||||||
|
if w.err != nil {
|
||||||
|
return nil, w.err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
tab.mutex.Lock()
|
||||||
|
defer tab.mutex.Unlock()
|
||||||
|
if b := tab.buckets[logdist(tab.self.ID, n.ID)]; !b.bump(n) {
|
||||||
|
tab.pingreplace(n, b)
|
||||||
|
}
|
||||||
|
return n, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (tab *Table) pingpong(w *bondproc, pinged bool, id NodeID, addr *net.UDPAddr, tcpPort uint16) {
|
||||||
|
<-tab.bondslots
|
||||||
|
defer func() { tab.bondslots <- struct{}{} }()
|
||||||
|
if w.err = tab.net.ping(id, addr); w.err != nil {
|
||||||
|
close(w.done)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if !pinged {
|
||||||
|
// Give the remote node a chance to ping us before we start
|
||||||
|
// sending findnode requests. If they still remember us,
|
||||||
|
// waitping will simply time out.
|
||||||
|
tab.net.waitping(id)
|
||||||
|
}
|
||||||
|
w.n = tab.db.add(id, addr, tcpPort)
|
||||||
|
close(w.done)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (tab *Table) pingreplace(new *Node, b *bucket) {
|
||||||
|
if len(b.entries) == bucketSize {
|
||||||
|
oldest := b.entries[bucketSize-1]
|
||||||
|
if err := tab.net.ping(oldest.ID, oldest.addr()); err == nil {
|
||||||
|
// The node responded, we don't need to replace it.
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
// it didn't respond, replace the node if it is still the oldest node.
|
} else {
|
||||||
tab.mutex.Lock()
|
// Add a slot at the end so the last entry doesn't
|
||||||
if len(b.entries) > 0 && b.entries[len(b.entries)-1] == old {
|
// fall off when adding the new node.
|
||||||
// slide down other entries and put the new one in front.
|
b.entries = append(b.entries, nil)
|
||||||
// TODO: insert in correct position to keep the order
|
}
|
||||||
copy(b.entries[1:], b.entries)
|
copy(b.entries[1:], b.entries)
|
||||||
b.entries[0] = n
|
b.entries[0] = new
|
||||||
}
|
|
||||||
tab.mutex.Unlock()
|
|
||||||
}()
|
|
||||||
}
|
|
||||||
|
|
||||||
// bump updates the activity timestamp for the given node.
|
|
||||||
// The caller must hold tab.mutex.
|
|
||||||
func (tab *Table) bump(node NodeID) {
|
|
||||||
tab.buckets[logdist(tab.self.ID, node)].bump(node)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// add puts the entries into the table if their corresponding
|
// add puts the entries into the table if their corresponding
|
||||||
@ -240,17 +323,17 @@ outer:
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b *bucket) bump(id NodeID) *Node {
|
func (b *bucket) bump(n *Node) bool {
|
||||||
for i, n := range b.entries {
|
for i := range b.entries {
|
||||||
if n.ID == id {
|
if b.entries[i].ID == n.ID {
|
||||||
n.active = time.Now()
|
n.bumpActive()
|
||||||
// move it to the front
|
// move it to the front
|
||||||
copy(b.entries[1:], b.entries[:i+1])
|
copy(b.entries[1:], b.entries[:i])
|
||||||
b.entries[0] = n
|
b.entries[0] = n
|
||||||
return n
|
return true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return nil
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
// nodesByDistance is a list of nodes, ordered by
|
// nodesByDistance is a list of nodes, ordered by
|
||||||
|
@ -2,78 +2,109 @@ package discover
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"crypto/ecdsa"
|
"crypto/ecdsa"
|
||||||
"errors"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"math/rand"
|
"math/rand"
|
||||||
"net"
|
"net"
|
||||||
"reflect"
|
"reflect"
|
||||||
"testing"
|
"testing"
|
||||||
"testing/quick"
|
"testing/quick"
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/ethereum/go-ethereum/crypto"
|
"github.com/ethereum/go-ethereum/crypto"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestTable_bumpOrAddBucketAssign(t *testing.T) {
|
func TestTable_pingReplace(t *testing.T) {
|
||||||
tab := newTable(nil, NodeID{}, &net.UDPAddr{})
|
doit := func(newNodeIsResponding, lastInBucketIsResponding bool) {
|
||||||
for i := 1; i < len(tab.buckets); i++ {
|
transport := newPingRecorder()
|
||||||
tab.bumpOrAdd(randomID(tab.self.ID, i), &net.UDPAddr{})
|
tab := newTable(transport, NodeID{}, &net.UDPAddr{})
|
||||||
}
|
last := fillBucket(tab, 200)
|
||||||
for i, b := range tab.buckets {
|
pingSender := randomID(tab.self.ID, 200)
|
||||||
if i > 0 && len(b.entries) != 1 {
|
|
||||||
t.Errorf("bucket %d has %d entries, want 1", i, len(b.entries))
|
// this gotPing should replace the last node
|
||||||
|
// if the last node is not responding.
|
||||||
|
transport.responding[last.ID] = lastInBucketIsResponding
|
||||||
|
transport.responding[pingSender] = newNodeIsResponding
|
||||||
|
tab.bond(true, pingSender, &net.UDPAddr{}, 0)
|
||||||
|
|
||||||
|
// first ping goes to sender (bonding pingback)
|
||||||
|
if !transport.pinged[pingSender] {
|
||||||
|
t.Error("table did not ping back sender")
|
||||||
|
}
|
||||||
|
if newNodeIsResponding {
|
||||||
|
// second ping goes to oldest node in bucket
|
||||||
|
// to see whether it is still alive.
|
||||||
|
if !transport.pinged[last.ID] {
|
||||||
|
t.Error("table did not ping last node in bucket")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
tab.mutex.Lock()
|
||||||
|
defer tab.mutex.Unlock()
|
||||||
|
if l := len(tab.buckets[200].entries); l != bucketSize {
|
||||||
|
t.Errorf("wrong bucket size after gotPing: got %d, want %d", bucketSize, l)
|
||||||
|
}
|
||||||
|
|
||||||
|
if lastInBucketIsResponding || !newNodeIsResponding {
|
||||||
|
if !contains(tab.buckets[200].entries, last.ID) {
|
||||||
|
t.Error("last entry was removed")
|
||||||
|
}
|
||||||
|
if contains(tab.buckets[200].entries, pingSender) {
|
||||||
|
t.Error("new entry was added")
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if contains(tab.buckets[200].entries, last.ID) {
|
||||||
|
t.Error("last entry was not removed")
|
||||||
|
}
|
||||||
|
if !contains(tab.buckets[200].entries, pingSender) {
|
||||||
|
t.Error("new entry was not added")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
doit(true, true)
|
||||||
|
doit(false, true)
|
||||||
|
doit(false, true)
|
||||||
|
doit(false, false)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestTable_bumpOrAddPingReplace(t *testing.T) {
|
func TestBucket_bumpNoDuplicates(t *testing.T) {
|
||||||
pingC := make(pingC)
|
t.Parallel()
|
||||||
tab := newTable(pingC, NodeID{}, &net.UDPAddr{})
|
cfg := &quick.Config{
|
||||||
last := fillBucket(tab, 200)
|
MaxCount: 1000,
|
||||||
|
Rand: quickrand,
|
||||||
// this bumpOrAdd should not replace the last node
|
Values: func(args []reflect.Value, rand *rand.Rand) {
|
||||||
// because the node replies to ping.
|
// generate a random list of nodes. this will be the content of the bucket.
|
||||||
new := tab.bumpOrAdd(randomID(tab.self.ID, 200), &net.UDPAddr{})
|
n := rand.Intn(bucketSize-1) + 1
|
||||||
|
nodes := make([]*Node, n)
|
||||||
pinged := <-pingC
|
for i := range nodes {
|
||||||
if pinged != last.ID {
|
nodes[i] = &Node{ID: randomID(NodeID{}, 200)}
|
||||||
t.Fatalf("pinged wrong node: %v\nwant %v", pinged, last.ID)
|
}
|
||||||
|
args[0] = reflect.ValueOf(nodes)
|
||||||
|
// generate random bump positions.
|
||||||
|
bumps := make([]int, rand.Intn(100))
|
||||||
|
for i := range bumps {
|
||||||
|
bumps[i] = rand.Intn(len(nodes))
|
||||||
|
}
|
||||||
|
args[1] = reflect.ValueOf(bumps)
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
tab.mutex.Lock()
|
prop := func(nodes []*Node, bumps []int) (ok bool) {
|
||||||
defer tab.mutex.Unlock()
|
b := &bucket{entries: make([]*Node, len(nodes))}
|
||||||
if l := len(tab.buckets[200].entries); l != bucketSize {
|
copy(b.entries, nodes)
|
||||||
t.Errorf("wrong bucket size after bumpOrAdd: got %d, want %d", bucketSize, l)
|
for i, pos := range bumps {
|
||||||
|
b.bump(b.entries[pos])
|
||||||
|
if hasDuplicates(b.entries) {
|
||||||
|
t.Logf("bucket has duplicates after %d/%d bumps:", i+1, len(bumps))
|
||||||
|
for _, n := range b.entries {
|
||||||
|
t.Logf(" %p", n)
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true
|
||||||
}
|
}
|
||||||
if !contains(tab.buckets[200].entries, last.ID) {
|
if err := quick.Check(prop, cfg); err != nil {
|
||||||
t.Error("last entry was removed")
|
t.Error(err)
|
||||||
}
|
|
||||||
if contains(tab.buckets[200].entries, new.ID) {
|
|
||||||
t.Error("new entry was added")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestTable_bumpOrAddPingTimeout(t *testing.T) {
|
|
||||||
tab := newTable(pingC(nil), NodeID{}, &net.UDPAddr{})
|
|
||||||
last := fillBucket(tab, 200)
|
|
||||||
|
|
||||||
// this bumpOrAdd should replace the last node
|
|
||||||
// because the node does not reply to ping.
|
|
||||||
new := tab.bumpOrAdd(randomID(tab.self.ID, 200), &net.UDPAddr{})
|
|
||||||
|
|
||||||
// wait for async bucket update. damn. this needs to go away.
|
|
||||||
time.Sleep(2 * time.Millisecond)
|
|
||||||
|
|
||||||
tab.mutex.Lock()
|
|
||||||
defer tab.mutex.Unlock()
|
|
||||||
if l := len(tab.buckets[200].entries); l != bucketSize {
|
|
||||||
t.Errorf("wrong bucket size after bumpOrAdd: got %d, want %d", bucketSize, l)
|
|
||||||
}
|
|
||||||
if contains(tab.buckets[200].entries, last.ID) {
|
|
||||||
t.Error("last entry was not removed")
|
|
||||||
}
|
|
||||||
if !contains(tab.buckets[200].entries, new.ID) {
|
|
||||||
t.Error("new entry was not added")
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -85,44 +116,27 @@ func fillBucket(tab *Table, ld int) (last *Node) {
|
|||||||
return b.entries[bucketSize-1]
|
return b.entries[bucketSize-1]
|
||||||
}
|
}
|
||||||
|
|
||||||
type pingC chan NodeID
|
type pingRecorder struct{ responding, pinged map[NodeID]bool }
|
||||||
|
|
||||||
func (t pingC) findnode(n *Node, target NodeID) ([]*Node, error) {
|
func newPingRecorder() *pingRecorder {
|
||||||
|
return &pingRecorder{make(map[NodeID]bool), make(map[NodeID]bool)}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *pingRecorder) findnode(toid NodeID, toaddr *net.UDPAddr, target NodeID) ([]*Node, error) {
|
||||||
panic("findnode called on pingRecorder")
|
panic("findnode called on pingRecorder")
|
||||||
}
|
}
|
||||||
func (t pingC) close() {
|
func (t *pingRecorder) close() {
|
||||||
panic("close called on pingRecorder")
|
panic("close called on pingRecorder")
|
||||||
}
|
}
|
||||||
func (t pingC) ping(n *Node) error {
|
func (t *pingRecorder) waitping(from NodeID) error {
|
||||||
if t == nil {
|
return nil // remote always pings
|
||||||
return errTimeout
|
|
||||||
}
|
|
||||||
t <- n.ID
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
func (t *pingRecorder) ping(toid NodeID, toaddr *net.UDPAddr) error {
|
||||||
func TestTable_bump(t *testing.T) {
|
t.pinged[toid] = true
|
||||||
tab := newTable(nil, NodeID{}, &net.UDPAddr{})
|
if t.responding[toid] {
|
||||||
|
return nil
|
||||||
// add an old entry and two recent ones
|
} else {
|
||||||
oldactive := time.Now().Add(-2 * time.Minute)
|
return errTimeout
|
||||||
old := &Node{ID: randomID(tab.self.ID, 200), active: oldactive}
|
|
||||||
others := []*Node{
|
|
||||||
&Node{ID: randomID(tab.self.ID, 200), active: time.Now()},
|
|
||||||
&Node{ID: randomID(tab.self.ID, 200), active: time.Now()},
|
|
||||||
}
|
|
||||||
tab.add(append(others, old))
|
|
||||||
if tab.buckets[200].entries[0] == old {
|
|
||||||
t.Fatal("old entry is at front of bucket")
|
|
||||||
}
|
|
||||||
|
|
||||||
// bumping the old entry should move it to the front
|
|
||||||
tab.bump(old.ID)
|
|
||||||
if old.active == oldactive {
|
|
||||||
t.Error("activity timestamp not updated")
|
|
||||||
}
|
|
||||||
if tab.buckets[200].entries[0] != old {
|
|
||||||
t.Errorf("bumped entry did not move to the front of bucket")
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -210,7 +224,7 @@ func TestTable_Lookup(t *testing.T) {
|
|||||||
t.Fatalf("lookup on empty table returned %d results: %#v", len(results), results)
|
t.Fatalf("lookup on empty table returned %d results: %#v", len(results), results)
|
||||||
}
|
}
|
||||||
// seed table with initial node (otherwise lookup will terminate immediately)
|
// seed table with initial node (otherwise lookup will terminate immediately)
|
||||||
tab.bumpOrAdd(randomID(target, 200), &net.UDPAddr{Port: 200})
|
tab.add([]*Node{newNode(randomID(target, 200), &net.UDPAddr{Port: 200})})
|
||||||
|
|
||||||
results := tab.Lookup(target)
|
results := tab.Lookup(target)
|
||||||
t.Logf("results:")
|
t.Logf("results:")
|
||||||
@ -238,16 +252,16 @@ type findnodeOracle struct {
|
|||||||
target NodeID
|
target NodeID
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t findnodeOracle) findnode(n *Node, target NodeID) ([]*Node, error) {
|
func (t findnodeOracle) findnode(toid NodeID, toaddr *net.UDPAddr, target NodeID) ([]*Node, error) {
|
||||||
t.t.Logf("findnode query at dist %d", n.DiscPort)
|
t.t.Logf("findnode query at dist %d", toaddr.Port)
|
||||||
// current log distance is encoded in port number
|
// current log distance is encoded in port number
|
||||||
var result []*Node
|
var result []*Node
|
||||||
switch n.DiscPort {
|
switch toaddr.Port {
|
||||||
case 0:
|
case 0:
|
||||||
panic("query to node at distance 0")
|
panic("query to node at distance 0")
|
||||||
default:
|
default:
|
||||||
// TODO: add more randomness to distances
|
// TODO: add more randomness to distances
|
||||||
next := n.DiscPort - 1
|
next := toaddr.Port - 1
|
||||||
for i := 0; i < bucketSize; i++ {
|
for i := 0; i < bucketSize; i++ {
|
||||||
result = append(result, &Node{ID: randomID(t.target, next), DiscPort: next})
|
result = append(result, &Node{ID: randomID(t.target, next), DiscPort: next})
|
||||||
}
|
}
|
||||||
@ -255,11 +269,9 @@ func (t findnodeOracle) findnode(n *Node, target NodeID) ([]*Node, error) {
|
|||||||
return result, nil
|
return result, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t findnodeOracle) close() {}
|
func (t findnodeOracle) close() {}
|
||||||
|
func (t findnodeOracle) waitping(from NodeID) error { return nil }
|
||||||
func (t findnodeOracle) ping(n *Node) error {
|
func (t findnodeOracle) ping(toid NodeID, toaddr *net.UDPAddr) error { return nil }
|
||||||
return errors.New("ping is not supported by this transport")
|
|
||||||
}
|
|
||||||
|
|
||||||
func hasDuplicates(slice []*Node) bool {
|
func hasDuplicates(slice []*Node) bool {
|
||||||
seen := make(map[NodeID]bool)
|
seen := make(map[NodeID]bool)
|
||||||
|
@ -16,13 +16,18 @@ import (
|
|||||||
|
|
||||||
var log = logger.NewLogger("P2P Discovery")
|
var log = logger.NewLogger("P2P Discovery")
|
||||||
|
|
||||||
|
const Version = 3
|
||||||
|
|
||||||
// Errors
|
// Errors
|
||||||
var (
|
var (
|
||||||
errPacketTooSmall = errors.New("too small")
|
errPacketTooSmall = errors.New("too small")
|
||||||
errBadHash = errors.New("bad hash")
|
errBadHash = errors.New("bad hash")
|
||||||
errExpired = errors.New("expired")
|
errExpired = errors.New("expired")
|
||||||
errTimeout = errors.New("RPC timeout")
|
errBadVersion = errors.New("version mismatch")
|
||||||
errClosed = errors.New("socket closed")
|
errUnsolicitedReply = errors.New("unsolicited reply")
|
||||||
|
errUnknownNode = errors.New("unknown node")
|
||||||
|
errTimeout = errors.New("RPC timeout")
|
||||||
|
errClosed = errors.New("socket closed")
|
||||||
)
|
)
|
||||||
|
|
||||||
// Timeouts
|
// Timeouts
|
||||||
@ -45,6 +50,7 @@ const (
|
|||||||
// RPC request structures
|
// RPC request structures
|
||||||
type (
|
type (
|
||||||
ping struct {
|
ping struct {
|
||||||
|
Version uint // must match Version
|
||||||
IP string // our IP
|
IP string // our IP
|
||||||
Port uint16 // our port
|
Port uint16 // our port
|
||||||
Expiration uint64
|
Expiration uint64
|
||||||
@ -76,14 +82,27 @@ type rpcNode struct {
|
|||||||
ID NodeID
|
ID NodeID
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type packet interface {
|
||||||
|
handle(t *udp, from *net.UDPAddr, fromID NodeID, mac []byte) error
|
||||||
|
}
|
||||||
|
|
||||||
|
type conn interface {
|
||||||
|
ReadFromUDP(b []byte) (n int, addr *net.UDPAddr, err error)
|
||||||
|
WriteToUDP(b []byte, addr *net.UDPAddr) (n int, err error)
|
||||||
|
Close() error
|
||||||
|
LocalAddr() net.Addr
|
||||||
|
}
|
||||||
|
|
||||||
// udp implements the RPC protocol.
|
// udp implements the RPC protocol.
|
||||||
type udp struct {
|
type udp struct {
|
||||||
conn *net.UDPConn
|
conn conn
|
||||||
priv *ecdsa.PrivateKey
|
priv *ecdsa.PrivateKey
|
||||||
|
|
||||||
addpending chan *pending
|
addpending chan *pending
|
||||||
replies chan reply
|
gotreply chan reply
|
||||||
closing chan struct{}
|
|
||||||
nat nat.Interface
|
closing chan struct{}
|
||||||
|
nat nat.Interface
|
||||||
|
|
||||||
*Table
|
*Table
|
||||||
}
|
}
|
||||||
@ -120,6 +139,9 @@ type reply struct {
|
|||||||
from NodeID
|
from NodeID
|
||||||
ptype byte
|
ptype byte
|
||||||
data interface{}
|
data interface{}
|
||||||
|
// loop indicates whether there was
|
||||||
|
// a matching request by sending on this channel.
|
||||||
|
matched chan<- bool
|
||||||
}
|
}
|
||||||
|
|
||||||
// ListenUDP returns a new table that listens for UDP packets on laddr.
|
// ListenUDP returns a new table that listens for UDP packets on laddr.
|
||||||
@ -132,15 +154,20 @@ func ListenUDP(priv *ecdsa.PrivateKey, laddr string, natm nat.Interface) (*Table
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
tab, _ := newUDP(priv, conn, natm)
|
||||||
|
log.Infoln("Listening,", tab.self)
|
||||||
|
return tab, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func newUDP(priv *ecdsa.PrivateKey, c conn, natm nat.Interface) (*Table, *udp) {
|
||||||
udp := &udp{
|
udp := &udp{
|
||||||
conn: conn,
|
conn: c,
|
||||||
priv: priv,
|
priv: priv,
|
||||||
closing: make(chan struct{}),
|
closing: make(chan struct{}),
|
||||||
|
gotreply: make(chan reply),
|
||||||
addpending: make(chan *pending),
|
addpending: make(chan *pending),
|
||||||
replies: make(chan reply),
|
|
||||||
}
|
}
|
||||||
|
realaddr := c.LocalAddr().(*net.UDPAddr)
|
||||||
realaddr := conn.LocalAddr().(*net.UDPAddr)
|
|
||||||
if natm != nil {
|
if natm != nil {
|
||||||
if !realaddr.IP.IsLoopback() {
|
if !realaddr.IP.IsLoopback() {
|
||||||
go nat.Map(natm, udp.closing, "udp", realaddr.Port, realaddr.Port, "ethereum discovery")
|
go nat.Map(natm, udp.closing, "udp", realaddr.Port, realaddr.Port, "ethereum discovery")
|
||||||
@ -151,11 +178,9 @@ func ListenUDP(priv *ecdsa.PrivateKey, laddr string, natm nat.Interface) (*Table
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
udp.Table = newTable(udp, PubkeyID(&priv.PublicKey), realaddr)
|
udp.Table = newTable(udp, PubkeyID(&priv.PublicKey), realaddr)
|
||||||
|
|
||||||
go udp.loop()
|
go udp.loop()
|
||||||
go udp.readLoop()
|
go udp.readLoop()
|
||||||
log.Infoln("Listening, ", udp.self)
|
return udp.Table, udp
|
||||||
return udp.Table, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *udp) close() {
|
func (t *udp) close() {
|
||||||
@ -165,10 +190,11 @@ func (t *udp) close() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// ping sends a ping message to the given node and waits for a reply.
|
// ping sends a ping message to the given node and waits for a reply.
|
||||||
func (t *udp) ping(e *Node) error {
|
func (t *udp) ping(toid NodeID, toaddr *net.UDPAddr) error {
|
||||||
// TODO: maybe check for ReplyTo field in callback to measure RTT
|
// TODO: maybe check for ReplyTo field in callback to measure RTT
|
||||||
errc := t.pending(e.ID, pongPacket, func(interface{}) bool { return true })
|
errc := t.pending(toid, pongPacket, func(interface{}) bool { return true })
|
||||||
t.send(e, pingPacket, ping{
|
t.send(toaddr, pingPacket, ping{
|
||||||
|
Version: Version,
|
||||||
IP: t.self.IP.String(),
|
IP: t.self.IP.String(),
|
||||||
Port: uint16(t.self.TCPPort),
|
Port: uint16(t.self.TCPPort),
|
||||||
Expiration: uint64(time.Now().Add(expiration).Unix()),
|
Expiration: uint64(time.Now().Add(expiration).Unix()),
|
||||||
@ -176,12 +202,16 @@ func (t *udp) ping(e *Node) error {
|
|||||||
return <-errc
|
return <-errc
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (t *udp) waitping(from NodeID) error {
|
||||||
|
return <-t.pending(from, pingPacket, func(interface{}) bool { return true })
|
||||||
|
}
|
||||||
|
|
||||||
// findnode sends a findnode request to the given node and waits until
|
// findnode sends a findnode request to the given node and waits until
|
||||||
// the node has sent up to k neighbors.
|
// the node has sent up to k neighbors.
|
||||||
func (t *udp) findnode(to *Node, target NodeID) ([]*Node, error) {
|
func (t *udp) findnode(toid NodeID, toaddr *net.UDPAddr, target NodeID) ([]*Node, error) {
|
||||||
nodes := make([]*Node, 0, bucketSize)
|
nodes := make([]*Node, 0, bucketSize)
|
||||||
nreceived := 0
|
nreceived := 0
|
||||||
errc := t.pending(to.ID, neighborsPacket, func(r interface{}) bool {
|
errc := t.pending(toid, neighborsPacket, func(r interface{}) bool {
|
||||||
reply := r.(*neighbors)
|
reply := r.(*neighbors)
|
||||||
for _, n := range reply.Nodes {
|
for _, n := range reply.Nodes {
|
||||||
nreceived++
|
nreceived++
|
||||||
@ -191,8 +221,7 @@ func (t *udp) findnode(to *Node, target NodeID) ([]*Node, error) {
|
|||||||
}
|
}
|
||||||
return nreceived >= bucketSize
|
return nreceived >= bucketSize
|
||||||
})
|
})
|
||||||
|
t.send(toaddr, findnodePacket, findnode{
|
||||||
t.send(to, findnodePacket, findnode{
|
|
||||||
Target: target,
|
Target: target,
|
||||||
Expiration: uint64(time.Now().Add(expiration).Unix()),
|
Expiration: uint64(time.Now().Add(expiration).Unix()),
|
||||||
})
|
})
|
||||||
@ -214,6 +243,17 @@ func (t *udp) pending(id NodeID, ptype byte, callback func(interface{}) bool) <-
|
|||||||
return ch
|
return ch
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (t *udp) handleReply(from NodeID, ptype byte, req packet) bool {
|
||||||
|
matched := make(chan bool)
|
||||||
|
select {
|
||||||
|
case t.gotreply <- reply{from, ptype, req, matched}:
|
||||||
|
// loop will handle it
|
||||||
|
return <-matched
|
||||||
|
case <-t.closing:
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// loop runs in its own goroutin. it keeps track of
|
// loop runs in its own goroutin. it keeps track of
|
||||||
// the refresh timer and the pending reply queue.
|
// the refresh timer and the pending reply queue.
|
||||||
func (t *udp) loop() {
|
func (t *udp) loop() {
|
||||||
@ -244,6 +284,7 @@ func (t *udp) loop() {
|
|||||||
for _, p := range pending {
|
for _, p := range pending {
|
||||||
p.errc <- errClosed
|
p.errc <- errClosed
|
||||||
}
|
}
|
||||||
|
pending = nil
|
||||||
return
|
return
|
||||||
|
|
||||||
case p := <-t.addpending:
|
case p := <-t.addpending:
|
||||||
@ -251,18 +292,21 @@ func (t *udp) loop() {
|
|||||||
pending = append(pending, p)
|
pending = append(pending, p)
|
||||||
rearmTimeout()
|
rearmTimeout()
|
||||||
|
|
||||||
case reply := <-t.replies:
|
case r := <-t.gotreply:
|
||||||
// run matching callbacks, remove if they return false.
|
var matched bool
|
||||||
for i := 0; i < len(pending); i++ {
|
for i := 0; i < len(pending); i++ {
|
||||||
p := pending[i]
|
if p := pending[i]; p.from == r.from && p.ptype == r.ptype {
|
||||||
if reply.from == p.from && reply.ptype == p.ptype && p.callback(reply.data) {
|
matched = true
|
||||||
p.errc <- nil
|
if p.callback(r.data) {
|
||||||
copy(pending[i:], pending[i+1:])
|
// callback indicates the request is done, remove it.
|
||||||
pending = pending[:len(pending)-1]
|
p.errc <- nil
|
||||||
i--
|
copy(pending[i:], pending[i+1:])
|
||||||
|
pending = pending[:len(pending)-1]
|
||||||
|
i--
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
rearmTimeout()
|
r.matched <- matched
|
||||||
|
|
||||||
case now := <-timeout.C:
|
case now := <-timeout.C:
|
||||||
// notify and remove callbacks whose deadline is in the past.
|
// notify and remove callbacks whose deadline is in the past.
|
||||||
@ -287,28 +331,11 @@ const (
|
|||||||
|
|
||||||
var headSpace = make([]byte, headSize)
|
var headSpace = make([]byte, headSize)
|
||||||
|
|
||||||
func (t *udp) send(to *Node, ptype byte, req interface{}) error {
|
func (t *udp) send(toaddr *net.UDPAddr, ptype byte, req interface{}) error {
|
||||||
b := new(bytes.Buffer)
|
packet, err := encodePacket(t.priv, ptype, req)
|
||||||
b.Write(headSpace)
|
|
||||||
b.WriteByte(ptype)
|
|
||||||
if err := rlp.Encode(b, req); err != nil {
|
|
||||||
log.Errorln("error encoding packet:", err)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
packet := b.Bytes()
|
|
||||||
sig, err := crypto.Sign(crypto.Sha3(packet[headSize:]), t.priv)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorln("could not sign packet:", err)
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
copy(packet[macSize:], sig)
|
|
||||||
// add the hash to the front. Note: this doesn't protect the
|
|
||||||
// packet in any way. Our public key will be part of this hash in
|
|
||||||
// the future.
|
|
||||||
copy(packet, crypto.Sha3(packet[macSize:]))
|
|
||||||
|
|
||||||
toaddr := &net.UDPAddr{IP: to.IP, Port: to.DiscPort}
|
|
||||||
log.DebugDetailf(">>> %v %T %v\n", toaddr, req, req)
|
log.DebugDetailf(">>> %v %T %v\n", toaddr, req, req)
|
||||||
if _, err = t.conn.WriteToUDP(packet, toaddr); err != nil {
|
if _, err = t.conn.WriteToUDP(packet, toaddr); err != nil {
|
||||||
log.DebugDetailln("UDP send failed:", err)
|
log.DebugDetailln("UDP send failed:", err)
|
||||||
@ -316,6 +343,28 @@ func (t *udp) send(to *Node, ptype byte, req interface{}) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func encodePacket(priv *ecdsa.PrivateKey, ptype byte, req interface{}) ([]byte, error) {
|
||||||
|
b := new(bytes.Buffer)
|
||||||
|
b.Write(headSpace)
|
||||||
|
b.WriteByte(ptype)
|
||||||
|
if err := rlp.Encode(b, req); err != nil {
|
||||||
|
log.Errorln("error encoding packet:", err)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
packet := b.Bytes()
|
||||||
|
sig, err := crypto.Sign(crypto.Sha3(packet[headSize:]), priv)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorln("could not sign packet:", err)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
copy(packet[macSize:], sig)
|
||||||
|
// add the hash to the front. Note: this doesn't protect the
|
||||||
|
// packet in any way. Our public key will be part of this hash in
|
||||||
|
// The future.
|
||||||
|
copy(packet, crypto.Sha3(packet[macSize:]))
|
||||||
|
return packet, nil
|
||||||
|
}
|
||||||
|
|
||||||
// readLoop runs in its own goroutine. it handles incoming UDP packets.
|
// readLoop runs in its own goroutine. it handles incoming UDP packets.
|
||||||
func (t *udp) readLoop() {
|
func (t *udp) readLoop() {
|
||||||
defer t.conn.Close()
|
defer t.conn.Close()
|
||||||
@ -325,29 +374,34 @@ func (t *udp) readLoop() {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if err := t.packetIn(from, buf[:nbytes]); err != nil {
|
packet, fromID, hash, err := decodePacket(buf[:nbytes])
|
||||||
|
if err != nil {
|
||||||
log.Debugf("Bad packet from %v: %v\n", from, err)
|
log.Debugf("Bad packet from %v: %v\n", from, err)
|
||||||
|
continue
|
||||||
}
|
}
|
||||||
|
log.DebugDetailf("<<< %v %T %v\n", from, packet, packet)
|
||||||
|
go func() {
|
||||||
|
if err := packet.handle(t, from, fromID, hash); err != nil {
|
||||||
|
log.Debugf("error handling %T from %v: %v", packet, from, err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *udp) packetIn(from *net.UDPAddr, buf []byte) error {
|
func decodePacket(buf []byte) (packet, NodeID, []byte, error) {
|
||||||
if len(buf) < headSize+1 {
|
if len(buf) < headSize+1 {
|
||||||
return errPacketTooSmall
|
return nil, NodeID{}, nil, errPacketTooSmall
|
||||||
}
|
}
|
||||||
hash, sig, sigdata := buf[:macSize], buf[macSize:headSize], buf[headSize:]
|
hash, sig, sigdata := buf[:macSize], buf[macSize:headSize], buf[headSize:]
|
||||||
shouldhash := crypto.Sha3(buf[macSize:])
|
shouldhash := crypto.Sha3(buf[macSize:])
|
||||||
if !bytes.Equal(hash, shouldhash) {
|
if !bytes.Equal(hash, shouldhash) {
|
||||||
return errBadHash
|
return nil, NodeID{}, nil, errBadHash
|
||||||
}
|
}
|
||||||
fromID, err := recoverNodeID(crypto.Sha3(buf[headSize:]), sig)
|
fromID, err := recoverNodeID(crypto.Sha3(buf[headSize:]), sig)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return nil, NodeID{}, hash, err
|
||||||
}
|
|
||||||
|
|
||||||
var req interface {
|
|
||||||
handle(t *udp, from *net.UDPAddr, fromID NodeID, mac []byte) error
|
|
||||||
}
|
}
|
||||||
|
var req packet
|
||||||
switch ptype := sigdata[0]; ptype {
|
switch ptype := sigdata[0]; ptype {
|
||||||
case pingPacket:
|
case pingPacket:
|
||||||
req = new(ping)
|
req = new(ping)
|
||||||
@ -358,31 +412,27 @@ func (t *udp) packetIn(from *net.UDPAddr, buf []byte) error {
|
|||||||
case neighborsPacket:
|
case neighborsPacket:
|
||||||
req = new(neighbors)
|
req = new(neighbors)
|
||||||
default:
|
default:
|
||||||
return fmt.Errorf("unknown type: %d", ptype)
|
return nil, fromID, hash, fmt.Errorf("unknown type: %d", ptype)
|
||||||
}
|
}
|
||||||
if err := rlp.Decode(bytes.NewReader(sigdata[1:]), req); err != nil {
|
err = rlp.Decode(bytes.NewReader(sigdata[1:]), req)
|
||||||
return err
|
return req, fromID, hash, err
|
||||||
}
|
|
||||||
log.DebugDetailf("<<< %v %T %v\n", from, req, req)
|
|
||||||
return req.handle(t, from, fromID, hash)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (req *ping) handle(t *udp, from *net.UDPAddr, fromID NodeID, mac []byte) error {
|
func (req *ping) handle(t *udp, from *net.UDPAddr, fromID NodeID, mac []byte) error {
|
||||||
if expired(req.Expiration) {
|
if expired(req.Expiration) {
|
||||||
return errExpired
|
return errExpired
|
||||||
}
|
}
|
||||||
t.mutex.Lock()
|
if req.Version != Version {
|
||||||
// Note: we're ignoring the provided IP address right now
|
return errBadVersion
|
||||||
n := t.bumpOrAdd(fromID, from)
|
|
||||||
if req.Port != 0 {
|
|
||||||
n.TCPPort = int(req.Port)
|
|
||||||
}
|
}
|
||||||
t.mutex.Unlock()
|
t.send(from, pongPacket, pong{
|
||||||
|
|
||||||
t.send(n, pongPacket, pong{
|
|
||||||
ReplyTok: mac,
|
ReplyTok: mac,
|
||||||
Expiration: uint64(time.Now().Add(expiration).Unix()),
|
Expiration: uint64(time.Now().Add(expiration).Unix()),
|
||||||
})
|
})
|
||||||
|
if !t.handleReply(fromID, pingPacket, req) {
|
||||||
|
// Note: we're ignoring the provided IP address right now
|
||||||
|
t.bond(true, fromID, from, req.Port)
|
||||||
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -390,11 +440,9 @@ func (req *pong) handle(t *udp, from *net.UDPAddr, fromID NodeID, mac []byte) er
|
|||||||
if expired(req.Expiration) {
|
if expired(req.Expiration) {
|
||||||
return errExpired
|
return errExpired
|
||||||
}
|
}
|
||||||
t.mutex.Lock()
|
if !t.handleReply(fromID, pongPacket, req) {
|
||||||
t.bump(fromID)
|
return errUnsolicitedReply
|
||||||
t.mutex.Unlock()
|
}
|
||||||
|
|
||||||
t.replies <- reply{fromID, pongPacket, req}
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -402,12 +450,21 @@ func (req *findnode) handle(t *udp, from *net.UDPAddr, fromID NodeID, mac []byte
|
|||||||
if expired(req.Expiration) {
|
if expired(req.Expiration) {
|
||||||
return errExpired
|
return errExpired
|
||||||
}
|
}
|
||||||
|
if t.db.get(fromID) == nil {
|
||||||
|
// No bond exists, we don't process the packet. This prevents
|
||||||
|
// an attack vector where the discovery protocol could be used
|
||||||
|
// to amplify traffic in a DDOS attack. A malicious actor
|
||||||
|
// would send a findnode request with the IP address and UDP
|
||||||
|
// port of the target as the source address. The recipient of
|
||||||
|
// the findnode packet would then send a neighbors packet
|
||||||
|
// (which is a much bigger packet than findnode) to the victim.
|
||||||
|
return errUnknownNode
|
||||||
|
}
|
||||||
t.mutex.Lock()
|
t.mutex.Lock()
|
||||||
e := t.bumpOrAdd(fromID, from)
|
|
||||||
closest := t.closest(req.Target, bucketSize).entries
|
closest := t.closest(req.Target, bucketSize).entries
|
||||||
t.mutex.Unlock()
|
t.mutex.Unlock()
|
||||||
|
|
||||||
t.send(e, neighborsPacket, neighbors{
|
t.send(from, neighborsPacket, neighbors{
|
||||||
Nodes: closest,
|
Nodes: closest,
|
||||||
Expiration: uint64(time.Now().Add(expiration).Unix()),
|
Expiration: uint64(time.Now().Add(expiration).Unix()),
|
||||||
})
|
})
|
||||||
@ -418,12 +475,9 @@ func (req *neighbors) handle(t *udp, from *net.UDPAddr, fromID NodeID, mac []byt
|
|||||||
if expired(req.Expiration) {
|
if expired(req.Expiration) {
|
||||||
return errExpired
|
return errExpired
|
||||||
}
|
}
|
||||||
t.mutex.Lock()
|
if !t.handleReply(fromID, neighborsPacket, req) {
|
||||||
t.bump(fromID)
|
return errUnsolicitedReply
|
||||||
t.add(req.Nodes)
|
}
|
||||||
t.mutex.Unlock()
|
|
||||||
|
|
||||||
t.replies <- reply{fromID, neighborsPacket, req}
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1,10 +1,18 @@
|
|||||||
package discover
|
package discover
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bytes"
|
||||||
|
"crypto/ecdsa"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"io"
|
||||||
logpkg "log"
|
logpkg "log"
|
||||||
"net"
|
"net"
|
||||||
"os"
|
"os"
|
||||||
|
"path"
|
||||||
|
"reflect"
|
||||||
|
"runtime"
|
||||||
|
"sync"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@ -15,22 +23,243 @@ func init() {
|
|||||||
logger.AddLogSystem(logger.NewStdLogSystem(os.Stdout, logpkg.LstdFlags, logger.ErrorLevel))
|
logger.AddLogSystem(logger.NewStdLogSystem(os.Stdout, logpkg.LstdFlags, logger.ErrorLevel))
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestUDP_ping(t *testing.T) {
|
type udpTest struct {
|
||||||
|
t *testing.T
|
||||||
|
pipe *dgramPipe
|
||||||
|
table *Table
|
||||||
|
udp *udp
|
||||||
|
sent [][]byte
|
||||||
|
localkey, remotekey *ecdsa.PrivateKey
|
||||||
|
remoteaddr *net.UDPAddr
|
||||||
|
}
|
||||||
|
|
||||||
|
func newUDPTest(t *testing.T) *udpTest {
|
||||||
|
test := &udpTest{
|
||||||
|
t: t,
|
||||||
|
pipe: newpipe(),
|
||||||
|
localkey: newkey(),
|
||||||
|
remotekey: newkey(),
|
||||||
|
remoteaddr: &net.UDPAddr{IP: net.IP{1, 2, 3, 4}, Port: 30303},
|
||||||
|
}
|
||||||
|
test.table, test.udp = newUDP(test.localkey, test.pipe, nil)
|
||||||
|
return test
|
||||||
|
}
|
||||||
|
|
||||||
|
// handles a packet as if it had been sent to the transport.
|
||||||
|
func (test *udpTest) packetIn(wantError error, ptype byte, data packet) error {
|
||||||
|
enc, err := encodePacket(test.remotekey, ptype, data)
|
||||||
|
if err != nil {
|
||||||
|
return test.errorf("packet (%d) encode error: %v", err)
|
||||||
|
}
|
||||||
|
test.sent = append(test.sent, enc)
|
||||||
|
err = data.handle(test.udp, test.remoteaddr, PubkeyID(&test.remotekey.PublicKey), enc[:macSize])
|
||||||
|
if err != wantError {
|
||||||
|
return test.errorf("error mismatch: got %q, want %q", err, wantError)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// waits for a packet to be sent by the transport.
|
||||||
|
// validate should have type func(*udpTest, X) error, where X is a packet type.
|
||||||
|
func (test *udpTest) waitPacketOut(validate interface{}) error {
|
||||||
|
dgram := test.pipe.waitPacketOut()
|
||||||
|
p, _, _, err := decodePacket(dgram)
|
||||||
|
if err != nil {
|
||||||
|
return test.errorf("sent packet decode error: %v", err)
|
||||||
|
}
|
||||||
|
fn := reflect.ValueOf(validate)
|
||||||
|
exptype := fn.Type().In(0)
|
||||||
|
if reflect.TypeOf(p) != exptype {
|
||||||
|
return test.errorf("sent packet type mismatch, got: %v, want: %v", reflect.TypeOf(p), exptype)
|
||||||
|
}
|
||||||
|
fn.Call([]reflect.Value{reflect.ValueOf(p)})
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (test *udpTest) errorf(format string, args ...interface{}) error {
|
||||||
|
_, file, line, ok := runtime.Caller(2) // errorf + waitPacketOut
|
||||||
|
if ok {
|
||||||
|
file = path.Base(file)
|
||||||
|
} else {
|
||||||
|
file = "???"
|
||||||
|
line = 1
|
||||||
|
}
|
||||||
|
err := fmt.Errorf(format, args...)
|
||||||
|
fmt.Printf("\t%s:%d: %v\n", file, line, err)
|
||||||
|
test.t.Fail()
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// shared test variables
|
||||||
|
var (
|
||||||
|
futureExp = uint64(time.Now().Add(10 * time.Hour).Unix())
|
||||||
|
testTarget = MustHexID("01010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101")
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestUDP_packetErrors(t *testing.T) {
|
||||||
|
test := newUDPTest(t)
|
||||||
|
defer test.table.Close()
|
||||||
|
|
||||||
|
test.packetIn(errExpired, pingPacket, &ping{IP: "foo", Port: 99, Version: Version})
|
||||||
|
test.packetIn(errBadVersion, pingPacket, &ping{IP: "foo", Port: 99, Version: 99, Expiration: futureExp})
|
||||||
|
test.packetIn(errUnsolicitedReply, pongPacket, &pong{ReplyTok: []byte{}, Expiration: futureExp})
|
||||||
|
test.packetIn(errUnknownNode, findnodePacket, &findnode{Expiration: futureExp})
|
||||||
|
test.packetIn(errUnsolicitedReply, neighborsPacket, &neighbors{Expiration: futureExp})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUDP_pingTimeout(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
test := newUDPTest(t)
|
||||||
|
defer test.table.Close()
|
||||||
|
|
||||||
n1, _ := ListenUDP(newkey(), "127.0.0.1:0", nil)
|
toaddr := &net.UDPAddr{IP: net.ParseIP("1.2.3.4"), Port: 2222}
|
||||||
n2, _ := ListenUDP(newkey(), "127.0.0.1:0", nil)
|
toid := NodeID{1, 2, 3, 4}
|
||||||
defer n1.Close()
|
if err := test.udp.ping(toid, toaddr); err != errTimeout {
|
||||||
defer n2.Close()
|
t.Error("expected timeout error, got", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if err := n1.net.ping(n2.self); err != nil {
|
func TestUDP_findnodeTimeout(t *testing.T) {
|
||||||
t.Fatalf("ping error: %v", err)
|
t.Parallel()
|
||||||
|
test := newUDPTest(t)
|
||||||
|
defer test.table.Close()
|
||||||
|
|
||||||
|
toaddr := &net.UDPAddr{IP: net.ParseIP("1.2.3.4"), Port: 2222}
|
||||||
|
toid := NodeID{1, 2, 3, 4}
|
||||||
|
target := NodeID{4, 5, 6, 7}
|
||||||
|
result, err := test.udp.findnode(toid, toaddr, target)
|
||||||
|
if err != errTimeout {
|
||||||
|
t.Error("expected timeout error, got", err)
|
||||||
}
|
}
|
||||||
if find(n2, n1.self.ID) == nil {
|
if len(result) > 0 {
|
||||||
t.Errorf("node 2 does not contain id of node 1")
|
t.Error("expected empty result, got", result)
|
||||||
}
|
}
|
||||||
if e := find(n1, n2.self.ID); e != nil {
|
}
|
||||||
t.Errorf("node 1 does contains id of node 2: %v", e)
|
|
||||||
|
func TestUDP_findnode(t *testing.T) {
|
||||||
|
test := newUDPTest(t)
|
||||||
|
defer test.table.Close()
|
||||||
|
|
||||||
|
// put a few nodes into the table. their exact
|
||||||
|
// distribution shouldn't matter much, altough we need to
|
||||||
|
// take care not to overflow any bucket.
|
||||||
|
target := testTarget
|
||||||
|
nodes := &nodesByDistance{target: target}
|
||||||
|
for i := 0; i < bucketSize; i++ {
|
||||||
|
nodes.push(&Node{
|
||||||
|
IP: net.IP{1, 2, 3, byte(i)},
|
||||||
|
DiscPort: i + 2,
|
||||||
|
TCPPort: i + 2,
|
||||||
|
ID: randomID(test.table.self.ID, i+2),
|
||||||
|
}, bucketSize)
|
||||||
|
}
|
||||||
|
test.table.add(nodes.entries)
|
||||||
|
|
||||||
|
// ensure there's a bond with the test node,
|
||||||
|
// findnode won't be accepted otherwise.
|
||||||
|
test.table.db.add(PubkeyID(&test.remotekey.PublicKey), test.remoteaddr, 99)
|
||||||
|
|
||||||
|
// check that closest neighbors are returned.
|
||||||
|
test.packetIn(nil, findnodePacket, &findnode{Target: testTarget, Expiration: futureExp})
|
||||||
|
test.waitPacketOut(func(p *neighbors) {
|
||||||
|
expected := test.table.closest(testTarget, bucketSize)
|
||||||
|
if len(p.Nodes) != bucketSize {
|
||||||
|
t.Errorf("wrong number of results: got %d, want %d", len(p.Nodes), bucketSize)
|
||||||
|
}
|
||||||
|
for i := range p.Nodes {
|
||||||
|
if p.Nodes[i].ID != expected.entries[i].ID {
|
||||||
|
t.Errorf("result mismatch at %d:\n got: %v\n want: %v", i, p.Nodes[i], expected.entries[i])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUDP_findnodeMultiReply(t *testing.T) {
|
||||||
|
test := newUDPTest(t)
|
||||||
|
defer test.table.Close()
|
||||||
|
|
||||||
|
// queue a pending findnode request
|
||||||
|
resultc, errc := make(chan []*Node), make(chan error)
|
||||||
|
go func() {
|
||||||
|
rid := PubkeyID(&test.remotekey.PublicKey)
|
||||||
|
ns, err := test.udp.findnode(rid, test.remoteaddr, testTarget)
|
||||||
|
if err != nil && len(ns) == 0 {
|
||||||
|
errc <- err
|
||||||
|
} else {
|
||||||
|
resultc <- ns
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
// wait for the findnode to be sent.
|
||||||
|
// after it is sent, the transport is waiting for a reply
|
||||||
|
test.waitPacketOut(func(p *findnode) {
|
||||||
|
if p.Target != testTarget {
|
||||||
|
t.Errorf("wrong target: got %v, want %v", p.Target, testTarget)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
// send the reply as two packets.
|
||||||
|
list := []*Node{
|
||||||
|
MustParseNode("enode://ba85011c70bcc5c04d8607d3a0ed29aa6179c092cbdda10d5d32684fb33ed01bd94f588ca8f91ac48318087dcb02eaf36773a7a453f0eedd6742af668097b29c@10.0.1.16:30303"),
|
||||||
|
MustParseNode("enode://81fa361d25f157cd421c60dcc28d8dac5ef6a89476633339c5df30287474520caca09627da18543d9079b5b288698b542d56167aa5c09111e55acdbbdf2ef799@10.0.1.16:30303"),
|
||||||
|
MustParseNode("enode://9bffefd833d53fac8e652415f4973bee289e8b1a5c6c4cbe70abf817ce8a64cee11b823b66a987f51aaa9fba0d6a91b3e6bf0d5a5d1042de8e9eeea057b217f8@10.0.1.36:30301"),
|
||||||
|
MustParseNode("enode://1b5b4aa662d7cb44a7221bfba67302590b643028197a7d5214790f3bac7aaa4a3241be9e83c09cf1f6c69d007c634faae3dc1b1221793e8446c0b3a09de65960@10.0.1.16:30303"),
|
||||||
|
}
|
||||||
|
test.packetIn(nil, neighborsPacket, &neighbors{Expiration: futureExp, Nodes: list[:2]})
|
||||||
|
test.packetIn(nil, neighborsPacket, &neighbors{Expiration: futureExp, Nodes: list[2:]})
|
||||||
|
|
||||||
|
// check that the sent neighbors are all returned by findnode
|
||||||
|
select {
|
||||||
|
case result := <-resultc:
|
||||||
|
if !reflect.DeepEqual(result, list) {
|
||||||
|
t.Errorf("neighbors mismatch:\n got: %v\n want: %v", result, list)
|
||||||
|
}
|
||||||
|
case err := <-errc:
|
||||||
|
t.Errorf("findnode error: %v", err)
|
||||||
|
case <-time.After(5 * time.Second):
|
||||||
|
t.Error("findnode did not return within 5 seconds")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUDP_successfulPing(t *testing.T) {
|
||||||
|
test := newUDPTest(t)
|
||||||
|
defer test.table.Close()
|
||||||
|
|
||||||
|
done := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
test.packetIn(nil, pingPacket, &ping{IP: "foo", Port: 99, Version: Version, Expiration: futureExp})
|
||||||
|
close(done)
|
||||||
|
}()
|
||||||
|
|
||||||
|
// the ping is replied to.
|
||||||
|
test.waitPacketOut(func(p *pong) {
|
||||||
|
pinghash := test.sent[0][:macSize]
|
||||||
|
if !bytes.Equal(p.ReplyTok, pinghash) {
|
||||||
|
t.Errorf("got ReplyTok %x, want %x", p.ReplyTok, pinghash)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
// remote is unknown, the table pings back.
|
||||||
|
test.waitPacketOut(func(p *ping) error { return nil })
|
||||||
|
test.packetIn(nil, pongPacket, &pong{Expiration: futureExp})
|
||||||
|
|
||||||
|
// ping should return shortly after getting the pong packet.
|
||||||
|
<-done
|
||||||
|
|
||||||
|
// check that the node was added.
|
||||||
|
rid := PubkeyID(&test.remotekey.PublicKey)
|
||||||
|
rnode := find(test.table, rid)
|
||||||
|
if rnode == nil {
|
||||||
|
t.Fatalf("node %v not found in table", rid)
|
||||||
|
}
|
||||||
|
if !bytes.Equal(rnode.IP, test.remoteaddr.IP) {
|
||||||
|
t.Errorf("node has wrong IP: got %v, want: %v", rnode.IP, test.remoteaddr.IP)
|
||||||
|
}
|
||||||
|
if rnode.DiscPort != test.remoteaddr.Port {
|
||||||
|
t.Errorf("node has wrong Port: got %v, want: %v", rnode.DiscPort, test.remoteaddr.Port)
|
||||||
|
}
|
||||||
|
if rnode.TCPPort != 99 {
|
||||||
|
t.Errorf("node has wrong Port: got %v, want: %v", rnode.TCPPort, 99)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -45,167 +274,66 @@ func find(tab *Table, id NodeID) *Node {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestUDP_findnode(t *testing.T) {
|
// dgramPipe is a fake UDP socket. It queues all sent datagrams.
|
||||||
t.Parallel()
|
type dgramPipe struct {
|
||||||
|
mu *sync.Mutex
|
||||||
|
cond *sync.Cond
|
||||||
|
closing chan struct{}
|
||||||
|
closed bool
|
||||||
|
queue [][]byte
|
||||||
|
}
|
||||||
|
|
||||||
n1, _ := ListenUDP(newkey(), "127.0.0.1:0", nil)
|
func newpipe() *dgramPipe {
|
||||||
n2, _ := ListenUDP(newkey(), "127.0.0.1:0", nil)
|
mu := new(sync.Mutex)
|
||||||
defer n1.Close()
|
return &dgramPipe{
|
||||||
defer n2.Close()
|
closing: make(chan struct{}),
|
||||||
|
cond: &sync.Cond{L: mu},
|
||||||
// put a few nodes into n2. the exact distribution shouldn't
|
mu: mu,
|
||||||
// matter much, altough we need to take care not to overflow
|
|
||||||
// any bucket.
|
|
||||||
target := randomID(n1.self.ID, 100)
|
|
||||||
nodes := &nodesByDistance{target: target}
|
|
||||||
for i := 0; i < bucketSize; i++ {
|
|
||||||
n2.add([]*Node{&Node{
|
|
||||||
IP: net.IP{1, 2, 3, byte(i)},
|
|
||||||
DiscPort: i + 2,
|
|
||||||
TCPPort: i + 2,
|
|
||||||
ID: randomID(n2.self.ID, i+2),
|
|
||||||
}})
|
|
||||||
}
|
|
||||||
n2.add(nodes.entries)
|
|
||||||
n2.bumpOrAdd(n1.self.ID, &net.UDPAddr{IP: n1.self.IP, Port: n1.self.DiscPort})
|
|
||||||
expected := n2.closest(target, bucketSize)
|
|
||||||
|
|
||||||
err := runUDP(10, func() error {
|
|
||||||
result, _ := n1.net.findnode(n2.self, target)
|
|
||||||
if len(result) != bucketSize {
|
|
||||||
return fmt.Errorf("wrong number of results: got %d, want %d", len(result), bucketSize)
|
|
||||||
}
|
|
||||||
for i := range result {
|
|
||||||
if result[i].ID != expected.entries[i].ID {
|
|
||||||
return fmt.Errorf("result mismatch at %d:\n got: %v\n want: %v", i, result[i], expected.entries[i])
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
t.Error(err)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestUDP_replytimeout(t *testing.T) {
|
// WriteToUDP queues a datagram.
|
||||||
t.Parallel()
|
func (c *dgramPipe) WriteToUDP(b []byte, to *net.UDPAddr) (n int, err error) {
|
||||||
|
msg := make([]byte, len(b))
|
||||||
// reserve a port so we don't talk to an existing service by accident
|
copy(msg, b)
|
||||||
addr, _ := net.ResolveUDPAddr("udp", "127.0.0.1:0")
|
c.mu.Lock()
|
||||||
fd, err := net.ListenUDP("udp", addr)
|
defer c.mu.Unlock()
|
||||||
if err != nil {
|
if c.closed {
|
||||||
t.Fatal(err)
|
return 0, errors.New("closed")
|
||||||
}
|
|
||||||
defer fd.Close()
|
|
||||||
|
|
||||||
n1, _ := ListenUDP(newkey(), "127.0.0.1:0", nil)
|
|
||||||
defer n1.Close()
|
|
||||||
n2 := n1.bumpOrAdd(randomID(n1.self.ID, 10), fd.LocalAddr().(*net.UDPAddr))
|
|
||||||
|
|
||||||
if err := n1.net.ping(n2); err != errTimeout {
|
|
||||||
t.Error("expected timeout error, got", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if result, err := n1.net.findnode(n2, n1.self.ID); err != errTimeout {
|
|
||||||
t.Error("expected timeout error, got", err)
|
|
||||||
} else if len(result) > 0 {
|
|
||||||
t.Error("expected empty result, got", result)
|
|
||||||
}
|
}
|
||||||
|
c.queue = append(c.queue, msg)
|
||||||
|
c.cond.Signal()
|
||||||
|
return len(b), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestUDP_findnodeMultiReply(t *testing.T) {
|
// ReadFromUDP just hangs until the pipe is closed.
|
||||||
t.Parallel()
|
func (c *dgramPipe) ReadFromUDP(b []byte) (n int, addr *net.UDPAddr, err error) {
|
||||||
|
<-c.closing
|
||||||
n1, _ := ListenUDP(newkey(), "127.0.0.1:0", nil)
|
return 0, nil, io.EOF
|
||||||
n2, _ := ListenUDP(newkey(), "127.0.0.1:0", nil)
|
|
||||||
udp2 := n2.net.(*udp)
|
|
||||||
defer n1.Close()
|
|
||||||
defer n2.Close()
|
|
||||||
|
|
||||||
err := runUDP(10, func() error {
|
|
||||||
nodes := make([]*Node, bucketSize)
|
|
||||||
for i := range nodes {
|
|
||||||
nodes[i] = &Node{
|
|
||||||
IP: net.IP{1, 2, 3, 4},
|
|
||||||
DiscPort: i + 1,
|
|
||||||
TCPPort: i + 1,
|
|
||||||
ID: randomID(n2.self.ID, i+1),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// ask N2 for neighbors. it will send an empty reply back.
|
|
||||||
// the request will wait for up to bucketSize replies.
|
|
||||||
resultc := make(chan []*Node)
|
|
||||||
errc := make(chan error)
|
|
||||||
go func() {
|
|
||||||
ns, err := n1.net.findnode(n2.self, n1.self.ID)
|
|
||||||
if err != nil {
|
|
||||||
errc <- err
|
|
||||||
} else {
|
|
||||||
resultc <- ns
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
// send a few more neighbors packets to N1.
|
|
||||||
// it should collect those.
|
|
||||||
for end := 0; end < len(nodes); {
|
|
||||||
off := end
|
|
||||||
if end = end + 5; end > len(nodes) {
|
|
||||||
end = len(nodes)
|
|
||||||
}
|
|
||||||
udp2.send(n1.self, neighborsPacket, neighbors{
|
|
||||||
Nodes: nodes[off:end],
|
|
||||||
Expiration: uint64(time.Now().Add(10 * time.Second).Unix()),
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// check that they are all returned. we cannot just check for
|
|
||||||
// equality because they might not be returned in the order they
|
|
||||||
// were sent.
|
|
||||||
var result []*Node
|
|
||||||
select {
|
|
||||||
case result = <-resultc:
|
|
||||||
case err := <-errc:
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if hasDuplicates(result) {
|
|
||||||
return fmt.Errorf("result slice contains duplicates")
|
|
||||||
}
|
|
||||||
if len(result) != len(nodes) {
|
|
||||||
return fmt.Errorf("wrong number of nodes returned: got %d, want %d", len(result), len(nodes))
|
|
||||||
}
|
|
||||||
matched := make(map[NodeID]bool)
|
|
||||||
for _, n := range result {
|
|
||||||
for _, expn := range nodes {
|
|
||||||
if n.ID == expn.ID { // && bytes.Equal(n.Addr.IP, expn.Addr.IP) && n.Addr.Port == expn.Addr.Port {
|
|
||||||
matched[n.ID] = true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if len(matched) != len(nodes) {
|
|
||||||
return fmt.Errorf("wrong number of matching nodes: got %d, want %d", len(matched), len(nodes))
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
t.Error(err)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// runUDP runs a test n times and returns an error if the test failed
|
func (c *dgramPipe) Close() error {
|
||||||
// in all n runs. This is necessary because UDP is unreliable even for
|
c.mu.Lock()
|
||||||
// connections on the local machine, causing test failures.
|
defer c.mu.Unlock()
|
||||||
func runUDP(n int, test func() error) error {
|
if !c.closed {
|
||||||
errcount := 0
|
close(c.closing)
|
||||||
errors := ""
|
c.closed = true
|
||||||
for i := 0; i < n; i++ {
|
|
||||||
if err := test(); err != nil {
|
|
||||||
errors += fmt.Sprintf("\n#%d: %v", i, err)
|
|
||||||
errcount++
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if errcount == n {
|
|
||||||
return fmt.Errorf("failed on all %d iterations:%s", n, errors)
|
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *dgramPipe) LocalAddr() net.Addr {
|
||||||
|
return &net.UDPAddr{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *dgramPipe) waitPacketOut() []byte {
|
||||||
|
c.mu.Lock()
|
||||||
|
defer c.mu.Unlock()
|
||||||
|
for len(c.queue) == 0 {
|
||||||
|
c.cond.Wait()
|
||||||
|
}
|
||||||
|
p := c.queue[0]
|
||||||
|
copy(c.queue, c.queue[1:])
|
||||||
|
c.queue = c.queue[:len(c.queue)-1]
|
||||||
|
return p
|
||||||
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user