Patch for concurrent iterator & others (onto v1.11.6) #386

Closed
roysc wants to merge 1565 commits from v1.11.6-statediff-v5 into master
28 changed files with 260 additions and 233 deletions
Showing only changes of commit 99f81d2724 - Show all commits

View File

@ -45,9 +45,10 @@ func (h *testHasher) Reset() {
h.hasher.Reset() h.hasher.Reset()
} }
func (h *testHasher) Update(key, val []byte) { func (h *testHasher) Update(key, val []byte) error {
h.hasher.Write(key) h.hasher.Write(key)
h.hasher.Write(val) h.hasher.Write(val)
return nil
} }
func (h *testHasher) Hash() common.Hash { func (h *testHasher) Hash() common.Hash {

View File

@ -371,7 +371,7 @@ func stackTrieGenerate(db ethdb.KeyValueWriter, scheme string, owner common.Hash
} }
t := trie.NewStackTrieWithOwner(nodeWriter, owner) t := trie.NewStackTrieWithOwner(nodeWriter, owner)
for leaf := range in { for leaf := range in {
t.TryUpdate(leaf.key[:], leaf.value) t.Update(leaf.key[:], leaf.value)
} }
var root common.Hash var root common.Hash
if db == nil { if db == nil {

View File

@ -230,7 +230,7 @@ func (dl *diskLayer) proveRange(ctx *generatorContext, trieId *trie.ID, prefix [
if origin == nil && !diskMore { if origin == nil && !diskMore {
stackTr := trie.NewStackTrie(nil) stackTr := trie.NewStackTrie(nil)
for i, key := range keys { for i, key := range keys {
stackTr.TryUpdate(key, vals[i]) stackTr.Update(key, vals[i])
} }
if gotRoot := stackTr.Hash(); gotRoot != root { if gotRoot := stackTr.Hash(); gotRoot != root {
return &proofResult{ return &proofResult{

View File

@ -161,7 +161,7 @@ func newHelper() *testHelper {
func (t *testHelper) addTrieAccount(acckey string, acc *Account) { func (t *testHelper) addTrieAccount(acckey string, acc *Account) {
val, _ := rlp.EncodeToBytes(acc) val, _ := rlp.EncodeToBytes(acc)
t.accTrie.Update([]byte(acckey), val) t.accTrie.MustUpdate([]byte(acckey), val)
} }
func (t *testHelper) addSnapAccount(acckey string, acc *Account) { func (t *testHelper) addSnapAccount(acckey string, acc *Account) {
@ -186,7 +186,7 @@ func (t *testHelper) makeStorageTrie(stateRoot, owner common.Hash, keys []string
id := trie.StorageTrieID(stateRoot, owner, common.Hash{}) id := trie.StorageTrieID(stateRoot, owner, common.Hash{})
stTrie, _ := trie.NewStateTrie(id, t.triedb) stTrie, _ := trie.NewStateTrie(id, t.triedb)
for i, k := range keys { for i, k := range keys {
stTrie.Update([]byte(k), []byte(vals[i])) stTrie.MustUpdate([]byte(k), []byte(vals[i]))
} }
if !commit { if !commit {
return stTrie.Hash().Bytes() return stTrie.Hash().Bytes()
@ -491,7 +491,7 @@ func TestGenerateWithExtraAccounts(t *testing.T) {
) )
acc := &Account{Balance: big.NewInt(1), Root: stRoot, CodeHash: types.EmptyCodeHash.Bytes()} acc := &Account{Balance: big.NewInt(1), Root: stRoot, CodeHash: types.EmptyCodeHash.Bytes()}
val, _ := rlp.EncodeToBytes(acc) val, _ := rlp.EncodeToBytes(acc)
helper.accTrie.Update([]byte("acc-1"), val) // 0x9250573b9c18c664139f3b6a7a8081b7d8f8916a8fcc5d94feec6c29f5fd4e9e helper.accTrie.MustUpdate([]byte("acc-1"), val) // 0x9250573b9c18c664139f3b6a7a8081b7d8f8916a8fcc5d94feec6c29f5fd4e9e
// Identical in the snap // Identical in the snap
key := hashData([]byte("acc-1")) key := hashData([]byte("acc-1"))
@ -562,7 +562,7 @@ func TestGenerateWithManyExtraAccounts(t *testing.T) {
) )
acc := &Account{Balance: big.NewInt(1), Root: stRoot, CodeHash: types.EmptyCodeHash.Bytes()} acc := &Account{Balance: big.NewInt(1), Root: stRoot, CodeHash: types.EmptyCodeHash.Bytes()}
val, _ := rlp.EncodeToBytes(acc) val, _ := rlp.EncodeToBytes(acc)
helper.accTrie.Update([]byte("acc-1"), val) // 0x9250573b9c18c664139f3b6a7a8081b7d8f8916a8fcc5d94feec6c29f5fd4e9e helper.accTrie.MustUpdate([]byte("acc-1"), val) // 0x9250573b9c18c664139f3b6a7a8081b7d8f8916a8fcc5d94feec6c29f5fd4e9e
// Identical in the snap // Identical in the snap
key := hashData([]byte("acc-1")) key := hashData([]byte("acc-1"))
@ -613,8 +613,8 @@ func TestGenerateWithExtraBeforeAndAfter(t *testing.T) {
{ {
acc := &Account{Balance: big.NewInt(1), Root: types.EmptyRootHash.Bytes(), CodeHash: types.EmptyCodeHash.Bytes()} acc := &Account{Balance: big.NewInt(1), Root: types.EmptyRootHash.Bytes(), CodeHash: types.EmptyCodeHash.Bytes()}
val, _ := rlp.EncodeToBytes(acc) val, _ := rlp.EncodeToBytes(acc)
helper.accTrie.Update(common.HexToHash("0x03").Bytes(), val) helper.accTrie.MustUpdate(common.HexToHash("0x03").Bytes(), val)
helper.accTrie.Update(common.HexToHash("0x07").Bytes(), val) helper.accTrie.MustUpdate(common.HexToHash("0x07").Bytes(), val)
rawdb.WriteAccountSnapshot(helper.diskdb, common.HexToHash("0x01"), val) rawdb.WriteAccountSnapshot(helper.diskdb, common.HexToHash("0x01"), val)
rawdb.WriteAccountSnapshot(helper.diskdb, common.HexToHash("0x02"), val) rawdb.WriteAccountSnapshot(helper.diskdb, common.HexToHash("0x02"), val)
@ -650,7 +650,7 @@ func TestGenerateWithMalformedSnapdata(t *testing.T) {
{ {
acc := &Account{Balance: big.NewInt(1), Root: types.EmptyRootHash.Bytes(), CodeHash: types.EmptyCodeHash.Bytes()} acc := &Account{Balance: big.NewInt(1), Root: types.EmptyRootHash.Bytes(), CodeHash: types.EmptyCodeHash.Bytes()}
val, _ := rlp.EncodeToBytes(acc) val, _ := rlp.EncodeToBytes(acc)
helper.accTrie.Update(common.HexToHash("0x03").Bytes(), val) helper.accTrie.MustUpdate(common.HexToHash("0x03").Bytes(), val)
junk := make([]byte, 100) junk := make([]byte, 100)
copy(junk, []byte{0xde, 0xad}) copy(junk, []byte{0xde, 0xad})

View File

@ -213,14 +213,14 @@ func testIterativeStateSync(t *testing.T, count int, commit bool, bypath bool) {
for i, node := range nodeElements { for i, node := range nodeElements {
if bypath { if bypath {
if len(node.syncPath) == 1 { if len(node.syncPath) == 1 {
data, _, err := srcTrie.TryGetNode(node.syncPath[0]) data, _, err := srcTrie.GetNode(node.syncPath[0])
if err != nil { if err != nil {
t.Fatalf("failed to retrieve node data for path %x: %v", node.syncPath[0], err) t.Fatalf("failed to retrieve node data for path %x: %v", node.syncPath[0], err)
} }
nodeResults[i] = trie.NodeSyncResult{Path: node.path, Data: data} nodeResults[i] = trie.NodeSyncResult{Path: node.path, Data: data}
} else { } else {
var acc types.StateAccount var acc types.StateAccount
if err := rlp.DecodeBytes(srcTrie.Get(node.syncPath[0]), &acc); err != nil { if err := rlp.DecodeBytes(srcTrie.MustGet(node.syncPath[0]), &acc); err != nil {
t.Fatalf("failed to decode account on path %x: %v", node.syncPath[0], err) t.Fatalf("failed to decode account on path %x: %v", node.syncPath[0], err)
} }
id := trie.StorageTrieID(srcRoot, common.BytesToHash(node.syncPath[0]), acc.Root) id := trie.StorageTrieID(srcRoot, common.BytesToHash(node.syncPath[0]), acc.Root)
@ -228,7 +228,7 @@ func testIterativeStateSync(t *testing.T, count int, commit bool, bypath bool) {
if err != nil { if err != nil {
t.Fatalf("failed to retriev storage trie for path %x: %v", node.syncPath[1], err) t.Fatalf("failed to retriev storage trie for path %x: %v", node.syncPath[1], err)
} }
data, _, err := stTrie.TryGetNode(node.syncPath[1]) data, _, err := stTrie.GetNode(node.syncPath[1])
if err != nil { if err != nil {
t.Fatalf("failed to retrieve node data for path %x: %v", node.syncPath[1], err) t.Fatalf("failed to retrieve node data for path %x: %v", node.syncPath[1], err)
} }

View File

@ -232,9 +232,10 @@ func (h *testHasher) Reset() {
h.hasher.Reset() h.hasher.Reset()
} }
func (h *testHasher) Update(key, val []byte) { func (h *testHasher) Update(key, val []byte) error {
h.hasher.Write(key) h.hasher.Write(key)
h.hasher.Write(val) h.hasher.Write(val)
return nil
} }
func (h *testHasher) Hash() common.Hash { func (h *testHasher) Hash() common.Hash {

View File

@ -62,7 +62,7 @@ func prefixedRlpHash(prefix byte, x interface{}) (h common.Hash) {
// This is internal, do not use. // This is internal, do not use.
type TrieHasher interface { type TrieHasher interface {
Reset() Reset()
Update([]byte, []byte) Update([]byte, []byte) error
Hash() common.Hash Hash() common.Hash
} }
@ -93,6 +93,9 @@ func DeriveSha(list DerivableList, hasher TrieHasher) common.Hash {
// StackTrie requires values to be inserted in increasing hash order, which is not the // StackTrie requires values to be inserted in increasing hash order, which is not the
// order that `list` provides hashes in. This insertion sequence ensures that the // order that `list` provides hashes in. This insertion sequence ensures that the
// order is correct. // order is correct.
//
// The error returned by hasher is omitted because hasher will produce an incorrect
// hash in case any error occurs.
var indexBuf []byte var indexBuf []byte
for i := 1; i < list.Len() && i <= 0x7f; i++ { for i := 1; i < list.Len() && i <= 0x7f; i++ {
indexBuf = rlp.AppendUint64(indexBuf[:0], uint64(i)) indexBuf = rlp.AppendUint64(indexBuf[:0], uint64(i))

View File

@ -219,9 +219,10 @@ func (d *hashToHumanReadable) Reset() {
d.data = make([]byte, 0) d.data = make([]byte, 0)
} }
func (d *hashToHumanReadable) Update(i []byte, i2 []byte) { func (d *hashToHumanReadable) Update(i []byte, i2 []byte) error {
l := fmt.Sprintf("%x %x\n", i, i2) l := fmt.Sprintf("%x %x\n", i, i2)
d.data = append(d.data, []byte(l)...) d.data = append(d.data, []byte(l)...)
return nil
} }
func (d *hashToHumanReadable) Hash() common.Hash { func (d *hashToHumanReadable) Hash() common.Hash {

View File

@ -216,7 +216,7 @@ func defaultTrieRequestHandler(t *testPeer, requestId uint64, root common.Hash,
for _, pathset := range paths { for _, pathset := range paths {
switch len(pathset) { switch len(pathset) {
case 1: case 1:
blob, _, err := t.accountTrie.TryGetNode(pathset[0]) blob, _, err := t.accountTrie.GetNode(pathset[0])
if err != nil { if err != nil {
t.logger.Info("Error handling req", "error", err) t.logger.Info("Error handling req", "error", err)
break break
@ -225,7 +225,7 @@ func defaultTrieRequestHandler(t *testPeer, requestId uint64, root common.Hash,
default: default:
account := t.storageTries[(common.BytesToHash(pathset[0]))] account := t.storageTries[(common.BytesToHash(pathset[0]))]
for _, path := range pathset[1:] { for _, path := range pathset[1:] {
blob, _, err := account.TryGetNode(path) blob, _, err := account.GetNode(path)
if err != nil { if err != nil {
t.logger.Info("Error handling req", "error", err) t.logger.Info("Error handling req", "error", err)
break break
@ -1381,7 +1381,7 @@ func makeAccountTrieNoStorage(n int) (string, *trie.Trie, entrySlice) {
}) })
key := key32(i) key := key32(i)
elem := &kv{key, value} elem := &kv{key, value}
accTrie.Update(elem.k, elem.v) accTrie.MustUpdate(elem.k, elem.v)
entries = append(entries, elem) entries = append(entries, elem)
} }
sort.Sort(entries) sort.Sort(entries)
@ -1431,7 +1431,7 @@ func makeBoundaryAccountTrie(n int) (string, *trie.Trie, entrySlice) {
CodeHash: getCodeHash(uint64(i)), CodeHash: getCodeHash(uint64(i)),
}) })
elem := &kv{boundaries[i].Bytes(), value} elem := &kv{boundaries[i].Bytes(), value}
accTrie.Update(elem.k, elem.v) accTrie.MustUpdate(elem.k, elem.v)
entries = append(entries, elem) entries = append(entries, elem)
} }
// Fill other accounts if required // Fill other accounts if required
@ -1443,7 +1443,7 @@ func makeBoundaryAccountTrie(n int) (string, *trie.Trie, entrySlice) {
CodeHash: getCodeHash(i), CodeHash: getCodeHash(i),
}) })
elem := &kv{key32(i), value} elem := &kv{key32(i), value}
accTrie.Update(elem.k, elem.v) accTrie.MustUpdate(elem.k, elem.v)
entries = append(entries, elem) entries = append(entries, elem)
} }
sort.Sort(entries) sort.Sort(entries)
@ -1487,7 +1487,7 @@ func makeAccountTrieWithStorageWithUniqueStorage(accounts, slots int, code bool)
CodeHash: codehash, CodeHash: codehash,
}) })
elem := &kv{key, value} elem := &kv{key, value}
accTrie.Update(elem.k, elem.v) accTrie.MustUpdate(elem.k, elem.v)
entries = append(entries, elem) entries = append(entries, elem)
storageRoots[common.BytesToHash(key)] = stRoot storageRoots[common.BytesToHash(key)] = stRoot
@ -1551,7 +1551,7 @@ func makeAccountTrieWithStorage(accounts, slots int, code, boundary bool) (strin
CodeHash: codehash, CodeHash: codehash,
}) })
elem := &kv{key, value} elem := &kv{key, value}
accTrie.Update(elem.k, elem.v) accTrie.MustUpdate(elem.k, elem.v)
entries = append(entries, elem) entries = append(entries, elem)
// we reuse the same one for all accounts // we reuse the same one for all accounts
@ -1599,7 +1599,7 @@ func makeStorageTrieWithSeed(owner common.Hash, n, seed uint64, db *trie.Databas
key := crypto.Keccak256Hash(slotKey[:]) key := crypto.Keccak256Hash(slotKey[:])
elem := &kv{key[:], rlpSlotValue} elem := &kv{key[:], rlpSlotValue}
trie.Update(elem.k, elem.v) trie.MustUpdate(elem.k, elem.v)
entries = append(entries, elem) entries = append(entries, elem)
} }
sort.Sort(entries) sort.Sort(entries)
@ -1638,7 +1638,7 @@ func makeBoundaryStorageTrie(owner common.Hash, n int, db *trie.Database) (commo
val := []byte{0xde, 0xad, 0xbe, 0xef} val := []byte{0xde, 0xad, 0xbe, 0xef}
elem := &kv{key[:], val} elem := &kv{key[:], val}
trie.Update(elem.k, elem.v) trie.MustUpdate(elem.k, elem.v)
entries = append(entries, elem) entries = append(entries, elem)
} }
// Fill other slots if required // Fill other slots if required
@ -1650,7 +1650,7 @@ func makeBoundaryStorageTrie(owner common.Hash, n int, db *trie.Database) (commo
rlpSlotValue, _ := rlp.EncodeToBytes(common.TrimLeftZeroes(slotValue[:])) rlpSlotValue, _ := rlp.EncodeToBytes(common.TrimLeftZeroes(slotValue[:]))
elem := &kv{key[:], rlpSlotValue} elem := &kv{key[:], rlpSlotValue}
trie.Update(elem.k, elem.v) trie.MustUpdate(elem.k, elem.v)
entries = append(entries, elem) entries = append(entries, elem)
} }
sort.Sort(entries) sort.Sort(entries)

View File

@ -364,7 +364,7 @@ func getAccount(triedb *trie.Database, root, hash common.Hash) (types.StateAccou
if err != nil { if err != nil {
return types.StateAccount{}, err return types.StateAccount{}, err
} }
blob, err := trie.TryGet(hash[:]) blob, err := trie.Get(hash[:])
if err != nil { if err != nil {
return types.StateAccount{}, err return types.StateAccount{}, err
} }

View File

@ -206,8 +206,7 @@ func (c *ChtIndexerBackend) Process(ctx context.Context, header *types.Header) e
var encNumber [8]byte var encNumber [8]byte
binary.BigEndian.PutUint64(encNumber[:], num) binary.BigEndian.PutUint64(encNumber[:], num)
data, _ := rlp.EncodeToBytes(ChtNode{hash, td}) data, _ := rlp.EncodeToBytes(ChtNode{hash, td})
c.trie.Update(encNumber[:], data) return c.trie.Update(encNumber[:], data)
return nil
} }
// Commit implements core.ChainIndexerBackend // Commit implements core.ChainIndexerBackend
@ -450,10 +449,15 @@ func (b *BloomTrieIndexerBackend) Commit() error {
decompSize += uint64(len(decomp)) decompSize += uint64(len(decomp))
compSize += uint64(len(comp)) compSize += uint64(len(comp))
var terr error
if len(comp) > 0 { if len(comp) > 0 {
b.trie.Update(encKey[:], comp) terr = b.trie.Update(encKey[:], comp)
} else { } else {
b.trie.Delete(encKey[:]) terr = b.trie.Delete(encKey[:])
}
if terr != nil {
return terr
} }
} }
root, nodes := b.trie.Commit(false) root, nodes := b.trie.Commit(false)

View File

@ -109,7 +109,7 @@ func (t *odrTrie) GetStorage(_ common.Address, key []byte) ([]byte, error) {
key = crypto.Keccak256(key) key = crypto.Keccak256(key)
var res []byte var res []byte
err := t.do(key, func() (err error) { err := t.do(key, func() (err error) {
res, err = t.trie.TryGet(key) res, err = t.trie.Get(key)
return err return err
}) })
return res, err return res, err
@ -119,7 +119,7 @@ func (t *odrTrie) GetAccount(address common.Address) (*types.StateAccount, error
var res types.StateAccount var res types.StateAccount
key := crypto.Keccak256(address.Bytes()) key := crypto.Keccak256(address.Bytes())
err := t.do(key, func() (err error) { err := t.do(key, func() (err error) {
value, err := t.trie.TryGet(key) value, err := t.trie.Get(key)
if err != nil { if err != nil {
return err return err
} }
@ -138,21 +138,21 @@ func (t *odrTrie) UpdateAccount(address common.Address, acc *types.StateAccount)
return fmt.Errorf("decoding error in account update: %w", err) return fmt.Errorf("decoding error in account update: %w", err)
} }
return t.do(key, func() error { return t.do(key, func() error {
return t.trie.TryUpdate(key, value) return t.trie.Update(key, value)
}) })
} }
func (t *odrTrie) UpdateStorage(_ common.Address, key, value []byte) error { func (t *odrTrie) UpdateStorage(_ common.Address, key, value []byte) error {
key = crypto.Keccak256(key) key = crypto.Keccak256(key)
return t.do(key, func() error { return t.do(key, func() error {
return t.trie.TryUpdate(key, value) return t.trie.Update(key, value)
}) })
} }
func (t *odrTrie) DeleteStorage(_ common.Address, key []byte) error { func (t *odrTrie) DeleteStorage(_ common.Address, key []byte) error {
key = crypto.Keccak256(key) key = crypto.Keccak256(key)
return t.do(key, func() error { return t.do(key, func() error {
return t.trie.TryDelete(key) return t.trie.Delete(key)
}) })
} }
@ -160,7 +160,7 @@ func (t *odrTrie) DeleteStorage(_ common.Address, key []byte) error {
func (t *odrTrie) DeleteAccount(address common.Address) error { func (t *odrTrie) DeleteAccount(address common.Address) error {
key := crypto.Keccak256(address.Bytes()) key := crypto.Keccak256(address.Bytes())
return t.do(key, func() error { return t.do(key, func() error {
return t.trie.TryDelete(key) return t.trie.Delete(key)
}) })
} }

View File

@ -93,13 +93,13 @@ func makeTries() (chtTrie *trie.Trie, bloomTrie *trie.Trie, chtKeys, bloomKeys [
// The element in CHT is <big-endian block number> -> <block hash> // The element in CHT is <big-endian block number> -> <block hash>
key := make([]byte, 8) key := make([]byte, 8)
binary.BigEndian.PutUint64(key, uint64(i+1)) binary.BigEndian.PutUint64(key, uint64(i+1))
chtTrie.Update(key, []byte{0x1, 0xf}) chtTrie.MustUpdate(key, []byte{0x1, 0xf})
chtKeys = append(chtKeys, key) chtKeys = append(chtKeys, key)
// The element in Bloom trie is <2 byte bit index> + <big-endian block number> -> bloom // The element in Bloom trie is <2 byte bit index> + <big-endian block number> -> bloom
key2 := make([]byte, 10) key2 := make([]byte, 10)
binary.BigEndian.PutUint64(key2[2:], uint64(i+1)) binary.BigEndian.PutUint64(key2[2:], uint64(i+1))
bloomTrie.Update(key2, []byte{0x2, 0xe}) bloomTrie.MustUpdate(key2, []byte{0x2, 0xe})
bloomKeys = append(bloomKeys, key2) bloomKeys = append(bloomKeys, key2)
} }
return return

View File

@ -69,8 +69,8 @@ func (f *fuzzer) randomTrie(n int) (*trie.Trie, map[string]*kv) {
for i := byte(0); i < byte(size); i++ { for i := byte(0); i < byte(size); i++ {
value := &kv{common.LeftPadBytes([]byte{i}, 32), []byte{i}, false} value := &kv{common.LeftPadBytes([]byte{i}, 32), []byte{i}, false}
value2 := &kv{common.LeftPadBytes([]byte{i + 10}, 32), []byte{i}, false} value2 := &kv{common.LeftPadBytes([]byte{i + 10}, 32), []byte{i}, false}
trie.Update(value.k, value.v) trie.MustUpdate(value.k, value.v)
trie.Update(value2.k, value2.v) trie.MustUpdate(value2.k, value2.v)
vals[string(value.k)] = value vals[string(value.k)] = value
vals[string(value2.k)] = value2 vals[string(value2.k)] = value2
} }
@ -82,7 +82,7 @@ func (f *fuzzer) randomTrie(n int) (*trie.Trie, map[string]*kv) {
k := f.randBytes(32) k := f.randBytes(32)
v := f.randBytes(20) v := f.randBytes(20)
value := &kv{k, v, false} value := &kv{k, v, false}
trie.Update(k, v) trie.MustUpdate(k, v)
vals[string(k)] = value vals[string(k)] = value
if f.exhausted { if f.exhausted {
return nil, nil return nil, nil

View File

@ -175,7 +175,7 @@ func (f *fuzzer) fuzz() int {
} }
keys[string(k)] = struct{}{} keys[string(k)] = struct{}{}
vals = append(vals, kv{k: k, v: v}) vals = append(vals, kv{k: k, v: v})
trieA.Update(k, v) trieA.MustUpdate(k, v)
useful = true useful = true
} }
if !useful { if !useful {
@ -195,7 +195,7 @@ func (f *fuzzer) fuzz() int {
if f.debugging { if f.debugging {
fmt.Printf("{\"%#x\" , \"%#x\"} // stacktrie.Update\n", kv.k, kv.v) fmt.Printf("{\"%#x\" , \"%#x\"} // stacktrie.Update\n", kv.k, kv.v)
} }
trieB.Update(kv.k, kv.v) trieB.MustUpdate(kv.k, kv.v)
} }
rootB := trieB.Hash() rootB := trieB.Hash()
trieB.Commit() trieB.Commit()
@ -223,7 +223,7 @@ func (f *fuzzer) fuzz() int {
checked int checked int
) )
for _, kv := range vals { for _, kv := range vals {
trieC.Update(kv.k, kv.v) trieC.MustUpdate(kv.k, kv.v)
} }
rootC, _ := trieC.Commit() rootC, _ := trieC.Commit()
if rootA != rootC { if rootA != rootC {

View File

@ -147,13 +147,13 @@ func runRandTest(rt randTest) error {
for i, step := range rt { for i, step := range rt {
switch step.op { switch step.op {
case opUpdate: case opUpdate:
tr.Update(step.key, step.value) tr.MustUpdate(step.key, step.value)
values[string(step.key)] = string(step.value) values[string(step.key)] = string(step.value)
case opDelete: case opDelete:
tr.Delete(step.key) tr.MustDelete(step.key)
delete(values, string(step.key)) delete(values, string(step.key))
case opGet: case opGet:
v := tr.Get(step.key) v := tr.MustGet(step.key)
want := values[string(step.key)] want := values[string(step.key)]
if string(v) != want { if string(v) != want {
rt[i].err = fmt.Errorf("mismatch for key %#x, got %#x want %#x", step.key, v, want) rt[i].err = fmt.Errorf("mismatch for key %#x, got %#x want %#x", step.key, v, want)
@ -176,7 +176,7 @@ func runRandTest(rt randTest) error {
checktr := trie.NewEmpty(triedb) checktr := trie.NewEmpty(triedb)
it := trie.NewIterator(tr.NodeIterator(nil)) it := trie.NewIterator(tr.NodeIterator(nil))
for it.Next() { for it.Next() {
checktr.Update(it.Key, it.Value) checktr.MustUpdate(it.Key, it.Value)
} }
if tr.Hash() != checktr.Hash() { if tr.Hash() != checktr.Hash() {
return fmt.Errorf("hash mismatch in opItercheckhash") return fmt.Errorf("hash mismatch in opItercheckhash")

View File

@ -22,7 +22,7 @@ import (
"github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/common"
) )
// MissingNodeError is returned by the trie functions (TryGet, TryUpdate, TryDelete) // MissingNodeError is returned by the trie functions (Get, Update, Delete)
// in the case where a trie node is not present in the local database. It contains // in the case where a trie node is not present in the local database. It contains
// information necessary for retrieving the missing node. // information necessary for retrieving the missing node.
type MissingNodeError struct { type MissingNodeError struct {

View File

@ -58,7 +58,7 @@ func TestIterator(t *testing.T) {
all := make(map[string]string) all := make(map[string]string)
for _, val := range vals { for _, val := range vals {
all[val.k] = val.v all[val.k] = val.v
trie.Update([]byte(val.k), []byte(val.v)) trie.MustUpdate([]byte(val.k), []byte(val.v))
} }
root, nodes := trie.Commit(false) root, nodes := trie.Commit(false)
db.Update(NewWithNodeSet(nodes)) db.Update(NewWithNodeSet(nodes))
@ -89,8 +89,8 @@ func TestIteratorLargeData(t *testing.T) {
for i := byte(0); i < 255; i++ { for i := byte(0); i < 255; i++ {
value := &kv{common.LeftPadBytes([]byte{i}, 32), []byte{i}, false} value := &kv{common.LeftPadBytes([]byte{i}, 32), []byte{i}, false}
value2 := &kv{common.LeftPadBytes([]byte{10, i}, 32), []byte{i}, false} value2 := &kv{common.LeftPadBytes([]byte{10, i}, 32), []byte{i}, false}
trie.Update(value.k, value.v) trie.MustUpdate(value.k, value.v)
trie.Update(value2.k, value2.v) trie.MustUpdate(value2.k, value2.v)
vals[string(value.k)] = value vals[string(value.k)] = value
vals[string(value2.k)] = value2 vals[string(value2.k)] = value2
} }
@ -178,7 +178,7 @@ var testdata2 = []kvs{
func TestIteratorSeek(t *testing.T) { func TestIteratorSeek(t *testing.T) {
trie := NewEmpty(NewDatabase(rawdb.NewMemoryDatabase())) trie := NewEmpty(NewDatabase(rawdb.NewMemoryDatabase()))
for _, val := range testdata1 { for _, val := range testdata1 {
trie.Update([]byte(val.k), []byte(val.v)) trie.MustUpdate([]byte(val.k), []byte(val.v))
} }
// Seek to the middle. // Seek to the middle.
@ -220,7 +220,7 @@ func TestDifferenceIterator(t *testing.T) {
dba := NewDatabase(rawdb.NewMemoryDatabase()) dba := NewDatabase(rawdb.NewMemoryDatabase())
triea := NewEmpty(dba) triea := NewEmpty(dba)
for _, val := range testdata1 { for _, val := range testdata1 {
triea.Update([]byte(val.k), []byte(val.v)) triea.MustUpdate([]byte(val.k), []byte(val.v))
} }
rootA, nodesA := triea.Commit(false) rootA, nodesA := triea.Commit(false)
dba.Update(NewWithNodeSet(nodesA)) dba.Update(NewWithNodeSet(nodesA))
@ -229,7 +229,7 @@ func TestDifferenceIterator(t *testing.T) {
dbb := NewDatabase(rawdb.NewMemoryDatabase()) dbb := NewDatabase(rawdb.NewMemoryDatabase())
trieb := NewEmpty(dbb) trieb := NewEmpty(dbb)
for _, val := range testdata2 { for _, val := range testdata2 {
trieb.Update([]byte(val.k), []byte(val.v)) trieb.MustUpdate([]byte(val.k), []byte(val.v))
} }
rootB, nodesB := trieb.Commit(false) rootB, nodesB := trieb.Commit(false)
dbb.Update(NewWithNodeSet(nodesB)) dbb.Update(NewWithNodeSet(nodesB))
@ -262,7 +262,7 @@ func TestUnionIterator(t *testing.T) {
dba := NewDatabase(rawdb.NewMemoryDatabase()) dba := NewDatabase(rawdb.NewMemoryDatabase())
triea := NewEmpty(dba) triea := NewEmpty(dba)
for _, val := range testdata1 { for _, val := range testdata1 {
triea.Update([]byte(val.k), []byte(val.v)) triea.MustUpdate([]byte(val.k), []byte(val.v))
} }
rootA, nodesA := triea.Commit(false) rootA, nodesA := triea.Commit(false)
dba.Update(NewWithNodeSet(nodesA)) dba.Update(NewWithNodeSet(nodesA))
@ -271,7 +271,7 @@ func TestUnionIterator(t *testing.T) {
dbb := NewDatabase(rawdb.NewMemoryDatabase()) dbb := NewDatabase(rawdb.NewMemoryDatabase())
trieb := NewEmpty(dbb) trieb := NewEmpty(dbb)
for _, val := range testdata2 { for _, val := range testdata2 {
trieb.Update([]byte(val.k), []byte(val.v)) trieb.MustUpdate([]byte(val.k), []byte(val.v))
} }
rootB, nodesB := trieb.Commit(false) rootB, nodesB := trieb.Commit(false)
dbb.Update(NewWithNodeSet(nodesB)) dbb.Update(NewWithNodeSet(nodesB))
@ -314,7 +314,7 @@ func TestUnionIterator(t *testing.T) {
func TestIteratorNoDups(t *testing.T) { func TestIteratorNoDups(t *testing.T) {
tr := NewEmpty(NewDatabase(rawdb.NewMemoryDatabase())) tr := NewEmpty(NewDatabase(rawdb.NewMemoryDatabase()))
for _, val := range testdata1 { for _, val := range testdata1 {
tr.Update([]byte(val.k), []byte(val.v)) tr.MustUpdate([]byte(val.k), []byte(val.v))
} }
checkIteratorNoDups(t, tr.NodeIterator(nil), nil) checkIteratorNoDups(t, tr.NodeIterator(nil), nil)
} }
@ -329,7 +329,7 @@ func testIteratorContinueAfterError(t *testing.T, memonly bool) {
tr := NewEmpty(triedb) tr := NewEmpty(triedb)
for _, val := range testdata1 { for _, val := range testdata1 {
tr.Update([]byte(val.k), []byte(val.v)) tr.MustUpdate([]byte(val.k), []byte(val.v))
} }
_, nodes := tr.Commit(false) _, nodes := tr.Commit(false)
triedb.Update(NewWithNodeSet(nodes)) triedb.Update(NewWithNodeSet(nodes))
@ -421,7 +421,7 @@ func testIteratorContinueAfterSeekError(t *testing.T, memonly bool) {
ctr := NewEmpty(triedb) ctr := NewEmpty(triedb)
for _, val := range testdata1 { for _, val := range testdata1 {
ctr.Update([]byte(val.k), []byte(val.v)) ctr.MustUpdate([]byte(val.k), []byte(val.v))
} }
root, nodes := ctr.Commit(false) root, nodes := ctr.Commit(false)
triedb.Update(NewWithNodeSet(nodes)) triedb.Update(NewWithNodeSet(nodes))
@ -540,7 +540,7 @@ func makeLargeTestTrie() (*Database, *StateTrie, *loggingDb) {
binary.BigEndian.PutUint64(val, uint64(i)) binary.BigEndian.PutUint64(val, uint64(i))
key = crypto.Keccak256(key) key = crypto.Keccak256(key)
val = crypto.Keccak256(val) val = crypto.Keccak256(val)
trie.Update(key, val) trie.MustUpdate(key, val)
} }
_, nodes := trie.Commit(false) _, nodes := trie.Commit(false)
triedb.Update(NewWithNodeSet(nodes)) triedb.Update(NewWithNodeSet(nodes))
@ -580,7 +580,7 @@ func TestIteratorNodeBlob(t *testing.T) {
all := make(map[string]string) all := make(map[string]string)
for _, val := range vals { for _, val := range vals {
all[val.k] = val.v all[val.k] = val.v
trie.Update([]byte(val.k), []byte(val.v)) trie.MustUpdate([]byte(val.k), []byte(val.v))
} }
_, nodes := trie.Commit(false) _, nodes := trie.Commit(false)
triedb.Update(NewWithNodeSet(nodes)) triedb.Update(NewWithNodeSet(nodes))

View File

@ -498,7 +498,7 @@ func VerifyRangeProof(rootHash common.Hash, firstKey []byte, lastKey []byte, key
if proof == nil { if proof == nil {
tr := NewStackTrie(nil) tr := NewStackTrie(nil)
for index, key := range keys { for index, key := range keys {
tr.TryUpdate(key, values[index]) tr.Update(key, values[index])
} }
if have, want := tr.Hash(), rootHash; have != want { if have, want := tr.Hash(), rootHash; have != want {
return false, fmt.Errorf("invalid proof, want hash %x, got %x", want, have) return false, fmt.Errorf("invalid proof, want hash %x, got %x", want, have)
@ -568,7 +568,7 @@ func VerifyRangeProof(rootHash common.Hash, firstKey []byte, lastKey []byte, key
tr.root = nil tr.root = nil
} }
for index, key := range keys { for index, key := range keys {
tr.TryUpdate(key, values[index]) tr.Update(key, values[index])
} }
if tr.Hash() != rootHash { if tr.Hash() != rootHash {
return false, fmt.Errorf("invalid proof, want hash %x, got %x", rootHash, tr.Hash()) return false, fmt.Errorf("invalid proof, want hash %x, got %x", rootHash, tr.Hash())

View File

@ -403,7 +403,7 @@ func TestOneElementRangeProof(t *testing.T) {
// Test the mini trie with only a single element. // Test the mini trie with only a single element.
tinyTrie := NewEmpty(NewDatabase(rawdb.NewMemoryDatabase())) tinyTrie := NewEmpty(NewDatabase(rawdb.NewMemoryDatabase()))
entry := &kv{randBytes(32), randBytes(20), false} entry := &kv{randBytes(32), randBytes(20), false}
tinyTrie.Update(entry.k, entry.v) tinyTrie.MustUpdate(entry.k, entry.v)
first = common.HexToHash("0x0000000000000000000000000000000000000000000000000000000000000000").Bytes() first = common.HexToHash("0x0000000000000000000000000000000000000000000000000000000000000000").Bytes()
last = entry.k last = entry.k
@ -477,7 +477,7 @@ func TestSingleSideRangeProof(t *testing.T) {
var entries entrySlice var entries entrySlice
for i := 0; i < 4096; i++ { for i := 0; i < 4096; i++ {
value := &kv{randBytes(32), randBytes(20), false} value := &kv{randBytes(32), randBytes(20), false}
trie.Update(value.k, value.v) trie.MustUpdate(value.k, value.v)
entries = append(entries, value) entries = append(entries, value)
} }
sort.Sort(entries) sort.Sort(entries)
@ -512,7 +512,7 @@ func TestReverseSingleSideRangeProof(t *testing.T) {
var entries entrySlice var entries entrySlice
for i := 0; i < 4096; i++ { for i := 0; i < 4096; i++ {
value := &kv{randBytes(32), randBytes(20), false} value := &kv{randBytes(32), randBytes(20), false}
trie.Update(value.k, value.v) trie.MustUpdate(value.k, value.v)
entries = append(entries, value) entries = append(entries, value)
} }
sort.Sort(entries) sort.Sort(entries)
@ -619,7 +619,7 @@ func TestGappedRangeProof(t *testing.T) {
var entries []*kv // Sorted entries var entries []*kv // Sorted entries
for i := byte(0); i < 10; i++ { for i := byte(0); i < 10; i++ {
value := &kv{common.LeftPadBytes([]byte{i}, 32), []byte{i}, false} value := &kv{common.LeftPadBytes([]byte{i}, 32), []byte{i}, false}
trie.Update(value.k, value.v) trie.MustUpdate(value.k, value.v)
entries = append(entries, value) entries = append(entries, value)
} }
first, last := 2, 8 first, last := 2, 8
@ -693,7 +693,7 @@ func TestHasRightElement(t *testing.T) {
var entries entrySlice var entries entrySlice
for i := 0; i < 4096; i++ { for i := 0; i < 4096; i++ {
value := &kv{randBytes(32), randBytes(20), false} value := &kv{randBytes(32), randBytes(20), false}
trie.Update(value.k, value.v) trie.MustUpdate(value.k, value.v)
entries = append(entries, value) entries = append(entries, value)
} }
sort.Sort(entries) sort.Sort(entries)
@ -1047,14 +1047,14 @@ func randomTrie(n int) (*Trie, map[string]*kv) {
for i := byte(0); i < 100; i++ { for i := byte(0); i < 100; i++ {
value := &kv{common.LeftPadBytes([]byte{i}, 32), []byte{i}, false} value := &kv{common.LeftPadBytes([]byte{i}, 32), []byte{i}, false}
value2 := &kv{common.LeftPadBytes([]byte{i + 10}, 32), []byte{i}, false} value2 := &kv{common.LeftPadBytes([]byte{i + 10}, 32), []byte{i}, false}
trie.Update(value.k, value.v) trie.MustUpdate(value.k, value.v)
trie.Update(value2.k, value2.v) trie.MustUpdate(value2.k, value2.v)
vals[string(value.k)] = value vals[string(value.k)] = value
vals[string(value2.k)] = value2 vals[string(value2.k)] = value2
} }
for i := 0; i < n; i++ { for i := 0; i < n; i++ {
value := &kv{randBytes(32), randBytes(20), false} value := &kv{randBytes(32), randBytes(20), false}
trie.Update(value.k, value.v) trie.MustUpdate(value.k, value.v)
vals[string(value.k)] = value vals[string(value.k)] = value
} }
return trie, vals return trie, vals
@ -1071,7 +1071,7 @@ func nonRandomTrie(n int) (*Trie, map[string]*kv) {
binary.LittleEndian.PutUint64(value, i-max) binary.LittleEndian.PutUint64(value, i-max)
//value := &kv{common.LeftPadBytes([]byte{i}, 32), []byte{i}, false} //value := &kv{common.LeftPadBytes([]byte{i}, 32), []byte{i}, false}
elem := &kv{key, value, false} elem := &kv{key, value, false}
trie.Update(elem.k, elem.v) trie.MustUpdate(elem.k, elem.v)
vals[string(elem.k)] = elem vals[string(elem.k)] = elem
} }
return trie, vals return trie, vals
@ -1088,7 +1088,7 @@ func TestRangeProofKeysWithSharedPrefix(t *testing.T) {
} }
trie := NewEmpty(NewDatabase(rawdb.NewMemoryDatabase())) trie := NewEmpty(NewDatabase(rawdb.NewMemoryDatabase()))
for i, key := range keys { for i, key := range keys {
trie.Update(key, vals[i]) trie.MustUpdate(key, vals[i])
} }
root := trie.Hash() root := trie.Hash()
proof := memorydb.New() proof := memorydb.New()

View File

@ -19,7 +19,6 @@ package trie
import ( import (
"github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/core/types" "github.com/ethereum/go-ethereum/core/types"
"github.com/ethereum/go-ethereum/log"
"github.com/ethereum/go-ethereum/rlp" "github.com/ethereum/go-ethereum/rlp"
) )
@ -72,14 +71,13 @@ func NewStateTrie(id *ID, db *Database) (*StateTrie, error) {
return &StateTrie{trie: *trie, preimages: db.preimages}, nil return &StateTrie{trie: *trie, preimages: db.preimages}, nil
} }
// Get returns the value for key stored in the trie. // MustGet returns the value for key stored in the trie.
// The value bytes must not be modified by the caller. // The value bytes must not be modified by the caller.
func (t *StateTrie) Get(key []byte) []byte { //
res, err := t.GetStorage(common.Address{}, key) // This function will omit any encountered error but just
if err != nil { // print out an error message.
log.Error("Unhandled trie error in StateTrie.Get", "err", err) func (t *StateTrie) MustGet(key []byte) []byte {
} return t.trie.MustGet(t.hashKey(key))
return res
} }
// GetStorage attempts to retrieve a storage slot with provided account address // GetStorage attempts to retrieve a storage slot with provided account address
@ -87,14 +85,14 @@ func (t *StateTrie) Get(key []byte) []byte {
// If the specified storage slot is not in the trie, nil will be returned. // If the specified storage slot is not in the trie, nil will be returned.
// If a trie node is not found in the database, a MissingNodeError is returned. // If a trie node is not found in the database, a MissingNodeError is returned.
func (t *StateTrie) GetStorage(_ common.Address, key []byte) ([]byte, error) { func (t *StateTrie) GetStorage(_ common.Address, key []byte) ([]byte, error) {
return t.trie.TryGet(t.hashKey(key)) return t.trie.Get(t.hashKey(key))
} }
// GetAccount attempts to retrieve an account with provided account address. // GetAccount attempts to retrieve an account with provided account address.
// If the specified account is not in the trie, nil will be returned. // If the specified account is not in the trie, nil will be returned.
// If a trie node is not found in the database, a MissingNodeError is returned. // If a trie node is not found in the database, a MissingNodeError is returned.
func (t *StateTrie) GetAccount(address common.Address) (*types.StateAccount, error) { func (t *StateTrie) GetAccount(address common.Address) (*types.StateAccount, error) {
res, err := t.trie.TryGet(t.hashKey(address.Bytes())) res, err := t.trie.Get(t.hashKey(address.Bytes()))
if res == nil || err != nil { if res == nil || err != nil {
return nil, err return nil, err
} }
@ -107,7 +105,7 @@ func (t *StateTrie) GetAccount(address common.Address) (*types.StateAccount, err
// account hash that is the hash of address. This constitutes an abstraction // account hash that is the hash of address. This constitutes an abstraction
// leak, since the client code needs to know the key format. // leak, since the client code needs to know the key format.
func (t *StateTrie) GetAccountByHash(addrHash common.Hash) (*types.StateAccount, error) { func (t *StateTrie) GetAccountByHash(addrHash common.Hash) (*types.StateAccount, error) {
res, err := t.trie.TryGet(addrHash.Bytes()) res, err := t.trie.Get(addrHash.Bytes())
if res == nil || err != nil { if res == nil || err != nil {
return nil, err return nil, err
} }
@ -121,19 +119,22 @@ func (t *StateTrie) GetAccountByHash(addrHash common.Hash) (*types.StateAccount,
// If the specified trie node is not in the trie, nil will be returned. // If the specified trie node is not in the trie, nil will be returned.
// If a trie node is not found in the database, a MissingNodeError is returned. // If a trie node is not found in the database, a MissingNodeError is returned.
func (t *StateTrie) GetNode(path []byte) ([]byte, int, error) { func (t *StateTrie) GetNode(path []byte) ([]byte, int, error) {
return t.trie.TryGetNode(path) return t.trie.GetNode(path)
} }
// Update associates key with value in the trie. Subsequent calls to // MustUpdate associates key with value in the trie. Subsequent calls to
// Get will return value. If value has length zero, any existing value // Get will return value. If value has length zero, any existing value
// is deleted from the trie and calls to Get will return nil. // is deleted from the trie and calls to Get will return nil.
// //
// The value bytes must not be modified by the caller while they are // The value bytes must not be modified by the caller while they are
// stored in the trie. // stored in the trie.
func (t *StateTrie) Update(key, value []byte) { //
if err := t.UpdateStorage(common.Address{}, key, value); err != nil { // This function will omit any encountered error but just print out an
log.Error("Unhandled trie error in StateTrie.Update", "err", err) // error message.
} func (t *StateTrie) MustUpdate(key, value []byte) {
hk := t.hashKey(key)
t.trie.MustUpdate(hk, value)
t.getSecKeyCache()[string(hk)] = common.CopyBytes(key)
} }
// UpdateStorage associates key with value in the trie. Subsequent calls to // UpdateStorage associates key with value in the trie. Subsequent calls to
@ -146,7 +147,7 @@ func (t *StateTrie) Update(key, value []byte) {
// If a node is not found in the database, a MissingNodeError is returned. // If a node is not found in the database, a MissingNodeError is returned.
func (t *StateTrie) UpdateStorage(_ common.Address, key, value []byte) error { func (t *StateTrie) UpdateStorage(_ common.Address, key, value []byte) error {
hk := t.hashKey(key) hk := t.hashKey(key)
err := t.trie.TryUpdate(hk, value) err := t.trie.Update(hk, value)
if err != nil { if err != nil {
return err return err
} }
@ -161,18 +162,19 @@ func (t *StateTrie) UpdateAccount(address common.Address, acc *types.StateAccoun
if err != nil { if err != nil {
return err return err
} }
if err := t.trie.TryUpdate(hk, data); err != nil { if err := t.trie.Update(hk, data); err != nil {
return err return err
} }
t.getSecKeyCache()[string(hk)] = address.Bytes() t.getSecKeyCache()[string(hk)] = address.Bytes()
return nil return nil
} }
// Delete removes any existing value for key from the trie. // MustDelete removes any existing value for key from the trie. This function
func (t *StateTrie) Delete(key []byte) { // will omit any encountered error but just print out an error message.
if err := t.DeleteStorage(common.Address{}, key); err != nil { func (t *StateTrie) MustDelete(key []byte) {
log.Error("Unhandled trie error in StateTrie.Delete", "err", err) hk := t.hashKey(key)
} delete(t.getSecKeyCache(), string(hk))
t.trie.MustDelete(hk)
} }
// DeleteStorage removes any existing storage slot from the trie. // DeleteStorage removes any existing storage slot from the trie.
@ -181,14 +183,14 @@ func (t *StateTrie) Delete(key []byte) {
func (t *StateTrie) DeleteStorage(_ common.Address, key []byte) error { func (t *StateTrie) DeleteStorage(_ common.Address, key []byte) error {
hk := t.hashKey(key) hk := t.hashKey(key)
delete(t.getSecKeyCache(), string(hk)) delete(t.getSecKeyCache(), string(hk))
return t.trie.TryDelete(hk) return t.trie.Delete(hk)
} }
// DeleteAccount abstracts an account deletion from the trie. // DeleteAccount abstracts an account deletion from the trie.
func (t *StateTrie) DeleteAccount(address common.Address) error { func (t *StateTrie) DeleteAccount(address common.Address) error {
hk := t.hashKey(address.Bytes()) hk := t.hashKey(address.Bytes())
delete(t.getSecKeyCache(), string(hk)) delete(t.getSecKeyCache(), string(hk))
return t.trie.TryDelete(hk) return t.trie.Delete(hk)
} }
// GetKey returns the sha3 preimage of a hashed key that was // GetKey returns the sha3 preimage of a hashed key that was

View File

@ -45,17 +45,17 @@ func makeTestStateTrie() (*Database, *StateTrie, map[string][]byte) {
// Map the same data under multiple keys // Map the same data under multiple keys
key, val := common.LeftPadBytes([]byte{1, i}, 32), []byte{i} key, val := common.LeftPadBytes([]byte{1, i}, 32), []byte{i}
content[string(key)] = val content[string(key)] = val
trie.Update(key, val) trie.MustUpdate(key, val)
key, val = common.LeftPadBytes([]byte{2, i}, 32), []byte{i} key, val = common.LeftPadBytes([]byte{2, i}, 32), []byte{i}
content[string(key)] = val content[string(key)] = val
trie.Update(key, val) trie.MustUpdate(key, val)
// Add some other data to inflate the trie // Add some other data to inflate the trie
for j := byte(3); j < 13; j++ { for j := byte(3); j < 13; j++ {
key, val = common.LeftPadBytes([]byte{j, i}, 32), []byte{j, i} key, val = common.LeftPadBytes([]byte{j, i}, 32), []byte{j, i}
content[string(key)] = val content[string(key)] = val
trie.Update(key, val) trie.MustUpdate(key, val)
} }
} }
root, nodes := trie.Commit(false) root, nodes := trie.Commit(false)
@ -81,9 +81,9 @@ func TestSecureDelete(t *testing.T) {
} }
for _, val := range vals { for _, val := range vals {
if val.v != "" { if val.v != "" {
trie.Update([]byte(val.k), []byte(val.v)) trie.MustUpdate([]byte(val.k), []byte(val.v))
} else { } else {
trie.Delete([]byte(val.k)) trie.MustDelete([]byte(val.k))
} }
} }
hash := trie.Hash() hash := trie.Hash()
@ -95,13 +95,13 @@ func TestSecureDelete(t *testing.T) {
func TestSecureGetKey(t *testing.T) { func TestSecureGetKey(t *testing.T) {
trie := newEmptySecure() trie := newEmptySecure()
trie.Update([]byte("foo"), []byte("bar")) trie.MustUpdate([]byte("foo"), []byte("bar"))
key := []byte("foo") key := []byte("foo")
value := []byte("bar") value := []byte("bar")
seckey := crypto.Keccak256(key) seckey := crypto.Keccak256(key)
if !bytes.Equal(trie.Get(key), value) { if !bytes.Equal(trie.MustGet(key), value) {
t.Errorf("Get did not return bar") t.Errorf("Get did not return bar")
} }
if k := trie.GetKey(seckey); !bytes.Equal(k, key) { if k := trie.GetKey(seckey); !bytes.Equal(k, key) {
@ -128,15 +128,15 @@ func TestStateTrieConcurrency(t *testing.T) {
for j := byte(0); j < 255; j++ { for j := byte(0); j < 255; j++ {
// Map the same data under multiple keys // Map the same data under multiple keys
key, val := common.LeftPadBytes([]byte{byte(index), 1, j}, 32), []byte{j} key, val := common.LeftPadBytes([]byte{byte(index), 1, j}, 32), []byte{j}
tries[index].Update(key, val) tries[index].MustUpdate(key, val)
key, val = common.LeftPadBytes([]byte{byte(index), 2, j}, 32), []byte{j} key, val = common.LeftPadBytes([]byte{byte(index), 2, j}, 32), []byte{j}
tries[index].Update(key, val) tries[index].MustUpdate(key, val)
// Add some other data to inflate the trie // Add some other data to inflate the trie
for k := byte(3); k < 13; k++ { for k := byte(3); k < 13; k++ {
key, val = common.LeftPadBytes([]byte{byte(index), k, j}, 32), []byte{k, j} key, val = common.LeftPadBytes([]byte{byte(index), k, j}, 32), []byte{k, j}
tries[index].Update(key, val) tries[index].MustUpdate(key, val)
} }
} }
tries[index].Commit(false) tries[index].Commit(false)

View File

@ -202,8 +202,8 @@ const (
hashedNode hashedNode
) )
// TryUpdate inserts a (key, value) pair into the stack trie // Update inserts a (key, value) pair into the stack trie.
func (st *StackTrie) TryUpdate(key, value []byte) error { func (st *StackTrie) Update(key, value []byte) error {
k := keybytesToHex(key) k := keybytesToHex(key)
if len(value) == 0 { if len(value) == 0 {
panic("deletion not supported") panic("deletion not supported")
@ -212,8 +212,10 @@ func (st *StackTrie) TryUpdate(key, value []byte) error {
return nil return nil
} }
func (st *StackTrie) Update(key, value []byte) { // MustUpdate is a wrapper of Update and will omit any encountered error but
if err := st.TryUpdate(key, value); err != nil { // just print out an error message.
func (st *StackTrie) MustUpdate(key, value []byte) {
if err := st.Update(key, value); err != nil {
log.Error("Unhandled trie error in StackTrie.Update", "err", err) log.Error("Unhandled trie error in StackTrie.Update", "err", err)
} }
} }

View File

@ -174,7 +174,7 @@ func TestStackTrieInsertAndHash(t *testing.T) {
st.Reset() st.Reset()
for j := 0; j < l; j++ { for j := 0; j < l; j++ {
kv := &test[j] kv := &test[j]
if err := st.TryUpdate(common.FromHex(kv.K), []byte(kv.V)); err != nil { if err := st.Update(common.FromHex(kv.K), []byte(kv.V)); err != nil {
t.Fatal(err) t.Fatal(err)
} }
} }
@ -193,8 +193,8 @@ func TestSizeBug(t *testing.T) {
leaf := common.FromHex("290decd9548b62a8d60345a988386fc84ba6bc95484008f6362f93160ef3e563") leaf := common.FromHex("290decd9548b62a8d60345a988386fc84ba6bc95484008f6362f93160ef3e563")
value := common.FromHex("94cf40d0d2b44f2b66e07cace1372ca42b73cf21a3") value := common.FromHex("94cf40d0d2b44f2b66e07cace1372ca42b73cf21a3")
nt.TryUpdate(leaf, value) nt.Update(leaf, value)
st.TryUpdate(leaf, value) st.Update(leaf, value)
if nt.Hash() != st.Hash() { if nt.Hash() != st.Hash() {
t.Fatalf("error %x != %x", st.Hash(), nt.Hash()) t.Fatalf("error %x != %x", st.Hash(), nt.Hash())
@ -218,8 +218,8 @@ func TestEmptyBug(t *testing.T) {
} }
for _, kv := range kvs { for _, kv := range kvs {
nt.TryUpdate(common.FromHex(kv.K), common.FromHex(kv.V)) nt.Update(common.FromHex(kv.K), common.FromHex(kv.V))
st.TryUpdate(common.FromHex(kv.K), common.FromHex(kv.V)) st.Update(common.FromHex(kv.K), common.FromHex(kv.V))
} }
if nt.Hash() != st.Hash() { if nt.Hash() != st.Hash() {
@ -241,8 +241,8 @@ func TestValLength56(t *testing.T) {
} }
for _, kv := range kvs { for _, kv := range kvs {
nt.TryUpdate(common.FromHex(kv.K), common.FromHex(kv.V)) nt.Update(common.FromHex(kv.K), common.FromHex(kv.V))
st.TryUpdate(common.FromHex(kv.K), common.FromHex(kv.V)) st.Update(common.FromHex(kv.K), common.FromHex(kv.V))
} }
if nt.Hash() != st.Hash() { if nt.Hash() != st.Hash() {
@ -263,8 +263,8 @@ func TestUpdateSmallNodes(t *testing.T) {
{"65", "3000"}, // stacktrie.Update {"65", "3000"}, // stacktrie.Update
} }
for _, kv := range kvs { for _, kv := range kvs {
nt.TryUpdate(common.FromHex(kv.K), common.FromHex(kv.V)) nt.Update(common.FromHex(kv.K), common.FromHex(kv.V))
st.TryUpdate(common.FromHex(kv.K), common.FromHex(kv.V)) st.Update(common.FromHex(kv.K), common.FromHex(kv.V))
} }
if nt.Hash() != st.Hash() { if nt.Hash() != st.Hash() {
t.Fatalf("error %x != %x", st.Hash(), nt.Hash()) t.Fatalf("error %x != %x", st.Hash(), nt.Hash())
@ -291,8 +291,8 @@ func TestUpdateVariableKeys(t *testing.T) {
{"0x3330353463653239356131303167617430", "313131"}, {"0x3330353463653239356131303167617430", "313131"},
} }
for _, kv := range kvs { for _, kv := range kvs {
nt.TryUpdate(common.FromHex(kv.K), common.FromHex(kv.V)) nt.Update(common.FromHex(kv.K), common.FromHex(kv.V))
st.TryUpdate(common.FromHex(kv.K), common.FromHex(kv.V)) st.Update(common.FromHex(kv.K), common.FromHex(kv.V))
} }
if nt.Hash() != st.Hash() { if nt.Hash() != st.Hash() {
t.Fatalf("error %x != %x", st.Hash(), nt.Hash()) t.Fatalf("error %x != %x", st.Hash(), nt.Hash())
@ -309,7 +309,7 @@ func TestStacktrieNotModifyValues(t *testing.T) {
value := make([]byte, 1, 100) value := make([]byte, 1, 100)
value[0] = 0x2 value[0] = 0x2
want := common.CopyBytes(value) want := common.CopyBytes(value)
st.TryUpdate([]byte{0x01}, value) st.Update([]byte{0x01}, value)
st.Hash() st.Hash()
if have := value; !bytes.Equal(have, want) { if have := value; !bytes.Equal(have, want) {
t.Fatalf("tiny trie: have %#x want %#x", have, want) t.Fatalf("tiny trie: have %#x want %#x", have, want)
@ -330,7 +330,7 @@ func TestStacktrieNotModifyValues(t *testing.T) {
for i := 0; i < 1000; i++ { for i := 0; i < 1000; i++ {
key := common.BigToHash(keyB) key := common.BigToHash(keyB)
value := getValue(i) value := getValue(i)
st.TryUpdate(key.Bytes(), value) st.Update(key.Bytes(), value)
vals = append(vals, value) vals = append(vals, value)
keyB = keyB.Add(keyB, keyDelta) keyB = keyB.Add(keyB, keyDelta)
keyDelta.Add(keyDelta, common.Big1) keyDelta.Add(keyDelta, common.Big1)
@ -371,7 +371,7 @@ func TestStacktrieSerialization(t *testing.T) {
keyDelta.Add(keyDelta, common.Big1) keyDelta.Add(keyDelta, common.Big1)
} }
for i, k := range keys { for i, k := range keys {
nt.TryUpdate(k, common.CopyBytes(vals[i])) nt.Update(k, common.CopyBytes(vals[i]))
} }
for i, k := range keys { for i, k := range keys {
@ -384,7 +384,7 @@ func TestStacktrieSerialization(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
st = newSt st = newSt
st.TryUpdate(k, common.CopyBytes(vals[i])) st.Update(k, common.CopyBytes(vals[i]))
} }
if have, want := st.Hash(), nt.Hash(); have != want { if have, want := st.Hash(), nt.Hash(); have != want {
t.Fatalf("have %#x want %#x", have, want) t.Fatalf("have %#x want %#x", have, want)

View File

@ -40,17 +40,17 @@ func makeTestTrie() (*Database, *StateTrie, map[string][]byte) {
// Map the same data under multiple keys // Map the same data under multiple keys
key, val := common.LeftPadBytes([]byte{1, i}, 32), []byte{i} key, val := common.LeftPadBytes([]byte{1, i}, 32), []byte{i}
content[string(key)] = val content[string(key)] = val
trie.Update(key, val) trie.MustUpdate(key, val)
key, val = common.LeftPadBytes([]byte{2, i}, 32), []byte{i} key, val = common.LeftPadBytes([]byte{2, i}, 32), []byte{i}
content[string(key)] = val content[string(key)] = val
trie.Update(key, val) trie.MustUpdate(key, val)
// Add some other data to inflate the trie // Add some other data to inflate the trie
for j := byte(3); j < 13; j++ { for j := byte(3); j < 13; j++ {
key, val = common.LeftPadBytes([]byte{j, i}, 32), []byte{j, i} key, val = common.LeftPadBytes([]byte{j, i}, 32), []byte{j, i}
content[string(key)] = val content[string(key)] = val
trie.Update(key, val) trie.MustUpdate(key, val)
} }
} }
root, nodes := trie.Commit(false) root, nodes := trie.Commit(false)
@ -74,7 +74,7 @@ func checkTrieContents(t *testing.T, db *Database, root []byte, content map[stri
t.Fatalf("inconsistent trie at %x: %v", root, err) t.Fatalf("inconsistent trie at %x: %v", root, err)
} }
for key, val := range content { for key, val := range content {
if have := trie.Get([]byte(key)); !bytes.Equal(have, val) { if have := trie.MustGet([]byte(key)); !bytes.Equal(have, val) {
t.Errorf("entry %x: content mismatch: have %x, want %x", key, have, val) t.Errorf("entry %x: content mismatch: have %x, want %x", key, have, val)
} }
} }

View File

@ -64,7 +64,7 @@ func testTrieTracer(t *testing.T, vals []struct{ k, v string }) {
// Determine all new nodes are tracked // Determine all new nodes are tracked
for _, val := range vals { for _, val := range vals {
trie.Update([]byte(val.k), []byte(val.v)) trie.MustUpdate([]byte(val.k), []byte(val.v))
} }
insertSet := copySet(trie.tracer.inserts) // copy before commit insertSet := copySet(trie.tracer.inserts) // copy before commit
deleteSet := copySet(trie.tracer.deletes) // copy before commit deleteSet := copySet(trie.tracer.deletes) // copy before commit
@ -82,7 +82,7 @@ func testTrieTracer(t *testing.T, vals []struct{ k, v string }) {
// Determine all deletions are tracked // Determine all deletions are tracked
trie, _ = New(TrieID(root), db) trie, _ = New(TrieID(root), db)
for _, val := range vals { for _, val := range vals {
trie.Delete([]byte(val.k)) trie.MustDelete([]byte(val.k))
} }
insertSet, deleteSet = copySet(trie.tracer.inserts), copySet(trie.tracer.deletes) insertSet, deleteSet = copySet(trie.tracer.inserts), copySet(trie.tracer.deletes)
if !compareSet(insertSet, nil) { if !compareSet(insertSet, nil) {
@ -104,10 +104,10 @@ func TestTrieTracerNoop(t *testing.T) {
func testTrieTracerNoop(t *testing.T, vals []struct{ k, v string }) { func testTrieTracerNoop(t *testing.T, vals []struct{ k, v string }) {
trie := NewEmpty(NewDatabase(rawdb.NewMemoryDatabase())) trie := NewEmpty(NewDatabase(rawdb.NewMemoryDatabase()))
for _, val := range vals { for _, val := range vals {
trie.Update([]byte(val.k), []byte(val.v)) trie.MustUpdate([]byte(val.k), []byte(val.v))
} }
for _, val := range vals { for _, val := range vals {
trie.Delete([]byte(val.k)) trie.MustDelete([]byte(val.k))
} }
if len(trie.tracer.inserts) != 0 { if len(trie.tracer.inserts) != 0 {
t.Fatal("Unexpected insertion set") t.Fatal("Unexpected insertion set")
@ -132,7 +132,7 @@ func testAccessList(t *testing.T, vals []struct{ k, v string }) {
) )
// Create trie from scratch // Create trie from scratch
for _, val := range vals { for _, val := range vals {
trie.Update([]byte(val.k), []byte(val.v)) trie.MustUpdate([]byte(val.k), []byte(val.v))
} }
root, nodes := trie.Commit(false) root, nodes := trie.Commit(false)
db.Update(NewWithNodeSet(nodes)) db.Update(NewWithNodeSet(nodes))
@ -146,7 +146,7 @@ func testAccessList(t *testing.T, vals []struct{ k, v string }) {
trie, _ = New(TrieID(root), db) trie, _ = New(TrieID(root), db)
orig = trie.Copy() orig = trie.Copy()
for _, val := range vals { for _, val := range vals {
trie.Update([]byte(val.k), randBytes(32)) trie.MustUpdate([]byte(val.k), randBytes(32))
} }
root, nodes = trie.Commit(false) root, nodes = trie.Commit(false)
db.Update(NewWithNodeSet(nodes)) db.Update(NewWithNodeSet(nodes))
@ -163,7 +163,7 @@ func testAccessList(t *testing.T, vals []struct{ k, v string }) {
for i := 0; i < 30; i++ { for i := 0; i < 30; i++ {
key := randBytes(32) key := randBytes(32)
keys = append(keys, string(key)) keys = append(keys, string(key))
trie.Update(key, randBytes(32)) trie.MustUpdate(key, randBytes(32))
} }
root, nodes = trie.Commit(false) root, nodes = trie.Commit(false)
db.Update(NewWithNodeSet(nodes)) db.Update(NewWithNodeSet(nodes))
@ -177,7 +177,7 @@ func testAccessList(t *testing.T, vals []struct{ k, v string }) {
trie, _ = New(TrieID(root), db) trie, _ = New(TrieID(root), db)
orig = trie.Copy() orig = trie.Copy()
for _, key := range keys { for _, key := range keys {
trie.Update([]byte(key), nil) trie.MustUpdate([]byte(key), nil)
} }
root, nodes = trie.Commit(false) root, nodes = trie.Commit(false)
db.Update(NewWithNodeSet(nodes)) db.Update(NewWithNodeSet(nodes))
@ -191,7 +191,7 @@ func testAccessList(t *testing.T, vals []struct{ k, v string }) {
trie, _ = New(TrieID(root), db) trie, _ = New(TrieID(root), db)
orig = trie.Copy() orig = trie.Copy()
for _, val := range vals { for _, val := range vals {
trie.Update([]byte(val.k), nil) trie.MustUpdate([]byte(val.k), nil)
} }
root, nodes = trie.Commit(false) root, nodes = trie.Commit(false)
db.Update(NewWithNodeSet(nodes)) db.Update(NewWithNodeSet(nodes))
@ -210,7 +210,7 @@ func TestAccessListLeak(t *testing.T) {
) )
// Create trie from scratch // Create trie from scratch
for _, val := range standard { for _, val := range standard {
trie.Update([]byte(val.k), []byte(val.v)) trie.MustUpdate([]byte(val.k), []byte(val.v))
} }
root, nodes := trie.Commit(false) root, nodes := trie.Commit(false)
db.Update(NewWithNodeSet(nodes)) db.Update(NewWithNodeSet(nodes))
@ -260,7 +260,7 @@ func TestTinyTree(t *testing.T) {
trie = NewEmpty(db) trie = NewEmpty(db)
) )
for _, val := range tiny { for _, val := range tiny {
trie.Update([]byte(val.k), randBytes(32)) trie.MustUpdate([]byte(val.k), randBytes(32))
} }
root, set := trie.Commit(false) root, set := trie.Commit(false)
db.Update(NewWithNodeSet(set)) db.Update(NewWithNodeSet(set))
@ -268,7 +268,7 @@ func TestTinyTree(t *testing.T) {
trie, _ = New(TrieID(root), db) trie, _ = New(TrieID(root), db)
orig := trie.Copy() orig := trie.Copy()
for _, val := range tiny { for _, val := range tiny {
trie.Update([]byte(val.k), []byte(val.v)) trie.MustUpdate([]byte(val.k), []byte(val.v))
} }
root, set = trie.Commit(false) root, set = trie.Commit(false)
db.Update(NewWithNodeSet(set)) db.Update(NewWithNodeSet(set))

View File

@ -105,28 +105,30 @@ func (t *Trie) NodeIterator(start []byte) NodeIterator {
return newNodeIterator(t, start) return newNodeIterator(t, start)
} }
// Get returns the value for key stored in the trie. // MustGet is a wrapper of Get and will omit any encountered error but just
// The value bytes must not be modified by the caller. // print out an error message.
func (t *Trie) Get(key []byte) []byte { func (t *Trie) MustGet(key []byte) []byte {
res, err := t.TryGet(key) res, err := t.Get(key)
if err != nil { if err != nil {
log.Error("Unhandled trie error in Trie.Get", "err", err) log.Error("Unhandled trie error in Trie.Get", "err", err)
} }
return res return res
} }
// TryGet returns the value for key stored in the trie. // Get returns the value for key stored in the trie.
// The value bytes must not be modified by the caller. // The value bytes must not be modified by the caller.
// If a node was not found in the database, a MissingNodeError is returned. //
func (t *Trie) TryGet(key []byte) ([]byte, error) { // If the requested node is not present in trie, no error will be returned.
value, newroot, didResolve, err := t.tryGet(t.root, keybytesToHex(key), 0) // If the trie is corrupted, a MissingNodeError is returned.
func (t *Trie) Get(key []byte) ([]byte, error) {
value, newroot, didResolve, err := t.get(t.root, keybytesToHex(key), 0)
if err == nil && didResolve { if err == nil && didResolve {
t.root = newroot t.root = newroot
} }
return value, err return value, err
} }
func (t *Trie) tryGet(origNode node, key []byte, pos int) (value []byte, newnode node, didResolve bool, err error) { func (t *Trie) get(origNode node, key []byte, pos int) (value []byte, newnode node, didResolve bool, err error) {
switch n := (origNode).(type) { switch n := (origNode).(type) {
case nil: case nil:
return nil, nil, false, nil return nil, nil, false, nil
@ -137,14 +139,14 @@ func (t *Trie) tryGet(origNode node, key []byte, pos int) (value []byte, newnode
// key not found in trie // key not found in trie
return nil, n, false, nil return nil, n, false, nil
} }
value, newnode, didResolve, err = t.tryGet(n.Val, key, pos+len(n.Key)) value, newnode, didResolve, err = t.get(n.Val, key, pos+len(n.Key))
if err == nil && didResolve { if err == nil && didResolve {
n = n.copy() n = n.copy()
n.Val = newnode n.Val = newnode
} }
return value, n, didResolve, err return value, n, didResolve, err
case *fullNode: case *fullNode:
value, newnode, didResolve, err = t.tryGet(n.Children[key[pos]], key, pos+1) value, newnode, didResolve, err = t.get(n.Children[key[pos]], key, pos+1)
if err == nil && didResolve { if err == nil && didResolve {
n = n.copy() n = n.copy()
n.Children[key[pos]] = newnode n.Children[key[pos]] = newnode
@ -155,17 +157,30 @@ func (t *Trie) tryGet(origNode node, key []byte, pos int) (value []byte, newnode
if err != nil { if err != nil {
return nil, n, true, err return nil, n, true, err
} }
value, newnode, _, err := t.tryGet(child, key, pos) value, newnode, _, err := t.get(child, key, pos)
return value, newnode, true, err return value, newnode, true, err
default: default:
panic(fmt.Sprintf("%T: invalid node: %v", origNode, origNode)) panic(fmt.Sprintf("%T: invalid node: %v", origNode, origNode))
} }
} }
// TryGetNode attempts to retrieve a trie node by compact-encoded path. It is not // MustGetNode is a wrapper of GetNode and will omit any encountered error but
// possible to use keybyte-encoding as the path might contain odd nibbles. // just print out an error message.
func (t *Trie) TryGetNode(path []byte) ([]byte, int, error) { func (t *Trie) MustGetNode(path []byte) ([]byte, int) {
item, newroot, resolved, err := t.tryGetNode(t.root, compactToHex(path), 0) item, resolved, err := t.GetNode(path)
if err != nil {
log.Error("Unhandled trie error in Trie.GetNode", "err", err)
}
return item, resolved
}
// GetNode retrieves a trie node by compact-encoded path. It is not possible
// to use keybyte-encoding as the path might contain odd nibbles.
//
// If the requested node is not present in trie, no error will be returned.
// If the trie is corrupted, a MissingNodeError is returned.
func (t *Trie) GetNode(path []byte) ([]byte, int, error) {
item, newroot, resolved, err := t.getNode(t.root, compactToHex(path), 0)
if err != nil { if err != nil {
return nil, resolved, err return nil, resolved, err
} }
@ -175,10 +190,10 @@ func (t *Trie) TryGetNode(path []byte) ([]byte, int, error) {
if item == nil { if item == nil {
return nil, resolved, nil return nil, resolved, nil
} }
return item, resolved, err return item, resolved, nil
} }
func (t *Trie) tryGetNode(origNode node, path []byte, pos int) (item []byte, newnode node, resolved int, err error) { func (t *Trie) getNode(origNode node, path []byte, pos int) (item []byte, newnode node, resolved int, err error) {
// If non-existent path requested, abort // If non-existent path requested, abort
if origNode == nil { if origNode == nil {
return nil, nil, 0, nil return nil, nil, 0, nil
@ -211,7 +226,7 @@ func (t *Trie) tryGetNode(origNode node, path []byte, pos int) (item []byte, new
// Path branches off from short node // Path branches off from short node
return nil, n, 0, nil return nil, n, 0, nil
} }
item, newnode, resolved, err = t.tryGetNode(n.Val, path, pos+len(n.Key)) item, newnode, resolved, err = t.getNode(n.Val, path, pos+len(n.Key))
if err == nil && resolved > 0 { if err == nil && resolved > 0 {
n = n.copy() n = n.copy()
n.Val = newnode n.Val = newnode
@ -219,7 +234,7 @@ func (t *Trie) tryGetNode(origNode node, path []byte, pos int) (item []byte, new
return item, n, resolved, err return item, n, resolved, err
case *fullNode: case *fullNode:
item, newnode, resolved, err = t.tryGetNode(n.Children[path[pos]], path, pos+1) item, newnode, resolved, err = t.getNode(n.Children[path[pos]], path, pos+1)
if err == nil && resolved > 0 { if err == nil && resolved > 0 {
n = n.copy() n = n.copy()
n.Children[path[pos]] = newnode n.Children[path[pos]] = newnode
@ -231,7 +246,7 @@ func (t *Trie) tryGetNode(origNode node, path []byte, pos int) (item []byte, new
if err != nil { if err != nil {
return nil, n, 1, err return nil, n, 1, err
} }
item, newnode, resolved, err := t.tryGetNode(child, path, pos) item, newnode, resolved, err := t.getNode(child, path, pos)
return item, newnode, resolved + 1, err return item, newnode, resolved + 1, err
default: default:
@ -239,33 +254,28 @@ func (t *Trie) tryGetNode(origNode node, path []byte, pos int) (item []byte, new
} }
} }
// MustUpdate is a wrapper of Update and will omit any encountered error but
// just print out an error message.
func (t *Trie) MustUpdate(key, value []byte) {
if err := t.Update(key, value); err != nil {
log.Error("Unhandled trie error in Trie.Update", "err", err)
}
}
// Update associates key with value in the trie. Subsequent calls to // Update associates key with value in the trie. Subsequent calls to
// Get will return value. If value has length zero, any existing value // Get will return value. If value has length zero, any existing value
// is deleted from the trie and calls to Get will return nil. // is deleted from the trie and calls to Get will return nil.
// //
// The value bytes must not be modified by the caller while they are // The value bytes must not be modified by the caller while they are
// stored in the trie. // stored in the trie.
func (t *Trie) Update(key, value []byte) { //
if err := t.TryUpdate(key, value); err != nil { // If the requested node is not present in trie, no error will be returned.
log.Error("Unhandled trie error in Trie.Update", "err", err) // If the trie is corrupted, a MissingNodeError is returned.
} func (t *Trie) Update(key, value []byte) error {
return t.update(key, value)
} }
// TryUpdate associates key with value in the trie. Subsequent calls to func (t *Trie) update(key, value []byte) error {
// Get will return value. If value has length zero, any existing value
// is deleted from the trie and calls to Get will return nil.
//
// The value bytes must not be modified by the caller while they are
// stored in the trie.
//
// If a node was not found in the database, a MissingNodeError is returned.
func (t *Trie) TryUpdate(key, value []byte) error {
return t.tryUpdate(key, value)
}
// tryUpdate expects an RLP-encoded value and performs the core function
// for TryUpdate and TryUpdateAccount.
func (t *Trie) tryUpdate(key, value []byte) error {
t.unhashed++ t.unhashed++
k := keybytesToHex(key) k := keybytesToHex(key)
if len(value) != 0 { if len(value) != 0 {
@ -363,16 +373,19 @@ func (t *Trie) insert(n node, prefix, key []byte, value node) (bool, node, error
} }
} }
// Delete removes any existing value for key from the trie. // MustDelete is a wrapper of Delete and will omit any encountered error but
func (t *Trie) Delete(key []byte) { // just print out an error message.
if err := t.TryDelete(key); err != nil { func (t *Trie) MustDelete(key []byte) {
if err := t.Delete(key); err != nil {
log.Error("Unhandled trie error in Trie.Delete", "err", err) log.Error("Unhandled trie error in Trie.Delete", "err", err)
} }
} }
// TryDelete removes any existing value for key from the trie. // Delete removes any existing value for key from the trie.
// If a node was not found in the database, a MissingNodeError is returned. //
func (t *Trie) TryDelete(key []byte) error { // If the requested node is not present in trie, no error will be returned.
// If the trie is corrupted, a MissingNodeError is returned.
func (t *Trie) Delete(key []byte) error {
t.unhashed++ t.unhashed++
k := keybytesToHex(key) k := keybytesToHex(key)
_, n, err := t.delete(t.root, nil, k) _, n, err := t.delete(t.root, nil, k)

View File

@ -56,8 +56,8 @@ func TestNull(t *testing.T) {
trie := NewEmpty(NewDatabase(rawdb.NewMemoryDatabase())) trie := NewEmpty(NewDatabase(rawdb.NewMemoryDatabase()))
key := make([]byte, 32) key := make([]byte, 32)
value := []byte("test") value := []byte("test")
trie.Update(key, value) trie.MustUpdate(key, value)
if !bytes.Equal(trie.Get(key), value) { if !bytes.Equal(trie.MustGet(key), value) {
t.Fatal("wrong value") t.Fatal("wrong value")
} }
} }
@ -90,27 +90,27 @@ func testMissingNode(t *testing.T, memonly bool) {
} }
trie, _ = New(TrieID(root), triedb) trie, _ = New(TrieID(root), triedb)
_, err := trie.TryGet([]byte("120000")) _, err := trie.Get([]byte("120000"))
if err != nil { if err != nil {
t.Errorf("Unexpected error: %v", err) t.Errorf("Unexpected error: %v", err)
} }
trie, _ = New(TrieID(root), triedb) trie, _ = New(TrieID(root), triedb)
_, err = trie.TryGet([]byte("120099")) _, err = trie.Get([]byte("120099"))
if err != nil { if err != nil {
t.Errorf("Unexpected error: %v", err) t.Errorf("Unexpected error: %v", err)
} }
trie, _ = New(TrieID(root), triedb) trie, _ = New(TrieID(root), triedb)
_, err = trie.TryGet([]byte("123456")) _, err = trie.Get([]byte("123456"))
if err != nil { if err != nil {
t.Errorf("Unexpected error: %v", err) t.Errorf("Unexpected error: %v", err)
} }
trie, _ = New(TrieID(root), triedb) trie, _ = New(TrieID(root), triedb)
err = trie.TryUpdate([]byte("120099"), []byte("zxcvzxcvzxcvzxcvzxcvzxcvzxcvzxcv")) err = trie.Update([]byte("120099"), []byte("zxcvzxcvzxcvzxcvzxcvzxcvzxcvzxcv"))
if err != nil { if err != nil {
t.Errorf("Unexpected error: %v", err) t.Errorf("Unexpected error: %v", err)
} }
trie, _ = New(TrieID(root), triedb) trie, _ = New(TrieID(root), triedb)
err = trie.TryDelete([]byte("123456")) err = trie.Delete([]byte("123456"))
if err != nil { if err != nil {
t.Errorf("Unexpected error: %v", err) t.Errorf("Unexpected error: %v", err)
} }
@ -123,27 +123,27 @@ func testMissingNode(t *testing.T, memonly bool) {
} }
trie, _ = New(TrieID(root), triedb) trie, _ = New(TrieID(root), triedb)
_, err = trie.TryGet([]byte("120000")) _, err = trie.Get([]byte("120000"))
if _, ok := err.(*MissingNodeError); !ok { if _, ok := err.(*MissingNodeError); !ok {
t.Errorf("Wrong error: %v", err) t.Errorf("Wrong error: %v", err)
} }
trie, _ = New(TrieID(root), triedb) trie, _ = New(TrieID(root), triedb)
_, err = trie.TryGet([]byte("120099")) _, err = trie.Get([]byte("120099"))
if _, ok := err.(*MissingNodeError); !ok { if _, ok := err.(*MissingNodeError); !ok {
t.Errorf("Wrong error: %v", err) t.Errorf("Wrong error: %v", err)
} }
trie, _ = New(TrieID(root), triedb) trie, _ = New(TrieID(root), triedb)
_, err = trie.TryGet([]byte("123456")) _, err = trie.Get([]byte("123456"))
if err != nil { if err != nil {
t.Errorf("Unexpected error: %v", err) t.Errorf("Unexpected error: %v", err)
} }
trie, _ = New(TrieID(root), triedb) trie, _ = New(TrieID(root), triedb)
err = trie.TryUpdate([]byte("120099"), []byte("zxcv")) err = trie.Update([]byte("120099"), []byte("zxcv"))
if _, ok := err.(*MissingNodeError); !ok { if _, ok := err.(*MissingNodeError); !ok {
t.Errorf("Wrong error: %v", err) t.Errorf("Wrong error: %v", err)
} }
trie, _ = New(TrieID(root), triedb) trie, _ = New(TrieID(root), triedb)
err = trie.TryDelete([]byte("123456")) err = trie.Delete([]byte("123456"))
if _, ok := err.(*MissingNodeError); !ok { if _, ok := err.(*MissingNodeError); !ok {
t.Errorf("Wrong error: %v", err) t.Errorf("Wrong error: %v", err)
} }
@ -311,8 +311,8 @@ func TestReplication(t *testing.T) {
func TestLargeValue(t *testing.T) { func TestLargeValue(t *testing.T) {
trie := NewEmpty(NewDatabase(rawdb.NewMemoryDatabase())) trie := NewEmpty(NewDatabase(rawdb.NewMemoryDatabase()))
trie.Update([]byte("key1"), []byte{99, 99, 99, 99}) trie.MustUpdate([]byte("key1"), []byte{99, 99, 99, 99})
trie.Update([]byte("key2"), bytes.Repeat([]byte{1}, 32)) trie.MustUpdate([]byte("key2"), bytes.Repeat([]byte{1}, 32))
trie.Hash() trie.Hash()
} }
@ -460,13 +460,13 @@ func runRandTest(rt randTest) bool {
switch step.op { switch step.op {
case opUpdate: case opUpdate:
tr.Update(step.key, step.value) tr.MustUpdate(step.key, step.value)
values[string(step.key)] = string(step.value) values[string(step.key)] = string(step.value)
case opDelete: case opDelete:
tr.Delete(step.key) tr.MustDelete(step.key)
delete(values, string(step.key)) delete(values, string(step.key))
case opGet: case opGet:
v := tr.Get(step.key) v := tr.MustGet(step.key)
want := values[string(step.key)] want := values[string(step.key)]
if string(v) != want { if string(v) != want {
rt[i].err = fmt.Errorf("mismatch for key %#x, got %#x want %#x", step.key, v, want) rt[i].err = fmt.Errorf("mismatch for key %#x, got %#x want %#x", step.key, v, want)
@ -509,7 +509,7 @@ func runRandTest(rt randTest) bool {
checktr := NewEmpty(triedb) checktr := NewEmpty(triedb)
it := NewIterator(tr.NodeIterator(nil)) it := NewIterator(tr.NodeIterator(nil))
for it.Next() { for it.Next() {
checktr.Update(it.Key, it.Value) checktr.MustUpdate(it.Key, it.Value)
} }
if tr.Hash() != checktr.Hash() { if tr.Hash() != checktr.Hash() {
rt[i].err = fmt.Errorf("hash mismatch in opItercheckhash") rt[i].err = fmt.Errorf("hash mismatch in opItercheckhash")
@ -595,13 +595,13 @@ func benchGet(b *testing.B) {
k := make([]byte, 32) k := make([]byte, 32)
for i := 0; i < benchElemCount; i++ { for i := 0; i < benchElemCount; i++ {
binary.LittleEndian.PutUint64(k, uint64(i)) binary.LittleEndian.PutUint64(k, uint64(i))
trie.Update(k, k) trie.MustUpdate(k, k)
} }
binary.LittleEndian.PutUint64(k, benchElemCount/2) binary.LittleEndian.PutUint64(k, benchElemCount/2)
b.ResetTimer() b.ResetTimer()
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
trie.Get(k) trie.MustGet(k)
} }
b.StopTimer() b.StopTimer()
} }
@ -612,7 +612,7 @@ func benchUpdate(b *testing.B, e binary.ByteOrder) *Trie {
b.ReportAllocs() b.ReportAllocs()
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
e.PutUint64(k, uint64(i)) e.PutUint64(k, uint64(i))
trie.Update(k, k) trie.MustUpdate(k, k)
} }
return trie return trie
} }
@ -640,11 +640,11 @@ func BenchmarkHash(b *testing.B) {
trie := NewEmpty(NewDatabase(rawdb.NewMemoryDatabase())) trie := NewEmpty(NewDatabase(rawdb.NewMemoryDatabase()))
i := 0 i := 0
for ; i < len(addresses)/2; i++ { for ; i < len(addresses)/2; i++ {
trie.Update(crypto.Keccak256(addresses[i][:]), accounts[i]) trie.MustUpdate(crypto.Keccak256(addresses[i][:]), accounts[i])
} }
trie.Hash() trie.Hash()
for ; i < len(addresses); i++ { for ; i < len(addresses); i++ {
trie.Update(crypto.Keccak256(addresses[i][:]), accounts[i]) trie.MustUpdate(crypto.Keccak256(addresses[i][:]), accounts[i])
} }
b.ResetTimer() b.ResetTimer()
b.ReportAllocs() b.ReportAllocs()
@ -670,7 +670,7 @@ func benchmarkCommitAfterHash(b *testing.B, collectLeaf bool) {
addresses, accounts := makeAccounts(b.N) addresses, accounts := makeAccounts(b.N)
trie := NewEmpty(NewDatabase(rawdb.NewMemoryDatabase())) trie := NewEmpty(NewDatabase(rawdb.NewMemoryDatabase()))
for i := 0; i < len(addresses); i++ { for i := 0; i < len(addresses); i++ {
trie.Update(crypto.Keccak256(addresses[i][:]), accounts[i]) trie.MustUpdate(crypto.Keccak256(addresses[i][:]), accounts[i])
} }
// Insert the accounts into the trie and hash it // Insert the accounts into the trie and hash it
trie.Hash() trie.Hash()
@ -683,22 +683,22 @@ func TestTinyTrie(t *testing.T) {
// Create a realistic account trie to hash // Create a realistic account trie to hash
_, accounts := makeAccounts(5) _, accounts := makeAccounts(5)
trie := NewEmpty(NewDatabase(rawdb.NewMemoryDatabase())) trie := NewEmpty(NewDatabase(rawdb.NewMemoryDatabase()))
trie.Update(common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000001337"), accounts[3]) trie.MustUpdate(common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000001337"), accounts[3])
if exp, root := common.HexToHash("8c6a85a4d9fda98feff88450299e574e5378e32391f75a055d470ac0653f1005"), trie.Hash(); exp != root { if exp, root := common.HexToHash("8c6a85a4d9fda98feff88450299e574e5378e32391f75a055d470ac0653f1005"), trie.Hash(); exp != root {
t.Errorf("1: got %x, exp %x", root, exp) t.Errorf("1: got %x, exp %x", root, exp)
} }
trie.Update(common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000001338"), accounts[4]) trie.MustUpdate(common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000001338"), accounts[4])
if exp, root := common.HexToHash("ec63b967e98a5720e7f720482151963982890d82c9093c0d486b7eb8883a66b1"), trie.Hash(); exp != root { if exp, root := common.HexToHash("ec63b967e98a5720e7f720482151963982890d82c9093c0d486b7eb8883a66b1"), trie.Hash(); exp != root {
t.Errorf("2: got %x, exp %x", root, exp) t.Errorf("2: got %x, exp %x", root, exp)
} }
trie.Update(common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000001339"), accounts[4]) trie.MustUpdate(common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000001339"), accounts[4])
if exp, root := common.HexToHash("0608c1d1dc3905fa22204c7a0e43644831c3b6d3def0f274be623a948197e64a"), trie.Hash(); exp != root { if exp, root := common.HexToHash("0608c1d1dc3905fa22204c7a0e43644831c3b6d3def0f274be623a948197e64a"), trie.Hash(); exp != root {
t.Errorf("3: got %x, exp %x", root, exp) t.Errorf("3: got %x, exp %x", root, exp)
} }
checktr := NewEmpty(NewDatabase(rawdb.NewMemoryDatabase())) checktr := NewEmpty(NewDatabase(rawdb.NewMemoryDatabase()))
it := NewIterator(trie.NodeIterator(nil)) it := NewIterator(trie.NodeIterator(nil))
for it.Next() { for it.Next() {
checktr.Update(it.Key, it.Value) checktr.MustUpdate(it.Key, it.Value)
} }
if troot, itroot := trie.Hash(), checktr.Hash(); troot != itroot { if troot, itroot := trie.Hash(), checktr.Hash(); troot != itroot {
t.Fatalf("hash mismatch in opItercheckhash, trie: %x, check: %x", troot, itroot) t.Fatalf("hash mismatch in opItercheckhash, trie: %x, check: %x", troot, itroot)
@ -710,7 +710,7 @@ func TestCommitAfterHash(t *testing.T) {
addresses, accounts := makeAccounts(1000) addresses, accounts := makeAccounts(1000)
trie := NewEmpty(NewDatabase(rawdb.NewMemoryDatabase())) trie := NewEmpty(NewDatabase(rawdb.NewMemoryDatabase()))
for i := 0; i < len(addresses); i++ { for i := 0; i < len(addresses); i++ {
trie.Update(crypto.Keccak256(addresses[i][:]), accounts[i]) trie.MustUpdate(crypto.Keccak256(addresses[i][:]), accounts[i])
} }
// Insert the accounts into the trie and hash it // Insert the accounts into the trie and hash it
trie.Hash() trie.Hash()
@ -820,7 +820,7 @@ func TestCommitSequence(t *testing.T) {
trie := NewEmpty(db) trie := NewEmpty(db)
// Fill the trie with elements // Fill the trie with elements
for i := 0; i < tc.count; i++ { for i := 0; i < tc.count; i++ {
trie.Update(crypto.Keccak256(addresses[i][:]), accounts[i]) trie.MustUpdate(crypto.Keccak256(addresses[i][:]), accounts[i])
} }
// Flush trie -> database // Flush trie -> database
root, nodes := trie.Commit(false) root, nodes := trie.Commit(false)
@ -861,7 +861,7 @@ func TestCommitSequenceRandomBlobs(t *testing.T) {
} }
prng.Read(key) prng.Read(key)
prng.Read(val) prng.Read(val)
trie.Update(key, val) trie.MustUpdate(key, val)
} }
// Flush trie -> database // Flush trie -> database
root, nodes := trie.Commit(false) root, nodes := trie.Commit(false)
@ -899,8 +899,8 @@ func TestCommitSequenceStackTrie(t *testing.T) {
val = make([]byte, 1+prng.Intn(1024)) val = make([]byte, 1+prng.Intn(1024))
} }
prng.Read(val) prng.Read(val)
trie.TryUpdate(key, val) trie.Update(key, val)
stTrie.TryUpdate(key, val) stTrie.Update(key, val)
} }
// Flush trie -> database // Flush trie -> database
root, nodes := trie.Commit(false) root, nodes := trie.Commit(false)
@ -948,8 +948,8 @@ func TestCommitSequenceSmallRoot(t *testing.T) {
// Add a single small-element to the trie(s) // Add a single small-element to the trie(s)
key := make([]byte, 5) key := make([]byte, 5)
key[0] = 1 key[0] = 1
trie.TryUpdate(key, []byte{0x1}) trie.Update(key, []byte{0x1})
stTrie.TryUpdate(key, []byte{0x1}) stTrie.Update(key, []byte{0x1})
// Flush trie -> database // Flush trie -> database
root, nodes := trie.Commit(false) root, nodes := trie.Commit(false)
// Flush memdb -> disk (sponge) // Flush memdb -> disk (sponge)
@ -1017,7 +1017,7 @@ func benchmarkHashFixedSize(b *testing.B, addresses [][20]byte, accounts [][]byt
b.ReportAllocs() b.ReportAllocs()
trie := NewEmpty(NewDatabase(rawdb.NewMemoryDatabase())) trie := NewEmpty(NewDatabase(rawdb.NewMemoryDatabase()))
for i := 0; i < len(addresses); i++ { for i := 0; i < len(addresses); i++ {
trie.Update(crypto.Keccak256(addresses[i][:]), accounts[i]) trie.MustUpdate(crypto.Keccak256(addresses[i][:]), accounts[i])
} }
// Insert the accounts into the trie and hash it // Insert the accounts into the trie and hash it
b.StartTimer() b.StartTimer()
@ -1068,7 +1068,7 @@ func benchmarkCommitAfterHashFixedSize(b *testing.B, addresses [][20]byte, accou
b.ReportAllocs() b.ReportAllocs()
trie := NewEmpty(NewDatabase(rawdb.NewMemoryDatabase())) trie := NewEmpty(NewDatabase(rawdb.NewMemoryDatabase()))
for i := 0; i < len(addresses); i++ { for i := 0; i < len(addresses); i++ {
trie.Update(crypto.Keccak256(addresses[i][:]), accounts[i]) trie.MustUpdate(crypto.Keccak256(addresses[i][:]), accounts[i])
} }
// Insert the accounts into the trie and hash it // Insert the accounts into the trie and hash it
trie.Hash() trie.Hash()
@ -1121,7 +1121,7 @@ func benchmarkDerefRootFixedSize(b *testing.B, addresses [][20]byte, accounts []
triedb := NewDatabase(rawdb.NewMemoryDatabase()) triedb := NewDatabase(rawdb.NewMemoryDatabase())
trie := NewEmpty(triedb) trie := NewEmpty(triedb)
for i := 0; i < len(addresses); i++ { for i := 0; i < len(addresses); i++ {
trie.Update(crypto.Keccak256(addresses[i][:]), accounts[i]) trie.MustUpdate(crypto.Keccak256(addresses[i][:]), accounts[i])
} }
h := trie.Hash() h := trie.Hash()
_, nodes := trie.Commit(false) _, nodes := trie.Commit(false)
@ -1132,15 +1132,15 @@ func benchmarkDerefRootFixedSize(b *testing.B, addresses [][20]byte, accounts []
} }
func getString(trie *Trie, k string) []byte { func getString(trie *Trie, k string) []byte {
return trie.Get([]byte(k)) return trie.MustGet([]byte(k))
} }
func updateString(trie *Trie, k, v string) { func updateString(trie *Trie, k, v string) {
trie.Update([]byte(k), []byte(v)) trie.MustUpdate([]byte(k), []byte(v))
} }
func deleteString(trie *Trie, k string) { func deleteString(trie *Trie, k string) {
trie.Delete([]byte(k)) trie.MustDelete([]byte(k))
} }
func TestDecodeNode(t *testing.T) { func TestDecodeNode(t *testing.T) {