Merge pull request #2510 from fjl/p2p-fixups

p2p/discover: prevent bonding self
This commit is contained in:
Jeffrey Wilcke 2016-05-03 13:30:51 +02:00
commit 258cc73ea9
3 changed files with 82 additions and 29 deletions

View File

@ -25,6 +25,7 @@ package discover
import ( import (
"crypto/rand" "crypto/rand"
"encoding/binary" "encoding/binary"
"errors"
"fmt" "fmt"
"net" "net"
"sort" "sort"
@ -457,6 +458,9 @@ func (tab *Table) bondall(nodes []*Node) (result []*Node) {
// If pinged is true, the remote node has just pinged us and one half // If pinged is true, the remote node has just pinged us and one half
// of the process can be skipped. // of the process can be skipped.
func (tab *Table) bond(pinged bool, id NodeID, addr *net.UDPAddr, tcpPort uint16) (*Node, error) { func (tab *Table) bond(pinged bool, id NodeID, addr *net.UDPAddr, tcpPort uint16) (*Node, error) {
if id == tab.self.ID {
return nil, errors.New("is self")
}
// Retrieve a previously known node and any recent findnode failures // Retrieve a previously known node and any recent findnode failures
node, fails := tab.db.node(id), 0 node, fails := tab.db.node(id), 0
if node != nil { if node != nil {

View File

@ -400,10 +400,9 @@ func (srv *Server) run(dialstate dialer) {
var ( var (
peers = make(map[discover.NodeID]*Peer) peers = make(map[discover.NodeID]*Peer)
trusted = make(map[discover.NodeID]bool, len(srv.TrustedNodes)) trusted = make(map[discover.NodeID]bool, len(srv.TrustedNodes))
tasks []task
pendingTasks []task
taskdone = make(chan task, maxActiveDialTasks) taskdone = make(chan task, maxActiveDialTasks)
runningTasks []task
queuedTasks []task // tasks that can't run yet
) )
// Put trusted nodes into a map to speed up checks. // Put trusted nodes into a map to speed up checks.
// Trusted peers are loaded on startup and cannot be // Trusted peers are loaded on startup and cannot be
@ -412,39 +411,39 @@ func (srv *Server) run(dialstate dialer) {
trusted[n.ID] = true trusted[n.ID] = true
} }
// Some task list helpers. // removes t from runningTasks
delTask := func(t task) { delTask := func(t task) {
for i := range tasks { for i := range runningTasks {
if tasks[i] == t { if runningTasks[i] == t {
tasks = append(tasks[:i], tasks[i+1:]...) runningTasks = append(runningTasks[:i], runningTasks[i+1:]...)
break break
} }
} }
} }
scheduleTasks := func(new []task) { // starts until max number of active tasks is satisfied
pt := append(pendingTasks, new...) startTasks := func(ts []task) (rest []task) {
start := maxActiveDialTasks - len(tasks) i := 0
if len(pt) < start { for ; len(runningTasks) < maxActiveDialTasks && i < len(ts); i++ {
start = len(pt) t := ts[i]
}
if start > 0 {
tasks = append(tasks, pt[:start]...)
for _, t := range pt[:start] {
t := t
glog.V(logger.Detail).Infoln("new task:", t) glog.V(logger.Detail).Infoln("new task:", t)
go func() { t.Do(srv); taskdone <- t }() go func() { t.Do(srv); taskdone <- t }()
runningTasks = append(runningTasks, t)
} }
copy(pt, pt[start:]) return ts[i:]
pendingTasks = pt[:len(pt)-start] }
scheduleTasks := func() {
// Start from queue first.
queuedTasks = append(queuedTasks[:0], startTasks(queuedTasks)...)
// Query dialer for new tasks and start as many as possible now.
if len(runningTasks) < maxActiveDialTasks {
nt := dialstate.newTasks(len(runningTasks)+len(queuedTasks), peers, time.Now())
queuedTasks = append(queuedTasks, startTasks(nt)...)
} }
} }
running: running:
for { for {
// Query the dialer for new tasks and launch them. scheduleTasks()
now := time.Now()
nt := dialstate.newTasks(len(pendingTasks)+len(tasks), peers, now)
scheduleTasks(nt)
select { select {
case <-srv.quit: case <-srv.quit:
@ -466,7 +465,7 @@ running:
// can update its state and remove it from the active // can update its state and remove it from the active
// tasks list. // tasks list.
glog.V(logger.Detail).Infoln("<-taskdone:", t) glog.V(logger.Detail).Infoln("<-taskdone:", t)
dialstate.taskDone(t, now) dialstate.taskDone(t, time.Now())
delTask(t) delTask(t)
case c := <-srv.posthandshake: case c := <-srv.posthandshake:
// A connection has passed the encryption handshake so // A connection has passed the encryption handshake so
@ -513,7 +512,7 @@ running:
// Wait for peers to shut down. Pending connections and tasks are // Wait for peers to shut down. Pending connections and tasks are
// not handled here and will terminate soon-ish because srv.quit // not handled here and will terminate soon-ish because srv.quit
// is closed. // is closed.
glog.V(logger.Detail).Infof("ignoring %d pending tasks at spindown", len(tasks)) glog.V(logger.Detail).Infof("ignoring %d pending tasks at spindown", len(runningTasks))
for len(peers) > 0 { for len(peers) > 0 {
p := <-srv.delpeer p := <-srv.delpeer
glog.V(logger.Detail).Infoln("<-delpeer (spindown):", p) glog.V(logger.Detail).Infoln("<-delpeer (spindown):", p)

View File

@ -235,6 +235,56 @@ func TestServerTaskScheduling(t *testing.T) {
} }
} }
// This test checks that Server doesn't drop tasks,
// even if newTasks returns more than the maximum number of tasks.
func TestServerManyTasks(t *testing.T) {
alltasks := make([]task, 300)
for i := range alltasks {
alltasks[i] = &testTask{index: i}
}
var (
srv = &Server{quit: make(chan struct{}), ntab: fakeTable{}, running: true}
done = make(chan *testTask)
start, end = 0, 0
)
defer srv.Stop()
srv.loopWG.Add(1)
go srv.run(taskgen{
newFunc: func(running int, peers map[discover.NodeID]*Peer) []task {
start, end = end, end+maxActiveDialTasks+10
if end > len(alltasks) {
end = len(alltasks)
}
return alltasks[start:end]
},
doneFunc: func(tt task) {
done <- tt.(*testTask)
},
})
doneset := make(map[int]bool)
timeout := time.After(2 * time.Second)
for len(doneset) < len(alltasks) {
select {
case tt := <-done:
if doneset[tt.index] {
t.Errorf("task %d got done more than once", tt.index)
} else {
doneset[tt.index] = true
}
case <-timeout:
t.Errorf("%d of %d tasks got done within 2s", len(doneset), len(alltasks))
for i := 0; i < len(alltasks); i++ {
if !doneset[i] {
t.Logf("task %d not done", i)
}
}
return
}
}
}
type taskgen struct { type taskgen struct {
newFunc func(running int, peers map[discover.NodeID]*Peer) []task newFunc func(running int, peers map[discover.NodeID]*Peer) []task
doneFunc func(task) doneFunc func(task)