trie: rework and document key encoding

'encode' and 'decode' are meaningless because the code deals with three
encodings. Document the encodings and give a name to each one.
This commit is contained in:
Felix Lange 2017-04-18 13:25:07 +02:00
parent a31d268b76
commit f958d7d482
7 changed files with 127 additions and 165 deletions

View File

@ -16,49 +16,54 @@
package trie package trie
func compactEncode(hexSlice []byte) []byte { // Trie keys are dealt with in three distinct encodings:
//
// KEYBYTES encoding contains the actual key and nothing else. This encoding is the
// input to most API functions.
//
// HEX encoding contains one byte for each nibble of the key and an optional trailing
// 'terminator' byte of value 0x10 which indicates whether or not the node at the key
// contains a value. Hex key encoding is used for nodes loaded in memory because it's
// convenient to access.
//
// COMPACT encoding is defined by the Ethereum Yellow Paper (it's called "hex prefix
// encoding" there) and contains the bytes of the key and a flag. The high nibble of the
// first byte contains the flag; the lowest bit encoding the oddness of the length and
// the second-lowest encoding whether the node at the key is a value node. The low nibble
// of the first byte is zero in the case of an even number of nibbles and the first nibble
// in the case of an odd number. All remaining nibbles (now an even number) fit properly
// into the remaining bytes. Compact encoding is used for nodes stored on disk.
func hexToCompact(hex []byte) []byte {
terminator := byte(0) terminator := byte(0)
if hexSlice[len(hexSlice)-1] == 16 { if hasTerm(hex) {
terminator = 1 terminator = 1
hexSlice = hexSlice[:len(hexSlice)-1] hex = hex[:len(hex)-1]
} }
var ( buf := make([]byte, len(hex)/2+1)
odd = byte(len(hexSlice) % 2) buf[0] = terminator << 5 // the flag byte
buflen = len(hexSlice)/2 + 1 if len(hex)&1 == 1 {
bi, hi = 0, 0 // indices buf[0] |= 1 << 4 // odd flag
hs = byte(0) // shift: flips between 0 and 4 buf[0] |= hex[0] // first nibble is contained in the first byte
) hex = hex[1:]
if odd == 0 {
bi = 1
hs = 4
}
buf := make([]byte, buflen)
buf[0] = terminator<<5 | byte(odd)<<4
for bi < len(buf) && hi < len(hexSlice) {
buf[bi] |= hexSlice[hi] << hs
if hs == 0 {
bi++
}
hi, hs = hi+1, hs^(1<<2)
} }
decodeNibbles(hex, buf[1:])
return buf return buf
} }
func compactDecode(str []byte) []byte { func compactToHex(compact []byte) []byte {
base := compactHexDecode(str) base := keybytesToHex(compact)
base = base[:len(base)-1] base = base[:len(base)-1]
// apply terminator flag
if base[0] >= 2 { if base[0] >= 2 {
base = append(base, 16) base = append(base, 16)
} }
if base[0]%2 == 1 { // apply odd flag
base = base[1:] chop := 2 - base[0]&1
} else { return base[chop:]
base = base[2:]
}
return base
} }
func compactHexDecode(str []byte) []byte { func keybytesToHex(str []byte) []byte {
l := len(str)*2 + 1 l := len(str)*2 + 1
var nibbles = make([]byte, l) var nibbles = make([]byte, l)
for i, b := range str { for i, b := range str {
@ -69,35 +74,24 @@ func compactHexDecode(str []byte) []byte {
return nibbles return nibbles
} }
// compactHexEncode encodes a series of nibbles into a byte array // hexToKeybytes turns hex nibbles into key bytes.
func compactHexEncode(nibbles []byte) []byte { // This can only be used for keys of even length.
nl := len(nibbles) func hexToKeybytes(hex []byte) []byte {
if nl == 0 { if hasTerm(hex) {
return nil hex = hex[:len(hex)-1]
} }
if nibbles[nl-1] == 16 { if len(hex)&1 != 0 {
nl-- panic("can't convert hex key of odd length")
} }
l := (nl + 1) / 2 key := make([]byte, (len(hex)+1)/2)
var str = make([]byte, l) decodeNibbles(hex, key)
for i := range str { return key
b := nibbles[i*2] * 16
if nl > i*2 {
b += nibbles[i*2+1]
}
str[i] = b
}
return str
} }
func decodeCompact(key []byte) []byte { func decodeNibbles(nibbles []byte, bytes []byte) {
l := len(key) / 2 for bi, ni := 0, 0; ni < len(nibbles); bi, ni = bi+1, ni+2 {
var res = make([]byte, l) bytes[bi] = nibbles[ni]<<4 | nibbles[ni+1]
for i := 0; i < l; i++ {
v1, v0 := key[2*i], key[2*i+1]
res[i] = v1*16 + v0
} }
return res
} }
// prefixLen returns the length of the common prefix of a and b. // prefixLen returns the length of the common prefix of a and b.
@ -114,15 +108,7 @@ func prefixLen(a, b []byte) int {
return i return i
} }
// hasTerm returns whether a hex key has the terminator flag.
func hasTerm(s []byte) bool { func hasTerm(s []byte) bool {
return s[len(s)-1] == 16 return len(s) > 0 && s[len(s)-1] == 16
}
func remTerm(s []byte) []byte {
if hasTerm(s) {
b := make([]byte, len(s)-1)
copy(b, s)
return b
}
return s
} }

View File

@ -17,113 +17,88 @@
package trie package trie
import ( import (
"encoding/hex" "bytes"
"testing" "testing"
checker "gopkg.in/check.v1"
) )
func TestEncoding(t *testing.T) { checker.TestingT(t) } func TestHexCompact(t *testing.T) {
tests := []struct{ hex, compact []byte }{
type TrieEncodingSuite struct{} // empty keys, with and without terminator.
{hex: []byte{}, compact: []byte{0x00}},
var _ = checker.Suite(&TrieEncodingSuite{}) {hex: []byte{16}, compact: []byte{0x20}},
// odd length, no terminator
func (s *TrieEncodingSuite) TestCompactEncode(c *checker.C) { {hex: []byte{1, 2, 3, 4, 5}, compact: []byte{0x11, 0x23, 0x45}},
// even compact encode // even length, no terminator
test1 := []byte{1, 2, 3, 4, 5} {hex: []byte{0, 1, 2, 3, 4, 5}, compact: []byte{0x00, 0x01, 0x23, 0x45}},
res1 := compactEncode(test1) // odd length, terminator
c.Assert(res1, checker.DeepEquals, []byte("\x11\x23\x45")) {hex: []byte{15, 1, 12, 11, 8, 16 /*term*/}, compact: []byte{0x3f, 0x1c, 0xb8}},
// even length, terminator
// odd compact encode {hex: []byte{0, 15, 1, 12, 11, 8, 16 /*term*/}, compact: []byte{0x20, 0x0f, 0x1c, 0xb8}},
test2 := []byte{0, 1, 2, 3, 4, 5} }
res2 := compactEncode(test2) for _, test := range tests {
c.Assert(res2, checker.DeepEquals, []byte("\x00\x01\x23\x45")) if c := hexToCompact(test.hex); !bytes.Equal(c, test.compact) {
t.Errorf("hexToCompact(%x) -> %x, want %x", test.hex, c, test.compact)
//odd terminated compact encode }
test3 := []byte{0, 15, 1, 12, 11, 8 /*term*/, 16} if h := compactToHex(test.compact); !bytes.Equal(h, test.hex) {
res3 := compactEncode(test3) t.Errorf("compactToHex(%x) -> %x, want %x", test.compact, h, test.hex)
c.Assert(res3, checker.DeepEquals, []byte("\x20\x0f\x1c\xb8")) }
// even terminated compact encode
test4 := []byte{15, 1, 12, 11, 8 /*term*/, 16}
res4 := compactEncode(test4)
c.Assert(res4, checker.DeepEquals, []byte("\x3f\x1c\xb8"))
}
func (s *TrieEncodingSuite) TestCompactHexDecode(c *checker.C) {
exp := []byte{7, 6, 6, 5, 7, 2, 6, 2, 16}
res := compactHexDecode([]byte("verb"))
c.Assert(res, checker.DeepEquals, exp)
}
func (s *TrieEncodingSuite) TestCompactHexEncode(c *checker.C) {
exp := []byte("verb")
res := compactHexEncode([]byte{7, 6, 6, 5, 7, 2, 6, 2, 16})
c.Assert(res, checker.DeepEquals, exp)
}
func (s *TrieEncodingSuite) TestCompactDecode(c *checker.C) {
// odd compact decode
exp := []byte{1, 2, 3, 4, 5}
res := compactDecode([]byte("\x11\x23\x45"))
c.Assert(res, checker.DeepEquals, exp)
// even compact decode
exp = []byte{0, 1, 2, 3, 4, 5}
res = compactDecode([]byte("\x00\x01\x23\x45"))
c.Assert(res, checker.DeepEquals, exp)
// even terminated compact decode
exp = []byte{0, 15, 1, 12, 11, 8 /*term*/, 16}
res = compactDecode([]byte("\x20\x0f\x1c\xb8"))
c.Assert(res, checker.DeepEquals, exp)
// even terminated compact decode
exp = []byte{15, 1, 12, 11, 8 /*term*/, 16}
res = compactDecode([]byte("\x3f\x1c\xb8"))
c.Assert(res, checker.DeepEquals, exp)
}
func (s *TrieEncodingSuite) TestDecodeCompact(c *checker.C) {
exp, _ := hex.DecodeString("012345")
res := decodeCompact([]byte{0, 1, 2, 3, 4, 5})
c.Assert(res, checker.DeepEquals, exp)
exp, _ = hex.DecodeString("012345")
res = decodeCompact([]byte{0, 1, 2, 3, 4, 5, 16})
c.Assert(res, checker.DeepEquals, exp)
exp, _ = hex.DecodeString("abcdef")
res = decodeCompact([]byte{10, 11, 12, 13, 14, 15})
c.Assert(res, checker.DeepEquals, exp)
}
func BenchmarkCompactEncode(b *testing.B) {
testBytes := []byte{0, 15, 1, 12, 11, 8 /*term*/, 16}
for i := 0; i < b.N; i++ {
compactEncode(testBytes)
} }
} }
func BenchmarkCompactDecode(b *testing.B) { func TestHexKeybytes(t *testing.T) {
testBytes := []byte{0, 15, 1, 12, 11, 8 /*term*/, 16} tests := []struct{ key, hexIn, hexOut []byte }{
for i := 0; i < b.N; i++ { {key: []byte{}, hexIn: []byte{16}, hexOut: []byte{16}},
compactDecode(testBytes) {key: []byte{}, hexIn: []byte{}, hexOut: []byte{16}},
{
key: []byte{0x12, 0x34, 0x56},
hexIn: []byte{1, 2, 3, 4, 5, 6, 16},
hexOut: []byte{1, 2, 3, 4, 5, 6, 16},
},
{
key: []byte{0x12, 0x34, 0x5},
hexIn: []byte{1, 2, 3, 4, 0, 5, 16},
hexOut: []byte{1, 2, 3, 4, 0, 5, 16},
},
{
key: []byte{0x12, 0x34, 0x56},
hexIn: []byte{1, 2, 3, 4, 5, 6},
hexOut: []byte{1, 2, 3, 4, 5, 6, 16},
},
}
for _, test := range tests {
if h := keybytesToHex(test.key); !bytes.Equal(h, test.hexOut) {
t.Errorf("keybytesToHex(%x) -> %x, want %x", test.key, h, test.hexOut)
}
if k := hexToKeybytes(test.hexIn); !bytes.Equal(k, test.key) {
t.Errorf("hexToKeybytes(%x) -> %x, want %x", test.hexIn, k, test.key)
}
} }
} }
func BenchmarkCompactHexDecode(b *testing.B) { func BenchmarkHexToCompact(b *testing.B) {
testBytes := []byte{0, 15, 1, 12, 11, 8, 16 /*term*/}
for i := 0; i < b.N; i++ {
hexToCompact(testBytes)
}
}
func BenchmarkCompactToHex(b *testing.B) {
testBytes := []byte{0, 15, 1, 12, 11, 8, 16 /*term*/}
for i := 0; i < b.N; i++ {
compactToHex(testBytes)
}
}
func BenchmarkKeybytesToHex(b *testing.B) {
testBytes := []byte{7, 6, 6, 5, 7, 2, 6, 2, 16} testBytes := []byte{7, 6, 6, 5, 7, 2, 6, 2, 16}
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
compactHexDecode(testBytes) keybytesToHex(testBytes)
} }
} }
func BenchmarkDecodeCompact(b *testing.B) { func BenchmarkHexToKeybytes(b *testing.B) {
testBytes := []byte{7, 6, 6, 5, 7, 2, 6, 2, 16} testBytes := []byte{7, 6, 6, 5, 7, 2, 6, 2, 16}
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
decodeCompact(testBytes) hexToKeybytes(testBytes)
} }
} }

