From 8bbb16b70ef4a9e732745aac104525a6f16306d6 Mon Sep 17 00:00:00 2001 From: Guillaume Ballet <3272758+gballet@users.noreply.github.com> Date: Wed, 28 Jun 2023 11:11:02 +0200 Subject: [PATCH] core/state, light, les: make signature of ContractCode hash-independent (#27209) * core/state, light, les: make signature of ContractCode hash-independent * push current state for feedback * les: fix unit test * core, les, light: fix les unittests * core/state, trie, les, light: fix state iterator * core, les: address comments * les: fix lint --------- Co-authored-by: Gary Rong --- core/blockchain_reader.go | 6 ++- core/state/database.go | 19 ++++---- core/state/iterator.go | 17 +++++-- core/state/state_object.go | 6 +-- core/state/sync_test.go | 83 +++++++++++++++++---------------- core/state/trie_prefetcher.go | 2 +- les/benchmark.go | 2 +- les/handler_test.go | 8 ++-- les/odr_requests.go | 20 ++++---- les/request_test.go | 3 +- les/server_handler.go | 15 +++--- les/server_requests.go | 21 +++++---- light/odr.go | 32 ++++++------- light/odr_test.go | 4 +- light/trie.go | 20 ++++---- tests/fuzzers/les/les-fuzzer.go | 41 ++++++++-------- trie/database_wrap.go | 7 ++- 17 files changed, 164 insertions(+), 142 deletions(-) diff --git a/core/blockchain_reader.go b/core/blockchain_reader.go index fd65cb2db..f9d0e3531 100644 --- a/core/blockchain_reader.go +++ b/core/blockchain_reader.go @@ -306,9 +306,11 @@ func (bc *BlockChain) TrieNode(hash common.Hash) ([]byte, error) { // new code scheme. func (bc *BlockChain) ContractCodeWithPrefix(hash common.Hash) ([]byte, error) { type codeReader interface { - ContractCodeWithPrefix(addrHash, codeHash common.Hash) ([]byte, error) + ContractCodeWithPrefix(address common.Address, codeHash common.Hash) ([]byte, error) } - return bc.stateCache.(codeReader).ContractCodeWithPrefix(common.Hash{}, hash) + // TODO(rjl493456442) The associated account address is also required + // in Verkle scheme. Fix it once snap-sync is supported for Verkle. + return bc.stateCache.(codeReader).ContractCodeWithPrefix(common.Address{}, hash) } // State returns a new mutable state based on the current HEAD block. diff --git a/core/state/database.go b/core/state/database.go index 5ca236636..a3b6322ae 100644 --- a/core/state/database.go +++ b/core/state/database.go @@ -24,6 +24,7 @@ import ( "github.com/ethereum/go-ethereum/common/lru" "github.com/ethereum/go-ethereum/core/rawdb" "github.com/ethereum/go-ethereum/core/types" + "github.com/ethereum/go-ethereum/crypto" "github.com/ethereum/go-ethereum/ethdb" "github.com/ethereum/go-ethereum/trie" "github.com/ethereum/go-ethereum/trie/trienode" @@ -43,16 +44,16 @@ type Database interface { OpenTrie(root common.Hash) (Trie, error) // OpenStorageTrie opens the storage trie of an account. - OpenStorageTrie(stateRoot common.Hash, addrHash, root common.Hash) (Trie, error) + OpenStorageTrie(stateRoot common.Hash, address common.Address, root common.Hash) (Trie, error) // CopyTrie returns an independent copy of the given trie. CopyTrie(Trie) Trie // ContractCode retrieves a particular contract's code. - ContractCode(addrHash, codeHash common.Hash) ([]byte, error) + ContractCode(addr common.Address, codeHash common.Hash) ([]byte, error) // ContractCodeSize retrieves a particular contracts code's size. - ContractCodeSize(addrHash, codeHash common.Hash) (int, error) + ContractCodeSize(addr common.Address, codeHash common.Hash) (int, error) // DiskDB returns the underlying key-value disk database. DiskDB() ethdb.KeyValueStore @@ -177,8 +178,8 @@ func (db *cachingDB) OpenTrie(root common.Hash) (Trie, error) { } // OpenStorageTrie opens the storage trie of an account. -func (db *cachingDB) OpenStorageTrie(stateRoot common.Hash, addrHash, root common.Hash) (Trie, error) { - tr, err := trie.NewStateTrie(trie.StorageTrieID(stateRoot, addrHash, root), db.triedb) +func (db *cachingDB) OpenStorageTrie(stateRoot common.Hash, address common.Address, root common.Hash) (Trie, error) { + tr, err := trie.NewStateTrie(trie.StorageTrieID(stateRoot, crypto.Keccak256Hash(address.Bytes()), root), db.triedb) if err != nil { return nil, err } @@ -196,7 +197,7 @@ func (db *cachingDB) CopyTrie(t Trie) Trie { } // ContractCode retrieves a particular contract's code. -func (db *cachingDB) ContractCode(addrHash, codeHash common.Hash) ([]byte, error) { +func (db *cachingDB) ContractCode(address common.Address, codeHash common.Hash) ([]byte, error) { code, _ := db.codeCache.Get(codeHash) if len(code) > 0 { return code, nil @@ -213,7 +214,7 @@ func (db *cachingDB) ContractCode(addrHash, codeHash common.Hash) ([]byte, error // ContractCodeWithPrefix retrieves a particular contract's code. If the // code can't be found in the cache, then check the existence with **new** // db scheme. -func (db *cachingDB) ContractCodeWithPrefix(addrHash, codeHash common.Hash) ([]byte, error) { +func (db *cachingDB) ContractCodeWithPrefix(address common.Address, codeHash common.Hash) ([]byte, error) { code, _ := db.codeCache.Get(codeHash) if len(code) > 0 { return code, nil @@ -228,11 +229,11 @@ func (db *cachingDB) ContractCodeWithPrefix(addrHash, codeHash common.Hash) ([]b } // ContractCodeSize retrieves a particular contracts code's size. -func (db *cachingDB) ContractCodeSize(addrHash, codeHash common.Hash) (int, error) { +func (db *cachingDB) ContractCodeSize(addr common.Address, codeHash common.Hash) (int, error) { if cached, ok := db.codeSizeCache.Get(codeHash); ok { return cached, nil } - code, err := db.ContractCode(addrHash, codeHash) + code, err := db.ContractCode(addr, codeHash) return len(code), err } diff --git a/core/state/iterator.go b/core/state/iterator.go index 7802263b7..bb9af0820 100644 --- a/core/state/iterator.go +++ b/core/state/iterator.go @@ -18,6 +18,7 @@ package state import ( "bytes" + "errors" "fmt" "github.com/ethereum/go-ethereum/common" @@ -27,7 +28,8 @@ import ( ) // nodeIterator is an iterator to traverse the entire state trie post-order, -// including all of the contract code and contract state tries. +// including all of the contract code and contract state tries. Preimage is +// required in order to resolve the contract address. type nodeIterator struct { state *StateDB // State being iterated @@ -113,7 +115,15 @@ func (it *nodeIterator) step() error { if err := rlp.Decode(bytes.NewReader(it.stateIt.LeafBlob()), &account); err != nil { return err } - dataTrie, err := it.state.db.OpenStorageTrie(it.state.originalRoot, common.BytesToHash(it.stateIt.LeafKey()), account.Root) + // Lookup the preimage of account hash + preimage := it.state.trie.GetKey(it.stateIt.LeafKey()) + if preimage == nil { + return errors.New("account address is not available") + } + address := common.BytesToAddress(preimage) + + // Traverse the storage slots belong to the account + dataTrie, err := it.state.db.OpenStorageTrie(it.state.originalRoot, address, account.Root) if err != nil { return err } @@ -126,8 +136,7 @@ func (it *nodeIterator) step() error { } if !bytes.Equal(account.CodeHash, types.EmptyCodeHash.Bytes()) { it.codeHash = common.BytesToHash(account.CodeHash) - addrHash := common.BytesToHash(it.stateIt.LeafKey()) - it.code, err = it.state.db.ContractCode(addrHash, common.BytesToHash(account.CodeHash)) + it.code, err = it.state.db.ContractCode(address, common.BytesToHash(account.CodeHash)) if err != nil { return fmt.Errorf("code %x: %v", account.CodeHash, err) } diff --git a/core/state/state_object.go b/core/state/state_object.go index 148136fc2..b443d12f1 100644 --- a/core/state/state_object.go +++ b/core/state/state_object.go @@ -142,7 +142,7 @@ func (s *stateObject) getTrie(db Database) (Trie, error) { s.trie = s.db.prefetcher.trie(s.addrHash, s.data.Root) } if s.trie == nil { - tr, err := db.OpenStorageTrie(s.db.originalRoot, s.addrHash, s.data.Root) + tr, err := db.OpenStorageTrie(s.db.originalRoot, s.address, s.data.Root) if err != nil { return nil, err } @@ -441,7 +441,7 @@ func (s *stateObject) Code(db Database) []byte { if bytes.Equal(s.CodeHash(), types.EmptyCodeHash.Bytes()) { return nil } - code, err := db.ContractCode(s.addrHash, common.BytesToHash(s.CodeHash())) + code, err := db.ContractCode(s.address, common.BytesToHash(s.CodeHash())) if err != nil { s.db.setError(fmt.Errorf("can't load code hash %x: %v", s.CodeHash(), err)) } @@ -459,7 +459,7 @@ func (s *stateObject) CodeSize(db Database) int { if bytes.Equal(s.CodeHash(), types.EmptyCodeHash.Bytes()) { return 0 } - size, err := db.ContractCodeSize(s.addrHash, common.BytesToHash(s.CodeHash())) + size, err := db.ContractCodeSize(s.address, common.BytesToHash(s.CodeHash())) if err != nil { s.db.setError(fmt.Errorf("can't load code size %x: %v", s.CodeHash(), err)) } diff --git a/core/state/sync_test.go b/core/state/sync_test.go index cac0ee6e4..7de69e0aa 100644 --- a/core/state/sync_test.go +++ b/core/state/sync_test.go @@ -42,7 +42,7 @@ type testAccount struct { func makeTestState() (ethdb.Database, Database, common.Hash, []*testAccount) { // Create an empty state db := rawdb.NewMemoryDatabase() - sdb := NewDatabase(db) + sdb := NewDatabaseWithConfig(db, &trie.Config{Preimages: true}) state, _ := New(types.EmptyRootHash, sdb, nil) // Fill it with some arbitrary data @@ -100,28 +100,9 @@ func checkStateAccounts(t *testing.T, db ethdb.Database, root common.Hash, accou } } -// checkTrieConsistency checks that all nodes in a (sub-)trie are indeed present. -func checkTrieConsistency(db ethdb.Database, root common.Hash) error { - if v, _ := db.Get(root[:]); v == nil { - return nil // Consider a non existent state consistent. - } - trie, err := trie.New(trie.StateTrieID(root), trie.NewDatabase(db)) - if err != nil { - return err - } - it := trie.MustNodeIterator(nil) - for it.Next(true) { - } - return it.Error() -} - // checkStateConsistency checks that all data of a state root is present. func checkStateConsistency(db ethdb.Database, root common.Hash) error { - // Create and iterate a state trie rooted in a sub-node - if _, err := db.Get(root.Bytes()); err != nil { - return nil // Consider a non existent state consistent. - } - state, err := New(root, NewDatabase(db), nil) + state, err := New(root, NewDatabaseWithConfig(db, &trie.Config{Preimages: true}), nil) if err != nil { return err } @@ -171,7 +152,7 @@ type stateElement struct { func testIterativeStateSync(t *testing.T, count int, commit bool, bypath bool) { // Create a random state to copy - _, srcDb, srcRoot, srcAccounts := makeTestState() + srcDisk, srcDb, srcRoot, srcAccounts := makeTestState() if commit { srcDb.TrieDB().Commit(srcRoot, false) } @@ -204,7 +185,7 @@ func testIterativeStateSync(t *testing.T, count int, commit bool, bypath bool) { codeResults = make([]trie.CodeSyncResult, len(codeElements)) ) for i, element := range codeElements { - data, err := srcDb.ContractCode(common.Hash{}, element.code) + data, err := srcDb.ContractCode(common.Address{}, element.code) if err != nil { t.Fatalf("failed to retrieve contract bytecode for hash %x", element.code) } @@ -274,6 +255,10 @@ func testIterativeStateSync(t *testing.T, count int, commit bool, bypath bool) { }) } } + // Copy the preimages from source db in order to traverse the state. + srcDb.TrieDB().WritePreimages() + copyPreimages(srcDisk, dstDb) + // Cross check that the two states are in sync checkStateAccounts(t, dstDb, srcRoot, srcAccounts) } @@ -282,7 +267,7 @@ func testIterativeStateSync(t *testing.T, count int, commit bool, bypath bool) { // partial results are returned, and the others sent only later. func TestIterativeDelayedStateSync(t *testing.T) { // Create a random state to copy - _, srcDb, srcRoot, srcAccounts := makeTestState() + srcDisk, srcDb, srcRoot, srcAccounts := makeTestState() // Create a destination state and sync with the scheduler dstDb := rawdb.NewMemoryDatabase() @@ -312,7 +297,7 @@ func TestIterativeDelayedStateSync(t *testing.T) { if len(codeElements) > 0 { codeResults := make([]trie.CodeSyncResult, len(codeElements)/2+1) for i, element := range codeElements[:len(codeResults)] { - data, err := srcDb.ContractCode(common.Hash{}, element.code) + data, err := srcDb.ContractCode(common.Address{}, element.code) if err != nil { t.Fatalf("failed to retrieve contract bytecode for %x", element.code) } @@ -363,6 +348,10 @@ func TestIterativeDelayedStateSync(t *testing.T) { }) } } + // Copy the preimages from source db in order to traverse the state. + srcDb.TrieDB().WritePreimages() + copyPreimages(srcDisk, dstDb) + // Cross check that the two states are in sync checkStateAccounts(t, dstDb, srcRoot, srcAccounts) } @@ -375,7 +364,7 @@ func TestIterativeRandomStateSyncBatched(t *testing.T) { testIterativeRandomS func testIterativeRandomStateSync(t *testing.T, count int) { // Create a random state to copy - _, srcDb, srcRoot, srcAccounts := makeTestState() + srcDisk, srcDb, srcRoot, srcAccounts := makeTestState() // Create a destination state and sync with the scheduler dstDb := rawdb.NewMemoryDatabase() @@ -399,7 +388,7 @@ func testIterativeRandomStateSync(t *testing.T, count int) { if len(codeQueue) > 0 { results := make([]trie.CodeSyncResult, 0, len(codeQueue)) for hash := range codeQueue { - data, err := srcDb.ContractCode(common.Hash{}, hash) + data, err := srcDb.ContractCode(common.Address{}, hash) if err != nil { t.Fatalf("failed to retrieve node data for %x", hash) } @@ -447,6 +436,10 @@ func testIterativeRandomStateSync(t *testing.T, count int) { codeQueue[hash] = struct{}{} } } + // Copy the preimages from source db in order to traverse the state. + srcDb.TrieDB().WritePreimages() + copyPreimages(srcDisk, dstDb) + // Cross check that the two states are in sync checkStateAccounts(t, dstDb, srcRoot, srcAccounts) } @@ -455,7 +448,7 @@ func testIterativeRandomStateSync(t *testing.T, count int) { // partial results are returned (Even those randomly), others sent only later. func TestIterativeRandomDelayedStateSync(t *testing.T) { // Create a random state to copy - _, srcDb, srcRoot, srcAccounts := makeTestState() + srcDisk, srcDb, srcRoot, srcAccounts := makeTestState() // Create a destination state and sync with the scheduler dstDb := rawdb.NewMemoryDatabase() @@ -481,7 +474,7 @@ func TestIterativeRandomDelayedStateSync(t *testing.T) { for hash := range codeQueue { delete(codeQueue, hash) - data, err := srcDb.ContractCode(common.Hash{}, hash) + data, err := srcDb.ContractCode(common.Address{}, hash) if err != nil { t.Fatalf("failed to retrieve node data for %x", hash) } @@ -537,6 +530,10 @@ func TestIterativeRandomDelayedStateSync(t *testing.T) { codeQueue[hash] = struct{}{} } } + // Copy the preimages from source db in order to traverse the state. + srcDb.TrieDB().WritePreimages() + copyPreimages(srcDisk, dstDb) + // Cross check that the two states are in sync checkStateAccounts(t, dstDb, srcRoot, srcAccounts) } @@ -555,7 +552,6 @@ func TestIncompleteStateSync(t *testing.T) { } } isCode[types.EmptyCodeHash] = struct{}{} - checkTrieConsistency(db, srcRoot) // Create a destination state and sync with the scheduler dstDb := rawdb.NewMemoryDatabase() @@ -588,7 +584,7 @@ func TestIncompleteStateSync(t *testing.T) { if len(codeQueue) > 0 { results := make([]trie.CodeSyncResult, 0, len(codeQueue)) for hash := range codeQueue { - data, err := srcDb.ContractCode(common.Hash{}, hash) + data, err := srcDb.ContractCode(common.Address{}, hash) if err != nil { t.Fatalf("failed to retrieve node data for %x", hash) } @@ -602,7 +598,6 @@ func TestIncompleteStateSync(t *testing.T) { } } } - var nodehashes []common.Hash if len(nodeQueue) > 0 { results := make([]trie.NodeSyncResult, 0, len(nodeQueue)) for path, element := range nodeQueue { @@ -617,7 +612,6 @@ func TestIncompleteStateSync(t *testing.T) { addedPaths = append(addedPaths, element.path) addedHashes = append(addedHashes, element.hash) } - nodehashes = append(nodehashes, element.hash) } // Process each of the state nodes for _, result := range results { @@ -632,13 +626,6 @@ func TestIncompleteStateSync(t *testing.T) { } batch.Write() - for _, root := range nodehashes { - // Can't use checkStateConsistency here because subtrie keys may have odd - // length and crash in LeafKey. - if err := checkTrieConsistency(dstDb, root); err != nil { - t.Fatalf("state inconsistent: %v", err) - } - } // Fetch the next batch to retrieve nodeQueue = make(map[string]stateElement) codeQueue = make(map[common.Hash]struct{}) @@ -654,6 +641,10 @@ func TestIncompleteStateSync(t *testing.T) { codeQueue[hash] = struct{}{} } } + // Copy the preimages from source db in order to traverse the state. + srcDb.TrieDB().WritePreimages() + copyPreimages(db, dstDb) + // Sanity check that removing any node from the database is detected for _, node := range addedCodes { val := rawdb.ReadCode(dstDb, node) @@ -678,3 +669,15 @@ func TestIncompleteStateSync(t *testing.T) { rawdb.WriteTrieNode(dstDb, owner, inner, hash, val, scheme) } } + +func copyPreimages(srcDb, dstDb ethdb.Database) { + it := srcDb.NewIterator(rawdb.PreimagePrefix, nil) + defer it.Release() + + preimages := make(map[common.Hash][]byte) + for it.Next() { + hash := it.Key()[len(rawdb.PreimagePrefix):] + preimages[common.BytesToHash(hash)] = common.CopyBytes(it.Value()) + } + rawdb.WritePreimages(dstDb, preimages) +} diff --git a/core/state/trie_prefetcher.go b/core/state/trie_prefetcher.go index 844f72fc1..4e8fd1e10 100644 --- a/core/state/trie_prefetcher.go +++ b/core/state/trie_prefetcher.go @@ -302,7 +302,7 @@ func (sf *subfetcher) loop() { } sf.trie = trie } else { - trie, err := sf.db.OpenStorageTrie(sf.state, sf.owner, sf.root) + trie, err := sf.db.OpenStorageTrie(sf.state, sf.addr, sf.root) if err != nil { log.Warn("Trie prefetcher failed opening trie", "root", sf.root, "err", err) return diff --git a/les/benchmark.go b/les/benchmark.go index c63a2564d..ab9351834 100644 --- a/les/benchmark.go +++ b/les/benchmark.go @@ -117,7 +117,7 @@ func (b *benchmarkProofsOrCode) request(peer *serverPeer, index int) error { key := make([]byte, 32) crand.Read(key) if b.code { - return peer.requestCode(0, []CodeReq{{BHash: b.headHash, AccKey: key}}) + return peer.requestCode(0, []CodeReq{{BHash: b.headHash, AccountAddress: key}}) } return peer.requestProofs(0, []ProofReq{{BHash: b.headHash, Key: key}}) } diff --git a/les/handler_test.go b/les/handler_test.go index d0878dcd9..31681fe36 100644 --- a/les/handler_test.go +++ b/les/handler_test.go @@ -295,8 +295,8 @@ func testGetCode(t *testing.T, protocol int) { for i := uint64(0); i <= bc.CurrentBlock().Number.Uint64(); i++ { header := bc.GetHeaderByNumber(i) req := &CodeReq{ - BHash: header.Hash(), - AccKey: crypto.Keccak256(testContractAddr[:]), + BHash: header.Hash(), + AccountAddress: testContractAddr[:], } codereqs = append(codereqs, req) if i >= testContractDeployed { @@ -331,8 +331,8 @@ func testGetStaleCode(t *testing.T, protocol int) { check := func(number uint64, expected [][]byte) { req := &CodeReq{ - BHash: bc.GetHeaderByNumber(number).Hash(), - AccKey: crypto.Keccak256(testContractAddr[:]), + BHash: bc.GetHeaderByNumber(number).Hash(), + AccountAddress: testContractAddr[:], } sendRequest(rawPeer.app, GetCodeMsg, 42, []*CodeReq{req}) if err := expectResponse(rawPeer.app, CodeMsg, 42, testBufLimit, expected); err != nil { diff --git a/les/odr_requests.go b/les/odr_requests.go index d8b094b72..2b23e0540 100644 --- a/les/odr_requests.go +++ b/les/odr_requests.go @@ -183,9 +183,9 @@ func (r *ReceiptsRequest) Validate(db ethdb.Database, msg *Msg) error { } type ProofReq struct { - BHash common.Hash - AccKey, Key []byte - FromLevel uint + BHash common.Hash + AccountAddress, Key []byte + FromLevel uint } // ODR request type for state/storage trie entries, see LesOdrRequest interface @@ -206,9 +206,9 @@ func (r *TrieRequest) CanSend(peer *serverPeer) bool { func (r *TrieRequest) Request(reqID uint64, peer *serverPeer) error { peer.Log().Debug("Requesting trie proof", "root", r.Id.Root, "key", r.Key) req := ProofReq{ - BHash: r.Id.BlockHash, - AccKey: r.Id.AccKey, - Key: r.Key, + BHash: r.Id.BlockHash, + AccountAddress: r.Id.AccountAddress, + Key: r.Key, } return peer.requestProofs(reqID, []ProofReq{req}) } @@ -238,8 +238,8 @@ func (r *TrieRequest) Validate(db ethdb.Database, msg *Msg) error { } type CodeReq struct { - BHash common.Hash - AccKey []byte + BHash common.Hash + AccountAddress []byte } // CodeRequest is the ODR request type for node data (used for retrieving contract code), see LesOdrRequest interface @@ -260,8 +260,8 @@ func (r *CodeRequest) CanSend(peer *serverPeer) bool { func (r *CodeRequest) Request(reqID uint64, peer *serverPeer) error { peer.Log().Debug("Requesting code data", "hash", r.Hash) req := CodeReq{ - BHash: r.Id.BlockHash, - AccKey: r.Id.AccKey, + BHash: r.Id.BlockHash, + AccountAddress: r.Id.AccountAddress, } return peer.requestCode(reqID, []CodeReq{req}) } diff --git a/les/request_test.go b/les/request_test.go index 9b52e6bd8..30ae2c2be 100644 --- a/les/request_test.go +++ b/les/request_test.go @@ -23,6 +23,7 @@ import ( "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/core/rawdb" + "github.com/ethereum/go-ethereum/core/types" "github.com/ethereum/go-ethereum/crypto" "github.com/ethereum/go-ethereum/ethdb" "github.com/ethereum/go-ethereum/light" @@ -77,7 +78,7 @@ func tfCodeAccess(db ethdb.Database, bhash common.Hash, num uint64) light.OdrReq return nil } sti := light.StateTrieID(header) - ci := light.StorageTrieID(sti, crypto.Keccak256Hash(testContractAddr[:]), common.Hash{}) + ci := light.StorageTrieID(sti, testContractAddr, types.EmptyRootHash) return &light.CodeRequest{Id: ci, Hash: crypto.Keccak256Hash(testContractCodeDeployed)} } diff --git a/les/server_handler.go b/les/server_handler.go index 418974c37..ad18d3f4e 100644 --- a/les/server_handler.go +++ b/les/server_handler.go @@ -18,6 +18,7 @@ package les import ( "errors" + "fmt" "sync" "time" @@ -34,7 +35,6 @@ import ( "github.com/ethereum/go-ethereum/log" "github.com/ethereum/go-ethereum/metrics" "github.com/ethereum/go-ethereum/p2p" - "github.com/ethereum/go-ethereum/rlp" "github.com/ethereum/go-ethereum/trie" ) @@ -358,20 +358,19 @@ func (h *serverHandler) AddTxsSync() bool { } // getAccount retrieves an account from the state based on root. -func getAccount(triedb *trie.Database, root, hash common.Hash) (types.StateAccount, error) { - trie, err := trie.New(trie.StateTrieID(root), triedb) +func getAccount(triedb *trie.Database, root common.Hash, addr common.Address) (types.StateAccount, error) { + trie, err := trie.NewStateTrie(trie.StateTrieID(root), triedb) if err != nil { return types.StateAccount{}, err } - blob, err := trie.Get(hash[:]) + acc, err := trie.GetAccount(addr) if err != nil { return types.StateAccount{}, err } - var acc types.StateAccount - if err = rlp.DecodeBytes(blob, &acc); err != nil { - return types.StateAccount{}, err + if acc == nil { + return types.StateAccount{}, fmt.Errorf("account %#x is not present", addr) } - return acc, nil + return *acc, nil } // GetHelperTrie returns the post-processed trie root for the given trie ID and section index diff --git a/les/server_requests.go b/les/server_requests.go index f6208db44..30ff2cd05 100644 --- a/les/server_requests.go +++ b/les/server_requests.go @@ -304,16 +304,16 @@ func handleGetCode(msg Decoder) (serveRequestFn, uint64, uint64, error) { continue } triedb := bc.StateCache().TrieDB() - - account, err := getAccount(triedb, header.Root, common.BytesToHash(request.AccKey)) + address := common.BytesToAddress(request.AccountAddress) + account, err := getAccount(triedb, header.Root, address) if err != nil { - p.Log().Warn("Failed to retrieve account for code", "block", header.Number, "hash", header.Hash(), "account", common.BytesToHash(request.AccKey), "err", err) + p.Log().Warn("Failed to retrieve account for code", "block", header.Number, "hash", header.Hash(), "account", address, "err", err) p.bumpInvalid() continue } - code, err := bc.StateCache().ContractCode(common.BytesToHash(request.AccKey), common.BytesToHash(account.CodeHash)) + code, err := bc.StateCache().ContractCode(address, common.BytesToHash(account.CodeHash)) if err != nil { - p.Log().Warn("Failed to retrieve account code", "block", header.Number, "hash", header.Hash(), "account", common.BytesToHash(request.AccKey), "codehash", common.BytesToHash(account.CodeHash), "err", err) + p.Log().Warn("Failed to retrieve account code", "block", header.Number, "hash", header.Hash(), "account", address, "codehash", common.BytesToHash(account.CodeHash), "err", err) continue } // Accumulate the code and abort if enough data was retrieved @@ -413,7 +413,7 @@ func handleGetProofs(msg Decoder) (serveRequestFn, uint64, uint64, error) { statedb := bc.StateCache() var trie state.Trie - switch len(request.AccKey) { + switch len(request.AccountAddress) { case 0: // No account key specified, open an account trie trie, err = statedb.OpenTrie(root) @@ -423,15 +423,16 @@ func handleGetProofs(msg Decoder) (serveRequestFn, uint64, uint64, error) { } default: // Account key specified, open a storage trie - account, err := getAccount(statedb.TrieDB(), root, common.BytesToHash(request.AccKey)) + address := common.BytesToAddress(request.AccountAddress) + account, err := getAccount(statedb.TrieDB(), root, address) if err != nil { - p.Log().Warn("Failed to retrieve account for proof", "block", header.Number, "hash", header.Hash(), "account", common.BytesToHash(request.AccKey), "err", err) + p.Log().Warn("Failed to retrieve account for proof", "block", header.Number, "hash", header.Hash(), "account", address, "err", err) p.bumpInvalid() continue } - trie, err = statedb.OpenStorageTrie(root, common.BytesToHash(request.AccKey), account.Root) + trie, err = statedb.OpenStorageTrie(root, address, account.Root) if trie == nil || err != nil { - p.Log().Warn("Failed to open storage trie for proof", "block", header.Number, "hash", header.Hash(), "account", common.BytesToHash(request.AccKey), "root", account.Root, "err", err) + p.Log().Warn("Failed to open storage trie for proof", "block", header.Number, "hash", header.Hash(), "account", address, "root", account.Root, "err", err) continue } } diff --git a/light/odr.go b/light/odr.go index d331d3c9b..259702743 100644 --- a/light/odr.go +++ b/light/odr.go @@ -54,35 +54,35 @@ type OdrRequest interface { // TrieID identifies a state or account storage trie type TrieID struct { - BlockHash common.Hash - BlockNumber uint64 - StateRoot common.Hash - Root common.Hash - AccKey []byte + BlockHash common.Hash + BlockNumber uint64 + StateRoot common.Hash + Root common.Hash + AccountAddress []byte } // StateTrieID returns a TrieID for a state trie belonging to a certain block // header. func StateTrieID(header *types.Header) *TrieID { return &TrieID{ - BlockHash: header.Hash(), - BlockNumber: header.Number.Uint64(), - StateRoot: header.Root, - Root: header.Root, - AccKey: nil, + BlockHash: header.Hash(), + BlockNumber: header.Number.Uint64(), + StateRoot: header.Root, + Root: header.Root, + AccountAddress: nil, } } // StorageTrieID returns a TrieID for a contract storage trie at a given account // of a given state trie. It also requires the root hash of the trie for // checking Merkle proofs. -func StorageTrieID(state *TrieID, addrHash, root common.Hash) *TrieID { +func StorageTrieID(state *TrieID, address common.Address, root common.Hash) *TrieID { return &TrieID{ - BlockHash: state.BlockHash, - BlockNumber: state.BlockNumber, - StateRoot: state.StateRoot, - AccKey: addrHash[:], - Root: root, + BlockHash: state.BlockHash, + BlockNumber: state.BlockNumber, + StateRoot: state.StateRoot, + AccountAddress: address[:], + Root: root, } } diff --git a/light/odr_test.go b/light/odr_test.go index 90fdea81d..79f901bbd 100644 --- a/light/odr_test.go +++ b/light/odr_test.go @@ -86,8 +86,8 @@ func (odr *testOdr) Retrieve(ctx context.Context, req OdrRequest) error { err error t state.Trie ) - if len(req.Id.AccKey) > 0 { - t, err = odr.serverState.OpenStorageTrie(req.Id.StateRoot, common.BytesToHash(req.Id.AccKey), req.Id.Root) + if len(req.Id.AccountAddress) > 0 { + t, err = odr.serverState.OpenStorageTrie(req.Id.StateRoot, common.BytesToAddress(req.Id.AccountAddress), req.Id.Root) } else { t, err = odr.serverState.OpenTrie(req.Id.Root) } diff --git a/light/trie.go b/light/trie.go index b411c299f..43fccd5e3 100644 --- a/light/trie.go +++ b/light/trie.go @@ -55,8 +55,8 @@ func (db *odrDatabase) OpenTrie(root common.Hash) (state.Trie, error) { return &odrTrie{db: db, id: db.id}, nil } -func (db *odrDatabase) OpenStorageTrie(state, addrHash, root common.Hash) (state.Trie, error) { - return &odrTrie{db: db, id: StorageTrieID(db.id, addrHash, root)}, nil +func (db *odrDatabase) OpenStorageTrie(stateRoot common.Hash, address common.Address, root common.Hash) (state.Trie, error) { + return &odrTrie{db: db, id: StorageTrieID(db.id, address, root)}, nil } func (db *odrDatabase) CopyTrie(t state.Trie) state.Trie { @@ -72,7 +72,7 @@ func (db *odrDatabase) CopyTrie(t state.Trie) state.Trie { } } -func (db *odrDatabase) ContractCode(addrHash, codeHash common.Hash) ([]byte, error) { +func (db *odrDatabase) ContractCode(addr common.Address, codeHash common.Hash) ([]byte, error) { if codeHash == sha3Nil { return nil, nil } @@ -81,14 +81,14 @@ func (db *odrDatabase) ContractCode(addrHash, codeHash common.Hash) ([]byte, err return code, nil } id := *db.id - id.AccKey = addrHash[:] + id.AccountAddress = addr[:] req := &CodeRequest{Id: &id, Hash: codeHash} err := db.backend.Retrieve(db.ctx, req) return req.Data, err } -func (db *odrDatabase) ContractCodeSize(addrHash, codeHash common.Hash) (int, error) { - code, err := db.ContractCode(addrHash, codeHash) +func (db *odrDatabase) ContractCodeSize(addr common.Address, codeHash common.Hash) (int, error) { + code, err := db.ContractCode(addr, codeHash) return len(code), err } @@ -207,8 +207,8 @@ func (t *odrTrie) do(key []byte, fn func() error) error { var err error if t.trie == nil { var id *trie.ID - if len(t.id.AccKey) > 0 { - id = trie.StorageTrieID(t.id.StateRoot, common.BytesToHash(t.id.AccKey), t.id.Root) + if len(t.id.AccountAddress) > 0 { + id = trie.StorageTrieID(t.id.StateRoot, crypto.Keccak256Hash(t.id.AccountAddress), t.id.Root) } else { id = trie.StateTrieID(t.id.StateRoot) } @@ -239,8 +239,8 @@ func newNodeIterator(t *odrTrie, startkey []byte) trie.NodeIterator { if t.trie == nil { it.do(func() error { var id *trie.ID - if len(t.id.AccKey) > 0 { - id = trie.StorageTrieID(t.id.StateRoot, common.BytesToHash(t.id.AccKey), t.id.Root) + if len(t.id.AccountAddress) > 0 { + id = trie.StorageTrieID(t.id.StateRoot, crypto.Keccak256Hash(t.id.AccountAddress), t.id.Root) } else { id = trie.StateTrieID(t.id.StateRoot) } diff --git a/tests/fuzzers/les/les-fuzzer.go b/tests/fuzzers/les/les-fuzzer.go index b4599025e..c62253a72 100644 --- a/tests/fuzzers/les/les-fuzzer.go +++ b/tests/fuzzers/les/les-fuzzer.go @@ -45,9 +45,9 @@ var ( testChainLen = 256 testContractCode = common.Hex2Bytes("606060405260cc8060106000396000f360606040526000357c01000000000000000000000000000000000000000000000000000000009004806360cd2685146041578063c16431b914606b57603f565b005b6055600480803590602001909190505060a9565b6040518082815260200191505060405180910390f35b60886004808035906020019091908035906020019091905050608a565b005b80600060005083606481101560025790900160005b50819055505b5050565b6000600060005082606481101560025790900160005b5054905060c7565b91905056") - chain *core.BlockChain - addrHashes []common.Hash - txHashes []common.Hash + chain *core.BlockChain + addresses []common.Address + txHashes []common.Hash chtTrie *trie.Trie bloomTrie *trie.Trie @@ -55,7 +55,7 @@ var ( bloomKeys [][]byte ) -func makechain() (bc *core.BlockChain, addrHashes, txHashes []common.Hash) { +func makechain() (bc *core.BlockChain, addresses []common.Address, txHashes []common.Hash) { gspec := &core.Genesis{ Config: params.TestChainConfig, Alloc: core.GenesisAlloc{bankAddr: {Balance: bankFunds}}, @@ -77,7 +77,7 @@ func makechain() (bc *core.BlockChain, addrHashes, txHashes []common.Hash) { tx, _ = types.SignTx(types.NewTransaction(nonce, addr, big.NewInt(10000), params.TxGas, big.NewInt(params.GWei), nil), signer, bankKey) } gen.AddTx(tx) - addrHashes = append(addrHashes, crypto.Keccak256Hash(addr[:])) + addresses = append(addresses, addr) txHashes = append(txHashes, tx.Hash()) }) bc, _ = core.NewBlockChain(rawdb.NewMemoryDatabase(), nil, gspec, nil, ethash.NewFaker(), vm.Config{}, nil, nil) @@ -107,7 +107,7 @@ func makeTries() (chtTrie *trie.Trie, bloomTrie *trie.Trie, chtKeys, bloomKeys [ } func init() { - chain, addrHashes, txHashes = makechain() + chain, addresses, txHashes = makechain() chtTrie, bloomTrie, chtKeys, bloomKeys = makeTries() } @@ -116,7 +116,8 @@ type fuzzer struct { pool *txpool.TxPool chainLen int - addr, txs []common.Hash + addresses []common.Address + txs []common.Hash nonce uint64 chtKeys [][]byte @@ -135,7 +136,7 @@ func newFuzzer(input []byte) *fuzzer { return &fuzzer{ chain: chain, chainLen: testChainLen, - addr: addrHashes, + addresses: addresses, txs: txHashes, chtTrie: chtTrie, bloomTrie: bloomTrie, @@ -198,12 +199,12 @@ func (f *fuzzer) randomBlockHash() common.Hash { return common.BytesToHash(f.read(common.HashLength)) } -func (f *fuzzer) randomAddrHash() []byte { - i := f.randomInt(3 * len(f.addr)) - if i < len(f.addr) { - return f.addr[i].Bytes() +func (f *fuzzer) randomAddress() []byte { + i := f.randomInt(3 * len(f.addresses)) + if i < len(f.addresses) { + return f.addresses[i].Bytes() } - return f.read(common.HashLength) + return f.read(common.AddressLength) } func (f *fuzzer) randomCHTTrieKey() []byte { @@ -315,8 +316,8 @@ func Fuzz(input []byte) int { req := &l.GetCodePacket{Reqs: make([]l.CodeReq, f.randomInt(l.MaxCodeFetch+1))} for i := range req.Reqs { req.Reqs[i] = l.CodeReq{ - BHash: f.randomBlockHash(), - AccKey: f.randomAddrHash(), + BHash: f.randomBlockHash(), + AccountAddress: f.randomAddress(), } } f.doFuzz(l.GetCodeMsg, req) @@ -333,15 +334,15 @@ func Fuzz(input []byte) int { for i := range req.Reqs { if f.randomBool() { req.Reqs[i] = l.ProofReq{ - BHash: f.randomBlockHash(), - AccKey: f.randomAddrHash(), - Key: f.randomAddrHash(), - FromLevel: uint(f.randomX(3)), + BHash: f.randomBlockHash(), + AccountAddress: f.randomAddress(), + Key: f.randomAddress(), + FromLevel: uint(f.randomX(3)), } } else { req.Reqs[i] = l.ProofReq{ BHash: f.randomBlockHash(), - Key: f.randomAddrHash(), + Key: f.randomAddress(), FromLevel: uint(f.randomX(3)), } } diff --git a/trie/database_wrap.go b/trie/database_wrap.go index 30835b7d2..5ca013b38 100644 --- a/trie/database_wrap.go +++ b/trie/database_wrap.go @@ -168,10 +168,15 @@ func (db *Database) Scheme() string { // It is meant to be called when closing the blockchain object, so that all // resources held can be released correctly. func (db *Database) Close() error { + db.WritePreimages() + return db.backend.Close() +} + +// WritePreimages flushes all accumulated preimages to disk forcibly. +func (db *Database) WritePreimages() { if db.preimages != nil { db.preimages.commit(true) } - return db.backend.Close() } // saveCache saves clean state cache to given directory path