diff --git a/trie/stacktrie.go b/trie/stacktrie.go index 0d65ee75e..781c84296 100644 --- a/trie/stacktrie.go +++ b/trie/stacktrie.go @@ -17,11 +17,7 @@ package trie import ( - "bufio" - "bytes" - "encoding/gob" "errors" - "io" "sync" "github.com/ethereum/go-ethereum/common" @@ -29,171 +25,96 @@ import ( "github.com/ethereum/go-ethereum/log" ) -var ErrCommitDisabled = errors.New("no database for committing") - -var stPool = sync.Pool{ - New: func() interface{} { - return NewStackTrie(nil) - }, -} +var ( + ErrCommitDisabled = errors.New("no database for committing") + stPool = sync.Pool{New: func() any { return new(stNode) }} + _ = types.TrieHasher((*StackTrie)(nil)) +) // NodeWriteFunc is used to provide all information of a dirty node for committing // so that callers can flush nodes into database with desired scheme. type NodeWriteFunc = func(owner common.Hash, path []byte, hash common.Hash, blob []byte) -func stackTrieFromPool(writeFn NodeWriteFunc, owner common.Hash) *StackTrie { - st := stPool.Get().(*StackTrie) - st.owner = owner - st.writeFn = writeFn - return st -} - -func returnToPool(st *StackTrie) { - st.Reset() - stPool.Put(st) -} - // StackTrie is a trie implementation that expects keys to be inserted // in order. Once it determines that a subtree will no longer be inserted // into, it will hash it and free up the memory it uses. type StackTrie struct { - owner common.Hash // the owner of the trie - nodeType uint8 // node type (as in branch, ext, leaf) - val []byte // value contained by this node if it's a leaf - key []byte // key chunk covered by this (leaf|ext) node - children [16]*StackTrie // list of children (for branch and exts) - writeFn NodeWriteFunc // function for committing nodes, can be nil + owner common.Hash // the owner of the trie + writeFn NodeWriteFunc // function for committing nodes, can be nil + root *stNode + h *hasher } // NewStackTrie allocates and initializes an empty trie. func NewStackTrie(writeFn NodeWriteFunc) *StackTrie { return &StackTrie{ - nodeType: emptyNode, - writeFn: writeFn, + writeFn: writeFn, + root: stPool.Get().(*stNode), + h: newHasher(false), } } // NewStackTrieWithOwner allocates and initializes an empty trie, but with // the additional owner field. func NewStackTrieWithOwner(writeFn NodeWriteFunc, owner common.Hash) *StackTrie { - return &StackTrie{ - owner: owner, - nodeType: emptyNode, - writeFn: writeFn, - } + stack := NewStackTrie(writeFn) + stack.owner = owner + return stack } -// NewFromBinary initialises a serialized stacktrie with the given db. -func NewFromBinary(data []byte, writeFn NodeWriteFunc) (*StackTrie, error) { - var st StackTrie - if err := st.UnmarshalBinary(data); err != nil { - return nil, err - } - // If a database is used, we need to recursively add it to every child - if writeFn != nil { - st.setWriter(writeFn) - } - return &st, nil -} - -// MarshalBinary implements encoding.BinaryMarshaler -func (st *StackTrie) MarshalBinary() (data []byte, err error) { - var ( - b bytes.Buffer - w = bufio.NewWriter(&b) - ) - if err := gob.NewEncoder(w).Encode(struct { - Owner common.Hash - NodeType uint8 - Val []byte - Key []byte - }{ - st.owner, - st.nodeType, - st.val, - st.key, - }); err != nil { - return nil, err - } - for _, child := range st.children { - if child == nil { - w.WriteByte(0) - continue - } - w.WriteByte(1) - if childData, err := child.MarshalBinary(); err != nil { - return nil, err - } else { - w.Write(childData) - } - } - w.Flush() - return b.Bytes(), nil -} - -// UnmarshalBinary implements encoding.BinaryUnmarshaler -func (st *StackTrie) UnmarshalBinary(data []byte) error { - r := bytes.NewReader(data) - return st.unmarshalBinary(r) -} - -func (st *StackTrie) unmarshalBinary(r io.Reader) error { - var dec struct { - Owner common.Hash - NodeType uint8 - Val []byte - Key []byte - } - if err := gob.NewDecoder(r).Decode(&dec); err != nil { - return err - } - st.owner = dec.Owner - st.nodeType = dec.NodeType - st.val = dec.Val - st.key = dec.Key - - var hasChild = make([]byte, 1) - for i := range st.children { - if _, err := r.Read(hasChild); err != nil { - return err - } else if hasChild[0] == 0 { - continue - } - var child StackTrie - if err := child.unmarshalBinary(r); err != nil { - return err - } - st.children[i] = &child +// Update inserts a (key, value) pair into the stack trie. +func (t *StackTrie) Update(key, value []byte) error { + k := keybytesToHex(key) + if len(value) == 0 { + panic("deletion not supported") } + t.insert(t.root, k[:len(k)-1], value, nil) return nil } -func (st *StackTrie) setWriter(writeFn NodeWriteFunc) { - st.writeFn = writeFn - for _, child := range st.children { - if child != nil { - child.setWriter(writeFn) - } +// MustUpdate is a wrapper of Update and will omit any encountered error but +// just print out an error message. +func (t *StackTrie) MustUpdate(key, value []byte) { + if err := t.Update(key, value); err != nil { + log.Error("Unhandled trie error in StackTrie.Update", "err", err) } } -func newLeaf(owner common.Hash, key, val []byte, writeFn NodeWriteFunc) *StackTrie { - st := stackTrieFromPool(writeFn, owner) - st.nodeType = leafNode +func (t *StackTrie) Reset() { + t.writeFn = nil + t.root = stPool.Get().(*stNode) +} + +// stNode represents a node within a StackTrie +type stNode struct { + typ uint8 // node type (as in branch, ext, leaf) + key []byte // key chunk covered by this (leaf|ext) node + val []byte // value contained by this node if it's a leaf + children [16]*stNode // list of children (for branch and exts) +} + +// newLeaf constructs a leaf node with provided node key and value. The key +// will be deep-copied in the function and safe to modify afterwards, but +// value is not. +func newLeaf(key, val []byte) *stNode { + st := stPool.Get().(*stNode) + st.typ = leafNode st.key = append(st.key, key...) st.val = val return st } -func newExt(owner common.Hash, key []byte, child *StackTrie, writeFn NodeWriteFunc) *StackTrie { - st := stackTrieFromPool(writeFn, owner) - st.nodeType = extNode +// newExt constructs an extension node with provided node key and child. The +// key will be deep-copied in the function and safe to modify afterwards. +func newExt(key []byte, child *stNode) *stNode { + st := stPool.Get().(*stNode) + st.typ = extNode st.key = append(st.key, key...) st.children[0] = child return st } -// List all values that StackTrie#nodeType can hold +// List all values that stNode#nodeType can hold const ( emptyNode = iota branchNode @@ -202,59 +123,40 @@ const ( hashedNode ) -// Update inserts a (key, value) pair into the stack trie. -func (st *StackTrie) Update(key, value []byte) error { - k := keybytesToHex(key) - if len(value) == 0 { - panic("deletion not supported") +func (n *stNode) reset() *stNode { + n.key = n.key[:0] + n.val = nil + for i := range n.children { + n.children[i] = nil } - st.insert(k[:len(k)-1], value, nil) - return nil -} - -// MustUpdate is a wrapper of Update and will omit any encountered error but -// just print out an error message. -func (st *StackTrie) MustUpdate(key, value []byte) { - if err := st.Update(key, value); err != nil { - log.Error("Unhandled trie error in StackTrie.Update", "err", err) - } -} - -func (st *StackTrie) Reset() { - st.owner = common.Hash{} - st.writeFn = nil - st.key = st.key[:0] - st.val = nil - for i := range st.children { - st.children[i] = nil - } - st.nodeType = emptyNode + n.typ = emptyNode + return n } // Helper function that, given a full key, determines the index // at which the chunk pointed by st.keyOffset is different from // the same chunk in the full key. -func (st *StackTrie) getDiffIndex(key []byte) int { - for idx, nibble := range st.key { +func (n *stNode) getDiffIndex(key []byte) int { + for idx, nibble := range n.key { if nibble != key[idx] { return idx } } - return len(st.key) + return len(n.key) } // Helper function to that inserts a (key, value) pair into // the trie. -func (st *StackTrie) insert(key, value []byte, prefix []byte) { - switch st.nodeType { +func (t *StackTrie) insert(st *stNode, key, value []byte, prefix []byte) { + switch st.typ { case branchNode: /* Branch */ idx := int(key[0]) // Unresolve elder siblings for i := idx - 1; i >= 0; i-- { if st.children[i] != nil { - if st.children[i].nodeType != hashedNode { - st.children[i].hash(append(prefix, byte(i))) + if st.children[i].typ != hashedNode { + t.hash(st.children[i], append(prefix, byte(i))) } break } @@ -262,9 +164,9 @@ func (st *StackTrie) insert(key, value []byte, prefix []byte) { // Add new child if st.children[idx] == nil { - st.children[idx] = newLeaf(st.owner, key[1:], value, st.writeFn) + st.children[idx] = newLeaf(key[1:], value) } else { - st.children[idx].insert(key[1:], value, append(prefix, key[0])) + t.insert(st.children[idx], key[1:], value, append(prefix, key[0])) } case extNode: /* Ext */ @@ -279,46 +181,46 @@ func (st *StackTrie) insert(key, value []byte, prefix []byte) { if diffidx == len(st.key) { // Ext key and key segment are identical, recurse into // the child node. - st.children[0].insert(key[diffidx:], value, append(prefix, key[:diffidx]...)) + t.insert(st.children[0], key[diffidx:], value, append(prefix, key[:diffidx]...)) return } // Save the original part. Depending if the break is // at the extension's last byte or not, create an // intermediate extension or use the extension's child // node directly. - var n *StackTrie + var n *stNode if diffidx < len(st.key)-1 { // Break on the non-last byte, insert an intermediate // extension. The path prefix of the newly-inserted // extension should also contain the different byte. - n = newExt(st.owner, st.key[diffidx+1:], st.children[0], st.writeFn) - n.hash(append(prefix, st.key[:diffidx+1]...)) + n = newExt(st.key[diffidx+1:], st.children[0]) + t.hash(n, append(prefix, st.key[:diffidx+1]...)) } else { // Break on the last byte, no need to insert // an extension node: reuse the current node. // The path prefix of the original part should // still be same. n = st.children[0] - n.hash(append(prefix, st.key...)) + t.hash(n, append(prefix, st.key...)) } - var p *StackTrie + var p *stNode if diffidx == 0 { // the break is on the first byte, so // the current node is converted into // a branch node. st.children[0] = nil p = st - st.nodeType = branchNode + st.typ = branchNode } else { // the common prefix is at least one byte // long, insert a new intermediate branch // node. - st.children[0] = stackTrieFromPool(st.writeFn, st.owner) - st.children[0].nodeType = branchNode + st.children[0] = stPool.Get().(*stNode) + st.children[0].typ = branchNode p = st.children[0] } // Create a leaf for the inserted part - o := newLeaf(st.owner, key[diffidx+1:], value, st.writeFn) + o := newLeaf(key[diffidx+1:], value) // Insert both child leaves where they belong: origIdx := st.key[diffidx] @@ -344,18 +246,18 @@ func (st *StackTrie) insert(key, value []byte, prefix []byte) { // Check if the split occurs at the first nibble of the // chunk. In that case, no prefix extnode is necessary. // Otherwise, create that - var p *StackTrie + var p *stNode if diffidx == 0 { // Convert current leaf into a branch - st.nodeType = branchNode + st.typ = branchNode p = st st.children[0] = nil } else { // Convert current node into an ext, // and insert a child branch node. - st.nodeType = extNode - st.children[0] = NewStackTrieWithOwner(st.writeFn, st.owner) - st.children[0].nodeType = branchNode + st.typ = extNode + st.children[0] = stPool.Get().(*stNode) + st.children[0].typ = branchNode p = st.children[0] } @@ -363,11 +265,11 @@ func (st *StackTrie) insert(key, value []byte, prefix []byte) { // value and another containing the new value. The child leaf // is hashed directly in order to free up some memory. origIdx := st.key[diffidx] - p.children[origIdx] = newLeaf(st.owner, st.key[diffidx+1:], st.val, st.writeFn) - p.children[origIdx].hash(append(prefix, st.key[:diffidx+1]...)) + p.children[origIdx] = newLeaf(st.key[diffidx+1:], st.val) + t.hash(p.children[origIdx], append(prefix, st.key[:diffidx+1]...)) newIdx := key[diffidx] - p.children[newIdx] = newLeaf(st.owner, key[diffidx+1:], value, st.writeFn) + p.children[newIdx] = newLeaf(key[diffidx+1:], value) // Finally, cut off the key part that has been passed // over to the children. @@ -375,7 +277,7 @@ func (st *StackTrie) insert(key, value []byte, prefix []byte) { st.val = nil case emptyNode: /* Empty */ - st.nodeType = leafNode + st.typ = leafNode st.key = key st.val = value @@ -398,25 +300,18 @@ func (st *StackTrie) insert(key, value []byte, prefix []byte) { // - And the 'st.type' will be 'hashedNode' AGAIN // // This method also sets 'st.type' to hashedNode, and clears 'st.key'. -func (st *StackTrie) hash(path []byte) { - h := newHasher(false) - defer returnHasherToPool(h) - - st.hashRec(h, path) -} - -func (st *StackTrie) hashRec(hasher *hasher, path []byte) { +func (t *StackTrie) hash(st *stNode, path []byte) { // The switch below sets this to the RLP-encoding of this node. var encodedNode []byte - switch st.nodeType { + switch st.typ { case hashedNode: return case emptyNode: st.val = types.EmptyRootHash.Bytes() st.key = st.key[:0] - st.nodeType = hashedNode + st.typ = hashedNode return case branchNode: @@ -426,23 +321,21 @@ func (st *StackTrie) hashRec(hasher *hasher, path []byte) { nodes.Children[i] = nilValueNode continue } - child.hashRec(hasher, append(path, byte(i))) + t.hash(child, append(path, byte(i))) + if len(child.val) < 32 { nodes.Children[i] = rawNode(child.val) } else { nodes.Children[i] = hashNode(child.val) } - - // Release child back to pool. st.children[i] = nil - returnToPool(child) + stPool.Put(child.reset()) // Release child back to pool. } - - nodes.encode(hasher.encbuf) - encodedNode = hasher.encodedBytes() + nodes.encode(t.h.encbuf) + encodedNode = t.h.encodedBytes() case extNode: - st.children[0].hashRec(hasher, append(path, st.key...)) + t.hash(st.children[0], append(path, st.key...)) n := shortNode{Key: hexToCompactInPlace(st.key)} if len(st.children[0].val) < 32 { @@ -450,26 +343,24 @@ func (st *StackTrie) hashRec(hasher *hasher, path []byte) { } else { n.Val = hashNode(st.children[0].val) } + n.encode(t.h.encbuf) + encodedNode = t.h.encodedBytes() - n.encode(hasher.encbuf) - encodedNode = hasher.encodedBytes() - - // Release child back to pool. - returnToPool(st.children[0]) + stPool.Put(st.children[0].reset()) // Release child back to pool. st.children[0] = nil case leafNode: st.key = append(st.key, byte(16)) n := shortNode{Key: hexToCompactInPlace(st.key), Val: valueNode(st.val)} - n.encode(hasher.encbuf) - encodedNode = hasher.encodedBytes() + n.encode(t.h.encbuf) + encodedNode = t.h.encodedBytes() default: panic("invalid node type") } - st.nodeType = hashedNode + st.typ = hashedNode st.key = st.key[:0] if len(encodedNode) < 32 { st.val = common.CopyBytes(encodedNode) @@ -478,18 +369,16 @@ func (st *StackTrie) hashRec(hasher *hasher, path []byte) { // Write the hash to the 'val'. We allocate a new val here to not mutate // input values - st.val = hasher.hashData(encodedNode) - if st.writeFn != nil { - st.writeFn(st.owner, path, common.BytesToHash(st.val), encodedNode) + st.val = t.h.hashData(encodedNode) + if t.writeFn != nil { + t.writeFn(t.owner, path, common.BytesToHash(st.val), encodedNode) } } // Hash returns the hash of the current node. -func (st *StackTrie) Hash() (h common.Hash) { - hasher := newHasher(false) - defer returnHasherToPool(hasher) - - st.hashRec(hasher, nil) +func (t *StackTrie) Hash() (h common.Hash) { + st := t.root + t.hash(st, nil) if len(st.val) == 32 { copy(h[:], st.val) return h @@ -497,9 +386,9 @@ func (st *StackTrie) Hash() (h common.Hash) { // If the node's RLP isn't 32 bytes long, the node will not // be hashed, and instead contain the rlp-encoding of the // node. For the top level node, we need to force the hashing. - hasher.sha.Reset() - hasher.sha.Write(st.val) - hasher.sha.Read(h[:]) + t.h.sha.Reset() + t.h.sha.Write(st.val) + t.h.sha.Read(h[:]) return h } @@ -510,14 +399,12 @@ func (st *StackTrie) Hash() (h common.Hash) { // // The associated database is expected, otherwise the whole commit // functionality should be disabled. -func (st *StackTrie) Commit() (h common.Hash, err error) { - if st.writeFn == nil { +func (t *StackTrie) Commit() (h common.Hash, err error) { + if t.writeFn == nil { return common.Hash{}, ErrCommitDisabled } - hasher := newHasher(false) - defer returnHasherToPool(hasher) - - st.hashRec(hasher, nil) + st := t.root + t.hash(st, nil) if len(st.val) == 32 { copy(h[:], st.val) return h, nil @@ -525,10 +412,10 @@ func (st *StackTrie) Commit() (h common.Hash, err error) { // If the node's RLP isn't 32 bytes long, the node will not // be hashed (and committed), and instead contain the rlp-encoding of the // node. For the top level node, we need to force the hashing+commit. - hasher.sha.Reset() - hasher.sha.Write(st.val) - hasher.sha.Read(h[:]) + t.h.sha.Reset() + t.h.sha.Write(st.val) + t.h.sha.Read(h[:]) - st.writeFn(st.owner, nil, h, st.val) + t.writeFn(t.owner, nil, h, st.val) return h, nil } diff --git a/trie/stacktrie_marshalling.go b/trie/stacktrie_marshalling.go new file mode 100644 index 000000000..c0bb07f86 --- /dev/null +++ b/trie/stacktrie_marshalling.go @@ -0,0 +1,120 @@ +// Copyright 2023 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package trie + +import ( + "bufio" + "bytes" + "encoding" + "encoding/gob" +) + +// Compile-time interface checks. +var ( + _ = encoding.BinaryMarshaler((*StackTrie)(nil)) + _ = encoding.BinaryUnmarshaler((*StackTrie)(nil)) +) + +// NewFromBinaryV2 initialises a serialized stacktrie with the given db. +// OBS! Format was changed along with the name of this constructor. +func NewFromBinaryV2(data []byte) (*StackTrie, error) { + stack := NewStackTrie(nil) + if err := stack.UnmarshalBinary(data); err != nil { + return nil, err + } + return stack, nil +} + +// MarshalBinary implements encoding.BinaryMarshaler. +func (t *StackTrie) MarshalBinary() (data []byte, err error) { + var ( + b bytes.Buffer + w = bufio.NewWriter(&b) + ) + if err := gob.NewEncoder(w).Encode(t.owner); err != nil { + return nil, err + } + if err := t.root.marshalInto(w); err != nil { + return nil, err + } + w.Flush() + return b.Bytes(), nil +} + +// UnmarshalBinary implements encoding.BinaryUnmarshaler. +func (t *StackTrie) UnmarshalBinary(data []byte) error { + r := bytes.NewReader(data) + if err := gob.NewDecoder(r).Decode(&t.owner); err != nil { + return err + } + if err := t.root.unmarshalFrom(r); err != nil { + return err + } + return nil +} + +type stackNodeMarshaling struct { + Typ uint8 + Key []byte + Val []byte +} + +func (n *stNode) marshalInto(w *bufio.Writer) (err error) { + enc := stackNodeMarshaling{ + Typ: n.typ, + Key: n.key, + Val: n.val, + } + if err := gob.NewEncoder(w).Encode(enc); err != nil { + return err + } + for _, child := range n.children { + if child == nil { + w.WriteByte(0) + continue + } + w.WriteByte(1) + if err := child.marshalInto(w); err != nil { + return err + } + } + return nil +} + +func (n *stNode) unmarshalFrom(r *bytes.Reader) error { + var dec stackNodeMarshaling + if err := gob.NewDecoder(r).Decode(&dec); err != nil { + return err + } + n.typ = dec.Typ + n.key = dec.Key + n.val = dec.Val + + for i := range n.children { + if b, err := r.ReadByte(); err != nil { + return err + } else if b == 0 { + continue + } + var child stNode + if err := child.unmarshalFrom(r); err != nil { + return err + } + n.children[i] = &child + } + return nil +} diff --git a/trie/stacktrie_test.go b/trie/stacktrie_test.go index 6bd0b83e3..5b86a971e 100644 --- a/trie/stacktrie_test.go +++ b/trie/stacktrie_test.go @@ -198,12 +198,11 @@ func TestStackTrieInsertAndHash(t *testing.T) { {"000003", "XXXXXXXXXXXXXXXXXXXXXXXXXXXX", "962c0fffdeef7612a4f7bff1950d67e3e81c878e48b9ae45b3b374253b050bd8"}, }, } - st := NewStackTrie(nil) for i, test := range tests { // The StackTrie does not allow Insert(), Hash(), Insert(), ... // so we will create new trie for every sequence length of inserts. for l := 1; l <= len(test); l++ { - st.Reset() + st := NewStackTrie(nil) for j := 0; j < l; j++ { kv := &test[j] if err := st.Update(common.FromHex(kv.K), []byte(kv.V)); err != nil { @@ -382,7 +381,7 @@ func TestStacktrieNotModifyValues(t *testing.T) { // serialize/unserialize it a lot func TestStacktrieSerialization(t *testing.T) { var ( - st = NewStackTrie(nil) + st = NewStackTrieWithOwner(nil, common.Hash{0x12}) nt = NewEmpty(NewDatabase(rawdb.NewMemoryDatabase(), nil)) keyB = big.NewInt(1) keyDelta = big.NewInt(1) @@ -411,7 +410,7 @@ func TestStacktrieSerialization(t *testing.T) { if err != nil { t.Fatal(err) } - newSt, err := NewFromBinary(blob, nil) + newSt, err := NewFromBinaryV2(blob) if err != nil { t.Fatal(err) } @@ -421,4 +420,7 @@ func TestStacktrieSerialization(t *testing.T) { if have, want := st.Hash(), nt.Hash(); have != want { t.Fatalf("have %#x want %#x", have, want) } + if have, want := st.owner, (common.Hash{0x12}); have != want { + t.Fatalf("have %#x want %#x", have, want) + } }