Patch for concurrent iterator & others (onto v1.11.6) #386

Closed
roysc wants to merge 1565 commits from v1.11.6-statediff-v5 into master
4 changed files with 69 additions and 21 deletions
Showing only changes of commit 6ef3a16869 - Show all commits

View File

@ -512,9 +512,6 @@ func (n *handshakeTestNode) init(key *ecdsa.PrivateKey, ip net.IP, clock mclock.
db, _ := enode.OpenDB("") db, _ := enode.OpenDB("")
n.ln = enode.NewLocalNode(db, key) n.ln = enode.NewLocalNode(db, key)
n.ln.SetStaticIP(ip) n.ln.SetStaticIP(ip)
if n.ln.Node().Seq() != 1 {
panic(fmt.Errorf("unexpected seq %d", n.ln.Node().Seq()))
}
n.c = NewCodec(n.ln, key, clock) n.c = NewCodec(n.ln, key, clock)
} }

View File

@ -36,20 +36,25 @@ const (
iptrackMinStatements = 10 iptrackMinStatements = 10
iptrackWindow = 5 * time.Minute iptrackWindow = 5 * time.Minute
iptrackContactWindow = 10 * time.Minute iptrackContactWindow = 10 * time.Minute
// time needed to wait between two updates to the local ENR
recordUpdateThrottle = time.Millisecond
) )
// LocalNode produces the signed node record of a local node, i.e. a node run in the // LocalNode produces the signed node record of a local node, i.e. a node run in the
// current process. Setting ENR entries via the Set method updates the record. A new version // current process. Setting ENR entries via the Set method updates the record. A new version
// of the record is signed on demand when the Node method is called. // of the record is signed on demand when the Node method is called.
type LocalNode struct { type LocalNode struct {
cur atomic.Value // holds a non-nil node pointer while the record is up-to-date. cur atomic.Value // holds a non-nil node pointer while the record is up-to-date
id ID id ID
key *ecdsa.PrivateKey key *ecdsa.PrivateKey
db *DB db *DB
// everything below is protected by a lock // everything below is protected by a lock
mu sync.Mutex mu sync.RWMutex
seq uint64 seq uint64
update time.Time // timestamp when the record was last updated
entries map[string]enr.Entry entries map[string]enr.Entry
endpoint4 lnEndpoint endpoint4 lnEndpoint
endpoint6 lnEndpoint endpoint6 lnEndpoint
@ -76,7 +81,8 @@ func NewLocalNode(db *DB, key *ecdsa.PrivateKey) *LocalNode {
}, },
} }
ln.seq = db.localSeq(ln.id) ln.seq = db.localSeq(ln.id)
ln.invalidate() ln.update = time.Now()
ln.cur.Store((*Node)(nil))
return ln return ln
} }
@ -87,14 +93,34 @@ func (ln *LocalNode) Database() *DB {
// Node returns the current version of the local node record. // Node returns the current version of the local node record.
func (ln *LocalNode) Node() *Node { func (ln *LocalNode) Node() *Node {
// If we have a valid record, return that
n := ln.cur.Load().(*Node) n := ln.cur.Load().(*Node)
if n != nil { if n != nil {
return n return n
} }
// Record was invalidated, sign a new copy. // Record was invalidated, sign a new copy.
ln.mu.Lock() ln.mu.Lock()
defer ln.mu.Unlock() defer ln.mu.Unlock()
// Double check the current record, since multiple goroutines might be waiting
// on the write mutex.
if n = ln.cur.Load().(*Node); n != nil {
return n
}
// The initial sequence number is the current timestamp in milliseconds. To ensure
// that the initial sequence number will always be higher than any previous sequence
// number (assuming the clock is correct), we want to avoid updating the record faster
// than once per ms. So we need to sleep here until the next possible update time has
// arrived.
lastChange := time.Since(ln.update)
if lastChange < recordUpdateThrottle {
time.Sleep(recordUpdateThrottle - lastChange)
}
ln.sign() ln.sign()
ln.update = time.Now()
return ln.cur.Load().(*Node) return ln.cur.Load().(*Node)
} }
@ -114,6 +140,10 @@ func (ln *LocalNode) ID() ID {
// Set puts the given entry into the local record, overwriting any existing value. // Set puts the given entry into the local record, overwriting any existing value.
// Use Set*IP and SetFallbackUDP to set IP addresses and UDP port, otherwise they'll // Use Set*IP and SetFallbackUDP to set IP addresses and UDP port, otherwise they'll
// be overwritten by the endpoint predictor. // be overwritten by the endpoint predictor.
//
// Since node record updates are throttled to one per second, Set is asynchronous.
// Any update will be queued up and published when at least one second passes from
// the last change.
func (ln *LocalNode) Set(e enr.Entry) { func (ln *LocalNode) Set(e enr.Entry) {
ln.mu.Lock() ln.mu.Lock()
defer ln.mu.Unlock() defer ln.mu.Unlock()
@ -288,3 +318,12 @@ func (ln *LocalNode) bumpSeq() {
ln.seq++ ln.seq++
ln.db.storeLocalSeq(ln.id, ln.seq) ln.db.storeLocalSeq(ln.id, ln.seq)
} }
// nowMilliseconds gives the current timestamp at millisecond precision.
func nowMilliseconds() uint64 {
ns := time.Now().UnixNano()
if ns < 0 {
return 0
}
return uint64(ns / 1000 / 1000)
}

View File

@ -49,32 +49,39 @@ func TestLocalNode(t *testing.T) {
} }
} }
// This test checks that the sequence number is persisted between restarts.
func TestLocalNodeSeqPersist(t *testing.T) { func TestLocalNodeSeqPersist(t *testing.T) {
timestamp := nowMilliseconds()
ln, db := newLocalNodeForTesting() ln, db := newLocalNodeForTesting()
defer db.Close() defer db.Close()
if s := ln.Node().Seq(); s != 1 { initialSeq := ln.Node().Seq()
t.Fatalf("wrong initial seq %d, want 1", s) if initialSeq < timestamp {
t.Fatalf("wrong initial seq %d, want at least %d", initialSeq, timestamp)
} }
ln.Set(enr.WithEntry("x", uint(1))) ln.Set(enr.WithEntry("x", uint(1)))
if s := ln.Node().Seq(); s != 2 { if s := ln.Node().Seq(); s != initialSeq+1 {
t.Fatalf("wrong seq %d after set, want 2", s) t.Fatalf("wrong seq %d after set, want %d", s, initialSeq+1)
} }
// Create a new instance, it should reload the sequence number. // Create a new instance, it should reload the sequence number.
// The number increases just after that because a new record is // The number increases just after that because a new record is
// created without the "x" entry. // created without the "x" entry.
ln2 := NewLocalNode(db, ln.key) ln2 := NewLocalNode(db, ln.key)
if s := ln2.Node().Seq(); s != 3 { if s := ln2.Node().Seq(); s != initialSeq+2 {
t.Fatalf("wrong seq %d on new instance, want 3", s) t.Fatalf("wrong seq %d on new instance, want %d", s, initialSeq+2)
} }
finalSeq := ln2.Node().Seq()
// Create a new instance with a different node key on the same database. // Create a new instance with a different node key on the same database.
// This should reset the sequence number. // This should reset the sequence number.
key, _ := crypto.GenerateKey() key, _ := crypto.GenerateKey()
ln3 := NewLocalNode(db, key) ln3 := NewLocalNode(db, key)
if s := ln3.Node().Seq(); s != 1 { if s := ln3.Node().Seq(); s < finalSeq {
t.Fatalf("wrong seq %d on instance with changed key, want 1", s) t.Fatalf("wrong seq %d on instance with changed key, want >= %d", s, finalSeq)
} }
} }
@ -91,20 +98,20 @@ func TestLocalNodeEndpoint(t *testing.T) {
// Nothing is set initially. // Nothing is set initially.
assert.Equal(t, net.IP(nil), ln.Node().IP()) assert.Equal(t, net.IP(nil), ln.Node().IP())
assert.Equal(t, 0, ln.Node().UDP()) assert.Equal(t, 0, ln.Node().UDP())
assert.Equal(t, uint64(1), ln.Node().Seq()) initialSeq := ln.Node().Seq()
// Set up fallback address. // Set up fallback address.
ln.SetFallbackIP(fallback.IP) ln.SetFallbackIP(fallback.IP)
ln.SetFallbackUDP(fallback.Port) ln.SetFallbackUDP(fallback.Port)
assert.Equal(t, fallback.IP, ln.Node().IP()) assert.Equal(t, fallback.IP, ln.Node().IP())
assert.Equal(t, fallback.Port, ln.Node().UDP()) assert.Equal(t, fallback.Port, ln.Node().UDP())
assert.Equal(t, uint64(2), ln.Node().Seq()) assert.Equal(t, initialSeq+1, ln.Node().Seq())
// Add endpoint statements from random hosts. // Add endpoint statements from random hosts.
for i := 0; i < iptrackMinStatements; i++ { for i := 0; i < iptrackMinStatements; i++ {
assert.Equal(t, fallback.IP, ln.Node().IP()) assert.Equal(t, fallback.IP, ln.Node().IP())
assert.Equal(t, fallback.Port, ln.Node().UDP()) assert.Equal(t, fallback.Port, ln.Node().UDP())
assert.Equal(t, uint64(2), ln.Node().Seq()) assert.Equal(t, initialSeq+1, ln.Node().Seq())
from := &net.UDPAddr{IP: make(net.IP, 4), Port: 90} from := &net.UDPAddr{IP: make(net.IP, 4), Port: 90}
rand.Read(from.IP) rand.Read(from.IP)
@ -112,11 +119,11 @@ func TestLocalNodeEndpoint(t *testing.T) {
} }
assert.Equal(t, predicted.IP, ln.Node().IP()) assert.Equal(t, predicted.IP, ln.Node().IP())
assert.Equal(t, predicted.Port, ln.Node().UDP()) assert.Equal(t, predicted.Port, ln.Node().UDP())
assert.Equal(t, uint64(3), ln.Node().Seq()) assert.Equal(t, initialSeq+2, ln.Node().Seq())
// Static IP overrides prediction. // Static IP overrides prediction.
ln.SetStaticIP(staticIP) ln.SetStaticIP(staticIP)
assert.Equal(t, staticIP, ln.Node().IP()) assert.Equal(t, staticIP, ln.Node().IP())
assert.Equal(t, fallback.Port, ln.Node().UDP()) assert.Equal(t, fallback.Port, ln.Node().UDP())
assert.Equal(t, uint64(4), ln.Node().Seq()) assert.Equal(t, initialSeq+3, ln.Node().Seq())
} }

View File

@ -427,9 +427,14 @@ func (db *DB) UpdateFindFailsV5(id ID, ip net.IP, fails int) error {
return db.storeInt64(v5Key(id, ip, dbNodeFindFails), int64(fails)) return db.storeInt64(v5Key(id, ip, dbNodeFindFails), int64(fails))
} }
// LocalSeq retrieves the local record sequence counter. // localSeq retrieves the local record sequence counter, defaulting to the current
// timestamp if no previous exists. This ensures that wiping all data associated
// with a node (apart from its key) will not generate already used sequence nums.
func (db *DB) localSeq(id ID) uint64 { func (db *DB) localSeq(id ID) uint64 {
return db.fetchUint64(localItemKey(id, dbLocalSeq)) if seq := db.fetchUint64(localItemKey(id, dbLocalSeq)); seq > 0 {
return seq
}
return nowMilliseconds()
} }
// storeLocalSeq stores the local record sequence counter. // storeLocalSeq stores the local record sequence counter.