diff --git a/core/state/dump.go b/core/state/dump.go index 8294d61b9..ffa1a7283 100644 --- a/core/state/dump.go +++ b/core/state/dump.go @@ -22,6 +22,7 @@ import ( "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/rlp" + "github.com/ethereum/go-ethereum/trie" ) type DumpAccount struct { @@ -44,7 +45,7 @@ func (self *StateDB) RawDump() Dump { Accounts: make(map[string]DumpAccount), } - it := self.trie.Iterator() + it := trie.NewIterator(self.trie.NodeIterator(nil)) for it.Next() { addr := self.trie.GetKey(it.Key) var data Account @@ -61,7 +62,7 @@ func (self *StateDB) RawDump() Dump { Code: common.Bytes2Hex(obj.Code(self.db)), Storage: make(map[string]string), } - storageIt := obj.getTrie(self.db).Iterator() + storageIt := trie.NewIterator(obj.getTrie(self.db).NodeIterator(nil)) for storageIt.Next() { account.Storage[common.Bytes2Hex(self.trie.GetKey(storageIt.Key))] = common.Bytes2Hex(storageIt.Value) } diff --git a/core/state/iterator.go b/core/state/iterator.go index 170aec983..a8a2722ae 100644 --- a/core/state/iterator.go +++ b/core/state/iterator.go @@ -75,7 +75,7 @@ func (it *NodeIterator) step() error { } // Initialize the iterator if we've just started if it.stateIt == nil { - it.stateIt = it.state.trie.NodeIterator() + it.stateIt = it.state.trie.NodeIterator(nil) } // If we had data nodes previously, we surely have at least state nodes if it.dataIt != nil { @@ -118,7 +118,7 @@ func (it *NodeIterator) step() error { if err != nil { return err } - it.dataIt = trie.NewNodeIterator(dataTrie) + it.dataIt = dataTrie.NodeIterator(nil) if !it.dataIt.Next(true) { it.dataIt = nil } diff --git a/core/state/state_object.go b/core/state/state_object.go index 7d3315303..dcad9d068 100644 --- a/core/state/state_object.go +++ b/core/state/state_object.go @@ -201,7 +201,7 @@ func (self *stateObject) setState(key, value common.Hash) { } // updateTrie writes cached storage modifications into the object's storage trie. -func (self *stateObject) updateTrie(db trie.Database) { +func (self *stateObject) updateTrie(db trie.Database) *trie.SecureTrie { tr := self.getTrie(db) for key, value := range self.dirtyStorage { delete(self.dirtyStorage, key) @@ -213,6 +213,7 @@ func (self *stateObject) updateTrie(db trie.Database) { v, _ := rlp.EncodeToBytes(bytes.TrimLeft(value[:], "\x00")) tr.Update(key[:], v) } + return tr } // UpdateRoot sets the trie root to the current root hash of @@ -280,7 +281,11 @@ func (c *stateObject) ReturnGas(gas *big.Int) {} func (self *stateObject) deepCopy(db *StateDB, onDirty func(addr common.Address)) *stateObject { stateObject := newObject(db, self.address, self.data, onDirty) - stateObject.trie = self.trie + if self.trie != nil { + // A shallow copy makes the two tries independent. + cpy := *self.trie + stateObject.trie = &cpy + } stateObject.code = self.code stateObject.dirtyStorage = self.dirtyStorage.Copy() stateObject.cachedStorage = self.dirtyStorage.Copy() diff --git a/core/state/statedb.go b/core/state/statedb.go index 0c72fc6b0..3b753a2e6 100644 --- a/core/state/statedb.go +++ b/core/state/statedb.go @@ -296,6 +296,17 @@ func (self *StateDB) GetState(a common.Address, b common.Hash) common.Hash { return common.Hash{} } +// StorageTrie returns the storage trie of an account. +// The return value is a copy and is nil for non-existent accounts. +func (self *StateDB) StorageTrie(a common.Address) *trie.SecureTrie { + stateObject := self.getStateObject(a) + if stateObject == nil { + return nil + } + cpy := stateObject.deepCopy(self, nil) + return cpy.updateTrie(self.db) +} + func (self *StateDB) HasSuicided(addr common.Address) bool { stateObject := self.getStateObject(addr) if stateObject != nil { @@ -481,7 +492,7 @@ func (db *StateDB) ForEachStorage(addr common.Address, cb func(key, value common cb(h, value) } - it := so.getTrie(db.db).Iterator() + it := trie.NewIterator(so.getTrie(db.db).NodeIterator(nil)) for it.Next() { // ignore cached values key := common.BytesToHash(db.trie.GetKey(it.Key)) diff --git a/eth/api.go b/eth/api.go index b386c08b4..61f7bdd92 100644 --- a/eth/api.go +++ b/eth/api.go @@ -20,7 +20,6 @@ import ( "bytes" "compress/gzip" "context" - "errors" "fmt" "io" "io/ioutil" @@ -41,6 +40,7 @@ import ( "github.com/ethereum/go-ethereum/params" "github.com/ethereum/go-ethereum/rlp" "github.com/ethereum/go-ethereum/rpc" + "github.com/ethereum/go-ethereum/trie" ) const defaultTraceTimeout = 5 * time.Second @@ -526,59 +526,67 @@ func (api *PrivateDebugAPI) TraceTransaction(ctx context.Context, txHash common. if tx == nil { return nil, fmt.Errorf("transaction %x not found", txHash) } - block := api.eth.BlockChain().GetBlockByHash(blockHash) - if block == nil { - return nil, fmt.Errorf("block %x not found", blockHash) - } - // Create the state database to mutate and eventually trace - parent := api.eth.BlockChain().GetBlock(block.ParentHash(), block.NumberU64()-1) - if parent == nil { - return nil, fmt.Errorf("block parent %x not found", block.ParentHash()) - } - stateDb, err := api.eth.BlockChain().StateAt(parent.Root()) + msg, context, statedb, err := api.computeTxEnv(blockHash, int(txIndex)) if err != nil { return nil, err } - signer := types.MakeSigner(api.config, block.Number()) - // Mutate the state and trace the selected transaction - for idx, tx := range block.Transactions() { - // Assemble the transaction call message - msg, err := tx.AsMessage(signer) - if err != nil { - return nil, fmt.Errorf("sender retrieval failed: %v", err) - } - context := core.NewEVMContext(msg, block.Header(), api.eth.BlockChain(), nil) - - // Mutate the state if we haven't reached the tracing transaction yet - if uint64(idx) < txIndex { - vmenv := vm.NewEVM(context, stateDb, api.config, vm.Config{}) - _, _, err := core.ApplyMessage(vmenv, msg, new(core.GasPool).AddGas(tx.Gas())) - if err != nil { - return nil, fmt.Errorf("mutation failed: %v", err) - } - stateDb.DeleteSuicides() - continue - } - - vmenv := vm.NewEVM(context, stateDb, api.config, vm.Config{Debug: true, Tracer: tracer}) - ret, gas, err := core.ApplyMessage(vmenv, msg, new(core.GasPool).AddGas(tx.Gas())) - if err != nil { - return nil, fmt.Errorf("tracing failed: %v", err) - } - - switch tracer := tracer.(type) { - case *vm.StructLogger: - return ðapi.ExecutionResult{ - Gas: gas, - ReturnValue: fmt.Sprintf("%x", ret), - StructLogs: ethapi.FormatLogs(tracer.StructLogs()), - }, nil - case *ethapi.JavascriptTracer: - return tracer.GetResult() - } + // Run the transaction with tracing enabled. + vmenv := vm.NewEVM(context, statedb, api.config, vm.Config{Debug: true, Tracer: tracer}) + ret, gas, err := core.ApplyMessage(vmenv, msg, new(core.GasPool).AddGas(tx.Gas())) + if err != nil { + return nil, fmt.Errorf("tracing failed: %v", err) } - return nil, errors.New("database inconsistency") + switch tracer := tracer.(type) { + case *vm.StructLogger: + return ðapi.ExecutionResult{ + Gas: gas, + ReturnValue: fmt.Sprintf("%x", ret), + StructLogs: ethapi.FormatLogs(tracer.StructLogs()), + }, nil + case *ethapi.JavascriptTracer: + return tracer.GetResult() + default: + panic(fmt.Sprintf("bad tracer type %T", tracer)) + } +} + +// computeTxEnv returns the execution environment of a certain transaction. +func (api *PrivateDebugAPI) computeTxEnv(blockHash common.Hash, txIndex int) (core.Message, vm.Context, *state.StateDB, error) { + // Create the parent state. + block := api.eth.BlockChain().GetBlockByHash(blockHash) + if block == nil { + return nil, vm.Context{}, nil, fmt.Errorf("block %x not found", blockHash) + } + parent := api.eth.BlockChain().GetBlock(block.ParentHash(), block.NumberU64()-1) + if parent == nil { + return nil, vm.Context{}, nil, fmt.Errorf("block parent %x not found", block.ParentHash()) + } + statedb, err := api.eth.BlockChain().StateAt(parent.Root()) + if err != nil { + return nil, vm.Context{}, nil, err + } + txs := block.Transactions() + + // Recompute transactions up to the target index. + signer := types.MakeSigner(api.config, block.Number()) + for idx, tx := range txs { + // Assemble the transaction call message + msg, _ := tx.AsMessage(signer) + context := core.NewEVMContext(msg, block.Header(), api.eth.BlockChain(), nil) + if idx == txIndex { + return msg, context, statedb, nil + } + + vmenv := vm.NewEVM(context, statedb, api.config, vm.Config{}) + gp := new(core.GasPool).AddGas(tx.Gas()) + _, _, err := core.ApplyMessage(vmenv, msg, gp) + if err != nil { + return nil, vm.Context{}, nil, fmt.Errorf("tx %x failed: %v", tx.Hash(), err) + } + statedb.DeleteSuicides() + } + return nil, vm.Context{}, nil, fmt.Errorf("tx index %d out of range for block %x", txIndex, blockHash) } // Preimage is a debug API function that returns the preimage for a sha3 hash, if known. @@ -592,3 +600,48 @@ func (api *PrivateDebugAPI) Preimage(ctx context.Context, hash common.Hash) (hex func (api *PrivateDebugAPI) GetBadBlocks(ctx context.Context) ([]core.BadBlockArgs, error) { return api.eth.BlockChain().BadBlocks() } + +// StorageRangeResult is the result of a debug_storageRangeAt API call. +type StorageRangeResult struct { + Storage storageMap `json:"storage"` + NextKey *common.Hash `json:"nextKey"` // nil if Storage includes the last key in the trie. +} + +type storageMap map[common.Hash]storageEntry + +type storageEntry struct { + Key *common.Hash `json:"key"` + Value common.Hash `json:"value"` +} + +// StorageRangeAt returns the storage at the given block height and transaction index. +func (api *PrivateDebugAPI) StorageRangeAt(ctx context.Context, blockHash common.Hash, txIndex int, contractAddress common.Address, keyStart hexutil.Bytes, maxResult int) (StorageRangeResult, error) { + _, _, statedb, err := api.computeTxEnv(blockHash, txIndex) + if err != nil { + return StorageRangeResult{}, err + } + st := statedb.StorageTrie(contractAddress) + if st == nil { + return StorageRangeResult{}, fmt.Errorf("account %x doesn't exist", contractAddress) + } + return storageRangeAt(st, keyStart, maxResult), nil +} + +func storageRangeAt(st *trie.SecureTrie, start []byte, maxResult int) StorageRangeResult { + it := trie.NewIterator(st.NodeIterator(start)) + result := StorageRangeResult{Storage: storageMap{}} + for i := 0; i < maxResult && it.Next(); i++ { + e := storageEntry{Value: common.BytesToHash(it.Value)} + if preimage := st.GetKey(it.Key); preimage != nil { + preimage := common.BytesToHash(preimage) + e.Key = &preimage + } + result.Storage[common.BytesToHash(it.Key)] = e + } + // Add the 'next key' so clients can continue downloading. + if it.Next() { + next := common.BytesToHash(it.Key) + result.NextKey = &next + } + return result +} diff --git a/eth/api_test.go b/eth/api_test.go new file mode 100644 index 000000000..f8d2e9c76 --- /dev/null +++ b/eth/api_test.go @@ -0,0 +1,88 @@ +// Copyright 2016 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package eth + +import ( + "reflect" + "testing" + + "github.com/davecgh/go-spew/spew" + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/core/state" + "github.com/ethereum/go-ethereum/ethdb" +) + +var dumper = spew.ConfigState{Indent: " "} + +func TestStorageRangeAt(t *testing.T) { + // Create a state where account 0x010000... has a few storage entries. + var ( + db, _ = ethdb.NewMemDatabase() + state, _ = state.New(common.Hash{}, db) + addr = common.Address{0x01} + keys = []common.Hash{ // hashes of Keys of storage + common.HexToHash("340dd630ad21bf010b4e676dbfa9ba9a02175262d1fa356232cfde6cb5b47ef2"), + common.HexToHash("426fcb404ab2d5d8e61a3d918108006bbb0a9be65e92235bb10eefbdb6dcd053"), + common.HexToHash("48078cfed56339ea54962e72c37c7f588fc4f8e5bc173827ba75cb10a63a96a5"), + common.HexToHash("5723d2c3a83af9b735e3b7f21531e5623d183a9095a56604ead41f3582fdfb75"), + } + storage = storageMap{ + keys[0]: {Key: &common.Hash{0x02}, Value: common.Hash{0x01}}, + keys[1]: {Key: &common.Hash{0x04}, Value: common.Hash{0x02}}, + keys[2]: {Key: &common.Hash{0x01}, Value: common.Hash{0x03}}, + keys[3]: {Key: &common.Hash{0x03}, Value: common.Hash{0x04}}, + } + ) + for _, entry := range storage { + state.SetState(addr, *entry.Key, entry.Value) + } + + // Check a few combinations of limit and start/end. + tests := []struct { + start []byte + limit int + want StorageRangeResult + }{ + { + start: []byte{}, limit: 0, + want: StorageRangeResult{storageMap{}, &keys[0]}, + }, + { + start: []byte{}, limit: 100, + want: StorageRangeResult{storage, nil}, + }, + { + start: []byte{}, limit: 2, + want: StorageRangeResult{storageMap{keys[0]: storage[keys[0]], keys[1]: storage[keys[1]]}, &keys[2]}, + }, + { + start: []byte{0x00}, limit: 4, + want: StorageRangeResult{storage, nil}, + }, + { + start: []byte{0x40}, limit: 2, + want: StorageRangeResult{storageMap{keys[1]: storage[keys[1]], keys[2]: storage[keys[2]]}, &keys[3]}, + }, + } + for _, test := range tests { + result := storageRangeAt(state.StorageTrie(addr), test.start, test.limit) + if !reflect.DeepEqual(result, test.want) { + t.Fatalf("wrong result for range 0x%x.., limit %d:\ngot %s\nwant %s", + test.start, test.limit, dumper.Sdump(result), dumper.Sdump(&test.want)) + } + } +} diff --git a/internal/web3ext/web3ext.go b/internal/web3ext/web3ext.go index 72c2bd996..c9cac125d 100644 --- a/internal/web3ext/web3ext.go +++ b/internal/web3ext/web3ext.go @@ -345,6 +345,11 @@ web3._extend({ call: 'debug_getBadBlocks', params: 0, }), + new web3._extend.Method({ + name: 'storageRangeAt', + call: 'debug_storageRangeAt', + params: 5, + }), ], properties: [] }); diff --git a/light/trie.go b/light/trie.go index 1440f2fbf..2988a16cf 100644 --- a/light/trie.go +++ b/light/trie.go @@ -19,6 +19,7 @@ package light import ( "context" + "github.com/ethereum/go-ethereum/crypto" "github.com/ethereum/go-ethereum/ethdb" "github.com/ethereum/go-ethereum/trie" ) @@ -46,26 +47,18 @@ func NewLightTrie(id *TrieID, odr OdrBackend, useFakeMap bool) *LightTrie { // retrieveKey retrieves a single key, returns true and stores nodes in local // database if successful func (t *LightTrie) retrieveKey(ctx context.Context, key []byte) bool { - r := &TrieRequest{Id: t.id, Key: key} + r := &TrieRequest{Id: t.id, Key: crypto.Keccak256(key)} return t.odr.Retrieve(ctx, r) == nil } // do tries and retries to execute a function until it returns with no error or // an error type other than MissingNodeError -func (t *LightTrie) do(ctx context.Context, fallbackKey []byte, fn func() error) error { +func (t *LightTrie) do(ctx context.Context, key []byte, fn func() error) error { err := fn() for err != nil { - mn, ok := err.(*trie.MissingNodeError) - if !ok { + if _, ok := err.(*trie.MissingNodeError); !ok { return err } - - var key []byte - if mn.PrefixLen+mn.SuffixLen > 0 { - key = mn.Key - } else { - key = fallbackKey - } if !t.retrieveKey(ctx, key) { break } diff --git a/trie/encoding.go b/trie/encoding.go index 2037118dd..e96a786e4 100644 --- a/trie/encoding.go +++ b/trie/encoding.go @@ -16,49 +16,54 @@ package trie -func compactEncode(hexSlice []byte) []byte { +// Trie keys are dealt with in three distinct encodings: +// +// KEYBYTES encoding contains the actual key and nothing else. This encoding is the +// input to most API functions. +// +// HEX encoding contains one byte for each nibble of the key and an optional trailing +// 'terminator' byte of value 0x10 which indicates whether or not the node at the key +// contains a value. Hex key encoding is used for nodes loaded in memory because it's +// convenient to access. +// +// COMPACT encoding is defined by the Ethereum Yellow Paper (it's called "hex prefix +// encoding" there) and contains the bytes of the key and a flag. The high nibble of the +// first byte contains the flag; the lowest bit encoding the oddness of the length and +// the second-lowest encoding whether the node at the key is a value node. The low nibble +// of the first byte is zero in the case of an even number of nibbles and the first nibble +// in the case of an odd number. All remaining nibbles (now an even number) fit properly +// into the remaining bytes. Compact encoding is used for nodes stored on disk. + +func hexToCompact(hex []byte) []byte { terminator := byte(0) - if hexSlice[len(hexSlice)-1] == 16 { + if hasTerm(hex) { terminator = 1 - hexSlice = hexSlice[:len(hexSlice)-1] + hex = hex[:len(hex)-1] } - var ( - odd = byte(len(hexSlice) % 2) - buflen = len(hexSlice)/2 + 1 - bi, hi = 0, 0 // indices - hs = byte(0) // shift: flips between 0 and 4 - ) - if odd == 0 { - bi = 1 - hs = 4 - } - buf := make([]byte, buflen) - buf[0] = terminator<<5 | byte(odd)<<4 - for bi < len(buf) && hi < len(hexSlice) { - buf[bi] |= hexSlice[hi] << hs - if hs == 0 { - bi++ - } - hi, hs = hi+1, hs^(1<<2) + buf := make([]byte, len(hex)/2+1) + buf[0] = terminator << 5 // the flag byte + if len(hex)&1 == 1 { + buf[0] |= 1 << 4 // odd flag + buf[0] |= hex[0] // first nibble is contained in the first byte + hex = hex[1:] } + decodeNibbles(hex, buf[1:]) return buf } -func compactDecode(str []byte) []byte { - base := compactHexDecode(str) +func compactToHex(compact []byte) []byte { + base := keybytesToHex(compact) base = base[:len(base)-1] + // apply terminator flag if base[0] >= 2 { base = append(base, 16) } - if base[0]%2 == 1 { - base = base[1:] - } else { - base = base[2:] - } - return base + // apply odd flag + chop := 2 - base[0]&1 + return base[chop:] } -func compactHexDecode(str []byte) []byte { +func keybytesToHex(str []byte) []byte { l := len(str)*2 + 1 var nibbles = make([]byte, l) for i, b := range str { @@ -69,35 +74,24 @@ func compactHexDecode(str []byte) []byte { return nibbles } -// compactHexEncode encodes a series of nibbles into a byte array -func compactHexEncode(nibbles []byte) []byte { - nl := len(nibbles) - if nl == 0 { - return nil +// hexToKeybytes turns hex nibbles into key bytes. +// This can only be used for keys of even length. +func hexToKeybytes(hex []byte) []byte { + if hasTerm(hex) { + hex = hex[:len(hex)-1] } - if nibbles[nl-1] == 16 { - nl-- + if len(hex)&1 != 0 { + panic("can't convert hex key of odd length") } - l := (nl + 1) / 2 - var str = make([]byte, l) - for i := range str { - b := nibbles[i*2] * 16 - if nl > i*2 { - b += nibbles[i*2+1] - } - str[i] = b - } - return str + key := make([]byte, (len(hex)+1)/2) + decodeNibbles(hex, key) + return key } -func decodeCompact(key []byte) []byte { - l := len(key) / 2 - var res = make([]byte, l) - for i := 0; i < l; i++ { - v1, v0 := key[2*i], key[2*i+1] - res[i] = v1*16 + v0 +func decodeNibbles(nibbles []byte, bytes []byte) { + for bi, ni := 0, 0; ni < len(nibbles); bi, ni = bi+1, ni+2 { + bytes[bi] = nibbles[ni]<<4 | nibbles[ni+1] } - return res } // prefixLen returns the length of the common prefix of a and b. @@ -114,15 +108,7 @@ func prefixLen(a, b []byte) int { return i } +// hasTerm returns whether a hex key has the terminator flag. func hasTerm(s []byte) bool { - return s[len(s)-1] == 16 -} - -func remTerm(s []byte) []byte { - if hasTerm(s) { - b := make([]byte, len(s)-1) - copy(b, s) - return b - } - return s + return len(s) > 0 && s[len(s)-1] == 16 } diff --git a/trie/encoding_test.go b/trie/encoding_test.go index 2f125ef2f..97d8da136 100644 --- a/trie/encoding_test.go +++ b/trie/encoding_test.go @@ -17,113 +17,88 @@ package trie import ( - "encoding/hex" + "bytes" "testing" - - checker "gopkg.in/check.v1" ) -func TestEncoding(t *testing.T) { checker.TestingT(t) } - -type TrieEncodingSuite struct{} - -var _ = checker.Suite(&TrieEncodingSuite{}) - -func (s *TrieEncodingSuite) TestCompactEncode(c *checker.C) { - // even compact encode - test1 := []byte{1, 2, 3, 4, 5} - res1 := compactEncode(test1) - c.Assert(res1, checker.DeepEquals, []byte("\x11\x23\x45")) - - // odd compact encode - test2 := []byte{0, 1, 2, 3, 4, 5} - res2 := compactEncode(test2) - c.Assert(res2, checker.DeepEquals, []byte("\x00\x01\x23\x45")) - - //odd terminated compact encode - test3 := []byte{0, 15, 1, 12, 11, 8 /*term*/, 16} - res3 := compactEncode(test3) - c.Assert(res3, checker.DeepEquals, []byte("\x20\x0f\x1c\xb8")) - - // even terminated compact encode - test4 := []byte{15, 1, 12, 11, 8 /*term*/, 16} - res4 := compactEncode(test4) - c.Assert(res4, checker.DeepEquals, []byte("\x3f\x1c\xb8")) -} - -func (s *TrieEncodingSuite) TestCompactHexDecode(c *checker.C) { - exp := []byte{7, 6, 6, 5, 7, 2, 6, 2, 16} - res := compactHexDecode([]byte("verb")) - c.Assert(res, checker.DeepEquals, exp) -} - -func (s *TrieEncodingSuite) TestCompactHexEncode(c *checker.C) { - exp := []byte("verb") - res := compactHexEncode([]byte{7, 6, 6, 5, 7, 2, 6, 2, 16}) - c.Assert(res, checker.DeepEquals, exp) -} - -func (s *TrieEncodingSuite) TestCompactDecode(c *checker.C) { - // odd compact decode - exp := []byte{1, 2, 3, 4, 5} - res := compactDecode([]byte("\x11\x23\x45")) - c.Assert(res, checker.DeepEquals, exp) - - // even compact decode - exp = []byte{0, 1, 2, 3, 4, 5} - res = compactDecode([]byte("\x00\x01\x23\x45")) - c.Assert(res, checker.DeepEquals, exp) - - // even terminated compact decode - exp = []byte{0, 15, 1, 12, 11, 8 /*term*/, 16} - res = compactDecode([]byte("\x20\x0f\x1c\xb8")) - c.Assert(res, checker.DeepEquals, exp) - - // even terminated compact decode - exp = []byte{15, 1, 12, 11, 8 /*term*/, 16} - res = compactDecode([]byte("\x3f\x1c\xb8")) - c.Assert(res, checker.DeepEquals, exp) -} - -func (s *TrieEncodingSuite) TestDecodeCompact(c *checker.C) { - exp, _ := hex.DecodeString("012345") - res := decodeCompact([]byte{0, 1, 2, 3, 4, 5}) - c.Assert(res, checker.DeepEquals, exp) - - exp, _ = hex.DecodeString("012345") - res = decodeCompact([]byte{0, 1, 2, 3, 4, 5, 16}) - c.Assert(res, checker.DeepEquals, exp) - - exp, _ = hex.DecodeString("abcdef") - res = decodeCompact([]byte{10, 11, 12, 13, 14, 15}) - c.Assert(res, checker.DeepEquals, exp) -} - -func BenchmarkCompactEncode(b *testing.B) { - - testBytes := []byte{0, 15, 1, 12, 11, 8 /*term*/, 16} - for i := 0; i < b.N; i++ { - compactEncode(testBytes) +func TestHexCompact(t *testing.T) { + tests := []struct{ hex, compact []byte }{ + // empty keys, with and without terminator. + {hex: []byte{}, compact: []byte{0x00}}, + {hex: []byte{16}, compact: []byte{0x20}}, + // odd length, no terminator + {hex: []byte{1, 2, 3, 4, 5}, compact: []byte{0x11, 0x23, 0x45}}, + // even length, no terminator + {hex: []byte{0, 1, 2, 3, 4, 5}, compact: []byte{0x00, 0x01, 0x23, 0x45}}, + // odd length, terminator + {hex: []byte{15, 1, 12, 11, 8, 16 /*term*/}, compact: []byte{0x3f, 0x1c, 0xb8}}, + // even length, terminator + {hex: []byte{0, 15, 1, 12, 11, 8, 16 /*term*/}, compact: []byte{0x20, 0x0f, 0x1c, 0xb8}}, + } + for _, test := range tests { + if c := hexToCompact(test.hex); !bytes.Equal(c, test.compact) { + t.Errorf("hexToCompact(%x) -> %x, want %x", test.hex, c, test.compact) + } + if h := compactToHex(test.compact); !bytes.Equal(h, test.hex) { + t.Errorf("compactToHex(%x) -> %x, want %x", test.compact, h, test.hex) + } } } -func BenchmarkCompactDecode(b *testing.B) { - testBytes := []byte{0, 15, 1, 12, 11, 8 /*term*/, 16} - for i := 0; i < b.N; i++ { - compactDecode(testBytes) +func TestHexKeybytes(t *testing.T) { + tests := []struct{ key, hexIn, hexOut []byte }{ + {key: []byte{}, hexIn: []byte{16}, hexOut: []byte{16}}, + {key: []byte{}, hexIn: []byte{}, hexOut: []byte{16}}, + { + key: []byte{0x12, 0x34, 0x56}, + hexIn: []byte{1, 2, 3, 4, 5, 6, 16}, + hexOut: []byte{1, 2, 3, 4, 5, 6, 16}, + }, + { + key: []byte{0x12, 0x34, 0x5}, + hexIn: []byte{1, 2, 3, 4, 0, 5, 16}, + hexOut: []byte{1, 2, 3, 4, 0, 5, 16}, + }, + { + key: []byte{0x12, 0x34, 0x56}, + hexIn: []byte{1, 2, 3, 4, 5, 6}, + hexOut: []byte{1, 2, 3, 4, 5, 6, 16}, + }, + } + for _, test := range tests { + if h := keybytesToHex(test.key); !bytes.Equal(h, test.hexOut) { + t.Errorf("keybytesToHex(%x) -> %x, want %x", test.key, h, test.hexOut) + } + if k := hexToKeybytes(test.hexIn); !bytes.Equal(k, test.key) { + t.Errorf("hexToKeybytes(%x) -> %x, want %x", test.hexIn, k, test.key) + } } } -func BenchmarkCompactHexDecode(b *testing.B) { +func BenchmarkHexToCompact(b *testing.B) { + testBytes := []byte{0, 15, 1, 12, 11, 8, 16 /*term*/} + for i := 0; i < b.N; i++ { + hexToCompact(testBytes) + } +} + +func BenchmarkCompactToHex(b *testing.B) { + testBytes := []byte{0, 15, 1, 12, 11, 8, 16 /*term*/} + for i := 0; i < b.N; i++ { + compactToHex(testBytes) + } +} + +func BenchmarkKeybytesToHex(b *testing.B) { testBytes := []byte{7, 6, 6, 5, 7, 2, 6, 2, 16} for i := 0; i < b.N; i++ { - compactHexDecode(testBytes) + keybytesToHex(testBytes) } } -func BenchmarkDecodeCompact(b *testing.B) { +func BenchmarkHexToKeybytes(b *testing.B) { testBytes := []byte{7, 6, 6, 5, 7, 2, 6, 2, 16} for i := 0; i < b.N; i++ { - decodeCompact(testBytes) + hexToKeybytes(testBytes) } } diff --git a/trie/errors.go b/trie/errors.go index 76129a70b..e23f9d563 100644 --- a/trie/errors.go +++ b/trie/errors.go @@ -30,10 +30,6 @@ import ( // // RootHash is the original root of the trie that contains the node // -// Key is a binary-encoded key that contains the prefix that leads to the first -// missing node and optionally a suffix that hints on which further nodes should -// also be retrieved -// // PrefixLen is the nibble length of the key prefix that leads from the root to // the missing node // @@ -42,7 +38,6 @@ import ( // such hints in the error message) type MissingNodeError struct { RootHash, NodeHash common.Hash - Key []byte PrefixLen, SuffixLen int } diff --git a/trie/hasher.go b/trie/hasher.go index 98c309531..85b6b60f5 100644 --- a/trie/hasher.go +++ b/trie/hasher.go @@ -105,7 +105,7 @@ func (h *hasher) hashChildren(original node, db DatabaseWriter) (node, node, err case *shortNode: // Hash the short node's child, caching the newly hashed subtree collapsed, cached := n.copy(), n.copy() - collapsed.Key = compactEncode(n.Key) + collapsed.Key = hexToCompact(n.Key) cached.Key = common.CopyBytes(n.Key) if _, ok := n.Val.(valueNode); !ok { diff --git a/trie/iterator.go b/trie/iterator.go index 42149a7d3..26ae1d5ad 100644 --- a/trie/iterator.go +++ b/trie/iterator.go @@ -19,9 +19,13 @@ package trie import ( "bytes" "container/heap" + "errors" + "github.com/ethereum/go-ethereum/common" ) +var iteratorEnd = errors.New("end of iteration") + // Iterator is a key-value trie iterator that traverses a Trie. type Iterator struct { nodeIt NodeIterator @@ -30,15 +34,8 @@ type Iterator struct { Value []byte // Current data value on which the iterator is positioned on } -// NewIterator creates a new key-value iterator. -func NewIterator(trie *Trie) *Iterator { - return &Iterator{ - nodeIt: NewNodeIterator(trie), - } -} - -// FromNodeIterator creates a new key-value iterator from a node iterator -func NewIteratorFromNodeIterator(it NodeIterator) *Iterator { +// NewIterator creates a new key-value iterator from a node iterator +func NewIterator(it NodeIterator) *Iterator { return &Iterator{ nodeIt: it, } @@ -48,7 +45,7 @@ func NewIteratorFromNodeIterator(it NodeIterator) *Iterator { func (it *Iterator) Next() bool { for it.nodeIt.Next(true) { if it.nodeIt.Leaf() { - it.Key = decodeCompact(it.nodeIt.Path()) + it.Key = hexToKeybytes(it.nodeIt.Path()) it.Value = it.nodeIt.LeafBlob() return true } @@ -85,25 +82,24 @@ type nodeIteratorState struct { hash common.Hash // Hash of the node being iterated (nil if not standalone) node node // Trie node being iterated parent common.Hash // Hash of the first full ancestor node (nil if current is the root) - child int // Child to be processed next + index int // Child to be processed next pathlen int // Length of the path to this node } type nodeIterator struct { trie *Trie // Trie being iterated stack []*nodeIteratorState // Hierarchy of trie nodes persisting the iteration state - - err error // Failure set in case of an internal error in the iterator - - path []byte // Path to the current node + err error // Failure set in case of an internal error in the iterator + path []byte // Path to the current node } -// NewNodeIterator creates an post-order trie iterator. -func NewNodeIterator(trie *Trie) NodeIterator { +func newNodeIterator(trie *Trie, start []byte) NodeIterator { if trie.Hash() == emptyState { return new(nodeIterator) } - return &nodeIterator{trie: trie} + it := &nodeIterator{trie: trie} + it.seek(start) + return it } // Hash returns the hash of the current node @@ -153,6 +149,9 @@ func (it *nodeIterator) Path() []byte { // Error returns the error set in case of an internal error in the iterator func (it *nodeIterator) Error() error { + if it.err == iteratorEnd { + return nil + } return it.err } @@ -161,47 +160,54 @@ func (it *nodeIterator) Error() error { // sets the Error field to the encountered failure. If `descend` is false, // skips iterating over any subnodes of the current node. func (it *nodeIterator) Next(descend bool) bool { - // If the iterator failed previously, don't do anything if it.err != nil { return false } // Otherwise step forward with the iterator and report any errors - if err := it.step(descend); err != nil { + state, parentIndex, path, err := it.peek(descend) + if err != nil { it.err = err return false } - return it.trie != nil + it.push(state, parentIndex, path) + return true } -// step moves the iterator to the next node of the trie. -func (it *nodeIterator) step(descend bool) error { - if it.trie == nil { - // Abort if we reached the end of the iteration - return nil +func (it *nodeIterator) seek(prefix []byte) { + // The path we're looking for is the hex encoded key without terminator. + key := keybytesToHex(prefix) + key = key[:len(key)-1] + // Move forward until we're just before the closest match to key. + for { + state, parentIndex, path, err := it.peek(bytes.HasPrefix(key, it.path)) + if err != nil || bytes.Compare(path, key) >= 0 { + it.err = err + return + } + it.push(state, parentIndex, path) } +} + +// peek creates the next state of the iterator. +func (it *nodeIterator) peek(descend bool) (*nodeIteratorState, *int, []byte, error) { if len(it.stack) == 0 { // Initialize the iterator if we've just started. root := it.trie.Hash() - state := &nodeIteratorState{node: it.trie.root, child: -1} + state := &nodeIteratorState{node: it.trie.root, index: -1} if root != emptyRoot { state.hash = root } - it.stack = append(it.stack, state) - return nil + return state, nil, nil, nil } - if !descend { // If we're skipping children, pop the current node first - it.path = it.path[:it.stack[len(it.stack)-1].pathlen] - it.stack = it.stack[:len(it.stack)-1] + it.pop() } // Continue iteration to the next child -outer: for { if len(it.stack) == 0 { - it.trie = nil - return nil + return nil, nil, nil, iteratorEnd } parent := it.stack[len(it.stack)-1] ancestor := parent.hash @@ -209,63 +215,76 @@ outer: ancestor = parent.parent } if node, ok := parent.node.(*fullNode); ok { - // Full node, iterate over children - for parent.child++; parent.child < len(node.Children); parent.child++ { - child := node.Children[parent.child] + // Full node, move to the first non-nil child. + for i := parent.index + 1; i < len(node.Children); i++ { + child := node.Children[i] if child != nil { hash, _ := child.cache() - it.stack = append(it.stack, &nodeIteratorState{ + state := &nodeIteratorState{ hash: common.BytesToHash(hash), node: child, parent: ancestor, - child: -1, + index: -1, pathlen: len(it.path), - }) - it.path = append(it.path, byte(parent.child)) - break outer + } + path := append(it.path, byte(i)) + parent.index = i - 1 + return state, &parent.index, path, nil } } } else if node, ok := parent.node.(*shortNode); ok { // Short node, return the pointer singleton child - if parent.child < 0 { - parent.child++ + if parent.index < 0 { hash, _ := node.Val.cache() - it.stack = append(it.stack, &nodeIteratorState{ + state := &nodeIteratorState{ hash: common.BytesToHash(hash), node: node.Val, parent: ancestor, - child: -1, + index: -1, pathlen: len(it.path), - }) - if hasTerm(node.Key) { - it.path = append(it.path, node.Key[:len(node.Key)-1]...) - } else { - it.path = append(it.path, node.Key...) } - break + var path []byte + if hasTerm(node.Key) { + path = append(it.path, node.Key[:len(node.Key)-1]...) + } else { + path = append(it.path, node.Key...) + } + return state, &parent.index, path, nil } } else if hash, ok := parent.node.(hashNode); ok { // Hash node, resolve the hash child from the database - if parent.child < 0 { - parent.child++ + if parent.index < 0 { node, err := it.trie.resolveHash(hash, nil, nil) if err != nil { - return err + return it.stack[len(it.stack)-1], &parent.index, it.path, err } - it.stack = append(it.stack, &nodeIteratorState{ + state := &nodeIteratorState{ hash: common.BytesToHash(hash), node: node, parent: ancestor, - child: -1, + index: -1, pathlen: len(it.path), - }) - break + } + return state, &parent.index, it.path, nil } } - it.path = it.path[:parent.pathlen] - it.stack = it.stack[:len(it.stack)-1] + // No more child nodes, move back up. + it.pop() } - return nil +} + +func (it *nodeIterator) push(state *nodeIteratorState, parentIndex *int, path []byte) { + it.path = path + it.stack = append(it.stack, state) + if parentIndex != nil { + *parentIndex += 1 + } +} + +func (it *nodeIterator) pop() { + parent := it.stack[len(it.stack)-1] + it.path = it.path[:parent.pathlen] + it.stack = it.stack[:len(it.stack)-1] } func compareNodes(a, b NodeIterator) int { diff --git a/trie/iterator_test.go b/trie/iterator_test.go index c101bb7b0..f161fd99d 100644 --- a/trie/iterator_test.go +++ b/trie/iterator_test.go @@ -17,6 +17,8 @@ package trie import ( + "bytes" + "fmt" "testing" "github.com/ethereum/go-ethereum/common" @@ -42,7 +44,7 @@ func TestIterator(t *testing.T) { trie.Commit() found := make(map[string]string) - it := NewIterator(trie) + it := NewIterator(trie.NodeIterator(nil)) for it.Next() { found[string(it.Key)] = string(it.Value) } @@ -72,7 +74,7 @@ func TestIteratorLargeData(t *testing.T) { vals[string(value2.k)] = value2 } - it := NewIterator(trie) + it := NewIterator(trie.NodeIterator(nil)) for it.Next() { vals[string(it.Key)].t = true } @@ -99,7 +101,7 @@ func TestNodeIteratorCoverage(t *testing.T) { // Gather all the node hashes found by the iterator hashes := make(map[common.Hash]struct{}) - for it := NewNodeIterator(trie); it.Next(true); { + for it := trie.NodeIterator(nil); it.Next(true); { if it.Hash() != (common.Hash{}) { hashes[it.Hash()] = struct{}{} } @@ -117,18 +119,20 @@ func TestNodeIteratorCoverage(t *testing.T) { } } -var testdata1 = []struct{ k, v string }{ - {"bar", "b"}, +type kvs struct{ k, v string } + +var testdata1 = []kvs{ {"barb", "ba"}, - {"bars", "bb"}, {"bard", "bc"}, + {"bars", "bb"}, + {"bar", "b"}, {"fab", "z"}, - {"foo", "a"}, {"food", "ab"}, {"foos", "aa"}, + {"foo", "a"}, } -var testdata2 = []struct{ k, v string }{ +var testdata2 = []kvs{ {"aardvark", "c"}, {"bar", "b"}, {"barb", "bd"}, @@ -140,6 +144,47 @@ var testdata2 = []struct{ k, v string }{ {"jars", "d"}, } +func TestIteratorSeek(t *testing.T) { + trie := newEmpty() + for _, val := range testdata1 { + trie.Update([]byte(val.k), []byte(val.v)) + } + + // Seek to the middle. + it := NewIterator(trie.NodeIterator([]byte("fab"))) + if err := checkIteratorOrder(testdata1[4:], it); err != nil { + t.Fatal(err) + } + + // Seek to a non-existent key. + it = NewIterator(trie.NodeIterator([]byte("barc"))) + if err := checkIteratorOrder(testdata1[1:], it); err != nil { + t.Fatal(err) + } + + // Seek beyond the end. + it = NewIterator(trie.NodeIterator([]byte("z"))) + if err := checkIteratorOrder(nil, it); err != nil { + t.Fatal(err) + } +} + +func checkIteratorOrder(want []kvs, it *Iterator) error { + for it.Next() { + if len(want) == 0 { + return fmt.Errorf("didn't expect any more values, got key %q", it.Key) + } + if !bytes.Equal(it.Key, []byte(want[0].k)) { + return fmt.Errorf("wrong key: got %q, want %q", it.Key, want[0].k) + } + want = want[1:] + } + if len(want) > 0 { + return fmt.Errorf("iterator ended early, want key %q", want[0]) + } + return nil +} + func TestDifferenceIterator(t *testing.T) { triea := newEmpty() for _, val := range testdata1 { @@ -154,8 +199,8 @@ func TestDifferenceIterator(t *testing.T) { trieb.Commit() found := make(map[string]string) - di, _ := NewDifferenceIterator(NewNodeIterator(triea), NewNodeIterator(trieb)) - it := NewIteratorFromNodeIterator(di) + di, _ := NewDifferenceIterator(triea.NodeIterator(nil), trieb.NodeIterator(nil)) + it := NewIterator(di) for it.Next() { found[string(it.Key)] = string(it.Value) } @@ -189,8 +234,8 @@ func TestUnionIterator(t *testing.T) { } trieb.Commit() - di, _ := NewUnionIterator([]NodeIterator{NewNodeIterator(triea), NewNodeIterator(trieb)}) - it := NewIteratorFromNodeIterator(di) + di, _ := NewUnionIterator([]NodeIterator{triea.NodeIterator(nil), trieb.NodeIterator(nil)}) + it := NewIterator(di) all := []struct{ k, v string }{ {"aardvark", "c"}, diff --git a/trie/node.go b/trie/node.go index 4aa0cab65..a7697fc0c 100644 --- a/trie/node.go +++ b/trie/node.go @@ -139,8 +139,8 @@ func decodeShort(hash, buf, elems []byte, cachegen uint16) (node, error) { return nil, err } flag := nodeFlag{hash: hash, gen: cachegen} - key := compactDecode(kbuf) - if key[len(key)-1] == 16 { + key := compactToHex(kbuf) + if hasTerm(key) { // value node val, _, err := rlp.SplitString(rest) if err != nil { diff --git a/trie/proof.go b/trie/proof.go index 06cf827ab..fb7734b86 100644 --- a/trie/proof.go +++ b/trie/proof.go @@ -38,7 +38,7 @@ import ( // absence of the key. func (t *Trie) Prove(key []byte) []rlp.RawValue { // Collect all nodes on the path to key. - key = compactHexDecode(key) + key = keybytesToHex(key) nodes := []node{} tn := t.root for len(key) > 0 && tn != nil { @@ -89,7 +89,7 @@ func (t *Trie) Prove(key []byte) []rlp.RawValue { // returns an error if the proof contains invalid trie nodes or the // wrong value. func VerifyProof(rootHash common.Hash, key []byte, proof []rlp.RawValue) (value []byte, err error) { - key = compactHexDecode(key) + key = keybytesToHex(key) sha := sha3.NewKeccak256() wantHash := rootHash.Bytes() for i, buf := range proof { diff --git a/trie/secure_trie.go b/trie/secure_trie.go index 113fb6a1a..37d1d4b09 100644 --- a/trie/secure_trie.go +++ b/trie/secure_trie.go @@ -156,12 +156,10 @@ func (t *SecureTrie) Root() []byte { return t.trie.Root() } -func (t *SecureTrie) Iterator() *Iterator { - return t.trie.Iterator() -} - -func (t *SecureTrie) NodeIterator() NodeIterator { - return NewNodeIterator(&t.trie) +// NodeIterator returns an iterator that returns nodes of the underlying trie. Iteration +// starts at the key after the given start key. +func (t *SecureTrie) NodeIterator(start []byte) NodeIterator { + return t.trie.NodeIterator(start) } // CommitTo writes all nodes and the secure hash pre-images to the given database. diff --git a/trie/sync_test.go b/trie/sync_test.go index acae039cd..1e27cbb67 100644 --- a/trie/sync_test.go +++ b/trie/sync_test.go @@ -80,7 +80,7 @@ func checkTrieConsistency(db Database, root common.Hash) error { if err != nil { return nil // // Consider a non existent state consistent } - it := NewNodeIterator(trie) + it := trie.NodeIterator(nil) for it.Next(true) { } return it.Error() diff --git a/trie/trie.go b/trie/trie.go index 2a6044068..5759f97e3 100644 --- a/trie/trie.go +++ b/trie/trie.go @@ -125,9 +125,10 @@ func New(root common.Hash, db Database) (*Trie, error) { return trie, nil } -// Iterator returns an iterator over all mappings in the trie. -func (t *Trie) Iterator() *Iterator { - return NewIterator(t) +// NodeIterator returns an iterator that returns nodes of the trie. Iteration starts at +// the key after the given start key. +func (t *Trie) NodeIterator(start []byte) NodeIterator { + return newNodeIterator(t, start) } // Get returns the value for key stored in the trie. @@ -144,7 +145,7 @@ func (t *Trie) Get(key []byte) []byte { // The value bytes must not be modified by the caller. // If a node was not found in the database, a MissingNodeError is returned. func (t *Trie) TryGet(key []byte) ([]byte, error) { - key = compactHexDecode(key) + key = keybytesToHex(key) value, newroot, didResolve, err := t.tryGet(t.root, key, 0) if err == nil && didResolve { t.root = newroot @@ -211,7 +212,7 @@ func (t *Trie) Update(key, value []byte) { // // If a node was not found in the database, a MissingNodeError is returned. func (t *Trie) TryUpdate(key, value []byte) error { - k := compactHexDecode(key) + k := keybytesToHex(key) if len(value) != 0 { _, n, err := t.insert(t.root, nil, k, valueNode(value)) if err != nil { @@ -307,7 +308,7 @@ func (t *Trie) Delete(key []byte) { // TryDelete removes any existing value for key from the trie. // If a node was not found in the database, a MissingNodeError is returned. func (t *Trie) TryDelete(key []byte) error { - k := compactHexDecode(key) + k := keybytesToHex(key) _, n, err := t.delete(t.root, nil, k) if err != nil { return err @@ -450,7 +451,6 @@ func (t *Trie) resolveHash(n hashNode, prefix, suffix []byte) (node, error) { return nil, &MissingNodeError{ RootHash: t.originalRoot, NodeHash: common.BytesToHash(n), - Key: compactHexEncode(append(prefix, suffix...)), PrefixLen: len(prefix), SuffixLen: len(suffix), } diff --git a/trie/trie_test.go b/trie/trie_test.go index 01ae3a4e7..61adbba0c 100644 --- a/trie/trie_test.go +++ b/trie/trie_test.go @@ -439,7 +439,7 @@ func runRandTest(rt randTest) bool { tr = newtr case opItercheckhash: checktr, _ := New(common.Hash{}, nil) - it := tr.Iterator() + it := NewIterator(tr.NodeIterator(nil)) for it.Next() { checktr.Update(it.Key, it.Value) }