From a31d268b76ff13df8e7d060163a842b8ed569793 Mon Sep 17 00:00:00 2001 From: Felix Lange Date: Tue, 18 Apr 2017 13:08:17 +0200 Subject: [PATCH 1/5] trie: remove Key in MissingNodeError The key was constructed from nibbles, which isn't possible for all nodes. Remove the only use of Key in LightTrie by always retrying with the original key that was looked up. --- light/trie.go | 15 ++++----------- trie/errors.go | 5 ----- trie/trie.go | 1 - 3 files changed, 4 insertions(+), 17 deletions(-) diff --git a/light/trie.go b/light/trie.go index 1440f2fbf..2988a16cf 100644 --- a/light/trie.go +++ b/light/trie.go @@ -19,6 +19,7 @@ package light import ( "context" + "github.com/ethereum/go-ethereum/crypto" "github.com/ethereum/go-ethereum/ethdb" "github.com/ethereum/go-ethereum/trie" ) @@ -46,26 +47,18 @@ func NewLightTrie(id *TrieID, odr OdrBackend, useFakeMap bool) *LightTrie { // retrieveKey retrieves a single key, returns true and stores nodes in local // database if successful func (t *LightTrie) retrieveKey(ctx context.Context, key []byte) bool { - r := &TrieRequest{Id: t.id, Key: key} + r := &TrieRequest{Id: t.id, Key: crypto.Keccak256(key)} return t.odr.Retrieve(ctx, r) == nil } // do tries and retries to execute a function until it returns with no error or // an error type other than MissingNodeError -func (t *LightTrie) do(ctx context.Context, fallbackKey []byte, fn func() error) error { +func (t *LightTrie) do(ctx context.Context, key []byte, fn func() error) error { err := fn() for err != nil { - mn, ok := err.(*trie.MissingNodeError) - if !ok { + if _, ok := err.(*trie.MissingNodeError); !ok { return err } - - var key []byte - if mn.PrefixLen+mn.SuffixLen > 0 { - key = mn.Key - } else { - key = fallbackKey - } if !t.retrieveKey(ctx, key) { break } diff --git a/trie/errors.go b/trie/errors.go index 76129a70b..e23f9d563 100644 --- a/trie/errors.go +++ b/trie/errors.go @@ -30,10 +30,6 @@ import ( // // RootHash is the original root of the trie that contains the node // -// Key is a binary-encoded key that contains the prefix that leads to the first -// missing node and optionally a suffix that hints on which further nodes should -// also be retrieved -// // PrefixLen is the nibble length of the key prefix that leads from the root to // the missing node // @@ -42,7 +38,6 @@ import ( // such hints in the error message) type MissingNodeError struct { RootHash, NodeHash common.Hash - Key []byte PrefixLen, SuffixLen int } diff --git a/trie/trie.go b/trie/trie.go index 2a6044068..0979eb625 100644 --- a/trie/trie.go +++ b/trie/trie.go @@ -450,7 +450,6 @@ func (t *Trie) resolveHash(n hashNode, prefix, suffix []byte) (node, error) { return nil, &MissingNodeError{ RootHash: t.originalRoot, NodeHash: common.BytesToHash(n), - Key: compactHexEncode(append(prefix, suffix...)), PrefixLen: len(prefix), SuffixLen: len(suffix), } From f958d7d4822d257598ae36fc3b381040faa5bb30 Mon Sep 17 00:00:00 2001 From: Felix Lange Date: Tue, 18 Apr 2017 13:25:07 +0200 Subject: [PATCH 2/5] trie: rework and document key encoding 'encode' and 'decode' are meaningless because the code deals with three encodings. Document the encodings and give a name to each one. --- trie/encoding.go | 114 +++++++++++++----------------- trie/encoding_test.go | 159 ++++++++++++++++++------------------------ trie/hasher.go | 2 +- trie/iterator.go | 3 +- trie/node.go | 4 +- trie/proof.go | 4 +- trie/trie.go | 6 +- 7 files changed, 127 insertions(+), 165 deletions(-) diff --git a/trie/encoding.go b/trie/encoding.go index 2037118dd..e96a786e4 100644 --- a/trie/encoding.go +++ b/trie/encoding.go @@ -16,49 +16,54 @@ package trie -func compactEncode(hexSlice []byte) []byte { +// Trie keys are dealt with in three distinct encodings: +// +// KEYBYTES encoding contains the actual key and nothing else. This encoding is the +// input to most API functions. +// +// HEX encoding contains one byte for each nibble of the key and an optional trailing +// 'terminator' byte of value 0x10 which indicates whether or not the node at the key +// contains a value. Hex key encoding is used for nodes loaded in memory because it's +// convenient to access. +// +// COMPACT encoding is defined by the Ethereum Yellow Paper (it's called "hex prefix +// encoding" there) and contains the bytes of the key and a flag. The high nibble of the +// first byte contains the flag; the lowest bit encoding the oddness of the length and +// the second-lowest encoding whether the node at the key is a value node. The low nibble +// of the first byte is zero in the case of an even number of nibbles and the first nibble +// in the case of an odd number. All remaining nibbles (now an even number) fit properly +// into the remaining bytes. Compact encoding is used for nodes stored on disk. + +func hexToCompact(hex []byte) []byte { terminator := byte(0) - if hexSlice[len(hexSlice)-1] == 16 { + if hasTerm(hex) { terminator = 1 - hexSlice = hexSlice[:len(hexSlice)-1] + hex = hex[:len(hex)-1] } - var ( - odd = byte(len(hexSlice) % 2) - buflen = len(hexSlice)/2 + 1 - bi, hi = 0, 0 // indices - hs = byte(0) // shift: flips between 0 and 4 - ) - if odd == 0 { - bi = 1 - hs = 4 - } - buf := make([]byte, buflen) - buf[0] = terminator<<5 | byte(odd)<<4 - for bi < len(buf) && hi < len(hexSlice) { - buf[bi] |= hexSlice[hi] << hs - if hs == 0 { - bi++ - } - hi, hs = hi+1, hs^(1<<2) + buf := make([]byte, len(hex)/2+1) + buf[0] = terminator << 5 // the flag byte + if len(hex)&1 == 1 { + buf[0] |= 1 << 4 // odd flag + buf[0] |= hex[0] // first nibble is contained in the first byte + hex = hex[1:] } + decodeNibbles(hex, buf[1:]) return buf } -func compactDecode(str []byte) []byte { - base := compactHexDecode(str) +func compactToHex(compact []byte) []byte { + base := keybytesToHex(compact) base = base[:len(base)-1] + // apply terminator flag if base[0] >= 2 { base = append(base, 16) } - if base[0]%2 == 1 { - base = base[1:] - } else { - base = base[2:] - } - return base + // apply odd flag + chop := 2 - base[0]&1 + return base[chop:] } -func compactHexDecode(str []byte) []byte { +func keybytesToHex(str []byte) []byte { l := len(str)*2 + 1 var nibbles = make([]byte, l) for i, b := range str { @@ -69,35 +74,24 @@ func compactHexDecode(str []byte) []byte { return nibbles } -// compactHexEncode encodes a series of nibbles into a byte array -func compactHexEncode(nibbles []byte) []byte { - nl := len(nibbles) - if nl == 0 { - return nil +// hexToKeybytes turns hex nibbles into key bytes. +// This can only be used for keys of even length. +func hexToKeybytes(hex []byte) []byte { + if hasTerm(hex) { + hex = hex[:len(hex)-1] } - if nibbles[nl-1] == 16 { - nl-- + if len(hex)&1 != 0 { + panic("can't convert hex key of odd length") } - l := (nl + 1) / 2 - var str = make([]byte, l) - for i := range str { - b := nibbles[i*2] * 16 - if nl > i*2 { - b += nibbles[i*2+1] - } - str[i] = b - } - return str + key := make([]byte, (len(hex)+1)/2) + decodeNibbles(hex, key) + return key } -func decodeCompact(key []byte) []byte { - l := len(key) / 2 - var res = make([]byte, l) - for i := 0; i < l; i++ { - v1, v0 := key[2*i], key[2*i+1] - res[i] = v1*16 + v0 +func decodeNibbles(nibbles []byte, bytes []byte) { + for bi, ni := 0, 0; ni < len(nibbles); bi, ni = bi+1, ni+2 { + bytes[bi] = nibbles[ni]<<4 | nibbles[ni+1] } - return res } // prefixLen returns the length of the common prefix of a and b. @@ -114,15 +108,7 @@ func prefixLen(a, b []byte) int { return i } +// hasTerm returns whether a hex key has the terminator flag. func hasTerm(s []byte) bool { - return s[len(s)-1] == 16 -} - -func remTerm(s []byte) []byte { - if hasTerm(s) { - b := make([]byte, len(s)-1) - copy(b, s) - return b - } - return s + return len(s) > 0 && s[len(s)-1] == 16 } diff --git a/trie/encoding_test.go b/trie/encoding_test.go index 2f125ef2f..97d8da136 100644 --- a/trie/encoding_test.go +++ b/trie/encoding_test.go @@ -17,113 +17,88 @@ package trie import ( - "encoding/hex" + "bytes" "testing" - - checker "gopkg.in/check.v1" ) -func TestEncoding(t *testing.T) { checker.TestingT(t) } - -type TrieEncodingSuite struct{} - -var _ = checker.Suite(&TrieEncodingSuite{}) - -func (s *TrieEncodingSuite) TestCompactEncode(c *checker.C) { - // even compact encode - test1 := []byte{1, 2, 3, 4, 5} - res1 := compactEncode(test1) - c.Assert(res1, checker.DeepEquals, []byte("\x11\x23\x45")) - - // odd compact encode - test2 := []byte{0, 1, 2, 3, 4, 5} - res2 := compactEncode(test2) - c.Assert(res2, checker.DeepEquals, []byte("\x00\x01\x23\x45")) - - //odd terminated compact encode - test3 := []byte{0, 15, 1, 12, 11, 8 /*term*/, 16} - res3 := compactEncode(test3) - c.Assert(res3, checker.DeepEquals, []byte("\x20\x0f\x1c\xb8")) - - // even terminated compact encode - test4 := []byte{15, 1, 12, 11, 8 /*term*/, 16} - res4 := compactEncode(test4) - c.Assert(res4, checker.DeepEquals, []byte("\x3f\x1c\xb8")) -} - -func (s *TrieEncodingSuite) TestCompactHexDecode(c *checker.C) { - exp := []byte{7, 6, 6, 5, 7, 2, 6, 2, 16} - res := compactHexDecode([]byte("verb")) - c.Assert(res, checker.DeepEquals, exp) -} - -func (s *TrieEncodingSuite) TestCompactHexEncode(c *checker.C) { - exp := []byte("verb") - res := compactHexEncode([]byte{7, 6, 6, 5, 7, 2, 6, 2, 16}) - c.Assert(res, checker.DeepEquals, exp) -} - -func (s *TrieEncodingSuite) TestCompactDecode(c *checker.C) { - // odd compact decode - exp := []byte{1, 2, 3, 4, 5} - res := compactDecode([]byte("\x11\x23\x45")) - c.Assert(res, checker.DeepEquals, exp) - - // even compact decode - exp = []byte{0, 1, 2, 3, 4, 5} - res = compactDecode([]byte("\x00\x01\x23\x45")) - c.Assert(res, checker.DeepEquals, exp) - - // even terminated compact decode - exp = []byte{0, 15, 1, 12, 11, 8 /*term*/, 16} - res = compactDecode([]byte("\x20\x0f\x1c\xb8")) - c.Assert(res, checker.DeepEquals, exp) - - // even terminated compact decode - exp = []byte{15, 1, 12, 11, 8 /*term*/, 16} - res = compactDecode([]byte("\x3f\x1c\xb8")) - c.Assert(res, checker.DeepEquals, exp) -} - -func (s *TrieEncodingSuite) TestDecodeCompact(c *checker.C) { - exp, _ := hex.DecodeString("012345") - res := decodeCompact([]byte{0, 1, 2, 3, 4, 5}) - c.Assert(res, checker.DeepEquals, exp) - - exp, _ = hex.DecodeString("012345") - res = decodeCompact([]byte{0, 1, 2, 3, 4, 5, 16}) - c.Assert(res, checker.DeepEquals, exp) - - exp, _ = hex.DecodeString("abcdef") - res = decodeCompact([]byte{10, 11, 12, 13, 14, 15}) - c.Assert(res, checker.DeepEquals, exp) -} - -func BenchmarkCompactEncode(b *testing.B) { - - testBytes := []byte{0, 15, 1, 12, 11, 8 /*term*/, 16} - for i := 0; i < b.N; i++ { - compactEncode(testBytes) +func TestHexCompact(t *testing.T) { + tests := []struct{ hex, compact []byte }{ + // empty keys, with and without terminator. + {hex: []byte{}, compact: []byte{0x00}}, + {hex: []byte{16}, compact: []byte{0x20}}, + // odd length, no terminator + {hex: []byte{1, 2, 3, 4, 5}, compact: []byte{0x11, 0x23, 0x45}}, + // even length, no terminator + {hex: []byte{0, 1, 2, 3, 4, 5}, compact: []byte{0x00, 0x01, 0x23, 0x45}}, + // odd length, terminator + {hex: []byte{15, 1, 12, 11, 8, 16 /*term*/}, compact: []byte{0x3f, 0x1c, 0xb8}}, + // even length, terminator + {hex: []byte{0, 15, 1, 12, 11, 8, 16 /*term*/}, compact: []byte{0x20, 0x0f, 0x1c, 0xb8}}, + } + for _, test := range tests { + if c := hexToCompact(test.hex); !bytes.Equal(c, test.compact) { + t.Errorf("hexToCompact(%x) -> %x, want %x", test.hex, c, test.compact) + } + if h := compactToHex(test.compact); !bytes.Equal(h, test.hex) { + t.Errorf("compactToHex(%x) -> %x, want %x", test.compact, h, test.hex) + } } } -func BenchmarkCompactDecode(b *testing.B) { - testBytes := []byte{0, 15, 1, 12, 11, 8 /*term*/, 16} - for i := 0; i < b.N; i++ { - compactDecode(testBytes) +func TestHexKeybytes(t *testing.T) { + tests := []struct{ key, hexIn, hexOut []byte }{ + {key: []byte{}, hexIn: []byte{16}, hexOut: []byte{16}}, + {key: []byte{}, hexIn: []byte{}, hexOut: []byte{16}}, + { + key: []byte{0x12, 0x34, 0x56}, + hexIn: []byte{1, 2, 3, 4, 5, 6, 16}, + hexOut: []byte{1, 2, 3, 4, 5, 6, 16}, + }, + { + key: []byte{0x12, 0x34, 0x5}, + hexIn: []byte{1, 2, 3, 4, 0, 5, 16}, + hexOut: []byte{1, 2, 3, 4, 0, 5, 16}, + }, + { + key: []byte{0x12, 0x34, 0x56}, + hexIn: []byte{1, 2, 3, 4, 5, 6}, + hexOut: []byte{1, 2, 3, 4, 5, 6, 16}, + }, + } + for _, test := range tests { + if h := keybytesToHex(test.key); !bytes.Equal(h, test.hexOut) { + t.Errorf("keybytesToHex(%x) -> %x, want %x", test.key, h, test.hexOut) + } + if k := hexToKeybytes(test.hexIn); !bytes.Equal(k, test.key) { + t.Errorf("hexToKeybytes(%x) -> %x, want %x", test.hexIn, k, test.key) + } } } -func BenchmarkCompactHexDecode(b *testing.B) { +func BenchmarkHexToCompact(b *testing.B) { + testBytes := []byte{0, 15, 1, 12, 11, 8, 16 /*term*/} + for i := 0; i < b.N; i++ { + hexToCompact(testBytes) + } +} + +func BenchmarkCompactToHex(b *testing.B) { + testBytes := []byte{0, 15, 1, 12, 11, 8, 16 /*term*/} + for i := 0; i < b.N; i++ { + compactToHex(testBytes) + } +} + +func BenchmarkKeybytesToHex(b *testing.B) { testBytes := []byte{7, 6, 6, 5, 7, 2, 6, 2, 16} for i := 0; i < b.N; i++ { - compactHexDecode(testBytes) + keybytesToHex(testBytes) } } -func BenchmarkDecodeCompact(b *testing.B) { +func BenchmarkHexToKeybytes(b *testing.B) { testBytes := []byte{7, 6, 6, 5, 7, 2, 6, 2, 16} for i := 0; i < b.N; i++ { - decodeCompact(testBytes) + hexToKeybytes(testBytes) } } diff --git a/trie/hasher.go b/trie/hasher.go index 98c309531..85b6b60f5 100644 --- a/trie/hasher.go +++ b/trie/hasher.go @@ -105,7 +105,7 @@ func (h *hasher) hashChildren(original node, db DatabaseWriter) (node, node, err case *shortNode: // Hash the short node's child, caching the newly hashed subtree collapsed, cached := n.copy(), n.copy() - collapsed.Key = compactEncode(n.Key) + collapsed.Key = hexToCompact(n.Key) cached.Key = common.CopyBytes(n.Key) if _, ok := n.Val.(valueNode); !ok { diff --git a/trie/iterator.go b/trie/iterator.go index 42149a7d3..dd63a0c5a 100644 --- a/trie/iterator.go +++ b/trie/iterator.go @@ -19,6 +19,7 @@ package trie import ( "bytes" "container/heap" + "github.com/ethereum/go-ethereum/common" ) @@ -48,7 +49,7 @@ func NewIteratorFromNodeIterator(it NodeIterator) *Iterator { func (it *Iterator) Next() bool { for it.nodeIt.Next(true) { if it.nodeIt.Leaf() { - it.Key = decodeCompact(it.nodeIt.Path()) + it.Key = hexToKeybytes(it.nodeIt.Path()) it.Value = it.nodeIt.LeafBlob() return true } diff --git a/trie/node.go b/trie/node.go index 4aa0cab65..a7697fc0c 100644 --- a/trie/node.go +++ b/trie/node.go @@ -139,8 +139,8 @@ func decodeShort(hash, buf, elems []byte, cachegen uint16) (node, error) { return nil, err } flag := nodeFlag{hash: hash, gen: cachegen} - key := compactDecode(kbuf) - if key[len(key)-1] == 16 { + key := compactToHex(kbuf) + if hasTerm(key) { // value node val, _, err := rlp.SplitString(rest) if err != nil { diff --git a/trie/proof.go b/trie/proof.go index 06cf827ab..fb7734b86 100644 --- a/trie/proof.go +++ b/trie/proof.go @@ -38,7 +38,7 @@ import ( // absence of the key. func (t *Trie) Prove(key []byte) []rlp.RawValue { // Collect all nodes on the path to key. - key = compactHexDecode(key) + key = keybytesToHex(key) nodes := []node{} tn := t.root for len(key) > 0 && tn != nil { @@ -89,7 +89,7 @@ func (t *Trie) Prove(key []byte) []rlp.RawValue { // returns an error if the proof contains invalid trie nodes or the // wrong value. func VerifyProof(rootHash common.Hash, key []byte, proof []rlp.RawValue) (value []byte, err error) { - key = compactHexDecode(key) + key = keybytesToHex(key) sha := sha3.NewKeccak256() wantHash := rootHash.Bytes() for i, buf := range proof { diff --git a/trie/trie.go b/trie/trie.go index 0979eb625..e61bd0383 100644 --- a/trie/trie.go +++ b/trie/trie.go @@ -144,7 +144,7 @@ func (t *Trie) Get(key []byte) []byte { // The value bytes must not be modified by the caller. // If a node was not found in the database, a MissingNodeError is returned. func (t *Trie) TryGet(key []byte) ([]byte, error) { - key = compactHexDecode(key) + key = keybytesToHex(key) value, newroot, didResolve, err := t.tryGet(t.root, key, 0) if err == nil && didResolve { t.root = newroot @@ -211,7 +211,7 @@ func (t *Trie) Update(key, value []byte) { // // If a node was not found in the database, a MissingNodeError is returned. func (t *Trie) TryUpdate(key, value []byte) error { - k := compactHexDecode(key) + k := keybytesToHex(key) if len(value) != 0 { _, n, err := t.insert(t.root, nil, k, valueNode(value)) if err != nil { @@ -307,7 +307,7 @@ func (t *Trie) Delete(key []byte) { // TryDelete removes any existing value for key from the trie. // If a node was not found in the database, a MissingNodeError is returned. func (t *Trie) TryDelete(key []byte) error { - k := compactHexDecode(key) + k := keybytesToHex(key) _, n, err := t.delete(t.root, nil, k) if err != nil { return err From a13e920af01692cb07a520cda688f1cc5b5469dd Mon Sep 17 00:00:00 2001 From: Felix Lange Date: Tue, 18 Apr 2017 13:37:10 +0200 Subject: [PATCH 3/5] trie: clean up iterator constructors Make it so each iterator has exactly one public constructor: - NodeIterators can be created through a method. - Iterators can be created through NewIterator on any NodeIterator. --- core/state/dump.go | 5 +++-- core/state/iterator.go | 2 +- core/state/statedb.go | 2 +- trie/iterator.go | 15 ++++----------- trie/iterator_test.go | 14 +++++++------- trie/secure_trie.go | 6 +----- trie/sync_test.go | 2 +- trie/trie.go | 4 ++-- trie/trie_test.go | 2 +- 9 files changed, 21 insertions(+), 31 deletions(-) diff --git a/core/state/dump.go b/core/state/dump.go index 8294d61b9..6338ddf88 100644 --- a/core/state/dump.go +++ b/core/state/dump.go @@ -22,6 +22,7 @@ import ( "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/rlp" + "github.com/ethereum/go-ethereum/trie" ) type DumpAccount struct { @@ -44,7 +45,7 @@ func (self *StateDB) RawDump() Dump { Accounts: make(map[string]DumpAccount), } - it := self.trie.Iterator() + it := trie.NewIterator(self.trie.NodeIterator()) for it.Next() { addr := self.trie.GetKey(it.Key) var data Account @@ -61,7 +62,7 @@ func (self *StateDB) RawDump() Dump { Code: common.Bytes2Hex(obj.Code(self.db)), Storage: make(map[string]string), } - storageIt := obj.getTrie(self.db).Iterator() + storageIt := trie.NewIterator(obj.getTrie(self.db).NodeIterator()) for storageIt.Next() { account.Storage[common.Bytes2Hex(self.trie.GetKey(storageIt.Key))] = common.Bytes2Hex(storageIt.Value) } diff --git a/core/state/iterator.go b/core/state/iterator.go index 170aec983..d2dd5a74e 100644 --- a/core/state/iterator.go +++ b/core/state/iterator.go @@ -118,7 +118,7 @@ func (it *NodeIterator) step() error { if err != nil { return err } - it.dataIt = trie.NewNodeIterator(dataTrie) + it.dataIt = dataTrie.NodeIterator() if !it.dataIt.Next(true) { it.dataIt = nil } diff --git a/core/state/statedb.go b/core/state/statedb.go index 0c72fc6b0..24381ced5 100644 --- a/core/state/statedb.go +++ b/core/state/statedb.go @@ -481,7 +481,7 @@ func (db *StateDB) ForEachStorage(addr common.Address, cb func(key, value common cb(h, value) } - it := so.getTrie(db.db).Iterator() + it := trie.NewIterator(so.getTrie(db.db).NodeIterator()) for it.Next() { // ignore cached values key := common.BytesToHash(db.trie.GetKey(it.Key)) diff --git a/trie/iterator.go b/trie/iterator.go index dd63a0c5a..fef5b2593 100644 --- a/trie/iterator.go +++ b/trie/iterator.go @@ -31,15 +31,8 @@ type Iterator struct { Value []byte // Current data value on which the iterator is positioned on } -// NewIterator creates a new key-value iterator. -func NewIterator(trie *Trie) *Iterator { - return &Iterator{ - nodeIt: NewNodeIterator(trie), - } -} - -// FromNodeIterator creates a new key-value iterator from a node iterator -func NewIteratorFromNodeIterator(it NodeIterator) *Iterator { +// NewIterator creates a new key-value iterator from a node iterator +func NewIterator(it NodeIterator) *Iterator { return &Iterator{ nodeIt: it, } @@ -99,8 +92,8 @@ type nodeIterator struct { path []byte // Path to the current node } -// NewNodeIterator creates an post-order trie iterator. -func NewNodeIterator(trie *Trie) NodeIterator { +// newNodeIterator creates an post-order trie iterator. +func newNodeIterator(trie *Trie) NodeIterator { if trie.Hash() == emptyState { return new(nodeIterator) } diff --git a/trie/iterator_test.go b/trie/iterator_test.go index c101bb7b0..04d51aaf5 100644 --- a/trie/iterator_test.go +++ b/trie/iterator_test.go @@ -42,7 +42,7 @@ func TestIterator(t *testing.T) { trie.Commit() found := make(map[string]string) - it := NewIterator(trie) + it := NewIterator(trie.NodeIterator()) for it.Next() { found[string(it.Key)] = string(it.Value) } @@ -72,7 +72,7 @@ func TestIteratorLargeData(t *testing.T) { vals[string(value2.k)] = value2 } - it := NewIterator(trie) + it := NewIterator(trie.NodeIterator()) for it.Next() { vals[string(it.Key)].t = true } @@ -99,7 +99,7 @@ func TestNodeIteratorCoverage(t *testing.T) { // Gather all the node hashes found by the iterator hashes := make(map[common.Hash]struct{}) - for it := NewNodeIterator(trie); it.Next(true); { + for it := trie.NodeIterator(); it.Next(true); { if it.Hash() != (common.Hash{}) { hashes[it.Hash()] = struct{}{} } @@ -154,8 +154,8 @@ func TestDifferenceIterator(t *testing.T) { trieb.Commit() found := make(map[string]string) - di, _ := NewDifferenceIterator(NewNodeIterator(triea), NewNodeIterator(trieb)) - it := NewIteratorFromNodeIterator(di) + di, _ := NewDifferenceIterator(triea.NodeIterator(), trieb.NodeIterator()) + it := NewIterator(di) for it.Next() { found[string(it.Key)] = string(it.Value) } @@ -189,8 +189,8 @@ func TestUnionIterator(t *testing.T) { } trieb.Commit() - di, _ := NewUnionIterator([]NodeIterator{NewNodeIterator(triea), NewNodeIterator(trieb)}) - it := NewIteratorFromNodeIterator(di) + di, _ := NewUnionIterator([]NodeIterator{triea.NodeIterator(), trieb.NodeIterator()}) + it := NewIterator(di) all := []struct{ k, v string }{ {"aardvark", "c"}, diff --git a/trie/secure_trie.go b/trie/secure_trie.go index 113fb6a1a..201716d18 100644 --- a/trie/secure_trie.go +++ b/trie/secure_trie.go @@ -156,12 +156,8 @@ func (t *SecureTrie) Root() []byte { return t.trie.Root() } -func (t *SecureTrie) Iterator() *Iterator { - return t.trie.Iterator() -} - func (t *SecureTrie) NodeIterator() NodeIterator { - return NewNodeIterator(&t.trie) + return t.trie.NodeIterator() } // CommitTo writes all nodes and the secure hash pre-images to the given database. diff --git a/trie/sync_test.go b/trie/sync_test.go index acae039cd..6d345ad3f 100644 --- a/trie/sync_test.go +++ b/trie/sync_test.go @@ -80,7 +80,7 @@ func checkTrieConsistency(db Database, root common.Hash) error { if err != nil { return nil // // Consider a non existent state consistent } - it := NewNodeIterator(trie) + it := trie.NodeIterator() for it.Next(true) { } return it.Error() diff --git a/trie/trie.go b/trie/trie.go index e61bd0383..dbffc0ac3 100644 --- a/trie/trie.go +++ b/trie/trie.go @@ -126,8 +126,8 @@ func New(root common.Hash, db Database) (*Trie, error) { } // Iterator returns an iterator over all mappings in the trie. -func (t *Trie) Iterator() *Iterator { - return NewIterator(t) +func (t *Trie) NodeIterator() NodeIterator { + return newNodeIterator(t) } // Get returns the value for key stored in the trie. diff --git a/trie/trie_test.go b/trie/trie_test.go index 01ae3a4e7..cacb08824 100644 --- a/trie/trie_test.go +++ b/trie/trie_test.go @@ -439,7 +439,7 @@ func runRandTest(rt randTest) bool { tr = newtr case opItercheckhash: checktr, _ := New(common.Hash{}, nil) - it := tr.Iterator() + it := NewIterator(tr.NodeIterator()) for it.Next() { checktr.Update(it.Key, it.Value) } From 4047ccad2fb73fd2cfd69bf5b8cbfa788871ce0f Mon Sep 17 00:00:00 2001 From: Felix Lange Date: Thu, 13 Apr 2017 14:41:24 +0200 Subject: [PATCH 4/5] trie: add start key to NodeIterator constructors The 'step' method is split into two parts, 'peek' and 'push'. peek returns the next state but doesn't make it current. The end of iteration was previously tracked by setting 'trie' to nil. End of iteration is now tracked using the 'iteratorEnd' error, which is slightly cleaner and requires less code. --- core/state/dump.go | 4 +- core/state/iterator.go | 4 +- core/state/statedb.go | 2 +- trie/iterator.go | 131 ++++++++++++++++++++++++----------------- trie/iterator_test.go | 65 ++++++++++++++++---- trie/secure_trie.go | 6 +- trie/sync_test.go | 2 +- trie/trie.go | 7 ++- trie/trie_test.go | 2 +- 9 files changed, 148 insertions(+), 75 deletions(-) diff --git a/core/state/dump.go b/core/state/dump.go index 6338ddf88..ffa1a7283 100644 --- a/core/state/dump.go +++ b/core/state/dump.go @@ -45,7 +45,7 @@ func (self *StateDB) RawDump() Dump { Accounts: make(map[string]DumpAccount), } - it := trie.NewIterator(self.trie.NodeIterator()) + it := trie.NewIterator(self.trie.NodeIterator(nil)) for it.Next() { addr := self.trie.GetKey(it.Key) var data Account @@ -62,7 +62,7 @@ func (self *StateDB) RawDump() Dump { Code: common.Bytes2Hex(obj.Code(self.db)), Storage: make(map[string]string), } - storageIt := trie.NewIterator(obj.getTrie(self.db).NodeIterator()) + storageIt := trie.NewIterator(obj.getTrie(self.db).NodeIterator(nil)) for storageIt.Next() { account.Storage[common.Bytes2Hex(self.trie.GetKey(storageIt.Key))] = common.Bytes2Hex(storageIt.Value) } diff --git a/core/state/iterator.go b/core/state/iterator.go index d2dd5a74e..a8a2722ae 100644 --- a/core/state/iterator.go +++ b/core/state/iterator.go @@ -75,7 +75,7 @@ func (it *NodeIterator) step() error { } // Initialize the iterator if we've just started if it.stateIt == nil { - it.stateIt = it.state.trie.NodeIterator() + it.stateIt = it.state.trie.NodeIterator(nil) } // If we had data nodes previously, we surely have at least state nodes if it.dataIt != nil { @@ -118,7 +118,7 @@ func (it *NodeIterator) step() error { if err != nil { return err } - it.dataIt = dataTrie.NodeIterator() + it.dataIt = dataTrie.NodeIterator(nil) if !it.dataIt.Next(true) { it.dataIt = nil } diff --git a/core/state/statedb.go b/core/state/statedb.go index 24381ced5..431f33e02 100644 --- a/core/state/statedb.go +++ b/core/state/statedb.go @@ -481,7 +481,7 @@ func (db *StateDB) ForEachStorage(addr common.Address, cb func(key, value common cb(h, value) } - it := trie.NewIterator(so.getTrie(db.db).NodeIterator()) + it := trie.NewIterator(so.getTrie(db.db).NodeIterator(nil)) for it.Next() { // ignore cached values key := common.BytesToHash(db.trie.GetKey(it.Key)) diff --git a/trie/iterator.go b/trie/iterator.go index fef5b2593..26ae1d5ad 100644 --- a/trie/iterator.go +++ b/trie/iterator.go @@ -19,10 +19,13 @@ package trie import ( "bytes" "container/heap" + "errors" "github.com/ethereum/go-ethereum/common" ) +var iteratorEnd = errors.New("end of iteration") + // Iterator is a key-value trie iterator that traverses a Trie. type Iterator struct { nodeIt NodeIterator @@ -79,25 +82,24 @@ type nodeIteratorState struct { hash common.Hash // Hash of the node being iterated (nil if not standalone) node node // Trie node being iterated parent common.Hash // Hash of the first full ancestor node (nil if current is the root) - child int // Child to be processed next + index int // Child to be processed next pathlen int // Length of the path to this node } type nodeIterator struct { trie *Trie // Trie being iterated stack []*nodeIteratorState // Hierarchy of trie nodes persisting the iteration state - - err error // Failure set in case of an internal error in the iterator - - path []byte // Path to the current node + err error // Failure set in case of an internal error in the iterator + path []byte // Path to the current node } -// newNodeIterator creates an post-order trie iterator. -func newNodeIterator(trie *Trie) NodeIterator { +func newNodeIterator(trie *Trie, start []byte) NodeIterator { if trie.Hash() == emptyState { return new(nodeIterator) } - return &nodeIterator{trie: trie} + it := &nodeIterator{trie: trie} + it.seek(start) + return it } // Hash returns the hash of the current node @@ -147,6 +149,9 @@ func (it *nodeIterator) Path() []byte { // Error returns the error set in case of an internal error in the iterator func (it *nodeIterator) Error() error { + if it.err == iteratorEnd { + return nil + } return it.err } @@ -155,47 +160,54 @@ func (it *nodeIterator) Error() error { // sets the Error field to the encountered failure. If `descend` is false, // skips iterating over any subnodes of the current node. func (it *nodeIterator) Next(descend bool) bool { - // If the iterator failed previously, don't do anything if it.err != nil { return false } // Otherwise step forward with the iterator and report any errors - if err := it.step(descend); err != nil { + state, parentIndex, path, err := it.peek(descend) + if err != nil { it.err = err return false } - return it.trie != nil + it.push(state, parentIndex, path) + return true } -// step moves the iterator to the next node of the trie. -func (it *nodeIterator) step(descend bool) error { - if it.trie == nil { - // Abort if we reached the end of the iteration - return nil +func (it *nodeIterator) seek(prefix []byte) { + // The path we're looking for is the hex encoded key without terminator. + key := keybytesToHex(prefix) + key = key[:len(key)-1] + // Move forward until we're just before the closest match to key. + for { + state, parentIndex, path, err := it.peek(bytes.HasPrefix(key, it.path)) + if err != nil || bytes.Compare(path, key) >= 0 { + it.err = err + return + } + it.push(state, parentIndex, path) } +} + +// peek creates the next state of the iterator. +func (it *nodeIterator) peek(descend bool) (*nodeIteratorState, *int, []byte, error) { if len(it.stack) == 0 { // Initialize the iterator if we've just started. root := it.trie.Hash() - state := &nodeIteratorState{node: it.trie.root, child: -1} + state := &nodeIteratorState{node: it.trie.root, index: -1} if root != emptyRoot { state.hash = root } - it.stack = append(it.stack, state) - return nil + return state, nil, nil, nil } - if !descend { // If we're skipping children, pop the current node first - it.path = it.path[:it.stack[len(it.stack)-1].pathlen] - it.stack = it.stack[:len(it.stack)-1] + it.pop() } // Continue iteration to the next child -outer: for { if len(it.stack) == 0 { - it.trie = nil - return nil + return nil, nil, nil, iteratorEnd } parent := it.stack[len(it.stack)-1] ancestor := parent.hash @@ -203,63 +215,76 @@ outer: ancestor = parent.parent } if node, ok := parent.node.(*fullNode); ok { - // Full node, iterate over children - for parent.child++; parent.child < len(node.Children); parent.child++ { - child := node.Children[parent.child] + // Full node, move to the first non-nil child. + for i := parent.index + 1; i < len(node.Children); i++ { + child := node.Children[i] if child != nil { hash, _ := child.cache() - it.stack = append(it.stack, &nodeIteratorState{ + state := &nodeIteratorState{ hash: common.BytesToHash(hash), node: child, parent: ancestor, - child: -1, + index: -1, pathlen: len(it.path), - }) - it.path = append(it.path, byte(parent.child)) - break outer + } + path := append(it.path, byte(i)) + parent.index = i - 1 + return state, &parent.index, path, nil } } } else if node, ok := parent.node.(*shortNode); ok { // Short node, return the pointer singleton child - if parent.child < 0 { - parent.child++ + if parent.index < 0 { hash, _ := node.Val.cache() - it.stack = append(it.stack, &nodeIteratorState{ + state := &nodeIteratorState{ hash: common.BytesToHash(hash), node: node.Val, parent: ancestor, - child: -1, + index: -1, pathlen: len(it.path), - }) - if hasTerm(node.Key) { - it.path = append(it.path, node.Key[:len(node.Key)-1]...) - } else { - it.path = append(it.path, node.Key...) } - break + var path []byte + if hasTerm(node.Key) { + path = append(it.path, node.Key[:len(node.Key)-1]...) + } else { + path = append(it.path, node.Key...) + } + return state, &parent.index, path, nil } } else if hash, ok := parent.node.(hashNode); ok { // Hash node, resolve the hash child from the database - if parent.child < 0 { - parent.child++ + if parent.index < 0 { node, err := it.trie.resolveHash(hash, nil, nil) if err != nil { - return err + return it.stack[len(it.stack)-1], &parent.index, it.path, err } - it.stack = append(it.stack, &nodeIteratorState{ + state := &nodeIteratorState{ hash: common.BytesToHash(hash), node: node, parent: ancestor, - child: -1, + index: -1, pathlen: len(it.path), - }) - break + } + return state, &parent.index, it.path, nil } } - it.path = it.path[:parent.pathlen] - it.stack = it.stack[:len(it.stack)-1] + // No more child nodes, move back up. + it.pop() } - return nil +} + +func (it *nodeIterator) push(state *nodeIteratorState, parentIndex *int, path []byte) { + it.path = path + it.stack = append(it.stack, state) + if parentIndex != nil { + *parentIndex += 1 + } +} + +func (it *nodeIterator) pop() { + parent := it.stack[len(it.stack)-1] + it.path = it.path[:parent.pathlen] + it.stack = it.stack[:len(it.stack)-1] } func compareNodes(a, b NodeIterator) int { diff --git a/trie/iterator_test.go b/trie/iterator_test.go index 04d51aaf5..f161fd99d 100644 --- a/trie/iterator_test.go +++ b/trie/iterator_test.go @@ -17,6 +17,8 @@ package trie import ( + "bytes" + "fmt" "testing" "github.com/ethereum/go-ethereum/common" @@ -42,7 +44,7 @@ func TestIterator(t *testing.T) { trie.Commit() found := make(map[string]string) - it := NewIterator(trie.NodeIterator()) + it := NewIterator(trie.NodeIterator(nil)) for it.Next() { found[string(it.Key)] = string(it.Value) } @@ -72,7 +74,7 @@ func TestIteratorLargeData(t *testing.T) { vals[string(value2.k)] = value2 } - it := NewIterator(trie.NodeIterator()) + it := NewIterator(trie.NodeIterator(nil)) for it.Next() { vals[string(it.Key)].t = true } @@ -99,7 +101,7 @@ func TestNodeIteratorCoverage(t *testing.T) { // Gather all the node hashes found by the iterator hashes := make(map[common.Hash]struct{}) - for it := trie.NodeIterator(); it.Next(true); { + for it := trie.NodeIterator(nil); it.Next(true); { if it.Hash() != (common.Hash{}) { hashes[it.Hash()] = struct{}{} } @@ -117,18 +119,20 @@ func TestNodeIteratorCoverage(t *testing.T) { } } -var testdata1 = []struct{ k, v string }{ - {"bar", "b"}, +type kvs struct{ k, v string } + +var testdata1 = []kvs{ {"barb", "ba"}, - {"bars", "bb"}, {"bard", "bc"}, + {"bars", "bb"}, + {"bar", "b"}, {"fab", "z"}, - {"foo", "a"}, {"food", "ab"}, {"foos", "aa"}, + {"foo", "a"}, } -var testdata2 = []struct{ k, v string }{ +var testdata2 = []kvs{ {"aardvark", "c"}, {"bar", "b"}, {"barb", "bd"}, @@ -140,6 +144,47 @@ var testdata2 = []struct{ k, v string }{ {"jars", "d"}, } +func TestIteratorSeek(t *testing.T) { + trie := newEmpty() + for _, val := range testdata1 { + trie.Update([]byte(val.k), []byte(val.v)) + } + + // Seek to the middle. + it := NewIterator(trie.NodeIterator([]byte("fab"))) + if err := checkIteratorOrder(testdata1[4:], it); err != nil { + t.Fatal(err) + } + + // Seek to a non-existent key. + it = NewIterator(trie.NodeIterator([]byte("barc"))) + if err := checkIteratorOrder(testdata1[1:], it); err != nil { + t.Fatal(err) + } + + // Seek beyond the end. + it = NewIterator(trie.NodeIterator([]byte("z"))) + if err := checkIteratorOrder(nil, it); err != nil { + t.Fatal(err) + } +} + +func checkIteratorOrder(want []kvs, it *Iterator) error { + for it.Next() { + if len(want) == 0 { + return fmt.Errorf("didn't expect any more values, got key %q", it.Key) + } + if !bytes.Equal(it.Key, []byte(want[0].k)) { + return fmt.Errorf("wrong key: got %q, want %q", it.Key, want[0].k) + } + want = want[1:] + } + if len(want) > 0 { + return fmt.Errorf("iterator ended early, want key %q", want[0]) + } + return nil +} + func TestDifferenceIterator(t *testing.T) { triea := newEmpty() for _, val := range testdata1 { @@ -154,7 +199,7 @@ func TestDifferenceIterator(t *testing.T) { trieb.Commit() found := make(map[string]string) - di, _ := NewDifferenceIterator(triea.NodeIterator(), trieb.NodeIterator()) + di, _ := NewDifferenceIterator(triea.NodeIterator(nil), trieb.NodeIterator(nil)) it := NewIterator(di) for it.Next() { found[string(it.Key)] = string(it.Value) @@ -189,7 +234,7 @@ func TestUnionIterator(t *testing.T) { } trieb.Commit() - di, _ := NewUnionIterator([]NodeIterator{triea.NodeIterator(), trieb.NodeIterator()}) + di, _ := NewUnionIterator([]NodeIterator{triea.NodeIterator(nil), trieb.NodeIterator(nil)}) it := NewIterator(di) all := []struct{ k, v string }{ diff --git a/trie/secure_trie.go b/trie/secure_trie.go index 201716d18..37d1d4b09 100644 --- a/trie/secure_trie.go +++ b/trie/secure_trie.go @@ -156,8 +156,10 @@ func (t *SecureTrie) Root() []byte { return t.trie.Root() } -func (t *SecureTrie) NodeIterator() NodeIterator { - return t.trie.NodeIterator() +// NodeIterator returns an iterator that returns nodes of the underlying trie. Iteration +// starts at the key after the given start key. +func (t *SecureTrie) NodeIterator(start []byte) NodeIterator { + return t.trie.NodeIterator(start) } // CommitTo writes all nodes and the secure hash pre-images to the given database. diff --git a/trie/sync_test.go b/trie/sync_test.go index 6d345ad3f..1e27cbb67 100644 --- a/trie/sync_test.go +++ b/trie/sync_test.go @@ -80,7 +80,7 @@ func checkTrieConsistency(db Database, root common.Hash) error { if err != nil { return nil // // Consider a non existent state consistent } - it := trie.NodeIterator() + it := trie.NodeIterator(nil) for it.Next(true) { } return it.Error() diff --git a/trie/trie.go b/trie/trie.go index dbffc0ac3..5759f97e3 100644 --- a/trie/trie.go +++ b/trie/trie.go @@ -125,9 +125,10 @@ func New(root common.Hash, db Database) (*Trie, error) { return trie, nil } -// Iterator returns an iterator over all mappings in the trie. -func (t *Trie) NodeIterator() NodeIterator { - return newNodeIterator(t) +// NodeIterator returns an iterator that returns nodes of the trie. Iteration starts at +// the key after the given start key. +func (t *Trie) NodeIterator(start []byte) NodeIterator { + return newNodeIterator(t, start) } // Get returns the value for key stored in the trie. diff --git a/trie/trie_test.go b/trie/trie_test.go index cacb08824..61adbba0c 100644 --- a/trie/trie_test.go +++ b/trie/trie_test.go @@ -439,7 +439,7 @@ func runRandTest(rt randTest) bool { tr = newtr case opItercheckhash: checktr, _ := New(common.Hash{}, nil) - it := NewIterator(tr.NodeIterator()) + it := NewIterator(tr.NodeIterator(nil)) for it.Next() { checktr.Update(it.Key, it.Value) } From 207bd7d2cddbf16ac2cb870fd6a1c558f02fd8ac Mon Sep 17 00:00:00 2001 From: Felix Lange Date: Wed, 19 Apr 2017 12:09:04 +0200 Subject: [PATCH 5/5] eth: add debug_storageRangeAt --- core/state/state_object.go | 9 ++- core/state/statedb.go | 11 +++ eth/api.go | 151 ++++++++++++++++++++++++------------ eth/api_test.go | 88 +++++++++++++++++++++ internal/web3ext/web3ext.go | 5 ++ 5 files changed, 213 insertions(+), 51 deletions(-) create mode 100644 eth/api_test.go diff --git a/core/state/state_object.go b/core/state/state_object.go index 7d3315303..dcad9d068 100644 --- a/core/state/state_object.go +++ b/core/state/state_object.go @@ -201,7 +201,7 @@ func (self *stateObject) setState(key, value common.Hash) { } // updateTrie writes cached storage modifications into the object's storage trie. -func (self *stateObject) updateTrie(db trie.Database) { +func (self *stateObject) updateTrie(db trie.Database) *trie.SecureTrie { tr := self.getTrie(db) for key, value := range self.dirtyStorage { delete(self.dirtyStorage, key) @@ -213,6 +213,7 @@ func (self *stateObject) updateTrie(db trie.Database) { v, _ := rlp.EncodeToBytes(bytes.TrimLeft(value[:], "\x00")) tr.Update(key[:], v) } + return tr } // UpdateRoot sets the trie root to the current root hash of @@ -280,7 +281,11 @@ func (c *stateObject) ReturnGas(gas *big.Int) {} func (self *stateObject) deepCopy(db *StateDB, onDirty func(addr common.Address)) *stateObject { stateObject := newObject(db, self.address, self.data, onDirty) - stateObject.trie = self.trie + if self.trie != nil { + // A shallow copy makes the two tries independent. + cpy := *self.trie + stateObject.trie = &cpy + } stateObject.code = self.code stateObject.dirtyStorage = self.dirtyStorage.Copy() stateObject.cachedStorage = self.dirtyStorage.Copy() diff --git a/core/state/statedb.go b/core/state/statedb.go index 431f33e02..3b753a2e6 100644 --- a/core/state/statedb.go +++ b/core/state/statedb.go @@ -296,6 +296,17 @@ func (self *StateDB) GetState(a common.Address, b common.Hash) common.Hash { return common.Hash{} } +// StorageTrie returns the storage trie of an account. +// The return value is a copy and is nil for non-existent accounts. +func (self *StateDB) StorageTrie(a common.Address) *trie.SecureTrie { + stateObject := self.getStateObject(a) + if stateObject == nil { + return nil + } + cpy := stateObject.deepCopy(self, nil) + return cpy.updateTrie(self.db) +} + func (self *StateDB) HasSuicided(addr common.Address) bool { stateObject := self.getStateObject(addr) if stateObject != nil { diff --git a/eth/api.go b/eth/api.go index b386c08b4..61f7bdd92 100644 --- a/eth/api.go +++ b/eth/api.go @@ -20,7 +20,6 @@ import ( "bytes" "compress/gzip" "context" - "errors" "fmt" "io" "io/ioutil" @@ -41,6 +40,7 @@ import ( "github.com/ethereum/go-ethereum/params" "github.com/ethereum/go-ethereum/rlp" "github.com/ethereum/go-ethereum/rpc" + "github.com/ethereum/go-ethereum/trie" ) const defaultTraceTimeout = 5 * time.Second @@ -526,59 +526,67 @@ func (api *PrivateDebugAPI) TraceTransaction(ctx context.Context, txHash common. if tx == nil { return nil, fmt.Errorf("transaction %x not found", txHash) } - block := api.eth.BlockChain().GetBlockByHash(blockHash) - if block == nil { - return nil, fmt.Errorf("block %x not found", blockHash) - } - // Create the state database to mutate and eventually trace - parent := api.eth.BlockChain().GetBlock(block.ParentHash(), block.NumberU64()-1) - if parent == nil { - return nil, fmt.Errorf("block parent %x not found", block.ParentHash()) - } - stateDb, err := api.eth.BlockChain().StateAt(parent.Root()) + msg, context, statedb, err := api.computeTxEnv(blockHash, int(txIndex)) if err != nil { return nil, err } - signer := types.MakeSigner(api.config, block.Number()) - // Mutate the state and trace the selected transaction - for idx, tx := range block.Transactions() { - // Assemble the transaction call message - msg, err := tx.AsMessage(signer) - if err != nil { - return nil, fmt.Errorf("sender retrieval failed: %v", err) - } - context := core.NewEVMContext(msg, block.Header(), api.eth.BlockChain(), nil) - - // Mutate the state if we haven't reached the tracing transaction yet - if uint64(idx) < txIndex { - vmenv := vm.NewEVM(context, stateDb, api.config, vm.Config{}) - _, _, err := core.ApplyMessage(vmenv, msg, new(core.GasPool).AddGas(tx.Gas())) - if err != nil { - return nil, fmt.Errorf("mutation failed: %v", err) - } - stateDb.DeleteSuicides() - continue - } - - vmenv := vm.NewEVM(context, stateDb, api.config, vm.Config{Debug: true, Tracer: tracer}) - ret, gas, err := core.ApplyMessage(vmenv, msg, new(core.GasPool).AddGas(tx.Gas())) - if err != nil { - return nil, fmt.Errorf("tracing failed: %v", err) - } - - switch tracer := tracer.(type) { - case *vm.StructLogger: - return ðapi.ExecutionResult{ - Gas: gas, - ReturnValue: fmt.Sprintf("%x", ret), - StructLogs: ethapi.FormatLogs(tracer.StructLogs()), - }, nil - case *ethapi.JavascriptTracer: - return tracer.GetResult() - } + // Run the transaction with tracing enabled. + vmenv := vm.NewEVM(context, statedb, api.config, vm.Config{Debug: true, Tracer: tracer}) + ret, gas, err := core.ApplyMessage(vmenv, msg, new(core.GasPool).AddGas(tx.Gas())) + if err != nil { + return nil, fmt.Errorf("tracing failed: %v", err) } - return nil, errors.New("database inconsistency") + switch tracer := tracer.(type) { + case *vm.StructLogger: + return ðapi.ExecutionResult{ + Gas: gas, + ReturnValue: fmt.Sprintf("%x", ret), + StructLogs: ethapi.FormatLogs(tracer.StructLogs()), + }, nil + case *ethapi.JavascriptTracer: + return tracer.GetResult() + default: + panic(fmt.Sprintf("bad tracer type %T", tracer)) + } +} + +// computeTxEnv returns the execution environment of a certain transaction. +func (api *PrivateDebugAPI) computeTxEnv(blockHash common.Hash, txIndex int) (core.Message, vm.Context, *state.StateDB, error) { + // Create the parent state. + block := api.eth.BlockChain().GetBlockByHash(blockHash) + if block == nil { + return nil, vm.Context{}, nil, fmt.Errorf("block %x not found", blockHash) + } + parent := api.eth.BlockChain().GetBlock(block.ParentHash(), block.NumberU64()-1) + if parent == nil { + return nil, vm.Context{}, nil, fmt.Errorf("block parent %x not found", block.ParentHash()) + } + statedb, err := api.eth.BlockChain().StateAt(parent.Root()) + if err != nil { + return nil, vm.Context{}, nil, err + } + txs := block.Transactions() + + // Recompute transactions up to the target index. + signer := types.MakeSigner(api.config, block.Number()) + for idx, tx := range txs { + // Assemble the transaction call message + msg, _ := tx.AsMessage(signer) + context := core.NewEVMContext(msg, block.Header(), api.eth.BlockChain(), nil) + if idx == txIndex { + return msg, context, statedb, nil + } + + vmenv := vm.NewEVM(context, statedb, api.config, vm.Config{}) + gp := new(core.GasPool).AddGas(tx.Gas()) + _, _, err := core.ApplyMessage(vmenv, msg, gp) + if err != nil { + return nil, vm.Context{}, nil, fmt.Errorf("tx %x failed: %v", tx.Hash(), err) + } + statedb.DeleteSuicides() + } + return nil, vm.Context{}, nil, fmt.Errorf("tx index %d out of range for block %x", txIndex, blockHash) } // Preimage is a debug API function that returns the preimage for a sha3 hash, if known. @@ -592,3 +600,48 @@ func (api *PrivateDebugAPI) Preimage(ctx context.Context, hash common.Hash) (hex func (api *PrivateDebugAPI) GetBadBlocks(ctx context.Context) ([]core.BadBlockArgs, error) { return api.eth.BlockChain().BadBlocks() } + +// StorageRangeResult is the result of a debug_storageRangeAt API call. +type StorageRangeResult struct { + Storage storageMap `json:"storage"` + NextKey *common.Hash `json:"nextKey"` // nil if Storage includes the last key in the trie. +} + +type storageMap map[common.Hash]storageEntry + +type storageEntry struct { + Key *common.Hash `json:"key"` + Value common.Hash `json:"value"` +} + +// StorageRangeAt returns the storage at the given block height and transaction index. +func (api *PrivateDebugAPI) StorageRangeAt(ctx context.Context, blockHash common.Hash, txIndex int, contractAddress common.Address, keyStart hexutil.Bytes, maxResult int) (StorageRangeResult, error) { + _, _, statedb, err := api.computeTxEnv(blockHash, txIndex) + if err != nil { + return StorageRangeResult{}, err + } + st := statedb.StorageTrie(contractAddress) + if st == nil { + return StorageRangeResult{}, fmt.Errorf("account %x doesn't exist", contractAddress) + } + return storageRangeAt(st, keyStart, maxResult), nil +} + +func storageRangeAt(st *trie.SecureTrie, start []byte, maxResult int) StorageRangeResult { + it := trie.NewIterator(st.NodeIterator(start)) + result := StorageRangeResult{Storage: storageMap{}} + for i := 0; i < maxResult && it.Next(); i++ { + e := storageEntry{Value: common.BytesToHash(it.Value)} + if preimage := st.GetKey(it.Key); preimage != nil { + preimage := common.BytesToHash(preimage) + e.Key = &preimage + } + result.Storage[common.BytesToHash(it.Key)] = e + } + // Add the 'next key' so clients can continue downloading. + if it.Next() { + next := common.BytesToHash(it.Key) + result.NextKey = &next + } + return result +} diff --git a/eth/api_test.go b/eth/api_test.go new file mode 100644 index 000000000..f8d2e9c76 --- /dev/null +++ b/eth/api_test.go @@ -0,0 +1,88 @@ +// Copyright 2016 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package eth + +import ( + "reflect" + "testing" + + "github.com/davecgh/go-spew/spew" + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/core/state" + "github.com/ethereum/go-ethereum/ethdb" +) + +var dumper = spew.ConfigState{Indent: " "} + +func TestStorageRangeAt(t *testing.T) { + // Create a state where account 0x010000... has a few storage entries. + var ( + db, _ = ethdb.NewMemDatabase() + state, _ = state.New(common.Hash{}, db) + addr = common.Address{0x01} + keys = []common.Hash{ // hashes of Keys of storage + common.HexToHash("340dd630ad21bf010b4e676dbfa9ba9a02175262d1fa356232cfde6cb5b47ef2"), + common.HexToHash("426fcb404ab2d5d8e61a3d918108006bbb0a9be65e92235bb10eefbdb6dcd053"), + common.HexToHash("48078cfed56339ea54962e72c37c7f588fc4f8e5bc173827ba75cb10a63a96a5"), + common.HexToHash("5723d2c3a83af9b735e3b7f21531e5623d183a9095a56604ead41f3582fdfb75"), + } + storage = storageMap{ + keys[0]: {Key: &common.Hash{0x02}, Value: common.Hash{0x01}}, + keys[1]: {Key: &common.Hash{0x04}, Value: common.Hash{0x02}}, + keys[2]: {Key: &common.Hash{0x01}, Value: common.Hash{0x03}}, + keys[3]: {Key: &common.Hash{0x03}, Value: common.Hash{0x04}}, + } + ) + for _, entry := range storage { + state.SetState(addr, *entry.Key, entry.Value) + } + + // Check a few combinations of limit and start/end. + tests := []struct { + start []byte + limit int + want StorageRangeResult + }{ + { + start: []byte{}, limit: 0, + want: StorageRangeResult{storageMap{}, &keys[0]}, + }, + { + start: []byte{}, limit: 100, + want: StorageRangeResult{storage, nil}, + }, + { + start: []byte{}, limit: 2, + want: StorageRangeResult{storageMap{keys[0]: storage[keys[0]], keys[1]: storage[keys[1]]}, &keys[2]}, + }, + { + start: []byte{0x00}, limit: 4, + want: StorageRangeResult{storage, nil}, + }, + { + start: []byte{0x40}, limit: 2, + want: StorageRangeResult{storageMap{keys[1]: storage[keys[1]], keys[2]: storage[keys[2]]}, &keys[3]}, + }, + } + for _, test := range tests { + result := storageRangeAt(state.StorageTrie(addr), test.start, test.limit) + if !reflect.DeepEqual(result, test.want) { + t.Fatalf("wrong result for range 0x%x.., limit %d:\ngot %s\nwant %s", + test.start, test.limit, dumper.Sdump(result), dumper.Sdump(&test.want)) + } + } +} diff --git a/internal/web3ext/web3ext.go b/internal/web3ext/web3ext.go index 72c2bd996..c9cac125d 100644 --- a/internal/web3ext/web3ext.go +++ b/internal/web3ext/web3ext.go @@ -345,6 +345,11 @@ web3._extend({ call: 'debug_getBadBlocks', params: 0, }), + new web3._extend.Method({ + name: 'storageRangeAt', + call: 'debug_storageRangeAt', + params: 5, + }), ], properties: [] });