Merge pull request #3153 from fjl/trie-unload-fix

trie: improve cache unloading mechanism
This commit is contained in:
Jeffrey Wilcke 2016-10-19 13:35:49 +02:00 committed by GitHub
commit 25ac04a444
7 changed files with 130 additions and 70 deletions

View File

@ -39,12 +39,12 @@ import (
var StartingNonce uint64 var StartingNonce uint64
const ( const (
// Number of past tries to keep. The arbitrarily chosen value here // Number of past tries to keep. This value is chosen such that
// is max uncle depth + 1. // reasonable chain reorg depths will hit an existing trie.
maxPastTries = 8 maxPastTries = 12
// Trie cache generation limit. // Trie cache generation limit.
maxTrieCacheGen = 100 maxTrieCacheGen = 120
// Number of codehash->size associations to keep. // Number of codehash->size associations to keep.
codeSizeCacheSize = 100000 codeSizeCacheSize = 100000

View File

@ -58,7 +58,7 @@ func (h *hasher) hash(n node, db DatabaseWriter, force bool) (node, node, error)
return hash, n, nil return hash, n, nil
} }
if n.canUnload(h.cachegen, h.cachelimit) { if n.canUnload(h.cachegen, h.cachelimit) {
// Evict the node from cache. All of its subnodes will have a lower or equal // Unload the node from cache. All of its subnodes will have a lower or equal
// cache generation number. // cache generation number.
return hash, hash, nil return hash, hash, nil
} }
@ -75,23 +75,20 @@ func (h *hasher) hash(n node, db DatabaseWriter, force bool) (node, node, error)
if err != nil { if err != nil {
return hashNode{}, n, err return hashNode{}, n, err
} }
// Cache the hash of the ndoe for later reuse. // Cache the hash of the ndoe for later reuse and remove
if hash, ok := hashed.(hashNode); ok && !force { // the dirty flag in commit mode. It's fine to assign these values directly
switch cached := cached.(type) { // without copying the node first because hashChildren copies it.
case *shortNode: cachedHash, _ := hashed.(hashNode)
cached = cached.copy() switch cn := cached.(type) {
cached.flags.hash = hash case *shortNode:
if db != nil { cn.flags.hash = cachedHash
cached.flags.dirty = false if db != nil {
} cn.flags.dirty = false
return hashed, cached, nil }
case *fullNode: case *fullNode:
cached = cached.copy() cn.flags.hash = cachedHash
cached.flags.hash = hash if db != nil {
if db != nil { cn.flags.dirty = false
cached.flags.dirty = false
}
return hashed, cached, nil
} }
} }
return hashed, cached, nil return hashed, cached, nil

View File

@ -104,8 +104,8 @@ func (n valueNode) fstring(ind string) string {
return fmt.Sprintf("%x ", []byte(n)) return fmt.Sprintf("%x ", []byte(n))
} }
func mustDecodeNode(hash, buf []byte) node { func mustDecodeNode(hash, buf []byte, cachegen uint16) node {
n, err := decodeNode(hash, buf) n, err := decodeNode(hash, buf, cachegen)
if err != nil { if err != nil {
panic(fmt.Sprintf("node %x: %v", hash, err)) panic(fmt.Sprintf("node %x: %v", hash, err))
} }
@ -113,7 +113,7 @@ func mustDecodeNode(hash, buf []byte) node {
} }
// decodeNode parses the RLP encoding of a trie node. // decodeNode parses the RLP encoding of a trie node.
func decodeNode(hash, buf []byte) (node, error) { func decodeNode(hash, buf []byte, cachegen uint16) (node, error) {
if len(buf) == 0 { if len(buf) == 0 {
return nil, io.ErrUnexpectedEOF return nil, io.ErrUnexpectedEOF
} }
@ -123,22 +123,22 @@ func decodeNode(hash, buf []byte) (node, error) {
} }
switch c, _ := rlp.CountValues(elems); c { switch c, _ := rlp.CountValues(elems); c {
case 2: case 2:
n, err := decodeShort(hash, buf, elems) n, err := decodeShort(hash, buf, elems, cachegen)
return n, wrapError(err, "short") return n, wrapError(err, "short")
case 17: case 17:
n, err := decodeFull(hash, buf, elems) n, err := decodeFull(hash, buf, elems, cachegen)
return n, wrapError(err, "full") return n, wrapError(err, "full")
default: default:
return nil, fmt.Errorf("invalid number of list elements: %v", c) return nil, fmt.Errorf("invalid number of list elements: %v", c)
} }
} }
func decodeShort(hash, buf, elems []byte) (node, error) { func decodeShort(hash, buf, elems []byte, cachegen uint16) (node, error) {
kbuf, rest, err := rlp.SplitString(elems) kbuf, rest, err := rlp.SplitString(elems)
if err != nil { if err != nil {
return nil, err return nil, err
} }
flag := nodeFlag{hash: hash} flag := nodeFlag{hash: hash, gen: cachegen}
key := compactDecode(kbuf) key := compactDecode(kbuf)
if key[len(key)-1] == 16 { if key[len(key)-1] == 16 {
// value node // value node
@ -148,17 +148,17 @@ func decodeShort(hash, buf, elems []byte) (node, error) {
} }
return &shortNode{key, append(valueNode{}, val...), flag}, nil return &shortNode{key, append(valueNode{}, val...), flag}, nil
} }
r, _, err := decodeRef(rest) r, _, err := decodeRef(rest, cachegen)
if err != nil { if err != nil {
return nil, wrapError(err, "val") return nil, wrapError(err, "val")
} }
return &shortNode{key, r, flag}, nil return &shortNode{key, r, flag}, nil
} }
func decodeFull(hash, buf, elems []byte) (*fullNode, error) { func decodeFull(hash, buf, elems []byte, cachegen uint16) (*fullNode, error) {
n := &fullNode{flags: nodeFlag{hash: hash}} n := &fullNode{flags: nodeFlag{hash: hash, gen: cachegen}}
for i := 0; i < 16; i++ { for i := 0; i < 16; i++ {
cld, rest, err := decodeRef(elems) cld, rest, err := decodeRef(elems, cachegen)
if err != nil { if err != nil {
return n, wrapError(err, fmt.Sprintf("[%d]", i)) return n, wrapError(err, fmt.Sprintf("[%d]", i))
} }
@ -176,7 +176,7 @@ func decodeFull(hash, buf, elems []byte) (*fullNode, error) {
const hashLen = len(common.Hash{}) const hashLen = len(common.Hash{})
func decodeRef(buf []byte) (node, []byte, error) { func decodeRef(buf []byte, cachegen uint16) (node, []byte, error) {
kind, val, rest, err := rlp.Split(buf) kind, val, rest, err := rlp.Split(buf)
if err != nil { if err != nil {
return nil, buf, err return nil, buf, err
@ -189,7 +189,7 @@ func decodeRef(buf []byte) (node, []byte, error) {
err := fmt.Errorf("oversized embedded node (size is %d bytes, want size < %d)", size, hashLen) err := fmt.Errorf("oversized embedded node (size is %d bytes, want size < %d)", size, hashLen)
return nil, buf, err return nil, buf, err
} }
n, err := decodeNode(nil, buf) n, err := decodeNode(nil, buf, cachegen)
return n, rest, err return n, rest, err
case kind == rlp.String && len(val) == 0: case kind == rlp.String && len(val) == 0:
// empty node // empty node

View File

@ -101,7 +101,7 @@ func VerifyProof(rootHash common.Hash, key []byte, proof []rlp.RawValue) (value
if !bytes.Equal(sha.Sum(nil), wantHash) { if !bytes.Equal(sha.Sum(nil), wantHash) {
return nil, fmt.Errorf("bad proof node %d: hash mismatch", i) return nil, fmt.Errorf("bad proof node %d: hash mismatch", i)
} }
n, err := decodeNode(wantHash, buf) n, err := decodeNode(wantHash, buf, 0)
if err != nil { if err != nil {
return nil, fmt.Errorf("bad proof node %d: %v", i, err) return nil, fmt.Errorf("bad proof node %d: %v", i, err)
} }

View File

@ -82,7 +82,7 @@ func (s *TrieSync) AddSubTrie(root common.Hash, depth int, parent common.Hash, c
} }
key := root.Bytes() key := root.Bytes()
blob, _ := s.database.Get(key) blob, _ := s.database.Get(key)
if local, err := decodeNode(key, blob); local != nil && err == nil { if local, err := decodeNode(key, blob, 0); local != nil && err == nil {
return return
} }
// Assemble the new sub-trie sync request // Assemble the new sub-trie sync request
@ -158,7 +158,7 @@ func (s *TrieSync) Process(results []SyncResult) (int, error) {
continue continue
} }
// Decode the node data content and update the request // Decode the node data content and update the request
node, err := decodeNode(item.Hash[:], item.Data) node, err := decodeNode(item.Hash[:], item.Data, 0)
if err != nil { if err != nil {
return i, err return i, err
} }
@ -246,7 +246,7 @@ func (s *TrieSync) children(req *request) ([]*request, error) {
if node, ok := (*child.node).(hashNode); ok { if node, ok := (*child.node).(hashNode); ok {
// Try to resolve the node from the local database // Try to resolve the node from the local database
blob, _ := s.database.Get(node) blob, _ := s.database.Get(node)
if local, err := decodeNode(node[:], blob); local != nil && err == nil { if local, err := decodeNode(node[:], blob, 0); local != nil && err == nil {
*child.node = local *child.node = local
continue continue
} }

View File

@ -105,13 +105,11 @@ func New(root common.Hash, db Database) (*Trie, error) {
if db == nil { if db == nil {
panic("trie.New: cannot use existing root without a database") panic("trie.New: cannot use existing root without a database")
} }
if v, _ := trie.db.Get(root[:]); len(v) == 0 { rootnode, err := trie.resolveHash(root[:], nil, nil)
return nil, &MissingNodeError{ if err != nil {
RootHash: root, return nil, err
NodeHash: root,
}
} }
trie.root = hashNode(root.Bytes()) trie.root = rootnode
} }
return trie, nil return trie, nil
} }
@ -158,14 +156,15 @@ func (t *Trie) tryGet(origNode node, key []byte, pos int) (value []byte, newnode
if err == nil && didResolve { if err == nil && didResolve {
n = n.copy() n = n.copy()
n.Val = newnode n.Val = newnode
n.flags.gen = t.cachegen
} }
return value, n, didResolve, err return value, n, didResolve, err
case *fullNode: case *fullNode:
value, newnode, didResolve, err = t.tryGet(n.Children[key[pos]], key, pos+1) value, newnode, didResolve, err = t.tryGet(n.Children[key[pos]], key, pos+1)
if err == nil && didResolve { if err == nil && didResolve {
n = n.copy() n = n.copy()
n.flags.gen = t.cachegen
n.Children[key[pos]] = newnode n.Children[key[pos]] = newnode
} }
return value, n, didResolve, err return value, n, didResolve, err
case hashNode: case hashNode:
@ -261,7 +260,8 @@ func (t *Trie) insert(n node, prefix, key []byte, value node) (bool, node, error
return false, n, err return false, n, err
} }
n = n.copy() n = n.copy()
n.Children[key[0]], n.flags.hash, n.flags.dirty = nn, nil, true n.flags = t.newFlag()
n.Children[key[0]] = nn
return true, n, nil return true, n, nil
case nil: case nil:
@ -345,7 +345,8 @@ func (t *Trie) delete(n node, prefix, key []byte) (bool, node, error) {
return false, n, err return false, n, err
} }
n = n.copy() n = n.copy()
n.Children[key[0]], n.flags.hash, n.flags.dirty = nn, nil, true n.flags = t.newFlag()
n.Children[key[0]] = nn
// Check how many non-nil entries are left after deleting and // Check how many non-nil entries are left after deleting and
// reduce the full node to a short node if only one entry is // reduce the full node to a short node if only one entry is
@ -443,7 +444,7 @@ func (t *Trie) resolveHash(n hashNode, prefix, suffix []byte) (node, error) {
SuffixLen: len(suffix), SuffixLen: len(suffix),
} }
} }
dec := mustDecodeNode(n, enc) dec := mustDecodeNode(n, enc, t.cachegen)
return dec, nil return dec, nil
} }

View File

@ -300,25 +300,6 @@ func TestReplication(t *testing.T) {
} }
} }
// Not an actual test
func TestOutput(t *testing.T) {
t.Skip()
base := "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"
trie := newEmpty()
for i := 0; i < 50; i++ {
updateString(trie, fmt.Sprintf("%s%d", base, i), "valueeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee")
}
fmt.Println("############################## FULL ################################")
fmt.Println(trie.root)
trie.Commit()
fmt.Println("############################## SMALL ################################")
trie2, _ := New(trie.Hash(), trie.db)
getString(trie2, base+"20")
fmt.Println(trie2.root)
}
func TestLargeValue(t *testing.T) { func TestLargeValue(t *testing.T) {
trie := newEmpty() trie := newEmpty()
trie.Update([]byte("key1"), []byte{99, 99, 99, 99}) trie.Update([]byte("key1"), []byte{99, 99, 99, 99})
@ -326,14 +307,56 @@ func TestLargeValue(t *testing.T) {
trie.Hash() trie.Hash()
} }
type countingDB struct {
Database
gets map[string]int
}
func (db *countingDB) Get(key []byte) ([]byte, error) {
db.gets[string(key)]++
return db.Database.Get(key)
}
// TestCacheUnload checks that decoded nodes are unloaded after a
// certain number of commit operations.
func TestCacheUnload(t *testing.T) {
// Create test trie with two branches.
trie := newEmpty()
key1 := "---------------------------------"
key2 := "---some other branch"
updateString(trie, key1, "this is the branch of key1.")
updateString(trie, key2, "this is the branch of key2.")
root, _ := trie.Commit()
// Commit the trie repeatedly and access key1.
// The branch containing it is loaded from DB exactly two times:
// in the 0th and 6th iteration.
db := &countingDB{Database: trie.db, gets: make(map[string]int)}
trie, _ = New(root, db)
trie.SetCacheLimit(5)
for i := 0; i < 12; i++ {
getString(trie, key1)
trie.Commit()
}
// Check that it got loaded two times.
for dbkey, count := range db.gets {
if count != 2 {
t.Errorf("db key %x loaded %d times, want %d times", []byte(dbkey), count, 2)
}
}
}
// randTest performs random trie operations.
// Instances of this test are created by Generate.
type randTest []randTestStep
type randTestStep struct { type randTestStep struct {
op int op int
key []byte // for opUpdate, opDelete, opGet key []byte // for opUpdate, opDelete, opGet
value []byte // for opUpdate value []byte // for opUpdate
} }
type randTest []randTestStep
const ( const (
opUpdate = iota opUpdate = iota
opDelete opDelete
@ -342,6 +365,7 @@ const (
opHash opHash
opReset opReset
opItercheckhash opItercheckhash
opCheckCacheInvariant
opMax // boundary value, not an actual op opMax // boundary value, not an actual op
) )
@ -437,6 +461,44 @@ func runRandTest(rt randTest) bool {
fmt.Println("hashes not equal") fmt.Println("hashes not equal")
return false return false
} }
case opCheckCacheInvariant:
return checkCacheInvariant(tr.root, nil, tr.cachegen, false, 0)
}
}
return true
}
func checkCacheInvariant(n, parent node, parentCachegen uint16, parentDirty bool, depth int) bool {
var children []node
var flag nodeFlag
switch n := n.(type) {
case *shortNode:
flag = n.flags
children = []node{n.Val}
case *fullNode:
flag = n.flags
children = n.Children[:]
default:
return true
}
showerror := func() {
fmt.Printf("at depth %d node %s", depth, spew.Sdump(n))
fmt.Printf("parent: %s", spew.Sdump(parent))
}
if flag.gen > parentCachegen {
fmt.Printf("cache invariant violation: %d > %d\n", flag.gen, parentCachegen)
showerror()
return false
}
if depth > 0 && !parentDirty && flag.dirty {
fmt.Printf("cache invariant violation: child is dirty but parent isn't\n")
showerror()
return false
}
for _, child := range children {
if !checkCacheInvariant(child, n, flag.gen, flag.dirty, depth+1) {
return false
} }
} }
return true return true