diff --git a/p2p/discover/node.go b/p2p/discover/node.go index 7ddf04fe8..8d4af166b 100644 --- a/p2p/discover/node.go +++ b/p2p/discover/node.go @@ -33,7 +33,8 @@ import ( // The fields of Node may not be modified. type node struct { enode.Node - addedAt time.Time // time when the node was added to the table + addedAt time.Time // time when the node was added to the table + livenessChecks uint // how often liveness was checked } type encPubkey [64]byte diff --git a/p2p/discover/table.go b/p2p/discover/table.go index 9f7f1d41b..ba4c06327 100644 --- a/p2p/discover/table.go +++ b/p2p/discover/table.go @@ -75,8 +75,10 @@ type Table struct { net transport refreshReq chan chan struct{} initDone chan struct{} - closeReq chan struct{} - closed chan struct{} + + closeOnce sync.Once + closeReq chan struct{} + closed chan struct{} nodeAddedHook func(*node) // for testing } @@ -180,16 +182,14 @@ func (tab *Table) ReadRandomNodes(buf []*enode.Node) (n int) { // Close terminates the network listener and flushes the node database. func (tab *Table) Close() { - if tab.net != nil { - tab.net.close() - } - - select { - case <-tab.closed: - // already closed. - case tab.closeReq <- struct{}{}: - <-tab.closed // wait for refreshLoop to end. - } + tab.closeOnce.Do(func() { + if tab.net != nil { + tab.net.close() + } + // Wait for loop to end. + close(tab.closeReq) + <-tab.closed + }) } // setFallbackNodes sets the initial points of contact. These nodes @@ -290,12 +290,16 @@ func (tab *Table) lookup(targetKey encPubkey, refreshIfEmpty bool) []*node { // we have asked all closest nodes, stop the search break } - // wait for the next reply - for _, n := range <-reply { - if n != nil && !seen[n.ID()] { - seen[n.ID()] = true - result.push(n, bucketSize) + select { + case nodes := <-reply: + for _, n := range nodes { + if n != nil && !seen[n.ID()] { + seen[n.ID()] = true + result.push(n, bucketSize) + } } + case <-tab.closeReq: + return nil // shutdown, no need to continue. } pendingQueries-- } @@ -303,18 +307,22 @@ func (tab *Table) lookup(targetKey encPubkey, refreshIfEmpty bool) []*node { } func (tab *Table) findnode(n *node, targetKey encPubkey, reply chan<- []*node) { - fails := tab.db.FindFails(n.ID()) + fails := tab.db.FindFails(n.ID(), n.IP()) r, err := tab.net.findnode(n.ID(), n.addr(), targetKey) - if err != nil || len(r) == 0 { + if err == errClosed { + // Avoid recording failures on shutdown. + reply <- nil + return + } else if err != nil || len(r) == 0 { fails++ - tab.db.UpdateFindFails(n.ID(), fails) + tab.db.UpdateFindFails(n.ID(), n.IP(), fails) log.Trace("Findnode failed", "id", n.ID(), "failcount", fails, "err", err) if fails >= maxFindnodeFailures { log.Trace("Too many findnode failures, dropping", "id", n.ID(), "failcount", fails) tab.delete(n) } } else if fails > 0 { - tab.db.UpdateFindFails(n.ID(), fails-1) + tab.db.UpdateFindFails(n.ID(), n.IP(), fails-1) } // Grab as many nodes as possible. Some of them might not be alive anymore, but we'll @@ -329,7 +337,7 @@ func (tab *Table) refresh() <-chan struct{} { done := make(chan struct{}) select { case tab.refreshReq <- done: - case <-tab.closed: + case <-tab.closeReq: close(done) } return done @@ -433,7 +441,7 @@ func (tab *Table) loadSeedNodes() { seeds = append(seeds, tab.nursery...) for i := range seeds { seed := seeds[i] - age := log.Lazy{Fn: func() interface{} { return time.Since(tab.db.LastPongReceived(seed.ID())) }} + age := log.Lazy{Fn: func() interface{} { return time.Since(tab.db.LastPongReceived(seed.ID(), seed.IP())) }} log.Trace("Found seed node in database", "id", seed.ID(), "addr", seed.addr(), "age", age) tab.add(seed) } @@ -458,16 +466,17 @@ func (tab *Table) doRevalidate(done chan<- struct{}) { b := tab.buckets[bi] if err == nil { // The node responded, move it to the front. - log.Debug("Revalidated node", "b", bi, "id", last.ID()) + last.livenessChecks++ + log.Debug("Revalidated node", "b", bi, "id", last.ID(), "checks", last.livenessChecks) b.bump(last) return } // No reply received, pick a replacement or delete the node if there aren't // any replacements. if r := tab.replace(b, last); r != nil { - log.Debug("Replaced dead node", "b", bi, "id", last.ID(), "ip", last.IP(), "r", r.ID(), "rip", r.IP()) + log.Debug("Replaced dead node", "b", bi, "id", last.ID(), "ip", last.IP(), "checks", last.livenessChecks, "r", r.ID(), "rip", r.IP()) } else { - log.Debug("Removed dead node", "b", bi, "id", last.ID(), "ip", last.IP()) + log.Debug("Removed dead node", "b", bi, "id", last.ID(), "ip", last.IP(), "checks", last.livenessChecks) } } @@ -502,7 +511,7 @@ func (tab *Table) copyLiveNodes() { now := time.Now() for _, b := range &tab.buckets { for _, n := range b.entries { - if now.Sub(n.addedAt) >= seedMinTableTime { + if n.livenessChecks > 0 && now.Sub(n.addedAt) >= seedMinTableTime { tab.db.UpdateNode(unwrapNode(n)) } } @@ -518,7 +527,9 @@ func (tab *Table) closest(target enode.ID, nresults int) *nodesByDistance { close := &nodesByDistance{target: target} for _, b := range &tab.buckets { for _, n := range b.entries { - close.push(n, nresults) + if n.livenessChecks > 0 { + close.push(n, nresults) + } } } return close @@ -572,23 +583,6 @@ func (tab *Table) addThroughPing(n *node) { tab.add(n) } -// stuff adds nodes the table to the end of their corresponding bucket -// if the bucket is not full. The caller must not hold tab.mutex. -func (tab *Table) stuff(nodes []*node) { - tab.mutex.Lock() - defer tab.mutex.Unlock() - - for _, n := range nodes { - if n.ID() == tab.self().ID() { - continue // don't add self - } - b := tab.bucket(n.ID()) - if len(b.entries) < bucketSize { - tab.bumpOrAdd(b, n) - } - } -} - // delete removes an entry from the node table. It is used to evacuate dead nodes. func (tab *Table) delete(node *node) { tab.mutex.Lock() diff --git a/p2p/discover/table_test.go b/p2p/discover/table_test.go index 6b4cd2d18..b00a93211 100644 --- a/p2p/discover/table_test.go +++ b/p2p/discover/table_test.go @@ -50,8 +50,8 @@ func TestTable_pingReplace(t *testing.T) { func testPingReplace(t *testing.T, newNodeIsResponding, lastInBucketIsResponding bool) { transport := newPingRecorder() tab, db := newTestTable(transport) - defer tab.Close() defer db.Close() + defer tab.Close() <-tab.initDone @@ -137,8 +137,8 @@ func TestBucket_bumpNoDuplicates(t *testing.T) { func TestTable_IPLimit(t *testing.T) { transport := newPingRecorder() tab, db := newTestTable(transport) - defer tab.Close() defer db.Close() + defer tab.Close() for i := 0; i < tableIPLimit+1; i++ { n := nodeAtDistance(tab.self().ID(), i, net.IP{172, 0, 1, byte(i)}) @@ -153,8 +153,8 @@ func TestTable_IPLimit(t *testing.T) { func TestTable_BucketIPLimit(t *testing.T) { transport := newPingRecorder() tab, db := newTestTable(transport) - defer tab.Close() defer db.Close() + defer tab.Close() d := 3 for i := 0; i < bucketIPLimit+1; i++ { @@ -173,9 +173,9 @@ func TestTable_closest(t *testing.T) { // for any node table, Target and N transport := newPingRecorder() tab, db := newTestTable(transport) - defer tab.Close() defer db.Close() - tab.stuff(test.All) + defer tab.Close() + fillTable(tab, test.All) // check that closest(Target, N) returns nodes result := tab.closest(test.Target, test.N).entries @@ -234,13 +234,13 @@ func TestTable_ReadRandomNodesGetAll(t *testing.T) { test := func(buf []*enode.Node) bool { transport := newPingRecorder() tab, db := newTestTable(transport) - defer tab.Close() defer db.Close() + defer tab.Close() <-tab.initDone for i := 0; i < len(buf); i++ { ld := cfg.Rand.Intn(len(tab.buckets)) - tab.stuff([]*node{nodeAtDistance(tab.self().ID(), ld, intIP(ld))}) + fillTable(tab, []*node{nodeAtDistance(tab.self().ID(), ld, intIP(ld))}) } gotN := tab.ReadRandomNodes(buf) if gotN != tab.len() { @@ -272,16 +272,19 @@ func (*closeTest) Generate(rand *rand.Rand, size int) reflect.Value { N: rand.Intn(bucketSize), } for _, id := range gen([]enode.ID{}, rand).([]enode.ID) { - n := enode.SignNull(new(enr.Record), id) - t.All = append(t.All, wrapNode(n)) + r := new(enr.Record) + r.Set(enr.IP(genIP(rand))) + n := wrapNode(enode.SignNull(r, id)) + n.livenessChecks = 1 + t.All = append(t.All, n) } return reflect.ValueOf(t) } func TestTable_Lookup(t *testing.T) { tab, db := newTestTable(lookupTestnet) - defer tab.Close() defer db.Close() + defer tab.Close() // lookup on empty table returns no nodes if results := tab.lookup(lookupTestnet.target, false); len(results) > 0 { @@ -289,8 +292,9 @@ func TestTable_Lookup(t *testing.T) { } // seed table with initial node (otherwise lookup will terminate immediately) seedKey, _ := decodePubkey(lookupTestnet.dists[256][0]) - seed := wrapNode(enode.NewV4(seedKey, net.IP{}, 0, 256)) - tab.stuff([]*node{seed}) + seed := wrapNode(enode.NewV4(seedKey, net.IP{127, 0, 0, 1}, 0, 256)) + seed.livenessChecks = 1 + fillTable(tab, []*node{seed}) results := tab.lookup(lookupTestnet.target, true) t.Logf("results:") @@ -578,6 +582,12 @@ func gen(typ interface{}, rand *rand.Rand) interface{} { return v.Interface() } +func genIP(rand *rand.Rand) net.IP { + ip := make(net.IP, 4) + rand.Read(ip) + return ip +} + func quickcfg() *quick.Config { return &quick.Config{ MaxCount: 5000, diff --git a/p2p/discover/table_util_test.go b/p2p/discover/table_util_test.go index d41519452..3ce582b99 100644 --- a/p2p/discover/table_util_test.go +++ b/p2p/discover/table_util_test.go @@ -83,6 +83,23 @@ func fillBucket(tab *Table, n *node) (last *node) { return b.entries[bucketSize-1] } +// fillTable adds nodes the table to the end of their corresponding bucket +// if the bucket is not full. The caller must not hold tab.mutex. +func fillTable(tab *Table, nodes []*node) { + tab.mutex.Lock() + defer tab.mutex.Unlock() + + for _, n := range nodes { + if n.ID() == tab.self().ID() { + continue // don't add self + } + b := tab.bucket(n.ID()) + if len(b.entries) < bucketSize { + tab.bumpOrAdd(b, n) + } + } +} + type pingRecorder struct { mu sync.Mutex dead, pinged map[enode.ID]bool @@ -109,10 +126,6 @@ func (t *pingRecorder) findnode(toid enode.ID, toaddr *net.UDPAddr, target encPu return nil, nil } -func (t *pingRecorder) waitping(from enode.ID) error { - return nil // remote always pings -} - func (t *pingRecorder) ping(toid enode.ID, toaddr *net.UDPAddr) error { t.mu.Lock() defer t.mu.Unlock() diff --git a/p2p/discover/udp.go b/p2p/discover/udp.go index 37a044902..5ce4c43dc 100644 --- a/p2p/discover/udp.go +++ b/p2p/discover/udp.go @@ -67,6 +67,8 @@ const ( // RPC request structures type ( ping struct { + senderKey *ecdsa.PublicKey // filled in by preverify + Version uint From, To rpcEndpoint Expiration uint64 @@ -155,8 +157,13 @@ func nodeToRPC(n *node) rpcNode { return rpcNode{ID: ekey, IP: n.IP(), UDP: uint16(n.UDP()), TCP: uint16(n.TCP())} } +// packet is implemented by all protocol messages. type packet interface { - handle(t *udp, from *net.UDPAddr, fromKey encPubkey, mac []byte) error + // preverify checks whether the packet is valid and should be handled at all. + preverify(t *udp, from *net.UDPAddr, fromID enode.ID, fromKey encPubkey) error + // handle handles the packet. + handle(t *udp, from *net.UDPAddr, fromID enode.ID, mac []byte) + // name returns the name of the packet for logging purposes. name() string } @@ -177,43 +184,48 @@ type udp struct { tab *Table wg sync.WaitGroup - addpending chan *pending - gotreply chan reply - closing chan struct{} + addReplyMatcher chan *replyMatcher + gotreply chan reply + closing chan struct{} } // pending represents a pending reply. // -// some implementations of the protocol wish to send more than one -// reply packet to findnode. in general, any neighbors packet cannot +// Some implementations of the protocol wish to send more than one +// reply packet to findnode. In general, any neighbors packet cannot // be matched up with a specific findnode packet. // -// our implementation handles this by storing a callback function for -// each pending reply. incoming packets from a node are dispatched -// to all the callback functions for that node. -type pending struct { +// Our implementation handles this by storing a callback function for +// each pending reply. Incoming packets from a node are dispatched +// to all callback functions for that node. +type replyMatcher struct { // these fields must match in the reply. from enode.ID + ip net.IP ptype byte // time when the request must complete deadline time.Time - // callback is called when a matching reply arrives. if it returns - // true, the callback is removed from the pending reply queue. - // if it returns false, the reply is considered incomplete and - // the callback will be invoked again for the next matching reply. - callback func(resp interface{}) (done bool) + // callback is called when a matching reply arrives. If it returns matched == true, the + // reply was acceptable. The second return value indicates whether the callback should + // be removed from the pending reply queue. If it returns false, the reply is considered + // incomplete and the callback will be invoked again for the next matching reply. + callback replyMatchFunc // errc receives nil when the callback indicates completion or an // error if no further reply is received within the timeout. errc chan<- error } +type replyMatchFunc func(interface{}) (matched bool, requestDone bool) + type reply struct { from enode.ID + ip net.IP ptype byte - data interface{} + data packet + // loop indicates whether there was // a matching request by sending on this channel. matched chan<- bool @@ -247,14 +259,14 @@ func ListenUDP(c conn, ln *enode.LocalNode, cfg Config) (*Table, error) { func newUDP(c conn, ln *enode.LocalNode, cfg Config) (*Table, *udp, error) { udp := &udp{ - conn: c, - priv: cfg.PrivateKey, - netrestrict: cfg.NetRestrict, - localNode: ln, - db: ln.Database(), - closing: make(chan struct{}), - gotreply: make(chan reply), - addpending: make(chan *pending), + conn: c, + priv: cfg.PrivateKey, + netrestrict: cfg.NetRestrict, + localNode: ln, + db: ln.Database(), + closing: make(chan struct{}), + gotreply: make(chan reply), + addReplyMatcher: make(chan *replyMatcher), } tab, err := newTable(udp, ln.Database(), cfg.Bootnodes) if err != nil { @@ -304,35 +316,37 @@ func (t *udp) sendPing(toid enode.ID, toaddr *net.UDPAddr, callback func()) <-ch errc <- err return errc } - errc := t.pending(toid, pongPacket, func(p interface{}) bool { - ok := bytes.Equal(p.(*pong).ReplyTok, hash) - if ok && callback != nil { + // Add a matcher for the reply to the pending reply queue. Pongs are matched if they + // reference the ping we're about to send. + errc := t.pending(toid, toaddr.IP, pongPacket, func(p interface{}) (matched bool, requestDone bool) { + matched = bytes.Equal(p.(*pong).ReplyTok, hash) + if matched && callback != nil { callback() } - return ok + return matched, matched }) + // Send the packet. t.localNode.UDPContact(toaddr) - t.write(toaddr, req.name(), packet) + t.write(toaddr, toid, req.name(), packet) return errc } -func (t *udp) waitping(from enode.ID) error { - return <-t.pending(from, pingPacket, func(interface{}) bool { return true }) -} - // findnode sends a findnode request to the given node and waits until // the node has sent up to k neighbors. func (t *udp) findnode(toid enode.ID, toaddr *net.UDPAddr, target encPubkey) ([]*node, error) { // If we haven't seen a ping from the destination node for a while, it won't remember // our endpoint proof and reject findnode. Solicit a ping first. - if time.Since(t.db.LastPingReceived(toid)) > bondExpiration { + if time.Since(t.db.LastPingReceived(toid, toaddr.IP)) > bondExpiration { t.ping(toid, toaddr) - t.waitping(toid) + // Wait for them to ping back and process our pong. + time.Sleep(respTimeout) } + // Add a matcher for 'neighbours' replies to the pending reply queue. The matcher is + // active until enough nodes have been received. nodes := make([]*node, 0, bucketSize) nreceived := 0 - errc := t.pending(toid, neighborsPacket, func(r interface{}) bool { + errc := t.pending(toid, toaddr.IP, neighborsPacket, func(r interface{}) (matched bool, requestDone bool) { reply := r.(*neighbors) for _, rn := range reply.Nodes { nreceived++ @@ -343,22 +357,22 @@ func (t *udp) findnode(toid enode.ID, toaddr *net.UDPAddr, target encPubkey) ([] } nodes = append(nodes, n) } - return nreceived >= bucketSize + return true, nreceived >= bucketSize }) - t.send(toaddr, findnodePacket, &findnode{ + t.send(toaddr, toid, findnodePacket, &findnode{ Target: target, Expiration: uint64(time.Now().Add(expiration).Unix()), }) return nodes, <-errc } -// pending adds a reply callback to the pending reply queue. -// see the documentation of type pending for a detailed explanation. -func (t *udp) pending(id enode.ID, ptype byte, callback func(interface{}) bool) <-chan error { +// pending adds a reply matcher to the pending reply queue. +// see the documentation of type replyMatcher for a detailed explanation. +func (t *udp) pending(id enode.ID, ip net.IP, ptype byte, callback replyMatchFunc) <-chan error { ch := make(chan error, 1) - p := &pending{from: id, ptype: ptype, callback: callback, errc: ch} + p := &replyMatcher{from: id, ip: ip, ptype: ptype, callback: callback, errc: ch} select { - case t.addpending <- p: + case t.addReplyMatcher <- p: // loop will handle it case <-t.closing: ch <- errClosed @@ -366,10 +380,12 @@ func (t *udp) pending(id enode.ID, ptype byte, callback func(interface{}) bool) return ch } -func (t *udp) handleReply(from enode.ID, ptype byte, req packet) bool { +// handleReply dispatches a reply packet, invoking reply matchers. It returns +// whether any matcher considered the packet acceptable. +func (t *udp) handleReply(from enode.ID, fromIP net.IP, ptype byte, req packet) bool { matched := make(chan bool, 1) select { - case t.gotreply <- reply{from, ptype, req, matched}: + case t.gotreply <- reply{from, fromIP, ptype, req, matched}: // loop will handle it return <-matched case <-t.closing: @@ -385,8 +401,8 @@ func (t *udp) loop() { var ( plist = list.New() timeout = time.NewTimer(0) - nextTimeout *pending // head of plist when timeout was last reset - contTimeouts = 0 // number of continuous timeouts to do NTP checks + nextTimeout *replyMatcher // head of plist when timeout was last reset + contTimeouts = 0 // number of continuous timeouts to do NTP checks ntpWarnTime = time.Unix(0, 0) ) <-timeout.C // ignore first timeout @@ -399,7 +415,7 @@ func (t *udp) loop() { // Start the timer so it fires when the next pending reply has expired. now := time.Now() for el := plist.Front(); el != nil; el = el.Next() { - nextTimeout = el.Value.(*pending) + nextTimeout = el.Value.(*replyMatcher) if dist := nextTimeout.deadline.Sub(now); dist < 2*respTimeout { timeout.Reset(dist) return @@ -420,25 +436,23 @@ func (t *udp) loop() { select { case <-t.closing: for el := plist.Front(); el != nil; el = el.Next() { - el.Value.(*pending).errc <- errClosed + el.Value.(*replyMatcher).errc <- errClosed } return - case p := <-t.addpending: + case p := <-t.addReplyMatcher: p.deadline = time.Now().Add(respTimeout) plist.PushBack(p) case r := <-t.gotreply: - var matched bool + var matched bool // whether any replyMatcher considered the reply acceptable. for el := plist.Front(); el != nil; el = el.Next() { - p := el.Value.(*pending) - if p.from == r.from && p.ptype == r.ptype { - matched = true - // Remove the matcher if its callback indicates - // that all replies have been received. This is - // required for packet types that expect multiple - // reply packets. - if p.callback(r.data) { + p := el.Value.(*replyMatcher) + if p.from == r.from && p.ptype == r.ptype && p.ip.Equal(r.ip) { + ok, requestDone := p.callback(r.data) + matched = matched || ok + // Remove the matcher if callback indicates that all replies have been received. + if requestDone { p.errc <- nil plist.Remove(el) } @@ -453,7 +467,7 @@ func (t *udp) loop() { // Notify and remove callbacks whose deadline is in the past. for el := plist.Front(); el != nil; el = el.Next() { - p := el.Value.(*pending) + p := el.Value.(*replyMatcher) if now.After(p.deadline) || now.Equal(p.deadline) { p.errc <- errTimeout plist.Remove(el) @@ -504,17 +518,17 @@ func init() { } } -func (t *udp) send(toaddr *net.UDPAddr, ptype byte, req packet) ([]byte, error) { +func (t *udp) send(toaddr *net.UDPAddr, toid enode.ID, ptype byte, req packet) ([]byte, error) { packet, hash, err := encodePacket(t.priv, ptype, req) if err != nil { return hash, err } - return hash, t.write(toaddr, req.name(), packet) + return hash, t.write(toaddr, toid, req.name(), packet) } -func (t *udp) write(toaddr *net.UDPAddr, what string, packet []byte) error { +func (t *udp) write(toaddr *net.UDPAddr, toid enode.ID, what string, packet []byte) error { _, err := t.conn.WriteToUDP(packet, toaddr) - log.Trace(">> "+what, "addr", toaddr, "err", err) + log.Trace(">> "+what, "id", toid, "addr", toaddr, "err", err) return err } @@ -573,13 +587,19 @@ func (t *udp) readLoop(unhandled chan<- ReadPacket) { } func (t *udp) handlePacket(from *net.UDPAddr, buf []byte) error { - packet, fromID, hash, err := decodePacket(buf) + packet, fromKey, hash, err := decodePacket(buf) if err != nil { log.Debug("Bad discv4 packet", "addr", from, "err", err) return err } - err = packet.handle(t, from, fromID, hash) - log.Trace("<< "+packet.name(), "addr", from, "err", err) + fromID := fromKey.id() + if err == nil { + err = packet.preverify(t, from, fromID, fromKey) + } + log.Trace("<< "+packet.name(), "id", fromID, "addr", from, "err", err) + if err == nil { + packet.handle(t, from, fromID, hash) + } return err } @@ -615,54 +635,67 @@ func decodePacket(buf []byte) (packet, encPubkey, []byte, error) { return req, fromKey, hash, err } -func (req *ping) handle(t *udp, from *net.UDPAddr, fromKey encPubkey, mac []byte) error { +// Packet Handlers + +func (req *ping) preverify(t *udp, from *net.UDPAddr, fromID enode.ID, fromKey encPubkey) error { if expired(req.Expiration) { return errExpired } key, err := decodePubkey(fromKey) if err != nil { - return fmt.Errorf("invalid public key: %v", err) + return errors.New("invalid public key") } - t.send(from, pongPacket, &pong{ + req.senderKey = key + return nil +} + +func (req *ping) handle(t *udp, from *net.UDPAddr, fromID enode.ID, mac []byte) { + // Reply. + t.send(from, fromID, pongPacket, &pong{ To: makeEndpoint(from, req.From.TCP), ReplyTok: mac, Expiration: uint64(time.Now().Add(expiration).Unix()), }) - n := wrapNode(enode.NewV4(key, from.IP, int(req.From.TCP), from.Port)) - t.handleReply(n.ID(), pingPacket, req) - if time.Since(t.db.LastPongReceived(n.ID())) > bondExpiration { - t.sendPing(n.ID(), from, func() { t.tab.addThroughPing(n) }) + + // Ping back if our last pong on file is too far in the past. + n := wrapNode(enode.NewV4(req.senderKey, from.IP, int(req.From.TCP), from.Port)) + if time.Since(t.db.LastPongReceived(n.ID(), from.IP)) > bondExpiration { + t.sendPing(fromID, from, func() { + t.tab.addThroughPing(n) + }) } else { t.tab.addThroughPing(n) } + + // Update node database and endpoint predictor. + t.db.UpdateLastPingReceived(n.ID(), from.IP, time.Now()) t.localNode.UDPEndpointStatement(from, &net.UDPAddr{IP: req.To.IP, Port: int(req.To.UDP)}) - t.db.UpdateLastPingReceived(n.ID(), time.Now()) - return nil } func (req *ping) name() string { return "PING/v4" } -func (req *pong) handle(t *udp, from *net.UDPAddr, fromKey encPubkey, mac []byte) error { +func (req *pong) preverify(t *udp, from *net.UDPAddr, fromID enode.ID, fromKey encPubkey) error { if expired(req.Expiration) { return errExpired } - fromID := fromKey.id() - if !t.handleReply(fromID, pongPacket, req) { + if !t.handleReply(fromID, from.IP, pongPacket, req) { return errUnsolicitedReply } - t.localNode.UDPEndpointStatement(from, &net.UDPAddr{IP: req.To.IP, Port: int(req.To.UDP)}) - t.db.UpdateLastPongReceived(fromID, time.Now()) return nil } +func (req *pong) handle(t *udp, from *net.UDPAddr, fromID enode.ID, mac []byte) { + t.localNode.UDPEndpointStatement(from, &net.UDPAddr{IP: req.To.IP, Port: int(req.To.UDP)}) + t.db.UpdateLastPongReceived(fromID, from.IP, time.Now()) +} + func (req *pong) name() string { return "PONG/v4" } -func (req *findnode) handle(t *udp, from *net.UDPAddr, fromKey encPubkey, mac []byte) error { +func (req *findnode) preverify(t *udp, from *net.UDPAddr, fromID enode.ID, fromKey encPubkey) error { if expired(req.Expiration) { return errExpired } - fromID := fromKey.id() - if time.Since(t.db.LastPongReceived(fromID)) > bondExpiration { + if time.Since(t.db.LastPongReceived(fromID, from.IP)) > bondExpiration { // No endpoint proof pong exists, we don't process the packet. This prevents an // attack vector where the discovery protocol could be used to amplify traffic in a // DDOS attack. A malicious actor would send a findnode request with the IP address @@ -671,43 +704,50 @@ func (req *findnode) handle(t *udp, from *net.UDPAddr, fromKey encPubkey, mac [] // findnode) to the victim. return errUnknownNode } + return nil +} + +func (req *findnode) handle(t *udp, from *net.UDPAddr, fromID enode.ID, mac []byte) { + // Determine closest nodes. target := enode.ID(crypto.Keccak256Hash(req.Target[:])) t.tab.mutex.Lock() closest := t.tab.closest(target, bucketSize).entries t.tab.mutex.Unlock() - p := neighbors{Expiration: uint64(time.Now().Add(expiration).Unix())} - var sent bool // Send neighbors in chunks with at most maxNeighbors per packet // to stay below the 1280 byte limit. + p := neighbors{Expiration: uint64(time.Now().Add(expiration).Unix())} + var sent bool for _, n := range closest { if netutil.CheckRelayIP(from.IP, n.IP()) == nil { p.Nodes = append(p.Nodes, nodeToRPC(n)) } if len(p.Nodes) == maxNeighbors { - t.send(from, neighborsPacket, &p) + t.send(from, fromID, neighborsPacket, &p) p.Nodes = p.Nodes[:0] sent = true } } if len(p.Nodes) > 0 || !sent { - t.send(from, neighborsPacket, &p) + t.send(from, fromID, neighborsPacket, &p) } - return nil } func (req *findnode) name() string { return "FINDNODE/v4" } -func (req *neighbors) handle(t *udp, from *net.UDPAddr, fromKey encPubkey, mac []byte) error { +func (req *neighbors) preverify(t *udp, from *net.UDPAddr, fromID enode.ID, fromKey encPubkey) error { if expired(req.Expiration) { return errExpired } - if !t.handleReply(fromKey.id(), neighborsPacket, req) { + if !t.handleReply(fromID, from.IP, neighborsPacket, req) { return errUnsolicitedReply } return nil } +func (req *neighbors) handle(t *udp, from *net.UDPAddr, fromID enode.ID, mac []byte) { +} + func (req *neighbors) name() string { return "NEIGHBORS/v4" } func expired(ts uint64) bool { diff --git a/p2p/discover/udp_test.go b/p2p/discover/udp_test.go index a4ddaf750..3d53c9309 100644 --- a/p2p/discover/udp_test.go +++ b/p2p/discover/udp_test.go @@ -19,6 +19,7 @@ package discover import ( "bytes" "crypto/ecdsa" + crand "crypto/rand" "encoding/binary" "encoding/hex" "errors" @@ -57,6 +58,7 @@ type udpTest struct { t *testing.T pipe *dgramPipe table *Table + db *enode.DB udp *udp sent [][]byte localkey, remotekey *ecdsa.PrivateKey @@ -71,22 +73,32 @@ func newUDPTest(t *testing.T) *udpTest { remotekey: newkey(), remoteaddr: &net.UDPAddr{IP: net.IP{10, 0, 1, 99}, Port: 30303}, } - db, _ := enode.OpenDB("") - ln := enode.NewLocalNode(db, test.localkey) + test.db, _ = enode.OpenDB("") + ln := enode.NewLocalNode(test.db, test.localkey) test.table, test.udp, _ = newUDP(test.pipe, ln, Config{PrivateKey: test.localkey}) // Wait for initial refresh so the table doesn't send unexpected findnode. <-test.table.initDone return test } +func (test *udpTest) close() { + test.table.Close() + test.db.Close() +} + // handles a packet as if it had been sent to the transport. func (test *udpTest) packetIn(wantError error, ptype byte, data packet) error { - enc, _, err := encodePacket(test.remotekey, ptype, data) + return test.packetInFrom(wantError, test.remotekey, test.remoteaddr, ptype, data) +} + +// handles a packet as if it had been sent to the transport by the key/endpoint. +func (test *udpTest) packetInFrom(wantError error, key *ecdsa.PrivateKey, addr *net.UDPAddr, ptype byte, data packet) error { + enc, _, err := encodePacket(key, ptype, data) if err != nil { return test.errorf("packet (%d) encode error: %v", ptype, err) } test.sent = append(test.sent, enc) - if err = test.udp.handlePacket(test.remoteaddr, enc); err != wantError { + if err = test.udp.handlePacket(addr, enc); err != wantError { return test.errorf("error mismatch: got %q, want %q", err, wantError) } return nil @@ -94,19 +106,19 @@ func (test *udpTest) packetIn(wantError error, ptype byte, data packet) error { // waits for a packet to be sent by the transport. // validate should have type func(*udpTest, X) error, where X is a packet type. -func (test *udpTest) waitPacketOut(validate interface{}) ([]byte, error) { +func (test *udpTest) waitPacketOut(validate interface{}) (*net.UDPAddr, []byte, error) { dgram := test.pipe.waitPacketOut() - p, _, hash, err := decodePacket(dgram) + p, _, hash, err := decodePacket(dgram.data) if err != nil { - return hash, test.errorf("sent packet decode error: %v", err) + return &dgram.to, hash, test.errorf("sent packet decode error: %v", err) } fn := reflect.ValueOf(validate) exptype := fn.Type().In(0) if reflect.TypeOf(p) != exptype { - return hash, test.errorf("sent packet type mismatch, got: %v, want: %v", reflect.TypeOf(p), exptype) + return &dgram.to, hash, test.errorf("sent packet type mismatch, got: %v, want: %v", reflect.TypeOf(p), exptype) } fn.Call([]reflect.Value{reflect.ValueOf(p)}) - return hash, nil + return &dgram.to, hash, nil } func (test *udpTest) errorf(format string, args ...interface{}) error { @@ -125,7 +137,7 @@ func (test *udpTest) errorf(format string, args ...interface{}) error { func TestUDP_packetErrors(t *testing.T) { test := newUDPTest(t) - defer test.table.Close() + defer test.close() test.packetIn(errExpired, pingPacket, &ping{From: testRemote, To: testLocalAnnounced, Version: 4}) test.packetIn(errUnsolicitedReply, pongPacket, &pong{ReplyTok: []byte{}, Expiration: futureExp}) @@ -136,7 +148,7 @@ func TestUDP_packetErrors(t *testing.T) { func TestUDP_pingTimeout(t *testing.T) { t.Parallel() test := newUDPTest(t) - defer test.table.Close() + defer test.close() toaddr := &net.UDPAddr{IP: net.ParseIP("1.2.3.4"), Port: 2222} toid := enode.ID{1, 2, 3, 4} @@ -148,7 +160,7 @@ func TestUDP_pingTimeout(t *testing.T) { func TestUDP_responseTimeouts(t *testing.T) { t.Parallel() test := newUDPTest(t) - defer test.table.Close() + defer test.close() rand.Seed(time.Now().UnixNano()) randomDuration := func(max time.Duration) time.Duration { @@ -166,20 +178,20 @@ func TestUDP_responseTimeouts(t *testing.T) { // with ptype <= 128 will not get a reply and should time out. // For all other requests, a reply is scheduled to arrive // within the timeout window. - p := &pending{ + p := &replyMatcher{ ptype: byte(rand.Intn(255)), - callback: func(interface{}) bool { return true }, + callback: func(interface{}) (bool, bool) { return true, true }, } binary.BigEndian.PutUint64(p.from[:], uint64(i)) if p.ptype <= 128 { p.errc = timeoutErr - test.udp.addpending <- p + test.udp.addReplyMatcher <- p nTimeouts++ } else { p.errc = nilErr - test.udp.addpending <- p + test.udp.addReplyMatcher <- p time.AfterFunc(randomDuration(60*time.Millisecond), func() { - if !test.udp.handleReply(p.from, p.ptype, nil) { + if !test.udp.handleReply(p.from, p.ip, p.ptype, nil) { t.Logf("not matched: %v", p) } }) @@ -220,7 +232,7 @@ func TestUDP_responseTimeouts(t *testing.T) { func TestUDP_findnodeTimeout(t *testing.T) { t.Parallel() test := newUDPTest(t) - defer test.table.Close() + defer test.close() toaddr := &net.UDPAddr{IP: net.ParseIP("1.2.3.4"), Port: 2222} toid := enode.ID{1, 2, 3, 4} @@ -236,50 +248,65 @@ func TestUDP_findnodeTimeout(t *testing.T) { func TestUDP_findnode(t *testing.T) { test := newUDPTest(t) - defer test.table.Close() + defer test.close() // put a few nodes into the table. their exact // distribution shouldn't matter much, although we need to // take care not to overflow any bucket. nodes := &nodesByDistance{target: testTarget.id()} - for i := 0; i < bucketSize; i++ { + live := make(map[enode.ID]bool) + numCandidates := 2 * bucketSize + for i := 0; i < numCandidates; i++ { key := newkey() - n := wrapNode(enode.NewV4(&key.PublicKey, net.IP{10, 13, 0, 1}, 0, i)) - nodes.push(n, bucketSize) + ip := net.IP{10, 13, 0, byte(i)} + n := wrapNode(enode.NewV4(&key.PublicKey, ip, 0, 2000)) + // Ensure half of table content isn't verified live yet. + if i > numCandidates/2 { + n.livenessChecks = 1 + live[n.ID()] = true + } + nodes.push(n, numCandidates) } - test.table.stuff(nodes.entries) + fillTable(test.table, nodes.entries) // ensure there's a bond with the test node, // findnode won't be accepted otherwise. remoteID := encodePubkey(&test.remotekey.PublicKey).id() - test.table.db.UpdateLastPongReceived(remoteID, time.Now()) + test.table.db.UpdateLastPongReceived(remoteID, test.remoteaddr.IP, time.Now()) // check that closest neighbors are returned. - test.packetIn(nil, findnodePacket, &findnode{Target: testTarget, Expiration: futureExp}) expected := test.table.closest(testTarget.id(), bucketSize) - + test.packetIn(nil, findnodePacket, &findnode{Target: testTarget, Expiration: futureExp}) waitNeighbors := func(want []*node) { test.waitPacketOut(func(p *neighbors) { if len(p.Nodes) != len(want) { t.Errorf("wrong number of results: got %d, want %d", len(p.Nodes), bucketSize) } - for i := range p.Nodes { - if p.Nodes[i].ID.id() != want[i].ID() { - t.Errorf("result mismatch at %d:\n got: %v\n want: %v", i, p.Nodes[i], expected.entries[i]) + for i, n := range p.Nodes { + if n.ID.id() != want[i].ID() { + t.Errorf("result mismatch at %d:\n got: %v\n want: %v", i, n, expected.entries[i]) + } + if !live[n.ID.id()] { + t.Errorf("result includes dead node %v", n.ID.id()) } } }) } - waitNeighbors(expected.entries[:maxNeighbors]) - waitNeighbors(expected.entries[maxNeighbors:]) + // Receive replies. + want := expected.entries + if len(want) > maxNeighbors { + waitNeighbors(want[:maxNeighbors]) + want = want[maxNeighbors:] + } + waitNeighbors(want) } func TestUDP_findnodeMultiReply(t *testing.T) { test := newUDPTest(t) - defer test.table.Close() + defer test.close() rid := enode.PubkeyToIDV4(&test.remotekey.PublicKey) - test.table.db.UpdateLastPingReceived(rid, time.Now()) + test.table.db.UpdateLastPingReceived(rid, test.remoteaddr.IP, time.Now()) // queue a pending findnode request resultc, errc := make(chan []*node), make(chan error) @@ -329,11 +356,40 @@ func TestUDP_findnodeMultiReply(t *testing.T) { } } +func TestUDP_pingMatch(t *testing.T) { + test := newUDPTest(t) + defer test.close() + + randToken := make([]byte, 32) + crand.Read(randToken) + + test.packetIn(nil, pingPacket, &ping{From: testRemote, To: testLocalAnnounced, Version: 4, Expiration: futureExp}) + test.waitPacketOut(func(*pong) error { return nil }) + test.waitPacketOut(func(*ping) error { return nil }) + test.packetIn(errUnsolicitedReply, pongPacket, &pong{ReplyTok: randToken, To: testLocalAnnounced, Expiration: futureExp}) +} + +func TestUDP_pingMatchIP(t *testing.T) { + test := newUDPTest(t) + defer test.close() + + test.packetIn(nil, pingPacket, &ping{From: testRemote, To: testLocalAnnounced, Version: 4, Expiration: futureExp}) + test.waitPacketOut(func(*pong) error { return nil }) + + _, hash, _ := test.waitPacketOut(func(*ping) error { return nil }) + wrongAddr := &net.UDPAddr{IP: net.IP{33, 44, 1, 2}, Port: 30000} + test.packetInFrom(errUnsolicitedReply, test.remotekey, wrongAddr, pongPacket, &pong{ + ReplyTok: hash, + To: testLocalAnnounced, + Expiration: futureExp, + }) +} + func TestUDP_successfulPing(t *testing.T) { test := newUDPTest(t) added := make(chan *node, 1) test.table.nodeAddedHook = func(n *node) { added <- n } - defer test.table.Close() + defer test.close() // The remote side sends a ping packet to initiate the exchange. go test.packetIn(nil, pingPacket, &ping{From: testRemote, To: testLocalAnnounced, Version: 4, Expiration: futureExp}) @@ -356,7 +412,7 @@ func TestUDP_successfulPing(t *testing.T) { }) // remote is unknown, the table pings back. - hash, _ := test.waitPacketOut(func(p *ping) error { + _, hash, _ := test.waitPacketOut(func(p *ping) error { if !reflect.DeepEqual(p.From, test.udp.ourEndpoint()) { t.Errorf("got ping.From %#v, want %#v", p.From, test.udp.ourEndpoint()) } @@ -510,7 +566,12 @@ type dgramPipe struct { cond *sync.Cond closing chan struct{} closed bool - queue [][]byte + queue []dgram +} + +type dgram struct { + to net.UDPAddr + data []byte } func newpipe() *dgramPipe { @@ -531,7 +592,7 @@ func (c *dgramPipe) WriteToUDP(b []byte, to *net.UDPAddr) (n int, err error) { if c.closed { return 0, errors.New("closed") } - c.queue = append(c.queue, msg) + c.queue = append(c.queue, dgram{*to, b}) c.cond.Signal() return len(b), nil } @@ -556,7 +617,7 @@ func (c *dgramPipe) LocalAddr() net.Addr { return &net.UDPAddr{IP: testLocal.IP, Port: int(testLocal.UDP)} } -func (c *dgramPipe) waitPacketOut() []byte { +func (c *dgramPipe) waitPacketOut() dgram { c.mu.Lock() defer c.mu.Unlock() for len(c.queue) == 0 { diff --git a/p2p/enode/nodedb.go b/p2p/enode/nodedb.go index 7ee0c09a9..9353b155c 100644 --- a/p2p/enode/nodedb.go +++ b/p2p/enode/nodedb.go @@ -21,11 +21,11 @@ import ( "crypto/rand" "encoding/binary" "fmt" + "net" "os" "sync" "time" - "github.com/ethereum/go-ethereum/log" "github.com/ethereum/go-ethereum/rlp" "github.com/syndtr/goleveldb/leveldb" "github.com/syndtr/goleveldb/leveldb/errors" @@ -37,24 +37,31 @@ import ( // Keys in the node database. const ( - dbVersionKey = "version" // Version of the database to flush if changes - dbItemPrefix = "n:" // Identifier to prefix node entries with + dbVersionKey = "version" // Version of the database to flush if changes + dbNodePrefix = "n:" // Identifier to prefix node entries with + dbLocalPrefix = "local:" + dbDiscoverRoot = "v4" - dbDiscoverRoot = ":discover" - dbDiscoverSeq = dbDiscoverRoot + ":seq" - dbDiscoverPing = dbDiscoverRoot + ":lastping" - dbDiscoverPong = dbDiscoverRoot + ":lastpong" - dbDiscoverFindFails = dbDiscoverRoot + ":findfail" - dbLocalRoot = ":local" - dbLocalSeq = dbLocalRoot + ":seq" + // These fields are stored per ID and IP, the full key is "n::v4::findfail". + // Use nodeItemKey to create those keys. + dbNodeFindFails = "findfail" + dbNodePing = "lastping" + dbNodePong = "lastpong" + dbNodeSeq = "seq" + + // Local information is keyed by ID only, the full key is "local::seq". + // Use localItemKey to create those keys. + dbLocalSeq = "seq" ) -var ( +const ( dbNodeExpiration = 24 * time.Hour // Time after which an unseen node should be dropped. dbCleanupCycle = time.Hour // Time period for running the expiration task. - dbVersion = 7 + dbVersion = 8 ) +var zeroIP = make(net.IP, 16) + // DB is the node database, storing previously seen nodes and any collected metadata about // them for QoS purposes. type DB struct { @@ -119,27 +126,58 @@ func newPersistentDB(path string) (*DB, error) { return &DB{lvl: db, quit: make(chan struct{})}, nil } -// makeKey generates the leveldb key-blob from a node id and its particular -// field of interest. -func makeKey(id ID, field string) []byte { - if (id == ID{}) { - return []byte(field) - } - return append([]byte(dbItemPrefix), append(id[:], field...)...) +// nodeKey returns the database key for a node record. +func nodeKey(id ID) []byte { + key := append([]byte(dbNodePrefix), id[:]...) + key = append(key, ':') + key = append(key, dbDiscoverRoot...) + return key } -// splitKey tries to split a database key into a node id and a field part. -func splitKey(key []byte) (id ID, field string) { - // If the key is not of a node, return it plainly - if !bytes.HasPrefix(key, []byte(dbItemPrefix)) { - return ID{}, string(key) +// splitNodeKey returns the node ID of a key created by nodeKey. +func splitNodeKey(key []byte) (id ID, rest []byte) { + if !bytes.HasPrefix(key, []byte(dbNodePrefix)) { + return ID{}, nil } - // Otherwise split the id and field - item := key[len(dbItemPrefix):] + item := key[len(dbNodePrefix):] copy(id[:], item[:len(id)]) - field = string(item[len(id):]) + return id, item[len(id)+1:] +} - return id, field +// nodeItemKey returns the database key for a node metadata field. +func nodeItemKey(id ID, ip net.IP, field string) []byte { + ip16 := ip.To16() + if ip16 == nil { + panic(fmt.Errorf("invalid IP (length %d)", len(ip))) + } + return bytes.Join([][]byte{nodeKey(id), ip16, []byte(field)}, []byte{':'}) +} + +// splitNodeItemKey returns the components of a key created by nodeItemKey. +func splitNodeItemKey(key []byte) (id ID, ip net.IP, field string) { + id, key = splitNodeKey(key) + // Skip discover root. + if string(key) == dbDiscoverRoot { + return id, nil, "" + } + key = key[len(dbDiscoverRoot)+1:] + // Split out the IP. + ip = net.IP(key[:16]) + if ip4 := ip.To4(); ip4 != nil { + ip = ip4 + } + key = key[16+1:] + // Field is the remainder of key. + field = string(key) + return id, ip, field +} + +// localItemKey returns the key of a local node item. +func localItemKey(id ID, field string) []byte { + key := append([]byte(dbLocalPrefix), id[:]...) + key = append(key, ':') + key = append(key, field...) + return key } // fetchInt64 retrieves an integer associated with a particular key. @@ -181,7 +219,7 @@ func (db *DB) storeUint64(key []byte, n uint64) error { // Node retrieves a node with a given id from the database. func (db *DB) Node(id ID) *Node { - blob, err := db.lvl.Get(makeKey(id, dbDiscoverRoot), nil) + blob, err := db.lvl.Get(nodeKey(id), nil) if err != nil { return nil } @@ -207,15 +245,15 @@ func (db *DB) UpdateNode(node *Node) error { if err != nil { return err } - if err := db.lvl.Put(makeKey(node.ID(), dbDiscoverRoot), blob, nil); err != nil { + if err := db.lvl.Put(nodeKey(node.ID()), blob, nil); err != nil { return err } - return db.storeUint64(makeKey(node.ID(), dbDiscoverSeq), node.Seq()) + return db.storeUint64(nodeItemKey(node.ID(), zeroIP, dbNodeSeq), node.Seq()) } // NodeSeq returns the stored record sequence number of the given node. func (db *DB) NodeSeq(id ID) uint64 { - return db.fetchUint64(makeKey(id, dbDiscoverSeq)) + return db.fetchUint64(nodeItemKey(id, zeroIP, dbNodeSeq)) } // Resolve returns the stored record of the node if it has a larger sequence @@ -227,15 +265,17 @@ func (db *DB) Resolve(n *Node) *Node { return db.Node(n.ID()) } -// DeleteNode deletes all information/keys associated with a node. -func (db *DB) DeleteNode(id ID) error { - deleter := db.lvl.NewIterator(util.BytesPrefix(makeKey(id, "")), nil) - for deleter.Next() { - if err := db.lvl.Delete(deleter.Key(), nil); err != nil { - return err - } +// DeleteNode deletes all information associated with a node. +func (db *DB) DeleteNode(id ID) { + deleteRange(db.lvl, nodeKey(id)) +} + +func deleteRange(db *leveldb.DB, prefix []byte) { + it := db.NewIterator(util.BytesPrefix(prefix), nil) + defer it.Release() + for it.Next() { + db.Delete(it.Key(), nil) } - return nil } // ensureExpirer is a small helper method ensuring that the data expiration @@ -259,9 +299,7 @@ func (db *DB) expirer() { for { select { case <-tick.C: - if err := db.expireNodes(); err != nil { - log.Error("Failed to expire nodedb items", "err", err) - } + db.expireNodes() case <-db.quit: return } @@ -269,71 +307,85 @@ func (db *DB) expirer() { } // expireNodes iterates over the database and deletes all nodes that have not -// been seen (i.e. received a pong from) for some allotted time. -func (db *DB) expireNodes() error { - threshold := time.Now().Add(-dbNodeExpiration) - - // Find discovered nodes that are older than the allowance - it := db.lvl.NewIterator(nil, nil) +// been seen (i.e. received a pong from) for some time. +func (db *DB) expireNodes() { + it := db.lvl.NewIterator(util.BytesPrefix([]byte(dbNodePrefix)), nil) defer it.Release() - - for it.Next() { - // Skip the item if not a discovery node - id, field := splitKey(it.Key()) - if field != dbDiscoverRoot { - continue - } - // Skip the node if not expired yet (and not self) - if seen := db.LastPongReceived(id); seen.After(threshold) { - continue - } - // Otherwise delete all associated information - db.DeleteNode(id) + if !it.Next() { + return + } + + var ( + threshold = time.Now().Add(-dbNodeExpiration).Unix() + youngestPong int64 + atEnd = false + ) + for !atEnd { + id, ip, field := splitNodeItemKey(it.Key()) + if field == dbNodePong { + time, _ := binary.Varint(it.Value()) + if time > youngestPong { + youngestPong = time + } + if time < threshold { + // Last pong from this IP older than threshold, remove fields belonging to it. + deleteRange(db.lvl, nodeItemKey(id, ip, "")) + } + } + atEnd = !it.Next() + nextID, _ := splitNodeKey(it.Key()) + if atEnd || nextID != id { + // We've moved beyond the last entry of the current ID. + // Remove everything if there was no recent enough pong. + if youngestPong > 0 && youngestPong < threshold { + deleteRange(db.lvl, nodeKey(id)) + } + youngestPong = 0 + } } - return nil } // LastPingReceived retrieves the time of the last ping packet received from // a remote node. -func (db *DB) LastPingReceived(id ID) time.Time { - return time.Unix(db.fetchInt64(makeKey(id, dbDiscoverPing)), 0) +func (db *DB) LastPingReceived(id ID, ip net.IP) time.Time { + return time.Unix(db.fetchInt64(nodeItemKey(id, ip, dbNodePing)), 0) } // UpdateLastPingReceived updates the last time we tried contacting a remote node. -func (db *DB) UpdateLastPingReceived(id ID, instance time.Time) error { - return db.storeInt64(makeKey(id, dbDiscoverPing), instance.Unix()) +func (db *DB) UpdateLastPingReceived(id ID, ip net.IP, instance time.Time) error { + return db.storeInt64(nodeItemKey(id, ip, dbNodePing), instance.Unix()) } // LastPongReceived retrieves the time of the last successful pong from remote node. -func (db *DB) LastPongReceived(id ID) time.Time { +func (db *DB) LastPongReceived(id ID, ip net.IP) time.Time { // Launch expirer db.ensureExpirer() - return time.Unix(db.fetchInt64(makeKey(id, dbDiscoverPong)), 0) + return time.Unix(db.fetchInt64(nodeItemKey(id, ip, dbNodePong)), 0) } // UpdateLastPongReceived updates the last pong time of a node. -func (db *DB) UpdateLastPongReceived(id ID, instance time.Time) error { - return db.storeInt64(makeKey(id, dbDiscoverPong), instance.Unix()) +func (db *DB) UpdateLastPongReceived(id ID, ip net.IP, instance time.Time) error { + return db.storeInt64(nodeItemKey(id, ip, dbNodePong), instance.Unix()) } // FindFails retrieves the number of findnode failures since bonding. -func (db *DB) FindFails(id ID) int { - return int(db.fetchInt64(makeKey(id, dbDiscoverFindFails))) +func (db *DB) FindFails(id ID, ip net.IP) int { + return int(db.fetchInt64(nodeItemKey(id, ip, dbNodeFindFails))) } // UpdateFindFails updates the number of findnode failures since bonding. -func (db *DB) UpdateFindFails(id ID, fails int) error { - return db.storeInt64(makeKey(id, dbDiscoverFindFails), int64(fails)) +func (db *DB) UpdateFindFails(id ID, ip net.IP, fails int) error { + return db.storeInt64(nodeItemKey(id, ip, dbNodeFindFails), int64(fails)) } // LocalSeq retrieves the local record sequence counter. func (db *DB) localSeq(id ID) uint64 { - return db.fetchUint64(makeKey(id, dbLocalSeq)) + return db.fetchUint64(nodeItemKey(id, zeroIP, dbLocalSeq)) } // storeLocalSeq stores the local record sequence counter. func (db *DB) storeLocalSeq(id ID, n uint64) { - db.storeUint64(makeKey(id, dbLocalSeq), n) + db.storeUint64(nodeItemKey(id, zeroIP, dbLocalSeq), n) } // QuerySeeds retrieves random nodes to be used as potential seed nodes @@ -355,14 +407,14 @@ seek: ctr := id[0] rand.Read(id[:]) id[0] = ctr + id[0]%16 - it.Seek(makeKey(id, dbDiscoverRoot)) + it.Seek(nodeKey(id)) n := nextNode(it) if n == nil { id[0] = 0 continue seek // iterator exhausted } - if now.Sub(db.LastPongReceived(n.ID())) > maxAge { + if now.Sub(db.LastPongReceived(n.ID(), n.IP())) > maxAge { continue seek } for i := range nodes { @@ -379,8 +431,8 @@ seek: // database entries. func nextNode(it iterator.Iterator) *Node { for end := false; !end; end = !it.Next() { - id, field := splitKey(it.Key()) - if field != dbDiscoverRoot { + id, rest := splitNodeKey(it.Key()) + if string(rest) != dbDiscoverRoot { continue } return mustDecodeNode(id[:], it.Value()) diff --git a/p2p/enode/nodedb_test.go b/p2p/enode/nodedb_test.go index 96794827c..341b61a28 100644 --- a/p2p/enode/nodedb_test.go +++ b/p2p/enode/nodedb_test.go @@ -28,42 +28,54 @@ import ( "time" ) -var nodeDBKeyTests = []struct { - id ID - field string - key []byte -}{ - { - id: ID{}, - field: "version", - key: []byte{0x76, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e}, // field - }, - { - id: HexID("51232b8d7821617d2b29b54b81cdefb9b3e9c37d7fd5f63270bcc9e1a6f6a439"), - field: ":discover", - key: []byte{ - 0x6e, 0x3a, // prefix - 0x51, 0x23, 0x2b, 0x8d, 0x78, 0x21, 0x61, 0x7d, // node id - 0x2b, 0x29, 0xb5, 0x4b, 0x81, 0xcd, 0xef, 0xb9, // - 0xb3, 0xe9, 0xc3, 0x7d, 0x7f, 0xd5, 0xf6, 0x32, // - 0x70, 0xbc, 0xc9, 0xe1, 0xa6, 0xf6, 0xa4, 0x39, // - 0x3a, 0x64, 0x69, 0x73, 0x63, 0x6f, 0x76, 0x65, 0x72, // field - }, - }, +var keytestID = HexID("51232b8d7821617d2b29b54b81cdefb9b3e9c37d7fd5f63270bcc9e1a6f6a439") + +func TestDBNodeKey(t *testing.T) { + enc := nodeKey(keytestID) + want := []byte{ + 'n', ':', + 0x51, 0x23, 0x2b, 0x8d, 0x78, 0x21, 0x61, 0x7d, // node id + 0x2b, 0x29, 0xb5, 0x4b, 0x81, 0xcd, 0xef, 0xb9, // + 0xb3, 0xe9, 0xc3, 0x7d, 0x7f, 0xd5, 0xf6, 0x32, // + 0x70, 0xbc, 0xc9, 0xe1, 0xa6, 0xf6, 0xa4, 0x39, // + ':', 'v', '4', + } + if !bytes.Equal(enc, want) { + t.Errorf("wrong encoded key:\ngot %q\nwant %q", enc, want) + } + id, _ := splitNodeKey(enc) + if id != keytestID { + t.Errorf("wrong ID from splitNodeKey") + } } -func TestDBKeys(t *testing.T) { - for i, tt := range nodeDBKeyTests { - if key := makeKey(tt.id, tt.field); !bytes.Equal(key, tt.key) { - t.Errorf("make test %d: key mismatch: have 0x%x, want 0x%x", i, key, tt.key) - } - id, field := splitKey(tt.key) - if !bytes.Equal(id[:], tt.id[:]) { - t.Errorf("split test %d: id mismatch: have 0x%x, want 0x%x", i, id, tt.id) - } - if field != tt.field { - t.Errorf("split test %d: field mismatch: have 0x%x, want 0x%x", i, field, tt.field) - } +func TestDBNodeItemKey(t *testing.T) { + wantIP := net.IP{127, 0, 0, 3} + wantField := "foobar" + enc := nodeItemKey(keytestID, wantIP, wantField) + want := []byte{ + 'n', ':', + 0x51, 0x23, 0x2b, 0x8d, 0x78, 0x21, 0x61, 0x7d, // node id + 0x2b, 0x29, 0xb5, 0x4b, 0x81, 0xcd, 0xef, 0xb9, // + 0xb3, 0xe9, 0xc3, 0x7d, 0x7f, 0xd5, 0xf6, 0x32, // + 0x70, 0xbc, 0xc9, 0xe1, 0xa6, 0xf6, 0xa4, 0x39, // + ':', 'v', '4', ':', + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // IP + 0x00, 0x00, 0xff, 0xff, 0x7f, 0x00, 0x00, 0x03, // + ':', 'f', 'o', 'o', 'b', 'a', 'r', + } + if !bytes.Equal(enc, want) { + t.Errorf("wrong encoded key:\ngot %q\nwant %q", enc, want) + } + id, ip, field := splitNodeItemKey(enc) + if id != keytestID { + t.Errorf("splitNodeItemKey returned wrong ID: %v", id) + } + if !bytes.Equal(ip, wantIP) { + t.Errorf("splitNodeItemKey returned wrong IP: %v", ip) + } + if field != wantField { + t.Errorf("splitNodeItemKey returned wrong field: %q", field) } } @@ -113,33 +125,33 @@ func TestDBFetchStore(t *testing.T) { defer db.Close() // Check fetch/store operations on a node ping object - if stored := db.LastPingReceived(node.ID()); stored.Unix() != 0 { + if stored := db.LastPingReceived(node.ID(), node.IP()); stored.Unix() != 0 { t.Errorf("ping: non-existing object: %v", stored) } - if err := db.UpdateLastPingReceived(node.ID(), inst); err != nil { + if err := db.UpdateLastPingReceived(node.ID(), node.IP(), inst); err != nil { t.Errorf("ping: failed to update: %v", err) } - if stored := db.LastPingReceived(node.ID()); stored.Unix() != inst.Unix() { + if stored := db.LastPingReceived(node.ID(), node.IP()); stored.Unix() != inst.Unix() { t.Errorf("ping: value mismatch: have %v, want %v", stored, inst) } // Check fetch/store operations on a node pong object - if stored := db.LastPongReceived(node.ID()); stored.Unix() != 0 { + if stored := db.LastPongReceived(node.ID(), node.IP()); stored.Unix() != 0 { t.Errorf("pong: non-existing object: %v", stored) } - if err := db.UpdateLastPongReceived(node.ID(), inst); err != nil { + if err := db.UpdateLastPongReceived(node.ID(), node.IP(), inst); err != nil { t.Errorf("pong: failed to update: %v", err) } - if stored := db.LastPongReceived(node.ID()); stored.Unix() != inst.Unix() { + if stored := db.LastPongReceived(node.ID(), node.IP()); stored.Unix() != inst.Unix() { t.Errorf("pong: value mismatch: have %v, want %v", stored, inst) } // Check fetch/store operations on a node findnode-failure object - if stored := db.FindFails(node.ID()); stored != 0 { + if stored := db.FindFails(node.ID(), node.IP()); stored != 0 { t.Errorf("find-node fails: non-existing object: %v", stored) } - if err := db.UpdateFindFails(node.ID(), num); err != nil { + if err := db.UpdateFindFails(node.ID(), node.IP(), num); err != nil { t.Errorf("find-node fails: failed to update: %v", err) } - if stored := db.FindFails(node.ID()); stored != num { + if stored := db.FindFails(node.ID(), node.IP()); stored != num { t.Errorf("find-node fails: value mismatch: have %v, want %v", stored, num) } // Check fetch/store operations on an actual node object @@ -256,7 +268,7 @@ func testSeedQuery() error { if err := db.UpdateNode(seed.node); err != nil { return fmt.Errorf("node %d: failed to insert: %v", i, err) } - if err := db.UpdateLastPongReceived(seed.node.ID(), seed.pong); err != nil { + if err := db.UpdateLastPongReceived(seed.node.ID(), seed.node.IP(), seed.pong); err != nil { return fmt.Errorf("node %d: failed to insert bondTime: %v", i, err) } } @@ -321,10 +333,12 @@ func TestDBPersistency(t *testing.T) { } var nodeDBExpirationNodes = []struct { - node *Node - pong time.Time - exp bool + node *Node + pong time.Time + storeNode bool + exp bool }{ + // Node has new enough pong time and isn't expired: { node: NewV4( hexPubkey("8d110e2ed4b446d9b5fb50f117e5f37fb7597af455e1dab0e6f045a6eeaa786a6781141659020d38bdc5e698ed3d4d2bafa8b5061810dfa63e8ac038db2e9b67"), @@ -332,17 +346,79 @@ var nodeDBExpirationNodes = []struct { 30303, 30303, ), - pong: time.Now().Add(-dbNodeExpiration + time.Minute), - exp: false, - }, { + storeNode: true, + pong: time.Now().Add(-dbNodeExpiration + time.Minute), + exp: false, + }, + // Node with pong time before expiration is removed: + { node: NewV4( hexPubkey("913a205579c32425b220dfba999d215066e5bdbf900226b11da1907eae5e93eb40616d47412cf819664e9eacbdfcca6b0c6e07e09847a38472d4be46ab0c3672"), net.IP{127, 0, 0, 2}, 30303, 30303, ), - pong: time.Now().Add(-dbNodeExpiration - time.Minute), - exp: true, + storeNode: true, + pong: time.Now().Add(-dbNodeExpiration - time.Minute), + exp: true, + }, + // Just pong time, no node stored: + { + node: NewV4( + hexPubkey("b56670e0b6bad2c5dab9f9fe6f061a16cf78d68b6ae2cfda3144262d08d97ce5f46fd8799b6d1f709b1abe718f2863e224488bd7518e5e3b43809ac9bd1138ca"), + net.IP{127, 0, 0, 3}, + 30303, + 30303, + ), + storeNode: false, + pong: time.Now().Add(-dbNodeExpiration - time.Minute), + exp: true, + }, + // Node with multiple pong times, all older than expiration. + { + node: NewV4( + hexPubkey("29f619cebfd32c9eab34aec797ed5e3fe15b9b45be95b4df3f5fe6a9ae892f433eb08d7698b2ef3621568b0fb70d57b515ab30d4e72583b798298e0f0a66b9d1"), + net.IP{127, 0, 0, 4}, + 30303, + 30303, + ), + storeNode: true, + pong: time.Now().Add(-dbNodeExpiration - time.Minute), + exp: true, + }, + { + node: NewV4( + hexPubkey("29f619cebfd32c9eab34aec797ed5e3fe15b9b45be95b4df3f5fe6a9ae892f433eb08d7698b2ef3621568b0fb70d57b515ab30d4e72583b798298e0f0a66b9d1"), + net.IP{127, 0, 0, 5}, + 30303, + 30303, + ), + storeNode: false, + pong: time.Now().Add(-dbNodeExpiration - 2*time.Minute), + exp: true, + }, + // Node with multiple pong times, one newer, one older than expiration. + { + node: NewV4( + hexPubkey("3b73a9e5f4af6c4701c57c73cc8cfa0f4802840b24c11eba92aac3aef65644a3728b4b2aec8199f6d72bd66be2c65861c773129039bd47daa091ca90a6d4c857"), + net.IP{127, 0, 0, 6}, + 30303, + 30303, + ), + storeNode: true, + pong: time.Now().Add(-dbNodeExpiration + time.Minute), + exp: false, + }, + { + node: NewV4( + hexPubkey("3b73a9e5f4af6c4701c57c73cc8cfa0f4802840b24c11eba92aac3aef65644a3728b4b2aec8199f6d72bd66be2c65861c773129039bd47daa091ca90a6d4c857"), + net.IP{127, 0, 0, 7}, + 30303, + 30303, + ), + storeNode: false, + pong: time.Now().Add(-dbNodeExpiration - time.Minute), + exp: true, }, } @@ -350,23 +426,39 @@ func TestDBExpiration(t *testing.T) { db, _ := OpenDB("") defer db.Close() - // Add all the test nodes and set their last pong time + // Add all the test nodes and set their last pong time. for i, seed := range nodeDBExpirationNodes { - if err := db.UpdateNode(seed.node); err != nil { - t.Fatalf("node %d: failed to insert: %v", i, err) + if seed.storeNode { + if err := db.UpdateNode(seed.node); err != nil { + t.Fatalf("node %d: failed to insert: %v", i, err) + } } - if err := db.UpdateLastPongReceived(seed.node.ID(), seed.pong); err != nil { + if err := db.UpdateLastPongReceived(seed.node.ID(), seed.node.IP(), seed.pong); err != nil { t.Fatalf("node %d: failed to update bondTime: %v", i, err) } } - // Expire some of them, and check the rest - if err := db.expireNodes(); err != nil { - t.Fatalf("failed to expire nodes: %v", err) - } + + db.expireNodes() + + // Check that expired entries have been removed. + unixZeroTime := time.Unix(0, 0) for i, seed := range nodeDBExpirationNodes { node := db.Node(seed.node.ID()) - if (node == nil && !seed.exp) || (node != nil && seed.exp) { - t.Errorf("node %d: expiration mismatch: have %v, want %v", i, node, seed.exp) + pong := db.LastPongReceived(seed.node.ID(), seed.node.IP()) + if seed.exp { + if seed.storeNode && node != nil { + t.Errorf("node %d (%s) shouldn't be present after expiration", i, seed.node.ID().TerminalString()) + } + if !pong.Equal(unixZeroTime) { + t.Errorf("pong time %d (%s %v) shouldn't be present after expiration", i, seed.node.ID().TerminalString(), seed.node.IP()) + } + } else { + if seed.storeNode && node == nil { + t.Errorf("node %d (%s) should be present after expiration", i, seed.node.ID().TerminalString()) + } + if !pong.Equal(seed.pong.Truncate(1 * time.Second)) { + t.Errorf("pong time %d (%s) should be %v after expiration, but is %v", i, seed.node.ID().TerminalString(), seed.pong, pong) + } } } }