p2p/enode: use unix timestamp as base ENR sequence number (#19903)

This PR ensures that wiping all data associated with a node (apart from its nodekey)
will not generate already used sequence number for the ENRs, since all remote nodes
would reject them until they out-number the previously published largest one.

The big complication with this scheme is that every local update to the ENR can
potentially bump the sequence number by one. In order to ensure that local updates
do not outrun the clock, the sequence number is a millisecond-precision timestamp,
and updates are throttled to occur at most once per millisecond.

Co-authored-by: Felix Lange <fjl@twurst.com>
This commit is contained in:
Péter Szilágyi 2021-09-07 13:36:48 +03:00 committed by GitHub
parent 794c6133ef
commit 6ef3a16869
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 69 additions and 21 deletions

View File

@ -512,9 +512,6 @@ func (n *handshakeTestNode) init(key *ecdsa.PrivateKey, ip net.IP, clock mclock.
db, _ := enode.OpenDB("") db, _ := enode.OpenDB("")
n.ln = enode.NewLocalNode(db, key) n.ln = enode.NewLocalNode(db, key)
n.ln.SetStaticIP(ip) n.ln.SetStaticIP(ip)
if n.ln.Node().Seq() != 1 {
panic(fmt.Errorf("unexpected seq %d", n.ln.Node().Seq()))
}
n.c = NewCodec(n.ln, key, clock) n.c = NewCodec(n.ln, key, clock)
} }

View File

