trie, core/state: improve memory usage and performance (#3135)

* trie: store nodes as pointers

This avoids memory copies when unwrapping node interface values.

name      old time/op  new time/op  delta
Get        388ns ± 8%   215ns ± 2%  -44.56%  (p=0.000 n=15+15)
GetDB      363ns ± 3%   202ns ± 2%  -44.21%  (p=0.000 n=15+15)
UpdateBE  1.57µs ± 2%  1.29µs ± 3%  -17.80%  (p=0.000 n=13+15)
UpdateLE  1.92µs ± 2%  1.61µs ± 2%  -16.25%  (p=0.000 n=14+14)
HashBE    2.16µs ± 6%  2.18µs ± 6%     ~     (p=0.436 n=15+15)
HashLE    7.43µs ± 3%  7.21µs ± 3%   -2.96%  (p=0.000 n=15+13)

* trie: close temporary databases in GetDB benchmark

* trie: don't keep []byte from DB load around

Nodes decoded from a DB load kept hashes and values as sub-slices of
the DB value. This can be a problem because loading from leveldb often
returns []byte with a cap that's larger than necessary, increasing
memory usage.

* trie: unload old cached nodes

* trie, core/state: use cache unloading for account trie

* trie: use explicit private flags (fixes Go 1.5 reflection issue).

* trie: fixup cachegen overflow at request of nick

* core/state: rename journal size constant
This commit is contained in:
Felix Lange 2016-10-14 18:04:33 +02:00 committed by Péter Szilágyi
parent c2ddfb343a
commit 40cdcf1183
15 changed files with 249 additions and 142 deletions

View File

@ -269,7 +269,7 @@ func (self *BlockChain) FastSyncCommitHead(hash common.Hash) error {
if block == nil { if block == nil {
return fmt.Errorf("non existent block [%x…]", hash[:4]) return fmt.Errorf("non existent block [%x…]", hash[:4])
} }
if _, err := trie.NewSecure(block.Root(), self.chainDb); err != nil { if _, err := trie.NewSecure(block.Root(), self.chainDb, 0); err != nil {
return err return err
} }
// If all checks out, manually set the head block // If all checks out, manually set the head block

View File

@ -137,9 +137,9 @@ func (self *StateObject) markSuicided() {
func (c *StateObject) getTrie(db trie.Database) *trie.SecureTrie { func (c *StateObject) getTrie(db trie.Database) *trie.SecureTrie {
if c.trie == nil { if c.trie == nil {
var err error var err error
c.trie, err = trie.NewSecure(c.data.Root, db) c.trie, err = trie.NewSecure(c.data.Root, db, 0)
if err != nil { if err != nil {
c.trie, _ = trie.NewSecure(common.Hash{}, db) c.trie, _ = trie.NewSecure(common.Hash{}, db, 0)
c.setError(fmt.Errorf("can't create storage trie: %v", err)) c.setError(fmt.Errorf("can't create storage trie: %v", err))
} }
} }

View File

@ -41,7 +41,10 @@ var StartingNonce uint64
const ( const (
// Number of past tries to keep. The arbitrarily chosen value here // Number of past tries to keep. The arbitrarily chosen value here
// is max uncle depth + 1. // is max uncle depth + 1.
maxTrieCacheLength = 8 maxPastTries = 8
// Trie cache generation limit.
maxTrieCacheGen = 100
// Number of codehash->size associations to keep. // Number of codehash->size associations to keep.
codeSizeCacheSize = 100000 codeSizeCacheSize = 100000
@ -86,7 +89,7 @@ type StateDB struct {
// Create a new state from a given trie // Create a new state from a given trie
func New(root common.Hash, db ethdb.Database) (*StateDB, error) { func New(root common.Hash, db ethdb.Database) (*StateDB, error) {
tr, err := trie.NewSecure(root, db) tr, err := trie.NewSecure(root, db, maxTrieCacheGen)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -155,14 +158,14 @@ func (self *StateDB) openTrie(root common.Hash) (*trie.SecureTrie, error) {
return &tr, nil return &tr, nil
} }
} }
return trie.NewSecure(root, self.db) return trie.NewSecure(root, self.db, maxTrieCacheGen)
} }
func (self *StateDB) pushTrie(t *trie.SecureTrie) { func (self *StateDB) pushTrie(t *trie.SecureTrie) {
self.lock.Lock() self.lock.Lock()
defer self.lock.Unlock() defer self.lock.Unlock()
if len(self.pastTries) >= maxTrieCacheLength { if len(self.pastTries) >= maxPastTries {
copy(self.pastTries, self.pastTries[1:]) copy(self.pastTries, self.pastTries[1:])
self.pastTries[len(self.pastTries)-1] = t self.pastTries[len(self.pastTries)-1] = t
} else { } else {

View File

@ -286,7 +286,7 @@ func (dl *downloadTester) headFastBlock() *types.Block {
func (dl *downloadTester) commitHeadBlock(hash common.Hash) error { func (dl *downloadTester) commitHeadBlock(hash common.Hash) error {
// For now only check that the state trie is correct // For now only check that the state trie is correct
if block := dl.getBlock(hash); block != nil { if block := dl.getBlock(hash); block != nil {
_, err := trie.NewSecure(block.Root(), dl.stateDb) _, err := trie.NewSecure(block.Root(), dl.stateDb, 0)
return err return err
} }
return fmt.Errorf("non existent block: %x", hash[:4]) return fmt.Errorf("non existent block: %x", hash[:4])

View File

@ -79,7 +79,7 @@ func (t *LightTrie) do(ctx context.Context, fallbackKey []byte, fn func() error)
func (t *LightTrie) Get(ctx context.Context, key []byte) (res []byte, err error) { func (t *LightTrie) Get(ctx context.Context, key []byte) (res []byte, err error) {
err = t.do(ctx, key, func() (err error) { err = t.do(ctx, key, func() (err error) {
if t.trie == nil { if t.trie == nil {
t.trie, err = trie.NewSecure(t.originalRoot, t.db) t.trie, err = trie.NewSecure(t.originalRoot, t.db, 0)
} }
if err == nil { if err == nil {
res, err = t.trie.TryGet(key) res, err = t.trie.TryGet(key)
@ -98,7 +98,7 @@ func (t *LightTrie) Get(ctx context.Context, key []byte) (res []byte, err error)
func (t *LightTrie) Update(ctx context.Context, key, value []byte) (err error) { func (t *LightTrie) Update(ctx context.Context, key, value []byte) (err error) {
err = t.do(ctx, key, func() (err error) { err = t.do(ctx, key, func() (err error) {
if t.trie == nil { if t.trie == nil {
t.trie, err = trie.NewSecure(t.originalRoot, t.db) t.trie, err = trie.NewSecure(t.originalRoot, t.db, 0)
} }
if err == nil { if err == nil {
err = t.trie.TryUpdate(key, value) err = t.trie.TryUpdate(key, value)
@ -112,7 +112,7 @@ func (t *LightTrie) Update(ctx context.Context, key, value []byte) (err error) {
func (t *LightTrie) Delete(ctx context.Context, key []byte) (err error) { func (t *LightTrie) Delete(ctx context.Context, key []byte) (err error) {
err = t.do(ctx, key, func() (err error) { err = t.do(ctx, key, func() (err error) {
if t.trie == nil { if t.trie == nil {
t.trie, err = trie.NewSecure(t.originalRoot, t.db) t.trie, err = trie.NewSecure(t.originalRoot, t.db, 0)
} }
if err == nil { if err == nil {
err = t.trie.TryDelete(key) err = t.trie.TryDelete(key)

View File

@ -29,6 +29,7 @@ import (
type hasher struct { type hasher struct {
tmp *bytes.Buffer tmp *bytes.Buffer
sha hash.Hash sha hash.Hash
cachegen, cachelimit uint16
} }
// hashers live in a global pool. // hashers live in a global pool.
@ -38,8 +39,10 @@ var hasherPool = sync.Pool{
}, },
} }
func newHasher() *hasher { func newHasher(cachegen, cachelimit uint16) *hasher {
return hasherPool.Get().(*hasher) h := hasherPool.Get().(*hasher)
h.cachegen, h.cachelimit = cachegen, cachelimit
return h
} }
func returnHasherToPool(h *hasher) { func returnHasherToPool(h *hasher) {
@ -50,9 +53,19 @@ func returnHasherToPool(h *hasher) {
// original node initialzied with the computed hash to replace the original one. // original node initialzied with the computed hash to replace the original one.
func (h *hasher) hash(n node, db DatabaseWriter, force bool) (node, node, error) { func (h *hasher) hash(n node, db DatabaseWriter, force bool) (node, node, error) {
// If we're not storing the node, just hashing, use avaialble cached data // If we're not storing the node, just hashing, use avaialble cached data
if hash, dirty := n.cache(); hash != nil && (db == nil || !dirty) { if hash, dirty := n.cache(); hash != nil {
if db == nil {
return hash, n, nil return hash, n, nil
} }
if n.canUnload(h.cachegen, h.cachelimit) {
// Evict the node from cache. All of its subnodes will have a lower or equal
// cache generation number.
return hash, hash, nil
}
if !dirty {
return hash, n, nil
}
}
// Trie not processed yet or needs storage, walk the children // Trie not processed yet or needs storage, walk the children
collapsed, cached, err := h.hashChildren(n, db) collapsed, cached, err := h.hashChildren(n, db)
if err != nil { if err != nil {
@ -62,19 +75,21 @@ func (h *hasher) hash(n node, db DatabaseWriter, force bool) (node, node, error)
if err != nil { if err != nil {
return hashNode{}, n, err return hashNode{}, n, err
} }
// Cache the hash and RLP blob of the ndoe for later reuse // Cache the hash of the ndoe for later reuse.
if hash, ok := hashed.(hashNode); ok && !force { if hash, ok := hashed.(hashNode); ok && !force {
switch cached := cached.(type) { switch cached := cached.(type) {
case shortNode: case *shortNode:
cached.hash = hash cached = cached.copy()
cached.flags.hash = hash
if db != nil { if db != nil {
cached.dirty = false cached.flags.dirty = false
} }
return hashed, cached, nil return hashed, cached, nil
case fullNode: case *fullNode:
cached.hash = hash cached = cached.copy()
cached.flags.hash = hash
if db != nil { if db != nil {
cached.dirty = false cached.flags.dirty = false
} }
return hashed, cached, nil return hashed, cached, nil
} }
@ -89,40 +104,42 @@ func (h *hasher) hashChildren(original node, db DatabaseWriter) (node, node, err
var err error var err error
switch n := original.(type) { switch n := original.(type) {
case shortNode: case *shortNode:
// Hash the short node's child, caching the newly hashed subtree // Hash the short node's child, caching the newly hashed subtree
cached := n collapsed, cached := n.copy(), n.copy()
cached.Key = common.CopyBytes(cached.Key) collapsed.Key = compactEncode(n.Key)
cached.Key = common.CopyBytes(n.Key)
n.Key = compactEncode(n.Key)
if _, ok := n.Val.(valueNode); !ok { if _, ok := n.Val.(valueNode); !ok {
if n.Val, cached.Val, err = h.hash(n.Val, db, false); err != nil { collapsed.Val, cached.Val, err = h.hash(n.Val, db, false)
return n, original, err if err != nil {
return original, original, err
} }
} }
if n.Val == nil { if collapsed.Val == nil {
n.Val = valueNode(nil) // Ensure that nil children are encoded as empty strings. collapsed.Val = valueNode(nil) // Ensure that nil children are encoded as empty strings.
} }
return n, cached, nil return collapsed, cached, nil
case fullNode: case *fullNode:
// Hash the full node's children, caching the newly hashed subtrees // Hash the full node's children, caching the newly hashed subtrees
cached := fullNode{dirty: n.dirty} collapsed, cached := n.copy(), n.copy()
for i := 0; i < 16; i++ { for i := 0; i < 16; i++ {
if n.Children[i] != nil { if n.Children[i] != nil {
if n.Children[i], cached.Children[i], err = h.hash(n.Children[i], db, false); err != nil { collapsed.Children[i], cached.Children[i], err = h.hash(n.Children[i], db, false)
return n, original, err if err != nil {
return original, original, err
} }
} else { } else {
n.Children[i] = valueNode(nil) // Ensure that nil children are encoded as empty strings. collapsed.Children[i] = valueNode(nil) // Ensure that nil children are encoded as empty strings.
} }
} }
cached.Children[16] = n.Children[16] cached.Children[16] = n.Children[16]
if n.Children[16] == nil { if collapsed.Children[16] == nil {
n.Children[16] = valueNode(nil) collapsed.Children[16] = valueNode(nil)
} }
return n, cached, nil return collapsed, cached, nil
default: default:
// Value and hash nodes don't have children so they're left as were // Value and hash nodes don't have children so they're left as were
@ -140,6 +157,7 @@ func (h *hasher) store(n node, db DatabaseWriter, force bool) (node, error) {
if err := rlp.Encode(h.tmp, n); err != nil { if err := rlp.Encode(h.tmp, n); err != nil {
panic("encode error: " + err.Error()) panic("encode error: " + err.Error())
} }
if h.tmp.Len() < 32 && !force { if h.tmp.Len() < 32 && !force {
return n, nil // Nodes smaller than 32 bytes are stored inside their parent return n, nil // Nodes smaller than 32 bytes are stored inside their parent
} }

View File

@ -56,11 +56,11 @@ func (it *Iterator) makeKey() []byte {
key := it.keyBuf[:0] key := it.keyBuf[:0]
for _, se := range it.nodeIt.stack { for _, se := range it.nodeIt.stack {
switch node := se.node.(type) { switch node := se.node.(type) {
case fullNode: case *fullNode:
if se.child <= 16 { if se.child <= 16 {
key = append(key, byte(se.child)) key = append(key, byte(se.child))
} }
case shortNode: case *shortNode:
if hasTerm(node.Key) { if hasTerm(node.Key) {
key = append(key, node.Key[:len(node.Key)-1]...) key = append(key, node.Key[:len(node.Key)-1]...)
} else { } else {
@ -148,7 +148,7 @@ func (it *NodeIterator) step() error {
if (ancestor == common.Hash{}) { if (ancestor == common.Hash{}) {
ancestor = parent.parent ancestor = parent.parent
} }
if node, ok := parent.node.(fullNode); ok { if node, ok := parent.node.(*fullNode); ok {
// Full node, traverse all children, then the node itself // Full node, traverse all children, then the node itself
if parent.child >= len(node.Children) { if parent.child >= len(node.Children) {
break break
@ -156,7 +156,7 @@ func (it *NodeIterator) step() error {
for parent.child++; parent.child < len(node.Children); parent.child++ { for parent.child++; parent.child < len(node.Children); parent.child++ {
if current := node.Children[parent.child]; current != nil { if current := node.Children[parent.child]; current != nil {
it.stack = append(it.stack, &nodeIteratorState{ it.stack = append(it.stack, &nodeIteratorState{
hash: common.BytesToHash(node.hash), hash: common.BytesToHash(node.flags.hash),
node: current, node: current,
parent: ancestor, parent: ancestor,
child: -1, child: -1,
@ -164,14 +164,14 @@ func (it *NodeIterator) step() error {
break break
} }
} }
} else if node, ok := parent.node.(shortNode); ok { } else if node, ok := parent.node.(*shortNode); ok {
// Short node, traverse the pointer singleton child, then the node itself // Short node, traverse the pointer singleton child, then the node itself
if parent.child >= 0 { if parent.child >= 0 {
break break
} }
parent.child++ parent.child++
it.stack = append(it.stack, &nodeIteratorState{ it.stack = append(it.stack, &nodeIteratorState{
hash: common.BytesToHash(node.hash), hash: common.BytesToHash(node.flags.hash),
node: node.Val, node: node.Val,
parent: ancestor, parent: ancestor,
child: -1, child: -1,

View File

@ -30,42 +30,60 @@ var indices = []string{"0", "1", "2", "3", "4", "5", "6", "7", "8", "9", "a", "b
type node interface { type node interface {
fstring(string) string fstring(string) string
cache() (hashNode, bool) cache() (hashNode, bool)
canUnload(cachegen, cachelimit uint16) bool
} }
type ( type (
fullNode struct { fullNode struct {
Children [17]node // Actual trie node data to encode/decode (needs custom encoder) Children [17]node // Actual trie node data to encode/decode (needs custom encoder)
hash hashNode // Cached hash of the node to prevent rehashing (may be nil) flags nodeFlag
dirty bool // Cached flag whether the node's new or already stored
} }
shortNode struct { shortNode struct {
Key []byte Key []byte
Val node Val node
hash hashNode // Cached hash of the node to prevent rehashing (may be nil) flags nodeFlag
dirty bool // Cached flag whether the node's new or already stored
} }
hashNode []byte hashNode []byte
valueNode []byte valueNode []byte
) )
// EncodeRLP encodes a full node into the consensus RLP format. // EncodeRLP encodes a full node into the consensus RLP format.
func (n fullNode) EncodeRLP(w io.Writer) error { func (n *fullNode) EncodeRLP(w io.Writer) error {
return rlp.Encode(w, n.Children) return rlp.Encode(w, n.Children)
} }
// Cache accessors to retrieve precalculated values (avoid lengthy type switches). func (n *fullNode) copy() *fullNode { copy := *n; return &copy }
func (n fullNode) cache() (hashNode, bool) { return n.hash, n.dirty } func (n *shortNode) copy() *shortNode { copy := *n; return &copy }
func (n shortNode) cache() (hashNode, bool) { return n.hash, n.dirty }
// nodeFlag contains caching-related metadata about a node.
type nodeFlag struct {
hash hashNode // cached hash of the node (may be nil)
gen uint16 // cache generation counter
dirty bool // whether the node has changes that must be written to the database
}
// canUnload tells whether a node can be unloaded.
func (n *nodeFlag) canUnload(cachegen, cachelimit uint16) bool {
return !n.dirty && cachegen-n.gen >= cachelimit
}
func (n *fullNode) canUnload(gen, limit uint16) bool { return n.flags.canUnload(gen, limit) }
func (n *shortNode) canUnload(gen, limit uint16) bool { return n.flags.canUnload(gen, limit) }
func (n hashNode) canUnload(uint16, uint16) bool { return false }
func (n valueNode) canUnload(uint16, uint16) bool { return false }
func (n *fullNode) cache() (hashNode, bool) { return n.flags.hash, n.flags.dirty }
func (n *shortNode) cache() (hashNode, bool) { return n.flags.hash, n.flags.dirty }
func (n hashNode) cache() (hashNode, bool) { return nil, true } func (n hashNode) cache() (hashNode, bool) { return nil, true }
func (n valueNode) cache() (hashNode, bool) { return nil, true } func (n valueNode) cache() (hashNode, bool) { return nil, true }
// Pretty printing. // Pretty printing.
func (n fullNode) String() string { return n.fstring("") } func (n *fullNode) String() string { return n.fstring("") }
func (n shortNode) String() string { return n.fstring("") } func (n *shortNode) String() string { return n.fstring("") }
func (n hashNode) String() string { return n.fstring("") } func (n hashNode) String() string { return n.fstring("") }
func (n valueNode) String() string { return n.fstring("") } func (n valueNode) String() string { return n.fstring("") }
func (n fullNode) fstring(ind string) string { func (n *fullNode) fstring(ind string) string {
resp := fmt.Sprintf("[\n%s ", ind) resp := fmt.Sprintf("[\n%s ", ind)
for i, node := range n.Children { for i, node := range n.Children {
if node == nil { if node == nil {
@ -76,7 +94,7 @@ func (n fullNode) fstring(ind string) string {
} }
return resp + fmt.Sprintf("\n%s] ", ind) return resp + fmt.Sprintf("\n%s] ", ind)
} }
func (n shortNode) fstring(ind string) string { func (n *shortNode) fstring(ind string) string {
return fmt.Sprintf("{%x: %v} ", n.Key, n.Val.fstring(ind+" ")) return fmt.Sprintf("{%x: %v} ", n.Key, n.Val.fstring(ind+" "))
} }
func (n hashNode) fstring(ind string) string { func (n hashNode) fstring(ind string) string {
@ -120,6 +138,7 @@ func decodeShort(hash, buf, elems []byte) (node, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
flag := nodeFlag{hash: hash}
key := compactDecode(kbuf) key := compactDecode(kbuf)
if key[len(key)-1] == 16 { if key[len(key)-1] == 16 {
// value node // value node
@ -127,17 +146,17 @@ func decodeShort(hash, buf, elems []byte) (node, error) {
if err != nil { if err != nil {
return nil, fmt.Errorf("invalid value node: %v", err) return nil, fmt.Errorf("invalid value node: %v", err)
} }
return shortNode{key, valueNode(val), hash, false}, nil return &shortNode{key, append(valueNode{}, val...), flag}, nil
} }
r, _, err := decodeRef(rest) r, _, err := decodeRef(rest)
if err != nil { if err != nil {
return nil, wrapError(err, "val") return nil, wrapError(err, "val")
} }
return shortNode{key, r, hash, false}, nil return &shortNode{key, r, flag}, nil
} }
func decodeFull(hash, buf, elems []byte) (fullNode, error) { func decodeFull(hash, buf, elems []byte) (*fullNode, error) {
n := fullNode{hash: hash} n := &fullNode{flags: nodeFlag{hash: hash}}
for i := 0; i < 16; i++ { for i := 0; i < 16; i++ {
cld, rest, err := decodeRef(elems) cld, rest, err := decodeRef(elems)
if err != nil { if err != nil {
@ -150,7 +169,7 @@ func decodeFull(hash, buf, elems []byte) (fullNode, error) {
return n, err return n, err
} }
if len(val) > 0 { if len(val) > 0 {
n.Children[16] = valueNode(val) n.Children[16] = append(valueNode{}, val...)
} }
return n, nil return n, nil
} }
@ -176,7 +195,7 @@ func decodeRef(buf []byte) (node, []byte, error) {
// empty node // empty node
return nil, rest, nil return nil, rest, nil
case kind == rlp.String && len(val) == 32: case kind == rlp.String && len(val) == 32:
return hashNode(val), rest, nil return append(hashNode{}, val...), rest, nil
default: default:
return nil, nil, fmt.Errorf("invalid RLP string size %d (want 0 or 32)", len(val)) return nil, nil, fmt.Errorf("invalid RLP string size %d (want 0 or 32)", len(val))
} }

58
trie/node_test.go Normal file
View File

@ -0,0 +1,58 @@
// Copyright 2016 The go-ethereum Authors
// This file is part of the go-ethereum library.
//
// The go-ethereum library is free software: you can redistribute it and/or modify
// it under the terms of the GNU Lesser General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// The go-ethereum library is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Lesser General Public License for more details.
//
// You should have received a copy of the GNU Lesser General Public License
// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
package trie
import "testing"
func TestCanUnload(t *testing.T) {
tests := []struct {
flag nodeFlag
cachegen, cachelimit uint16
want bool
}{
{
flag: nodeFlag{dirty: true, gen: 0},
want: false,
},
{
flag: nodeFlag{dirty: false, gen: 0},
cachegen: 0, cachelimit: 0,
want: true,
},
{
flag: nodeFlag{dirty: false, gen: 65534},
cachegen: 65535, cachelimit: 1,
want: true,
},
{
flag: nodeFlag{dirty: false, gen: 65534},
cachegen: 0, cachelimit: 1,
want: true,
},
{
flag: nodeFlag{dirty: false, gen: 1},
cachegen: 65535, cachelimit: 1,
want: true,
},
}
for _, test := range tests {
if got := test.flag.canUnload(test.cachegen, test.cachelimit); got != test.want {
t.Errorf("%+v\n got %t, want %t", test, got, test.want)
}
}
}

View File

@ -44,7 +44,7 @@ func (t *Trie) Prove(key []byte) []rlp.RawValue {
tn := t.root tn := t.root
for len(key) > 0 && tn != nil { for len(key) > 0 && tn != nil {
switch n := tn.(type) { switch n := tn.(type) {
case shortNode: case *shortNode:
if len(key) < len(n.Key) || !bytes.Equal(n.Key, key[:len(n.Key)]) { if len(key) < len(n.Key) || !bytes.Equal(n.Key, key[:len(n.Key)]) {
// The trie doesn't contain the key. // The trie doesn't contain the key.
tn = nil tn = nil
@ -53,7 +53,7 @@ func (t *Trie) Prove(key []byte) []rlp.RawValue {
key = key[len(n.Key):] key = key[len(n.Key):]
} }
nodes = append(nodes, n) nodes = append(nodes, n)
case fullNode: case *fullNode:
tn = n.Children[key[0]] tn = n.Children[key[0]]
key = key[1:] key = key[1:]
nodes = append(nodes, n) nodes = append(nodes, n)
@ -70,7 +70,7 @@ func (t *Trie) Prove(key []byte) []rlp.RawValue {
panic(fmt.Sprintf("%T: invalid node: %v", tn, tn)) panic(fmt.Sprintf("%T: invalid node: %v", tn, tn))
} }
} }
hasher := newHasher() hasher := newHasher(0, 0)
proof := make([]rlp.RawValue, 0, len(nodes)) proof := make([]rlp.RawValue, 0, len(nodes))
for i, n := range nodes { for i, n := range nodes {
// Don't bother checking for errors here since hasher panics // Don't bother checking for errors here since hasher panics
@ -130,13 +130,13 @@ func VerifyProof(rootHash common.Hash, key []byte, proof []rlp.RawValue) (value
func get(tn node, key []byte) ([]byte, node) { func get(tn node, key []byte) ([]byte, node) {
for len(key) > 0 { for len(key) > 0 {
switch n := tn.(type) { switch n := tn.(type) {
case shortNode: case *shortNode:
if len(key) < len(n.Key) || !bytes.Equal(n.Key, key[:len(n.Key)]) { if len(key) < len(n.Key) || !bytes.Equal(n.Key, key[:len(n.Key)]) {
return nil, nil return nil, nil
} }
tn = n.Val tn = n.Val
key = key[len(n.Key):] key = key[len(n.Key):]
case fullNode: case *fullNode:
tn = n.Children[key[0]] tn = n.Children[key[0]]
key = key[1:] key = key[1:]
case hashNode: case hashNode:

View File

@ -49,8 +49,12 @@ type SecureTrie struct {
// If root is the zero hash or the sha3 hash of an empty string, the // If root is the zero hash or the sha3 hash of an empty string, the
// trie is initially empty. Otherwise, New will panic if db is nil // trie is initially empty. Otherwise, New will panic if db is nil
// and returns MissingNodeError if the root node cannot be found. // and returns MissingNodeError if the root node cannot be found.
//
// Accessing the trie loads nodes from db on demand. // Accessing the trie loads nodes from db on demand.
func NewSecure(root common.Hash, db Database) (*SecureTrie, error) { // Loaded nodes are kept around until their 'cache generation' expires.
// A new cache generation is created by each call to Commit.
// cachelimit sets the number of past cache generations to keep.
func NewSecure(root common.Hash, db Database, cachelimit uint16) (*SecureTrie, error) {
if db == nil { if db == nil {
panic("NewSecure called with nil database") panic("NewSecure called with nil database")
} }
@ -58,9 +62,8 @@ func NewSecure(root common.Hash, db Database) (*SecureTrie, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
return &SecureTrie{ trie.SetCacheLimit(cachelimit)
trie: *trie, return &SecureTrie{trie: *trie}, nil
}, nil
} }
// Get returns the value for key stored in the trie. // Get returns the value for key stored in the trie.
@ -191,7 +194,7 @@ func (t *SecureTrie) secKey(key []byte) []byte {
// The caller must not hold onto the return value because it will become // The caller must not hold onto the return value because it will become
// invalid on the next call to hashKey or secKey. // invalid on the next call to hashKey or secKey.
func (t *SecureTrie) hashKey(key []byte) []byte { func (t *SecureTrie) hashKey(key []byte) []byte {
h := newHasher() h := newHasher(0, 0)
h.sha.Reset() h.sha.Reset()
h.sha.Write(key) h.sha.Write(key)
buf := h.sha.Sum(t.hashKeyBuf[:0]) buf := h.sha.Sum(t.hashKeyBuf[:0])

View File

@ -29,7 +29,7 @@ import (
func newEmptySecure() *SecureTrie { func newEmptySecure() *SecureTrie {
db, _ := ethdb.NewMemDatabase() db, _ := ethdb.NewMemDatabase()
trie, _ := NewSecure(common.Hash{}, db) trie, _ := NewSecure(common.Hash{}, db, 0)
return trie return trie
} }
@ -37,7 +37,7 @@ func newEmptySecure() *SecureTrie {
func makeTestSecureTrie() (ethdb.Database, *SecureTrie, map[string][]byte) { func makeTestSecureTrie() (ethdb.Database, *SecureTrie, map[string][]byte) {
// Create an empty trie // Create an empty trie
db, _ := ethdb.NewMemDatabase() db, _ := ethdb.NewMemDatabase()
trie, _ := NewSecure(common.Hash{}, db) trie, _ := NewSecure(common.Hash{}, db, 0)
// Fill it with some arbitrary data // Fill it with some arbitrary data
content := make(map[string][]byte) content := make(map[string][]byte)

View File

@ -212,12 +212,12 @@ func (s *TrieSync) children(req *request) ([]*request, error) {
children := []child{} children := []child{}
switch node := (*req.object).(type) { switch node := (*req.object).(type) {
case shortNode: case *shortNode:
children = []child{{ children = []child{{
node: &node.Val, node: &node.Val,
depth: req.depth + len(node.Key), depth: req.depth + len(node.Key),
}} }}
case fullNode: case *fullNode:
for i := 0; i < 17; i++ { for i := 0; i < 17; i++ {
if node.Children[i] != nil { if node.Children[i] != nil {
children = append(children, child{ children = append(children, child{

View File

@ -62,6 +62,23 @@ type Trie struct {
root node root node
db Database db Database
originalRoot common.Hash originalRoot common.Hash
// Cache generation values.
// cachegen increase by one with each commit operation.
// new nodes are tagged with the current generation and unloaded
// when their generation is older than than cachegen-cachelimit.
cachegen, cachelimit uint16
}
// SetCacheLimit sets the number of 'cache generations' to keep.
// A cache generations is created by a call to Commit.
func (t *Trie) SetCacheLimit(l uint16) {
t.cachelimit = l
}
// newFlag returns the cache flag value for a newly created node.
func (t *Trie) newFlag() nodeFlag {
return nodeFlag{dirty: true, gen: t.cachegen}
} }
// New creates a trie with an existing root node from db. // New creates a trie with an existing root node from db.
@ -120,27 +137,25 @@ func (t *Trie) tryGet(origNode node, key []byte, pos int) (value []byte, newnode
return nil, nil, false, nil return nil, nil, false, nil
case valueNode: case valueNode:
return n, n, false, nil return n, n, false, nil
case shortNode: case *shortNode:
if len(key)-pos < len(n.Key) || !bytes.Equal(n.Key, key[pos:pos+len(n.Key)]) { if len(key)-pos < len(n.Key) || !bytes.Equal(n.Key, key[pos:pos+len(n.Key)]) {
// key not found in trie // key not found in trie
return nil, n, false, nil return nil, n, false, nil
} }
value, newnode, didResolve, err = t.tryGet(n.Val, key, pos+len(n.Key)) value, newnode, didResolve, err = t.tryGet(n.Val, key, pos+len(n.Key))
if err == nil && didResolve { if err == nil && didResolve {
n = n.copy()
n.Val = newnode n.Val = newnode
return value, n, didResolve, err
} else {
return value, origNode, didResolve, err
} }
case fullNode: return value, n, didResolve, err
child := n.Children[key[pos]] case *fullNode:
value, newnode, didResolve, err = t.tryGet(child, key, pos+1) value, newnode, didResolve, err = t.tryGet(n.Children[key[pos]], key, pos+1)
if err == nil && didResolve { if err == nil && didResolve {
n = n.copy()
n.Children[key[pos]] = newnode n.Children[key[pos]] = newnode
return value, n, didResolve, err
} else {
return value, origNode, didResolve, err
} }
return value, n, didResolve, err
case hashNode: case hashNode:
child, err := t.resolveHash(n, key[:pos], key[pos:]) child, err := t.resolveHash(n, key[:pos], key[pos:])
if err != nil { if err != nil {
@ -199,22 +214,19 @@ func (t *Trie) insert(n node, prefix, key []byte, value node) (bool, node, error
return true, value, nil return true, value, nil
} }
switch n := n.(type) { switch n := n.(type) {
case shortNode: case *shortNode:
matchlen := prefixLen(key, n.Key) matchlen := prefixLen(key, n.Key)
// If the whole key matches, keep this short node as is // If the whole key matches, keep this short node as is
// and only update the value. // and only update the value.
if matchlen == len(n.Key) { if matchlen == len(n.Key) {
dirty, nn, err := t.insert(n.Val, append(prefix, key[:matchlen]...), key[matchlen:], value) dirty, nn, err := t.insert(n.Val, append(prefix, key[:matchlen]...), key[matchlen:], value)
if err != nil { if !dirty || err != nil {
return false, nil, err return false, n, err
} }
if !dirty { return true, &shortNode{n.Key, nn, t.newFlag()}, nil
return false, n, nil
}
return true, shortNode{n.Key, nn, nil, true}, nil
} }
// Otherwise branch out at the index where they differ. // Otherwise branch out at the index where they differ.
branch := fullNode{dirty: true} branch := &fullNode{flags: t.newFlag()}
var err error var err error
_, branch.Children[n.Key[matchlen]], err = t.insert(nil, append(prefix, n.Key[:matchlen+1]...), n.Key[matchlen+1:], n.Val) _, branch.Children[n.Key[matchlen]], err = t.insert(nil, append(prefix, n.Key[:matchlen+1]...), n.Key[matchlen+1:], n.Val)
if err != nil { if err != nil {
@ -229,21 +241,19 @@ func (t *Trie) insert(n node, prefix, key []byte, value node) (bool, node, error
return true, branch, nil return true, branch, nil
} }
// Otherwise, replace it with a short node leading up to the branch. // Otherwise, replace it with a short node leading up to the branch.
return true, shortNode{key[:matchlen], branch, nil, true}, nil return true, &shortNode{key[:matchlen], branch, t.newFlag()}, nil
case fullNode: case *fullNode:
dirty, nn, err := t.insert(n.Children[key[0]], append(prefix, key[0]), key[1:], value) dirty, nn, err := t.insert(n.Children[key[0]], append(prefix, key[0]), key[1:], value)
if err != nil { if !dirty || err != nil {
return false, nil, err return false, n, err
} }
if !dirty { n = n.copy()
return false, n, nil n.Children[key[0]], n.flags.hash, n.flags.dirty = nn, nil, true
}
n.Children[key[0]], n.hash, n.dirty = nn, nil, true
return true, n, nil return true, n, nil
case nil: case nil:
return true, shortNode{key, value, nil, true}, nil return true, &shortNode{key, value, t.newFlag()}, nil
case hashNode: case hashNode:
// We've hit a part of the trie that isn't loaded yet. Load // We've hit a part of the trie that isn't loaded yet. Load
@ -254,11 +264,8 @@ func (t *Trie) insert(n node, prefix, key []byte, value node) (bool, node, error
return false, nil, err return false, nil, err
} }
dirty, nn, err := t.insert(rn, prefix, key, value) dirty, nn, err := t.insert(rn, prefix, key, value)
if err != nil { if !dirty || err != nil {
return false, nil, err return false, rn, err
}
if !dirty {
return false, rn, nil
} }
return true, nn, nil return true, nn, nil
@ -291,7 +298,7 @@ func (t *Trie) TryDelete(key []byte) error {
// nodes on the way up after deleting recursively. // nodes on the way up after deleting recursively.
func (t *Trie) delete(n node, prefix, key []byte) (bool, node, error) { func (t *Trie) delete(n node, prefix, key []byte) (bool, node, error) {
switch n := n.(type) { switch n := n.(type) {
case shortNode: case *shortNode:
matchlen := prefixLen(key, n.Key) matchlen := prefixLen(key, n.Key)
if matchlen < len(n.Key) { if matchlen < len(n.Key) {
return false, n, nil // don't replace n on mismatch return false, n, nil // don't replace n on mismatch
@ -304,34 +311,29 @@ func (t *Trie) delete(n node, prefix, key []byte) (bool, node, error) {
// subtrie must contain at least two other values with keys // subtrie must contain at least two other values with keys
// longer than n.Key. // longer than n.Key.
dirty, child, err := t.delete(n.Val, append(prefix, key[:len(n.Key)]...), key[len(n.Key):]) dirty, child, err := t.delete(n.Val, append(prefix, key[:len(n.Key)]...), key[len(n.Key):])
if err != nil { if !dirty || err != nil {
return false, nil, err return false, n, err
}
if !dirty {
return false, n, nil
} }
switch child := child.(type) { switch child := child.(type) {
case shortNode: case *shortNode:
// Deleting from the subtrie reduced it to another // Deleting from the subtrie reduced it to another
// short node. Merge the nodes to avoid creating a // short node. Merge the nodes to avoid creating a
// shortNode{..., shortNode{...}}. Use concat (which // shortNode{..., shortNode{...}}. Use concat (which
// always creates a new slice) instead of append to // always creates a new slice) instead of append to
// avoid modifying n.Key since it might be shared with // avoid modifying n.Key since it might be shared with
// other nodes. // other nodes.
return true, shortNode{concat(n.Key, child.Key...), child.Val, nil, true}, nil return true, &shortNode{concat(n.Key, child.Key...), child.Val, t.newFlag()}, nil
default: default:
return true, shortNode{n.Key, child, nil, true}, nil return true, &shortNode{n.Key, child, t.newFlag()}, nil
} }
case fullNode: case *fullNode:
dirty, nn, err := t.delete(n.Children[key[0]], append(prefix, key[0]), key[1:]) dirty, nn, err := t.delete(n.Children[key[0]], append(prefix, key[0]), key[1:])
if err != nil { if !dirty || err != nil {
return false, nil, err return false, n, err
} }
if !dirty { n = n.copy()
return false, n, nil n.Children[key[0]], n.flags.hash, n.flags.dirty = nn, nil, true
}
n.Children[key[0]], n.hash, n.dirty = nn, nil, true
// Check how many non-nil entries are left after deleting and // Check how many non-nil entries are left after deleting and
// reduce the full node to a short node if only one entry is // reduce the full node to a short node if only one entry is
@ -365,14 +367,14 @@ func (t *Trie) delete(n node, prefix, key []byte) (bool, node, error) {
if err != nil { if err != nil {
return false, nil, err return false, nil, err
} }
if cnode, ok := cnode.(shortNode); ok { if cnode, ok := cnode.(*shortNode); ok {
k := append([]byte{byte(pos)}, cnode.Key...) k := append([]byte{byte(pos)}, cnode.Key...)
return true, shortNode{k, cnode.Val, nil, true}, nil return true, &shortNode{k, cnode.Val, t.newFlag()}, nil
} }
} }
// Otherwise, n is replaced by a one-nibble short node // Otherwise, n is replaced by a one-nibble short node
// containing the child. // containing the child.
return true, shortNode{[]byte{byte(pos)}, n.Children[pos], nil, true}, nil return true, &shortNode{[]byte{byte(pos)}, n.Children[pos], t.newFlag()}, nil
} }
// n still contains at least two values and cannot be reduced. // n still contains at least two values and cannot be reduced.
return true, n, nil return true, n, nil
@ -392,11 +394,8 @@ func (t *Trie) delete(n node, prefix, key []byte) (bool, node, error) {
return false, nil, err return false, nil, err
} }
dirty, nn, err := t.delete(rn, prefix, key) dirty, nn, err := t.delete(rn, prefix, key)
if err != nil { if !dirty || err != nil {
return false, nil, err return false, rn, err
}
if !dirty {
return false, rn, nil
} }
return true, nn, nil return true, nn, nil
@ -471,6 +470,7 @@ func (t *Trie) CommitTo(db DatabaseWriter) (root common.Hash, err error) {
return (common.Hash{}), err return (common.Hash{}), err
} }
t.root = cached t.root = cached
t.cachegen++
return common.BytesToHash(hash.(hashNode)), nil return common.BytesToHash(hash.(hashNode)), nil
} }
@ -478,7 +478,7 @@ func (t *Trie) hashRoot(db DatabaseWriter) (node, node, error) {
if t.root == nil { if t.root == nil {
return hashNode(emptyRoot.Bytes()), nil, nil return hashNode(emptyRoot.Bytes()), nil, nil
} }
h := newHasher() h := newHasher(t.cachegen, t.cachelimit)
defer returnHasherToPool(h) defer returnHasherToPool(h)
return h.hash(t.root, db, true) return h.hash(t.root, db, true)
} }

View File

@ -460,8 +460,7 @@ const benchElemCount = 20000
func benchGet(b *testing.B, commit bool) { func benchGet(b *testing.B, commit bool) {
trie := new(Trie) trie := new(Trie)
if commit { if commit {
dir, tmpdb := tempDB() _, tmpdb := tempDB()
defer os.RemoveAll(dir)
trie, _ = New(common.Hash{}, tmpdb) trie, _ = New(common.Hash{}, tmpdb)
} }
k := make([]byte, 32) k := make([]byte, 32)
@ -478,6 +477,13 @@ func benchGet(b *testing.B, commit bool) {
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
trie.Get(k) trie.Get(k)
} }
b.StopTimer()
if commit {
ldb := trie.db.(*ethdb.LDBDatabase)
ldb.Close()
os.RemoveAll(ldb.Path())
}
} }
func benchUpdate(b *testing.B, e binary.ByteOrder) *Trie { func benchUpdate(b *testing.B, e binary.ByteOrder) *Trie {