diff --git a/trie/encoding.go b/trie/encoding.go index 2037118dd..e96a786e4 100644 --- a/trie/encoding.go +++ b/trie/encoding.go @@ -16,49 +16,54 @@ package trie -func compactEncode(hexSlice []byte) []byte { +// Trie keys are dealt with in three distinct encodings: +// +// KEYBYTES encoding contains the actual key and nothing else. This encoding is the +// input to most API functions. +// +// HEX encoding contains one byte for each nibble of the key and an optional trailing +// 'terminator' byte of value 0x10 which indicates whether or not the node at the key +// contains a value. Hex key encoding is used for nodes loaded in memory because it's +// convenient to access. +// +// COMPACT encoding is defined by the Ethereum Yellow Paper (it's called "hex prefix +// encoding" there) and contains the bytes of the key and a flag. The high nibble of the +// first byte contains the flag; the lowest bit encoding the oddness of the length and +// the second-lowest encoding whether the node at the key is a value node. The low nibble +// of the first byte is zero in the case of an even number of nibbles and the first nibble +// in the case of an odd number. All remaining nibbles (now an even number) fit properly +// into the remaining bytes. Compact encoding is used for nodes stored on disk. + +func hexToCompact(hex []byte) []byte { terminator := byte(0) - if hexSlice[len(hexSlice)-1] == 16 { + if hasTerm(hex) { terminator = 1 - hexSlice = hexSlice[:len(hexSlice)-1] + hex = hex[:len(hex)-1] } - var ( - odd = byte(len(hexSlice) % 2) - buflen = len(hexSlice)/2 + 1 - bi, hi = 0, 0 // indices - hs = byte(0) // shift: flips between 0 and 4 - ) - if odd == 0 { - bi = 1 - hs = 4 - } - buf := make([]byte, buflen) - buf[0] = terminator<<5 | byte(odd)<<4 - for bi < len(buf) && hi < len(hexSlice) { - buf[bi] |= hexSlice[hi] << hs - if hs == 0 { - bi++ - } - hi, hs = hi+1, hs^(1<<2) + buf := make([]byte, len(hex)/2+1) + buf[0] = terminator << 5 // the flag byte + if len(hex)&1 == 1 { + buf[0] |= 1 << 4 // odd flag + buf[0] |= hex[0] // first nibble is contained in the first byte + hex = hex[1:] } + decodeNibbles(hex, buf[1:]) return buf } -func compactDecode(str []byte) []byte { - base := compactHexDecode(str) +func compactToHex(compact []byte) []byte { + base := keybytesToHex(compact) base = base[:len(base)-1] + // apply terminator flag if base[0] >= 2 { base = append(base, 16) } - if base[0]%2 == 1 { - base = base[1:] - } else { - base = base[2:] - } - return base + // apply odd flag + chop := 2 - base[0]&1 + return base[chop:] } -func compactHexDecode(str []byte) []byte { +func keybytesToHex(str []byte) []byte { l := len(str)*2 + 1 var nibbles = make([]byte, l) for i, b := range str { @@ -69,35 +74,24 @@ func compactHexDecode(str []byte) []byte { return nibbles } -// compactHexEncode encodes a series of nibbles into a byte array -func compactHexEncode(nibbles []byte) []byte { - nl := len(nibbles) - if nl == 0 { - return nil +// hexToKeybytes turns hex nibbles into key bytes. +// This can only be used for keys of even length. +func hexToKeybytes(hex []byte) []byte { + if hasTerm(hex) { + hex = hex[:len(hex)-1] } - if nibbles[nl-1] == 16 { - nl-- + if len(hex)&1 != 0 { + panic("can't convert hex key of odd length") } - l := (nl + 1) / 2 - var str = make([]byte, l) - for i := range str { - b := nibbles[i*2] * 16 - if nl > i*2 { - b += nibbles[i*2+1] - } - str[i] = b - } - return str + key := make([]byte, (len(hex)+1)/2) + decodeNibbles(hex, key) + return key } -func decodeCompact(key []byte) []byte { - l := len(key) / 2 - var res = make([]byte, l) - for i := 0; i < l; i++ { - v1, v0 := key[2*i], key[2*i+1] - res[i] = v1*16 + v0 +func decodeNibbles(nibbles []byte, bytes []byte) { + for bi, ni := 0, 0; ni < len(nibbles); bi, ni = bi+1, ni+2 { + bytes[bi] = nibbles[ni]<<4 | nibbles[ni+1] } - return res } // prefixLen returns the length of the common prefix of a and b. @@ -114,15 +108,7 @@ func prefixLen(a, b []byte) int { return i } +// hasTerm returns whether a hex key has the terminator flag. func hasTerm(s []byte) bool { - return s[len(s)-1] == 16 -} - -func remTerm(s []byte) []byte { - if hasTerm(s) { - b := make([]byte, len(s)-1) - copy(b, s) - return b - } - return s + return len(s) > 0 && s[len(s)-1] == 16 } diff --git a/trie/encoding_test.go b/trie/encoding_test.go index 2f125ef2f..97d8da136 100644 --- a/trie/encoding_test.go +++ b/trie/encoding_test.go @@ -17,113 +17,88 @@ package trie import ( - "encoding/hex" + "bytes" "testing" - - checker "gopkg.in/check.v1" ) -func TestEncoding(t *testing.T) { checker.TestingT(t) } - -type TrieEncodingSuite struct{} - -var _ = checker.Suite(&TrieEncodingSuite{}) - -func (s *TrieEncodingSuite) TestCompactEncode(c *checker.C) { - // even compact encode - test1 := []byte{1, 2, 3, 4, 5} - res1 := compactEncode(test1) - c.Assert(res1, checker.DeepEquals, []byte("\x11\x23\x45")) - - // odd compact encode - test2 := []byte{0, 1, 2, 3, 4, 5} - res2 := compactEncode(test2) - c.Assert(res2, checker.DeepEquals, []byte("\x00\x01\x23\x45")) - - //odd terminated compact encode - test3 := []byte{0, 15, 1, 12, 11, 8 /*term*/, 16} - res3 := compactEncode(test3) - c.Assert(res3, checker.DeepEquals, []byte("\x20\x0f\x1c\xb8")) - - // even terminated compact encode - test4 := []byte{15, 1, 12, 11, 8 /*term*/, 16} - res4 := compactEncode(test4) - c.Assert(res4, checker.DeepEquals, []byte("\x3f\x1c\xb8")) -} - -func (s *TrieEncodingSuite) TestCompactHexDecode(c *checker.C) { - exp := []byte{7, 6, 6, 5, 7, 2, 6, 2, 16} - res := compactHexDecode([]byte("verb")) - c.Assert(res, checker.DeepEquals, exp) -} - -func (s *TrieEncodingSuite) TestCompactHexEncode(c *checker.C) { - exp := []byte("verb") - res := compactHexEncode([]byte{7, 6, 6, 5, 7, 2, 6, 2, 16}) - c.Assert(res, checker.DeepEquals, exp) -} - -func (s *TrieEncodingSuite) TestCompactDecode(c *checker.C) { - // odd compact decode - exp := []byte{1, 2, 3, 4, 5} - res := compactDecode([]byte("\x11\x23\x45")) - c.Assert(res, checker.DeepEquals, exp) - - // even compact decode - exp = []byte{0, 1, 2, 3, 4, 5} - res = compactDecode([]byte("\x00\x01\x23\x45")) - c.Assert(res, checker.DeepEquals, exp) - - // even terminated compact decode - exp = []byte{0, 15, 1, 12, 11, 8 /*term*/, 16} - res = compactDecode([]byte("\x20\x0f\x1c\xb8")) - c.Assert(res, checker.DeepEquals, exp) - - // even terminated compact decode - exp = []byte{15, 1, 12, 11, 8 /*term*/, 16} - res = compactDecode([]byte("\x3f\x1c\xb8")) - c.Assert(res, checker.DeepEquals, exp) -} - -func (s *TrieEncodingSuite) TestDecodeCompact(c *checker.C) { - exp, _ := hex.DecodeString("012345") - res := decodeCompact([]byte{0, 1, 2, 3, 4, 5}) - c.Assert(res, checker.DeepEquals, exp) - - exp, _ = hex.DecodeString("012345") - res = decodeCompact([]byte{0, 1, 2, 3, 4, 5, 16}) - c.Assert(res, checker.DeepEquals, exp) - - exp, _ = hex.DecodeString("abcdef") - res = decodeCompact([]byte{10, 11, 12, 13, 14, 15}) - c.Assert(res, checker.DeepEquals, exp) -} - -func BenchmarkCompactEncode(b *testing.B) { - - testBytes := []byte{0, 15, 1, 12, 11, 8 /*term*/, 16} - for i := 0; i < b.N; i++ { - compactEncode(testBytes) +func TestHexCompact(t *testing.T) { + tests := []struct{ hex, compact []byte }{ + // empty keys, with and without terminator. + {hex: []byte{}, compact: []byte{0x00}}, + {hex: []byte{16}, compact: []byte{0x20}}, + // odd length, no terminator + {hex: []byte{1, 2, 3, 4, 5}, compact: []byte{0x11, 0x23, 0x45}}, + // even length, no terminator + {hex: []byte{0, 1, 2, 3, 4, 5}, compact: []byte{0x00, 0x01, 0x23, 0x45}}, + // odd length, terminator + {hex: []byte{15, 1, 12, 11, 8, 16 /*term*/}, compact: []byte{0x3f, 0x1c, 0xb8}}, + // even length, terminator + {hex: []byte{0, 15, 1, 12, 11, 8, 16 /*term*/}, compact: []byte{0x20, 0x0f, 0x1c, 0xb8}}, + } + for _, test := range tests { + if c := hexToCompact(test.hex); !bytes.Equal(c, test.compact) { + t.Errorf("hexToCompact(%x) -> %x, want %x", test.hex, c, test.compact) + } + if h := compactToHex(test.compact); !bytes.Equal(h, test.hex) { + t.Errorf("compactToHex(%x) -> %x, want %x", test.compact, h, test.hex) + } } } -func BenchmarkCompactDecode(b *testing.B) { - testBytes := []byte{0, 15, 1, 12, 11, 8 /*term*/, 16} - for i := 0; i < b.N; i++ { - compactDecode(testBytes) +func TestHexKeybytes(t *testing.T) { + tests := []struct{ key, hexIn, hexOut []byte }{ + {key: []byte{}, hexIn: []byte{16}, hexOut: []byte{16}}, + {key: []byte{}, hexIn: []byte{}, hexOut: []byte{16}}, + { + key: []byte{0x12, 0x34, 0x56}, + hexIn: []byte{1, 2, 3, 4, 5, 6, 16}, + hexOut: []byte{1, 2, 3, 4, 5, 6, 16}, + }, + { + key: []byte{0x12, 0x34, 0x5}, + hexIn: []byte{1, 2, 3, 4, 0, 5, 16}, + hexOut: []byte{1, 2, 3, 4, 0, 5, 16}, + }, + { + key: []byte{0x12, 0x34, 0x56}, + hexIn: []byte{1, 2, 3, 4, 5, 6}, + hexOut: []byte{1, 2, 3, 4, 5, 6, 16}, + }, + } + for _, test := range tests { + if h := keybytesToHex(test.key); !bytes.Equal(h, test.hexOut) { + t.Errorf("keybytesToHex(%x) -> %x, want %x", test.key, h, test.hexOut) + } + if k := hexToKeybytes(test.hexIn); !bytes.Equal(k, test.key) { + t.Errorf("hexToKeybytes(%x) -> %x, want %x", test.hexIn, k, test.key) + } } } -func BenchmarkCompactHexDecode(b *testing.B) { +func BenchmarkHexToCompact(b *testing.B) { + testBytes := []byte{0, 15, 1, 12, 11, 8, 16 /*term*/} + for i := 0; i < b.N; i++ { + hexToCompact(testBytes) + } +} + +func BenchmarkCompactToHex(b *testing.B) { + testBytes := []byte{0, 15, 1, 12, 11, 8, 16 /*term*/} + for i := 0; i < b.N; i++ { + compactToHex(testBytes) + } +} + +func BenchmarkKeybytesToHex(b *testing.B) { testBytes := []byte{7, 6, 6, 5, 7, 2, 6, 2, 16} for i := 0; i < b.N; i++ { - compactHexDecode(testBytes) + keybytesToHex(testBytes) } } -func BenchmarkDecodeCompact(b *testing.B) { +func BenchmarkHexToKeybytes(b *testing.B) { testBytes := []byte{7, 6, 6, 5, 7, 2, 6, 2, 16} for i := 0; i < b.N; i++ { - decodeCompact(testBytes) + hexToKeybytes(testBytes) } } diff --git a/trie/hasher.go b/trie/hasher.go index 98c309531..85b6b60f5 100644 --- a/trie/hasher.go +++ b/trie/hasher.go @@ -105,7 +105,7 @@ func (h *hasher) hashChildren(original node, db DatabaseWriter) (node, node, err case *shortNode: // Hash the short node's child, caching the newly hashed subtree collapsed, cached := n.copy(), n.copy() - collapsed.Key = compactEncode(n.Key) + collapsed.Key = hexToCompact(n.Key) cached.Key = common.CopyBytes(n.Key) if _, ok := n.Val.(valueNode); !ok { diff --git a/trie/iterator.go b/trie/iterator.go index 42149a7d3..dd63a0c5a 100644 --- a/trie/iterator.go +++ b/trie/iterator.go @@ -19,6 +19,7 @@ package trie import ( "bytes" "container/heap" + "github.com/ethereum/go-ethereum/common" ) @@ -48,7 +49,7 @@ func NewIteratorFromNodeIterator(it NodeIterator) *Iterator { func (it *Iterator) Next() bool { for it.nodeIt.Next(true) { if it.nodeIt.Leaf() { - it.Key = decodeCompact(it.nodeIt.Path()) + it.Key = hexToKeybytes(it.nodeIt.Path()) it.Value = it.nodeIt.LeafBlob() return true } diff --git a/trie/node.go b/trie/node.go index 4aa0cab65..a7697fc0c 100644 --- a/trie/node.go +++ b/trie/node.go @@ -139,8 +139,8 @@ func decodeShort(hash, buf, elems []byte, cachegen uint16) (node, error) { return nil, err } flag := nodeFlag{hash: hash, gen: cachegen} - key := compactDecode(kbuf) - if key[len(key)-1] == 16 { + key := compactToHex(kbuf) + if hasTerm(key) { // value node val, _, err := rlp.SplitString(rest) if err != nil { diff --git a/trie/proof.go b/trie/proof.go index 06cf827ab..fb7734b86 100644 --- a/trie/proof.go +++ b/trie/proof.go @@ -38,7 +38,7 @@ import ( // absence of the key. func (t *Trie) Prove(key []byte) []rlp.RawValue { // Collect all nodes on the path to key. - key = compactHexDecode(key) + key = keybytesToHex(key) nodes := []node{} tn := t.root for len(key) > 0 && tn != nil { @@ -89,7 +89,7 @@ func (t *Trie) Prove(key []byte) []rlp.RawValue { // returns an error if the proof contains invalid trie nodes or the // wrong value. func VerifyProof(rootHash common.Hash, key []byte, proof []rlp.RawValue) (value []byte, err error) { - key = compactHexDecode(key) + key = keybytesToHex(key) sha := sha3.NewKeccak256() wantHash := rootHash.Bytes() for i, buf := range proof { diff --git a/trie/trie.go b/trie/trie.go index 0979eb625..e61bd0383 100644 --- a/trie/trie.go +++ b/trie/trie.go @@ -144,7 +144,7 @@ func (t *Trie) Get(key []byte) []byte { // The value bytes must not be modified by the caller. // If a node was not found in the database, a MissingNodeError is returned. func (t *Trie) TryGet(key []byte) ([]byte, error) { - key = compactHexDecode(key) + key = keybytesToHex(key) value, newroot, didResolve, err := t.tryGet(t.root, key, 0) if err == nil && didResolve { t.root = newroot @@ -211,7 +211,7 @@ func (t *Trie) Update(key, value []byte) { // // If a node was not found in the database, a MissingNodeError is returned. func (t *Trie) TryUpdate(key, value []byte) error { - k := compactHexDecode(key) + k := keybytesToHex(key) if len(value) != 0 { _, n, err := t.insert(t.root, nil, k, valueNode(value)) if err != nil { @@ -307,7 +307,7 @@ func (t *Trie) Delete(key []byte) { // TryDelete removes any existing value for key from the trie. // If a node was not found in the database, a MissingNodeError is returned. func (t *Trie) TryDelete(key []byte) error { - k := compactHexDecode(key) + k := keybytesToHex(key) _, n, err := t.delete(t.root, nil, k) if err != nil { return err