Patch for concurrent iterator & others (onto v1.11.6) #386
@ -512,9 +512,6 @@ func (n *handshakeTestNode) init(key *ecdsa.PrivateKey, ip net.IP, clock mclock.
|
|||||||
db, _ := enode.OpenDB("")
|
db, _ := enode.OpenDB("")
|
||||||
n.ln = enode.NewLocalNode(db, key)
|
n.ln = enode.NewLocalNode(db, key)
|
||||||
n.ln.SetStaticIP(ip)
|
n.ln.SetStaticIP(ip)
|
||||||
if n.ln.Node().Seq() != 1 {
|
|
||||||
panic(fmt.Errorf("unexpected seq %d", n.ln.Node().Seq()))
|
|
||||||
}
|
|
||||||
n.c = NewCodec(n.ln, key, clock)
|
n.c = NewCodec(n.ln, key, clock)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -36,20 +36,25 @@ const (
|
|||||||
iptrackMinStatements = 10
|
iptrackMinStatements = 10
|
||||||
iptrackWindow = 5 * time.Minute
|
iptrackWindow = 5 * time.Minute
|
||||||
iptrackContactWindow = 10 * time.Minute
|
iptrackContactWindow = 10 * time.Minute
|
||||||
|
|
||||||
|
// time needed to wait between two updates to the local ENR
|
||||||
|
recordUpdateThrottle = time.Millisecond
|
||||||
)
|
)
|
||||||
|
|
||||||
// LocalNode produces the signed node record of a local node, i.e. a node run in the
|
// LocalNode produces the signed node record of a local node, i.e. a node run in the
|
||||||
// current process. Setting ENR entries via the Set method updates the record. A new version
|
// current process. Setting ENR entries via the Set method updates the record. A new version
|
||||||
// of the record is signed on demand when the Node method is called.
|
// of the record is signed on demand when the Node method is called.
|
||||||
type LocalNode struct {
|
type LocalNode struct {
|
||||||
cur atomic.Value // holds a non-nil node pointer while the record is up-to-date.
|
cur atomic.Value // holds a non-nil node pointer while the record is up-to-date
|
||||||
|
|
||||||
id ID
|
id ID
|
||||||
key *ecdsa.PrivateKey
|
key *ecdsa.PrivateKey
|
||||||
db *DB
|
db *DB
|
||||||
|
|
||||||
// everything below is protected by a lock
|
// everything below is protected by a lock
|
||||||
mu sync.Mutex
|
mu sync.RWMutex
|
||||||
seq uint64
|
seq uint64
|
||||||
|
update time.Time // timestamp when the record was last updated
|
||||||
entries map[string]enr.Entry
|
entries map[string]enr.Entry
|
||||||
endpoint4 lnEndpoint
|
endpoint4 lnEndpoint
|
||||||
endpoint6 lnEndpoint
|
endpoint6 lnEndpoint
|
||||||
@ -76,7 +81,8 @@ func NewLocalNode(db *DB, key *ecdsa.PrivateKey) *LocalNode {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
ln.seq = db.localSeq(ln.id)
|
ln.seq = db.localSeq(ln.id)
|
||||||
ln.invalidate()
|
ln.update = time.Now()
|
||||||
|
ln.cur.Store((*Node)(nil))
|
||||||
return ln
|
return ln
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -87,14 +93,34 @@ func (ln *LocalNode) Database() *DB {
|
|||||||
|
|
||||||
// Node returns the current version of the local node record.
|
// Node returns the current version of the local node record.
|
||||||
func (ln *LocalNode) Node() *Node {
|
func (ln *LocalNode) Node() *Node {
|
||||||
|
// If we have a valid record, return that
|
||||||
n := ln.cur.Load().(*Node)
|
n := ln.cur.Load().(*Node)
|
||||||
if n != nil {
|
if n != nil {
|
||||||
return n
|
return n
|
||||||
}
|
}
|
||||||
|
|
||||||
// Record was invalidated, sign a new copy.
|
// Record was invalidated, sign a new copy.
|
||||||
ln.mu.Lock()
|
ln.mu.Lock()
|
||||||
defer ln.mu.Unlock()
|
defer ln.mu.Unlock()
|
||||||
|
|
||||||
|
// Double check the current record, since multiple goroutines might be waiting
|
||||||
|
// on the write mutex.
|
||||||
|
if n = ln.cur.Load().(*Node); n != nil {
|
||||||
|
return n
|
||||||
|
}
|
||||||
|
|
||||||
|
// The initial sequence number is the current timestamp in milliseconds. To ensure
|
||||||
|
// that the initial sequence number will always be higher than any previous sequence
|
||||||
|
// number (assuming the clock is correct), we want to avoid updating the record faster
|
||||||
|
// than once per ms. So we need to sleep here until the next possible update time has
|
||||||
|
// arrived.
|
||||||
|
lastChange := time.Since(ln.update)
|
||||||
|
if lastChange < recordUpdateThrottle {
|
||||||
|
time.Sleep(recordUpdateThrottle - lastChange)
|
||||||
|
}
|
||||||
|
|
||||||
ln.sign()
|
ln.sign()
|
||||||
|
ln.update = time.Now()
|
||||||
return ln.cur.Load().(*Node)
|
return ln.cur.Load().(*Node)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -114,6 +140,10 @@ func (ln *LocalNode) ID() ID {
|
|||||||
// Set puts the given entry into the local record, overwriting any existing value.
|
// Set puts the given entry into the local record, overwriting any existing value.
|
||||||
// Use Set*IP and SetFallbackUDP to set IP addresses and UDP port, otherwise they'll
|
// Use Set*IP and SetFallbackUDP to set IP addresses and UDP port, otherwise they'll
|
||||||
// be overwritten by the endpoint predictor.
|
// be overwritten by the endpoint predictor.
|
||||||
|
//
|
||||||
|
// Since node record updates are throttled to one per second, Set is asynchronous.
|
||||||
|
// Any update will be queued up and published when at least one second passes from
|
||||||
|
// the last change.
|
||||||
func (ln *LocalNode) Set(e enr.Entry) {
|
func (ln *LocalNode) Set(e enr.Entry) {
|
||||||
ln.mu.Lock()
|
ln.mu.Lock()
|
||||||
defer ln.mu.Unlock()
|
defer ln.mu.Unlock()
|
||||||
@ -288,3 +318,12 @@ func (ln *LocalNode) bumpSeq() {
|
|||||||
ln.seq++
|
ln.seq++
|
||||||
ln.db.storeLocalSeq(ln.id, ln.seq)
|
ln.db.storeLocalSeq(ln.id, ln.seq)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// nowMilliseconds gives the current timestamp at millisecond precision.
|
||||||
|
func nowMilliseconds() uint64 {
|
||||||
|
ns := time.Now().UnixNano()
|
||||||
|
if ns < 0 {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
return uint64(ns / 1000 / 1000)
|
||||||
|
}
|
||||||
|
@ -49,32 +49,39 @@ func TestLocalNode(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// This test checks that the sequence number is persisted between restarts.
|
||||||
func TestLocalNodeSeqPersist(t *testing.T) {
|
func TestLocalNodeSeqPersist(t *testing.T) {
|
||||||
|
timestamp := nowMilliseconds()
|
||||||
|
|
||||||
ln, db := newLocalNodeForTesting()
|
ln, db := newLocalNodeForTesting()
|
||||||
defer db.Close()
|
defer db.Close()
|
||||||
|
|
||||||
if s := ln.Node().Seq(); s != 1 {
|
initialSeq := ln.Node().Seq()
|
||||||
t.Fatalf("wrong initial seq %d, want 1", s)
|
if initialSeq < timestamp {
|
||||||
|
t.Fatalf("wrong initial seq %d, want at least %d", initialSeq, timestamp)
|
||||||
}
|
}
|
||||||
|
|
||||||
ln.Set(enr.WithEntry("x", uint(1)))
|
ln.Set(enr.WithEntry("x", uint(1)))
|
||||||
if s := ln.Node().Seq(); s != 2 {
|
if s := ln.Node().Seq(); s != initialSeq+1 {
|
||||||
t.Fatalf("wrong seq %d after set, want 2", s)
|
t.Fatalf("wrong seq %d after set, want %d", s, initialSeq+1)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create a new instance, it should reload the sequence number.
|
// Create a new instance, it should reload the sequence number.
|
||||||
// The number increases just after that because a new record is
|
// The number increases just after that because a new record is
|
||||||
// created without the "x" entry.
|
// created without the "x" entry.
|
||||||
ln2 := NewLocalNode(db, ln.key)
|
ln2 := NewLocalNode(db, ln.key)
|
||||||
if s := ln2.Node().Seq(); s != 3 {
|
if s := ln2.Node().Seq(); s != initialSeq+2 {
|
||||||
t.Fatalf("wrong seq %d on new instance, want 3", s)
|
t.Fatalf("wrong seq %d on new instance, want %d", s, initialSeq+2)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
finalSeq := ln2.Node().Seq()
|
||||||
|
|
||||||
// Create a new instance with a different node key on the same database.
|
// Create a new instance with a different node key on the same database.
|
||||||
// This should reset the sequence number.
|
// This should reset the sequence number.
|
||||||
key, _ := crypto.GenerateKey()
|
key, _ := crypto.GenerateKey()
|
||||||
ln3 := NewLocalNode(db, key)
|
ln3 := NewLocalNode(db, key)
|
||||||
if s := ln3.Node().Seq(); s != 1 {
|
if s := ln3.Node().Seq(); s < finalSeq {
|
||||||
t.Fatalf("wrong seq %d on instance with changed key, want 1", s)
|
t.Fatalf("wrong seq %d on instance with changed key, want >= %d", s, finalSeq)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -91,20 +98,20 @@ func TestLocalNodeEndpoint(t *testing.T) {
|
|||||||
// Nothing is set initially.
|
// Nothing is set initially.
|
||||||
assert.Equal(t, net.IP(nil), ln.Node().IP())
|
assert.Equal(t, net.IP(nil), ln.Node().IP())
|
||||||
assert.Equal(t, 0, ln.Node().UDP())
|
assert.Equal(t, 0, ln.Node().UDP())
|
||||||
assert.Equal(t, uint64(1), ln.Node().Seq())
|
initialSeq := ln.Node().Seq()
|
||||||
|
|
||||||
// Set up fallback address.
|
// Set up fallback address.
|
||||||
ln.SetFallbackIP(fallback.IP)
|
ln.SetFallbackIP(fallback.IP)
|
||||||
ln.SetFallbackUDP(fallback.Port)
|
ln.SetFallbackUDP(fallback.Port)
|
||||||
assert.Equal(t, fallback.IP, ln.Node().IP())
|
assert.Equal(t, fallback.IP, ln.Node().IP())
|
||||||
assert.Equal(t, fallback.Port, ln.Node().UDP())
|
assert.Equal(t, fallback.Port, ln.Node().UDP())
|
||||||
assert.Equal(t, uint64(2), ln.Node().Seq())
|
assert.Equal(t, initialSeq+1, ln.Node().Seq())
|
||||||
|
|
||||||
// Add endpoint statements from random hosts.
|
// Add endpoint statements from random hosts.
|
||||||
for i := 0; i < iptrackMinStatements; i++ {
|
for i := 0; i < iptrackMinStatements; i++ {
|
||||||
assert.Equal(t, fallback.IP, ln.Node().IP())
|
assert.Equal(t, fallback.IP, ln.Node().IP())
|
||||||
assert.Equal(t, fallback.Port, ln.Node().UDP())
|
assert.Equal(t, fallback.Port, ln.Node().UDP())
|
||||||
assert.Equal(t, uint64(2), ln.Node().Seq())
|
assert.Equal(t, initialSeq+1, ln.Node().Seq())
|
||||||
|
|
||||||
from := &net.UDPAddr{IP: make(net.IP, 4), Port: 90}
|
from := &net.UDPAddr{IP: make(net.IP, 4), Port: 90}
|
||||||
rand.Read(from.IP)
|
rand.Read(from.IP)
|
||||||
@ -112,11 +119,11 @@ func TestLocalNodeEndpoint(t *testing.T) {
|
|||||||
}
|
}
|
||||||
assert.Equal(t, predicted.IP, ln.Node().IP())
|
assert.Equal(t, predicted.IP, ln.Node().IP())
|
||||||
assert.Equal(t, predicted.Port, ln.Node().UDP())
|
assert.Equal(t, predicted.Port, ln.Node().UDP())
|
||||||
assert.Equal(t, uint64(3), ln.Node().Seq())
|
assert.Equal(t, initialSeq+2, ln.Node().Seq())
|
||||||
|
|
||||||
// Static IP overrides prediction.
|
// Static IP overrides prediction.
|
||||||
ln.SetStaticIP(staticIP)
|
ln.SetStaticIP(staticIP)
|
||||||
assert.Equal(t, staticIP, ln.Node().IP())
|
assert.Equal(t, staticIP, ln.Node().IP())
|
||||||
assert.Equal(t, fallback.Port, ln.Node().UDP())
|
assert.Equal(t, fallback.Port, ln.Node().UDP())
|
||||||
assert.Equal(t, uint64(4), ln.Node().Seq())
|
assert.Equal(t, initialSeq+3, ln.Node().Seq())
|
||||||
}
|
}
|
||||||
|
@ -427,9 +427,14 @@ func (db *DB) UpdateFindFailsV5(id ID, ip net.IP, fails int) error {
|
|||||||
return db.storeInt64(v5Key(id, ip, dbNodeFindFails), int64(fails))
|
return db.storeInt64(v5Key(id, ip, dbNodeFindFails), int64(fails))
|
||||||
}
|
}
|
||||||
|
|
||||||
// LocalSeq retrieves the local record sequence counter.
|
// localSeq retrieves the local record sequence counter, defaulting to the current
|
||||||
|
// timestamp if no previous exists. This ensures that wiping all data associated
|
||||||
|
// with a node (apart from its key) will not generate already used sequence nums.
|
||||||
func (db *DB) localSeq(id ID) uint64 {
|
func (db *DB) localSeq(id ID) uint64 {
|
||||||
return db.fetchUint64(localItemKey(id, dbLocalSeq))
|
if seq := db.fetchUint64(localItemKey(id, dbLocalSeq)); seq > 0 {
|
||||||
|
return seq
|
||||||
|
}
|
||||||
|
return nowMilliseconds()
|
||||||
}
|
}
|
||||||
|
|
||||||
// storeLocalSeq stores the local record sequence counter.
|
// storeLocalSeq stores the local record sequence counter.
|
||||||
|
Loading…
Reference in New Issue
Block a user