diff --git a/go.mod b/go.mod index 6c1f715..67d2676 100644 --- a/go.mod +++ b/go.mod @@ -5,6 +5,7 @@ go 1.18 require ( github.com/VictoriaMetrics/fastcache v1.6.0 github.com/cerc-io/ipfs-ethdb/v5 v5.0.0-alpha + github.com/davecgh/go-spew v1.1.1 github.com/ethereum/go-ethereum v1.11.5 github.com/hashicorp/golang-lru v0.5.5-0.20210104140557-80c98217689d github.com/ipfs/go-cid v0.2.0 @@ -21,13 +22,13 @@ require ( github.com/DataDog/zstd v1.5.2 // indirect github.com/beorn7/perks v1.0.1 // indirect github.com/btcsuite/btcd/btcec/v2 v2.2.0 // indirect + github.com/btcsuite/btcd/chaincfg/chainhash v1.0.2 // indirect github.com/cespare/xxhash/v2 v2.2.0 // indirect github.com/cockroachdb/errors v1.9.1 // indirect github.com/cockroachdb/logtags v0.0.0-20230118201751-21c54148d20b // indirect github.com/cockroachdb/pebble v0.0.0-20230209160836-829675f94811 // indirect github.com/cockroachdb/redact v1.1.3 // indirect github.com/cpuguy83/go-md2man/v2 v2.0.2 // indirect - github.com/davecgh/go-spew v1.1.1 // indirect github.com/deckarep/golang-set/v2 v2.1.0 // indirect github.com/decred/dcrd/dcrec/secp256k1/v4 v4.0.1 // indirect github.com/edsrzf/mmap-go v1.0.0 // indirect @@ -112,4 +113,4 @@ require ( lukechampine.com/blake3 v1.1.6 // indirect ) -replace github.com/ethereum/go-ethereum v1.11.5 => github.com/cerc-io/go-ethereum v1.11.5-statediff-5.0.1-alpha +replace github.com/ethereum/go-ethereum v1.11.5 => github.com/cerc-io/go-ethereum v1.11.5-statediff-5.0.4 diff --git a/go.sum b/go.sum index fef91fa..e7e8faa 100644 --- a/go.sum +++ b/go.sum @@ -17,12 +17,13 @@ github.com/armon/consul-api v0.0.0-20180202201655-eb2c6b5be1b6/go.mod h1:grANhF5 github.com/aymerick/raymond v2.0.3-0.20180322193309-b565731e1464+incompatible/go.mod h1:osfaiScAUVup+UC9Nfq76eWqDhXlp+4UYaA8uhTBO6g= github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM= github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw= -github.com/btcsuite/btcd v0.22.0-beta h1:LTDpDKUM5EeOFBPM8IXpinEcmZ6FWfNZbE3lfrfdnWo= github.com/btcsuite/btcd/btcec/v2 v2.2.0 h1:fzn1qaOt32TuLjFlkzYSsBC35Q3KUjT1SwPxiMSCF5k= github.com/btcsuite/btcd/btcec/v2 v2.2.0/go.mod h1:U7MHm051Al6XmscBQ0BoNydpOTsFAn707034b5nY8zU= +github.com/btcsuite/btcd/chaincfg/chainhash v1.0.2 h1:KdUfX2zKommPRa+PD0sWZUyXe9w277ABlgELO7H04IM= +github.com/btcsuite/btcd/chaincfg/chainhash v1.0.2/go.mod h1:7SFka0XMvUgj3hfZtydOrQY2mwhPclbT2snogU7SQQc= github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod h1:f6KPmirojxKA12rnyqOA5BBL4O983OfeGPqjHWSTneU= -github.com/cerc-io/go-ethereum v1.11.5-statediff-5.0.1-alpha h1:x9muuG0Z2W/UAkwAq+0SXhYG9MCP6SJlGEfFNQa6JPI= -github.com/cerc-io/go-ethereum v1.11.5-statediff-5.0.1-alpha/go.mod h1:DIk2wFexjyzvyjuzSOtBEIAPRNZTnLXNbIHEyq1Igek= +github.com/cerc-io/go-ethereum v1.11.5-statediff-5.0.3-alpha h1:rhRmK/NeWMnQ07E4DuLb7WSh9FMotlXMPPaOrf8GJwM= +github.com/cerc-io/go-ethereum v1.11.5-statediff-5.0.3-alpha/go.mod h1:DIk2wFexjyzvyjuzSOtBEIAPRNZTnLXNbIHEyq1Igek= github.com/cerc-io/ipfs-ethdb/v5 v5.0.0-alpha h1:I1iXTaIjbTH8ehzNXmT2waXcYBifi1yjK6FK3W3a0Pg= github.com/cerc-io/ipfs-ethdb/v5 v5.0.0-alpha/go.mod h1:EGAdV/YewEADFDDVF1k9GNwy8vNWR29Xb87sRHgMIng= github.com/cespare/cp v0.1.0 h1:SE+dxFebS7Iik5LK0tsi1k9ZCxEaFX4AjQmoyA+1dJk= diff --git a/internal/util.go b/internal/util.go index a19df0a..634c18a 100644 --- a/internal/util.go +++ b/internal/util.go @@ -1,6 +1,10 @@ package internal import ( + "testing" + "time" + + pgipfsethdb "github.com/cerc-io/ipfs-ethdb/v5/postgres/v0" "github.com/ipfs/go-cid" "github.com/multiformats/go-multihash" ) @@ -12,3 +16,12 @@ func Keccak256ToCid(codec uint64, h []byte) (cid.Cid, error) { } return cid.NewCidV1(codec, buf), nil } + +// returns a cache config with unique name (groupcache names are global) +func MakeCacheConfig(t testing.TB) pgipfsethdb.CacheConfig { + return pgipfsethdb.CacheConfig{ + Name: t.Name(), + Size: 3000000, // 3MB + ExpiryDuration: time.Hour, + } +} diff --git a/trie_by_cid/doc.go b/trie_by_cid/doc.go new file mode 100644 index 0000000..4b9b440 --- /dev/null +++ b/trie_by_cid/doc.go @@ -0,0 +1,3 @@ +// This package is a near complete copy of go-ethereum/trie and go-ethereum/core/state, modified to use +// a v0 IPFS blockstore as the backing DB, i.e. DB values are indexed by CID rather than hash. +package trie_by_cid diff --git a/trie_by_cid/helper/statediff_helper.go b/trie_by_cid/helper/statediff_helper.go index b03abf6..030c2b3 100644 --- a/trie_by_cid/helper/statediff_helper.go +++ b/trie_by_cid/helper/statediff_helper.go @@ -20,7 +20,10 @@ var ( mockTD = big.NewInt(1) ) -func IndexChain(dbConfig postgres.Config, stateCache state.Database, rootA, rootB common.Hash) error { +// IndexStateDiff indexes a single statediff. +// - uses TestChainConfig +// - block hash/number are left as zero +func IndexStateDiff(dbConfig postgres.Config, stateCache state.Database, rootA, rootB common.Hash) error { _, indexer, err := indexer.NewStateDiffIndexer( context.Background(), ChainConfig, node.Info{}, dbConfig) if err != nil { @@ -28,11 +31,10 @@ func IndexChain(dbConfig postgres.Config, stateCache state.Database, rootA, root } defer indexer.Close() // fixme: hangs when using PGX driver - // generating statediff payload for block, and transform the data into Postgres builder := statediff.NewBuilder(stateCache) block := types.NewBlock(&types.Header{Root: rootB}, nil, nil, nil, NewHasher()) - // todo: use dummy block hashes to just produce trie structure for testing + // uses zero block hash/number, we only need the trie structure here args := statediff.Args{ OldStateRoot: rootA, NewStateRoot: rootB, @@ -45,12 +47,7 @@ func IndexChain(dbConfig postgres.Config, stateCache state.Database, rootA, root if err != nil { return err } - // for _, node := range diff.Nodes { - // err := indexer.PushStateNode(tx, node, block.Hash().String()) - // if err != nil { - // return err - // } - // } + // we don't need to index diff.Nodes since we are just interested in the trie for _, ipld := range diff.IPLDs { if err := indexer.PushIPLD(tx, ipld); err != nil { return err diff --git a/trie_by_cid/state/access_list.go b/trie_by_cid/state/access_list.go new file mode 100644 index 0000000..4194691 --- /dev/null +++ b/trie_by_cid/state/access_list.go @@ -0,0 +1,136 @@ +// Copyright 2020 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package state + +import ( + "github.com/ethereum/go-ethereum/common" +) + +type accessList struct { + addresses map[common.Address]int + slots []map[common.Hash]struct{} +} + +// ContainsAddress returns true if the address is in the access list. +func (al *accessList) ContainsAddress(address common.Address) bool { + _, ok := al.addresses[address] + return ok +} + +// Contains checks if a slot within an account is present in the access list, returning +// separate flags for the presence of the account and the slot respectively. +func (al *accessList) Contains(address common.Address, slot common.Hash) (addressPresent bool, slotPresent bool) { + idx, ok := al.addresses[address] + if !ok { + // no such address (and hence zero slots) + return false, false + } + if idx == -1 { + // address yes, but no slots + return true, false + } + _, slotPresent = al.slots[idx][slot] + return true, slotPresent +} + +// newAccessList creates a new accessList. +func newAccessList() *accessList { + return &accessList{ + addresses: make(map[common.Address]int), + } +} + +// Copy creates an independent copy of an accessList. +func (a *accessList) Copy() *accessList { + cp := newAccessList() + for k, v := range a.addresses { + cp.addresses[k] = v + } + cp.slots = make([]map[common.Hash]struct{}, len(a.slots)) + for i, slotMap := range a.slots { + newSlotmap := make(map[common.Hash]struct{}, len(slotMap)) + for k := range slotMap { + newSlotmap[k] = struct{}{} + } + cp.slots[i] = newSlotmap + } + return cp +} + +// AddAddress adds an address to the access list, and returns 'true' if the operation +// caused a change (addr was not previously in the list). +func (al *accessList) AddAddress(address common.Address) bool { + if _, present := al.addresses[address]; present { + return false + } + al.addresses[address] = -1 + return true +} + +// AddSlot adds the specified (addr, slot) combo to the access list. +// Return values are: +// - address added +// - slot added +// For any 'true' value returned, a corresponding journal entry must be made. +func (al *accessList) AddSlot(address common.Address, slot common.Hash) (addrChange bool, slotChange bool) { + idx, addrPresent := al.addresses[address] + if !addrPresent || idx == -1 { + // Address not present, or addr present but no slots there + al.addresses[address] = len(al.slots) + slotmap := map[common.Hash]struct{}{slot: {}} + al.slots = append(al.slots, slotmap) + return !addrPresent, true + } + // There is already an (address,slot) mapping + slotmap := al.slots[idx] + if _, ok := slotmap[slot]; !ok { + slotmap[slot] = struct{}{} + // Journal add slot change + return false, true + } + // No changes required + return false, false +} + +// DeleteSlot removes an (address, slot)-tuple from the access list. +// This operation needs to be performed in the same order as the addition happened. +// This method is meant to be used by the journal, which maintains ordering of +// operations. +func (al *accessList) DeleteSlot(address common.Address, slot common.Hash) { + idx, addrOk := al.addresses[address] + // There are two ways this can fail + if !addrOk { + panic("reverting slot change, address not present in list") + } + slotmap := al.slots[idx] + delete(slotmap, slot) + // If that was the last (first) slot, remove it + // Since additions and rollbacks are always performed in order, + // we can delete the item without worrying about screwing up later indices + if len(slotmap) == 0 { + al.slots = al.slots[:idx] + al.addresses[address] = -1 + } +} + +// DeleteAddress removes an address from the access list. This operation +// needs to be performed in the same order as the addition happened. +// This method is meant to be used by the journal, which maintains ordering of +// operations. +func (al *accessList) DeleteAddress(address common.Address) { + delete(al.addresses, address) +} diff --git a/trie_by_cid/state/database.go b/trie_by_cid/state/database.go index 30adbb7..e6920da 100644 --- a/trie_by_cid/state/database.go +++ b/trie_by_cid/state/database.go @@ -1,14 +1,30 @@ +// Copyright 2017 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + package state import ( "errors" + "fmt" - "github.com/VictoriaMetrics/fastcache" "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/common/lru" "github.com/ethereum/go-ethereum/core/types" "github.com/ethereum/go-ethereum/ethdb" "github.com/ethereum/go-ethereum/statediff/indexer/ipld" - lru "github.com/hashicorp/golang-lru" "github.com/cerc-io/ipld-eth-statedb/internal" "github.com/cerc-io/ipld-eth-statedb/trie_by_cid/trie" @@ -28,7 +44,10 @@ type Database interface { OpenTrie(root common.Hash) (Trie, error) // OpenStorageTrie opens the storage trie of an account. - OpenStorageTrie(addrHash, root common.Hash) (Trie, error) + OpenStorageTrie(stateRoot common.Hash, addrHash, root common.Hash) (Trie, error) + + // CopyTrie returns an independent copy of the given trie. + CopyTrie(Trie) Trie // ContractCode retrieves a particular contract's code. ContractCode(codeHash common.Hash) ([]byte, error) @@ -36,17 +55,79 @@ type Database interface { // ContractCodeSize retrieves a particular contracts code's size. ContractCodeSize(codeHash common.Hash) (int, error) + // DiskDB returns the underlying key-value disk database. + DiskDB() ethdb.KeyValueStore + // TrieDB retrieves the low level trie database used for data storage. TrieDB() *trie.Database } // Trie is a Ethereum Merkle Patricia trie. type Trie interface { + // GetKey returns the sha3 preimage of a hashed key that was previously used + // to store a value. + // + // TODO(fjl): remove this when StateTrie is removed + GetKey([]byte) []byte + + // TryGet returns the value for key stored in the trie. The value bytes must + // not be modified by the caller. If a node was not found in the database, a + // trie.MissingNodeError is returned. TryGet(key []byte) ([]byte, error) + + // TryGetNode attempts to retrieve a trie node by compact-encoded path. It is not + // possible to use keybyte-encoding as the path might contain odd nibbles. TryGetNode(path []byte) ([]byte, int, error) - TryGetAccount(key []byte) (*types.StateAccount, error) + + // TryGetAccount abstracts an account read from the trie. It retrieves the + // account blob from the trie with provided account address and decodes it + // with associated decoding algorithm. If the specified account is not in + // the trie, nil will be returned. If the trie is corrupted(e.g. some nodes + // are missing or the account blob is incorrect for decoding), an error will + // be returned. + TryGetAccount(address common.Address) (*types.StateAccount, error) + + // TryUpdate associates key with value in the trie. If value has length zero, any + // existing value is deleted from the trie. The value bytes must not be modified + // by the caller while they are stored in the trie. If a node was not found in the + // database, a trie.MissingNodeError is returned. + TryUpdate(key, value []byte) error + + // TryUpdateAccount abstracts an account write to the trie. It encodes the + // provided account object with associated algorithm and then updates it + // in the trie with provided address. + TryUpdateAccount(address common.Address, account *types.StateAccount) error + + // TryDelete removes any existing value for key from the trie. If a node was not + // found in the database, a trie.MissingNodeError is returned. + TryDelete(key []byte) error + + // TryDeleteAccount abstracts an account deletion from the trie. + TryDeleteAccount(address common.Address) error + + // Hash returns the root hash of the trie. It does not write to the database and + // can be used even if the trie doesn't have one. Hash() common.Hash + + // Commit collects all dirty nodes in the trie and replace them with the + // corresponding node hash. All collected nodes(including dirty leaves if + // collectLeaf is true) will be encapsulated into a nodeset for return. + // The returned nodeset can be nil if the trie is clean(nothing to commit). + // Once the trie is committed, it's not usable anymore. A new trie must + // be created with new root and updated trie database for following usage + Commit(collectLeaf bool) (common.Hash, *trie.NodeSet) + + // NodeIterator returns an iterator that returns nodes of the trie. Iteration + // starts at the key after the given start key. NodeIterator(startKey []byte) trie.NodeIterator + + // Prove constructs a Merkle proof for key. The result contains all encoded nodes + // on the path to the value at key. The value itself is also included in the last + // node and can be retrieved by verifying the proof. + // + // If the trie does not contain a value for key, the returned proof contains all + // nodes of the longest existing prefix of the key (at least the root), ending + // with the node that proves the absence of the key. Prove(key []byte, fromLevel uint, proofDb ethdb.KeyValueWriter) error } @@ -61,23 +142,34 @@ func NewDatabase(db ethdb.Database) Database { // is safe for concurrent use and retains a lot of collapsed RLP trie nodes in a // large memory cache. func NewDatabaseWithConfig(db ethdb.Database, config *trie.Config) Database { - csc, _ := lru.New(codeSizeCacheSize) return &cachingDB{ - db: trie.NewDatabaseWithConfig(db, config), - codeSizeCache: csc, - codeCache: fastcache.New(codeCacheSize), + disk: db, + codeSizeCache: lru.NewCache[common.Hash, int](codeSizeCacheSize), + codeCache: lru.NewSizeConstrainedCache[common.Hash, []byte](codeCacheSize), + triedb: trie.NewDatabaseWithConfig(db, config), + } +} + +// NewDatabaseWithNodeDB creates a state database with an already initialized node database. +func NewDatabaseWithNodeDB(db ethdb.Database, triedb *trie.Database) Database { + return &cachingDB{ + disk: db, + codeSizeCache: lru.NewCache[common.Hash, int](codeSizeCacheSize), + codeCache: lru.NewSizeConstrainedCache[common.Hash, []byte](codeCacheSize), + triedb: triedb, } } type cachingDB struct { - db *trie.Database - codeSizeCache *lru.Cache - codeCache *fastcache.Cache + disk ethdb.KeyValueStore + codeSizeCache *lru.Cache[common.Hash, int] + codeCache *lru.SizeConstrainedCache[common.Hash, []byte] + triedb *trie.Database } // OpenTrie opens the main account trie at a specific root hash. func (db *cachingDB) OpenTrie(root common.Hash) (Trie, error) { - tr, err := trie.NewStateTrie(common.Hash{}, root, db.db) + tr, err := trie.NewStateTrie(trie.StateTrieID(root), db.triedb, trie.StateTrieCodec) if err != nil { return nil, err } @@ -85,29 +177,40 @@ func (db *cachingDB) OpenTrie(root common.Hash) (Trie, error) { } // OpenStorageTrie opens the storage trie of an account. -func (db *cachingDB) OpenStorageTrie(addrHash, root common.Hash) (Trie, error) { - tr, err := trie.NewStorageTrie(addrHash, root, db.db) +func (db *cachingDB) OpenStorageTrie(stateRoot common.Hash, addrHash, root common.Hash) (Trie, error) { + tr, err := trie.NewStateTrie(trie.StorageTrieID(stateRoot, addrHash, root), db.triedb, trie.StorageTrieCodec) if err != nil { return nil, err } return tr, nil } +// CopyTrie returns an independent copy of the given trie. +func (db *cachingDB) CopyTrie(t Trie) Trie { + switch t := t.(type) { + case *trie.StateTrie: + return t.Copy() + default: + panic(fmt.Errorf("unknown trie type %T", t)) + } +} + // ContractCode retrieves a particular contract's code. func (db *cachingDB) ContractCode(codeHash common.Hash) ([]byte, error) { - if code := db.codeCache.Get(nil, codeHash.Bytes()); len(code) > 0 { + code, _ := db.codeCache.Get(codeHash) + if len(code) > 0 { return code, nil } - codeCID, err := internal.Keccak256ToCid(ipld.RawBinary, codeHash.Bytes()) + cid, err := internal.Keccak256ToCid(ipld.RawBinary, codeHash.Bytes()) if err != nil { return nil, err } - code, err := db.db.DiskDB().Get(codeCID.Bytes()) + code, err = db.disk.Get(cid.Bytes()) if err != nil { return nil, err } if len(code) > 0 { - db.codeCache.Set(codeHash.Bytes(), code) + db.codeCache.Add(codeHash, code) db.codeSizeCache.Add(codeHash, len(code)) return code, nil } @@ -117,13 +220,18 @@ func (db *cachingDB) ContractCode(codeHash common.Hash) ([]byte, error) { // ContractCodeSize retrieves a particular contracts code's size. func (db *cachingDB) ContractCodeSize(codeHash common.Hash) (int, error) { if cached, ok := db.codeSizeCache.Get(codeHash); ok { - return cached.(int), nil + return cached, nil } code, err := db.ContractCode(codeHash) return len(code), err } +// DiskDB returns the underlying key-value disk database. +func (db *cachingDB) DiskDB() ethdb.KeyValueStore { + return db.disk +} + // TrieDB retrieves any intermediate trie-node caching layer. func (db *cachingDB) TrieDB() *trie.Database { - return db.db + return db.triedb } diff --git a/trie_by_cid/state/journal.go b/trie_by_cid/state/journal.go new file mode 100644 index 0000000..1722fb4 --- /dev/null +++ b/trie_by_cid/state/journal.go @@ -0,0 +1,282 @@ +// Copyright 2016 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package state + +import ( + "math/big" + + "github.com/ethereum/go-ethereum/common" +) + +// journalEntry is a modification entry in the state change journal that can be +// reverted on demand. +type journalEntry interface { + // revert undoes the changes introduced by this journal entry. + revert(*StateDB) + + // dirtied returns the Ethereum address modified by this journal entry. + dirtied() *common.Address +} + +// journal contains the list of state modifications applied since the last state +// commit. These are tracked to be able to be reverted in the case of an execution +// exception or request for reversal. +type journal struct { + entries []journalEntry // Current changes tracked by the journal + dirties map[common.Address]int // Dirty accounts and the number of changes +} + +// newJournal creates a new initialized journal. +func newJournal() *journal { + return &journal{ + dirties: make(map[common.Address]int), + } +} + +// append inserts a new modification entry to the end of the change journal. +func (j *journal) append(entry journalEntry) { + j.entries = append(j.entries, entry) + if addr := entry.dirtied(); addr != nil { + j.dirties[*addr]++ + } +} + +// revert undoes a batch of journalled modifications along with any reverted +// dirty handling too. +func (j *journal) revert(statedb *StateDB, snapshot int) { + for i := len(j.entries) - 1; i >= snapshot; i-- { + // Undo the changes made by the operation + j.entries[i].revert(statedb) + + // Drop any dirty tracking induced by the change + if addr := j.entries[i].dirtied(); addr != nil { + if j.dirties[*addr]--; j.dirties[*addr] == 0 { + delete(j.dirties, *addr) + } + } + } + j.entries = j.entries[:snapshot] +} + +// dirty explicitly sets an address to dirty, even if the change entries would +// otherwise suggest it as clean. This method is an ugly hack to handle the RIPEMD +// precompile consensus exception. +func (j *journal) dirty(addr common.Address) { + j.dirties[addr]++ +} + +// length returns the current number of entries in the journal. +func (j *journal) length() int { + return len(j.entries) +} + +type ( + // Changes to the account trie. + createObjectChange struct { + account *common.Address + } + resetObjectChange struct { + prev *stateObject + prevdestruct bool + } + suicideChange struct { + account *common.Address + prev bool // whether account had already suicided + prevbalance *big.Int + } + + // Changes to individual accounts. + balanceChange struct { + account *common.Address + prev *big.Int + } + nonceChange struct { + account *common.Address + prev uint64 + } + storageChange struct { + account *common.Address + key, prevalue common.Hash + } + codeChange struct { + account *common.Address + prevcode, prevhash []byte + } + + // Changes to other state values. + refundChange struct { + prev uint64 + } + addLogChange struct { + txhash common.Hash + } + addPreimageChange struct { + hash common.Hash + } + touchChange struct { + account *common.Address + } + // Changes to the access list + accessListAddAccountChange struct { + address *common.Address + } + accessListAddSlotChange struct { + address *common.Address + slot *common.Hash + } + + transientStorageChange struct { + account *common.Address + key, prevalue common.Hash + } +) + +func (ch createObjectChange) revert(s *StateDB) { + delete(s.stateObjects, *ch.account) + delete(s.stateObjectsDirty, *ch.account) +} + +func (ch createObjectChange) dirtied() *common.Address { + return ch.account +} + +func (ch resetObjectChange) revert(s *StateDB) { + s.setStateObject(ch.prev) + if !ch.prevdestruct { + delete(s.stateObjectsDestruct, ch.prev.address) + } +} + +func (ch resetObjectChange) dirtied() *common.Address { + return nil +} + +func (ch suicideChange) revert(s *StateDB) { + obj := s.getStateObject(*ch.account) + if obj != nil { + obj.suicided = ch.prev + obj.setBalance(ch.prevbalance) + } +} + +func (ch suicideChange) dirtied() *common.Address { + return ch.account +} + +var ripemd = common.HexToAddress("0000000000000000000000000000000000000003") + +func (ch touchChange) revert(s *StateDB) { +} + +func (ch touchChange) dirtied() *common.Address { + return ch.account +} + +func (ch balanceChange) revert(s *StateDB) { + s.getStateObject(*ch.account).setBalance(ch.prev) +} + +func (ch balanceChange) dirtied() *common.Address { + return ch.account +} + +func (ch nonceChange) revert(s *StateDB) { + s.getStateObject(*ch.account).setNonce(ch.prev) +} + +func (ch nonceChange) dirtied() *common.Address { + return ch.account +} + +func (ch codeChange) revert(s *StateDB) { + s.getStateObject(*ch.account).setCode(common.BytesToHash(ch.prevhash), ch.prevcode) +} + +func (ch codeChange) dirtied() *common.Address { + return ch.account +} + +func (ch storageChange) revert(s *StateDB) { + s.getStateObject(*ch.account).setState(ch.key, ch.prevalue) +} + +func (ch storageChange) dirtied() *common.Address { + return ch.account +} + +func (ch transientStorageChange) revert(s *StateDB) { + s.setTransientState(*ch.account, ch.key, ch.prevalue) +} + +func (ch transientStorageChange) dirtied() *common.Address { + return nil +} + +func (ch refundChange) revert(s *StateDB) { + s.refund = ch.prev +} + +func (ch refundChange) dirtied() *common.Address { + return nil +} + +func (ch addLogChange) revert(s *StateDB) { + logs := s.logs[ch.txhash] + if len(logs) == 1 { + delete(s.logs, ch.txhash) + } else { + s.logs[ch.txhash] = logs[:len(logs)-1] + } + s.logSize-- +} + +func (ch addLogChange) dirtied() *common.Address { + return nil +} + +func (ch addPreimageChange) revert(s *StateDB) { + delete(s.preimages, ch.hash) +} + +func (ch addPreimageChange) dirtied() *common.Address { + return nil +} + +func (ch accessListAddAccountChange) revert(s *StateDB) { + /* + One important invariant here, is that whenever a (addr, slot) is added, if the + addr is not already present, the add causes two journal entries: + - one for the address, + - one for the (address,slot) + Therefore, when unrolling the change, we can always blindly delete the + (addr) at this point, since no storage adds can remain when come upon + a single (addr) change. + */ + s.accessList.DeleteAddress(*ch.address) +} + +func (ch accessListAddAccountChange) dirtied() *common.Address { + return nil +} + +func (ch accessListAddSlotChange) revert(s *StateDB) { + s.accessList.DeleteSlot(*ch.address, *ch.slot) +} + +func (ch accessListAddSlotChange) dirtied() *common.Address { + return nil +} diff --git a/trie_by_cid/state/metrics.go b/trie_by_cid/state/metrics.go new file mode 100644 index 0000000..e702ef3 --- /dev/null +++ b/trie_by_cid/state/metrics.go @@ -0,0 +1,30 @@ +// Copyright 2021 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package state + +import "github.com/ethereum/go-ethereum/metrics" + +var ( + accountUpdatedMeter = metrics.NewRegisteredMeter("state/update/account", nil) + storageUpdatedMeter = metrics.NewRegisteredMeter("state/update/storage", nil) + accountDeletedMeter = metrics.NewRegisteredMeter("state/delete/account", nil) + storageDeletedMeter = metrics.NewRegisteredMeter("state/delete/storage", nil) + accountTrieUpdatedMeter = metrics.NewRegisteredMeter("state/update/accountnodes", nil) + storageTriesUpdatedMeter = metrics.NewRegisteredMeter("state/update/storagenodes", nil) + accountTrieDeletedMeter = metrics.NewRegisteredMeter("state/delete/accountnodes", nil) + storageTriesDeletedMeter = metrics.NewRegisteredMeter("state/delete/storagenodes", nil) +) diff --git a/trie_by_cid/state/state_object.go b/trie_by_cid/state/state_object.go index 6b53c67..e1afcb3 100644 --- a/trie_by_cid/state/state_object.go +++ b/trie_by_cid/state/state_object.go @@ -1,8 +1,25 @@ +// Copyright 2014 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + package state import ( "bytes" "fmt" + "io" "math/big" "time" @@ -11,15 +28,7 @@ import ( "github.com/ethereum/go-ethereum/crypto" "github.com/ethereum/go-ethereum/metrics" "github.com/ethereum/go-ethereum/rlp" -) - -var ( - // emptyRoot is the known root hash of an empty trie. - // this is calculated as: emptyRoot = crypto.Keccak256(rlp.Encode([][]byte{})) - // that is, the keccak356 hash of the rlp encoding of an empty trie node (empty byte slice array) - emptyRoot = common.HexToHash("56e81f171bcc55a6ff8345e692c0f86e5b48e01b996cadc001622fb5e363b421") - // emptyCodeHash is the CodeHash for an EOA, for an account without contract code deployed - emptyCodeHash = crypto.Keccak256(nil) + // "github.com/cerc-io/ipld-eth-statedb/trie_by_cid/trie" ) type Code []byte @@ -34,7 +43,6 @@ func (s Storage) String() (str string) { for key, value := range s { str += fmt.Sprintf("%X : %X\n", key, value) } - return } @@ -43,32 +51,39 @@ func (s Storage) Copy() Storage { for key, value := range s { cpy[key] = value } - return cpy } -// stateObject represents an Ethereum account which is being accessed. +// stateObject represents an Ethereum account which is being modified. // // The usage pattern is as follows: // First you need to obtain a state object. -// Account values can be accessed through the object. +// Account values can be accessed and modified through the object. type stateObject struct { address common.Address addrHash common.Hash // hash of ethereum address of the account data types.StateAccount db *StateDB - // Caches. + // Write caches. trie Trie // storage trie, which becomes non-nil on first access code Code // contract bytecode, which gets set when code is loaded - originStorage Storage // Storage cache of original entries to dedup rewrites, reset for every transaction - fakeStorage Storage // Fake storage which constructed by caller for debugging purpose. + originStorage Storage // Storage cache of original entries to dedup rewrites, reset for every transaction + pendingStorage Storage // Storage entries that need to be flushed to disk, at the end of an entire block + dirtyStorage Storage // Storage entries that have been modified in the current transaction execution + + // Cache flags. + // When an object is marked suicided it will be deleted from the trie + // during the "update" phase of the state transition. + dirtyCode bool // true if the code was updated + suicided bool + deleted bool } // empty returns whether the account is considered empty. func (s *stateObject) empty() bool { - return s.data.Nonce == 0 && s.data.Balance.Sign() == 0 && bytes.Equal(s.data.CodeHash, emptyCodeHash) + return s.data.Nonce == 0 && s.data.Balance.Sign() == 0 && bytes.Equal(s.data.CodeHash, types.EmptyCodeHash.Bytes()) } // newObject creates a state object. @@ -77,71 +92,128 @@ func newObject(db *StateDB, address common.Address, data types.StateAccount) *st data.Balance = new(big.Int) } if data.CodeHash == nil { - data.CodeHash = emptyCodeHash + data.CodeHash = types.EmptyCodeHash.Bytes() } if data.Root == (common.Hash{}) { - data.Root = emptyRoot + data.Root = types.EmptyRootHash } return &stateObject{ - db: db, - address: address, - addrHash: crypto.Keccak256Hash(address[:]), - data: data, - originStorage: make(Storage), + db: db, + address: address, + addrHash: crypto.Keccak256Hash(address[:]), + data: data, + originStorage: make(Storage), + pendingStorage: make(Storage), + dirtyStorage: make(Storage), } } -// setError remembers the first non-nil error it is called with. -func (s *stateObject) setError(err error) { - s.db.setError(err) +// EncodeRLP implements rlp.Encoder. +func (s *stateObject) EncodeRLP(w io.Writer) error { + return rlp.Encode(w, &s.data) } -func (s *stateObject) getTrie(db Database) Trie { +func (s *stateObject) markSuicided() { + s.suicided = true +} + +func (s *stateObject) touch() { + s.db.journal.append(touchChange{ + account: &s.address, + }) + if s.address == ripemd { + // Explicitly put it in the dirty-cache, which is otherwise generated from + // flattened journals. + s.db.journal.dirty(s.address) + } +} + +// getTrie returns the associated storage trie. The trie will be opened +// if it's not loaded previously. An error will be returned if trie can't +// be loaded. +func (s *stateObject) getTrie(db Database) (Trie, error) { if s.trie == nil { - // // Try fetching from prefetcher first - // // We don't prefetch empty tries - // if s.data.Root != emptyRoot && s.db.prefetcher != nil { - // // When the miner is creating the pending state, there is no - // // prefetcher - // s.trie = s.db.prefetcher.trie(s.addrHash, s.data.Root) - // } + // Try fetching from prefetcher first + // We don't prefetch empty tries + if s.data.Root != types.EmptyRootHash && s.db.prefetcher != nil { + // When the miner is creating the pending state, there is no + // prefetcher + s.trie = s.db.prefetcher.trie(s.addrHash, s.data.Root) + } if s.trie == nil { - var err error - s.trie, err = db.OpenStorageTrie(s.addrHash, s.data.Root) + tr, err := db.OpenStorageTrie(s.db.originalRoot, s.addrHash, s.data.Root) if err != nil { - s.trie, _ = db.OpenStorageTrie(s.addrHash, common.Hash{}) - s.setError(fmt.Errorf("can't create storage trie: %w", err)) + return nil, err } + s.trie = tr } } - return s.trie + return s.trie, nil +} + +// GetState retrieves a value from the account storage trie. +func (s *stateObject) GetState(db Database, key common.Hash) common.Hash { + // If we have a dirty value for this state entry, return it + value, dirty := s.dirtyStorage[key] + if dirty { + return value + } + // Otherwise return the entry's original value + return s.GetCommittedState(db, key) } // GetCommittedState retrieves a value from the committed account storage trie. -func (s *stateObject) GetState(db Database, key common.Hash) common.Hash { - // If the fake storage is set, only lookup the state here(in the debugging mode) - if s.fakeStorage != nil { - return s.fakeStorage[key] +func (s *stateObject) GetCommittedState(db Database, key common.Hash) common.Hash { + // If we have a pending write or clean cached, return that + if value, pending := s.pendingStorage[key]; pending { + return value } - // If we have a cached value, return that if value, cached := s.originStorage[key]; cached { return value } - // If no live objects are available, load from the database. - start := time.Now() - enc, err := s.getTrie(db).TryGet(key.Bytes()) - if metrics.EnabledExpensive { - s.db.StorageReads += time.Since(start) - } - if err != nil { - s.setError(err) + // If the object was destructed in *this* block (and potentially resurrected), + // the storage has been cleared out, and we should *not* consult the previous + // database about any storage values. The only possible alternatives are: + // 1) resurrect happened, and new slot values were set -- those should + // have been handles via pendingStorage above. + // 2) we don't have new values, and can deliver empty response back + if _, destructed := s.db.stateObjectsDestruct[s.address]; destructed { return common.Hash{} } + // If no live objects are available, attempt to use snapshots + var ( + enc []byte + err error + ) + if s.db.snap != nil { + start := time.Now() + enc, err = s.db.snap.Storage(s.addrHash, crypto.Keccak256Hash(key.Bytes())) + if metrics.EnabledExpensive { + s.db.SnapshotStorageReads += time.Since(start) + } + } + // If the snapshot is unavailable or reading from it fails, load from the database. + if s.db.snap == nil || err != nil { + start := time.Now() + tr, err := s.getTrie(db) + if err != nil { + s.db.setError(err) + return common.Hash{} + } + enc, err = tr.TryGet(key.Bytes()) + if metrics.EnabledExpensive { + s.db.StorageReads += time.Since(start) + } + if err != nil { + s.db.setError(err) + return common.Hash{} + } + } var value common.Hash if len(enc) > 0 { _, content, _, err := rlp.Split(enc) if err != nil { - s.setError(err) + s.db.setError(err) } value.SetBytes(content) } @@ -149,6 +221,182 @@ func (s *stateObject) GetState(db Database, key common.Hash) common.Hash { return value } +// SetState updates a value in account storage. +func (s *stateObject) SetState(db Database, key, value common.Hash) { + // If the new value is the same as old, don't set + prev := s.GetState(db, key) + if prev == value { + return + } + // New value is different, update and journal the change + s.db.journal.append(storageChange{ + account: &s.address, + key: key, + prevalue: prev, + }) + s.setState(key, value) +} + +func (s *stateObject) setState(key, value common.Hash) { + s.dirtyStorage[key] = value +} + +// finalise moves all dirty storage slots into the pending area to be hashed or +// committed later. It is invoked at the end of every transaction. +func (s *stateObject) finalise(prefetch bool) { + slotsToPrefetch := make([][]byte, 0, len(s.dirtyStorage)) + for key, value := range s.dirtyStorage { + s.pendingStorage[key] = value + if value != s.originStorage[key] { + slotsToPrefetch = append(slotsToPrefetch, common.CopyBytes(key[:])) // Copy needed for closure + } + } + if s.db.prefetcher != nil && prefetch && len(slotsToPrefetch) > 0 && s.data.Root != types.EmptyRootHash { + s.db.prefetcher.prefetch(s.addrHash, s.data.Root, slotsToPrefetch) + } + if len(s.dirtyStorage) > 0 { + s.dirtyStorage = make(Storage) + } +} + +// updateTrie writes cached storage modifications into the object's storage trie. +// It will return nil if the trie has not been loaded and no changes have been +// made. An error will be returned if the trie can't be loaded/updated correctly. +func (s *stateObject) updateTrie(db Database) (Trie, error) { + // Make sure all dirty slots are finalized into the pending storage area + s.finalise(false) // Don't prefetch anymore, pull directly if need be + if len(s.pendingStorage) == 0 { + return s.trie, nil + } + // Track the amount of time wasted on updating the storage trie + if metrics.EnabledExpensive { + defer func(start time.Time) { s.db.StorageUpdates += time.Since(start) }(time.Now()) + } + // The snapshot storage map for the object + var ( + storage map[common.Hash][]byte + hasher = s.db.hasher + ) + tr, err := s.getTrie(db) + if err != nil { + s.db.setError(err) + return nil, err + } + // Insert all the pending updates into the trie + usedStorage := make([][]byte, 0, len(s.pendingStorage)) + for key, value := range s.pendingStorage { + // Skip noop changes, persist actual changes + if value == s.originStorage[key] { + continue + } + s.originStorage[key] = value + + var v []byte + if (value == common.Hash{}) { + if err := tr.TryDelete(key[:]); err != nil { + s.db.setError(err) + return nil, err + } + s.db.StorageDeleted += 1 + } else { + // Encoding []byte cannot fail, ok to ignore the error. + v, _ = rlp.EncodeToBytes(common.TrimLeftZeroes(value[:])) + if err := tr.TryUpdate(key[:], v); err != nil { + s.db.setError(err) + return nil, err + } + s.db.StorageUpdated += 1 + } + // If state snapshotting is active, cache the data til commit + if s.db.snap != nil { + if storage == nil { + // Retrieve the old storage map, if available, create a new one otherwise + if storage = s.db.snapStorage[s.addrHash]; storage == nil { + storage = make(map[common.Hash][]byte) + s.db.snapStorage[s.addrHash] = storage + } + } + storage[crypto.HashData(hasher, key[:])] = v // v will be nil if it's deleted + } + usedStorage = append(usedStorage, common.CopyBytes(key[:])) // Copy needed for closure + } + if s.db.prefetcher != nil { + s.db.prefetcher.used(s.addrHash, s.data.Root, usedStorage) + } + if len(s.pendingStorage) > 0 { + s.pendingStorage = make(Storage) + } + return tr, nil +} + +// UpdateRoot sets the trie root to the current root hash of. An error +// will be returned if trie root hash is not computed correctly. +func (s *stateObject) updateRoot(db Database) { + tr, err := s.updateTrie(db) + if err != nil { + return + } + // If nothing changed, don't bother with hashing anything + if tr == nil { + return + } + // Track the amount of time wasted on hashing the storage trie + if metrics.EnabledExpensive { + defer func(start time.Time) { s.db.StorageHashes += time.Since(start) }(time.Now()) + } + s.data.Root = tr.Hash() +} + +// AddBalance adds amount to s's balance. +// It is used to add funds to the destination account of a transfer. +func (s *stateObject) AddBalance(amount *big.Int) { + // EIP161: We must check emptiness for the objects such that the account + // clearing (0,0,0 objects) can take effect. + if amount.Sign() == 0 { + if s.empty() { + s.touch() + } + return + } + s.SetBalance(new(big.Int).Add(s.Balance(), amount)) +} + +// SubBalance removes amount from s's balance. +// It is used to remove funds from the origin account of a transfer. +func (s *stateObject) SubBalance(amount *big.Int) { + if amount.Sign() == 0 { + return + } + s.SetBalance(new(big.Int).Sub(s.Balance(), amount)) +} + +func (s *stateObject) SetBalance(amount *big.Int) { + s.db.journal.append(balanceChange{ + account: &s.address, + prev: new(big.Int).Set(s.data.Balance), + }) + s.setBalance(amount) +} + +func (s *stateObject) setBalance(amount *big.Int) { + s.data.Balance = amount +} + +func (s *stateObject) deepCopy(db *StateDB) *stateObject { + stateObject := newObject(db, s.address, s.data) + if s.trie != nil { + stateObject.trie = db.db.CopyTrie(s.trie) + } + stateObject.code = s.code + stateObject.dirtyStorage = s.dirtyStorage.Copy() + stateObject.originStorage = s.originStorage.Copy() + stateObject.pendingStorage = s.pendingStorage.Copy() + stateObject.suicided = s.suicided + stateObject.dirtyCode = s.dirtyCode + stateObject.deleted = s.deleted + return stateObject +} + // // Attribute accessors // @@ -163,12 +411,12 @@ func (s *stateObject) Code(db Database) []byte { if s.code != nil { return s.code } - if bytes.Equal(s.CodeHash(), emptyCodeHash) { + if bytes.Equal(s.CodeHash(), types.EmptyCodeHash.Bytes()) { return nil } code, err := db.ContractCode(common.BytesToHash(s.CodeHash())) if err != nil { - s.setError(fmt.Errorf("can't load code hash %x: %v", s.CodeHash(), err)) + s.db.setError(fmt.Errorf("can't load code hash %x: %v", s.CodeHash(), err)) } s.code = code return code @@ -181,16 +429,44 @@ func (s *stateObject) CodeSize(db Database) int { if s.code != nil { return len(s.code) } - if bytes.Equal(s.CodeHash(), emptyCodeHash) { + if bytes.Equal(s.CodeHash(), types.EmptyCodeHash.Bytes()) { return 0 } size, err := db.ContractCodeSize(common.BytesToHash(s.CodeHash())) if err != nil { - s.setError(fmt.Errorf("can't load code size %x: %v", s.CodeHash(), err)) + s.db.setError(fmt.Errorf("can't load code size %x: %v", s.CodeHash(), err)) } return size } +func (s *stateObject) SetCode(codeHash common.Hash, code []byte) { + prevcode := s.Code(s.db.db) + s.db.journal.append(codeChange{ + account: &s.address, + prevhash: s.CodeHash(), + prevcode: prevcode, + }) + s.setCode(codeHash, code) +} + +func (s *stateObject) setCode(codeHash common.Hash, code []byte) { + s.code = code + s.data.CodeHash = codeHash[:] + s.dirtyCode = true +} + +func (s *stateObject) SetNonce(nonce uint64) { + s.db.journal.append(nonceChange{ + account: &s.address, + prev: s.data.Nonce, + }) + s.setNonce(nonce) +} + +func (s *stateObject) setNonce(nonce uint64) { + s.data.Nonce = nonce +} + func (s *stateObject) CodeHash() []byte { return s.data.CodeHash } @@ -202,10 +478,3 @@ func (s *stateObject) Balance() *big.Int { func (s *stateObject) Nonce() uint64 { return s.data.Nonce } - -// Never called, but must be present to allow stateObject to be used -// as a vm.Account interface that also satisfies the vm.ContractRef -// interface. Interfaces are awesome. -func (s *stateObject) Value() *big.Int { - panic("Value on stateObject should never be called") -} diff --git a/trie_by_cid/state/state_object_test.go b/trie_by_cid/state/state_object_test.go new file mode 100644 index 0000000..42fd778 --- /dev/null +++ b/trie_by_cid/state/state_object_test.go @@ -0,0 +1,46 @@ +// Copyright 2019 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package state + +import ( + "bytes" + "testing" + + "github.com/ethereum/go-ethereum/common" +) + +func BenchmarkCutOriginal(b *testing.B) { + value := common.HexToHash("0x01") + for i := 0; i < b.N; i++ { + bytes.TrimLeft(value[:], "\x00") + } +} + +func BenchmarkCutsetterFn(b *testing.B) { + value := common.HexToHash("0x01") + cutSetFn := func(r rune) bool { return r == 0 } + for i := 0; i < b.N; i++ { + bytes.TrimLeftFunc(value[:], cutSetFn) + } +} + +func BenchmarkCutCustomTrim(b *testing.B) { + value := common.HexToHash("0x01") + for i := 0; i < b.N; i++ { + common.TrimLeftZeroes(value[:]) + } +} diff --git a/trie_by_cid/state/state_test.go b/trie_by_cid/state/state_test.go new file mode 100644 index 0000000..6b3fa85 --- /dev/null +++ b/trie_by_cid/state/state_test.go @@ -0,0 +1,219 @@ +// Copyright 2014 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package state + +import ( + "bytes" + "context" + "math/big" + "testing" + + pgipfsethdb "github.com/cerc-io/ipfs-ethdb/v5/postgres/v0" + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/core/rawdb" + "github.com/ethereum/go-ethereum/crypto" + "github.com/ethereum/go-ethereum/ethdb" + "github.com/ethereum/go-ethereum/statediff/indexer/database/sql/postgres" + + "github.com/cerc-io/ipld-eth-statedb/internal" +) + +var ( + testCtx = context.Background() + testConfig, _ = postgres.DefaultConfig.WithEnv() + teardownStatements = []string{`TRUNCATE ipld.blocks`} +) + +type stateTest struct { + db ethdb.Database + state *StateDB +} + +func newStateTest(t *testing.T) *stateTest { + pool, err := postgres.ConnectSQLX(testCtx, testConfig) + if err != nil { + t.Fatal(err) + } + db := pgipfsethdb.NewDatabase(pool, internal.MakeCacheConfig(t)) + sdb, err := New(common.Hash{}, NewDatabase(db), nil) + if err != nil { + t.Fatal(err) + } + return &stateTest{db: db, state: sdb} +} + +func TestNull(t *testing.T) { + s := newStateTest(t) + address := common.HexToAddress("0x823140710bf13990e4500136726d8b55") + s.state.CreateAccount(address) + //value := common.FromHex("0x823140710bf13990e4500136726d8b55") + var value common.Hash + + s.state.SetState(address, common.Hash{}, value) + // s.state.Commit(false) + + if value := s.state.GetState(address, common.Hash{}); value != (common.Hash{}) { + t.Errorf("expected empty current value, got %x", value) + } + if value := s.state.GetCommittedState(address, common.Hash{}); value != (common.Hash{}) { + t.Errorf("expected empty committed value, got %x", value) + } +} + +func TestSnapshot(t *testing.T) { + stateobjaddr := common.BytesToAddress([]byte("aa")) + var storageaddr common.Hash + data1 := common.BytesToHash([]byte{42}) + data2 := common.BytesToHash([]byte{43}) + s := newStateTest(t) + + // snapshot the genesis state + genesis := s.state.Snapshot() + + // set initial state object value + s.state.SetState(stateobjaddr, storageaddr, data1) + snapshot := s.state.Snapshot() + + // set a new state object value, revert it and ensure correct content + s.state.SetState(stateobjaddr, storageaddr, data2) + s.state.RevertToSnapshot(snapshot) + + if v := s.state.GetState(stateobjaddr, storageaddr); v != data1 { + t.Errorf("wrong storage value %v, want %v", v, data1) + } + if v := s.state.GetCommittedState(stateobjaddr, storageaddr); v != (common.Hash{}) { + t.Errorf("wrong committed storage value %v, want %v", v, common.Hash{}) + } + + // revert up to the genesis state and ensure correct content + s.state.RevertToSnapshot(genesis) + if v := s.state.GetState(stateobjaddr, storageaddr); v != (common.Hash{}) { + t.Errorf("wrong storage value %v, want %v", v, common.Hash{}) + } + if v := s.state.GetCommittedState(stateobjaddr, storageaddr); v != (common.Hash{}) { + t.Errorf("wrong committed storage value %v, want %v", v, common.Hash{}) + } +} + +func TestSnapshotEmpty(t *testing.T) { + s := newStateTest(t) + s.state.RevertToSnapshot(s.state.Snapshot()) +} + +func TestSnapshot2(t *testing.T) { + state, _ := New(common.Hash{}, NewDatabase(rawdb.NewMemoryDatabase()), nil) + + stateobjaddr0 := common.BytesToAddress([]byte("so0")) + stateobjaddr1 := common.BytesToAddress([]byte("so1")) + var storageaddr common.Hash + + data0 := common.BytesToHash([]byte{17}) + data1 := common.BytesToHash([]byte{18}) + + state.SetState(stateobjaddr0, storageaddr, data0) + state.SetState(stateobjaddr1, storageaddr, data1) + + // db, trie are already non-empty values + so0 := state.getStateObject(stateobjaddr0) + so0.SetBalance(big.NewInt(42)) + so0.SetNonce(43) + so0.SetCode(crypto.Keccak256Hash([]byte{'c', 'a', 'f', 'e'}), []byte{'c', 'a', 'f', 'e'}) + so0.suicided = false + so0.deleted = false + state.setStateObject(so0) + + // root, _ := state.Commit(false) + // state, _ = New(root, state.db, state.snaps) + + // and one with deleted == true + so1 := state.getStateObject(stateobjaddr1) + so1.SetBalance(big.NewInt(52)) + so1.SetNonce(53) + so1.SetCode(crypto.Keccak256Hash([]byte{'c', 'a', 'f', 'e', '2'}), []byte{'c', 'a', 'f', 'e', '2'}) + so1.suicided = true + so1.deleted = true + state.setStateObject(so1) + + so1 = state.getStateObject(stateobjaddr1) + if so1 != nil { + t.Fatalf("deleted object not nil when getting") + } + + snapshot := state.Snapshot() + state.RevertToSnapshot(snapshot) + + so0Restored := state.getStateObject(stateobjaddr0) + // Update lazily-loaded values before comparing. + so0Restored.GetState(state.db, storageaddr) + so0Restored.Code(state.db) + // non-deleted is equal (restored) + compareStateObjects(so0Restored, so0, t) + + // deleted should be nil, both before and after restore of state copy + so1Restored := state.getStateObject(stateobjaddr1) + if so1Restored != nil { + t.Fatalf("deleted object not nil after restoring snapshot: %+v", so1Restored) + } +} + +func compareStateObjects(so0, so1 *stateObject, t *testing.T) { + if so0.Address() != so1.Address() { + t.Fatalf("Address mismatch: have %v, want %v", so0.address, so1.address) + } + if so0.Balance().Cmp(so1.Balance()) != 0 { + t.Fatalf("Balance mismatch: have %v, want %v", so0.Balance(), so1.Balance()) + } + if so0.Nonce() != so1.Nonce() { + t.Fatalf("Nonce mismatch: have %v, want %v", so0.Nonce(), so1.Nonce()) + } + if so0.data.Root != so1.data.Root { + t.Errorf("Root mismatch: have %x, want %x", so0.data.Root[:], so1.data.Root[:]) + } + if !bytes.Equal(so0.CodeHash(), so1.CodeHash()) { + t.Fatalf("CodeHash mismatch: have %v, want %v", so0.CodeHash(), so1.CodeHash()) + } + if !bytes.Equal(so0.code, so1.code) { + t.Fatalf("Code mismatch: have %v, want %v", so0.code, so1.code) + } + + if len(so1.dirtyStorage) != len(so0.dirtyStorage) { + t.Errorf("Dirty storage size mismatch: have %d, want %d", len(so1.dirtyStorage), len(so0.dirtyStorage)) + } + for k, v := range so1.dirtyStorage { + if so0.dirtyStorage[k] != v { + t.Errorf("Dirty storage key %x mismatch: have %v, want %v", k, so0.dirtyStorage[k], v) + } + } + for k, v := range so0.dirtyStorage { + if so1.dirtyStorage[k] != v { + t.Errorf("Dirty storage key %x mismatch: have %v, want none.", k, v) + } + } + if len(so1.originStorage) != len(so0.originStorage) { + t.Errorf("Origin storage size mismatch: have %d, want %d", len(so1.originStorage), len(so0.originStorage)) + } + for k, v := range so1.originStorage { + if so0.originStorage[k] != v { + t.Errorf("Origin storage key %x mismatch: have %v, want %v", k, so0.originStorage[k], v) + } + } + for k, v := range so0.originStorage { + if so1.originStorage[k] != v { + t.Errorf("Origin storage key %x mismatch: have %v, want none.", k, v) + } + } +} diff --git a/trie_by_cid/state/statedb.go b/trie_by_cid/state/statedb.go index fc0c0b8..b4c4369 100644 --- a/trie_by_cid/state/statedb.go +++ b/trie_by_cid/state/statedb.go @@ -1,17 +1,45 @@ +// Copyright 2014 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +// Package state provides a caching layer atop the Ethereum state trie. package state import ( "errors" "fmt" "math/big" + "sort" "time" "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/core/state/snapshot" "github.com/ethereum/go-ethereum/core/types" "github.com/ethereum/go-ethereum/crypto" "github.com/ethereum/go-ethereum/metrics" + "github.com/ethereum/go-ethereum/params" + "github.com/ethereum/go-ethereum/rlp" + + "github.com/cerc-io/ipld-eth-statedb/trie_by_cid/trie" ) +type revision struct { + id int + journalIndex int +} + type proofList [][]byte func (n *proofList) Put(key []byte, value []byte) error { @@ -28,46 +56,127 @@ func (n *proofList) Delete(key []byte) error { // nested states. It's the general query interface to retrieve: // * Contracts // * Accounts -// -// This implementation is read-only and performs no journaling, prefetching, or metrics tracking. type StateDB struct { - db Database - trie Trie - hasher crypto.KeccakState + db Database + prefetcher *triePrefetcher + trie Trie + hasher crypto.KeccakState + + // originalRoot is the pre-state root, before any changes were made. + // It will be updated when the Commit is called. + originalRoot common.Hash + + snaps *snapshot.Tree + snap snapshot.Snapshot + snapAccounts map[common.Hash][]byte + snapStorage map[common.Hash]map[common.Hash][]byte // This map holds 'live' objects, which will get modified while processing a state transition. - stateObjects map[common.Address]*stateObject + stateObjects map[common.Address]*stateObject + stateObjectsPending map[common.Address]struct{} // State objects finalized but not yet written to the trie + stateObjectsDirty map[common.Address]struct{} // State objects modified in the current execution + stateObjectsDestruct map[common.Address]struct{} // State objects destructed in the block // DB error. // State objects are used by the consensus core and VM which are // unable to deal with database-level errors. Any error that occurs - // during a database read is memoized here and will eventually be returned - // by StateDB.Commit. + // during a database read is memoized here and will eventually be + // returned by StateDB.Commit. Notably, this error is also shared + // by all cached state objects in case the database failure occurs + // when accessing state of accounts. dbErr error + // The refund counter, also used by state transitioning. + refund uint64 + + thash common.Hash + txIndex int + logs map[common.Hash][]*types.Log + logSize uint + preimages map[common.Hash][]byte + // Per-transaction access list + accessList *accessList + + // Transient storage + transientStorage transientStorage + + // Journal of state modifications. This is the backbone of + // Snapshot and RevertToSnapshot. + journal *journal + validRevisions []revision + nextRevisionId int + // Measurements gathered during execution for debugging purposes - AccountReads time.Duration - StorageReads time.Duration + AccountReads time.Duration + AccountHashes time.Duration + AccountUpdates time.Duration + StorageReads time.Duration + StorageHashes time.Duration + StorageUpdates time.Duration + SnapshotAccountReads time.Duration + SnapshotStorageReads time.Duration + + AccountUpdated int + StorageUpdated int + AccountDeleted int + StorageDeleted int } // New creates a new state from a given trie. -func New(root common.Hash, db Database) (*StateDB, error) { +func New(root common.Hash, db Database, snaps *snapshot.Tree) (*StateDB, error) { tr, err := db.OpenTrie(root) if err != nil { return nil, err } sdb := &StateDB{ - db: db, - trie: tr, - stateObjects: make(map[common.Address]*stateObject), - preimages: make(map[common.Hash][]byte), - hasher: crypto.NewKeccakState(), + db: db, + trie: tr, + originalRoot: root, + snaps: snaps, + stateObjects: make(map[common.Address]*stateObject), + stateObjectsPending: make(map[common.Address]struct{}), + stateObjectsDirty: make(map[common.Address]struct{}), + stateObjectsDestruct: make(map[common.Address]struct{}), + logs: make(map[common.Hash][]*types.Log), + preimages: make(map[common.Hash][]byte), + journal: newJournal(), + accessList: newAccessList(), + transientStorage: newTransientStorage(), + hasher: crypto.NewKeccakState(), + } + if sdb.snaps != nil { + if sdb.snap = sdb.snaps.Snapshot(root); sdb.snap != nil { + sdb.snapAccounts = make(map[common.Hash][]byte) + sdb.snapStorage = make(map[common.Hash]map[common.Hash][]byte) + } } return sdb, nil } +// StartPrefetcher initializes a new trie prefetcher to pull in nodes from the +// state trie concurrently while the state is mutated so that when we reach the +// commit phase, most of the needed data is already hot. +func (s *StateDB) StartPrefetcher(namespace string) { + if s.prefetcher != nil { + s.prefetcher.close() + s.prefetcher = nil + } + if s.snap != nil { + s.prefetcher = newTriePrefetcher(s.db, s.originalRoot, namespace) + } +} + +// StopPrefetcher terminates a running prefetcher and reports any leftover stats +// from the gathered metrics. +func (s *StateDB) StopPrefetcher() { + if s.prefetcher != nil { + s.prefetcher.close() + s.prefetcher = nil + } +} + // setError remembers the first non-nil error it is called with. func (s *StateDB) setError(err error) { if s.dbErr == nil { @@ -75,17 +184,44 @@ func (s *StateDB) setError(err error) { } } +// Error returns the memorized database failure occurred earlier. func (s *StateDB) Error() error { return s.dbErr } func (s *StateDB) AddLog(log *types.Log) { - panic("unsupported") + s.journal.append(addLogChange{txhash: s.thash}) + + log.TxHash = s.thash + log.TxIndex = uint(s.txIndex) + log.Index = s.logSize + s.logs[s.thash] = append(s.logs[s.thash], log) + s.logSize++ +} + +// GetLogs returns the logs matching the specified transaction hash, and annotates +// them with the given blockNumber and blockHash. +func (s *StateDB) GetLogs(hash common.Hash, blockNumber uint64, blockHash common.Hash) []*types.Log { + logs := s.logs[hash] + for _, l := range logs { + l.BlockNumber = blockNumber + l.BlockHash = blockHash + } + return logs +} + +func (s *StateDB) Logs() []*types.Log { + var logs []*types.Log + for _, lgs := range s.logs { + logs = append(logs, lgs...) + } + return logs } // AddPreimage records a SHA3 preimage seen by the VM. func (s *StateDB) AddPreimage(hash common.Hash, preimage []byte) { if _, ok := s.preimages[hash]; !ok { + s.journal.append(addPreimageChange{hash: hash}) pi := make([]byte, len(preimage)) copy(pi, preimage) s.preimages[hash] = pi @@ -99,13 +235,18 @@ func (s *StateDB) Preimages() map[common.Hash][]byte { // AddRefund adds gas to the refund counter func (s *StateDB) AddRefund(gas uint64) { - panic("unsupported") + s.journal.append(refundChange{prev: s.refund}) + s.refund += gas } // SubRefund removes gas from the refund counter. // This method will panic if the refund counter goes below zero func (s *StateDB) SubRefund(gas uint64) { - panic("unsupported") + s.journal.append(refundChange{prev: s.refund}) + if gas > s.refund { + panic(fmt.Sprintf("Refund counter below zero (gas: %d > refund: %d)", gas, s.refund)) + } + s.refund -= gas } // Exist reports whether the given account address exists in the state. @@ -139,6 +280,11 @@ func (s *StateDB) GetNonce(addr common.Address) uint64 { return 0 } +// TxIndex returns the current transaction index set by Prepare. +func (s *StateDB) TxIndex() int { + return s.txIndex +} + func (s *StateDB) GetCode(addr common.Address) []byte { stateObject := s.getStateObject(addr) if stateObject != nil { @@ -186,18 +332,28 @@ func (s *StateDB) GetProofByHash(addrHash common.Hash) ([][]byte, error) { // GetStorageProof returns the Merkle proof for given storage slot. func (s *StateDB) GetStorageProof(a common.Address, key common.Hash) ([][]byte, error) { - var proof proofList - trie := s.StorageTrie(a) - if trie == nil { - return proof, errors.New("storage trie for requested address does not exist") + trie, err := s.StorageTrie(a) + if err != nil { + return nil, err } - err := trie.Prove(crypto.Keccak256(key.Bytes()), 0, &proof) - return proof, err + if trie == nil { + return nil, errors.New("storage trie for requested address does not exist") + } + var proof proofList + err = trie.Prove(crypto.Keccak256(key.Bytes()), 0, &proof) + if err != nil { + return nil, err + } + return proof, nil } // GetCommittedState retrieves a value from the given account's committed storage trie. func (s *StateDB) GetCommittedState(addr common.Address, hash common.Hash) common.Hash { - return s.GetState(addr, hash) + stateObject := s.getStateObject(addr) + if stateObject != nil { + return stateObject.GetCommittedState(s.db, hash) + } + return common.Hash{} } // Database retrieves the low level database supporting the lower level trie ops. @@ -205,17 +361,26 @@ func (s *StateDB) Database() Database { return s.db } -// StorageTrie returns the storage trie of an account. -// The return value is a copy and is nil for non-existent accounts. -func (s *StateDB) StorageTrie(addr common.Address) Trie { +// StorageTrie returns the storage trie of an account. The return value is a copy +// and is nil for non-existent accounts. An error will be returned if storage trie +// is existent but can't be loaded correctly. +func (s *StateDB) StorageTrie(addr common.Address) (Trie, error) { stateObject := s.getStateObject(addr) if stateObject == nil { - return nil + return nil, nil } - return stateObject.getTrie(s.db) + cpy := stateObject.deepCopy(s) + if _, err := cpy.updateTrie(s.db); err != nil { + return nil, err + } + return cpy.getTrie(s.db) } func (s *StateDB) HasSuicided(addr common.Address) bool { + stateObject := s.getStateObject(addr) + if stateObject != nil { + return stateObject.suicided + } return false } @@ -225,34 +390,61 @@ func (s *StateDB) HasSuicided(addr common.Address) bool { // AddBalance adds amount to the account associated with addr. func (s *StateDB) AddBalance(addr common.Address, amount *big.Int) { - panic("unsupported") + stateObject := s.GetOrNewStateObject(addr) + if stateObject != nil { + stateObject.AddBalance(amount) + } } // SubBalance subtracts amount from the account associated with addr. func (s *StateDB) SubBalance(addr common.Address, amount *big.Int) { - panic("unsupported") + stateObject := s.GetOrNewStateObject(addr) + if stateObject != nil { + stateObject.SubBalance(amount) + } } func (s *StateDB) SetBalance(addr common.Address, amount *big.Int) { - panic("unsupported") + stateObject := s.GetOrNewStateObject(addr) + if stateObject != nil { + stateObject.SetBalance(amount) + } } func (s *StateDB) SetNonce(addr common.Address, nonce uint64) { - panic("unsupported") + stateObject := s.GetOrNewStateObject(addr) + if stateObject != nil { + stateObject.SetNonce(nonce) + } } func (s *StateDB) SetCode(addr common.Address, code []byte) { - panic("unsupported") + stateObject := s.GetOrNewStateObject(addr) + if stateObject != nil { + stateObject.SetCode(crypto.Keccak256Hash(code), code) + } } func (s *StateDB) SetState(addr common.Address, key, value common.Hash) { - panic("unsupported") + stateObject := s.GetOrNewStateObject(addr) + if stateObject != nil { + stateObject.SetState(s.db, key, value) + } } // SetStorage replaces the entire storage for the specified account with given // storage. This function should only be used for debugging. func (s *StateDB) SetStorage(addr common.Address, storage map[common.Hash]common.Hash) { - panic("unsupported") + // SetStorage needs to wipe existing storage. We achieve this by pretending + // that the account self-destructed earlier in this block, by flagging + // it in stateObjectsDestruct. The effect of doing so is that storage lookups + // will not hit disk, since it is assumed that the disk-data is belonging + // to a previous incarnation of the object. + s.stateObjectsDestruct[addr] = struct{}{} + stateObject := s.GetOrNewStateObject(addr) + for k, v := range storage { + stateObject.SetState(s.db, k, v) + } } // Suicide marks the given account as suicided. @@ -261,36 +453,147 @@ func (s *StateDB) SetStorage(addr common.Address, storage map[common.Hash]common // The account's state object is still available until the state is committed, // getStateObject will return a non-nil account after Suicide. func (s *StateDB) Suicide(addr common.Address) bool { - panic("unsupported") - return false + stateObject := s.getStateObject(addr) + if stateObject == nil { + return false + } + s.journal.append(suicideChange{ + account: &addr, + prev: stateObject.suicided, + prevbalance: new(big.Int).Set(stateObject.Balance()), + }) + stateObject.markSuicided() + stateObject.data.Balance = new(big.Int) + + return true +} + +// SetTransientState sets transient storage for a given account. It +// adds the change to the journal so that it can be rolled back +// to its previous value if there is a revert. +func (s *StateDB) SetTransientState(addr common.Address, key, value common.Hash) { + prev := s.GetTransientState(addr, key) + if prev == value { + return + } + s.journal.append(transientStorageChange{ + account: &addr, + key: key, + prevalue: prev, + }) + s.setTransientState(addr, key, value) +} + +// setTransientState is a lower level setter for transient storage. It +// is called during a revert to prevent modifications to the journal. +func (s *StateDB) setTransientState(addr common.Address, key, value common.Hash) { + s.transientStorage.Set(addr, key, value) +} + +// GetTransientState gets transient storage for a given account. +func (s *StateDB) GetTransientState(addr common.Address, key common.Hash) common.Hash { + return s.transientStorage.Get(addr, key) } // // Setting, updating & deleting state object methods. // +// updateStateObject writes the given object to the trie. +func (s *StateDB) updateStateObject(obj *stateObject) { + // Track the amount of time wasted on updating the account from the trie + if metrics.EnabledExpensive { + defer func(start time.Time) { s.AccountUpdates += time.Since(start) }(time.Now()) + } + // Encode the account and update the account trie + addr := obj.Address() + if err := s.trie.TryUpdateAccount(addr, &obj.data); err != nil { + s.setError(fmt.Errorf("updateStateObject (%x) error: %v", addr[:], err)) + } + + // If state snapshotting is active, cache the data til commit. Note, this + // update mechanism is not symmetric to the deletion, because whereas it is + // enough to track account updates at commit time, deletions need tracking + // at transaction boundary level to ensure we capture state clearing. + if s.snap != nil { + s.snapAccounts[obj.addrHash] = snapshot.SlimAccountRLP(obj.data.Nonce, obj.data.Balance, obj.data.Root, obj.data.CodeHash) + } +} + +// deleteStateObject removes the given object from the state trie. +func (s *StateDB) deleteStateObject(obj *stateObject) { + // Track the amount of time wasted on deleting the account from the trie + if metrics.EnabledExpensive { + defer func(start time.Time) { s.AccountUpdates += time.Since(start) }(time.Now()) + } + // Delete the account from the trie + addr := obj.Address() + if err := s.trie.TryDeleteAccount(addr); err != nil { + s.setError(fmt.Errorf("deleteStateObject (%x) error: %v", addr[:], err)) + } +} + // getStateObject retrieves a state object given by the address, returning nil if -// the object is not found or was deleted in this execution context. +// the object is not found or was deleted in this execution context. If you need +// to differentiate between non-existent/just-deleted, use getDeletedStateObject. func (s *StateDB) getStateObject(addr common.Address) *stateObject { + if obj := s.getDeletedStateObject(addr); obj != nil && !obj.deleted { + return obj + } + return nil +} + +// getDeletedStateObject is similar to getStateObject, but instead of returning +// nil for a deleted state object, it returns the actual object with the deleted +// flag set. This is needed by the state journal to revert to the correct s- +// destructed object instead of wiping all knowledge about the state object. +func (s *StateDB) getDeletedStateObject(addr common.Address) *stateObject { // Prefer live objects if any is available if obj := s.stateObjects[addr]; obj != nil { return obj } - // If no live objects are available, load from the database - start := time.Now() - var err error - data, err := s.trie.TryGetAccount(addr.Bytes()) - if metrics.EnabledExpensive { - s.AccountReads += time.Since(start) - } - if err != nil { - s.setError(fmt.Errorf("getStateObject (%x) error: %w", addr.Bytes(), err)) - return nil + // If no live objects are available, attempt to use snapshots + var data *types.StateAccount + if s.snap != nil { + start := time.Now() + acc, err := s.snap.Account(crypto.HashData(s.hasher, addr.Bytes())) + if metrics.EnabledExpensive { + s.SnapshotAccountReads += time.Since(start) + } + if err == nil { + if acc == nil { + return nil + } + data = &types.StateAccount{ + Nonce: acc.Nonce, + Balance: acc.Balance, + CodeHash: acc.CodeHash, + Root: common.BytesToHash(acc.Root), + } + if len(data.CodeHash) == 0 { + data.CodeHash = types.EmptyCodeHash.Bytes() + } + if data.Root == (common.Hash{}) { + data.Root = types.EmptyRootHash + } + } } + // If snapshot unavailable or reading from it failed, load from the database if data == nil { - return nil + start := time.Now() + var err error + data, err = s.trie.TryGetAccount(addr) + if metrics.EnabledExpensive { + s.AccountReads += time.Since(start) + } + if err != nil { + s.setError(fmt.Errorf("getDeleteStateObject (%x) error: %w", addr.Bytes(), err)) + return nil + } + if data == nil { + return nil + } } - // Insert into the live set obj := newObject(s, addr, *data) s.setStateObject(obj) @@ -301,69 +604,439 @@ func (s *StateDB) setStateObject(object *stateObject) { s.stateObjects[object.Address()] = object } +// GetOrNewStateObject retrieves a state object or create a new state object if nil. +func (s *StateDB) GetOrNewStateObject(addr common.Address) *stateObject { + stateObject := s.getStateObject(addr) + if stateObject == nil { + stateObject, _ = s.createObject(addr) + } + return stateObject +} + +// createObject creates a new state object. If there is an existing account with +// the given address, it is overwritten and returned as the second return value. +func (s *StateDB) createObject(addr common.Address) (newobj, prev *stateObject) { + prev = s.getDeletedStateObject(addr) // Note, prev might have been deleted, we need that! + + var prevdestruct bool + if prev != nil { + _, prevdestruct = s.stateObjectsDestruct[prev.address] + if !prevdestruct { + s.stateObjectsDestruct[prev.address] = struct{}{} + } + } + newobj = newObject(s, addr, types.StateAccount{}) + if prev == nil { + s.journal.append(createObjectChange{account: &addr}) + } else { + s.journal.append(resetObjectChange{prev: prev, prevdestruct: prevdestruct}) + } + s.setStateObject(newobj) + if prev != nil && !prev.deleted { + return newobj, prev + } + return newobj, nil +} + // CreateAccount explicitly creates a state object. If a state object with the address // already exists the balance is carried over to the new account. // // CreateAccount is called during the EVM CREATE operation. The situation might arise that // a contract does the following: // -// 1. sends funds to sha(account ++ (nonce + 1)) -// 2. tx_create(sha(account ++ nonce)) (note that this gets the address of 1) +// 1. sends funds to sha(account ++ (nonce + 1)) +// 2. tx_create(sha(account ++ nonce)) (note that this gets the address of 1) // // Carrying over the balance ensures that Ether doesn't disappear. func (s *StateDB) CreateAccount(addr common.Address) { - panic("unsupported") + newObj, prev := s.createObject(addr) + if prev != nil { + newObj.setBalance(prev.data.Balance) + } } func (db *StateDB) ForEachStorage(addr common.Address, cb func(key, value common.Hash) bool) error { + so := db.getStateObject(addr) + if so == nil { + return nil + } + tr, err := so.getTrie(db.db) + if err != nil { + return err + } + it := trie.NewIterator(tr.NodeIterator(nil)) + + for it.Next() { + key := common.BytesToHash(db.trie.GetKey(it.Key)) + if value, dirty := so.dirtyStorage[key]; dirty { + if !cb(key, value) { + return nil + } + continue + } + + if len(it.Value) > 0 { + _, content, _, err := rlp.Split(it.Value) + if err != nil { + return err + } + if !cb(key, common.BytesToHash(content)) { + return nil + } + } + } return nil } +// Copy creates a deep, independent copy of the state. +// Snapshots of the copied state cannot be applied to the copy. +func (s *StateDB) Copy() *StateDB { + // Copy all the basic fields, initialize the memory ones + state := &StateDB{ + db: s.db, + trie: s.db.CopyTrie(s.trie), + originalRoot: s.originalRoot, + stateObjects: make(map[common.Address]*stateObject, len(s.journal.dirties)), + stateObjectsPending: make(map[common.Address]struct{}, len(s.stateObjectsPending)), + stateObjectsDirty: make(map[common.Address]struct{}, len(s.journal.dirties)), + stateObjectsDestruct: make(map[common.Address]struct{}, len(s.stateObjectsDestruct)), + refund: s.refund, + logs: make(map[common.Hash][]*types.Log, len(s.logs)), + logSize: s.logSize, + preimages: make(map[common.Hash][]byte, len(s.preimages)), + journal: newJournal(), + hasher: crypto.NewKeccakState(), + } + // Copy the dirty states, logs, and preimages + for addr := range s.journal.dirties { + // As documented [here](https://github.com/ethereum/go-ethereum/pull/16485#issuecomment-380438527), + // and in the Finalise-method, there is a case where an object is in the journal but not + // in the stateObjects: OOG after touch on ripeMD prior to Byzantium. Thus, we need to check for + // nil + if object, exist := s.stateObjects[addr]; exist { + // Even though the original object is dirty, we are not copying the journal, + // so we need to make sure that any side-effect the journal would have caused + // during a commit (or similar op) is already applied to the copy. + state.stateObjects[addr] = object.deepCopy(state) + + state.stateObjectsDirty[addr] = struct{}{} // Mark the copy dirty to force internal (code/state) commits + state.stateObjectsPending[addr] = struct{}{} // Mark the copy pending to force external (account) commits + } + } + // Above, we don't copy the actual journal. This means that if the copy + // is copied, the loop above will be a no-op, since the copy's journal + // is empty. Thus, here we iterate over stateObjects, to enable copies + // of copies. + for addr := range s.stateObjectsPending { + if _, exist := state.stateObjects[addr]; !exist { + state.stateObjects[addr] = s.stateObjects[addr].deepCopy(state) + } + state.stateObjectsPending[addr] = struct{}{} + } + for addr := range s.stateObjectsDirty { + if _, exist := state.stateObjects[addr]; !exist { + state.stateObjects[addr] = s.stateObjects[addr].deepCopy(state) + } + state.stateObjectsDirty[addr] = struct{}{} + } + // Deep copy the destruction flag. + for addr := range s.stateObjectsDestruct { + state.stateObjectsDestruct[addr] = struct{}{} + } + for hash, logs := range s.logs { + cpy := make([]*types.Log, len(logs)) + for i, l := range logs { + cpy[i] = new(types.Log) + *cpy[i] = *l + } + state.logs[hash] = cpy + } + for hash, preimage := range s.preimages { + state.preimages[hash] = preimage + } + // Do we need to copy the access list and transient storage? + // In practice: No. At the start of a transaction, these two lists are empty. + // In practice, we only ever copy state _between_ transactions/blocks, never + // in the middle of a transaction. However, it doesn't cost us much to copy + // empty lists, so we do it anyway to not blow up if we ever decide copy them + // in the middle of a transaction. + state.accessList = s.accessList.Copy() + state.transientStorage = s.transientStorage.Copy() + + // If there's a prefetcher running, make an inactive copy of it that can + // only access data but does not actively preload (since the user will not + // know that they need to explicitly terminate an active copy). + if s.prefetcher != nil { + state.prefetcher = s.prefetcher.copy() + } + if s.snaps != nil { + // In order for the miner to be able to use and make additions + // to the snapshot tree, we need to copy that as well. + // Otherwise, any block mined by ourselves will cause gaps in the tree, + // and force the miner to operate trie-backed only + state.snaps = s.snaps + state.snap = s.snap + + // deep copy needed + state.snapAccounts = make(map[common.Hash][]byte) + for k, v := range s.snapAccounts { + state.snapAccounts[k] = v + } + state.snapStorage = make(map[common.Hash]map[common.Hash][]byte) + for k, v := range s.snapStorage { + temp := make(map[common.Hash][]byte) + for kk, vv := range v { + temp[kk] = vv + } + state.snapStorage[k] = temp + } + } + return state +} + // Snapshot returns an identifier for the current revision of the state. func (s *StateDB) Snapshot() int { - return 0 + id := s.nextRevisionId + s.nextRevisionId++ + s.validRevisions = append(s.validRevisions, revision{id, s.journal.length()}) + return id } // RevertToSnapshot reverts all state changes made since the given revision. func (s *StateDB) RevertToSnapshot(revid int) { - panic("unsupported") + // Find the snapshot in the stack of valid snapshots. + idx := sort.Search(len(s.validRevisions), func(i int) bool { + return s.validRevisions[i].id >= revid + }) + if idx == len(s.validRevisions) || s.validRevisions[idx].id != revid { + panic(fmt.Errorf("revision id %v cannot be reverted", revid)) + } + snapshot := s.validRevisions[idx].journalIndex + + // Replay the journal to undo changes and remove invalidated snapshots + s.journal.revert(s, snapshot) + s.validRevisions = s.validRevisions[:idx] } // GetRefund returns the current value of the refund counter. func (s *StateDB) GetRefund() uint64 { - panic("unsupported") - return 0 + return s.refund } -// PrepareAccessList handles the preparatory steps for executing a state transition with -// regards to both EIP-2929 and EIP-2930: +// Finalise finalises the state by removing the destructed objects and clears +// the journal as well as the refunds. Finalise, however, will not push any updates +// into the tries just yet. Only IntermediateRoot or Commit will do that. +func (s *StateDB) Finalise(deleteEmptyObjects bool) { + addressesToPrefetch := make([][]byte, 0, len(s.journal.dirties)) + for addr := range s.journal.dirties { + obj, exist := s.stateObjects[addr] + if !exist { + // ripeMD is 'touched' at block 1714175, in tx 0x1237f737031e40bcde4a8b7e717b2d15e3ecadfe49bb1bbc71ee9deb09c6fcf2 + // That tx goes out of gas, and although the notion of 'touched' does not exist there, the + // touch-event will still be recorded in the journal. Since ripeMD is a special snowflake, + // it will persist in the journal even though the journal is reverted. In this special circumstance, + // it may exist in `s.journal.dirties` but not in `s.stateObjects`. + // Thus, we can safely ignore it here + continue + } + if obj.suicided || (deleteEmptyObjects && obj.empty()) { + obj.deleted = true + + // We need to maintain account deletions explicitly (will remain + // set indefinitely). + s.stateObjectsDestruct[obj.address] = struct{}{} + + // If state snapshotting is active, also mark the destruction there. + // Note, we can't do this only at the end of a block because multiple + // transactions within the same block might self destruct and then + // resurrect an account; but the snapshotter needs both events. + if s.snap != nil { + delete(s.snapAccounts, obj.addrHash) // Clear out any previously updated account data (may be recreated via a resurrect) + delete(s.snapStorage, obj.addrHash) // Clear out any previously updated storage data (may be recreated via a resurrect) + } + } else { + obj.finalise(true) // Prefetch slots in the background + } + s.stateObjectsPending[addr] = struct{}{} + s.stateObjectsDirty[addr] = struct{}{} + + // At this point, also ship the address off to the precacher. The precacher + // will start loading tries, and when the change is eventually committed, + // the commit-phase will be a lot faster + addressesToPrefetch = append(addressesToPrefetch, common.CopyBytes(addr[:])) // Copy needed for closure + } + if s.prefetcher != nil && len(addressesToPrefetch) > 0 { + s.prefetcher.prefetch(common.Hash{}, s.originalRoot, addressesToPrefetch) + } + // Invalidate journal because reverting across transactions is not allowed. + s.clearJournalAndRefund() +} + +// IntermediateRoot computes the current root hash of the state trie. +// It is called in between transactions to get the root hash that +// goes into transaction receipts. +func (s *StateDB) IntermediateRoot(deleteEmptyObjects bool) common.Hash { + // Finalise all the dirty storage states and write them into the tries + s.Finalise(deleteEmptyObjects) + + // If there was a trie prefetcher operating, it gets aborted and irrevocably + // modified after we start retrieving tries. Remove it from the statedb after + // this round of use. + // + // This is weird pre-byzantium since the first tx runs with a prefetcher and + // the remainder without, but pre-byzantium even the initial prefetcher is + // useless, so no sleep lost. + prefetcher := s.prefetcher + if s.prefetcher != nil { + defer func() { + s.prefetcher.close() + s.prefetcher = nil + }() + } + // Although naively it makes sense to retrieve the account trie and then do + // the contract storage and account updates sequentially, that short circuits + // the account prefetcher. Instead, let's process all the storage updates + // first, giving the account prefetches just a few more milliseconds of time + // to pull useful data from disk. + for addr := range s.stateObjectsPending { + if obj := s.stateObjects[addr]; !obj.deleted { + obj.updateRoot(s.db) + } + } + // Now we're about to start to write changes to the trie. The trie is so far + // _untouched_. We can check with the prefetcher, if it can give us a trie + // which has the same root, but also has some content loaded into it. + if prefetcher != nil { + if trie := prefetcher.trie(common.Hash{}, s.originalRoot); trie != nil { + s.trie = trie + } + } + usedAddrs := make([][]byte, 0, len(s.stateObjectsPending)) + for addr := range s.stateObjectsPending { + if obj := s.stateObjects[addr]; obj.deleted { + s.deleteStateObject(obj) + s.AccountDeleted += 1 + } else { + s.updateStateObject(obj) + s.AccountUpdated += 1 + } + usedAddrs = append(usedAddrs, common.CopyBytes(addr[:])) // Copy needed for closure + } + if prefetcher != nil { + prefetcher.used(common.Hash{}, s.originalRoot, usedAddrs) + } + if len(s.stateObjectsPending) > 0 { + s.stateObjectsPending = make(map[common.Address]struct{}) + } + // Track the amount of time wasted on hashing the account trie + if metrics.EnabledExpensive { + defer func(start time.Time) { s.AccountHashes += time.Since(start) }(time.Now()) + } + return s.trie.Hash() +} + +// SetTxContext sets the current transaction hash and index which are +// used when the EVM emits new state logs. It should be invoked before +// transaction execution. +func (s *StateDB) SetTxContext(thash common.Hash, ti int) { + s.thash = thash + s.txIndex = ti +} + +func (s *StateDB) clearJournalAndRefund() { + if len(s.journal.entries) > 0 { + s.journal = newJournal() + s.refund = 0 + } + s.validRevisions = s.validRevisions[:0] // Snapshots can be created without journal entries +} + +// Prepare handles the preparatory steps for executing a state transition with. +// This method must be invoked before state transition. // +// Berlin fork: // - Add sender to access list (2929) // - Add destination to access list (2929) // - Add precompiles to access list (2929) // - Add the contents of the optional tx access list (2930) // -// This method should only be called if Berlin/2929+2930 is applicable at the current number. -func (s *StateDB) PrepareAccessList(sender common.Address, dst *common.Address, precompiles []common.Address, list types.AccessList) { - panic("unsupported") +// Potential EIPs: +// - Reset access list (Berlin) +// - Add coinbase to access list (EIP-3651) +// - Reset transient storage (EIP-1153) +func (s *StateDB) Prepare(rules params.Rules, sender, coinbase common.Address, dst *common.Address, precompiles []common.Address, list types.AccessList) { + if rules.IsBerlin { + // Clear out any leftover from previous executions + al := newAccessList() + s.accessList = al + + al.AddAddress(sender) + if dst != nil { + al.AddAddress(*dst) + // If it's a create-tx, the destination will be added inside evm.create + } + for _, addr := range precompiles { + al.AddAddress(addr) + } + for _, el := range list { + al.AddAddress(el.Address) + for _, key := range el.StorageKeys { + al.AddSlot(el.Address, key) + } + } + if rules.IsShanghai { // EIP-3651: warm coinbase + al.AddAddress(coinbase) + } + } + // Reset transient storage at the beginning of transaction execution + s.transientStorage = newTransientStorage() } // AddAddressToAccessList adds the given address to the access list func (s *StateDB) AddAddressToAccessList(addr common.Address) { - panic("unsupported") + if s.accessList.AddAddress(addr) { + s.journal.append(accessListAddAccountChange{&addr}) + } } // AddSlotToAccessList adds the given (address, slot)-tuple to the access list func (s *StateDB) AddSlotToAccessList(addr common.Address, slot common.Hash) { - panic("unsupported") + addrMod, slotMod := s.accessList.AddSlot(addr, slot) + if addrMod { + // In practice, this should not happen, since there is no way to enter the + // scope of 'address' without having the 'address' become already added + // to the access list (via call-variant, create, etc). + // Better safe than sorry, though + s.journal.append(accessListAddAccountChange{&addr}) + } + if slotMod { + s.journal.append(accessListAddSlotChange{ + address: &addr, + slot: &slot, + }) + } } // AddressInAccessList returns true if the given address is in the access list. func (s *StateDB) AddressInAccessList(addr common.Address) bool { - return false + return s.accessList.ContainsAddress(addr) } // SlotInAccessList returns true if the given (address, slot)-tuple is in the access list. func (s *StateDB) SlotInAccessList(addr common.Address, slot common.Hash) (addressPresent bool, slotPresent bool) { - return + return s.accessList.Contains(addr, slot) +} + +// convertAccountSet converts a provided account set from address keyed to hash keyed. +func (s *StateDB) convertAccountSet(set map[common.Address]struct{}) map[common.Hash]struct{} { + ret := make(map[common.Hash]struct{}) + for addr := range set { + obj, exist := s.stateObjects[addr] + if !exist { + ret[crypto.Keccak256Hash(addr[:])] = struct{}{} + } else { + ret[obj.addrHash] = struct{}{} + } + } + return ret } diff --git a/trie_by_cid/state/statedb_test.go b/trie_by_cid/state/statedb_test.go new file mode 100644 index 0000000..60cd66a --- /dev/null +++ b/trie_by_cid/state/statedb_test.go @@ -0,0 +1,594 @@ +// Copyright 2016 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package state + +import ( + "bytes" + "encoding/binary" + "fmt" + "math" + "math/big" + "math/rand" + "reflect" + "strings" + "sync" + "testing" + "testing/quick" + + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/core/rawdb" + "github.com/ethereum/go-ethereum/core/types" +) + +// TestCopy tests that copying a StateDB object indeed makes the original and +// the copy independent of each other. This test is a regression test against +// https://github.com/ethereum/go-ethereum/pull/15549. +func TestCopy(t *testing.T) { + // Create a random state test to copy and modify "independently" + orig, _ := New(common.Hash{}, NewDatabase(rawdb.NewMemoryDatabase()), nil) + + for i := byte(0); i < 255; i++ { + obj := orig.GetOrNewStateObject(common.BytesToAddress([]byte{i})) + obj.AddBalance(big.NewInt(int64(i))) + orig.updateStateObject(obj) + } + orig.Finalise(false) + + // Copy the state + copy := orig.Copy() + + // Copy the copy state + ccopy := copy.Copy() + + // modify all in memory + for i := byte(0); i < 255; i++ { + origObj := orig.GetOrNewStateObject(common.BytesToAddress([]byte{i})) + copyObj := copy.GetOrNewStateObject(common.BytesToAddress([]byte{i})) + ccopyObj := ccopy.GetOrNewStateObject(common.BytesToAddress([]byte{i})) + + origObj.AddBalance(big.NewInt(2 * int64(i))) + copyObj.AddBalance(big.NewInt(3 * int64(i))) + ccopyObj.AddBalance(big.NewInt(4 * int64(i))) + + orig.updateStateObject(origObj) + copy.updateStateObject(copyObj) + ccopy.updateStateObject(copyObj) + } + + // Finalise the changes on all concurrently + finalise := func(wg *sync.WaitGroup, db *StateDB) { + defer wg.Done() + db.Finalise(true) + } + + var wg sync.WaitGroup + wg.Add(3) + go finalise(&wg, orig) + go finalise(&wg, copy) + go finalise(&wg, ccopy) + wg.Wait() + + // Verify that the three states have been updated independently + for i := byte(0); i < 255; i++ { + origObj := orig.GetOrNewStateObject(common.BytesToAddress([]byte{i})) + copyObj := copy.GetOrNewStateObject(common.BytesToAddress([]byte{i})) + ccopyObj := ccopy.GetOrNewStateObject(common.BytesToAddress([]byte{i})) + + if want := big.NewInt(3 * int64(i)); origObj.Balance().Cmp(want) != 0 { + t.Errorf("orig obj %d: balance mismatch: have %v, want %v", i, origObj.Balance(), want) + } + if want := big.NewInt(4 * int64(i)); copyObj.Balance().Cmp(want) != 0 { + t.Errorf("copy obj %d: balance mismatch: have %v, want %v", i, copyObj.Balance(), want) + } + if want := big.NewInt(5 * int64(i)); ccopyObj.Balance().Cmp(want) != 0 { + t.Errorf("copy obj %d: balance mismatch: have %v, want %v", i, ccopyObj.Balance(), want) + } + } +} + +func TestSnapshotRandom(t *testing.T) { + config := &quick.Config{MaxCount: 1000} + err := quick.Check((*snapshotTest).run, config) + if cerr, ok := err.(*quick.CheckError); ok { + test := cerr.In[0].(*snapshotTest) + t.Errorf("%v:\n%s", test.err, test) + } else if err != nil { + t.Error(err) + } +} + +// A snapshotTest checks that reverting StateDB snapshots properly undoes all changes +// captured by the snapshot. Instances of this test with pseudorandom content are created +// by Generate. +// +// The test works as follows: +// +// A new state is created and all actions are applied to it. Several snapshots are taken +// in between actions. The test then reverts each snapshot. For each snapshot the actions +// leading up to it are replayed on a fresh, empty state. The behaviour of all public +// accessor methods on the reverted state must match the return value of the equivalent +// methods on the replayed state. +type snapshotTest struct { + addrs []common.Address // all account addresses + actions []testAction // modifications to the state + snapshots []int // actions indexes at which snapshot is taken + err error // failure details are reported through this field +} + +type testAction struct { + name string + fn func(testAction, *StateDB) + args []int64 + noAddr bool +} + +// newTestAction creates a random action that changes state. +func newTestAction(addr common.Address, r *rand.Rand) testAction { + actions := []testAction{ + { + name: "SetBalance", + fn: func(a testAction, s *StateDB) { + s.SetBalance(addr, big.NewInt(a.args[0])) + }, + args: make([]int64, 1), + }, + { + name: "AddBalance", + fn: func(a testAction, s *StateDB) { + s.AddBalance(addr, big.NewInt(a.args[0])) + }, + args: make([]int64, 1), + }, + { + name: "SetNonce", + fn: func(a testAction, s *StateDB) { + s.SetNonce(addr, uint64(a.args[0])) + }, + args: make([]int64, 1), + }, + { + name: "SetState", + fn: func(a testAction, s *StateDB) { + var key, val common.Hash + binary.BigEndian.PutUint16(key[:], uint16(a.args[0])) + binary.BigEndian.PutUint16(val[:], uint16(a.args[1])) + s.SetState(addr, key, val) + }, + args: make([]int64, 2), + }, + { + name: "SetCode", + fn: func(a testAction, s *StateDB) { + code := make([]byte, 16) + binary.BigEndian.PutUint64(code, uint64(a.args[0])) + binary.BigEndian.PutUint64(code[8:], uint64(a.args[1])) + s.SetCode(addr, code) + }, + args: make([]int64, 2), + }, + { + name: "CreateAccount", + fn: func(a testAction, s *StateDB) { + s.CreateAccount(addr) + }, + }, + { + name: "Suicide", + fn: func(a testAction, s *StateDB) { + s.Suicide(addr) + }, + }, + { + name: "AddRefund", + fn: func(a testAction, s *StateDB) { + s.AddRefund(uint64(a.args[0])) + }, + args: make([]int64, 1), + noAddr: true, + }, + { + name: "AddLog", + fn: func(a testAction, s *StateDB) { + data := make([]byte, 2) + binary.BigEndian.PutUint16(data, uint16(a.args[0])) + s.AddLog(&types.Log{Address: addr, Data: data}) + }, + args: make([]int64, 1), + }, + { + name: "AddPreimage", + fn: func(a testAction, s *StateDB) { + preimage := []byte{1} + hash := common.BytesToHash(preimage) + s.AddPreimage(hash, preimage) + }, + args: make([]int64, 1), + }, + { + name: "AddAddressToAccessList", + fn: func(a testAction, s *StateDB) { + s.AddAddressToAccessList(addr) + }, + }, + { + name: "AddSlotToAccessList", + fn: func(a testAction, s *StateDB) { + s.AddSlotToAccessList(addr, + common.Hash{byte(a.args[0])}) + }, + args: make([]int64, 1), + }, + { + name: "SetTransientState", + fn: func(a testAction, s *StateDB) { + var key, val common.Hash + binary.BigEndian.PutUint16(key[:], uint16(a.args[0])) + binary.BigEndian.PutUint16(val[:], uint16(a.args[1])) + s.SetTransientState(addr, key, val) + }, + args: make([]int64, 2), + }, + } + action := actions[r.Intn(len(actions))] + var nameargs []string + if !action.noAddr { + nameargs = append(nameargs, addr.Hex()) + } + for i := range action.args { + action.args[i] = rand.Int63n(100) + nameargs = append(nameargs, fmt.Sprint(action.args[i])) + } + action.name += strings.Join(nameargs, ", ") + return action +} + +// Generate returns a new snapshot test of the given size. All randomness is +// derived from r. +func (*snapshotTest) Generate(r *rand.Rand, size int) reflect.Value { + // Generate random actions. + addrs := make([]common.Address, 50) + for i := range addrs { + addrs[i][0] = byte(i) + } + actions := make([]testAction, size) + for i := range actions { + addr := addrs[r.Intn(len(addrs))] + actions[i] = newTestAction(addr, r) + } + // Generate snapshot indexes. + nsnapshots := int(math.Sqrt(float64(size))) + if size > 0 && nsnapshots == 0 { + nsnapshots = 1 + } + snapshots := make([]int, nsnapshots) + snaplen := len(actions) / nsnapshots + for i := range snapshots { + // Try to place the snapshots some number of actions apart from each other. + snapshots[i] = (i * snaplen) + r.Intn(snaplen) + } + return reflect.ValueOf(&snapshotTest{addrs, actions, snapshots, nil}) +} + +func (test *snapshotTest) String() string { + out := new(bytes.Buffer) + sindex := 0 + for i, action := range test.actions { + if len(test.snapshots) > sindex && i == test.snapshots[sindex] { + fmt.Fprintf(out, "---- snapshot %d ----\n", sindex) + sindex++ + } + fmt.Fprintf(out, "%4d: %s\n", i, action.name) + } + return out.String() +} + +func (test *snapshotTest) run() bool { + // Run all actions and create snapshots. + var ( + state, _ = New(common.Hash{}, NewDatabase(rawdb.NewMemoryDatabase()), nil) + snapshotRevs = make([]int, len(test.snapshots)) + sindex = 0 + ) + for i, action := range test.actions { + if len(test.snapshots) > sindex && i == test.snapshots[sindex] { + snapshotRevs[sindex] = state.Snapshot() + sindex++ + } + action.fn(action, state) + } + // Revert all snapshots in reverse order. Each revert must yield a state + // that is equivalent to fresh state with all actions up the snapshot applied. + for sindex--; sindex >= 0; sindex-- { + checkstate, _ := New(common.Hash{}, state.Database(), nil) + for _, action := range test.actions[:test.snapshots[sindex]] { + action.fn(action, checkstate) + } + state.RevertToSnapshot(snapshotRevs[sindex]) + if err := test.checkEqual(state, checkstate); err != nil { + test.err = fmt.Errorf("state mismatch after revert to snapshot %d\n%v", sindex, err) + return false + } + } + return true +} + +// checkEqual checks that methods of state and checkstate return the same values. +func (test *snapshotTest) checkEqual(state, checkstate *StateDB) error { + for _, addr := range test.addrs { + var err error + checkeq := func(op string, a, b interface{}) bool { + if err == nil && !reflect.DeepEqual(a, b) { + err = fmt.Errorf("got %s(%s) == %v, want %v", op, addr.Hex(), a, b) + return false + } + return true + } + // Check basic accessor methods. + checkeq("Exist", state.Exist(addr), checkstate.Exist(addr)) + checkeq("HasSuicided", state.HasSuicided(addr), checkstate.HasSuicided(addr)) + checkeq("GetBalance", state.GetBalance(addr), checkstate.GetBalance(addr)) + checkeq("GetNonce", state.GetNonce(addr), checkstate.GetNonce(addr)) + checkeq("GetCode", state.GetCode(addr), checkstate.GetCode(addr)) + checkeq("GetCodeHash", state.GetCodeHash(addr), checkstate.GetCodeHash(addr)) + checkeq("GetCodeSize", state.GetCodeSize(addr), checkstate.GetCodeSize(addr)) + // Check storage. + if obj := state.getStateObject(addr); obj != nil { + state.ForEachStorage(addr, func(key, value common.Hash) bool { + return checkeq("GetState("+key.Hex()+")", checkstate.GetState(addr, key), value) + }) + checkstate.ForEachStorage(addr, func(key, value common.Hash) bool { + return checkeq("GetState("+key.Hex()+")", checkstate.GetState(addr, key), value) + }) + } + if err != nil { + return err + } + } + + if state.GetRefund() != checkstate.GetRefund() { + return fmt.Errorf("got GetRefund() == %d, want GetRefund() == %d", + state.GetRefund(), checkstate.GetRefund()) + } + if !reflect.DeepEqual(state.GetLogs(common.Hash{}, 0, common.Hash{}), checkstate.GetLogs(common.Hash{}, 0, common.Hash{})) { + return fmt.Errorf("got GetLogs(common.Hash{}) == %v, want GetLogs(common.Hash{}) == %v", + state.GetLogs(common.Hash{}, 0, common.Hash{}), checkstate.GetLogs(common.Hash{}, 0, common.Hash{})) + } + return nil +} + +// TestCopyOfCopy tests that modified objects are carried over to the copy, and the copy of the copy. +// See https://github.com/ethereum/go-ethereum/pull/15225#issuecomment-380191512 +func TestCopyOfCopy(t *testing.T) { + state, _ := New(common.Hash{}, NewDatabase(rawdb.NewMemoryDatabase()), nil) + addr := common.HexToAddress("aaaa") + state.SetBalance(addr, big.NewInt(42)) + + if got := state.Copy().GetBalance(addr).Uint64(); got != 42 { + t.Fatalf("1st copy fail, expected 42, got %v", got) + } + if got := state.Copy().Copy().GetBalance(addr).Uint64(); got != 42 { + t.Fatalf("2nd copy fail, expected 42, got %v", got) + } +} + +func TestStateDBAccessList(t *testing.T) { + // Some helpers + addr := func(a string) common.Address { + return common.HexToAddress(a) + } + slot := func(a string) common.Hash { + return common.HexToHash(a) + } + + memDb := rawdb.NewMemoryDatabase() + db := NewDatabase(memDb) + state, _ := New(common.Hash{}, db, nil) + state.accessList = newAccessList() + + verifyAddrs := func(astrings ...string) { + t.Helper() + // convert to common.Address form + var addresses []common.Address + var addressMap = make(map[common.Address]struct{}) + for _, astring := range astrings { + address := addr(astring) + addresses = append(addresses, address) + addressMap[address] = struct{}{} + } + // Check that the given addresses are in the access list + for _, address := range addresses { + if !state.AddressInAccessList(address) { + t.Fatalf("expected %x to be in access list", address) + } + } + // Check that only the expected addresses are present in the access list + for address := range state.accessList.addresses { + if _, exist := addressMap[address]; !exist { + t.Fatalf("extra address %x in access list", address) + } + } + } + verifySlots := func(addrString string, slotStrings ...string) { + if !state.AddressInAccessList(addr(addrString)) { + t.Fatalf("scope missing address/slots %v", addrString) + } + var address = addr(addrString) + // convert to common.Hash form + var slots []common.Hash + var slotMap = make(map[common.Hash]struct{}) + for _, slotString := range slotStrings { + s := slot(slotString) + slots = append(slots, s) + slotMap[s] = struct{}{} + } + // Check that the expected items are in the access list + for i, s := range slots { + if _, slotPresent := state.SlotInAccessList(address, s); !slotPresent { + t.Fatalf("input %d: scope missing slot %v (address %v)", i, s, addrString) + } + } + // Check that no extra elements are in the access list + index := state.accessList.addresses[address] + if index >= 0 { + stateSlots := state.accessList.slots[index] + for s := range stateSlots { + if _, slotPresent := slotMap[s]; !slotPresent { + t.Fatalf("scope has extra slot %v (address %v)", s, addrString) + } + } + } + } + + state.AddAddressToAccessList(addr("aa")) // 1 + state.AddSlotToAccessList(addr("bb"), slot("01")) // 2,3 + state.AddSlotToAccessList(addr("bb"), slot("02")) // 4 + verifyAddrs("aa", "bb") + verifySlots("bb", "01", "02") + + // Make a copy + stateCopy1 := state.Copy() + if exp, got := 4, state.journal.length(); exp != got { + t.Fatalf("journal length mismatch: have %d, want %d", got, exp) + } + + // same again, should cause no journal entries + state.AddSlotToAccessList(addr("bb"), slot("01")) + state.AddSlotToAccessList(addr("bb"), slot("02")) + state.AddAddressToAccessList(addr("aa")) + if exp, got := 4, state.journal.length(); exp != got { + t.Fatalf("journal length mismatch: have %d, want %d", got, exp) + } + // some new ones + state.AddSlotToAccessList(addr("bb"), slot("03")) // 5 + state.AddSlotToAccessList(addr("aa"), slot("01")) // 6 + state.AddSlotToAccessList(addr("cc"), slot("01")) // 7,8 + state.AddAddressToAccessList(addr("cc")) + if exp, got := 8, state.journal.length(); exp != got { + t.Fatalf("journal length mismatch: have %d, want %d", got, exp) + } + + verifyAddrs("aa", "bb", "cc") + verifySlots("aa", "01") + verifySlots("bb", "01", "02", "03") + verifySlots("cc", "01") + + // now start rolling back changes + state.journal.revert(state, 7) + if _, ok := state.SlotInAccessList(addr("cc"), slot("01")); ok { + t.Fatalf("slot present, expected missing") + } + verifyAddrs("aa", "bb", "cc") + verifySlots("aa", "01") + verifySlots("bb", "01", "02", "03") + + state.journal.revert(state, 6) + if state.AddressInAccessList(addr("cc")) { + t.Fatalf("addr present, expected missing") + } + verifyAddrs("aa", "bb") + verifySlots("aa", "01") + verifySlots("bb", "01", "02", "03") + + state.journal.revert(state, 5) + if _, ok := state.SlotInAccessList(addr("aa"), slot("01")); ok { + t.Fatalf("slot present, expected missing") + } + verifyAddrs("aa", "bb") + verifySlots("bb", "01", "02", "03") + + state.journal.revert(state, 4) + if _, ok := state.SlotInAccessList(addr("bb"), slot("03")); ok { + t.Fatalf("slot present, expected missing") + } + verifyAddrs("aa", "bb") + verifySlots("bb", "01", "02") + + state.journal.revert(state, 3) + if _, ok := state.SlotInAccessList(addr("bb"), slot("02")); ok { + t.Fatalf("slot present, expected missing") + } + verifyAddrs("aa", "bb") + verifySlots("bb", "01") + + state.journal.revert(state, 2) + if _, ok := state.SlotInAccessList(addr("bb"), slot("01")); ok { + t.Fatalf("slot present, expected missing") + } + verifyAddrs("aa", "bb") + + state.journal.revert(state, 1) + if state.AddressInAccessList(addr("bb")) { + t.Fatalf("addr present, expected missing") + } + verifyAddrs("aa") + + state.journal.revert(state, 0) + if state.AddressInAccessList(addr("aa")) { + t.Fatalf("addr present, expected missing") + } + if got, exp := len(state.accessList.addresses), 0; got != exp { + t.Fatalf("expected empty, got %d", got) + } + if got, exp := len(state.accessList.slots), 0; got != exp { + t.Fatalf("expected empty, got %d", got) + } + // Check the copy + // Make a copy + state = stateCopy1 + verifyAddrs("aa", "bb") + verifySlots("bb", "01", "02") + if got, exp := len(state.accessList.addresses), 2; got != exp { + t.Fatalf("expected empty, got %d", got) + } + if got, exp := len(state.accessList.slots), 1; got != exp { + t.Fatalf("expected empty, got %d", got) + } +} + +func TestStateDBTransientStorage(t *testing.T) { + memDb := rawdb.NewMemoryDatabase() + db := NewDatabase(memDb) + state, _ := New(common.Hash{}, db, nil) + + key := common.Hash{0x01} + value := common.Hash{0x02} + addr := common.Address{} + + state.SetTransientState(addr, key, value) + if exp, got := 1, state.journal.length(); exp != got { + t.Fatalf("journal length mismatch: have %d, want %d", got, exp) + } + // the retrieved value should equal what was set + if got := state.GetTransientState(addr, key); got != value { + t.Fatalf("transient storage mismatch: have %x, want %x", got, value) + } + + // revert the transient state being set and then check that the + // value is now the empty hash + state.journal.revert(state, 0) + if got, exp := state.GetTransientState(addr, key), (common.Hash{}); exp != got { + t.Fatalf("transient storage mismatch: have %x, want %x", got, exp) + } + + // set transient state and then copy the statedb and ensure that + // the transient state is copied + state.SetTransientState(addr, key, value) + cpy := state.Copy() + if got := cpy.GetTransientState(addr, key); got != value { + t.Fatalf("transient storage mismatch: have %x, want %x", got, value) + } +} diff --git a/trie_by_cid/state/transient_storage.go b/trie_by_cid/state/transient_storage.go new file mode 100644 index 0000000..66e563e --- /dev/null +++ b/trie_by_cid/state/transient_storage.go @@ -0,0 +1,55 @@ +// Copyright 2022 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package state + +import ( + "github.com/ethereum/go-ethereum/common" +) + +// transientStorage is a representation of EIP-1153 "Transient Storage". +type transientStorage map[common.Address]Storage + +// newTransientStorage creates a new instance of a transientStorage. +func newTransientStorage() transientStorage { + return make(transientStorage) +} + +// Set sets the transient-storage `value` for `key` at the given `addr`. +func (t transientStorage) Set(addr common.Address, key, value common.Hash) { + if _, ok := t[addr]; !ok { + t[addr] = make(Storage) + } + t[addr][key] = value +} + +// Get gets the transient storage for `key` at the given `addr`. +func (t transientStorage) Get(addr common.Address, key common.Hash) common.Hash { + val, ok := t[addr] + if !ok { + return common.Hash{} + } + return val[key] +} + +// Copy does a deep copy of the transientStorage +func (t transientStorage) Copy() transientStorage { + storage := make(transientStorage) + for key, value := range t { + storage[key] = value.Copy() + } + return storage +} diff --git a/trie_by_cid/state/trie_prefetcher.go b/trie_by_cid/state/trie_prefetcher.go new file mode 100644 index 0000000..5dd1b5b --- /dev/null +++ b/trie_by_cid/state/trie_prefetcher.go @@ -0,0 +1,354 @@ +// Copyright 2020 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package state + +import ( + "sync" + + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/metrics" + log "github.com/sirupsen/logrus" +) + +var ( + // triePrefetchMetricsPrefix is the prefix under which to publish the metrics. + triePrefetchMetricsPrefix = "trie/prefetch/" +) + +// triePrefetcher is an active prefetcher, which receives accounts or storage +// items and does trie-loading of them. The goal is to get as much useful content +// into the caches as possible. +// +// Note, the prefetcher's API is not thread safe. +type triePrefetcher struct { + db Database // Database to fetch trie nodes through + root common.Hash // Root hash of the account trie for metrics + fetches map[string]Trie // Partially or fully fetcher tries + fetchers map[string]*subfetcher // Subfetchers for each trie + + deliveryMissMeter metrics.Meter + accountLoadMeter metrics.Meter + accountDupMeter metrics.Meter + accountSkipMeter metrics.Meter + accountWasteMeter metrics.Meter + storageLoadMeter metrics.Meter + storageDupMeter metrics.Meter + storageSkipMeter metrics.Meter + storageWasteMeter metrics.Meter +} + +func newTriePrefetcher(db Database, root common.Hash, namespace string) *triePrefetcher { + prefix := triePrefetchMetricsPrefix + namespace + p := &triePrefetcher{ + db: db, + root: root, + fetchers: make(map[string]*subfetcher), // Active prefetchers use the fetchers map + + deliveryMissMeter: metrics.GetOrRegisterMeter(prefix+"/deliverymiss", nil), + accountLoadMeter: metrics.GetOrRegisterMeter(prefix+"/account/load", nil), + accountDupMeter: metrics.GetOrRegisterMeter(prefix+"/account/dup", nil), + accountSkipMeter: metrics.GetOrRegisterMeter(prefix+"/account/skip", nil), + accountWasteMeter: metrics.GetOrRegisterMeter(prefix+"/account/waste", nil), + storageLoadMeter: metrics.GetOrRegisterMeter(prefix+"/storage/load", nil), + storageDupMeter: metrics.GetOrRegisterMeter(prefix+"/storage/dup", nil), + storageSkipMeter: metrics.GetOrRegisterMeter(prefix+"/storage/skip", nil), + storageWasteMeter: metrics.GetOrRegisterMeter(prefix+"/storage/waste", nil), + } + return p +} + +// close iterates over all the subfetchers, aborts any that were left spinning +// and reports the stats to the metrics subsystem. +func (p *triePrefetcher) close() { + for _, fetcher := range p.fetchers { + fetcher.abort() // safe to do multiple times + + if metrics.Enabled { + if fetcher.root == p.root { + p.accountLoadMeter.Mark(int64(len(fetcher.seen))) + p.accountDupMeter.Mark(int64(fetcher.dups)) + p.accountSkipMeter.Mark(int64(len(fetcher.tasks))) + + for _, key := range fetcher.used { + delete(fetcher.seen, string(key)) + } + p.accountWasteMeter.Mark(int64(len(fetcher.seen))) + } else { + p.storageLoadMeter.Mark(int64(len(fetcher.seen))) + p.storageDupMeter.Mark(int64(fetcher.dups)) + p.storageSkipMeter.Mark(int64(len(fetcher.tasks))) + + for _, key := range fetcher.used { + delete(fetcher.seen, string(key)) + } + p.storageWasteMeter.Mark(int64(len(fetcher.seen))) + } + } + } + // Clear out all fetchers (will crash on a second call, deliberate) + p.fetchers = nil +} + +// copy creates a deep-but-inactive copy of the trie prefetcher. Any trie data +// already loaded will be copied over, but no goroutines will be started. This +// is mostly used in the miner which creates a copy of it's actively mutated +// state to be sealed while it may further mutate the state. +func (p *triePrefetcher) copy() *triePrefetcher { + copy := &triePrefetcher{ + db: p.db, + root: p.root, + fetches: make(map[string]Trie), // Active prefetchers use the fetches map + + deliveryMissMeter: p.deliveryMissMeter, + accountLoadMeter: p.accountLoadMeter, + accountDupMeter: p.accountDupMeter, + accountSkipMeter: p.accountSkipMeter, + accountWasteMeter: p.accountWasteMeter, + storageLoadMeter: p.storageLoadMeter, + storageDupMeter: p.storageDupMeter, + storageSkipMeter: p.storageSkipMeter, + storageWasteMeter: p.storageWasteMeter, + } + // If the prefetcher is already a copy, duplicate the data + if p.fetches != nil { + for root, fetch := range p.fetches { + if fetch == nil { + continue + } + copy.fetches[root] = p.db.CopyTrie(fetch) + } + return copy + } + // Otherwise we're copying an active fetcher, retrieve the current states + for id, fetcher := range p.fetchers { + copy.fetches[id] = fetcher.peek() + } + return copy +} + +// prefetch schedules a batch of trie items to prefetch. +func (p *triePrefetcher) prefetch(owner common.Hash, root common.Hash, keys [][]byte) { + // If the prefetcher is an inactive one, bail out + if p.fetches != nil { + return + } + // Active fetcher, schedule the retrievals + id := p.trieID(owner, root) + fetcher := p.fetchers[id] + if fetcher == nil { + fetcher = newSubfetcher(p.db, p.root, owner, root) + p.fetchers[id] = fetcher + } + fetcher.schedule(keys) +} + +// trie returns the trie matching the root hash, or nil if the prefetcher doesn't +// have it. +func (p *triePrefetcher) trie(owner common.Hash, root common.Hash) Trie { + // If the prefetcher is inactive, return from existing deep copies + id := p.trieID(owner, root) + if p.fetches != nil { + trie := p.fetches[id] + if trie == nil { + p.deliveryMissMeter.Mark(1) + return nil + } + return p.db.CopyTrie(trie) + } + // Otherwise the prefetcher is active, bail if no trie was prefetched for this root + fetcher := p.fetchers[id] + if fetcher == nil { + p.deliveryMissMeter.Mark(1) + return nil + } + // Interrupt the prefetcher if it's by any chance still running and return + // a copy of any pre-loaded trie. + fetcher.abort() // safe to do multiple times + + trie := fetcher.peek() + if trie == nil { + p.deliveryMissMeter.Mark(1) + return nil + } + return trie +} + +// used marks a batch of state items used to allow creating statistics as to +// how useful or wasteful the prefetcher is. +func (p *triePrefetcher) used(owner common.Hash, root common.Hash, used [][]byte) { + if fetcher := p.fetchers[p.trieID(owner, root)]; fetcher != nil { + fetcher.used = used + } +} + +// trieID returns an unique trie identifier consists the trie owner and root hash. +func (p *triePrefetcher) trieID(owner common.Hash, root common.Hash) string { + return string(append(owner.Bytes(), root.Bytes()...)) +} + +// subfetcher is a trie fetcher goroutine responsible for pulling entries for a +// single trie. It is spawned when a new root is encountered and lives until the +// main prefetcher is paused and either all requested items are processed or if +// the trie being worked on is retrieved from the prefetcher. +type subfetcher struct { + db Database // Database to load trie nodes through + state common.Hash // Root hash of the state to prefetch + owner common.Hash // Owner of the trie, usually account hash + root common.Hash // Root hash of the trie to prefetch + trie Trie // Trie being populated with nodes + + tasks [][]byte // Items queued up for retrieval + lock sync.Mutex // Lock protecting the task queue + + wake chan struct{} // Wake channel if a new task is scheduled + stop chan struct{} // Channel to interrupt processing + term chan struct{} // Channel to signal interruption + copy chan chan Trie // Channel to request a copy of the current trie + + seen map[string]struct{} // Tracks the entries already loaded + dups int // Number of duplicate preload tasks + used [][]byte // Tracks the entries used in the end +} + +// newSubfetcher creates a goroutine to prefetch state items belonging to a +// particular root hash. +func newSubfetcher(db Database, state common.Hash, owner common.Hash, root common.Hash) *subfetcher { + sf := &subfetcher{ + db: db, + state: state, + owner: owner, + root: root, + wake: make(chan struct{}, 1), + stop: make(chan struct{}), + term: make(chan struct{}), + copy: make(chan chan Trie), + seen: make(map[string]struct{}), + } + go sf.loop() + return sf +} + +// schedule adds a batch of trie keys to the queue to prefetch. +func (sf *subfetcher) schedule(keys [][]byte) { + // Append the tasks to the current queue + sf.lock.Lock() + sf.tasks = append(sf.tasks, keys...) + sf.lock.Unlock() + + // Notify the prefetcher, it's fine if it's already terminated + select { + case sf.wake <- struct{}{}: + default: + } +} + +// peek tries to retrieve a deep copy of the fetcher's trie in whatever form it +// is currently. +func (sf *subfetcher) peek() Trie { + ch := make(chan Trie) + select { + case sf.copy <- ch: + // Subfetcher still alive, return copy from it + return <-ch + + case <-sf.term: + // Subfetcher already terminated, return a copy directly + if sf.trie == nil { + return nil + } + return sf.db.CopyTrie(sf.trie) + } +} + +// abort interrupts the subfetcher immediately. It is safe to call abort multiple +// times but it is not thread safe. +func (sf *subfetcher) abort() { + select { + case <-sf.stop: + default: + close(sf.stop) + } + <-sf.term +} + +// loop waits for new tasks to be scheduled and keeps loading them until it runs +// out of tasks or its underlying trie is retrieved for committing. +func (sf *subfetcher) loop() { + // No matter how the loop stops, signal anyone waiting that it's terminated + defer close(sf.term) + + // Start by opening the trie and stop processing if it fails + if sf.owner == (common.Hash{}) { + trie, err := sf.db.OpenTrie(sf.root) + if err != nil { + log.Warn("Trie prefetcher failed opening trie", "root", sf.root, "err", err) + return + } + sf.trie = trie + } else { + trie, err := sf.db.OpenStorageTrie(sf.state, sf.owner, sf.root) + if err != nil { + log.Warn("Trie prefetcher failed opening trie", "root", sf.root, "err", err) + return + } + sf.trie = trie + } + // Trie opened successfully, keep prefetching items + for { + select { + case <-sf.wake: + // Subfetcher was woken up, retrieve any tasks to avoid spinning the lock + sf.lock.Lock() + tasks := sf.tasks + sf.tasks = nil + sf.lock.Unlock() + + // Prefetch any tasks until the loop is interrupted + for i, task := range tasks { + select { + case <-sf.stop: + // If termination is requested, add any leftover back and return + sf.lock.Lock() + sf.tasks = append(sf.tasks, tasks[i:]...) + sf.lock.Unlock() + return + + case ch := <-sf.copy: + // Somebody wants a copy of the current trie, grant them + ch <- sf.db.CopyTrie(sf.trie) + + default: + // No termination request yet, prefetch the next entry + if _, ok := sf.seen[string(task)]; ok { + sf.dups++ + } else { + sf.trie.TryGet(task) + sf.seen[string(task)] = struct{}{} + } + } + } + + case ch := <-sf.copy: + // Somebody wants a copy of the current trie, grant them + ch <- sf.db.CopyTrie(sf.trie) + + case <-sf.stop: + // Termination is requested, abort and leave remaining tasks + return + } + } +} diff --git a/trie_by_cid/state/trie_prefetcher_test.go b/trie_by_cid/state/trie_prefetcher_test.go new file mode 100644 index 0000000..cb0b67d --- /dev/null +++ b/trie_by_cid/state/trie_prefetcher_test.go @@ -0,0 +1,110 @@ +// Copyright 2021 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package state + +import ( + "math/big" + "testing" + "time" + + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/core/rawdb" +) + +func filledStateDB() *StateDB { + state, _ := New(common.Hash{}, NewDatabase(rawdb.NewMemoryDatabase()), nil) + + // Create an account and check if the retrieved balance is correct + addr := common.HexToAddress("0xaffeaffeaffeaffeaffeaffeaffeaffeaffeaffe") + skey := common.HexToHash("aaa") + sval := common.HexToHash("bbb") + + state.SetBalance(addr, big.NewInt(42)) // Change the account trie + state.SetCode(addr, []byte("hello")) // Change an external metadata + state.SetState(addr, skey, sval) // Change the storage trie + for i := 0; i < 100; i++ { + sk := common.BigToHash(big.NewInt(int64(i))) + state.SetState(addr, sk, sk) // Change the storage trie + } + return state +} + +func TestCopyAndClose(t *testing.T) { + db := filledStateDB() + prefetcher := newTriePrefetcher(db.db, db.originalRoot, "") + skey := common.HexToHash("aaa") + prefetcher.prefetch(common.Hash{}, db.originalRoot, [][]byte{skey.Bytes()}) + prefetcher.prefetch(common.Hash{}, db.originalRoot, [][]byte{skey.Bytes()}) + time.Sleep(1 * time.Second) + a := prefetcher.trie(common.Hash{}, db.originalRoot) + prefetcher.prefetch(common.Hash{}, db.originalRoot, [][]byte{skey.Bytes()}) + b := prefetcher.trie(common.Hash{}, db.originalRoot) + cpy := prefetcher.copy() + cpy.prefetch(common.Hash{}, db.originalRoot, [][]byte{skey.Bytes()}) + cpy.prefetch(common.Hash{}, db.originalRoot, [][]byte{skey.Bytes()}) + c := cpy.trie(common.Hash{}, db.originalRoot) + prefetcher.close() + cpy2 := cpy.copy() + cpy2.prefetch(common.Hash{}, db.originalRoot, [][]byte{skey.Bytes()}) + d := cpy2.trie(common.Hash{}, db.originalRoot) + cpy.close() + cpy2.close() + if a.Hash() != b.Hash() || a.Hash() != c.Hash() || a.Hash() != d.Hash() { + t.Fatalf("Invalid trie, hashes should be equal: %v %v %v %v", a.Hash(), b.Hash(), c.Hash(), d.Hash()) + } +} + +func TestUseAfterClose(t *testing.T) { + db := filledStateDB() + prefetcher := newTriePrefetcher(db.db, db.originalRoot, "") + skey := common.HexToHash("aaa") + prefetcher.prefetch(common.Hash{}, db.originalRoot, [][]byte{skey.Bytes()}) + a := prefetcher.trie(common.Hash{}, db.originalRoot) + prefetcher.close() + b := prefetcher.trie(common.Hash{}, db.originalRoot) + if a == nil { + t.Fatal("Prefetching before close should not return nil") + } + if b != nil { + t.Fatal("Trie after close should return nil") + } +} + +func TestCopyClose(t *testing.T) { + db := filledStateDB() + prefetcher := newTriePrefetcher(db.db, db.originalRoot, "") + skey := common.HexToHash("aaa") + prefetcher.prefetch(common.Hash{}, db.originalRoot, [][]byte{skey.Bytes()}) + cpy := prefetcher.copy() + a := prefetcher.trie(common.Hash{}, db.originalRoot) + b := cpy.trie(common.Hash{}, db.originalRoot) + prefetcher.close() + c := prefetcher.trie(common.Hash{}, db.originalRoot) + d := cpy.trie(common.Hash{}, db.originalRoot) + if a == nil { + t.Fatal("Prefetching before close should not return nil") + } + if b == nil { + t.Fatal("Copy trie should return nil") + } + if c != nil { + t.Fatal("Trie after close should return nil") + } + if d == nil { + t.Fatal("Copy trie should not return nil") + } +} diff --git a/trie_by_cid/trie/committer.go b/trie_by_cid/trie/committer.go new file mode 100644 index 0000000..9f97887 --- /dev/null +++ b/trie_by_cid/trie/committer.go @@ -0,0 +1,196 @@ +// Copyright 2020 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package trie + +import ( + "fmt" + + "github.com/ethereum/go-ethereum/common" +) + +// leaf represents a trie leaf node +type leaf struct { + blob []byte // raw blob of leaf + parent common.Hash // the hash of parent node +} + +// committer is the tool used for the trie Commit operation. The committer will +// capture all dirty nodes during the commit process and keep them cached in +// insertion order. +type committer struct { + nodes *NodeSet + collectLeaf bool +} + +// newCommitter creates a new committer or picks one from the pool. +func newCommitter(nodeset *NodeSet, collectLeaf bool) *committer { + return &committer{ + nodes: nodeset, + collectLeaf: collectLeaf, + } +} + +// Commit collapses a node down into a hash node. +func (c *committer) Commit(n node) hashNode { + return c.commit(nil, n).(hashNode) +} + +// commit collapses a node down into a hash node and returns it. +func (c *committer) commit(path []byte, n node) node { + // if this path is clean, use available cached data + hash, dirty := n.cache() + if hash != nil && !dirty { + return hash + } + // Commit children, then parent, and remove the dirty flag. + switch cn := n.(type) { + case *shortNode: + // Commit child + collapsed := cn.copy() + + // If the child is fullNode, recursively commit, + // otherwise it can only be hashNode or valueNode. + if _, ok := cn.Val.(*fullNode); ok { + collapsed.Val = c.commit(append(path, cn.Key...), cn.Val) + } + // The key needs to be copied, since we're adding it to the + // modified nodeset. + collapsed.Key = hexToCompact(cn.Key) + hashedNode := c.store(path, collapsed) + if hn, ok := hashedNode.(hashNode); ok { + return hn + } + return collapsed + case *fullNode: + hashedKids := c.commitChildren(path, cn) + collapsed := cn.copy() + collapsed.Children = hashedKids + + hashedNode := c.store(path, collapsed) + if hn, ok := hashedNode.(hashNode); ok { + return hn + } + return collapsed + case hashNode: + return cn + default: + // nil, valuenode shouldn't be committed + panic(fmt.Sprintf("%T: invalid node: %v", n, n)) + } +} + +// commitChildren commits the children of the given fullnode +func (c *committer) commitChildren(path []byte, n *fullNode) [17]node { + var children [17]node + for i := 0; i < 16; i++ { + child := n.Children[i] + if child == nil { + continue + } + // If it's the hashed child, save the hash value directly. + // Note: it's impossible that the child in range [0, 15] + // is a valueNode. + if hn, ok := child.(hashNode); ok { + children[i] = hn + continue + } + // Commit the child recursively and store the "hashed" value. + // Note the returned node can be some embedded nodes, so it's + // possible the type is not hashNode. + children[i] = c.commit(append(path, byte(i)), child) + } + // For the 17th child, it's possible the type is valuenode. + if n.Children[16] != nil { + children[16] = n.Children[16] + } + return children +} + +// store hashes the node n and adds it to the modified nodeset. If leaf collection +// is enabled, leaf nodes will be tracked in the modified nodeset as well. +func (c *committer) store(path []byte, n node) node { + // Larger nodes are replaced by their hash and stored in the database. + var hash, _ = n.cache() + + // This was not generated - must be a small node stored in the parent. + // In theory, we should check if the node is leaf here (embedded node + // usually is leaf node). But small value (less than 32bytes) is not + // our target (leaves in account trie only). + if hash == nil { + // The node is embedded in its parent, in other words, this node + // will not be stored in the database independently, mark it as + // deleted only if the node was existent in database before. + if _, ok := c.nodes.accessList[string(path)]; ok { + c.nodes.markDeleted(path) + } + return n + } + // We have the hash already, estimate the RLP encoding-size of the node. + // The size is used for mem tracking, does not need to be exact + var ( + size = estimateSize(n) + nhash = common.BytesToHash(hash) + mnode = &memoryNode{ + hash: nhash, + node: simplifyNode(n), + size: uint16(size), + } + ) + // Collect the dirty node to nodeset for return. + c.nodes.markUpdated(path, mnode) + + // Collect the corresponding leaf node if it's required. We don't check + // full node since it's impossible to store value in fullNode. The key + // length of leaves should be exactly same. + if c.collectLeaf { + if sn, ok := n.(*shortNode); ok { + if val, ok := sn.Val.(valueNode); ok { + c.nodes.addLeaf(&leaf{blob: val, parent: nhash}) + } + } + } + return hash +} + +// estimateSize estimates the size of an rlp-encoded node, without actually +// rlp-encoding it (zero allocs). This method has been experimentally tried, and with a trie +// with 1000 leaves, the only errors above 1% are on small shortnodes, where this +// method overestimates by 2 or 3 bytes (e.g. 37 instead of 35) +func estimateSize(n node) int { + switch n := n.(type) { + case *shortNode: + // A short node contains a compacted key, and a value. + return 3 + len(n.Key) + estimateSize(n.Val) + case *fullNode: + // A full node contains up to 16 hashes (some nils), and a key + s := 3 + for i := 0; i < 16; i++ { + if child := n.Children[i]; child != nil { + s += estimateSize(child) + } else { + s++ + } + } + return s + case valueNode: + return 1 + len(n) + case hashNode: + return 1 + len(n) + default: + panic(fmt.Sprintf("node type %T", n)) + } +} diff --git a/trie_by_cid/trie/database.go b/trie_by_cid/trie/database.go index d13cb49..ac05f98 100644 --- a/trie_by_cid/trie/database.go +++ b/trie_by_cid/trie/database.go @@ -18,25 +18,50 @@ package trie import ( "errors" + "runtime" + "sync" + "time" "github.com/VictoriaMetrics/fastcache" + "github.com/cerc-io/ipld-eth-statedb/internal" + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/core/rawdb" + "github.com/ethereum/go-ethereum/core/types" "github.com/ethereum/go-ethereum/ethdb" + "github.com/ethereum/go-ethereum/rlp" "github.com/ethereum/go-ethereum/trie" - "github.com/ipfs/go-cid" + log "github.com/sirupsen/logrus" ) -type CidKey = cid.Cid - -func isEmpty(key CidKey) bool { - return len(key.KeyString()) == 0 -} - -// Database is an intermediate read-only layer between the trie data structures and -// the disk database. This trie Database is thread safe in providing individual, -// independent node access. +// Database is an intermediate write layer between the trie data structures and +// the disk database. The aim is to accumulate trie writes in-memory and only +// periodically flush a couple tries to disk, garbage collecting the remainder. +// +// Note, the trie Database is **not** thread safe in its mutations, but it **is** +// thread safe in providing individual, independent node access. The rationale +// behind this split design is to provide read access to RPC handlers and sync +// servers even while the trie is executing expensive garbage collection. type Database struct { - diskdb ethdb.KeyValueStore // Persistent storage for matured trie nodes - cleans *fastcache.Cache // GC friendly memory cache of clean node RLPs + diskdb ethdb.Database // Persistent storage for matured trie nodes + + cleans *fastcache.Cache // GC friendly memory cache of clean node RLPs + dirties map[common.Hash]*cachedNode // Data and references relationships of dirty trie nodes + oldest common.Hash // Oldest tracked node, flush-list head + newest common.Hash // Newest tracked node, flush-list tail + + gctime time.Duration // Time spent on garbage collection since last commit + gcnodes uint64 // Nodes garbage collected since last commit + gcsize common.StorageSize // Data storage garbage collected since last commit + + flushtime time.Duration // Time spent on data flushing since last commit + flushnodes uint64 // Nodes flushed since last commit + flushsize common.StorageSize // Data storage flushed since last commit + + dirtiesSize common.StorageSize // Storage size of the dirty node cache (exc. metadata) + childrenSize common.StorageSize // Storage size of the external children tracking + preimages *preimageStore // The store for caching preimages + + lock sync.RWMutex } // Config defines all necessary options for database. @@ -46,14 +71,14 @@ type Config = trie.Config // NewDatabase creates a new trie database to store ephemeral trie content before // its written out to disk or garbage collected. No read cache is created, so all // data retrievals will hit the underlying disk database. -func NewDatabase(diskdb ethdb.KeyValueStore) *Database { +func NewDatabase(diskdb ethdb.Database) *Database { return NewDatabaseWithConfig(diskdb, nil) } // NewDatabaseWithConfig creates a new trie database to store ephemeral trie content -// before it's written out to disk or garbage collected. It also acts as a read cache +// before its written out to disk or garbage collected. It also acts as a read cache // for nodes loaded from disk. -func NewDatabaseWithConfig(diskdb ethdb.KeyValueStore, config *Config) *Database { +func NewDatabaseWithConfig(diskdb ethdb.Database, config *Config) *Database { var cleans *fastcache.Cache if config != nil && config.Cache > 0 { if config.Journal == "" { @@ -62,44 +87,354 @@ func NewDatabaseWithConfig(diskdb ethdb.KeyValueStore, config *Config) *Database cleans = fastcache.LoadFromFileOrNew(config.Journal, config.Cache*1024*1024) } } + var preimage *preimageStore + if config != nil && config.Preimages { + preimage = newPreimageStore(diskdb) + } db := &Database{ diskdb: diskdb, cleans: cleans, + dirties: map[common.Hash]*cachedNode{{}: { + children: make(map[common.Hash]uint16), + }}, + preimages: preimage, } return db } -// DiskDB retrieves the persistent storage backing the trie database. -func (db *Database) DiskDB() ethdb.KeyValueStore { - return db.diskdb +// insert inserts a simplified trie node into the memory database. +// All nodes inserted by this function will be reference tracked +// and in theory should only used for **trie nodes** insertion. +func (db *Database) insert(hash common.Hash, size int, node node) { + // If the node's already cached, skip + if _, ok := db.dirties[hash]; ok { + return + } + memcacheDirtyWriteMeter.Mark(int64(size)) + + // Create the cached entry for this node + entry := &cachedNode{ + node: node, + size: uint16(size), + flushPrev: db.newest, + } + entry.forChilds(func(child common.Hash) { + if c := db.dirties[child]; c != nil { + c.parents++ + } + }) + db.dirties[hash] = entry + + // Update the flush-list endpoints + if db.oldest == (common.Hash{}) { + db.oldest, db.newest = hash, hash + } else { + db.dirties[db.newest].flushNext, db.newest = hash, hash + } + db.dirtiesSize += common.StorageSize(common.HashLength + entry.size) } -// Node retrieves an encoded trie node by CID. If it cannot be found -// cached in memory, it queries the persistent database. -func (db *Database) Node(key CidKey) ([]byte, error) { +// Node retrieves an encoded cached trie node from memory. If it cannot be found +// cached, the method queries the persistent database for the content. +func (db *Database) Node(hash common.Hash, codec uint64) ([]byte, error) { // It doesn't make sense to retrieve the metaroot - if isEmpty(key) { + if hash == (common.Hash{}) { return nil, errors.New("not found") } - cidbytes := key.Bytes() // Retrieve the node from the clean cache if available if db.cleans != nil { - if enc := db.cleans.Get(nil, cidbytes); enc != nil { + if enc := db.cleans.Get(nil, hash[:]); enc != nil { + memcacheCleanHitMeter.Mark(1) + memcacheCleanReadMeter.Mark(int64(len(enc))) return enc, nil } } + // Retrieve the node from the dirty cache if available + db.lock.RLock() + dirty := db.dirties[hash] + db.lock.RUnlock() + + if dirty != nil { + memcacheDirtyHitMeter.Mark(1) + memcacheDirtyReadMeter.Mark(int64(dirty.size)) + return dirty.rlp(), nil + } + memcacheDirtyMissMeter.Mark(1) // Content unavailable in memory, attempt to retrieve from disk - enc, err := db.diskdb.Get(cidbytes) + cid, err := internal.Keccak256ToCid(codec, hash[:]) + if err != nil { + return nil, err + } + enc, err := db.diskdb.Get(cid.Bytes()) if err != nil { return nil, err } - if len(enc) != 0 { if db.cleans != nil { - db.cleans.Set(cidbytes, enc) + db.cleans.Set(hash[:], enc) + memcacheCleanMissMeter.Mark(1) + memcacheCleanWriteMeter.Mark(int64(len(enc))) } return enc, nil } return nil, errors.New("not found") } + +// Nodes retrieves the hashes of all the nodes cached within the memory database. +// This method is extremely expensive and should only be used to validate internal +// states in test code. +func (db *Database) Nodes() []common.Hash { + db.lock.RLock() + defer db.lock.RUnlock() + + var hashes = make([]common.Hash, 0, len(db.dirties)) + for hash := range db.dirties { + if hash != (common.Hash{}) { // Special case for "root" references/nodes + hashes = append(hashes, hash) + } + } + return hashes +} + +// Reference adds a new reference from a parent node to a child node. +// This function is used to add reference between internal trie node +// and external node(e.g. storage trie root), all internal trie nodes +// are referenced together by database itself. +func (db *Database) Reference(child common.Hash, parent common.Hash) { + db.lock.Lock() + defer db.lock.Unlock() + + db.reference(child, parent) +} + +// reference is the private locked version of Reference. +func (db *Database) reference(child common.Hash, parent common.Hash) { + // If the node does not exist, it's a node pulled from disk, skip + node, ok := db.dirties[child] + if !ok { + return + } + // If the reference already exists, only duplicate for roots + if db.dirties[parent].children == nil { + db.dirties[parent].children = make(map[common.Hash]uint16) + db.childrenSize += cachedNodeChildrenSize + } else if _, ok = db.dirties[parent].children[child]; ok && parent != (common.Hash{}) { + return + } + node.parents++ + db.dirties[parent].children[child]++ + if db.dirties[parent].children[child] == 1 { + db.childrenSize += common.HashLength + 2 // uint16 counter + } +} + +// Dereference removes an existing reference from a root node. +func (db *Database) Dereference(root common.Hash) { + // Sanity check to ensure that the meta-root is not removed + if root == (common.Hash{}) { + log.Error("Attempted to dereference the trie cache meta root") + return + } + db.lock.Lock() + defer db.lock.Unlock() + + nodes, storage, start := len(db.dirties), db.dirtiesSize, time.Now() + db.dereference(root, common.Hash{}) + + db.gcnodes += uint64(nodes - len(db.dirties)) + db.gcsize += storage - db.dirtiesSize + db.gctime += time.Since(start) + + memcacheGCTimeTimer.Update(time.Since(start)) + memcacheGCSizeMeter.Mark(int64(storage - db.dirtiesSize)) + memcacheGCNodesMeter.Mark(int64(nodes - len(db.dirties))) + + log.Debug("Dereferenced trie from memory database", "nodes", nodes-len(db.dirties), "size", storage-db.dirtiesSize, "time", time.Since(start), + "gcnodes", db.gcnodes, "gcsize", db.gcsize, "gctime", db.gctime, "livenodes", len(db.dirties), "livesize", db.dirtiesSize) +} + +// dereference is the private locked version of Dereference. +func (db *Database) dereference(child common.Hash, parent common.Hash) { + // Dereference the parent-child + node := db.dirties[parent] + + if node.children != nil && node.children[child] > 0 { + node.children[child]-- + if node.children[child] == 0 { + delete(node.children, child) + db.childrenSize -= (common.HashLength + 2) // uint16 counter + } + } + // If the child does not exist, it's a previously committed node. + node, ok := db.dirties[child] + if !ok { + return + } + // If there are no more references to the child, delete it and cascade + if node.parents > 0 { + // This is a special cornercase where a node loaded from disk (i.e. not in the + // memcache any more) gets reinjected as a new node (short node split into full, + // then reverted into short), causing a cached node to have no parents. That is + // no problem in itself, but don't make maxint parents out of it. + node.parents-- + } + if node.parents == 0 { + // Remove the node from the flush-list + switch child { + case db.oldest: + db.oldest = node.flushNext + db.dirties[node.flushNext].flushPrev = common.Hash{} + case db.newest: + db.newest = node.flushPrev + db.dirties[node.flushPrev].flushNext = common.Hash{} + default: + db.dirties[node.flushPrev].flushNext = node.flushNext + db.dirties[node.flushNext].flushPrev = node.flushPrev + } + // Dereference all children and delete the node + node.forChilds(func(hash common.Hash) { + db.dereference(hash, child) + }) + delete(db.dirties, child) + db.dirtiesSize -= common.StorageSize(common.HashLength + int(node.size)) + if node.children != nil { + db.childrenSize -= cachedNodeChildrenSize + } + } +} + +// Update inserts the dirty nodes in provided nodeset into database and +// link the account trie with multiple storage tries if necessary. +func (db *Database) Update(nodes *MergedNodeSet) error { + db.lock.Lock() + defer db.lock.Unlock() + + // Insert dirty nodes into the database. In the same tree, it must be + // ensured that children are inserted first, then parent so that children + // can be linked with their parent correctly. + // + // Note, the storage tries must be flushed before the account trie to + // retain the invariant that children go into the dirty cache first. + var order []common.Hash + for owner := range nodes.sets { + if owner == (common.Hash{}) { + continue + } + order = append(order, owner) + } + if _, ok := nodes.sets[common.Hash{}]; ok { + order = append(order, common.Hash{}) + } + for _, owner := range order { + subset := nodes.sets[owner] + subset.forEachWithOrder(func(path string, n *memoryNode) { + if n.isDeleted() { + return // ignore deletion + } + db.insert(n.hash, int(n.size), n.node) + }) + } + // Link up the account trie and storage trie if the node points + // to an account trie leaf. + if set, present := nodes.sets[common.Hash{}]; present { + for _, n := range set.leaves { + var account types.StateAccount + if err := rlp.DecodeBytes(n.blob, &account); err != nil { + return err + } + if account.Root != types.EmptyRootHash { + db.reference(account.Root, n.parent) + } + } + } + return nil +} + +// Size returns the current storage size of the memory cache in front of the +// persistent database layer. +func (db *Database) Size() (common.StorageSize, common.StorageSize) { + db.lock.RLock() + defer db.lock.RUnlock() + + // db.dirtiesSize only contains the useful data in the cache, but when reporting + // the total memory consumption, the maintenance metadata is also needed to be + // counted. + var metadataSize = common.StorageSize((len(db.dirties) - 1) * cachedNodeSize) + var metarootRefs = common.StorageSize(len(db.dirties[common.Hash{}].children) * (common.HashLength + 2)) + var preimageSize common.StorageSize + if db.preimages != nil { + preimageSize = db.preimages.size() + } + return db.dirtiesSize + db.childrenSize + metadataSize - metarootRefs, preimageSize +} + +// GetReader retrieves a node reader belonging to the given state root. +func (db *Database) GetReader(root common.Hash, codec uint64) Reader { + return &hashReader{db: db, codec: codec} +} + +// hashReader is reader of hashDatabase which implements the Reader interface. +type hashReader struct { + db *Database + codec uint64 +} + +// Node retrieves the trie node with the given node hash. +func (reader *hashReader) Node(owner common.Hash, path []byte, hash common.Hash) (node, error) { + blob, err := reader.NodeBlob(owner, path, hash) + if err != nil { + return nil, err + } + return decodeNodeUnsafe(hash[:], blob) +} + +// NodeBlob retrieves the RLP-encoded trie node blob with the given node hash. +func (reader *hashReader) NodeBlob(_ common.Hash, _ []byte, hash common.Hash) ([]byte, error) { + return reader.db.Node(hash, reader.codec) +} + +// saveCache saves clean state cache to given directory path +// using specified CPU cores. +func (db *Database) saveCache(dir string, threads int) error { + if db.cleans == nil { + return nil + } + log.Info("Writing clean trie cache to disk", "path", dir, "threads", threads) + + start := time.Now() + err := db.cleans.SaveToFileConcurrent(dir, threads) + if err != nil { + log.Error("Failed to persist clean trie cache", "error", err) + return err + } + log.Info("Persisted the clean trie cache", "path", dir, "elapsed", common.PrettyDuration(time.Since(start))) + return nil +} + +// SaveCache atomically saves fast cache data to the given dir using all +// available CPU cores. +func (db *Database) SaveCache(dir string) error { + return db.saveCache(dir, runtime.GOMAXPROCS(0)) +} + +// SaveCachePeriodically atomically saves fast cache data to the given dir with +// the specified interval. All dump operation will only use a single CPU core. +func (db *Database) SaveCachePeriodically(dir string, interval time.Duration, stopCh <-chan struct{}) { + ticker := time.NewTicker(interval) + defer ticker.Stop() + + for { + select { + case <-ticker.C: + db.saveCache(dir, 1) + case <-stopCh: + return + } + } +} + +// Scheme returns the node scheme used in the database. +func (db *Database) Scheme() string { + return rawdb.HashScheme +} diff --git a/trie_by_cid/trie/database_metrics.go b/trie_by_cid/trie/database_metrics.go new file mode 100644 index 0000000..55efc55 --- /dev/null +++ b/trie_by_cid/trie/database_metrics.go @@ -0,0 +1,27 @@ +package trie + +import "github.com/ethereum/go-ethereum/metrics" + +var ( + memcacheCleanHitMeter = metrics.NewRegisteredMeter("trie/memcache/clean/hit", nil) + memcacheCleanMissMeter = metrics.NewRegisteredMeter("trie/memcache/clean/miss", nil) + memcacheCleanReadMeter = metrics.NewRegisteredMeter("trie/memcache/clean/read", nil) + memcacheCleanWriteMeter = metrics.NewRegisteredMeter("trie/memcache/clean/write", nil) + + memcacheDirtyHitMeter = metrics.NewRegisteredMeter("trie/memcache/dirty/hit", nil) + memcacheDirtyMissMeter = metrics.NewRegisteredMeter("trie/memcache/dirty/miss", nil) + memcacheDirtyReadMeter = metrics.NewRegisteredMeter("trie/memcache/dirty/read", nil) + memcacheDirtyWriteMeter = metrics.NewRegisteredMeter("trie/memcache/dirty/write", nil) + + memcacheFlushTimeTimer = metrics.NewRegisteredResettingTimer("trie/memcache/flush/time", nil) + memcacheFlushNodesMeter = metrics.NewRegisteredMeter("trie/memcache/flush/nodes", nil) + memcacheFlushSizeMeter = metrics.NewRegisteredMeter("trie/memcache/flush/size", nil) + + memcacheGCTimeTimer = metrics.NewRegisteredResettingTimer("trie/memcache/gc/time", nil) + memcacheGCNodesMeter = metrics.NewRegisteredMeter("trie/memcache/gc/nodes", nil) + memcacheGCSizeMeter = metrics.NewRegisteredMeter("trie/memcache/gc/size", nil) + + memcacheCommitTimeTimer = metrics.NewRegisteredResettingTimer("trie/memcache/commit/time", nil) + memcacheCommitNodesMeter = metrics.NewRegisteredMeter("trie/memcache/commit/nodes", nil) + memcacheCommitSizeMeter = metrics.NewRegisteredMeter("trie/memcache/commit/size", nil) +) diff --git a/trie_by_cid/trie/database_node.go b/trie_by_cid/trie/database_node.go new file mode 100644 index 0000000..58037ae --- /dev/null +++ b/trie_by_cid/trie/database_node.go @@ -0,0 +1,183 @@ +package trie + +import ( + "fmt" + "io" + "reflect" + + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/rlp" +) + +// rawNode is a simple binary blob used to differentiate between collapsed trie +// nodes and already encoded RLP binary blobs (while at the same time store them +// in the same cache fields). +type rawNode []byte + +func (n rawNode) cache() (hashNode, bool) { panic("this should never end up in a live trie") } +func (n rawNode) fstring(ind string) string { panic("this should never end up in a live trie") } + +func (n rawNode) EncodeRLP(w io.Writer) error { + _, err := w.Write(n) + return err +} + +// rawFullNode represents only the useful data content of a full node, with the +// caches and flags stripped out to minimize its data storage. This type honors +// the same RLP encoding as the original parent. +type rawFullNode [17]node + +func (n rawFullNode) cache() (hashNode, bool) { panic("this should never end up in a live trie") } +func (n rawFullNode) fstring(ind string) string { panic("this should never end up in a live trie") } + +func (n rawFullNode) EncodeRLP(w io.Writer) error { + eb := rlp.NewEncoderBuffer(w) + n.encode(eb) + return eb.Flush() +} + +// rawShortNode represents only the useful data content of a short node, with the +// caches and flags stripped out to minimize its data storage. This type honors +// the same RLP encoding as the original parent. +type rawShortNode struct { + Key []byte + Val node +} + +func (n rawShortNode) cache() (hashNode, bool) { panic("this should never end up in a live trie") } +func (n rawShortNode) fstring(ind string) string { panic("this should never end up in a live trie") } + +// cachedNode is all the information we know about a single cached trie node +// in the memory database write layer. +type cachedNode struct { + node node // Cached collapsed trie node, or raw rlp data + size uint16 // Byte size of the useful cached data + + parents uint32 // Number of live nodes referencing this one + children map[common.Hash]uint16 // External children referenced by this node + + flushPrev common.Hash // Previous node in the flush-list + flushNext common.Hash // Next node in the flush-list +} + +// cachedNodeSize is the raw size of a cachedNode data structure without any +// node data included. It's an approximate size, but should be a lot better +// than not counting them. +var cachedNodeSize = int(reflect.TypeOf(cachedNode{}).Size()) + +// cachedNodeChildrenSize is the raw size of an initialized but empty external +// reference map. +const cachedNodeChildrenSize = 48 + +// rlp returns the raw rlp encoded blob of the cached trie node, either directly +// from the cache, or by regenerating it from the collapsed node. +func (n *cachedNode) rlp() []byte { + if node, ok := n.node.(rawNode); ok { + return node + } + return nodeToBytes(n.node) +} + +// obj returns the decoded and expanded trie node, either directly from the cache, +// or by regenerating it from the rlp encoded blob. +func (n *cachedNode) obj(hash common.Hash) node { + if node, ok := n.node.(rawNode); ok { + // The raw-blob format nodes are loaded either from the + // clean cache or the database, they are all in their own + // copy and safe to use unsafe decoder. + return mustDecodeNodeUnsafe(hash[:], node) + } + return expandNode(hash[:], n.node) +} + +// forChilds invokes the callback for all the tracked children of this node, +// both the implicit ones from inside the node as well as the explicit ones +// from outside the node. +func (n *cachedNode) forChilds(onChild func(hash common.Hash)) { + for child := range n.children { + onChild(child) + } + if _, ok := n.node.(rawNode); !ok { + forGatherChildren(n.node, onChild) + } +} + +// forGatherChildren traverses the node hierarchy of a collapsed storage node and +// invokes the callback for all the hashnode children. +func forGatherChildren(n node, onChild func(hash common.Hash)) { + switch n := n.(type) { + case *rawShortNode: + forGatherChildren(n.Val, onChild) + case rawFullNode: + for i := 0; i < 16; i++ { + forGatherChildren(n[i], onChild) + } + case hashNode: + onChild(common.BytesToHash(n)) + case valueNode, nil, rawNode: + default: + panic(fmt.Sprintf("unknown node type: %T", n)) + } +} + +// simplifyNode traverses the hierarchy of an expanded memory node and discards +// all the internal caches, returning a node that only contains the raw data. +func simplifyNode(n node) node { + switch n := n.(type) { + case *shortNode: + // Short nodes discard the flags and cascade + return &rawShortNode{Key: n.Key, Val: simplifyNode(n.Val)} + + case *fullNode: + // Full nodes discard the flags and cascade + node := rawFullNode(n.Children) + for i := 0; i < len(node); i++ { + if node[i] != nil { + node[i] = simplifyNode(node[i]) + } + } + return node + + case valueNode, hashNode, rawNode: + return n + + default: + panic(fmt.Sprintf("unknown node type: %T", n)) + } +} + +// expandNode traverses the node hierarchy of a collapsed storage node and converts +// all fields and keys into expanded memory form. +func expandNode(hash hashNode, n node) node { + switch n := n.(type) { + case *rawShortNode: + // Short nodes need key and child expansion + return &shortNode{ + Key: compactToHex(n.Key), + Val: expandNode(nil, n.Val), + flags: nodeFlag{ + hash: hash, + }, + } + + case rawFullNode: + // Full nodes need child expansion + node := &fullNode{ + flags: nodeFlag{ + hash: hash, + }, + } + for i := 0; i < len(node.Children); i++ { + if n[i] != nil { + node.Children[i] = expandNode(nil, n[i]) + } + } + return node + + case valueNode, hashNode: + return n + + default: + panic(fmt.Sprintf("unknown node type: %T", n)) + } +} diff --git a/trie_by_cid/trie/database_test.go b/trie_by_cid/trie/database_test.go index a800135..5a6e36d 100644 --- a/trie_by_cid/trie/database_test.go +++ b/trie_by_cid/trie/database_test.go @@ -14,21 +14,20 @@ // You should have received a copy of the GNU Lesser General Public License // along with the go-ethereum library. If not, see . -package trie_test +package trie import ( "testing" - "github.com/ethereum/go-ethereum/ethdb/memorydb" - - "github.com/cerc-io/ipld-eth-statedb/trie_by_cid/trie" + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/core/rawdb" ) // Tests that the trie database returns a missing trie node error if attempting // to retrieve the meta root. func TestDatabaseMetarootFetch(t *testing.T) { - db := trie.NewDatabase(memorydb.New()) - if _, err := db.Node(trie.CidKey{}); err == nil { + db := NewDatabase(rawdb.NewMemoryDatabase()) + if _, err := db.Node(common.Hash{}, StateTrieCodec); err == nil { t.Fatalf("metaroot retrieval succeeded") } } diff --git a/trie_by_cid/trie/encoding.go b/trie_by_cid/trie/encoding.go index 381883a..ace4570 100644 --- a/trie_by_cid/trie/encoding.go +++ b/trie_by_cid/trie/encoding.go @@ -16,8 +16,81 @@ package trie +// Trie keys are dealt with in three distinct encodings: +// +// KEYBYTES encoding contains the actual key and nothing else. This encoding is the +// input to most API functions. +// +// HEX encoding contains one byte for each nibble of the key and an optional trailing +// 'terminator' byte of value 0x10 which indicates whether or not the node at the key +// contains a value. Hex key encoding is used for nodes loaded in memory because it's +// convenient to access. +// +// COMPACT encoding is defined by the Ethereum Yellow Paper (it's called "hex prefix +// encoding" there) and contains the bytes of the key and a flag. The high nibble of the +// first byte contains the flag; the lowest bit encoding the oddness of the length and +// the second-lowest encoding whether the node at the key is a value node. The low nibble +// of the first byte is zero in the case of an even number of nibbles and the first nibble +// in the case of an odd number. All remaining nibbles (now an even number) fit properly +// into the remaining bytes. Compact encoding is used for nodes stored on disk. + +// HexToCompact converts a hex path to the compact encoded format +func HexToCompact(hex []byte) []byte { + return hexToCompact(hex) +} + +func hexToCompact(hex []byte) []byte { + terminator := byte(0) + if hasTerm(hex) { + terminator = 1 + hex = hex[:len(hex)-1] + } + buf := make([]byte, len(hex)/2+1) + buf[0] = terminator << 5 // the flag byte + if len(hex)&1 == 1 { + buf[0] |= 1 << 4 // odd flag + buf[0] |= hex[0] // first nibble is contained in the first byte + hex = hex[1:] + } + decodeNibbles(hex, buf[1:]) + return buf +} + +// hexToCompactInPlace places the compact key in input buffer, returning the length +// needed for the representation +func hexToCompactInPlace(hex []byte) int { + var ( + hexLen = len(hex) // length of the hex input + firstByte = byte(0) + ) + // Check if we have a terminator there + if hexLen > 0 && hex[hexLen-1] == 16 { + firstByte = 1 << 5 + hexLen-- // last part was the terminator, ignore that + } + var ( + binLen = hexLen/2 + 1 + ni = 0 // index in hex + bi = 1 // index in bin (compact) + ) + if hexLen&1 == 1 { + firstByte |= 1 << 4 // odd flag + firstByte |= hex[0] // first nibble is contained in the first byte + ni++ + } + for ; ni < hexLen; bi, ni = bi+1, ni+2 { + hex[bi] = hex[ni]<<4 | hex[ni+1] + } + hex[0] = firstByte + return binLen +} + // CompactToHex converts a compact encoded path to hex format func CompactToHex(compact []byte) []byte { + return compactToHex(compact) +} + +func compactToHex(compact []byte) []byte { if len(compact) == 0 { return compact } @@ -62,6 +135,20 @@ func decodeNibbles(nibbles []byte, bytes []byte) { } } +// prefixLen returns the length of the common prefix of a and b. +func prefixLen(a, b []byte) int { + var i, length = 0, len(a) + if len(b) < length { + length = len(b) + } + for ; i < length; i++ { + if a[i] != b[i] { + break + } + } + return i +} + // hasTerm returns whether a hex key has the terminator flag. func hasTerm(s []byte) bool { return len(s) > 0 && s[len(s)-1] == 16 diff --git a/trie_by_cid/trie/encoding_test.go b/trie_by_cid/trie/encoding_test.go new file mode 100644 index 0000000..abc1e9d --- /dev/null +++ b/trie_by_cid/trie/encoding_test.go @@ -0,0 +1,141 @@ +// Copyright 2014 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package trie + +import ( + "bytes" + crand "crypto/rand" + "encoding/hex" + "math/rand" + "testing" +) + +func TestHexCompact(t *testing.T) { + tests := []struct{ hex, compact []byte }{ + // empty keys, with and without terminator. + {hex: []byte{}, compact: []byte{0x00}}, + {hex: []byte{16}, compact: []byte{0x20}}, + // odd length, no terminator + {hex: []byte{1, 2, 3, 4, 5}, compact: []byte{0x11, 0x23, 0x45}}, + // even length, no terminator + {hex: []byte{0, 1, 2, 3, 4, 5}, compact: []byte{0x00, 0x01, 0x23, 0x45}}, + // odd length, terminator + {hex: []byte{15, 1, 12, 11, 8, 16 /*term*/}, compact: []byte{0x3f, 0x1c, 0xb8}}, + // even length, terminator + {hex: []byte{0, 15, 1, 12, 11, 8, 16 /*term*/}, compact: []byte{0x20, 0x0f, 0x1c, 0xb8}}, + } + for _, test := range tests { + if c := hexToCompact(test.hex); !bytes.Equal(c, test.compact) { + t.Errorf("hexToCompact(%x) -> %x, want %x", test.hex, c, test.compact) + } + if h := compactToHex(test.compact); !bytes.Equal(h, test.hex) { + t.Errorf("compactToHex(%x) -> %x, want %x", test.compact, h, test.hex) + } + } +} + +func TestHexKeybytes(t *testing.T) { + tests := []struct{ key, hexIn, hexOut []byte }{ + {key: []byte{}, hexIn: []byte{16}, hexOut: []byte{16}}, + {key: []byte{}, hexIn: []byte{}, hexOut: []byte{16}}, + { + key: []byte{0x12, 0x34, 0x56}, + hexIn: []byte{1, 2, 3, 4, 5, 6, 16}, + hexOut: []byte{1, 2, 3, 4, 5, 6, 16}, + }, + { + key: []byte{0x12, 0x34, 0x5}, + hexIn: []byte{1, 2, 3, 4, 0, 5, 16}, + hexOut: []byte{1, 2, 3, 4, 0, 5, 16}, + }, + { + key: []byte{0x12, 0x34, 0x56}, + hexIn: []byte{1, 2, 3, 4, 5, 6}, + hexOut: []byte{1, 2, 3, 4, 5, 6, 16}, + }, + } + for _, test := range tests { + if h := keybytesToHex(test.key); !bytes.Equal(h, test.hexOut) { + t.Errorf("keybytesToHex(%x) -> %x, want %x", test.key, h, test.hexOut) + } + if k := hexToKeyBytes(test.hexIn); !bytes.Equal(k, test.key) { + t.Errorf("hexToKeyBytes(%x) -> %x, want %x", test.hexIn, k, test.key) + } + } +} + +func TestHexToCompactInPlace(t *testing.T) { + for i, key := range []string{ + "00", + "060a040c0f000a090b040803010801010900080d090a0a0d0903000b10", + "10", + } { + hexBytes, _ := hex.DecodeString(key) + exp := hexToCompact(hexBytes) + sz := hexToCompactInPlace(hexBytes) + got := hexBytes[:sz] + if !bytes.Equal(exp, got) { + t.Fatalf("test %d: encoding err\ninp %v\ngot %x\nexp %x\n", i, key, got, exp) + } + } +} + +func TestHexToCompactInPlaceRandom(t *testing.T) { + for i := 0; i < 10000; i++ { + l := rand.Intn(128) + key := make([]byte, l) + crand.Read(key) + hexBytes := keybytesToHex(key) + hexOrig := []byte(string(hexBytes)) + exp := hexToCompact(hexBytes) + sz := hexToCompactInPlace(hexBytes) + got := hexBytes[:sz] + + if !bytes.Equal(exp, got) { + t.Fatalf("encoding err \ncpt %x\nhex %x\ngot %x\nexp %x\n", + key, hexOrig, got, exp) + } + } +} + +func BenchmarkHexToCompact(b *testing.B) { + testBytes := []byte{0, 15, 1, 12, 11, 8, 16 /*term*/} + for i := 0; i < b.N; i++ { + hexToCompact(testBytes) + } +} + +func BenchmarkCompactToHex(b *testing.B) { + testBytes := []byte{0, 15, 1, 12, 11, 8, 16 /*term*/} + for i := 0; i < b.N; i++ { + compactToHex(testBytes) + } +} + +func BenchmarkKeybytesToHex(b *testing.B) { + testBytes := []byte{7, 6, 6, 5, 7, 2, 6, 2, 16} + for i := 0; i < b.N; i++ { + keybytesToHex(testBytes) + } +} + +func BenchmarkHexToKeybytes(b *testing.B) { + testBytes := []byte{7, 6, 6, 5, 7, 2, 6, 2, 16} + for i := 0; i < b.N; i++ { + hexToKeyBytes(testBytes) + } +} diff --git a/trie_by_cid/trie/errors.go b/trie_by_cid/trie/errors.go index 3881487..afe344b 100644 --- a/trie_by_cid/trie/errors.go +++ b/trie_by_cid/trie/errors.go @@ -27,7 +27,7 @@ import ( // information necessary for retrieving the missing node. type MissingNodeError struct { Owner common.Hash // owner of the trie if it's 2-layered trie - NodeHash []byte // hash of the missing node + NodeHash common.Hash // hash of the missing node Path []byte // hex-encoded path to the missing node err error // concrete error for missing trie node } diff --git a/trie_by_cid/trie/hasher.go b/trie_by_cid/trie/hasher.go index caa80f0..e594d6d 100644 --- a/trie_by_cid/trie/hasher.go +++ b/trie_by_cid/trie/hasher.go @@ -21,7 +21,6 @@ import ( "github.com/ethereum/go-ethereum/crypto" "github.com/ethereum/go-ethereum/rlp" - "github.com/ethereum/go-ethereum/trie" "golang.org/x/crypto/sha3" ) @@ -99,7 +98,7 @@ func (h *hasher) hashShortNodeChildren(n *shortNode) (collapsed, cached *shortNo // Previously, we did copy this one. We don't seem to need to actually // do that, since we don't overwrite/reuse keys //cached.Key = common.CopyBytes(n.Key) - collapsed.Key = trie.HexToCompact(n.Key) + collapsed.Key = hexToCompact(n.Key) // Unless the child is a valuenode or hashnode, hash it switch n.Val.(type) { case *fullNode, *shortNode: @@ -171,8 +170,8 @@ func (h *hasher) fullnodeToHash(n *fullNode, force bool) node { // // All node encoding must be done like this: // -// node.encode(h.encbuf) -// enc := h.encodedBytes() +// node.encode(h.encbuf) +// enc := h.encodedBytes() // // This convention exists because node.encode can only be inlined/escape-analyzed when // called on a concrete receiver type. diff --git a/trie_by_cid/trie/iterator.go b/trie_by_cid/trie/iterator.go index 55e7a96..de506eb 100644 --- a/trie_by_cid/trie/iterator.go +++ b/trie_by_cid/trie/iterator.go @@ -18,14 +18,26 @@ package trie import ( "bytes" + "container/heap" "errors" + "time" + + "github.com/ethereum/go-ethereum/statediff/indexer/database/metrics" "github.com/ethereum/go-ethereum/common" - "github.com/ethereum/go-ethereum/trie" + "github.com/ethereum/go-ethereum/core/types" + gethtrie "github.com/ethereum/go-ethereum/trie" ) -// NodeIterator is a re-export of the go-ethereum interface -type NodeIterator = trie.NodeIterator +// NodeIterator is an iterator to traverse the trie pre-order. +type NodeIterator = gethtrie.NodeIterator + +// NodeResolver is used for looking up trie nodes before reaching into the real +// persistent layer. This is not mandatory, rather is an optimization for cases +// where trie nodes can be recovered from some external mechanism without reading +// from disk. In those cases, this resolver allows short circuiting accesses and +// returning them from memory. +type NodeResolver = gethtrie.NodeResolver // Iterator is a key-value trie iterator that traverses a Trie. type Iterator struct { @@ -82,7 +94,7 @@ type nodeIterator struct { path []byte // Path to the current node err error // Failure set in case of an internal error in the iterator - resolver trie.NodeResolver // Optional intermediate resolver above the disk layer + resolver NodeResolver // optional node resolver for avoiding disk hits } // errIteratorEnd is stored in nodeIterator.err when iteration is done. @@ -99,7 +111,7 @@ func (e seekError) Error() string { } func newNodeIterator(trie *Trie, start []byte) NodeIterator { - if trie.Hash() == emptyRoot { + if trie.Hash() == types.EmptyRootHash { return &nodeIterator{ trie: trie, err: errIteratorEnd, @@ -110,7 +122,7 @@ func newNodeIterator(trie *Trie, start []byte) NodeIterator { return it } -func (it *nodeIterator) AddResolver(resolver trie.NodeResolver) { +func (it *nodeIterator) AddResolver(resolver NodeResolver) { it.resolver = resolver } @@ -128,6 +140,14 @@ func (it *nodeIterator) Parent() common.Hash { return it.stack[len(it.stack)-1].parent } +func (it *nodeIterator) ParentPath() []byte { + if len(it.stack) == 0 { + return []byte{} + } + pathlen := it.stack[len(it.stack)-1].pathlen + return it.path[:pathlen] +} + func (it *nodeIterator) Leaf() bool { return hasTerm(it.path) } @@ -241,7 +261,7 @@ func (it *nodeIterator) seek(prefix []byte) error { func (it *nodeIterator) init() (*nodeIteratorState, error) { root := it.trie.Hash() state := &nodeIteratorState{node: it.trie.root, index: -1} - if root != emptyRoot { + if root != types.EmptyRootHash { state.hash = root } return state, state.resolve(it, nil) @@ -320,7 +340,12 @@ func (it *nodeIterator) resolveHash(hash hashNode, path []byte) (node, error) { } } } - return it.trie.resolveHash(hash, path) + // Retrieve the specified node from the underlying node reader. + // it.trie.resolveAndTrack is not used since in that function the + // loaded blob will be tracked, while it's not required here since + // all loaded nodes won't be linked to trie at all and track nodes + // may lead to out-of-memory issue. + return it.trie.reader.node(path, common.BytesToHash(hash)) } func (it *nodeIterator) resolveBlob(hash hashNode, path []byte) ([]byte, error) { @@ -329,7 +354,12 @@ func (it *nodeIterator) resolveBlob(hash hashNode, path []byte) ([]byte, error) return blob, nil } } - return it.trie.resolveBlob(hash, path) + // Retrieve the specified node from the underlying node reader. + // it.trie.resolveAndTrack is not used since in that function the + // loaded blob will be tracked, while it's not required here since + // all loaded nodes won't be linked to trie at all and track nodes + // may lead to out-of-memory issue. + return it.trie.reader.nodeBlob(path, common.BytesToHash(hash)) } func (st *nodeIteratorState) resolve(it *nodeIterator, path []byte) error { @@ -455,3 +485,248 @@ func (it *nodeIterator) pop() { it.stack[len(it.stack)-1] = nil it.stack = it.stack[:len(it.stack)-1] } + +func compareNodes(a, b NodeIterator) int { + if cmp := bytes.Compare(a.Path(), b.Path()); cmp != 0 { + return cmp + } + if a.Leaf() && !b.Leaf() { + return -1 + } else if b.Leaf() && !a.Leaf() { + return 1 + } + if cmp := bytes.Compare(a.Hash().Bytes(), b.Hash().Bytes()); cmp != 0 { + return cmp + } + if a.Leaf() && b.Leaf() { + return bytes.Compare(a.LeafBlob(), b.LeafBlob()) + } + return 0 +} + +type differenceIterator struct { + a, b NodeIterator // Nodes returned are those in b - a. + eof bool // Indicates a has run out of elements + count int // Number of nodes scanned on either trie +} + +// NewDifferenceIterator constructs a NodeIterator that iterates over elements in b that +// are not in a. Returns the iterator, and a pointer to an integer recording the number +// of nodes seen. +func NewDifferenceIterator(a, b NodeIterator) (NodeIterator, *int) { + a.Next(true) + it := &differenceIterator{ + a: a, + b: b, + } + return it, &it.count +} + +func (it *differenceIterator) Hash() common.Hash { + return it.b.Hash() +} + +func (it *differenceIterator) Parent() common.Hash { + return it.b.Parent() +} + +func (it *differenceIterator) ParentPath() []byte { + return it.b.ParentPath() +} + +func (it *differenceIterator) Leaf() bool { + return it.b.Leaf() +} + +func (it *differenceIterator) LeafKey() []byte { + return it.b.LeafKey() +} + +func (it *differenceIterator) LeafBlob() []byte { + return it.b.LeafBlob() +} + +func (it *differenceIterator) LeafProof() [][]byte { + return it.b.LeafProof() +} + +func (it *differenceIterator) Path() []byte { + return it.b.Path() +} + +func (it *differenceIterator) NodeBlob() []byte { + return it.b.NodeBlob() +} + +func (it *differenceIterator) AddResolver(resolver NodeResolver) { + panic("not implemented") +} + +func (it *differenceIterator) Next(bool) bool { + defer metrics.UpdateDuration(time.Now(), metrics.IndexerMetrics.DifferenceIteratorNextTimer) + // Invariants: + // - We always advance at least one element in b. + // - At the start of this function, a's path is lexically greater than b's. + if !it.b.Next(true) { + return false + } + it.count++ + + if it.eof { + // a has reached eof, so we just return all elements from b + return true + } + + for { + switch compareNodes(it.a, it.b) { + case -1: + // b jumped past a; advance a + if !it.a.Next(true) { + it.eof = true + return true + } + it.count++ + case 1: + // b is before a + return true + case 0: + // a and b are identical; skip this whole subtree if the nodes have hashes + hasHash := it.a.Hash() == common.Hash{} + if !it.b.Next(hasHash) { + return false + } + it.count++ + if !it.a.Next(hasHash) { + it.eof = true + return true + } + it.count++ + } + } +} + +func (it *differenceIterator) Error() error { + if err := it.a.Error(); err != nil { + return err + } + return it.b.Error() +} + +type nodeIteratorHeap []NodeIterator + +func (h nodeIteratorHeap) Len() int { return len(h) } +func (h nodeIteratorHeap) Less(i, j int) bool { return compareNodes(h[i], h[j]) < 0 } +func (h nodeIteratorHeap) Swap(i, j int) { h[i], h[j] = h[j], h[i] } +func (h *nodeIteratorHeap) Push(x interface{}) { *h = append(*h, x.(NodeIterator)) } +func (h *nodeIteratorHeap) Pop() interface{} { + n := len(*h) + x := (*h)[n-1] + *h = (*h)[0 : n-1] + return x +} + +type unionIterator struct { + items *nodeIteratorHeap // Nodes returned are the union of the ones in these iterators + count int // Number of nodes scanned across all tries +} + +// NewUnionIterator constructs a NodeIterator that iterates over elements in the union +// of the provided NodeIterators. Returns the iterator, and a pointer to an integer +// recording the number of nodes visited. +func NewUnionIterator(iters []NodeIterator) (NodeIterator, *int) { + h := make(nodeIteratorHeap, len(iters)) + copy(h, iters) + heap.Init(&h) + + ui := &unionIterator{items: &h} + return ui, &ui.count +} + +func (it *unionIterator) Hash() common.Hash { + return (*it.items)[0].Hash() +} + +func (it *unionIterator) Parent() common.Hash { + return (*it.items)[0].Parent() +} + +func (it *unionIterator) ParentPath() []byte { + return (*it.items)[0].ParentPath() +} + +func (it *unionIterator) Leaf() bool { + return (*it.items)[0].Leaf() +} + +func (it *unionIterator) LeafKey() []byte { + return (*it.items)[0].LeafKey() +} + +func (it *unionIterator) LeafBlob() []byte { + return (*it.items)[0].LeafBlob() +} + +func (it *unionIterator) LeafProof() [][]byte { + return (*it.items)[0].LeafProof() +} + +func (it *unionIterator) Path() []byte { + return (*it.items)[0].Path() +} + +func (it *unionIterator) NodeBlob() []byte { + return (*it.items)[0].NodeBlob() +} + +func (it *unionIterator) AddResolver(resolver NodeResolver) { + panic("not implemented") +} + +// Next returns the next node in the union of tries being iterated over. +// +// It does this by maintaining a heap of iterators, sorted by the iteration +// order of their next elements, with one entry for each source trie. Each +// time Next() is called, it takes the least element from the heap to return, +// advancing any other iterators that also point to that same element. These +// iterators are called with descend=false, since we know that any nodes under +// these nodes will also be duplicates, found in the currently selected iterator. +// Whenever an iterator is advanced, it is pushed back into the heap if it still +// has elements remaining. +// +// In the case that descend=false - eg, we're asked to ignore all subnodes of the +// current node - we also advance any iterators in the heap that have the current +// path as a prefix. +func (it *unionIterator) Next(descend bool) bool { + if len(*it.items) == 0 { + return false + } + + // Get the next key from the union + least := heap.Pop(it.items).(NodeIterator) + + // Skip over other nodes as long as they're identical, or, if we're not descending, as + // long as they have the same prefix as the current node. + for len(*it.items) > 0 && ((!descend && bytes.HasPrefix((*it.items)[0].Path(), least.Path())) || compareNodes(least, (*it.items)[0]) == 0) { + skipped := heap.Pop(it.items).(NodeIterator) + // Skip the whole subtree if the nodes have hashes; otherwise just skip this node + if skipped.Next(skipped.Hash() == common.Hash{}) { + it.count++ + // If there are more elements, push the iterator back on the heap + heap.Push(it.items, skipped) + } + } + if least.Next(descend) { + it.count++ + heap.Push(it.items, least) + } + return len(*it.items) > 0 +} + +func (it *unionIterator) Error() error { + for i := 0; i < len(*it.items); i++ { + if err := (*it.items)[i].Error(); err != nil { + return err + } + } + return nil +} diff --git a/trie_by_cid/trie/iterator_test.go b/trie_by_cid/trie/iterator_test.go index f7f0103..8b606ef 100644 --- a/trie_by_cid/trie/iterator_test.go +++ b/trie_by_cid/trie/iterator_test.go @@ -14,28 +14,24 @@ // You should have received a copy of the GNU Lesser General Public License // along with the go-ethereum library. If not, see . -package trie_test +package trie import ( "bytes" - "context" + "encoding/binary" "fmt" + "math/rand" "testing" "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/core/rawdb" - "github.com/ethereum/go-ethereum/statediff/indexer/database/sql/postgres" + "github.com/ethereum/go-ethereum/crypto" + "github.com/ethereum/go-ethereum/ethdb/memorydb" geth_trie "github.com/ethereum/go-ethereum/trie" - - "github.com/cerc-io/ipld-eth-statedb/trie_by_cid/trie" ) var ( - dbConfig, _ = postgres.DefaultConfig.WithEnv() - trieConfig = trie.Config{Cache: 256} - ctx = context.Background() - - testdata0 = []kvs{ + packableTestData = []kvsi{ {"one", 1}, {"two", 2}, {"three", 3}, @@ -43,20 +39,10 @@ var ( {"five", 5}, {"ten", 10}, } - testdata1 = []kvs{ - {"barb", 0}, - {"bard", 1}, - {"bars", 2}, - {"bar", 3}, - {"fab", 4}, - {"food", 5}, - {"foos", 6}, - {"foo", 7}, - } ) func TestEmptyIterator(t *testing.T) { - trie := trie.NewEmpty(trie.NewDatabase(rawdb.NewMemoryDatabase())) + trie := NewEmpty(NewDatabase(rawdb.NewMemoryDatabase())) iter := trie.NodeIterator(nil) seen := make(map[string]struct{}) @@ -69,63 +55,163 @@ func TestEmptyIterator(t *testing.T) { } func TestIterator(t *testing.T) { - edb := rawdb.NewMemoryDatabase() - db := geth_trie.NewDatabase(edb) - origtrie := geth_trie.NewEmpty(db) - all, err := updateTrie(origtrie, testdata0) - if err != nil { - t.Fatal(err) + db := NewDatabase(rawdb.NewMemoryDatabase()) + trie := NewEmpty(db) + vals := []struct{ k, v string }{ + {"do", "verb"}, + {"ether", "wookiedoo"}, + {"horse", "stallion"}, + {"shaman", "horse"}, + {"doge", "coin"}, + {"dog", "puppy"}, + {"somethingveryoddindeedthis is", "myothernodedata"}, } - // commit and index data - root := commitTrie(t, db, origtrie) - tr := indexTrie(t, edb, root) + all := make(map[string]string) + for _, val := range vals { + all[val.k] = val.v + trie.Update([]byte(val.k), []byte(val.v)) + } + root, nodes := trie.Commit(false) + db.Update(NewWithNodeSet(nodes)) - found := make(map[string]int64) - it := trie.NewIterator(tr.NodeIterator(nil)) + trie, _ = New(TrieID(root), db, StateTrieCodec) + found := make(map[string]string) + it := NewIterator(trie.NodeIterator(nil)) for it.Next() { - found[string(it.Key)] = unpackValue(it.Value) + found[string(it.Key)] = string(it.Value) } - if len(found) != len(all) { - t.Errorf("number of iterated values do not match: want %d, found %d", len(all), len(found)) - } - for k, kv := range all { - if found[k] != kv.v { - t.Errorf("iterator value mismatch for %s: got %q want %q", k, found[k], kv.v) + for k, v := range all { + if found[k] != v { + t.Errorf("iterator value mismatch for %s: got %q want %q", k, found[k], v) } } } -func TestIteratorSeek(t *testing.T) { - edb := rawdb.NewMemoryDatabase() - db := geth_trie.NewDatabase(edb) - orig := geth_trie.NewEmpty(geth_trie.NewDatabase(rawdb.NewMemoryDatabase())) - if _, err := updateTrie(orig, testdata1); err != nil { - t.Fatal(err) +type kv struct { + k, v []byte + t bool +} + +func TestIteratorLargeData(t *testing.T) { + trie := NewEmpty(NewDatabase(rawdb.NewMemoryDatabase())) + vals := make(map[string]*kv) + + for i := byte(0); i < 255; i++ { + value := &kv{common.LeftPadBytes([]byte{i}, 32), []byte{i}, false} + value2 := &kv{common.LeftPadBytes([]byte{10, i}, 32), []byte{i}, false} + trie.Update(value.k, value.v) + trie.Update(value2.k, value2.v) + vals[string(value.k)] = value + vals[string(value2.k)] = value2 + } + + it := NewIterator(trie.NodeIterator(nil)) + for it.Next() { + vals[string(it.Key)].t = true + } + + var untouched []*kv + for _, value := range vals { + if !value.t { + untouched = append(untouched, value) + } + } + + if len(untouched) > 0 { + t.Errorf("Missed %d nodes", len(untouched)) + for _, value := range untouched { + t.Error(value) + } + } +} + +// Tests that the node iterator indeed walks over the entire database contents. +func TestNodeIteratorCoverage(t *testing.T) { + db, trie, _ := makeTestTrie(t) + // Create some arbitrary test trie to iterate + + // Gather all the node hashes found by the iterator + hashes := make(map[common.Hash]struct{}) + for it := trie.NodeIterator(nil); it.Next(true); { + if it.Hash() != (common.Hash{}) { + hashes[it.Hash()] = struct{}{} + } + } + // Cross check the hashes and the database itself + for hash := range hashes { + if _, err := db.Node(hash, StateTrieCodec); err != nil { + t.Errorf("failed to retrieve reported node %x: %v", hash, err) + } + } + for hash, obj := range db.dirties { + if obj != nil && hash != (common.Hash{}) { + if _, ok := hashes[hash]; !ok { + t.Errorf("state entry not reported %x", hash) + } + } + } + it := db.diskdb.NewIterator(nil, nil) + for it.Next() { + key := it.Key() + if _, ok := hashes[common.BytesToHash(key)]; !ok { + t.Errorf("state entry not reported %x", key) + } + } + it.Release() +} + +type kvs struct{ k, v string } + +var testdata1 = []kvs{ + {"barb", "ba"}, + {"bard", "bc"}, + {"bars", "bb"}, + {"bar", "b"}, + {"fab", "z"}, + {"food", "ab"}, + {"foos", "aa"}, + {"foo", "a"}, +} + +var testdata2 = []kvs{ + {"aardvark", "c"}, + {"bar", "b"}, + {"barb", "bd"}, + {"bars", "be"}, + {"fab", "z"}, + {"foo", "a"}, + {"foos", "aa"}, + {"food", "ab"}, + {"jars", "d"}, +} + +func TestIteratorSeek(t *testing.T) { + trie := NewEmpty(NewDatabase(rawdb.NewMemoryDatabase())) + for _, val := range testdata1 { + trie.Update([]byte(val.k), []byte(val.v)) } - root := commitTrie(t, db, orig) - tr := indexTrie(t, edb, root) // Seek to the middle. - it := trie.NewIterator(tr.NodeIterator([]byte("fab"))) + it := NewIterator(trie.NodeIterator([]byte("fab"))) if err := checkIteratorOrder(testdata1[4:], it); err != nil { t.Fatal(err) } // Seek to a non-existent key. - it = trie.NewIterator(tr.NodeIterator([]byte("barc"))) + it = NewIterator(trie.NodeIterator([]byte("barc"))) if err := checkIteratorOrder(testdata1[1:], it); err != nil { t.Fatal(err) } // Seek beyond the end. - it = trie.NewIterator(tr.NodeIterator([]byte("z"))) + it = NewIterator(trie.NodeIterator([]byte("z"))) if err := checkIteratorOrder(nil, it); err != nil { t.Fatal(err) } } -func checkIteratorOrder(want []kvs, it *trie.Iterator) error { +func checkIteratorOrder(want []kvs, it *Iterator) error { for it.Next() { if len(want) == 0 { return fmt.Errorf("didn't expect any more values, got key %q", it.Key) @@ -141,11 +227,366 @@ func checkIteratorOrder(want []kvs, it *trie.Iterator) error { return nil } +func TestDifferenceIterator(t *testing.T) { + dba := NewDatabase(rawdb.NewMemoryDatabase()) + triea := NewEmpty(dba) + for _, val := range testdata1 { + triea.Update([]byte(val.k), []byte(val.v)) + } + rootA, nodesA := triea.Commit(false) + dba.Update(NewWithNodeSet(nodesA)) + triea, _ = New(TrieID(rootA), dba, StateTrieCodec) + + dbb := NewDatabase(rawdb.NewMemoryDatabase()) + trieb := NewEmpty(dbb) + for _, val := range testdata2 { + trieb.Update([]byte(val.k), []byte(val.v)) + } + rootB, nodesB := trieb.Commit(false) + dbb.Update(NewWithNodeSet(nodesB)) + trieb, _ = New(TrieID(rootB), dbb, StateTrieCodec) + + found := make(map[string]string) + di, _ := NewDifferenceIterator(triea.NodeIterator(nil), trieb.NodeIterator(nil)) + it := NewIterator(di) + for it.Next() { + found[string(it.Key)] = string(it.Value) + } + + all := []struct{ k, v string }{ + {"aardvark", "c"}, + {"barb", "bd"}, + {"bars", "be"}, + {"jars", "d"}, + } + for _, item := range all { + if found[item.k] != item.v { + t.Errorf("iterator value mismatch for %s: got %v want %v", item.k, found[item.k], item.v) + } + } + if len(found) != len(all) { + t.Errorf("iterator count mismatch: got %d values, want %d", len(found), len(all)) + } +} + +func TestUnionIterator(t *testing.T) { + dba := NewDatabase(rawdb.NewMemoryDatabase()) + triea := NewEmpty(dba) + for _, val := range testdata1 { + triea.Update([]byte(val.k), []byte(val.v)) + } + rootA, nodesA := triea.Commit(false) + dba.Update(NewWithNodeSet(nodesA)) + triea, _ = New(TrieID(rootA), dba, StateTrieCodec) + + dbb := NewDatabase(rawdb.NewMemoryDatabase()) + trieb := NewEmpty(dbb) + for _, val := range testdata2 { + trieb.Update([]byte(val.k), []byte(val.v)) + } + rootB, nodesB := trieb.Commit(false) + dbb.Update(NewWithNodeSet(nodesB)) + trieb, _ = New(TrieID(rootB), dbb, StateTrieCodec) + + di, _ := NewUnionIterator([]NodeIterator{triea.NodeIterator(nil), trieb.NodeIterator(nil)}) + it := NewIterator(di) + + all := []struct{ k, v string }{ + {"aardvark", "c"}, + {"barb", "ba"}, + {"barb", "bd"}, + {"bard", "bc"}, + {"bars", "bb"}, + {"bars", "be"}, + {"bar", "b"}, + {"fab", "z"}, + {"food", "ab"}, + {"foos", "aa"}, + {"foo", "a"}, + {"jars", "d"}, + } + + for i, kv := range all { + if !it.Next() { + t.Errorf("Iterator ends prematurely at element %d", i) + } + if kv.k != string(it.Key) { + t.Errorf("iterator value mismatch for element %d: got key %s want %s", i, it.Key, kv.k) + } + if kv.v != string(it.Value) { + t.Errorf("iterator value mismatch for element %d: got value %s want %s", i, it.Value, kv.v) + } + } + if it.Next() { + t.Errorf("Iterator returned extra values.") + } +} + +func TestIteratorNoDups(t *testing.T) { + tr := NewEmpty(NewDatabase(rawdb.NewMemoryDatabase())) + for _, val := range testdata1 { + tr.Update([]byte(val.k), []byte(val.v)) + } + checkIteratorNoDups(t, tr.NodeIterator(nil), nil) +} + +// This test checks that nodeIterator.Next can be retried after inserting missing trie nodes. +func TestIteratorContinueAfterErrorMemonly(t *testing.T) { testIteratorContinueAfterError(t, true) } + +func testIteratorContinueAfterError(t *testing.T, memonly bool) { + diskdb := rawdb.NewMemoryDatabase() + triedb := NewDatabase(diskdb) + + tr := NewEmpty(triedb) + for _, val := range testdata1 { + tr.Update([]byte(val.k), []byte(val.v)) + } + _, nodes := tr.Commit(false) + triedb.Update(NewWithNodeSet(nodes)) + // if !memonly { + // triedb.Commit(tr.Hash(), false) + // } + wantNodeCount := checkIteratorNoDups(t, tr.NodeIterator(nil), nil) + + var ( + diskKeys [][]byte + memKeys []common.Hash + ) + if memonly { + memKeys = triedb.Nodes() + } else { + it := diskdb.NewIterator(nil, nil) + for it.Next() { + diskKeys = append(diskKeys, it.Key()) + } + it.Release() + } + for i := 0; i < 20; i++ { + // Create trie that will load all nodes from DB. + tr, _ := New(TrieID(tr.Hash()), triedb, StateTrieCodec) + + // Remove a random node from the database. It can't be the root node + // because that one is already loaded. + var ( + rkey common.Hash + rval []byte + robj *cachedNode + ) + for { + if memonly { + rkey = memKeys[rand.Intn(len(memKeys))] + } else { + copy(rkey[:], diskKeys[rand.Intn(len(diskKeys))]) + } + if rkey != tr.Hash() { + break + } + } + if memonly { + robj = triedb.dirties[rkey] + delete(triedb.dirties, rkey) + } else { + rval, _ = diskdb.Get(rkey[:]) + diskdb.Delete(rkey[:]) + } + // Iterate until the error is hit. + seen := make(map[string]bool) + it := tr.NodeIterator(nil) + checkIteratorNoDups(t, it, seen) + missing, ok := it.Error().(*MissingNodeError) + if !ok || missing.NodeHash != rkey { + t.Fatal("didn't hit missing node, got", it.Error()) + } + + // Add the node back and continue iteration. + if memonly { + triedb.dirties[rkey] = robj + } else { + diskdb.Put(rkey[:], rval) + } + checkIteratorNoDups(t, it, seen) + if it.Error() != nil { + t.Fatal("unexpected error", it.Error()) + } + if len(seen) != wantNodeCount { + t.Fatal("wrong node iteration count, got", len(seen), "want", wantNodeCount) + } + } +} + +// Similar to the test above, this one checks that failure to create nodeIterator at a +// certain key prefix behaves correctly when Next is called. The expectation is that Next +// should retry seeking before returning true for the first time. +func TestIteratorContinueAfterSeekErrorMemonly(t *testing.T) { + testIteratorContinueAfterSeekError(t, true) +} + +func testIteratorContinueAfterSeekError(t *testing.T, memonly bool) { + // Commit test trie to db, then remove the node containing "bars". + diskdb := rawdb.NewMemoryDatabase() + triedb := NewDatabase(diskdb) + + ctr := NewEmpty(triedb) + for _, val := range testdata1 { + ctr.Update([]byte(val.k), []byte(val.v)) + } + root, nodes := ctr.Commit(false) + triedb.Update(NewWithNodeSet(nodes)) + // if !memonly { + // triedb.Commit(root, false) + // } + barNodeHash := common.HexToHash("05041990364eb72fcb1127652ce40d8bab765f2bfe53225b1170d276cc101c2e") + var ( + barNodeBlob []byte + barNodeObj *cachedNode + ) + if memonly { + barNodeObj = triedb.dirties[barNodeHash] + delete(triedb.dirties, barNodeHash) + } else { + barNodeBlob, _ = diskdb.Get(barNodeHash[:]) + diskdb.Delete(barNodeHash[:]) + } + // Create a new iterator that seeks to "bars". Seeking can't proceed because + // the node is missing. + tr, _ := New(TrieID(root), triedb, StateTrieCodec) + it := tr.NodeIterator([]byte("bars")) + missing, ok := it.Error().(*MissingNodeError) + if !ok { + t.Fatal("want MissingNodeError, got", it.Error()) + } else if missing.NodeHash != barNodeHash { + t.Fatal("wrong node missing") + } + // Reinsert the missing node. + if memonly { + triedb.dirties[barNodeHash] = barNodeObj + } else { + diskdb.Put(barNodeHash[:], barNodeBlob) + } + // Check that iteration produces the right set of values. + if err := checkIteratorOrder(testdata1[2:], NewIterator(it)); err != nil { + t.Fatal(err) + } +} + +func checkIteratorNoDups(t *testing.T, it NodeIterator, seen map[string]bool) int { + if seen == nil { + seen = make(map[string]bool) + } + for it.Next(true) { + if seen[string(it.Path())] { + t.Fatalf("iterator visited node path %x twice", it.Path()) + } + seen[string(it.Path())] = true + } + return len(seen) +} + +type loggingTrieDb struct { + *Database + getCount uint64 +} + +// GetReader retrieves a node reader belonging to the given state root. +func (db *loggingTrieDb) GetReader(root common.Hash, codec uint64) Reader { + return &loggingNodeReader{db, codec} +} + +// hashReader is reader of hashDatabase which implements the Reader interface. +type loggingNodeReader struct { + db *loggingTrieDb + codec uint64 +} + +// Node retrieves the trie node with the given node hash. +func (reader *loggingNodeReader) Node(owner common.Hash, path []byte, hash common.Hash) (node, error) { + blob, err := reader.NodeBlob(owner, path, hash) + if err != nil { + return nil, err + } + return decodeNodeUnsafe(hash[:], blob) +} + +// NodeBlob retrieves the RLP-encoded trie node blob with the given node hash. +func (reader *loggingNodeReader) NodeBlob(_ common.Hash, _ []byte, hash common.Hash) ([]byte, error) { + reader.db.getCount++ + return reader.db.Node(hash, reader.codec) +} + +func newLoggingStateTrie(id *ID, db *Database, codec uint64) (*StateTrie, *loggingTrieDb, error) { + logdb := &loggingTrieDb{Database: db} + trie, err := New(id, logdb, codec) + if err != nil { + return nil, nil, err + } + return &StateTrie{trie: *trie, preimages: db.preimages}, logdb, nil +} + +// makeLargeTestTrie create a sample test trie +func makeLargeTestTrie(t testing.TB) (*Database, *StateTrie, *loggingTrieDb) { + // Create an empty trie + triedb := NewDatabase(rawdb.NewDatabase(memorydb.New())) + trie, logDb, err := newLoggingStateTrie(TrieID(common.Hash{}), triedb, StateTrieCodec) + if err != nil { + t.Fatal(err) + } + + // Fill it with some arbitrary data + for i := 0; i < 10000; i++ { + key := make([]byte, 32) + val := make([]byte, 32) + binary.BigEndian.PutUint64(key, uint64(i)) + binary.BigEndian.PutUint64(val, uint64(i)) + key = crypto.Keccak256(key) + val = crypto.Keccak256(val) + trie.Update(key, val) + } + _, nodes := trie.Commit(false) + triedb.Update(NewWithNodeSet(nodes)) + // Return the generated trie + return triedb, trie, logDb +} + +// Tests that the node iterator indeed walks over the entire database contents. +func TestNodeIteratorLargeTrie(t *testing.T) { + // Create some arbitrary test trie to iterate + _, trie, logDb := makeLargeTestTrie(t) + // Do a seek operation + trie.NodeIterator(common.FromHex("0x77667766776677766778855885885885")) + // master: 24 get operations + // this pr: 5 get operations + if have, want := logDb.getCount, uint64(5); have != want { + t.Fatalf("Wrong number of lookups during seek, have %d want %d", have, want) + } +} + func TestIteratorNodeBlob(t *testing.T) { + // var ( + // db = rawdb.NewMemoryDatabase() + // triedb = NewDatabase(db) + // trie = NewEmpty(triedb) + // ) + // vals := []struct{ k, v string }{ + // {"do", "verb"}, + // {"ether", "wookiedoo"}, + // {"horse", "stallion"}, + // {"shaman", "horse"}, + // {"doge", "coin"}, + // {"dog", "puppy"}, + // {"somethingveryoddindeedthis is", "myothernodedata"}, + // } + // all := make(map[string]string) + // for _, val := range vals { + // all[val.k] = val.v + // trie.Update([]byte(val.k), []byte(val.v)) + // } + // _, nodes := trie.Commit(false) + // triedb.Update(NewWithNodeSet(nodes)) + edb := rawdb.NewMemoryDatabase() db := geth_trie.NewDatabase(edb) orig := geth_trie.NewEmpty(geth_trie.NewDatabase(rawdb.NewMemoryDatabase())) - if _, err := updateTrie(orig, testdata1); err != nil { + if _, err := updateTrie(orig, packableTestData); err != nil { t.Fatal(err) } root := commitTrie(t, db, orig) @@ -167,7 +608,7 @@ func TestIteratorNodeBlob(t *testing.T) { for dbIter.Next() { got, present := found[common.BytesToHash(dbIter.Key())] if !present { - t.Fatalf("Missing trie node %v", dbIter.Key()) + t.Fatalf("Miss trie node %v", dbIter.Key()) } if !bytes.Equal(got, dbIter.Value()) { t.Fatalf("Unexpected trie node want %v got %v", dbIter.Value(), got) @@ -175,6 +616,6 @@ func TestIteratorNodeBlob(t *testing.T) { count += 1 } if count != len(found) { - t.Fatal("Find extra trie node via iterator") + t.Fatalf("Wrong number of trie nodes found, want %d, got %d", len(found), count) } } diff --git a/trie_by_cid/trie/node.go b/trie_by_cid/trie/node.go index c995099..6ce6551 100644 --- a/trie_by_cid/trie/node.go +++ b/trie_by_cid/trie/node.go @@ -23,7 +23,6 @@ import ( "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/rlp" - "github.com/ethereum/go-ethereum/trie" ) var indices = []string{"0", "1", "2", "3", "4", "5", "6", "7", "8", "9", "a", "b", "c", "d", "e", "f", "[17]"} @@ -157,7 +156,7 @@ func decodeShort(hash, elems []byte) (node, error) { return nil, err } flag := nodeFlag{hash: hash} - key := trie.CompactToHex(kbuf) + key := compactToHex(kbuf) if hasTerm(key) { // value node val, _, err := rlp.SplitString(rest) diff --git a/trie_by_cid/trie/node_enc.go b/trie_by_cid/trie/node_enc.go index 2d26350..cade35b 100644 --- a/trie_by_cid/trie/node_enc.go +++ b/trie_by_cid/trie/node_enc.go @@ -58,3 +58,30 @@ func (n hashNode) encode(w rlp.EncoderBuffer) { func (n valueNode) encode(w rlp.EncoderBuffer) { w.WriteBytes(n) } + +func (n rawFullNode) encode(w rlp.EncoderBuffer) { + offset := w.List() + for _, c := range n { + if c != nil { + c.encode(w) + } else { + w.Write(rlp.EmptyString) + } + } + w.ListEnd(offset) +} + +func (n *rawShortNode) encode(w rlp.EncoderBuffer) { + offset := w.List() + w.WriteBytes(n.Key) + if n.Val != nil { + n.Val.encode(w) + } else { + w.Write(rlp.EmptyString) + } + w.ListEnd(offset) +} + +func (n rawNode) encode(w rlp.EncoderBuffer) { + w.Write(n) +} diff --git a/trie_by_cid/trie/node_test.go b/trie_by_cid/trie/node_test.go index ac1d8fb..9b8b337 100644 --- a/trie_by_cid/trie/node_test.go +++ b/trie_by_cid/trie/node_test.go @@ -20,6 +20,7 @@ import ( "bytes" "testing" + "github.com/ethereum/go-ethereum/crypto" "github.com/ethereum/go-ethereum/rlp" ) @@ -92,3 +93,123 @@ func TestDecodeFullNode(t *testing.T) { t.Fatalf("decode full node err: %v", err) } } + +// goos: darwin +// goarch: arm64 +// pkg: github.com/ethereum/go-ethereum/trie +// BenchmarkEncodeShortNode +// BenchmarkEncodeShortNode-8 16878850 70.81 ns/op 48 B/op 1 allocs/op +func BenchmarkEncodeShortNode(b *testing.B) { + node := &shortNode{ + Key: []byte{0x1, 0x2}, + Val: hashNode(randBytes(32)), + } + b.ResetTimer() + b.ReportAllocs() + + for i := 0; i < b.N; i++ { + nodeToBytes(node) + } +} + +// goos: darwin +// goarch: arm64 +// pkg: github.com/ethereum/go-ethereum/trie +// BenchmarkEncodeFullNode +// BenchmarkEncodeFullNode-8 4323273 284.4 ns/op 576 B/op 1 allocs/op +func BenchmarkEncodeFullNode(b *testing.B) { + node := &fullNode{} + for i := 0; i < 16; i++ { + node.Children[i] = hashNode(randBytes(32)) + } + b.ResetTimer() + b.ReportAllocs() + + for i := 0; i < b.N; i++ { + nodeToBytes(node) + } +} + +// goos: darwin +// goarch: arm64 +// pkg: github.com/ethereum/go-ethereum/trie +// BenchmarkDecodeShortNode +// BenchmarkDecodeShortNode-8 7925638 151.0 ns/op 157 B/op 4 allocs/op +func BenchmarkDecodeShortNode(b *testing.B) { + node := &shortNode{ + Key: []byte{0x1, 0x2}, + Val: hashNode(randBytes(32)), + } + blob := nodeToBytes(node) + hash := crypto.Keccak256(blob) + + b.ResetTimer() + b.ReportAllocs() + + for i := 0; i < b.N; i++ { + mustDecodeNode(hash, blob) + } +} + +// goos: darwin +// goarch: arm64 +// pkg: github.com/ethereum/go-ethereum/trie +// BenchmarkDecodeShortNodeUnsafe +// BenchmarkDecodeShortNodeUnsafe-8 9027476 128.6 ns/op 109 B/op 3 allocs/op +func BenchmarkDecodeShortNodeUnsafe(b *testing.B) { + node := &shortNode{ + Key: []byte{0x1, 0x2}, + Val: hashNode(randBytes(32)), + } + blob := nodeToBytes(node) + hash := crypto.Keccak256(blob) + + b.ResetTimer() + b.ReportAllocs() + + for i := 0; i < b.N; i++ { + mustDecodeNodeUnsafe(hash, blob) + } +} + +// goos: darwin +// goarch: arm64 +// pkg: github.com/ethereum/go-ethereum/trie +// BenchmarkDecodeFullNode +// BenchmarkDecodeFullNode-8 1597462 761.9 ns/op 1280 B/op 18 allocs/op +func BenchmarkDecodeFullNode(b *testing.B) { + node := &fullNode{} + for i := 0; i < 16; i++ { + node.Children[i] = hashNode(randBytes(32)) + } + blob := nodeToBytes(node) + hash := crypto.Keccak256(blob) + + b.ResetTimer() + b.ReportAllocs() + + for i := 0; i < b.N; i++ { + mustDecodeNode(hash, blob) + } +} + +// goos: darwin +// goarch: arm64 +// pkg: github.com/ethereum/go-ethereum/trie +// BenchmarkDecodeFullNodeUnsafe +// BenchmarkDecodeFullNodeUnsafe-8 1789070 687.1 ns/op 704 B/op 17 allocs/op +func BenchmarkDecodeFullNodeUnsafe(b *testing.B) { + node := &fullNode{} + for i := 0; i < 16; i++ { + node.Children[i] = hashNode(randBytes(32)) + } + blob := nodeToBytes(node) + hash := crypto.Keccak256(blob) + + b.ResetTimer() + b.ReportAllocs() + + for i := 0; i < b.N; i++ { + mustDecodeNodeUnsafe(hash, blob) + } +} diff --git a/trie_by_cid/trie/nodeset.go b/trie_by_cid/trie/nodeset.go new file mode 100644 index 0000000..99e4a80 --- /dev/null +++ b/trie_by_cid/trie/nodeset.go @@ -0,0 +1,218 @@ +// Copyright 2022 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package trie + +import ( + "fmt" + "reflect" + "sort" + "strings" + + "github.com/ethereum/go-ethereum/common" +) + +// memoryNode is all the information we know about a single cached trie node +// in the memory. +type memoryNode struct { + hash common.Hash // Node hash, computed by hashing rlp value, empty for deleted nodes + size uint16 // Byte size of the useful cached data, 0 for deleted nodes + node node // Cached collapsed trie node, or raw rlp data, nil for deleted nodes +} + +// memoryNodeSize is the raw size of a memoryNode data structure without any +// node data included. It's an approximate size, but should be a lot better +// than not counting them. +// nolint:unused +var memoryNodeSize = int(reflect.TypeOf(memoryNode{}).Size()) + +// memorySize returns the total memory size used by this node. +// nolint:unused +func (n *memoryNode) memorySize(pathlen int) int { + return int(n.size) + memoryNodeSize + pathlen +} + +// rlp returns the raw rlp encoded blob of the cached trie node, either directly +// from the cache, or by regenerating it from the collapsed node. +// nolint:unused +func (n *memoryNode) rlp() []byte { + if node, ok := n.node.(rawNode); ok { + return node + } + return nodeToBytes(n.node) +} + +// obj returns the decoded and expanded trie node, either directly from the cache, +// or by regenerating it from the rlp encoded blob. +// nolint:unused +func (n *memoryNode) obj() node { + if node, ok := n.node.(rawNode); ok { + return mustDecodeNode(n.hash[:], node) + } + return expandNode(n.hash[:], n.node) +} + +// isDeleted returns the indicator if the node is marked as deleted. +func (n *memoryNode) isDeleted() bool { + return n.hash == (common.Hash{}) +} + +// nodeWithPrev wraps the memoryNode with the previous node value. +// nolint: unused +type nodeWithPrev struct { + *memoryNode + prev []byte // RLP-encoded previous value, nil means it's non-existent +} + +// unwrap returns the internal memoryNode object. +// nolint:unused +func (n *nodeWithPrev) unwrap() *memoryNode { + return n.memoryNode +} + +// memorySize returns the total memory size used by this node. It overloads +// the function in memoryNode by counting the size of previous value as well. +// nolint: unused +func (n *nodeWithPrev) memorySize(pathlen int) int { + return n.memoryNode.memorySize(pathlen) + len(n.prev) +} + +// NodeSet contains all dirty nodes collected during the commit operation. +// Each node is keyed by path. It's not thread-safe to use. +type NodeSet struct { + owner common.Hash // the identifier of the trie + nodes map[string]*memoryNode // the set of dirty nodes(inserted, updated, deleted) + leaves []*leaf // the list of dirty leaves + updates int // the count of updated and inserted nodes + deletes int // the count of deleted nodes + + // The list of accessed nodes, which records the original node value. + // The origin value is expected to be nil for newly inserted node + // and is expected to be non-nil for other types(updated, deleted). + accessList map[string][]byte +} + +// NewNodeSet initializes an empty node set to be used for tracking dirty nodes +// from a specific account or storage trie. The owner is zero for the account +// trie and the owning account address hash for storage tries. +func NewNodeSet(owner common.Hash, accessList map[string][]byte) *NodeSet { + return &NodeSet{ + owner: owner, + nodes: make(map[string]*memoryNode), + accessList: accessList, + } +} + +// forEachWithOrder iterates the dirty nodes with the order from bottom to top, +// right to left, nodes with the longest path will be iterated first. +func (set *NodeSet) forEachWithOrder(callback func(path string, n *memoryNode)) { + var paths sort.StringSlice + for path := range set.nodes { + paths = append(paths, path) + } + // Bottom-up, longest path first + sort.Sort(sort.Reverse(paths)) + for _, path := range paths { + callback(path, set.nodes[path]) + } +} + +// markUpdated marks the node as dirty(newly-inserted or updated). +func (set *NodeSet) markUpdated(path []byte, node *memoryNode) { + set.nodes[string(path)] = node + set.updates += 1 +} + +// markDeleted marks the node as deleted. +func (set *NodeSet) markDeleted(path []byte) { + set.nodes[string(path)] = &memoryNode{} + set.deletes += 1 +} + +// addLeaf collects the provided leaf node into set. +func (set *NodeSet) addLeaf(node *leaf) { + set.leaves = append(set.leaves, node) +} + +// Size returns the number of dirty nodes in set. +func (set *NodeSet) Size() (int, int) { + return set.updates, set.deletes +} + +// Hashes returns the hashes of all updated nodes. TODO(rjl493456442) how can +// we get rid of it? +func (set *NodeSet) Hashes() []common.Hash { + var ret []common.Hash + for _, node := range set.nodes { + ret = append(ret, node.hash) + } + return ret +} + +// Summary returns a string-representation of the NodeSet. +func (set *NodeSet) Summary() string { + var out = new(strings.Builder) + fmt.Fprintf(out, "nodeset owner: %v\n", set.owner) + if set.nodes != nil { + for path, n := range set.nodes { + // Deletion + if n.isDeleted() { + fmt.Fprintf(out, " [-]: %x prev: %x\n", path, set.accessList[path]) + continue + } + // Insertion + origin, ok := set.accessList[path] + if !ok { + fmt.Fprintf(out, " [+]: %x -> %v\n", path, n.hash) + continue + } + // Update + fmt.Fprintf(out, " [*]: %x -> %v prev: %x\n", path, n.hash, origin) + } + } + for _, n := range set.leaves { + fmt.Fprintf(out, "[leaf]: %v\n", n) + } + return out.String() +} + +// MergedNodeSet represents a merged dirty node set for a group of tries. +type MergedNodeSet struct { + sets map[common.Hash]*NodeSet +} + +// NewMergedNodeSet initializes an empty merged set. +func NewMergedNodeSet() *MergedNodeSet { + return &MergedNodeSet{sets: make(map[common.Hash]*NodeSet)} +} + +// NewWithNodeSet constructs a merged nodeset with the provided single set. +func NewWithNodeSet(set *NodeSet) *MergedNodeSet { + merged := NewMergedNodeSet() + merged.Merge(set) + return merged +} + +// Merge merges the provided dirty nodes of a trie into the set. The assumption +// is held that no duplicated set belonging to the same trie will be merged twice. +func (set *MergedNodeSet) Merge(other *NodeSet) error { + _, present := set.sets[other.owner] + if present { + return fmt.Errorf("duplicate trie for owner %#x", other.owner) + } + set.sets[other.owner] = other + return nil +} diff --git a/trie_by_cid/trie/preimages.go b/trie_by_cid/trie/preimages.go new file mode 100644 index 0000000..a6359ca --- /dev/null +++ b/trie_by_cid/trie/preimages.go @@ -0,0 +1,78 @@ +// Copyright 2022 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package trie + +import ( + "sync" + + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/core/rawdb" + "github.com/ethereum/go-ethereum/ethdb" +) + +// preimageStore is the store for caching preimages of node key. +type preimageStore struct { + lock sync.RWMutex + disk ethdb.KeyValueStore + preimages map[common.Hash][]byte // Preimages of nodes from the secure trie + preimagesSize common.StorageSize // Storage size of the preimages cache +} + +// newPreimageStore initializes the store for caching preimages. +func newPreimageStore(disk ethdb.KeyValueStore) *preimageStore { + return &preimageStore{ + disk: disk, + preimages: make(map[common.Hash][]byte), + } +} + +// insertPreimage writes a new trie node pre-image to the memory database if it's +// yet unknown. The method will NOT make a copy of the slice, only use if the +// preimage will NOT be changed later on. +func (store *preimageStore) insertPreimage(preimages map[common.Hash][]byte) { + store.lock.Lock() + defer store.lock.Unlock() + + for hash, preimage := range preimages { + if _, ok := store.preimages[hash]; ok { + continue + } + store.preimages[hash] = preimage + store.preimagesSize += common.StorageSize(common.HashLength + len(preimage)) + } +} + +// preimage retrieves a cached trie node pre-image from memory. If it cannot be +// found cached, the method queries the persistent database for the content. +func (store *preimageStore) preimage(hash common.Hash) []byte { + store.lock.RLock() + preimage := store.preimages[hash] + store.lock.RUnlock() + + if preimage != nil { + return preimage + } + return rawdb.ReadPreimage(store.disk, hash) +} + +// size returns the current storage size of accumulated preimages. +func (store *preimageStore) size() common.StorageSize { + store.lock.RLock() + defer store.lock.RUnlock() + + return store.preimagesSize +} diff --git a/trie_by_cid/trie/proof.go b/trie_by_cid/trie/proof.go index e48eda5..7315c0d 100644 --- a/trie_by_cid/trie/proof.go +++ b/trie_by_cid/trie/proof.go @@ -20,9 +20,10 @@ import ( "bytes" "fmt" + "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/ethdb" - "github.com/ethereum/go-ethereum/log" "github.com/ethereum/go-ethereum/trie" + log "github.com/sirupsen/logrus" ) var VerifyProof = trie.VerifyProof @@ -61,10 +62,15 @@ func (t *Trie) Prove(key []byte, fromLevel uint, proofDb ethdb.KeyValueWriter) e key = key[1:] nodes = append(nodes, n) case hashNode: + // Retrieve the specified node from the underlying node reader. + // trie.resolveAndTrack is not used since in that function the + // loaded blob will be tracked, while it's not required here since + // all loaded nodes won't be linked to trie at all and track nodes + // may lead to out-of-memory issue. var err error - tn, err = t.resolveHash(n, prefix) + tn, err = t.reader.node(prefix, common.BytesToHash(n)) if err != nil { - log.Error(fmt.Sprintf("Unhandled trie error: %v", err)) + log.Error("Unhandled trie error in Trie.Prove", "err", err) return err } default: diff --git a/trie_by_cid/trie/proof_test.go b/trie_by_cid/trie/proof_test.go index 3fd508a..6b23bcd 100644 --- a/trie_by_cid/trie/proof_test.go +++ b/trie_by_cid/trie/proof_test.go @@ -14,29 +14,39 @@ // You should have received a copy of the GNU Lesser General Public License // along with the go-ethereum library. If not, see . -package trie_test +package trie import ( "bytes" + crand "crypto/rand" + "encoding/binary" + "fmt" mrand "math/rand" "sort" "testing" - "time" "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/core/rawdb" "github.com/ethereum/go-ethereum/crypto" "github.com/ethereum/go-ethereum/ethdb/memorydb" - geth_trie "github.com/ethereum/go-ethereum/trie" - - . "github.com/cerc-io/ipld-eth-statedb/trie_by_cid/trie" ) -// tweakable trie size -var scaleFactor = 512 +// Prng is a pseudo random number generator seeded by strong randomness. +// The randomness is printed on startup in order to make failures reproducible. +var prng = initRnd() -func init() { - mrand.Seed(time.Now().UnixNano()) +func initRnd() *mrand.Rand { + var seed [8]byte + crand.Read(seed[:]) + rnd := mrand.New(mrand.NewSource(int64(binary.LittleEndian.Uint64(seed[:])))) + fmt.Printf("Seed: %x\n", seed) + return rnd +} + +func randBytes(n int) []byte { + r := make([]byte, n) + prng.Read(r) + return r } // makeProvers creates Merkle trie provers based on different implementations to @@ -64,7 +74,7 @@ func makeProvers(trie *Trie) []func(key []byte) *memorydb.Database { } func TestProof(t *testing.T) { - trie, vals := randomTrie(t, scaleFactor) + trie, vals := randomTrie(500) root := trie.Hash() for i, prover := range makeProvers(trie) { for _, kv := range vals { @@ -72,11 +82,11 @@ func TestProof(t *testing.T) { if proof == nil { t.Fatalf("prover %d: missing key %x while constructing proof", i, kv.k) } - val, err := geth_trie.VerifyProof(root, kv.k, proof) + val, err := VerifyProof(root, kv.k, proof) if err != nil { t.Fatalf("prover %d: failed to verify proof for key %x: %v\nraw proof: %x", i, kv.k, err, proof) } - if kv.v != unpackValue(val) { + if !bytes.Equal(val, kv.v) { t.Fatalf("prover %d: verified value mismatch for key %x: have %x, want %x", i, kv.k, val, kv.v) } } @@ -84,13 +94,8 @@ func TestProof(t *testing.T) { } func TestOneElementProof(t *testing.T) { - edb := rawdb.NewMemoryDatabase() - db := geth_trie.NewDatabase(edb) - orig := geth_trie.NewEmpty(db) - orig.Update([]byte("k"), packValue(42)) - root := commitTrie(t, db, orig) - trie := indexTrie(t, edb, root) - + trie := NewEmpty(NewDatabase(rawdb.NewMemoryDatabase())) + updateString(trie, "k", "v") for i, prover := range makeProvers(trie) { proof := prover([]byte("k")) if proof == nil { @@ -99,18 +104,18 @@ func TestOneElementProof(t *testing.T) { if proof.Len() != 1 { t.Errorf("prover %d: proof should have one element", i) } - val, err := geth_trie.VerifyProof(trie.Hash(), []byte("k"), proof) + val, err := VerifyProof(trie.Hash(), []byte("k"), proof) if err != nil { t.Fatalf("prover %d: failed to verify proof: %v\nraw proof: %x", i, err, proof) } - if 42 != unpackValue(val) { + if !bytes.Equal(val, []byte("v")) { t.Fatalf("prover %d: verified value mismatch: have %x, want 'k'", i, val) } } } func TestBadProof(t *testing.T) { - trie, vals := randomTrie(t, 2*scaleFactor) + trie, vals := randomTrie(800) root := trie.Hash() for i, prover := range makeProvers(trie) { for _, kv := range vals { @@ -140,12 +145,8 @@ func TestBadProof(t *testing.T) { // Tests that missing keys can also be proven. The test explicitly uses a single // entry trie and checks for missing keys both before and after the single entry. func TestMissingKeyProof(t *testing.T) { - edb := rawdb.NewMemoryDatabase() - db := geth_trie.NewDatabase(edb) - orig := geth_trie.NewEmpty(db) - orig.Update([]byte("k"), packValue(42)) - root := commitTrie(t, db, orig) - trie := indexTrie(t, edb, root) + trie := NewEmpty(NewDatabase(rawdb.NewMemoryDatabase())) + updateString(trie, "k", "v") for i, key := range []string{"a", "j", "l", "z"} { proof := memorydb.New() @@ -164,15 +165,7 @@ func TestMissingKeyProof(t *testing.T) { } } -type entry struct { - k, v []byte -} - -func packEntry(kv *kv) *entry { - return &entry{kv.k, packValue(kv.v)} -} - -type entrySlice []*entry +type entrySlice []*kv func (p entrySlice) Len() int { return len(p) } func (p entrySlice) Less(i, j int) bool { return bytes.Compare(p[i].k, p[j].k) < 0 } @@ -181,10 +174,10 @@ func (p entrySlice) Swap(i, j int) { p[i], p[j] = p[j], p[i] } // TestRangeProof tests normal range proof with both edge proofs // as the existent proof. The test cases are generated randomly. func TestRangeProof(t *testing.T) { - trie, vals := randomTrie(t, 8*scaleFactor) + trie, vals := randomTrie(4096) var entries entrySlice for _, kv := range vals { - entries = append(entries, packEntry(kv)) + entries = append(entries, kv) } sort.Sort(entries) for i := 0; i < 500; i++ { @@ -214,10 +207,10 @@ func TestRangeProof(t *testing.T) { // TestRangeProof tests normal range proof with two non-existent proofs. // The test cases are generated randomly. func TestRangeProofWithNonExistentProof(t *testing.T) { - trie, vals := randomTrie(t, 8*scaleFactor) + trie, vals := randomTrie(4096) var entries entrySlice for _, kv := range vals { - entries = append(entries, packEntry(kv)) + entries = append(entries, kv) } sort.Sort(entries) for i := 0; i < 500; i++ { @@ -286,10 +279,10 @@ func TestRangeProofWithNonExistentProof(t *testing.T) { // - There exists a gap between the first element and the left edge proof // - There exists a gap between the last element and the right edge proof func TestRangeProofWithInvalidNonExistentProof(t *testing.T) { - trie, vals := randomTrie(t, 8*scaleFactor) + trie, vals := randomTrie(4096) var entries entrySlice for _, kv := range vals { - entries = append(entries, packEntry(kv)) + entries = append(entries, kv) } sort.Sort(entries) @@ -343,10 +336,10 @@ func TestRangeProofWithInvalidNonExistentProof(t *testing.T) { // element. The first edge proof can be existent one or // non-existent one. func TestOneElementRangeProof(t *testing.T) { - trie, vals := randomTrie(t, 8*scaleFactor) + trie, vals := randomTrie(4096) var entries entrySlice for _, kv := range vals { - entries = append(entries, packEntry(kv)) + entries = append(entries, kv) } sort.Sort(entries) @@ -408,38 +401,32 @@ func TestOneElementRangeProof(t *testing.T) { } // Test the mini trie with only a single element. - t.Run("single element", func(t *testing.T) { - edb := rawdb.NewMemoryDatabase() - db := geth_trie.NewDatabase(edb) - orig := geth_trie.NewEmpty(db) - entry := &entry{randBytes(32), packValue(mrand.Int63())} - orig.Update(entry.k, entry.v) - root := commitTrie(t, db, orig) - tinyTrie := indexTrie(t, edb, root) + tinyTrie := NewEmpty(NewDatabase(rawdb.NewMemoryDatabase())) + entry := &kv{randBytes(32), randBytes(20), false} + tinyTrie.Update(entry.k, entry.v) - first = common.HexToHash("0x0000000000000000000000000000000000000000000000000000000000000000").Bytes() - last = entry.k - proof = memorydb.New() - if err := tinyTrie.Prove(first, 0, proof); err != nil { - t.Fatalf("Failed to prove the first node %v", err) - } - if err := tinyTrie.Prove(last, 0, proof); err != nil { - t.Fatalf("Failed to prove the last node %v", err) - } - _, err = VerifyRangeProof(tinyTrie.Hash(), first, last, [][]byte{entry.k}, [][]byte{entry.v}, proof) - if err != nil { - t.Fatalf("Expected no error, got %v", err) - } - }) + first = common.HexToHash("0x0000000000000000000000000000000000000000000000000000000000000000").Bytes() + last = entry.k + proof = memorydb.New() + if err := tinyTrie.Prove(first, 0, proof); err != nil { + t.Fatalf("Failed to prove the first node %v", err) + } + if err := tinyTrie.Prove(last, 0, proof); err != nil { + t.Fatalf("Failed to prove the last node %v", err) + } + _, err = VerifyRangeProof(tinyTrie.Hash(), first, last, [][]byte{entry.k}, [][]byte{entry.v}, proof) + if err != nil { + t.Fatalf("Expected no error, got %v", err) + } } // TestAllElementsProof tests the range proof with all elements. // The edge proofs can be nil. func TestAllElementsProof(t *testing.T) { - trie, vals := randomTrie(t, 8*scaleFactor) + trie, vals := randomTrie(4096) var entries entrySlice for _, kv := range vals { - entries = append(entries, packEntry(kv)) + entries = append(entries, kv) } sort.Sort(entries) @@ -483,86 +470,73 @@ func TestAllElementsProof(t *testing.T) { } } -func positionCases(size int) []int { - var cases []int - for _, pos := range []int{0, 1, 50, 100, 1000, 2000, size - 1} { - if pos >= size { - break - } - cases = append(cases, pos) - } - return cases -} - // TestSingleSideRangeProof tests the range starts from zero. func TestSingleSideRangeProof(t *testing.T) { - edb := rawdb.NewMemoryDatabase() - db := geth_trie.NewDatabase(edb) - orig := geth_trie.NewEmpty(db) - var entries entrySlice - for i := 0; i < 8*scaleFactor; i++ { - value := &entry{randBytes(32), packValue(mrand.Int63())} - orig.Update(value.k, value.v) - entries = append(entries, value) - } - root := commitTrie(t, db, orig) - trie := indexTrie(t, edb, root) - sort.Sort(entries) + for i := 0; i < 64; i++ { + trie := NewEmpty(NewDatabase(rawdb.NewMemoryDatabase())) + var entries entrySlice + for i := 0; i < 4096; i++ { + value := &kv{randBytes(32), randBytes(20), false} + trie.Update(value.k, value.v) + entries = append(entries, value) + } + sort.Sort(entries) - for _, pos := range positionCases(len(entries)) { - proof := memorydb.New() - if err := trie.Prove(common.Hash{}.Bytes(), 0, proof); err != nil { - t.Fatalf("Failed to prove the first node %v", err) - } - if err := trie.Prove(entries[pos].k, 0, proof); err != nil { - t.Fatalf("Failed to prove the first node %v", err) - } - k := make([][]byte, 0) - v := make([][]byte, 0) - for i := 0; i <= pos; i++ { - k = append(k, entries[i].k) - v = append(v, entries[i].v) - } - _, err := VerifyRangeProof(trie.Hash(), common.Hash{}.Bytes(), k[len(k)-1], k, v, proof) - if err != nil { - t.Fatalf("Expected no error, got %v", err) + var cases = []int{0, 1, 50, 100, 1000, 2000, len(entries) - 1} + for _, pos := range cases { + proof := memorydb.New() + if err := trie.Prove(common.Hash{}.Bytes(), 0, proof); err != nil { + t.Fatalf("Failed to prove the first node %v", err) + } + if err := trie.Prove(entries[pos].k, 0, proof); err != nil { + t.Fatalf("Failed to prove the first node %v", err) + } + k := make([][]byte, 0) + v := make([][]byte, 0) + for i := 0; i <= pos; i++ { + k = append(k, entries[i].k) + v = append(v, entries[i].v) + } + _, err := VerifyRangeProof(trie.Hash(), common.Hash{}.Bytes(), k[len(k)-1], k, v, proof) + if err != nil { + t.Fatalf("Expected no error, got %v", err) + } } } } // TestReverseSingleSideRangeProof tests the range ends with 0xffff...fff. func TestReverseSingleSideRangeProof(t *testing.T) { - edb := rawdb.NewMemoryDatabase() - db := geth_trie.NewDatabase(edb) - orig := geth_trie.NewEmpty(db) - var entries entrySlice - for i := 0; i < 8*scaleFactor; i++ { - value := &entry{randBytes(32), packValue(mrand.Int63())} - orig.Update(value.k, value.v) - entries = append(entries, value) - } - root := commitTrie(t, db, orig) - trie := indexTrie(t, edb, root) - sort.Sort(entries) + for i := 0; i < 64; i++ { + trie := NewEmpty(NewDatabase(rawdb.NewMemoryDatabase())) + var entries entrySlice + for i := 0; i < 4096; i++ { + value := &kv{randBytes(32), randBytes(20), false} + trie.Update(value.k, value.v) + entries = append(entries, value) + } + sort.Sort(entries) - for _, pos := range positionCases(len(entries)) { - proof := memorydb.New() - if err := trie.Prove(entries[pos].k, 0, proof); err != nil { - t.Fatalf("Failed to prove the first node %v", err) - } - last := common.HexToHash("0xffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff") - if err := trie.Prove(last.Bytes(), 0, proof); err != nil { - t.Fatalf("Failed to prove the last node %v", err) - } - k := make([][]byte, 0) - v := make([][]byte, 0) - for i := pos; i < len(entries); i++ { - k = append(k, entries[i].k) - v = append(v, entries[i].v) - } - _, err := VerifyRangeProof(trie.Hash(), k[0], last.Bytes(), k, v, proof) - if err != nil { - t.Fatalf("Expected no error, got %v", err) + var cases = []int{0, 1, 50, 100, 1000, 2000, len(entries) - 1} + for _, pos := range cases { + proof := memorydb.New() + if err := trie.Prove(entries[pos].k, 0, proof); err != nil { + t.Fatalf("Failed to prove the first node %v", err) + } + last := common.HexToHash("0xffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff") + if err := trie.Prove(last.Bytes(), 0, proof); err != nil { + t.Fatalf("Failed to prove the last node %v", err) + } + k := make([][]byte, 0) + v := make([][]byte, 0) + for i := pos; i < len(entries); i++ { + k = append(k, entries[i].k) + v = append(v, entries[i].v) + } + _, err := VerifyRangeProof(trie.Hash(), k[0], last.Bytes(), k, v, proof) + if err != nil { + t.Fatalf("Expected no error, got %v", err) + } } } } @@ -570,10 +544,10 @@ func TestReverseSingleSideRangeProof(t *testing.T) { // TestBadRangeProof tests a few cases which the proof is wrong. // The prover is expected to detect the error. func TestBadRangeProof(t *testing.T) { - trie, vals := randomTrie(t, 8*scaleFactor) + trie, vals := randomTrie(4096) var entries entrySlice for _, kv := range vals { - entries = append(entries, packEntry(kv)) + entries = append(entries, kv) } sort.Sort(entries) @@ -641,17 +615,13 @@ func TestBadRangeProof(t *testing.T) { // TestGappedRangeProof focuses on the small trie with embedded nodes. // If the gapped node is embedded in the trie, it should be detected too. func TestGappedRangeProof(t *testing.T) { - edb := rawdb.NewMemoryDatabase() - db := geth_trie.NewDatabase(edb) - orig := geth_trie.NewEmpty(db) - var entries entrySlice + trie := NewEmpty(NewDatabase(rawdb.NewMemoryDatabase())) + var entries []*kv // Sorted entries for i := byte(0); i < 10; i++ { - value := &entry{common.LeftPadBytes([]byte{i}, 32), packValue(int64(i))} - orig.Update(value.k, value.v) + value := &kv{common.LeftPadBytes([]byte{i}, 32), []byte{i}, false} + trie.Update(value.k, value.v) entries = append(entries, value) } - root := commitTrie(t, db, orig) - trie := indexTrie(t, edb, root) first, last := 2, 8 proof := memorydb.New() if err := trie.Prove(entries[first].k, 0, proof); err != nil { @@ -677,10 +647,10 @@ func TestGappedRangeProof(t *testing.T) { // TestSameSideProofs tests the element is not in the range covered by proofs func TestSameSideProofs(t *testing.T) { - trie, vals := randomTrie(t, 8*scaleFactor) + trie, vals := randomTrie(4096) var entries entrySlice for _, kv := range vals { - entries = append(entries, packEntry(kv)) + entries = append(entries, kv) } sort.Sort(entries) @@ -719,17 +689,13 @@ func TestSameSideProofs(t *testing.T) { } func TestHasRightElement(t *testing.T) { - edb := rawdb.NewMemoryDatabase() - db := geth_trie.NewDatabase(edb) - orig := geth_trie.NewEmpty(db) + trie := NewEmpty(NewDatabase(rawdb.NewMemoryDatabase())) var entries entrySlice - for i := 0; i < 8*scaleFactor; i++ { - value := &entry{randBytes(32), packValue(int64(i))} - orig.Update(value.k, value.v) + for i := 0; i < 4096; i++ { + value := &kv{randBytes(32), randBytes(20), false} + trie.Update(value.k, value.v) entries = append(entries, value) } - root := commitTrie(t, db, orig) - trie := indexTrie(t, edb, root) sort.Sort(entries) var cases = []struct { @@ -797,10 +763,10 @@ func TestHasRightElement(t *testing.T) { // TestEmptyRangeProof tests the range proof with "no" element. // The first edge proof must be a non-existent proof. func TestEmptyRangeProof(t *testing.T) { - trie, vals := randomTrie(t, 8*scaleFactor) + trie, vals := randomTrie(4096) var entries entrySlice for _, kv := range vals { - entries = append(entries, packEntry(kv)) + entries = append(entries, kv) } sort.Sort(entries) @@ -827,14 +793,49 @@ func TestEmptyRangeProof(t *testing.T) { } } +// TestBloatedProof tests a malicious proof, where the proof is more or less the +// whole trie. Previously we didn't accept such packets, but the new APIs do, so +// lets leave this test as a bit weird, but present. +func TestBloatedProof(t *testing.T) { + // Use a small trie + trie, kvs := nonRandomTrie(100) + var entries entrySlice + for _, kv := range kvs { + entries = append(entries, kv) + } + sort.Sort(entries) + var keys [][]byte + var vals [][]byte + + proof := memorydb.New() + // In the 'malicious' case, we add proofs for every single item + // (but only one key/value pair used as leaf) + for i, entry := range entries { + trie.Prove(entry.k, 0, proof) + if i == 50 { + keys = append(keys, entry.k) + vals = append(vals, entry.v) + } + } + // For reference, we use the same function, but _only_ prove the first + // and last element + want := memorydb.New() + trie.Prove(keys[0], 0, want) + trie.Prove(keys[len(keys)-1], 0, want) + + if _, err := VerifyRangeProof(trie.Hash(), keys[0], keys[len(keys)-1], keys, vals, proof); err != nil { + t.Fatalf("expected bloated proof to succeed, got %v", err) + } +} + // TestEmptyValueRangeProof tests normal range proof with both edge proofs // as the existent proof, but with an extra empty value included, which is a // noop technically, but practically should be rejected. func TestEmptyValueRangeProof(t *testing.T) { - trie, values := randomTrie(t, scaleFactor) + trie, values := randomTrie(512) var entries entrySlice for _, kv := range values { - entries = append(entries, packEntry(kv)) + entries = append(entries, kv) } sort.Sort(entries) @@ -847,8 +848,8 @@ func TestEmptyValueRangeProof(t *testing.T) { break } } - noop := &entry{key, []byte{}} - entries = append(append(append(entrySlice{}, entries[:mid]...), noop), entries[mid:]...) + noop := &kv{key, []byte{}, false} + entries = append(append(append([]*kv{}, entries[:mid]...), noop), entries[mid:]...) start, end := 1, len(entries)-1 @@ -875,10 +876,10 @@ func TestEmptyValueRangeProof(t *testing.T) { // but with an extra empty value included, which is a noop technically, but // practically should be rejected. func TestAllElementsEmptyValueRangeProof(t *testing.T) { - trie, values := randomTrie(t, scaleFactor) + trie, values := randomTrie(512) var entries entrySlice for _, kv := range values { - entries = append(entries, packEntry(kv)) + entries = append(entries, kv) } sort.Sort(entries) @@ -891,8 +892,8 @@ func TestAllElementsEmptyValueRangeProof(t *testing.T) { break } } - noop := &entry{key, nil} - entries = append(append(append(entrySlice{}, entries[:mid]...), noop), entries[mid:]...) + noop := &kv{key, []byte{}, false} + entries = append(append(append([]*kv{}, entries[:mid]...), noop), entries[mid:]...) var keys [][]byte var vals [][]byte @@ -938,7 +939,7 @@ func decreaseKey(key []byte) []byte { } func BenchmarkProve(b *testing.B) { - trie, vals := randomTrie(b, 100) + trie, vals := randomTrie(100) var keys []string for k := range vals { keys = append(keys, k) @@ -955,7 +956,7 @@ func BenchmarkProve(b *testing.B) { } func BenchmarkVerifyProof(b *testing.B) { - trie, vals := randomTrie(b, 100) + trie, vals := randomTrie(100) root := trie.Hash() var keys []string var proofs []*memorydb.Database @@ -981,10 +982,10 @@ func BenchmarkVerifyRangeProof1000(b *testing.B) { benchmarkVerifyRangeProof(b, func BenchmarkVerifyRangeProof5000(b *testing.B) { benchmarkVerifyRangeProof(b, 5000) } func benchmarkVerifyRangeProof(b *testing.B, size int) { - trie, vals := randomTrie(b, 8192) + trie, vals := randomTrie(8192) var entries entrySlice for _, kv := range vals { - entries = append(entries, packEntry(kv)) + entries = append(entries, kv) } sort.Sort(entries) @@ -1018,10 +1019,10 @@ func BenchmarkVerifyRangeNoProof500(b *testing.B) { benchmarkVerifyRangeNoProof func BenchmarkVerifyRangeNoProof1000(b *testing.B) { benchmarkVerifyRangeNoProof(b, 1000) } func benchmarkVerifyRangeNoProof(b *testing.B, size int) { - trie, vals := randomTrie(b, size) + trie, vals := randomTrie(size) var entries entrySlice for _, kv := range vals { - entries = append(entries, packEntry(kv)) + entries = append(entries, kv) } sort.Sort(entries) @@ -1040,24 +1041,56 @@ func benchmarkVerifyRangeNoProof(b *testing.B, size int) { } } +func randomTrie(n int) (*Trie, map[string]*kv) { + trie := NewEmpty(NewDatabase(rawdb.NewMemoryDatabase())) + vals := make(map[string]*kv) + for i := byte(0); i < 100; i++ { + value := &kv{common.LeftPadBytes([]byte{i}, 32), []byte{i}, false} + value2 := &kv{common.LeftPadBytes([]byte{i + 10}, 32), []byte{i}, false} + trie.Update(value.k, value.v) + trie.Update(value2.k, value2.v) + vals[string(value.k)] = value + vals[string(value2.k)] = value2 + } + for i := 0; i < n; i++ { + value := &kv{randBytes(32), randBytes(20), false} + trie.Update(value.k, value.v) + vals[string(value.k)] = value + } + return trie, vals +} + +func nonRandomTrie(n int) (*Trie, map[string]*kv) { + trie := NewEmpty(NewDatabase(rawdb.NewMemoryDatabase())) + vals := make(map[string]*kv) + max := uint64(0xffffffffffffffff) + for i := uint64(0); i < uint64(n); i++ { + value := make([]byte, 32) + key := make([]byte, 32) + binary.LittleEndian.PutUint64(key, i) + binary.LittleEndian.PutUint64(value, i-max) + //value := &kv{common.LeftPadBytes([]byte{i}, 32), []byte{i}, false} + elem := &kv{key, value, false} + trie.Update(elem.k, elem.v) + vals[string(elem.k)] = elem + } + return trie, vals +} + func TestRangeProofKeysWithSharedPrefix(t *testing.T) { keys := [][]byte{ common.Hex2Bytes("aa10000000000000000000000000000000000000000000000000000000000000"), common.Hex2Bytes("aa20000000000000000000000000000000000000000000000000000000000000"), } vals := [][]byte{ - packValue(2), - packValue(3), + common.Hex2Bytes("02"), + common.Hex2Bytes("03"), } - edb := rawdb.NewMemoryDatabase() - db := geth_trie.NewDatabase(edb) - orig := geth_trie.NewEmpty(db) + trie := NewEmpty(NewDatabase(rawdb.NewMemoryDatabase())) for i, key := range keys { - orig.Update(key, vals[i]) + trie.Update(key, vals[i]) } - root := commitTrie(t, db, orig) - trie := indexTrie(t, edb, root) - + root := trie.Hash() proof := memorydb.New() start := common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000000") end := common.Hex2Bytes("ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff") diff --git a/trie_by_cid/trie/secure_trie.go b/trie_by_cid/trie/secure_trie.go index 6e4b4ca..25f1fd7 100644 --- a/trie_by_cid/trie/secure_trie.go +++ b/trie_by_cid/trie/secure_trie.go @@ -17,82 +17,88 @@ package trie import ( - "fmt" - "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/core/types" "github.com/ethereum/go-ethereum/rlp" - "github.com/ethereum/go-ethereum/statediff/indexer/ipld" + log "github.com/sirupsen/logrus" ) -// StateTrie wraps a trie with key hashing. In a secure trie, all +// StateTrie wraps a trie with key hashing. In a stateTrie trie, all // access operations hash the key using keccak256. This prevents // calling code from creating long chains of nodes that // increase the access time. // // Contrary to a regular trie, a StateTrie can only be created with -// New and must have an attached database. +// New and must have an attached database. The database also stores +// the preimage of each key if preimage recording is enabled. // // StateTrie is not safe for concurrent use. type StateTrie struct { - trie Trie - hashKeyBuf [common.HashLength]byte + trie Trie + preimages *preimageStore + hashKeyBuf [common.HashLength]byte + secKeyCache map[string][]byte + secKeyCacheOwner *StateTrie // Pointer to self, replace the key cache on mismatch } -// NewStateTrie creates a trie with an existing root node from a backing database -// and optional intermediate in-memory node pool. +// NewStateTrie creates a trie with an existing root node from a backing database. // // If root is the zero hash or the sha3 hash of an empty string, the // trie is initially empty. Otherwise, New will panic if db is nil // and returns MissingNodeError if the root node cannot be found. -// -// Accessing the trie loads nodes from the database or node pool on demand. -// Loaded nodes are kept around until their 'cache generation' expires. -// A new cache generation is created by each call to Commit. -// cachelimit sets the number of past cache generations to keep. -// -// Retrieves IPLD blocks by CID encoded as "eth-state-trie" -func NewStateTrie(owner common.Hash, root common.Hash, db *Database) (*StateTrie, error) { - return newStateTrie(owner, root, db, ipld.MEthStateTrie) -} - -// NewStorageTrie is identical to NewStateTrie, but retrieves IPLD blocks encoded -// as "eth-storage-trie" -func NewStorageTrie(owner common.Hash, root common.Hash, db *Database) (*StateTrie, error) { - return newStateTrie(owner, root, db, ipld.MEthStorageTrie) -} - -func newStateTrie(owner common.Hash, root common.Hash, db *Database, codec uint64) (*StateTrie, error) { +func NewStateTrie(id *ID, db *Database, codec uint64) (*StateTrie, error) { if db == nil { - panic("NewStateTrie called without a database") + panic("trie.NewStateTrie called without a database") } - trie, err := New(owner, root, db, codec) + trie, err := New(id, db, codec) if err != nil { return nil, err } - return &StateTrie{trie: *trie}, nil + return &StateTrie{trie: *trie, preimages: db.preimages}, nil +} + +// Get returns the value for key stored in the trie. +// The value bytes must not be modified by the caller. +func (t *StateTrie) Get(key []byte) []byte { + res, err := t.TryGet(key) + if err != nil { + log.Error("Unhandled trie error in StateTrie.Get", "err", err) + } + return res } // TryGet returns the value for key stored in the trie. // The value bytes must not be modified by the caller. -// If a node was not found in the database, a MissingNodeError is returned. +// If the specified node is not in the trie, nil will be returned. +// If a trie node is not found in the database, a MissingNodeError is returned. func (t *StateTrie) TryGet(key []byte) ([]byte, error) { return t.trie.TryGet(t.hashKey(key)) } -func (t *StateTrie) TryGetAccount(key []byte) (*types.StateAccount, error) { - var ret types.StateAccount - res, err := t.TryGet(key) - if err != nil { - // log.Error(fmt.Sprintf("Unhandled trie error: %v", err)) - panic(fmt.Sprintf("Unhandled trie error: %v", err)) - return &ret, err +// TryGetAccount attempts to retrieve an account with provided account address. +// If the specified account is not in the trie, nil will be returned. +// If a trie node is not found in the database, a MissingNodeError is returned. +func (t *StateTrie) TryGetAccount(address common.Address) (*types.StateAccount, error) { + res, err := t.trie.TryGet(t.hashKey(address.Bytes())) + if res == nil || err != nil { + return nil, err } - if res == nil { - return nil, nil + ret := new(types.StateAccount) + err = rlp.DecodeBytes(res, ret) + return ret, err +} + +// TryGetAccountByHash does the same thing as TryGetAccount, however +// it expects an account hash that is the hash of address. This constitutes an +// abstraction leak, since the client code needs to know the key format. +func (t *StateTrie) TryGetAccountByHash(addrHash common.Hash) (*types.StateAccount, error) { + res, err := t.trie.TryGet(addrHash.Bytes()) + if res == nil || err != nil { + return nil, err } - err = rlp.DecodeBytes(res, &ret) - return &ret, err + ret := new(types.StateAccount) + err = rlp.DecodeBytes(res, ret) + return ret, err } // TryGetNode attempts to retrieve a trie node by compact-encoded path. It is not @@ -103,12 +109,124 @@ func (t *StateTrie) TryGetNode(path []byte) ([]byte, int, error) { return t.trie.TryGetNode(path) } +// Update associates key with value in the trie. Subsequent calls to +// Get will return value. If value has length zero, any existing value +// is deleted from the trie and calls to Get will return nil. +// +// The value bytes must not be modified by the caller while they are +// stored in the trie. +func (t *StateTrie) Update(key, value []byte) { + if err := t.TryUpdate(key, value); err != nil { + log.Error("Unhandled trie error in StateTrie.Update", "err", err) + } +} + +// TryUpdate associates key with value in the trie. Subsequent calls to +// Get will return value. If value has length zero, any existing value +// is deleted from the trie and calls to Get will return nil. +// +// The value bytes must not be modified by the caller while they are +// stored in the trie. +// +// If a node is not found in the database, a MissingNodeError is returned. +func (t *StateTrie) TryUpdate(key, value []byte) error { + hk := t.hashKey(key) + err := t.trie.TryUpdate(hk, value) + if err != nil { + return err + } + t.getSecKeyCache()[string(hk)] = common.CopyBytes(key) + return nil +} + +// TryUpdateAccount account will abstract the write of an account to the +// secure trie. +func (t *StateTrie) TryUpdateAccount(address common.Address, acc *types.StateAccount) error { + hk := t.hashKey(address.Bytes()) + data, err := rlp.EncodeToBytes(acc) + if err != nil { + return err + } + if err := t.trie.TryUpdate(hk, data); err != nil { + return err + } + t.getSecKeyCache()[string(hk)] = address.Bytes() + return nil +} + +// Delete removes any existing value for key from the trie. +func (t *StateTrie) Delete(key []byte) { + if err := t.TryDelete(key); err != nil { + log.Error("Unhandled trie error in StateTrie.Delete", "err", err) + } +} + +// TryDelete removes any existing value for key from the trie. +// If the specified trie node is not in the trie, nothing will be changed. +// If a node is not found in the database, a MissingNodeError is returned. +func (t *StateTrie) TryDelete(key []byte) error { + hk := t.hashKey(key) + delete(t.getSecKeyCache(), string(hk)) + return t.trie.TryDelete(hk) +} + +// TryDeleteAccount abstracts an account deletion from the trie. +func (t *StateTrie) TryDeleteAccount(address common.Address) error { + hk := t.hashKey(address.Bytes()) + delete(t.getSecKeyCache(), string(hk)) + return t.trie.TryDelete(hk) +} + +// GetKey returns the sha3 preimage of a hashed key that was +// previously used to store a value. +func (t *StateTrie) GetKey(shaKey []byte) []byte { + if key, ok := t.getSecKeyCache()[string(shaKey)]; ok { + return key + } + if t.preimages == nil { + return nil + } + return t.preimages.preimage(common.BytesToHash(shaKey)) +} + +// Commit collects all dirty nodes in the trie and replaces them with the +// corresponding node hash. All collected nodes (including dirty leaves if +// collectLeaf is true) will be encapsulated into a nodeset for return. +// The returned nodeset can be nil if the trie is clean (nothing to commit). +// All cached preimages will be also flushed if preimages recording is enabled. +// Once the trie is committed, it's not usable anymore. A new trie must +// be created with new root and updated trie database for following usage +func (t *StateTrie) Commit(collectLeaf bool) (common.Hash, *NodeSet) { + // Write all the pre-images to the actual disk database + if len(t.getSecKeyCache()) > 0 { + if t.preimages != nil { + preimages := make(map[common.Hash][]byte) + for hk, key := range t.secKeyCache { + preimages[common.BytesToHash([]byte(hk))] = key + } + t.preimages.insertPreimage(preimages) + } + t.secKeyCache = make(map[string][]byte) + } + // Commit the trie and return its modified nodeset. + return t.trie.Commit(collectLeaf) +} + // Hash returns the root hash of StateTrie. It does not write to the // database and can be used even if the trie doesn't have one. func (t *StateTrie) Hash() common.Hash { return t.trie.Hash() } +// Copy returns a copy of StateTrie. +func (t *StateTrie) Copy() *StateTrie { + return &StateTrie{ + trie: *t.trie.Copy(), + preimages: t.preimages, + secKeyCache: t.secKeyCache, + } +} + // NodeIterator returns an iterator that returns nodes of the underlying trie. Iteration // starts at the key after the given start key. func (t *StateTrie) NodeIterator(start []byte) NodeIterator { @@ -126,3 +244,14 @@ func (t *StateTrie) hashKey(key []byte) []byte { returnHasherToPool(h) return t.hashKeyBuf[:] } + +// getSecKeyCache returns the current secure key cache, creating a new one if +// ownership changed (i.e. the current secure trie is a copy of another owning +// the actual cache). +func (t *StateTrie) getSecKeyCache() map[string][]byte { + if t != t.secKeyCacheOwner { + t.secKeyCacheOwner = t + t.secKeyCache = make(map[string][]byte) + } + return t.secKeyCache +} diff --git a/trie_by_cid/trie/secure_trie_test.go b/trie_by_cid/trie/secure_trie_test.go new file mode 100644 index 0000000..41edbc3 --- /dev/null +++ b/trie_by_cid/trie/secure_trie_test.go @@ -0,0 +1,148 @@ +// Copyright 2015 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package trie + +import ( + "bytes" + "fmt" + "runtime" + "sync" + "testing" + + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/core/rawdb" + "github.com/ethereum/go-ethereum/crypto" +) + +// makeTestStateTrie creates a large enough secure trie for testing. +func makeTestStateTrie() (*Database, *StateTrie, map[string][]byte) { + // Create an empty trie + triedb := NewDatabase(rawdb.NewMemoryDatabase()) + trie, _ := NewStateTrie(TrieID(common.Hash{}), triedb, StateTrieCodec) + + // Fill it with some arbitrary data + content := make(map[string][]byte) + for i := byte(0); i < 255; i++ { + // Map the same data under multiple keys + key, val := common.LeftPadBytes([]byte{1, i}, 32), []byte{i} + content[string(key)] = val + trie.Update(key, val) + + key, val = common.LeftPadBytes([]byte{2, i}, 32), []byte{i} + content[string(key)] = val + trie.Update(key, val) + + // Add some other data to inflate the trie + for j := byte(3); j < 13; j++ { + key, val = common.LeftPadBytes([]byte{j, i}, 32), []byte{j, i} + content[string(key)] = val + trie.Update(key, val) + } + } + root, nodes := trie.Commit(false) + if err := triedb.Update(NewWithNodeSet(nodes)); err != nil { + panic(fmt.Errorf("failed to commit db %v", err)) + } + // Re-create the trie based on the new state + trie, _ = NewStateTrie(TrieID(root), triedb, StateTrieCodec) + return triedb, trie, content +} + +func TestSecureDelete(t *testing.T) { + trie, err := NewStateTrie(TrieID(common.Hash{}), NewDatabase(rawdb.NewMemoryDatabase()), StateTrieCodec) + if err != nil { + t.Fatal(err) + } + vals := []struct{ k, v string }{ + {"do", "verb"}, + {"ether", "wookiedoo"}, + {"horse", "stallion"}, + {"shaman", "horse"}, + {"doge", "coin"}, + {"ether", ""}, + {"dog", "puppy"}, + {"shaman", ""}, + } + for _, val := range vals { + if val.v != "" { + trie.Update([]byte(val.k), []byte(val.v)) + } else { + trie.Delete([]byte(val.k)) + } + } + hash := trie.Hash() + exp := common.HexToHash("29b235a58c3c25ab83010c327d5932bcf05324b7d6b1185e650798034783ca9d") + if hash != exp { + t.Errorf("expected %x got %x", exp, hash) + } +} + +func TestSecureGetKey(t *testing.T) { + trie, err := NewStateTrie(TrieID(common.Hash{}), NewDatabase(rawdb.NewMemoryDatabase()), StateTrieCodec) + if err != nil { + t.Fatal(err) + } + trie.Update([]byte("foo"), []byte("bar")) + + key := []byte("foo") + value := []byte("bar") + seckey := crypto.Keccak256(key) + + if !bytes.Equal(trie.Get(key), value) { + t.Errorf("Get did not return bar") + } + if k := trie.GetKey(seckey); !bytes.Equal(k, key) { + t.Errorf("GetKey returned %q, want %q", k, key) + } +} + +func TestStateTrieConcurrency(t *testing.T) { + // Create an initial trie and copy if for concurrent access + _, trie, _ := makeTestStateTrie() + + threads := runtime.NumCPU() + tries := make([]*StateTrie, threads) + for i := 0; i < threads; i++ { + tries[i] = trie.Copy() + } + // Start a batch of goroutines interacting with the trie + pend := new(sync.WaitGroup) + pend.Add(threads) + for i := 0; i < threads; i++ { + go func(index int) { + defer pend.Done() + + for j := byte(0); j < 255; j++ { + // Map the same data under multiple keys + key, val := common.LeftPadBytes([]byte{byte(index), 1, j}, 32), []byte{j} + tries[index].Update(key, val) + + key, val = common.LeftPadBytes([]byte{byte(index), 2, j}, 32), []byte{j} + tries[index].Update(key, val) + + // Add some other data to inflate the trie + for k := byte(3); k < 13; k++ { + key, val = common.LeftPadBytes([]byte{byte(index), k, j}, 32), []byte{k, j} + tries[index].Update(key, val) + } + } + tries[index].Commit(false) + }(i) + } + // Wait for all threads to finish + pend.Wait() +} diff --git a/trie_by_cid/trie/tracer.go b/trie_by_cid/trie/tracer.go new file mode 100644 index 0000000..a27e371 --- /dev/null +++ b/trie_by_cid/trie/tracer.go @@ -0,0 +1,125 @@ +// Copyright 2022 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package trie + +import "github.com/ethereum/go-ethereum/common" + +// tracer tracks the changes of trie nodes. During the trie operations, +// some nodes can be deleted from the trie, while these deleted nodes +// won't be captured by trie.Hasher or trie.Committer. Thus, these deleted +// nodes won't be removed from the disk at all. Tracer is an auxiliary tool +// used to track all insert and delete operations of trie and capture all +// deleted nodes eventually. +// +// The changed nodes can be mainly divided into two categories: the leaf +// node and intermediate node. The former is inserted/deleted by callers +// while the latter is inserted/deleted in order to follow the rule of trie. +// This tool can track all of them no matter the node is embedded in its +// parent or not, but valueNode is never tracked. +// +// Besides, it's also used for recording the original value of the nodes +// when they are resolved from the disk. The pre-value of the nodes will +// be used to construct trie history in the future. +// +// Note tracer is not thread-safe, callers should be responsible for handling +// the concurrency issues by themselves. +type tracer struct { + inserts map[string]struct{} + deletes map[string]struct{} + accessList map[string][]byte +} + +// newTracer initializes the tracer for capturing trie changes. +func newTracer() *tracer { + return &tracer{ + inserts: make(map[string]struct{}), + deletes: make(map[string]struct{}), + accessList: make(map[string][]byte), + } +} + +// onRead tracks the newly loaded trie node and caches the rlp-encoded +// blob internally. Don't change the value outside of function since +// it's not deep-copied. +func (t *tracer) onRead(path []byte, val []byte) { + t.accessList[string(path)] = val +} + +// onInsert tracks the newly inserted trie node. If it's already +// in the deletion set (resurrected node), then just wipe it from +// the deletion set as it's "untouched". +func (t *tracer) onInsert(path []byte) { + if _, present := t.deletes[string(path)]; present { + delete(t.deletes, string(path)) + return + } + t.inserts[string(path)] = struct{}{} +} + +// onDelete tracks the newly deleted trie node. If it's already +// in the addition set, then just wipe it from the addition set +// as it's untouched. +func (t *tracer) onDelete(path []byte) { + if _, present := t.inserts[string(path)]; present { + delete(t.inserts, string(path)) + return + } + t.deletes[string(path)] = struct{}{} +} + +// reset clears the content tracked by tracer. +func (t *tracer) reset() { + t.inserts = make(map[string]struct{}) + t.deletes = make(map[string]struct{}) + t.accessList = make(map[string][]byte) +} + +// copy returns a deep copied tracer instance. +func (t *tracer) copy() *tracer { + var ( + inserts = make(map[string]struct{}) + deletes = make(map[string]struct{}) + accessList = make(map[string][]byte) + ) + for path := range t.inserts { + inserts[path] = struct{}{} + } + for path := range t.deletes { + deletes[path] = struct{}{} + } + for path, blob := range t.accessList { + accessList[path] = common.CopyBytes(blob) + } + return &tracer{ + inserts: inserts, + deletes: deletes, + accessList: accessList, + } +} + +// markDeletions puts all tracked deletions into the provided nodeset. +func (t *tracer) markDeletions(set *NodeSet) { + for path := range t.deletes { + // It's possible a few deleted nodes were embedded + // in their parent before, the deletions can be no + // effect by deleting nothing, filter them out. + if _, ok := set.accessList[path]; !ok { + continue + } + set.markDeleted([]byte(path)) + } +} diff --git a/trie_by_cid/trie/trie.go b/trie_by_cid/trie/trie.go index b2d8986..5ad22c1 100644 --- a/trie_by_cid/trie/trie.go +++ b/trie_by_cid/trie/trie.go @@ -7,18 +7,15 @@ import ( "fmt" "github.com/ethereum/go-ethereum/common" - "github.com/ethereum/go-ethereum/crypto" + "github.com/ethereum/go-ethereum/core/types" + log "github.com/sirupsen/logrus" - util "github.com/cerc-io/ipld-eth-statedb/internal" "github.com/ethereum/go-ethereum/statediff/indexer/ipld" ) var ( - // emptyRoot is the known root hash of an empty trie. - emptyRoot = common.HexToHash("56e81f171bcc55a6ff8345e692c0f86e5b48e01b996cadc001622fb5e363b421") - - // emptyState is the known hash of an empty state trie entry. - emptyState = crypto.Keccak256Hash(nil) + StateTrieCodec uint64 = ipld.MEthStateTrie + StorageTrieCodec uint64 = ipld.MEthStorageTrie ) // Trie is a Merkle Patricia Trie. Use New to create a trie that sits on @@ -37,29 +34,48 @@ type Trie struct { // actually unhashed nodes. unhashed int - // db is the handler trie can retrieve nodes from. It's - // only for reading purpose and not available for writing. - db *Database + // reader is the handler trie can retrieve nodes from. + reader *trieReader - // Multihash codec for key encoding - codec uint64 + // tracer is the tool to track the trie changes. + // It will be reset after each commit operation. + tracer *tracer } -// New creates a trie with an existing root node from db and an assigned -// owner for storage proximity. -// -// If root is the zero hash or the sha3 hash of an empty string, the -// trie is initially empty and does not require a database. Otherwise, -// New will panic if db is nil and returns a MissingNodeError if root does -// not exist in the database. Accessing the trie loads nodes from db on demand. -func New(owner common.Hash, root common.Hash, db *Database, codec uint64) (*Trie, error) { - trie := &Trie{ - owner: owner, - db: db, - codec: codec, +// newFlag returns the cache flag value for a newly created node. +func (t *Trie) newFlag() nodeFlag { + return nodeFlag{dirty: true} +} + +// Copy returns a copy of Trie. +func (t *Trie) Copy() *Trie { + return &Trie{ + root: t.root, + owner: t.owner, + unhashed: t.unhashed, + reader: t.reader, + tracer: t.tracer.copy(), } - if root != (common.Hash{}) && root != emptyRoot { - rootnode, err := trie.resolveHash(root[:], nil) +} + +// New creates a trie instance with the provided trie id and the read-only +// database. The state specified by trie id must be available, otherwise +// an error will be returned. The trie root specified by trie id can be +// zero hash or the sha3 hash of an empty string, then trie is initially +// empty, otherwise, the root node must be present in database or returns +// a MissingNodeError if not. +func New(id *ID, db NodeReader, codec uint64) (*Trie, error) { + reader, err := newTrieReader(id.StateRoot, id.Owner, db, codec) + if err != nil { + return nil, err + } + trie := &Trie{ + owner: id.Owner, + reader: reader, + tracer: newTracer(), + } + if id.Root != (common.Hash{}) && id.Root != types.EmptyRootHash { + rootnode, err := trie.resolveAndTrack(id.Root[:], nil) if err != nil { return nil, err } @@ -70,7 +86,10 @@ func New(owner common.Hash, root common.Hash, db *Database, codec uint64) (*Trie // NewEmpty is a shortcut to create empty tree. It's mostly used in tests. func NewEmpty(db *Database) *Trie { - tr, _ := New(common.Hash{}, common.Hash{}, db, ipld.MEthStateTrie) + tr, err := New(TrieID(common.Hash{}), db, StateTrieCodec) + if err != nil { + panic(err) + } return tr } @@ -80,6 +99,16 @@ func (t *Trie) NodeIterator(start []byte) NodeIterator { return newNodeIterator(t, start) } +// Get returns the value for key stored in the trie. +// The value bytes must not be modified by the caller. +func (t *Trie) Get(key []byte) []byte { + res, err := t.TryGet(key) + if err != nil { + log.Error("Unhandled trie error in Trie.Get", "err", err) + } + return res +} + // TryGet returns the value for key stored in the trie. // The value bytes must not be modified by the caller. // If a node was not found in the database, a MissingNodeError is returned. @@ -116,7 +145,7 @@ func (t *Trie) tryGet(origNode node, key []byte, pos int) (value []byte, newnode } return value, n, didResolve, err case hashNode: - child, err := t.resolveHash(n, key[:pos]) + child, err := t.resolveAndTrack(n, key[:pos]) if err != nil { return nil, n, true, err } @@ -130,7 +159,7 @@ func (t *Trie) tryGet(origNode node, key []byte, pos int) (value []byte, newnode // TryGetNode attempts to retrieve a trie node by compact-encoded path. It is not // possible to use keybyte-encoding as the path might contain odd nibbles. func (t *Trie) TryGetNode(path []byte) ([]byte, int, error) { - item, newroot, resolved, err := t.tryGetNode(t.root, CompactToHex(path), 0) + item, newroot, resolved, err := t.tryGetNode(t.root, compactToHex(path), 0) if err != nil { return nil, resolved, err } @@ -162,11 +191,7 @@ func (t *Trie) tryGetNode(origNode node, path []byte, pos int) (item []byte, new if hash == nil { return nil, origNode, 0, errors.New("non-consensus node") } - cid, err := util.Keccak256ToCid(t.codec, hash) - if err != nil { - return nil, origNode, 0, err - } - blob, err := t.db.Node(cid) + blob, err := t.reader.nodeBlob(path, common.BytesToHash(hash)) return blob, origNode, 1, err } // Path still needs to be traversed, descend into children @@ -196,7 +221,7 @@ func (t *Trie) tryGetNode(origNode node, path []byte, pos int) (item []byte, new return item, n, resolved, err case hashNode: - child, err := t.resolveHash(n, path[:pos]) + child, err := t.resolveAndTrack(n, path[:pos]) if err != nil { return nil, n, 1, err } @@ -208,42 +233,309 @@ func (t *Trie) tryGetNode(origNode node, path []byte, pos int) (item []byte, new } } -// resolveHash loads node from the underlying database with the provided -// node hash and path prefix. -func (t *Trie) resolveHash(n hashNode, prefix []byte) (node, error) { - cid, err := util.Keccak256ToCid(t.codec, n) - if err != nil { - return nil, err +// Update associates key with value in the trie. Subsequent calls to +// Get will return value. If value has length zero, any existing value +// is deleted from the trie and calls to Get will return nil. +// +// The value bytes must not be modified by the caller while they are +// stored in the trie. +func (t *Trie) Update(key, value []byte) { + if err := t.TryUpdate(key, value); err != nil { + log.Error("Unhandled trie error in Trie.Update", "err", err) } - enc, err := t.db.Node(cid) - if err != nil { - return nil, &MissingNodeError{Owner: t.owner, NodeHash: n, Path: prefix, err: err} - } - node, err := decodeNodeUnsafe(n, enc) - if err != nil { - return nil, err - } - if node != nil { - return node, nil - } - return nil, &MissingNodeError{Owner: t.owner, NodeHash: n, Path: prefix} } -// resolveHash loads rlp-encoded node blob from the underlying database -// with the provided node hash and path prefix. -func (t *Trie) resolveBlob(n hashNode, prefix []byte) ([]byte, error) { - cid, err := util.Keccak256ToCid(t.codec, n) +// TryUpdate associates key with value in the trie. Subsequent calls to +// Get will return value. If value has length zero, any existing value +// is deleted from the trie and calls to Get will return nil. +// +// The value bytes must not be modified by the caller while they are +// stored in the trie. +// +// If a node was not found in the database, a MissingNodeError is returned. +func (t *Trie) TryUpdate(key, value []byte) error { + return t.tryUpdate(key, value) +} + +// tryUpdate expects an RLP-encoded value and performs the core function +// for TryUpdate and TryUpdateAccount. +func (t *Trie) tryUpdate(key, value []byte) error { + t.unhashed++ + k := keybytesToHex(key) + if len(value) != 0 { + _, n, err := t.insert(t.root, nil, k, valueNode(value)) + if err != nil { + return err + } + t.root = n + } else { + _, n, err := t.delete(t.root, nil, k) + if err != nil { + return err + } + t.root = n + } + return nil +} + +func (t *Trie) insert(n node, prefix, key []byte, value node) (bool, node, error) { + if len(key) == 0 { + if v, ok := n.(valueNode); ok { + return !bytes.Equal(v, value.(valueNode)), value, nil + } + return true, value, nil + } + switch n := n.(type) { + case *shortNode: + matchlen := prefixLen(key, n.Key) + // If the whole key matches, keep this short node as is + // and only update the value. + if matchlen == len(n.Key) { + dirty, nn, err := t.insert(n.Val, append(prefix, key[:matchlen]...), key[matchlen:], value) + if !dirty || err != nil { + return false, n, err + } + return true, &shortNode{n.Key, nn, t.newFlag()}, nil + } + // Otherwise branch out at the index where they differ. + branch := &fullNode{flags: t.newFlag()} + var err error + _, branch.Children[n.Key[matchlen]], err = t.insert(nil, append(prefix, n.Key[:matchlen+1]...), n.Key[matchlen+1:], n.Val) + if err != nil { + return false, nil, err + } + _, branch.Children[key[matchlen]], err = t.insert(nil, append(prefix, key[:matchlen+1]...), key[matchlen+1:], value) + if err != nil { + return false, nil, err + } + // Replace this shortNode with the branch if it occurs at index 0. + if matchlen == 0 { + return true, branch, nil + } + // New branch node is created as a child of the original short node. + // Track the newly inserted node in the tracer. The node identifier + // passed is the path from the root node. + t.tracer.onInsert(append(prefix, key[:matchlen]...)) + + // Replace it with a short node leading up to the branch. + return true, &shortNode{key[:matchlen], branch, t.newFlag()}, nil + + case *fullNode: + dirty, nn, err := t.insert(n.Children[key[0]], append(prefix, key[0]), key[1:], value) + if !dirty || err != nil { + return false, n, err + } + n = n.copy() + n.flags = t.newFlag() + n.Children[key[0]] = nn + return true, n, nil + + case nil: + // New short node is created and track it in the tracer. The node identifier + // passed is the path from the root node. Note the valueNode won't be tracked + // since it's always embedded in its parent. + t.tracer.onInsert(prefix) + + return true, &shortNode{key, value, t.newFlag()}, nil + + case hashNode: + // We've hit a part of the trie that isn't loaded yet. Load + // the node and insert into it. This leaves all child nodes on + // the path to the value in the trie. + rn, err := t.resolveAndTrack(n, prefix) + if err != nil { + return false, nil, err + } + dirty, nn, err := t.insert(rn, prefix, key, value) + if !dirty || err != nil { + return false, rn, err + } + return true, nn, nil + + default: + panic(fmt.Sprintf("%T: invalid node: %v", n, n)) + } +} + +// Delete removes any existing value for key from the trie. +func (t *Trie) Delete(key []byte) { + if err := t.TryDelete(key); err != nil { + log.Error("Unhandled trie error in Trie.Delete", "err", err) + } +} + +// TryDelete removes any existing value for key from the trie. +// If a node was not found in the database, a MissingNodeError is returned. +func (t *Trie) TryDelete(key []byte) error { + t.unhashed++ + k := keybytesToHex(key) + _, n, err := t.delete(t.root, nil, k) + if err != nil { + return err + } + t.root = n + return nil +} + +// delete returns the new root of the trie with key deleted. +// It reduces the trie to minimal form by simplifying +// nodes on the way up after deleting recursively. +func (t *Trie) delete(n node, prefix, key []byte) (bool, node, error) { + switch n := n.(type) { + case *shortNode: + matchlen := prefixLen(key, n.Key) + if matchlen < len(n.Key) { + return false, n, nil // don't replace n on mismatch + } + if matchlen == len(key) { + // The matched short node is deleted entirely and track + // it in the deletion set. The same the valueNode doesn't + // need to be tracked at all since it's always embedded. + t.tracer.onDelete(prefix) + + return true, nil, nil // remove n entirely for whole matches + } + // The key is longer than n.Key. Remove the remaining suffix + // from the subtrie. Child can never be nil here since the + // subtrie must contain at least two other values with keys + // longer than n.Key. + dirty, child, err := t.delete(n.Val, append(prefix, key[:len(n.Key)]...), key[len(n.Key):]) + if !dirty || err != nil { + return false, n, err + } + switch child := child.(type) { + case *shortNode: + // The child shortNode is merged into its parent, track + // is deleted as well. + t.tracer.onDelete(append(prefix, n.Key...)) + + // Deleting from the subtrie reduced it to another + // short node. Merge the nodes to avoid creating a + // shortNode{..., shortNode{...}}. Use concat (which + // always creates a new slice) instead of append to + // avoid modifying n.Key since it might be shared with + // other nodes. + return true, &shortNode{concat(n.Key, child.Key...), child.Val, t.newFlag()}, nil + default: + return true, &shortNode{n.Key, child, t.newFlag()}, nil + } + + case *fullNode: + dirty, nn, err := t.delete(n.Children[key[0]], append(prefix, key[0]), key[1:]) + if !dirty || err != nil { + return false, n, err + } + n = n.copy() + n.flags = t.newFlag() + n.Children[key[0]] = nn + + // Because n is a full node, it must've contained at least two children + // before the delete operation. If the new child value is non-nil, n still + // has at least two children after the deletion, and cannot be reduced to + // a short node. + if nn != nil { + return true, n, nil + } + // Reduction: + // Check how many non-nil entries are left after deleting and + // reduce the full node to a short node if only one entry is + // left. Since n must've contained at least two children + // before deletion (otherwise it would not be a full node) n + // can never be reduced to nil. + // + // When the loop is done, pos contains the index of the single + // value that is left in n or -2 if n contains at least two + // values. + pos := -1 + for i, cld := range &n.Children { + if cld != nil { + if pos == -1 { + pos = i + } else { + pos = -2 + break + } + } + } + if pos >= 0 { + if pos != 16 { + // If the remaining entry is a short node, it replaces + // n and its key gets the missing nibble tacked to the + // front. This avoids creating an invalid + // shortNode{..., shortNode{...}}. Since the entry + // might not be loaded yet, resolve it just for this + // check. + cnode, err := t.resolve(n.Children[pos], append(prefix, byte(pos))) + if err != nil { + return false, nil, err + } + if cnode, ok := cnode.(*shortNode); ok { + // Replace the entire full node with the short node. + // Mark the original short node as deleted since the + // value is embedded into the parent now. + t.tracer.onDelete(append(prefix, byte(pos))) + + k := append([]byte{byte(pos)}, cnode.Key...) + return true, &shortNode{k, cnode.Val, t.newFlag()}, nil + } + } + // Otherwise, n is replaced by a one-nibble short node + // containing the child. + return true, &shortNode{[]byte{byte(pos)}, n.Children[pos], t.newFlag()}, nil + } + // n still contains at least two values and cannot be reduced. + return true, n, nil + + case valueNode: + return true, nil, nil + + case nil: + return false, nil, nil + + case hashNode: + // We've hit a part of the trie that isn't loaded yet. Load + // the node and delete from it. This leaves all child nodes on + // the path to the value in the trie. + rn, err := t.resolveAndTrack(n, prefix) + if err != nil { + return false, nil, err + } + dirty, nn, err := t.delete(rn, prefix, key) + if !dirty || err != nil { + return false, rn, err + } + return true, nn, nil + + default: + panic(fmt.Sprintf("%T: invalid node: %v (%v)", n, n, key)) + } +} + +func concat(s1 []byte, s2 ...byte) []byte { + r := make([]byte, len(s1)+len(s2)) + copy(r, s1) + copy(r[len(s1):], s2) + return r +} + +func (t *Trie) resolve(n node, prefix []byte) (node, error) { + if n, ok := n.(hashNode); ok { + return t.resolveAndTrack(n, prefix) + } + return n, nil +} + +// resolveAndTrack loads node from the underlying store with the given node hash +// and path prefix and also tracks the loaded node blob in tracer treated as the +// node's original value. The rlp-encoded blob is preferred to be loaded from +// database because it's easy to decode node while complex to encode node to blob. +func (t *Trie) resolveAndTrack(n hashNode, prefix []byte) (node, error) { + blob, err := t.reader.nodeBlob(prefix, common.BytesToHash(n)) if err != nil { return nil, err } - blob, err := t.db.Node(cid) - if err != nil { - return nil, err - } - if len(blob) != 0 { - return blob, nil - } - return nil, &MissingNodeError{Owner: t.owner, NodeHash: n, Path: prefix} + t.tracer.onRead(prefix, blob) + return mustDecodeNode(n, blob), nil } // Hash returns the root hash of the trie. It does not write to the @@ -254,15 +546,59 @@ func (t *Trie) Hash() common.Hash { return common.BytesToHash(hash.(hashNode)) } +// Commit collects all dirty nodes in the trie and replaces them with the +// corresponding node hash. All collected nodes (including dirty leaves if +// collectLeaf is true) will be encapsulated into a nodeset for return. +// The returned nodeset can be nil if the trie is clean (nothing to commit). +// Once the trie is committed, it's not usable anymore. A new trie must +// be created with new root and updated trie database for following usage +func (t *Trie) Commit(collectLeaf bool) (common.Hash, *NodeSet) { + defer t.tracer.reset() + + nodes := NewNodeSet(t.owner, t.tracer.accessList) + t.tracer.markDeletions(nodes) + + // Trie is empty and can be classified into two types of situations: + // - The trie was empty and no update happens + // - The trie was non-empty and all nodes are dropped + if t.root == nil { + return types.EmptyRootHash, nodes + } + // Derive the hash for all dirty nodes first. We hold the assumption + // in the following procedure that all nodes are hashed. + rootHash := t.Hash() + + // Do a quick check if we really need to commit. This can happen e.g. + // if we load a trie for reading storage values, but don't write to it. + if hashedNode, dirty := t.root.cache(); !dirty { + // Replace the root node with the origin hash in order to + // ensure all resolved nodes are dropped after the commit. + t.root = hashedNode + return rootHash, nil + } + t.root = newCommitter(nodes, collectLeaf).Commit(t.root) + return rootHash, nodes +} + // hashRoot calculates the root hash of the given trie func (t *Trie) hashRoot() (node, node) { if t.root == nil { - return hashNode(emptyRoot.Bytes()), nil + return hashNode(types.EmptyRootHash.Bytes()), nil } // If the number of changes is below 100, we let one thread handle it h := newHasher(t.unhashed >= 100) - defer returnHasherToPool(h) + defer func() { + returnHasherToPool(h) + t.unhashed = 0 + }() hashed, cached := h.hash(t.root, true) - t.unhashed = 0 return hashed, cached } + +// Reset drops the referenced root node and cleans all internal state. +func (t *Trie) Reset() { + t.root = nil + t.owner = common.Hash{} + t.unhashed = 0 + t.tracer.reset() +} diff --git a/trie_by_cid/trie/trie_id.go b/trie_by_cid/trie/trie_id.go new file mode 100644 index 0000000..8ab490c --- /dev/null +++ b/trie_by_cid/trie/trie_id.go @@ -0,0 +1,55 @@ +// Copyright 2022 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see + +package trie + +import "github.com/ethereum/go-ethereum/common" + +// ID is the identifier for uniquely identifying a trie. +type ID struct { + StateRoot common.Hash // The root of the corresponding state(block.root) + Owner common.Hash // The contract address hash which the trie belongs to + Root common.Hash // The root hash of trie +} + +// StateTrieID constructs an identifier for state trie with the provided state root. +func StateTrieID(root common.Hash) *ID { + return &ID{ + StateRoot: root, + Owner: common.Hash{}, + Root: root, + } +} + +// StorageTrieID constructs an identifier for storage trie which belongs to a certain +// state and contract specified by the stateRoot and owner. +func StorageTrieID(stateRoot common.Hash, owner common.Hash, root common.Hash) *ID { + return &ID{ + StateRoot: stateRoot, + Owner: owner, + Root: root, + } +} + +// TrieID constructs an identifier for a standard trie(not a second-layer trie) +// with provided root. It's mostly used in tests and some other tries like CHT trie. +func TrieID(root common.Hash) *ID { + return &ID{ + StateRoot: root, + Owner: common.Hash{}, + Root: root, + } +} diff --git a/trie_by_cid/trie/trie_reader.go b/trie_by_cid/trie/trie_reader.go new file mode 100644 index 0000000..b0a7fdd --- /dev/null +++ b/trie_by_cid/trie/trie_reader.go @@ -0,0 +1,106 @@ +// Copyright 2022 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package trie + +import ( + "fmt" + + "github.com/ethereum/go-ethereum/common" +) + +// Reader wraps the Node and NodeBlob method of a backing trie store. +type Reader interface { + // Node retrieves the trie node with the provided trie identifier, hexary + // node path and the corresponding node hash. + // No error will be returned if the node is not found. + Node(owner common.Hash, path []byte, hash common.Hash) (node, error) + + // NodeBlob retrieves the RLP-encoded trie node blob with the provided trie + // identifier, hexary node path and the corresponding node hash. + // No error will be returned if the node is not found. + NodeBlob(owner common.Hash, path []byte, hash common.Hash) ([]byte, error) +} + +// NodeReader wraps all the necessary functions for accessing trie node. +type NodeReader interface { + // GetReader returns a reader for accessing all trie nodes with provided + // state root. Nil is returned in case the state is not available. + GetReader(root common.Hash, codec uint64) Reader +} + +// trieReader is a wrapper of the underlying node reader. It's not safe +// for concurrent usage. +type trieReader struct { + owner common.Hash + reader Reader + banned map[string]struct{} // Marker to prevent node from being accessed, for tests +} + +// newTrieReader initializes the trie reader with the given node reader. +func newTrieReader(stateRoot, owner common.Hash, db NodeReader, codec uint64) (*trieReader, error) { + reader := db.GetReader(stateRoot, codec) + if reader == nil { + return nil, fmt.Errorf("state not found #%x", stateRoot) + } + return &trieReader{owner: owner, reader: reader}, nil +} + +// newEmptyReader initializes the pure in-memory reader. All read operations +// should be forbidden and returns the MissingNodeError. +func newEmptyReader() *trieReader { + return &trieReader{} +} + +// node retrieves the trie node with the provided trie node information. +// An MissingNodeError will be returned in case the node is not found or +// any error is encountered. +func (r *trieReader) node(path []byte, hash common.Hash) (node, error) { + // Perform the logics in tests for preventing trie node access. + if r.banned != nil { + if _, ok := r.banned[string(path)]; ok { + return nil, &MissingNodeError{Owner: r.owner, NodeHash: hash, Path: path} + } + } + if r.reader == nil { + return nil, &MissingNodeError{Owner: r.owner, NodeHash: hash, Path: path} + } + node, err := r.reader.Node(r.owner, path, hash) + if err != nil || node == nil { + return nil, &MissingNodeError{Owner: r.owner, NodeHash: hash, Path: path, err: err} + } + return node, nil +} + +// node retrieves the rlp-encoded trie node with the provided trie node +// information. An MissingNodeError will be returned in case the node is +// not found or any error is encountered. +func (r *trieReader) nodeBlob(path []byte, hash common.Hash) ([]byte, error) { + // Perform the logics in tests for preventing trie node access. + if r.banned != nil { + if _, ok := r.banned[string(path)]; ok { + return nil, &MissingNodeError{Owner: r.owner, NodeHash: hash, Path: path} + } + } + if r.reader == nil { + return nil, &MissingNodeError{Owner: r.owner, NodeHash: hash, Path: path} + } + blob, err := r.reader.NodeBlob(r.owner, path, hash) + if err != nil || len(blob) == 0 { + return nil, &MissingNodeError{Owner: r.owner, NodeHash: hash, Path: path, err: err} + } + return blob, nil +} diff --git a/trie_by_cid/trie/trie_test.go b/trie_by_cid/trie/trie_test.go index 0418409..76aaf72 100644 --- a/trie_by_cid/trie/trie_test.go +++ b/trie_by_cid/trie/trie_test.go @@ -14,26 +14,34 @@ // You should have received a copy of the GNU Lesser General Public License // along with the go-ethereum library. If not, see . -package trie_test +package trie import ( + "bytes" + "encoding/binary" + "errors" "fmt" "math/big" "math/rand" + "reflect" "testing" + "testing/quick" + "github.com/davecgh/go-spew/spew" "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/core/rawdb" "github.com/ethereum/go-ethereum/core/types" "github.com/ethereum/go-ethereum/crypto" "github.com/ethereum/go-ethereum/rlp" - geth_trie "github.com/ethereum/go-ethereum/trie" - - "github.com/cerc-io/ipld-eth-statedb/trie_by_cid/trie" ) -func TestTrieEmpty(t *testing.T) { - trie := trie.NewEmpty(trie.NewDatabase(rawdb.NewMemoryDatabase())) +func init() { + spew.Config.Indent = " " + spew.Config.DisableMethods = false +} + +func TestEmptyTrie(t *testing.T) { + trie := NewEmpty(NewDatabase(rawdb.NewMemoryDatabase())) res := trie.Hash() exp := types.EmptyRootHash if res != exp { @@ -41,42 +49,543 @@ func TestTrieEmpty(t *testing.T) { } } -func TestTrieMissingRoot(t *testing.T) { +func TestNull(t *testing.T) { + trie := NewEmpty(NewDatabase(rawdb.NewMemoryDatabase())) + key := make([]byte, 32) + value := []byte("test") + trie.Update(key, value) + if !bytes.Equal(trie.Get(key), value) { + t.Fatal("wrong value") + } +} + +func TestMissingRoot(t *testing.T) { root := common.HexToHash("0beec7b5ea3f0fdbc95d0dd47f3c5bc275da8a33") - tr, err := newStateTrie(root, trie.NewDatabase(rawdb.NewMemoryDatabase())) - if tr != nil { + trie, err := NewAccountTrie(TrieID(root), NewDatabase(rawdb.NewMemoryDatabase())) + if trie != nil { t.Error("New returned non-nil trie for invalid root") } - if _, ok := err.(*trie.MissingNodeError); !ok { + if _, ok := err.(*MissingNodeError); !ok { t.Errorf("New returned wrong error: %v", err) } } -func TestTrieBasic(t *testing.T) { - edb := rawdb.NewMemoryDatabase() - db := geth_trie.NewDatabase(edb) - origtrie := geth_trie.NewEmpty(db) - origtrie.Update([]byte("foo"), packValue(842)) - expected := commitTrie(t, db, origtrie) - tr := indexTrie(t, edb, expected) - got := tr.Hash() - if expected != got { - t.Errorf("got %x expected %x", got, expected) +func TestMissingNodeMemonly(t *testing.T) { testMissingNode(t, true) } + +func testMissingNode(t *testing.T, memonly bool) { + diskdb := rawdb.NewMemoryDatabase() + triedb := NewDatabase(diskdb) + + trie := NewEmpty(triedb) + updateString(trie, "120000", "qwerqwerqwerqwerqwerqwerqwerqwer") + updateString(trie, "123456", "asdfasdfasdfasdfasdfasdfasdfasdf") + root, nodes := trie.Commit(false) + triedb.Update(NewWithNodeSet(nodes)) + + trie, _ = NewAccountTrie(TrieID(root), triedb) + _, err := trie.TryGet([]byte("120000")) + if err != nil { + t.Errorf("Unexpected error: %v", err) + } + trie, _ = NewAccountTrie(TrieID(root), triedb) + _, err = trie.TryGet([]byte("120099")) + if err != nil { + t.Errorf("Unexpected error: %v", err) + } + trie, _ = NewAccountTrie(TrieID(root), triedb) + _, err = trie.TryGet([]byte("123456")) + if err != nil { + t.Errorf("Unexpected error: %v", err) + } + trie, _ = NewAccountTrie(TrieID(root), triedb) + err = trie.TryUpdate([]byte("120099"), []byte("zxcvzxcvzxcvzxcvzxcvzxcvzxcvzxcv")) + if err != nil { + t.Errorf("Unexpected error: %v", err) + } + trie, _ = NewAccountTrie(TrieID(root), triedb) + err = trie.TryDelete([]byte("123456")) + if err != nil { + t.Errorf("Unexpected error: %v", err) + } + + hash := common.HexToHash("0xe1d943cc8f061a0c0b98162830b970395ac9315654824bf21b73b891365262f9") + if memonly { + delete(triedb.dirties, hash) + } else { + diskdb.Delete(hash[:]) + } + + trie, _ = NewAccountTrie(TrieID(root), triedb) + _, err = trie.TryGet([]byte("120000")) + if _, ok := err.(*MissingNodeError); !ok { + t.Errorf("Wrong error: %v", err) + } + trie, _ = NewAccountTrie(TrieID(root), triedb) + _, err = trie.TryGet([]byte("120099")) + if _, ok := err.(*MissingNodeError); !ok { + t.Errorf("Wrong error: %v", err) + } + trie, _ = NewAccountTrie(TrieID(root), triedb) + _, err = trie.TryGet([]byte("123456")) + if err != nil { + t.Errorf("Unexpected error: %v", err) + } + trie, _ = NewAccountTrie(TrieID(root), triedb) + err = trie.TryUpdate([]byte("120099"), []byte("zxcv")) + if _, ok := err.(*MissingNodeError); !ok { + t.Errorf("Wrong error: %v", err) + } + trie, _ = NewAccountTrie(TrieID(root), triedb) + err = trie.TryDelete([]byte("123456")) + if _, ok := err.(*MissingNodeError); !ok { + t.Errorf("Wrong error: %v", err) } - checkValue(t, tr, []byte("foo")) } -func TestTrieTiny(t *testing.T) { +func TestInsert(t *testing.T) { + trie := NewEmpty(NewDatabase(rawdb.NewMemoryDatabase())) + + updateString(trie, "doe", "reindeer") + updateString(trie, "dog", "puppy") + updateString(trie, "dogglesworth", "cat") + + exp := common.HexToHash("8aad789dff2f538bca5d8ea56e8abe10f4c7ba3a5dea95fea4cd6e7c3a1168d3") + root := trie.Hash() + if root != exp { + t.Errorf("case 1: exp %x got %x", exp, root) + } + + trie = NewEmpty(NewDatabase(rawdb.NewMemoryDatabase())) + updateString(trie, "A", "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa") + + exp = common.HexToHash("d23786fb4a010da3ce639d66d5e904a11dbc02746d1ce25029e53290cabf28ab") + root, _ = trie.Commit(false) + if root != exp { + t.Errorf("case 2: exp %x got %x", exp, root) + } +} + +func TestGet(t *testing.T) { + db := NewDatabase(rawdb.NewMemoryDatabase()) + trie := NewEmpty(db) + updateString(trie, "doe", "reindeer") + updateString(trie, "dog", "puppy") + updateString(trie, "dogglesworth", "cat") + + for i := 0; i < 2; i++ { + res := getString(trie, "dog") + if !bytes.Equal(res, []byte("puppy")) { + t.Errorf("expected puppy got %x", res) + } + unknown := getString(trie, "unknown") + if unknown != nil { + t.Errorf("expected nil got %x", unknown) + } + if i == 1 { + return + } + root, nodes := trie.Commit(false) + db.Update(NewWithNodeSet(nodes)) + trie, _ = NewAccountTrie(TrieID(root), db) + } +} + +func TestDelete(t *testing.T) { + trie := NewEmpty(NewDatabase(rawdb.NewMemoryDatabase())) + vals := []struct{ k, v string }{ + {"do", "verb"}, + {"ether", "wookiedoo"}, + {"horse", "stallion"}, + {"shaman", "horse"}, + {"doge", "coin"}, + {"ether", ""}, + {"dog", "puppy"}, + {"shaman", ""}, + } + for _, val := range vals { + if val.v != "" { + updateString(trie, val.k, val.v) + } else { + deleteString(trie, val.k) + } + } + + hash := trie.Hash() + exp := common.HexToHash("5991bb8c6514148a29db676a14ac506cd2cd5775ace63c30a4fe457715e9ac84") + if hash != exp { + t.Errorf("expected %x got %x", exp, hash) + } +} + +func TestEmptyValues(t *testing.T) { + trie := NewEmpty(NewDatabase(rawdb.NewMemoryDatabase())) + + vals := []struct{ k, v string }{ + {"do", "verb"}, + {"ether", "wookiedoo"}, + {"horse", "stallion"}, + {"shaman", "horse"}, + {"doge", "coin"}, + {"ether", ""}, + {"dog", "puppy"}, + {"shaman", ""}, + } + for _, val := range vals { + updateString(trie, val.k, val.v) + } + + hash := trie.Hash() + exp := common.HexToHash("5991bb8c6514148a29db676a14ac506cd2cd5775ace63c30a4fe457715e9ac84") + if hash != exp { + t.Errorf("expected %x got %x", exp, hash) + } +} + +func TestReplication(t *testing.T) { + triedb := NewDatabase(rawdb.NewMemoryDatabase()) + trie := NewEmpty(triedb) + vals := []struct{ k, v string }{ + {"do", "verb"}, + {"ether", "wookiedoo"}, + {"horse", "stallion"}, + {"shaman", "horse"}, + {"doge", "coin"}, + {"dog", "puppy"}, + {"somethingveryoddindeedthis is", "myothernodedata"}, + } + for _, val := range vals { + updateString(trie, val.k, val.v) + } + exp, nodes := trie.Commit(false) + triedb.Update(NewWithNodeSet(nodes)) + + // create a new trie on top of the database and check that lookups work. + trie2, err := NewAccountTrie(TrieID(exp), triedb) + if err != nil { + t.Fatalf("can't recreate trie at %x: %v", exp, err) + } + for _, kv := range vals { + if string(getString(trie2, kv.k)) != kv.v { + t.Errorf("trie2 doesn't have %q => %q", kv.k, kv.v) + } + } + hash, nodes := trie2.Commit(false) + if hash != exp { + t.Errorf("root failure. expected %x got %x", exp, hash) + } + + // recreate the trie after commit + if nodes != nil { + triedb.Update(NewWithNodeSet(nodes)) + } + trie2, err = NewAccountTrie(TrieID(hash), triedb) + if err != nil { + t.Fatalf("can't recreate trie at %x: %v", exp, err) + } + // perform some insertions on the new trie. + vals2 := []struct{ k, v string }{ + {"do", "verb"}, + {"ether", "wookiedoo"}, + {"horse", "stallion"}, + // {"shaman", "horse"}, + // {"doge", "coin"}, + // {"ether", ""}, + // {"dog", "puppy"}, + // {"somethingveryoddindeedthis is", "myothernodedata"}, + // {"shaman", ""}, + } + for _, val := range vals2 { + updateString(trie2, val.k, val.v) + } + if hash := trie2.Hash(); hash != exp { + t.Errorf("root failure. expected %x got %x", exp, hash) + } +} + +func TestLargeValue(t *testing.T) { + trie := NewEmpty(NewDatabase(rawdb.NewMemoryDatabase())) + trie.Update([]byte("key1"), []byte{99, 99, 99, 99}) + trie.Update([]byte("key2"), bytes.Repeat([]byte{1}, 32)) + trie.Hash() +} + +// TestRandomCases tests some cases that were found via random fuzzing +func TestRandomCases(t *testing.T) { + var rt = []randTestStep{ + {op: 6, key: common.Hex2Bytes(""), value: common.Hex2Bytes("")}, // step 0 + {op: 6, key: common.Hex2Bytes(""), value: common.Hex2Bytes("")}, // step 1 + {op: 0, key: common.Hex2Bytes("d51b182b95d677e5f1c82508c0228de96b73092d78ce78b2230cd948674f66fd1483bd"), value: common.Hex2Bytes("0000000000000002")}, // step 2 + {op: 2, key: common.Hex2Bytes("c2a38512b83107d665c65235b0250002882ac2022eb00711552354832c5f1d030d0e408e"), value: common.Hex2Bytes("")}, // step 3 + {op: 3, key: common.Hex2Bytes(""), value: common.Hex2Bytes("")}, // step 4 + {op: 3, key: common.Hex2Bytes(""), value: common.Hex2Bytes("")}, // step 5 + {op: 6, key: common.Hex2Bytes(""), value: common.Hex2Bytes("")}, // step 6 + {op: 3, key: common.Hex2Bytes(""), value: common.Hex2Bytes("")}, // step 7 + {op: 0, key: common.Hex2Bytes("c2a38512b83107d665c65235b0250002882ac2022eb00711552354832c5f1d030d0e408e"), value: common.Hex2Bytes("0000000000000008")}, // step 8 + {op: 0, key: common.Hex2Bytes("d51b182b95d677e5f1c82508c0228de96b73092d78ce78b2230cd948674f66fd1483bd"), value: common.Hex2Bytes("0000000000000009")}, // step 9 + {op: 2, key: common.Hex2Bytes("fd"), value: common.Hex2Bytes("")}, // step 10 + {op: 6, key: common.Hex2Bytes(""), value: common.Hex2Bytes("")}, // step 11 + {op: 6, key: common.Hex2Bytes(""), value: common.Hex2Bytes("")}, // step 12 + {op: 0, key: common.Hex2Bytes("fd"), value: common.Hex2Bytes("000000000000000d")}, // step 13 + {op: 6, key: common.Hex2Bytes(""), value: common.Hex2Bytes("")}, // step 14 + {op: 1, key: common.Hex2Bytes("c2a38512b83107d665c65235b0250002882ac2022eb00711552354832c5f1d030d0e408e"), value: common.Hex2Bytes("")}, // step 15 + {op: 3, key: common.Hex2Bytes(""), value: common.Hex2Bytes("")}, // step 16 + {op: 0, key: common.Hex2Bytes("c2a38512b83107d665c65235b0250002882ac2022eb00711552354832c5f1d030d0e408e"), value: common.Hex2Bytes("0000000000000011")}, // step 17 + {op: 5, key: common.Hex2Bytes(""), value: common.Hex2Bytes("")}, // step 18 + {op: 3, key: common.Hex2Bytes(""), value: common.Hex2Bytes("")}, // step 19 + {op: 0, key: common.Hex2Bytes("d51b182b95d677e5f1c82508c0228de96b73092d78ce78b2230cd948674f66fd1483bd"), value: common.Hex2Bytes("0000000000000014")}, // step 20 + {op: 0, key: common.Hex2Bytes("d51b182b95d677e5f1c82508c0228de96b73092d78ce78b2230cd948674f66fd1483bd"), value: common.Hex2Bytes("0000000000000015")}, // step 21 + {op: 0, key: common.Hex2Bytes("c2a38512b83107d665c65235b0250002882ac2022eb00711552354832c5f1d030d0e408e"), value: common.Hex2Bytes("0000000000000016")}, // step 22 + {op: 5, key: common.Hex2Bytes(""), value: common.Hex2Bytes("")}, // step 23 + {op: 1, key: common.Hex2Bytes("980c393656413a15c8da01978ed9f89feb80b502f58f2d640e3a2f5f7a99a7018f1b573befd92053ac6f78fca4a87268"), value: common.Hex2Bytes("")}, // step 24 + {op: 1, key: common.Hex2Bytes("fd"), value: common.Hex2Bytes("")}, // step 25 + } + runRandTest(rt) +} + +// randTest performs random trie operations. +// Instances of this test are created by Generate. +type randTest []randTestStep + +type randTestStep struct { + op int + key []byte // for opUpdate, opDelete, opGet + value []byte // for opUpdate + err error // for debugging +} + +const ( + opUpdate = iota + opDelete + opGet + opHash + opCommit + opItercheckhash + opNodeDiff + opProve + opMax // boundary value, not an actual op +) + +func (randTest) Generate(r *rand.Rand, size int) reflect.Value { + var allKeys [][]byte + genKey := func() []byte { + if len(allKeys) < 2 || r.Intn(100) < 10 { + // new key + key := make([]byte, r.Intn(50)) + r.Read(key) + allKeys = append(allKeys, key) + return key + } + // use existing key + return allKeys[r.Intn(len(allKeys))] + } + + var steps randTest + for i := 0; i < size; i++ { + step := randTestStep{op: r.Intn(opMax)} + switch step.op { + case opUpdate: + step.key = genKey() + step.value = make([]byte, 8) + binary.BigEndian.PutUint64(step.value, uint64(i)) + case opGet, opDelete, opProve: + step.key = genKey() + } + steps = append(steps, step) + } + return reflect.ValueOf(steps) +} + +func verifyAccessList(old *Trie, new *Trie, set *NodeSet) error { + deletes, inserts, updates := diffTries(old, new) + + // Check insertion set + for path := range inserts { + n, ok := set.nodes[path] + if !ok || n.isDeleted() { + return errors.New("expect new node") + } + _, ok = set.accessList[path] + if ok { + return errors.New("unexpected origin value") + } + } + // Check deletion set + for path, blob := range deletes { + n, ok := set.nodes[path] + if !ok || !n.isDeleted() { + return errors.New("expect deleted node") + } + v, ok := set.accessList[path] + if !ok { + return errors.New("expect origin value") + } + if !bytes.Equal(v, blob) { + return errors.New("invalid origin value") + } + } + // Check update set + for path, blob := range updates { + n, ok := set.nodes[path] + if !ok || n.isDeleted() { + return errors.New("expect updated node") + } + v, ok := set.accessList[path] + if !ok { + return errors.New("expect origin value") + } + if !bytes.Equal(v, blob) { + return errors.New("invalid origin value") + } + } + return nil +} + +func runRandTest(rt randTest) bool { + var ( + triedb = NewDatabase(rawdb.NewMemoryDatabase()) + tr = NewEmpty(triedb) + values = make(map[string]string) // tracks content of the trie + origTrie = NewEmpty(triedb) + ) + for i, step := range rt { + // fmt.Printf("{op: %d, key: common.Hex2Bytes(\"%x\"), value: common.Hex2Bytes(\"%x\")}, // step %d\n", + // step.op, step.key, step.value, i) + + switch step.op { + case opUpdate: + tr.Update(step.key, step.value) + values[string(step.key)] = string(step.value) + case opDelete: + tr.Delete(step.key) + delete(values, string(step.key)) + case opGet: + v := tr.Get(step.key) + want := values[string(step.key)] + if string(v) != want { + rt[i].err = fmt.Errorf("mismatch for key %#x, got %#x want %#x", step.key, v, want) + } + case opProve: + hash := tr.Hash() + if hash == types.EmptyRootHash { + continue + } + proofDb := rawdb.NewMemoryDatabase() + err := tr.Prove(step.key, 0, proofDb) + if err != nil { + rt[i].err = fmt.Errorf("failed for proving key %#x, %v", step.key, err) + } + _, err = VerifyProof(hash, step.key, proofDb) + if err != nil { + rt[i].err = fmt.Errorf("failed for verifying key %#x, %v", step.key, err) + } + case opHash: + tr.Hash() + case opCommit: + root, nodes := tr.Commit(true) + if nodes != nil { + triedb.Update(NewWithNodeSet(nodes)) + } + newtr, err := NewAccountTrie(TrieID(root), triedb) + if err != nil { + rt[i].err = err + return false + } + if nodes != nil { + if err := verifyAccessList(origTrie, newtr, nodes); err != nil { + rt[i].err = err + return false + } + } + tr = newtr + origTrie = tr.Copy() + case opItercheckhash: + checktr := NewEmpty(triedb) + it := NewIterator(tr.NodeIterator(nil)) + for it.Next() { + checktr.Update(it.Key, it.Value) + } + if tr.Hash() != checktr.Hash() { + rt[i].err = fmt.Errorf("hash mismatch in opItercheckhash") + } + case opNodeDiff: + var ( + origIter = origTrie.NodeIterator(nil) + curIter = tr.NodeIterator(nil) + origSeen = make(map[string]struct{}) + curSeen = make(map[string]struct{}) + ) + for origIter.Next(true) { + if origIter.Leaf() { + continue + } + origSeen[string(origIter.Path())] = struct{}{} + } + for curIter.Next(true) { + if curIter.Leaf() { + continue + } + curSeen[string(curIter.Path())] = struct{}{} + } + var ( + insertExp = make(map[string]struct{}) + deleteExp = make(map[string]struct{}) + ) + for path := range curSeen { + _, present := origSeen[path] + if !present { + insertExp[path] = struct{}{} + } + } + for path := range origSeen { + _, present := curSeen[path] + if !present { + deleteExp[path] = struct{}{} + } + } + if len(insertExp) != len(tr.tracer.inserts) { + rt[i].err = fmt.Errorf("insert set mismatch") + } + if len(deleteExp) != len(tr.tracer.deletes) { + rt[i].err = fmt.Errorf("delete set mismatch") + } + for insert := range tr.tracer.inserts { + if _, present := insertExp[insert]; !present { + rt[i].err = fmt.Errorf("missing inserted node") + } + } + for del := range tr.tracer.deletes { + if _, present := deleteExp[del]; !present { + rt[i].err = fmt.Errorf("missing deleted node") + } + } + } + // Abort the test on error. + if rt[i].err != nil { + return false + } + } + return true +} + +func TestRandom(t *testing.T) { + if err := quick.Check(runRandTest, nil); err != nil { + if cerr, ok := err.(*quick.CheckError); ok { + t.Fatalf("random test iteration %d failed: %s", cerr.Count, spew.Sdump(cerr.In)) + } + t.Fatal(err) + } +} + +func TestTinyTrie(t *testing.T) { // Create a realistic account trie to hash _, accounts := makeAccounts(5) - edb := rawdb.NewMemoryDatabase() - db := geth_trie.NewDatabase(edb) - origtrie := geth_trie.NewEmpty(db) + trie := NewEmpty(NewDatabase(rawdb.NewMemoryDatabase())) type testCase struct { key, account []byte root common.Hash } + cases := []testCase{ { common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000001337"), @@ -92,48 +601,48 @@ func TestTrieTiny(t *testing.T) { common.HexToHash("0608c1d1dc3905fa22204c7a0e43644831c3b6d3def0f274be623a948197e64a"), }, } - for i, tc := range cases { - t.Run(fmt.Sprintf("case %d", i), func(t *testing.T) { - origtrie.Update(tc.key, tc.account) - trie := indexTrie(t, edb, commitTrie(t, db, origtrie)) - if exp, root := tc.root, trie.Hash(); exp != root { - t.Errorf("got %x, exp %x", root, exp) - } - checkValue(t, trie, tc.key) - }) + for i, c := range cases { + trie.Update(c.key, c.account) + root := trie.Hash() + if root != c.root { + t.Errorf("case %d: got %x, exp %x", i, root, c.root) + } + } + checktr := NewEmpty(NewDatabase(rawdb.NewMemoryDatabase())) + it := NewIterator(trie.NodeIterator(nil)) + for it.Next() { + checktr.Update(it.Key, it.Value) + } + if troot, itroot := trie.Hash(), checktr.Hash(); troot != itroot { + t.Fatalf("hash mismatch in opItercheckhash, trie: %x, check: %x", troot, itroot) } } -func TestTrieMedium(t *testing.T) { +func TestCommitAfterHash(t *testing.T) { // Create a realistic account trie to hash addresses, accounts := makeAccounts(1000) - edb := rawdb.NewMemoryDatabase() - db := geth_trie.NewDatabase(edb) - origtrie := geth_trie.NewEmpty(db) - var keys [][]byte + trie := NewEmpty(NewDatabase(rawdb.NewMemoryDatabase())) for i := 0; i < len(addresses); i++ { - key := crypto.Keccak256(addresses[i][:]) - if i%50 == 0 { - keys = append(keys, key) - } - origtrie.Update(key, accounts[i]) + trie.Update(crypto.Keccak256(addresses[i][:]), accounts[i]) } - tr := indexTrie(t, edb, commitTrie(t, db, origtrie)) - - root := tr.Hash() + // Insert the accounts into the trie and hash it + trie.Hash() + trie.Commit(false) + root := trie.Hash() exp := common.HexToHash("72f9d3f3fe1e1dd7b8936442e7642aef76371472d94319900790053c493f3fe6") if exp != root { t.Errorf("got %x, exp %x", root, exp) } - - for _, key := range keys { - checkValue(t, tr, key) + root, _ = trie.Commit(false) + if exp != root { + t.Errorf("got %x, exp %x", root, exp) } } -// Make deterministically random accounts func makeAccounts(size int) (addresses [][20]byte, accounts [][]byte) { + // Make the random benchmark deterministic random := rand.New(rand.NewSource(0)) + // Create a realistic account trie to hash addresses = make([][20]byte, size) for i := 0; i < len(addresses); i++ { data := make([]byte, 20) @@ -149,25 +658,40 @@ func makeAccounts(size int) (addresses [][20]byte, accounts [][]byte) { ) // The big.Rand function is not deterministic with regards to 64 vs 32 bit systems, // and will consume different amount of data from the rand source. - // balance = new(big.Int).Rand(random, new(big.Int).Exp(common.Big2, common.Big256, nil)) + //balance = new(big.Int).Rand(random, new(big.Int).Exp(common.Big2, common.Big256, nil)) // Therefore, we instead just read via byte buffer numBytes := random.Uint32() % 33 // [0, 32] bytes balanceBytes := make([]byte, numBytes) random.Read(balanceBytes) balance := new(big.Int).SetBytes(balanceBytes) - acct := &types.StateAccount{Nonce: nonce, Balance: balance, Root: root, CodeHash: code} - data, _ := rlp.EncodeToBytes(acct) + data, _ := rlp.EncodeToBytes(&types.StateAccount{Nonce: nonce, Balance: balance, Root: root, CodeHash: code}) accounts[i] = data } return addresses, accounts } -func checkValue(t *testing.T, tr *trie.Trie, key []byte) { - val, err := tr.TryGet(key) - if err != nil { - t.Fatalf("error getting node: %s", err) - } - if len(val) == 0 { - t.Errorf("failed to get value for %x", key) +func getString(trie *Trie, k string) []byte { + return trie.Get([]byte(k)) +} + +func updateString(trie *Trie, k, v string) { + trie.Update([]byte(k), []byte(v)) +} + +func deleteString(trie *Trie, k string) { + trie.Delete([]byte(k)) +} + +func TestDecodeNode(t *testing.T) { + t.Parallel() + + var ( + hash = make([]byte, 20) + elems = make([]byte, 20) + ) + for i := 0; i < 5000000; i++ { + prng.Read(hash) + prng.Read(elems) + decodeNode(hash, elems) } } diff --git a/trie_by_cid/trie/util_test.go b/trie_by_cid/trie/util_test.go index 3272826..a75ac8c 100644 --- a/trie_by_cid/trie/util_test.go +++ b/trie_by_cid/trie/util_test.go @@ -1,43 +1,133 @@ -package trie_test +package trie import ( + "bytes" + "context" "fmt" "math/big" "math/rand" "testing" - "time" - - "github.com/jmoiron/sqlx" "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/core/rawdb" - geth_state "github.com/ethereum/go-ethereum/core/state" + gethstate "github.com/ethereum/go-ethereum/core/state" "github.com/ethereum/go-ethereum/core/types" "github.com/ethereum/go-ethereum/ethdb" "github.com/ethereum/go-ethereum/rlp" - geth_trie "github.com/ethereum/go-ethereum/trie" + gethtrie "github.com/ethereum/go-ethereum/trie" + "github.com/jmoiron/sqlx" pgipfsethdb "github.com/cerc-io/ipfs-ethdb/v5/postgres/v0" - "github.com/cerc-io/ipld-eth-statedb/trie_by_cid/helper" - "github.com/cerc-io/ipld-eth-statedb/trie_by_cid/state" - "github.com/cerc-io/ipld-eth-statedb/trie_by_cid/trie" "github.com/ethereum/go-ethereum/statediff/indexer/database/sql/postgres" - "github.com/ethereum/go-ethereum/statediff/indexer/ipld" "github.com/ethereum/go-ethereum/statediff/test_helpers" + + "github.com/cerc-io/ipld-eth-statedb/internal" + "github.com/cerc-io/ipld-eth-statedb/trie_by_cid/helper" ) -type kv struct { +var ( + dbConfig, _ = postgres.DefaultConfig.WithEnv() + trieConfig = Config{Cache: 256} +) + +type kvi struct { k []byte v int64 } -type kvMap map[string]*kv +type kvMap map[string]*kvi -type kvs struct { +type kvsi struct { k string v int64 } +// NewAccountTrie is a shortcut to create a trie using the StateTrieCodec (ie. IPLD MEthStateTrie codec). +func NewAccountTrie(id *ID, db NodeReader) (*Trie, error) { + return New(id, db, StateTrieCodec) +} + +// makeTestTrie create a sample test trie to test node-wise reconstruction. +func makeTestTrie(t testing.TB) (*Database, *StateTrie, map[string][]byte) { + // Create an empty trie + triedb := NewDatabase(rawdb.NewMemoryDatabase()) + trie, err := NewStateTrie(TrieID(common.Hash{}), triedb, StateTrieCodec) + if err != nil { + t.Fatal(err) + } + + // Fill it with some arbitrary data + content := make(map[string][]byte) + for i := byte(0); i < 255; i++ { + // Map the same data under multiple keys + key, val := common.LeftPadBytes([]byte{1, i}, 32), []byte{i} + content[string(key)] = val + trie.Update(key, val) + + key, val = common.LeftPadBytes([]byte{2, i}, 32), []byte{i} + content[string(key)] = val + trie.Update(key, val) + + // Add some other data to inflate the trie + for j := byte(3); j < 13; j++ { + key, val = common.LeftPadBytes([]byte{j, i}, 32), []byte{j, i} + content[string(key)] = val + trie.Update(key, val) + } + } + root, nodes := trie.Commit(false) + if err := triedb.Update(NewWithNodeSet(nodes)); err != nil { + panic(fmt.Errorf("failed to commit db %v", err)) + } + // Re-create the trie based on the new state + trie, err = NewStateTrie(TrieID(root), triedb, StateTrieCodec) + if err != nil { + t.Fatal(err) + } + return triedb, trie, content +} + +func forHashedNodes(tr *Trie) map[string][]byte { + var ( + it = tr.NodeIterator(nil) + nodes = make(map[string][]byte) + ) + for it.Next(true) { + if it.Hash() == (common.Hash{}) { + continue + } + nodes[string(it.Path())] = common.CopyBytes(it.NodeBlob()) + } + return nodes +} + +func diffTries(trieA, trieB *Trie) (map[string][]byte, map[string][]byte, map[string][]byte) { + var ( + nodesA = forHashedNodes(trieA) + nodesB = forHashedNodes(trieB) + inA = make(map[string][]byte) // hashed nodes in trie a but not b + inB = make(map[string][]byte) // hashed nodes in trie b but not a + both = make(map[string][]byte) // hashed nodes in both tries but different value + ) + for path, blobA := range nodesA { + if blobB, ok := nodesB[path]; ok { + if bytes.Equal(blobA, blobB) { + continue + } + both[path] = blobA + continue + } + inA[path] = blobA + } + for path, blobB := range nodesB { + if _, ok := nodesA[path]; ok { + continue + } + inB[path] = blobB + } + return inA, inB, both +} + func packValue(val int64) []byte { acct := &types.StateAccount{ Balance: big.NewInt(val), @@ -51,27 +141,19 @@ func packValue(val int64) []byte { return acct_rlp } -func unpackValue(val []byte) int64 { - var acct types.StateAccount - if err := rlp.DecodeBytes(val, &acct); err != nil { - panic(err) - } - return acct.Balance.Int64() -} - -func updateTrie(tr *geth_trie.Trie, vals []kvs) (kvMap, error) { +func updateTrie(tr *gethtrie.Trie, vals []kvsi) (kvMap, error) { all := kvMap{} for _, val := range vals { - all[string(val.k)] = &kv{[]byte(val.k), val.v} + all[string(val.k)] = &kvi{[]byte(val.k), val.v} tr.Update([]byte(val.k), packValue(val.v)) } return all, nil } -func commitTrie(t testing.TB, db *geth_trie.Database, tr *geth_trie.Trie) common.Hash { +func commitTrie(t testing.TB, db *gethtrie.Database, tr *gethtrie.Trie) common.Hash { t.Helper() root, nodes := tr.Commit(false) - if err := db.Update(geth_trie.NewWithNodeSet(nodes)); err != nil { + if err := db.Update(gethtrie.NewWithNodeSet(nodes)); err != nil { t.Fatal(err) } if err := db.Commit(root, false); err != nil { @@ -80,16 +162,8 @@ func commitTrie(t testing.TB, db *geth_trie.Database, tr *geth_trie.Trie) common return root } -// commit a LevelDB state trie, index to IPLD and return new trie -func indexTrie(t testing.TB, edb ethdb.Database, root common.Hash) *trie.Trie { - t.Helper() - dbConfig.Driver = postgres.PGX - err := helper.IndexChain(dbConfig, geth_state.NewDatabase(edb), common.Hash{}, root) - if err != nil { - t.Fatal(err) - } - - pg_db, err := postgres.ConnectSQLX(ctx, dbConfig) +func makePgIpfsEthDB(t testing.TB) ethdb.Database { + pg_db, err := postgres.ConnectSQLX(context.Background(), dbConfig) if err != nil { t.Fatal(err) } @@ -98,62 +172,48 @@ func indexTrie(t testing.TB, edb ethdb.Database, root common.Hash) *trie.Trie { t.Fatal(err) } }) + return pgipfsethdb.NewDatabase(pg_db, internal.MakeCacheConfig(t)) +} - ipfs_db := pgipfsethdb.NewDatabase(pg_db, makeCacheConfig(t)) - sdb_db := state.NewDatabase(ipfs_db) - tr, err := newStateTrie(root, sdb_db.TrieDB()) +// commit a LevelDB state trie, index to IPLD and return new trie +func indexTrie(t testing.TB, edb ethdb.Database, root common.Hash) *Trie { + t.Helper() + dbConfig.Driver = postgres.PGX + err := helper.IndexStateDiff(dbConfig, gethstate.NewDatabase(edb), common.Hash{}, root) + if err != nil { + t.Fatal(err) + } + + ipfs_db := makePgIpfsEthDB(t) + tr, err := New(TrieID(root), NewDatabase(ipfs_db), StateTrieCodec) if err != nil { t.Fatal(err) } return tr } -func newStateTrie(root common.Hash, db *trie.Database) (*trie.Trie, error) { - tr, err := trie.New(common.Hash{}, root, db, ipld.MEthStateTrie) - if err != nil { - return nil, err - } - return tr, nil -} - // generates a random Geth LevelDB trie of n key-value pairs and corresponding value map -func randomGethTrie(n int, db *geth_trie.Database) (*geth_trie.Trie, kvMap) { - trie := geth_trie.NewEmpty(db) - var vals []*kv +func randomGethTrie(n int, db *gethtrie.Database) (*gethtrie.Trie, kvMap) { + trie := gethtrie.NewEmpty(db) + var vals []*kvi for i := byte(0); i < 100; i++ { - e := &kv{common.LeftPadBytes([]byte{i}, 32), int64(i)} - e2 := &kv{common.LeftPadBytes([]byte{i + 10}, 32), int64(i)} + e := &kvi{common.LeftPadBytes([]byte{i}, 32), int64(i)} + e2 := &kvi{common.LeftPadBytes([]byte{i + 10}, 32), int64(i)} vals = append(vals, e, e2) } for i := 0; i < n; i++ { k := randBytes(32) v := rand.Int63() - vals = append(vals, &kv{k, v}) + vals = append(vals, &kvi{k, v}) } all := kvMap{} for _, val := range vals { - all[string(val.k)] = &kv{[]byte(val.k), val.v} + all[string(val.k)] = &kvi{[]byte(val.k), val.v} trie.Update([]byte(val.k), packValue(val.v)) } return trie, all } -// generates a random IPLD-indexed trie -func randomTrie(t testing.TB, n int) (*trie.Trie, kvMap) { - edb := rawdb.NewMemoryDatabase() - db := geth_trie.NewDatabase(edb) - orig, vals := randomGethTrie(n, db) - root := commitTrie(t, db, orig) - trie := indexTrie(t, edb, root) - return trie, vals -} - -func randBytes(n int) []byte { - r := make([]byte, n) - rand.Read(r) - return r -} - // TearDownDB is used to tear down the watcher dbs after tests func TearDownDB(db *sqlx.DB) error { tx, err := db.Beginx() @@ -179,12 +239,3 @@ func TearDownDB(db *sqlx.DB) error { } return tx.Commit() } - -// returns a cache config with unique name (groupcache names are global) -func makeCacheConfig(t testing.TB) pgipfsethdb.CacheConfig { - return pgipfsethdb.CacheConfig{ - Name: t.Name(), - Size: 3000000, // 3MB - ExpiryDuration: time.Hour, - } -}