@ -36,20 +36,25 @@ const (
iptrackMinStatements = 10 iptrackMinStatements = 10
iptrackWindow = 5 * time.Minute iptrackWindow = 5 * time.Minute
iptrackContactWindow = 10 * time.Minute iptrackContactWindow = 10 * time.Minute
// time needed to wait between two updates to the local ENR
recordUpdateThrottle = time.Millisecond
) )
// LocalNode produces the signed node record of a local node, i.e. a node run in the // LocalNode produces the signed node record of a local node, i.e. a node run in the
// current process. Setting ENR entries via the Set method updates the record. A new version // current process. Setting ENR entries via the Set method updates the record. A new version
// of the record is signed on demand when the Node method is called. // of the record is signed on demand when the Node method is called.
type LocalNode struct { type LocalNode struct {
cur atomic.Value // holds a non-nil node pointer while the record is up-to-date. cur atomic.Value // holds a non-nil node pointer while the record is up-to-date
id ID id ID
key *ecdsa.PrivateKey key *ecdsa.PrivateKey
db *DB db *DB
// everything below is protected by a lock // everything below is protected by a lock
mu sync.Mutex mu sync.RWMutex
seq uint64 seq uint64
update time.Time // timestamp when the record was last updated
entries map[string]enr.Entry entries map[string]enr.Entry
endpoint4 lnEndpoint endpoint4 lnEndpoint
endpoint6 lnEndpoint endpoint6 lnEndpoint
@ -76,7 +81,8 @@ func NewLocalNode(db *DB, key *ecdsa.PrivateKey) *LocalNode {
}, },
} }
ln.seq = db.localSeq(ln.id) ln.seq = db.localSeq(ln.id)
ln.invalidate() ln.update = time.Now()
ln.cur.Store((*Node)(nil))
return ln return ln
} }
@ -87,14 +93,34 @@ func (ln *LocalNode) Database() *DB {
// Node returns the current version of the local node record. // Node returns the current version of the local node record.
func (ln *LocalNode) Node() *Node { func (ln *LocalNode) Node() *Node {
// If we have a valid record, return that
n := ln.cur.Load().(*Node) n := ln.cur.Load().(*Node)
if n != nil { if n != nil {
return n return n
} }
// Record was invalidated, sign a new copy. // Record was invalidated, sign a new copy.
ln.mu.Lock() ln.mu.Lock()
defer ln.mu.Unlock() defer ln.mu.Unlock()
// Double check the current record, since multiple goroutines might be waiting
// on the write mutex.
if n = ln.cur.Load().(*Node); n != nil {
return n
}
// The initial sequence number is the current timestamp in milliseconds. To ensure
// that the initial sequence number will always be higher than any previous sequence
// number (assuming the clock is correct), we want to avoid updating the record faster
// than once per ms. So we need to sleep here until the next possible update time has
// arrived.
lastChange := time.Since(ln.update)
if lastChange < recordUpdateThrottle {
time.Sleep(recordUpdateThrottle - lastChange)
}
ln.sign() ln.sign()
ln.update = time.Now()
return ln.cur.Load().(*Node) return ln.cur.Load().(*Node)
} }
@ -114,6 +140,10 @@ func (ln *LocalNode) ID() ID {
// Set puts the given entry into the local record, overwriting any existing value. // Set puts the given entry into the local record, overwriting any existing value.
// Use Set*IP and SetFallbackUDP to set IP addresses and UDP port, otherwise they'll // Use Set*IP and SetFallbackUDP to set IP addresses and UDP port, otherwise they'll
// be overwritten by the endpoint predictor. // be overwritten by the endpoint predictor.
//
// Since node record updates are throttled to one per second, Set is asynchronous.
// Any update will be queued up and published when at least one second passes from
// the last change.
func (ln *LocalNode) Set(e enr.Entry) { func (ln *LocalNode) Set(e enr.Entry) {
ln.mu.Lock() ln.mu.Lock()
defer ln.mu.Unlock() defer ln.mu.Unlock()
@ -288,3 +318,12 @@ func (ln *LocalNode) bumpSeq() {
ln.seq++ ln.seq++
ln.db.storeLocalSeq(ln.id, ln.seq) ln.db.storeLocalSeq(ln.id, ln.seq)
} }
// nowMilliseconds gives the current timestamp at millisecond precision.
func nowMilliseconds() uint64 {
ns := time.Now().UnixNano()
if ns < 0 {
return 0
}
return uint64(ns / 1000 / 1000)
}

View File

@ -49,32 +49,39 @@ func TestLocalNode(t *testing.T) {
} }
} }
// This test checks that the sequence number is persisted between restarts.
func TestLocalNodeSeqPersist(t *testing.T) { func TestLocalNodeSeqPersist(t *testing.T) {
timestamp := nowMilliseconds()
ln, db := newLocalNodeForTesting() ln, db := newLocalNodeForTesting()
defer db.Close() defer db.Close()
if s := ln.Node().Seq(); s != 1 { initialSeq := ln.Node().Seq()
t.Fatalf("wrong initial seq %d, want 1", s) if initialSeq < timestamp {
t.Fatalf("wrong initial seq %d, want at least %d", initialSeq, timestamp)
} }
ln.Set(enr.WithEntry("x", uint(1))) ln.Set(enr.WithEntry("x", uint(1)))
if s := ln.Node().Seq(); s != 2 { if s := ln.Node().Seq(); s != initialSeq+1 {
t.Fatalf("wrong seq %d after set, want 2", s) t.Fatalf("wrong seq %d after set, want %d", s, initialSeq+1)
} }
// Create a new instance, it should reload the sequence number. // Create a new instance, it should reload the sequence number.
// The number increases just after that because a new record is // The number increases just after that because a new record is
// created without the "x" entry. // created without the "x" entry.
ln2 := NewLocalNode(db, ln.key) ln2 := NewLocalNode(db, ln.key)
if s := ln2.Node().Seq(); s != 3 { if s := ln2.Node().Seq(); s != initialSeq+2 {
t.Fatalf("wrong seq %d on new instance, want 3", s) t.Fatalf("wrong seq %d on new instance, want %d", s, initialSeq+2)
} }
finalSeq := ln2.Node().Seq()
// Create a new instance with a different node key on the same database. // Create a new instance with a different node key on the same database.
// This should reset the sequence number. // This should reset the sequence number.
key, _ := crypto.GenerateKey() key, _ := crypto.GenerateKey()
ln3 := NewLocalNode(db, key) ln3 := NewLocalNode(db, key)
if s := ln3.Node().Seq(); s != 1 { if s := ln3.Node().Seq(); s < finalSeq {
t.Fatalf("wrong seq %d on instance with changed key, want 1", s) t.Fatalf("wrong seq %d on instance with changed key, want >= %d", s, finalSeq)
} }
} }
@ -91,20 +98,20 @@ func TestLocalNodeEndpoint(t *testing.T) {
// Nothing is set initially. // Nothing is set initially.
assert.Equal(t, net.IP(nil), ln.Node().IP()) assert.Equal(t, net.IP(nil), ln.Node().IP())
assert.Equal(t, 0, ln.Node().UDP()) assert.Equal(t, 0, ln.Node().UDP())
assert.Equal(t, uint64(1), ln.Node().Seq()) initialSeq := ln.Node().Seq()
// Set up fallback address. // Set up fallback address.
ln.SetFallbackIP(fallback.IP) ln.SetFallbackIP(fallback.IP)
ln.SetFallbackUDP(fallback.Port) ln.SetFallbackUDP(fallback.Port)
assert.Equal(t, fallback.IP, ln.Node().IP()) assert.Equal(t, fallback.IP, ln.Node().IP())
assert.Equal(t, fallback.Port, ln.Node().UDP()) assert.Equal(t, fallback.Port, ln.Node().UDP())
assert.Equal(t, uint64(2), ln.Node().Seq()) assert.Equal(t, initialSeq+1, ln.Node().Seq())
// Add endpoint statements from random hosts. // Add endpoint statements from random hosts.
for i := 0; i < iptrackMinStatements; i++ { for i := 0; i < iptrackMinStatements; i++ {
assert.Equal(t, fallback.IP, ln.Node().IP()) assert.Equal(t, fallback.IP, ln.Node().IP())
assert.Equal(t, fallback.Port, ln.Node().UDP()) assert.Equal(t, fallback.Port, ln.Node().UDP())
assert.Equal(t, uint64(2), ln.Node().Seq()) assert.Equal(t, initialSeq+1, ln.Node().Seq())
from := &net.UDPAddr{IP: make(net.IP, 4), Port: 90} from := &net.UDPAddr{IP: make(net.IP, 4), Port: 90}
rand.Read(from.IP) rand.Read(from.IP)
@ -112,11 +119,11 @@ func TestLocalNodeEndpoint(t *testing.T) {
} }
assert.Equal(t, predicted.IP, ln.Node().IP()) assert.Equal(t, predicted.IP, ln.Node().IP())
assert.Equal(t, predicted.Port, ln.Node().UDP()) assert.Equal(t, predicted.Port, ln.Node().UDP())
assert.Equal(t, uint64(3), ln.Node().Seq()) assert.Equal(t, initialSeq+2, ln.Node().Seq())
// Static IP overrides prediction. // Static IP overrides prediction.
ln.SetStaticIP(staticIP) ln.SetStaticIP(staticIP)
assert.Equal(t, staticIP, ln.Node().IP()) assert.Equal(t, staticIP, ln.Node().IP())
assert.Equal(t, fallback.Port, ln.Node().UDP()) assert.Equal(t, fallback.Port, ln.Node().UDP())
assert.Equal(t, uint64(4), ln.Node().Seq()) assert.Equal(t, initialSeq+3, ln.Node().Seq())
} }

View File

@ -427,9 +427,14 @@ func (db *DB) UpdateFindFailsV5(id ID, ip net.IP, fails int) error {
return db.storeInt64(v5Key(id, ip, dbNodeFindFails), int64(fails)) return db.storeInt64(v5Key(id, ip, dbNodeFindFails), int64(fails))
} }
// LocalSeq retrieves the local record sequence counter. // localSeq retrieves the local record sequence counter, defaulting to the current
// timestamp if no previous exists. This ensures that wiping all data associated
// with a node (apart from its key) will not generate already used sequence nums.
func (db *DB) localSeq(id ID) uint64 { func (db *DB) localSeq(id ID) uint64 {
return db.fetchUint64(localItemKey(id, dbLocalSeq)) if seq := db.fetchUint64(localItemKey(id, dbLocalSeq)); seq > 0 {
return seq
}
return nowMilliseconds()
} }
// storeLocalSeq stores the local record sequence counter. // storeLocalSeq stores the local record sequence counter.