From 6667e62765783d508b9bc76e963b5a25ef6bdd6a Mon Sep 17 00:00:00 2001 From: philip-morlier Date: Tue, 4 Apr 2023 21:25:44 -0700 Subject: [PATCH] added hasher package --- restricted/hasher/encoding.go | 145 +++++++++ restricted/hasher/hasher.go | 209 +++++++++++++ restricted/hasher/node.go | 280 +++++++++++++++++ restricted/hasher/node_enc.go | 87 ++++++ restricted/hasher/statetrie.go | 532 +++++++++++++++++++++++++++++++++ 5 files changed, 1253 insertions(+) create mode 100644 restricted/hasher/encoding.go create mode 100644 restricted/hasher/hasher.go create mode 100644 restricted/hasher/node.go create mode 100644 restricted/hasher/node_enc.go create mode 100644 restricted/hasher/statetrie.go diff --git a/restricted/hasher/encoding.go b/restricted/hasher/encoding.go new file mode 100644 index 0000000..f55bf11 --- /dev/null +++ b/restricted/hasher/encoding.go @@ -0,0 +1,145 @@ +// Copyright 2014 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package hasher + +// Trie keys are dealt with in three distinct encodings: +// +// KEYBYTES encoding contains the actual key and nothing else. This encoding is the +// input to most API functions. +// +// HEX encoding contains one byte for each nibble of the key and an optional trailing +// 'terminator' byte of value 0x10 which indicates whether or not the node at the key +// contains a value. Hex key encoding is used for nodes loaded in memory because it's +// convenient to access. +// +// COMPACT encoding is defined by the Ethereum Yellow Paper (it's called "hex prefix +// encoding" there) and contains the bytes of the key and a flag. The high nibble of the +// first byte contains the flag; the lowest bit encoding the oddness of the length and +// the second-lowest encoding whether the node at the key is a value node. The low nibble +// of the first byte is zero in the case of an even number of nibbles and the first nibble +// in the case of an odd number. All remaining nibbles (now an even number) fit properly +// into the remaining bytes. Compact encoding is used for nodes stored on disk. + +func hexToCompact(hex []byte) []byte { + terminator := byte(0) + if hasTerm(hex) { + terminator = 1 + hex = hex[:len(hex)-1] + } + buf := make([]byte, len(hex)/2+1) + buf[0] = terminator << 5 // the flag byte + if len(hex)&1 == 1 { + buf[0] |= 1 << 4 // odd flag + buf[0] |= hex[0] // first nibble is contained in the first byte + hex = hex[1:] + } + decodeNibbles(hex, buf[1:]) + return buf +} + +// hexToCompactInPlace places the compact key in input buffer, returning the length +// needed for the representation +func hexToCompactInPlace(hex []byte) int { + var ( + hexLen = len(hex) // length of the hex input + firstByte = byte(0) + ) + // Check if we have a terminator there + if hexLen > 0 && hex[hexLen-1] == 16 { + firstByte = 1 << 5 + hexLen-- // last part was the terminator, ignore that + } + var ( + binLen = hexLen/2 + 1 + ni = 0 // index in hex + bi = 1 // index in bin (compact) + ) + if hexLen&1 == 1 { + firstByte |= 1 << 4 // odd flag + firstByte |= hex[0] // first nibble is contained in the first byte + ni++ + } + for ; ni < hexLen; bi, ni = bi+1, ni+2 { + hex[bi] = hex[ni]<<4 | hex[ni+1] + } + hex[0] = firstByte + return binLen +} + +func compactToHex(compact []byte) []byte { + if len(compact) == 0 { + return compact + } + base := keybytesToHex(compact) + // delete terminator flag + if base[0] < 2 { + base = base[:len(base)-1] + } + // apply odd flag + chop := 2 - base[0]&1 + return base[chop:] +} + +func keybytesToHex(str []byte) []byte { + l := len(str)*2 + 1 + var nibbles = make([]byte, l) + for i, b := range str { + nibbles[i*2] = b / 16 + nibbles[i*2+1] = b % 16 + } + nibbles[l-1] = 16 + return nibbles +} + +// hexToKeybytes turns hex nibbles into key bytes. +// This can only be used for keys of even length. +func hexToKeybytes(hex []byte) []byte { + if hasTerm(hex) { + hex = hex[:len(hex)-1] + } + if len(hex)&1 != 0 { + panic("can't convert hex key of odd length") + } + key := make([]byte, len(hex)/2) + decodeNibbles(hex, key) + return key +} + +func decodeNibbles(nibbles []byte, bytes []byte) { + for bi, ni := 0, 0; ni < len(nibbles); bi, ni = bi+1, ni+2 { + bytes[bi] = nibbles[ni]<<4 | nibbles[ni+1] + } +} + +// prefixLen returns the length of the common prefix of a and b. +func prefixLen(a, b []byte) int { + var i, length = 0, len(a) + if len(b) < length { + length = len(b) + } + for ; i < length; i++ { + if a[i] != b[i] { + break + } + } + return i +} + +// hasTerm returns whether a hex key has the terminator flag. +func hasTerm(s []byte) bool { + return len(s) > 0 && s[len(s)-1] == 16 +} diff --git a/restricted/hasher/hasher.go b/restricted/hasher/hasher.go new file mode 100644 index 0000000..85c9d56 --- /dev/null +++ b/restricted/hasher/hasher.go @@ -0,0 +1,209 @@ +// Copyright 2016 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package hasher + +import ( + "sync" + + "github.com/openrelayxyz/plugeth-utils/restricted/crypto" + "github.com/openrelayxyz/plugeth-utils/restricted/rlp" + "golang.org/x/crypto/sha3" +) + +// Hasher is a type used for the trie Hash operation. A Hasher has some +// internal preallocated temp space +type Hasher struct { + sha crypto.KeccakState + tmp []byte + encbuf rlp.EncoderBuffer + parallel bool // Whether to use parallel threads when hashing +} + +// HasherPool holds pureHashers +var HasherPool = sync.Pool{ + New: func() interface{} { + return &Hasher{ + tmp: make([]byte, 0, 550), // cap is as large as a full fullNode. + sha: sha3.NewLegacyKeccak256().(crypto.KeccakState), + encbuf: rlp.NewEncoderBuffer(nil), + } + }, +} + +func newHasher(parallel bool) *Hasher { + h := HasherPool.Get().(*Hasher) + h.parallel = parallel + return h +} + +func returnHasherToPool(h *Hasher) { + HasherPool.Put(h) +} + +// hash collapses a node down into a hash node, also returning a copy of the +// original node initialized with the computed hash to replace the original one. +func (h *Hasher) hash(n node, force bool) (hashed node, cached node) { + // Return the cached hash if it's available + if hash, _ := n.cache(); hash != nil { + return hash, n + } + // Trie not processed yet, walk the children + switch n := n.(type) { + case *shortNode: + collapsed, cached := h.hashShortNodeChildren(n) + hashed := h.shortnodeToHash(collapsed, force) + // We need to retain the possibly _not_ hashed node, in case it was too + // small to be hashed + if hn, ok := hashed.(hashNode); ok { + cached.flags.hash = hn + } else { + cached.flags.hash = nil + } + return hashed, cached + case *fullNode: + collapsed, cached := h.hashFullNodeChildren(n) + hashed = h.fullnodeToHash(collapsed, force) + if hn, ok := hashed.(hashNode); ok { + cached.flags.hash = hn + } else { + cached.flags.hash = nil + } + return hashed, cached + default: + // Value and hash nodes don't have children so they're left as were + return n, n + } +} + +// hashShortNodeChildren collapses the short node. The returned collapsed node +// holds a live reference to the Key, and must not be modified. +// The cached +func (h *Hasher) hashShortNodeChildren(n *shortNode) (collapsed, cached *shortNode) { + // Hash the short node's child, caching the newly hashed subtree + collapsed, cached = n.copy(), n.copy() + // Previously, we did copy this one. We don't seem to need to actually + // do that, since we don't overwrite/reuse keys + //cached.Key = common.CopyBytes(n.Key) + collapsed.Key = hexToCompact(n.Key) + // Unless the child is a valuenode or hashnode, hash it + switch n.Val.(type) { + case *fullNode, *shortNode: + collapsed.Val, cached.Val = h.hash(n.Val, false) + } + return collapsed, cached +} + +func (h *Hasher) hashFullNodeChildren(n *fullNode) (collapsed *fullNode, cached *fullNode) { + // Hash the full node's children, caching the newly hashed subtrees + cached = n.copy() + collapsed = n.copy() + if h.parallel { + var wg sync.WaitGroup + wg.Add(16) + for i := 0; i < 16; i++ { + go func(i int) { + Hasher := newHasher(false) + if child := n.Children[i]; child != nil { + collapsed.Children[i], cached.Children[i] = Hasher.hash(child, false) + } else { + collapsed.Children[i] = nilValueNode + } + returnHasherToPool(Hasher) + wg.Done() + }(i) + } + wg.Wait() + } else { + for i := 0; i < 16; i++ { + if child := n.Children[i]; child != nil { + collapsed.Children[i], cached.Children[i] = h.hash(child, false) + } else { + collapsed.Children[i] = nilValueNode + } + } + } + return collapsed, cached +} + +// shortnodeToHash creates a hashNode from a shortNode. The supplied shortnode +// should have hex-type Key, which will be converted (without modification) +// into compact form for RLP encoding. +// If the rlp data is smaller than 32 bytes, `nil` is returned. +func (h *Hasher) shortnodeToHash(n *shortNode, force bool) node { + n.encode(h.encbuf) + enc := h.encodedBytes() + + if len(enc) < 32 && !force { + return n // Nodes smaller than 32 bytes are stored inside their parent + } + return h.hashData(enc) +} + +// shortnodeToHash is used to creates a hashNode from a set of hashNodes, (which +// may contain nil values) +func (h *Hasher) fullnodeToHash(n *fullNode, force bool) node { + n.encode(h.encbuf) + enc := h.encodedBytes() + + if len(enc) < 32 && !force { + return n // Nodes smaller than 32 bytes are stored inside their parent + } + return h.hashData(enc) +} + +// encodedBytes returns the result of the last encoding operation on h.encbuf. +// This also resets the encoder buffer. +// +// All node encoding must be done like this: +// +// node.encode(h.encbuf) +// enc := h.encodedBytes() +// +// This convention exists because node.encode can only be inlined/escape-analyzed when +// called on a concrete receiver type. +func (h *Hasher) encodedBytes() []byte { + h.tmp = h.encbuf.AppendToBytes(h.tmp[:0]) + h.encbuf.Reset(nil) + return h.tmp +} + +// hashData hashes the provided data +func (h *Hasher) hashData(data []byte) hashNode { + n := make(hashNode, 32) + h.sha.Reset() + h.sha.Write(data) + h.sha.Read(n) + return n +} + +// proofHash is used to construct trie proofs, and returns the 'collapsed' +// node (for later RLP encoding) as well as the hashed node -- unless the +// node is smaller than 32 bytes, in which case it will be returned as is. +// This method does not do anything on value- or hash-nodes. +func (h *Hasher) proofHash(original node) (collapsed, hashed node) { + switch n := original.(type) { + case *shortNode: + sn, _ := h.hashShortNodeChildren(n) + return sn, h.shortnodeToHash(sn, false) + case *fullNode: + fn, _ := h.hashFullNodeChildren(n) + return fn, h.fullnodeToHash(fn, false) + default: + // Value and hash nodes don't have children so they're left as were + return n, n + } +} diff --git a/restricted/hasher/node.go b/restricted/hasher/node.go new file mode 100644 index 0000000..419ea8f --- /dev/null +++ b/restricted/hasher/node.go @@ -0,0 +1,280 @@ +// Copyright 2014 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package hasher + +import ( + "fmt" + "io" + "strings" + + "github.com/openrelayxyz/plugeth-utils/core" + "github.com/openrelayxyz/plugeth-utils/restricted/rlp" +) + +var indices = []string{"0", "1", "2", "3", "4", "5", "6", "7", "8", "9", "a", "b", "c", "d", "e", "f", "[17]"} + +type node interface { + cache() (hashNode, bool) + encode(w rlp.EncoderBuffer) + fstring(string) string +} + +type ( + fullNode struct { + Children [17]node // Actual trie node data to encode/decode (needs custom encoder) + flags nodeFlag + } + shortNode struct { + Key []byte + Val node + flags nodeFlag + } + hashNode []byte + valueNode []byte +) + +// nilValueNode is used when collapsing internal trie nodes for hashing, since +// unset children need to serialize correctly. +var nilValueNode = valueNode(nil) + +// EncodeRLP encodes a full node into the consensus RLP format. +func (n *fullNode) EncodeRLP(w io.Writer) error { + eb := rlp.NewEncoderBuffer(w) + n.encode(eb) + return eb.Flush() +} + +func (n *fullNode) copy() *fullNode { copy := *n; return © } +func (n *shortNode) copy() *shortNode { copy := *n; return © } + +// nodeFlag contains caching-related metadata about a node. +type nodeFlag struct { + hash hashNode // cached hash of the node (may be nil) + dirty bool // whether the node has changes that must be written to the database +} + +func (n *fullNode) cache() (hashNode, bool) { return n.flags.hash, n.flags.dirty } +func (n *shortNode) cache() (hashNode, bool) { return n.flags.hash, n.flags.dirty } +func (n hashNode) cache() (hashNode, bool) { return nil, true } +func (n valueNode) cache() (hashNode, bool) { return nil, true } + +// Pretty printing. +func (n *fullNode) String() string { return n.fstring("") } +func (n *shortNode) String() string { return n.fstring("") } +func (n hashNode) String() string { return n.fstring("") } +func (n valueNode) String() string { return n.fstring("") } + +func (n *fullNode) fstring(ind string) string { + resp := fmt.Sprintf("[\n%s ", ind) + for i, node := range &n.Children { + if node == nil { + resp += fmt.Sprintf("%s: ", indices[i]) + } else { + resp += fmt.Sprintf("%s: %v", indices[i], node.fstring(ind+" ")) + } + } + return resp + fmt.Sprintf("\n%s] ", ind) +} +func (n *shortNode) fstring(ind string) string { + return fmt.Sprintf("{%x: %v} ", n.Key, n.Val.fstring(ind+" ")) +} +func (n hashNode) fstring(ind string) string { + return fmt.Sprintf("<%x> ", []byte(n)) +} +func (n valueNode) fstring(ind string) string { + return fmt.Sprintf("%x ", []byte(n)) +} + +// mustDecodeNode is a wrapper of decodeNode and panic if any error is encountered. +func mustDecodeNode(hash, buf []byte) node { + n, err := decodeNode(hash, buf) + if err != nil { + panic(fmt.Sprintf("node %x: %v", hash, err)) + } + return n +} + +// mustDecodeNodeUnsafe is a wrapper of decodeNodeUnsafe and panic if any error is +// encountered. +func mustDecodeNodeUnsafe(hash, buf []byte) node { + n, err := decodeNodeUnsafe(hash, buf) + if err != nil { + panic(fmt.Sprintf("node %x: %v", hash, err)) + } + return n +} + +// decodeNode parses the RLP encoding of a trie node. It will deep-copy the passed +// byte slice for decoding, so it's safe to modify the byte slice afterwards. The- +// decode performance of this function is not optimal, but it is suitable for most +// scenarios with low performance requirements and hard to determine whether the +// byte slice be modified or not. +func decodeNode(hash, buf []byte) (node, error) { + return decodeNodeUnsafe(hash, core.CopyBytes(buf)) +} + +// decodeNodeUnsafe parses the RLP encoding of a trie node. The passed byte slice +// will be directly referenced by node without bytes deep copy, so the input MUST +// not be changed after. +func decodeNodeUnsafe(hash, buf []byte) (node, error) { + if len(buf) == 0 { + return nil, io.ErrUnexpectedEOF + } + elems, _, err := rlp.SplitList(buf) + if err != nil { + return nil, fmt.Errorf("decode error: %v", err) + } + switch c, _ := rlp.CountValues(elems); c { + case 2: + n, err := decodeShort(hash, elems) + return n, wrapError(err, "short") + case 17: + n, err := decodeFull(hash, elems) + return n, wrapError(err, "full") + default: + return nil, fmt.Errorf("invalid number of list elements: %v", c) + } +} + +func decodeShort(hash, elems []byte) (node, error) { + kbuf, rest, err := rlp.SplitString(elems) + if err != nil { + return nil, err + } + flag := nodeFlag{hash: hash} + key := compactToHex(kbuf) + if hasTerm(key) { + // value node + val, _, err := rlp.SplitString(rest) + if err != nil { + return nil, fmt.Errorf("invalid value node: %v", err) + } + return &shortNode{key, valueNode(val), flag}, nil + } + r, _, err := decodeRef(rest) + if err != nil { + return nil, wrapError(err, "val") + } + return &shortNode{key, r, flag}, nil +} + +func decodeFull(hash, elems []byte) (*fullNode, error) { + n := &fullNode{flags: nodeFlag{hash: hash}} + for i := 0; i < 16; i++ { + cld, rest, err := decodeRef(elems) + if err != nil { + return n, wrapError(err, fmt.Sprintf("[%d]", i)) + } + n.Children[i], elems = cld, rest + } + val, _, err := rlp.SplitString(elems) + if err != nil { + return n, err + } + if len(val) > 0 { + n.Children[16] = valueNode(val) + } + return n, nil +} + +const hashLen = len(core.Hash{}) + +func decodeRef(buf []byte) (node, []byte, error) { + kind, val, rest, err := rlp.Split(buf) + if err != nil { + return nil, buf, err + } + switch { + case kind == rlp.List: + // 'embedded' node reference. The encoding must be smaller + // than a hash in order to be valid. + if size := len(buf) - len(rest); size > hashLen { + err := fmt.Errorf("oversized embedded node (size is %d bytes, want size < %d)", size, hashLen) + return nil, buf, err + } + n, err := decodeNode(nil, buf) + return n, rest, err + case kind == rlp.String && len(val) == 0: + // empty node + return nil, rest, nil + case kind == rlp.String && len(val) == 32: + return hashNode(val), rest, nil + default: + return nil, nil, fmt.Errorf("invalid RLP string size %d (want 0 or 32)", len(val)) + } +} + +// wraps a decoding error with information about the path to the +// invalid child node (for debugging encoding issues). +type decodeError struct { + what error + stack []string +} + +func wrapError(err error, ctx string) error { + if err == nil { + return nil + } + if decErr, ok := err.(*decodeError); ok { + decErr.stack = append(decErr.stack, ctx) + return decErr + } + return &decodeError{err, []string{ctx}} +} + +func (err *decodeError) Error() string { + return fmt.Sprintf("%v (decode path: %s)", err.what, strings.Join(err.stack, "<-")) +} + +// rawNode is a simple binary blob used to differentiate between collapsed trie +// nodes and already encoded RLP binary blobs (while at the same time store them +// in the same cache fields). +type rawNode []byte + +func (n rawNode) cache() (hashNode, bool) { panic("this should never end up in a live trie") } +func (n rawNode) fstring(ind string) string { panic("this should never end up in a live trie") } + +func (n rawNode) EncodeRLP(w io.Writer) error { + _, err := w.Write(n) + return err +} + +// rawFullNode represents only the useful data content of a full node, with the +// caches and flags stripped out to minimize its data storage. This type honors +// the same RLP encoding as the original parent. +type rawFullNode [17]node + +func (n rawFullNode) cache() (hashNode, bool) { panic("this should never end up in a live trie") } +func (n rawFullNode) fstring(ind string) string { panic("this should never end up in a live trie") } + +func (n rawFullNode) EncodeRLP(w io.Writer) error { + eb := rlp.NewEncoderBuffer(w) + n.encode(eb) + return eb.Flush() +} + +// rawShortNode represents only the useful data content of a short node, with the +// caches and flags stripped out to minimize its data storage. This type honors +// the same RLP encoding as the original parent. +type rawShortNode struct { + Key []byte + Val node +} + +func (n rawShortNode) cache() (hashNode, bool) { panic("this should never end up in a live trie") } +func (n rawShortNode) fstring(ind string) string { panic("this should never end up in a live trie") } + diff --git a/restricted/hasher/node_enc.go b/restricted/hasher/node_enc.go new file mode 100644 index 0000000..4da5fe6 --- /dev/null +++ b/restricted/hasher/node_enc.go @@ -0,0 +1,87 @@ +// Copyright 2022 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package hasher + +import ( + "github.com/openrelayxyz/plugeth-utils/restricted/rlp" +) + +func nodeToBytes(n node) []byte { + w := rlp.NewEncoderBuffer(nil) + n.encode(w) + result := w.ToBytes() + w.Flush() + return result +} + +func (n *fullNode) encode(w rlp.EncoderBuffer) { + offset := w.List() + for _, c := range n.Children { + if c != nil { + c.encode(w) + } else { + w.Write(rlp.EmptyString) + } + } + w.ListEnd(offset) +} + +func (n *shortNode) encode(w rlp.EncoderBuffer) { + offset := w.List() + w.WriteBytes(n.Key) + if n.Val != nil { + n.Val.encode(w) + } else { + w.Write(rlp.EmptyString) + } + w.ListEnd(offset) +} + +func (n hashNode) encode(w rlp.EncoderBuffer) { + w.WriteBytes(n) +} + +func (n valueNode) encode(w rlp.EncoderBuffer) { + w.WriteBytes(n) +} + +func (n rawFullNode) encode(w rlp.EncoderBuffer) { + offset := w.List() + for _, c := range n { + if c != nil { + c.encode(w) + } else { + w.Write(rlp.EmptyString) + } + } + w.ListEnd(offset) +} + +func (n *rawShortNode) encode(w rlp.EncoderBuffer) { + offset := w.List() + w.WriteBytes(n.Key) + if n.Val != nil { + n.Val.encode(w) + } else { + w.Write(rlp.EmptyString) + } + w.ListEnd(offset) +} + +func (n rawNode) encode(w rlp.EncoderBuffer) { + w.Write(n) +} diff --git a/restricted/hasher/statetrie.go b/restricted/hasher/statetrie.go new file mode 100644 index 0000000..25fff53 --- /dev/null +++ b/restricted/hasher/statetrie.go @@ -0,0 +1,532 @@ +// Copyright 2020 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package hasher + +import ( + "fmt" + "bufio" + "bytes" + "encoding/gob" + "errors" + "io" + "sync" + + "github.com/openrelayxyz/plugeth-utils/core" + "github.com/openrelayxyz/plugeth-utils/restricted/types" +) + +var ErrCommitDisabled = errors.New("no database for committing") + +var stPool = sync.Pool{ + New: func() interface{} { + return NewStackTrie(nil) + }, +} + +// NodeWriteFunc is used to provide all information of a dirty node for committing +// so that callers can flush nodes into database with desired scheme. +type NodeWriteFunc = func(owner core.Hash, path []byte, hash core.Hash, blob []byte) + +func stackTrieFromPool(writeFn NodeWriteFunc, owner core.Hash) *StackTrie { + st := stPool.Get().(*StackTrie) + st.owner = owner + st.writeFn = writeFn + return st +} + +func returnToPool(st *StackTrie) { + st.Reset() + stPool.Put(st) +} + +// StackTrie is a trie implementation that expects keys to be inserted +// in order. Once it determines that a subtree will no longer be inserted +// into, it will hash it and free up the memory it uses. +type StackTrie struct { + owner core.Hash // the owner of the trie + nodeType uint8 // node type (as in branch, ext, leaf) + val []byte // value contained by this node if it's a leaf + key []byte // key chunk covered by this (leaf|ext) node + children [16]*StackTrie // list of children (for branch and exts) + writeFn NodeWriteFunc // function for committing nodes, can be nil +} + +// NewStackTrie allocates and initializes an empty trie. +func NewStackTrie(writeFn NodeWriteFunc) *StackTrie { + return &StackTrie{ + nodeType: emptyNode, + writeFn: writeFn, + } +} + +// NewStackTrieWithOwner allocates and initializes an empty trie, but with +// the additional owner field. +func NewStackTrieWithOwner(writeFn NodeWriteFunc, owner core.Hash) *StackTrie { + return &StackTrie{ + owner: owner, + nodeType: emptyNode, + writeFn: writeFn, + } +} + +// NewFromBinary initialises a serialized stacktrie with the given db. +func NewFromBinary(data []byte, writeFn NodeWriteFunc) (*StackTrie, error) { + var st StackTrie + if err := st.UnmarshalBinary(data); err != nil { + return nil, err + } + // If a database is used, we need to recursively add it to every child + if writeFn != nil { + st.setWriter(writeFn) + } + return &st, nil +} + +// MarshalBinary implements encoding.BinaryMarshaler +func (st *StackTrie) MarshalBinary() (data []byte, err error) { + var ( + b bytes.Buffer + w = bufio.NewWriter(&b) + ) + if err := gob.NewEncoder(w).Encode(struct { + Owner core.Hash + NodeType uint8 + Val []byte + Key []byte + }{ + st.owner, + st.nodeType, + st.val, + st.key, + }); err != nil { + return nil, err + } + for _, child := range st.children { + if child == nil { + w.WriteByte(0) + continue + } + w.WriteByte(1) + if childData, err := child.MarshalBinary(); err != nil { + return nil, err + } else { + w.Write(childData) + } + } + w.Flush() + return b.Bytes(), nil +} + +// UnmarshalBinary implements encoding.BinaryUnmarshaler +func (st *StackTrie) UnmarshalBinary(data []byte) error { + r := bytes.NewReader(data) + return st.unmarshalBinary(r) +} + +func (st *StackTrie) unmarshalBinary(r io.Reader) error { + var dec struct { + Owner core.Hash + NodeType uint8 + Val []byte + Key []byte + } + if err := gob.NewDecoder(r).Decode(&dec); err != nil { + return err + } + st.owner = dec.Owner + st.nodeType = dec.NodeType + st.val = dec.Val + st.key = dec.Key + + var hasChild = make([]byte, 1) + for i := range st.children { + if _, err := r.Read(hasChild); err != nil { + return err + } else if hasChild[0] == 0 { + continue + } + var child StackTrie + if err := child.unmarshalBinary(r); err != nil { + return err + } + st.children[i] = &child + } + return nil +} + +func (st *StackTrie) setWriter(writeFn NodeWriteFunc) { + st.writeFn = writeFn + for _, child := range st.children { + if child != nil { + child.setWriter(writeFn) + } + } +} + +func newLeaf(owner core.Hash, key, val []byte, writeFn NodeWriteFunc) *StackTrie { + st := stackTrieFromPool(writeFn, owner) + st.nodeType = leafNode + st.key = append(st.key, key...) + st.val = val + return st +} + +func newExt(owner core.Hash, key []byte, child *StackTrie, writeFn NodeWriteFunc) *StackTrie { + st := stackTrieFromPool(writeFn, owner) + st.nodeType = extNode + st.key = append(st.key, key...) + st.children[0] = child + return st +} + +// List all values that StackTrie#nodeType can hold +const ( + emptyNode = iota + branchNode + extNode + leafNode + hashedNode +) + +// TryUpdate inserts a (key, value) pair into the stack trie +func (st *StackTrie) TryUpdate(key, value []byte) error { + k := keybytesToHex(key) + if len(value) == 0 { + panic("deletion not supported") + } + st.insert(k[:len(k)-1], value, nil) + return nil +} + +func (st *StackTrie) Update(key, value []byte) { + if err := st.TryUpdate(key, value); err != nil { + fmt.Errorf("Unhandled trie error in StackTrie.Update", "err", err) + } +} + +func (st *StackTrie) Reset() { + st.owner = core.Hash{} + st.writeFn = nil + st.key = st.key[:0] + st.val = nil + for i := range st.children { + st.children[i] = nil + } + st.nodeType = emptyNode +} + +// Helper function that, given a full key, determines the index +// at which the chunk pointed by st.keyOffset is different from +// the same chunk in the full key. +func (st *StackTrie) getDiffIndex(key []byte) int { + for idx, nibble := range st.key { + if nibble != key[idx] { + return idx + } + } + return len(st.key) +} + +// Helper function to that inserts a (key, value) pair into +// the trie. +func (st *StackTrie) insert(key, value []byte, prefix []byte) { + switch st.nodeType { + case branchNode: /* Branch */ + idx := int(key[0]) + + // Unresolve elder siblings + for i := idx - 1; i >= 0; i-- { + if st.children[i] != nil { + if st.children[i].nodeType != hashedNode { + st.children[i].hash(append(prefix, byte(i))) + } + break + } + } + + // Add new child + if st.children[idx] == nil { + st.children[idx] = newLeaf(st.owner, key[1:], value, st.writeFn) + } else { + st.children[idx].insert(key[1:], value, append(prefix, key[0])) + } + + case extNode: /* Ext */ + // Compare both key chunks and see where they differ + diffidx := st.getDiffIndex(key) + + // Check if chunks are identical. If so, recurse into + // the child node. Otherwise, the key has to be split + // into 1) an optional common prefix, 2) the fullnode + // representing the two differing path, and 3) a leaf + // for each of the differentiated subtrees. + if diffidx == len(st.key) { + // Ext key and key segment are identical, recurse into + // the child node. + st.children[0].insert(key[diffidx:], value, append(prefix, key[:diffidx]...)) + return + } + // Save the original part. Depending if the break is + // at the extension's last byte or not, create an + // intermediate extension or use the extension's child + // node directly. + var n *StackTrie + if diffidx < len(st.key)-1 { + // Break on the non-last byte, insert an intermediate + // extension. The path prefix of the newly-inserted + // extension should also contain the different byte. + n = newExt(st.owner, st.key[diffidx+1:], st.children[0], st.writeFn) + n.hash(append(prefix, st.key[:diffidx+1]...)) + } else { + // Break on the last byte, no need to insert + // an extension node: reuse the current node. + // The path prefix of the original part should + // still be same. + n = st.children[0] + n.hash(append(prefix, st.key...)) + } + var p *StackTrie + if diffidx == 0 { + // the break is on the first byte, so + // the current node is converted into + // a branch node. + st.children[0] = nil + p = st + st.nodeType = branchNode + } else { + // the common prefix is at least one byte + // long, insert a new intermediate branch + // node. + st.children[0] = stackTrieFromPool(st.writeFn, st.owner) + st.children[0].nodeType = branchNode + p = st.children[0] + } + // Create a leaf for the inserted part + o := newLeaf(st.owner, key[diffidx+1:], value, st.writeFn) + + // Insert both child leaves where they belong: + origIdx := st.key[diffidx] + newIdx := key[diffidx] + p.children[origIdx] = n + p.children[newIdx] = o + st.key = st.key[:diffidx] + + case leafNode: /* Leaf */ + // Compare both key chunks and see where they differ + diffidx := st.getDiffIndex(key) + + // Overwriting a key isn't supported, which means that + // the current leaf is expected to be split into 1) an + // optional extension for the common prefix of these 2 + // keys, 2) a fullnode selecting the path on which the + // keys differ, and 3) one leaf for the differentiated + // component of each key. + if diffidx >= len(st.key) { + panic("Trying to insert into existing key") + } + + // Check if the split occurs at the first nibble of the + // chunk. In that case, no prefix extnode is necessary. + // Otherwise, create that + var p *StackTrie + if diffidx == 0 { + // Convert current leaf into a branch + st.nodeType = branchNode + p = st + st.children[0] = nil + } else { + // Convert current node into an ext, + // and insert a child branch node. + st.nodeType = extNode + st.children[0] = NewStackTrieWithOwner(st.writeFn, st.owner) + st.children[0].nodeType = branchNode + p = st.children[0] + } + + // Create the two child leaves: one containing the original + // value and another containing the new value. The child leaf + // is hashed directly in order to free up some memory. + origIdx := st.key[diffidx] + p.children[origIdx] = newLeaf(st.owner, st.key[diffidx+1:], st.val, st.writeFn) + p.children[origIdx].hash(append(prefix, st.key[:diffidx+1]...)) + + newIdx := key[diffidx] + p.children[newIdx] = newLeaf(st.owner, key[diffidx+1:], value, st.writeFn) + + // Finally, cut off the key part that has been passed + // over to the children. + st.key = st.key[:diffidx] + st.val = nil + + case emptyNode: /* Empty */ + st.nodeType = leafNode + st.key = key + st.val = value + + case hashedNode: + panic("trying to insert into hash") + + default: + panic("invalid type") + } +} + +// hash converts st into a 'hashedNode', if possible. Possible outcomes: +// +// 1. The rlp-encoded value was >= 32 bytes: +// - Then the 32-byte `hash` will be accessible in `st.val`. +// - And the 'st.type' will be 'hashedNode' +// +// 2. The rlp-encoded value was < 32 bytes +// - Then the <32 byte rlp-encoded value will be accessible in 'st.val'. +// - And the 'st.type' will be 'hashedNode' AGAIN +// +// This method also sets 'st.type' to hashedNode, and clears 'st.key'. +func (st *StackTrie) hash(path []byte) { + h := newHasher(false) + defer returnHasherToPool(h) + + st.hashRec(h, path) +} + +func (st *StackTrie) hashRec(hasher *Hasher, path []byte) { + // The switch below sets this to the RLP-encoding of this node. + var encodedNode []byte + + switch st.nodeType { + case hashedNode: + return + + case emptyNode: + st.val = types.EmptyRootHash.Bytes() + st.key = st.key[:0] + st.nodeType = hashedNode + return + + case branchNode: + var nodes rawFullNode + for i, child := range st.children { + if child == nil { + nodes[i] = nilValueNode + continue + } + child.hashRec(hasher, append(path, byte(i))) + if len(child.val) < 32 { + nodes[i] = rawNode(child.val) + } else { + nodes[i] = hashNode(child.val) + } + + // Release child back to pool. + st.children[i] = nil + returnToPool(child) + } + + nodes.encode(hasher.encbuf) + encodedNode = hasher.encodedBytes() + + case extNode: + st.children[0].hashRec(hasher, append(path, st.key...)) + + n := rawShortNode{Key: hexToCompact(st.key)} + if len(st.children[0].val) < 32 { + n.Val = rawNode(st.children[0].val) + } else { + n.Val = hashNode(st.children[0].val) + } + + n.encode(hasher.encbuf) + encodedNode = hasher.encodedBytes() + + // Release child back to pool. + returnToPool(st.children[0]) + st.children[0] = nil + + case leafNode: + st.key = append(st.key, byte(16)) + n := rawShortNode{Key: hexToCompact(st.key), Val: valueNode(st.val)} + + n.encode(hasher.encbuf) + encodedNode = hasher.encodedBytes() + + default: + panic("invalid node type") + } + + st.nodeType = hashedNode + st.key = st.key[:0] + if len(encodedNode) < 32 { + st.val = core.CopyBytes(encodedNode) + return + } + + // Write the hash to the 'val'. We allocate a new val here to not mutate + // input values + st.val = hasher.hashData(encodedNode) + if st.writeFn != nil { + st.writeFn(st.owner, path, core.BytesToHash(st.val), encodedNode) + } +} + +// Hash returns the hash of the current node. +func (st *StackTrie) Hash() (h core.Hash) { + hasher := newHasher(false) + defer returnHasherToPool(hasher) + + st.hashRec(hasher, nil) + if len(st.val) == 32 { + copy(h[:], st.val) + return h + } + // If the node's RLP isn't 32 bytes long, the node will not + // be hashed, and instead contain the rlp-encoding of the + // node. For the top level node, we need to force the hashing. + hasher.sha.Reset() + hasher.sha.Write(st.val) + hasher.sha.Read(h[:]) + return h +} + +// Commit will firstly hash the entire trie if it's still not hashed +// and then commit all nodes to the associated database. Actually most +// of the trie nodes MAY have been committed already. The main purpose +// here is to commit the root node. +// +// The associated database is expected, otherwise the whole commit +// functionality should be disabled. +func (st *StackTrie) Commit() (h core.Hash, err error) { + if st.writeFn == nil { + return core.Hash{}, ErrCommitDisabled + } + hasher := newHasher(false) + defer returnHasherToPool(hasher) + + st.hashRec(hasher, nil) + if len(st.val) == 32 { + copy(h[:], st.val) + return h, nil + } + // If the node's RLP isn't 32 bytes long, the node will not + // be hashed (and committed), and instead contain the rlp-encoding of the + // node. For the top level node, we need to force the hashing+commit. + hasher.sha.Reset() + hasher.sha.Write(st.val) + hasher.sha.Read(h[:]) + + st.writeFn(st.owner, nil, h, st.val) + return h, nil +}