diff --git a/core/rawdb/accessors_indexes_test.go b/core/rawdb/accessors_indexes_test.go index 4734e986e..c5447b16d 100644 --- a/core/rawdb/accessors_indexes_test.go +++ b/core/rawdb/accessors_indexes_test.go @@ -45,9 +45,10 @@ func (h *testHasher) Reset() { h.hasher.Reset() } -func (h *testHasher) Update(key, val []byte) { +func (h *testHasher) Update(key, val []byte) error { h.hasher.Write(key) h.hasher.Write(val) + return nil } func (h *testHasher) Hash() common.Hash { diff --git a/core/state/snapshot/conversion.go b/core/state/snapshot/conversion.go index ed7cb963a..0e583e626 100644 --- a/core/state/snapshot/conversion.go +++ b/core/state/snapshot/conversion.go @@ -371,7 +371,7 @@ func stackTrieGenerate(db ethdb.KeyValueWriter, scheme string, owner common.Hash } t := trie.NewStackTrieWithOwner(nodeWriter, owner) for leaf := range in { - t.TryUpdate(leaf.key[:], leaf.value) + t.Update(leaf.key[:], leaf.value) } var root common.Hash if db == nil { diff --git a/core/state/snapshot/generate.go b/core/state/snapshot/generate.go index d46705d31..68c2f574b 100644 --- a/core/state/snapshot/generate.go +++ b/core/state/snapshot/generate.go @@ -230,7 +230,7 @@ func (dl *diskLayer) proveRange(ctx *generatorContext, trieId *trie.ID, prefix [ if origin == nil && !diskMore { stackTr := trie.NewStackTrie(nil) for i, key := range keys { - stackTr.TryUpdate(key, vals[i]) + stackTr.Update(key, vals[i]) } if gotRoot := stackTr.Hash(); gotRoot != root { return &proofResult{ diff --git a/core/state/snapshot/generate_test.go b/core/state/snapshot/generate_test.go index 1bac4fd56..546132e7d 100644 --- a/core/state/snapshot/generate_test.go +++ b/core/state/snapshot/generate_test.go @@ -161,7 +161,7 @@ func newHelper() *testHelper { func (t *testHelper) addTrieAccount(acckey string, acc *Account) { val, _ := rlp.EncodeToBytes(acc) - t.accTrie.Update([]byte(acckey), val) + t.accTrie.MustUpdate([]byte(acckey), val) } func (t *testHelper) addSnapAccount(acckey string, acc *Account) { @@ -186,7 +186,7 @@ func (t *testHelper) makeStorageTrie(stateRoot, owner common.Hash, keys []string id := trie.StorageTrieID(stateRoot, owner, common.Hash{}) stTrie, _ := trie.NewStateTrie(id, t.triedb) for i, k := range keys { - stTrie.Update([]byte(k), []byte(vals[i])) + stTrie.MustUpdate([]byte(k), []byte(vals[i])) } if !commit { return stTrie.Hash().Bytes() @@ -491,7 +491,7 @@ func TestGenerateWithExtraAccounts(t *testing.T) { ) acc := &Account{Balance: big.NewInt(1), Root: stRoot, CodeHash: types.EmptyCodeHash.Bytes()} val, _ := rlp.EncodeToBytes(acc) - helper.accTrie.Update([]byte("acc-1"), val) // 0x9250573b9c18c664139f3b6a7a8081b7d8f8916a8fcc5d94feec6c29f5fd4e9e + helper.accTrie.MustUpdate([]byte("acc-1"), val) // 0x9250573b9c18c664139f3b6a7a8081b7d8f8916a8fcc5d94feec6c29f5fd4e9e // Identical in the snap key := hashData([]byte("acc-1")) @@ -562,7 +562,7 @@ func TestGenerateWithManyExtraAccounts(t *testing.T) { ) acc := &Account{Balance: big.NewInt(1), Root: stRoot, CodeHash: types.EmptyCodeHash.Bytes()} val, _ := rlp.EncodeToBytes(acc) - helper.accTrie.Update([]byte("acc-1"), val) // 0x9250573b9c18c664139f3b6a7a8081b7d8f8916a8fcc5d94feec6c29f5fd4e9e + helper.accTrie.MustUpdate([]byte("acc-1"), val) // 0x9250573b9c18c664139f3b6a7a8081b7d8f8916a8fcc5d94feec6c29f5fd4e9e // Identical in the snap key := hashData([]byte("acc-1")) @@ -613,8 +613,8 @@ func TestGenerateWithExtraBeforeAndAfter(t *testing.T) { { acc := &Account{Balance: big.NewInt(1), Root: types.EmptyRootHash.Bytes(), CodeHash: types.EmptyCodeHash.Bytes()} val, _ := rlp.EncodeToBytes(acc) - helper.accTrie.Update(common.HexToHash("0x03").Bytes(), val) - helper.accTrie.Update(common.HexToHash("0x07").Bytes(), val) + helper.accTrie.MustUpdate(common.HexToHash("0x03").Bytes(), val) + helper.accTrie.MustUpdate(common.HexToHash("0x07").Bytes(), val) rawdb.WriteAccountSnapshot(helper.diskdb, common.HexToHash("0x01"), val) rawdb.WriteAccountSnapshot(helper.diskdb, common.HexToHash("0x02"), val) @@ -650,7 +650,7 @@ func TestGenerateWithMalformedSnapdata(t *testing.T) { { acc := &Account{Balance: big.NewInt(1), Root: types.EmptyRootHash.Bytes(), CodeHash: types.EmptyCodeHash.Bytes()} val, _ := rlp.EncodeToBytes(acc) - helper.accTrie.Update(common.HexToHash("0x03").Bytes(), val) + helper.accTrie.MustUpdate(common.HexToHash("0x03").Bytes(), val) junk := make([]byte, 100) copy(junk, []byte{0xde, 0xad}) diff --git a/core/state/sync_test.go b/core/state/sync_test.go index aff91268a..090d55e47 100644 --- a/core/state/sync_test.go +++ b/core/state/sync_test.go @@ -213,14 +213,14 @@ func testIterativeStateSync(t *testing.T, count int, commit bool, bypath bool) { for i, node := range nodeElements { if bypath { if len(node.syncPath) == 1 { - data, _, err := srcTrie.TryGetNode(node.syncPath[0]) + data, _, err := srcTrie.GetNode(node.syncPath[0]) if err != nil { t.Fatalf("failed to retrieve node data for path %x: %v", node.syncPath[0], err) } nodeResults[i] = trie.NodeSyncResult{Path: node.path, Data: data} } else { var acc types.StateAccount - if err := rlp.DecodeBytes(srcTrie.Get(node.syncPath[0]), &acc); err != nil { + if err := rlp.DecodeBytes(srcTrie.MustGet(node.syncPath[0]), &acc); err != nil { t.Fatalf("failed to decode account on path %x: %v", node.syncPath[0], err) } id := trie.StorageTrieID(srcRoot, common.BytesToHash(node.syncPath[0]), acc.Root) @@ -228,7 +228,7 @@ func testIterativeStateSync(t *testing.T, count int, commit bool, bypath bool) { if err != nil { t.Fatalf("failed to retriev storage trie for path %x: %v", node.syncPath[1], err) } - data, _, err := stTrie.TryGetNode(node.syncPath[1]) + data, _, err := stTrie.GetNode(node.syncPath[1]) if err != nil { t.Fatalf("failed to retrieve node data for path %x: %v", node.syncPath[1], err) } diff --git a/core/types/block_test.go b/core/types/block_test.go index 49197c923..966015eb0 100644 --- a/core/types/block_test.go +++ b/core/types/block_test.go @@ -232,9 +232,10 @@ func (h *testHasher) Reset() { h.hasher.Reset() } -func (h *testHasher) Update(key, val []byte) { +func (h *testHasher) Update(key, val []byte) error { h.hasher.Write(key) h.hasher.Write(val) + return nil } func (h *testHasher) Hash() common.Hash { diff --git a/core/types/hashing.go b/core/types/hashing.go index 71295c107..fbdeaf0d0 100644 --- a/core/types/hashing.go +++ b/core/types/hashing.go @@ -62,7 +62,7 @@ func prefixedRlpHash(prefix byte, x interface{}) (h common.Hash) { // This is internal, do not use. type TrieHasher interface { Reset() - Update([]byte, []byte) + Update([]byte, []byte) error Hash() common.Hash } @@ -93,6 +93,9 @@ func DeriveSha(list DerivableList, hasher TrieHasher) common.Hash { // StackTrie requires values to be inserted in increasing hash order, which is not the // order that `list` provides hashes in. This insertion sequence ensures that the // order is correct. + // + // The error returned by hasher is omitted because hasher will produce an incorrect + // hash in case any error occurs. var indexBuf []byte for i := 1; i < list.Len() && i <= 0x7f; i++ { indexBuf = rlp.AppendUint64(indexBuf[:0], uint64(i)) diff --git a/core/types/hashing_test.go b/core/types/hashing_test.go index 294a3977d..c5b9f690d 100644 --- a/core/types/hashing_test.go +++ b/core/types/hashing_test.go @@ -219,9 +219,10 @@ func (d *hashToHumanReadable) Reset() { d.data = make([]byte, 0) } -func (d *hashToHumanReadable) Update(i []byte, i2 []byte) { +func (d *hashToHumanReadable) Update(i []byte, i2 []byte) error { l := fmt.Sprintf("%x %x\n", i, i2) d.data = append(d.data, []byte(l)...) + return nil } func (d *hashToHumanReadable) Hash() common.Hash { diff --git a/eth/protocols/snap/sync_test.go b/eth/protocols/snap/sync_test.go index 0a6117972..6a3b482d5 100644 --- a/eth/protocols/snap/sync_test.go +++ b/eth/protocols/snap/sync_test.go @@ -216,7 +216,7 @@ func defaultTrieRequestHandler(t *testPeer, requestId uint64, root common.Hash, for _, pathset := range paths { switch len(pathset) { case 1: - blob, _, err := t.accountTrie.TryGetNode(pathset[0]) + blob, _, err := t.accountTrie.GetNode(pathset[0]) if err != nil { t.logger.Info("Error handling req", "error", err) break @@ -225,7 +225,7 @@ func defaultTrieRequestHandler(t *testPeer, requestId uint64, root common.Hash, default: account := t.storageTries[(common.BytesToHash(pathset[0]))] for _, path := range pathset[1:] { - blob, _, err := account.TryGetNode(path) + blob, _, err := account.GetNode(path) if err != nil { t.logger.Info("Error handling req", "error", err) break @@ -1381,7 +1381,7 @@ func makeAccountTrieNoStorage(n int) (string, *trie.Trie, entrySlice) { }) key := key32(i) elem := &kv{key, value} - accTrie.Update(elem.k, elem.v) + accTrie.MustUpdate(elem.k, elem.v) entries = append(entries, elem) } sort.Sort(entries) @@ -1431,7 +1431,7 @@ func makeBoundaryAccountTrie(n int) (string, *trie.Trie, entrySlice) { CodeHash: getCodeHash(uint64(i)), }) elem := &kv{boundaries[i].Bytes(), value} - accTrie.Update(elem.k, elem.v) + accTrie.MustUpdate(elem.k, elem.v) entries = append(entries, elem) } // Fill other accounts if required @@ -1443,7 +1443,7 @@ func makeBoundaryAccountTrie(n int) (string, *trie.Trie, entrySlice) { CodeHash: getCodeHash(i), }) elem := &kv{key32(i), value} - accTrie.Update(elem.k, elem.v) + accTrie.MustUpdate(elem.k, elem.v) entries = append(entries, elem) } sort.Sort(entries) @@ -1487,7 +1487,7 @@ func makeAccountTrieWithStorageWithUniqueStorage(accounts, slots int, code bool) CodeHash: codehash, }) elem := &kv{key, value} - accTrie.Update(elem.k, elem.v) + accTrie.MustUpdate(elem.k, elem.v) entries = append(entries, elem) storageRoots[common.BytesToHash(key)] = stRoot @@ -1551,7 +1551,7 @@ func makeAccountTrieWithStorage(accounts, slots int, code, boundary bool) (strin CodeHash: codehash, }) elem := &kv{key, value} - accTrie.Update(elem.k, elem.v) + accTrie.MustUpdate(elem.k, elem.v) entries = append(entries, elem) // we reuse the same one for all accounts @@ -1599,7 +1599,7 @@ func makeStorageTrieWithSeed(owner common.Hash, n, seed uint64, db *trie.Databas key := crypto.Keccak256Hash(slotKey[:]) elem := &kv{key[:], rlpSlotValue} - trie.Update(elem.k, elem.v) + trie.MustUpdate(elem.k, elem.v) entries = append(entries, elem) } sort.Sort(entries) @@ -1638,7 +1638,7 @@ func makeBoundaryStorageTrie(owner common.Hash, n int, db *trie.Database) (commo val := []byte{0xde, 0xad, 0xbe, 0xef} elem := &kv{key[:], val} - trie.Update(elem.k, elem.v) + trie.MustUpdate(elem.k, elem.v) entries = append(entries, elem) } // Fill other slots if required @@ -1650,7 +1650,7 @@ func makeBoundaryStorageTrie(owner common.Hash, n int, db *trie.Database) (commo rlpSlotValue, _ := rlp.EncodeToBytes(common.TrimLeftZeroes(slotValue[:])) elem := &kv{key[:], rlpSlotValue} - trie.Update(elem.k, elem.v) + trie.MustUpdate(elem.k, elem.v) entries = append(entries, elem) } sort.Sort(entries) diff --git a/les/server_handler.go b/les/server_handler.go index 2ea496ac2..39c7ace1c 100644 --- a/les/server_handler.go +++ b/les/server_handler.go @@ -364,7 +364,7 @@ func getAccount(triedb *trie.Database, root, hash common.Hash) (types.StateAccou if err != nil { return types.StateAccount{}, err } - blob, err := trie.TryGet(hash[:]) + blob, err := trie.Get(hash[:]) if err != nil { return types.StateAccount{}, err } diff --git a/light/postprocess.go b/light/postprocess.go index e800a1f0f..763ba2752 100644 --- a/light/postprocess.go +++ b/light/postprocess.go @@ -206,8 +206,7 @@ func (c *ChtIndexerBackend) Process(ctx context.Context, header *types.Header) e var encNumber [8]byte binary.BigEndian.PutUint64(encNumber[:], num) data, _ := rlp.EncodeToBytes(ChtNode{hash, td}) - c.trie.Update(encNumber[:], data) - return nil + return c.trie.Update(encNumber[:], data) } // Commit implements core.ChainIndexerBackend @@ -450,10 +449,15 @@ func (b *BloomTrieIndexerBackend) Commit() error { decompSize += uint64(len(decomp)) compSize += uint64(len(comp)) + + var terr error if len(comp) > 0 { - b.trie.Update(encKey[:], comp) + terr = b.trie.Update(encKey[:], comp) } else { - b.trie.Delete(encKey[:]) + terr = b.trie.Delete(encKey[:]) + } + if terr != nil { + return terr } } root, nodes := b.trie.Commit(false) diff --git a/light/trie.go b/light/trie.go index 2f8f71c61..38dd6b5c2 100644 --- a/light/trie.go +++ b/light/trie.go @@ -109,7 +109,7 @@ func (t *odrTrie) GetStorage(_ common.Address, key []byte) ([]byte, error) { key = crypto.Keccak256(key) var res []byte err := t.do(key, func() (err error) { - res, err = t.trie.TryGet(key) + res, err = t.trie.Get(key) return err }) return res, err @@ -119,7 +119,7 @@ func (t *odrTrie) GetAccount(address common.Address) (*types.StateAccount, error var res types.StateAccount key := crypto.Keccak256(address.Bytes()) err := t.do(key, func() (err error) { - value, err := t.trie.TryGet(key) + value, err := t.trie.Get(key) if err != nil { return err } @@ -138,21 +138,21 @@ func (t *odrTrie) UpdateAccount(address common.Address, acc *types.StateAccount) return fmt.Errorf("decoding error in account update: %w", err) } return t.do(key, func() error { - return t.trie.TryUpdate(key, value) + return t.trie.Update(key, value) }) } func (t *odrTrie) UpdateStorage(_ common.Address, key, value []byte) error { key = crypto.Keccak256(key) return t.do(key, func() error { - return t.trie.TryUpdate(key, value) + return t.trie.Update(key, value) }) } func (t *odrTrie) DeleteStorage(_ common.Address, key []byte) error { key = crypto.Keccak256(key) return t.do(key, func() error { - return t.trie.TryDelete(key) + return t.trie.Delete(key) }) } @@ -160,7 +160,7 @@ func (t *odrTrie) DeleteStorage(_ common.Address, key []byte) error { func (t *odrTrie) DeleteAccount(address common.Address) error { key := crypto.Keccak256(address.Bytes()) return t.do(key, func() error { - return t.trie.TryDelete(key) + return t.trie.Delete(key) }) } diff --git a/tests/fuzzers/les/les-fuzzer.go b/tests/fuzzers/les/les-fuzzer.go index 924a749e5..926de0458 100644 --- a/tests/fuzzers/les/les-fuzzer.go +++ b/tests/fuzzers/les/les-fuzzer.go @@ -93,13 +93,13 @@ func makeTries() (chtTrie *trie.Trie, bloomTrie *trie.Trie, chtKeys, bloomKeys [ // The element in CHT is -> key := make([]byte, 8) binary.BigEndian.PutUint64(key, uint64(i+1)) - chtTrie.Update(key, []byte{0x1, 0xf}) + chtTrie.MustUpdate(key, []byte{0x1, 0xf}) chtKeys = append(chtKeys, key) // The element in Bloom trie is <2 byte bit index> + -> bloom key2 := make([]byte, 10) binary.BigEndian.PutUint64(key2[2:], uint64(i+1)) - bloomTrie.Update(key2, []byte{0x2, 0xe}) + bloomTrie.MustUpdate(key2, []byte{0x2, 0xe}) bloomKeys = append(bloomKeys, key2) } return diff --git a/tests/fuzzers/rangeproof/rangeproof-fuzzer.go b/tests/fuzzers/rangeproof/rangeproof-fuzzer.go index bca93bbe1..2881c7a7c 100644 --- a/tests/fuzzers/rangeproof/rangeproof-fuzzer.go +++ b/tests/fuzzers/rangeproof/rangeproof-fuzzer.go @@ -69,8 +69,8 @@ func (f *fuzzer) randomTrie(n int) (*trie.Trie, map[string]*kv) { for i := byte(0); i < byte(size); i++ { value := &kv{common.LeftPadBytes([]byte{i}, 32), []byte{i}, false} value2 := &kv{common.LeftPadBytes([]byte{i + 10}, 32), []byte{i}, false} - trie.Update(value.k, value.v) - trie.Update(value2.k, value2.v) + trie.MustUpdate(value.k, value.v) + trie.MustUpdate(value2.k, value2.v) vals[string(value.k)] = value vals[string(value2.k)] = value2 } @@ -82,7 +82,7 @@ func (f *fuzzer) randomTrie(n int) (*trie.Trie, map[string]*kv) { k := f.randBytes(32) v := f.randBytes(20) value := &kv{k, v, false} - trie.Update(k, v) + trie.MustUpdate(k, v) vals[string(k)] = value if f.exhausted { return nil, nil diff --git a/tests/fuzzers/stacktrie/trie_fuzzer.go b/tests/fuzzers/stacktrie/trie_fuzzer.go index 435aa3a47..809dba8ce 100644 --- a/tests/fuzzers/stacktrie/trie_fuzzer.go +++ b/tests/fuzzers/stacktrie/trie_fuzzer.go @@ -175,7 +175,7 @@ func (f *fuzzer) fuzz() int { } keys[string(k)] = struct{}{} vals = append(vals, kv{k: k, v: v}) - trieA.Update(k, v) + trieA.MustUpdate(k, v) useful = true } if !useful { @@ -195,7 +195,7 @@ func (f *fuzzer) fuzz() int { if f.debugging { fmt.Printf("{\"%#x\" , \"%#x\"} // stacktrie.Update\n", kv.k, kv.v) } - trieB.Update(kv.k, kv.v) + trieB.MustUpdate(kv.k, kv.v) } rootB := trieB.Hash() trieB.Commit() @@ -223,7 +223,7 @@ func (f *fuzzer) fuzz() int { checked int ) for _, kv := range vals { - trieC.Update(kv.k, kv.v) + trieC.MustUpdate(kv.k, kv.v) } rootC, _ := trieC.Commit() if rootA != rootC { diff --git a/tests/fuzzers/trie/trie-fuzzer.go b/tests/fuzzers/trie/trie-fuzzer.go index 40dec76b8..c0cbceff3 100644 --- a/tests/fuzzers/trie/trie-fuzzer.go +++ b/tests/fuzzers/trie/trie-fuzzer.go @@ -147,13 +147,13 @@ func runRandTest(rt randTest) error { for i, step := range rt { switch step.op { case opUpdate: - tr.Update(step.key, step.value) + tr.MustUpdate(step.key, step.value) values[string(step.key)] = string(step.value) case opDelete: - tr.Delete(step.key) + tr.MustDelete(step.key) delete(values, string(step.key)) case opGet: - v := tr.Get(step.key) + v := tr.MustGet(step.key) want := values[string(step.key)] if string(v) != want { rt[i].err = fmt.Errorf("mismatch for key %#x, got %#x want %#x", step.key, v, want) @@ -176,7 +176,7 @@ func runRandTest(rt randTest) error { checktr := trie.NewEmpty(triedb) it := trie.NewIterator(tr.NodeIterator(nil)) for it.Next() { - checktr.Update(it.Key, it.Value) + checktr.MustUpdate(it.Key, it.Value) } if tr.Hash() != checktr.Hash() { return fmt.Errorf("hash mismatch in opItercheckhash") diff --git a/trie/errors.go b/trie/errors.go index afe344bed..bd82b950a 100644 --- a/trie/errors.go +++ b/trie/errors.go @@ -22,7 +22,7 @@ import ( "github.com/ethereum/go-ethereum/common" ) -// MissingNodeError is returned by the trie functions (TryGet, TryUpdate, TryDelete) +// MissingNodeError is returned by the trie functions (Get, Update, Delete) // in the case where a trie node is not present in the local database. It contains // information necessary for retrieving the missing node. type MissingNodeError struct { diff --git a/trie/iterator_test.go b/trie/iterator_test.go index ac2a3d2a4..6dc38db6f 100644 --- a/trie/iterator_test.go +++ b/trie/iterator_test.go @@ -58,7 +58,7 @@ func TestIterator(t *testing.T) { all := make(map[string]string) for _, val := range vals { all[val.k] = val.v - trie.Update([]byte(val.k), []byte(val.v)) + trie.MustUpdate([]byte(val.k), []byte(val.v)) } root, nodes := trie.Commit(false) db.Update(NewWithNodeSet(nodes)) @@ -89,8 +89,8 @@ func TestIteratorLargeData(t *testing.T) { for i := byte(0); i < 255; i++ { value := &kv{common.LeftPadBytes([]byte{i}, 32), []byte{i}, false} value2 := &kv{common.LeftPadBytes([]byte{10, i}, 32), []byte{i}, false} - trie.Update(value.k, value.v) - trie.Update(value2.k, value2.v) + trie.MustUpdate(value.k, value.v) + trie.MustUpdate(value2.k, value2.v) vals[string(value.k)] = value vals[string(value2.k)] = value2 } @@ -178,7 +178,7 @@ var testdata2 = []kvs{ func TestIteratorSeek(t *testing.T) { trie := NewEmpty(NewDatabase(rawdb.NewMemoryDatabase())) for _, val := range testdata1 { - trie.Update([]byte(val.k), []byte(val.v)) + trie.MustUpdate([]byte(val.k), []byte(val.v)) } // Seek to the middle. @@ -220,7 +220,7 @@ func TestDifferenceIterator(t *testing.T) { dba := NewDatabase(rawdb.NewMemoryDatabase()) triea := NewEmpty(dba) for _, val := range testdata1 { - triea.Update([]byte(val.k), []byte(val.v)) + triea.MustUpdate([]byte(val.k), []byte(val.v)) } rootA, nodesA := triea.Commit(false) dba.Update(NewWithNodeSet(nodesA)) @@ -229,7 +229,7 @@ func TestDifferenceIterator(t *testing.T) { dbb := NewDatabase(rawdb.NewMemoryDatabase()) trieb := NewEmpty(dbb) for _, val := range testdata2 { - trieb.Update([]byte(val.k), []byte(val.v)) + trieb.MustUpdate([]byte(val.k), []byte(val.v)) } rootB, nodesB := trieb.Commit(false) dbb.Update(NewWithNodeSet(nodesB)) @@ -262,7 +262,7 @@ func TestUnionIterator(t *testing.T) { dba := NewDatabase(rawdb.NewMemoryDatabase()) triea := NewEmpty(dba) for _, val := range testdata1 { - triea.Update([]byte(val.k), []byte(val.v)) + triea.MustUpdate([]byte(val.k), []byte(val.v)) } rootA, nodesA := triea.Commit(false) dba.Update(NewWithNodeSet(nodesA)) @@ -271,7 +271,7 @@ func TestUnionIterator(t *testing.T) { dbb := NewDatabase(rawdb.NewMemoryDatabase()) trieb := NewEmpty(dbb) for _, val := range testdata2 { - trieb.Update([]byte(val.k), []byte(val.v)) + trieb.MustUpdate([]byte(val.k), []byte(val.v)) } rootB, nodesB := trieb.Commit(false) dbb.Update(NewWithNodeSet(nodesB)) @@ -314,7 +314,7 @@ func TestUnionIterator(t *testing.T) { func TestIteratorNoDups(t *testing.T) { tr := NewEmpty(NewDatabase(rawdb.NewMemoryDatabase())) for _, val := range testdata1 { - tr.Update([]byte(val.k), []byte(val.v)) + tr.MustUpdate([]byte(val.k), []byte(val.v)) } checkIteratorNoDups(t, tr.NodeIterator(nil), nil) } @@ -329,7 +329,7 @@ func testIteratorContinueAfterError(t *testing.T, memonly bool) { tr := NewEmpty(triedb) for _, val := range testdata1 { - tr.Update([]byte(val.k), []byte(val.v)) + tr.MustUpdate([]byte(val.k), []byte(val.v)) } _, nodes := tr.Commit(false) triedb.Update(NewWithNodeSet(nodes)) @@ -421,7 +421,7 @@ func testIteratorContinueAfterSeekError(t *testing.T, memonly bool) { ctr := NewEmpty(triedb) for _, val := range testdata1 { - ctr.Update([]byte(val.k), []byte(val.v)) + ctr.MustUpdate([]byte(val.k), []byte(val.v)) } root, nodes := ctr.Commit(false) triedb.Update(NewWithNodeSet(nodes)) @@ -540,7 +540,7 @@ func makeLargeTestTrie() (*Database, *StateTrie, *loggingDb) { binary.BigEndian.PutUint64(val, uint64(i)) key = crypto.Keccak256(key) val = crypto.Keccak256(val) - trie.Update(key, val) + trie.MustUpdate(key, val) } _, nodes := trie.Commit(false) triedb.Update(NewWithNodeSet(nodes)) @@ -580,7 +580,7 @@ func TestIteratorNodeBlob(t *testing.T) { all := make(map[string]string) for _, val := range vals { all[val.k] = val.v - trie.Update([]byte(val.k), []byte(val.v)) + trie.MustUpdate([]byte(val.k), []byte(val.v)) } _, nodes := trie.Commit(false) triedb.Update(NewWithNodeSet(nodes)) diff --git a/trie/proof.go b/trie/proof.go index f11dfc47a..65df7577b 100644 --- a/trie/proof.go +++ b/trie/proof.go @@ -498,7 +498,7 @@ func VerifyRangeProof(rootHash common.Hash, firstKey []byte, lastKey []byte, key if proof == nil { tr := NewStackTrie(nil) for index, key := range keys { - tr.TryUpdate(key, values[index]) + tr.Update(key, values[index]) } if have, want := tr.Hash(), rootHash; have != want { return false, fmt.Errorf("invalid proof, want hash %x, got %x", want, have) @@ -568,7 +568,7 @@ func VerifyRangeProof(rootHash common.Hash, firstKey []byte, lastKey []byte, key tr.root = nil } for index, key := range keys { - tr.TryUpdate(key, values[index]) + tr.Update(key, values[index]) } if tr.Hash() != rootHash { return false, fmt.Errorf("invalid proof, want hash %x, got %x", rootHash, tr.Hash()) diff --git a/trie/proof_test.go b/trie/proof_test.go index 6b23bcdb2..69e3f8e9c 100644 --- a/trie/proof_test.go +++ b/trie/proof_test.go @@ -403,7 +403,7 @@ func TestOneElementRangeProof(t *testing.T) { // Test the mini trie with only a single element. tinyTrie := NewEmpty(NewDatabase(rawdb.NewMemoryDatabase())) entry := &kv{randBytes(32), randBytes(20), false} - tinyTrie.Update(entry.k, entry.v) + tinyTrie.MustUpdate(entry.k, entry.v) first = common.HexToHash("0x0000000000000000000000000000000000000000000000000000000000000000").Bytes() last = entry.k @@ -477,7 +477,7 @@ func TestSingleSideRangeProof(t *testing.T) { var entries entrySlice for i := 0; i < 4096; i++ { value := &kv{randBytes(32), randBytes(20), false} - trie.Update(value.k, value.v) + trie.MustUpdate(value.k, value.v) entries = append(entries, value) } sort.Sort(entries) @@ -512,7 +512,7 @@ func TestReverseSingleSideRangeProof(t *testing.T) { var entries entrySlice for i := 0; i < 4096; i++ { value := &kv{randBytes(32), randBytes(20), false} - trie.Update(value.k, value.v) + trie.MustUpdate(value.k, value.v) entries = append(entries, value) } sort.Sort(entries) @@ -619,7 +619,7 @@ func TestGappedRangeProof(t *testing.T) { var entries []*kv // Sorted entries for i := byte(0); i < 10; i++ { value := &kv{common.LeftPadBytes([]byte{i}, 32), []byte{i}, false} - trie.Update(value.k, value.v) + trie.MustUpdate(value.k, value.v) entries = append(entries, value) } first, last := 2, 8 @@ -693,7 +693,7 @@ func TestHasRightElement(t *testing.T) { var entries entrySlice for i := 0; i < 4096; i++ { value := &kv{randBytes(32), randBytes(20), false} - trie.Update(value.k, value.v) + trie.MustUpdate(value.k, value.v) entries = append(entries, value) } sort.Sort(entries) @@ -1047,14 +1047,14 @@ func randomTrie(n int) (*Trie, map[string]*kv) { for i := byte(0); i < 100; i++ { value := &kv{common.LeftPadBytes([]byte{i}, 32), []byte{i}, false} value2 := &kv{common.LeftPadBytes([]byte{i + 10}, 32), []byte{i}, false} - trie.Update(value.k, value.v) - trie.Update(value2.k, value2.v) + trie.MustUpdate(value.k, value.v) + trie.MustUpdate(value2.k, value2.v) vals[string(value.k)] = value vals[string(value2.k)] = value2 } for i := 0; i < n; i++ { value := &kv{randBytes(32), randBytes(20), false} - trie.Update(value.k, value.v) + trie.MustUpdate(value.k, value.v) vals[string(value.k)] = value } return trie, vals @@ -1071,7 +1071,7 @@ func nonRandomTrie(n int) (*Trie, map[string]*kv) { binary.LittleEndian.PutUint64(value, i-max) //value := &kv{common.LeftPadBytes([]byte{i}, 32), []byte{i}, false} elem := &kv{key, value, false} - trie.Update(elem.k, elem.v) + trie.MustUpdate(elem.k, elem.v) vals[string(elem.k)] = elem } return trie, vals @@ -1088,7 +1088,7 @@ func TestRangeProofKeysWithSharedPrefix(t *testing.T) { } trie := NewEmpty(NewDatabase(rawdb.NewMemoryDatabase())) for i, key := range keys { - trie.Update(key, vals[i]) + trie.MustUpdate(key, vals[i]) } root := trie.Hash() proof := memorydb.New() diff --git a/trie/secure_trie.go b/trie/secure_trie.go index b25bf758c..5bfd24650 100644 --- a/trie/secure_trie.go +++ b/trie/secure_trie.go @@ -19,7 +19,6 @@ package trie import ( "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/core/types" - "github.com/ethereum/go-ethereum/log" "github.com/ethereum/go-ethereum/rlp" ) @@ -72,14 +71,13 @@ func NewStateTrie(id *ID, db *Database) (*StateTrie, error) { return &StateTrie{trie: *trie, preimages: db.preimages}, nil } -// Get returns the value for key stored in the trie. +// MustGet returns the value for key stored in the trie. // The value bytes must not be modified by the caller. -func (t *StateTrie) Get(key []byte) []byte { - res, err := t.GetStorage(common.Address{}, key) - if err != nil { - log.Error("Unhandled trie error in StateTrie.Get", "err", err) - } - return res +// +// This function will omit any encountered error but just +// print out an error message. +func (t *StateTrie) MustGet(key []byte) []byte { + return t.trie.MustGet(t.hashKey(key)) } // GetStorage attempts to retrieve a storage slot with provided account address @@ -87,14 +85,14 @@ func (t *StateTrie) Get(key []byte) []byte { // If the specified storage slot is not in the trie, nil will be returned. // If a trie node is not found in the database, a MissingNodeError is returned. func (t *StateTrie) GetStorage(_ common.Address, key []byte) ([]byte, error) { - return t.trie.TryGet(t.hashKey(key)) + return t.trie.Get(t.hashKey(key)) } // GetAccount attempts to retrieve an account with provided account address. // If the specified account is not in the trie, nil will be returned. // If a trie node is not found in the database, a MissingNodeError is returned. func (t *StateTrie) GetAccount(address common.Address) (*types.StateAccount, error) { - res, err := t.trie.TryGet(t.hashKey(address.Bytes())) + res, err := t.trie.Get(t.hashKey(address.Bytes())) if res == nil || err != nil { return nil, err } @@ -107,7 +105,7 @@ func (t *StateTrie) GetAccount(address common.Address) (*types.StateAccount, err // account hash that is the hash of address. This constitutes an abstraction // leak, since the client code needs to know the key format. func (t *StateTrie) GetAccountByHash(addrHash common.Hash) (*types.StateAccount, error) { - res, err := t.trie.TryGet(addrHash.Bytes()) + res, err := t.trie.Get(addrHash.Bytes()) if res == nil || err != nil { return nil, err } @@ -121,19 +119,22 @@ func (t *StateTrie) GetAccountByHash(addrHash common.Hash) (*types.StateAccount, // If the specified trie node is not in the trie, nil will be returned. // If a trie node is not found in the database, a MissingNodeError is returned. func (t *StateTrie) GetNode(path []byte) ([]byte, int, error) { - return t.trie.TryGetNode(path) + return t.trie.GetNode(path) } -// Update associates key with value in the trie. Subsequent calls to +// MustUpdate associates key with value in the trie. Subsequent calls to // Get will return value. If value has length zero, any existing value // is deleted from the trie and calls to Get will return nil. // // The value bytes must not be modified by the caller while they are // stored in the trie. -func (t *StateTrie) Update(key, value []byte) { - if err := t.UpdateStorage(common.Address{}, key, value); err != nil { - log.Error("Unhandled trie error in StateTrie.Update", "err", err) - } +// +// This function will omit any encountered error but just print out an +// error message. +func (t *StateTrie) MustUpdate(key, value []byte) { + hk := t.hashKey(key) + t.trie.MustUpdate(hk, value) + t.getSecKeyCache()[string(hk)] = common.CopyBytes(key) } // UpdateStorage associates key with value in the trie. Subsequent calls to @@ -146,7 +147,7 @@ func (t *StateTrie) Update(key, value []byte) { // If a node is not found in the database, a MissingNodeError is returned. func (t *StateTrie) UpdateStorage(_ common.Address, key, value []byte) error { hk := t.hashKey(key) - err := t.trie.TryUpdate(hk, value) + err := t.trie.Update(hk, value) if err != nil { return err } @@ -161,18 +162,19 @@ func (t *StateTrie) UpdateAccount(address common.Address, acc *types.StateAccoun if err != nil { return err } - if err := t.trie.TryUpdate(hk, data); err != nil { + if err := t.trie.Update(hk, data); err != nil { return err } t.getSecKeyCache()[string(hk)] = address.Bytes() return nil } -// Delete removes any existing value for key from the trie. -func (t *StateTrie) Delete(key []byte) { - if err := t.DeleteStorage(common.Address{}, key); err != nil { - log.Error("Unhandled trie error in StateTrie.Delete", "err", err) - } +// MustDelete removes any existing value for key from the trie. This function +// will omit any encountered error but just print out an error message. +func (t *StateTrie) MustDelete(key []byte) { + hk := t.hashKey(key) + delete(t.getSecKeyCache(), string(hk)) + t.trie.MustDelete(hk) } // DeleteStorage removes any existing storage slot from the trie. @@ -181,14 +183,14 @@ func (t *StateTrie) Delete(key []byte) { func (t *StateTrie) DeleteStorage(_ common.Address, key []byte) error { hk := t.hashKey(key) delete(t.getSecKeyCache(), string(hk)) - return t.trie.TryDelete(hk) + return t.trie.Delete(hk) } // DeleteAccount abstracts an account deletion from the trie. func (t *StateTrie) DeleteAccount(address common.Address) error { hk := t.hashKey(address.Bytes()) delete(t.getSecKeyCache(), string(hk)) - return t.trie.TryDelete(hk) + return t.trie.Delete(hk) } // GetKey returns the sha3 preimage of a hashed key that was diff --git a/trie/secure_trie_test.go b/trie/secure_trie_test.go index d3e6c6706..a55c10a60 100644 --- a/trie/secure_trie_test.go +++ b/trie/secure_trie_test.go @@ -45,17 +45,17 @@ func makeTestStateTrie() (*Database, *StateTrie, map[string][]byte) { // Map the same data under multiple keys key, val := common.LeftPadBytes([]byte{1, i}, 32), []byte{i} content[string(key)] = val - trie.Update(key, val) + trie.MustUpdate(key, val) key, val = common.LeftPadBytes([]byte{2, i}, 32), []byte{i} content[string(key)] = val - trie.Update(key, val) + trie.MustUpdate(key, val) // Add some other data to inflate the trie for j := byte(3); j < 13; j++ { key, val = common.LeftPadBytes([]byte{j, i}, 32), []byte{j, i} content[string(key)] = val - trie.Update(key, val) + trie.MustUpdate(key, val) } } root, nodes := trie.Commit(false) @@ -81,9 +81,9 @@ func TestSecureDelete(t *testing.T) { } for _, val := range vals { if val.v != "" { - trie.Update([]byte(val.k), []byte(val.v)) + trie.MustUpdate([]byte(val.k), []byte(val.v)) } else { - trie.Delete([]byte(val.k)) + trie.MustDelete([]byte(val.k)) } } hash := trie.Hash() @@ -95,13 +95,13 @@ func TestSecureDelete(t *testing.T) { func TestSecureGetKey(t *testing.T) { trie := newEmptySecure() - trie.Update([]byte("foo"), []byte("bar")) + trie.MustUpdate([]byte("foo"), []byte("bar")) key := []byte("foo") value := []byte("bar") seckey := crypto.Keccak256(key) - if !bytes.Equal(trie.Get(key), value) { + if !bytes.Equal(trie.MustGet(key), value) { t.Errorf("Get did not return bar") } if k := trie.GetKey(seckey); !bytes.Equal(k, key) { @@ -128,15 +128,15 @@ func TestStateTrieConcurrency(t *testing.T) { for j := byte(0); j < 255; j++ { // Map the same data under multiple keys key, val := common.LeftPadBytes([]byte{byte(index), 1, j}, 32), []byte{j} - tries[index].Update(key, val) + tries[index].MustUpdate(key, val) key, val = common.LeftPadBytes([]byte{byte(index), 2, j}, 32), []byte{j} - tries[index].Update(key, val) + tries[index].MustUpdate(key, val) // Add some other data to inflate the trie for k := byte(3); k < 13; k++ { key, val = common.LeftPadBytes([]byte{byte(index), k, j}, 32), []byte{k, j} - tries[index].Update(key, val) + tries[index].MustUpdate(key, val) } } tries[index].Commit(false) diff --git a/trie/stacktrie.go b/trie/stacktrie.go index 83034e29a..030d2a596 100644 --- a/trie/stacktrie.go +++ b/trie/stacktrie.go @@ -202,8 +202,8 @@ const ( hashedNode ) -// TryUpdate inserts a (key, value) pair into the stack trie -func (st *StackTrie) TryUpdate(key, value []byte) error { +// Update inserts a (key, value) pair into the stack trie. +func (st *StackTrie) Update(key, value []byte) error { k := keybytesToHex(key) if len(value) == 0 { panic("deletion not supported") @@ -212,8 +212,10 @@ func (st *StackTrie) TryUpdate(key, value []byte) error { return nil } -func (st *StackTrie) Update(key, value []byte) { - if err := st.TryUpdate(key, value); err != nil { +// MustUpdate is a wrapper of Update and will omit any encountered error but +// just print out an error message. +func (st *StackTrie) MustUpdate(key, value []byte) { + if err := st.Update(key, value); err != nil { log.Error("Unhandled trie error in StackTrie.Update", "err", err) } } diff --git a/trie/stacktrie_test.go b/trie/stacktrie_test.go index 3e6cc8cd5..ea3eef788 100644 --- a/trie/stacktrie_test.go +++ b/trie/stacktrie_test.go @@ -174,7 +174,7 @@ func TestStackTrieInsertAndHash(t *testing.T) { st.Reset() for j := 0; j < l; j++ { kv := &test[j] - if err := st.TryUpdate(common.FromHex(kv.K), []byte(kv.V)); err != nil { + if err := st.Update(common.FromHex(kv.K), []byte(kv.V)); err != nil { t.Fatal(err) } } @@ -193,8 +193,8 @@ func TestSizeBug(t *testing.T) { leaf := common.FromHex("290decd9548b62a8d60345a988386fc84ba6bc95484008f6362f93160ef3e563") value := common.FromHex("94cf40d0d2b44f2b66e07cace1372ca42b73cf21a3") - nt.TryUpdate(leaf, value) - st.TryUpdate(leaf, value) + nt.Update(leaf, value) + st.Update(leaf, value) if nt.Hash() != st.Hash() { t.Fatalf("error %x != %x", st.Hash(), nt.Hash()) @@ -218,8 +218,8 @@ func TestEmptyBug(t *testing.T) { } for _, kv := range kvs { - nt.TryUpdate(common.FromHex(kv.K), common.FromHex(kv.V)) - st.TryUpdate(common.FromHex(kv.K), common.FromHex(kv.V)) + nt.Update(common.FromHex(kv.K), common.FromHex(kv.V)) + st.Update(common.FromHex(kv.K), common.FromHex(kv.V)) } if nt.Hash() != st.Hash() { @@ -241,8 +241,8 @@ func TestValLength56(t *testing.T) { } for _, kv := range kvs { - nt.TryUpdate(common.FromHex(kv.K), common.FromHex(kv.V)) - st.TryUpdate(common.FromHex(kv.K), common.FromHex(kv.V)) + nt.Update(common.FromHex(kv.K), common.FromHex(kv.V)) + st.Update(common.FromHex(kv.K), common.FromHex(kv.V)) } if nt.Hash() != st.Hash() { @@ -263,8 +263,8 @@ func TestUpdateSmallNodes(t *testing.T) { {"65", "3000"}, // stacktrie.Update } for _, kv := range kvs { - nt.TryUpdate(common.FromHex(kv.K), common.FromHex(kv.V)) - st.TryUpdate(common.FromHex(kv.K), common.FromHex(kv.V)) + nt.Update(common.FromHex(kv.K), common.FromHex(kv.V)) + st.Update(common.FromHex(kv.K), common.FromHex(kv.V)) } if nt.Hash() != st.Hash() { t.Fatalf("error %x != %x", st.Hash(), nt.Hash()) @@ -291,8 +291,8 @@ func TestUpdateVariableKeys(t *testing.T) { {"0x3330353463653239356131303167617430", "313131"}, } for _, kv := range kvs { - nt.TryUpdate(common.FromHex(kv.K), common.FromHex(kv.V)) - st.TryUpdate(common.FromHex(kv.K), common.FromHex(kv.V)) + nt.Update(common.FromHex(kv.K), common.FromHex(kv.V)) + st.Update(common.FromHex(kv.K), common.FromHex(kv.V)) } if nt.Hash() != st.Hash() { t.Fatalf("error %x != %x", st.Hash(), nt.Hash()) @@ -309,7 +309,7 @@ func TestStacktrieNotModifyValues(t *testing.T) { value := make([]byte, 1, 100) value[0] = 0x2 want := common.CopyBytes(value) - st.TryUpdate([]byte{0x01}, value) + st.Update([]byte{0x01}, value) st.Hash() if have := value; !bytes.Equal(have, want) { t.Fatalf("tiny trie: have %#x want %#x", have, want) @@ -330,7 +330,7 @@ func TestStacktrieNotModifyValues(t *testing.T) { for i := 0; i < 1000; i++ { key := common.BigToHash(keyB) value := getValue(i) - st.TryUpdate(key.Bytes(), value) + st.Update(key.Bytes(), value) vals = append(vals, value) keyB = keyB.Add(keyB, keyDelta) keyDelta.Add(keyDelta, common.Big1) @@ -371,7 +371,7 @@ func TestStacktrieSerialization(t *testing.T) { keyDelta.Add(keyDelta, common.Big1) } for i, k := range keys { - nt.TryUpdate(k, common.CopyBytes(vals[i])) + nt.Update(k, common.CopyBytes(vals[i])) } for i, k := range keys { @@ -384,7 +384,7 @@ func TestStacktrieSerialization(t *testing.T) { t.Fatal(err) } st = newSt - st.TryUpdate(k, common.CopyBytes(vals[i])) + st.Update(k, common.CopyBytes(vals[i])) } if have, want := st.Hash(), nt.Hash(); have != want { t.Fatalf("have %#x want %#x", have, want) diff --git a/trie/sync_test.go b/trie/sync_test.go index f404dc1cd..70898604f 100644 --- a/trie/sync_test.go +++ b/trie/sync_test.go @@ -40,17 +40,17 @@ func makeTestTrie() (*Database, *StateTrie, map[string][]byte) { // Map the same data under multiple keys key, val := common.LeftPadBytes([]byte{1, i}, 32), []byte{i} content[string(key)] = val - trie.Update(key, val) + trie.MustUpdate(key, val) key, val = common.LeftPadBytes([]byte{2, i}, 32), []byte{i} content[string(key)] = val - trie.Update(key, val) + trie.MustUpdate(key, val) // Add some other data to inflate the trie for j := byte(3); j < 13; j++ { key, val = common.LeftPadBytes([]byte{j, i}, 32), []byte{j, i} content[string(key)] = val - trie.Update(key, val) + trie.MustUpdate(key, val) } } root, nodes := trie.Commit(false) @@ -74,7 +74,7 @@ func checkTrieContents(t *testing.T, db *Database, root []byte, content map[stri t.Fatalf("inconsistent trie at %x: %v", root, err) } for key, val := range content { - if have := trie.Get([]byte(key)); !bytes.Equal(have, val) { + if have := trie.MustGet([]byte(key)); !bytes.Equal(have, val) { t.Errorf("entry %x: content mismatch: have %x, want %x", key, have, val) } } diff --git a/trie/tracer_test.go b/trie/tracer_test.go index 5e627c89c..1b9f44108 100644 --- a/trie/tracer_test.go +++ b/trie/tracer_test.go @@ -64,7 +64,7 @@ func testTrieTracer(t *testing.T, vals []struct{ k, v string }) { // Determine all new nodes are tracked for _, val := range vals { - trie.Update([]byte(val.k), []byte(val.v)) + trie.MustUpdate([]byte(val.k), []byte(val.v)) } insertSet := copySet(trie.tracer.inserts) // copy before commit deleteSet := copySet(trie.tracer.deletes) // copy before commit @@ -82,7 +82,7 @@ func testTrieTracer(t *testing.T, vals []struct{ k, v string }) { // Determine all deletions are tracked trie, _ = New(TrieID(root), db) for _, val := range vals { - trie.Delete([]byte(val.k)) + trie.MustDelete([]byte(val.k)) } insertSet, deleteSet = copySet(trie.tracer.inserts), copySet(trie.tracer.deletes) if !compareSet(insertSet, nil) { @@ -104,10 +104,10 @@ func TestTrieTracerNoop(t *testing.T) { func testTrieTracerNoop(t *testing.T, vals []struct{ k, v string }) { trie := NewEmpty(NewDatabase(rawdb.NewMemoryDatabase())) for _, val := range vals { - trie.Update([]byte(val.k), []byte(val.v)) + trie.MustUpdate([]byte(val.k), []byte(val.v)) } for _, val := range vals { - trie.Delete([]byte(val.k)) + trie.MustDelete([]byte(val.k)) } if len(trie.tracer.inserts) != 0 { t.Fatal("Unexpected insertion set") @@ -132,7 +132,7 @@ func testAccessList(t *testing.T, vals []struct{ k, v string }) { ) // Create trie from scratch for _, val := range vals { - trie.Update([]byte(val.k), []byte(val.v)) + trie.MustUpdate([]byte(val.k), []byte(val.v)) } root, nodes := trie.Commit(false) db.Update(NewWithNodeSet(nodes)) @@ -146,7 +146,7 @@ func testAccessList(t *testing.T, vals []struct{ k, v string }) { trie, _ = New(TrieID(root), db) orig = trie.Copy() for _, val := range vals { - trie.Update([]byte(val.k), randBytes(32)) + trie.MustUpdate([]byte(val.k), randBytes(32)) } root, nodes = trie.Commit(false) db.Update(NewWithNodeSet(nodes)) @@ -163,7 +163,7 @@ func testAccessList(t *testing.T, vals []struct{ k, v string }) { for i := 0; i < 30; i++ { key := randBytes(32) keys = append(keys, string(key)) - trie.Update(key, randBytes(32)) + trie.MustUpdate(key, randBytes(32)) } root, nodes = trie.Commit(false) db.Update(NewWithNodeSet(nodes)) @@ -177,7 +177,7 @@ func testAccessList(t *testing.T, vals []struct{ k, v string }) { trie, _ = New(TrieID(root), db) orig = trie.Copy() for _, key := range keys { - trie.Update([]byte(key), nil) + trie.MustUpdate([]byte(key), nil) } root, nodes = trie.Commit(false) db.Update(NewWithNodeSet(nodes)) @@ -191,7 +191,7 @@ func testAccessList(t *testing.T, vals []struct{ k, v string }) { trie, _ = New(TrieID(root), db) orig = trie.Copy() for _, val := range vals { - trie.Update([]byte(val.k), nil) + trie.MustUpdate([]byte(val.k), nil) } root, nodes = trie.Commit(false) db.Update(NewWithNodeSet(nodes)) @@ -210,7 +210,7 @@ func TestAccessListLeak(t *testing.T) { ) // Create trie from scratch for _, val := range standard { - trie.Update([]byte(val.k), []byte(val.v)) + trie.MustUpdate([]byte(val.k), []byte(val.v)) } root, nodes := trie.Commit(false) db.Update(NewWithNodeSet(nodes)) @@ -260,7 +260,7 @@ func TestTinyTree(t *testing.T) { trie = NewEmpty(db) ) for _, val := range tiny { - trie.Update([]byte(val.k), randBytes(32)) + trie.MustUpdate([]byte(val.k), randBytes(32)) } root, set := trie.Commit(false) db.Update(NewWithNodeSet(set)) @@ -268,7 +268,7 @@ func TestTinyTree(t *testing.T) { trie, _ = New(TrieID(root), db) orig := trie.Copy() for _, val := range tiny { - trie.Update([]byte(val.k), []byte(val.v)) + trie.MustUpdate([]byte(val.k), []byte(val.v)) } root, set = trie.Commit(false) db.Update(NewWithNodeSet(set)) diff --git a/trie/trie.go b/trie/trie.go index 17bacba00..4b6f6a55b 100644 --- a/trie/trie.go +++ b/trie/trie.go @@ -105,28 +105,30 @@ func (t *Trie) NodeIterator(start []byte) NodeIterator { return newNodeIterator(t, start) } -// Get returns the value for key stored in the trie. -// The value bytes must not be modified by the caller. -func (t *Trie) Get(key []byte) []byte { - res, err := t.TryGet(key) +// MustGet is a wrapper of Get and will omit any encountered error but just +// print out an error message. +func (t *Trie) MustGet(key []byte) []byte { + res, err := t.Get(key) if err != nil { log.Error("Unhandled trie error in Trie.Get", "err", err) } return res } -// TryGet returns the value for key stored in the trie. +// Get returns the value for key stored in the trie. // The value bytes must not be modified by the caller. -// If a node was not found in the database, a MissingNodeError is returned. -func (t *Trie) TryGet(key []byte) ([]byte, error) { - value, newroot, didResolve, err := t.tryGet(t.root, keybytesToHex(key), 0) +// +// If the requested node is not present in trie, no error will be returned. +// If the trie is corrupted, a MissingNodeError is returned. +func (t *Trie) Get(key []byte) ([]byte, error) { + value, newroot, didResolve, err := t.get(t.root, keybytesToHex(key), 0) if err == nil && didResolve { t.root = newroot } return value, err } -func (t *Trie) tryGet(origNode node, key []byte, pos int) (value []byte, newnode node, didResolve bool, err error) { +func (t *Trie) get(origNode node, key []byte, pos int) (value []byte, newnode node, didResolve bool, err error) { switch n := (origNode).(type) { case nil: return nil, nil, false, nil @@ -137,14 +139,14 @@ func (t *Trie) tryGet(origNode node, key []byte, pos int) (value []byte, newnode // key not found in trie return nil, n, false, nil } - value, newnode, didResolve, err = t.tryGet(n.Val, key, pos+len(n.Key)) + value, newnode, didResolve, err = t.get(n.Val, key, pos+len(n.Key)) if err == nil && didResolve { n = n.copy() n.Val = newnode } return value, n, didResolve, err case *fullNode: - value, newnode, didResolve, err = t.tryGet(n.Children[key[pos]], key, pos+1) + value, newnode, didResolve, err = t.get(n.Children[key[pos]], key, pos+1) if err == nil && didResolve { n = n.copy() n.Children[key[pos]] = newnode @@ -155,17 +157,30 @@ func (t *Trie) tryGet(origNode node, key []byte, pos int) (value []byte, newnode if err != nil { return nil, n, true, err } - value, newnode, _, err := t.tryGet(child, key, pos) + value, newnode, _, err := t.get(child, key, pos) return value, newnode, true, err default: panic(fmt.Sprintf("%T: invalid node: %v", origNode, origNode)) } } -// TryGetNode attempts to retrieve a trie node by compact-encoded path. It is not -// possible to use keybyte-encoding as the path might contain odd nibbles. -func (t *Trie) TryGetNode(path []byte) ([]byte, int, error) { - item, newroot, resolved, err := t.tryGetNode(t.root, compactToHex(path), 0) +// MustGetNode is a wrapper of GetNode and will omit any encountered error but +// just print out an error message. +func (t *Trie) MustGetNode(path []byte) ([]byte, int) { + item, resolved, err := t.GetNode(path) + if err != nil { + log.Error("Unhandled trie error in Trie.GetNode", "err", err) + } + return item, resolved +} + +// GetNode retrieves a trie node by compact-encoded path. It is not possible +// to use keybyte-encoding as the path might contain odd nibbles. +// +// If the requested node is not present in trie, no error will be returned. +// If the trie is corrupted, a MissingNodeError is returned. +func (t *Trie) GetNode(path []byte) ([]byte, int, error) { + item, newroot, resolved, err := t.getNode(t.root, compactToHex(path), 0) if err != nil { return nil, resolved, err } @@ -175,10 +190,10 @@ func (t *Trie) TryGetNode(path []byte) ([]byte, int, error) { if item == nil { return nil, resolved, nil } - return item, resolved, err + return item, resolved, nil } -func (t *Trie) tryGetNode(origNode node, path []byte, pos int) (item []byte, newnode node, resolved int, err error) { +func (t *Trie) getNode(origNode node, path []byte, pos int) (item []byte, newnode node, resolved int, err error) { // If non-existent path requested, abort if origNode == nil { return nil, nil, 0, nil @@ -211,7 +226,7 @@ func (t *Trie) tryGetNode(origNode node, path []byte, pos int) (item []byte, new // Path branches off from short node return nil, n, 0, nil } - item, newnode, resolved, err = t.tryGetNode(n.Val, path, pos+len(n.Key)) + item, newnode, resolved, err = t.getNode(n.Val, path, pos+len(n.Key)) if err == nil && resolved > 0 { n = n.copy() n.Val = newnode @@ -219,7 +234,7 @@ func (t *Trie) tryGetNode(origNode node, path []byte, pos int) (item []byte, new return item, n, resolved, err case *fullNode: - item, newnode, resolved, err = t.tryGetNode(n.Children[path[pos]], path, pos+1) + item, newnode, resolved, err = t.getNode(n.Children[path[pos]], path, pos+1) if err == nil && resolved > 0 { n = n.copy() n.Children[path[pos]] = newnode @@ -231,7 +246,7 @@ func (t *Trie) tryGetNode(origNode node, path []byte, pos int) (item []byte, new if err != nil { return nil, n, 1, err } - item, newnode, resolved, err := t.tryGetNode(child, path, pos) + item, newnode, resolved, err := t.getNode(child, path, pos) return item, newnode, resolved + 1, err default: @@ -239,33 +254,28 @@ func (t *Trie) tryGetNode(origNode node, path []byte, pos int) (item []byte, new } } +// MustUpdate is a wrapper of Update and will omit any encountered error but +// just print out an error message. +func (t *Trie) MustUpdate(key, value []byte) { + if err := t.Update(key, value); err != nil { + log.Error("Unhandled trie error in Trie.Update", "err", err) + } +} + // Update associates key with value in the trie. Subsequent calls to // Get will return value. If value has length zero, any existing value // is deleted from the trie and calls to Get will return nil. // // The value bytes must not be modified by the caller while they are // stored in the trie. -func (t *Trie) Update(key, value []byte) { - if err := t.TryUpdate(key, value); err != nil { - log.Error("Unhandled trie error in Trie.Update", "err", err) - } +// +// If the requested node is not present in trie, no error will be returned. +// If the trie is corrupted, a MissingNodeError is returned. +func (t *Trie) Update(key, value []byte) error { + return t.update(key, value) } -// TryUpdate associates key with value in the trie. Subsequent calls to -// Get will return value. If value has length zero, any existing value -// is deleted from the trie and calls to Get will return nil. -// -// The value bytes must not be modified by the caller while they are -// stored in the trie. -// -// If a node was not found in the database, a MissingNodeError is returned. -func (t *Trie) TryUpdate(key, value []byte) error { - return t.tryUpdate(key, value) -} - -// tryUpdate expects an RLP-encoded value and performs the core function -// for TryUpdate and TryUpdateAccount. -func (t *Trie) tryUpdate(key, value []byte) error { +func (t *Trie) update(key, value []byte) error { t.unhashed++ k := keybytesToHex(key) if len(value) != 0 { @@ -363,16 +373,19 @@ func (t *Trie) insert(n node, prefix, key []byte, value node) (bool, node, error } } -// Delete removes any existing value for key from the trie. -func (t *Trie) Delete(key []byte) { - if err := t.TryDelete(key); err != nil { +// MustDelete is a wrapper of Delete and will omit any encountered error but +// just print out an error message. +func (t *Trie) MustDelete(key []byte) { + if err := t.Delete(key); err != nil { log.Error("Unhandled trie error in Trie.Delete", "err", err) } } -// TryDelete removes any existing value for key from the trie. -// If a node was not found in the database, a MissingNodeError is returned. -func (t *Trie) TryDelete(key []byte) error { +// Delete removes any existing value for key from the trie. +// +// If the requested node is not present in trie, no error will be returned. +// If the trie is corrupted, a MissingNodeError is returned. +func (t *Trie) Delete(key []byte) error { t.unhashed++ k := keybytesToHex(key) _, n, err := t.delete(t.root, nil, k) diff --git a/trie/trie_test.go b/trie/trie_test.go index 089bb44a9..82ead8b44 100644 --- a/trie/trie_test.go +++ b/trie/trie_test.go @@ -56,8 +56,8 @@ func TestNull(t *testing.T) { trie := NewEmpty(NewDatabase(rawdb.NewMemoryDatabase())) key := make([]byte, 32) value := []byte("test") - trie.Update(key, value) - if !bytes.Equal(trie.Get(key), value) { + trie.MustUpdate(key, value) + if !bytes.Equal(trie.MustGet(key), value) { t.Fatal("wrong value") } } @@ -90,27 +90,27 @@ func testMissingNode(t *testing.T, memonly bool) { } trie, _ = New(TrieID(root), triedb) - _, err := trie.TryGet([]byte("120000")) + _, err := trie.Get([]byte("120000")) if err != nil { t.Errorf("Unexpected error: %v", err) } trie, _ = New(TrieID(root), triedb) - _, err = trie.TryGet([]byte("120099")) + _, err = trie.Get([]byte("120099")) if err != nil { t.Errorf("Unexpected error: %v", err) } trie, _ = New(TrieID(root), triedb) - _, err = trie.TryGet([]byte("123456")) + _, err = trie.Get([]byte("123456")) if err != nil { t.Errorf("Unexpected error: %v", err) } trie, _ = New(TrieID(root), triedb) - err = trie.TryUpdate([]byte("120099"), []byte("zxcvzxcvzxcvzxcvzxcvzxcvzxcvzxcv")) + err = trie.Update([]byte("120099"), []byte("zxcvzxcvzxcvzxcvzxcvzxcvzxcvzxcv")) if err != nil { t.Errorf("Unexpected error: %v", err) } trie, _ = New(TrieID(root), triedb) - err = trie.TryDelete([]byte("123456")) + err = trie.Delete([]byte("123456")) if err != nil { t.Errorf("Unexpected error: %v", err) } @@ -123,27 +123,27 @@ func testMissingNode(t *testing.T, memonly bool) { } trie, _ = New(TrieID(root), triedb) - _, err = trie.TryGet([]byte("120000")) + _, err = trie.Get([]byte("120000")) if _, ok := err.(*MissingNodeError); !ok { t.Errorf("Wrong error: %v", err) } trie, _ = New(TrieID(root), triedb) - _, err = trie.TryGet([]byte("120099")) + _, err = trie.Get([]byte("120099")) if _, ok := err.(*MissingNodeError); !ok { t.Errorf("Wrong error: %v", err) } trie, _ = New(TrieID(root), triedb) - _, err = trie.TryGet([]byte("123456")) + _, err = trie.Get([]byte("123456")) if err != nil { t.Errorf("Unexpected error: %v", err) } trie, _ = New(TrieID(root), triedb) - err = trie.TryUpdate([]byte("120099"), []byte("zxcv")) + err = trie.Update([]byte("120099"), []byte("zxcv")) if _, ok := err.(*MissingNodeError); !ok { t.Errorf("Wrong error: %v", err) } trie, _ = New(TrieID(root), triedb) - err = trie.TryDelete([]byte("123456")) + err = trie.Delete([]byte("123456")) if _, ok := err.(*MissingNodeError); !ok { t.Errorf("Wrong error: %v", err) } @@ -311,8 +311,8 @@ func TestReplication(t *testing.T) { func TestLargeValue(t *testing.T) { trie := NewEmpty(NewDatabase(rawdb.NewMemoryDatabase())) - trie.Update([]byte("key1"), []byte{99, 99, 99, 99}) - trie.Update([]byte("key2"), bytes.Repeat([]byte{1}, 32)) + trie.MustUpdate([]byte("key1"), []byte{99, 99, 99, 99}) + trie.MustUpdate([]byte("key2"), bytes.Repeat([]byte{1}, 32)) trie.Hash() } @@ -460,13 +460,13 @@ func runRandTest(rt randTest) bool { switch step.op { case opUpdate: - tr.Update(step.key, step.value) + tr.MustUpdate(step.key, step.value) values[string(step.key)] = string(step.value) case opDelete: - tr.Delete(step.key) + tr.MustDelete(step.key) delete(values, string(step.key)) case opGet: - v := tr.Get(step.key) + v := tr.MustGet(step.key) want := values[string(step.key)] if string(v) != want { rt[i].err = fmt.Errorf("mismatch for key %#x, got %#x want %#x", step.key, v, want) @@ -509,7 +509,7 @@ func runRandTest(rt randTest) bool { checktr := NewEmpty(triedb) it := NewIterator(tr.NodeIterator(nil)) for it.Next() { - checktr.Update(it.Key, it.Value) + checktr.MustUpdate(it.Key, it.Value) } if tr.Hash() != checktr.Hash() { rt[i].err = fmt.Errorf("hash mismatch in opItercheckhash") @@ -595,13 +595,13 @@ func benchGet(b *testing.B) { k := make([]byte, 32) for i := 0; i < benchElemCount; i++ { binary.LittleEndian.PutUint64(k, uint64(i)) - trie.Update(k, k) + trie.MustUpdate(k, k) } binary.LittleEndian.PutUint64(k, benchElemCount/2) b.ResetTimer() for i := 0; i < b.N; i++ { - trie.Get(k) + trie.MustGet(k) } b.StopTimer() } @@ -612,7 +612,7 @@ func benchUpdate(b *testing.B, e binary.ByteOrder) *Trie { b.ReportAllocs() for i := 0; i < b.N; i++ { e.PutUint64(k, uint64(i)) - trie.Update(k, k) + trie.MustUpdate(k, k) } return trie } @@ -640,11 +640,11 @@ func BenchmarkHash(b *testing.B) { trie := NewEmpty(NewDatabase(rawdb.NewMemoryDatabase())) i := 0 for ; i < len(addresses)/2; i++ { - trie.Update(crypto.Keccak256(addresses[i][:]), accounts[i]) + trie.MustUpdate(crypto.Keccak256(addresses[i][:]), accounts[i]) } trie.Hash() for ; i < len(addresses); i++ { - trie.Update(crypto.Keccak256(addresses[i][:]), accounts[i]) + trie.MustUpdate(crypto.Keccak256(addresses[i][:]), accounts[i]) } b.ResetTimer() b.ReportAllocs() @@ -670,7 +670,7 @@ func benchmarkCommitAfterHash(b *testing.B, collectLeaf bool) { addresses, accounts := makeAccounts(b.N) trie := NewEmpty(NewDatabase(rawdb.NewMemoryDatabase())) for i := 0; i < len(addresses); i++ { - trie.Update(crypto.Keccak256(addresses[i][:]), accounts[i]) + trie.MustUpdate(crypto.Keccak256(addresses[i][:]), accounts[i]) } // Insert the accounts into the trie and hash it trie.Hash() @@ -683,22 +683,22 @@ func TestTinyTrie(t *testing.T) { // Create a realistic account trie to hash _, accounts := makeAccounts(5) trie := NewEmpty(NewDatabase(rawdb.NewMemoryDatabase())) - trie.Update(common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000001337"), accounts[3]) + trie.MustUpdate(common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000001337"), accounts[3]) if exp, root := common.HexToHash("8c6a85a4d9fda98feff88450299e574e5378e32391f75a055d470ac0653f1005"), trie.Hash(); exp != root { t.Errorf("1: got %x, exp %x", root, exp) } - trie.Update(common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000001338"), accounts[4]) + trie.MustUpdate(common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000001338"), accounts[4]) if exp, root := common.HexToHash("ec63b967e98a5720e7f720482151963982890d82c9093c0d486b7eb8883a66b1"), trie.Hash(); exp != root { t.Errorf("2: got %x, exp %x", root, exp) } - trie.Update(common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000001339"), accounts[4]) + trie.MustUpdate(common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000001339"), accounts[4]) if exp, root := common.HexToHash("0608c1d1dc3905fa22204c7a0e43644831c3b6d3def0f274be623a948197e64a"), trie.Hash(); exp != root { t.Errorf("3: got %x, exp %x", root, exp) } checktr := NewEmpty(NewDatabase(rawdb.NewMemoryDatabase())) it := NewIterator(trie.NodeIterator(nil)) for it.Next() { - checktr.Update(it.Key, it.Value) + checktr.MustUpdate(it.Key, it.Value) } if troot, itroot := trie.Hash(), checktr.Hash(); troot != itroot { t.Fatalf("hash mismatch in opItercheckhash, trie: %x, check: %x", troot, itroot) @@ -710,7 +710,7 @@ func TestCommitAfterHash(t *testing.T) { addresses, accounts := makeAccounts(1000) trie := NewEmpty(NewDatabase(rawdb.NewMemoryDatabase())) for i := 0; i < len(addresses); i++ { - trie.Update(crypto.Keccak256(addresses[i][:]), accounts[i]) + trie.MustUpdate(crypto.Keccak256(addresses[i][:]), accounts[i]) } // Insert the accounts into the trie and hash it trie.Hash() @@ -820,7 +820,7 @@ func TestCommitSequence(t *testing.T) { trie := NewEmpty(db) // Fill the trie with elements for i := 0; i < tc.count; i++ { - trie.Update(crypto.Keccak256(addresses[i][:]), accounts[i]) + trie.MustUpdate(crypto.Keccak256(addresses[i][:]), accounts[i]) } // Flush trie -> database root, nodes := trie.Commit(false) @@ -861,7 +861,7 @@ func TestCommitSequenceRandomBlobs(t *testing.T) { } prng.Read(key) prng.Read(val) - trie.Update(key, val) + trie.MustUpdate(key, val) } // Flush trie -> database root, nodes := trie.Commit(false) @@ -899,8 +899,8 @@ func TestCommitSequenceStackTrie(t *testing.T) { val = make([]byte, 1+prng.Intn(1024)) } prng.Read(val) - trie.TryUpdate(key, val) - stTrie.TryUpdate(key, val) + trie.Update(key, val) + stTrie.Update(key, val) } // Flush trie -> database root, nodes := trie.Commit(false) @@ -948,8 +948,8 @@ func TestCommitSequenceSmallRoot(t *testing.T) { // Add a single small-element to the trie(s) key := make([]byte, 5) key[0] = 1 - trie.TryUpdate(key, []byte{0x1}) - stTrie.TryUpdate(key, []byte{0x1}) + trie.Update(key, []byte{0x1}) + stTrie.Update(key, []byte{0x1}) // Flush trie -> database root, nodes := trie.Commit(false) // Flush memdb -> disk (sponge) @@ -1017,7 +1017,7 @@ func benchmarkHashFixedSize(b *testing.B, addresses [][20]byte, accounts [][]byt b.ReportAllocs() trie := NewEmpty(NewDatabase(rawdb.NewMemoryDatabase())) for i := 0; i < len(addresses); i++ { - trie.Update(crypto.Keccak256(addresses[i][:]), accounts[i]) + trie.MustUpdate(crypto.Keccak256(addresses[i][:]), accounts[i]) } // Insert the accounts into the trie and hash it b.StartTimer() @@ -1068,7 +1068,7 @@ func benchmarkCommitAfterHashFixedSize(b *testing.B, addresses [][20]byte, accou b.ReportAllocs() trie := NewEmpty(NewDatabase(rawdb.NewMemoryDatabase())) for i := 0; i < len(addresses); i++ { - trie.Update(crypto.Keccak256(addresses[i][:]), accounts[i]) + trie.MustUpdate(crypto.Keccak256(addresses[i][:]), accounts[i]) } // Insert the accounts into the trie and hash it trie.Hash() @@ -1121,7 +1121,7 @@ func benchmarkDerefRootFixedSize(b *testing.B, addresses [][20]byte, accounts [] triedb := NewDatabase(rawdb.NewMemoryDatabase()) trie := NewEmpty(triedb) for i := 0; i < len(addresses); i++ { - trie.Update(crypto.Keccak256(addresses[i][:]), accounts[i]) + trie.MustUpdate(crypto.Keccak256(addresses[i][:]), accounts[i]) } h := trie.Hash() _, nodes := trie.Commit(false) @@ -1132,15 +1132,15 @@ func benchmarkDerefRootFixedSize(b *testing.B, addresses [][20]byte, accounts [] } func getString(trie *Trie, k string) []byte { - return trie.Get([]byte(k)) + return trie.MustGet([]byte(k)) } func updateString(trie *Trie, k, v string) { - trie.Update([]byte(k), []byte(v)) + trie.MustUpdate([]byte(k), []byte(v)) } func deleteString(trie *Trie, k string) { - trie.Delete([]byte(k)) + trie.MustDelete([]byte(k)) } func TestDecodeNode(t *testing.T) {