Merge pull request #1621 from ethereum/fix-discover-hangs

p2p/discover: fix two major bugs in reply packet handling
This commit is contained in:
Jeffrey Wilcke 2015-08-11 12:17:13 -07:00
commit 05c66529b2
4 changed files with 197 additions and 83 deletions

View File

@ -78,9 +78,8 @@ type transport interface {
close() close()
} }
// bucket contains nodes, ordered by their last activity. // bucket contains nodes, ordered by their last activity. the entry
// the entry that was most recently active is the last element // that was most recently active is the first element in entries.
// in entries.
type bucket struct { type bucket struct {
lastLookup time.Time lastLookup time.Time
entries []*Node entries []*Node
@ -235,7 +234,7 @@ func (tab *Table) Lookup(targetID NodeID) []*Node {
if fails >= maxFindnodeFailures { if fails >= maxFindnodeFailures {
glog.V(logger.Detail).Infof("Evacuating node %x: %d findnode failures", n.ID[:8], fails) glog.V(logger.Detail).Infof("Evacuating node %x: %d findnode failures", n.ID[:8], fails)
tab.del(n) tab.delete(n)
} }
} }
reply <- tab.bondall(r) reply <- tab.bondall(r)
@ -401,15 +400,11 @@ func (tab *Table) bond(pinged bool, id NodeID, addr *net.UDPAddr, tcpPort uint16
node = w.n node = w.n
} }
} }
// Even if bonding temporarily failed, give the node a chance
if node != nil { if node != nil {
tab.mutex.Lock() // Add the node to the table even if the bonding ping/pong
defer tab.mutex.Unlock() // fails. It will be relaced quickly if it continues to be
// unresponsive.
b := tab.buckets[logdist(tab.self.sha, node.sha)] tab.add(node)
if !b.bump(node) {
tab.pingreplace(node, b)
}
tab.db.updateFindFails(id, 0) tab.db.updateFindFails(id, 0)
} }
return node, result return node, result
@ -420,7 +415,7 @@ func (tab *Table) pingpong(w *bondproc, pinged bool, id NodeID, addr *net.UDPAdd
<-tab.bondslots <-tab.bondslots
defer func() { tab.bondslots <- struct{}{} }() defer func() { tab.bondslots <- struct{}{} }()
// Ping the remote side and wait for a pong // Ping the remote side and wait for a pong.
if w.err = tab.ping(id, addr); w.err != nil { if w.err = tab.ping(id, addr); w.err != nil {
close(w.done) close(w.done)
return return
@ -431,33 +426,14 @@ func (tab *Table) pingpong(w *bondproc, pinged bool, id NodeID, addr *net.UDPAdd
// waitping will simply time out. // waitping will simply time out.
tab.net.waitping(id) tab.net.waitping(id)
} }
// Bonding succeeded, update the node database // Bonding succeeded, update the node database.
w.n = newNode(id, addr.IP, uint16(addr.Port), tcpPort) w.n = newNode(id, addr.IP, uint16(addr.Port), tcpPort)
tab.db.updateNode(w.n) tab.db.updateNode(w.n)
close(w.done) close(w.done)
} }
func (tab *Table) pingreplace(new *Node, b *bucket) { // ping a remote endpoint and wait for a reply, also updating the node
if len(b.entries) == bucketSize { // database accordingly.
oldest := b.entries[bucketSize-1]
if err := tab.ping(oldest.ID, oldest.addr()); err == nil {
// The node responded, we don't need to replace it.
return
}
} else {
// Add a slot at the end so the last entry doesn't
// fall off when adding the new node.
b.entries = append(b.entries, nil)
}
copy(b.entries[1:], b.entries)
b.entries[0] = new
if tab.nodeAddedHook != nil {
tab.nodeAddedHook(new)
}
}
// ping a remote endpoint and wait for a reply, also updating the node database
// accordingly.
func (tab *Table) ping(id NodeID, addr *net.UDPAddr) error { func (tab *Table) ping(id NodeID, addr *net.UDPAddr) error {
// Update the last ping and send the message // Update the last ping and send the message
tab.db.updateLastPing(id, time.Now()) tab.db.updateLastPing(id, time.Now())
@ -467,24 +443,53 @@ func (tab *Table) ping(id NodeID, addr *net.UDPAddr) error {
// Pong received, update the database and return // Pong received, update the database and return
tab.db.updateLastPong(id, time.Now()) tab.db.updateLastPong(id, time.Now())
tab.db.ensureExpirer() tab.db.ensureExpirer()
return nil return nil
} }
// add puts the entries into the table if their corresponding // add attempts to add the given node its corresponding bucket. If the
// bucket is not full. The caller must hold tab.mutex. // bucket has space available, adding the node succeeds immediately.
func (tab *Table) add(entries []*Node) { // Otherwise, the node is added if the least recently active node in
// the bucket does not respond to a ping packet.
//
// The caller must not hold tab.mutex.
func (tab *Table) add(new *Node) {
b := tab.buckets[logdist(tab.self.sha, new.sha)]
tab.mutex.Lock()
if b.bump(new) {
tab.mutex.Unlock()
return
}
var oldest *Node
if len(b.entries) == bucketSize {
oldest = b.entries[bucketSize-1]
// Let go of the mutex so other goroutines can access
// the table while we ping the least recently active node.
tab.mutex.Unlock()
if err := tab.ping(oldest.ID, oldest.addr()); err == nil {
// The node responded, don't replace it.
return
}
tab.mutex.Lock()
}
added := b.replace(new, oldest)
tab.mutex.Unlock()
if added && tab.nodeAddedHook != nil {
tab.nodeAddedHook(new)
}
}
// stuff adds nodes the table to the end of their corresponding bucket
// if the bucket is not full. The caller must hold tab.mutex.
func (tab *Table) stuff(nodes []*Node) {
outer: outer:
for _, n := range entries { for _, n := range nodes {
if n.ID == tab.self.ID { if n.ID == tab.self.ID {
// don't add self. continue // don't add self
continue
} }
bucket := tab.buckets[logdist(tab.self.sha, n.sha)] bucket := tab.buckets[logdist(tab.self.sha, n.sha)]
for i := range bucket.entries { for i := range bucket.entries {
if bucket.entries[i].ID == n.ID { if bucket.entries[i].ID == n.ID {
// already in bucket continue outer // already in bucket
continue outer
} }
} }
if len(bucket.entries) < bucketSize { if len(bucket.entries) < bucketSize {
@ -496,12 +501,11 @@ outer:
} }
} }
// del removes an entry from the node table (used to evacuate failed/non-bonded // delete removes an entry from the node table (used to evacuate
// discovery peers). // failed/non-bonded discovery peers).
func (tab *Table) del(node *Node) { func (tab *Table) delete(node *Node) {
tab.mutex.Lock() tab.mutex.Lock()
defer tab.mutex.Unlock() defer tab.mutex.Unlock()
bucket := tab.buckets[logdist(tab.self.sha, node.sha)] bucket := tab.buckets[logdist(tab.self.sha, node.sha)]
for i := range bucket.entries { for i := range bucket.entries {
if bucket.entries[i].ID == node.ID { if bucket.entries[i].ID == node.ID {
@ -511,6 +515,27 @@ func (tab *Table) del(node *Node) {
} }
} }
func (b *bucket) replace(n *Node, last *Node) bool {
// Don't add if b already contains n.
for i := range b.entries {
if b.entries[i].ID == n.ID {
return false
}
}
// Replace last if it is still the last entry or just add n if b
// isn't full. If is no longer the last entry, it has either been
// replaced with someone else or became active.
if len(b.entries) == bucketSize && (last == nil || b.entries[bucketSize-1].ID != last.ID) {
return false
}
if len(b.entries) < bucketSize {
b.entries = append(b.entries, nil)
}
copy(b.entries[1:], b.entries)
b.entries[0] = n
return true
}
func (b *bucket) bump(n *Node) bool { func (b *bucket) bump(n *Node) bool {
for i := range b.entries { for i := range b.entries {
if b.entries[i].ID == n.ID { if b.entries[i].ID == n.ID {

View File

@ -178,8 +178,8 @@ func TestTable_closest(t *testing.T) {
test := func(test *closeTest) bool { test := func(test *closeTest) bool {
// for any node table, Target and N // for any node table, Target and N
tab := newTable(nil, test.Self, &net.UDPAddr{}, "") tab := newTable(nil, test.Self, &net.UDPAddr{}, "")
tab.add(test.All)
defer tab.Close() defer tab.Close()
tab.stuff(test.All)
// check that doClosest(Target, N) returns nodes // check that doClosest(Target, N) returns nodes
result := tab.closest(test.Target, test.N).entries result := tab.closest(test.Target, test.N).entries
@ -240,7 +240,7 @@ func TestTable_ReadRandomNodesGetAll(t *testing.T) {
defer tab.Close() defer tab.Close()
for i := 0; i < len(buf); i++ { for i := 0; i < len(buf); i++ {
ld := cfg.Rand.Intn(len(tab.buckets)) ld := cfg.Rand.Intn(len(tab.buckets))
tab.add([]*Node{nodeAtDistance(tab.self.sha, ld)}) tab.stuff([]*Node{nodeAtDistance(tab.self.sha, ld)})
} }
gotN := tab.ReadRandomNodes(buf) gotN := tab.ReadRandomNodes(buf)
if gotN != tab.len() { if gotN != tab.len() {
@ -288,7 +288,7 @@ func TestTable_Lookup(t *testing.T) {
} }
// seed table with initial node (otherwise lookup will terminate immediately) // seed table with initial node (otherwise lookup will terminate immediately)
seed := newNode(lookupTestnet.dists[256][0], net.IP{}, 256, 0) seed := newNode(lookupTestnet.dists[256][0], net.IP{}, 256, 0)
tab.add([]*Node{seed}) tab.stuff([]*Node{seed})
results := tab.Lookup(lookupTestnet.target) results := tab.Lookup(lookupTestnet.target)
t.Logf("results:") t.Logf("results:")

View File

@ -18,6 +18,7 @@ package discover
import ( import (
"bytes" "bytes"
"container/list"
"crypto/ecdsa" "crypto/ecdsa"
"errors" "errors"
"fmt" "fmt"
@ -43,6 +44,7 @@ var (
errUnsolicitedReply = errors.New("unsolicited reply") errUnsolicitedReply = errors.New("unsolicited reply")
errUnknownNode = errors.New("unknown node") errUnknownNode = errors.New("unknown node")
errTimeout = errors.New("RPC timeout") errTimeout = errors.New("RPC timeout")
errClockWarp = errors.New("reply deadline too far in the future")
errClosed = errors.New("socket closed") errClosed = errors.New("socket closed")
) )
@ -296,7 +298,7 @@ func (t *udp) pending(id NodeID, ptype byte, callback func(interface{}) bool) <-
} }
func (t *udp) handleReply(from NodeID, ptype byte, req packet) bool { func (t *udp) handleReply(from NodeID, ptype byte, req packet) bool {
matched := make(chan bool) matched := make(chan bool, 1)
select { select {
case t.gotreply <- reply{from, ptype, req, matched}: case t.gotreply <- reply{from, ptype, req, matched}:
// loop will handle it // loop will handle it
@ -310,68 +312,82 @@ func (t *udp) handleReply(from NodeID, ptype byte, req packet) bool {
// the refresh timer and the pending reply queue. // the refresh timer and the pending reply queue.
func (t *udp) loop() { func (t *udp) loop() {
var ( var (
pending []*pending plist = list.New()
nextDeadline time.Time
timeout = time.NewTimer(0) timeout = time.NewTimer(0)
nextTimeout *pending // head of plist when timeout was last reset
refresh = time.NewTicker(refreshInterval) refresh = time.NewTicker(refreshInterval)
) )
<-timeout.C // ignore first timeout <-timeout.C // ignore first timeout
defer refresh.Stop() defer refresh.Stop()
defer timeout.Stop() defer timeout.Stop()
rearmTimeout := func() { resetTimeout := func() {
now := time.Now() if plist.Front() == nil || nextTimeout == plist.Front().Value {
if len(pending) == 0 || now.Before(nextDeadline) {
return return
} }
nextDeadline = pending[0].deadline // Start the timer so it fires when the next pending reply has expired.
timeout.Reset(nextDeadline.Sub(now)) now := time.Now()
for el := plist.Front(); el != nil; el = el.Next() {
nextTimeout = el.Value.(*pending)
if dist := nextTimeout.deadline.Sub(now); dist < 2*respTimeout {
timeout.Reset(dist)
return
}
// Remove pending replies whose deadline is too far in the
// future. These can occur if the system clock jumped
// backwards after the deadline was assigned.
nextTimeout.errc <- errClockWarp
plist.Remove(el)
}
nextTimeout = nil
timeout.Stop()
} }
for { for {
resetTimeout()
select { select {
case <-refresh.C: case <-refresh.C:
go t.refresh() go t.refresh()
case <-t.closing: case <-t.closing:
for _, p := range pending { for el := plist.Front(); el != nil; el = el.Next() {
p.errc <- errClosed el.Value.(*pending).errc <- errClosed
} }
pending = nil
return return
case p := <-t.addpending: case p := <-t.addpending:
p.deadline = time.Now().Add(respTimeout) p.deadline = time.Now().Add(respTimeout)
pending = append(pending, p) plist.PushBack(p)
rearmTimeout()
case r := <-t.gotreply: case r := <-t.gotreply:
var matched bool var matched bool
for i := 0; i < len(pending); i++ { for el := plist.Front(); el != nil; el = el.Next() {
if p := pending[i]; p.from == r.from && p.ptype == r.ptype { p := el.Value.(*pending)
if p.from == r.from && p.ptype == r.ptype {
matched = true matched = true
// Remove the matcher if its callback indicates
// that all replies have been received. This is
// required for packet types that expect multiple
// reply packets.
if p.callback(r.data) { if p.callback(r.data) {
// callback indicates the request is done, remove it.
p.errc <- nil p.errc <- nil
copy(pending[i:], pending[i+1:]) plist.Remove(el)
pending = pending[:len(pending)-1]
i--
} }
} }
} }
r.matched <- matched r.matched <- matched
case now := <-timeout.C: case now := <-timeout.C:
// notify and remove callbacks whose deadline is in the past. nextTimeout = nil
i := 0 // Notify and remove callbacks whose deadline is in the past.
for ; i < len(pending) && now.After(pending[i].deadline); i++ { for el := plist.Front(); el != nil; el = el.Next() {
pending[i].errc <- errTimeout p := el.Value.(*pending)
if now.After(p.deadline) || now.Equal(p.deadline) {
p.errc <- errTimeout
plist.Remove(el)
} }
if i > 0 {
copy(pending, pending[i:])
pending = pending[:len(pending)-i]
} }
rearmTimeout()
} }
} }
} }
@ -385,7 +401,7 @@ const (
var ( var (
headSpace = make([]byte, headSize) headSpace = make([]byte, headSize)
// Neighbors responses are sent across multiple packets to // Neighbors replies are sent across multiple packets to
// stay below the 1280 byte limit. We compute the maximum number // stay below the 1280 byte limit. We compute the maximum number
// of entries by stuffing a packet until it grows too large. // of entries by stuffing a packet until it grows too large.
maxNeighbors int maxNeighbors int

View File

@ -19,10 +19,12 @@ package discover
import ( import (
"bytes" "bytes"
"crypto/ecdsa" "crypto/ecdsa"
"encoding/binary"
"errors" "errors"
"fmt" "fmt"
"io" "io"
logpkg "log" logpkg "log"
"math/rand"
"net" "net"
"os" "os"
"path/filepath" "path/filepath"
@ -138,6 +140,77 @@ func TestUDP_pingTimeout(t *testing.T) {
} }
} }
func TestUDP_responseTimeouts(t *testing.T) {
t.Parallel()
test := newUDPTest(t)
defer test.table.Close()
rand.Seed(time.Now().UnixNano())
randomDuration := func(max time.Duration) time.Duration {
return time.Duration(rand.Int63n(int64(max)))
}
var (
nReqs = 200
nTimeouts = 0 // number of requests with ptype > 128
nilErr = make(chan error, nReqs) // for requests that get a reply
timeoutErr = make(chan error, nReqs) // for requests that time out
)
for i := 0; i < nReqs; i++ {
// Create a matcher for a random request in udp.loop. Requests
// with ptype <= 128 will not get a reply and should time out.
// For all other requests, a reply is scheduled to arrive
// within the timeout window.
p := &pending{
ptype: byte(rand.Intn(255)),
callback: func(interface{}) bool { return true },
}
binary.BigEndian.PutUint64(p.from[:], uint64(i))
if p.ptype <= 128 {
p.errc = timeoutErr
nTimeouts++
} else {
p.errc = nilErr
time.AfterFunc(randomDuration(60*time.Millisecond), func() {
if !test.udp.handleReply(p.from, p.ptype, nil) {
t.Logf("not matched: %v", p)
}
})
}
test.udp.addpending <- p
time.Sleep(randomDuration(30 * time.Millisecond))
}
// Check that all timeouts were delivered and that the rest got nil errors.
// The replies must be delivered.
var (
recvDeadline = time.After(20 * time.Second)
nTimeoutsRecv, nNil = 0, 0
)
for i := 0; i < nReqs; i++ {
select {
case err := <-timeoutErr:
if err != errTimeout {
t.Fatalf("got non-timeout error on timeoutErr %d: %v", i, err)
}
nTimeoutsRecv++
case err := <-nilErr:
if err != nil {
t.Fatalf("got non-nil error on nilErr %d: %v", i, err)
}
nNil++
case <-recvDeadline:
t.Fatalf("exceeded recv deadline")
}
}
if nTimeoutsRecv != nTimeouts {
t.Errorf("wrong number of timeout errors received: got %d, want %d", nTimeoutsRecv, nTimeouts)
}
if nNil != nReqs-nTimeouts {
t.Errorf("wrong number of successful replies: got %d, want %d", nNil, nReqs-nTimeouts)
}
}
func TestUDP_findnodeTimeout(t *testing.T) { func TestUDP_findnodeTimeout(t *testing.T) {
t.Parallel() t.Parallel()
test := newUDPTest(t) test := newUDPTest(t)
@ -167,7 +240,7 @@ func TestUDP_findnode(t *testing.T) {
for i := 0; i < bucketSize; i++ { for i := 0; i < bucketSize; i++ {
nodes.push(nodeAtDistance(test.table.self.sha, i+2), bucketSize) nodes.push(nodeAtDistance(test.table.self.sha, i+2), bucketSize)
} }
test.table.add(nodes.entries) test.table.stuff(nodes.entries)
// ensure there's a bond with the test node, // ensure there's a bond with the test node,
// findnode won't be accepted otherwise. // findnode won't be accepted otherwise.