View File

@ -105,7 +105,7 @@ func (h *hasher) hashChildren(original node, db DatabaseWriter) (node, node, err
case *shortNode: case *shortNode:
// Hash the short node's child, caching the newly hashed subtree // Hash the short node's child, caching the newly hashed subtree
collapsed, cached := n.copy(), n.copy() collapsed, cached := n.copy(), n.copy()
collapsed.Key = compactEncode(n.Key) collapsed.Key = hexToCompact(n.Key)
cached.Key = common.CopyBytes(n.Key) cached.Key = common.CopyBytes(n.Key)
if _, ok := n.Val.(valueNode); !ok { if _, ok := n.Val.(valueNode); !ok {

View File

@ -19,6 +19,7 @@ package trie
import ( import (
"bytes" "bytes"
"container/heap" "container/heap"
"github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/common"
) )
@ -48,7 +49,7 @@ func NewIteratorFromNodeIterator(it NodeIterator) *Iterator {
func (it *Iterator) Next() bool { func (it *Iterator) Next() bool {
for it.nodeIt.Next(true) { for it.nodeIt.Next(true) {
if it.nodeIt.Leaf() { if it.nodeIt.Leaf() {
it.Key = decodeCompact(it.nodeIt.Path()) it.Key = hexToKeybytes(it.nodeIt.Path())
it.Value = it.nodeIt.LeafBlob() it.Value = it.nodeIt.LeafBlob()
return true return true
} }

View File

@ -139,8 +139,8 @@ func decodeShort(hash, buf, elems []byte, cachegen uint16) (node, error) {
return nil, err return nil, err
} }
flag := nodeFlag{hash: hash, gen: cachegen} flag := nodeFlag{hash: hash, gen: cachegen}
key := compactDecode(kbuf) key := compactToHex(kbuf)
if key[len(key)-1] == 16 { if hasTerm(key) {
// value node // value node
val, _, err := rlp.SplitString(rest) val, _, err := rlp.SplitString(rest)
if err != nil { if err != nil {

View File

@ -38,7 +38,7 @@ import (
// absence of the key. // absence of the key.
func (t *Trie) Prove(key []byte) []rlp.RawValue { func (t *Trie) Prove(key []byte) []rlp.RawValue {
// Collect all nodes on the path to key. // Collect all nodes on the path to key.
key = compactHexDecode(key) key = keybytesToHex(key)
nodes := []node{} nodes := []node{}
tn := t.root tn := t.root
for len(key) > 0 && tn != nil { for len(key) > 0 && tn != nil {
@ -89,7 +89,7 @@ func (t *Trie) Prove(key []byte) []rlp.RawValue {
// returns an error if the proof contains invalid trie nodes or the // returns an error if the proof contains invalid trie nodes or the
// wrong value. // wrong value.
func VerifyProof(rootHash common.Hash, key []byte, proof []rlp.RawValue) (value []byte, err error) { func VerifyProof(rootHash common.Hash, key []byte, proof []rlp.RawValue) (value []byte, err error) {
key = compactHexDecode(key) key = keybytesToHex(key)
sha := sha3.NewKeccak256() sha := sha3.NewKeccak256()
wantHash := rootHash.Bytes() wantHash := rootHash.Bytes()
for i, buf := range proof { for i, buf := range proof {

View File

@ -144,7 +144,7 @@ func (t *Trie) Get(key []byte) []byte {
// The value bytes must not be modified by the caller. // The value bytes must not be modified by the caller.
// If a node was not found in the database, a MissingNodeError is returned. // If a node was not found in the database, a MissingNodeError is returned.
func (t *Trie) TryGet(key []byte) ([]byte, error) { func (t *Trie) TryGet(key []byte) ([]byte, error) {
key = compactHexDecode(key) key = keybytesToHex(key)
value, newroot, didResolve, err := t.tryGet(t.root, key, 0) value, newroot, didResolve, err := t.tryGet(t.root, key, 0)
if err == nil && didResolve { if err == nil && didResolve {
t.root = newroot t.root = newroot
@ -211,7 +211,7 @@ func (t *Trie) Update(key, value []byte) {
// //
// If a node was not found in the database, a MissingNodeError is returned. // If a node was not found in the database, a MissingNodeError is returned.
func (t *Trie) TryUpdate(key, value []byte) error { func (t *Trie) TryUpdate(key, value []byte) error {
k := compactHexDecode(key) k := keybytesToHex(key)
if len(value) != 0 { if len(value) != 0 {
_, n, err := t.insert(t.root, nil, k, valueNode(value)) _, n, err := t.insert(t.root, nil, k, valueNode(value))
if err != nil { if err != nil {
@ -307,7 +307,7 @@ func (t *Trie) Delete(key []byte) {
// TryDelete removes any existing value for key from the trie. // TryDelete removes any existing value for key from the trie.
// If a node was not found in the database, a MissingNodeError is returned. // If a node was not found in the database, a MissingNodeError is returned.
func (t *Trie) TryDelete(key []byte) error { func (t *Trie) TryDelete(key []byte) error {
k := compactHexDecode(key) k := keybytesToHex(key)
_, n, err := t.delete(t.root, nil, k) _, n, err := t.delete(t.root, nil, k)
if err != nil { if err != nil {
return err return err