diff --git a/direct_by_leaf/database.go b/direct_by_leaf/database.go
index a5d5058..b1d506b 100644
--- a/direct_by_leaf/database.go
+++ b/direct_by_leaf/database.go
@@ -4,10 +4,10 @@ import (
"context"
"errors"
"fmt"
- "math/big"
"github.com/VictoriaMetrics/fastcache"
lru "github.com/hashicorp/golang-lru"
+ "github.com/holiman/uint256"
"github.com/cerc-io/plugeth-statediff/indexer/ipld"
"github.com/ethereum/go-ethereum/common"
@@ -99,8 +99,10 @@ func (sd *cachingDB) StateAccount(addressHash, blockHash common.Hash) (*types.St
// TODO: check expected behavior for deleted/non existing accounts
return nil, nil
}
- bal := new(big.Int)
- bal.SetString(res.Balance, 10)
+ bal, err := uint256.FromDecimal(res.Balance)
+ if err != nil {
+ return nil, err
+ }
return &types.StateAccount{
Nonce: res.Nonce,
Balance: bal,
diff --git a/direct_by_leaf/journal.go b/direct_by_leaf/journal.go
index 0ea1fb5..4a2bb6d 100644
--- a/direct_by_leaf/journal.go
+++ b/direct_by_leaf/journal.go
@@ -1,9 +1,8 @@
package state
import (
- "math/big"
-
"github.com/ethereum/go-ethereum/common"
+ "github.com/holiman/uint256"
)
// journalEntry is a modification entry in the state change journal that can be
@@ -74,19 +73,26 @@ type (
account *common.Address
}
resetObjectChange struct {
+ account *common.Address
prev *stateObject
prevdestruct bool
+ prevAccount []byte
+ prevStorage map[common.Hash][]byte
+
+ prevAccountOriginExist bool
+ prevAccountOrigin []byte
+ prevStorageOrigin map[common.Hash][]byte
}
- suicideChange struct {
+ selfDestructChange struct {
account *common.Address
- prev bool // whether account had already suicided
- prevbalance *big.Int
+ prev bool // whether account had already self-destructed
+ prevbalance *uint256.Int
}
// Changes to individual accounts.
balanceChange struct {
account *common.Address
- prev *big.Int
+ prev *uint256.Int
}
nonceChange struct {
account *common.Address
@@ -141,21 +147,36 @@ func (ch createObjectChange) dirtied() *common.Address {
func (ch resetObjectChange) revert(s *StateDB) {
s.setStateObject(ch.prev)
+ if !ch.prevdestruct {
+ delete(s.stateObjectsDestruct, ch.prev.address)
+ }
+ if ch.prevAccount != nil {
+ s.accounts[ch.prev.addrHash] = ch.prevAccount
+ }
+ if ch.prevStorage != nil {
+ s.storages[ch.prev.addrHash] = ch.prevStorage
+ }
+ if ch.prevAccountOriginExist {
+ s.accountsOrigin[ch.prev.address] = ch.prevAccountOrigin
+ }
+ if ch.prevStorageOrigin != nil {
+ s.storagesOrigin[ch.prev.address] = ch.prevStorageOrigin
+ }
}
func (ch resetObjectChange) dirtied() *common.Address {
- return nil
+ return ch.account
}
-func (ch suicideChange) revert(s *StateDB) {
+func (ch selfDestructChange) revert(s *StateDB) {
obj := s.getStateObject(*ch.account)
if obj != nil {
- obj.suicided = ch.prev
+ obj.selfDestructed = ch.prev
obj.setBalance(ch.prevbalance)
}
}
-func (ch suicideChange) dirtied() *common.Address {
+func (ch selfDestructChange) dirtied() *common.Address {
return ch.account
}
diff --git a/direct_by_leaf/state_object.go b/direct_by_leaf/state_object.go
index eb42d8c..8c1398c 100644
--- a/direct_by_leaf/state_object.go
+++ b/direct_by_leaf/state_object.go
@@ -3,7 +3,6 @@ package state
import (
"bytes"
"fmt"
- "math/big"
"time"
"github.com/ethereum/go-ethereum/common"
@@ -11,15 +10,7 @@ import (
"github.com/ethereum/go-ethereum/crypto"
"github.com/ethereum/go-ethereum/metrics"
"github.com/ethereum/go-ethereum/rlp"
-)
-
-var (
- // emptyRoot is the known root hash of an empty trie.
- // this is calculated as: emptyRoot = crypto.Keccak256(rlp.Encode([][]byte{}))
- // that is, the keccak356 hash of the rlp encoding of an empty trie node (empty byte slice array)
- emptyRoot = common.HexToHash("56e81f171bcc55a6ff8345e692c0f86e5b48e01b996cadc001622fb5e363b421")
- // emptyCodeHash is the CodeHash for an EOA, for an account without contract code deployed
- emptyCodeHash = crypto.Keccak256(nil)
+ "github.com/holiman/uint256"
)
type Code []byte
@@ -53,72 +44,66 @@ func (s Storage) Copy() Storage {
// First you need to obtain a state object.
// Account values can be accessed and modified through the object.
type stateObject struct {
- address common.Address
- addrHash common.Hash // hash of ethereum address of the account
- blockHash common.Hash // hash of the block this state object exists at or is being applied on top of
- data types.StateAccount
db *StateDB
-
- // DB error.
- // State objects are used by the consensus core and VM which are
- // unable to deal with database-level errors. Any error that occurs
- // during a database read is memoized here and will eventually be returned
- // by StateDB.Commit.
- dbErr error
+ address common.Address
+ addrHash common.Hash // hash of ethereum address of the account
+ blockHash common.Hash // hash of the block this state object exists at or is being applied on top of
+ origin *types.StateAccount // Account original data without any change applied, nil means it was not existent
+ data types.StateAccount
// Write caches.
code Code // contract bytecode, which gets set when code is loaded
- originStorage Storage // Storage cache of original entries to dedup rewrites, reset for every transaction
+ originStorage Storage // Storage cache of original entries to dedup rewrites
pendingStorage Storage // Storage entries that need to be flushed to disk, at the end of an entire block
- dirtyStorage Storage // Storage entries that have been modified in the current transaction execution
- fakeStorage Storage // Fake storage which constructed by caller for debugging purpose.
+ dirtyStorage Storage // Storage entries that have been modified in the current transaction execution, reset for every transaction
// Cache flags.
- // When an object is marked suicided it will be delete from the trie
- // during the "update" phase of the state transition.
dirtyCode bool // true if the code was updated
- suicided bool
- deleted bool
+
+ // Flag whether the account was marked as self-destructed. The self-destructed account
+ // is still accessible in the scope of same transaction.
+ selfDestructed bool
+
+ // Flag whether the account was marked as deleted. A self-destructed account
+ // or an account that is considered as empty will be marked as deleted at
+ // the end of transaction and no longer accessible anymore.
+ deleted bool
+
+ // Flag whether the object was created in the current transaction
+ created bool
}
// empty returns whether the account is considered empty.
func (s *stateObject) empty() bool {
- return s.data.Nonce == 0 && s.data.Balance.Sign() == 0 && bytes.Equal(s.data.CodeHash, emptyCodeHash)
+ return s.data.Nonce == 0 && s.data.Balance.IsZero() && bytes.Equal(s.data.CodeHash, types.EmptyCodeHash.Bytes())
}
// newObject creates a state object.
-func newObject(db *StateDB, address common.Address, data types.StateAccount, blockHash common.Hash) *stateObject {
- if data.Balance == nil {
- data.Balance = new(big.Int)
- }
- if data.CodeHash == nil {
- data.CodeHash = emptyCodeHash
- }
- if data.Root == (common.Hash{}) {
- data.Root = emptyRoot
+func newObject(db *StateDB, address common.Address, acct *types.StateAccount, blockHash common.Hash) *stateObject {
+ var (
+ origin = acct
+ created = acct == nil // true if the account was not existent
+ )
+ if acct == nil {
+ acct = types.NewEmptyStateAccount()
}
return &stateObject{
db: db,
address: address,
addrHash: crypto.Keccak256Hash(address[:]),
blockHash: blockHash,
- data: data,
+ origin: origin,
+ data: *acct,
originStorage: make(Storage),
pendingStorage: make(Storage),
dirtyStorage: make(Storage),
+ created: created,
}
}
-// setError remembers the first non-nil error it is called with.
-func (s *stateObject) setError(err error) {
- if s.dbErr == nil {
- s.dbErr = err
- }
-}
-
-func (s *stateObject) markSuicided() {
- s.suicided = true
+func (s *stateObject) markSelfdestructed() {
+ s.selfDestructed = true
}
func (s *stateObject) touch() {
@@ -133,46 +118,51 @@ func (s *stateObject) touch() {
}
// GetState retrieves a value from the account storage trie.
-func (s *stateObject) GetState(db StateDatabase, key common.Hash) common.Hash {
- // If the fake storage is set, only lookup the state here(in the debugging mode)
- if s.fakeStorage != nil {
- return s.fakeStorage[key]
- }
+func (s *stateObject) GetState(key common.Hash) common.Hash {
// If we have a dirty value for this state entry, return it
value, dirty := s.dirtyStorage[key]
if dirty {
return value
}
// Otherwise return the entry's original value
- return s.GetCommittedState(db, key)
+ return s.GetCommittedState(key)
}
// GetCommittedState retrieves a value from the committed account storage trie.
-func (s *stateObject) GetCommittedState(db StateDatabase, key common.Hash) common.Hash {
- // If the fake storage is set, only lookup the state here(in the debugging mode)
- if s.fakeStorage != nil {
- return s.fakeStorage[key]
+func (s *stateObject) GetCommittedState(key common.Hash) common.Hash {
+ // If we have a pending write or clean cached, return that
+ if value, pending := s.pendingStorage[key]; pending {
+ return value
}
// If we have a pending write or clean cached, return that
if value, cached := s.originStorage[key]; cached {
return value
}
+ // If the object was destructed in *this* block (and potentially resurrected),
+ // the storage has been cleared out, and we should *not* consult the previous
+ // database about any storage values. The only possible alternatives are:
+ // 1) resurrect happened, and new slot values were set -- those should
+ // have been handles via pendingStorage above.
+ // 2) we don't have new values, and can deliver empty response back
+ if _, destructed := s.db.stateObjectsDestruct[s.address]; destructed {
+ return common.Hash{}
+ }
// If no live objects are available, load from database
start := time.Now()
keyHash := crypto.Keccak256Hash(key[:])
- enc, err := db.StorageValue(s.addrHash, keyHash, s.blockHash)
+ enc, err := s.db.db.StorageValue(s.addrHash, keyHash, s.blockHash)
if metrics.EnabledExpensive {
s.db.StorageReads += time.Since(start)
}
if err != nil {
- s.setError(err)
+ s.db.setError(err)
return common.Hash{}
}
var value common.Hash
if len(enc) > 0 {
_, content, _, err := rlp.Split(enc)
if err != nil {
- s.setError(err)
+ s.db.setError(err)
}
value.SetBytes(content)
}
@@ -181,14 +171,9 @@ func (s *stateObject) GetCommittedState(db StateDatabase, key common.Hash) commo
}
// SetState updates a value in account storage.
-func (s *stateObject) SetState(db StateDatabase, key, value common.Hash) {
- // If the fake storage is set, put the temporary state update here.
- if s.fakeStorage != nil {
- s.fakeStorage[key] = value
- return
- }
+func (s *stateObject) SetState(key, value common.Hash) {
// If the new value is the same as old, don't set
- prev := s.GetState(db, key)
+ prev := s.GetState(key)
if prev == value {
return
}
@@ -201,63 +186,78 @@ func (s *stateObject) SetState(db StateDatabase, key, value common.Hash) {
s.setState(key, value)
}
-// SetStorage replaces the entire state storage with the given one.
-//
-// After this function is called, all original state will be ignored and state
-// lookup only happens in the fake state storage.
-//
-// Note this function should only be used for debugging purpose.
-func (s *stateObject) SetStorage(storage map[common.Hash]common.Hash) {
- // Allocate fake storage if it's nil.
- if s.fakeStorage == nil {
- s.fakeStorage = make(Storage)
- }
- for key, value := range storage {
- s.fakeStorage[key] = value
- }
- // Don't bother journal since this function should only be used for
- // debugging and the `fake` storage won't be committed to database.
-}
-
func (s *stateObject) setState(key, value common.Hash) {
s.dirtyStorage[key] = value
}
+// finalise moves all dirty storage slots into the pending area to be hashed or
+// committed later. It is invoked at the end of every transaction.
+func (s *stateObject) finalise(prefetch bool) {
+ slotsToPrefetch := make([][]byte, 0, len(s.dirtyStorage))
+ for key, value := range s.dirtyStorage {
+ s.pendingStorage[key] = value
+ if value != s.originStorage[key] {
+ slotsToPrefetch = append(slotsToPrefetch, common.CopyBytes(key[:])) // Copy needed for closure
+ }
+ }
+ if len(s.dirtyStorage) > 0 {
+ s.dirtyStorage = make(Storage)
+ }
+}
+
// AddBalance adds amount to s's balance.
// It is used to add funds to the destination account of a transfer.
-func (s *stateObject) AddBalance(amount *big.Int) {
+func (s *stateObject) AddBalance(amount *uint256.Int) {
// EIP161: We must check emptiness for the objects such that the account
// clearing (0,0,0 objects) can take effect.
- if amount.Sign() == 0 {
+ if amount.IsZero() {
if s.empty() {
s.touch()
}
return
}
- s.SetBalance(new(big.Int).Add(s.Balance(), amount))
+ s.SetBalance(new(uint256.Int).Add(s.Balance(), amount))
}
// SubBalance removes amount from s's balance.
// It is used to remove funds from the origin account of a transfer.
-func (s *stateObject) SubBalance(amount *big.Int) {
- if amount.Sign() == 0 {
+func (s *stateObject) SubBalance(amount *uint256.Int) {
+ if amount.IsZero() {
return
}
- s.SetBalance(new(big.Int).Sub(s.Balance(), amount))
+ s.SetBalance(new(uint256.Int).Sub(s.Balance(), amount))
}
-func (s *stateObject) SetBalance(amount *big.Int) {
+func (s *stateObject) SetBalance(amount *uint256.Int) {
s.db.journal.append(balanceChange{
account: &s.address,
- prev: new(big.Int).Set(s.data.Balance),
+ prev: new(uint256.Int).Set(s.data.Balance),
})
s.setBalance(amount)
}
-func (s *stateObject) setBalance(amount *big.Int) {
+func (s *stateObject) setBalance(amount *uint256.Int) {
s.data.Balance = amount
}
+func (s *stateObject) deepCopy(db *StateDB) *stateObject {
+ obj := &stateObject{
+ db: db,
+ address: s.address,
+ addrHash: s.addrHash,
+ origin: s.origin,
+ data: s.data,
+ }
+ obj.code = s.code
+ obj.dirtyStorage = s.dirtyStorage.Copy()
+ obj.originStorage = s.originStorage.Copy()
+ obj.pendingStorage = s.pendingStorage.Copy()
+ obj.selfDestructed = s.selfDestructed
+ obj.dirtyCode = s.dirtyCode
+ obj.deleted = s.deleted
+ return obj
+}
+
//
// Attribute accessors
//
@@ -268,16 +268,16 @@ func (s *stateObject) Address() common.Address {
}
// Code returns the contract code associated with this object, if any.
-func (s *stateObject) Code(db StateDatabase) []byte {
+func (s *stateObject) Code() []byte {
if s.code != nil {
return s.code
}
- if bytes.Equal(s.CodeHash(), emptyCodeHash) {
+ if bytes.Equal(s.CodeHash(), types.EmptyCodeHash.Bytes()) {
return nil
}
- code, err := db.ContractCode(common.BytesToHash(s.CodeHash()))
+ code, err := s.db.db.ContractCode(common.BytesToHash(s.CodeHash()))
if err != nil {
- s.setError(fmt.Errorf("can't load code hash %x: %v", s.CodeHash(), err))
+ s.db.setError(fmt.Errorf("can't load code hash %x: %v", s.CodeHash(), err))
}
s.code = code
return code
@@ -286,22 +286,22 @@ func (s *stateObject) Code(db StateDatabase) []byte {
// CodeSize returns the size of the contract code associated with this object,
// or zero if none. This method is an almost mirror of Code, but uses a cache
// inside the database to avoid loading codes seen recently.
-func (s *stateObject) CodeSize(db StateDatabase) int {
+func (s *stateObject) CodeSize() int {
if s.code != nil {
return len(s.code)
}
- if bytes.Equal(s.CodeHash(), emptyCodeHash) {
+ if bytes.Equal(s.CodeHash(), types.EmptyCodeHash.Bytes()) {
return 0
}
- size, err := db.ContractCodeSize(common.BytesToHash(s.CodeHash()))
+ size, err := s.db.db.ContractCodeSize(common.BytesToHash(s.CodeHash()))
if err != nil {
- s.setError(fmt.Errorf("can't load code size %x: %v", s.CodeHash(), err))
+ s.db.setError(fmt.Errorf("can't load code size %x: %v", s.CodeHash(), err))
}
return size
}
func (s *stateObject) SetCode(codeHash common.Hash, code []byte) {
- prevcode := s.Code(s.db.db)
+ prevcode := s.Code()
s.db.journal.append(codeChange{
account: &s.address,
prevhash: s.CodeHash(),
@@ -332,7 +332,7 @@ func (s *stateObject) CodeHash() []byte {
return s.data.CodeHash
}
-func (s *stateObject) Balance() *big.Int {
+func (s *stateObject) Balance() *uint256.Int {
return s.data.Balance
}
@@ -340,32 +340,6 @@ func (s *stateObject) Nonce() uint64 {
return s.data.Nonce
}
-// Value is never called, but must be present to allow stateObject to be used
-// as a vm.Account interface that also satisfies the vm.ContractRef
-// interface. Interfaces are awesome.
-func (s *stateObject) Value() *big.Int {
- panic("Value on stateObject should never be called")
-}
-
-// finalise moves all dirty storage slots into the pending area to be hashed or
-// committed later. It is invoked at the end of every transaction.
-func (s *stateObject) finalise(prefetch bool) {
- for key, value := range s.dirtyStorage {
- s.pendingStorage[key] = value
- }
- if len(s.dirtyStorage) > 0 {
- s.dirtyStorage = make(Storage)
- }
-}
-
-func (s *stateObject) deepCopy(db *StateDB) *stateObject {
- stateObject := newObject(db, s.address, s.data, s.blockHash)
- stateObject.code = s.code
- stateObject.dirtyStorage = s.dirtyStorage.Copy()
- stateObject.originStorage = s.originStorage.Copy()
- stateObject.pendingStorage = s.pendingStorage.Copy()
- stateObject.suicided = s.suicided
- stateObject.dirtyCode = s.dirtyCode
- stateObject.deleted = s.deleted
- return stateObject
+func (s *stateObject) Root() common.Hash {
+ return s.data.Root
}
diff --git a/direct_by_leaf/statedb.go b/direct_by_leaf/statedb.go
index e0d0229..2f2897b 100644
--- a/direct_by_leaf/statedb.go
+++ b/direct_by_leaf/statedb.go
@@ -2,7 +2,6 @@ package state
import (
"fmt"
- "math/big"
"sort"
"time"
@@ -12,6 +11,13 @@ import (
"github.com/ethereum/go-ethereum/crypto"
"github.com/ethereum/go-ethereum/metrics"
"github.com/ethereum/go-ethereum/params"
+ "github.com/holiman/uint256"
+)
+
+const (
+ // storageDeleteLimit denotes the highest permissible memory allocation
+ // employed for contract storage deletion.
+ storageDeleteLimit = 512 * 1024 * 1024
)
/*
@@ -41,20 +47,38 @@ type revision struct {
// StateDB structs within the ethereum protocol are used to store anything
// within the merkle trie. StateDBs take care of caching and storing
// nested states. It's the general query interface to retrieve:
+//
// * Contracts
// * Accounts
+//
+// Once the state is committed, tries cached in stateDB (including account
+// trie, storage tries) will no longer be functional. A new state instance
+// must be created with new root and updated database for accessing post-
+// commit states.
type StateDB struct {
- db StateDatabase
+ db Database
hasher crypto.KeccakState
// originBlockHash is the blockhash for the state we are working on top of
originBlockHash common.Hash
- // This map holds 'live' objects, which will get modified while processing a state transition.
+ // originalRoot is the pre-state root, before any changes were made.
+ // It will be updated when the Commit is called.
+ originalRoot common.Hash
+
+ // These maps hold the state changes (including the corresponding
+ // original value) that occurred in this **block**.
+ accounts map[common.Hash][]byte // The mutated accounts in 'slim RLP' encoding
+ storages map[common.Hash]map[common.Hash][]byte // The mutated slots in prefix-zero trimmed rlp format
+ accountsOrigin map[common.Address][]byte // The original value of mutated accounts in 'slim RLP' encoding
+ storagesOrigin map[common.Address]map[common.Hash][]byte // The original value of mutated slots in prefix-zero trimmed rlp format
+
+ // This map holds 'live' objects, which will get modified while processing
+ // a state transition.
stateObjects map[common.Address]*stateObject
- stateObjectsPending map[common.Address]struct{} // State objects finalized but not yet written to the trie
- stateObjectsDirty map[common.Address]struct{} // State objects modified in the current execution
- stateObjectsDestruct map[common.Address]struct{} // State objects destructed in the block
+ stateObjectsPending map[common.Address]struct{} // State objects finalized but not yet written to the trie
+ stateObjectsDirty map[common.Address]struct{} // State objects modified in the current execution
+ stateObjectsDestruct map[common.Address]*types.StateAccount // State objects destructed in the block along with its previous value
// DB error.
// State objects are used by the consensus core and VM which are
@@ -66,11 +90,13 @@ type StateDB struct {
// The refund counter, also used by state transitioning.
refund uint64
+ // The tx context and all occurred logs in the scope of transaction.
thash common.Hash
txIndex int
logs map[common.Hash][]*types.Log
logSize uint
+ // Preimages occurred seen by VM in the scope of block.
preimages map[common.Hash][]byte
// Per-transaction access list
@@ -91,14 +117,14 @@ type StateDB struct {
}
// New creates a new StateDB on the state for the provided blockHash
-func New(blockHash common.Hash, db StateDatabase) (*StateDB, error) {
+func New(blockHash common.Hash, db Database) (*StateDB, error) {
sdb := &StateDB{
db: db,
originBlockHash: blockHash,
stateObjects: make(map[common.Address]*stateObject),
stateObjectsPending: make(map[common.Address]struct{}),
stateObjectsDirty: make(map[common.Address]struct{}),
- stateObjectsDestruct: make(map[common.Address]struct{}),
+ stateObjectsDestruct: make(map[common.Address]*types.StateAccount),
logs: make(map[common.Hash][]*types.Log),
preimages: make(map[common.Hash][]byte),
journal: newJournal(),
@@ -153,7 +179,7 @@ func (s *StateDB) SubRefund(gas uint64) {
}
// Exist reports whether the given account address exists in the state.
-// Notably this also returns true for suicided accounts.
+// Notably this also returns true for self-destructed accounts.
func (s *StateDB) Exist(addr common.Address) bool {
return s.getStateObject(addr) != nil
}
@@ -166,14 +192,15 @@ func (s *StateDB) Empty(addr common.Address) bool {
}
// GetBalance retrieves the balance from the given address or 0 if object not found
-func (s *StateDB) GetBalance(addr common.Address) *big.Int {
+func (s *StateDB) GetBalance(addr common.Address) *uint256.Int {
stateObject := s.getStateObject(addr)
if stateObject != nil {
return stateObject.Balance()
}
- return common.Big0
+ return common.U2560
}
+// GetNonce retrieves the nonce from the given address or 0 if object not found
func (s *StateDB) GetNonce(addr common.Address) uint64 {
stateObject := s.getStateObject(addr)
if stateObject != nil {
@@ -183,10 +210,25 @@ func (s *StateDB) GetNonce(addr common.Address) uint64 {
return 0
}
+// GetStorageRoot retrieves the storage root from the given address or empty
+// if object not found.
+func (s *StateDB) GetStorageRoot(addr common.Address) common.Hash {
+ stateObject := s.getStateObject(addr)
+ if stateObject != nil {
+ return stateObject.Root()
+ }
+ return common.Hash{}
+}
+
+// TxIndex returns the current transaction index set by Prepare.
+func (s *StateDB) TxIndex() int {
+ return s.txIndex
+}
+
func (s *StateDB) GetCode(addr common.Address) []byte {
stateObject := s.getStateObject(addr)
if stateObject != nil {
- return stateObject.Code(s.db)
+ return stateObject.Code()
}
return nil
}
@@ -194,24 +236,24 @@ func (s *StateDB) GetCode(addr common.Address) []byte {
func (s *StateDB) GetCodeSize(addr common.Address) int {
stateObject := s.getStateObject(addr)
if stateObject != nil {
- return stateObject.CodeSize(s.db)
+ return stateObject.CodeSize()
}
return 0
}
func (s *StateDB) GetCodeHash(addr common.Address) common.Hash {
stateObject := s.getStateObject(addr)
- if stateObject == nil {
- return common.Hash{}
+ if stateObject != nil {
+ return common.BytesToHash(stateObject.CodeHash())
}
- return common.BytesToHash(stateObject.CodeHash())
+ return common.Hash{}
}
// GetState retrieves a value from the given account's storage trie.
func (s *StateDB) GetState(addr common.Address, hash common.Hash) common.Hash {
stateObject := s.getStateObject(addr)
if stateObject != nil {
- return stateObject.GetState(s.db, hash)
+ return stateObject.GetState(hash)
}
return common.Hash{}
}
@@ -220,15 +262,20 @@ func (s *StateDB) GetState(addr common.Address, hash common.Hash) common.Hash {
func (s *StateDB) GetCommittedState(addr common.Address, hash common.Hash) common.Hash {
stateObject := s.getStateObject(addr)
if stateObject != nil {
- return stateObject.GetCommittedState(s.db, hash)
+ return stateObject.GetCommittedState(hash)
}
return common.Hash{}
}
-func (s *StateDB) HasSuicided(addr common.Address) bool {
+// Database retrieves the low level database supporting the lower level trie ops.
+func (s *StateDB) Database() Database {
+ return s.db
+}
+
+func (s *StateDB) HasSelfDestructed(addr common.Address) bool {
stateObject := s.getStateObject(addr)
if stateObject != nil {
- return stateObject.suicided
+ return stateObject.selfDestructed
}
return false
}
@@ -238,7 +285,7 @@ func (s *StateDB) HasSuicided(addr common.Address) bool {
*/
// AddBalance adds amount to the account associated with addr.
-func (s *StateDB) AddBalance(addr common.Address, amount *big.Int) {
+func (s *StateDB) AddBalance(addr common.Address, amount *uint256.Int) {
stateObject := s.getOrNewStateObject(addr)
if stateObject != nil {
stateObject.AddBalance(amount)
@@ -246,14 +293,14 @@ func (s *StateDB) AddBalance(addr common.Address, amount *big.Int) {
}
// SubBalance subtracts amount from the account associated with addr.
-func (s *StateDB) SubBalance(addr common.Address, amount *big.Int) {
+func (s *StateDB) SubBalance(addr common.Address, amount *uint256.Int) {
stateObject := s.getOrNewStateObject(addr)
if stateObject != nil {
stateObject.SubBalance(amount)
}
}
-func (s *StateDB) SetBalance(addr common.Address, amount *big.Int) {
+func (s *StateDB) SetBalance(addr common.Address, amount *uint256.Int) {
stateObject := s.getOrNewStateObject(addr)
if stateObject != nil {
stateObject.SetBalance(amount)
@@ -277,39 +324,59 @@ func (s *StateDB) SetCode(addr common.Address, code []byte) {
func (s *StateDB) SetState(addr common.Address, key, value common.Hash) {
stateObject := s.getOrNewStateObject(addr)
if stateObject != nil {
- stateObject.SetState(s.db, key, value)
+ stateObject.SetState(key, value)
}
}
// SetStorage replaces the entire storage for the specified account with given
-// storage. This function should only be used for debugging.
+// storage. This function should only be used for debugging and the mutations
+// must be discarded afterwards.
func (s *StateDB) SetStorage(addr common.Address, storage map[common.Hash]common.Hash) {
- s.stateObjectsDestruct[addr] = struct{}{}
+ // SetStorage needs to wipe existing storage. We achieve this by pretending
+ // that the account self-destructed earlier in this block, by flagging
+ // it in stateObjectsDestruct. The effect of doing so is that storage lookups
+ // will not hit disk, since it is assumed that the disk-data is belonging
+ // to a previous incarnation of the object.
+ //
+ // TODO(rjl493456442) this function should only be supported by 'unwritable'
+ // state and all mutations made should all be discarded afterwards.
+ if _, ok := s.stateObjectsDestruct[addr]; !ok {
+ s.stateObjectsDestruct[addr] = nil
+ }
stateObject := s.getOrNewStateObject(addr)
- if stateObject != nil {
- stateObject.SetStorage(storage)
+ for k, v := range storage {
+ stateObject.SetState(k, v)
}
}
-// Suicide marks the given account as suicided.
+// SelfDestruct marks the given account as selfdestructed.
// This clears the account balance.
//
// The account's state object is still available until the state is committed,
-// getStateObject will return a non-nil account after Suicide.
-func (s *StateDB) Suicide(addr common.Address) bool {
+// getStateObject will return a non-nil account after SelfDestruct.
+func (s *StateDB) SelfDestruct(addr common.Address) {
stateObject := s.getStateObject(addr)
if stateObject == nil {
- return false
+ return
}
- s.journal.append(suicideChange{
+ s.journal.append(selfDestructChange{
account: &addr,
- prev: stateObject.suicided,
- prevbalance: new(big.Int).Set(stateObject.Balance()),
+ prev: stateObject.selfDestructed,
+ prevbalance: new(uint256.Int).Set(stateObject.Balance()),
})
- stateObject.markSuicided()
- stateObject.data.Balance = new(big.Int)
+ stateObject.markSelfdestructed()
+ stateObject.data.Balance = new(uint256.Int)
+}
- return true
+func (s *StateDB) Selfdestruct6780(addr common.Address) {
+ stateObject := s.getStateObject(addr)
+ if stateObject == nil {
+ return
+ }
+
+ if stateObject.created {
+ s.SelfDestruct(addr)
+ }
}
// SetTransientState sets transient storage for a given account. It
@@ -380,7 +447,7 @@ func (s *StateDB) getDeletedStateObject(addr common.Address) *stateObject {
return nil
}
// Insert into the live set
- obj := newObject(s, addr, *data, s.originBlockHash)
+ obj := newObject(s, addr, data, s.originBlockHash)
s.setStateObject(obj)
return obj
}
@@ -402,19 +469,36 @@ func (s *StateDB) getOrNewStateObject(addr common.Address) *stateObject {
// the given address, it is overwritten and returned as the second return value.
func (s *StateDB) createObject(addr common.Address) (newobj, prev *stateObject) {
prev = s.getDeletedStateObject(addr) // Note, prev might have been deleted, we need that!
-
- var prevdestruct bool
- if prev != nil {
- _, prevdestruct = s.stateObjectsDestruct[prev.address]
- if !prevdestruct {
- s.stateObjectsDestruct[prev.address] = struct{}{}
- }
- }
- newobj = newObject(s, addr, types.StateAccount{}, s.originBlockHash)
+ newobj = newObject(s, addr, nil, s.originBlockHash)
if prev == nil {
s.journal.append(createObjectChange{account: &addr})
} else {
- s.journal.append(resetObjectChange{prev: prev, prevdestruct: prevdestruct}) // NOTE: prevdestruct used to be set here from snapshot
+ // The original account should be marked as destructed and all cached
+ // account and storage data should be cleared as well. Note, it must
+ // be done here, otherwise the destruction event of "original account"
+ // will be lost.
+ _, prevdestruct := s.stateObjectsDestruct[prev.address]
+ if !prevdestruct {
+ s.stateObjectsDestruct[prev.address] = prev.origin
+ }
+ // There may be some cached account/storage data already since IntermediateRoot
+ // will be called for each transaction before byzantium fork which will always
+ // cache the latest account/storage data.
+ prevAccount, ok := s.accountsOrigin[prev.address]
+ s.journal.append(resetObjectChange{
+ account: &addr,
+ prev: prev,
+ prevdestruct: prevdestruct,
+ prevAccount: s.accounts[prev.addrHash],
+ prevStorage: s.storages[prev.addrHash],
+ prevAccountOriginExist: ok,
+ prevAccountOrigin: prevAccount,
+ prevStorageOrigin: s.storagesOrigin[prev.address],
+ })
+ delete(s.accounts, prev.addrHash)
+ delete(s.storages, prev.addrHash)
+ delete(s.accountsOrigin, prev.address)
+ delete(s.storagesOrigin, prev.address)
}
s.setStateObject(newobj)
if prev != nil && !prev.deleted {
@@ -440,14 +524,6 @@ func (s *StateDB) CreateAccount(addr common.Address) {
}
}
-// ForEachStorage satisfies vm.StateDB but is not implemented
-func (db *StateDB) ForEachStorage(addr common.Address, cb func(key, value common.Hash) bool) error {
- // NOTE: as far as I can tell this method is only ever used in tests
- // in that case, we can leave it unimplemented
- // or if it needs to be implemented we can use iplfs-ethdb to do normal trie access
- panic("ForEachStorage is not implemented")
-}
-
// Snapshot returns an identifier for the current revision of the state.
func (s *StateDB) Snapshot() int {
id := s.nextRevisionId
@@ -477,6 +553,70 @@ func (s *StateDB) GetRefund() uint64 {
return s.refund
}
+// Finalise finalises the state by removing the destructed objects and clears
+// the journal as well as the refunds. Finalise, however, will not push any updates
+// into the tries just yet. Only IntermediateRoot or Commit will do that.
+func (s *StateDB) Finalise(deleteEmptyObjects bool) {
+ addressesToPrefetch := make([][]byte, 0, len(s.journal.dirties))
+ for addr := range s.journal.dirties {
+ obj, exist := s.stateObjects[addr]
+ if !exist {
+ // ripeMD is 'touched' at block 1714175, in tx 0x1237f737031e40bcde4a8b7e717b2d15e3ecadfe49bb1bbc71ee9deb09c6fcf2
+ // That tx goes out of gas, and although the notion of 'touched' does not exist there, the
+ // touch-event will still be recorded in the journal. Since ripeMD is a special snowflake,
+ // it will persist in the journal even though the journal is reverted. In this special circumstance,
+ // it may exist in `s.journal.dirties` but not in `s.stateObjects`.
+ // Thus, we can safely ignore it here
+ continue
+ }
+ if obj.selfDestructed || (deleteEmptyObjects && obj.empty()) {
+ obj.deleted = true
+
+ // We need to maintain account deletions explicitly (will remain
+ // set indefinitely). Note only the first occurred self-destruct
+ // event is tracked.
+ if _, ok := s.stateObjectsDestruct[obj.address]; !ok {
+ s.stateObjectsDestruct[obj.address] = obj.origin
+ }
+ // Note, we can't do this only at the end of a block because multiple
+ // transactions within the same block might self destruct and then
+ // resurrect an account; but the snapshotter needs both events.
+ delete(s.accounts, obj.addrHash) // Clear out any previously updated account data (may be recreated via a resurrect)
+ delete(s.storages, obj.addrHash) // Clear out any previously updated storage data (may be recreated via a resurrect)
+ delete(s.accountsOrigin, obj.address) // Clear out any previously updated account data (may be recreated via a resurrect)
+ delete(s.storagesOrigin, obj.address) // Clear out any previously updated storage data (may be recreated via a resurrect)
+ } else {
+ obj.finalise(true) // Prefetch slots in the background
+ }
+ obj.created = false
+ s.stateObjectsPending[addr] = struct{}{}
+ s.stateObjectsDirty[addr] = struct{}{}
+
+ // At this point, also ship the address off to the precacher. The precacher
+ // will start loading tries, and when the change is eventually committed,
+ // the commit-phase will be a lot faster
+ addressesToPrefetch = append(addressesToPrefetch, common.CopyBytes(addr[:])) // Copy needed for closure
+ }
+ // Invalidate journal because reverting across transactions is not allowed.
+ s.clearJournalAndRefund()
+}
+
+// SetTxContext sets the current transaction hash and index which are
+// used when the EVM emits new state logs. It should be invoked before
+// transaction execution.
+func (s *StateDB) SetTxContext(thash common.Hash, ti int) {
+ s.thash = thash
+ s.txIndex = ti
+}
+
+func (s *StateDB) clearJournalAndRefund() {
+ if len(s.journal.entries) > 0 {
+ s.journal = newJournal()
+ s.refund = 0
+ }
+ s.validRevisions = s.validRevisions[:0] // Snapshots can be created without journal entries
+}
+
// Prepare handles the preparatory steps for executing a state transition with.
// This method must be invoked before state transition.
//
@@ -553,65 +693,21 @@ func (s *StateDB) SlotInAccessList(addr common.Address, slot common.Hash) (addre
return s.accessList.Contains(addr, slot)
}
-// Finalise finalises the state by removing the destructed objects and clears
-// the journal as well as the refunds. Finalise, however, will not push any updates
-// into the tries just yet. Only IntermediateRoot or Commit will do that.
-func (s *StateDB) Finalise(deleteEmptyObjects bool) {
- for addr := range s.journal.dirties {
- obj, exist := s.stateObjects[addr]
- if !exist {
- // ripeMD is 'touched' at block 1714175, in tx 0x1237f737031e40bcde4a8b7e717b2d15e3ecadfe49bb1bbc71ee9deb09c6fcf2
- // That tx goes out of gas, and although the notion of 'touched' does not exist there, the
- // touch-event will still be recorded in the journal. Since ripeMD is a special snowflake,
- // it will persist in the journal even though the journal is reverted. In this special circumstance,
- // it may exist in `s.journal.dirties` but not in `s.stateObjects`.
- // Thus, we can safely ignore it here
- continue
- }
- if obj.suicided || (deleteEmptyObjects && obj.empty()) {
- obj.deleted = true
-
- // We need to maintain account deletions explicitly (will remain
- // set indefinitely).
- s.stateObjectsDestruct[obj.address] = struct{}{}
- } else {
- obj.finalise(true) // Prefetch slots in the background
- }
- s.stateObjectsPending[addr] = struct{}{}
- s.stateObjectsDirty[addr] = struct{}{}
- }
-
- // Invalidate journal because reverting across transactions is not allowed.
- s.clearJournalAndRefund()
-}
-
-// SetTxContext sets the current transaction hash and index which are
-// used when the EVM emits new state logs. It should be invoked before
-// transaction execution.
-func (s *StateDB) SetTxContext(thash common.Hash, ti int) {
- s.thash = thash
- s.txIndex = ti
-}
-
-func (s *StateDB) clearJournalAndRefund() {
- if len(s.journal.entries) > 0 {
- s.journal = newJournal()
- s.refund = 0
- }
- s.validRevisions = s.validRevisions[:0] // Snapshots can be created without journal entries
-}
-
// Copy creates a deep, independent copy of the state.
// Snapshots of the copied state cannot be applied to the copy.
func (s *StateDB) Copy() *StateDB {
// Copy all the basic fields, initialize the memory ones
state := &StateDB{
db: s.db,
- originBlockHash: s.originBlockHash,
+ originalRoot: s.originalRoot,
+ accounts: make(map[common.Hash][]byte),
+ storages: make(map[common.Hash]map[common.Hash][]byte),
+ accountsOrigin: make(map[common.Address][]byte),
+ storagesOrigin: make(map[common.Address]map[common.Hash][]byte),
stateObjects: make(map[common.Address]*stateObject, len(s.journal.dirties)),
stateObjectsPending: make(map[common.Address]struct{}, len(s.stateObjectsPending)),
stateObjectsDirty: make(map[common.Address]struct{}, len(s.journal.dirties)),
- stateObjectsDestruct: make(map[common.Address]struct{}, len(s.stateObjectsDestruct)),
+ stateObjectsDestruct: make(map[common.Address]*types.StateAccount, len(s.stateObjectsDestruct)),
refund: s.refund,
logs: make(map[common.Hash][]*types.Log, len(s.logs)),
logSize: s.logSize,
@@ -651,10 +747,18 @@ func (s *StateDB) Copy() *StateDB {
}
state.stateObjectsDirty[addr] = struct{}{}
}
- // Deep copy the destruction flag.
- for addr := range s.stateObjectsDestruct {
- state.stateObjectsDestruct[addr] = struct{}{}
+ // Deep copy the destruction markers.
+ for addr, value := range s.stateObjectsDestruct {
+ state.stateObjectsDestruct[addr] = value
}
+ // Deep copy the state changes made in the scope of block
+ // along with their original values.
+ state.accounts = copySet(s.accounts)
+ state.storages = copy2DSet(s.storages)
+ state.accountsOrigin = copySet(state.accountsOrigin)
+ state.storagesOrigin = copy2DSet(state.storagesOrigin)
+
+ // Deep copy the logs occurred in the scope of block
for hash, logs := range s.logs {
cpy := make([]*types.Log, len(logs))
for i, l := range logs {
@@ -663,6 +767,7 @@ func (s *StateDB) Copy() *StateDB {
}
state.logs[hash] = cpy
}
+ // Deep copy the preimages occurred in the scope of block
for hash, preimage := range s.preimages {
state.preimages[hash] = preimage
}
@@ -674,6 +779,26 @@ func (s *StateDB) Copy() *StateDB {
// in the middle of a transaction.
state.accessList = s.accessList.Copy()
state.transientStorage = s.transientStorage.Copy()
-
return state
}
+
+// copySet returns a deep-copied set.
+func copySet[k comparable](set map[k][]byte) map[k][]byte {
+ copied := make(map[k][]byte, len(set))
+ for key, val := range set {
+ copied[key] = common.CopyBytes(val)
+ }
+ return copied
+}
+
+// copy2DSet returns a two-dimensional deep-copied set.
+func copy2DSet[k comparable](set map[k]map[common.Hash][]byte) map[k]map[common.Hash][]byte {
+ copied := make(map[k]map[common.Hash][]byte, len(set))
+ for addr, subset := range set {
+ copied[addr] = make(map[common.Hash][]byte, len(subset))
+ for key, val := range subset {
+ copied[addr][key] = common.CopyBytes(val)
+ }
+ }
+ return copied
+}
diff --git a/direct_by_leaf/statedb_test.go b/direct_by_leaf/statedb_test.go
index 272d79b..b485925 100644
--- a/direct_by_leaf/statedb_test.go
+++ b/direct_by_leaf/statedb_test.go
@@ -5,6 +5,7 @@ import (
"math/big"
"testing"
+ "github.com/holiman/uint256"
"github.com/lib/pq"
"github.com/multiformats/go-multihash"
"github.com/stretchr/testify/require"
@@ -66,7 +67,7 @@ var (
Account = types.StateAccount{
Nonce: uint64(0),
- Balance: big.NewInt(1000),
+ Balance: uint256.NewInt(1000),
CodeHash: AccountCodeHash.Bytes(),
Root: common.Hash{},
}
@@ -112,7 +113,7 @@ func TestPGXSuite(t *testing.T) {
database := sql.NewPGXDriverFromPool(context.Background(), pool)
insertSuiteData(t, database)
- db := state.NewStateDatabase(database)
+ db := state.NewDatabase(database)
require.NoError(t, err)
testSuite(t, db)
}
@@ -137,7 +138,7 @@ func TestSQLXSuite(t *testing.T) {
database := sql.NewSQLXDriverFromPool(context.Background(), pool)
insertSuiteData(t, database)
- db := state.NewStateDatabase(database)
+ db := state.NewDatabase(database)
require.NoError(t, err)
testSuite(t, db)
}
@@ -226,7 +227,7 @@ func insertSuiteData(t *testing.T, database sql.Database) {
require.NoError(t, insertContractCode(database))
}
-func testSuite(t *testing.T, db state.StateDatabase) {
+func testSuite(t *testing.T, db state.Database) {
t.Run("Database", func(t *testing.T) {
size, err := db.ContractCodeSize(AccountCodeHash)
require.NoError(t, err)
@@ -309,14 +310,14 @@ func testSuite(t *testing.T, db state.StateDatabase) {
newStorage := crypto.Keccak256Hash([]byte{5, 4, 3, 2, 1})
newCode := []byte{1, 3, 3, 7}
- sdb.SetBalance(AccountAddress, big.NewInt(300))
- sdb.AddBalance(AccountAddress, big.NewInt(200))
- sdb.SubBalance(AccountAddress, big.NewInt(100))
+ sdb.SetBalance(AccountAddress, uint256.NewInt(300))
+ sdb.AddBalance(AccountAddress, uint256.NewInt(200))
+ sdb.SubBalance(AccountAddress, uint256.NewInt(100))
sdb.SetNonce(AccountAddress, 42)
sdb.SetState(AccountAddress, StorageSlot, newStorage)
sdb.SetCode(AccountAddress, newCode)
- require.Equal(t, big.NewInt(400), sdb.GetBalance(AccountAddress))
+ require.Equal(t, uint256.NewInt(400), sdb.GetBalance(AccountAddress))
require.Equal(t, uint64(42), sdb.GetNonce(AccountAddress))
require.Equal(t, newStorage, sdb.GetState(AccountAddress, StorageSlot))
require.Equal(t, newCode, sdb.GetCode(AccountAddress))
diff --git a/go.mod b/go.mod
index 1d48d44..ec40334 100644
--- a/go.mod
+++ b/go.mod
@@ -116,3 +116,4 @@ replace github.com/cerc-io/plugeth-statediff => git.vdb.to/cerc-io/plugeth-state
// dev
replace github.com/cerc-io/ipfs-ethdb/v5 => git.vdb.to/cerc-io/ipfs-ethdb/v5 v5.0.1-alpha.0.20240403094152-a95b1aea6c5c
+replace github.com/ethereum/go-ethereum => ../go-ethereum
diff --git a/internal/util.go b/internal/util.go
index 634c18a..59fe5a2 100644
--- a/internal/util.go
+++ b/internal/util.go
@@ -5,10 +5,18 @@ import (
"time"
pgipfsethdb "github.com/cerc-io/ipfs-ethdb/v5/postgres/v0"
+ "github.com/cerc-io/plugeth-statediff/indexer/ipld"
+ "github.com/ethereum/go-ethereum/common"
+ "github.com/ethereum/go-ethereum/ethdb"
"github.com/ipfs/go-cid"
"github.com/multiformats/go-multihash"
)
+var (
+ StateTrieCodec uint64 = ipld.MEthStateTrie
+ StorageTrieCodec uint64 = ipld.MEthStorageTrie
+)
+
func Keccak256ToCid(codec uint64, h []byte) (cid.Cid, error) {
buf, err := multihash.Encode(h, multihash.KECCAK_256)
if err != nil {
@@ -25,3 +33,33 @@ func MakeCacheConfig(t testing.TB) pgipfsethdb.CacheConfig {
ExpiryDuration: time.Hour,
}
}
+
+// ReadLegacyTrieNode retrieves the legacy trie node with the given
+// associated node hash.
+func ReadLegacyTrieNode(db ethdb.KeyValueReader, hash common.Hash, codec uint64) ([]byte, error) {
+ cid, err := Keccak256ToCid(codec, hash[:])
+ if err != nil {
+ return nil, err
+ }
+ enc, err := db.Get(cid.Bytes())
+ if err != nil {
+ return nil, err
+ }
+ return enc, nil
+}
+
+func WriteLegacyTrieNode(db ethdb.KeyValueWriter, hash common.Hash, codec uint64, data []byte) error {
+ cid, err := Keccak256ToCid(codec, hash[:])
+ if err != nil {
+ return err
+ }
+ return db.Put(cid.Bytes(), data)
+}
+
+func ReadCode(db ethdb.KeyValueReader, hash common.Hash) ([]byte, error) {
+ return ReadLegacyTrieNode(db, hash, ipld.RawBinary)
+}
+
+func WriteCode(db ethdb.KeyValueWriter, hash common.Hash, code []byte) error {
+ return WriteLegacyTrieNode(db, hash, ipld.RawBinary, code)
+}
diff --git a/trie_by_cid/state/database.go b/trie_by_cid/state/database.go
index c858b70..3657748 100644
--- a/trie_by_cid/state/database.go
+++ b/trie_by_cid/state/database.go
@@ -20,7 +20,7 @@ import (
"errors"
"fmt"
- "github.com/cerc-io/plugeth-statediff/indexer/ipld"
+ "github.com/crate-crypto/go-ipa/banderwagon"
"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/common/lru"
"github.com/ethereum/go-ethereum/core/types"
@@ -28,6 +28,9 @@ import (
"github.com/cerc-io/ipld-eth-statedb/internal"
"github.com/cerc-io/ipld-eth-statedb/trie_by_cid/trie"
+ "github.com/cerc-io/ipld-eth-statedb/trie_by_cid/trie/trienode"
+ "github.com/cerc-io/ipld-eth-statedb/trie_by_cid/trie/utils"
+ "github.com/cerc-io/ipld-eth-statedb/trie_by_cid/triedb"
)
const (
@@ -36,6 +39,12 @@ const (
// Cache size granted for caching clean code.
codeCacheSize = 64 * 1024 * 1024
+
+ // commitmentSize is the size of commitment stored in cache.
+ commitmentSize = banderwagon.UncompressedSize
+
+ // Cache item granted for caching commitment results.
+ commitmentCacheItems = 64 * 1024 * 1024 / (commitmentSize + common.AddressLength)
)
// Database wraps access to tries and contract code.
@@ -44,22 +53,22 @@ type Database interface {
OpenTrie(root common.Hash) (Trie, error)
// OpenStorageTrie opens the storage trie of an account.
- OpenStorageTrie(stateRoot common.Hash, addrHash, root common.Hash) (Trie, error)
+ OpenStorageTrie(stateRoot, addrHash common.Hash, root common.Hash, trie Trie) (Trie, error)
// CopyTrie returns an independent copy of the given trie.
CopyTrie(Trie) Trie
// ContractCode retrieves a particular contract's code.
- ContractCode(codeHash common.Hash) ([]byte, error)
+ ContractCode(addr common.Address, codeHash common.Hash) ([]byte, error)
// ContractCodeSize retrieves a particular contracts code's size.
- ContractCodeSize(codeHash common.Hash) (int, error)
+ ContractCodeSize(addr common.Address, codeHash common.Hash) (int, error)
// DiskDB returns the underlying key-value disk database.
DiskDB() ethdb.KeyValueStore
- // TrieDB retrieves the low level trie database used for data storage.
- TrieDB() *trie.Database
+ // TrieDB returns the underlying trie database for managing trie nodes.
+ TrieDB() *triedb.Database
}
// Trie is a Ethereum Merkle Patricia trie.
@@ -70,40 +79,40 @@ type Trie interface {
// TODO(fjl): remove this when StateTrie is removed
GetKey([]byte) []byte
- // TryGet returns the value for key stored in the trie. The value bytes must
- // not be modified by the caller. If a node was not found in the database, a
- // trie.MissingNodeError is returned.
- TryGet(key []byte) ([]byte, error)
-
- // TryGetNode attempts to retrieve a trie node by compact-encoded path. It is not
- // possible to use keybyte-encoding as the path might contain odd nibbles.
- TryGetNode(path []byte) ([]byte, int, error)
-
- // TryGetAccount abstracts an account read from the trie. It retrieves the
+ // GetAccount abstracts an account read from the trie. It retrieves the
// account blob from the trie with provided account address and decodes it
// with associated decoding algorithm. If the specified account is not in
// the trie, nil will be returned. If the trie is corrupted(e.g. some nodes
// are missing or the account blob is incorrect for decoding), an error will
// be returned.
- TryGetAccount(address common.Address) (*types.StateAccount, error)
+ GetAccount(address common.Address) (*types.StateAccount, error)
- // TryUpdate associates key with value in the trie. If value has length zero, any
- // existing value is deleted from the trie. The value bytes must not be modified
- // by the caller while they are stored in the trie. If a node was not found in the
- // database, a trie.MissingNodeError is returned.
- TryUpdate(key, value []byte) error
+ // GetStorage returns the value for key stored in the trie. The value bytes
+ // must not be modified by the caller. If a node was not found in the database,
+ // a trie.MissingNodeError is returned.
+ GetStorage(addr common.Address, key []byte) ([]byte, error)
- // TryUpdateAccount abstracts an account write to the trie. It encodes the
+ // UpdateAccount abstracts an account write to the trie. It encodes the
// provided account object with associated algorithm and then updates it
// in the trie with provided address.
- TryUpdateAccount(address common.Address, account *types.StateAccount) error
+ UpdateAccount(address common.Address, account *types.StateAccount) error
- // TryDelete removes any existing value for key from the trie. If a node was not
- // found in the database, a trie.MissingNodeError is returned.
- TryDelete(key []byte) error
+ // UpdateStorage associates key with value in the trie. If value has length zero,
+ // any existing value is deleted from the trie. The value bytes must not be modified
+ // by the caller while they are stored in the trie. If a node was not found in the
+ // database, a trie.MissingNodeError is returned.
+ UpdateStorage(addr common.Address, key, value []byte) error
- // TryDeleteAccount abstracts an account deletion from the trie.
- TryDeleteAccount(address common.Address) error
+ // DeleteAccount abstracts an account deletion from the trie.
+ DeleteAccount(address common.Address) error
+
+ // DeleteStorage removes any existing value for key from the trie. If a node
+ // was not found in the database, a trie.MissingNodeError is returned.
+ DeleteStorage(addr common.Address, key []byte) error
+
+ // UpdateContractCode abstracts code write to the trie. It is expected
+ // to be moved to the stateWriter interface when the latter is ready.
+ UpdateContractCode(address common.Address, codeHash common.Hash, code []byte) error
// Hash returns the root hash of the trie. It does not write to the database and
// can be used even if the trie doesn't have one.
@@ -115,11 +124,12 @@ type Trie interface {
// The returned nodeset can be nil if the trie is clean(nothing to commit).
// Once the trie is committed, it's not usable anymore. A new trie must
// be created with new root and updated trie database for following usage
- Commit(collectLeaf bool) (common.Hash, *trie.NodeSet)
+ Commit(collectLeaf bool) (common.Hash, *trienode.NodeSet, error)
// NodeIterator returns an iterator that returns nodes of the trie. Iteration
- // starts at the key after the given start key.
- NodeIterator(startKey []byte) trie.NodeIterator
+ // starts at the key after the given start key. And error will be returned
+ // if fails to create node iterator.
+ NodeIterator(startKey []byte) (trie.NodeIterator, error)
// Prove constructs a Merkle proof for key. The result contains all encoded nodes
// on the path to the value at key. The value itself is also included in the last
@@ -128,7 +138,7 @@ type Trie interface {
// If the trie does not contain a value for key, the returned proof contains all
// nodes of the longest existing prefix of the key (at least the root), ending
// with the node that proves the absence of the key.
- Prove(key []byte, fromLevel uint, proofDb ethdb.KeyValueWriter) error
+ Prove(key []byte, proofDb ethdb.KeyValueWriter) error
}
// NewDatabase creates a backing store for state. The returned database is safe for
@@ -141,17 +151,17 @@ func NewDatabase(db ethdb.Database) Database {
// NewDatabaseWithConfig creates a backing store for state. The returned database
// is safe for concurrent use and retains a lot of collapsed RLP trie nodes in a
// large memory cache.
-func NewDatabaseWithConfig(db ethdb.Database, config *trie.Config) Database {
+func NewDatabaseWithConfig(db ethdb.Database, config *triedb.Config) Database {
return &cachingDB{
disk: db,
codeSizeCache: lru.NewCache[common.Hash, int](codeSizeCacheSize),
codeCache: lru.NewSizeConstrainedCache[common.Hash, []byte](codeCacheSize),
- triedb: trie.NewDatabaseWithConfig(db, config),
+ triedb: triedb.NewDatabase(db, config),
}
}
// NewDatabaseWithNodeDB creates a state database with an already initialized node database.
-func NewDatabaseWithNodeDB(db ethdb.Database, triedb *trie.Database) Database {
+func NewDatabaseWithNodeDB(db ethdb.Database, triedb *triedb.Database) Database {
return &cachingDB{
disk: db,
codeSizeCache: lru.NewCache[common.Hash, int](codeSizeCacheSize),
@@ -164,12 +174,15 @@ type cachingDB struct {
disk ethdb.KeyValueStore
codeSizeCache *lru.Cache[common.Hash, int]
codeCache *lru.SizeConstrainedCache[common.Hash, []byte]
- triedb *trie.Database
+ triedb *triedb.Database
}
// OpenTrie opens the main account trie at a specific root hash.
func (db *cachingDB) OpenTrie(root common.Hash) (Trie, error) {
- tr, err := trie.NewStateTrie(trie.StateTrieID(root), db.triedb, trie.StateTrieCodec)
+ if db.triedb.IsVerkle() {
+ return trie.NewVerkleTrie(root, db.triedb, utils.NewPointCache(commitmentCacheItems))
+ }
+ tr, err := trie.NewStateTrie(trie.StateTrieID(root), db.triedb)
if err != nil {
return nil, err
}
@@ -177,8 +190,14 @@ func (db *cachingDB) OpenTrie(root common.Hash) (Trie, error) {
}
// OpenStorageTrie opens the storage trie of an account.
-func (db *cachingDB) OpenStorageTrie(stateRoot common.Hash, addrHash, root common.Hash) (Trie, error) {
- tr, err := trie.NewStateTrie(trie.StorageTrieID(stateRoot, addrHash, root), db.triedb, trie.StorageTrieCodec)
+func (db *cachingDB) OpenStorageTrie(stateRoot, addrHash common.Hash, root common.Hash, self Trie) (Trie, error) {
+ // In the verkle case, there is only one tree. But the two-tree structure
+ // is hardcoded in the codebase. So we need to return the same trie in this
+ // case.
+ if db.triedb.IsVerkle() {
+ return self, nil
+ }
+ tr, err := trie.NewStateTrie(trie.StorageTrieID(stateRoot, addrHash, root), db.triedb)
if err != nil {
return nil, err
}
@@ -196,16 +215,12 @@ func (db *cachingDB) CopyTrie(t Trie) Trie {
}
// ContractCode retrieves a particular contract's code.
-func (db *cachingDB) ContractCode(codeHash common.Hash) ([]byte, error) {
+func (db *cachingDB) ContractCode(address common.Address, codeHash common.Hash) ([]byte, error) {
code, _ := db.codeCache.Get(codeHash)
if len(code) > 0 {
return code, nil
}
- cid, err := internal.Keccak256ToCid(ipld.RawBinary, codeHash.Bytes())
- if err != nil {
- return nil, err
- }
- code, err = db.disk.Get(cid.Bytes())
+ code, err := internal.ReadCode(db.disk, codeHash)
if err != nil {
return nil, err
}
@@ -217,12 +232,19 @@ func (db *cachingDB) ContractCode(codeHash common.Hash) ([]byte, error) {
return nil, errors.New("not found")
}
+// ContractCodeWithPrefix retrieves a particular contract's code. If the
+// code can't be found in the cache, then check the existence with **new**
+// db scheme.
+func (db *cachingDB) ContractCodeWithPrefix(address common.Address, codeHash common.Hash) ([]byte, error) {
+ return db.ContractCode(address, codeHash)
+}
+
// ContractCodeSize retrieves a particular contracts code's size.
-func (db *cachingDB) ContractCodeSize(codeHash common.Hash) (int, error) {
+func (db *cachingDB) ContractCodeSize(addr common.Address, codeHash common.Hash) (int, error) {
if cached, ok := db.codeSizeCache.Get(codeHash); ok {
return cached, nil
}
- code, err := db.ContractCode(codeHash)
+ code, err := db.ContractCode(addr, codeHash)
return len(code), err
}
@@ -232,6 +254,6 @@ func (db *cachingDB) DiskDB() ethdb.KeyValueStore {
}
// TrieDB retrieves any intermediate trie-node caching layer.
-func (db *cachingDB) TrieDB() *trie.Database {
+func (db *cachingDB) TrieDB() *triedb.Database {
return db.triedb
}
diff --git a/trie_by_cid/state/dump.go b/trie_by_cid/state/dump.go
new file mode 100644
index 0000000..05da3c0
--- /dev/null
+++ b/trie_by_cid/state/dump.go
@@ -0,0 +1,236 @@
+// Copyright 2014 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see .
+
+package state
+
+import (
+ "encoding/json"
+ "fmt"
+ "time"
+
+ "github.com/ethereum/go-ethereum/common"
+ "github.com/ethereum/go-ethereum/common/hexutil"
+ "github.com/ethereum/go-ethereum/core/types"
+ "github.com/ethereum/go-ethereum/log"
+ "github.com/ethereum/go-ethereum/rlp"
+
+ "github.com/cerc-io/ipld-eth-statedb/trie_by_cid/trie"
+)
+
+// DumpConfig is a set of options to control what portions of the state will be
+// iterated and collected.
+type DumpConfig struct {
+ SkipCode bool
+ SkipStorage bool
+ OnlyWithAddresses bool
+ Start []byte
+ Max uint64
+}
+
+// DumpCollector interface which the state trie calls during iteration
+type DumpCollector interface {
+ // OnRoot is called with the state root
+ OnRoot(common.Hash)
+ // OnAccount is called once for each account in the trie
+ OnAccount(*common.Address, DumpAccount)
+}
+
+// DumpAccount represents an account in the state.
+type DumpAccount struct {
+ Balance string `json:"balance"`
+ Nonce uint64 `json:"nonce"`
+ Root hexutil.Bytes `json:"root"`
+ CodeHash hexutil.Bytes `json:"codeHash"`
+ Code hexutil.Bytes `json:"code,omitempty"`
+ Storage map[common.Hash]string `json:"storage,omitempty"`
+ Address *common.Address `json:"address,omitempty"` // Address only present in iterative (line-by-line) mode
+ AddressHash hexutil.Bytes `json:"key,omitempty"` // If we don't have address, we can output the key
+
+}
+
+// Dump represents the full dump in a collected format, as one large map.
+type Dump struct {
+ Root string `json:"root"`
+ Accounts map[string]DumpAccount `json:"accounts"`
+ // Next can be set to represent that this dump is only partial, and Next
+ // is where an iterator should be positioned in order to continue the dump.
+ Next []byte `json:"next,omitempty"` // nil if no more accounts
+}
+
+// OnRoot implements DumpCollector interface
+func (d *Dump) OnRoot(root common.Hash) {
+ d.Root = fmt.Sprintf("%x", root)
+}
+
+// OnAccount implements DumpCollector interface
+func (d *Dump) OnAccount(addr *common.Address, account DumpAccount) {
+ if addr == nil {
+ d.Accounts[fmt.Sprintf("pre(%s)", account.AddressHash)] = account
+ }
+ if addr != nil {
+ d.Accounts[(*addr).String()] = account
+ }
+}
+
+// iterativeDump is a DumpCollector-implementation which dumps output line-by-line iteratively.
+type iterativeDump struct {
+ *json.Encoder
+}
+
+// OnAccount implements DumpCollector interface
+func (d iterativeDump) OnAccount(addr *common.Address, account DumpAccount) {
+ dumpAccount := &DumpAccount{
+ Balance: account.Balance,
+ Nonce: account.Nonce,
+ Root: account.Root,
+ CodeHash: account.CodeHash,
+ Code: account.Code,
+ Storage: account.Storage,
+ AddressHash: account.AddressHash,
+ Address: addr,
+ }
+ d.Encode(dumpAccount)
+}
+
+// OnRoot implements DumpCollector interface
+func (d iterativeDump) OnRoot(root common.Hash) {
+ d.Encode(struct {
+ Root common.Hash `json:"root"`
+ }{root})
+}
+
+// DumpToCollector iterates the state according to the given options and inserts
+// the items into a collector for aggregation or serialization.
+func (s *StateDB) DumpToCollector(c DumpCollector, conf *DumpConfig) (nextKey []byte) {
+ // Sanitize the input to allow nil configs
+ if conf == nil {
+ conf = new(DumpConfig)
+ }
+ var (
+ missingPreimages int
+ accounts uint64
+ start = time.Now()
+ logged = time.Now()
+ )
+ log.Info("Trie dumping started", "root", s.trie.Hash())
+ c.OnRoot(s.trie.Hash())
+
+ trieIt, err := s.trie.NodeIterator(conf.Start)
+ if err != nil {
+ log.Error("Trie dumping error", "err", err)
+ return nil
+ }
+ it := trie.NewIterator(trieIt)
+ for it.Next() {
+ var data types.StateAccount
+ if err := rlp.DecodeBytes(it.Value, &data); err != nil {
+ panic(err)
+ }
+ var (
+ account = DumpAccount{
+ Balance: data.Balance.String(),
+ Nonce: data.Nonce,
+ Root: data.Root[:],
+ CodeHash: data.CodeHash,
+ AddressHash: it.Key,
+ }
+ address *common.Address
+ addr common.Address
+ addrBytes = s.trie.GetKey(it.Key)
+ )
+ if addrBytes == nil {
+ missingPreimages++
+ if conf.OnlyWithAddresses {
+ continue
+ }
+ } else {
+ addr = common.BytesToAddress(addrBytes)
+ address = &addr
+ account.Address = address
+ }
+ obj := newObject(s, addr, &data)
+ if !conf.SkipCode {
+ account.Code = obj.Code()
+ }
+ if !conf.SkipStorage {
+ account.Storage = make(map[common.Hash]string)
+ tr, err := obj.getTrie()
+ if err != nil {
+ log.Error("Failed to load storage trie", "err", err)
+ continue
+ }
+ trieIt, err := tr.NodeIterator(nil)
+ if err != nil {
+ log.Error("Failed to create trie iterator", "err", err)
+ continue
+ }
+ storageIt := trie.NewIterator(trieIt)
+ for storageIt.Next() {
+ _, content, _, err := rlp.Split(storageIt.Value)
+ if err != nil {
+ log.Error("Failed to decode the value returned by iterator", "error", err)
+ continue
+ }
+ account.Storage[common.BytesToHash(s.trie.GetKey(storageIt.Key))] = common.Bytes2Hex(content)
+ }
+ }
+ c.OnAccount(address, account)
+ accounts++
+ if time.Since(logged) > 8*time.Second {
+ log.Info("Trie dumping in progress", "at", it.Key, "accounts", accounts,
+ "elapsed", common.PrettyDuration(time.Since(start)))
+ logged = time.Now()
+ }
+ if conf.Max > 0 && accounts >= conf.Max {
+ if it.Next() {
+ nextKey = it.Key
+ }
+ break
+ }
+ }
+ if missingPreimages > 0 {
+ log.Warn("Dump incomplete due to missing preimages", "missing", missingPreimages)
+ }
+ log.Info("Trie dumping complete", "accounts", accounts,
+ "elapsed", common.PrettyDuration(time.Since(start)))
+
+ return nextKey
+}
+
+// RawDump returns the state. If the processing is aborted e.g. due to options
+// reaching Max, the `Next` key is set on the returned Dump.
+func (s *StateDB) RawDump(opts *DumpConfig) Dump {
+ dump := &Dump{
+ Accounts: make(map[string]DumpAccount),
+ }
+ dump.Next = s.DumpToCollector(dump, opts)
+ return *dump
+}
+
+// Dump returns a JSON string representing the entire state as a single json-object
+func (s *StateDB) Dump(opts *DumpConfig) []byte {
+ dump := s.RawDump(opts)
+ json, err := json.MarshalIndent(dump, "", " ")
+ if err != nil {
+ log.Error("Error dumping state", "err", err)
+ }
+ return json
+}
+
+// IterativeDump dumps out accounts as json-objects, delimited by linebreaks on stdout
+func (s *StateDB) IterativeDump(opts *DumpConfig, output *json.Encoder) {
+ s.DumpToCollector(iterativeDump{output}, opts)
+}
diff --git a/trie_by_cid/state/journal.go b/trie_by_cid/state/journal.go
index 1722fb4..6cdc1fc 100644
--- a/trie_by_cid/state/journal.go
+++ b/trie_by_cid/state/journal.go
@@ -17,9 +17,8 @@
package state
import (
- "math/big"
-
"github.com/ethereum/go-ethereum/common"
+ "github.com/holiman/uint256"
)
// journalEntry is a modification entry in the state change journal that can be
@@ -90,19 +89,26 @@ type (
account *common.Address
}
resetObjectChange struct {
+ account *common.Address
prev *stateObject
prevdestruct bool
+ prevAccount []byte
+ prevStorage map[common.Hash][]byte
+
+ prevAccountOriginExist bool
+ prevAccountOrigin []byte
+ prevStorageOrigin map[common.Hash][]byte
}
- suicideChange struct {
+ selfDestructChange struct {
account *common.Address
- prev bool // whether account had already suicided
- prevbalance *big.Int
+ prev bool // whether account had already self-destructed
+ prevbalance *uint256.Int
}
// Changes to individual accounts.
balanceChange struct {
account *common.Address
- prev *big.Int
+ prev *uint256.Int
}
nonceChange struct {
account *common.Address
@@ -159,21 +165,33 @@ func (ch resetObjectChange) revert(s *StateDB) {
if !ch.prevdestruct {
delete(s.stateObjectsDestruct, ch.prev.address)
}
+ if ch.prevAccount != nil {
+ s.accounts[ch.prev.addrHash] = ch.prevAccount
+ }
+ if ch.prevStorage != nil {
+ s.storages[ch.prev.addrHash] = ch.prevStorage
+ }
+ if ch.prevAccountOriginExist {
+ s.accountsOrigin[ch.prev.address] = ch.prevAccountOrigin
+ }
+ if ch.prevStorageOrigin != nil {
+ s.storagesOrigin[ch.prev.address] = ch.prevStorageOrigin
+ }
}
func (ch resetObjectChange) dirtied() *common.Address {
- return nil
+ return ch.account
}
-func (ch suicideChange) revert(s *StateDB) {
+func (ch selfDestructChange) revert(s *StateDB) {
obj := s.getStateObject(*ch.account)
if obj != nil {
- obj.suicided = ch.prev
+ obj.selfDestructed = ch.prev
obj.setBalance(ch.prevbalance)
}
}
-func (ch suicideChange) dirtied() *common.Address {
+func (ch selfDestructChange) dirtied() *common.Address {
return ch.account
}
diff --git a/trie_by_cid/state/metrics.go b/trie_by_cid/state/metrics.go
index e702ef3..64c6514 100644
--- a/trie_by_cid/state/metrics.go
+++ b/trie_by_cid/state/metrics.go
@@ -27,4 +27,11 @@ var (
storageTriesUpdatedMeter = metrics.NewRegisteredMeter("state/update/storagenodes", nil)
accountTrieDeletedMeter = metrics.NewRegisteredMeter("state/delete/accountnodes", nil)
storageTriesDeletedMeter = metrics.NewRegisteredMeter("state/delete/storagenodes", nil)
+
+ slotDeletionMaxCount = metrics.NewRegisteredGauge("state/delete/storage/max/slot", nil)
+ slotDeletionMaxSize = metrics.NewRegisteredGauge("state/delete/storage/max/size", nil)
+ slotDeletionTimer = metrics.NewRegisteredResettingTimer("state/delete/storage/timer", nil)
+ slotDeletionCount = metrics.NewRegisteredMeter("state/delete/storage/slot", nil)
+ slotDeletionSize = metrics.NewRegisteredMeter("state/delete/storage/size", nil)
+ slotDeletionSkip = metrics.NewRegisteredGauge("state/delete/storage/skip", nil)
)
diff --git a/trie_by_cid/state/state_object.go b/trie_by_cid/state/state_object.go
index e1afcb3..9196757 100644
--- a/trie_by_cid/state/state_object.go
+++ b/trie_by_cid/state/state_object.go
@@ -20,7 +20,6 @@ import (
"bytes"
"fmt"
"io"
- "math/big"
"time"
"github.com/ethereum/go-ethereum/common"
@@ -28,7 +27,9 @@ import (
"github.com/ethereum/go-ethereum/crypto"
"github.com/ethereum/go-ethereum/metrics"
"github.com/ethereum/go-ethereum/rlp"
- // "github.com/cerc-io/ipld-eth-statedb/trie_by_cid/trie"
+ "github.com/holiman/uint256"
+
+ "github.com/cerc-io/ipld-eth-statedb/trie_by_cid/trie/trienode"
)
type Code []byte
@@ -57,54 +58,64 @@ func (s Storage) Copy() Storage {
// stateObject represents an Ethereum account which is being modified.
//
// The usage pattern is as follows:
-// First you need to obtain a state object.
-// Account values can be accessed and modified through the object.
+// - First you need to obtain a state object.
+// - Account values as well as storages can be accessed and modified through the object.
+// - Finally, call commit to return the changes of storage trie and update account data.
type stateObject struct {
- address common.Address
- addrHash common.Hash // hash of ethereum address of the account
- data types.StateAccount
db *StateDB
+ address common.Address // address of ethereum account
+ addrHash common.Hash // hash of ethereum address of the account
+ origin *types.StateAccount // Account original data without any change applied, nil means it was not existent
+ data types.StateAccount // Account data with all mutations applied in the scope of block
// Write caches.
trie Trie // storage trie, which becomes non-nil on first access
code Code // contract bytecode, which gets set when code is loaded
- originStorage Storage // Storage cache of original entries to dedup rewrites, reset for every transaction
+ originStorage Storage // Storage cache of original entries to dedup rewrites
pendingStorage Storage // Storage entries that need to be flushed to disk, at the end of an entire block
- dirtyStorage Storage // Storage entries that have been modified in the current transaction execution
+ dirtyStorage Storage // Storage entries that have been modified in the current transaction execution, reset for every transaction
// Cache flags.
- // When an object is marked suicided it will be deleted from the trie
- // during the "update" phase of the state transition.
dirtyCode bool // true if the code was updated
- suicided bool
- deleted bool
+
+ // Flag whether the account was marked as self-destructed. The self-destructed account
+ // is still accessible in the scope of same transaction.
+ selfDestructed bool
+
+ // Flag whether the account was marked as deleted. A self-destructed account
+ // or an account that is considered as empty will be marked as deleted at
+ // the end of transaction and no longer accessible anymore.
+ deleted bool
+
+ // Flag whether the object was created in the current transaction
+ created bool
}
// empty returns whether the account is considered empty.
func (s *stateObject) empty() bool {
- return s.data.Nonce == 0 && s.data.Balance.Sign() == 0 && bytes.Equal(s.data.CodeHash, types.EmptyCodeHash.Bytes())
+ return s.data.Nonce == 0 && s.data.Balance.IsZero() && bytes.Equal(s.data.CodeHash, types.EmptyCodeHash.Bytes())
}
// newObject creates a state object.
-func newObject(db *StateDB, address common.Address, data types.StateAccount) *stateObject {
- if data.Balance == nil {
- data.Balance = new(big.Int)
- }
- if data.CodeHash == nil {
- data.CodeHash = types.EmptyCodeHash.Bytes()
- }
- if data.Root == (common.Hash{}) {
- data.Root = types.EmptyRootHash
+func newObject(db *StateDB, address common.Address, acct *types.StateAccount) *stateObject {
+ var (
+ origin = acct
+ created = acct == nil // true if the account was not existent
+ )
+ if acct == nil {
+ acct = types.NewEmptyStateAccount()
}
return &stateObject{
db: db,
address: address,
addrHash: crypto.Keccak256Hash(address[:]),
- data: data,
+ origin: origin,
+ data: *acct,
originStorage: make(Storage),
pendingStorage: make(Storage),
dirtyStorage: make(Storage),
+ created: created,
}
}
@@ -113,8 +124,8 @@ func (s *stateObject) EncodeRLP(w io.Writer) error {
return rlp.Encode(w, &s.data)
}
-func (s *stateObject) markSuicided() {
- s.suicided = true
+func (s *stateObject) markSelfdestructed() {
+ s.selfDestructed = true
}
func (s *stateObject) touch() {
@@ -131,17 +142,15 @@ func (s *stateObject) touch() {
// getTrie returns the associated storage trie. The trie will be opened
// if it's not loaded previously. An error will be returned if trie can't
// be loaded.
-func (s *stateObject) getTrie(db Database) (Trie, error) {
+func (s *stateObject) getTrie() (Trie, error) {
if s.trie == nil {
// Try fetching from prefetcher first
- // We don't prefetch empty tries
if s.data.Root != types.EmptyRootHash && s.db.prefetcher != nil {
- // When the miner is creating the pending state, there is no
- // prefetcher
+ // When the miner is creating the pending state, there is no prefetcher
s.trie = s.db.prefetcher.trie(s.addrHash, s.data.Root)
}
if s.trie == nil {
- tr, err := db.OpenStorageTrie(s.db.originalRoot, s.addrHash, s.data.Root)
+ tr, err := s.db.db.OpenStorageTrie(s.db.originalRoot, s.addrHash, s.data.Root, s.db.trie)
if err != nil {
return nil, err
}
@@ -152,18 +161,18 @@ func (s *stateObject) getTrie(db Database) (Trie, error) {
}
// GetState retrieves a value from the account storage trie.
-func (s *stateObject) GetState(db Database, key common.Hash) common.Hash {
+func (s *stateObject) GetState(key common.Hash) common.Hash {
// If we have a dirty value for this state entry, return it
value, dirty := s.dirtyStorage[key]
if dirty {
return value
}
// Otherwise return the entry's original value
- return s.GetCommittedState(db, key)
+ return s.GetCommittedState(key)
}
// GetCommittedState retrieves a value from the committed account storage trie.
-func (s *stateObject) GetCommittedState(db Database, key common.Hash) common.Hash {
+func (s *stateObject) GetCommittedState(key common.Hash) common.Hash {
// If we have a pending write or clean cached, return that
if value, pending := s.pendingStorage[key]; pending {
return value
@@ -182,8 +191,9 @@ func (s *stateObject) GetCommittedState(db Database, key common.Hash) common.Has
}
// If no live objects are available, attempt to use snapshots
var (
- enc []byte
- err error
+ enc []byte
+ err error
+ value common.Hash
)
if s.db.snap != nil {
start := time.Now()
@@ -191,16 +201,23 @@ func (s *stateObject) GetCommittedState(db Database, key common.Hash) common.Has
if metrics.EnabledExpensive {
s.db.SnapshotStorageReads += time.Since(start)
}
+ if len(enc) > 0 {
+ _, content, _, err := rlp.Split(enc)
+ if err != nil {
+ s.db.setError(err)
+ }
+ value.SetBytes(content)
+ }
}
// If the snapshot is unavailable or reading from it fails, load from the database.
if s.db.snap == nil || err != nil {
start := time.Now()
- tr, err := s.getTrie(db)
+ tr, err := s.getTrie()
if err != nil {
s.db.setError(err)
return common.Hash{}
}
- enc, err = tr.TryGet(key.Bytes())
+ val, err := tr.GetStorage(s.address, key.Bytes())
if metrics.EnabledExpensive {
s.db.StorageReads += time.Since(start)
}
@@ -208,23 +225,16 @@ func (s *stateObject) GetCommittedState(db Database, key common.Hash) common.Has
s.db.setError(err)
return common.Hash{}
}
- }
- var value common.Hash
- if len(enc) > 0 {
- _, content, _, err := rlp.Split(enc)
- if err != nil {
- s.db.setError(err)
- }
- value.SetBytes(content)
+ value.SetBytes(val)
}
s.originStorage[key] = value
return value
}
// SetState updates a value in account storage.
-func (s *stateObject) SetState(db Database, key, value common.Hash) {
+func (s *stateObject) SetState(key, value common.Hash) {
// If the new value is the same as old, don't set
- prev := s.GetState(db, key)
+ prev := s.GetState(key)
if prev == value {
return
}
@@ -252,19 +262,24 @@ func (s *stateObject) finalise(prefetch bool) {
}
}
if s.db.prefetcher != nil && prefetch && len(slotsToPrefetch) > 0 && s.data.Root != types.EmptyRootHash {
- s.db.prefetcher.prefetch(s.addrHash, s.data.Root, slotsToPrefetch)
+ s.db.prefetcher.prefetch(s.addrHash, s.data.Root, s.address, slotsToPrefetch)
}
if len(s.dirtyStorage) > 0 {
s.dirtyStorage = make(Storage)
}
}
-// updateTrie writes cached storage modifications into the object's storage trie.
-// It will return nil if the trie has not been loaded and no changes have been
-// made. An error will be returned if the trie can't be loaded/updated correctly.
-func (s *stateObject) updateTrie(db Database) (Trie, error) {
+// updateTrie is responsible for persisting cached storage changes into the
+// object's storage trie. In case the storage trie is not yet loaded, this
+// function will load the trie automatically. If any issues arise during the
+// loading or updating of the trie, an error will be returned. Furthermore,
+// this function will return the mutated storage trie, or nil if there is no
+// storage change at all.
+func (s *stateObject) updateTrie() (Trie, error) {
// Make sure all dirty slots are finalized into the pending storage area
- s.finalise(false) // Don't prefetch anymore, pull directly if need be
+ s.finalise(false)
+
+ // Short circuit if nothing changed, don't bother with hashing anything
if len(s.pendingStorage) == 0 {
return s.trie, nil
}
@@ -275,69 +290,84 @@ func (s *stateObject) updateTrie(db Database) (Trie, error) {
// The snapshot storage map for the object
var (
storage map[common.Hash][]byte
- hasher = s.db.hasher
+ origin map[common.Hash][]byte
)
- tr, err := s.getTrie(db)
+ tr, err := s.getTrie()
if err != nil {
s.db.setError(err)
return nil, err
}
- // Insert all the pending updates into the trie
+ // Insert all the pending storage updates into the trie
usedStorage := make([][]byte, 0, len(s.pendingStorage))
for key, value := range s.pendingStorage {
// Skip noop changes, persist actual changes
if value == s.originStorage[key] {
continue
}
+ prev := s.originStorage[key]
s.originStorage[key] = value
- var v []byte
+ var encoded []byte // rlp-encoded value to be used by the snapshot
if (value == common.Hash{}) {
- if err := tr.TryDelete(key[:]); err != nil {
+ if err := tr.DeleteStorage(s.address, key[:]); err != nil {
s.db.setError(err)
return nil, err
}
s.db.StorageDeleted += 1
} else {
// Encoding []byte cannot fail, ok to ignore the error.
- v, _ = rlp.EncodeToBytes(common.TrimLeftZeroes(value[:]))
- if err := tr.TryUpdate(key[:], v); err != nil {
+ trimmed := common.TrimLeftZeroes(value[:])
+ encoded, _ = rlp.EncodeToBytes(trimmed)
+ if err := tr.UpdateStorage(s.address, key[:], trimmed); err != nil {
s.db.setError(err)
return nil, err
}
s.db.StorageUpdated += 1
}
- // If state snapshotting is active, cache the data til commit
- if s.db.snap != nil {
- if storage == nil {
- // Retrieve the old storage map, if available, create a new one otherwise
- if storage = s.db.snapStorage[s.addrHash]; storage == nil {
- storage = make(map[common.Hash][]byte)
- s.db.snapStorage[s.addrHash] = storage
- }
+ // Cache the mutated storage slots until commit
+ if storage == nil {
+ if storage = s.db.storages[s.addrHash]; storage == nil {
+ storage = make(map[common.Hash][]byte)
+ s.db.storages[s.addrHash] = storage
}
- storage[crypto.HashData(hasher, key[:])] = v // v will be nil if it's deleted
}
+ khash := crypto.HashData(s.db.hasher, key[:])
+ storage[khash] = encoded // encoded will be nil if it's deleted
+
+ // Cache the original value of mutated storage slots
+ if origin == nil {
+ if origin = s.db.storagesOrigin[s.address]; origin == nil {
+ origin = make(map[common.Hash][]byte)
+ s.db.storagesOrigin[s.address] = origin
+ }
+ }
+ // Track the original value of slot only if it's mutated first time
+ if _, ok := origin[khash]; !ok {
+ if prev == (common.Hash{}) {
+ origin[khash] = nil // nil if it was not present previously
+ } else {
+ // Encoding []byte cannot fail, ok to ignore the error.
+ b, _ := rlp.EncodeToBytes(common.TrimLeftZeroes(prev[:]))
+ origin[khash] = b
+ }
+ }
+ // Cache the items for preloading
usedStorage = append(usedStorage, common.CopyBytes(key[:])) // Copy needed for closure
}
if s.db.prefetcher != nil {
s.db.prefetcher.used(s.addrHash, s.data.Root, usedStorage)
}
- if len(s.pendingStorage) > 0 {
- s.pendingStorage = make(Storage)
- }
+ s.pendingStorage = make(Storage) // reset pending map
return tr, nil
}
-// UpdateRoot sets the trie root to the current root hash of. An error
-// will be returned if trie root hash is not computed correctly.
-func (s *stateObject) updateRoot(db Database) {
- tr, err := s.updateTrie(db)
- if err != nil {
- return
- }
- // If nothing changed, don't bother with hashing anything
- if tr == nil {
+// updateRoot flushes all cached storage mutations to trie, recalculating the
+// new storage trie root.
+func (s *stateObject) updateRoot() {
+ // Flush cached storage mutations into trie, short circuit if any error
+ // is occurred or there is not change in the trie.
+ tr, err := s.updateTrie()
+ if err != nil || tr == nil {
return
}
// Track the amount of time wasted on hashing the storage trie
@@ -347,54 +377,87 @@ func (s *stateObject) updateRoot(db Database) {
s.data.Root = tr.Hash()
}
+// commit obtains a set of dirty storage trie nodes and updates the account data.
+// The returned set can be nil if nothing to commit. This function assumes all
+// storage mutations have already been flushed into trie by updateRoot.
+func (s *stateObject) commit() (*trienode.NodeSet, error) {
+ // Short circuit if trie is not even loaded, don't bother with committing anything
+ if s.trie == nil {
+ s.origin = s.data.Copy()
+ return nil, nil
+ }
+ // Track the amount of time wasted on committing the storage trie
+ if metrics.EnabledExpensive {
+ defer func(start time.Time) { s.db.StorageCommits += time.Since(start) }(time.Now())
+ }
+ // The trie is currently in an open state and could potentially contain
+ // cached mutations. Call commit to acquire a set of nodes that have been
+ // modified, the set can be nil if nothing to commit.
+ root, nodes, err := s.trie.Commit(false)
+ if err != nil {
+ return nil, err
+ }
+ s.data.Root = root
+
+ // Update original account data after commit
+ s.origin = s.data.Copy()
+ return nodes, nil
+}
+
// AddBalance adds amount to s's balance.
// It is used to add funds to the destination account of a transfer.
-func (s *stateObject) AddBalance(amount *big.Int) {
+func (s *stateObject) AddBalance(amount *uint256.Int) {
// EIP161: We must check emptiness for the objects such that the account
// clearing (0,0,0 objects) can take effect.
- if amount.Sign() == 0 {
+ if amount.IsZero() {
if s.empty() {
s.touch()
}
return
}
- s.SetBalance(new(big.Int).Add(s.Balance(), amount))
+ s.SetBalance(new(uint256.Int).Add(s.Balance(), amount))
}
// SubBalance removes amount from s's balance.
// It is used to remove funds from the origin account of a transfer.
-func (s *stateObject) SubBalance(amount *big.Int) {
- if amount.Sign() == 0 {
+func (s *stateObject) SubBalance(amount *uint256.Int) {
+ if amount.IsZero() {
return
}
- s.SetBalance(new(big.Int).Sub(s.Balance(), amount))
+ s.SetBalance(new(uint256.Int).Sub(s.Balance(), amount))
}
-func (s *stateObject) SetBalance(amount *big.Int) {
+func (s *stateObject) SetBalance(amount *uint256.Int) {
s.db.journal.append(balanceChange{
account: &s.address,
- prev: new(big.Int).Set(s.data.Balance),
+ prev: new(uint256.Int).Set(s.data.Balance),
})
s.setBalance(amount)
}
-func (s *stateObject) setBalance(amount *big.Int) {
+func (s *stateObject) setBalance(amount *uint256.Int) {
s.data.Balance = amount
}
func (s *stateObject) deepCopy(db *StateDB) *stateObject {
- stateObject := newObject(db, s.address, s.data)
- if s.trie != nil {
- stateObject.trie = db.db.CopyTrie(s.trie)
+ obj := &stateObject{
+ db: db,
+ address: s.address,
+ addrHash: s.addrHash,
+ origin: s.origin,
+ data: s.data,
}
- stateObject.code = s.code
- stateObject.dirtyStorage = s.dirtyStorage.Copy()
- stateObject.originStorage = s.originStorage.Copy()
- stateObject.pendingStorage = s.pendingStorage.Copy()
- stateObject.suicided = s.suicided
- stateObject.dirtyCode = s.dirtyCode
- stateObject.deleted = s.deleted
- return stateObject
+ if s.trie != nil {
+ obj.trie = db.db.CopyTrie(s.trie)
+ }
+ obj.code = s.code
+ obj.dirtyStorage = s.dirtyStorage.Copy()
+ obj.originStorage = s.originStorage.Copy()
+ obj.pendingStorage = s.pendingStorage.Copy()
+ obj.selfDestructed = s.selfDestructed
+ obj.dirtyCode = s.dirtyCode
+ obj.deleted = s.deleted
+ return obj
}
//
@@ -407,14 +470,14 @@ func (s *stateObject) Address() common.Address {
}
// Code returns the contract code associated with this object, if any.
-func (s *stateObject) Code(db Database) []byte {
+func (s *stateObject) Code() []byte {
if s.code != nil {
return s.code
}
if bytes.Equal(s.CodeHash(), types.EmptyCodeHash.Bytes()) {
return nil
}
- code, err := db.ContractCode(common.BytesToHash(s.CodeHash()))
+ code, err := s.db.db.ContractCode(s.address, common.BytesToHash(s.CodeHash()))
if err != nil {
s.db.setError(fmt.Errorf("can't load code hash %x: %v", s.CodeHash(), err))
}
@@ -425,14 +488,14 @@ func (s *stateObject) Code(db Database) []byte {
// CodeSize returns the size of the contract code associated with this object,
// or zero if none. This method is an almost mirror of Code, but uses a cache
// inside the database to avoid loading codes seen recently.
-func (s *stateObject) CodeSize(db Database) int {
+func (s *stateObject) CodeSize() int {
if s.code != nil {
return len(s.code)
}
if bytes.Equal(s.CodeHash(), types.EmptyCodeHash.Bytes()) {
return 0
}
- size, err := db.ContractCodeSize(common.BytesToHash(s.CodeHash()))
+ size, err := s.db.db.ContractCodeSize(s.address, common.BytesToHash(s.CodeHash()))
if err != nil {
s.db.setError(fmt.Errorf("can't load code size %x: %v", s.CodeHash(), err))
}
@@ -440,7 +503,7 @@ func (s *stateObject) CodeSize(db Database) int {
}
func (s *stateObject) SetCode(codeHash common.Hash, code []byte) {
- prevcode := s.Code(s.db.db)
+ prevcode := s.Code()
s.db.journal.append(codeChange{
account: &s.address,
prevhash: s.CodeHash(),
@@ -471,10 +534,14 @@ func (s *stateObject) CodeHash() []byte {
return s.data.CodeHash
}
-func (s *stateObject) Balance() *big.Int {
+func (s *stateObject) Balance() *uint256.Int {
return s.data.Balance
}
func (s *stateObject) Nonce() uint64 {
return s.data.Nonce
}
+
+func (s *stateObject) Root() common.Hash {
+ return s.data.Root
+}
diff --git a/trie_by_cid/state/state_test.go b/trie_by_cid/state/state_test.go
index f3e5c81..03756e5 100644
--- a/trie_by_cid/state/state_test.go
+++ b/trie_by_cid/state/state_test.go
@@ -19,15 +19,15 @@ package state
import (
"bytes"
"context"
- "math/big"
"testing"
pgipfsethdb "github.com/cerc-io/ipfs-ethdb/v5/postgres/v0"
"github.com/cerc-io/plugeth-statediff/indexer/database/sql/postgres"
"github.com/ethereum/go-ethereum/common"
- "github.com/ethereum/go-ethereum/core/rawdb"
+ "github.com/ethereum/go-ethereum/core/types"
"github.com/ethereum/go-ethereum/crypto"
"github.com/ethereum/go-ethereum/ethdb"
+ "github.com/holiman/uint256"
"github.com/cerc-io/ipld-eth-statedb/internal"
)
@@ -38,33 +38,38 @@ var (
teardownStatements = []string{`TRUNCATE ipld.blocks`}
)
-type stateTest struct {
+type stateEnv struct {
db ethdb.Database
state *StateDB
}
-func newStateTest(t *testing.T) *stateTest {
+func newStateEnv(t *testing.T) *stateEnv {
+ db := newPgIpfsEthdb(t)
+ sdb, err := New(types.EmptyRootHash, NewDatabase(db), nil)
+ if err != nil {
+ t.Fatal(err)
+ }
+ return &stateEnv{db: db, state: sdb}
+}
+
+func newPgIpfsEthdb(t *testing.T) ethdb.Database {
pool, err := postgres.ConnectSQLX(testCtx, testConfig)
if err != nil {
t.Fatal(err)
}
db := pgipfsethdb.NewDatabase(pool, internal.MakeCacheConfig(t))
- sdb, err := New(common.Hash{}, NewDatabase(db), nil)
- if err != nil {
- t.Fatal(err)
- }
- return &stateTest{db: db, state: sdb}
+ return db
}
func TestNull(t *testing.T) {
- s := newStateTest(t)
+ s := newStateEnv(t)
address := common.HexToAddress("0x823140710bf13990e4500136726d8b55")
s.state.CreateAccount(address)
//value := common.FromHex("0x823140710bf13990e4500136726d8b55")
var value common.Hash
s.state.SetState(address, common.Hash{}, value)
- // s.state.Commit(false)
+ // s.state.Commit(0, false)
if value := s.state.GetState(address, common.Hash{}); value != (common.Hash{}) {
t.Errorf("expected empty current value, got %x", value)
@@ -79,7 +84,7 @@ func TestSnapshot(t *testing.T) {
var storageaddr common.Hash
data1 := common.BytesToHash([]byte{42})
data2 := common.BytesToHash([]byte{43})
- s := newStateTest(t)
+ s := newStateEnv(t)
// snapshot the genesis state
genesis := s.state.Snapshot()
@@ -110,12 +115,12 @@ func TestSnapshot(t *testing.T) {
}
func TestSnapshotEmpty(t *testing.T) {
- s := newStateTest(t)
+ s := newStateEnv(t)
s.state.RevertToSnapshot(s.state.Snapshot())
}
func TestSnapshot2(t *testing.T) {
- state, _ := New(common.Hash{}, NewDatabase(rawdb.NewMemoryDatabase()), nil)
+ state, _ := New(types.EmptyRootHash, NewDatabase(newPgIpfsEthdb(t)), nil)
stateobjaddr0 := common.BytesToAddress([]byte("so0"))
stateobjaddr1 := common.BytesToAddress([]byte("so1"))
@@ -129,22 +134,22 @@ func TestSnapshot2(t *testing.T) {
// db, trie are already non-empty values
so0 := state.getStateObject(stateobjaddr0)
- so0.SetBalance(big.NewInt(42))
+ so0.SetBalance(uint256.NewInt(42))
so0.SetNonce(43)
so0.SetCode(crypto.Keccak256Hash([]byte{'c', 'a', 'f', 'e'}), []byte{'c', 'a', 'f', 'e'})
- so0.suicided = false
+ so0.selfDestructed = false
so0.deleted = false
state.setStateObject(so0)
- // root, _ := state.Commit(false)
+ // root, _ := state.Commit(0, false)
// state, _ = New(root, state.db, state.snaps)
// and one with deleted == true
so1 := state.getStateObject(stateobjaddr1)
- so1.SetBalance(big.NewInt(52))
+ so1.SetBalance(uint256.NewInt(52))
so1.SetNonce(53)
so1.SetCode(crypto.Keccak256Hash([]byte{'c', 'a', 'f', 'e', '2'}), []byte{'c', 'a', 'f', 'e', '2'})
- so1.suicided = true
+ so1.selfDestructed = true
so1.deleted = true
state.setStateObject(so1)
@@ -158,8 +163,8 @@ func TestSnapshot2(t *testing.T) {
so0Restored := state.getStateObject(stateobjaddr0)
// Update lazily-loaded values before comparing.
- so0Restored.GetState(state.db, storageaddr)
- so0Restored.Code(state.db)
+ so0Restored.GetState(storageaddr)
+ so0Restored.Code()
// non-deleted is equal (restored)
compareStateObjects(so0Restored, so0, t)
diff --git a/trie_by_cid/state/statedb.go b/trie_by_cid/state/statedb.go
index b4c4369..e70f17c 100644
--- a/trie_by_cid/state/statedb.go
+++ b/trie_by_cid/state/statedb.go
@@ -18,21 +18,29 @@
package state
import (
- "errors"
"fmt"
- "math/big"
"sort"
"time"
"github.com/ethereum/go-ethereum/common"
+ "github.com/ethereum/go-ethereum/core/rawdb"
"github.com/ethereum/go-ethereum/core/state/snapshot"
"github.com/ethereum/go-ethereum/core/types"
"github.com/ethereum/go-ethereum/crypto"
+ "github.com/ethereum/go-ethereum/log"
"github.com/ethereum/go-ethereum/metrics"
"github.com/ethereum/go-ethereum/params"
- "github.com/ethereum/go-ethereum/rlp"
+ "github.com/holiman/uint256"
"github.com/cerc-io/ipld-eth-statedb/trie_by_cid/trie"
+ "github.com/cerc-io/ipld-eth-statedb/trie_by_cid/trie/trienode"
+ "github.com/cerc-io/ipld-eth-statedb/trie_by_cid/trie/triestate"
+)
+
+const (
+ // storageDeleteLimit denotes the highest permissible memory allocation
+ // employed for contract storage deletion.
+ storageDeleteLimit = 512 * 1024 * 1024
)
type revision struct {
@@ -40,42 +48,42 @@ type revision struct {
journalIndex int
}
-type proofList [][]byte
-
-func (n *proofList) Put(key []byte, value []byte) error {
- *n = append(*n, value)
- return nil
-}
-
-func (n *proofList) Delete(key []byte) error {
- panic("not supported")
-}
-
// StateDB structs within the ethereum protocol are used to store anything
// within the merkle trie. StateDBs take care of caching and storing
// nested states. It's the general query interface to retrieve:
+//
// * Contracts
// * Accounts
+//
+// Once the state is committed, tries cached in stateDB (including account
+// trie, storage tries) will no longer be functional. A new state instance
+// must be created with new root and updated database for accessing post-
+// commit states.
type StateDB struct {
db Database
prefetcher *triePrefetcher
trie Trie
hasher crypto.KeccakState
+ snaps *snapshot.Tree // Nil if snapshot is not available
+ snap snapshot.Snapshot // Nil if snapshot is not available
// originalRoot is the pre-state root, before any changes were made.
// It will be updated when the Commit is called.
originalRoot common.Hash
- snaps *snapshot.Tree
- snap snapshot.Snapshot
- snapAccounts map[common.Hash][]byte
- snapStorage map[common.Hash]map[common.Hash][]byte
+ // These maps hold the state changes (including the corresponding
+ // original value) that occurred in this **block**.
+ accounts map[common.Hash][]byte // The mutated accounts in 'slim RLP' encoding
+ storages map[common.Hash]map[common.Hash][]byte // The mutated slots in prefix-zero trimmed rlp format
+ accountsOrigin map[common.Address][]byte // The original value of mutated accounts in 'slim RLP' encoding
+ storagesOrigin map[common.Address]map[common.Hash][]byte // The original value of mutated slots in prefix-zero trimmed rlp format
- // This map holds 'live' objects, which will get modified while processing a state transition.
+ // This map holds 'live' objects, which will get modified while processing
+ // a state transition.
stateObjects map[common.Address]*stateObject
- stateObjectsPending map[common.Address]struct{} // State objects finalized but not yet written to the trie
- stateObjectsDirty map[common.Address]struct{} // State objects modified in the current execution
- stateObjectsDestruct map[common.Address]struct{} // State objects destructed in the block
+ stateObjectsPending map[common.Address]struct{} // State objects finalized but not yet written to the trie
+ stateObjectsDirty map[common.Address]struct{} // State objects modified in the current execution
+ stateObjectsDestruct map[common.Address]*types.StateAccount // State objects destructed in the block along with its previous value
// DB error.
// State objects are used by the consensus core and VM which are
@@ -89,11 +97,13 @@ type StateDB struct {
// The refund counter, also used by state transitioning.
refund uint64
+ // The tx context and all occurred logs in the scope of transaction.
thash common.Hash
txIndex int
logs map[common.Hash][]*types.Log
logSize uint
+ // Preimages occurred seen by VM in the scope of block.
preimages map[common.Hash][]byte
// Per-transaction access list
@@ -112,16 +122,23 @@ type StateDB struct {
AccountReads time.Duration
AccountHashes time.Duration
AccountUpdates time.Duration
+ AccountCommits time.Duration
StorageReads time.Duration
StorageHashes time.Duration
StorageUpdates time.Duration
+ StorageCommits time.Duration
SnapshotAccountReads time.Duration
SnapshotStorageReads time.Duration
+ SnapshotCommits time.Duration
+ TrieDBCommits time.Duration
AccountUpdated int
StorageUpdated int
AccountDeleted int
StorageDeleted int
+
+ // Testing hooks
+ onCommit func(states *triestate.Set) // Hook invoked when commit is performed
}
// New creates a new state from a given trie.
@@ -135,10 +152,14 @@ func New(root common.Hash, db Database, snaps *snapshot.Tree) (*StateDB, error)
trie: tr,
originalRoot: root,
snaps: snaps,
+ accounts: make(map[common.Hash][]byte),
+ storages: make(map[common.Hash]map[common.Hash][]byte),
+ accountsOrigin: make(map[common.Address][]byte),
+ storagesOrigin: make(map[common.Address]map[common.Hash][]byte),
stateObjects: make(map[common.Address]*stateObject),
stateObjectsPending: make(map[common.Address]struct{}),
stateObjectsDirty: make(map[common.Address]struct{}),
- stateObjectsDestruct: make(map[common.Address]struct{}),
+ stateObjectsDestruct: make(map[common.Address]*types.StateAccount),
logs: make(map[common.Hash][]*types.Log),
preimages: make(map[common.Hash][]byte),
journal: newJournal(),
@@ -147,10 +168,7 @@ func New(root common.Hash, db Database, snaps *snapshot.Tree) (*StateDB, error)
hasher: crypto.NewKeccakState(),
}
if sdb.snaps != nil {
- if sdb.snap = sdb.snaps.Snapshot(root); sdb.snap != nil {
- sdb.snapAccounts = make(map[common.Hash][]byte)
- sdb.snapStorage = make(map[common.Hash]map[common.Hash][]byte)
- }
+ sdb.snap = sdb.snaps.Snapshot(root)
}
return sdb, nil
}
@@ -250,7 +268,7 @@ func (s *StateDB) SubRefund(gas uint64) {
}
// Exist reports whether the given account address exists in the state.
-// Notably this also returns true for suicided accounts.
+// Notably this also returns true for self-destructed accounts.
func (s *StateDB) Exist(addr common.Address) bool {
return s.getStateObject(addr) != nil
}
@@ -263,14 +281,15 @@ func (s *StateDB) Empty(addr common.Address) bool {
}
// GetBalance retrieves the balance from the given address or 0 if object not found
-func (s *StateDB) GetBalance(addr common.Address) *big.Int {
+func (s *StateDB) GetBalance(addr common.Address) *uint256.Int {
stateObject := s.getStateObject(addr)
if stateObject != nil {
return stateObject.Balance()
}
- return common.Big0
+ return common.U2560
}
+// GetNonce retrieves the nonce from the given address or 0 if object not found
func (s *StateDB) GetNonce(addr common.Address) uint64 {
stateObject := s.getStateObject(addr)
if stateObject != nil {
@@ -280,6 +299,16 @@ func (s *StateDB) GetNonce(addr common.Address) uint64 {
return 0
}
+// GetStorageRoot retrieves the storage root from the given address or empty
+// if object not found.
+func (s *StateDB) GetStorageRoot(addr common.Address) common.Hash {
+ stateObject := s.getStateObject(addr)
+ if stateObject != nil {
+ return stateObject.Root()
+ }
+ return common.Hash{}
+}
+
// TxIndex returns the current transaction index set by Prepare.
func (s *StateDB) TxIndex() int {
return s.txIndex
@@ -288,7 +317,7 @@ func (s *StateDB) TxIndex() int {
func (s *StateDB) GetCode(addr common.Address) []byte {
stateObject := s.getStateObject(addr)
if stateObject != nil {
- return stateObject.Code(s.db)
+ return stateObject.Code()
}
return nil
}
@@ -296,62 +325,33 @@ func (s *StateDB) GetCode(addr common.Address) []byte {
func (s *StateDB) GetCodeSize(addr common.Address) int {
stateObject := s.getStateObject(addr)
if stateObject != nil {
- return stateObject.CodeSize(s.db)
+ return stateObject.CodeSize()
}
return 0
}
func (s *StateDB) GetCodeHash(addr common.Address) common.Hash {
stateObject := s.getStateObject(addr)
- if stateObject == nil {
- return common.Hash{}
+ if stateObject != nil {
+ return common.BytesToHash(stateObject.CodeHash())
}
- return common.BytesToHash(stateObject.CodeHash())
+ return common.Hash{}
}
// GetState retrieves a value from the given account's storage trie.
func (s *StateDB) GetState(addr common.Address, hash common.Hash) common.Hash {
stateObject := s.getStateObject(addr)
if stateObject != nil {
- return stateObject.GetState(s.db, hash)
+ return stateObject.GetState(hash)
}
return common.Hash{}
}
-// GetProof returns the Merkle proof for a given account.
-func (s *StateDB) GetProof(addr common.Address) ([][]byte, error) {
- return s.GetProofByHash(crypto.Keccak256Hash(addr.Bytes()))
-}
-
-// GetProofByHash returns the Merkle proof for a given account.
-func (s *StateDB) GetProofByHash(addrHash common.Hash) ([][]byte, error) {
- var proof proofList
- err := s.trie.Prove(addrHash[:], 0, &proof)
- return proof, err
-}
-
-// GetStorageProof returns the Merkle proof for given storage slot.
-func (s *StateDB) GetStorageProof(a common.Address, key common.Hash) ([][]byte, error) {
- trie, err := s.StorageTrie(a)
- if err != nil {
- return nil, err
- }
- if trie == nil {
- return nil, errors.New("storage trie for requested address does not exist")
- }
- var proof proofList
- err = trie.Prove(crypto.Keccak256(key.Bytes()), 0, &proof)
- if err != nil {
- return nil, err
- }
- return proof, nil
-}
-
// GetCommittedState retrieves a value from the given account's committed storage trie.
func (s *StateDB) GetCommittedState(addr common.Address, hash common.Hash) common.Hash {
stateObject := s.getStateObject(addr)
if stateObject != nil {
- return stateObject.GetCommittedState(s.db, hash)
+ return stateObject.GetCommittedState(hash)
}
return common.Hash{}
}
@@ -361,25 +361,10 @@ func (s *StateDB) Database() Database {
return s.db
}
-// StorageTrie returns the storage trie of an account. The return value is a copy
-// and is nil for non-existent accounts. An error will be returned if storage trie
-// is existent but can't be loaded correctly.
-func (s *StateDB) StorageTrie(addr common.Address) (Trie, error) {
- stateObject := s.getStateObject(addr)
- if stateObject == nil {
- return nil, nil
- }
- cpy := stateObject.deepCopy(s)
- if _, err := cpy.updateTrie(s.db); err != nil {
- return nil, err
- }
- return cpy.getTrie(s.db)
-}
-
-func (s *StateDB) HasSuicided(addr common.Address) bool {
+func (s *StateDB) HasSelfDestructed(addr common.Address) bool {
stateObject := s.getStateObject(addr)
if stateObject != nil {
- return stateObject.suicided
+ return stateObject.selfDestructed
}
return false
}
@@ -389,83 +374,98 @@ func (s *StateDB) HasSuicided(addr common.Address) bool {
*/
// AddBalance adds amount to the account associated with addr.
-func (s *StateDB) AddBalance(addr common.Address, amount *big.Int) {
- stateObject := s.GetOrNewStateObject(addr)
+func (s *StateDB) AddBalance(addr common.Address, amount *uint256.Int) {
+ stateObject := s.getOrNewStateObject(addr)
if stateObject != nil {
stateObject.AddBalance(amount)
}
}
// SubBalance subtracts amount from the account associated with addr.
-func (s *StateDB) SubBalance(addr common.Address, amount *big.Int) {
- stateObject := s.GetOrNewStateObject(addr)
+func (s *StateDB) SubBalance(addr common.Address, amount *uint256.Int) {
+ stateObject := s.getOrNewStateObject(addr)
if stateObject != nil {
stateObject.SubBalance(amount)
}
}
-func (s *StateDB) SetBalance(addr common.Address, amount *big.Int) {
- stateObject := s.GetOrNewStateObject(addr)
+func (s *StateDB) SetBalance(addr common.Address, amount *uint256.Int) {
+ stateObject := s.getOrNewStateObject(addr)
if stateObject != nil {
stateObject.SetBalance(amount)
}
}
func (s *StateDB) SetNonce(addr common.Address, nonce uint64) {
- stateObject := s.GetOrNewStateObject(addr)
+ stateObject := s.getOrNewStateObject(addr)
if stateObject != nil {
stateObject.SetNonce(nonce)
}
}
func (s *StateDB) SetCode(addr common.Address, code []byte) {
- stateObject := s.GetOrNewStateObject(addr)
+ stateObject := s.getOrNewStateObject(addr)
if stateObject != nil {
stateObject.SetCode(crypto.Keccak256Hash(code), code)
}
}
func (s *StateDB) SetState(addr common.Address, key, value common.Hash) {
- stateObject := s.GetOrNewStateObject(addr)
+ stateObject := s.getOrNewStateObject(addr)
if stateObject != nil {
- stateObject.SetState(s.db, key, value)
+ stateObject.SetState(key, value)
}
}
// SetStorage replaces the entire storage for the specified account with given
-// storage. This function should only be used for debugging.
+// storage. This function should only be used for debugging and the mutations
+// must be discarded afterwards.
func (s *StateDB) SetStorage(addr common.Address, storage map[common.Hash]common.Hash) {
// SetStorage needs to wipe existing storage. We achieve this by pretending
// that the account self-destructed earlier in this block, by flagging
// it in stateObjectsDestruct. The effect of doing so is that storage lookups
// will not hit disk, since it is assumed that the disk-data is belonging
// to a previous incarnation of the object.
- s.stateObjectsDestruct[addr] = struct{}{}
- stateObject := s.GetOrNewStateObject(addr)
+ //
+ // TODO(rjl493456442) this function should only be supported by 'unwritable'
+ // state and all mutations made should all be discarded afterwards.
+ if _, ok := s.stateObjectsDestruct[addr]; !ok {
+ s.stateObjectsDestruct[addr] = nil
+ }
+ stateObject := s.getOrNewStateObject(addr)
for k, v := range storage {
- stateObject.SetState(s.db, k, v)
+ stateObject.SetState(k, v)
}
}
-// Suicide marks the given account as suicided.
+// SelfDestruct marks the given account as selfdestructed.
// This clears the account balance.
//
// The account's state object is still available until the state is committed,
-// getStateObject will return a non-nil account after Suicide.
-func (s *StateDB) Suicide(addr common.Address) bool {
+// getStateObject will return a non-nil account after SelfDestruct.
+func (s *StateDB) SelfDestruct(addr common.Address) {
stateObject := s.getStateObject(addr)
if stateObject == nil {
- return false
+ return
}
- s.journal.append(suicideChange{
+ s.journal.append(selfDestructChange{
account: &addr,
- prev: stateObject.suicided,
- prevbalance: new(big.Int).Set(stateObject.Balance()),
+ prev: stateObject.selfDestructed,
+ prevbalance: new(uint256.Int).Set(stateObject.Balance()),
})
- stateObject.markSuicided()
- stateObject.data.Balance = new(big.Int)
+ stateObject.markSelfdestructed()
+ stateObject.data.Balance = new(uint256.Int)
+}
- return true
+func (s *StateDB) Selfdestruct6780(addr common.Address) {
+ stateObject := s.getStateObject(addr)
+ if stateObject == nil {
+ return
+ }
+
+ if stateObject.created {
+ s.SelfDestruct(addr)
+ }
}
// SetTransientState sets transient storage for a given account. It
@@ -507,16 +507,27 @@ func (s *StateDB) updateStateObject(obj *stateObject) {
}
// Encode the account and update the account trie
addr := obj.Address()
- if err := s.trie.TryUpdateAccount(addr, &obj.data); err != nil {
+ if err := s.trie.UpdateAccount(addr, &obj.data); err != nil {
s.setError(fmt.Errorf("updateStateObject (%x) error: %v", addr[:], err))
}
+ if obj.dirtyCode {
+ s.trie.UpdateContractCode(obj.Address(), common.BytesToHash(obj.CodeHash()), obj.code)
+ }
+ // Cache the data until commit. Note, this update mechanism is not symmetric
+ // to the deletion, because whereas it is enough to track account updates
+ // at commit time, deletions need tracking at transaction boundary level to
+ // ensure we capture state clearing.
+ s.accounts[obj.addrHash] = types.SlimAccountRLP(obj.data)
- // If state snapshotting is active, cache the data til commit. Note, this
- // update mechanism is not symmetric to the deletion, because whereas it is
- // enough to track account updates at commit time, deletions need tracking
- // at transaction boundary level to ensure we capture state clearing.
- if s.snap != nil {
- s.snapAccounts[obj.addrHash] = snapshot.SlimAccountRLP(obj.data.Nonce, obj.data.Balance, obj.data.Root, obj.data.CodeHash)
+ // Track the original value of mutated account, nil means it was not present.
+ // Skip if it has been tracked (because updateStateObject may be called
+ // multiple times in a block).
+ if _, ok := s.accountsOrigin[obj.address]; !ok {
+ if obj.origin == nil {
+ s.accountsOrigin[obj.address] = nil
+ } else {
+ s.accountsOrigin[obj.address] = types.SlimAccountRLP(*obj.origin)
+ }
}
}
@@ -528,7 +539,7 @@ func (s *StateDB) deleteStateObject(obj *stateObject) {
}
// Delete the account from the trie
addr := obj.Address()
- if err := s.trie.TryDeleteAccount(addr); err != nil {
+ if err := s.trie.DeleteAccount(addr); err != nil {
s.setError(fmt.Errorf("deleteStateObject (%x) error: %v", addr[:], err))
}
}
@@ -582,7 +593,7 @@ func (s *StateDB) getDeletedStateObject(addr common.Address) *stateObject {
if data == nil {
start := time.Now()
var err error
- data, err = s.trie.TryGetAccount(addr)
+ data, err = s.trie.GetAccount(addr)
if metrics.EnabledExpensive {
s.AccountReads += time.Since(start)
}
@@ -595,7 +606,7 @@ func (s *StateDB) getDeletedStateObject(addr common.Address) *stateObject {
}
}
// Insert into the live set
- obj := newObject(s, addr, *data)
+ obj := newObject(s, addr, data)
s.setStateObject(obj)
return obj
}
@@ -604,8 +615,8 @@ func (s *StateDB) setStateObject(object *stateObject) {
s.stateObjects[object.Address()] = object
}
-// GetOrNewStateObject retrieves a state object or create a new state object if nil.
-func (s *StateDB) GetOrNewStateObject(addr common.Address) *stateObject {
+// getOrNewStateObject retrieves a state object or create a new state object if nil.
+func (s *StateDB) getOrNewStateObject(addr common.Address) *stateObject {
stateObject := s.getStateObject(addr)
if stateObject == nil {
stateObject, _ = s.createObject(addr)
@@ -617,19 +628,36 @@ func (s *StateDB) GetOrNewStateObject(addr common.Address) *stateObject {
// the given address, it is overwritten and returned as the second return value.
func (s *StateDB) createObject(addr common.Address) (newobj, prev *stateObject) {
prev = s.getDeletedStateObject(addr) // Note, prev might have been deleted, we need that!
-
- var prevdestruct bool
- if prev != nil {
- _, prevdestruct = s.stateObjectsDestruct[prev.address]
- if !prevdestruct {
- s.stateObjectsDestruct[prev.address] = struct{}{}
- }
- }
- newobj = newObject(s, addr, types.StateAccount{})
+ newobj = newObject(s, addr, nil)
if prev == nil {
s.journal.append(createObjectChange{account: &addr})
} else {
- s.journal.append(resetObjectChange{prev: prev, prevdestruct: prevdestruct})
+ // The original account should be marked as destructed and all cached
+ // account and storage data should be cleared as well. Note, it must
+ // be done here, otherwise the destruction event of "original account"
+ // will be lost.
+ _, prevdestruct := s.stateObjectsDestruct[prev.address]
+ if !prevdestruct {
+ s.stateObjectsDestruct[prev.address] = prev.origin
+ }
+ // There may be some cached account/storage data already since IntermediateRoot
+ // will be called for each transaction before byzantium fork which will always
+ // cache the latest account/storage data.
+ prevAccount, ok := s.accountsOrigin[prev.address]
+ s.journal.append(resetObjectChange{
+ account: &addr,
+ prev: prev,
+ prevdestruct: prevdestruct,
+ prevAccount: s.accounts[prev.addrHash],
+ prevStorage: s.storages[prev.addrHash],
+ prevAccountOriginExist: ok,
+ prevAccountOrigin: prevAccount,
+ prevStorageOrigin: s.storagesOrigin[prev.address],
+ })
+ delete(s.accounts, prev.addrHash)
+ delete(s.storages, prev.addrHash)
+ delete(s.accountsOrigin, prev.address)
+ delete(s.storagesOrigin, prev.address)
}
s.setStateObject(newobj)
if prev != nil && !prev.deleted {
@@ -655,39 +683,6 @@ func (s *StateDB) CreateAccount(addr common.Address) {
}
}
-func (db *StateDB) ForEachStorage(addr common.Address, cb func(key, value common.Hash) bool) error {
- so := db.getStateObject(addr)
- if so == nil {
- return nil
- }
- tr, err := so.getTrie(db.db)
- if err != nil {
- return err
- }
- it := trie.NewIterator(tr.NodeIterator(nil))
-
- for it.Next() {
- key := common.BytesToHash(db.trie.GetKey(it.Key))
- if value, dirty := so.dirtyStorage[key]; dirty {
- if !cb(key, value) {
- return nil
- }
- continue
- }
-
- if len(it.Value) > 0 {
- _, content, _, err := rlp.Split(it.Value)
- if err != nil {
- return err
- }
- if !cb(key, common.BytesToHash(content)) {
- return nil
- }
- }
- }
- return nil
-}
-
// Copy creates a deep, independent copy of the state.
// Snapshots of the copied state cannot be applied to the copy.
func (s *StateDB) Copy() *StateDB {
@@ -696,16 +691,27 @@ func (s *StateDB) Copy() *StateDB {
db: s.db,
trie: s.db.CopyTrie(s.trie),
originalRoot: s.originalRoot,
+ accounts: make(map[common.Hash][]byte),
+ storages: make(map[common.Hash]map[common.Hash][]byte),
+ accountsOrigin: make(map[common.Address][]byte),
+ storagesOrigin: make(map[common.Address]map[common.Hash][]byte),
stateObjects: make(map[common.Address]*stateObject, len(s.journal.dirties)),
stateObjectsPending: make(map[common.Address]struct{}, len(s.stateObjectsPending)),
stateObjectsDirty: make(map[common.Address]struct{}, len(s.journal.dirties)),
- stateObjectsDestruct: make(map[common.Address]struct{}, len(s.stateObjectsDestruct)),
+ stateObjectsDestruct: make(map[common.Address]*types.StateAccount, len(s.stateObjectsDestruct)),
refund: s.refund,
logs: make(map[common.Hash][]*types.Log, len(s.logs)),
logSize: s.logSize,
preimages: make(map[common.Hash][]byte, len(s.preimages)),
journal: newJournal(),
hasher: crypto.NewKeccakState(),
+
+ // In order for the block producer to be able to use and make additions
+ // to the snapshot tree, we need to copy that as well. Otherwise, any
+ // block mined by ourselves will cause gaps in the tree, and force the
+ // miner to operate trie-backed only.
+ snaps: s.snaps,
+ snap: s.snap,
}
// Copy the dirty states, logs, and preimages
for addr := range s.journal.dirties {
@@ -739,10 +745,18 @@ func (s *StateDB) Copy() *StateDB {
}
state.stateObjectsDirty[addr] = struct{}{}
}
- // Deep copy the destruction flag.
- for addr := range s.stateObjectsDestruct {
- state.stateObjectsDestruct[addr] = struct{}{}
+ // Deep copy the destruction markers.
+ for addr, value := range s.stateObjectsDestruct {
+ state.stateObjectsDestruct[addr] = value
}
+ // Deep copy the state changes made in the scope of block
+ // along with their original values.
+ state.accounts = copySet(s.accounts)
+ state.storages = copy2DSet(s.storages)
+ state.accountsOrigin = copySet(state.accountsOrigin)
+ state.storagesOrigin = copy2DSet(state.storagesOrigin)
+
+ // Deep copy the logs occurred in the scope of block
for hash, logs := range s.logs {
cpy := make([]*types.Log, len(logs))
for i, l := range logs {
@@ -751,6 +765,7 @@ func (s *StateDB) Copy() *StateDB {
}
state.logs[hash] = cpy
}
+ // Deep copy the preimages occurred in the scope of block
for hash, preimage := range s.preimages {
state.preimages[hash] = preimage
}
@@ -769,28 +784,6 @@ func (s *StateDB) Copy() *StateDB {
if s.prefetcher != nil {
state.prefetcher = s.prefetcher.copy()
}
- if s.snaps != nil {
- // In order for the miner to be able to use and make additions
- // to the snapshot tree, we need to copy that as well.
- // Otherwise, any block mined by ourselves will cause gaps in the tree,
- // and force the miner to operate trie-backed only
- state.snaps = s.snaps
- state.snap = s.snap
-
- // deep copy needed
- state.snapAccounts = make(map[common.Hash][]byte)
- for k, v := range s.snapAccounts {
- state.snapAccounts[k] = v
- }
- state.snapStorage = make(map[common.Hash]map[common.Hash][]byte)
- for k, v := range s.snapStorage {
- temp := make(map[common.Hash][]byte)
- for kk, vv := range v {
- temp[kk] = vv
- }
- state.snapStorage[k] = temp
- }
- }
return state
}
@@ -839,24 +832,26 @@ func (s *StateDB) Finalise(deleteEmptyObjects bool) {
// Thus, we can safely ignore it here
continue
}
- if obj.suicided || (deleteEmptyObjects && obj.empty()) {
+ if obj.selfDestructed || (deleteEmptyObjects && obj.empty()) {
obj.deleted = true
// We need to maintain account deletions explicitly (will remain
- // set indefinitely).
- s.stateObjectsDestruct[obj.address] = struct{}{}
-
- // If state snapshotting is active, also mark the destruction there.
+ // set indefinitely). Note only the first occurred self-destruct
+ // event is tracked.
+ if _, ok := s.stateObjectsDestruct[obj.address]; !ok {
+ s.stateObjectsDestruct[obj.address] = obj.origin
+ }
// Note, we can't do this only at the end of a block because multiple
// transactions within the same block might self destruct and then
// resurrect an account; but the snapshotter needs both events.
- if s.snap != nil {
- delete(s.snapAccounts, obj.addrHash) // Clear out any previously updated account data (may be recreated via a resurrect)
- delete(s.snapStorage, obj.addrHash) // Clear out any previously updated storage data (may be recreated via a resurrect)
- }
+ delete(s.accounts, obj.addrHash) // Clear out any previously updated account data (may be recreated via a resurrect)
+ delete(s.storages, obj.addrHash) // Clear out any previously updated storage data (may be recreated via a resurrect)
+ delete(s.accountsOrigin, obj.address) // Clear out any previously updated account data (may be recreated via a resurrect)
+ delete(s.storagesOrigin, obj.address) // Clear out any previously updated storage data (may be recreated via a resurrect)
} else {
obj.finalise(true) // Prefetch slots in the background
}
+ obj.created = false
s.stateObjectsPending[addr] = struct{}{}
s.stateObjectsDirty[addr] = struct{}{}
@@ -866,7 +861,7 @@ func (s *StateDB) Finalise(deleteEmptyObjects bool) {
addressesToPrefetch = append(addressesToPrefetch, common.CopyBytes(addr[:])) // Copy needed for closure
}
if s.prefetcher != nil && len(addressesToPrefetch) > 0 {
- s.prefetcher.prefetch(common.Hash{}, s.originalRoot, addressesToPrefetch)
+ s.prefetcher.prefetch(common.Hash{}, s.originalRoot, common.Address{}, addressesToPrefetch)
}
// Invalidate journal because reverting across transactions is not allowed.
s.clearJournalAndRefund()
@@ -900,7 +895,7 @@ func (s *StateDB) IntermediateRoot(deleteEmptyObjects bool) common.Hash {
// to pull useful data from disk.
for addr := range s.stateObjectsPending {
if obj := s.stateObjects[addr]; !obj.deleted {
- obj.updateRoot(s.db)
+ obj.updateRoot()
}
}
// Now we're about to start to write changes to the trie. The trie is so far
@@ -951,6 +946,359 @@ func (s *StateDB) clearJournalAndRefund() {
s.validRevisions = s.validRevisions[:0] // Snapshots can be created without journal entries
}
+// fastDeleteStorage is the function that efficiently deletes the storage trie
+// of a specific account. It leverages the associated state snapshot for fast
+// storage iteration and constructs trie node deletion markers by creating
+// stack trie with iterated slots.
+func (s *StateDB) fastDeleteStorage(addrHash common.Hash, root common.Hash) (bool, common.StorageSize, map[common.Hash][]byte, *trienode.NodeSet, error) {
+ iter, err := s.snaps.StorageIterator(s.originalRoot, addrHash, common.Hash{})
+ if err != nil {
+ return false, 0, nil, nil, err
+ }
+ defer iter.Release()
+
+ var (
+ size common.StorageSize
+ nodes = trienode.NewNodeSet(addrHash)
+ slots = make(map[common.Hash][]byte)
+ )
+ options := trie.NewStackTrieOptions()
+ options = options.WithWriter(func(path []byte, hash common.Hash, blob []byte) {
+ nodes.AddNode(path, trienode.NewDeleted())
+ size += common.StorageSize(len(path))
+ })
+ stack := trie.NewStackTrie(options)
+ for iter.Next() {
+ if size > storageDeleteLimit {
+ return true, size, nil, nil, nil
+ }
+ slot := common.CopyBytes(iter.Slot())
+ if err := iter.Error(); err != nil { // error might occur after Slot function
+ return false, 0, nil, nil, err
+ }
+ size += common.StorageSize(common.HashLength + len(slot))
+ slots[iter.Hash()] = slot
+
+ if err := stack.Update(iter.Hash().Bytes(), slot); err != nil {
+ return false, 0, nil, nil, err
+ }
+ }
+ if err := iter.Error(); err != nil { // error might occur during iteration
+ return false, 0, nil, nil, err
+ }
+ if stack.Hash() != root {
+ return false, 0, nil, nil, fmt.Errorf("snapshot is not matched, exp %x, got %x", root, stack.Hash())
+ }
+ return false, size, slots, nodes, nil
+}
+
+// slowDeleteStorage serves as a less-efficient alternative to "fastDeleteStorage,"
+// employed when the associated state snapshot is not available. It iterates the
+// storage slots along with all internal trie nodes via trie directly.
+func (s *StateDB) slowDeleteStorage(addr common.Address, addrHash common.Hash, root common.Hash) (bool, common.StorageSize, map[common.Hash][]byte, *trienode.NodeSet, error) {
+ tr, err := s.db.OpenStorageTrie(s.originalRoot, addrHash, root, s.trie)
+ if err != nil {
+ return false, 0, nil, nil, fmt.Errorf("failed to open storage trie, err: %w", err)
+ }
+ it, err := tr.NodeIterator(nil)
+ if err != nil {
+ return false, 0, nil, nil, fmt.Errorf("failed to open storage iterator, err: %w", err)
+ }
+ var (
+ size common.StorageSize
+ nodes = trienode.NewNodeSet(addrHash)
+ slots = make(map[common.Hash][]byte)
+ )
+ for it.Next(true) {
+ if size > storageDeleteLimit {
+ return true, size, nil, nil, nil
+ }
+ if it.Leaf() {
+ slots[common.BytesToHash(it.LeafKey())] = common.CopyBytes(it.LeafBlob())
+ size += common.StorageSize(common.HashLength + len(it.LeafBlob()))
+ continue
+ }
+ if it.Hash() == (common.Hash{}) {
+ continue
+ }
+ size += common.StorageSize(len(it.Path()))
+ nodes.AddNode(it.Path(), trienode.NewDeleted())
+ }
+ if err := it.Error(); err != nil {
+ return false, 0, nil, nil, err
+ }
+ return false, size, slots, nodes, nil
+}
+
+// deleteStorage is designed to delete the storage trie of a designated account.
+// It could potentially be terminated if the storage size is excessively large,
+// potentially leading to an out-of-memory panic. The function will make an attempt
+// to utilize an efficient strategy if the associated state snapshot is reachable;
+// otherwise, it will resort to a less-efficient approach.
+func (s *StateDB) deleteStorage(addr common.Address, addrHash common.Hash, root common.Hash) (bool, map[common.Hash][]byte, *trienode.NodeSet, error) {
+ var (
+ start = time.Now()
+ err error
+ aborted bool
+ size common.StorageSize
+ slots map[common.Hash][]byte
+ nodes *trienode.NodeSet
+ )
+ // The fast approach can be failed if the snapshot is not fully
+ // generated, or it's internally corrupted. Fallback to the slow
+ // one just in case.
+ if s.snap != nil {
+ aborted, size, slots, nodes, err = s.fastDeleteStorage(addrHash, root)
+ }
+ if s.snap == nil || err != nil {
+ aborted, size, slots, nodes, err = s.slowDeleteStorage(addr, addrHash, root)
+ }
+ if err != nil {
+ return false, nil, nil, err
+ }
+ if metrics.EnabledExpensive {
+ if aborted {
+ slotDeletionSkip.Inc(1)
+ }
+ n := int64(len(slots))
+
+ slotDeletionMaxCount.UpdateIfGt(int64(len(slots)))
+ slotDeletionMaxSize.UpdateIfGt(int64(size))
+
+ slotDeletionTimer.UpdateSince(start)
+ slotDeletionCount.Mark(n)
+ slotDeletionSize.Mark(int64(size))
+ }
+ return aborted, slots, nodes, nil
+}
+
+// handleDestruction processes all destruction markers and deletes the account
+// and associated storage slots if necessary. There are four possible situations
+// here:
+//
+// - the account was not existent and be marked as destructed
+//
+// - the account was not existent and be marked as destructed,
+// however, it's resurrected later in the same block.
+//
+// - the account was existent and be marked as destructed
+//
+// - the account was existent and be marked as destructed,
+// however it's resurrected later in the same block.
+//
+// In case (a), nothing needs be deleted, nil to nil transition can be ignored.
+//
+// In case (b), nothing needs be deleted, nil is used as the original value for
+// newly created account and storages
+//
+// In case (c), **original** account along with its storages should be deleted,
+// with their values be tracked as original value.
+//
+// In case (d), **original** account along with its storages should be deleted,
+// with their values be tracked as original value.
+func (s *StateDB) handleDestruction(nodes *trienode.MergedNodeSet) (map[common.Address]struct{}, error) {
+ // Short circuit if geth is running with hash mode. This procedure can consume
+ // considerable time and storage deletion isn't supported in hash mode, thus
+ // preemptively avoiding unnecessary expenses.
+ incomplete := make(map[common.Address]struct{})
+ if s.db.TrieDB().Scheme() == rawdb.HashScheme {
+ return incomplete, nil
+ }
+ for addr, prev := range s.stateObjectsDestruct {
+ // The original account was non-existing, and it's marked as destructed
+ // in the scope of block. It can be case (a) or (b).
+ // - for (a), skip it without doing anything.
+ // - for (b), track account's original value as nil. It may overwrite
+ // the data cached in s.accountsOrigin set by 'updateStateObject'.
+ addrHash := crypto.Keccak256Hash(addr[:])
+ if prev == nil {
+ if _, ok := s.accounts[addrHash]; ok {
+ s.accountsOrigin[addr] = nil // case (b)
+ }
+ continue
+ }
+ // It can overwrite the data in s.accountsOrigin set by 'updateStateObject'.
+ s.accountsOrigin[addr] = types.SlimAccountRLP(*prev) // case (c) or (d)
+
+ // Short circuit if the storage was empty.
+ if prev.Root == types.EmptyRootHash {
+ continue
+ }
+ // Remove storage slots belong to the account.
+ aborted, slots, set, err := s.deleteStorage(addr, addrHash, prev.Root)
+ if err != nil {
+ return nil, fmt.Errorf("failed to delete storage, err: %w", err)
+ }
+ // The storage is too huge to handle, skip it but mark as incomplete.
+ // For case (d), the account is resurrected might with a few slots
+ // created. In this case, wipe the entire storage state diff because
+ // of aborted deletion.
+ if aborted {
+ incomplete[addr] = struct{}{}
+ delete(s.storagesOrigin, addr)
+ continue
+ }
+ if s.storagesOrigin[addr] == nil {
+ s.storagesOrigin[addr] = slots
+ } else {
+ // It can overwrite the data in s.storagesOrigin[addrHash] set by
+ // 'object.updateTrie'.
+ for key, val := range slots {
+ s.storagesOrigin[addr][key] = val
+ }
+ }
+ if err := nodes.Merge(set); err != nil {
+ return nil, err
+ }
+ }
+ return incomplete, nil
+}
+
+// Commit writes the state to the underlying in-memory trie database.
+// Once the state is committed, tries cached in stateDB (including account
+// trie, storage tries) will no longer be functional. A new state instance
+// must be created with new root and updated database for accessing post-
+// commit states.
+//
+// The associated block number of the state transition is also provided
+// for more chain context.
+func (s *StateDB) Commit(block uint64, deleteEmptyObjects bool) (common.Hash, error) {
+ // Short circuit in case any database failure occurred earlier.
+ if s.dbErr != nil {
+ return common.Hash{}, fmt.Errorf("commit aborted due to earlier error: %v", s.dbErr)
+ }
+ // Finalize any pending changes and merge everything into the tries
+ s.IntermediateRoot(deleteEmptyObjects)
+
+ // Commit objects to the trie, measuring the elapsed time
+ var (
+ accountTrieNodesUpdated int
+ accountTrieNodesDeleted int
+ storageTrieNodesUpdated int
+ storageTrieNodesDeleted int
+ nodes = trienode.NewMergedNodeSet()
+ codeWriter = s.db.DiskDB().NewBatch()
+ )
+ // Handle all state deletions first
+ incomplete, err := s.handleDestruction(nodes)
+ if err != nil {
+ return common.Hash{}, err
+ }
+ // Handle all state updates afterwards
+ for addr := range s.stateObjectsDirty {
+ obj := s.stateObjects[addr]
+ if obj.deleted {
+ continue
+ }
+ // Write any contract code associated with the state object
+ if obj.code != nil && obj.dirtyCode {
+ rawdb.WriteCode(codeWriter, common.BytesToHash(obj.CodeHash()), obj.code)
+ obj.dirtyCode = false
+ }
+ // Write any storage changes in the state object to its storage trie
+ set, err := obj.commit()
+ if err != nil {
+ return common.Hash{}, err
+ }
+ // Merge the dirty nodes of storage trie into global set. It is possible
+ // that the account was destructed and then resurrected in the same block.
+ // In this case, the node set is shared by both accounts.
+ if set != nil {
+ if err := nodes.Merge(set); err != nil {
+ return common.Hash{}, err
+ }
+ updates, deleted := set.Size()
+ storageTrieNodesUpdated += updates
+ storageTrieNodesDeleted += deleted
+ }
+ }
+ if codeWriter.ValueSize() > 0 {
+ if err := codeWriter.Write(); err != nil {
+ log.Crit("Failed to commit dirty codes", "error", err)
+ }
+ }
+ // Write the account trie changes, measuring the amount of wasted time
+ var start time.Time
+ if metrics.EnabledExpensive {
+ start = time.Now()
+ }
+ root, set, err := s.trie.Commit(true)
+ if err != nil {
+ return common.Hash{}, err
+ }
+ // Merge the dirty nodes of account trie into global set
+ if set != nil {
+ if err := nodes.Merge(set); err != nil {
+ return common.Hash{}, err
+ }
+ accountTrieNodesUpdated, accountTrieNodesDeleted = set.Size()
+ }
+ if metrics.EnabledExpensive {
+ s.AccountCommits += time.Since(start)
+
+ accountUpdatedMeter.Mark(int64(s.AccountUpdated))
+ storageUpdatedMeter.Mark(int64(s.StorageUpdated))
+ accountDeletedMeter.Mark(int64(s.AccountDeleted))
+ storageDeletedMeter.Mark(int64(s.StorageDeleted))
+ accountTrieUpdatedMeter.Mark(int64(accountTrieNodesUpdated))
+ accountTrieDeletedMeter.Mark(int64(accountTrieNodesDeleted))
+ storageTriesUpdatedMeter.Mark(int64(storageTrieNodesUpdated))
+ storageTriesDeletedMeter.Mark(int64(storageTrieNodesDeleted))
+ s.AccountUpdated, s.AccountDeleted = 0, 0
+ s.StorageUpdated, s.StorageDeleted = 0, 0
+ }
+ // If snapshotting is enabled, update the snapshot tree with this new version
+ if s.snap != nil {
+ start := time.Now()
+ // Only update if there's a state transition (skip empty Clique blocks)
+ if parent := s.snap.Root(); parent != root {
+ if err := s.snaps.Update(root, parent, s.convertAccountSet(s.stateObjectsDestruct), s.accounts, s.storages); err != nil {
+ log.Warn("Failed to update snapshot tree", "from", parent, "to", root, "err", err)
+ }
+ // Keep 128 diff layers in the memory, persistent layer is 129th.
+ // - head layer is paired with HEAD state
+ // - head-1 layer is paired with HEAD-1 state
+ // - head-127 layer(bottom-most diff layer) is paired with HEAD-127 state
+ if err := s.snaps.Cap(root, 128); err != nil {
+ log.Warn("Failed to cap snapshot tree", "root", root, "layers", 128, "err", err)
+ }
+ }
+ if metrics.EnabledExpensive {
+ s.SnapshotCommits += time.Since(start)
+ }
+ s.snap = nil
+ }
+ if root == (common.Hash{}) {
+ root = types.EmptyRootHash
+ }
+ origin := s.originalRoot
+ if origin == (common.Hash{}) {
+ origin = types.EmptyRootHash
+ }
+ if root != origin {
+ start := time.Now()
+ set := triestate.New(s.accountsOrigin, s.storagesOrigin, incomplete)
+ if err := s.db.TrieDB().Update(root, origin, block, nodes, set); err != nil {
+ return common.Hash{}, err
+ }
+ s.originalRoot = root
+ if metrics.EnabledExpensive {
+ s.TrieDBCommits += time.Since(start)
+ }
+ if s.onCommit != nil {
+ s.onCommit(set)
+ }
+ }
+ // Clear all internal flags at the end of commit operation.
+ s.accounts = make(map[common.Hash][]byte)
+ s.storages = make(map[common.Hash]map[common.Hash][]byte)
+ s.accountsOrigin = make(map[common.Address][]byte)
+ s.storagesOrigin = make(map[common.Address]map[common.Hash][]byte)
+ s.stateObjectsDirty = make(map[common.Address]struct{})
+ s.stateObjectsDestruct = make(map[common.Address]*types.StateAccount)
+ return root, nil
+}
+
// Prepare handles the preparatory steps for executing a state transition with.
// This method must be invoked before state transition.
//
@@ -1028,8 +1376,8 @@ func (s *StateDB) SlotInAccessList(addr common.Address, slot common.Hash) (addre
}
// convertAccountSet converts a provided account set from address keyed to hash keyed.
-func (s *StateDB) convertAccountSet(set map[common.Address]struct{}) map[common.Hash]struct{} {
- ret := make(map[common.Hash]struct{})
+func (s *StateDB) convertAccountSet(set map[common.Address]*types.StateAccount) map[common.Hash]struct{} {
+ ret := make(map[common.Hash]struct{}, len(set))
for addr := range set {
obj, exist := s.stateObjects[addr]
if !exist {
@@ -1040,3 +1388,24 @@ func (s *StateDB) convertAccountSet(set map[common.Address]struct{}) map[common.
}
return ret
}
+
+// copySet returns a deep-copied set.
+func copySet[k comparable](set map[k][]byte) map[k][]byte {
+ copied := make(map[k][]byte, len(set))
+ for key, val := range set {
+ copied[key] = common.CopyBytes(val)
+ }
+ return copied
+}
+
+// copy2DSet returns a two-dimensional deep-copied set.
+func copy2DSet[k comparable](set map[k]map[common.Hash][]byte) map[k]map[common.Hash][]byte {
+ copied := make(map[k]map[common.Hash][]byte, len(set))
+ for addr, subset := range set {
+ copied[addr] = make(map[common.Hash][]byte, len(subset))
+ for key, val := range subset {
+ copied[addr][key] = common.CopyBytes(val)
+ }
+ }
+ return copied
+}
diff --git a/trie_by_cid/state/statedb_test.go b/trie_by_cid/state/statedb_test.go
index 60cd66a..c82537d 100644
--- a/trie_by_cid/state/statedb_test.go
+++ b/trie_by_cid/state/statedb_test.go
@@ -21,7 +21,6 @@ import (
"encoding/binary"
"fmt"
"math"
- "math/big"
"math/rand"
"reflect"
"strings"
@@ -30,8 +29,11 @@ import (
"testing/quick"
"github.com/ethereum/go-ethereum/common"
- "github.com/ethereum/go-ethereum/core/rawdb"
"github.com/ethereum/go-ethereum/core/types"
+ "github.com/ethereum/go-ethereum/rlp"
+ "github.com/holiman/uint256"
+
+ "github.com/cerc-io/ipld-eth-statedb/trie_by_cid/trie"
)
// TestCopy tests that copying a StateDB object indeed makes the original and
@@ -39,11 +41,13 @@ import (
// https://github.com/ethereum/go-ethereum/pull/15549.
func TestCopy(t *testing.T) {
// Create a random state test to copy and modify "independently"
- orig, _ := New(common.Hash{}, NewDatabase(rawdb.NewMemoryDatabase()), nil)
+ db, cleanup := newPgIpfsEthdb(t)
+ t.Cleanup(cleanup)
+ orig, _ := New(types.EmptyRootHash, NewDatabase(db), nil)
for i := byte(0); i < 255; i++ {
- obj := orig.GetOrNewStateObject(common.BytesToAddress([]byte{i}))
- obj.AddBalance(big.NewInt(int64(i)))
+ obj := orig.getOrNewStateObject(common.BytesToAddress([]byte{i}))
+ obj.AddBalance(uint256.NewInt(uint64(i)))
orig.updateStateObject(obj)
}
orig.Finalise(false)
@@ -56,13 +60,13 @@ func TestCopy(t *testing.T) {
// modify all in memory
for i := byte(0); i < 255; i++ {
- origObj := orig.GetOrNewStateObject(common.BytesToAddress([]byte{i}))
- copyObj := copy.GetOrNewStateObject(common.BytesToAddress([]byte{i}))
- ccopyObj := ccopy.GetOrNewStateObject(common.BytesToAddress([]byte{i}))
+ origObj := orig.getOrNewStateObject(common.BytesToAddress([]byte{i}))
+ copyObj := copy.getOrNewStateObject(common.BytesToAddress([]byte{i}))
+ ccopyObj := ccopy.getOrNewStateObject(common.BytesToAddress([]byte{i}))
- origObj.AddBalance(big.NewInt(2 * int64(i)))
- copyObj.AddBalance(big.NewInt(3 * int64(i)))
- ccopyObj.AddBalance(big.NewInt(4 * int64(i)))
+ origObj.AddBalance(uint256.NewInt(2 * uint64(i)))
+ copyObj.AddBalance(uint256.NewInt(3 * uint64(i)))
+ ccopyObj.AddBalance(uint256.NewInt(4 * uint64(i)))
orig.updateStateObject(origObj)
copy.updateStateObject(copyObj)
@@ -84,25 +88,34 @@ func TestCopy(t *testing.T) {
// Verify that the three states have been updated independently
for i := byte(0); i < 255; i++ {
- origObj := orig.GetOrNewStateObject(common.BytesToAddress([]byte{i}))
- copyObj := copy.GetOrNewStateObject(common.BytesToAddress([]byte{i}))
- ccopyObj := ccopy.GetOrNewStateObject(common.BytesToAddress([]byte{i}))
+ origObj := orig.getOrNewStateObject(common.BytesToAddress([]byte{i}))
+ copyObj := copy.getOrNewStateObject(common.BytesToAddress([]byte{i}))
+ ccopyObj := ccopy.getOrNewStateObject(common.BytesToAddress([]byte{i}))
- if want := big.NewInt(3 * int64(i)); origObj.Balance().Cmp(want) != 0 {
+ if want := uint256.NewInt(3 * uint64(i)); origObj.Balance().Cmp(want) != 0 {
t.Errorf("orig obj %d: balance mismatch: have %v, want %v", i, origObj.Balance(), want)
}
- if want := big.NewInt(4 * int64(i)); copyObj.Balance().Cmp(want) != 0 {
+ if want := uint256.NewInt(4 * uint64(i)); copyObj.Balance().Cmp(want) != 0 {
t.Errorf("copy obj %d: balance mismatch: have %v, want %v", i, copyObj.Balance(), want)
}
- if want := big.NewInt(5 * int64(i)); ccopyObj.Balance().Cmp(want) != 0 {
+ if want := uint256.NewInt(5 * uint64(i)); ccopyObj.Balance().Cmp(want) != 0 {
t.Errorf("copy obj %d: balance mismatch: have %v, want %v", i, ccopyObj.Balance(), want)
}
}
}
func TestSnapshotRandom(t *testing.T) {
- config := &quick.Config{MaxCount: 1000}
- err := quick.Check((*snapshotTest).run, config)
+ config := &quick.Config{MaxCount: 10}
+ i := 0
+ run := func(test *snapshotTest) bool {
+ var res bool
+ t.Run(fmt.Sprintf("test-%d", i), func(t *testing.T) {
+ res = test.run(t)
+ })
+ i++
+ return res
+ }
+ err := quick.Check(run, config)
if cerr, ok := err.(*quick.CheckError); ok {
test := cerr.In[0].(*snapshotTest)
t.Errorf("%v:\n%s", test.err, test)
@@ -142,14 +155,14 @@ func newTestAction(addr common.Address, r *rand.Rand) testAction {
{
name: "SetBalance",
fn: func(a testAction, s *StateDB) {
- s.SetBalance(addr, big.NewInt(a.args[0]))
+ s.SetBalance(addr, uint256.NewInt(uint64(a.args[0])))
},
args: make([]int64, 1),
},
{
name: "AddBalance",
fn: func(a testAction, s *StateDB) {
- s.AddBalance(addr, big.NewInt(a.args[0]))
+ s.AddBalance(addr, uint256.NewInt(uint64(a.args[0])))
},
args: make([]int64, 1),
},
@@ -187,9 +200,9 @@ func newTestAction(addr common.Address, r *rand.Rand) testAction {
},
},
{
- name: "Suicide",
+ name: "SelfDestruct",
fn: func(a testAction, s *StateDB) {
- s.Suicide(addr)
+ s.SelfDestruct(addr)
},
},
{
@@ -296,16 +309,20 @@ func (test *snapshotTest) String() string {
return out.String()
}
-func (test *snapshotTest) run() bool {
+func (test *snapshotTest) run(t *testing.T) bool {
// Run all actions and create snapshots.
+ db, cleanup := newPgIpfsEthdb(t)
+ t.Cleanup(cleanup)
var (
- state, _ = New(common.Hash{}, NewDatabase(rawdb.NewMemoryDatabase()), nil)
+ state, _ = New(types.EmptyRootHash, NewDatabase(db), nil)
snapshotRevs = make([]int, len(test.snapshots))
sindex = 0
+ checkstates = make([]*StateDB, len(test.snapshots))
)
for i, action := range test.actions {
if len(test.snapshots) > sindex && i == test.snapshots[sindex] {
snapshotRevs[sindex] = state.Snapshot()
+ checkstates[sindex] = state.Copy()
sindex++
}
action.fn(action, state)
@@ -313,12 +330,8 @@ func (test *snapshotTest) run() bool {
// Revert all snapshots in reverse order. Each revert must yield a state
// that is equivalent to fresh state with all actions up the snapshot applied.
for sindex--; sindex >= 0; sindex-- {
- checkstate, _ := New(common.Hash{}, state.Database(), nil)
- for _, action := range test.actions[:test.snapshots[sindex]] {
- action.fn(action, checkstate)
- }
state.RevertToSnapshot(snapshotRevs[sindex])
- if err := test.checkEqual(state, checkstate); err != nil {
+ if err := test.checkEqual(state, checkstates[sindex]); err != nil {
test.err = fmt.Errorf("state mismatch after revert to snapshot %d\n%v", sindex, err)
return false
}
@@ -326,6 +339,43 @@ func (test *snapshotTest) run() bool {
return true
}
+func forEachStorage(s *StateDB, addr common.Address, cb func(key, value common.Hash) bool) error {
+ so := s.getStateObject(addr)
+ if so == nil {
+ return nil
+ }
+ tr, err := so.getTrie()
+ if err != nil {
+ return err
+ }
+ trieIt, err := tr.NodeIterator(nil)
+ if err != nil {
+ return err
+ }
+ it := trie.NewIterator(trieIt)
+
+ for it.Next() {
+ key := common.BytesToHash(s.trie.GetKey(it.Key))
+ if value, dirty := so.dirtyStorage[key]; dirty {
+ if !cb(key, value) {
+ return nil
+ }
+ continue
+ }
+
+ if len(it.Value) > 0 {
+ _, content, _, err := rlp.Split(it.Value)
+ if err != nil {
+ return err
+ }
+ if !cb(key, common.BytesToHash(content)) {
+ return nil
+ }
+ }
+ }
+ return nil
+}
+
// checkEqual checks that methods of state and checkstate return the same values.
func (test *snapshotTest) checkEqual(state, checkstate *StateDB) error {
for _, addr := range test.addrs {
@@ -339,7 +389,7 @@ func (test *snapshotTest) checkEqual(state, checkstate *StateDB) error {
}
// Check basic accessor methods.
checkeq("Exist", state.Exist(addr), checkstate.Exist(addr))
- checkeq("HasSuicided", state.HasSuicided(addr), checkstate.HasSuicided(addr))
+ checkeq("HasSelfdestructed", state.HasSelfDestructed(addr), checkstate.HasSelfDestructed(addr))
checkeq("GetBalance", state.GetBalance(addr), checkstate.GetBalance(addr))
checkeq("GetNonce", state.GetNonce(addr), checkstate.GetNonce(addr))
checkeq("GetCode", state.GetCode(addr), checkstate.GetCode(addr))
@@ -347,10 +397,10 @@ func (test *snapshotTest) checkEqual(state, checkstate *StateDB) error {
checkeq("GetCodeSize", state.GetCodeSize(addr), checkstate.GetCodeSize(addr))
// Check storage.
if obj := state.getStateObject(addr); obj != nil {
- state.ForEachStorage(addr, func(key, value common.Hash) bool {
+ forEachStorage(state, addr, func(key, value common.Hash) bool {
return checkeq("GetState("+key.Hex()+")", checkstate.GetState(addr, key), value)
})
- checkstate.ForEachStorage(addr, func(key, value common.Hash) bool {
+ forEachStorage(checkstate, addr, func(key, value common.Hash) bool {
return checkeq("GetState("+key.Hex()+")", checkstate.GetState(addr, key), value)
})
}
@@ -373,9 +423,11 @@ func (test *snapshotTest) checkEqual(state, checkstate *StateDB) error {
// TestCopyOfCopy tests that modified objects are carried over to the copy, and the copy of the copy.
// See https://github.com/ethereum/go-ethereum/pull/15225#issuecomment-380191512
func TestCopyOfCopy(t *testing.T) {
- state, _ := New(common.Hash{}, NewDatabase(rawdb.NewMemoryDatabase()), nil)
+ db, cleanup := newPgIpfsEthdb(t)
+ t.Cleanup(cleanup)
+ state, _ := New(types.EmptyRootHash, NewDatabase(db), nil)
addr := common.HexToAddress("aaaa")
- state.SetBalance(addr, big.NewInt(42))
+ state.SetBalance(addr, uint256.NewInt(42))
if got := state.Copy().GetBalance(addr).Uint64(); got != 42 {
t.Fatalf("1st copy fail, expected 42, got %v", got)
@@ -394,9 +446,10 @@ func TestStateDBAccessList(t *testing.T) {
return common.HexToHash(a)
}
- memDb := rawdb.NewMemoryDatabase()
- db := NewDatabase(memDb)
- state, _ := New(common.Hash{}, db, nil)
+ pgdb, cleanup := newPgIpfsEthdb(t)
+ t.Cleanup(cleanup)
+ db := NewDatabase(pgdb)
+ state, _ := New(types.EmptyRootHash, db, nil)
state.accessList = newAccessList()
verifyAddrs := func(astrings ...string) {
@@ -560,9 +613,10 @@ func TestStateDBAccessList(t *testing.T) {
}
func TestStateDBTransientStorage(t *testing.T) {
- memDb := rawdb.NewMemoryDatabase()
- db := NewDatabase(memDb)
- state, _ := New(common.Hash{}, db, nil)
+ pgdb, cleanup := newPgIpfsEthdb(t)
+ t.Cleanup(cleanup)
+ db := NewDatabase(pgdb)
+ state, _ := New(types.EmptyRootHash, db, nil)
key := common.Hash{0x01}
value := common.Hash{0x02}
diff --git a/trie_by_cid/state/trie_prefetcher.go b/trie_by_cid/state/trie_prefetcher.go
index 5dd1b5b..df9e141 100644
--- a/trie_by_cid/state/trie_prefetcher.go
+++ b/trie_by_cid/state/trie_prefetcher.go
@@ -20,8 +20,8 @@ import (
"sync"
"github.com/ethereum/go-ethereum/common"
+ "github.com/ethereum/go-ethereum/log"
"github.com/ethereum/go-ethereum/metrics"
- log "github.com/sirupsen/logrus"
)
var (
@@ -37,7 +37,7 @@ var (
type triePrefetcher struct {
db Database // Database to fetch trie nodes through
root common.Hash // Root hash of the account trie for metrics
- fetches map[string]Trie // Partially or fully fetcher tries
+ fetches map[string]Trie // Partially or fully fetched tries. Only populated for inactive copies.
fetchers map[string]*subfetcher // Subfetchers for each trie
deliveryMissMeter metrics.Meter
@@ -141,7 +141,7 @@ func (p *triePrefetcher) copy() *triePrefetcher {
}
// prefetch schedules a batch of trie items to prefetch.
-func (p *triePrefetcher) prefetch(owner common.Hash, root common.Hash, keys [][]byte) {
+func (p *triePrefetcher) prefetch(owner common.Hash, root common.Hash, addr common.Address, keys [][]byte) {
// If the prefetcher is an inactive one, bail out
if p.fetches != nil {
return
@@ -150,7 +150,7 @@ func (p *triePrefetcher) prefetch(owner common.Hash, root common.Hash, keys [][]
id := p.trieID(owner, root)
fetcher := p.fetchers[id]
if fetcher == nil {
- fetcher = newSubfetcher(p.db, p.root, owner, root)
+ fetcher = newSubfetcher(p.db, p.root, owner, root, addr)
p.fetchers[id] = fetcher
}
fetcher.schedule(keys)
@@ -197,7 +197,10 @@ func (p *triePrefetcher) used(owner common.Hash, root common.Hash, used [][]byte
// trieID returns an unique trie identifier consists the trie owner and root hash.
func (p *triePrefetcher) trieID(owner common.Hash, root common.Hash) string {
- return string(append(owner.Bytes(), root.Bytes()...))
+ trieID := make([]byte, common.HashLength*2)
+ copy(trieID, owner.Bytes())
+ copy(trieID[common.HashLength:], root.Bytes())
+ return string(trieID)
}
// subfetcher is a trie fetcher goroutine responsible for pulling entries for a
@@ -205,11 +208,12 @@ func (p *triePrefetcher) trieID(owner common.Hash, root common.Hash) string {
// main prefetcher is paused and either all requested items are processed or if
// the trie being worked on is retrieved from the prefetcher.
type subfetcher struct {
- db Database // Database to load trie nodes through
- state common.Hash // Root hash of the state to prefetch
- owner common.Hash // Owner of the trie, usually account hash
- root common.Hash // Root hash of the trie to prefetch
- trie Trie // Trie being populated with nodes
+ db Database // Database to load trie nodes through
+ state common.Hash // Root hash of the state to prefetch
+ owner common.Hash // Owner of the trie, usually account hash
+ root common.Hash // Root hash of the trie to prefetch
+ addr common.Address // Address of the account that the trie belongs to
+ trie Trie // Trie being populated with nodes
tasks [][]byte // Items queued up for retrieval
lock sync.Mutex // Lock protecting the task queue
@@ -226,12 +230,13 @@ type subfetcher struct {
// newSubfetcher creates a goroutine to prefetch state items belonging to a
// particular root hash.
-func newSubfetcher(db Database, state common.Hash, owner common.Hash, root common.Hash) *subfetcher {
+func newSubfetcher(db Database, state common.Hash, owner common.Hash, root common.Hash, addr common.Address) *subfetcher {
sf := &subfetcher{
db: db,
state: state,
owner: owner,
root: root,
+ addr: addr,
wake: make(chan struct{}, 1),
stop: make(chan struct{}),
term: make(chan struct{}),
@@ -300,7 +305,9 @@ func (sf *subfetcher) loop() {
}
sf.trie = trie
} else {
- trie, err := sf.db.OpenStorageTrie(sf.state, sf.owner, sf.root)
+ // The trie argument can be nil as verkle doesn't support prefetching
+ // yet. TODO FIX IT(rjl493456442), otherwise code will panic here.
+ trie, err := sf.db.OpenStorageTrie(sf.state, sf.owner, sf.root, nil)
if err != nil {
log.Warn("Trie prefetcher failed opening trie", "root", sf.root, "err", err)
return
@@ -336,7 +343,11 @@ func (sf *subfetcher) loop() {
if _, ok := sf.seen[string(task)]; ok {
sf.dups++
} else {
- sf.trie.TryGet(task)
+ if len(task) == common.AddressLength {
+ sf.trie.GetAccount(common.BytesToAddress(task))
+ } else {
+ sf.trie.GetStorage(sf.addr, task)
+ }
sf.seen[string(task)] = struct{}{}
}
}
diff --git a/trie_by_cid/state/trie_prefetcher_test.go b/trie_by_cid/state/trie_prefetcher_test.go
index cb0b67d..c49d912 100644
--- a/trie_by_cid/state/trie_prefetcher_test.go
+++ b/trie_by_cid/state/trie_prefetcher_test.go
@@ -22,20 +22,23 @@ import (
"time"
"github.com/ethereum/go-ethereum/common"
- "github.com/ethereum/go-ethereum/core/rawdb"
+ "github.com/ethereum/go-ethereum/core/types"
+ "github.com/holiman/uint256"
)
-func filledStateDB() *StateDB {
- state, _ := New(common.Hash{}, NewDatabase(rawdb.NewMemoryDatabase()), nil)
+func filledStateDB(t *testing.T) *StateDB {
+ db, cleanup := newPgIpfsEthdb(t)
+ t.Cleanup(cleanup)
+ state, _ := New(types.EmptyRootHash, NewDatabase(db), nil)
// Create an account and check if the retrieved balance is correct
addr := common.HexToAddress("0xaffeaffeaffeaffeaffeaffeaffeaffeaffeaffe")
skey := common.HexToHash("aaa")
sval := common.HexToHash("bbb")
- state.SetBalance(addr, big.NewInt(42)) // Change the account trie
- state.SetCode(addr, []byte("hello")) // Change an external metadata
- state.SetState(addr, skey, sval) // Change the storage trie
+ state.SetBalance(addr, uint256.NewInt(42)) // Change the account trie
+ state.SetCode(addr, []byte("hello")) // Change an external metadata
+ state.SetState(addr, skey, sval) // Change the storage trie
for i := 0; i < 100; i++ {
sk := common.BigToHash(big.NewInt(int64(i)))
state.SetState(addr, sk, sk) // Change the storage trie
@@ -44,22 +47,22 @@ func filledStateDB() *StateDB {
}
func TestCopyAndClose(t *testing.T) {
- db := filledStateDB()
+ db := filledStateDB(t)
prefetcher := newTriePrefetcher(db.db, db.originalRoot, "")
skey := common.HexToHash("aaa")
- prefetcher.prefetch(common.Hash{}, db.originalRoot, [][]byte{skey.Bytes()})
- prefetcher.prefetch(common.Hash{}, db.originalRoot, [][]byte{skey.Bytes()})
+ prefetcher.prefetch(common.Hash{}, db.originalRoot, common.Address{}, [][]byte{skey.Bytes()})
+ prefetcher.prefetch(common.Hash{}, db.originalRoot, common.Address{}, [][]byte{skey.Bytes()})
time.Sleep(1 * time.Second)
a := prefetcher.trie(common.Hash{}, db.originalRoot)
- prefetcher.prefetch(common.Hash{}, db.originalRoot, [][]byte{skey.Bytes()})
+ prefetcher.prefetch(common.Hash{}, db.originalRoot, common.Address{}, [][]byte{skey.Bytes()})
b := prefetcher.trie(common.Hash{}, db.originalRoot)
cpy := prefetcher.copy()
- cpy.prefetch(common.Hash{}, db.originalRoot, [][]byte{skey.Bytes()})
- cpy.prefetch(common.Hash{}, db.originalRoot, [][]byte{skey.Bytes()})
+ cpy.prefetch(common.Hash{}, db.originalRoot, common.Address{}, [][]byte{skey.Bytes()})
+ cpy.prefetch(common.Hash{}, db.originalRoot, common.Address{}, [][]byte{skey.Bytes()})
c := cpy.trie(common.Hash{}, db.originalRoot)
prefetcher.close()
cpy2 := cpy.copy()
- cpy2.prefetch(common.Hash{}, db.originalRoot, [][]byte{skey.Bytes()})
+ cpy2.prefetch(common.Hash{}, db.originalRoot, common.Address{}, [][]byte{skey.Bytes()})
d := cpy2.trie(common.Hash{}, db.originalRoot)
cpy.close()
cpy2.close()
@@ -69,10 +72,10 @@ func TestCopyAndClose(t *testing.T) {
}
func TestUseAfterClose(t *testing.T) {
- db := filledStateDB()
+ db := filledStateDB(t)
prefetcher := newTriePrefetcher(db.db, db.originalRoot, "")
skey := common.HexToHash("aaa")
- prefetcher.prefetch(common.Hash{}, db.originalRoot, [][]byte{skey.Bytes()})
+ prefetcher.prefetch(common.Hash{}, db.originalRoot, common.Address{}, [][]byte{skey.Bytes()})
a := prefetcher.trie(common.Hash{}, db.originalRoot)
prefetcher.close()
b := prefetcher.trie(common.Hash{}, db.originalRoot)
@@ -85,10 +88,10 @@ func TestUseAfterClose(t *testing.T) {
}
func TestCopyClose(t *testing.T) {
- db := filledStateDB()
+ db := filledStateDB(t)
prefetcher := newTriePrefetcher(db.db, db.originalRoot, "")
skey := common.HexToHash("aaa")
- prefetcher.prefetch(common.Hash{}, db.originalRoot, [][]byte{skey.Bytes()})
+ prefetcher.prefetch(common.Hash{}, db.originalRoot, common.Address{}, [][]byte{skey.Bytes()})
cpy := prefetcher.copy()
a := prefetcher.trie(common.Hash{}, db.originalRoot)
b := cpy.trie(common.Hash{}, db.originalRoot)
diff --git a/trie_by_cid/trie/committer.go b/trie_by_cid/trie/committer.go
index 9f97887..c20f207 100644
--- a/trie_by_cid/trie/committer.go
+++ b/trie_by_cid/trie/committer.go
@@ -20,26 +20,24 @@ import (
"fmt"
"github.com/ethereum/go-ethereum/common"
-)
-// leaf represents a trie leaf node
-type leaf struct {
- blob []byte // raw blob of leaf
- parent common.Hash // the hash of parent node
-}
+ "github.com/cerc-io/ipld-eth-statedb/trie_by_cid/trie/trienode"
+)
// committer is the tool used for the trie Commit operation. The committer will
// capture all dirty nodes during the commit process and keep them cached in
// insertion order.
type committer struct {
- nodes *NodeSet
+ nodes *trienode.NodeSet
+ tracer *tracer
collectLeaf bool
}
// newCommitter creates a new committer or picks one from the pool.
-func newCommitter(nodeset *NodeSet, collectLeaf bool) *committer {
+func newCommitter(nodeset *trienode.NodeSet, tracer *tracer, collectLeaf bool) *committer {
return &committer{
nodes: nodeset,
+ tracer: tracer,
collectLeaf: collectLeaf,
}
}
@@ -134,24 +132,15 @@ func (c *committer) store(path []byte, n node) node {
// The node is embedded in its parent, in other words, this node
// will not be stored in the database independently, mark it as
// deleted only if the node was existent in database before.
- if _, ok := c.nodes.accessList[string(path)]; ok {
- c.nodes.markDeleted(path)
+ _, ok := c.tracer.accessList[string(path)]
+ if ok {
+ c.nodes.AddNode(path, trienode.NewDeleted())
}
return n
}
- // We have the hash already, estimate the RLP encoding-size of the node.
- // The size is used for mem tracking, does not need to be exact
- var (
- size = estimateSize(n)
- nhash = common.BytesToHash(hash)
- mnode = &memoryNode{
- hash: nhash,
- node: simplifyNode(n),
- size: uint16(size),
- }
- )
// Collect the dirty node to nodeset for return.
- c.nodes.markUpdated(path, mnode)
+ nhash := common.BytesToHash(hash)
+ c.nodes.AddNode(path, trienode.New(nhash, nodeToBytes(n)))
// Collect the corresponding leaf node if it's required. We don't check
// full node since it's impossible to store value in fullNode. The key
@@ -159,38 +148,36 @@ func (c *committer) store(path []byte, n node) node {
if c.collectLeaf {
if sn, ok := n.(*shortNode); ok {
if val, ok := sn.Val.(valueNode); ok {
- c.nodes.addLeaf(&leaf{blob: val, parent: nhash})
+ c.nodes.AddLeaf(nhash, val)
}
}
}
return hash
}
-// estimateSize estimates the size of an rlp-encoded node, without actually
-// rlp-encoding it (zero allocs). This method has been experimentally tried, and with a trie
-// with 1000 leaves, the only errors above 1% are on small shortnodes, where this
-// method overestimates by 2 or 3 bytes (e.g. 37 instead of 35)
-func estimateSize(n node) int {
+// MerkleResolver the children resolver in merkle-patricia-tree.
+type MerkleResolver struct{}
+
+// ForEach implements childResolver, decodes the provided node and
+// traverses the children inside.
+func (resolver MerkleResolver) ForEach(node []byte, onChild func(common.Hash)) {
+ forGatherChildren(mustDecodeNodeUnsafe(nil, node), onChild)
+}
+
+// forGatherChildren traverses the node hierarchy and invokes the callback
+// for all the hashnode children.
+func forGatherChildren(n node, onChild func(hash common.Hash)) {
switch n := n.(type) {
case *shortNode:
- // A short node contains a compacted key, and a value.
- return 3 + len(n.Key) + estimateSize(n.Val)
+ forGatherChildren(n.Val, onChild)
case *fullNode:
- // A full node contains up to 16 hashes (some nils), and a key
- s := 3
for i := 0; i < 16; i++ {
- if child := n.Children[i]; child != nil {
- s += estimateSize(child)
- } else {
- s++
- }
+ forGatherChildren(n.Children[i], onChild)
}
- return s
- case valueNode:
- return 1 + len(n)
case hashNode:
- return 1 + len(n)
+ onChild(common.BytesToHash(n))
+ case valueNode, nil:
default:
- panic(fmt.Sprintf("node type %T", n))
+ panic(fmt.Sprintf("unknown node type: %T", n))
}
}
diff --git a/trie_by_cid/trie/database.go b/trie_by_cid/trie/database.go
deleted file mode 100644
index ac05f98..0000000
--- a/trie_by_cid/trie/database.go
+++ /dev/null
@@ -1,440 +0,0 @@
-// Copyright 2018 The go-ethereum Authors
-// This file is part of the go-ethereum library.
-//
-// The go-ethereum library is free software: you can redistribute it and/or modify
-// it under the terms of the GNU Lesser General Public License as published by
-// the Free Software Foundation, either version 3 of the License, or
-// (at your option) any later version.
-//
-// The go-ethereum library is distributed in the hope that it will be useful,
-// but WITHOUT ANY WARRANTY; without even the implied warranty of
-// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
-// GNU Lesser General Public License for more details.
-//
-// You should have received a copy of the GNU Lesser General Public License
-// along with the go-ethereum library. If not, see .
-
-package trie
-
-import (
- "errors"
- "runtime"
- "sync"
- "time"
-
- "github.com/VictoriaMetrics/fastcache"
- "github.com/cerc-io/ipld-eth-statedb/internal"
- "github.com/ethereum/go-ethereum/common"
- "github.com/ethereum/go-ethereum/core/rawdb"
- "github.com/ethereum/go-ethereum/core/types"
- "github.com/ethereum/go-ethereum/ethdb"
- "github.com/ethereum/go-ethereum/rlp"
- "github.com/ethereum/go-ethereum/trie"
- log "github.com/sirupsen/logrus"
-)
-
-// Database is an intermediate write layer between the trie data structures and
-// the disk database. The aim is to accumulate trie writes in-memory and only
-// periodically flush a couple tries to disk, garbage collecting the remainder.
-//
-// Note, the trie Database is **not** thread safe in its mutations, but it **is**
-// thread safe in providing individual, independent node access. The rationale
-// behind this split design is to provide read access to RPC handlers and sync
-// servers even while the trie is executing expensive garbage collection.
-type Database struct {
- diskdb ethdb.Database // Persistent storage for matured trie nodes
-
- cleans *fastcache.Cache // GC friendly memory cache of clean node RLPs
- dirties map[common.Hash]*cachedNode // Data and references relationships of dirty trie nodes
- oldest common.Hash // Oldest tracked node, flush-list head
- newest common.Hash // Newest tracked node, flush-list tail
-
- gctime time.Duration // Time spent on garbage collection since last commit
- gcnodes uint64 // Nodes garbage collected since last commit
- gcsize common.StorageSize // Data storage garbage collected since last commit
-
- flushtime time.Duration // Time spent on data flushing since last commit
- flushnodes uint64 // Nodes flushed since last commit
- flushsize common.StorageSize // Data storage flushed since last commit
-
- dirtiesSize common.StorageSize // Storage size of the dirty node cache (exc. metadata)
- childrenSize common.StorageSize // Storage size of the external children tracking
- preimages *preimageStore // The store for caching preimages
-
- lock sync.RWMutex
-}
-
-// Config defines all necessary options for database.
-// (re-export)
-type Config = trie.Config
-
-// NewDatabase creates a new trie database to store ephemeral trie content before
-// its written out to disk or garbage collected. No read cache is created, so all
-// data retrievals will hit the underlying disk database.
-func NewDatabase(diskdb ethdb.Database) *Database {
- return NewDatabaseWithConfig(diskdb, nil)
-}
-
-// NewDatabaseWithConfig creates a new trie database to store ephemeral trie content
-// before its written out to disk or garbage collected. It also acts as a read cache
-// for nodes loaded from disk.
-func NewDatabaseWithConfig(diskdb ethdb.Database, config *Config) *Database {
- var cleans *fastcache.Cache
- if config != nil && config.Cache > 0 {
- if config.Journal == "" {
- cleans = fastcache.New(config.Cache * 1024 * 1024)
- } else {
- cleans = fastcache.LoadFromFileOrNew(config.Journal, config.Cache*1024*1024)
- }
- }
- var preimage *preimageStore
- if config != nil && config.Preimages {
- preimage = newPreimageStore(diskdb)
- }
- db := &Database{
- diskdb: diskdb,
- cleans: cleans,
- dirties: map[common.Hash]*cachedNode{{}: {
- children: make(map[common.Hash]uint16),
- }},
- preimages: preimage,
- }
- return db
-}
-
-// insert inserts a simplified trie node into the memory database.
-// All nodes inserted by this function will be reference tracked
-// and in theory should only used for **trie nodes** insertion.
-func (db *Database) insert(hash common.Hash, size int, node node) {
- // If the node's already cached, skip
- if _, ok := db.dirties[hash]; ok {
- return
- }
- memcacheDirtyWriteMeter.Mark(int64(size))
-
- // Create the cached entry for this node
- entry := &cachedNode{
- node: node,
- size: uint16(size),
- flushPrev: db.newest,
- }
- entry.forChilds(func(child common.Hash) {
- if c := db.dirties[child]; c != nil {
- c.parents++
- }
- })
- db.dirties[hash] = entry
-
- // Update the flush-list endpoints
- if db.oldest == (common.Hash{}) {
- db.oldest, db.newest = hash, hash
- } else {
- db.dirties[db.newest].flushNext, db.newest = hash, hash
- }
- db.dirtiesSize += common.StorageSize(common.HashLength + entry.size)
-}
-
-// Node retrieves an encoded cached trie node from memory. If it cannot be found
-// cached, the method queries the persistent database for the content.
-func (db *Database) Node(hash common.Hash, codec uint64) ([]byte, error) {
- // It doesn't make sense to retrieve the metaroot
- if hash == (common.Hash{}) {
- return nil, errors.New("not found")
- }
- // Retrieve the node from the clean cache if available
- if db.cleans != nil {
- if enc := db.cleans.Get(nil, hash[:]); enc != nil {
- memcacheCleanHitMeter.Mark(1)
- memcacheCleanReadMeter.Mark(int64(len(enc)))
- return enc, nil
- }
- }
- // Retrieve the node from the dirty cache if available
- db.lock.RLock()
- dirty := db.dirties[hash]
- db.lock.RUnlock()
-
- if dirty != nil {
- memcacheDirtyHitMeter.Mark(1)
- memcacheDirtyReadMeter.Mark(int64(dirty.size))
- return dirty.rlp(), nil
- }
- memcacheDirtyMissMeter.Mark(1)
-
- // Content unavailable in memory, attempt to retrieve from disk
- cid, err := internal.Keccak256ToCid(codec, hash[:])
- if err != nil {
- return nil, err
- }
- enc, err := db.diskdb.Get(cid.Bytes())
- if err != nil {
- return nil, err
- }
- if len(enc) != 0 {
- if db.cleans != nil {
- db.cleans.Set(hash[:], enc)
- memcacheCleanMissMeter.Mark(1)
- memcacheCleanWriteMeter.Mark(int64(len(enc)))
- }
- return enc, nil
- }
- return nil, errors.New("not found")
-}
-
-// Nodes retrieves the hashes of all the nodes cached within the memory database.
-// This method is extremely expensive and should only be used to validate internal
-// states in test code.
-func (db *Database) Nodes() []common.Hash {
- db.lock.RLock()
- defer db.lock.RUnlock()
-
- var hashes = make([]common.Hash, 0, len(db.dirties))
- for hash := range db.dirties {
- if hash != (common.Hash{}) { // Special case for "root" references/nodes
- hashes = append(hashes, hash)
- }
- }
- return hashes
-}
-
-// Reference adds a new reference from a parent node to a child node.
-// This function is used to add reference between internal trie node
-// and external node(e.g. storage trie root), all internal trie nodes
-// are referenced together by database itself.
-func (db *Database) Reference(child common.Hash, parent common.Hash) {
- db.lock.Lock()
- defer db.lock.Unlock()
-
- db.reference(child, parent)
-}
-
-// reference is the private locked version of Reference.
-func (db *Database) reference(child common.Hash, parent common.Hash) {
- // If the node does not exist, it's a node pulled from disk, skip
- node, ok := db.dirties[child]
- if !ok {
- return
- }
- // If the reference already exists, only duplicate for roots
- if db.dirties[parent].children == nil {
- db.dirties[parent].children = make(map[common.Hash]uint16)
- db.childrenSize += cachedNodeChildrenSize
- } else if _, ok = db.dirties[parent].children[child]; ok && parent != (common.Hash{}) {
- return
- }
- node.parents++
- db.dirties[parent].children[child]++
- if db.dirties[parent].children[child] == 1 {
- db.childrenSize += common.HashLength + 2 // uint16 counter
- }
-}
-
-// Dereference removes an existing reference from a root node.
-func (db *Database) Dereference(root common.Hash) {
- // Sanity check to ensure that the meta-root is not removed
- if root == (common.Hash{}) {
- log.Error("Attempted to dereference the trie cache meta root")
- return
- }
- db.lock.Lock()
- defer db.lock.Unlock()
-
- nodes, storage, start := len(db.dirties), db.dirtiesSize, time.Now()
- db.dereference(root, common.Hash{})
-
- db.gcnodes += uint64(nodes - len(db.dirties))
- db.gcsize += storage - db.dirtiesSize
- db.gctime += time.Since(start)
-
- memcacheGCTimeTimer.Update(time.Since(start))
- memcacheGCSizeMeter.Mark(int64(storage - db.dirtiesSize))
- memcacheGCNodesMeter.Mark(int64(nodes - len(db.dirties)))
-
- log.Debug("Dereferenced trie from memory database", "nodes", nodes-len(db.dirties), "size", storage-db.dirtiesSize, "time", time.Since(start),
- "gcnodes", db.gcnodes, "gcsize", db.gcsize, "gctime", db.gctime, "livenodes", len(db.dirties), "livesize", db.dirtiesSize)
-}
-
-// dereference is the private locked version of Dereference.
-func (db *Database) dereference(child common.Hash, parent common.Hash) {
- // Dereference the parent-child
- node := db.dirties[parent]
-
- if node.children != nil && node.children[child] > 0 {
- node.children[child]--
- if node.children[child] == 0 {
- delete(node.children, child)
- db.childrenSize -= (common.HashLength + 2) // uint16 counter
- }
- }
- // If the child does not exist, it's a previously committed node.
- node, ok := db.dirties[child]
- if !ok {
- return
- }
- // If there are no more references to the child, delete it and cascade
- if node.parents > 0 {
- // This is a special cornercase where a node loaded from disk (i.e. not in the
- // memcache any more) gets reinjected as a new node (short node split into full,
- // then reverted into short), causing a cached node to have no parents. That is
- // no problem in itself, but don't make maxint parents out of it.
- node.parents--
- }
- if node.parents == 0 {
- // Remove the node from the flush-list
- switch child {
- case db.oldest:
- db.oldest = node.flushNext
- db.dirties[node.flushNext].flushPrev = common.Hash{}
- case db.newest:
- db.newest = node.flushPrev
- db.dirties[node.flushPrev].flushNext = common.Hash{}
- default:
- db.dirties[node.flushPrev].flushNext = node.flushNext
- db.dirties[node.flushNext].flushPrev = node.flushPrev
- }
- // Dereference all children and delete the node
- node.forChilds(func(hash common.Hash) {
- db.dereference(hash, child)
- })
- delete(db.dirties, child)
- db.dirtiesSize -= common.StorageSize(common.HashLength + int(node.size))
- if node.children != nil {
- db.childrenSize -= cachedNodeChildrenSize
- }
- }
-}
-
-// Update inserts the dirty nodes in provided nodeset into database and
-// link the account trie with multiple storage tries if necessary.
-func (db *Database) Update(nodes *MergedNodeSet) error {
- db.lock.Lock()
- defer db.lock.Unlock()
-
- // Insert dirty nodes into the database. In the same tree, it must be
- // ensured that children are inserted first, then parent so that children
- // can be linked with their parent correctly.
- //
- // Note, the storage tries must be flushed before the account trie to
- // retain the invariant that children go into the dirty cache first.
- var order []common.Hash
- for owner := range nodes.sets {
- if owner == (common.Hash{}) {
- continue
- }
- order = append(order, owner)
- }
- if _, ok := nodes.sets[common.Hash{}]; ok {
- order = append(order, common.Hash{})
- }
- for _, owner := range order {
- subset := nodes.sets[owner]
- subset.forEachWithOrder(func(path string, n *memoryNode) {
- if n.isDeleted() {
- return // ignore deletion
- }
- db.insert(n.hash, int(n.size), n.node)
- })
- }
- // Link up the account trie and storage trie if the node points
- // to an account trie leaf.
- if set, present := nodes.sets[common.Hash{}]; present {
- for _, n := range set.leaves {
- var account types.StateAccount
- if err := rlp.DecodeBytes(n.blob, &account); err != nil {
- return err
- }
- if account.Root != types.EmptyRootHash {
- db.reference(account.Root, n.parent)
- }
- }
- }
- return nil
-}
-
-// Size returns the current storage size of the memory cache in front of the
-// persistent database layer.
-func (db *Database) Size() (common.StorageSize, common.StorageSize) {
- db.lock.RLock()
- defer db.lock.RUnlock()
-
- // db.dirtiesSize only contains the useful data in the cache, but when reporting
- // the total memory consumption, the maintenance metadata is also needed to be
- // counted.
- var metadataSize = common.StorageSize((len(db.dirties) - 1) * cachedNodeSize)
- var metarootRefs = common.StorageSize(len(db.dirties[common.Hash{}].children) * (common.HashLength + 2))
- var preimageSize common.StorageSize
- if db.preimages != nil {
- preimageSize = db.preimages.size()
- }
- return db.dirtiesSize + db.childrenSize + metadataSize - metarootRefs, preimageSize
-}
-
-// GetReader retrieves a node reader belonging to the given state root.
-func (db *Database) GetReader(root common.Hash, codec uint64) Reader {
- return &hashReader{db: db, codec: codec}
-}
-
-// hashReader is reader of hashDatabase which implements the Reader interface.
-type hashReader struct {
- db *Database
- codec uint64
-}
-
-// Node retrieves the trie node with the given node hash.
-func (reader *hashReader) Node(owner common.Hash, path []byte, hash common.Hash) (node, error) {
- blob, err := reader.NodeBlob(owner, path, hash)
- if err != nil {
- return nil, err
- }
- return decodeNodeUnsafe(hash[:], blob)
-}
-
-// NodeBlob retrieves the RLP-encoded trie node blob with the given node hash.
-func (reader *hashReader) NodeBlob(_ common.Hash, _ []byte, hash common.Hash) ([]byte, error) {
- return reader.db.Node(hash, reader.codec)
-}
-
-// saveCache saves clean state cache to given directory path
-// using specified CPU cores.
-func (db *Database) saveCache(dir string, threads int) error {
- if db.cleans == nil {
- return nil
- }
- log.Info("Writing clean trie cache to disk", "path", dir, "threads", threads)
-
- start := time.Now()
- err := db.cleans.SaveToFileConcurrent(dir, threads)
- if err != nil {
- log.Error("Failed to persist clean trie cache", "error", err)
- return err
- }
- log.Info("Persisted the clean trie cache", "path", dir, "elapsed", common.PrettyDuration(time.Since(start)))
- return nil
-}
-
-// SaveCache atomically saves fast cache data to the given dir using all
-// available CPU cores.
-func (db *Database) SaveCache(dir string) error {
- return db.saveCache(dir, runtime.GOMAXPROCS(0))
-}
-
-// SaveCachePeriodically atomically saves fast cache data to the given dir with
-// the specified interval. All dump operation will only use a single CPU core.
-func (db *Database) SaveCachePeriodically(dir string, interval time.Duration, stopCh <-chan struct{}) {
- ticker := time.NewTicker(interval)
- defer ticker.Stop()
-
- for {
- select {
- case <-ticker.C:
- db.saveCache(dir, 1)
- case <-stopCh:
- return
- }
- }
-}
-
-// Scheme returns the node scheme used in the database.
-func (db *Database) Scheme() string {
- return rawdb.HashScheme
-}
diff --git a/trie_by_cid/trie/database_metrics.go b/trie_by_cid/trie/database_metrics.go
deleted file mode 100644
index 55efc55..0000000
--- a/trie_by_cid/trie/database_metrics.go
+++ /dev/null
@@ -1,27 +0,0 @@
-package trie
-
-import "github.com/ethereum/go-ethereum/metrics"
-
-var (
- memcacheCleanHitMeter = metrics.NewRegisteredMeter("trie/memcache/clean/hit", nil)
- memcacheCleanMissMeter = metrics.NewRegisteredMeter("trie/memcache/clean/miss", nil)
- memcacheCleanReadMeter = metrics.NewRegisteredMeter("trie/memcache/clean/read", nil)
- memcacheCleanWriteMeter = metrics.NewRegisteredMeter("trie/memcache/clean/write", nil)
-
- memcacheDirtyHitMeter = metrics.NewRegisteredMeter("trie/memcache/dirty/hit", nil)
- memcacheDirtyMissMeter = metrics.NewRegisteredMeter("trie/memcache/dirty/miss", nil)
- memcacheDirtyReadMeter = metrics.NewRegisteredMeter("trie/memcache/dirty/read", nil)
- memcacheDirtyWriteMeter = metrics.NewRegisteredMeter("trie/memcache/dirty/write", nil)
-
- memcacheFlushTimeTimer = metrics.NewRegisteredResettingTimer("trie/memcache/flush/time", nil)
- memcacheFlushNodesMeter = metrics.NewRegisteredMeter("trie/memcache/flush/nodes", nil)
- memcacheFlushSizeMeter = metrics.NewRegisteredMeter("trie/memcache/flush/size", nil)
-
- memcacheGCTimeTimer = metrics.NewRegisteredResettingTimer("trie/memcache/gc/time", nil)
- memcacheGCNodesMeter = metrics.NewRegisteredMeter("trie/memcache/gc/nodes", nil)
- memcacheGCSizeMeter = metrics.NewRegisteredMeter("trie/memcache/gc/size", nil)
-
- memcacheCommitTimeTimer = metrics.NewRegisteredResettingTimer("trie/memcache/commit/time", nil)
- memcacheCommitNodesMeter = metrics.NewRegisteredMeter("trie/memcache/commit/nodes", nil)
- memcacheCommitSizeMeter = metrics.NewRegisteredMeter("trie/memcache/commit/size", nil)
-)
diff --git a/trie_by_cid/trie/database_node.go b/trie_by_cid/trie/database_node.go
deleted file mode 100644
index 58037ae..0000000
--- a/trie_by_cid/trie/database_node.go
+++ /dev/null
@@ -1,183 +0,0 @@
-package trie
-
-import (
- "fmt"
- "io"
- "reflect"
-
- "github.com/ethereum/go-ethereum/common"
- "github.com/ethereum/go-ethereum/rlp"
-)
-
-// rawNode is a simple binary blob used to differentiate between collapsed trie
-// nodes and already encoded RLP binary blobs (while at the same time store them
-// in the same cache fields).
-type rawNode []byte
-
-func (n rawNode) cache() (hashNode, bool) { panic("this should never end up in a live trie") }
-func (n rawNode) fstring(ind string) string { panic("this should never end up in a live trie") }
-
-func (n rawNode) EncodeRLP(w io.Writer) error {
- _, err := w.Write(n)
- return err
-}
-
-// rawFullNode represents only the useful data content of a full node, with the
-// caches and flags stripped out to minimize its data storage. This type honors
-// the same RLP encoding as the original parent.
-type rawFullNode [17]node
-
-func (n rawFullNode) cache() (hashNode, bool) { panic("this should never end up in a live trie") }
-func (n rawFullNode) fstring(ind string) string { panic("this should never end up in a live trie") }
-
-func (n rawFullNode) EncodeRLP(w io.Writer) error {
- eb := rlp.NewEncoderBuffer(w)
- n.encode(eb)
- return eb.Flush()
-}
-
-// rawShortNode represents only the useful data content of a short node, with the
-// caches and flags stripped out to minimize its data storage. This type honors
-// the same RLP encoding as the original parent.
-type rawShortNode struct {
- Key []byte
- Val node
-}
-
-func (n rawShortNode) cache() (hashNode, bool) { panic("this should never end up in a live trie") }
-func (n rawShortNode) fstring(ind string) string { panic("this should never end up in a live trie") }
-
-// cachedNode is all the information we know about a single cached trie node
-// in the memory database write layer.
-type cachedNode struct {
- node node // Cached collapsed trie node, or raw rlp data
- size uint16 // Byte size of the useful cached data
-
- parents uint32 // Number of live nodes referencing this one
- children map[common.Hash]uint16 // External children referenced by this node
-
- flushPrev common.Hash // Previous node in the flush-list
- flushNext common.Hash // Next node in the flush-list
-}
-
-// cachedNodeSize is the raw size of a cachedNode data structure without any
-// node data included. It's an approximate size, but should be a lot better
-// than not counting them.
-var cachedNodeSize = int(reflect.TypeOf(cachedNode{}).Size())
-
-// cachedNodeChildrenSize is the raw size of an initialized but empty external
-// reference map.
-const cachedNodeChildrenSize = 48
-
-// rlp returns the raw rlp encoded blob of the cached trie node, either directly
-// from the cache, or by regenerating it from the collapsed node.
-func (n *cachedNode) rlp() []byte {
- if node, ok := n.node.(rawNode); ok {
- return node
- }
- return nodeToBytes(n.node)
-}
-
-// obj returns the decoded and expanded trie node, either directly from the cache,
-// or by regenerating it from the rlp encoded blob.
-func (n *cachedNode) obj(hash common.Hash) node {
- if node, ok := n.node.(rawNode); ok {
- // The raw-blob format nodes are loaded either from the
- // clean cache or the database, they are all in their own
- // copy and safe to use unsafe decoder.
- return mustDecodeNodeUnsafe(hash[:], node)
- }
- return expandNode(hash[:], n.node)
-}
-
-// forChilds invokes the callback for all the tracked children of this node,
-// both the implicit ones from inside the node as well as the explicit ones
-// from outside the node.
-func (n *cachedNode) forChilds(onChild func(hash common.Hash)) {
- for child := range n.children {
- onChild(child)
- }
- if _, ok := n.node.(rawNode); !ok {
- forGatherChildren(n.node, onChild)
- }
-}
-
-// forGatherChildren traverses the node hierarchy of a collapsed storage node and
-// invokes the callback for all the hashnode children.
-func forGatherChildren(n node, onChild func(hash common.Hash)) {
- switch n := n.(type) {
- case *rawShortNode:
- forGatherChildren(n.Val, onChild)
- case rawFullNode:
- for i := 0; i < 16; i++ {
- forGatherChildren(n[i], onChild)
- }
- case hashNode:
- onChild(common.BytesToHash(n))
- case valueNode, nil, rawNode:
- default:
- panic(fmt.Sprintf("unknown node type: %T", n))
- }
-}
-
-// simplifyNode traverses the hierarchy of an expanded memory node and discards
-// all the internal caches, returning a node that only contains the raw data.
-func simplifyNode(n node) node {
- switch n := n.(type) {
- case *shortNode:
- // Short nodes discard the flags and cascade
- return &rawShortNode{Key: n.Key, Val: simplifyNode(n.Val)}
-
- case *fullNode:
- // Full nodes discard the flags and cascade
- node := rawFullNode(n.Children)
- for i := 0; i < len(node); i++ {
- if node[i] != nil {
- node[i] = simplifyNode(node[i])
- }
- }
- return node
-
- case valueNode, hashNode, rawNode:
- return n
-
- default:
- panic(fmt.Sprintf("unknown node type: %T", n))
- }
-}
-
-// expandNode traverses the node hierarchy of a collapsed storage node and converts
-// all fields and keys into expanded memory form.
-func expandNode(hash hashNode, n node) node {
- switch n := n.(type) {
- case *rawShortNode:
- // Short nodes need key and child expansion
- return &shortNode{
- Key: compactToHex(n.Key),
- Val: expandNode(nil, n.Val),
- flags: nodeFlag{
- hash: hash,
- },
- }
-
- case rawFullNode:
- // Full nodes need child expansion
- node := &fullNode{
- flags: nodeFlag{
- hash: hash,
- },
- }
- for i := 0; i < len(node.Children); i++ {
- if n[i] != nil {
- node.Children[i] = expandNode(nil, n[i])
- }
- }
- return node
-
- case valueNode, hashNode:
- return n
-
- default:
- panic(fmt.Sprintf("unknown node type: %T", n))
- }
-}
diff --git a/trie_by_cid/trie/database_test.go b/trie_by_cid/trie/database_test.go
index 5a6e36d..663b04f 100644
--- a/trie_by_cid/trie/database_test.go
+++ b/trie_by_cid/trie/database_test.go
@@ -17,17 +17,137 @@
package trie
import (
- "testing"
-
"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/core/rawdb"
+ "github.com/ethereum/go-ethereum/core/types"
+ "github.com/ethereum/go-ethereum/ethdb"
+
+ "github.com/cerc-io/ipld-eth-statedb/trie_by_cid/trie/trienode"
+ "github.com/cerc-io/ipld-eth-statedb/trie_by_cid/triedb/database"
)
-// Tests that the trie database returns a missing trie node error if attempting
-// to retrieve the meta root.
-func TestDatabaseMetarootFetch(t *testing.T) {
- db := NewDatabase(rawdb.NewMemoryDatabase())
- if _, err := db.Node(common.Hash{}, StateTrieCodec); err == nil {
- t.Fatalf("metaroot retrieval succeeded")
+// testReader implements database.Reader interface, providing function to
+// access trie nodes.
+type testReader struct {
+ db ethdb.Database
+ scheme string
+ nodes []*trienode.MergedNodeSet // sorted from new to old
+}
+
+// Node implements database.Reader interface, retrieving trie node with
+// all available cached layers.
+func (r *testReader) Node(owner common.Hash, path []byte, hash common.Hash) ([]byte, error) {
+ // Check the node presence with the cached layer, from latest to oldest.
+ for _, nodes := range r.nodes {
+ if _, ok := nodes.Sets[owner]; !ok {
+ continue
+ }
+ n, ok := nodes.Sets[owner].Nodes[string(path)]
+ if !ok {
+ continue
+ }
+ if n.IsDeleted() || n.Hash != hash {
+ return nil, &MissingNodeError{Owner: owner, Path: path, NodeHash: hash}
+ }
+ return n.Blob, nil
+ }
+ // Check the node presence in database.
+ return rawdb.ReadTrieNode(r.db, owner, path, hash, r.scheme), nil
+}
+
+// testDb implements database.Database interface, using for testing purpose.
+type testDb struct {
+ disk ethdb.Database
+ root common.Hash
+ scheme string
+ nodes map[common.Hash]*trienode.MergedNodeSet
+ parents map[common.Hash]common.Hash
+}
+
+func newTestDatabase(diskdb ethdb.Database, scheme string) *testDb {
+ return &testDb{
+ disk: diskdb,
+ root: types.EmptyRootHash,
+ scheme: scheme,
+ nodes: make(map[common.Hash]*trienode.MergedNodeSet),
+ parents: make(map[common.Hash]common.Hash),
}
}
+
+func (db *testDb) Reader(stateRoot common.Hash) (database.Reader, error) {
+ nodes, _ := db.dirties(stateRoot, true)
+ return &testReader{db: db.disk, scheme: db.scheme, nodes: nodes}, nil
+}
+
+func (db *testDb) Preimage(hash common.Hash) []byte {
+ return rawdb.ReadPreimage(db.disk, hash)
+}
+
+func (db *testDb) InsertPreimage(preimages map[common.Hash][]byte) {
+ rawdb.WritePreimages(db.disk, preimages)
+}
+
+func (db *testDb) Scheme() string { return db.scheme }
+
+func (db *testDb) Update(root common.Hash, parent common.Hash, nodes *trienode.MergedNodeSet) error {
+ if root == parent {
+ return nil
+ }
+ if _, ok := db.nodes[root]; ok {
+ return nil
+ }
+ db.parents[root] = parent
+ db.nodes[root] = nodes
+ return nil
+}
+
+func (db *testDb) dirties(root common.Hash, topToBottom bool) ([]*trienode.MergedNodeSet, []common.Hash) {
+ var (
+ pending []*trienode.MergedNodeSet
+ roots []common.Hash
+ )
+ for {
+ if root == db.root {
+ break
+ }
+ nodes, ok := db.nodes[root]
+ if !ok {
+ break
+ }
+ if topToBottom {
+ pending = append(pending, nodes)
+ roots = append(roots, root)
+ } else {
+ pending = append([]*trienode.MergedNodeSet{nodes}, pending...)
+ roots = append([]common.Hash{root}, roots...)
+ }
+ root = db.parents[root]
+ }
+ return pending, roots
+}
+
+func (db *testDb) Commit(root common.Hash) error {
+ if root == db.root {
+ return nil
+ }
+ pending, roots := db.dirties(root, false)
+ for i, nodes := range pending {
+ for owner, set := range nodes.Sets {
+ if owner == (common.Hash{}) {
+ continue
+ }
+ set.ForEachWithOrder(func(path string, n *trienode.Node) {
+ rawdb.WriteTrieNode(db.disk, owner, []byte(path), n.Hash, n.Blob, db.scheme)
+ })
+ }
+ nodes.Sets[common.Hash{}].ForEachWithOrder(func(path string, n *trienode.Node) {
+ rawdb.WriteTrieNode(db.disk, common.Hash{}, []byte(path), n.Hash, n.Blob, db.scheme)
+ })
+ db.root = roots[i]
+ }
+ for _, root := range roots {
+ delete(db.nodes, root)
+ delete(db.parents, root)
+ }
+ return nil
+}
diff --git a/trie_by_cid/trie/encoding.go b/trie_by_cid/trie/encoding.go
index ace4570..3284d3f 100644
--- a/trie_by_cid/trie/encoding.go
+++ b/trie_by_cid/trie/encoding.go
@@ -34,11 +34,6 @@ package trie
// in the case of an odd number. All remaining nibbles (now an even number) fit properly
// into the remaining bytes. Compact encoding is used for nodes stored on disk.
-// HexToCompact converts a hex path to the compact encoded format
-func HexToCompact(hex []byte) []byte {
- return hexToCompact(hex)
-}
-
func hexToCompact(hex []byte) []byte {
terminator := byte(0)
if hasTerm(hex) {
@@ -56,9 +51,8 @@ func hexToCompact(hex []byte) []byte {
return buf
}
-// hexToCompactInPlace places the compact key in input buffer, returning the length
-// needed for the representation
-func hexToCompactInPlace(hex []byte) int {
+// hexToCompactInPlace places the compact key in input buffer, returning the compacted key.
+func hexToCompactInPlace(hex []byte) []byte {
var (
hexLen = len(hex) // length of the hex input
firstByte = byte(0)
@@ -82,12 +76,7 @@ func hexToCompactInPlace(hex []byte) int {
hex[bi] = hex[ni]<<4 | hex[ni+1]
}
hex[0] = firstByte
- return binLen
-}
-
-// CompactToHex converts a compact encoded path to hex format
-func CompactToHex(compact []byte) []byte {
- return compactToHex(compact)
+ return hex[:binLen]
}
func compactToHex(compact []byte) []byte {
@@ -115,9 +104,9 @@ func keybytesToHex(str []byte) []byte {
return nibbles
}
-// hexToKeyBytes turns hex nibbles into key bytes.
+// hexToKeybytes turns hex nibbles into key bytes.
// This can only be used for keys of even length.
-func hexToKeyBytes(hex []byte) []byte {
+func hexToKeybytes(hex []byte) []byte {
if hasTerm(hex) {
hex = hex[:len(hex)-1]
}
diff --git a/trie_by_cid/trie/encoding_test.go b/trie_by_cid/trie/encoding_test.go
index abc1e9d..ac50b5d 100644
--- a/trie_by_cid/trie/encoding_test.go
+++ b/trie_by_cid/trie/encoding_test.go
@@ -72,8 +72,8 @@ func TestHexKeybytes(t *testing.T) {
if h := keybytesToHex(test.key); !bytes.Equal(h, test.hexOut) {
t.Errorf("keybytesToHex(%x) -> %x, want %x", test.key, h, test.hexOut)
}
- if k := hexToKeyBytes(test.hexIn); !bytes.Equal(k, test.key) {
- t.Errorf("hexToKeyBytes(%x) -> %x, want %x", test.hexIn, k, test.key)
+ if k := hexToKeybytes(test.hexIn); !bytes.Equal(k, test.key) {
+ t.Errorf("hexToKeybytes(%x) -> %x, want %x", test.hexIn, k, test.key)
}
}
}
@@ -86,8 +86,7 @@ func TestHexToCompactInPlace(t *testing.T) {
} {
hexBytes, _ := hex.DecodeString(key)
exp := hexToCompact(hexBytes)
- sz := hexToCompactInPlace(hexBytes)
- got := hexBytes[:sz]
+ got := hexToCompactInPlace(hexBytes)
if !bytes.Equal(exp, got) {
t.Fatalf("test %d: encoding err\ninp %v\ngot %x\nexp %x\n", i, key, got, exp)
}
@@ -102,8 +101,7 @@ func TestHexToCompactInPlaceRandom(t *testing.T) {
hexBytes := keybytesToHex(key)
hexOrig := []byte(string(hexBytes))
exp := hexToCompact(hexBytes)
- sz := hexToCompactInPlace(hexBytes)
- got := hexBytes[:sz]
+ got := hexToCompactInPlace(hexBytes)
if !bytes.Equal(exp, got) {
t.Fatalf("encoding err \ncpt %x\nhex %x\ngot %x\nexp %x\n",
@@ -119,6 +117,13 @@ func BenchmarkHexToCompact(b *testing.B) {
}
}
+func BenchmarkHexToCompactInPlace(b *testing.B) {
+ testBytes := []byte{0, 15, 1, 12, 11, 8, 16 /*term*/}
+ for i := 0; i < b.N; i++ {
+ hexToCompactInPlace(testBytes)
+ }
+}
+
func BenchmarkCompactToHex(b *testing.B) {
testBytes := []byte{0, 15, 1, 12, 11, 8, 16 /*term*/}
for i := 0; i < b.N; i++ {
@@ -136,6 +141,6 @@ func BenchmarkKeybytesToHex(b *testing.B) {
func BenchmarkHexToKeybytes(b *testing.B) {
testBytes := []byte{7, 6, 6, 5, 7, 2, 6, 2, 16}
for i := 0; i < b.N; i++ {
- hexToKeyBytes(testBytes)
+ hexToKeybytes(testBytes)
}
}
diff --git a/trie_by_cid/trie/errors.go b/trie_by_cid/trie/errors.go
index afe344b..7be7041 100644
--- a/trie_by_cid/trie/errors.go
+++ b/trie_by_cid/trie/errors.go
@@ -17,12 +17,18 @@
package trie
import (
+ "errors"
"fmt"
"github.com/ethereum/go-ethereum/common"
)
-// MissingNodeError is returned by the trie functions (TryGet, TryUpdate, TryDelete)
+// ErrCommitted is returned when a already committed trie is requested for usage.
+// The potential usages can be `Get`, `Update`, `Delete`, `NodeIterator`, `Prove`
+// and so on.
+var ErrCommitted = errors.New("trie is already committed")
+
+// MissingNodeError is returned by the trie functions (Get, Update, Delete)
// in the case where a trie node is not present in the local database. It contains
// information necessary for retrieving the missing node.
type MissingNodeError struct {
diff --git a/trie_by_cid/trie/hasher.go b/trie_by_cid/trie/hasher.go
index e594d6d..1e063d8 100644
--- a/trie_by_cid/trie/hasher.go
+++ b/trie_by_cid/trie/hasher.go
@@ -84,20 +84,19 @@ func (h *hasher) hash(n node, force bool) (hashed node, cached node) {
}
return hashed, cached
default:
- // Value and hash nodes don't have children so they're left as were
+ // Value and hash nodes don't have children, so they're left as were
return n, n
}
}
// hashShortNodeChildren collapses the short node. The returned collapsed node
// holds a live reference to the Key, and must not be modified.
-// The cached
func (h *hasher) hashShortNodeChildren(n *shortNode) (collapsed, cached *shortNode) {
// Hash the short node's child, caching the newly hashed subtree
collapsed, cached = n.copy(), n.copy()
// Previously, we did copy this one. We don't seem to need to actually
// do that, since we don't overwrite/reuse keys
- //cached.Key = common.CopyBytes(n.Key)
+ // cached.Key = common.CopyBytes(n.Key)
collapsed.Key = hexToCompact(n.Key)
// Unless the child is a valuenode or hashnode, hash it
switch n.Val.(type) {
@@ -153,7 +152,7 @@ func (h *hasher) shortnodeToHash(n *shortNode, force bool) node {
return h.hashData(enc)
}
-// shortnodeToHash is used to creates a hashNode from a set of hashNodes, (which
+// fullnodeToHash is used to create a hashNode from a fullNode, (which
// may contain nil values)
func (h *hasher) fullnodeToHash(n *fullNode, force bool) node {
n.encode(h.encbuf)
@@ -203,7 +202,7 @@ func (h *hasher) proofHash(original node) (collapsed, hashed node) {
fn, _ := h.hashFullNodeChildren(n)
return fn, h.fullnodeToHash(fn, false)
default:
- // Value and hash nodes don't have children so they're left as were
+ // Value and hash nodes don't have children, so they're left as were
return n, n
}
}
diff --git a/trie_by_cid/trie/iterator.go b/trie_by_cid/trie/iterator.go
index 07f98e0..f87ecd7 100644
--- a/trie_by_cid/trie/iterator.go
+++ b/trie_by_cid/trie/iterator.go
@@ -26,9 +26,6 @@ import (
gethtrie "github.com/ethereum/go-ethereum/trie"
)
-// NodeIterator is an iterator to traverse the trie pre-order.
-type NodeIterator = gethtrie.NodeIterator
-
// NodeResolver is used for looking up trie nodes before reaching into the real
// persistent layer. This is not mandatory, rather is an optimization for cases
// where trie nodes can be recovered from some external mechanism without reading
@@ -75,6 +72,9 @@ func (it *Iterator) Prove() [][]byte {
return it.nodeIt.LeafProof()
}
+// NodeIterator is an iterator to traverse the trie pre-order.
+type NodeIterator = gethtrie.NodeIterator
+
// nodeIteratorState represents the iteration state at one particular node of the
// trie, which can be resumed at a later invocation.
type nodeIteratorState struct {
@@ -91,7 +91,8 @@ type nodeIterator struct {
path []byte // Path to the current node
err error // Failure set in case of an internal error in the iterator
- resolver NodeResolver // optional node resolver for avoiding disk hits
+ resolver NodeResolver // optional node resolver for avoiding disk hits
+ pool []*nodeIteratorState // local pool for iteratorstates
}
// errIteratorEnd is stored in nodeIterator.err when iteration is done.
@@ -119,6 +120,24 @@ func newNodeIterator(trie *Trie, start []byte) NodeIterator {
return it
}
+func (it *nodeIterator) putInPool(item *nodeIteratorState) {
+ if len(it.pool) < 40 {
+ item.node = nil
+ it.pool = append(it.pool, item)
+ }
+}
+
+func (it *nodeIterator) getFromPool() *nodeIteratorState {
+ idx := len(it.pool) - 1
+ if idx < 0 {
+ return new(nodeIteratorState)
+ }
+ el := it.pool[idx]
+ it.pool[idx] = nil
+ it.pool = it.pool[:idx]
+ return el
+}
+
func (it *nodeIterator) AddResolver(resolver NodeResolver) {
it.resolver = resolver
}
@@ -137,14 +156,6 @@ func (it *nodeIterator) Parent() common.Hash {
return it.stack[len(it.stack)-1].parent
}
-func (it *nodeIterator) ParentPath() []byte {
- if len(it.stack) == 0 {
- return []byte{}
- }
- pathlen := it.stack[len(it.stack)-1].pathlen
- return it.path[:pathlen]
-}
-
func (it *nodeIterator) Leaf() bool {
return hasTerm(it.path)
}
@@ -152,7 +163,7 @@ func (it *nodeIterator) Leaf() bool {
func (it *nodeIterator) LeafKey() []byte {
if len(it.stack) > 0 {
if _, ok := it.stack[len(it.stack)-1].node.(valueNode); ok {
- return hexToKeyBytes(it.path)
+ return hexToKeybytes(it.path)
}
}
panic("not at leaf")
@@ -342,7 +353,14 @@ func (it *nodeIterator) resolveHash(hash hashNode, path []byte) (node, error) {
// loaded blob will be tracked, while it's not required here since
// all loaded nodes won't be linked to trie at all and track nodes
// may lead to out-of-memory issue.
- return it.trie.reader.node(path, common.BytesToHash(hash))
+ blob, err := it.trie.reader.node(path, common.BytesToHash(hash))
+ if err != nil {
+ return nil, err
+ }
+ // The raw-blob format nodes are loaded either from the
+ // clean cache or the database, they are all in their own
+ // copy and safe to use unsafe decoder.
+ return mustDecodeNodeUnsafe(hash, blob), nil
}
func (it *nodeIterator) resolveBlob(hash hashNode, path []byte) ([]byte, error) {
@@ -356,7 +374,7 @@ func (it *nodeIterator) resolveBlob(hash hashNode, path []byte) ([]byte, error)
// loaded blob will be tracked, while it's not required here since
// all loaded nodes won't be linked to trie at all and track nodes
// may lead to out-of-memory issue.
- return it.trie.reader.nodeBlob(path, common.BytesToHash(hash))
+ return it.trie.reader.node(path, common.BytesToHash(hash))
}
func (st *nodeIteratorState) resolve(it *nodeIterator, path []byte) error {
@@ -371,8 +389,9 @@ func (st *nodeIteratorState) resolve(it *nodeIterator, path []byte) error {
return nil
}
-func findChild(n *fullNode, index int, path []byte, ancestor common.Hash) (node, *nodeIteratorState, []byte, int) {
+func (it *nodeIterator) findChild(n *fullNode, index int, ancestor common.Hash) (node, *nodeIteratorState, []byte, int) {
var (
+ path = it.path
child node
state *nodeIteratorState
childPath []byte
@@ -381,13 +400,12 @@ func findChild(n *fullNode, index int, path []byte, ancestor common.Hash) (node,
if n.Children[index] != nil {
child = n.Children[index]
hash, _ := child.cache()
- state = &nodeIteratorState{
- hash: common.BytesToHash(hash),
- node: child,
- parent: ancestor,
- index: -1,
- pathlen: len(path),
- }
+ state = it.getFromPool()
+ state.hash = common.BytesToHash(hash)
+ state.node = child
+ state.parent = ancestor
+ state.index = -1
+ state.pathlen = len(path)
childPath = append(childPath, path...)
childPath = append(childPath, byte(index))
return child, state, childPath, index
@@ -400,7 +418,7 @@ func (it *nodeIterator) nextChild(parent *nodeIteratorState, ancestor common.Has
switch node := parent.node.(type) {
case *fullNode:
// Full node, move to the first non-nil child.
- if child, state, path, index := findChild(node, parent.index+1, it.path, ancestor); child != nil {
+ if child, state, path, index := it.findChild(node, parent.index+1, ancestor); child != nil {
parent.index = index - 1
return state, path, true
}
@@ -408,13 +426,12 @@ func (it *nodeIterator) nextChild(parent *nodeIteratorState, ancestor common.Has
// Short node, return the pointer singleton child
if parent.index < 0 {
hash, _ := node.Val.cache()
- state := &nodeIteratorState{
- hash: common.BytesToHash(hash),
- node: node.Val,
- parent: ancestor,
- index: -1,
- pathlen: len(it.path),
- }
+ state := it.getFromPool()
+ state.hash = common.BytesToHash(hash)
+ state.node = node.Val
+ state.parent = ancestor
+ state.index = -1
+ state.pathlen = len(it.path)
path := append(it.path, node.Key...)
return state, path, true
}
@@ -428,7 +445,7 @@ func (it *nodeIterator) nextChildAt(parent *nodeIteratorState, ancestor common.H
switch n := parent.node.(type) {
case *fullNode:
// Full node, move to the first non-nil child before the desired key position
- child, state, path, index := findChild(n, parent.index+1, it.path, ancestor)
+ child, state, path, index := it.findChild(n, parent.index+1, ancestor)
if child == nil {
// No more children in this fullnode
return parent, it.path, false
@@ -440,7 +457,7 @@ func (it *nodeIterator) nextChildAt(parent *nodeIteratorState, ancestor common.H
}
// The child is before the seek position. Try advancing
for {
- nextChild, nextState, nextPath, nextIndex := findChild(n, index+1, it.path, ancestor)
+ nextChild, nextState, nextPath, nextIndex := it.findChild(n, index+1, ancestor)
// If we run out of children, or skipped past the target, return the
// previous one
if nextChild == nil || bytes.Compare(nextPath, key) >= 0 {
@@ -454,13 +471,12 @@ func (it *nodeIterator) nextChildAt(parent *nodeIteratorState, ancestor common.H
// Short node, return the pointer singleton child
if parent.index < 0 {
hash, _ := n.Val.cache()
- state := &nodeIteratorState{
- hash: common.BytesToHash(hash),
- node: n.Val,
- parent: ancestor,
- index: -1,
- pathlen: len(it.path),
- }
+ state := it.getFromPool()
+ state.hash = common.BytesToHash(hash)
+ state.node = n.Val
+ state.parent = ancestor
+ state.index = -1
+ state.pathlen = len(it.path)
path := append(it.path, n.Key...)
return state, path, true
}
@@ -481,6 +497,8 @@ func (it *nodeIterator) pop() {
it.path = it.path[:last.pathlen]
it.stack[len(it.stack)-1] = nil
it.stack = it.stack[:len(it.stack)-1]
+ // last is now unused
+ it.putInPool(last)
}
func compareNodes(a, b NodeIterator) int {
diff --git a/trie_by_cid/trie/iterator_test.go b/trie_by_cid/trie/iterator_test.go
index 46b03d7..dcabd10 100644
--- a/trie_by_cid/trie/iterator_test.go
+++ b/trie_by_cid/trie/iterator_test.go
@@ -18,32 +18,82 @@ package trie
import (
"bytes"
- "encoding/binary"
"fmt"
"math/rand"
"testing"
"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/core/rawdb"
+ "github.com/ethereum/go-ethereum/core/types"
"github.com/ethereum/go-ethereum/crypto"
- "github.com/ethereum/go-ethereum/ethdb/memorydb"
- geth_trie "github.com/ethereum/go-ethereum/trie"
+ "github.com/ethereum/go-ethereum/ethdb"
+
+ "github.com/cerc-io/ipld-eth-statedb/trie_by_cid/trie/trienode"
)
-var (
- packableTestData = []kvsi{
- {"one", 1},
- {"two", 2},
- {"three", 3},
- {"four", 4},
- {"five", 5},
- {"ten", 10},
+// makeTestTrie create a sample test trie to test node-wise reconstruction.
+func makeTestTrie(scheme string) (ethdb.Database, *testDb, *StateTrie, map[string][]byte) {
+ // Create an empty trie
+ db := rawdb.NewMemoryDatabase()
+ triedb := newTestDatabase(db, scheme)
+ trie, _ := NewStateTrie(TrieID(types.EmptyRootHash), triedb)
+
+ // Fill it with some arbitrary data
+ content := make(map[string][]byte)
+ for i := byte(0); i < 255; i++ {
+ // Map the same data under multiple keys
+ key, val := common.LeftPadBytes([]byte{1, i}, 32), []byte{i}
+ content[string(key)] = val
+ trie.MustUpdate(key, val)
+
+ key, val = common.LeftPadBytes([]byte{2, i}, 32), []byte{i}
+ content[string(key)] = val
+ trie.MustUpdate(key, val)
+
+ // Add some other data to inflate the trie
+ for j := byte(3); j < 13; j++ {
+ key, val = common.LeftPadBytes([]byte{j, i}, 32), []byte{j, i}
+ content[string(key)] = val
+ trie.MustUpdate(key, val)
+ }
}
-)
+ root, nodes, _ := trie.Commit(false)
+ if err := triedb.Update(root, types.EmptyRootHash, trienode.NewWithNodeSet(nodes)); err != nil {
+ panic(fmt.Errorf("failed to commit db %v", err))
+ }
+ if err := triedb.Commit(root); err != nil {
+ panic(err)
+ }
+ // Re-create the trie based on the new state
+ trie, _ = NewStateTrie(TrieID(root), triedb)
+ return db, triedb, trie, content
+}
+
+// checkTrieConsistency checks that all nodes in a trie are indeed present.
+func checkTrieConsistency(db ethdb.Database, scheme string, root common.Hash, rawTrie bool) error {
+ ndb := newTestDatabase(db, scheme)
+ var it NodeIterator
+ if rawTrie {
+ trie, err := New(TrieID(root), ndb)
+ if err != nil {
+ return nil // Consider a non existent state consistent
+ }
+ it = trie.MustNodeIterator(nil)
+ } else {
+ trie, err := NewStateTrie(TrieID(root), ndb)
+ if err != nil {
+ return nil // Consider a non existent state consistent
+ }
+ it = trie.MustNodeIterator(nil)
+ }
+ for it.Next(true) {
+ }
+ return it.Error()
+}
func TestEmptyIterator(t *testing.T) {
- trie := NewEmpty(NewDatabase(rawdb.NewMemoryDatabase()))
- iter := trie.NodeIterator(nil)
+ trie := NewEmpty(newTestDatabase(rawdb.NewMemoryDatabase(), rawdb.HashScheme))
+ iter := trie.MustNodeIterator(nil)
seen := make(map[string]struct{})
for iter.Next(true) {
@@ -55,7 +105,7 @@ func TestEmptyIterator(t *testing.T) {
}
func TestIterator(t *testing.T) {
- db := NewDatabase(rawdb.NewMemoryDatabase())
+ db := newTestDatabase(rawdb.NewMemoryDatabase(), rawdb.HashScheme)
trie := NewEmpty(db)
vals := []struct{ k, v string }{
{"do", "verb"},
@@ -69,14 +119,14 @@ func TestIterator(t *testing.T) {
all := make(map[string]string)
for _, val := range vals {
all[val.k] = val.v
- trie.Update([]byte(val.k), []byte(val.v))
+ trie.MustUpdate([]byte(val.k), []byte(val.v))
}
- root, nodes := trie.Commit(false)
- db.Update(NewWithNodeSet(nodes))
+ root, nodes, _ := trie.Commit(false)
+ db.Update(root, types.EmptyRootHash, trienode.NewWithNodeSet(nodes))
- trie, _ = New(TrieID(root), db, StateTrieCodec)
+ trie, _ = New(TrieID(root), db)
found := make(map[string]string)
- it := NewIterator(trie.NodeIterator(nil))
+ it := NewIterator(trie.MustNodeIterator(nil))
for it.Next() {
found[string(it.Key)] = string(it.Value)
}
@@ -93,20 +143,24 @@ type kv struct {
t bool
}
+func (k *kv) cmp(other *kv) int {
+ return bytes.Compare(k.k, other.k)
+}
+
func TestIteratorLargeData(t *testing.T) {
- trie := NewEmpty(NewDatabase(rawdb.NewMemoryDatabase()))
+ trie := NewEmpty(newTestDatabase(rawdb.NewMemoryDatabase(), rawdb.HashScheme))
vals := make(map[string]*kv)
for i := byte(0); i < 255; i++ {
value := &kv{common.LeftPadBytes([]byte{i}, 32), []byte{i}, false}
value2 := &kv{common.LeftPadBytes([]byte{10, i}, 32), []byte{i}, false}
- trie.Update(value.k, value.v)
- trie.Update(value2.k, value2.v)
+ trie.MustUpdate(value.k, value.v)
+ trie.MustUpdate(value2.k, value2.v)
vals[string(value.k)] = value
vals[string(value2.k)] = value2
}
- it := NewIterator(trie.NodeIterator(nil))
+ it := NewIterator(trie.MustNodeIterator(nil))
for it.Next() {
vals[string(it.Key)].t = true
}
@@ -126,39 +180,65 @@ func TestIteratorLargeData(t *testing.T) {
}
}
+type iterationElement struct {
+ hash common.Hash
+ path []byte
+ blob []byte
+}
+
// Tests that the node iterator indeed walks over the entire database contents.
func TestNodeIteratorCoverage(t *testing.T) {
- db, trie, _ := makeTestTrie(t)
+ testNodeIteratorCoverage(t, rawdb.HashScheme)
+ testNodeIteratorCoverage(t, rawdb.PathScheme)
+}
+
+func testNodeIteratorCoverage(t *testing.T, scheme string) {
// Create some arbitrary test trie to iterate
+ db, nodeDb, trie, _ := makeTestTrie(scheme)
// Gather all the node hashes found by the iterator
- hashes := make(map[common.Hash]struct{})
- for it := trie.NodeIterator(nil); it.Next(true); {
+ var elements = make(map[common.Hash]iterationElement)
+ for it := trie.MustNodeIterator(nil); it.Next(true); {
if it.Hash() != (common.Hash{}) {
- hashes[it.Hash()] = struct{}{}
- }
- }
- // Cross check the hashes and the database itself
- for hash := range hashes {
- if _, err := db.Node(hash, StateTrieCodec); err != nil {
- t.Errorf("failed to retrieve reported node %x: %v", hash, err)
- }
- }
- for hash, obj := range db.dirties {
- if obj != nil && hash != (common.Hash{}) {
- if _, ok := hashes[hash]; !ok {
- t.Errorf("state entry not reported %x", hash)
+ elements[it.Hash()] = iterationElement{
+ hash: it.Hash(),
+ path: common.CopyBytes(it.Path()),
+ blob: common.CopyBytes(it.NodeBlob()),
}
}
}
- it := db.diskdb.NewIterator(nil, nil)
+ // Cross check the hashes and the database itself
+ reader, err := nodeDb.Reader(trie.Hash())
+ if err != nil {
+ t.Fatalf("state is not available %x", trie.Hash())
+ }
+ for _, element := range elements {
+ if blob, err := reader.Node(common.Hash{}, element.path, element.hash); err != nil {
+ t.Errorf("failed to retrieve reported node %x: %v", element.hash, err)
+ } else if !bytes.Equal(blob, element.blob) {
+ t.Errorf("node blob is different, want %v got %v", element.blob, blob)
+ }
+ }
+ var (
+ count int
+ it = db.NewIterator(nil, nil)
+ )
for it.Next() {
- key := it.Key()
- if _, ok := hashes[common.BytesToHash(key)]; !ok {
- t.Errorf("state entry not reported %x", key)
+ res, _, _ := isTrieNode(nodeDb.Scheme(), it.Key(), it.Value())
+ if !res {
+ continue
+ }
+ count += 1
+ if elem, ok := elements[crypto.Keccak256Hash(it.Value())]; !ok {
+ t.Error("state entry not reported")
+ } else if !bytes.Equal(it.Value(), elem.blob) {
+ t.Errorf("node blob is different, want %v got %v", elem.blob, it.Value())
}
}
it.Release()
+ if count != len(elements) {
+ t.Errorf("state entry is mismatched %d %d", count, len(elements))
+ }
}
type kvs struct{ k, v string }
@@ -187,25 +267,25 @@ var testdata2 = []kvs{
}
func TestIteratorSeek(t *testing.T) {
- trie := NewEmpty(NewDatabase(rawdb.NewMemoryDatabase()))
+ trie := NewEmpty(newTestDatabase(rawdb.NewMemoryDatabase(), rawdb.HashScheme))
for _, val := range testdata1 {
- trie.Update([]byte(val.k), []byte(val.v))
+ trie.MustUpdate([]byte(val.k), []byte(val.v))
}
// Seek to the middle.
- it := NewIterator(trie.NodeIterator([]byte("fab")))
+ it := NewIterator(trie.MustNodeIterator([]byte("fab")))
if err := checkIteratorOrder(testdata1[4:], it); err != nil {
t.Fatal(err)
}
// Seek to a non-existent key.
- it = NewIterator(trie.NodeIterator([]byte("barc")))
+ it = NewIterator(trie.MustNodeIterator([]byte("barc")))
if err := checkIteratorOrder(testdata1[1:], it); err != nil {
t.Fatal(err)
}
// Seek beyond the end.
- it = NewIterator(trie.NodeIterator([]byte("z")))
+ it = NewIterator(trie.MustNodeIterator([]byte("z")))
if err := checkIteratorOrder(nil, it); err != nil {
t.Fatal(err)
}
@@ -228,26 +308,26 @@ func checkIteratorOrder(want []kvs, it *Iterator) error {
}
func TestDifferenceIterator(t *testing.T) {
- dba := NewDatabase(rawdb.NewMemoryDatabase())
+ dba := newTestDatabase(rawdb.NewMemoryDatabase(), rawdb.HashScheme)
triea := NewEmpty(dba)
for _, val := range testdata1 {
- triea.Update([]byte(val.k), []byte(val.v))
+ triea.MustUpdate([]byte(val.k), []byte(val.v))
}
- rootA, nodesA := triea.Commit(false)
- dba.Update(NewWithNodeSet(nodesA))
- triea, _ = New(TrieID(rootA), dba, StateTrieCodec)
+ rootA, nodesA, _ := triea.Commit(false)
+ dba.Update(rootA, types.EmptyRootHash, trienode.NewWithNodeSet(nodesA))
+ triea, _ = New(TrieID(rootA), dba)
- dbb := NewDatabase(rawdb.NewMemoryDatabase())
+ dbb := newTestDatabase(rawdb.NewMemoryDatabase(), rawdb.HashScheme)
trieb := NewEmpty(dbb)
for _, val := range testdata2 {
- trieb.Update([]byte(val.k), []byte(val.v))
+ trieb.MustUpdate([]byte(val.k), []byte(val.v))
}
- rootB, nodesB := trieb.Commit(false)
- dbb.Update(NewWithNodeSet(nodesB))
- trieb, _ = New(TrieID(rootB), dbb, StateTrieCodec)
+ rootB, nodesB, _ := trieb.Commit(false)
+ dbb.Update(rootB, types.EmptyRootHash, trienode.NewWithNodeSet(nodesB))
+ trieb, _ = New(TrieID(rootB), dbb)
found := make(map[string]string)
- di, _ := NewDifferenceIterator(triea.NodeIterator(nil), trieb.NodeIterator(nil))
+ di, _ := NewDifferenceIterator(triea.MustNodeIterator(nil), trieb.MustNodeIterator(nil))
it := NewIterator(di)
for it.Next() {
found[string(it.Key)] = string(it.Value)
@@ -270,25 +350,25 @@ func TestDifferenceIterator(t *testing.T) {
}
func TestUnionIterator(t *testing.T) {
- dba := NewDatabase(rawdb.NewMemoryDatabase())
+ dba := newTestDatabase(rawdb.NewMemoryDatabase(), rawdb.HashScheme)
triea := NewEmpty(dba)
for _, val := range testdata1 {
- triea.Update([]byte(val.k), []byte(val.v))
+ triea.MustUpdate([]byte(val.k), []byte(val.v))
}
- rootA, nodesA := triea.Commit(false)
- dba.Update(NewWithNodeSet(nodesA))
- triea, _ = New(TrieID(rootA), dba, StateTrieCodec)
+ rootA, nodesA, _ := triea.Commit(false)
+ dba.Update(rootA, types.EmptyRootHash, trienode.NewWithNodeSet(nodesA))
+ triea, _ = New(TrieID(rootA), dba)
- dbb := NewDatabase(rawdb.NewMemoryDatabase())
+ dbb := newTestDatabase(rawdb.NewMemoryDatabase(), rawdb.HashScheme)
trieb := NewEmpty(dbb)
for _, val := range testdata2 {
- trieb.Update([]byte(val.k), []byte(val.v))
+ trieb.MustUpdate([]byte(val.k), []byte(val.v))
}
- rootB, nodesB := trieb.Commit(false)
- dbb.Update(NewWithNodeSet(nodesB))
- trieb, _ = New(TrieID(rootB), dbb, StateTrieCodec)
+ rootB, nodesB, _ := trieb.Commit(false)
+ dbb.Update(rootB, types.EmptyRootHash, trienode.NewWithNodeSet(nodesB))
+ trieb, _ = New(TrieID(rootB), dbb)
- di, _ := NewUnionIterator([]NodeIterator{triea.NodeIterator(nil), trieb.NodeIterator(nil)})
+ di, _ := NewUnionIterator([]NodeIterator{triea.MustNodeIterator(nil), trieb.MustNodeIterator(nil)})
it := NewIterator(di)
all := []struct{ k, v string }{
@@ -323,86 +403,107 @@ func TestUnionIterator(t *testing.T) {
}
func TestIteratorNoDups(t *testing.T) {
- tr := NewEmpty(NewDatabase(rawdb.NewMemoryDatabase()))
+ db := newTestDatabase(rawdb.NewMemoryDatabase(), rawdb.HashScheme)
+ tr := NewEmpty(db)
for _, val := range testdata1 {
- tr.Update([]byte(val.k), []byte(val.v))
+ tr.MustUpdate([]byte(val.k), []byte(val.v))
}
- checkIteratorNoDups(t, tr.NodeIterator(nil), nil)
+ checkIteratorNoDups(t, tr.MustNodeIterator(nil), nil)
}
// This test checks that nodeIterator.Next can be retried after inserting missing trie nodes.
-func TestIteratorContinueAfterErrorMemonly(t *testing.T) { testIteratorContinueAfterError(t, true) }
+func TestIteratorContinueAfterError(t *testing.T) {
+ testIteratorContinueAfterError(t, false, rawdb.HashScheme)
+ testIteratorContinueAfterError(t, true, rawdb.HashScheme)
+ testIteratorContinueAfterError(t, false, rawdb.PathScheme)
+ testIteratorContinueAfterError(t, true, rawdb.PathScheme)
+}
-func testIteratorContinueAfterError(t *testing.T, memonly bool) {
+func testIteratorContinueAfterError(t *testing.T, memonly bool, scheme string) {
diskdb := rawdb.NewMemoryDatabase()
- triedb := NewDatabase(diskdb)
+ tdb := newTestDatabase(diskdb, scheme)
- tr := NewEmpty(triedb)
+ tr := NewEmpty(tdb)
for _, val := range testdata1 {
- tr.Update([]byte(val.k), []byte(val.v))
+ tr.MustUpdate([]byte(val.k), []byte(val.v))
}
- _, nodes := tr.Commit(false)
- triedb.Update(NewWithNodeSet(nodes))
- // if !memonly {
- // triedb.Commit(tr.Hash(), false)
- // }
- wantNodeCount := checkIteratorNoDups(t, tr.NodeIterator(nil), nil)
+ root, nodes, _ := tr.Commit(false)
+ tdb.Update(root, types.EmptyRootHash, trienode.NewWithNodeSet(nodes))
+ if !memonly {
+ tdb.Commit(root)
+ }
+ tr, _ = New(TrieID(root), tdb)
+ wantNodeCount := checkIteratorNoDups(t, tr.MustNodeIterator(nil), nil)
var (
- diskKeys [][]byte
- memKeys []common.Hash
+ paths [][]byte
+ hashes []common.Hash
)
if memonly {
- memKeys = triedb.Nodes()
+ for path, n := range nodes.Nodes {
+ paths = append(paths, []byte(path))
+ hashes = append(hashes, n.Hash)
+ }
} else {
it := diskdb.NewIterator(nil, nil)
for it.Next() {
- diskKeys = append(diskKeys, it.Key())
+ ok, path, hash := isTrieNode(tdb.Scheme(), it.Key(), it.Value())
+ if !ok {
+ continue
+ }
+ paths = append(paths, path)
+ hashes = append(hashes, hash)
}
it.Release()
}
for i := 0; i < 20; i++ {
// Create trie that will load all nodes from DB.
- tr, _ := New(TrieID(tr.Hash()), triedb, StateTrieCodec)
+ tr, _ := New(TrieID(tr.Hash()), tdb)
// Remove a random node from the database. It can't be the root node
// because that one is already loaded.
var (
- rkey common.Hash
- rval []byte
- robj *cachedNode
+ rval []byte
+ rpath []byte
+ rhash common.Hash
)
for {
if memonly {
- rkey = memKeys[rand.Intn(len(memKeys))]
+ rpath = paths[rand.Intn(len(paths))]
+ n := nodes.Nodes[string(rpath)]
+ if n == nil {
+ continue
+ }
+ rhash = n.Hash
} else {
- copy(rkey[:], diskKeys[rand.Intn(len(diskKeys))])
+ index := rand.Intn(len(paths))
+ rpath = paths[index]
+ rhash = hashes[index]
}
- if rkey != tr.Hash() {
+ if rhash != tr.Hash() {
break
}
}
if memonly {
- robj = triedb.dirties[rkey]
- delete(triedb.dirties, rkey)
+ tr.reader.banned = map[string]struct{}{string(rpath): {}}
} else {
- rval, _ = diskdb.Get(rkey[:])
- diskdb.Delete(rkey[:])
+ rval = rawdb.ReadTrieNode(diskdb, common.Hash{}, rpath, rhash, tdb.Scheme())
+ rawdb.DeleteTrieNode(diskdb, common.Hash{}, rpath, rhash, tdb.Scheme())
}
// Iterate until the error is hit.
seen := make(map[string]bool)
- it := tr.NodeIterator(nil)
+ it := tr.MustNodeIterator(nil)
checkIteratorNoDups(t, it, seen)
missing, ok := it.Error().(*MissingNodeError)
- if !ok || missing.NodeHash != rkey {
+ if !ok || missing.NodeHash != rhash {
t.Fatal("didn't hit missing node, got", it.Error())
}
// Add the node back and continue iteration.
if memonly {
- triedb.dirties[rkey] = robj
+ delete(tr.reader.banned, string(rpath))
} else {
- diskdb.Put(rkey[:], rval)
+ rawdb.WriteTrieNode(diskdb, common.Hash{}, rpath, rhash, rval, tdb.Scheme())
}
checkIteratorNoDups(t, it, seen)
if it.Error() != nil {
@@ -417,40 +518,49 @@ func testIteratorContinueAfterError(t *testing.T, memonly bool) {
// Similar to the test above, this one checks that failure to create nodeIterator at a
// certain key prefix behaves correctly when Next is called. The expectation is that Next
// should retry seeking before returning true for the first time.
-func TestIteratorContinueAfterSeekErrorMemonly(t *testing.T) {
- testIteratorContinueAfterSeekError(t, true)
+func TestIteratorContinueAfterSeekError(t *testing.T) {
+ testIteratorContinueAfterSeekError(t, false, rawdb.HashScheme)
+ testIteratorContinueAfterSeekError(t, true, rawdb.HashScheme)
+ testIteratorContinueAfterSeekError(t, false, rawdb.PathScheme)
+ testIteratorContinueAfterSeekError(t, true, rawdb.PathScheme)
}
-func testIteratorContinueAfterSeekError(t *testing.T, memonly bool) {
+func testIteratorContinueAfterSeekError(t *testing.T, memonly bool, scheme string) {
// Commit test trie to db, then remove the node containing "bars".
+ var (
+ barNodePath []byte
+ barNodeHash = common.HexToHash("05041990364eb72fcb1127652ce40d8bab765f2bfe53225b1170d276cc101c2e")
+ )
diskdb := rawdb.NewMemoryDatabase()
- triedb := NewDatabase(diskdb)
-
+ triedb := newTestDatabase(diskdb, scheme)
ctr := NewEmpty(triedb)
for _, val := range testdata1 {
- ctr.Update([]byte(val.k), []byte(val.v))
+ ctr.MustUpdate([]byte(val.k), []byte(val.v))
+ }
+ root, nodes, _ := ctr.Commit(false)
+ for path, n := range nodes.Nodes {
+ if n.Hash == barNodeHash {
+ barNodePath = []byte(path)
+ break
+ }
+ }
+ triedb.Update(root, types.EmptyRootHash, trienode.NewWithNodeSet(nodes))
+ if !memonly {
+ triedb.Commit(root)
}
- root, nodes := ctr.Commit(false)
- triedb.Update(NewWithNodeSet(nodes))
- // if !memonly {
- // triedb.Commit(root, false)
- // }
- barNodeHash := common.HexToHash("05041990364eb72fcb1127652ce40d8bab765f2bfe53225b1170d276cc101c2e")
var (
barNodeBlob []byte
- barNodeObj *cachedNode
)
+ tr, _ := New(TrieID(root), triedb)
if memonly {
- barNodeObj = triedb.dirties[barNodeHash]
- delete(triedb.dirties, barNodeHash)
+ tr.reader.banned = map[string]struct{}{string(barNodePath): {}}
} else {
- barNodeBlob, _ = diskdb.Get(barNodeHash[:])
- diskdb.Delete(barNodeHash[:])
+ barNodeBlob = rawdb.ReadTrieNode(diskdb, common.Hash{}, barNodePath, barNodeHash, triedb.Scheme())
+ rawdb.DeleteTrieNode(diskdb, common.Hash{}, barNodePath, barNodeHash, triedb.Scheme())
}
// Create a new iterator that seeks to "bars". Seeking can't proceed because
// the node is missing.
- tr, _ := New(TrieID(root), triedb, StateTrieCodec)
- it := tr.NodeIterator([]byte("bars"))
+ it := tr.MustNodeIterator([]byte("bars"))
missing, ok := it.Error().(*MissingNodeError)
if !ok {
t.Fatal("want MissingNodeError, got", it.Error())
@@ -459,9 +569,9 @@ func testIteratorContinueAfterSeekError(t *testing.T, memonly bool) {
}
// Reinsert the missing node.
if memonly {
- triedb.dirties[barNodeHash] = barNodeObj
+ delete(tr.reader.banned, string(barNodePath))
} else {
- diskdb.Put(barNodeHash[:], barNodeBlob)
+ rawdb.WriteTrieNode(diskdb, common.Hash{}, barNodePath, barNodeHash, barNodeBlob, triedb.Scheme())
}
// Check that iteration produces the right set of values.
if err := checkIteratorOrder(testdata1[2:], NewIterator(it)); err != nil {
@@ -482,96 +592,38 @@ func checkIteratorNoDups(t *testing.T, it NodeIterator, seen map[string]bool) in
return len(seen)
}
-type loggingTrieDb struct {
- *Database
- getCount uint64
-}
-
-// GetReader retrieves a node reader belonging to the given state root.
-func (db *loggingTrieDb) GetReader(root common.Hash, codec uint64) Reader {
- return &loggingNodeReader{db, codec}
-}
-
-// loggingNodeReader is a trie node reader that logs the number of get operations.
-type loggingNodeReader struct {
- db *loggingTrieDb
- codec uint64
-}
-
-// Node retrieves the trie node with the given node hash.
-func (reader *loggingNodeReader) Node(owner common.Hash, path []byte, hash common.Hash) (node, error) {
- blob, err := reader.NodeBlob(owner, path, hash)
- if err != nil {
- return nil, err
- }
- return decodeNodeUnsafe(hash[:], blob)
-}
-
-// NodeBlob retrieves the RLP-encoded trie node blob with the given node hash.
-func (reader *loggingNodeReader) NodeBlob(_ common.Hash, _ []byte, hash common.Hash) ([]byte, error) {
- reader.db.getCount++
- return reader.db.Node(hash, reader.codec)
-}
-
-func newLoggingStateTrie(id *ID, db *Database, codec uint64) (*StateTrie, *loggingTrieDb, error) {
- logdb := &loggingTrieDb{Database: db}
- trie, err := New(id, logdb, codec)
- if err != nil {
- return nil, nil, err
- }
- return &StateTrie{trie: *trie, preimages: db.preimages}, logdb, nil
-}
-
-// makeLargeTestTrie create a sample test trie
-func makeLargeTestTrie(t testing.TB) (*Database, *StateTrie, *loggingTrieDb) {
- // Create an empty trie
- triedb := NewDatabase(rawdb.NewDatabase(memorydb.New()))
- trie, logDb, err := newLoggingStateTrie(TrieID(common.Hash{}), triedb, StateTrieCodec)
- if err != nil {
- t.Fatal(err)
- }
-
- // Fill it with some arbitrary data
- for i := 0; i < 10000; i++ {
- key := make([]byte, 32)
- val := make([]byte, 32)
- binary.BigEndian.PutUint64(key, uint64(i))
- binary.BigEndian.PutUint64(val, uint64(i))
- key = crypto.Keccak256(key)
- val = crypto.Keccak256(val)
- trie.Update(key, val)
- }
- _, nodes := trie.Commit(false)
- triedb.Update(NewWithNodeSet(nodes))
- // Return the generated trie
- return triedb, trie, logDb
-}
-
-// Tests that the node iterator indeed walks over the entire database contents.
-func TestNodeIteratorLargeTrie(t *testing.T) {
- // Create some arbitrary test trie to iterate
- _, trie, logDb := makeLargeTestTrie(t)
- // Do a seek operation
- trie.NodeIterator(common.FromHex("0x77667766776677766778855885885885"))
- // master: 24 get operations
- // this pr: 5 get operations
- if have, want := logDb.getCount, uint64(5); have != want {
- t.Fatalf("Wrong number of lookups during seek, have %d want %d", have, want)
- }
-}
-
func TestIteratorNodeBlob(t *testing.T) {
- edb := rawdb.NewMemoryDatabase()
- db := geth_trie.NewDatabase(edb)
- orig := geth_trie.NewEmpty(geth_trie.NewDatabase(rawdb.NewMemoryDatabase()))
- if _, err := updateTrie(orig, packableTestData); err != nil {
- t.Fatal(err)
- }
- root := commitTrie(t, db, orig)
- trie := indexTrie(t, edb, root)
+ testIteratorNodeBlob(t, rawdb.HashScheme)
+ testIteratorNodeBlob(t, rawdb.PathScheme)
+}
- found := make(map[common.Hash][]byte)
- it := trie.NodeIterator(nil)
+func testIteratorNodeBlob(t *testing.T, scheme string) {
+ var (
+ db = rawdb.NewMemoryDatabase()
+ triedb = newTestDatabase(db, scheme)
+ trie = NewEmpty(triedb)
+ )
+ vals := []struct{ k, v string }{
+ {"do", "verb"},
+ {"ether", "wookiedoo"},
+ {"horse", "stallion"},
+ {"shaman", "horse"},
+ {"doge", "coin"},
+ {"dog", "puppy"},
+ {"somethingveryoddindeedthis is", "myothernodedata"},
+ }
+ all := make(map[string]string)
+ for _, val := range vals {
+ all[val.k] = val.v
+ trie.MustUpdate([]byte(val.k), []byte(val.v))
+ }
+ root, nodes, _ := trie.Commit(false)
+ triedb.Update(root, types.EmptyRootHash, trienode.NewWithNodeSet(nodes))
+ triedb.Commit(root)
+
+ var found = make(map[common.Hash][]byte)
+ trie, _ = New(TrieID(root), triedb)
+ it := trie.MustNodeIterator(nil)
for it.Next(true) {
if it.Hash() == (common.Hash{}) {
continue
@@ -579,14 +631,18 @@ func TestIteratorNodeBlob(t *testing.T) {
found[it.Hash()] = it.NodeBlob()
}
- dbIter := edb.NewIterator(nil, nil)
+ dbIter := db.NewIterator(nil, nil)
defer dbIter.Release()
var count int
for dbIter.Next() {
- got, present := found[common.BytesToHash(dbIter.Key())]
+ ok, _, _ := isTrieNode(triedb.Scheme(), dbIter.Key(), dbIter.Value())
+ if !ok {
+ continue
+ }
+ got, present := found[crypto.Keccak256Hash(dbIter.Value())]
if !present {
- t.Fatalf("Miss trie node %v", dbIter.Key())
+ t.Fatal("Miss trie node")
}
if !bytes.Equal(got, dbIter.Value()) {
t.Fatalf("Unexpected trie node want %v got %v", dbIter.Value(), got)
@@ -594,6 +650,44 @@ func TestIteratorNodeBlob(t *testing.T) {
count += 1
}
if count != len(found) {
- t.Fatalf("Wrong number of trie nodes found, want %d, got %d", len(found), count)
+ t.Fatal("Find extra trie node via iterator")
+ }
+}
+
+// isTrieNode is a helper function which reports if the provided
+// database entry belongs to a trie node or not. Note in tests
+// only single layer trie is used, namely storage trie is not
+// considered at all.
+func isTrieNode(scheme string, key, val []byte) (bool, []byte, common.Hash) {
+ var (
+ path []byte
+ hash common.Hash
+ )
+ if scheme == rawdb.HashScheme {
+ ok := rawdb.IsLegacyTrieNode(key, val)
+ if !ok {
+ return false, nil, common.Hash{}
+ }
+ hash = common.BytesToHash(key)
+ } else {
+ ok, remain := rawdb.ResolveAccountTrieNodeKey(key)
+ if !ok {
+ return false, nil, common.Hash{}
+ }
+ path = common.CopyBytes(remain)
+ hash = crypto.Keccak256Hash(val)
+ }
+ return true, path, hash
+}
+
+func BenchmarkIterator(b *testing.B) {
+ diskDb, srcDb, tr, _ := makeTestTrie(rawdb.HashScheme)
+ root := tr.Hash()
+ b.ReportAllocs()
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ if err := checkTrieConsistency(diskDb, srcDb.Scheme(), root, false); err != nil {
+ b.Fatal(err)
+ }
}
}
diff --git a/trie_by_cid/trie/node.go b/trie_by_cid/trie/node.go
index 6ce6551..15bbf62 100644
--- a/trie_by_cid/trie/node.go
+++ b/trie_by_cid/trie/node.go
@@ -99,6 +99,19 @@ func (n valueNode) fstring(ind string) string {
return fmt.Sprintf("%x ", []byte(n))
}
+// rawNode is a simple binary blob used to differentiate between collapsed trie
+// nodes and already encoded RLP binary blobs (while at the same time store them
+// in the same cache fields).
+type rawNode []byte
+
+func (n rawNode) cache() (hashNode, bool) { panic("this should never end up in a live trie") }
+func (n rawNode) fstring(ind string) string { panic("this should never end up in a live trie") }
+
+func (n rawNode) EncodeRLP(w io.Writer) error {
+ _, err := w.Write(n)
+ return err
+}
+
// mustDecodeNode is a wrapper of decodeNode and panic if any error is encountered.
func mustDecodeNode(hash, buf []byte) node {
n, err := decodeNode(hash, buf)
diff --git a/trie_by_cid/trie/node_enc.go b/trie_by_cid/trie/node_enc.go
index cade35b..1b2eca6 100644
--- a/trie_by_cid/trie/node_enc.go
+++ b/trie_by_cid/trie/node_enc.go
@@ -59,29 +59,6 @@ func (n valueNode) encode(w rlp.EncoderBuffer) {
w.WriteBytes(n)
}
-func (n rawFullNode) encode(w rlp.EncoderBuffer) {
- offset := w.List()
- for _, c := range n {
- if c != nil {
- c.encode(w)
- } else {
- w.Write(rlp.EmptyString)
- }
- }
- w.ListEnd(offset)
-}
-
-func (n *rawShortNode) encode(w rlp.EncoderBuffer) {
- offset := w.List()
- w.WriteBytes(n.Key)
- if n.Val != nil {
- n.Val.encode(w)
- } else {
- w.Write(rlp.EmptyString)
- }
- w.ListEnd(offset)
-}
-
func (n rawNode) encode(w rlp.EncoderBuffer) {
w.Write(n)
}
diff --git a/trie_by_cid/trie/node_test.go b/trie_by_cid/trie/node_test.go
index 9b8b337..3552957 100644
--- a/trie_by_cid/trie/node_test.go
+++ b/trie_by_cid/trie/node_test.go
@@ -96,7 +96,7 @@ func TestDecodeFullNode(t *testing.T) {
// goos: darwin
// goarch: arm64
-// pkg: github.com/ethereum/go-ethereum/trie
+// pkg: github.com/cerc-io/ipld-eth-statedb/trie_by_cid/trie
// BenchmarkEncodeShortNode
// BenchmarkEncodeShortNode-8 16878850 70.81 ns/op 48 B/op 1 allocs/op
func BenchmarkEncodeShortNode(b *testing.B) {
@@ -114,7 +114,7 @@ func BenchmarkEncodeShortNode(b *testing.B) {
// goos: darwin
// goarch: arm64
-// pkg: github.com/ethereum/go-ethereum/trie
+// pkg: github.com/cerc-io/ipld-eth-statedb/trie_by_cid/trie
// BenchmarkEncodeFullNode
// BenchmarkEncodeFullNode-8 4323273 284.4 ns/op 576 B/op 1 allocs/op
func BenchmarkEncodeFullNode(b *testing.B) {
@@ -132,7 +132,7 @@ func BenchmarkEncodeFullNode(b *testing.B) {
// goos: darwin
// goarch: arm64
-// pkg: github.com/ethereum/go-ethereum/trie
+// pkg: github.com/cerc-io/ipld-eth-statedb/trie_by_cid/trie
// BenchmarkDecodeShortNode
// BenchmarkDecodeShortNode-8 7925638 151.0 ns/op 157 B/op 4 allocs/op
func BenchmarkDecodeShortNode(b *testing.B) {
@@ -153,7 +153,7 @@ func BenchmarkDecodeShortNode(b *testing.B) {
// goos: darwin
// goarch: arm64
-// pkg: github.com/ethereum/go-ethereum/trie
+// pkg: github.com/cerc-io/ipld-eth-statedb/trie_by_cid/trie
// BenchmarkDecodeShortNodeUnsafe
// BenchmarkDecodeShortNodeUnsafe-8 9027476 128.6 ns/op 109 B/op 3 allocs/op
func BenchmarkDecodeShortNodeUnsafe(b *testing.B) {
@@ -174,7 +174,7 @@ func BenchmarkDecodeShortNodeUnsafe(b *testing.B) {
// goos: darwin
// goarch: arm64
-// pkg: github.com/ethereum/go-ethereum/trie
+// pkg: github.com/cerc-io/ipld-eth-statedb/trie_by_cid/trie
// BenchmarkDecodeFullNode
// BenchmarkDecodeFullNode-8 1597462 761.9 ns/op 1280 B/op 18 allocs/op
func BenchmarkDecodeFullNode(b *testing.B) {
@@ -195,7 +195,7 @@ func BenchmarkDecodeFullNode(b *testing.B) {
// goos: darwin
// goarch: arm64
-// pkg: github.com/ethereum/go-ethereum/trie
+// pkg: github.com/cerc-io/ipld-eth-statedb/trie_by_cid/trie
// BenchmarkDecodeFullNodeUnsafe
// BenchmarkDecodeFullNodeUnsafe-8 1789070 687.1 ns/op 704 B/op 17 allocs/op
func BenchmarkDecodeFullNodeUnsafe(b *testing.B) {
diff --git a/trie_by_cid/trie/nodeset.go b/trie_by_cid/trie/nodeset.go
deleted file mode 100644
index 99e4a80..0000000
--- a/trie_by_cid/trie/nodeset.go
+++ /dev/null
@@ -1,218 +0,0 @@
-// Copyright 2022 The go-ethereum Authors
-// This file is part of the go-ethereum library.
-//
-// The go-ethereum library is free software: you can redistribute it and/or modify
-// it under the terms of the GNU Lesser General Public License as published by
-// the Free Software Foundation, either version 3 of the License, or
-// (at your option) any later version.
-//
-// The go-ethereum library is distributed in the hope that it will be useful,
-// but WITHOUT ANY WARRANTY; without even the implied warranty of
-// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
-// GNU Lesser General Public License for more details.
-//
-// You should have received a copy of the GNU Lesser General Public License
-// along with the go-ethereum library. If not, see .
-
-package trie
-
-import (
- "fmt"
- "reflect"
- "sort"
- "strings"
-
- "github.com/ethereum/go-ethereum/common"
-)
-
-// memoryNode is all the information we know about a single cached trie node
-// in the memory.
-type memoryNode struct {
- hash common.Hash // Node hash, computed by hashing rlp value, empty for deleted nodes
- size uint16 // Byte size of the useful cached data, 0 for deleted nodes
- node node // Cached collapsed trie node, or raw rlp data, nil for deleted nodes
-}
-
-// memoryNodeSize is the raw size of a memoryNode data structure without any
-// node data included. It's an approximate size, but should be a lot better
-// than not counting them.
-// nolint:unused
-var memoryNodeSize = int(reflect.TypeOf(memoryNode{}).Size())
-
-// memorySize returns the total memory size used by this node.
-// nolint:unused
-func (n *memoryNode) memorySize(pathlen int) int {
- return int(n.size) + memoryNodeSize + pathlen
-}
-
-// rlp returns the raw rlp encoded blob of the cached trie node, either directly
-// from the cache, or by regenerating it from the collapsed node.
-// nolint:unused
-func (n *memoryNode) rlp() []byte {
- if node, ok := n.node.(rawNode); ok {
- return node
- }
- return nodeToBytes(n.node)
-}
-
-// obj returns the decoded and expanded trie node, either directly from the cache,
-// or by regenerating it from the rlp encoded blob.
-// nolint:unused
-func (n *memoryNode) obj() node {
- if node, ok := n.node.(rawNode); ok {
- return mustDecodeNode(n.hash[:], node)
- }
- return expandNode(n.hash[:], n.node)
-}
-
-// isDeleted returns the indicator if the node is marked as deleted.
-func (n *memoryNode) isDeleted() bool {
- return n.hash == (common.Hash{})
-}
-
-// nodeWithPrev wraps the memoryNode with the previous node value.
-// nolint: unused
-type nodeWithPrev struct {
- *memoryNode
- prev []byte // RLP-encoded previous value, nil means it's non-existent
-}
-
-// unwrap returns the internal memoryNode object.
-// nolint:unused
-func (n *nodeWithPrev) unwrap() *memoryNode {
- return n.memoryNode
-}
-
-// memorySize returns the total memory size used by this node. It overloads
-// the function in memoryNode by counting the size of previous value as well.
-// nolint: unused
-func (n *nodeWithPrev) memorySize(pathlen int) int {
- return n.memoryNode.memorySize(pathlen) + len(n.prev)
-}
-
-// NodeSet contains all dirty nodes collected during the commit operation.
-// Each node is keyed by path. It's not thread-safe to use.
-type NodeSet struct {
- owner common.Hash // the identifier of the trie
- nodes map[string]*memoryNode // the set of dirty nodes(inserted, updated, deleted)
- leaves []*leaf // the list of dirty leaves
- updates int // the count of updated and inserted nodes
- deletes int // the count of deleted nodes
-
- // The list of accessed nodes, which records the original node value.
- // The origin value is expected to be nil for newly inserted node
- // and is expected to be non-nil for other types(updated, deleted).
- accessList map[string][]byte
-}
-
-// NewNodeSet initializes an empty node set to be used for tracking dirty nodes
-// from a specific account or storage trie. The owner is zero for the account
-// trie and the owning account address hash for storage tries.
-func NewNodeSet(owner common.Hash, accessList map[string][]byte) *NodeSet {
- return &NodeSet{
- owner: owner,
- nodes: make(map[string]*memoryNode),
- accessList: accessList,
- }
-}
-
-// forEachWithOrder iterates the dirty nodes with the order from bottom to top,
-// right to left, nodes with the longest path will be iterated first.
-func (set *NodeSet) forEachWithOrder(callback func(path string, n *memoryNode)) {
- var paths sort.StringSlice
- for path := range set.nodes {
- paths = append(paths, path)
- }
- // Bottom-up, longest path first
- sort.Sort(sort.Reverse(paths))
- for _, path := range paths {
- callback(path, set.nodes[path])
- }
-}
-
-// markUpdated marks the node as dirty(newly-inserted or updated).
-func (set *NodeSet) markUpdated(path []byte, node *memoryNode) {
- set.nodes[string(path)] = node
- set.updates += 1
-}
-
-// markDeleted marks the node as deleted.
-func (set *NodeSet) markDeleted(path []byte) {
- set.nodes[string(path)] = &memoryNode{}
- set.deletes += 1
-}
-
-// addLeaf collects the provided leaf node into set.
-func (set *NodeSet) addLeaf(node *leaf) {
- set.leaves = append(set.leaves, node)
-}
-
-// Size returns the number of dirty nodes in set.
-func (set *NodeSet) Size() (int, int) {
- return set.updates, set.deletes
-}
-
-// Hashes returns the hashes of all updated nodes. TODO(rjl493456442) how can
-// we get rid of it?
-func (set *NodeSet) Hashes() []common.Hash {
- var ret []common.Hash
- for _, node := range set.nodes {
- ret = append(ret, node.hash)
- }
- return ret
-}
-
-// Summary returns a string-representation of the NodeSet.
-func (set *NodeSet) Summary() string {
- var out = new(strings.Builder)
- fmt.Fprintf(out, "nodeset owner: %v\n", set.owner)
- if set.nodes != nil {
- for path, n := range set.nodes {
- // Deletion
- if n.isDeleted() {
- fmt.Fprintf(out, " [-]: %x prev: %x\n", path, set.accessList[path])
- continue
- }
- // Insertion
- origin, ok := set.accessList[path]
- if !ok {
- fmt.Fprintf(out, " [+]: %x -> %v\n", path, n.hash)
- continue
- }
- // Update
- fmt.Fprintf(out, " [*]: %x -> %v prev: %x\n", path, n.hash, origin)
- }
- }
- for _, n := range set.leaves {
- fmt.Fprintf(out, "[leaf]: %v\n", n)
- }
- return out.String()
-}
-
-// MergedNodeSet represents a merged dirty node set for a group of tries.
-type MergedNodeSet struct {
- sets map[common.Hash]*NodeSet
-}
-
-// NewMergedNodeSet initializes an empty merged set.
-func NewMergedNodeSet() *MergedNodeSet {
- return &MergedNodeSet{sets: make(map[common.Hash]*NodeSet)}
-}
-
-// NewWithNodeSet constructs a merged nodeset with the provided single set.
-func NewWithNodeSet(set *NodeSet) *MergedNodeSet {
- merged := NewMergedNodeSet()
- merged.Merge(set)
- return merged
-}
-
-// Merge merges the provided dirty nodes of a trie into the set. The assumption
-// is held that no duplicated set belonging to the same trie will be merged twice.
-func (set *MergedNodeSet) Merge(other *NodeSet) error {
- _, present := set.sets[other.owner]
- if present {
- return fmt.Errorf("duplicate trie for owner %#x", other.owner)
- }
- set.sets[other.owner] = other
- return nil
-}
diff --git a/trie_by_cid/trie/proof.go b/trie_by_cid/trie/proof.go
index 7315c0d..fd892fb 100644
--- a/trie_by_cid/trie/proof.go
+++ b/trie_by_cid/trie/proof.go
@@ -18,17 +18,14 @@ package trie
import (
"bytes"
+ "errors"
"fmt"
"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/ethdb"
- "github.com/ethereum/go-ethereum/trie"
- log "github.com/sirupsen/logrus"
+ "github.com/ethereum/go-ethereum/log"
)
-var VerifyProof = trie.VerifyProof
-var VerifyRangeProof = trie.VerifyRangeProof
-
// Prove constructs a merkle proof for key. The result contains all encoded nodes
// on the path to the value at key. The value itself is also included in the last
// node and can be retrieved by verifying the proof.
@@ -36,7 +33,11 @@ var VerifyRangeProof = trie.VerifyRangeProof
// If the trie does not contain a value for key, the returned proof contains all
// nodes of the longest existing prefix of the key (at least the root node), ending
// with the node that proves the absence of the key.
-func (t *Trie) Prove(key []byte, fromLevel uint, proofDb ethdb.KeyValueWriter) error {
+func (t *Trie) Prove(key []byte, proofDb ethdb.KeyValueWriter) error {
+ // Short circuit if the trie is already committed and not usable.
+ if t.committed {
+ return ErrCommitted
+ }
// Collect all nodes on the path to key.
var (
prefix []byte
@@ -67,12 +68,15 @@ func (t *Trie) Prove(key []byte, fromLevel uint, proofDb ethdb.KeyValueWriter) e
// loaded blob will be tracked, while it's not required here since
// all loaded nodes won't be linked to trie at all and track nodes
// may lead to out-of-memory issue.
- var err error
- tn, err = t.reader.node(prefix, common.BytesToHash(n))
+ blob, err := t.reader.node(prefix, common.BytesToHash(n))
if err != nil {
log.Error("Unhandled trie error in Trie.Prove", "err", err)
return err
}
+ // The raw-blob format nodes are loaded either from the
+ // clean cache or the database, they are all in their own
+ // copy and safe to use unsafe decoder.
+ tn = mustDecodeNodeUnsafe(n, blob)
default:
panic(fmt.Sprintf("%T: invalid node: %v", tn, tn))
}
@@ -81,10 +85,6 @@ func (t *Trie) Prove(key []byte, fromLevel uint, proofDb ethdb.KeyValueWriter) e
defer returnHasherToPool(hasher)
for i, n := range nodes {
- if fromLevel > 0 {
- fromLevel--
- continue
- }
var hn node
n, hn = hasher.proofHash(n)
if hash, ok := hn.(hashNode); ok || i == 0 {
@@ -107,6 +107,510 @@ func (t *Trie) Prove(key []byte, fromLevel uint, proofDb ethdb.KeyValueWriter) e
// If the trie does not contain a value for key, the returned proof contains all
// nodes of the longest existing prefix of the key (at least the root node), ending
// with the node that proves the absence of the key.
-func (t *StateTrie) Prove(key []byte, fromLevel uint, proofDb ethdb.KeyValueWriter) error {
- return t.trie.Prove(key, fromLevel, proofDb)
+func (t *StateTrie) Prove(key []byte, proofDb ethdb.KeyValueWriter) error {
+ return t.trie.Prove(key, proofDb)
+}
+
+// VerifyProof checks merkle proofs. The given proof must contain the value for
+// key in a trie with the given root hash. VerifyProof returns an error if the
+// proof contains invalid trie nodes or the wrong value.
+func VerifyProof(rootHash common.Hash, key []byte, proofDb ethdb.KeyValueReader) (value []byte, err error) {
+ key = keybytesToHex(key)
+ wantHash := rootHash
+ for i := 0; ; i++ {
+ buf, _ := proofDb.Get(wantHash[:])
+ if buf == nil {
+ return nil, fmt.Errorf("proof node %d (hash %064x) missing", i, wantHash)
+ }
+ n, err := decodeNode(wantHash[:], buf)
+ if err != nil {
+ return nil, fmt.Errorf("bad proof node %d: %v", i, err)
+ }
+ keyrest, cld := get(n, key, true)
+ switch cld := cld.(type) {
+ case nil:
+ // The trie doesn't contain the key.
+ return nil, nil
+ case hashNode:
+ key = keyrest
+ copy(wantHash[:], cld)
+ case valueNode:
+ return cld, nil
+ }
+ }
+}
+
+// proofToPath converts a merkle proof to trie node path. The main purpose of
+// this function is recovering a node path from the merkle proof stream. All
+// necessary nodes will be resolved and leave the remaining as hashnode.
+//
+// The given edge proof is allowed to be an existent or non-existent proof.
+func proofToPath(rootHash common.Hash, root node, key []byte, proofDb ethdb.KeyValueReader, allowNonExistent bool) (node, []byte, error) {
+ // resolveNode retrieves and resolves trie node from merkle proof stream
+ resolveNode := func(hash common.Hash) (node, error) {
+ buf, _ := proofDb.Get(hash[:])
+ if buf == nil {
+ return nil, fmt.Errorf("proof node (hash %064x) missing", hash)
+ }
+ n, err := decodeNode(hash[:], buf)
+ if err != nil {
+ return nil, fmt.Errorf("bad proof node %v", err)
+ }
+ return n, err
+ }
+ // If the root node is empty, resolve it first.
+ // Root node must be included in the proof.
+ if root == nil {
+ n, err := resolveNode(rootHash)
+ if err != nil {
+ return nil, nil, err
+ }
+ root = n
+ }
+ var (
+ err error
+ child, parent node
+ keyrest []byte
+ valnode []byte
+ )
+ key, parent = keybytesToHex(key), root
+ for {
+ keyrest, child = get(parent, key, false)
+ switch cld := child.(type) {
+ case nil:
+ // The trie doesn't contain the key. It's possible
+ // the proof is a non-existing proof, but at least
+ // we can prove all resolved nodes are correct, it's
+ // enough for us to prove range.
+ if allowNonExistent {
+ return root, nil, nil
+ }
+ return nil, nil, errors.New("the node is not contained in trie")
+ case *shortNode:
+ key, parent = keyrest, child // Already resolved
+ continue
+ case *fullNode:
+ key, parent = keyrest, child // Already resolved
+ continue
+ case hashNode:
+ child, err = resolveNode(common.BytesToHash(cld))
+ if err != nil {
+ return nil, nil, err
+ }
+ case valueNode:
+ valnode = cld
+ }
+ // Link the parent and child.
+ switch pnode := parent.(type) {
+ case *shortNode:
+ pnode.Val = child
+ case *fullNode:
+ pnode.Children[key[0]] = child
+ default:
+ panic(fmt.Sprintf("%T: invalid node: %v", pnode, pnode))
+ }
+ if len(valnode) > 0 {
+ return root, valnode, nil // The whole path is resolved
+ }
+ key, parent = keyrest, child
+ }
+}
+
+// unsetInternal removes all internal node references(hashnode, embedded node).
+// It should be called after a trie is constructed with two edge paths. Also
+// the given boundary keys must be the one used to construct the edge paths.
+//
+// It's the key step for range proof. All visited nodes should be marked dirty
+// since the node content might be modified. Besides it can happen that some
+// fullnodes only have one child which is disallowed. But if the proof is valid,
+// the missing children will be filled, otherwise it will be thrown anyway.
+//
+// Note we have the assumption here the given boundary keys are different
+// and right is larger than left.
+func unsetInternal(n node, left []byte, right []byte) (bool, error) {
+ left, right = keybytesToHex(left), keybytesToHex(right)
+
+ // Step down to the fork point. There are two scenarios can happen:
+ // - the fork point is a shortnode: either the key of left proof or
+ // right proof doesn't match with shortnode's key.
+ // - the fork point is a fullnode: both two edge proofs are allowed
+ // to point to a non-existent key.
+ var (
+ pos = 0
+ parent node
+
+ // fork indicator, 0 means no fork, -1 means proof is less, 1 means proof is greater
+ shortForkLeft, shortForkRight int
+ )
+findFork:
+ for {
+ switch rn := (n).(type) {
+ case *shortNode:
+ rn.flags = nodeFlag{dirty: true}
+
+ // If either the key of left proof or right proof doesn't match with
+ // shortnode, stop here and the forkpoint is the shortnode.
+ if len(left)-pos < len(rn.Key) {
+ shortForkLeft = bytes.Compare(left[pos:], rn.Key)
+ } else {
+ shortForkLeft = bytes.Compare(left[pos:pos+len(rn.Key)], rn.Key)
+ }
+ if len(right)-pos < len(rn.Key) {
+ shortForkRight = bytes.Compare(right[pos:], rn.Key)
+ } else {
+ shortForkRight = bytes.Compare(right[pos:pos+len(rn.Key)], rn.Key)
+ }
+ if shortForkLeft != 0 || shortForkRight != 0 {
+ break findFork
+ }
+ parent = n
+ n, pos = rn.Val, pos+len(rn.Key)
+ case *fullNode:
+ rn.flags = nodeFlag{dirty: true}
+
+ // If either the node pointed by left proof or right proof is nil,
+ // stop here and the forkpoint is the fullnode.
+ leftnode, rightnode := rn.Children[left[pos]], rn.Children[right[pos]]
+ if leftnode == nil || rightnode == nil || leftnode != rightnode {
+ break findFork
+ }
+ parent = n
+ n, pos = rn.Children[left[pos]], pos+1
+ default:
+ panic(fmt.Sprintf("%T: invalid node: %v", n, n))
+ }
+ }
+ switch rn := n.(type) {
+ case *shortNode:
+ // There can have these five scenarios:
+ // - both proofs are less than the trie path => no valid range
+ // - both proofs are greater than the trie path => no valid range
+ // - left proof is less and right proof is greater => valid range, unset the shortnode entirely
+ // - left proof points to the shortnode, but right proof is greater
+ // - right proof points to the shortnode, but left proof is less
+ if shortForkLeft == -1 && shortForkRight == -1 {
+ return false, errors.New("empty range")
+ }
+ if shortForkLeft == 1 && shortForkRight == 1 {
+ return false, errors.New("empty range")
+ }
+ if shortForkLeft != 0 && shortForkRight != 0 {
+ // The fork point is root node, unset the entire trie
+ if parent == nil {
+ return true, nil
+ }
+ parent.(*fullNode).Children[left[pos-1]] = nil
+ return false, nil
+ }
+ // Only one proof points to non-existent key.
+ if shortForkRight != 0 {
+ if _, ok := rn.Val.(valueNode); ok {
+ // The fork point is root node, unset the entire trie
+ if parent == nil {
+ return true, nil
+ }
+ parent.(*fullNode).Children[left[pos-1]] = nil
+ return false, nil
+ }
+ return false, unset(rn, rn.Val, left[pos:], len(rn.Key), false)
+ }
+ if shortForkLeft != 0 {
+ if _, ok := rn.Val.(valueNode); ok {
+ // The fork point is root node, unset the entire trie
+ if parent == nil {
+ return true, nil
+ }
+ parent.(*fullNode).Children[right[pos-1]] = nil
+ return false, nil
+ }
+ return false, unset(rn, rn.Val, right[pos:], len(rn.Key), true)
+ }
+ return false, nil
+ case *fullNode:
+ // unset all internal nodes in the forkpoint
+ for i := left[pos] + 1; i < right[pos]; i++ {
+ rn.Children[i] = nil
+ }
+ if err := unset(rn, rn.Children[left[pos]], left[pos:], 1, false); err != nil {
+ return false, err
+ }
+ if err := unset(rn, rn.Children[right[pos]], right[pos:], 1, true); err != nil {
+ return false, err
+ }
+ return false, nil
+ default:
+ panic(fmt.Sprintf("%T: invalid node: %v", n, n))
+ }
+}
+
+// unset removes all internal node references either the left most or right most.
+// It can meet these scenarios:
+//
+// - The given path is existent in the trie, unset the associated nodes with the
+// specific direction
+// - The given path is non-existent in the trie
+// - the fork point is a fullnode, the corresponding child pointed by path
+// is nil, return
+// - the fork point is a shortnode, the shortnode is included in the range,
+// keep the entire branch and return.
+// - the fork point is a shortnode, the shortnode is excluded in the range,
+// unset the entire branch.
+func unset(parent node, child node, key []byte, pos int, removeLeft bool) error {
+ switch cld := child.(type) {
+ case *fullNode:
+ if removeLeft {
+ for i := 0; i < int(key[pos]); i++ {
+ cld.Children[i] = nil
+ }
+ cld.flags = nodeFlag{dirty: true}
+ } else {
+ for i := key[pos] + 1; i < 16; i++ {
+ cld.Children[i] = nil
+ }
+ cld.flags = nodeFlag{dirty: true}
+ }
+ return unset(cld, cld.Children[key[pos]], key, pos+1, removeLeft)
+ case *shortNode:
+ if len(key[pos:]) < len(cld.Key) || !bytes.Equal(cld.Key, key[pos:pos+len(cld.Key)]) {
+ // Find the fork point, it's an non-existent branch.
+ if removeLeft {
+ if bytes.Compare(cld.Key, key[pos:]) < 0 {
+ // The key of fork shortnode is less than the path
+ // (it belongs to the range), unset the entire
+ // branch. The parent must be a fullnode.
+ fn := parent.(*fullNode)
+ fn.Children[key[pos-1]] = nil
+ }
+ //else {
+ // The key of fork shortnode is greater than the
+ // path(it doesn't belong to the range), keep
+ // it with the cached hash available.
+ //}
+ } else {
+ if bytes.Compare(cld.Key, key[pos:]) > 0 {
+ // The key of fork shortnode is greater than the
+ // path(it belongs to the range), unset the entries
+ // branch. The parent must be a fullnode.
+ fn := parent.(*fullNode)
+ fn.Children[key[pos-1]] = nil
+ }
+ //else {
+ // The key of fork shortnode is less than the
+ // path(it doesn't belong to the range), keep
+ // it with the cached hash available.
+ //}
+ }
+ return nil
+ }
+ if _, ok := cld.Val.(valueNode); ok {
+ fn := parent.(*fullNode)
+ fn.Children[key[pos-1]] = nil
+ return nil
+ }
+ cld.flags = nodeFlag{dirty: true}
+ return unset(cld, cld.Val, key, pos+len(cld.Key), removeLeft)
+ case nil:
+ // If the node is nil, then it's a child of the fork point
+ // fullnode(it's a non-existent branch).
+ return nil
+ default:
+ panic("it shouldn't happen") // hashNode, valueNode
+ }
+}
+
+// hasRightElement returns the indicator whether there exists more elements
+// on the right side of the given path. The given path can point to an existent
+// key or a non-existent one. This function has the assumption that the whole
+// path should already be resolved.
+func hasRightElement(node node, key []byte) bool {
+ pos, key := 0, keybytesToHex(key)
+ for node != nil {
+ switch rn := node.(type) {
+ case *fullNode:
+ for i := key[pos] + 1; i < 16; i++ {
+ if rn.Children[i] != nil {
+ return true
+ }
+ }
+ node, pos = rn.Children[key[pos]], pos+1
+ case *shortNode:
+ if len(key)-pos < len(rn.Key) || !bytes.Equal(rn.Key, key[pos:pos+len(rn.Key)]) {
+ return bytes.Compare(rn.Key, key[pos:]) > 0
+ }
+ node, pos = rn.Val, pos+len(rn.Key)
+ case valueNode:
+ return false // We have resolved the whole path
+ default:
+ panic(fmt.Sprintf("%T: invalid node: %v", node, node)) // hashnode
+ }
+ }
+ return false
+}
+
+// VerifyRangeProof checks whether the given leaf nodes and edge proof
+// can prove the given trie leaves range is matched with the specific root.
+// Besides, the range should be consecutive (no gap inside) and monotonic
+// increasing.
+//
+// Note the given proof actually contains two edge proofs. Both of them can
+// be non-existent proofs. For example the first proof is for a non-existent
+// key 0x03, the last proof is for a non-existent key 0x10. The given batch
+// leaves are [0x04, 0x05, .. 0x09]. It's still feasible to prove the given
+// batch is valid.
+//
+// The firstKey is paired with firstProof, not necessarily the same as keys[0]
+// (unless firstProof is an existent proof). Similarly, lastKey and lastProof
+// are paired.
+//
+// Expect the normal case, this function can also be used to verify the following
+// range proofs:
+//
+// - All elements proof. In this case the proof can be nil, but the range should
+// be all the leaves in the trie.
+//
+// - One element proof. In this case no matter the edge proof is a non-existent
+// proof or not, we can always verify the correctness of the proof.
+//
+// - Zero element proof. In this case a single non-existent proof is enough to prove.
+// Besides, if there are still some other leaves available on the right side, then
+// an error will be returned.
+//
+// Except returning the error to indicate the proof is valid or not, the function will
+// also return a flag to indicate whether there exists more accounts/slots in the trie.
+//
+// Note: This method does not verify that the proof is of minimal form. If the input
+// proofs are 'bloated' with neighbour leaves or random data, aside from the 'useful'
+// data, then the proof will still be accepted.
+func VerifyRangeProof(rootHash common.Hash, firstKey []byte, keys [][]byte, values [][]byte, proof ethdb.KeyValueReader) (bool, error) {
+ if len(keys) != len(values) {
+ return false, fmt.Errorf("inconsistent proof data, keys: %d, values: %d", len(keys), len(values))
+ }
+ // Ensure the received batch is monotonic increasing and contains no deletions
+ for i := 0; i < len(keys)-1; i++ {
+ if bytes.Compare(keys[i], keys[i+1]) >= 0 {
+ return false, errors.New("range is not monotonically increasing")
+ }
+ }
+ for _, value := range values {
+ if len(value) == 0 {
+ return false, errors.New("range contains deletion")
+ }
+ }
+ // Special case, there is no edge proof at all. The given range is expected
+ // to be the whole leaf-set in the trie.
+ if proof == nil {
+ tr := NewStackTrie(nil)
+ for index, key := range keys {
+ tr.Update(key, values[index])
+ }
+ if have, want := tr.Hash(), rootHash; have != want {
+ return false, fmt.Errorf("invalid proof, want hash %x, got %x", want, have)
+ }
+ return false, nil // No more elements
+ }
+ // Special case, there is a provided edge proof but zero key/value
+ // pairs, ensure there are no more accounts / slots in the trie.
+ if len(keys) == 0 {
+ root, val, err := proofToPath(rootHash, nil, firstKey, proof, true)
+ if err != nil {
+ return false, err
+ }
+ if val != nil || hasRightElement(root, firstKey) {
+ return false, errors.New("more entries available")
+ }
+ return false, nil
+ }
+ var lastKey = keys[len(keys)-1]
+ // Special case, there is only one element and two edge keys are same.
+ // In this case, we can't construct two edge paths. So handle it here.
+ if len(keys) == 1 && bytes.Equal(firstKey, lastKey) {
+ root, val, err := proofToPath(rootHash, nil, firstKey, proof, false)
+ if err != nil {
+ return false, err
+ }
+ if !bytes.Equal(firstKey, keys[0]) {
+ return false, errors.New("correct proof but invalid key")
+ }
+ if !bytes.Equal(val, values[0]) {
+ return false, errors.New("correct proof but invalid data")
+ }
+ return hasRightElement(root, firstKey), nil
+ }
+ // Ok, in all other cases, we require two edge paths available.
+ // First check the validity of edge keys.
+ if bytes.Compare(firstKey, lastKey) >= 0 {
+ return false, errors.New("invalid edge keys")
+ }
+ // todo(rjl493456442) different length edge keys should be supported
+ if len(firstKey) != len(lastKey) {
+ return false, errors.New("inconsistent edge keys")
+ }
+ // Convert the edge proofs to edge trie paths. Then we can
+ // have the same tree architecture with the original one.
+ // For the first edge proof, non-existent proof is allowed.
+ root, _, err := proofToPath(rootHash, nil, firstKey, proof, true)
+ if err != nil {
+ return false, err
+ }
+ // Pass the root node here, the second path will be merged
+ // with the first one. For the last edge proof, non-existent
+ // proof is also allowed.
+ root, _, err = proofToPath(rootHash, root, lastKey, proof, true)
+ if err != nil {
+ return false, err
+ }
+ // Remove all internal references. All the removed parts should
+ // be re-filled(or re-constructed) by the given leaves range.
+ empty, err := unsetInternal(root, firstKey, lastKey)
+ if err != nil {
+ return false, err
+ }
+ // Rebuild the trie with the leaf stream, the shape of trie
+ // should be same with the original one.
+ tr := &Trie{root: root, reader: newEmptyReader(), tracer: newTracer()}
+ if empty {
+ tr.root = nil
+ }
+ for index, key := range keys {
+ tr.Update(key, values[index])
+ }
+ if tr.Hash() != rootHash {
+ return false, fmt.Errorf("invalid proof, want hash %x, got %x", rootHash, tr.Hash())
+ }
+ return hasRightElement(tr.root, keys[len(keys)-1]), nil
+}
+
+// get returns the child of the given node. Return nil if the
+// node with specified key doesn't exist at all.
+//
+// There is an additional flag `skipResolved`. If it's set then
+// all resolved nodes won't be returned.
+func get(tn node, key []byte, skipResolved bool) ([]byte, node) {
+ for {
+ switch n := tn.(type) {
+ case *shortNode:
+ if len(key) < len(n.Key) || !bytes.Equal(n.Key, key[:len(n.Key)]) {
+ return nil, nil
+ }
+ tn = n.Val
+ key = key[len(n.Key):]
+ if !skipResolved {
+ return key, tn
+ }
+ case *fullNode:
+ tn = n.Children[key[0]]
+ key = key[1:]
+ if !skipResolved {
+ return key, tn
+ }
+ case hashNode:
+ return key, n
+ case nil:
+ return key, nil
+ case valueNode:
+ return nil, n
+ default:
+ panic(fmt.Sprintf("%T: invalid node: %v", tn, tn))
+ }
+ }
}
diff --git a/trie_by_cid/trie/proof_test.go b/trie_by_cid/trie/proof_test.go
index 6b23bcd..5471d0e 100644
--- a/trie_by_cid/trie/proof_test.go
+++ b/trie_by_cid/trie/proof_test.go
@@ -22,13 +22,13 @@ import (
"encoding/binary"
"fmt"
mrand "math/rand"
- "sort"
"testing"
"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/core/rawdb"
"github.com/ethereum/go-ethereum/crypto"
"github.com/ethereum/go-ethereum/ethdb/memorydb"
+ "golang.org/x/exp/slices"
)
// Prng is a pseudo random number generator seeded by strong randomness.
@@ -57,13 +57,13 @@ func makeProvers(trie *Trie) []func(key []byte) *memorydb.Database {
// Create a direct trie based Merkle prover
provers = append(provers, func(key []byte) *memorydb.Database {
proof := memorydb.New()
- trie.Prove(key, 0, proof)
+ trie.Prove(key, proof)
return proof
})
// Create a leaf iterator based Merkle prover
provers = append(provers, func(key []byte) *memorydb.Database {
proof := memorydb.New()
- if it := NewIterator(trie.NodeIterator(key)); it.Next() && bytes.Equal(key, it.Key) {
+ if it := NewIterator(trie.MustNodeIterator(key)); it.Next() && bytes.Equal(key, it.Key) {
for _, p := range it.Prove() {
proof.Put(crypto.Keccak256(p), p)
}
@@ -94,7 +94,7 @@ func TestProof(t *testing.T) {
}
func TestOneElementProof(t *testing.T) {
- trie := NewEmpty(NewDatabase(rawdb.NewMemoryDatabase()))
+ trie := NewEmpty(newTestDatabase(rawdb.NewMemoryDatabase(), rawdb.HashScheme))
updateString(trie, "k", "v")
for i, prover := range makeProvers(trie) {
proof := prover([]byte("k"))
@@ -145,12 +145,12 @@ func TestBadProof(t *testing.T) {
// Tests that missing keys can also be proven. The test explicitly uses a single
// entry trie and checks for missing keys both before and after the single entry.
func TestMissingKeyProof(t *testing.T) {
- trie := NewEmpty(NewDatabase(rawdb.NewMemoryDatabase()))
+ trie := NewEmpty(newTestDatabase(rawdb.NewMemoryDatabase(), rawdb.HashScheme))
updateString(trie, "k", "v")
for i, key := range []string{"a", "j", "l", "z"} {
proof := memorydb.New()
- trie.Prove([]byte(key), 0, proof)
+ trie.Prove([]byte(key), proof)
if proof.Len() != 1 {
t.Errorf("test %d: proof should have one element", i)
@@ -165,30 +165,24 @@ func TestMissingKeyProof(t *testing.T) {
}
}
-type entrySlice []*kv
-
-func (p entrySlice) Len() int { return len(p) }
-func (p entrySlice) Less(i, j int) bool { return bytes.Compare(p[i].k, p[j].k) < 0 }
-func (p entrySlice) Swap(i, j int) { p[i], p[j] = p[j], p[i] }
-
// TestRangeProof tests normal range proof with both edge proofs
// as the existent proof. The test cases are generated randomly.
func TestRangeProof(t *testing.T) {
trie, vals := randomTrie(4096)
- var entries entrySlice
+ var entries []*kv
for _, kv := range vals {
entries = append(entries, kv)
}
- sort.Sort(entries)
+ slices.SortFunc(entries, (*kv).cmp)
for i := 0; i < 500; i++ {
start := mrand.Intn(len(entries))
end := mrand.Intn(len(entries)-start) + start + 1
proof := memorydb.New()
- if err := trie.Prove(entries[start].k, 0, proof); err != nil {
+ if err := trie.Prove(entries[start].k, proof); err != nil {
t.Fatalf("Failed to prove the first node %v", err)
}
- if err := trie.Prove(entries[end-1].k, 0, proof); err != nil {
+ if err := trie.Prove(entries[end-1].k, proof); err != nil {
t.Fatalf("Failed to prove the last node %v", err)
}
var keys [][]byte
@@ -197,7 +191,7 @@ func TestRangeProof(t *testing.T) {
keys = append(keys, entries[i].k)
vals = append(vals, entries[i].v)
}
- _, err := VerifyRangeProof(trie.Hash(), keys[0], keys[len(keys)-1], keys, vals, proof)
+ _, err := VerifyRangeProof(trie.Hash(), keys[0], keys, vals, proof)
if err != nil {
t.Fatalf("Case %d(%d->%d) expect no error, got %v", i, start, end-1, err)
}
@@ -208,11 +202,11 @@ func TestRangeProof(t *testing.T) {
// The test cases are generated randomly.
func TestRangeProofWithNonExistentProof(t *testing.T) {
trie, vals := randomTrie(4096)
- var entries entrySlice
+ var entries []*kv
for _, kv := range vals {
entries = append(entries, kv)
}
- sort.Sort(entries)
+ slices.SortFunc(entries, (*kv).cmp)
for i := 0; i < 500; i++ {
start := mrand.Intn(len(entries))
end := mrand.Intn(len(entries)-start) + start + 1
@@ -227,19 +221,10 @@ func TestRangeProofWithNonExistentProof(t *testing.T) {
if bytes.Compare(first, entries[start].k) > 0 {
continue
}
- // Short circuit if the increased key is same with the next key
- last := increaseKey(common.CopyBytes(entries[end-1].k))
- if end != len(entries) && bytes.Equal(last, entries[end].k) {
- continue
- }
- // Short circuit if the increased key is overflow
- if bytes.Compare(last, entries[end-1].k) < 0 {
- continue
- }
- if err := trie.Prove(first, 0, proof); err != nil {
+ if err := trie.Prove(first, proof); err != nil {
t.Fatalf("Failed to prove the first node %v", err)
}
- if err := trie.Prove(last, 0, proof); err != nil {
+ if err := trie.Prove(entries[end-1].k, proof); err != nil {
t.Fatalf("Failed to prove the last node %v", err)
}
var keys [][]byte
@@ -248,53 +233,32 @@ func TestRangeProofWithNonExistentProof(t *testing.T) {
keys = append(keys, entries[i].k)
vals = append(vals, entries[i].v)
}
- _, err := VerifyRangeProof(trie.Hash(), first, last, keys, vals, proof)
+ _, err := VerifyRangeProof(trie.Hash(), first, keys, vals, proof)
if err != nil {
t.Fatalf("Case %d(%d->%d) expect no error, got %v", i, start, end-1, err)
}
}
- // Special case, two edge proofs for two edge key.
- proof := memorydb.New()
- first := common.HexToHash("0x0000000000000000000000000000000000000000000000000000000000000000").Bytes()
- last := common.HexToHash("0xffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff").Bytes()
- if err := trie.Prove(first, 0, proof); err != nil {
- t.Fatalf("Failed to prove the first node %v", err)
- }
- if err := trie.Prove(last, 0, proof); err != nil {
- t.Fatalf("Failed to prove the last node %v", err)
- }
- var k [][]byte
- var v [][]byte
- for i := 0; i < len(entries); i++ {
- k = append(k, entries[i].k)
- v = append(v, entries[i].v)
- }
- _, err := VerifyRangeProof(trie.Hash(), first, last, k, v, proof)
- if err != nil {
- t.Fatal("Failed to verify whole rang with non-existent edges")
- }
}
// TestRangeProofWithInvalidNonExistentProof tests such scenarios:
// - There exists a gap between the first element and the left edge proof
-// - There exists a gap between the last element and the right edge proof
func TestRangeProofWithInvalidNonExistentProof(t *testing.T) {
trie, vals := randomTrie(4096)
- var entries entrySlice
+ var entries []*kv
for _, kv := range vals {
entries = append(entries, kv)
}
- sort.Sort(entries)
+ slices.SortFunc(entries, (*kv).cmp)
// Case 1
start, end := 100, 200
first := decreaseKey(common.CopyBytes(entries[start].k))
proof := memorydb.New()
- if err := trie.Prove(first, 0, proof); err != nil {
+ if err := trie.Prove(first, proof); err != nil {
t.Fatalf("Failed to prove the first node %v", err)
}
- if err := trie.Prove(entries[end-1].k, 0, proof); err != nil {
+ if err := trie.Prove(entries[end-1].k, proof); err != nil {
t.Fatalf("Failed to prove the last node %v", err)
}
start = 105 // Gap created
@@ -304,29 +268,7 @@ func TestRangeProofWithInvalidNonExistentProof(t *testing.T) {
k = append(k, entries[i].k)
v = append(v, entries[i].v)
}
- _, err := VerifyRangeProof(trie.Hash(), first, k[len(k)-1], k, v, proof)
- if err == nil {
- t.Fatalf("Expected to detect the error, got nil")
- }
-
- // Case 2
- start, end = 100, 200
- last := increaseKey(common.CopyBytes(entries[end-1].k))
- proof = memorydb.New()
- if err := trie.Prove(entries[start].k, 0, proof); err != nil {
- t.Fatalf("Failed to prove the first node %v", err)
- }
- if err := trie.Prove(last, 0, proof); err != nil {
- t.Fatalf("Failed to prove the last node %v", err)
- }
- end = 195 // Capped slice
- k = make([][]byte, 0)
- v = make([][]byte, 0)
- for i := start; i < end; i++ {
- k = append(k, entries[i].k)
- v = append(v, entries[i].v)
- }
- _, err = VerifyRangeProof(trie.Hash(), k[0], last, k, v, proof)
+ _, err := VerifyRangeProof(trie.Hash(), first, k, v, proof)
if err == nil {
t.Fatalf("Expected to detect the error, got nil")
}
@@ -337,20 +279,20 @@ func TestRangeProofWithInvalidNonExistentProof(t *testing.T) {
// non-existent one.
func TestOneElementRangeProof(t *testing.T) {
trie, vals := randomTrie(4096)
- var entries entrySlice
+ var entries []*kv
for _, kv := range vals {
entries = append(entries, kv)
}
- sort.Sort(entries)
+ slices.SortFunc(entries, (*kv).cmp)
// One element with existent edge proof, both edge proofs
// point to the SAME key.
start := 1000
proof := memorydb.New()
- if err := trie.Prove(entries[start].k, 0, proof); err != nil {
+ if err := trie.Prove(entries[start].k, proof); err != nil {
t.Fatalf("Failed to prove the first node %v", err)
}
- _, err := VerifyRangeProof(trie.Hash(), entries[start].k, entries[start].k, [][]byte{entries[start].k}, [][]byte{entries[start].v}, proof)
+ _, err := VerifyRangeProof(trie.Hash(), entries[start].k, [][]byte{entries[start].k}, [][]byte{entries[start].v}, proof)
if err != nil {
t.Fatalf("Expected no error, got %v", err)
}
@@ -359,13 +301,13 @@ func TestOneElementRangeProof(t *testing.T) {
start = 1000
first := decreaseKey(common.CopyBytes(entries[start].k))
proof = memorydb.New()
- if err := trie.Prove(first, 0, proof); err != nil {
+ if err := trie.Prove(first, proof); err != nil {
t.Fatalf("Failed to prove the first node %v", err)
}
- if err := trie.Prove(entries[start].k, 0, proof); err != nil {
+ if err := trie.Prove(entries[start].k, proof); err != nil {
t.Fatalf("Failed to prove the last node %v", err)
}
- _, err = VerifyRangeProof(trie.Hash(), first, entries[start].k, [][]byte{entries[start].k}, [][]byte{entries[start].v}, proof)
+ _, err = VerifyRangeProof(trie.Hash(), first, [][]byte{entries[start].k}, [][]byte{entries[start].v}, proof)
if err != nil {
t.Fatalf("Expected no error, got %v", err)
}
@@ -374,13 +316,13 @@ func TestOneElementRangeProof(t *testing.T) {
start = 1000
last := increaseKey(common.CopyBytes(entries[start].k))
proof = memorydb.New()
- if err := trie.Prove(entries[start].k, 0, proof); err != nil {
+ if err := trie.Prove(entries[start].k, proof); err != nil {
t.Fatalf("Failed to prove the first node %v", err)
}
- if err := trie.Prove(last, 0, proof); err != nil {
+ if err := trie.Prove(last, proof); err != nil {
t.Fatalf("Failed to prove the last node %v", err)
}
- _, err = VerifyRangeProof(trie.Hash(), entries[start].k, last, [][]byte{entries[start].k}, [][]byte{entries[start].v}, proof)
+ _, err = VerifyRangeProof(trie.Hash(), entries[start].k, [][]byte{entries[start].k}, [][]byte{entries[start].v}, proof)
if err != nil {
t.Fatalf("Expected no error, got %v", err)
}
@@ -389,32 +331,32 @@ func TestOneElementRangeProof(t *testing.T) {
start = 1000
first, last = decreaseKey(common.CopyBytes(entries[start].k)), increaseKey(common.CopyBytes(entries[start].k))
proof = memorydb.New()
- if err := trie.Prove(first, 0, proof); err != nil {
+ if err := trie.Prove(first, proof); err != nil {
t.Fatalf("Failed to prove the first node %v", err)
}
- if err := trie.Prove(last, 0, proof); err != nil {
+ if err := trie.Prove(last, proof); err != nil {
t.Fatalf("Failed to prove the last node %v", err)
}
- _, err = VerifyRangeProof(trie.Hash(), first, last, [][]byte{entries[start].k}, [][]byte{entries[start].v}, proof)
+ _, err = VerifyRangeProof(trie.Hash(), first, [][]byte{entries[start].k}, [][]byte{entries[start].v}, proof)
if err != nil {
t.Fatalf("Expected no error, got %v", err)
}
// Test the mini trie with only a single element.
- tinyTrie := NewEmpty(NewDatabase(rawdb.NewMemoryDatabase()))
+ tinyTrie := NewEmpty(newTestDatabase(rawdb.NewMemoryDatabase(), rawdb.HashScheme))
entry := &kv{randBytes(32), randBytes(20), false}
- tinyTrie.Update(entry.k, entry.v)
+ tinyTrie.MustUpdate(entry.k, entry.v)
first = common.HexToHash("0x0000000000000000000000000000000000000000000000000000000000000000").Bytes()
last = entry.k
proof = memorydb.New()
- if err := tinyTrie.Prove(first, 0, proof); err != nil {
+ if err := tinyTrie.Prove(first, proof); err != nil {
t.Fatalf("Failed to prove the first node %v", err)
}
- if err := tinyTrie.Prove(last, 0, proof); err != nil {
+ if err := tinyTrie.Prove(last, proof); err != nil {
t.Fatalf("Failed to prove the last node %v", err)
}
- _, err = VerifyRangeProof(tinyTrie.Hash(), first, last, [][]byte{entry.k}, [][]byte{entry.v}, proof)
+ _, err = VerifyRangeProof(tinyTrie.Hash(), first, [][]byte{entry.k}, [][]byte{entry.v}, proof)
if err != nil {
t.Fatalf("Expected no error, got %v", err)
}
@@ -424,11 +366,11 @@ func TestOneElementRangeProof(t *testing.T) {
// The edge proofs can be nil.
func TestAllElementsProof(t *testing.T) {
trie, vals := randomTrie(4096)
- var entries entrySlice
+ var entries []*kv
for _, kv := range vals {
entries = append(entries, kv)
}
- sort.Sort(entries)
+ slices.SortFunc(entries, (*kv).cmp)
var k [][]byte
var v [][]byte
@@ -436,20 +378,20 @@ func TestAllElementsProof(t *testing.T) {
k = append(k, entries[i].k)
v = append(v, entries[i].v)
}
- _, err := VerifyRangeProof(trie.Hash(), nil, nil, k, v, nil)
+ _, err := VerifyRangeProof(trie.Hash(), nil, k, v, nil)
if err != nil {
t.Fatalf("Expected no error, got %v", err)
}
// With edge proofs, it should still work.
proof := memorydb.New()
- if err := trie.Prove(entries[0].k, 0, proof); err != nil {
+ if err := trie.Prove(entries[0].k, proof); err != nil {
t.Fatalf("Failed to prove the first node %v", err)
}
- if err := trie.Prove(entries[len(entries)-1].k, 0, proof); err != nil {
+ if err := trie.Prove(entries[len(entries)-1].k, proof); err != nil {
t.Fatalf("Failed to prove the last node %v", err)
}
- _, err = VerifyRangeProof(trie.Hash(), k[0], k[len(k)-1], k, v, proof)
+ _, err = VerifyRangeProof(trie.Hash(), k[0], k, v, proof)
if err != nil {
t.Fatalf("Expected no error, got %v", err)
}
@@ -457,14 +399,13 @@ func TestAllElementsProof(t *testing.T) {
// Even with non-existent edge proofs, it should still work.
proof = memorydb.New()
first := common.HexToHash("0x0000000000000000000000000000000000000000000000000000000000000000").Bytes()
- last := common.HexToHash("0xffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff").Bytes()
- if err := trie.Prove(first, 0, proof); err != nil {
+ if err := trie.Prove(first, proof); err != nil {
t.Fatalf("Failed to prove the first node %v", err)
}
- if err := trie.Prove(last, 0, proof); err != nil {
+ if err := trie.Prove(entries[len(entries)-1].k, proof); err != nil {
t.Fatalf("Failed to prove the last node %v", err)
}
- _, err = VerifyRangeProof(trie.Hash(), first, last, k, v, proof)
+ _, err = VerifyRangeProof(trie.Hash(), first, k, v, proof)
if err != nil {
t.Fatalf("Expected no error, got %v", err)
}
@@ -473,22 +414,22 @@ func TestAllElementsProof(t *testing.T) {
// TestSingleSideRangeProof tests the range starts from zero.
func TestSingleSideRangeProof(t *testing.T) {
for i := 0; i < 64; i++ {
- trie := NewEmpty(NewDatabase(rawdb.NewMemoryDatabase()))
- var entries entrySlice
+ trie := NewEmpty(newTestDatabase(rawdb.NewMemoryDatabase(), rawdb.HashScheme))
+ var entries []*kv
for i := 0; i < 4096; i++ {
value := &kv{randBytes(32), randBytes(20), false}
- trie.Update(value.k, value.v)
+ trie.MustUpdate(value.k, value.v)
entries = append(entries, value)
}
- sort.Sort(entries)
+ slices.SortFunc(entries, (*kv).cmp)
var cases = []int{0, 1, 50, 100, 1000, 2000, len(entries) - 1}
for _, pos := range cases {
proof := memorydb.New()
- if err := trie.Prove(common.Hash{}.Bytes(), 0, proof); err != nil {
+ if err := trie.Prove(common.Hash{}.Bytes(), proof); err != nil {
t.Fatalf("Failed to prove the first node %v", err)
}
- if err := trie.Prove(entries[pos].k, 0, proof); err != nil {
+ if err := trie.Prove(entries[pos].k, proof); err != nil {
t.Fatalf("Failed to prove the first node %v", err)
}
k := make([][]byte, 0)
@@ -497,43 +438,7 @@ func TestSingleSideRangeProof(t *testing.T) {
k = append(k, entries[i].k)
v = append(v, entries[i].v)
}
- _, err := VerifyRangeProof(trie.Hash(), common.Hash{}.Bytes(), k[len(k)-1], k, v, proof)
- if err != nil {
- t.Fatalf("Expected no error, got %v", err)
- }
- }
- }
-}
-
-// TestReverseSingleSideRangeProof tests the range ends with 0xffff...fff.
-func TestReverseSingleSideRangeProof(t *testing.T) {
- for i := 0; i < 64; i++ {
- trie := NewEmpty(NewDatabase(rawdb.NewMemoryDatabase()))
- var entries entrySlice
- for i := 0; i < 4096; i++ {
- value := &kv{randBytes(32), randBytes(20), false}
- trie.Update(value.k, value.v)
- entries = append(entries, value)
- }
- sort.Sort(entries)
-
- var cases = []int{0, 1, 50, 100, 1000, 2000, len(entries) - 1}
- for _, pos := range cases {
- proof := memorydb.New()
- if err := trie.Prove(entries[pos].k, 0, proof); err != nil {
- t.Fatalf("Failed to prove the first node %v", err)
- }
- last := common.HexToHash("0xffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff")
- if err := trie.Prove(last.Bytes(), 0, proof); err != nil {
- t.Fatalf("Failed to prove the last node %v", err)
- }
- k := make([][]byte, 0)
- v := make([][]byte, 0)
- for i := pos; i < len(entries); i++ {
- k = append(k, entries[i].k)
- v = append(v, entries[i].v)
- }
- _, err := VerifyRangeProof(trie.Hash(), k[0], last.Bytes(), k, v, proof)
+ _, err := VerifyRangeProof(trie.Hash(), common.Hash{}.Bytes(), k, v, proof)
if err != nil {
t.Fatalf("Expected no error, got %v", err)
}
@@ -545,20 +450,20 @@ func TestReverseSingleSideRangeProof(t *testing.T) {
// The prover is expected to detect the error.
func TestBadRangeProof(t *testing.T) {
trie, vals := randomTrie(4096)
- var entries entrySlice
+ var entries []*kv
for _, kv := range vals {
entries = append(entries, kv)
}
- sort.Sort(entries)
+ slices.SortFunc(entries, (*kv).cmp)
for i := 0; i < 500; i++ {
start := mrand.Intn(len(entries))
end := mrand.Intn(len(entries)-start) + start + 1
proof := memorydb.New()
- if err := trie.Prove(entries[start].k, 0, proof); err != nil {
+ if err := trie.Prove(entries[start].k, proof); err != nil {
t.Fatalf("Failed to prove the first node %v", err)
}
- if err := trie.Prove(entries[end-1].k, 0, proof); err != nil {
+ if err := trie.Prove(entries[end-1].k, proof); err != nil {
t.Fatalf("Failed to prove the last node %v", err)
}
var keys [][]byte
@@ -567,7 +472,7 @@ func TestBadRangeProof(t *testing.T) {
keys = append(keys, entries[i].k)
vals = append(vals, entries[i].v)
}
- var first, last = keys[0], keys[len(keys)-1]
+ var first = keys[0]
testcase := mrand.Intn(6)
var index int
switch testcase {
@@ -582,7 +487,7 @@ func TestBadRangeProof(t *testing.T) {
case 2:
// Gapped entry slice
index = mrand.Intn(end - start)
- if (index == 0 && start < 100) || (index == end-start-1 && end <= 100) {
+ if (index == 0 && start < 100) || (index == end-start-1) {
continue
}
keys = append(keys[:index], keys[index+1:]...)
@@ -605,7 +510,7 @@ func TestBadRangeProof(t *testing.T) {
index = mrand.Intn(end - start)
vals[index] = nil
}
- _, err := VerifyRangeProof(trie.Hash(), first, last, keys, vals, proof)
+ _, err := VerifyRangeProof(trie.Hash(), first, keys, vals, proof)
if err == nil {
t.Fatalf("%d Case %d index %d range: (%d->%d) expect error, got nil", i, testcase, index, start, end-1)
}
@@ -615,19 +520,19 @@ func TestBadRangeProof(t *testing.T) {
// TestGappedRangeProof focuses on the small trie with embedded nodes.
// If the gapped node is embedded in the trie, it should be detected too.
func TestGappedRangeProof(t *testing.T) {
- trie := NewEmpty(NewDatabase(rawdb.NewMemoryDatabase()))
+ trie := NewEmpty(newTestDatabase(rawdb.NewMemoryDatabase(), rawdb.HashScheme))
var entries []*kv // Sorted entries
for i := byte(0); i < 10; i++ {
value := &kv{common.LeftPadBytes([]byte{i}, 32), []byte{i}, false}
- trie.Update(value.k, value.v)
+ trie.MustUpdate(value.k, value.v)
entries = append(entries, value)
}
first, last := 2, 8
proof := memorydb.New()
- if err := trie.Prove(entries[first].k, 0, proof); err != nil {
+ if err := trie.Prove(entries[first].k, proof); err != nil {
t.Fatalf("Failed to prove the first node %v", err)
}
- if err := trie.Prove(entries[last-1].k, 0, proof); err != nil {
+ if err := trie.Prove(entries[last-1].k, proof); err != nil {
t.Fatalf("Failed to prove the last node %v", err)
}
var keys [][]byte
@@ -639,7 +544,7 @@ func TestGappedRangeProof(t *testing.T) {
keys = append(keys, entries[i].k)
vals = append(vals, entries[i].v)
}
- _, err := VerifyRangeProof(trie.Hash(), keys[0], keys[len(keys)-1], keys, vals, proof)
+ _, err := VerifyRangeProof(trie.Hash(), keys[0], keys, vals, proof)
if err == nil {
t.Fatal("expect error, got nil")
}
@@ -648,55 +553,53 @@ func TestGappedRangeProof(t *testing.T) {
// TestSameSideProofs tests the element is not in the range covered by proofs
func TestSameSideProofs(t *testing.T) {
trie, vals := randomTrie(4096)
- var entries entrySlice
+ var entries []*kv
for _, kv := range vals {
entries = append(entries, kv)
}
- sort.Sort(entries)
+ slices.SortFunc(entries, (*kv).cmp)
pos := 1000
- first := decreaseKey(common.CopyBytes(entries[pos].k))
- first = decreaseKey(first)
- last := decreaseKey(common.CopyBytes(entries[pos].k))
+ first := common.CopyBytes(entries[0].k)
proof := memorydb.New()
- if err := trie.Prove(first, 0, proof); err != nil {
+ if err := trie.Prove(first, proof); err != nil {
t.Fatalf("Failed to prove the first node %v", err)
}
- if err := trie.Prove(last, 0, proof); err != nil {
- t.Fatalf("Failed to prove the last node %v", err)
+ if err := trie.Prove(entries[2000].k, proof); err != nil {
+ t.Fatalf("Failed to prove the first node %v", err)
}
- _, err := VerifyRangeProof(trie.Hash(), first, last, [][]byte{entries[pos].k}, [][]byte{entries[pos].v}, proof)
+ _, err := VerifyRangeProof(trie.Hash(), first, [][]byte{entries[pos].k}, [][]byte{entries[pos].v}, proof)
if err == nil {
t.Fatalf("Expected error, got nil")
}
first = increaseKey(common.CopyBytes(entries[pos].k))
- last = increaseKey(common.CopyBytes(entries[pos].k))
+ last := increaseKey(common.CopyBytes(entries[pos].k))
last = increaseKey(last)
proof = memorydb.New()
- if err := trie.Prove(first, 0, proof); err != nil {
+ if err := trie.Prove(first, proof); err != nil {
t.Fatalf("Failed to prove the first node %v", err)
}
- if err := trie.Prove(last, 0, proof); err != nil {
+ if err := trie.Prove(last, proof); err != nil {
t.Fatalf("Failed to prove the last node %v", err)
}
- _, err = VerifyRangeProof(trie.Hash(), first, last, [][]byte{entries[pos].k}, [][]byte{entries[pos].v}, proof)
+ _, err = VerifyRangeProof(trie.Hash(), first, [][]byte{entries[pos].k}, [][]byte{entries[pos].v}, proof)
if err == nil {
t.Fatalf("Expected error, got nil")
}
}
func TestHasRightElement(t *testing.T) {
- trie := NewEmpty(NewDatabase(rawdb.NewMemoryDatabase()))
- var entries entrySlice
+ trie := NewEmpty(newTestDatabase(rawdb.NewMemoryDatabase(), rawdb.HashScheme))
+ var entries []*kv
for i := 0; i < 4096; i++ {
value := &kv{randBytes(32), randBytes(20), false}
- trie.Update(value.k, value.v)
+ trie.MustUpdate(value.k, value.v)
entries = append(entries, value)
}
- sort.Sort(entries)
+ slices.SortFunc(entries, (*kv).cmp)
var cases = []struct {
start int
@@ -709,40 +612,29 @@ func TestHasRightElement(t *testing.T) {
{50, 100, true},
{50, len(entries), false}, // No more element expected
{len(entries) - 1, len(entries), false}, // Single last element with two existent proofs(point to same key)
- {len(entries) - 1, -1, false}, // Single last element with non-existent right proof
{0, len(entries), false}, // The whole set with existent left proof
{-1, len(entries), false}, // The whole set with non-existent left proof
- {-1, -1, false}, // The whole set with non-existent left/right proof
}
for _, c := range cases {
var (
firstKey []byte
- lastKey []byte
start = c.start
end = c.end
proof = memorydb.New()
)
if c.start == -1 {
firstKey, start = common.Hash{}.Bytes(), 0
- if err := trie.Prove(firstKey, 0, proof); err != nil {
+ if err := trie.Prove(firstKey, proof); err != nil {
t.Fatalf("Failed to prove the first node %v", err)
}
} else {
firstKey = entries[c.start].k
- if err := trie.Prove(entries[c.start].k, 0, proof); err != nil {
+ if err := trie.Prove(entries[c.start].k, proof); err != nil {
t.Fatalf("Failed to prove the first node %v", err)
}
}
- if c.end == -1 {
- lastKey, end = common.HexToHash("0xffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff").Bytes(), len(entries)
- if err := trie.Prove(lastKey, 0, proof); err != nil {
- t.Fatalf("Failed to prove the first node %v", err)
- }
- } else {
- lastKey = entries[c.end-1].k
- if err := trie.Prove(entries[c.end-1].k, 0, proof); err != nil {
- t.Fatalf("Failed to prove the first node %v", err)
- }
+ if err := trie.Prove(entries[c.end-1].k, proof); err != nil {
+ t.Fatalf("Failed to prove the first node %v", err)
}
k := make([][]byte, 0)
v := make([][]byte, 0)
@@ -750,7 +642,7 @@ func TestHasRightElement(t *testing.T) {
k = append(k, entries[i].k)
v = append(v, entries[i].v)
}
- hasMore, err := VerifyRangeProof(trie.Hash(), firstKey, lastKey, k, v, proof)
+ hasMore, err := VerifyRangeProof(trie.Hash(), firstKey, k, v, proof)
if err != nil {
t.Fatalf("Expected no error, got %v", err)
}
@@ -764,11 +656,11 @@ func TestHasRightElement(t *testing.T) {
// The first edge proof must be a non-existent proof.
func TestEmptyRangeProof(t *testing.T) {
trie, vals := randomTrie(4096)
- var entries entrySlice
+ var entries []*kv
for _, kv := range vals {
entries = append(entries, kv)
}
- sort.Sort(entries)
+ slices.SortFunc(entries, (*kv).cmp)
var cases = []struct {
pos int
@@ -780,10 +672,10 @@ func TestEmptyRangeProof(t *testing.T) {
for _, c := range cases {
proof := memorydb.New()
first := increaseKey(common.CopyBytes(entries[c.pos].k))
- if err := trie.Prove(first, 0, proof); err != nil {
+ if err := trie.Prove(first, proof); err != nil {
t.Fatalf("Failed to prove the first node %v", err)
}
- _, err := VerifyRangeProof(trie.Hash(), first, nil, nil, nil, proof)
+ _, err := VerifyRangeProof(trie.Hash(), first, nil, nil, proof)
if c.err && err == nil {
t.Fatalf("Expected error, got nil")
}
@@ -799,11 +691,11 @@ func TestEmptyRangeProof(t *testing.T) {
func TestBloatedProof(t *testing.T) {
// Use a small trie
trie, kvs := nonRandomTrie(100)
- var entries entrySlice
+ var entries []*kv
for _, kv := range kvs {
entries = append(entries, kv)
}
- sort.Sort(entries)
+ slices.SortFunc(entries, (*kv).cmp)
var keys [][]byte
var vals [][]byte
@@ -811,7 +703,7 @@ func TestBloatedProof(t *testing.T) {
// In the 'malicious' case, we add proofs for every single item
// (but only one key/value pair used as leaf)
for i, entry := range entries {
- trie.Prove(entry.k, 0, proof)
+ trie.Prove(entry.k, proof)
if i == 50 {
keys = append(keys, entry.k)
vals = append(vals, entry.v)
@@ -820,10 +712,10 @@ func TestBloatedProof(t *testing.T) {
// For reference, we use the same function, but _only_ prove the first
// and last element
want := memorydb.New()
- trie.Prove(keys[0], 0, want)
- trie.Prove(keys[len(keys)-1], 0, want)
+ trie.Prove(keys[0], want)
+ trie.Prove(keys[len(keys)-1], want)
- if _, err := VerifyRangeProof(trie.Hash(), keys[0], keys[len(keys)-1], keys, vals, proof); err != nil {
+ if _, err := VerifyRangeProof(trie.Hash(), keys[0], keys, vals, proof); err != nil {
t.Fatalf("expected bloated proof to succeed, got %v", err)
}
}
@@ -833,11 +725,11 @@ func TestBloatedProof(t *testing.T) {
// noop technically, but practically should be rejected.
func TestEmptyValueRangeProof(t *testing.T) {
trie, values := randomTrie(512)
- var entries entrySlice
+ var entries []*kv
for _, kv := range values {
entries = append(entries, kv)
}
- sort.Sort(entries)
+ slices.SortFunc(entries, (*kv).cmp)
// Create a new entry with a slightly modified key
mid := len(entries) / 2
@@ -854,10 +746,10 @@ func TestEmptyValueRangeProof(t *testing.T) {
start, end := 1, len(entries)-1
proof := memorydb.New()
- if err := trie.Prove(entries[start].k, 0, proof); err != nil {
+ if err := trie.Prove(entries[start].k, proof); err != nil {
t.Fatalf("Failed to prove the first node %v", err)
}
- if err := trie.Prove(entries[end-1].k, 0, proof); err != nil {
+ if err := trie.Prove(entries[end-1].k, proof); err != nil {
t.Fatalf("Failed to prove the last node %v", err)
}
var keys [][]byte
@@ -866,7 +758,7 @@ func TestEmptyValueRangeProof(t *testing.T) {
keys = append(keys, entries[i].k)
vals = append(vals, entries[i].v)
}
- _, err := VerifyRangeProof(trie.Hash(), keys[0], keys[len(keys)-1], keys, vals, proof)
+ _, err := VerifyRangeProof(trie.Hash(), keys[0], keys, vals, proof)
if err == nil {
t.Fatalf("Expected failure on noop entry")
}
@@ -877,11 +769,11 @@ func TestEmptyValueRangeProof(t *testing.T) {
// practically should be rejected.
func TestAllElementsEmptyValueRangeProof(t *testing.T) {
trie, values := randomTrie(512)
- var entries entrySlice
+ var entries []*kv
for _, kv := range values {
entries = append(entries, kv)
}
- sort.Sort(entries)
+ slices.SortFunc(entries, (*kv).cmp)
// Create a new entry with a slightly modified key
mid := len(entries) / 2
@@ -901,7 +793,7 @@ func TestAllElementsEmptyValueRangeProof(t *testing.T) {
keys = append(keys, entries[i].k)
vals = append(vals, entries[i].v)
}
- _, err := VerifyRangeProof(trie.Hash(), nil, nil, keys, vals, nil)
+ _, err := VerifyRangeProof(trie.Hash(), nil, keys, vals, nil)
if err == nil {
t.Fatalf("Expected failure on noop entry")
}
@@ -949,7 +841,7 @@ func BenchmarkProve(b *testing.B) {
for i := 0; i < b.N; i++ {
kv := vals[keys[i%len(keys)]]
proofs := memorydb.New()
- if trie.Prove(kv.k, 0, proofs); proofs.Len() == 0 {
+ if trie.Prove(kv.k, proofs); proofs.Len() == 0 {
b.Fatalf("zero length proof for %x", kv.k)
}
}
@@ -963,7 +855,7 @@ func BenchmarkVerifyProof(b *testing.B) {
for k := range vals {
keys = append(keys, k)
proof := memorydb.New()
- trie.Prove([]byte(k), 0, proof)
+ trie.Prove([]byte(k), proof)
proofs = append(proofs, proof)
}
@@ -983,19 +875,19 @@ func BenchmarkVerifyRangeProof5000(b *testing.B) { benchmarkVerifyRangeProof(b,
func benchmarkVerifyRangeProof(b *testing.B, size int) {
trie, vals := randomTrie(8192)
- var entries entrySlice
+ var entries []*kv
for _, kv := range vals {
entries = append(entries, kv)
}
- sort.Sort(entries)
+ slices.SortFunc(entries, (*kv).cmp)
start := 2
end := start + size
proof := memorydb.New()
- if err := trie.Prove(entries[start].k, 0, proof); err != nil {
+ if err := trie.Prove(entries[start].k, proof); err != nil {
b.Fatalf("Failed to prove the first node %v", err)
}
- if err := trie.Prove(entries[end-1].k, 0, proof); err != nil {
+ if err := trie.Prove(entries[end-1].k, proof); err != nil {
b.Fatalf("Failed to prove the last node %v", err)
}
var keys [][]byte
@@ -1007,7 +899,7 @@ func benchmarkVerifyRangeProof(b *testing.B, size int) {
b.ResetTimer()
for i := 0; i < b.N; i++ {
- _, err := VerifyRangeProof(trie.Hash(), keys[0], keys[len(keys)-1], keys, values, proof)
+ _, err := VerifyRangeProof(trie.Hash(), keys[0], keys, values, proof)
if err != nil {
b.Fatalf("Case %d(%d->%d) expect no error, got %v", i, start, end-1, err)
}
@@ -1020,11 +912,11 @@ func BenchmarkVerifyRangeNoProof1000(b *testing.B) { benchmarkVerifyRangeNoProof
func benchmarkVerifyRangeNoProof(b *testing.B, size int) {
trie, vals := randomTrie(size)
- var entries entrySlice
+ var entries []*kv
for _, kv := range vals {
entries = append(entries, kv)
}
- sort.Sort(entries)
+ slices.SortFunc(entries, (*kv).cmp)
var keys [][]byte
var values [][]byte
@@ -1034,7 +926,7 @@ func benchmarkVerifyRangeNoProof(b *testing.B, size int) {
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
- _, err := VerifyRangeProof(trie.Hash(), keys[0], keys[len(keys)-1], keys, values, nil)
+ _, err := VerifyRangeProof(trie.Hash(), keys[0], keys, values, nil)
if err != nil {
b.Fatalf("Expected no error, got %v", err)
}
@@ -1042,26 +934,26 @@ func benchmarkVerifyRangeNoProof(b *testing.B, size int) {
}
func randomTrie(n int) (*Trie, map[string]*kv) {
- trie := NewEmpty(NewDatabase(rawdb.NewMemoryDatabase()))
+ trie := NewEmpty(newTestDatabase(rawdb.NewMemoryDatabase(), rawdb.HashScheme))
vals := make(map[string]*kv)
for i := byte(0); i < 100; i++ {
value := &kv{common.LeftPadBytes([]byte{i}, 32), []byte{i}, false}
value2 := &kv{common.LeftPadBytes([]byte{i + 10}, 32), []byte{i}, false}
- trie.Update(value.k, value.v)
- trie.Update(value2.k, value2.v)
+ trie.MustUpdate(value.k, value.v)
+ trie.MustUpdate(value2.k, value2.v)
vals[string(value.k)] = value
vals[string(value2.k)] = value2
}
for i := 0; i < n; i++ {
value := &kv{randBytes(32), randBytes(20), false}
- trie.Update(value.k, value.v)
+ trie.MustUpdate(value.k, value.v)
vals[string(value.k)] = value
}
return trie, vals
}
func nonRandomTrie(n int) (*Trie, map[string]*kv) {
- trie := NewEmpty(NewDatabase(rawdb.NewMemoryDatabase()))
+ trie := NewEmpty(newTestDatabase(rawdb.NewMemoryDatabase(), rawdb.HashScheme))
vals := make(map[string]*kv)
max := uint64(0xffffffffffffffff)
for i := uint64(0); i < uint64(n); i++ {
@@ -1071,7 +963,7 @@ func nonRandomTrie(n int) (*Trie, map[string]*kv) {
binary.LittleEndian.PutUint64(value, i-max)
//value := &kv{common.LeftPadBytes([]byte{i}, 32), []byte{i}, false}
elem := &kv{key, value, false}
- trie.Update(elem.k, elem.v)
+ trie.MustUpdate(elem.k, elem.v)
vals[string(elem.k)] = elem
}
return trie, vals
@@ -1086,22 +978,21 @@ func TestRangeProofKeysWithSharedPrefix(t *testing.T) {
common.Hex2Bytes("02"),
common.Hex2Bytes("03"),
}
- trie := NewEmpty(NewDatabase(rawdb.NewMemoryDatabase()))
+ trie := NewEmpty(newTestDatabase(rawdb.NewMemoryDatabase(), rawdb.HashScheme))
for i, key := range keys {
- trie.Update(key, vals[i])
+ trie.MustUpdate(key, vals[i])
}
root := trie.Hash()
proof := memorydb.New()
start := common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000000")
- end := common.Hex2Bytes("ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff")
- if err := trie.Prove(start, 0, proof); err != nil {
+ if err := trie.Prove(start, proof); err != nil {
t.Fatalf("failed to prove start: %v", err)
}
- if err := trie.Prove(end, 0, proof); err != nil {
+ if err := trie.Prove(keys[len(keys)-1], proof); err != nil {
t.Fatalf("failed to prove end: %v", err)
}
- more, err := VerifyRangeProof(root, start, end, keys, vals, proof)
+ more, err := VerifyRangeProof(root, start, keys, vals, proof)
if err != nil {
t.Fatalf("failed to verify range proof: %v", err)
}
diff --git a/trie_by_cid/trie/secure_trie.go b/trie_by_cid/trie/secure_trie.go
index e748f74..77bbfae 100644
--- a/trie_by_cid/trie/secure_trie.go
+++ b/trie_by_cid/trie/secure_trie.go
@@ -20,9 +20,26 @@ import (
"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/core/types"
"github.com/ethereum/go-ethereum/rlp"
- log "github.com/sirupsen/logrus"
+
+ "github.com/cerc-io/ipld-eth-statedb/trie_by_cid/trie/trienode"
+ "github.com/cerc-io/ipld-eth-statedb/trie_by_cid/triedb/database"
)
+// SecureTrie is the old name of StateTrie.
+// Deprecated: use StateTrie.
+type SecureTrie = StateTrie
+
+// NewSecure creates a new StateTrie.
+// Deprecated: use NewStateTrie.
+func NewSecure(stateRoot common.Hash, owner common.Hash, root common.Hash, db database.Database) (*SecureTrie, error) {
+ id := &ID{
+ StateRoot: stateRoot,
+ Owner: owner,
+ Root: root,
+ }
+ return NewStateTrie(id, db)
+}
+
// StateTrie wraps a trie with key hashing. In a stateTrie trie, all
// access operations hash the key using keccak256. This prevents
// calling code from creating long chains of nodes that
@@ -35,7 +52,7 @@ import (
// StateTrie is not safe for concurrent use.
type StateTrie struct {
trie Trie
- preimages *preimageStore
+ db database.Database
hashKeyBuf [common.HashLength]byte
secKeyCache map[string][]byte
secKeyCacheOwner *StateTrie // Pointer to self, replace the key cache on mismatch
@@ -46,41 +63,44 @@ type StateTrie struct {
// If root is the zero hash or the sha3 hash of an empty string, the
// trie is initially empty. Otherwise, New will panic if db is nil
// and returns MissingNodeError if the root node cannot be found.
-func NewStateTrie(id *ID, db *Database, codec uint64) (*StateTrie, error) {
- // TODO: codec can be derived based on whether Owner is the zero hash
+func NewStateTrie(id *ID, db database.Database) (*StateTrie, error) {
if db == nil {
panic("trie.NewStateTrie called without a database")
}
- trie, err := New(id, db, codec)
+ trie, err := New(id, db)
if err != nil {
return nil, err
}
- return &StateTrie{trie: *trie, preimages: db.preimages}, nil
+ return &StateTrie{trie: *trie, db: db}, nil
}
-// Get returns the value for key stored in the trie.
+// MustGet returns the value for key stored in the trie.
// The value bytes must not be modified by the caller.
-func (t *StateTrie) Get(key []byte) []byte {
- res, err := t.TryGet(key)
- if err != nil {
- log.Error("Unhandled trie error in StateTrie.Get", "err", err)
- }
- return res
+//
+// This function will omit any encountered error but just
+// print out an error message.
+func (t *StateTrie) MustGet(key []byte) []byte {
+ return t.trie.MustGet(t.hashKey(key))
}
-// TryGet returns the value for key stored in the trie.
-// The value bytes must not be modified by the caller.
-// If the specified node is not in the trie, nil will be returned.
+// GetStorage attempts to retrieve a storage slot with provided account address
+// and slot key. The value bytes must not be modified by the caller.
+// If the specified storage slot is not in the trie, nil will be returned.
// If a trie node is not found in the database, a MissingNodeError is returned.
-func (t *StateTrie) TryGet(key []byte) ([]byte, error) {
- return t.trie.TryGet(t.hashKey(key))
+func (t *StateTrie) GetStorage(_ common.Address, key []byte) ([]byte, error) {
+ enc, err := t.trie.Get(t.hashKey(key))
+ if err != nil || len(enc) == 0 {
+ return nil, err
+ }
+ _, content, _, err := rlp.Split(enc)
+ return content, err
}
-// TryGetAccount attempts to retrieve an account with provided account address.
+// GetAccount attempts to retrieve an account with provided account address.
// If the specified account is not in the trie, nil will be returned.
// If a trie node is not found in the database, a MissingNodeError is returned.
-func (t *StateTrie) TryGetAccount(address common.Address) (*types.StateAccount, error) {
- res, err := t.trie.TryGet(t.hashKey(address.Bytes()))
+func (t *StateTrie) GetAccount(address common.Address) (*types.StateAccount, error) {
+ res, err := t.trie.Get(t.hashKey(address.Bytes()))
if res == nil || err != nil {
return nil, err
}
@@ -89,11 +109,11 @@ func (t *StateTrie) TryGetAccount(address common.Address) (*types.StateAccount,
return ret, err
}
-// TryGetAccountByHash does the same thing as TryGetAccount, however
-// it expects an account hash that is the hash of address. This constitutes an
-// abstraction leak, since the client code needs to know the key format.
-func (t *StateTrie) TryGetAccountByHash(addrHash common.Hash) (*types.StateAccount, error) {
- res, err := t.trie.TryGet(addrHash.Bytes())
+// GetAccountByHash does the same thing as GetAccount, however it expects an
+// account hash that is the hash of address. This constitutes an abstraction
+// leak, since the client code needs to know the key format.
+func (t *StateTrie) GetAccountByHash(addrHash common.Hash) (*types.StateAccount, error) {
+ res, err := t.trie.Get(addrHash.Bytes())
if res == nil || err != nil {
return nil, err
}
@@ -102,27 +122,30 @@ func (t *StateTrie) TryGetAccountByHash(addrHash common.Hash) (*types.StateAccou
return ret, err
}
-// TryGetNode attempts to retrieve a trie node by compact-encoded path. It is not
+// GetNode attempts to retrieve a trie node by compact-encoded path. It is not
// possible to use keybyte-encoding as the path might contain odd nibbles.
// If the specified trie node is not in the trie, nil will be returned.
// If a trie node is not found in the database, a MissingNodeError is returned.
-func (t *StateTrie) TryGetNode(path []byte) ([]byte, int, error) {
- return t.trie.TryGetNode(path)
+func (t *StateTrie) GetNode(path []byte) ([]byte, int, error) {
+ return t.trie.GetNode(path)
}
-// Update associates key with value in the trie. Subsequent calls to
+// MustUpdate associates key with value in the trie. Subsequent calls to
// Get will return value. If value has length zero, any existing value
// is deleted from the trie and calls to Get will return nil.
//
// The value bytes must not be modified by the caller while they are
// stored in the trie.
-func (t *StateTrie) Update(key, value []byte) {
- if err := t.TryUpdate(key, value); err != nil {
- log.Error("Unhandled trie error in StateTrie.Update", "err", err)
- }
+//
+// This function will omit any encountered error but just print out an
+// error message.
+func (t *StateTrie) MustUpdate(key, value []byte) {
+ hk := t.hashKey(key)
+ t.trie.MustUpdate(hk, value)
+ t.getSecKeyCache()[string(hk)] = common.CopyBytes(key)
}
-// TryUpdate associates key with value in the trie. Subsequent calls to
+// UpdateStorage associates key with value in the trie. Subsequent calls to
// Get will return value. If value has length zero, any existing value
// is deleted from the trie and calls to Get will return nil.
//
@@ -130,9 +153,10 @@ func (t *StateTrie) Update(key, value []byte) {
// stored in the trie.
//
// If a node is not found in the database, a MissingNodeError is returned.
-func (t *StateTrie) TryUpdate(key, value []byte) error {
+func (t *StateTrie) UpdateStorage(_ common.Address, key, value []byte) error {
hk := t.hashKey(key)
- err := t.trie.TryUpdate(hk, value)
+ v, _ := rlp.EncodeToBytes(value)
+ err := t.trie.Update(hk, v)
if err != nil {
return err
}
@@ -140,42 +164,46 @@ func (t *StateTrie) TryUpdate(key, value []byte) error {
return nil
}
-// TryUpdateAccount account will abstract the write of an account to the
-// secure trie.
-func (t *StateTrie) TryUpdateAccount(address common.Address, acc *types.StateAccount) error {
+// UpdateAccount will abstract the write of an account to the secure trie.
+func (t *StateTrie) UpdateAccount(address common.Address, acc *types.StateAccount) error {
hk := t.hashKey(address.Bytes())
data, err := rlp.EncodeToBytes(acc)
if err != nil {
return err
}
- if err := t.trie.TryUpdate(hk, data); err != nil {
+ if err := t.trie.Update(hk, data); err != nil {
return err
}
t.getSecKeyCache()[string(hk)] = address.Bytes()
return nil
}
-// Delete removes any existing value for key from the trie.
-func (t *StateTrie) Delete(key []byte) {
- if err := t.TryDelete(key); err != nil {
- log.Error("Unhandled trie error in StateTrie.Delete", "err", err)
- }
+func (t *StateTrie) UpdateContractCode(_ common.Address, _ common.Hash, _ []byte) error {
+ return nil
}
-// TryDelete removes any existing value for key from the trie.
-// If the specified trie node is not in the trie, nothing will be changed.
-// If a node is not found in the database, a MissingNodeError is returned.
-func (t *StateTrie) TryDelete(key []byte) error {
+// MustDelete removes any existing value for key from the trie. This function
+// will omit any encountered error but just print out an error message.
+func (t *StateTrie) MustDelete(key []byte) {
hk := t.hashKey(key)
delete(t.getSecKeyCache(), string(hk))
- return t.trie.TryDelete(hk)
+ t.trie.MustDelete(hk)
}
-// TryDeleteAccount abstracts an account deletion from the trie.
-func (t *StateTrie) TryDeleteAccount(address common.Address) error {
+// DeleteStorage removes any existing storage slot from the trie.
+// If the specified trie node is not in the trie, nothing will be changed.
+// If a node is not found in the database, a MissingNodeError is returned.
+func (t *StateTrie) DeleteStorage(_ common.Address, key []byte) error {
+ hk := t.hashKey(key)
+ delete(t.getSecKeyCache(), string(hk))
+ return t.trie.Delete(hk)
+}
+
+// DeleteAccount abstracts an account deletion from the trie.
+func (t *StateTrie) DeleteAccount(address common.Address) error {
hk := t.hashKey(address.Bytes())
delete(t.getSecKeyCache(), string(hk))
- return t.trie.TryDelete(hk)
+ return t.trie.Delete(hk)
}
// GetKey returns the sha3 preimage of a hashed key that was
@@ -184,10 +212,7 @@ func (t *StateTrie) GetKey(shaKey []byte) []byte {
if key, ok := t.getSecKeyCache()[string(shaKey)]; ok {
return key
}
- if t.preimages == nil {
- return nil
- }
- return t.preimages.preimage(common.BytesToHash(shaKey))
+ return t.db.Preimage(common.BytesToHash(shaKey))
}
// Commit collects all dirty nodes in the trie and replaces them with the
@@ -197,16 +222,14 @@ func (t *StateTrie) GetKey(shaKey []byte) []byte {
// All cached preimages will be also flushed if preimages recording is enabled.
// Once the trie is committed, it's not usable anymore. A new trie must
// be created with new root and updated trie database for following usage
-func (t *StateTrie) Commit(collectLeaf bool) (common.Hash, *NodeSet) {
+func (t *StateTrie) Commit(collectLeaf bool) (common.Hash, *trienode.NodeSet, error) {
// Write all the pre-images to the actual disk database
if len(t.getSecKeyCache()) > 0 {
- if t.preimages != nil {
- preimages := make(map[common.Hash][]byte)
- for hk, key := range t.secKeyCache {
- preimages[common.BytesToHash([]byte(hk))] = key
- }
- t.preimages.insertPreimage(preimages)
+ preimages := make(map[common.Hash][]byte)
+ for hk, key := range t.secKeyCache {
+ preimages[common.BytesToHash([]byte(hk))] = key
}
+ t.db.InsertPreimage(preimages)
t.secKeyCache = make(map[string][]byte)
}
// Commit the trie and return its modified nodeset.
@@ -223,17 +246,23 @@ func (t *StateTrie) Hash() common.Hash {
func (t *StateTrie) Copy() *StateTrie {
return &StateTrie{
trie: *t.trie.Copy(),
- preimages: t.preimages,
+ db: t.db,
secKeyCache: t.secKeyCache,
}
}
-// NodeIterator returns an iterator that returns nodes of the underlying trie. Iteration
-// starts at the key after the given start key.
-func (t *StateTrie) NodeIterator(start []byte) NodeIterator {
+// NodeIterator returns an iterator that returns nodes of the underlying trie.
+// Iteration starts at the key after the given start key.
+func (t *StateTrie) NodeIterator(start []byte) (NodeIterator, error) {
return t.trie.NodeIterator(start)
}
+// MustNodeIterator is a wrapper of NodeIterator and will omit any encountered
+// error but just print out an error message.
+func (t *StateTrie) MustNodeIterator(start []byte) NodeIterator {
+ return t.trie.MustNodeIterator(start)
+}
+
// hashKey returns the hash of key as an ephemeral buffer.
// The caller must not hold onto the return value because it will become
// invalid on the next call to hashKey or secKey.
diff --git a/trie_by_cid/trie/secure_trie_test.go b/trie_by_cid/trie/secure_trie_test.go
index 41edbc3..c5e775c 100644
--- a/trie_by_cid/trie/secure_trie_test.go
+++ b/trie_by_cid/trie/secure_trie_test.go
@@ -25,14 +25,22 @@ import (
"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/core/rawdb"
+ "github.com/ethereum/go-ethereum/core/types"
"github.com/ethereum/go-ethereum/crypto"
+
+ "github.com/cerc-io/ipld-eth-statedb/trie_by_cid/trie/trienode"
)
+func newEmptySecure() *StateTrie {
+ trie, _ := NewStateTrie(TrieID(types.EmptyRootHash), newTestDatabase(rawdb.NewMemoryDatabase(), rawdb.HashScheme))
+ return trie
+}
+
// makeTestStateTrie creates a large enough secure trie for testing.
-func makeTestStateTrie() (*Database, *StateTrie, map[string][]byte) {
+func makeTestStateTrie() (*testDb, *StateTrie, map[string][]byte) {
// Create an empty trie
- triedb := NewDatabase(rawdb.NewMemoryDatabase())
- trie, _ := NewStateTrie(TrieID(common.Hash{}), triedb, StateTrieCodec)
+ triedb := newTestDatabase(rawdb.NewMemoryDatabase(), rawdb.HashScheme)
+ trie, _ := NewStateTrie(TrieID(types.EmptyRootHash), triedb)
// Fill it with some arbitrary data
content := make(map[string][]byte)
@@ -40,33 +48,30 @@ func makeTestStateTrie() (*Database, *StateTrie, map[string][]byte) {
// Map the same data under multiple keys
key, val := common.LeftPadBytes([]byte{1, i}, 32), []byte{i}
content[string(key)] = val
- trie.Update(key, val)
+ trie.MustUpdate(key, val)
key, val = common.LeftPadBytes([]byte{2, i}, 32), []byte{i}
content[string(key)] = val
- trie.Update(key, val)
+ trie.MustUpdate(key, val)
// Add some other data to inflate the trie
for j := byte(3); j < 13; j++ {
key, val = common.LeftPadBytes([]byte{j, i}, 32), []byte{j, i}
content[string(key)] = val
- trie.Update(key, val)
+ trie.MustUpdate(key, val)
}
}
- root, nodes := trie.Commit(false)
- if err := triedb.Update(NewWithNodeSet(nodes)); err != nil {
+ root, nodes, _ := trie.Commit(false)
+ if err := triedb.Update(root, types.EmptyRootHash, trienode.NewWithNodeSet(nodes)); err != nil {
panic(fmt.Errorf("failed to commit db %v", err))
}
// Re-create the trie based on the new state
- trie, _ = NewStateTrie(TrieID(root), triedb, StateTrieCodec)
+ trie, _ = NewStateTrie(TrieID(root), triedb)
return triedb, trie, content
}
func TestSecureDelete(t *testing.T) {
- trie, err := NewStateTrie(TrieID(common.Hash{}), NewDatabase(rawdb.NewMemoryDatabase()), StateTrieCodec)
- if err != nil {
- t.Fatal(err)
- }
+ trie := newEmptySecure()
vals := []struct{ k, v string }{
{"do", "verb"},
{"ether", "wookiedoo"},
@@ -79,9 +84,9 @@ func TestSecureDelete(t *testing.T) {
}
for _, val := range vals {
if val.v != "" {
- trie.Update([]byte(val.k), []byte(val.v))
+ trie.MustUpdate([]byte(val.k), []byte(val.v))
} else {
- trie.Delete([]byte(val.k))
+ trie.MustDelete([]byte(val.k))
}
}
hash := trie.Hash()
@@ -92,17 +97,14 @@ func TestSecureDelete(t *testing.T) {
}
func TestSecureGetKey(t *testing.T) {
- trie, err := NewStateTrie(TrieID(common.Hash{}), NewDatabase(rawdb.NewMemoryDatabase()), StateTrieCodec)
- if err != nil {
- t.Fatal(err)
- }
- trie.Update([]byte("foo"), []byte("bar"))
+ trie := newEmptySecure()
+ trie.MustUpdate([]byte("foo"), []byte("bar"))
key := []byte("foo")
value := []byte("bar")
seckey := crypto.Keccak256(key)
- if !bytes.Equal(trie.Get(key), value) {
+ if !bytes.Equal(trie.MustGet(key), value) {
t.Errorf("Get did not return bar")
}
if k := trie.GetKey(seckey); !bytes.Equal(k, key) {
@@ -129,15 +131,15 @@ func TestStateTrieConcurrency(t *testing.T) {
for j := byte(0); j < 255; j++ {
// Map the same data under multiple keys
key, val := common.LeftPadBytes([]byte{byte(index), 1, j}, 32), []byte{j}
- tries[index].Update(key, val)
+ tries[index].MustUpdate(key, val)
key, val = common.LeftPadBytes([]byte{byte(index), 2, j}, 32), []byte{j}
- tries[index].Update(key, val)
+ tries[index].MustUpdate(key, val)
// Add some other data to inflate the trie
for k := byte(3); k < 13; k++ {
key, val = common.LeftPadBytes([]byte{byte(index), k, j}, 32), []byte{k, j}
- tries[index].Update(key, val)
+ tries[index].MustUpdate(key, val)
}
}
tries[index].Commit(false)
diff --git a/trie_by_cid/trie/tracer.go b/trie_by_cid/trie/tracer.go
index a27e371..5786af4 100644
--- a/trie_by_cid/trie/tracer.go
+++ b/trie_by_cid/trie/tracer.go
@@ -16,7 +16,9 @@
package trie
-import "github.com/ethereum/go-ethereum/common"
+import (
+ "github.com/ethereum/go-ethereum/common"
+)
// tracer tracks the changes of trie nodes. During the trie operations,
// some nodes can be deleted from the trie, while these deleted nodes
@@ -111,15 +113,18 @@ func (t *tracer) copy() *tracer {
}
}
-// markDeletions puts all tracked deletions into the provided nodeset.
-func (t *tracer) markDeletions(set *NodeSet) {
+// deletedNodes returns a list of node paths which are deleted from the trie.
+func (t *tracer) deletedNodes() []string {
+ var paths []string
for path := range t.deletes {
// It's possible a few deleted nodes were embedded
// in their parent before, the deletions can be no
// effect by deleting nothing, filter them out.
- if _, ok := set.accessList[path]; !ok {
+ _, ok := t.accessList[path]
+ if !ok {
continue
}
- set.markDeleted([]byte(path))
+ paths = append(paths, path)
}
+ return paths
}
diff --git a/trie_by_cid/trie/trie.go b/trie_by_cid/trie/trie.go
index 0ab997c..314709a 100644
--- a/trie_by_cid/trie/trie.go
+++ b/trie_by_cid/trie/trie.go
@@ -1,3 +1,19 @@
+// Copyright 2014 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see .
+
// Package trie implements Merkle Patricia Tries.
package trie
@@ -8,14 +24,10 @@ import (
"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/core/types"
- log "github.com/sirupsen/logrus"
+ "github.com/ethereum/go-ethereum/log"
- "github.com/cerc-io/plugeth-statediff/indexer/ipld"
-)
-
-var (
- StateTrieCodec uint64 = ipld.MEthStateTrie
- StorageTrieCodec uint64 = ipld.MEthStorageTrie
+ "github.com/cerc-io/ipld-eth-statedb/trie_by_cid/trie/trienode"
+ "github.com/cerc-io/ipld-eth-statedb/trie_by_cid/triedb/database"
)
// Trie is a Merkle Patricia Trie. Use New to create a trie that sits on
@@ -29,6 +41,10 @@ type Trie struct {
root node
owner common.Hash
+ // Flag whether the commit operation is already performed. If so the
+ // trie is not usable(latest states is invisible).
+ committed bool
+
// Keep track of the number leaves which have been inserted since the last
// hashing operation. This number will not directly map to the number of
// actually unhashed nodes.
@@ -50,24 +66,23 @@ func (t *Trie) newFlag() nodeFlag {
// Copy returns a copy of Trie.
func (t *Trie) Copy() *Trie {
return &Trie{
- root: t.root,
- owner: t.owner,
- unhashed: t.unhashed,
- reader: t.reader,
- tracer: t.tracer.copy(),
+ root: t.root,
+ owner: t.owner,
+ committed: t.committed,
+ unhashed: t.unhashed,
+ reader: t.reader,
+ tracer: t.tracer.copy(),
}
}
-// New creates a trie instance with the provided trie id and the read-only
+// New creates the trie instance with provided trie id and the read-only
// database. The state specified by trie id must be available, otherwise
// an error will be returned. The trie root specified by trie id can be
// zero hash or the sha3 hash of an empty string, then trie is initially
// empty, otherwise, the root node must be present in database or returns
// a MissingNodeError if not.
-// The passed codec specifies whether to read state or storage nodes from the
-// trie.
-func New(id *ID, db NodeReader, codec uint64) (*Trie, error) {
- reader, err := newTrieReader(id.StateRoot, id.Owner, db, codec)
+func New(id *ID, db database.Database) (*Trie, error) {
+ reader, err := newTrieReader(id.StateRoot, id.Owner, db)
if err != nil {
return nil, err
}
@@ -87,42 +102,59 @@ func New(id *ID, db NodeReader, codec uint64) (*Trie, error) {
}
// NewEmpty is a shortcut to create empty tree. It's mostly used in tests.
-func NewEmpty(db *Database) *Trie {
- tr, err := New(TrieID(common.Hash{}), db, StateTrieCodec)
- if err != nil {
- panic(err)
- }
+func NewEmpty(db database.Database) *Trie {
+ tr, _ := New(TrieID(types.EmptyRootHash), db)
return tr
}
+// MustNodeIterator is a wrapper of NodeIterator and will omit any encountered
+// error but just print out an error message.
+func (t *Trie) MustNodeIterator(start []byte) NodeIterator {
+ it, err := t.NodeIterator(start)
+ if err != nil {
+ log.Error("Unhandled trie error in Trie.NodeIterator", "err", err)
+ }
+ return it
+}
+
// NodeIterator returns an iterator that returns nodes of the trie. Iteration starts at
// the key after the given start key.
-func (t *Trie) NodeIterator(start []byte) NodeIterator {
- return newNodeIterator(t, start)
+func (t *Trie) NodeIterator(start []byte) (NodeIterator, error) {
+ // Short circuit if the trie is already committed and not usable.
+ if t.committed {
+ return nil, ErrCommitted
+ }
+ return newNodeIterator(t, start), nil
}
-// Get returns the value for key stored in the trie.
-// The value bytes must not be modified by the caller.
-func (t *Trie) Get(key []byte) []byte {
- res, err := t.TryGet(key)
+// MustGet is a wrapper of Get and will omit any encountered error but just
+// print out an error message.
+func (t *Trie) MustGet(key []byte) []byte {
+ res, err := t.Get(key)
if err != nil {
log.Error("Unhandled trie error in Trie.Get", "err", err)
}
return res
}
-// TryGet returns the value for key stored in the trie.
+// Get returns the value for key stored in the trie.
// The value bytes must not be modified by the caller.
-// If a node was not found in the database, a MissingNodeError is returned.
-func (t *Trie) TryGet(key []byte) ([]byte, error) {
- value, newroot, didResolve, err := t.tryGet(t.root, keybytesToHex(key), 0)
+//
+// If the requested node is not present in trie, no error will be returned.
+// If the trie is corrupted, a MissingNodeError is returned.
+func (t *Trie) Get(key []byte) ([]byte, error) {
+ // Short circuit if the trie is already committed and not usable.
+ if t.committed {
+ return nil, ErrCommitted
+ }
+ value, newroot, didResolve, err := t.get(t.root, keybytesToHex(key), 0)
if err == nil && didResolve {
t.root = newroot
}
return value, err
}
-func (t *Trie) tryGet(origNode node, key []byte, pos int) (value []byte, newnode node, didResolve bool, err error) {
+func (t *Trie) get(origNode node, key []byte, pos int) (value []byte, newnode node, didResolve bool, err error) {
switch n := (origNode).(type) {
case nil:
return nil, nil, false, nil
@@ -133,14 +165,14 @@ func (t *Trie) tryGet(origNode node, key []byte, pos int) (value []byte, newnode
// key not found in trie
return nil, n, false, nil
}
- value, newnode, didResolve, err = t.tryGet(n.Val, key, pos+len(n.Key))
+ value, newnode, didResolve, err = t.get(n.Val, key, pos+len(n.Key))
if err == nil && didResolve {
n = n.copy()
n.Val = newnode
}
return value, n, didResolve, err
case *fullNode:
- value, newnode, didResolve, err = t.tryGet(n.Children[key[pos]], key, pos+1)
+ value, newnode, didResolve, err = t.get(n.Children[key[pos]], key, pos+1)
if err == nil && didResolve {
n = n.copy()
n.Children[key[pos]] = newnode
@@ -151,17 +183,34 @@ func (t *Trie) tryGet(origNode node, key []byte, pos int) (value []byte, newnode
if err != nil {
return nil, n, true, err
}
- value, newnode, _, err := t.tryGet(child, key, pos)
+ value, newnode, _, err := t.get(child, key, pos)
return value, newnode, true, err
default:
panic(fmt.Sprintf("%T: invalid node: %v", origNode, origNode))
}
}
-// TryGetNode attempts to retrieve a trie node by compact-encoded path. It is not
-// possible to use keybyte-encoding as the path might contain odd nibbles.
-func (t *Trie) TryGetNode(path []byte) ([]byte, int, error) {
- item, newroot, resolved, err := t.tryGetNode(t.root, compactToHex(path), 0)
+// MustGetNode is a wrapper of GetNode and will omit any encountered error but
+// just print out an error message.
+func (t *Trie) MustGetNode(path []byte) ([]byte, int) {
+ item, resolved, err := t.GetNode(path)
+ if err != nil {
+ log.Error("Unhandled trie error in Trie.GetNode", "err", err)
+ }
+ return item, resolved
+}
+
+// GetNode retrieves a trie node by compact-encoded path. It is not possible
+// to use keybyte-encoding as the path might contain odd nibbles.
+//
+// If the requested node is not present in trie, no error will be returned.
+// If the trie is corrupted, a MissingNodeError is returned.
+func (t *Trie) GetNode(path []byte) ([]byte, int, error) {
+ // Short circuit if the trie is already committed and not usable.
+ if t.committed {
+ return nil, 0, ErrCommitted
+ }
+ item, newroot, resolved, err := t.getNode(t.root, compactToHex(path), 0)
if err != nil {
return nil, resolved, err
}
@@ -171,10 +220,10 @@ func (t *Trie) TryGetNode(path []byte) ([]byte, int, error) {
if item == nil {
return nil, resolved, nil
}
- return item, resolved, err
+ return item, resolved, nil
}
-func (t *Trie) tryGetNode(origNode node, path []byte, pos int) (item []byte, newnode node, resolved int, err error) {
+func (t *Trie) getNode(origNode node, path []byte, pos int) (item []byte, newnode node, resolved int, err error) {
// If non-existent path requested, abort
if origNode == nil {
return nil, nil, 0, nil
@@ -193,7 +242,7 @@ func (t *Trie) tryGetNode(origNode node, path []byte, pos int) (item []byte, new
if hash == nil {
return nil, origNode, 0, errors.New("non-consensus node")
}
- blob, err := t.reader.nodeBlob(path, common.BytesToHash(hash))
+ blob, err := t.reader.node(path, common.BytesToHash(hash))
return blob, origNode, 1, err
}
// Path still needs to be traversed, descend into children
@@ -207,7 +256,7 @@ func (t *Trie) tryGetNode(origNode node, path []byte, pos int) (item []byte, new
// Path branches off from short node
return nil, n, 0, nil
}
- item, newnode, resolved, err = t.tryGetNode(n.Val, path, pos+len(n.Key))
+ item, newnode, resolved, err = t.getNode(n.Val, path, pos+len(n.Key))
if err == nil && resolved > 0 {
n = n.copy()
n.Val = newnode
@@ -215,7 +264,7 @@ func (t *Trie) tryGetNode(origNode node, path []byte, pos int) (item []byte, new
return item, n, resolved, err
case *fullNode:
- item, newnode, resolved, err = t.tryGetNode(n.Children[path[pos]], path, pos+1)
+ item, newnode, resolved, err = t.getNode(n.Children[path[pos]], path, pos+1)
if err == nil && resolved > 0 {
n = n.copy()
n.Children[path[pos]] = newnode
@@ -227,7 +276,7 @@ func (t *Trie) tryGetNode(origNode node, path []byte, pos int) (item []byte, new
if err != nil {
return nil, n, 1, err
}
- item, newnode, resolved, err := t.tryGetNode(child, path, pos)
+ item, newnode, resolved, err := t.getNode(child, path, pos)
return item, newnode, resolved + 1, err
default:
@@ -235,33 +284,32 @@ func (t *Trie) tryGetNode(origNode node, path []byte, pos int) (item []byte, new
}
}
+// MustUpdate is a wrapper of Update and will omit any encountered error but
+// just print out an error message.
+func (t *Trie) MustUpdate(key, value []byte) {
+ if err := t.Update(key, value); err != nil {
+ log.Error("Unhandled trie error in Trie.Update", "err", err)
+ }
+}
+
// Update associates key with value in the trie. Subsequent calls to
// Get will return value. If value has length zero, any existing value
// is deleted from the trie and calls to Get will return nil.
//
// The value bytes must not be modified by the caller while they are
// stored in the trie.
-func (t *Trie) Update(key, value []byte) {
- if err := t.TryUpdate(key, value); err != nil {
- log.Error("Unhandled trie error in Trie.Update", "err", err)
+//
+// If the requested node is not present in trie, no error will be returned.
+// If the trie is corrupted, a MissingNodeError is returned.
+func (t *Trie) Update(key, value []byte) error {
+ // Short circuit if the trie is already committed and not usable.
+ if t.committed {
+ return ErrCommitted
}
+ return t.update(key, value)
}
-// TryUpdate associates key with value in the trie. Subsequent calls to
-// Get will return value. If value has length zero, any existing value
-// is deleted from the trie and calls to Get will return nil.
-//
-// The value bytes must not be modified by the caller while they are
-// stored in the trie.
-//
-// If a node was not found in the database, a MissingNodeError is returned.
-func (t *Trie) TryUpdate(key, value []byte) error {
- return t.tryUpdate(key, value)
-}
-
-// tryUpdate expects an RLP-encoded value and performs the core function
-// for TryUpdate and TryUpdateAccount.
-func (t *Trie) tryUpdate(key, value []byte) error {
+func (t *Trie) update(key, value []byte) error {
t.unhashed++
k := keybytesToHex(key)
if len(value) != 0 {
@@ -359,16 +407,23 @@ func (t *Trie) insert(n node, prefix, key []byte, value node) (bool, node, error
}
}
-// Delete removes any existing value for key from the trie.
-func (t *Trie) Delete(key []byte) {
- if err := t.TryDelete(key); err != nil {
+// MustDelete is a wrapper of Delete and will omit any encountered error but
+// just print out an error message.
+func (t *Trie) MustDelete(key []byte) {
+ if err := t.Delete(key); err != nil {
log.Error("Unhandled trie error in Trie.Delete", "err", err)
}
}
-// TryDelete removes any existing value for key from the trie.
-// If a node was not found in the database, a MissingNodeError is returned.
-func (t *Trie) TryDelete(key []byte) error {
+// Delete removes any existing value for key from the trie.
+//
+// If the requested node is not present in trie, no error will be returned.
+// If the trie is corrupted, a MissingNodeError is returned.
+func (t *Trie) Delete(key []byte) error {
+ // Short circuit if the trie is already committed and not usable.
+ if t.committed {
+ return ErrCommitted
+ }
t.unhashed++
k := keybytesToHex(key)
_, n, err := t.delete(t.root, nil, k)
@@ -532,7 +587,7 @@ func (t *Trie) resolve(n node, prefix []byte) (node, error) {
// node's original value. The rlp-encoded blob is preferred to be loaded from
// database because it's easy to decode node while complex to encode node to blob.
func (t *Trie) resolveAndTrack(n hashNode, prefix []byte) (node, error) {
- blob, err := t.reader.nodeBlob(prefix, common.BytesToHash(n))
+ blob, err := t.reader.node(prefix, common.BytesToHash(n))
if err != nil {
return nil, err
}
@@ -554,17 +609,25 @@ func (t *Trie) Hash() common.Hash {
// The returned nodeset can be nil if the trie is clean (nothing to commit).
// Once the trie is committed, it's not usable anymore. A new trie must
// be created with new root and updated trie database for following usage
-func (t *Trie) Commit(collectLeaf bool) (common.Hash, *NodeSet) {
+func (t *Trie) Commit(collectLeaf bool) (common.Hash, *trienode.NodeSet, error) {
defer t.tracer.reset()
-
- nodes := NewNodeSet(t.owner, t.tracer.accessList)
- t.tracer.markDeletions(nodes)
-
+ defer func() {
+ t.committed = true
+ }()
// Trie is empty and can be classified into two types of situations:
- // - The trie was empty and no update happens
- // - The trie was non-empty and all nodes are dropped
+ // (a) The trie was empty and no update happens => return nil
+ // (b) The trie was non-empty and all nodes are dropped => return
+ // the node set includes all deleted nodes
if t.root == nil {
- return types.EmptyRootHash, nodes
+ paths := t.tracer.deletedNodes()
+ if len(paths) == 0 {
+ return types.EmptyRootHash, nil, nil // case (a)
+ }
+ nodes := trienode.NewNodeSet(t.owner)
+ for _, path := range paths {
+ nodes.AddNode([]byte(path), trienode.NewDeleted())
+ }
+ return types.EmptyRootHash, nodes, nil // case (b)
}
// Derive the hash for all dirty nodes first. We hold the assumption
// in the following procedure that all nodes are hashed.
@@ -576,10 +639,14 @@ func (t *Trie) Commit(collectLeaf bool) (common.Hash, *NodeSet) {
// Replace the root node with the origin hash in order to
// ensure all resolved nodes are dropped after the commit.
t.root = hashedNode
- return rootHash, nil
+ return rootHash, nil, nil
}
- t.root = newCommitter(nodes, collectLeaf).Commit(t.root)
- return rootHash, nodes
+ nodes := trienode.NewNodeSet(t.owner)
+ for _, path := range t.tracer.deletedNodes() {
+ nodes.AddNode([]byte(path), trienode.NewDeleted())
+ }
+ t.root = newCommitter(nodes, t.tracer, collectLeaf).Commit(t.root)
+ return rootHash, nodes, nil
}
// hashRoot calculates the root hash of the given trie
@@ -603,4 +670,5 @@ func (t *Trie) Reset() {
t.owner = common.Hash{}
t.unhashed = 0
t.tracer.reset()
+ t.committed = false
}
diff --git a/trie_by_cid/trie/trie_reader.go b/trie_by_cid/trie/trie_reader.go
index b0a7fdd..091b0a1 100644
--- a/trie_by_cid/trie/trie_reader.go
+++ b/trie_by_cid/trie/trie_reader.go
@@ -17,44 +17,33 @@
package trie
import (
- "fmt"
-
"github.com/ethereum/go-ethereum/common"
+ "github.com/ethereum/go-ethereum/core/types"
+ "github.com/ethereum/go-ethereum/log"
+
+ "github.com/cerc-io/ipld-eth-statedb/trie_by_cid/trie/triestate"
+ "github.com/cerc-io/ipld-eth-statedb/trie_by_cid/triedb/database"
)
-// Reader wraps the Node and NodeBlob method of a backing trie store.
-type Reader interface {
- // Node retrieves the trie node with the provided trie identifier, hexary
- // node path and the corresponding node hash.
- // No error will be returned if the node is not found.
- Node(owner common.Hash, path []byte, hash common.Hash) (node, error)
-
- // NodeBlob retrieves the RLP-encoded trie node blob with the provided trie
- // identifier, hexary node path and the corresponding node hash.
- // No error will be returned if the node is not found.
- NodeBlob(owner common.Hash, path []byte, hash common.Hash) ([]byte, error)
-}
-
-// NodeReader wraps all the necessary functions for accessing trie node.
-type NodeReader interface {
- // GetReader returns a reader for accessing all trie nodes with provided
- // state root. Nil is returned in case the state is not available.
- GetReader(root common.Hash, codec uint64) Reader
-}
-
// trieReader is a wrapper of the underlying node reader. It's not safe
// for concurrent usage.
type trieReader struct {
owner common.Hash
- reader Reader
+ reader database.Reader
banned map[string]struct{} // Marker to prevent node from being accessed, for tests
}
// newTrieReader initializes the trie reader with the given node reader.
-func newTrieReader(stateRoot, owner common.Hash, db NodeReader, codec uint64) (*trieReader, error) {
- reader := db.GetReader(stateRoot, codec)
- if reader == nil {
- return nil, fmt.Errorf("state not found #%x", stateRoot)
+func newTrieReader(stateRoot, owner common.Hash, db database.Database) (*trieReader, error) {
+ if stateRoot == (common.Hash{}) || stateRoot == types.EmptyRootHash {
+ if stateRoot == (common.Hash{}) {
+ log.Error("Zero state root hash!")
+ }
+ return &trieReader{owner: owner}, nil
+ }
+ reader, err := db.Reader(stateRoot)
+ if err != nil {
+ return nil, &MissingNodeError{Owner: owner, NodeHash: stateRoot, err: err}
}
return &trieReader{owner: owner, reader: reader}, nil
}
@@ -65,30 +54,10 @@ func newEmptyReader() *trieReader {
return &trieReader{}
}
-// node retrieves the trie node with the provided trie node information.
-// An MissingNodeError will be returned in case the node is not found or
-// any error is encountered.
-func (r *trieReader) node(path []byte, hash common.Hash) (node, error) {
- // Perform the logics in tests for preventing trie node access.
- if r.banned != nil {
- if _, ok := r.banned[string(path)]; ok {
- return nil, &MissingNodeError{Owner: r.owner, NodeHash: hash, Path: path}
- }
- }
- if r.reader == nil {
- return nil, &MissingNodeError{Owner: r.owner, NodeHash: hash, Path: path}
- }
- node, err := r.reader.Node(r.owner, path, hash)
- if err != nil || node == nil {
- return nil, &MissingNodeError{Owner: r.owner, NodeHash: hash, Path: path, err: err}
- }
- return node, nil
-}
-
// node retrieves the rlp-encoded trie node with the provided trie node
// information. An MissingNodeError will be returned in case the node is
// not found or any error is encountered.
-func (r *trieReader) nodeBlob(path []byte, hash common.Hash) ([]byte, error) {
+func (r *trieReader) node(path []byte, hash common.Hash) ([]byte, error) {
// Perform the logics in tests for preventing trie node access.
if r.banned != nil {
if _, ok := r.banned[string(path)]; ok {
@@ -98,9 +67,29 @@ func (r *trieReader) nodeBlob(path []byte, hash common.Hash) ([]byte, error) {
if r.reader == nil {
return nil, &MissingNodeError{Owner: r.owner, NodeHash: hash, Path: path}
}
- blob, err := r.reader.NodeBlob(r.owner, path, hash)
+ blob, err := r.reader.Node(r.owner, path, hash)
if err != nil || len(blob) == 0 {
return nil, &MissingNodeError{Owner: r.owner, NodeHash: hash, Path: path, err: err}
}
return blob, nil
}
+
+// MerkleLoader implements triestate.TrieLoader for constructing tries.
+type MerkleLoader struct {
+ db database.Database
+}
+
+// NewMerkleLoader creates the merkle trie loader.
+func NewMerkleLoader(db database.Database) *MerkleLoader {
+ return &MerkleLoader{db: db}
+}
+
+// OpenTrie opens the main account trie.
+func (l *MerkleLoader) OpenTrie(root common.Hash) (triestate.Trie, error) {
+ return New(TrieID(root), l.db)
+}
+
+// OpenStorageTrie opens the storage trie of an account.
+func (l *MerkleLoader) OpenStorageTrie(stateRoot common.Hash, addrHash, root common.Hash) (triestate.Trie, error) {
+ return New(StorageTrieID(stateRoot, addrHash, root), l.db)
+}
diff --git a/trie_by_cid/trie/trie_test.go b/trie_by_cid/trie/trie_test.go
index 76aaf72..75eb5ef 100644
--- a/trie_by_cid/trie/trie_test.go
+++ b/trie_by_cid/trie/trie_test.go
@@ -21,9 +21,11 @@ import (
"encoding/binary"
"errors"
"fmt"
- "math/big"
+ "hash"
+ "io"
"math/rand"
"reflect"
+ "sort"
"testing"
"testing/quick"
@@ -32,7 +34,12 @@ import (
"github.com/ethereum/go-ethereum/core/rawdb"
"github.com/ethereum/go-ethereum/core/types"
"github.com/ethereum/go-ethereum/crypto"
+ "github.com/ethereum/go-ethereum/ethdb"
"github.com/ethereum/go-ethereum/rlp"
+ "github.com/holiman/uint256"
+ "golang.org/x/crypto/sha3"
+
+ "github.com/cerc-io/ipld-eth-statedb/trie_by_cid/trie/trienode"
)
func init() {
@@ -41,7 +48,7 @@ func init() {
}
func TestEmptyTrie(t *testing.T) {
- trie := NewEmpty(NewDatabase(rawdb.NewMemoryDatabase()))
+ trie := NewEmpty(newTestDatabase(rawdb.NewMemoryDatabase(), rawdb.HashScheme))
res := trie.Hash()
exp := types.EmptyRootHash
if res != exp {
@@ -50,18 +57,23 @@ func TestEmptyTrie(t *testing.T) {
}
func TestNull(t *testing.T) {
- trie := NewEmpty(NewDatabase(rawdb.NewMemoryDatabase()))
+ trie := NewEmpty(newTestDatabase(rawdb.NewMemoryDatabase(), rawdb.HashScheme))
key := make([]byte, 32)
value := []byte("test")
- trie.Update(key, value)
- if !bytes.Equal(trie.Get(key), value) {
+ trie.MustUpdate(key, value)
+ if !bytes.Equal(trie.MustGet(key), value) {
t.Fatal("wrong value")
}
}
func TestMissingRoot(t *testing.T) {
+ testMissingRoot(t, rawdb.HashScheme)
+ testMissingRoot(t, rawdb.PathScheme)
+}
+
+func testMissingRoot(t *testing.T, scheme string) {
root := common.HexToHash("0beec7b5ea3f0fdbc95d0dd47f3c5bc275da8a33")
- trie, err := NewAccountTrie(TrieID(root), NewDatabase(rawdb.NewMemoryDatabase()))
+ trie, err := New(TrieID(root), newTestDatabase(rawdb.NewMemoryDatabase(), scheme))
if trie != nil {
t.Error("New returned non-nil trie for invalid root")
}
@@ -70,80 +82,94 @@ func TestMissingRoot(t *testing.T) {
}
}
-func TestMissingNodeMemonly(t *testing.T) { testMissingNode(t, true) }
+func TestMissingNode(t *testing.T) {
+ testMissingNode(t, false, rawdb.HashScheme)
+ testMissingNode(t, false, rawdb.PathScheme)
+ testMissingNode(t, true, rawdb.HashScheme)
+ testMissingNode(t, true, rawdb.PathScheme)
+}
-func testMissingNode(t *testing.T, memonly bool) {
+func testMissingNode(t *testing.T, memonly bool, scheme string) {
diskdb := rawdb.NewMemoryDatabase()
- triedb := NewDatabase(diskdb)
+ triedb := newTestDatabase(diskdb, scheme)
trie := NewEmpty(triedb)
updateString(trie, "120000", "qwerqwerqwerqwerqwerqwerqwerqwer")
updateString(trie, "123456", "asdfasdfasdfasdfasdfasdfasdfasdf")
- root, nodes := trie.Commit(false)
- triedb.Update(NewWithNodeSet(nodes))
+ root, nodes, _ := trie.Commit(false)
+ triedb.Update(root, types.EmptyRootHash, trienode.NewWithNodeSet(nodes))
- trie, _ = NewAccountTrie(TrieID(root), triedb)
- _, err := trie.TryGet([]byte("120000"))
+ if !memonly {
+ triedb.Commit(root)
+ }
+
+ trie, _ = New(TrieID(root), triedb)
+ _, err := trie.Get([]byte("120000"))
if err != nil {
t.Errorf("Unexpected error: %v", err)
}
- trie, _ = NewAccountTrie(TrieID(root), triedb)
- _, err = trie.TryGet([]byte("120099"))
+ trie, _ = New(TrieID(root), triedb)
+ _, err = trie.Get([]byte("120099"))
if err != nil {
t.Errorf("Unexpected error: %v", err)
}
- trie, _ = NewAccountTrie(TrieID(root), triedb)
- _, err = trie.TryGet([]byte("123456"))
+ trie, _ = New(TrieID(root), triedb)
+ _, err = trie.Get([]byte("123456"))
if err != nil {
t.Errorf("Unexpected error: %v", err)
}
- trie, _ = NewAccountTrie(TrieID(root), triedb)
- err = trie.TryUpdate([]byte("120099"), []byte("zxcvzxcvzxcvzxcvzxcvzxcvzxcvzxcv"))
+ trie, _ = New(TrieID(root), triedb)
+ err = trie.Update([]byte("120099"), []byte("zxcvzxcvzxcvzxcvzxcvzxcvzxcvzxcv"))
if err != nil {
t.Errorf("Unexpected error: %v", err)
}
- trie, _ = NewAccountTrie(TrieID(root), triedb)
- err = trie.TryDelete([]byte("123456"))
+ trie, _ = New(TrieID(root), triedb)
+ err = trie.Delete([]byte("123456"))
if err != nil {
t.Errorf("Unexpected error: %v", err)
}
- hash := common.HexToHash("0xe1d943cc8f061a0c0b98162830b970395ac9315654824bf21b73b891365262f9")
+ var (
+ path []byte
+ hash = common.HexToHash("0xe1d943cc8f061a0c0b98162830b970395ac9315654824bf21b73b891365262f9")
+ )
+ for p, n := range nodes.Nodes {
+ if n.Hash == hash {
+ path = common.CopyBytes([]byte(p))
+ break
+ }
+ }
+ trie, _ = New(TrieID(root), triedb)
if memonly {
- delete(triedb.dirties, hash)
+ trie.reader.banned = map[string]struct{}{string(path): {}}
} else {
- diskdb.Delete(hash[:])
+ rawdb.DeleteTrieNode(diskdb, common.Hash{}, path, hash, scheme)
}
- trie, _ = NewAccountTrie(TrieID(root), triedb)
- _, err = trie.TryGet([]byte("120000"))
+ _, err = trie.Get([]byte("120000"))
if _, ok := err.(*MissingNodeError); !ok {
t.Errorf("Wrong error: %v", err)
}
- trie, _ = NewAccountTrie(TrieID(root), triedb)
- _, err = trie.TryGet([]byte("120099"))
+ _, err = trie.Get([]byte("120099"))
if _, ok := err.(*MissingNodeError); !ok {
t.Errorf("Wrong error: %v", err)
}
- trie, _ = NewAccountTrie(TrieID(root), triedb)
- _, err = trie.TryGet([]byte("123456"))
+ _, err = trie.Get([]byte("123456"))
if err != nil {
t.Errorf("Unexpected error: %v", err)
}
- trie, _ = NewAccountTrie(TrieID(root), triedb)
- err = trie.TryUpdate([]byte("120099"), []byte("zxcv"))
+ err = trie.Update([]byte("120099"), []byte("zxcv"))
if _, ok := err.(*MissingNodeError); !ok {
t.Errorf("Wrong error: %v", err)
}
- trie, _ = NewAccountTrie(TrieID(root), triedb)
- err = trie.TryDelete([]byte("123456"))
+ err = trie.Delete([]byte("123456"))
if _, ok := err.(*MissingNodeError); !ok {
t.Errorf("Wrong error: %v", err)
}
}
func TestInsert(t *testing.T) {
- trie := NewEmpty(NewDatabase(rawdb.NewMemoryDatabase()))
+ trie := NewEmpty(newTestDatabase(rawdb.NewMemoryDatabase(), rawdb.HashScheme))
updateString(trie, "doe", "reindeer")
updateString(trie, "dog", "puppy")
@@ -155,18 +181,18 @@ func TestInsert(t *testing.T) {
t.Errorf("case 1: exp %x got %x", exp, root)
}
- trie = NewEmpty(NewDatabase(rawdb.NewMemoryDatabase()))
+ trie = NewEmpty(newTestDatabase(rawdb.NewMemoryDatabase(), rawdb.HashScheme))
updateString(trie, "A", "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa")
exp = common.HexToHash("d23786fb4a010da3ce639d66d5e904a11dbc02746d1ce25029e53290cabf28ab")
- root, _ = trie.Commit(false)
+ root, _, _ = trie.Commit(false)
if root != exp {
t.Errorf("case 2: exp %x got %x", exp, root)
}
}
func TestGet(t *testing.T) {
- db := NewDatabase(rawdb.NewMemoryDatabase())
+ db := newTestDatabase(rawdb.NewMemoryDatabase(), rawdb.HashScheme)
trie := NewEmpty(db)
updateString(trie, "doe", "reindeer")
updateString(trie, "dog", "puppy")
@@ -184,14 +210,15 @@ func TestGet(t *testing.T) {
if i == 1 {
return
}
- root, nodes := trie.Commit(false)
- db.Update(NewWithNodeSet(nodes))
- trie, _ = NewAccountTrie(TrieID(root), db)
+ root, nodes, _ := trie.Commit(false)
+ db.Update(root, types.EmptyRootHash, trienode.NewWithNodeSet(nodes))
+ trie, _ = New(TrieID(root), db)
}
}
func TestDelete(t *testing.T) {
- trie := NewEmpty(NewDatabase(rawdb.NewMemoryDatabase()))
+ db := newTestDatabase(rawdb.NewMemoryDatabase(), rawdb.HashScheme)
+ trie := NewEmpty(db)
vals := []struct{ k, v string }{
{"do", "verb"},
{"ether", "wookiedoo"},
@@ -218,7 +245,7 @@ func TestDelete(t *testing.T) {
}
func TestEmptyValues(t *testing.T) {
- trie := NewEmpty(NewDatabase(rawdb.NewMemoryDatabase()))
+ trie := NewEmpty(newTestDatabase(rawdb.NewMemoryDatabase(), rawdb.HashScheme))
vals := []struct{ k, v string }{
{"do", "verb"},
@@ -242,8 +269,8 @@ func TestEmptyValues(t *testing.T) {
}
func TestReplication(t *testing.T) {
- triedb := NewDatabase(rawdb.NewMemoryDatabase())
- trie := NewEmpty(triedb)
+ db := newTestDatabase(rawdb.NewMemoryDatabase(), rawdb.HashScheme)
+ trie := NewEmpty(db)
vals := []struct{ k, v string }{
{"do", "verb"},
{"ether", "wookiedoo"},
@@ -256,31 +283,31 @@ func TestReplication(t *testing.T) {
for _, val := range vals {
updateString(trie, val.k, val.v)
}
- exp, nodes := trie.Commit(false)
- triedb.Update(NewWithNodeSet(nodes))
+ root, nodes, _ := trie.Commit(false)
+ db.Update(root, types.EmptyRootHash, trienode.NewWithNodeSet(nodes))
// create a new trie on top of the database and check that lookups work.
- trie2, err := NewAccountTrie(TrieID(exp), triedb)
+ trie2, err := New(TrieID(root), db)
if err != nil {
- t.Fatalf("can't recreate trie at %x: %v", exp, err)
+ t.Fatalf("can't recreate trie at %x: %v", root, err)
}
for _, kv := range vals {
if string(getString(trie2, kv.k)) != kv.v {
t.Errorf("trie2 doesn't have %q => %q", kv.k, kv.v)
}
}
- hash, nodes := trie2.Commit(false)
- if hash != exp {
- t.Errorf("root failure. expected %x got %x", exp, hash)
+ hash, nodes, _ := trie2.Commit(false)
+ if hash != root {
+ t.Errorf("root failure. expected %x got %x", root, hash)
}
// recreate the trie after commit
if nodes != nil {
- triedb.Update(NewWithNodeSet(nodes))
+ db.Update(hash, types.EmptyRootHash, trienode.NewWithNodeSet(nodes))
}
- trie2, err = NewAccountTrie(TrieID(hash), triedb)
+ trie2, err = New(TrieID(hash), db)
if err != nil {
- t.Fatalf("can't recreate trie at %x: %v", exp, err)
+ t.Fatalf("can't recreate trie at %x: %v", hash, err)
}
// perform some insertions on the new trie.
vals2 := []struct{ k, v string }{
@@ -297,15 +324,15 @@ func TestReplication(t *testing.T) {
for _, val := range vals2 {
updateString(trie2, val.k, val.v)
}
- if hash := trie2.Hash(); hash != exp {
- t.Errorf("root failure. expected %x got %x", exp, hash)
+ if trie2.Hash() != hash {
+ t.Errorf("root failure. expected %x got %x", hash, hash)
}
}
func TestLargeValue(t *testing.T) {
- trie := NewEmpty(NewDatabase(rawdb.NewMemoryDatabase()))
- trie.Update([]byte("key1"), []byte{99, 99, 99, 99})
- trie.Update([]byte("key2"), bytes.Repeat([]byte{1}, 32))
+ trie := NewEmpty(newTestDatabase(rawdb.NewMemoryDatabase(), rawdb.HashScheme))
+ trie.MustUpdate([]byte("key1"), []byte{99, 99, 99, 99})
+ trie.MustUpdate([]byte("key2"), bytes.Repeat([]byte{1}, 32))
trie.Hash()
}
@@ -339,13 +366,18 @@ func TestRandomCases(t *testing.T) {
{op: 1, key: common.Hex2Bytes("980c393656413a15c8da01978ed9f89feb80b502f58f2d640e3a2f5f7a99a7018f1b573befd92053ac6f78fca4a87268"), value: common.Hex2Bytes("")}, // step 24
{op: 1, key: common.Hex2Bytes("fd"), value: common.Hex2Bytes("")}, // step 25
}
- runRandTest(rt)
+ if err := runRandTest(rt); err != nil {
+ t.Fatal(err)
+ }
}
// randTest performs random trie operations.
// Instances of this test are created by Generate.
type randTest []randTestStep
+// compile-time interface check
+var _ quick.Generator = (randTest)(nil)
+
type randTestStep struct {
op int
key []byte // for opUpdate, opDelete, opGet
@@ -366,83 +398,102 @@ const (
)
func (randTest) Generate(r *rand.Rand, size int) reflect.Value {
+ var finishedFn = func() bool {
+ size--
+ return size == 0
+ }
+ return reflect.ValueOf(generateSteps(finishedFn, r))
+}
+
+func generateSteps(finished func() bool, r io.Reader) randTest {
var allKeys [][]byte
+ var one = []byte{0}
genKey := func() []byte {
- if len(allKeys) < 2 || r.Intn(100) < 10 {
+ r.Read(one)
+ if len(allKeys) < 2 || one[0]%100 > 90 {
// new key
- key := make([]byte, r.Intn(50))
+ size := one[0] % 50
+ key := make([]byte, size)
r.Read(key)
allKeys = append(allKeys, key)
return key
}
// use existing key
- return allKeys[r.Intn(len(allKeys))]
+ idx := int(one[0]) % len(allKeys)
+ return allKeys[idx]
}
-
var steps randTest
- for i := 0; i < size; i++ {
- step := randTestStep{op: r.Intn(opMax)}
+ for !finished() {
+ r.Read(one)
+ step := randTestStep{op: int(one[0]) % opMax}
switch step.op {
case opUpdate:
step.key = genKey()
step.value = make([]byte, 8)
- binary.BigEndian.PutUint64(step.value, uint64(i))
+ binary.BigEndian.PutUint64(step.value, uint64(len(steps)))
case opGet, opDelete, opProve:
step.key = genKey()
}
steps = append(steps, step)
}
- return reflect.ValueOf(steps)
+ return steps
}
-func verifyAccessList(old *Trie, new *Trie, set *NodeSet) error {
+func verifyAccessList(old *Trie, new *Trie, set *trienode.NodeSet) error {
deletes, inserts, updates := diffTries(old, new)
// Check insertion set
for path := range inserts {
- n, ok := set.nodes[path]
- if !ok || n.isDeleted() {
+ n, ok := set.Nodes[path]
+ if !ok || n.IsDeleted() {
return errors.New("expect new node")
}
- _, ok = set.accessList[path]
- if ok {
- return errors.New("unexpected origin value")
- }
+ //if len(n.Prev) > 0 {
+ // return errors.New("unexpected origin value")
+ //}
}
// Check deletion set
- for path, blob := range deletes {
- n, ok := set.nodes[path]
- if !ok || !n.isDeleted() {
+ for path := range deletes {
+ n, ok := set.Nodes[path]
+ if !ok || !n.IsDeleted() {
return errors.New("expect deleted node")
}
- v, ok := set.accessList[path]
- if !ok {
- return errors.New("expect origin value")
- }
- if !bytes.Equal(v, blob) {
- return errors.New("invalid origin value")
- }
+ //if len(n.Prev) == 0 {
+ // return errors.New("expect origin value")
+ //}
+ //if !bytes.Equal(n.Prev, blob) {
+ // return errors.New("invalid origin value")
+ //}
}
// Check update set
- for path, blob := range updates {
- n, ok := set.nodes[path]
- if !ok || n.isDeleted() {
+ for path := range updates {
+ n, ok := set.Nodes[path]
+ if !ok || n.IsDeleted() {
return errors.New("expect updated node")
}
- v, ok := set.accessList[path]
- if !ok {
- return errors.New("expect origin value")
- }
- if !bytes.Equal(v, blob) {
- return errors.New("invalid origin value")
- }
+ //if len(n.Prev) == 0 {
+ // return errors.New("expect origin value")
+ //}
+ //if !bytes.Equal(n.Prev, blob) {
+ // return errors.New("invalid origin value")
+ //}
}
return nil
}
-func runRandTest(rt randTest) bool {
+// runRandTestBool coerces error to boolean, for use in quick.Check
+func runRandTestBool(rt randTest) bool {
+ return runRandTest(rt) == nil
+}
+
+func runRandTest(rt randTest) error {
+ var scheme = rawdb.HashScheme
+ if rand.Intn(2) == 0 {
+ scheme = rawdb.PathScheme
+ }
var (
- triedb = NewDatabase(rawdb.NewMemoryDatabase())
+ origin = types.EmptyRootHash
+ triedb = newTestDatabase(rawdb.NewMemoryDatabase(), scheme)
tr = NewEmpty(triedb)
values = make(map[string]string) // tracks content of the trie
origTrie = NewEmpty(triedb)
@@ -453,13 +504,13 @@ func runRandTest(rt randTest) bool {
switch step.op {
case opUpdate:
- tr.Update(step.key, step.value)
+ tr.MustUpdate(step.key, step.value)
values[string(step.key)] = string(step.value)
case opDelete:
- tr.Delete(step.key)
+ tr.MustDelete(step.key)
delete(values, string(step.key))
case opGet:
- v := tr.Get(step.key)
+ v := tr.MustGet(step.key)
want := values[string(step.key)]
if string(v) != want {
rt[i].err = fmt.Errorf("mismatch for key %#x, got %#x want %#x", step.key, v, want)
@@ -470,7 +521,7 @@ func runRandTest(rt randTest) bool {
continue
}
proofDb := rawdb.NewMemoryDatabase()
- err := tr.Prove(step.key, 0, proofDb)
+ err := tr.Prove(step.key, proofDb)
if err != nil {
rt[i].err = fmt.Errorf("failed for proving key %#x, %v", step.key, err)
}
@@ -481,36 +532,37 @@ func runRandTest(rt randTest) bool {
case opHash:
tr.Hash()
case opCommit:
- root, nodes := tr.Commit(true)
+ root, nodes, _ := tr.Commit(true)
if nodes != nil {
- triedb.Update(NewWithNodeSet(nodes))
+ triedb.Update(root, origin, trienode.NewWithNodeSet(nodes))
}
- newtr, err := NewAccountTrie(TrieID(root), triedb)
+ newtr, err := New(TrieID(root), triedb)
if err != nil {
rt[i].err = err
- return false
+ return err
}
if nodes != nil {
if err := verifyAccessList(origTrie, newtr, nodes); err != nil {
rt[i].err = err
- return false
+ return err
}
}
tr = newtr
origTrie = tr.Copy()
+ origin = root
case opItercheckhash:
checktr := NewEmpty(triedb)
- it := NewIterator(tr.NodeIterator(nil))
+ it := NewIterator(tr.MustNodeIterator(nil))
for it.Next() {
- checktr.Update(it.Key, it.Value)
+ checktr.MustUpdate(it.Key, it.Value)
}
if tr.Hash() != checktr.Hash() {
rt[i].err = fmt.Errorf("hash mismatch in opItercheckhash")
}
case opNodeDiff:
var (
- origIter = origTrie.NodeIterator(nil)
- curIter = tr.NodeIterator(nil)
+ origIter = origTrie.MustNodeIterator(nil)
+ curIter = tr.MustNodeIterator(nil)
origSeen = make(map[string]struct{})
curSeen = make(map[string]struct{})
)
@@ -561,14 +613,14 @@ func runRandTest(rt randTest) bool {
}
// Abort the test on error.
if rt[i].err != nil {
- return false
+ return rt[i].err
}
}
- return true
+ return nil
}
func TestRandom(t *testing.T) {
- if err := quick.Check(runRandTest, nil); err != nil {
+ if err := quick.Check(runRandTestBool, nil); err != nil {
if cerr, ok := err.(*quick.CheckError); ok {
t.Fatalf("random test iteration %d failed: %s", cerr.Count, spew.Sdump(cerr.In))
}
@@ -576,42 +628,126 @@ func TestRandom(t *testing.T) {
}
}
+func BenchmarkGet(b *testing.B) { benchGet(b) }
+func BenchmarkUpdateBE(b *testing.B) { benchUpdate(b, binary.BigEndian) }
+func BenchmarkUpdateLE(b *testing.B) { benchUpdate(b, binary.LittleEndian) }
+
+const benchElemCount = 20000
+
+func benchGet(b *testing.B) {
+ triedb := newTestDatabase(rawdb.NewMemoryDatabase(), rawdb.HashScheme)
+ trie := NewEmpty(triedb)
+ k := make([]byte, 32)
+ for i := 0; i < benchElemCount; i++ {
+ binary.LittleEndian.PutUint64(k, uint64(i))
+ v := make([]byte, 32)
+ binary.LittleEndian.PutUint64(v, uint64(i))
+ trie.MustUpdate(k, v)
+ }
+ binary.LittleEndian.PutUint64(k, benchElemCount/2)
+
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ trie.MustGet(k)
+ }
+ b.StopTimer()
+}
+
+func benchUpdate(b *testing.B, e binary.ByteOrder) *Trie {
+ trie := NewEmpty(newTestDatabase(rawdb.NewMemoryDatabase(), rawdb.HashScheme))
+ k := make([]byte, 32)
+ b.ReportAllocs()
+ for i := 0; i < b.N; i++ {
+ v := make([]byte, 32)
+ e.PutUint64(k, uint64(i))
+ e.PutUint64(v, uint64(i))
+ trie.MustUpdate(k, v)
+ }
+ return trie
+}
+
+// Benchmarks the trie hashing. Since the trie caches the result of any operation,
+// we cannot use b.N as the number of hashing rounds, since all rounds apart from
+// the first one will be NOOP. As such, we'll use b.N as the number of account to
+// insert into the trie before measuring the hashing.
+// BenchmarkHash-6 288680 4561 ns/op 682 B/op 9 allocs/op
+// BenchmarkHash-6 275095 4800 ns/op 685 B/op 9 allocs/op
+// pure hasher:
+// BenchmarkHash-6 319362 4230 ns/op 675 B/op 9 allocs/op
+// BenchmarkHash-6 257460 4674 ns/op 689 B/op 9 allocs/op
+// With hashing in-between and pure hasher:
+// BenchmarkHash-6 225417 7150 ns/op 982 B/op 12 allocs/op
+// BenchmarkHash-6 220378 6197 ns/op 983 B/op 12 allocs/op
+// same with old hasher
+// BenchmarkHash-6 229758 6437 ns/op 981 B/op 12 allocs/op
+// BenchmarkHash-6 212610 7137 ns/op 986 B/op 12 allocs/op
+func BenchmarkHash(b *testing.B) {
+ // Create a realistic account trie to hash. We're first adding and hashing N
+ // entries, then adding N more.
+ addresses, accounts := makeAccounts(2 * b.N)
+ // Insert the accounts into the trie and hash it
+ trie := NewEmpty(newTestDatabase(rawdb.NewMemoryDatabase(), rawdb.HashScheme))
+ i := 0
+ for ; i < len(addresses)/2; i++ {
+ trie.MustUpdate(crypto.Keccak256(addresses[i][:]), accounts[i])
+ }
+ trie.Hash()
+ for ; i < len(addresses); i++ {
+ trie.MustUpdate(crypto.Keccak256(addresses[i][:]), accounts[i])
+ }
+ b.ResetTimer()
+ b.ReportAllocs()
+ //trie.hashRoot(nil, nil)
+ trie.Hash()
+}
+
+// Benchmarks the trie Commit following a Hash. Since the trie caches the result of any operation,
+// we cannot use b.N as the number of hashing rounds, since all rounds apart from
+// the first one will be NOOP. As such, we'll use b.N as the number of account to
+// insert into the trie before measuring the hashing.
+func BenchmarkCommitAfterHash(b *testing.B) {
+ b.Run("no-onleaf", func(b *testing.B) {
+ benchmarkCommitAfterHash(b, false)
+ })
+ b.Run("with-onleaf", func(b *testing.B) {
+ benchmarkCommitAfterHash(b, true)
+ })
+}
+
+func benchmarkCommitAfterHash(b *testing.B, collectLeaf bool) {
+ // Make the random benchmark deterministic
+ addresses, accounts := makeAccounts(b.N)
+ trie := NewEmpty(newTestDatabase(rawdb.NewMemoryDatabase(), rawdb.HashScheme))
+ for i := 0; i < len(addresses); i++ {
+ trie.MustUpdate(crypto.Keccak256(addresses[i][:]), accounts[i])
+ }
+ // Insert the accounts into the trie and hash it
+ trie.Hash()
+ b.ResetTimer()
+ b.ReportAllocs()
+ trie.Commit(collectLeaf)
+}
+
func TestTinyTrie(t *testing.T) {
// Create a realistic account trie to hash
_, accounts := makeAccounts(5)
- trie := NewEmpty(NewDatabase(rawdb.NewMemoryDatabase()))
-
- type testCase struct {
- key, account []byte
- root common.Hash
+ trie := NewEmpty(newTestDatabase(rawdb.NewMemoryDatabase(), rawdb.HashScheme))
+ trie.MustUpdate(common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000001337"), accounts[3])
+ if exp, root := common.HexToHash("8c6a85a4d9fda98feff88450299e574e5378e32391f75a055d470ac0653f1005"), trie.Hash(); exp != root {
+ t.Errorf("1: got %x, exp %x", root, exp)
}
-
- cases := []testCase{
- {
- common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000001337"),
- accounts[3],
- common.HexToHash("8c6a85a4d9fda98feff88450299e574e5378e32391f75a055d470ac0653f1005"),
- }, {
- common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000001338"),
- accounts[4],
- common.HexToHash("ec63b967e98a5720e7f720482151963982890d82c9093c0d486b7eb8883a66b1"),
- }, {
- common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000001339"),
- accounts[4],
- common.HexToHash("0608c1d1dc3905fa22204c7a0e43644831c3b6d3def0f274be623a948197e64a"),
- },
+ trie.MustUpdate(common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000001338"), accounts[4])
+ if exp, root := common.HexToHash("ec63b967e98a5720e7f720482151963982890d82c9093c0d486b7eb8883a66b1"), trie.Hash(); exp != root {
+ t.Errorf("2: got %x, exp %x", root, exp)
}
- for i, c := range cases {
- trie.Update(c.key, c.account)
- root := trie.Hash()
- if root != c.root {
- t.Errorf("case %d: got %x, exp %x", i, root, c.root)
- }
+ trie.MustUpdate(common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000001339"), accounts[4])
+ if exp, root := common.HexToHash("0608c1d1dc3905fa22204c7a0e43644831c3b6d3def0f274be623a948197e64a"), trie.Hash(); exp != root {
+ t.Errorf("3: got %x, exp %x", root, exp)
}
- checktr := NewEmpty(NewDatabase(rawdb.NewMemoryDatabase()))
- it := NewIterator(trie.NodeIterator(nil))
+ checktr := NewEmpty(newTestDatabase(rawdb.NewMemoryDatabase(), rawdb.HashScheme))
+ it := NewIterator(trie.MustNodeIterator(nil))
for it.Next() {
- checktr.Update(it.Key, it.Value)
+ checktr.MustUpdate(it.Key, it.Value)
}
if troot, itroot := trie.Hash(), checktr.Hash(); troot != itroot {
t.Fatalf("hash mismatch in opItercheckhash, trie: %x, check: %x", troot, itroot)
@@ -621,9 +757,9 @@ func TestTinyTrie(t *testing.T) {
func TestCommitAfterHash(t *testing.T) {
// Create a realistic account trie to hash
addresses, accounts := makeAccounts(1000)
- trie := NewEmpty(NewDatabase(rawdb.NewMemoryDatabase()))
+ trie := NewEmpty(newTestDatabase(rawdb.NewMemoryDatabase(), rawdb.HashScheme))
for i := 0; i < len(addresses); i++ {
- trie.Update(crypto.Keccak256(addresses[i][:]), accounts[i])
+ trie.MustUpdate(crypto.Keccak256(addresses[i][:]), accounts[i])
}
// Insert the accounts into the trie and hash it
trie.Hash()
@@ -633,7 +769,7 @@ func TestCommitAfterHash(t *testing.T) {
if exp != root {
t.Errorf("got %x, exp %x", root, exp)
}
- root, _ = trie.Commit(false)
+ root, _, _ = trie.Commit(false)
if exp != root {
t.Errorf("got %x, exp %x", root, exp)
}
@@ -663,23 +799,391 @@ func makeAccounts(size int) (addresses [][20]byte, accounts [][]byte) {
numBytes := random.Uint32() % 33 // [0, 32] bytes
balanceBytes := make([]byte, numBytes)
random.Read(balanceBytes)
- balance := new(big.Int).SetBytes(balanceBytes)
+ balance := new(uint256.Int).SetBytes(balanceBytes)
data, _ := rlp.EncodeToBytes(&types.StateAccount{Nonce: nonce, Balance: balance, Root: root, CodeHash: code})
accounts[i] = data
}
return addresses, accounts
}
+// spongeDb is a dummy db backend which accumulates writes in a sponge
+type spongeDb struct {
+ sponge hash.Hash
+ id string
+ journal []string
+ keys []string
+ values map[string]string
+}
+
+func (s *spongeDb) Has(key []byte) (bool, error) { panic("implement me") }
+func (s *spongeDb) Get(key []byte) ([]byte, error) { return nil, errors.New("no such elem") }
+func (s *spongeDb) Delete(key []byte) error { panic("implement me") }
+func (s *spongeDb) NewBatch() ethdb.Batch { return &spongeBatch{s} }
+func (s *spongeDb) NewBatchWithSize(size int) ethdb.Batch { return &spongeBatch{s} }
+func (s *spongeDb) NewSnapshot() (ethdb.Snapshot, error) { panic("implement me") }
+func (s *spongeDb) Stat(property string) (string, error) { panic("implement me") }
+func (s *spongeDb) Compact(start []byte, limit []byte) error { panic("implement me") }
+func (s *spongeDb) Close() error { return nil }
+func (s *spongeDb) Put(key []byte, value []byte) error {
+ var (
+ keybrief = key
+ valbrief = value
+ )
+ if len(keybrief) > 8 {
+ keybrief = keybrief[:8]
+ }
+ if len(valbrief) > 8 {
+ valbrief = valbrief[:8]
+ }
+ s.journal = append(s.journal, fmt.Sprintf("%v: PUT([%x...], [%d bytes] %x...)\n", s.id, keybrief, len(value), valbrief))
+
+ if s.values == nil {
+ s.sponge.Write(key)
+ s.sponge.Write(value)
+ } else {
+ s.keys = append(s.keys, string(key))
+ s.values[string(key)] = string(value)
+ }
+ return nil
+}
+func (s *spongeDb) NewIterator(prefix []byte, start []byte) ethdb.Iterator { panic("implement me") }
+
+func (s *spongeDb) Flush() {
+ // Bottom-up, the longest path first
+ sort.Sort(sort.Reverse(sort.StringSlice(s.keys)))
+ for _, key := range s.keys {
+ s.sponge.Write([]byte(key))
+ s.sponge.Write([]byte(s.values[key]))
+ }
+}
+
+// spongeBatch is a dummy batch which immediately writes to the underlying spongedb
+type spongeBatch struct {
+ db *spongeDb
+}
+
+func (b *spongeBatch) Put(key, value []byte) error {
+ b.db.Put(key, value)
+ return nil
+}
+func (b *spongeBatch) Delete(key []byte) error { panic("implement me") }
+func (b *spongeBatch) ValueSize() int { return 100 }
+func (b *spongeBatch) Write() error { return nil }
+func (b *spongeBatch) Reset() {}
+func (b *spongeBatch) Replay(w ethdb.KeyValueWriter) error { return nil }
+
+// TestCommitSequence tests that the trie.Commit operation writes the elements of the trie
+// in the expected order.
+// The test data was based on the 'master' code, and is basically random. It can be used
+// to check whether changes to the trie modifies the write order or data in any way.
+func TestCommitSequence(t *testing.T) {
+ for i, tc := range []struct {
+ count int
+ expWriteSeqHash []byte
+ }{
+ {20, common.FromHex("330b0afae2853d96b9f015791fbe0fb7f239bf65f335f16dfc04b76c7536276d")},
+ {200, common.FromHex("5162b3735c06b5d606b043a3ee8adbdbbb408543f4966bca9dcc63da82684eeb")},
+ {2000, common.FromHex("4574cd8e6b17f3fe8ad89140d1d0bf4f1bd7a87a8ac3fb623b33550544c77635")},
+ } {
+ addresses, accounts := makeAccounts(tc.count)
+ // This spongeDb is used to check the sequence of disk-db-writes
+ s := &spongeDb{sponge: sha3.NewLegacyKeccak256()}
+ db := newTestDatabase(rawdb.NewDatabase(s), rawdb.HashScheme)
+ trie := NewEmpty(db)
+ // Fill the trie with elements
+ for i := 0; i < tc.count; i++ {
+ trie.MustUpdate(crypto.Keccak256(addresses[i][:]), accounts[i])
+ }
+ // Flush trie -> database
+ root, nodes, _ := trie.Commit(false)
+ db.Update(root, types.EmptyRootHash, trienode.NewWithNodeSet(nodes))
+ // Flush memdb -> disk (sponge)
+ db.Commit(root)
+ if got, exp := s.sponge.Sum(nil), tc.expWriteSeqHash; !bytes.Equal(got, exp) {
+ t.Errorf("test %d, disk write sequence wrong:\ngot %x exp %x\n", i, got, exp)
+ }
+ }
+}
+
+// TestCommitSequenceRandomBlobs is identical to TestCommitSequence
+// but uses random blobs instead of 'accounts'
+func TestCommitSequenceRandomBlobs(t *testing.T) {
+ for i, tc := range []struct {
+ count int
+ expWriteSeqHash []byte
+ }{
+ {20, common.FromHex("8016650c7a50cf88485fd06cde52d634a89711051107f00d21fae98234f2f13d")},
+ {200, common.FromHex("dde92ca9812e068e6982d04b40846dc65a61a9fd4996fc0f55f2fde172a8e13c")},
+ {2000, common.FromHex("ab553a7f9aff82e3929c382908e30ef7dd17a332933e92ba3fe873fc661ef382")},
+ } {
+ prng := rand.New(rand.NewSource(int64(i)))
+ // This spongeDb is used to check the sequence of disk-db-writes
+ s := &spongeDb{sponge: sha3.NewLegacyKeccak256()}
+ db := newTestDatabase(rawdb.NewDatabase(s), rawdb.HashScheme)
+ trie := NewEmpty(db)
+ // Fill the trie with elements
+ for i := 0; i < tc.count; i++ {
+ key := make([]byte, 32)
+ var val []byte
+ // 50% short elements, 50% large elements
+ if prng.Intn(2) == 0 {
+ val = make([]byte, 1+prng.Intn(32))
+ } else {
+ val = make([]byte, 1+prng.Intn(4096))
+ }
+ prng.Read(key)
+ prng.Read(val)
+ trie.MustUpdate(key, val)
+ }
+ // Flush trie -> database
+ root, nodes, _ := trie.Commit(false)
+ db.Update(root, types.EmptyRootHash, trienode.NewWithNodeSet(nodes))
+ // Flush memdb -> disk (sponge)
+ db.Commit(root)
+ if got, exp := s.sponge.Sum(nil), tc.expWriteSeqHash; !bytes.Equal(got, exp) {
+ t.Fatalf("test %d, disk write sequence wrong:\ngot %x exp %x\n", i, got, exp)
+ }
+ }
+}
+
+func TestCommitSequenceStackTrie(t *testing.T) {
+ for count := 1; count < 200; count++ {
+ prng := rand.New(rand.NewSource(int64(count)))
+ // This spongeDb is used to check the sequence of disk-db-writes
+ s := &spongeDb{
+ sponge: sha3.NewLegacyKeccak256(),
+ id: "a",
+ values: make(map[string]string),
+ }
+ db := newTestDatabase(rawdb.NewDatabase(s), rawdb.HashScheme)
+ trie := NewEmpty(db)
+
+ // Another sponge is used for the stacktrie commits
+ stackTrieSponge := &spongeDb{
+ sponge: sha3.NewLegacyKeccak256(),
+ id: "b",
+ values: make(map[string]string),
+ }
+ options := NewStackTrieOptions()
+ options = options.WithWriter(func(path []byte, hash common.Hash, blob []byte) {
+ rawdb.WriteTrieNode(stackTrieSponge, common.Hash{}, path, hash, blob, db.Scheme())
+ })
+ stTrie := NewStackTrie(options)
+
+ // Fill the trie with elements
+ for i := 0; i < count; i++ {
+ // For the stack trie, we need to do inserts in proper order
+ key := make([]byte, 32)
+ binary.BigEndian.PutUint64(key, uint64(i))
+ var val []byte
+ // 50% short elements, 50% large elements
+ if prng.Intn(2) == 0 {
+ val = make([]byte, 1+prng.Intn(32))
+ } else {
+ val = make([]byte, 1+prng.Intn(1024))
+ }
+ prng.Read(val)
+ trie.Update(key, val)
+ stTrie.Update(key, val)
+ }
+ // Flush trie -> database
+ root, nodes, _ := trie.Commit(false)
+ // Flush memdb -> disk (sponge)
+ db.Update(root, types.EmptyRootHash, trienode.NewWithNodeSet(nodes))
+ db.Commit(root)
+ s.Flush()
+
+ // And flush stacktrie -> disk
+ stRoot := stTrie.Commit()
+ if stRoot != root {
+ t.Fatalf("root wrong, got %x exp %x", stRoot, root)
+ }
+ stackTrieSponge.Flush()
+ if got, exp := stackTrieSponge.sponge.Sum(nil), s.sponge.Sum(nil); !bytes.Equal(got, exp) {
+ // Show the journal
+ t.Logf("Expected:")
+ for i, v := range s.journal {
+ t.Logf("op %d: %v", i, v)
+ }
+ t.Logf("Stacktrie:")
+ for i, v := range stackTrieSponge.journal {
+ t.Logf("op %d: %v", i, v)
+ }
+ t.Fatalf("test %d, disk write sequence wrong:\ngot %x exp %x\n", count, got, exp)
+ }
+ }
+}
+
+// TestCommitSequenceSmallRoot tests that a trie which is essentially only a
+// small (<32 byte) shortnode with an included value is properly committed to a
+// database.
+// This case might not matter, since in practice, all keys are 32 bytes, which means
+// that even a small trie which contains a leaf will have an extension making it
+// not fit into 32 bytes, rlp-encoded. However, it's still the correct thing to do.
+func TestCommitSequenceSmallRoot(t *testing.T) {
+ s := &spongeDb{
+ sponge: sha3.NewLegacyKeccak256(),
+ id: "a",
+ values: make(map[string]string),
+ }
+ db := newTestDatabase(rawdb.NewDatabase(s), rawdb.HashScheme)
+ trie := NewEmpty(db)
+
+ // Another sponge is used for the stacktrie commits
+ stackTrieSponge := &spongeDb{
+ sponge: sha3.NewLegacyKeccak256(),
+ id: "b",
+ values: make(map[string]string),
+ }
+ options := NewStackTrieOptions()
+ options = options.WithWriter(func(path []byte, hash common.Hash, blob []byte) {
+ rawdb.WriteTrieNode(stackTrieSponge, common.Hash{}, path, hash, blob, db.Scheme())
+ })
+ stTrie := NewStackTrie(options)
+
+ // Add a single small-element to the trie(s)
+ key := make([]byte, 5)
+ key[0] = 1
+ trie.Update(key, []byte{0x1})
+ stTrie.Update(key, []byte{0x1})
+
+ // Flush trie -> database
+ root, nodes, _ := trie.Commit(false)
+ // Flush memdb -> disk (sponge)
+ db.Update(root, types.EmptyRootHash, trienode.NewWithNodeSet(nodes))
+ db.Commit(root)
+
+ // And flush stacktrie -> disk
+ stRoot := stTrie.Commit()
+ if stRoot != root {
+ t.Fatalf("root wrong, got %x exp %x", stRoot, root)
+ }
+ t.Logf("root: %x\n", stRoot)
+
+ s.Flush()
+ stackTrieSponge.Flush()
+ if got, exp := stackTrieSponge.sponge.Sum(nil), s.sponge.Sum(nil); !bytes.Equal(got, exp) {
+ t.Fatalf("test, disk write sequence wrong:\ngot %x exp %x\n", got, exp)
+ }
+}
+
+// BenchmarkCommitAfterHashFixedSize benchmarks the Commit (after Hash) of a fixed number of updates to a trie.
+// This benchmark is meant to capture the difference on efficiency of small versus large changes. Typically,
+// storage tries are small (a couple of entries), whereas the full post-block account trie update is large (a couple
+// of thousand entries)
+func BenchmarkHashFixedSize(b *testing.B) {
+ b.Run("10", func(b *testing.B) {
+ b.StopTimer()
+ acc, add := makeAccounts(20)
+ for i := 0; i < b.N; i++ {
+ benchmarkHashFixedSize(b, acc, add)
+ }
+ })
+ b.Run("100", func(b *testing.B) {
+ b.StopTimer()
+ acc, add := makeAccounts(100)
+ for i := 0; i < b.N; i++ {
+ benchmarkHashFixedSize(b, acc, add)
+ }
+ })
+
+ b.Run("1K", func(b *testing.B) {
+ b.StopTimer()
+ acc, add := makeAccounts(1000)
+ for i := 0; i < b.N; i++ {
+ benchmarkHashFixedSize(b, acc, add)
+ }
+ })
+ b.Run("10K", func(b *testing.B) {
+ b.StopTimer()
+ acc, add := makeAccounts(10000)
+ for i := 0; i < b.N; i++ {
+ benchmarkHashFixedSize(b, acc, add)
+ }
+ })
+ b.Run("100K", func(b *testing.B) {
+ b.StopTimer()
+ acc, add := makeAccounts(100000)
+ for i := 0; i < b.N; i++ {
+ benchmarkHashFixedSize(b, acc, add)
+ }
+ })
+}
+
+func benchmarkHashFixedSize(b *testing.B, addresses [][20]byte, accounts [][]byte) {
+ b.ReportAllocs()
+ trie := NewEmpty(newTestDatabase(rawdb.NewMemoryDatabase(), rawdb.HashScheme))
+ for i := 0; i < len(addresses); i++ {
+ trie.MustUpdate(crypto.Keccak256(addresses[i][:]), accounts[i])
+ }
+ // Insert the accounts into the trie and hash it
+ b.StartTimer()
+ trie.Hash()
+ b.StopTimer()
+}
+
+func BenchmarkCommitAfterHashFixedSize(b *testing.B) {
+ b.Run("10", func(b *testing.B) {
+ b.StopTimer()
+ acc, add := makeAccounts(20)
+ for i := 0; i < b.N; i++ {
+ benchmarkCommitAfterHashFixedSize(b, acc, add)
+ }
+ })
+ b.Run("100", func(b *testing.B) {
+ b.StopTimer()
+ acc, add := makeAccounts(100)
+ for i := 0; i < b.N; i++ {
+ benchmarkCommitAfterHashFixedSize(b, acc, add)
+ }
+ })
+
+ b.Run("1K", func(b *testing.B) {
+ b.StopTimer()
+ acc, add := makeAccounts(1000)
+ for i := 0; i < b.N; i++ {
+ benchmarkCommitAfterHashFixedSize(b, acc, add)
+ }
+ })
+ b.Run("10K", func(b *testing.B) {
+ b.StopTimer()
+ acc, add := makeAccounts(10000)
+ for i := 0; i < b.N; i++ {
+ benchmarkCommitAfterHashFixedSize(b, acc, add)
+ }
+ })
+ b.Run("100K", func(b *testing.B) {
+ b.StopTimer()
+ acc, add := makeAccounts(100000)
+ for i := 0; i < b.N; i++ {
+ benchmarkCommitAfterHashFixedSize(b, acc, add)
+ }
+ })
+}
+
+func benchmarkCommitAfterHashFixedSize(b *testing.B, addresses [][20]byte, accounts [][]byte) {
+ b.ReportAllocs()
+ trie := NewEmpty(newTestDatabase(rawdb.NewMemoryDatabase(), rawdb.HashScheme))
+ for i := 0; i < len(addresses); i++ {
+ trie.MustUpdate(crypto.Keccak256(addresses[i][:]), accounts[i])
+ }
+ // Insert the accounts into the trie and hash it
+ trie.Hash()
+ b.StartTimer()
+ trie.Commit(false)
+ b.StopTimer()
+}
+
func getString(trie *Trie, k string) []byte {
- return trie.Get([]byte(k))
+ return trie.MustGet([]byte(k))
}
func updateString(trie *Trie, k, v string) {
- trie.Update([]byte(k), []byte(v))
+ trie.MustUpdate([]byte(k), []byte(v))
}
func deleteString(trie *Trie, k string) {
- trie.Delete([]byte(k))
+ trie.MustDelete([]byte(k))
}
func TestDecodeNode(t *testing.T) {
@@ -695,3 +1199,17 @@ func TestDecodeNode(t *testing.T) {
decodeNode(hash, elems)
}
}
+
+func FuzzTrie(f *testing.F) {
+ f.Fuzz(func(t *testing.T, data []byte) {
+ var steps = 500
+ var input = bytes.NewReader(data)
+ var finishedFn = func() bool {
+ steps--
+ return steps < 0 || input.Len() == 0
+ }
+ if err := runRandTest(generateSteps(finishedFn, input)); err != nil {
+ t.Fatal(err)
+ }
+ })
+}
diff --git a/trie_by_cid/trie/util_test.go b/trie_by_cid/trie/util_test.go
deleted file mode 100644
index 7bebec8..0000000
--- a/trie_by_cid/trie/util_test.go
+++ /dev/null
@@ -1,241 +0,0 @@
-package trie
-
-import (
- "bytes"
- "context"
- "fmt"
- "math/big"
- "math/rand"
- "testing"
-
- "github.com/ethereum/go-ethereum/common"
- "github.com/ethereum/go-ethereum/core/rawdb"
- gethstate "github.com/ethereum/go-ethereum/core/state"
- "github.com/ethereum/go-ethereum/core/types"
- "github.com/ethereum/go-ethereum/ethdb"
- "github.com/ethereum/go-ethereum/rlp"
- gethtrie "github.com/ethereum/go-ethereum/trie"
- "github.com/jmoiron/sqlx"
-
- pgipfsethdb "github.com/cerc-io/ipfs-ethdb/v5/postgres/v0"
- "github.com/cerc-io/plugeth-statediff/indexer/database/sql/postgres"
- "github.com/cerc-io/plugeth-statediff/test_helpers"
-
- "github.com/cerc-io/ipld-eth-statedb/internal"
- "github.com/cerc-io/ipld-eth-statedb/trie_by_cid/helper"
-)
-
-var (
- dbConfig, _ = postgres.TestConfig.WithEnv()
- trieConfig = Config{Cache: 256}
-)
-
-type kvi struct {
- k []byte
- v int64
-}
-
-type kvMap map[string]*kvi
-
-type kvsi struct {
- k string
- v int64
-}
-
-// NewAccountTrie is a shortcut to create a trie using the StateTrieCodec (ie. IPLD MEthStateTrie codec).
-func NewAccountTrie(id *ID, db NodeReader) (*Trie, error) {
- return New(id, db, StateTrieCodec)
-}
-
-// makeTestTrie create a sample test trie to test node-wise reconstruction.
-func makeTestTrie(t testing.TB) (*Database, *StateTrie, map[string][]byte) {
- // Create an empty trie
- triedb := NewDatabase(rawdb.NewMemoryDatabase())
- trie, err := NewStateTrie(TrieID(common.Hash{}), triedb, StateTrieCodec)
- if err != nil {
- t.Fatal(err)
- }
-
- // Fill it with some arbitrary data
- content := make(map[string][]byte)
- for i := byte(0); i < 255; i++ {
- // Map the same data under multiple keys
- key, val := common.LeftPadBytes([]byte{1, i}, 32), []byte{i}
- content[string(key)] = val
- trie.Update(key, val)
-
- key, val = common.LeftPadBytes([]byte{2, i}, 32), []byte{i}
- content[string(key)] = val
- trie.Update(key, val)
-
- // Add some other data to inflate the trie
- for j := byte(3); j < 13; j++ {
- key, val = common.LeftPadBytes([]byte{j, i}, 32), []byte{j, i}
- content[string(key)] = val
- trie.Update(key, val)
- }
- }
- root, nodes := trie.Commit(false)
- if err := triedb.Update(NewWithNodeSet(nodes)); err != nil {
- panic(fmt.Errorf("failed to commit db %v", err))
- }
- // Re-create the trie based on the new state
- trie, err = NewStateTrie(TrieID(root), triedb, StateTrieCodec)
- if err != nil {
- t.Fatal(err)
- }
- return triedb, trie, content
-}
-
-func forHashedNodes(tr *Trie) map[string][]byte {
- var (
- it = tr.NodeIterator(nil)
- nodes = make(map[string][]byte)
- )
- for it.Next(true) {
- if it.Hash() == (common.Hash{}) {
- continue
- }
- nodes[string(it.Path())] = common.CopyBytes(it.NodeBlob())
- }
- return nodes
-}
-
-func diffTries(trieA, trieB *Trie) (map[string][]byte, map[string][]byte, map[string][]byte) {
- var (
- nodesA = forHashedNodes(trieA)
- nodesB = forHashedNodes(trieB)
- inA = make(map[string][]byte) // hashed nodes in trie a but not b
- inB = make(map[string][]byte) // hashed nodes in trie b but not a
- both = make(map[string][]byte) // hashed nodes in both tries but different value
- )
- for path, blobA := range nodesA {
- if blobB, ok := nodesB[path]; ok {
- if bytes.Equal(blobA, blobB) {
- continue
- }
- both[path] = blobA
- continue
- }
- inA[path] = blobA
- }
- for path, blobB := range nodesB {
- if _, ok := nodesA[path]; ok {
- continue
- }
- inB[path] = blobB
- }
- return inA, inB, both
-}
-
-func packValue(val int64) []byte {
- acct := &types.StateAccount{
- Balance: big.NewInt(val),
- CodeHash: test_helpers.NullCodeHash.Bytes(),
- Root: test_helpers.EmptyContractRoot,
- }
- acct_rlp, err := rlp.EncodeToBytes(acct)
- if err != nil {
- panic(err)
- }
- return acct_rlp
-}
-
-func updateTrie(tr *gethtrie.Trie, vals []kvsi) (kvMap, error) {
- all := kvMap{}
- for _, val := range vals {
- all[string(val.k)] = &kvi{[]byte(val.k), val.v}
- tr.Update([]byte(val.k), packValue(val.v))
- }
- return all, nil
-}
-
-func commitTrie(t testing.TB, db *gethtrie.Database, tr *gethtrie.Trie) common.Hash {
- t.Helper()
- root, nodes := tr.Commit(false)
- if err := db.Update(gethtrie.NewWithNodeSet(nodes)); err != nil {
- t.Fatal(err)
- }
- if err := db.Commit(root, false); err != nil {
- t.Fatal(err)
- }
- return root
-}
-
-func makePgIpfsEthDB(t testing.TB) ethdb.Database {
- pg_db, err := postgres.ConnectSQLX(context.Background(), dbConfig)
- if err != nil {
- t.Fatal(err)
- }
- t.Cleanup(func() {
- if err := TearDownDB(pg_db); err != nil {
- t.Fatal(err)
- }
- })
- return pgipfsethdb.NewDatabase(pg_db, internal.MakeCacheConfig(t))
-}
-
-// commit a LevelDB state trie, index to IPLD and return new trie
-func indexTrie(t testing.TB, edb ethdb.Database, root common.Hash) *Trie {
- t.Helper()
- dbConfig.Driver = postgres.PGX
- err := helper.IndexStateDiff(dbConfig, gethstate.NewDatabase(edb), common.Hash{}, root)
- if err != nil {
- t.Fatal(err)
- }
-
- ipfs_db := makePgIpfsEthDB(t)
- tr, err := New(TrieID(root), NewDatabase(ipfs_db), StateTrieCodec)
- if err != nil {
- t.Fatal(err)
- }
- return tr
-}
-
-// generates a random Geth LevelDB trie of n key-value pairs and corresponding value map
-func randomGethTrie(n int, db *gethtrie.Database) (*gethtrie.Trie, kvMap) {
- trie := gethtrie.NewEmpty(db)
- var vals []*kvi
- for i := byte(0); i < 100; i++ {
- e := &kvi{common.LeftPadBytes([]byte{i}, 32), int64(i)}
- e2 := &kvi{common.LeftPadBytes([]byte{i + 10}, 32), int64(i)}
- vals = append(vals, e, e2)
- }
- for i := 0; i < n; i++ {
- k := randBytes(32)
- v := rand.Int63()
- vals = append(vals, &kvi{k, v})
- }
- all := kvMap{}
- for _, val := range vals {
- all[string(val.k)] = &kvi{[]byte(val.k), val.v}
- trie.Update([]byte(val.k), packValue(val.v))
- }
- return trie, all
-}
-
-// TearDownDB is used to tear down the watcher dbs after tests
-func TearDownDB(db *sqlx.DB) error {
- tx, err := db.Beginx()
- if err != nil {
- return err
- }
- statements := []string{
- `DELETE FROM nodes`,
- `DELETE FROM ipld.blocks`,
- `DELETE FROM eth.header_cids`,
- `DELETE FROM eth.uncle_cids`,
- `DELETE FROM eth.transaction_cids`,
- `DELETE FROM eth.receipt_cids`,
- `DELETE FROM eth.state_cids`,
- `DELETE FROM eth.storage_cids`,
- `DELETE FROM eth.log_cids`,
- `DELETE FROM eth_meta.watched_addresses`,
- }
- for _, stm := range statements {
- if _, err = tx.Exec(stm); err != nil {
- return fmt.Errorf("error executing `%s`: %w", stm, err)
- }
- }
- return tx.Commit()
-}
diff --git a/trie_by_cid/triedb/database.go b/trie_by_cid/triedb/database.go
new file mode 100644
index 0000000..fc3a519
--- /dev/null
+++ b/trie_by_cid/triedb/database.go
@@ -0,0 +1,339 @@
+// Copyright 2022 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see .
+
+package triedb
+
+import (
+ "errors"
+
+ "github.com/ethereum/go-ethereum/common"
+ "github.com/ethereum/go-ethereum/ethdb"
+ "github.com/ethereum/go-ethereum/log"
+
+ "github.com/cerc-io/ipld-eth-statedb/trie_by_cid/trie"
+ "github.com/cerc-io/ipld-eth-statedb/trie_by_cid/trie/trienode"
+ "github.com/cerc-io/ipld-eth-statedb/trie_by_cid/trie/triestate"
+ "github.com/cerc-io/ipld-eth-statedb/trie_by_cid/triedb/database"
+ "github.com/cerc-io/ipld-eth-statedb/trie_by_cid/triedb/hashdb"
+ "github.com/cerc-io/ipld-eth-statedb/trie_by_cid/triedb/pathdb"
+)
+
+// Config defines all necessary options for database.
+type Config struct {
+ Preimages bool // Flag whether the preimage of node key is recorded
+ IsVerkle bool // Flag whether the db is holding a verkle tree
+ HashDB *hashdb.Config // Configs for hash-based scheme
+ PathDB *pathdb.Config // Configs for experimental path-based scheme
+}
+
+// HashDefaults represents a config for using hash-based scheme with
+// default settings.
+var HashDefaults = &Config{
+ Preimages: false,
+ HashDB: hashdb.Defaults,
+}
+
+// backend defines the methods needed to access/update trie nodes in different
+// state scheme.
+type backend interface {
+ // Scheme returns the identifier of used storage scheme.
+ Scheme() string
+
+ // Initialized returns an indicator if the state data is already initialized
+ // according to the state scheme.
+ Initialized(genesisRoot common.Hash) bool
+
+ // Size returns the current storage size of the diff layers on top of the
+ // disk layer and the storage size of the nodes cached in the disk layer.
+ //
+ // For hash scheme, there is no differentiation between diff layer nodes
+ // and dirty disk layer nodes, so both are merged into the second return.
+ Size() (common.StorageSize, common.StorageSize)
+
+ // Update performs a state transition by committing dirty nodes contained
+ // in the given set in order to update state from the specified parent to
+ // the specified root.
+ //
+ // The passed in maps(nodes, states) will be retained to avoid copying
+ // everything. Therefore, these maps must not be changed afterwards.
+ Update(root common.Hash, parent common.Hash, block uint64, nodes *trienode.MergedNodeSet, states *triestate.Set) error
+
+ // Commit writes all relevant trie nodes belonging to the specified state
+ // to disk. Report specifies whether logs will be displayed in info level.
+ Commit(root common.Hash, report bool) error
+
+ // Close closes the trie database backend and releases all held resources.
+ Close() error
+}
+
+// Database is the wrapper of the underlying backend which is shared by different
+// types of node backend as an entrypoint. It's responsible for all interactions
+// relevant with trie nodes and node preimages.
+type Database struct {
+ config *Config // Configuration for trie database
+ diskdb ethdb.Database // Persistent database to store the snapshot
+ preimages *preimageStore // The store for caching preimages
+ backend backend // The backend for managing trie nodes
+}
+
+// NewDatabase initializes the trie database with default settings, note
+// the legacy hash-based scheme is used by default.
+func NewDatabase(diskdb ethdb.Database, config *Config) *Database {
+ // Sanitize the config and use the default one if it's not specified.
+ if config == nil {
+ config = HashDefaults
+ }
+ var preimages *preimageStore
+ if config.Preimages {
+ preimages = newPreimageStore(diskdb)
+ }
+ db := &Database{
+ config: config,
+ diskdb: diskdb,
+ preimages: preimages,
+ }
+ if config.HashDB != nil && config.PathDB != nil {
+ log.Crit("Both 'hash' and 'path' mode are configured")
+ }
+ if config.PathDB != nil {
+ db.backend = pathdb.New(diskdb, config.PathDB)
+ } else {
+ var resolver hashdb.ChildResolver
+ if config.IsVerkle {
+ // TODO define verkle resolver
+ log.Crit("Verkle node resolver is not defined")
+ } else {
+ resolver = trie.MerkleResolver{}
+ }
+ db.backend = hashdb.New(diskdb, config.HashDB, resolver)
+ }
+ return db
+}
+
+// Reader returns a reader for accessing all trie nodes with provided state root.
+// An error will be returned if the requested state is not available.
+func (db *Database) Reader(blockRoot common.Hash) (database.Reader, error) {
+ switch b := db.backend.(type) {
+ case *hashdb.Database:
+ return b.Reader(blockRoot)
+ case *pathdb.Database:
+ return b.Reader(blockRoot)
+ }
+ return nil, errors.New("unknown backend")
+}
+
+// Update performs a state transition by committing dirty nodes contained in the
+// given set in order to update state from the specified parent to the specified
+// root. The held pre-images accumulated up to this point will be flushed in case
+// the size exceeds the threshold.
+//
+// The passed in maps(nodes, states) will be retained to avoid copying everything.
+// Therefore, these maps must not be changed afterwards.
+func (db *Database) Update(root common.Hash, parent common.Hash, block uint64, nodes *trienode.MergedNodeSet, states *triestate.Set) error {
+ if db.preimages != nil {
+ db.preimages.commit(false)
+ }
+ return db.backend.Update(root, parent, block, nodes, states)
+}
+
+// Commit iterates over all the children of a particular node, writes them out
+// to disk. As a side effect, all pre-images accumulated up to this point are
+// also written.
+func (db *Database) Commit(root common.Hash, report bool) error {
+ if db.preimages != nil {
+ db.preimages.commit(true)
+ }
+ return db.backend.Commit(root, report)
+}
+
+// Size returns the storage size of diff layer nodes above the persistent disk
+// layer, the dirty nodes buffered within the disk layer, and the size of cached
+// preimages.
+func (db *Database) Size() (common.StorageSize, common.StorageSize, common.StorageSize) {
+ var (
+ diffs, nodes common.StorageSize
+ preimages common.StorageSize
+ )
+ diffs, nodes = db.backend.Size()
+ if db.preimages != nil {
+ preimages = db.preimages.size()
+ }
+ return diffs, nodes, preimages
+}
+
+// Initialized returns an indicator if the state data is already initialized
+// according to the state scheme.
+func (db *Database) Initialized(genesisRoot common.Hash) bool {
+ return db.backend.Initialized(genesisRoot)
+}
+
+// Scheme returns the node scheme used in the database.
+func (db *Database) Scheme() string {
+ return db.backend.Scheme()
+}
+
+// Close flushes the dangling preimages to disk and closes the trie database.
+// It is meant to be called when closing the blockchain object, so that all
+// resources held can be released correctly.
+func (db *Database) Close() error {
+ db.WritePreimages()
+ return db.backend.Close()
+}
+
+// WritePreimages flushes all accumulated preimages to disk forcibly.
+func (db *Database) WritePreimages() {
+ if db.preimages != nil {
+ // db.preimages.commit(true)
+ }
+}
+
+// Preimage retrieves a cached trie node pre-image from preimage store.
+func (db *Database) Preimage(hash common.Hash) []byte {
+ if db.preimages == nil {
+ return nil
+ }
+ return db.preimages.preimage(hash)
+}
+
+// InsertPreimage writes pre-images of trie node to the preimage store.
+func (db *Database) InsertPreimage(preimages map[common.Hash][]byte) {
+ if db.preimages == nil {
+ return
+ }
+ db.preimages.insertPreimage(preimages)
+}
+
+// Cap iteratively flushes old but still referenced trie nodes until the total
+// memory usage goes below the given threshold. The held pre-images accumulated
+// up to this point will be flushed in case the size exceeds the threshold.
+//
+// It's only supported by hash-based database and will return an error for others.
+func (db *Database) Cap(limit common.StorageSize) error {
+ hdb, ok := db.backend.(*hashdb.Database)
+ if !ok {
+ return errors.New("not supported")
+ }
+ if db.preimages != nil {
+ // db.preimages.commit(false)
+ }
+ return hdb.Cap(limit)
+}
+
+// Reference adds a new reference from a parent node to a child node. This function
+// is used to add reference between internal trie node and external node(e.g. storage
+// trie root), all internal trie nodes are referenced together by database itself.
+//
+// It's only supported by hash-based database and will return an error for others.
+func (db *Database) Reference(root common.Hash, parent common.Hash) error {
+ hdb, ok := db.backend.(*hashdb.Database)
+ if !ok {
+ return errors.New("not supported")
+ }
+ hdb.Reference(root, parent)
+ return nil
+}
+
+// Dereference removes an existing reference from a root node. It's only
+// supported by hash-based database and will return an error for others.
+func (db *Database) Dereference(root common.Hash) error {
+ hdb, ok := db.backend.(*hashdb.Database)
+ if !ok {
+ return errors.New("not supported")
+ }
+ hdb.Dereference(root)
+ return nil
+}
+
+// Recover rollbacks the database to a specified historical point. The state is
+// supported as the rollback destination only if it's canonical state and the
+// corresponding trie histories are existent. It's only supported by path-based
+// database and will return an error for others.
+func (db *Database) Recover(target common.Hash) error {
+ pdb, ok := db.backend.(*pathdb.Database)
+ if !ok {
+ return errors.New("not supported")
+ }
+ var loader triestate.TrieLoader
+ if db.config.IsVerkle {
+ // TODO define verkle loader
+ log.Crit("Verkle loader is not defined")
+ } else {
+ loader = trie.NewMerkleLoader(db)
+ }
+ return pdb.Recover(target, loader)
+}
+
+// Recoverable returns the indicator if the specified state is enabled to be
+// recovered. It's only supported by path-based database and will return an
+// error for others.
+func (db *Database) Recoverable(root common.Hash) (bool, error) {
+ pdb, ok := db.backend.(*pathdb.Database)
+ if !ok {
+ return false, errors.New("not supported")
+ }
+ return pdb.Recoverable(root), nil
+}
+
+// Disable deactivates the database and invalidates all available state layers
+// as stale to prevent access to the persistent state, which is in the syncing
+// stage.
+//
+// It's only supported by path-based database and will return an error for others.
+func (db *Database) Disable() error {
+ pdb, ok := db.backend.(*pathdb.Database)
+ if !ok {
+ return errors.New("not supported")
+ }
+ return pdb.Disable()
+}
+
+// Enable activates database and resets the state tree with the provided persistent
+// state root once the state sync is finished.
+func (db *Database) Enable(root common.Hash) error {
+ pdb, ok := db.backend.(*pathdb.Database)
+ if !ok {
+ return errors.New("not supported")
+ }
+ return pdb.Enable(root)
+}
+
+// Journal commits an entire diff hierarchy to disk into a single journal entry.
+// This is meant to be used during shutdown to persist the snapshot without
+// flattening everything down (bad for reorgs). It's only supported by path-based
+// database and will return an error for others.
+func (db *Database) Journal(root common.Hash) error {
+ pdb, ok := db.backend.(*pathdb.Database)
+ if !ok {
+ return errors.New("not supported")
+ }
+ return pdb.Journal(root)
+}
+
+// SetBufferSize sets the node buffer size to the provided value(in bytes).
+// It's only supported by path-based database and will return an error for
+// others.
+func (db *Database) SetBufferSize(size int) error {
+ pdb, ok := db.backend.(*pathdb.Database)
+ if !ok {
+ return errors.New("not supported")
+ }
+ return pdb.SetBufferSize(size)
+}
+
+// IsVerkle returns the indicator if the database is holding a verkle tree.
+func (db *Database) IsVerkle() bool {
+ return db.config.IsVerkle
+}
diff --git a/trie_by_cid/triedb/database/database.go b/trie_by_cid/triedb/database/database.go
new file mode 100644
index 0000000..18a8f45
--- /dev/null
+++ b/trie_by_cid/triedb/database/database.go
@@ -0,0 +1,48 @@
+// Copyright 2024 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see .
+
+package database
+
+import (
+ "github.com/ethereum/go-ethereum/common"
+)
+
+// Reader wraps the Node method of a backing trie reader.
+type Reader interface {
+ // Node retrieves the trie node blob with the provided trie identifier,
+ // node path and the corresponding node hash. No error will be returned
+ // if the node is not found.
+ Node(owner common.Hash, path []byte, hash common.Hash) ([]byte, error)
+}
+
+// PreimageStore wraps the methods of a backing store for reading and writing
+// trie node preimages.
+type PreimageStore interface {
+ // Preimage retrieves the preimage of the specified hash.
+ Preimage(hash common.Hash) []byte
+
+ // InsertPreimage commits a set of preimages along with their hashes.
+ InsertPreimage(preimages map[common.Hash][]byte)
+}
+
+// Database wraps the methods of a backing trie store.
+type Database interface {
+ PreimageStore
+
+ // Reader returns a node reader associated with the specific state.
+ // An error will be returned if the specified state is not available.
+ Reader(stateRoot common.Hash) (Reader, error)
+}
diff --git a/trie_by_cid/triedb/hashdb/database.go b/trie_by_cid/triedb/hashdb/database.go
new file mode 100644
index 0000000..bd3a191
--- /dev/null
+++ b/trie_by_cid/triedb/hashdb/database.go
@@ -0,0 +1,665 @@
+// Copyright 2018 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see .
+
+package hashdb
+
+import (
+ "errors"
+ "fmt"
+ "reflect"
+ "sync"
+ "time"
+
+ "github.com/VictoriaMetrics/fastcache"
+ "github.com/ethereum/go-ethereum/common"
+ "github.com/ethereum/go-ethereum/core/rawdb"
+ "github.com/ethereum/go-ethereum/core/types"
+ "github.com/ethereum/go-ethereum/ethdb"
+ "github.com/ethereum/go-ethereum/log"
+ "github.com/ethereum/go-ethereum/metrics"
+ "github.com/ethereum/go-ethereum/rlp"
+
+ "github.com/cerc-io/ipld-eth-statedb/internal"
+ "github.com/cerc-io/ipld-eth-statedb/trie_by_cid/trie/trienode"
+ "github.com/cerc-io/ipld-eth-statedb/trie_by_cid/trie/triestate"
+)
+
+var (
+ memcacheCleanHitMeter = metrics.NewRegisteredMeter("hashdb/memcache/clean/hit", nil)
+ memcacheCleanMissMeter = metrics.NewRegisteredMeter("hashdb/memcache/clean/miss", nil)
+ memcacheCleanReadMeter = metrics.NewRegisteredMeter("hashdb/memcache/clean/read", nil)
+ memcacheCleanWriteMeter = metrics.NewRegisteredMeter("hashdb/memcache/clean/write", nil)
+
+ memcacheDirtyHitMeter = metrics.NewRegisteredMeter("hashdb/memcache/dirty/hit", nil)
+ memcacheDirtyMissMeter = metrics.NewRegisteredMeter("hashdb/memcache/dirty/miss", nil)
+ memcacheDirtyReadMeter = metrics.NewRegisteredMeter("hashdb/memcache/dirty/read", nil)
+ memcacheDirtyWriteMeter = metrics.NewRegisteredMeter("hashdb/memcache/dirty/write", nil)
+
+ memcacheFlushTimeTimer = metrics.NewRegisteredResettingTimer("hashdb/memcache/flush/time", nil)
+ memcacheFlushNodesMeter = metrics.NewRegisteredMeter("hashdb/memcache/flush/nodes", nil)
+ memcacheFlushBytesMeter = metrics.NewRegisteredMeter("hashdb/memcache/flush/bytes", nil)
+
+ memcacheGCTimeTimer = metrics.NewRegisteredResettingTimer("hashdb/memcache/gc/time", nil)
+ memcacheGCNodesMeter = metrics.NewRegisteredMeter("hashdb/memcache/gc/nodes", nil)
+ memcacheGCBytesMeter = metrics.NewRegisteredMeter("hashdb/memcache/gc/bytes", nil)
+
+ memcacheCommitTimeTimer = metrics.NewRegisteredResettingTimer("hashdb/memcache/commit/time", nil)
+ memcacheCommitNodesMeter = metrics.NewRegisteredMeter("hashdb/memcache/commit/nodes", nil)
+ memcacheCommitBytesMeter = metrics.NewRegisteredMeter("hashdb/memcache/commit/bytes", nil)
+)
+
+// ChildResolver defines the required method to decode the provided
+// trie node and iterate the children on top.
+type ChildResolver interface {
+ ForEach(node []byte, onChild func(common.Hash))
+}
+
+// Config contains the settings for database.
+type Config struct {
+ CleanCacheSize int // Maximum memory allowance (in bytes) for caching clean nodes
+}
+
+// Defaults is the default setting for database if it's not specified.
+// Notably, clean cache is disabled explicitly,
+var Defaults = &Config{
+ // Explicitly set clean cache size to 0 to avoid creating fastcache,
+ // otherwise database must be closed when it's no longer needed to
+ // prevent memory leak.
+ CleanCacheSize: 0,
+}
+
+// Database is an intermediate write layer between the trie data structures and
+// the disk database. The aim is to accumulate trie writes in-memory and only
+// periodically flush a couple tries to disk, garbage collecting the remainder.
+type Database struct {
+ diskdb ethdb.Database // Persistent storage for matured trie nodes
+ resolver ChildResolver // The handler to resolve children of nodes
+
+ cleans *fastcache.Cache // GC friendly memory cache of clean node RLPs
+ dirties map[common.Hash]*cachedNode // Data and references relationships of dirty trie nodes
+ oldest common.Hash // Oldest tracked node, flush-list head
+ newest common.Hash // Newest tracked node, flush-list tail
+
+ gctime time.Duration // Time spent on garbage collection since last commit
+ gcnodes uint64 // Nodes garbage collected since last commit
+ gcsize common.StorageSize // Data storage garbage collected since last commit
+
+ flushtime time.Duration // Time spent on data flushing since last commit
+ flushnodes uint64 // Nodes flushed since last commit
+ flushsize common.StorageSize // Data storage flushed since last commit
+
+ dirtiesSize common.StorageSize // Storage size of the dirty node cache (exc. metadata)
+ childrenSize common.StorageSize // Storage size of the external children tracking
+
+ lock sync.RWMutex
+}
+
+// cachedNode is all the information we know about a single cached trie node
+// in the memory database write layer.
+type cachedNode struct {
+ node []byte // Encoded node blob, immutable
+ parents uint32 // Number of live nodes referencing this one
+ external map[common.Hash]struct{} // The set of external children
+ flushPrev common.Hash // Previous node in the flush-list
+ flushNext common.Hash // Next node in the flush-list
+}
+
+// cachedNodeSize is the raw size of a cachedNode data structure without any
+// node data included. It's an approximate size, but should be a lot better
+// than not counting them.
+var cachedNodeSize = int(reflect.TypeOf(cachedNode{}).Size())
+
+// forChildren invokes the callback for all the tracked children of this node,
+// both the implicit ones from inside the node as well as the explicit ones
+// from outside the node.
+func (n *cachedNode) forChildren(resolver ChildResolver, onChild func(hash common.Hash)) {
+ for child := range n.external {
+ onChild(child)
+ }
+ resolver.ForEach(n.node, onChild)
+}
+
+// New initializes the hash-based node database.
+func New(diskdb ethdb.Database, config *Config, resolver ChildResolver) *Database {
+ if config == nil {
+ config = Defaults
+ }
+ var cleans *fastcache.Cache
+ if config.CleanCacheSize > 0 {
+ cleans = fastcache.New(config.CleanCacheSize)
+ }
+ return &Database{
+ diskdb: diskdb,
+ resolver: resolver,
+ cleans: cleans,
+ dirties: make(map[common.Hash]*cachedNode),
+ }
+}
+
+// insert inserts a trie node into the memory database. All nodes inserted by
+// this function will be reference tracked. This function assumes the lock is
+// already held.
+func (db *Database) insert(hash common.Hash, node []byte) {
+ // If the node's already cached, skip
+ if _, ok := db.dirties[hash]; ok {
+ return
+ }
+ memcacheDirtyWriteMeter.Mark(int64(len(node)))
+
+ // Create the cached entry for this node
+ entry := &cachedNode{
+ node: node,
+ flushPrev: db.newest,
+ }
+ entry.forChildren(db.resolver, func(child common.Hash) {
+ if c := db.dirties[child]; c != nil {
+ c.parents++
+ }
+ })
+ db.dirties[hash] = entry
+
+ // Update the flush-list endpoints
+ if db.oldest == (common.Hash{}) {
+ db.oldest, db.newest = hash, hash
+ } else {
+ db.dirties[db.newest].flushNext, db.newest = hash, hash
+ }
+ db.dirtiesSize += common.StorageSize(common.HashLength + len(node))
+}
+
+// node retrieves an encoded cached trie node from memory. If it cannot be found
+// cached, the method queries the persistent database for the content.
+func (db *Database) node(hash common.Hash, codec uint64) ([]byte, error) {
+ // It doesn't make sense to retrieve the metaroot
+ if hash == (common.Hash{}) {
+ return nil, errors.New("not found")
+ }
+ // Retrieve the node from the clean cache if available
+ if db.cleans != nil {
+ if enc := db.cleans.Get(nil, hash[:]); enc != nil {
+ memcacheCleanHitMeter.Mark(1)
+ memcacheCleanReadMeter.Mark(int64(len(enc)))
+ return enc, nil
+ }
+ }
+ // Retrieve the node from the dirty cache if available.
+ db.lock.RLock()
+ dirty := db.dirties[hash]
+ db.lock.RUnlock()
+
+ // Return the cached node if it's found in the dirty set.
+ // The dirty.node field is immutable and safe to read it
+ // even without lock guard.
+ if dirty != nil {
+ memcacheDirtyHitMeter.Mark(1)
+ memcacheDirtyReadMeter.Mark(int64(len(dirty.node)))
+ return dirty.node, nil
+ }
+ memcacheDirtyMissMeter.Mark(1)
+
+ // Content unavailable in memory, attempt to retrieve from disk
+ cid, err := internal.Keccak256ToCid(codec, hash[:])
+ if err != nil {
+ return nil, err
+ }
+ enc, err := db.diskdb.Get(cid.Bytes())
+ if err != nil {
+ return nil, err
+ }
+ if len(enc) != 0 {
+ if db.cleans != nil {
+ db.cleans.Set(hash[:], enc)
+ memcacheCleanMissMeter.Mark(1)
+ memcacheCleanWriteMeter.Mark(int64(len(enc)))
+ }
+ return enc, nil
+ }
+ return nil, errors.New("not found")
+}
+
+// Reference adds a new reference from a parent node to a child node.
+// This function is used to add reference between internal trie node
+// and external node(e.g. storage trie root), all internal trie nodes
+// are referenced together by database itself.
+func (db *Database) Reference(child common.Hash, parent common.Hash) {
+ db.lock.Lock()
+ defer db.lock.Unlock()
+
+ db.reference(child, parent)
+}
+
+// reference is the private locked version of Reference.
+func (db *Database) reference(child common.Hash, parent common.Hash) {
+ // If the node does not exist, it's a node pulled from disk, skip
+ node, ok := db.dirties[child]
+ if !ok {
+ return
+ }
+ // The reference is for state root, increase the reference counter.
+ if parent == (common.Hash{}) {
+ node.parents += 1
+ return
+ }
+ // The reference is for external storage trie, don't duplicate if
+ // the reference is already existent.
+ if db.dirties[parent].external == nil {
+ db.dirties[parent].external = make(map[common.Hash]struct{})
+ }
+ if _, ok := db.dirties[parent].external[child]; ok {
+ return
+ }
+ node.parents++
+ db.dirties[parent].external[child] = struct{}{}
+ db.childrenSize += common.HashLength
+}
+
+// Dereference removes an existing reference from a root node.
+func (db *Database) Dereference(root common.Hash) {
+ // Sanity check to ensure that the meta-root is not removed
+ if root == (common.Hash{}) {
+ log.Error("Attempted to dereference the trie cache meta root")
+ return
+ }
+ db.lock.Lock()
+ defer db.lock.Unlock()
+
+ nodes, storage, start := len(db.dirties), db.dirtiesSize, time.Now()
+ db.dereference(root)
+
+ db.gcnodes += uint64(nodes - len(db.dirties))
+ db.gcsize += storage - db.dirtiesSize
+ db.gctime += time.Since(start)
+
+ memcacheGCTimeTimer.Update(time.Since(start))
+ memcacheGCBytesMeter.Mark(int64(storage - db.dirtiesSize))
+ memcacheGCNodesMeter.Mark(int64(nodes - len(db.dirties)))
+
+ log.Debug("Dereferenced trie from memory database", "nodes", nodes-len(db.dirties), "size", storage-db.dirtiesSize, "time", time.Since(start),
+ "gcnodes", db.gcnodes, "gcsize", db.gcsize, "gctime", db.gctime, "livenodes", len(db.dirties), "livesize", db.dirtiesSize)
+}
+
+// dereference is the private locked version of Dereference.
+func (db *Database) dereference(hash common.Hash) {
+ // If the node does not exist, it's a previously committed node.
+ node, ok := db.dirties[hash]
+ if !ok {
+ return
+ }
+ // If there are no more references to the node, delete it and cascade
+ if node.parents > 0 {
+ // This is a special cornercase where a node loaded from disk (i.e. not in the
+ // memcache any more) gets reinjected as a new node (short node split into full,
+ // then reverted into short), causing a cached node to have no parents. That is
+ // no problem in itself, but don't make maxint parents out of it.
+ node.parents--
+ }
+ if node.parents == 0 {
+ // Remove the node from the flush-list
+ switch hash {
+ case db.oldest:
+ db.oldest = node.flushNext
+ if node.flushNext != (common.Hash{}) {
+ db.dirties[node.flushNext].flushPrev = common.Hash{}
+ }
+ case db.newest:
+ db.newest = node.flushPrev
+ if node.flushPrev != (common.Hash{}) {
+ db.dirties[node.flushPrev].flushNext = common.Hash{}
+ }
+ default:
+ db.dirties[node.flushPrev].flushNext = node.flushNext
+ db.dirties[node.flushNext].flushPrev = node.flushPrev
+ }
+ // Dereference all children and delete the node
+ node.forChildren(db.resolver, func(child common.Hash) {
+ db.dereference(child)
+ })
+ delete(db.dirties, hash)
+ db.dirtiesSize -= common.StorageSize(common.HashLength + len(node.node))
+ if node.external != nil {
+ db.childrenSize -= common.StorageSize(len(node.external) * common.HashLength)
+ }
+ }
+}
+
+// Cap iteratively flushes old but still referenced trie nodes until the total
+// memory usage goes below the given threshold.
+func (db *Database) Cap(limit common.StorageSize) error {
+ db.lock.Lock()
+ defer db.lock.Unlock()
+
+ // Create a database batch to flush persistent data out. It is important that
+ // outside code doesn't see an inconsistent state (referenced data removed from
+ // memory cache during commit but not yet in persistent storage). This is ensured
+ // by only uncaching existing data when the database write finalizes.
+ batch := db.diskdb.NewBatch()
+ nodes, storage, start := len(db.dirties), db.dirtiesSize, time.Now()
+
+ // db.dirtiesSize only contains the useful data in the cache, but when reporting
+ // the total memory consumption, the maintenance metadata is also needed to be
+ // counted.
+ size := db.dirtiesSize + common.StorageSize(len(db.dirties)*cachedNodeSize)
+ size += db.childrenSize
+
+ // Keep committing nodes from the flush-list until we're below allowance
+ oldest := db.oldest
+ for size > limit && oldest != (common.Hash{}) {
+ // Fetch the oldest referenced node and push into the batch
+ node := db.dirties[oldest]
+ rawdb.WriteLegacyTrieNode(batch, oldest, node.node)
+
+ // If we exceeded the ideal batch size, commit and reset
+ if batch.ValueSize() >= ethdb.IdealBatchSize {
+ if err := batch.Write(); err != nil {
+ log.Error("Failed to write flush list to disk", "err", err)
+ return err
+ }
+ batch.Reset()
+ }
+ // Iterate to the next flush item, or abort if the size cap was achieved. Size
+ // is the total size, including the useful cached data (hash -> blob), the
+ // cache item metadata, as well as external children mappings.
+ size -= common.StorageSize(common.HashLength + len(node.node) + cachedNodeSize)
+ if node.external != nil {
+ size -= common.StorageSize(len(node.external) * common.HashLength)
+ }
+ oldest = node.flushNext
+ }
+ // Flush out any remainder data from the last batch
+ if err := batch.Write(); err != nil {
+ log.Error("Failed to write flush list to disk", "err", err)
+ return err
+ }
+ // Write successful, clear out the flushed data
+ for db.oldest != oldest {
+ node := db.dirties[db.oldest]
+ delete(db.dirties, db.oldest)
+ db.oldest = node.flushNext
+
+ db.dirtiesSize -= common.StorageSize(common.HashLength + len(node.node))
+ if node.external != nil {
+ db.childrenSize -= common.StorageSize(len(node.external) * common.HashLength)
+ }
+ }
+ if db.oldest != (common.Hash{}) {
+ db.dirties[db.oldest].flushPrev = common.Hash{}
+ }
+ db.flushnodes += uint64(nodes - len(db.dirties))
+ db.flushsize += storage - db.dirtiesSize
+ db.flushtime += time.Since(start)
+
+ memcacheFlushTimeTimer.Update(time.Since(start))
+ memcacheFlushBytesMeter.Mark(int64(storage - db.dirtiesSize))
+ memcacheFlushNodesMeter.Mark(int64(nodes - len(db.dirties)))
+
+ log.Debug("Persisted nodes from memory database", "nodes", nodes-len(db.dirties), "size", storage-db.dirtiesSize, "time", time.Since(start),
+ "flushnodes", db.flushnodes, "flushsize", db.flushsize, "flushtime", db.flushtime, "livenodes", len(db.dirties), "livesize", db.dirtiesSize)
+
+ return nil
+}
+
+// Commit iterates over all the children of a particular node, writes them out
+// to disk, forcefully tearing down all references in both directions. As a side
+// effect, all pre-images accumulated up to this point are also written.
+func (db *Database) Commit(node common.Hash, report bool) error {
+ db.lock.Lock()
+ defer db.lock.Unlock()
+
+ // Create a database batch to flush persistent data out. It is important that
+ // outside code doesn't see an inconsistent state (referenced data removed from
+ // memory cache during commit but not yet in persistent storage). This is ensured
+ // by only uncaching existing data when the database write finalizes.
+ start := time.Now()
+ batch := db.diskdb.NewBatch()
+
+ // Move the trie itself into the batch, flushing if enough data is accumulated
+ nodes, storage := len(db.dirties), db.dirtiesSize
+
+ uncacher := &cleaner{db}
+ if err := db.commit(node, batch, uncacher); err != nil {
+ log.Error("Failed to commit trie from trie database", "err", err)
+ return err
+ }
+ // Trie mostly committed to disk, flush any batch leftovers
+ if err := batch.Write(); err != nil {
+ log.Error("Failed to write trie to disk", "err", err)
+ return err
+ }
+ // Uncache any leftovers in the last batch
+ if err := batch.Replay(uncacher); err != nil {
+ return err
+ }
+ batch.Reset()
+
+ // Reset the storage counters and bumped metrics
+ memcacheCommitTimeTimer.Update(time.Since(start))
+ memcacheCommitBytesMeter.Mark(int64(storage - db.dirtiesSize))
+ memcacheCommitNodesMeter.Mark(int64(nodes - len(db.dirties)))
+
+ logger := log.Info
+ if !report {
+ logger = log.Debug
+ }
+ logger("Persisted trie from memory database", "nodes", nodes-len(db.dirties)+int(db.flushnodes), "size", storage-db.dirtiesSize+db.flushsize, "time", time.Since(start)+db.flushtime,
+ "gcnodes", db.gcnodes, "gcsize", db.gcsize, "gctime", db.gctime, "livenodes", len(db.dirties), "livesize", db.dirtiesSize)
+
+ // Reset the garbage collection statistics
+ db.gcnodes, db.gcsize, db.gctime = 0, 0, 0
+ db.flushnodes, db.flushsize, db.flushtime = 0, 0, 0
+
+ return nil
+}
+
+// commit is the private locked version of Commit.
+func (db *Database) commit(hash common.Hash, batch ethdb.Batch, uncacher *cleaner) error {
+ // If the node does not exist, it's a previously committed node
+ node, ok := db.dirties[hash]
+ if !ok {
+ return nil
+ }
+ var err error
+
+ // Dereference all children and delete the node
+ node.forChildren(db.resolver, func(child common.Hash) {
+ if err == nil {
+ err = db.commit(child, batch, uncacher)
+ }
+ })
+ if err != nil {
+ return err
+ }
+ // If we've reached an optimal batch size, commit and start over
+ rawdb.WriteLegacyTrieNode(batch, hash, node.node)
+ if batch.ValueSize() >= ethdb.IdealBatchSize {
+ if err := batch.Write(); err != nil {
+ return err
+ }
+ err := batch.Replay(uncacher)
+ if err != nil {
+ return err
+ }
+ batch.Reset()
+ }
+ return nil
+}
+
+// cleaner is a database batch replayer that takes a batch of write operations
+// and cleans up the trie database from anything written to disk.
+type cleaner struct {
+ db *Database
+}
+
+// Put reacts to database writes and implements dirty data uncaching. This is the
+// post-processing step of a commit operation where the already persisted trie is
+// removed from the dirty cache and moved into the clean cache. The reason behind
+// the two-phase commit is to ensure data availability while moving from memory
+// to disk.
+func (c *cleaner) Put(key []byte, rlp []byte) error {
+ hash := common.BytesToHash(key)
+
+ // If the node does not exist, we're done on this path
+ node, ok := c.db.dirties[hash]
+ if !ok {
+ return nil
+ }
+ // Node still exists, remove it from the flush-list
+ switch hash {
+ case c.db.oldest:
+ c.db.oldest = node.flushNext
+ if node.flushNext != (common.Hash{}) {
+ c.db.dirties[node.flushNext].flushPrev = common.Hash{}
+ }
+ case c.db.newest:
+ c.db.newest = node.flushPrev
+ if node.flushPrev != (common.Hash{}) {
+ c.db.dirties[node.flushPrev].flushNext = common.Hash{}
+ }
+ default:
+ c.db.dirties[node.flushPrev].flushNext = node.flushNext
+ c.db.dirties[node.flushNext].flushPrev = node.flushPrev
+ }
+ // Remove the node from the dirty cache
+ delete(c.db.dirties, hash)
+ c.db.dirtiesSize -= common.StorageSize(common.HashLength + len(node.node))
+ if node.external != nil {
+ c.db.childrenSize -= common.StorageSize(len(node.external) * common.HashLength)
+ }
+ // Move the flushed node into the clean cache to prevent insta-reloads
+ if c.db.cleans != nil {
+ c.db.cleans.Set(hash[:], rlp)
+ memcacheCleanWriteMeter.Mark(int64(len(rlp)))
+ }
+ return nil
+}
+
+func (c *cleaner) Delete(key []byte) error {
+ panic("not implemented")
+}
+
+// Initialized returns an indicator if state data is already initialized
+// in hash-based scheme by checking the presence of genesis state.
+func (db *Database) Initialized(genesisRoot common.Hash) bool {
+ return rawdb.HasLegacyTrieNode(db.diskdb, genesisRoot)
+}
+
+// Update inserts the dirty nodes in provided nodeset into database and link the
+// account trie with multiple storage tries if necessary.
+func (db *Database) Update(root common.Hash, parent common.Hash, block uint64, nodes *trienode.MergedNodeSet, states *triestate.Set) error {
+ // Ensure the parent state is present and signal a warning if not.
+ if parent != types.EmptyRootHash {
+ if blob, _ := db.node(parent, internal.StateTrieCodec); len(blob) == 0 {
+ log.Error("parent state is not present")
+ }
+ }
+ db.lock.Lock()
+ defer db.lock.Unlock()
+
+ // Insert dirty nodes into the database. In the same tree, it must be
+ // ensured that children are inserted first, then parent so that children
+ // can be linked with their parent correctly.
+ //
+ // Note, the storage tries must be flushed before the account trie to
+ // retain the invariant that children go into the dirty cache first.
+ var order []common.Hash
+ for owner := range nodes.Sets {
+ if owner == (common.Hash{}) {
+ continue
+ }
+ order = append(order, owner)
+ }
+ if _, ok := nodes.Sets[common.Hash{}]; ok {
+ order = append(order, common.Hash{})
+ }
+ for _, owner := range order {
+ subset := nodes.Sets[owner]
+ subset.ForEachWithOrder(func(path string, n *trienode.Node) {
+ if n.IsDeleted() {
+ return // ignore deletion
+ }
+ db.insert(n.Hash, n.Blob)
+ })
+ }
+ // Link up the account trie and storage trie if the node points
+ // to an account trie leaf.
+ if set, present := nodes.Sets[common.Hash{}]; present {
+ for _, n := range set.Leaves {
+ var account types.StateAccount
+ if err := rlp.DecodeBytes(n.Blob, &account); err != nil {
+ return err
+ }
+ if account.Root != types.EmptyRootHash {
+ db.reference(account.Root, n.Parent)
+ }
+ }
+ }
+ return nil
+}
+
+// Size returns the current storage size of the memory cache in front of the
+// persistent database layer.
+//
+// The first return will always be 0, representing the memory stored in unbounded
+// diff layers above the dirty cache. This is only available in pathdb.
+func (db *Database) Size() (common.StorageSize, common.StorageSize) {
+ db.lock.RLock()
+ defer db.lock.RUnlock()
+
+ // db.dirtiesSize only contains the useful data in the cache, but when reporting
+ // the total memory consumption, the maintenance metadata is also needed to be
+ // counted.
+ var metadataSize = common.StorageSize(len(db.dirties) * cachedNodeSize)
+ return 0, db.dirtiesSize + db.childrenSize + metadataSize
+}
+
+// Close closes the trie database and releases all held resources.
+func (db *Database) Close() error {
+ if db.cleans != nil {
+ db.cleans.Reset()
+ db.cleans = nil
+ }
+ return nil
+}
+
+// Scheme returns the node scheme used in the database.
+func (db *Database) Scheme() string {
+ return rawdb.HashScheme
+}
+
+// Reader retrieves a node reader belonging to the given state root.
+// An error will be returned if the requested state is not available.
+func (db *Database) Reader(root common.Hash) (*reader, error) {
+ if _, err := db.node(root, internal.StateTrieCodec); err != nil {
+ return nil, fmt.Errorf("state %#x is not available, %v", root, err)
+ }
+ return &reader{db: db}, nil
+}
+
+// reader is a state reader of Database which implements the Reader interface.
+type reader struct {
+ db *Database
+}
+
+// Node retrieves the trie node with the given node hash. No error will be
+// returned if the node is not found.
+func (reader *reader) Node(owner common.Hash, path []byte, hash common.Hash) ([]byte, error) {
+ // this is an account node iff the owner hash is zero
+ codec := internal.StateTrieCodec
+ if owner != (common.Hash{}) {
+ codec = internal.StorageTrieCodec
+ }
+ blob, _ := reader.db.node(hash, codec)
+ return blob, nil
+}
diff --git a/trie_by_cid/triedb/pathdb/database.go b/trie_by_cid/triedb/pathdb/database.go
new file mode 100644
index 0000000..2c9dc05
--- /dev/null
+++ b/trie_by_cid/triedb/pathdb/database.go
@@ -0,0 +1,485 @@
+// Copyright 2022 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see .
+
+package pathdb
+
+import (
+ "errors"
+ "fmt"
+ "io"
+ "sync"
+ "time"
+
+ "github.com/cerc-io/ipld-eth-statedb/trie_by_cid/trie/trienode"
+ "github.com/cerc-io/ipld-eth-statedb/trie_by_cid/trie/triestate"
+ "github.com/ethereum/go-ethereum/common"
+ "github.com/ethereum/go-ethereum/core/rawdb"
+ "github.com/ethereum/go-ethereum/core/types"
+ "github.com/ethereum/go-ethereum/ethdb"
+ "github.com/ethereum/go-ethereum/log"
+ "github.com/ethereum/go-ethereum/params"
+)
+
+const (
+ // maxDiffLayers is the maximum diff layers allowed in the layer tree.
+ maxDiffLayers = 128
+
+ // defaultCleanSize is the default memory allowance of clean cache.
+ defaultCleanSize = 16 * 1024 * 1024
+
+ // maxBufferSize is the maximum memory allowance of node buffer.
+ // Too large nodebuffer will cause the system to pause for a long
+ // time when write happens. Also, the largest batch that pebble can
+ // support is 4GB, node will panic if batch size exceeds this limit.
+ maxBufferSize = 256 * 1024 * 1024
+
+ // DefaultBufferSize is the default memory allowance of node buffer
+ // that aggregates the writes from above until it's flushed into the
+ // disk. It's meant to be used once the initial sync is finished.
+ // Do not increase the buffer size arbitrarily, otherwise the system
+ // pause time will increase when the database writes happen.
+ DefaultBufferSize = 64 * 1024 * 1024
+)
+
+// layer is the interface implemented by all state layers which includes some
+// public methods and some additional methods for internal usage.
+type layer interface {
+ // Node retrieves the trie node with the node info. An error will be returned
+ // if the read operation exits abnormally. For example, if the layer is already
+ // stale, or the associated state is regarded as corrupted. Notably, no error
+ // will be returned if the requested node is not found in database.
+ Node(owner common.Hash, path []byte, hash common.Hash) ([]byte, error)
+
+ // rootHash returns the root hash for which this layer was made.
+ rootHash() common.Hash
+
+ // stateID returns the associated state id of layer.
+ stateID() uint64
+
+ // parentLayer returns the subsequent layer of it, or nil if the disk was reached.
+ parentLayer() layer
+
+ // update creates a new layer on top of the existing layer diff tree with
+ // the provided dirty trie nodes along with the state change set.
+ //
+ // Note, the maps are retained by the method to avoid copying everything.
+ update(root common.Hash, id uint64, block uint64, nodes map[common.Hash]map[string]*trienode.Node, states *triestate.Set) *diffLayer
+
+ // journal commits an entire diff hierarchy to disk into a single journal entry.
+ // This is meant to be used during shutdown to persist the layer without
+ // flattening everything down (bad for reorgs).
+ journal(w io.Writer) error
+}
+
+// Config contains the settings for database.
+type Config struct {
+ StateHistory uint64 // Number of recent blocks to maintain state history for
+ CleanCacheSize int // Maximum memory allowance (in bytes) for caching clean nodes
+ DirtyCacheSize int // Maximum memory allowance (in bytes) for caching dirty nodes
+ ReadOnly bool // Flag whether the database is opened in read only mode.
+}
+
+// sanitize checks the provided user configurations and changes anything that's
+// unreasonable or unworkable.
+func (c *Config) sanitize() *Config {
+ conf := *c
+ if conf.DirtyCacheSize > maxBufferSize {
+ log.Warn("Sanitizing invalid node buffer size", "provided", common.StorageSize(conf.DirtyCacheSize), "updated", common.StorageSize(maxBufferSize))
+ conf.DirtyCacheSize = maxBufferSize
+ }
+ return &conf
+}
+
+// Defaults contains default settings for Ethereum mainnet.
+var Defaults = &Config{
+ StateHistory: params.FullImmutabilityThreshold,
+ CleanCacheSize: defaultCleanSize,
+ DirtyCacheSize: DefaultBufferSize,
+}
+
+// ReadOnly is the config in order to open database in read only mode.
+var ReadOnly = &Config{ReadOnly: true}
+
+// Database is a multiple-layered structure for maintaining in-memory trie nodes.
+// It consists of one persistent base layer backed by a key-value store, on top
+// of which arbitrarily many in-memory diff layers are stacked. The memory diffs
+// can form a tree with branching, but the disk layer is singleton and common to
+// all. If a reorg goes deeper than the disk layer, a batch of reverse diffs can
+// be applied to rollback. The deepest reorg that can be handled depends on the
+// amount of state histories tracked in the disk.
+//
+// At most one readable and writable database can be opened at the same time in
+// the whole system which ensures that only one database writer can operate disk
+// state. Unexpected open operations can cause the system to panic.
+type Database struct {
+ // readOnly is the flag whether the mutation is allowed to be applied.
+ // It will be set automatically when the database is journaled during
+ // the shutdown to reject all following unexpected mutations.
+ readOnly bool // Flag if database is opened in read only mode
+ waitSync bool // Flag if database is deactivated due to initial state sync
+ bufferSize int // Memory allowance (in bytes) for caching dirty nodes
+ config *Config // Configuration for database
+ diskdb ethdb.Database // Persistent storage for matured trie nodes
+ tree *layerTree // The group for all known layers
+ freezer *rawdb.ResettableFreezer // Freezer for storing trie histories, nil possible in tests
+ lock sync.RWMutex // Lock to prevent mutations from happening at the same time
+}
+
+// New attempts to load an already existing layer from a persistent key-value
+// store (with a number of memory layers from a journal). If the journal is not
+// matched with the base persistent layer, all the recorded diff layers are discarded.
+func New(diskdb ethdb.Database, config *Config) *Database {
+ if config == nil {
+ config = Defaults
+ }
+ config = config.sanitize()
+
+ db := &Database{
+ readOnly: config.ReadOnly,
+ bufferSize: config.DirtyCacheSize,
+ config: config,
+ diskdb: diskdb,
+ }
+ // Construct the layer tree by resolving the in-disk singleton state
+ // and in-memory layer journal.
+ db.tree = newLayerTree(db.loadLayers())
+
+ // Open the freezer for state history if the passed database contains an
+ // ancient store. Otherwise, all the relevant functionalities are disabled.
+ //
+ // Because the freezer can only be opened once at the same time, this
+ // mechanism also ensures that at most one **non-readOnly** database
+ // is opened at the same time to prevent accidental mutation.
+ if ancient, err := diskdb.AncientDatadir(); err == nil && ancient != "" && !db.readOnly {
+ freezer, err := rawdb.NewStateFreezer(ancient, false)
+ if err != nil {
+ log.Crit("Failed to open state history freezer", "err", err)
+ }
+ db.freezer = freezer
+
+ diskLayerID := db.tree.bottom().stateID()
+ if diskLayerID == 0 {
+ // Reset the entire state histories in case the trie database is
+ // not initialized yet, as these state histories are not expected.
+ frozen, err := db.freezer.Ancients()
+ if err != nil {
+ log.Crit("Failed to retrieve head of state history", "err", err)
+ }
+ if frozen != 0 {
+ err := db.freezer.Reset()
+ if err != nil {
+ log.Crit("Failed to reset state histories", "err", err)
+ }
+ log.Info("Truncated extraneous state history")
+ }
+ } else {
+ // Truncate the extra state histories above in freezer in case
+ // it's not aligned with the disk layer.
+ pruned, err := truncateFromHead(db.diskdb, freezer, diskLayerID)
+ if err != nil {
+ log.Crit("Failed to truncate extra state histories", "err", err)
+ }
+ if pruned != 0 {
+ log.Warn("Truncated extra state histories", "number", pruned)
+ }
+ }
+ }
+ // Disable database in case node is still in the initial state sync stage.
+ if rawdb.ReadSnapSyncStatusFlag(diskdb) == rawdb.StateSyncRunning && !db.readOnly {
+ if err := db.Disable(); err != nil {
+ log.Crit("Failed to disable database", "err", err) // impossible to happen
+ }
+ }
+ log.Warn("Path-based state scheme is an experimental feature")
+ return db
+}
+
+// Reader retrieves a layer belonging to the given state root.
+func (db *Database) Reader(root common.Hash) (layer, error) {
+ l := db.tree.get(root)
+ if l == nil {
+ return nil, fmt.Errorf("state %#x is not available", root)
+ }
+ return l, nil
+}
+
+// Update adds a new layer into the tree, if that can be linked to an existing
+// old parent. It is disallowed to insert a disk layer (the origin of all). Apart
+// from that this function will flatten the extra diff layers at bottom into disk
+// to only keep 128 diff layers in memory by default.
+//
+// The passed in maps(nodes, states) will be retained to avoid copying everything.
+// Therefore, these maps must not be changed afterwards.
+func (db *Database) Update(root common.Hash, parentRoot common.Hash, block uint64, nodes *trienode.MergedNodeSet, states *triestate.Set) error {
+ // Hold the lock to prevent concurrent mutations.
+ db.lock.Lock()
+ defer db.lock.Unlock()
+
+ // Short circuit if the mutation is not allowed.
+ if err := db.modifyAllowed(); err != nil {
+ return err
+ }
+ if err := db.tree.add(root, parentRoot, block, nodes, states); err != nil {
+ return err
+ }
+ // Keep 128 diff layers in the memory, persistent layer is 129th.
+ // - head layer is paired with HEAD state
+ // - head-1 layer is paired with HEAD-1 state
+ // - head-127 layer(bottom-most diff layer) is paired with HEAD-127 state
+ // - head-128 layer(disk layer) is paired with HEAD-128 state
+ return db.tree.cap(root, maxDiffLayers)
+}
+
+// Commit traverses downwards the layer tree from a specified layer with the
+// provided state root and all the layers below are flattened downwards. It
+// can be used alone and mostly for test purposes.
+func (db *Database) Commit(root common.Hash, report bool) error {
+ // Hold the lock to prevent concurrent mutations.
+ db.lock.Lock()
+ defer db.lock.Unlock()
+
+ // Short circuit if the mutation is not allowed.
+ if err := db.modifyAllowed(); err != nil {
+ return err
+ }
+ return db.tree.cap(root, 0)
+}
+
+// Disable deactivates the database and invalidates all available state layers
+// as stale to prevent access to the persistent state, which is in the syncing
+// stage.
+func (db *Database) Disable() error {
+ db.lock.Lock()
+ defer db.lock.Unlock()
+
+ // Short circuit if the database is in read only mode.
+ if db.readOnly {
+ return errDatabaseReadOnly
+ }
+ // Prevent duplicated disable operation.
+ if db.waitSync {
+ log.Error("Reject duplicated disable operation")
+ return nil
+ }
+ db.waitSync = true
+
+ // Mark the disk layer as stale to prevent access to persistent state.
+ db.tree.bottom().markStale()
+
+ // Write the initial sync flag to persist it across restarts.
+ rawdb.WriteSnapSyncStatusFlag(db.diskdb, rawdb.StateSyncRunning)
+ log.Info("Disabled trie database due to state sync")
+ return nil
+}
+
+// Enable activates database and resets the state tree with the provided persistent
+// state root once the state sync is finished.
+func (db *Database) Enable(root common.Hash) error {
+ db.lock.Lock()
+ defer db.lock.Unlock()
+
+ // Short circuit if the database is in read only mode.
+ if db.readOnly {
+ return errDatabaseReadOnly
+ }
+ // Ensure the provided state root matches the stored one.
+ root = types.TrieRootHash(root)
+ _, stored := rawdb.ReadAccountTrieNode(db.diskdb, nil)
+ if stored != root {
+ return fmt.Errorf("state root mismatch: stored %x, synced %x", stored, root)
+ }
+ // Drop the stale state journal in persistent database and
+ // reset the persistent state id back to zero.
+ batch := db.diskdb.NewBatch()
+ rawdb.DeleteTrieJournal(batch)
+ rawdb.WritePersistentStateID(batch, 0)
+ if err := batch.Write(); err != nil {
+ return err
+ }
+ // Clean up all state histories in freezer. Theoretically
+ // all root->id mappings should be removed as well. Since
+ // mappings can be huge and might take a while to clear
+ // them, just leave them in disk and wait for overwriting.
+ if db.freezer != nil {
+ if err := db.freezer.Reset(); err != nil {
+ return err
+ }
+ }
+ // Re-construct a new disk layer backed by persistent state
+ // with **empty clean cache and node buffer**.
+ db.tree.reset(newDiskLayer(root, 0, db, nil, newNodeBuffer(db.bufferSize, nil, 0)))
+
+ // Re-enable the database as the final step.
+ db.waitSync = false
+ rawdb.WriteSnapSyncStatusFlag(db.diskdb, rawdb.StateSyncFinished)
+ log.Info("Rebuilt trie database", "root", root)
+ return nil
+}
+
+// Recover rollbacks the database to a specified historical point.
+// The state is supported as the rollback destination only if it's
+// canonical state and the corresponding trie histories are existent.
+func (db *Database) Recover(root common.Hash, loader triestate.TrieLoader) error {
+ db.lock.Lock()
+ defer db.lock.Unlock()
+
+ // Short circuit if rollback operation is not supported.
+ if err := db.modifyAllowed(); err != nil {
+ return err
+ }
+ if db.freezer == nil {
+ return errors.New("state rollback is non-supported")
+ }
+ // Short circuit if the target state is not recoverable.
+ root = types.TrieRootHash(root)
+ if !db.Recoverable(root) {
+ return errStateUnrecoverable
+ }
+ // Apply the state histories upon the disk layer in order.
+ var (
+ start = time.Now()
+ dl = db.tree.bottom()
+ )
+ for dl.rootHash() != root {
+ h, err := readHistory(db.freezer, dl.stateID())
+ if err != nil {
+ return err
+ }
+ dl, err = dl.revert(h, loader)
+ if err != nil {
+ return err
+ }
+ // reset layer with newly created disk layer. It must be
+ // done after each revert operation, otherwise the new
+ // disk layer won't be accessible from outside.
+ db.tree.reset(dl)
+ }
+ rawdb.DeleteTrieJournal(db.diskdb)
+ _, err := truncateFromHead(db.diskdb, db.freezer, dl.stateID())
+ if err != nil {
+ return err
+ }
+ log.Debug("Recovered state", "root", root, "elapsed", common.PrettyDuration(time.Since(start)))
+ return nil
+}
+
+// Recoverable returns the indicator if the specified state is recoverable.
+func (db *Database) Recoverable(root common.Hash) bool {
+ // Ensure the requested state is a known state.
+ root = types.TrieRootHash(root)
+ id := rawdb.ReadStateID(db.diskdb, root)
+ if id == nil {
+ return false
+ }
+ // Recoverable state must below the disk layer. The recoverable
+ // state only refers the state that is currently not available,
+ // but can be restored by applying state history.
+ dl := db.tree.bottom()
+ if *id >= dl.stateID() {
+ return false
+ }
+ // Ensure the requested state is a canonical state and all state
+ // histories in range [id+1, disklayer.ID] are present and complete.
+ parent := root
+ return checkHistories(db.freezer, *id+1, dl.stateID()-*id, func(m *meta) error {
+ if m.parent != parent {
+ return errors.New("unexpected state history")
+ }
+ if len(m.incomplete) > 0 {
+ return errors.New("incomplete state history")
+ }
+ parent = m.root
+ return nil
+ }) == nil
+}
+
+// Close closes the trie database and the held freezer.
+func (db *Database) Close() error {
+ db.lock.Lock()
+ defer db.lock.Unlock()
+
+ // Set the database to read-only mode to prevent all
+ // following mutations.
+ db.readOnly = true
+
+ // Release the memory held by clean cache.
+ db.tree.bottom().resetCache()
+
+ // Close the attached state history freezer.
+ if db.freezer == nil {
+ return nil
+ }
+ return db.freezer.Close()
+}
+
+// Size returns the current storage size of the memory cache in front of the
+// persistent database layer.
+func (db *Database) Size() (diffs common.StorageSize, nodes common.StorageSize) {
+ db.tree.forEach(func(layer layer) {
+ if diff, ok := layer.(*diffLayer); ok {
+ diffs += common.StorageSize(diff.memory)
+ }
+ if disk, ok := layer.(*diskLayer); ok {
+ nodes += disk.size()
+ }
+ })
+ return diffs, nodes
+}
+
+// Initialized returns an indicator if the state data is already
+// initialized in path-based scheme.
+func (db *Database) Initialized(genesisRoot common.Hash) bool {
+ var inited bool
+ db.tree.forEach(func(layer layer) {
+ if layer.rootHash() != types.EmptyRootHash {
+ inited = true
+ }
+ })
+ if !inited {
+ inited = rawdb.ReadSnapSyncStatusFlag(db.diskdb) != rawdb.StateSyncUnknown
+ }
+ return inited
+}
+
+// SetBufferSize sets the node buffer size to the provided value(in bytes).
+func (db *Database) SetBufferSize(size int) error {
+ db.lock.Lock()
+ defer db.lock.Unlock()
+
+ if size > maxBufferSize {
+ log.Info("Capped node buffer size", "provided", common.StorageSize(size), "adjusted", common.StorageSize(maxBufferSize))
+ size = maxBufferSize
+ }
+ db.bufferSize = size
+ return db.tree.bottom().setBufferSize(db.bufferSize)
+}
+
+// Scheme returns the node scheme used in the database.
+func (db *Database) Scheme() string {
+ return rawdb.PathScheme
+}
+
+// modifyAllowed returns the indicator if mutation is allowed. This function
+// assumes the db.lock is already held.
+func (db *Database) modifyAllowed() error {
+ if db.readOnly {
+ return errDatabaseReadOnly
+ }
+ if db.waitSync {
+ return errDatabaseWaitSync
+ }
+ return nil
+}
diff --git a/trie_by_cid/triedb/pathdb/database_test.go b/trie_by_cid/triedb/pathdb/database_test.go
new file mode 100644
index 0000000..78068d9
--- /dev/null
+++ b/trie_by_cid/triedb/pathdb/database_test.go
@@ -0,0 +1,608 @@
+// Copyright 2022 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see .
+
+package pathdb
+
+import (
+ "bytes"
+ "errors"
+ "fmt"
+ "math/rand"
+ "testing"
+
+ "github.com/cerc-io/ipld-eth-statedb/trie_by_cid/trie/testutil"
+ "github.com/cerc-io/ipld-eth-statedb/trie_by_cid/trie/trienode"
+ "github.com/cerc-io/ipld-eth-statedb/trie_by_cid/trie/triestate"
+ "github.com/ethereum/go-ethereum/common"
+ "github.com/ethereum/go-ethereum/core/rawdb"
+ "github.com/ethereum/go-ethereum/core/types"
+ "github.com/ethereum/go-ethereum/crypto"
+ "github.com/ethereum/go-ethereum/rlp"
+ "github.com/holiman/uint256"
+)
+
+func updateTrie(addrHash common.Hash, root common.Hash, dirties, cleans map[common.Hash][]byte) (common.Hash, *trienode.NodeSet) {
+ h, err := newTestHasher(addrHash, root, cleans)
+ if err != nil {
+ panic(fmt.Errorf("failed to create hasher, err: %w", err))
+ }
+ for key, val := range dirties {
+ if len(val) == 0 {
+ h.Delete(key.Bytes())
+ } else {
+ h.Update(key.Bytes(), val)
+ }
+ }
+ root, nodes, _ := h.Commit(false)
+ return root, nodes
+}
+
+func generateAccount(storageRoot common.Hash) types.StateAccount {
+ return types.StateAccount{
+ Nonce: uint64(rand.Intn(100)),
+ Balance: uint256.NewInt(rand.Uint64()),
+ CodeHash: testutil.RandBytes(32),
+ Root: storageRoot,
+ }
+}
+
+const (
+ createAccountOp int = iota
+ modifyAccountOp
+ deleteAccountOp
+ opLen
+)
+
+type genctx struct {
+ accounts map[common.Hash][]byte
+ storages map[common.Hash]map[common.Hash][]byte
+ accountOrigin map[common.Address][]byte
+ storageOrigin map[common.Address]map[common.Hash][]byte
+ nodes *trienode.MergedNodeSet
+}
+
+func newCtx() *genctx {
+ return &genctx{
+ accounts: make(map[common.Hash][]byte),
+ storages: make(map[common.Hash]map[common.Hash][]byte),
+ accountOrigin: make(map[common.Address][]byte),
+ storageOrigin: make(map[common.Address]map[common.Hash][]byte),
+ nodes: trienode.NewMergedNodeSet(),
+ }
+}
+
+type tester struct {
+ db *Database
+ roots []common.Hash
+ preimages map[common.Hash]common.Address
+ accounts map[common.Hash][]byte
+ storages map[common.Hash]map[common.Hash][]byte
+
+ // state snapshots
+ snapAccounts map[common.Hash]map[common.Hash][]byte
+ snapStorages map[common.Hash]map[common.Hash]map[common.Hash][]byte
+}
+
+func newTester(t *testing.T, historyLimit uint64) *tester {
+ var (
+ disk, _ = rawdb.NewDatabaseWithFreezer(rawdb.NewMemoryDatabase(), t.TempDir(), "", false)
+ db = New(disk, &Config{
+ StateHistory: historyLimit,
+ CleanCacheSize: 256 * 1024,
+ DirtyCacheSize: 256 * 1024,
+ })
+ obj = &tester{
+ db: db,
+ preimages: make(map[common.Hash]common.Address),
+ accounts: make(map[common.Hash][]byte),
+ storages: make(map[common.Hash]map[common.Hash][]byte),
+ snapAccounts: make(map[common.Hash]map[common.Hash][]byte),
+ snapStorages: make(map[common.Hash]map[common.Hash]map[common.Hash][]byte),
+ }
+ )
+ for i := 0; i < 2*128; i++ {
+ var parent = types.EmptyRootHash
+ if len(obj.roots) != 0 {
+ parent = obj.roots[len(obj.roots)-1]
+ }
+ root, nodes, states := obj.generate(parent)
+ if err := db.Update(root, parent, uint64(i), nodes, states); err != nil {
+ panic(fmt.Errorf("failed to update state changes, err: %w", err))
+ }
+ obj.roots = append(obj.roots, root)
+ }
+ return obj
+}
+
+func (t *tester) release() {
+ t.db.Close()
+ t.db.diskdb.Close()
+}
+
+func (t *tester) randAccount() (common.Address, []byte) {
+ for addrHash, account := range t.accounts {
+ return t.preimages[addrHash], account
+ }
+ return common.Address{}, nil
+}
+
+func (t *tester) generateStorage(ctx *genctx, addr common.Address) common.Hash {
+ var (
+ addrHash = crypto.Keccak256Hash(addr.Bytes())
+ storage = make(map[common.Hash][]byte)
+ origin = make(map[common.Hash][]byte)
+ )
+ for i := 0; i < 10; i++ {
+ v, _ := rlp.EncodeToBytes(common.TrimLeftZeroes(testutil.RandBytes(32)))
+ hash := testutil.RandomHash()
+
+ storage[hash] = v
+ origin[hash] = nil
+ }
+ root, set := updateTrie(addrHash, types.EmptyRootHash, storage, nil)
+
+ ctx.storages[addrHash] = storage
+ ctx.storageOrigin[addr] = origin
+ ctx.nodes.Merge(set)
+ return root
+}
+
+func (t *tester) mutateStorage(ctx *genctx, addr common.Address, root common.Hash) common.Hash {
+ var (
+ addrHash = crypto.Keccak256Hash(addr.Bytes())
+ storage = make(map[common.Hash][]byte)
+ origin = make(map[common.Hash][]byte)
+ )
+ for hash, val := range t.storages[addrHash] {
+ origin[hash] = val
+ storage[hash] = nil
+
+ if len(origin) == 3 {
+ break
+ }
+ }
+ for i := 0; i < 3; i++ {
+ v, _ := rlp.EncodeToBytes(common.TrimLeftZeroes(testutil.RandBytes(32)))
+ hash := testutil.RandomHash()
+
+ storage[hash] = v
+ origin[hash] = nil
+ }
+ root, set := updateTrie(crypto.Keccak256Hash(addr.Bytes()), root, storage, t.storages[addrHash])
+
+ ctx.storages[addrHash] = storage
+ ctx.storageOrigin[addr] = origin
+ ctx.nodes.Merge(set)
+ return root
+}
+
+func (t *tester) clearStorage(ctx *genctx, addr common.Address, root common.Hash) common.Hash {
+ var (
+ addrHash = crypto.Keccak256Hash(addr.Bytes())
+ storage = make(map[common.Hash][]byte)
+ origin = make(map[common.Hash][]byte)
+ )
+ for hash, val := range t.storages[addrHash] {
+ origin[hash] = val
+ storage[hash] = nil
+ }
+ root, set := updateTrie(addrHash, root, storage, t.storages[addrHash])
+ if root != types.EmptyRootHash {
+ panic("failed to clear storage trie")
+ }
+ ctx.storages[addrHash] = storage
+ ctx.storageOrigin[addr] = origin
+ ctx.nodes.Merge(set)
+ return root
+}
+
+func (t *tester) generate(parent common.Hash) (common.Hash, *trienode.MergedNodeSet, *triestate.Set) {
+ var (
+ ctx = newCtx()
+ dirties = make(map[common.Hash]struct{})
+ )
+ for i := 0; i < 20; i++ {
+ switch rand.Intn(opLen) {
+ case createAccountOp:
+ // account creation
+ addr := testutil.RandomAddress()
+ addrHash := crypto.Keccak256Hash(addr.Bytes())
+ if _, ok := t.accounts[addrHash]; ok {
+ continue
+ }
+ if _, ok := dirties[addrHash]; ok {
+ continue
+ }
+ dirties[addrHash] = struct{}{}
+
+ root := t.generateStorage(ctx, addr)
+ ctx.accounts[addrHash] = types.SlimAccountRLP(generateAccount(root))
+ ctx.accountOrigin[addr] = nil
+ t.preimages[addrHash] = addr
+
+ case modifyAccountOp:
+ // account mutation
+ addr, account := t.randAccount()
+ if addr == (common.Address{}) {
+ continue
+ }
+ addrHash := crypto.Keccak256Hash(addr.Bytes())
+ if _, ok := dirties[addrHash]; ok {
+ continue
+ }
+ dirties[addrHash] = struct{}{}
+
+ acct, _ := types.FullAccount(account)
+ stRoot := t.mutateStorage(ctx, addr, acct.Root)
+ newAccount := types.SlimAccountRLP(generateAccount(stRoot))
+
+ ctx.accounts[addrHash] = newAccount
+ ctx.accountOrigin[addr] = account
+
+ case deleteAccountOp:
+ // account deletion
+ addr, account := t.randAccount()
+ if addr == (common.Address{}) {
+ continue
+ }
+ addrHash := crypto.Keccak256Hash(addr.Bytes())
+ if _, ok := dirties[addrHash]; ok {
+ continue
+ }
+ dirties[addrHash] = struct{}{}
+
+ acct, _ := types.FullAccount(account)
+ if acct.Root != types.EmptyRootHash {
+ t.clearStorage(ctx, addr, acct.Root)
+ }
+ ctx.accounts[addrHash] = nil
+ ctx.accountOrigin[addr] = account
+ }
+ }
+ root, set := updateTrie(common.Hash{}, parent, ctx.accounts, t.accounts)
+ ctx.nodes.Merge(set)
+
+ // Save state snapshot before commit
+ t.snapAccounts[parent] = copyAccounts(t.accounts)
+ t.snapStorages[parent] = copyStorages(t.storages)
+
+ // Commit all changes to live state set
+ for addrHash, account := range ctx.accounts {
+ if len(account) == 0 {
+ delete(t.accounts, addrHash)
+ } else {
+ t.accounts[addrHash] = account
+ }
+ }
+ for addrHash, slots := range ctx.storages {
+ if _, ok := t.storages[addrHash]; !ok {
+ t.storages[addrHash] = make(map[common.Hash][]byte)
+ }
+ for sHash, slot := range slots {
+ if len(slot) == 0 {
+ delete(t.storages[addrHash], sHash)
+ } else {
+ t.storages[addrHash][sHash] = slot
+ }
+ }
+ }
+ return root, ctx.nodes, triestate.New(ctx.accountOrigin, ctx.storageOrigin, nil)
+}
+
+// lastRoot returns the latest root hash, or empty if nothing is cached.
+func (t *tester) lastHash() common.Hash {
+ if len(t.roots) == 0 {
+ return common.Hash{}
+ }
+ return t.roots[len(t.roots)-1]
+}
+
+func (t *tester) verifyState(root common.Hash) error {
+ reader, err := t.db.Reader(root)
+ if err != nil {
+ return err
+ }
+ _, err = reader.Node(common.Hash{}, nil, root)
+ if err != nil {
+ return errors.New("root node is not available")
+ }
+ for addrHash, account := range t.snapAccounts[root] {
+ blob, err := reader.Node(common.Hash{}, addrHash.Bytes(), crypto.Keccak256Hash(account))
+ if err != nil || !bytes.Equal(blob, account) {
+ return fmt.Errorf("account is mismatched: %w", err)
+ }
+ }
+ for addrHash, slots := range t.snapStorages[root] {
+ for hash, slot := range slots {
+ blob, err := reader.Node(addrHash, hash.Bytes(), crypto.Keccak256Hash(slot))
+ if err != nil || !bytes.Equal(blob, slot) {
+ return fmt.Errorf("slot is mismatched: %w", err)
+ }
+ }
+ }
+ return nil
+}
+
+func (t *tester) verifyHistory() error {
+ bottom := t.bottomIndex()
+ for i, root := range t.roots {
+ // The state history related to the state above disk layer should not exist.
+ if i > bottom {
+ _, err := readHistory(t.db.freezer, uint64(i+1))
+ if err == nil {
+ return errors.New("unexpected state history")
+ }
+ continue
+ }
+ // The state history related to the state below or equal to the disk layer
+ // should exist.
+ obj, err := readHistory(t.db.freezer, uint64(i+1))
+ if err != nil {
+ return err
+ }
+ parent := types.EmptyRootHash
+ if i != 0 {
+ parent = t.roots[i-1]
+ }
+ if obj.meta.parent != parent {
+ return fmt.Errorf("unexpected parent, want: %x, got: %x", parent, obj.meta.parent)
+ }
+ if obj.meta.root != root {
+ return fmt.Errorf("unexpected root, want: %x, got: %x", root, obj.meta.root)
+ }
+ }
+ return nil
+}
+
+// bottomIndex returns the index of current disk layer.
+func (t *tester) bottomIndex() int {
+ bottom := t.db.tree.bottom()
+ for i := 0; i < len(t.roots); i++ {
+ if t.roots[i] == bottom.rootHash() {
+ return i
+ }
+ }
+ return -1
+}
+
+func TestDatabaseRollback(t *testing.T) {
+ // Verify state histories
+ tester := newTester(t, 0)
+ defer tester.release()
+
+ if err := tester.verifyHistory(); err != nil {
+ t.Fatalf("Invalid state history, err: %v", err)
+ }
+ // Revert database from top to bottom
+ for i := tester.bottomIndex(); i >= 0; i-- {
+ root := tester.roots[i]
+ parent := types.EmptyRootHash
+ if i > 0 {
+ parent = tester.roots[i-1]
+ }
+ loader := newHashLoader(tester.snapAccounts[root], tester.snapStorages[root])
+ if err := tester.db.Recover(parent, loader); err != nil {
+ t.Fatalf("Failed to revert db, err: %v", err)
+ }
+ tester.verifyState(parent)
+ }
+ if tester.db.tree.len() != 1 {
+ t.Fatal("Only disk layer is expected")
+ }
+}
+
+func TestDatabaseRecoverable(t *testing.T) {
+ var (
+ tester = newTester(t, 0)
+ index = tester.bottomIndex()
+ )
+ defer tester.release()
+
+ var cases = []struct {
+ root common.Hash
+ expect bool
+ }{
+ // Unknown state should be unrecoverable
+ {common.Hash{0x1}, false},
+
+ // Initial state should be recoverable
+ {types.EmptyRootHash, true},
+
+ // Initial state should be recoverable
+ {common.Hash{}, true},
+
+ // Layers below current disk layer are recoverable
+ {tester.roots[index-1], true},
+
+ // Disklayer itself is not recoverable, since it's
+ // available for accessing.
+ {tester.roots[index], false},
+
+ // Layers above current disk layer are not recoverable
+ // since they are available for accessing.
+ {tester.roots[index+1], false},
+ }
+ for i, c := range cases {
+ result := tester.db.Recoverable(c.root)
+ if result != c.expect {
+ t.Fatalf("case: %d, unexpected result, want %t, got %t", i, c.expect, result)
+ }
+ }
+}
+
+func TestDisable(t *testing.T) {
+ tester := newTester(t, 0)
+ defer tester.release()
+
+ _, stored := rawdb.ReadAccountTrieNode(tester.db.diskdb, nil)
+ if err := tester.db.Disable(); err != nil {
+ t.Fatal("Failed to deactivate database")
+ }
+ if err := tester.db.Enable(types.EmptyRootHash); err == nil {
+ t.Fatalf("Invalid activation should be rejected")
+ }
+ if err := tester.db.Enable(stored); err != nil {
+ t.Fatal("Failed to activate database")
+ }
+
+ // Ensure journal is deleted from disk
+ if blob := rawdb.ReadTrieJournal(tester.db.diskdb); len(blob) != 0 {
+ t.Fatal("Failed to clean journal")
+ }
+ // Ensure all trie histories are removed
+ n, err := tester.db.freezer.Ancients()
+ if err != nil {
+ t.Fatal("Failed to clean state history")
+ }
+ if n != 0 {
+ t.Fatal("Failed to clean state history")
+ }
+ // Verify layer tree structure, single disk layer is expected
+ if tester.db.tree.len() != 1 {
+ t.Fatalf("Extra layer kept %d", tester.db.tree.len())
+ }
+ if tester.db.tree.bottom().rootHash() != stored {
+ t.Fatalf("Root hash is not matched exp %x got %x", stored, tester.db.tree.bottom().rootHash())
+ }
+}
+
+func TestCommit(t *testing.T) {
+ tester := newTester(t, 0)
+ defer tester.release()
+
+ if err := tester.db.Commit(tester.lastHash(), false); err != nil {
+ t.Fatalf("Failed to cap database, err: %v", err)
+ }
+ // Verify layer tree structure, single disk layer is expected
+ if tester.db.tree.len() != 1 {
+ t.Fatal("Layer tree structure is invalid")
+ }
+ if tester.db.tree.bottom().rootHash() != tester.lastHash() {
+ t.Fatal("Layer tree structure is invalid")
+ }
+ // Verify states
+ if err := tester.verifyState(tester.lastHash()); err != nil {
+ t.Fatalf("State is invalid, err: %v", err)
+ }
+ // Verify state histories
+ if err := tester.verifyHistory(); err != nil {
+ t.Fatalf("State history is invalid, err: %v", err)
+ }
+}
+
+func TestJournal(t *testing.T) {
+ tester := newTester(t, 0)
+ defer tester.release()
+
+ if err := tester.db.Journal(tester.lastHash()); err != nil {
+ t.Errorf("Failed to journal, err: %v", err)
+ }
+ tester.db.Close()
+ tester.db = New(tester.db.diskdb, nil)
+
+ // Verify states including disk layer and all diff on top.
+ for i := 0; i < len(tester.roots); i++ {
+ if i >= tester.bottomIndex() {
+ if err := tester.verifyState(tester.roots[i]); err != nil {
+ t.Fatalf("Invalid state, err: %v", err)
+ }
+ continue
+ }
+ if err := tester.verifyState(tester.roots[i]); err == nil {
+ t.Fatal("Unexpected state")
+ }
+ }
+}
+
+func TestCorruptedJournal(t *testing.T) {
+ tester := newTester(t, 0)
+ defer tester.release()
+
+ if err := tester.db.Journal(tester.lastHash()); err != nil {
+ t.Errorf("Failed to journal, err: %v", err)
+ }
+ tester.db.Close()
+ _, root := rawdb.ReadAccountTrieNode(tester.db.diskdb, nil)
+
+ // Mutate the journal in disk, it should be regarded as invalid
+ blob := rawdb.ReadTrieJournal(tester.db.diskdb)
+ blob[0] = 1
+ rawdb.WriteTrieJournal(tester.db.diskdb, blob)
+
+ // Verify states, all not-yet-written states should be discarded
+ tester.db = New(tester.db.diskdb, nil)
+ for i := 0; i < len(tester.roots); i++ {
+ if tester.roots[i] == root {
+ if err := tester.verifyState(root); err != nil {
+ t.Fatalf("Disk state is corrupted, err: %v", err)
+ }
+ continue
+ }
+ if err := tester.verifyState(tester.roots[i]); err == nil {
+ t.Fatal("Unexpected state")
+ }
+ }
+}
+
+// TestTailTruncateHistory function is designed to test a specific edge case where,
+// when history objects are removed from the end, it should trigger a state flush
+// if the ID of the new tail object is even higher than the persisted state ID.
+//
+// For example, let's say the ID of the persistent state is 10, and the current
+// history objects range from ID(5) to ID(15). As we accumulate six more objects,
+// the history will expand to cover ID(11) to ID(21). ID(11) then becomes the
+// oldest history object, and its ID is even higher than the stored state.
+//
+// In this scenario, it is mandatory to update the persistent state before
+// truncating the tail histories. This ensures that the ID of the persistent state
+// always falls within the range of [oldest-history-id, latest-history-id].
+func TestTailTruncateHistory(t *testing.T) {
+ tester := newTester(t, 10)
+ defer tester.release()
+
+ tester.db.Close()
+ tester.db = New(tester.db.diskdb, &Config{StateHistory: 10})
+
+ head, err := tester.db.freezer.Ancients()
+ if err != nil {
+ t.Fatalf("Failed to obtain freezer head")
+ }
+ stored := rawdb.ReadPersistentStateID(tester.db.diskdb)
+ if head != stored {
+ t.Fatalf("Failed to truncate excess history object above, stored: %d, head: %d", stored, head)
+ }
+}
+
+// copyAccounts returns a deep-copied account set of the provided one.
+func copyAccounts(set map[common.Hash][]byte) map[common.Hash][]byte {
+ copied := make(map[common.Hash][]byte, len(set))
+ for key, val := range set {
+ copied[key] = common.CopyBytes(val)
+ }
+ return copied
+}
+
+// copyStorages returns a deep-copied storage set of the provided one.
+func copyStorages(set map[common.Hash]map[common.Hash][]byte) map[common.Hash]map[common.Hash][]byte {
+ copied := make(map[common.Hash]map[common.Hash][]byte, len(set))
+ for addrHash, subset := range set {
+ copied[addrHash] = make(map[common.Hash][]byte, len(subset))
+ for key, val := range subset {
+ copied[addrHash][key] = common.CopyBytes(val)
+ }
+ }
+ return copied
+}
diff --git a/trie_by_cid/triedb/pathdb/difflayer.go b/trie_by_cid/triedb/pathdb/difflayer.go
new file mode 100644
index 0000000..88e9e98
--- /dev/null
+++ b/trie_by_cid/triedb/pathdb/difflayer.go
@@ -0,0 +1,174 @@
+// Copyright 2022 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see .
+
+package pathdb
+
+import (
+ "fmt"
+ "sync"
+
+ "github.com/cerc-io/ipld-eth-statedb/trie_by_cid/trie/trienode"
+ "github.com/cerc-io/ipld-eth-statedb/trie_by_cid/trie/triestate"
+ "github.com/ethereum/go-ethereum/common"
+ "github.com/ethereum/go-ethereum/log"
+)
+
+// diffLayer represents a collection of modifications made to the in-memory tries
+// along with associated state changes after running a block on top.
+//
+// The goal of a diff layer is to act as a journal, tracking recent modifications
+// made to the state, that have not yet graduated into a semi-immutable state.
+type diffLayer struct {
+ // Immutables
+ root common.Hash // Root hash to which this layer diff belongs to
+ id uint64 // Corresponding state id
+ block uint64 // Associated block number
+ nodes map[common.Hash]map[string]*trienode.Node // Cached trie nodes indexed by owner and path
+ states *triestate.Set // Associated state change set for building history
+ memory uint64 // Approximate guess as to how much memory we use
+
+ parent layer // Parent layer modified by this one, never nil, **can be changed**
+ lock sync.RWMutex // Lock used to protect parent
+}
+
+// newDiffLayer creates a new diff layer on top of an existing layer.
+func newDiffLayer(parent layer, root common.Hash, id uint64, block uint64, nodes map[common.Hash]map[string]*trienode.Node, states *triestate.Set) *diffLayer {
+ var (
+ size int64
+ count int
+ )
+ dl := &diffLayer{
+ root: root,
+ id: id,
+ block: block,
+ nodes: nodes,
+ states: states,
+ parent: parent,
+ }
+ for _, subset := range nodes {
+ for path, n := range subset {
+ dl.memory += uint64(n.Size() + len(path))
+ size += int64(len(n.Blob) + len(path))
+ }
+ count += len(subset)
+ }
+ if states != nil {
+ dl.memory += uint64(states.Size())
+ }
+ dirtyWriteMeter.Mark(size)
+ diffLayerNodesMeter.Mark(int64(count))
+ diffLayerBytesMeter.Mark(int64(dl.memory))
+ log.Debug("Created new diff layer", "id", id, "block", block, "nodes", count, "size", common.StorageSize(dl.memory))
+ return dl
+}
+
+// rootHash implements the layer interface, returning the root hash of
+// corresponding state.
+func (dl *diffLayer) rootHash() common.Hash {
+ return dl.root
+}
+
+// stateID implements the layer interface, returning the state id of the layer.
+func (dl *diffLayer) stateID() uint64 {
+ return dl.id
+}
+
+// parentLayer implements the layer interface, returning the subsequent
+// layer of the diff layer.
+func (dl *diffLayer) parentLayer() layer {
+ dl.lock.RLock()
+ defer dl.lock.RUnlock()
+
+ return dl.parent
+}
+
+// node retrieves the node with provided node information. It's the internal
+// version of Node function with additional accessed layer tracked. No error
+// will be returned if node is not found.
+func (dl *diffLayer) node(owner common.Hash, path []byte, hash common.Hash, depth int) ([]byte, error) {
+ // Hold the lock, ensure the parent won't be changed during the
+ // state accessing.
+ dl.lock.RLock()
+ defer dl.lock.RUnlock()
+
+ // If the trie node is known locally, return it
+ subset, ok := dl.nodes[owner]
+ if ok {
+ n, ok := subset[string(path)]
+ if ok {
+ // If the trie node is not hash matched, or marked as removed,
+ // bubble up an error here. It shouldn't happen at all.
+ if n.Hash != hash {
+ dirtyFalseMeter.Mark(1)
+ log.Error("Unexpected trie node in diff layer", "owner", owner, "path", path, "expect", hash, "got", n.Hash)
+ return nil, newUnexpectedNodeError("diff", hash, n.Hash, owner, path, n.Blob)
+ }
+ dirtyHitMeter.Mark(1)
+ dirtyNodeHitDepthHist.Update(int64(depth))
+ dirtyReadMeter.Mark(int64(len(n.Blob)))
+ return n.Blob, nil
+ }
+ }
+ // Trie node unknown to this layer, resolve from parent
+ if diff, ok := dl.parent.(*diffLayer); ok {
+ return diff.node(owner, path, hash, depth+1)
+ }
+ // Failed to resolve through diff layers, fallback to disk layer
+ return dl.parent.Node(owner, path, hash)
+}
+
+// Node implements the layer interface, retrieving the trie node blob with the
+// provided node information. No error will be returned if the node is not found.
+func (dl *diffLayer) Node(owner common.Hash, path []byte, hash common.Hash) ([]byte, error) {
+ return dl.node(owner, path, hash, 0)
+}
+
+// update implements the layer interface, creating a new layer on top of the
+// existing layer tree with the specified data items.
+func (dl *diffLayer) update(root common.Hash, id uint64, block uint64, nodes map[common.Hash]map[string]*trienode.Node, states *triestate.Set) *diffLayer {
+ return newDiffLayer(dl, root, id, block, nodes, states)
+}
+
+// persist flushes the diff layer and all its parent layers to disk layer.
+func (dl *diffLayer) persist(force bool) (layer, error) {
+ if parent, ok := dl.parentLayer().(*diffLayer); ok {
+ // Hold the lock to prevent any read operation until the new
+ // parent is linked correctly.
+ dl.lock.Lock()
+
+ // The merging of diff layers starts at the bottom-most layer,
+ // therefore we recurse down here, flattening on the way up
+ // (diffToDisk).
+ result, err := parent.persist(force)
+ if err != nil {
+ dl.lock.Unlock()
+ return nil, err
+ }
+ dl.parent = result
+ dl.lock.Unlock()
+ }
+ return diffToDisk(dl, force)
+}
+
+// diffToDisk merges a bottom-most diff into the persistent disk layer underneath
+// it. The method will panic if called onto a non-bottom-most diff layer.
+func diffToDisk(layer *diffLayer, force bool) (layer, error) {
+ disk, ok := layer.parentLayer().(*diskLayer)
+ if !ok {
+ panic(fmt.Sprintf("unknown layer type: %T", layer.parentLayer()))
+ }
+ return disk.commit(layer, force)
+}
diff --git a/trie_by_cid/triedb/pathdb/difflayer_test.go b/trie_by_cid/triedb/pathdb/difflayer_test.go
new file mode 100644
index 0000000..087f490
--- /dev/null
+++ b/trie_by_cid/triedb/pathdb/difflayer_test.go
@@ -0,0 +1,170 @@
+// Copyright 2019 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see .
+
+package pathdb
+
+import (
+ "bytes"
+ "testing"
+
+ "github.com/cerc-io/ipld-eth-statedb/trie_by_cid/trie/testutil"
+ "github.com/cerc-io/ipld-eth-statedb/trie_by_cid/trie/trienode"
+ "github.com/ethereum/go-ethereum/common"
+ "github.com/ethereum/go-ethereum/core/rawdb"
+)
+
+func emptyLayer() *diskLayer {
+ return &diskLayer{
+ db: New(rawdb.NewMemoryDatabase(), nil),
+ buffer: newNodeBuffer(DefaultBufferSize, nil, 0),
+ }
+}
+
+// goos: darwin
+// goarch: arm64
+// pkg: github.com/cerc-io/ipld-eth-statedb/trie_by_cid/trie
+// BenchmarkSearch128Layers
+// BenchmarkSearch128Layers-8 243826 4755 ns/op
+func BenchmarkSearch128Layers(b *testing.B) { benchmarkSearch(b, 0, 128) }
+
+// goos: darwin
+// goarch: arm64
+// pkg: github.com/cerc-io/ipld-eth-statedb/trie_by_cid/trie
+// BenchmarkSearch512Layers
+// BenchmarkSearch512Layers-8 49686 24256 ns/op
+func BenchmarkSearch512Layers(b *testing.B) { benchmarkSearch(b, 0, 512) }
+
+// goos: darwin
+// goarch: arm64
+// pkg: github.com/cerc-io/ipld-eth-statedb/trie_by_cid/trie
+// BenchmarkSearch1Layer
+// BenchmarkSearch1Layer-8 14062725 88.40 ns/op
+func BenchmarkSearch1Layer(b *testing.B) { benchmarkSearch(b, 127, 128) }
+
+func benchmarkSearch(b *testing.B, depth int, total int) {
+ var (
+ npath []byte
+ nhash common.Hash
+ nblob []byte
+ )
+ // First, we set up 128 diff layers, with 3K items each
+ fill := func(parent layer, index int) *diffLayer {
+ nodes := make(map[common.Hash]map[string]*trienode.Node)
+ nodes[common.Hash{}] = make(map[string]*trienode.Node)
+ for i := 0; i < 3000; i++ {
+ var (
+ path = testutil.RandBytes(32)
+ node = testutil.RandomNode()
+ )
+ nodes[common.Hash{}][string(path)] = trienode.New(node.Hash, node.Blob)
+ if npath == nil && depth == index {
+ npath = common.CopyBytes(path)
+ nblob = common.CopyBytes(node.Blob)
+ nhash = node.Hash
+ }
+ }
+ return newDiffLayer(parent, common.Hash{}, 0, 0, nodes, nil)
+ }
+ var layer layer
+ layer = emptyLayer()
+ for i := 0; i < total; i++ {
+ layer = fill(layer, i)
+ }
+ b.ResetTimer()
+
+ var (
+ have []byte
+ err error
+ )
+ for i := 0; i < b.N; i++ {
+ have, err = layer.Node(common.Hash{}, npath, nhash)
+ if err != nil {
+ b.Fatal(err)
+ }
+ }
+ if !bytes.Equal(have, nblob) {
+ b.Fatalf("have %x want %x", have, nblob)
+ }
+}
+
+// goos: darwin
+// goarch: arm64
+// pkg: github.com/cerc-io/ipld-eth-statedb/trie_by_cid/trie
+// BenchmarkPersist
+// BenchmarkPersist-8 10 111252975 ns/op
+func BenchmarkPersist(b *testing.B) {
+ // First, we set up 128 diff layers, with 3K items each
+ fill := func(parent layer) *diffLayer {
+ nodes := make(map[common.Hash]map[string]*trienode.Node)
+ nodes[common.Hash{}] = make(map[string]*trienode.Node)
+ for i := 0; i < 3000; i++ {
+ var (
+ path = testutil.RandBytes(32)
+ node = testutil.RandomNode()
+ )
+ nodes[common.Hash{}][string(path)] = trienode.New(node.Hash, node.Blob)
+ }
+ return newDiffLayer(parent, common.Hash{}, 0, 0, nodes, nil)
+ }
+ for i := 0; i < b.N; i++ {
+ b.StopTimer()
+ var layer layer
+ layer = emptyLayer()
+ for i := 1; i < 128; i++ {
+ layer = fill(layer)
+ }
+ b.StartTimer()
+
+ dl, ok := layer.(*diffLayer)
+ if !ok {
+ break
+ }
+ dl.persist(false)
+ }
+}
+
+// BenchmarkJournal benchmarks the performance for journaling the layers.
+//
+// BenchmarkJournal
+// BenchmarkJournal-8 10 110969279 ns/op
+func BenchmarkJournal(b *testing.B) {
+ b.SkipNow()
+
+ // First, we set up 128 diff layers, with 3K items each
+ fill := func(parent layer) *diffLayer {
+ nodes := make(map[common.Hash]map[string]*trienode.Node)
+ nodes[common.Hash{}] = make(map[string]*trienode.Node)
+ for i := 0; i < 3000; i++ {
+ var (
+ path = testutil.RandBytes(32)
+ node = testutil.RandomNode()
+ )
+ nodes[common.Hash{}][string(path)] = trienode.New(node.Hash, node.Blob)
+ }
+ // TODO(rjl493456442) a non-nil state set is expected.
+ return newDiffLayer(parent, common.Hash{}, 0, 0, nodes, nil)
+ }
+ var layer layer
+ layer = emptyLayer()
+ for i := 0; i < 128; i++ {
+ layer = fill(layer)
+ }
+ b.ResetTimer()
+
+ for i := 0; i < b.N; i++ {
+ layer.journal(new(bytes.Buffer))
+ }
+}
diff --git a/trie_by_cid/triedb/pathdb/disklayer.go b/trie_by_cid/triedb/pathdb/disklayer.go
new file mode 100644
index 0000000..e016d6d
--- /dev/null
+++ b/trie_by_cid/triedb/pathdb/disklayer.go
@@ -0,0 +1,338 @@
+// Copyright 2022 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see .
+
+package pathdb
+
+import (
+ "errors"
+ "fmt"
+ "sync"
+
+ "github.com/VictoriaMetrics/fastcache"
+ "github.com/cerc-io/ipld-eth-statedb/trie_by_cid/trie/trienode"
+ "github.com/cerc-io/ipld-eth-statedb/trie_by_cid/trie/triestate"
+ "github.com/ethereum/go-ethereum/common"
+ "github.com/ethereum/go-ethereum/core/rawdb"
+ "github.com/ethereum/go-ethereum/crypto"
+ "github.com/ethereum/go-ethereum/log"
+ "golang.org/x/crypto/sha3"
+)
+
+// diskLayer is a low level persistent layer built on top of a key-value store.
+type diskLayer struct {
+ root common.Hash // Immutable, root hash to which this layer was made for
+ id uint64 // Immutable, corresponding state id
+ db *Database // Path-based trie database
+ cleans *fastcache.Cache // GC friendly memory cache of clean node RLPs
+ buffer *nodebuffer // Node buffer to aggregate writes
+ stale bool // Signals that the layer became stale (state progressed)
+ lock sync.RWMutex // Lock used to protect stale flag
+}
+
+// newDiskLayer creates a new disk layer based on the passing arguments.
+func newDiskLayer(root common.Hash, id uint64, db *Database, cleans *fastcache.Cache, buffer *nodebuffer) *diskLayer {
+ // Initialize a clean cache if the memory allowance is not zero
+ // or reuse the provided cache if it is not nil (inherited from
+ // the original disk layer).
+ if cleans == nil && db.config.CleanCacheSize != 0 {
+ cleans = fastcache.New(db.config.CleanCacheSize)
+ }
+ return &diskLayer{
+ root: root,
+ id: id,
+ db: db,
+ cleans: cleans,
+ buffer: buffer,
+ }
+}
+
+// root implements the layer interface, returning root hash of corresponding state.
+func (dl *diskLayer) rootHash() common.Hash {
+ return dl.root
+}
+
+// stateID implements the layer interface, returning the state id of disk layer.
+func (dl *diskLayer) stateID() uint64 {
+ return dl.id
+}
+
+// parent implements the layer interface, returning nil as there's no layer
+// below the disk.
+func (dl *diskLayer) parentLayer() layer {
+ return nil
+}
+
+// isStale return whether this layer has become stale (was flattened across) or if
+// it's still live.
+func (dl *diskLayer) isStale() bool {
+ dl.lock.RLock()
+ defer dl.lock.RUnlock()
+
+ return dl.stale
+}
+
+// markStale sets the stale flag as true.
+func (dl *diskLayer) markStale() {
+ dl.lock.Lock()
+ defer dl.lock.Unlock()
+
+ if dl.stale {
+ panic("triedb disk layer is stale") // we've committed into the same base from two children, boom
+ }
+ dl.stale = true
+}
+
+// Node implements the layer interface, retrieving the trie node with the
+// provided node info. No error will be returned if the node is not found.
+func (dl *diskLayer) Node(owner common.Hash, path []byte, hash common.Hash) ([]byte, error) {
+ dl.lock.RLock()
+ defer dl.lock.RUnlock()
+
+ if dl.stale {
+ return nil, errSnapshotStale
+ }
+ // Try to retrieve the trie node from the not-yet-written
+ // node buffer first. Note the buffer is lock free since
+ // it's impossible to mutate the buffer before tagging the
+ // layer as stale.
+ n, err := dl.buffer.node(owner, path, hash)
+ if err != nil {
+ return nil, err
+ }
+ if n != nil {
+ dirtyHitMeter.Mark(1)
+ dirtyReadMeter.Mark(int64(len(n.Blob)))
+ return n.Blob, nil
+ }
+ dirtyMissMeter.Mark(1)
+
+ // Try to retrieve the trie node from the clean memory cache
+ key := cacheKey(owner, path)
+ if dl.cleans != nil {
+ if blob := dl.cleans.Get(nil, key); len(blob) > 0 {
+ h := newHasher()
+ defer h.release()
+
+ got := h.hash(blob)
+ if got == hash {
+ cleanHitMeter.Mark(1)
+ cleanReadMeter.Mark(int64(len(blob)))
+ return blob, nil
+ }
+ cleanFalseMeter.Mark(1)
+ log.Error("Unexpected trie node in clean cache", "owner", owner, "path", path, "expect", hash, "got", got)
+ }
+ cleanMissMeter.Mark(1)
+ }
+ // Try to retrieve the trie node from the disk.
+ var (
+ nBlob []byte
+ nHash common.Hash
+ )
+ if owner == (common.Hash{}) {
+ nBlob, nHash = rawdb.ReadAccountTrieNode(dl.db.diskdb, path)
+ } else {
+ nBlob, nHash = rawdb.ReadStorageTrieNode(dl.db.diskdb, owner, path)
+ }
+ if nHash != hash {
+ diskFalseMeter.Mark(1)
+ log.Error("Unexpected trie node in disk", "owner", owner, "path", path, "expect", hash, "got", nHash)
+ return nil, newUnexpectedNodeError("disk", hash, nHash, owner, path, nBlob)
+ }
+ if dl.cleans != nil && len(nBlob) > 0 {
+ dl.cleans.Set(key, nBlob)
+ cleanWriteMeter.Mark(int64(len(nBlob)))
+ }
+ return nBlob, nil
+}
+
+// update implements the layer interface, returning a new diff layer on top
+// with the given state set.
+func (dl *diskLayer) update(root common.Hash, id uint64, block uint64, nodes map[common.Hash]map[string]*trienode.Node, states *triestate.Set) *diffLayer {
+ return newDiffLayer(dl, root, id, block, nodes, states)
+}
+
+// commit merges the given bottom-most diff layer into the node buffer
+// and returns a newly constructed disk layer. Note the current disk
+// layer must be tagged as stale first to prevent re-access.
+func (dl *diskLayer) commit(bottom *diffLayer, force bool) (*diskLayer, error) {
+ dl.lock.Lock()
+ defer dl.lock.Unlock()
+
+ // Construct and store the state history first. If crash happens after storing
+ // the state history but without flushing the corresponding states(journal),
+ // the stored state history will be truncated from head in the next restart.
+ var (
+ overflow bool
+ oldest uint64
+ )
+ if dl.db.freezer != nil {
+ err := writeHistory(dl.db.freezer, bottom)
+ if err != nil {
+ return nil, err
+ }
+ // Determine if the persisted history object has exceeded the configured
+ // limitation, set the overflow as true if so.
+ tail, err := dl.db.freezer.Tail()
+ if err != nil {
+ return nil, err
+ }
+ limit := dl.db.config.StateHistory
+ if limit != 0 && bottom.stateID()-tail > limit {
+ overflow = true
+ oldest = bottom.stateID() - limit + 1 // track the id of history **after truncation**
+ }
+ }
+ // Mark the diskLayer as stale before applying any mutations on top.
+ dl.stale = true
+
+ // Store the root->id lookup afterwards. All stored lookups are identified
+ // by the **unique** state root. It's impossible that in the same chain
+ // blocks are not adjacent but have the same root.
+ if dl.id == 0 {
+ rawdb.WriteStateID(dl.db.diskdb, dl.root, 0)
+ }
+ rawdb.WriteStateID(dl.db.diskdb, bottom.rootHash(), bottom.stateID())
+
+ // Construct a new disk layer by merging the nodes from the provided diff
+ // layer, and flush the content in disk layer if there are too many nodes
+ // cached. The clean cache is inherited from the original disk layer.
+ ndl := newDiskLayer(bottom.root, bottom.stateID(), dl.db, dl.cleans, dl.buffer.commit(bottom.nodes))
+
+ // In a unique scenario where the ID of the oldest history object (after tail
+ // truncation) surpasses the persisted state ID, we take the necessary action
+ // of forcibly committing the cached dirty nodes to ensure that the persisted
+ // state ID remains higher.
+ if !force && rawdb.ReadPersistentStateID(dl.db.diskdb) < oldest {
+ force = true
+ }
+ if err := ndl.buffer.flush(ndl.db.diskdb, ndl.cleans, ndl.id, force); err != nil {
+ return nil, err
+ }
+ // To remove outdated history objects from the end, we set the 'tail' parameter
+ // to 'oldest-1' due to the offset between the freezer index and the history ID.
+ if overflow {
+ pruned, err := truncateFromTail(ndl.db.diskdb, ndl.db.freezer, oldest-1)
+ if err != nil {
+ return nil, err
+ }
+ log.Debug("Pruned state history", "items", pruned, "tailid", oldest)
+ }
+ return ndl, nil
+}
+
+// revert applies the given state history and return a reverted disk layer.
+func (dl *diskLayer) revert(h *history, loader triestate.TrieLoader) (*diskLayer, error) {
+ if h.meta.root != dl.rootHash() {
+ return nil, errUnexpectedHistory
+ }
+ // Reject if the provided state history is incomplete. It's due to
+ // a large construct SELF-DESTRUCT which can't be handled because
+ // of memory limitation.
+ if len(h.meta.incomplete) > 0 {
+ return nil, errors.New("incomplete state history")
+ }
+ if dl.id == 0 {
+ return nil, fmt.Errorf("%w: zero state id", errStateUnrecoverable)
+ }
+ // Apply the reverse state changes upon the current state. This must
+ // be done before holding the lock in order to access state in "this"
+ // layer.
+ nodes, err := triestate.Apply(h.meta.parent, h.meta.root, h.accounts, h.storages, loader)
+ if err != nil {
+ return nil, err
+ }
+ // Mark the diskLayer as stale before applying any mutations on top.
+ dl.lock.Lock()
+ defer dl.lock.Unlock()
+
+ dl.stale = true
+
+ // State change may be applied to node buffer, or the persistent
+ // state, depends on if node buffer is empty or not. If the node
+ // buffer is not empty, it means that the state transition that
+ // needs to be reverted is not yet flushed and cached in node
+ // buffer, otherwise, manipulate persistent state directly.
+ if !dl.buffer.empty() {
+ err := dl.buffer.revert(dl.db.diskdb, nodes)
+ if err != nil {
+ return nil, err
+ }
+ } else {
+ batch := dl.db.diskdb.NewBatch()
+ writeNodes(batch, nodes, dl.cleans)
+ rawdb.WritePersistentStateID(batch, dl.id-1)
+ if err := batch.Write(); err != nil {
+ log.Crit("Failed to write states", "err", err)
+ }
+ }
+ return newDiskLayer(h.meta.parent, dl.id-1, dl.db, dl.cleans, dl.buffer), nil
+}
+
+// setBufferSize sets the node buffer size to the provided value.
+func (dl *diskLayer) setBufferSize(size int) error {
+ dl.lock.RLock()
+ defer dl.lock.RUnlock()
+
+ if dl.stale {
+ return errSnapshotStale
+ }
+ return dl.buffer.setSize(size, dl.db.diskdb, dl.cleans, dl.id)
+}
+
+// size returns the approximate size of cached nodes in the disk layer.
+func (dl *diskLayer) size() common.StorageSize {
+ dl.lock.RLock()
+ defer dl.lock.RUnlock()
+
+ if dl.stale {
+ return 0
+ }
+ return common.StorageSize(dl.buffer.size)
+}
+
+// resetCache releases the memory held by clean cache to prevent memory leak.
+func (dl *diskLayer) resetCache() {
+ dl.lock.RLock()
+ defer dl.lock.RUnlock()
+
+ // Stale disk layer loses the ownership of clean cache.
+ if dl.stale {
+ return
+ }
+ if dl.cleans != nil {
+ dl.cleans.Reset()
+ }
+}
+
+// hasher is used to compute the sha256 hash of the provided data.
+type hasher struct{ sha crypto.KeccakState }
+
+var hasherPool = sync.Pool{
+ New: func() interface{} { return &hasher{sha: sha3.NewLegacyKeccak256().(crypto.KeccakState)} },
+}
+
+func newHasher() *hasher {
+ return hasherPool.Get().(*hasher)
+}
+
+func (h *hasher) hash(data []byte) common.Hash {
+ return crypto.HashData(h.sha, data)
+}
+
+func (h *hasher) release() {
+ hasherPool.Put(h)
+}
diff --git a/trie_by_cid/triedb/pathdb/errors.go b/trie_by_cid/triedb/pathdb/errors.go
new file mode 100644
index 0000000..78ee445
--- /dev/null
+++ b/trie_by_cid/triedb/pathdb/errors.go
@@ -0,0 +1,60 @@
+// Copyright 2023 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see
+
+package pathdb
+
+import (
+ "errors"
+ "fmt"
+
+ "github.com/ethereum/go-ethereum/common"
+ "github.com/ethereum/go-ethereum/common/hexutil"
+)
+
+var (
+ // errDatabaseReadOnly is returned if the database is opened in read only mode
+ // to prevent any mutation.
+ errDatabaseReadOnly = errors.New("read only")
+
+ // errDatabaseWaitSync is returned if the initial state sync is not completed
+ // yet and database is disabled to prevent accessing state.
+ errDatabaseWaitSync = errors.New("waiting for sync")
+
+ // errSnapshotStale is returned from data accessors if the underlying layer
+ // layer had been invalidated due to the chain progressing forward far enough
+ // to not maintain the layer's original state.
+ errSnapshotStale = errors.New("layer stale")
+
+ // errUnexpectedHistory is returned if an unmatched state history is applied
+ // to the database for state rollback.
+ errUnexpectedHistory = errors.New("unexpected state history")
+
+ // errStateUnrecoverable is returned if state is required to be reverted to
+ // a destination without associated state history available.
+ errStateUnrecoverable = errors.New("state is unrecoverable")
+
+ // errUnexpectedNode is returned if the requested node with specified path is
+ // not hash matched with expectation.
+ errUnexpectedNode = errors.New("unexpected node")
+)
+
+func newUnexpectedNodeError(loc string, expHash common.Hash, gotHash common.Hash, owner common.Hash, path []byte, blob []byte) error {
+ blobHex := "nil"
+ if len(blob) > 0 {
+ blobHex = hexutil.Encode(blob)
+ }
+ return fmt.Errorf("%w, loc: %s, node: (%x %v), %x!=%x, blob: %s", errUnexpectedNode, loc, owner, path, expHash, gotHash, blobHex)
+}
diff --git a/trie_by_cid/triedb/pathdb/history.go b/trie_by_cid/triedb/pathdb/history.go
new file mode 100644
index 0000000..e24fa34
--- /dev/null
+++ b/trie_by_cid/triedb/pathdb/history.go
@@ -0,0 +1,649 @@
+// Copyright 2022 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see
+
+package pathdb
+
+import (
+ "bytes"
+ "encoding/binary"
+ "errors"
+ "fmt"
+ "time"
+
+ "github.com/cerc-io/ipld-eth-statedb/trie_by_cid/trie/triestate"
+ "github.com/ethereum/go-ethereum/common"
+ "github.com/ethereum/go-ethereum/core/rawdb"
+ "github.com/ethereum/go-ethereum/ethdb"
+ "github.com/ethereum/go-ethereum/log"
+ "golang.org/x/exp/slices"
+)
+
+// State history records the state changes involved in executing a block. The
+// state can be reverted to the previous version by applying the associated
+// history object (state reverse diff). State history objects are kept to
+// guarantee that the system can perform state rollbacks in case of deep reorg.
+//
+// Each state transition will generate a state history object. Note that not
+// every block has a corresponding state history object. If a block performs
+// no state changes whatsoever, no state is created for it. Each state history
+// will have a sequentially increasing number acting as its unique identifier.
+//
+// The state history is written to disk (ancient store) when the corresponding
+// diff layer is merged into the disk layer. At the same time, system can prune
+// the oldest histories according to config.
+//
+// Disk State
+// ^
+// |
+// +------------+ +---------+ +---------+ +---------+
+// | Init State |---->| State 1 |---->| ... |---->| State n |
+// +------------+ +---------+ +---------+ +---------+
+//
+// +-----------+ +------+ +-----------+
+// | History 1 |----> | ... |---->| History n |
+// +-----------+ +------+ +-----------+
+//
+// # Rollback
+//
+// If the system wants to roll back to a previous state n, it needs to ensure
+// all history objects from n+1 up to the current disk layer are existent. The
+// history objects are applied to the state in reverse order, starting from the
+// current disk layer.
+
+const (
+ accountIndexSize = common.AddressLength + 13 // The length of encoded account index
+ slotIndexSize = common.HashLength + 5 // The length of encoded slot index
+ historyMetaSize = 9 + 2*common.HashLength // The length of fixed size part of meta object
+
+ stateHistoryVersion = uint8(0) // initial version of state history structure.
+)
+
+// Each state history entry is consisted of five elements:
+//
+// # metadata
+// This object contains a few meta fields, such as the associated state root,
+// block number, version tag and so on. This object may contain an extra
+// accountHash list which means the storage changes belong to these accounts
+// are not complete due to large contract destruction. The incomplete history
+// can not be used for rollback and serving archive state request.
+//
+// # account index
+// This object contains some index information of account. For example, offset
+// and length indicate the location of the data belonging to the account. Besides,
+// storageOffset and storageSlots indicate the storage modification location
+// belonging to the account.
+//
+// The size of each account index is *fixed*, and all indexes are sorted
+// lexicographically. Thus binary search can be performed to quickly locate a
+// specific account.
+//
+// # account data
+// Account data is a concatenated byte stream composed of all account data.
+// The account data can be solved by the offset and length info indicated
+// by corresponding account index.
+//
+// fixed size
+// ^ ^
+// / \
+// +-----------------+-----------------+----------------+-----------------+
+// | Account index 1 | Account index 2 | ... | Account index N |
+// +-----------------+-----------------+----------------+-----------------+
+// |
+// | length
+// offset |----------------+
+// v v
+// +----------------+----------------+----------------+----------------+
+// | Account data 1 | Account data 2 | ... | Account data N |
+// +----------------+----------------+----------------+----------------+
+//
+// # storage index
+// This object is similar with account index. It's also fixed size and contains
+// the location info of storage slot data.
+//
+// # storage data
+// Storage data is a concatenated byte stream composed of all storage slot data.
+// The storage slot data can be solved by the location info indicated by
+// corresponding account index and storage slot index.
+//
+// fixed size
+// ^ ^
+// / \
+// +-----------------+-----------------+----------------+-----------------+
+// | Account index 1 | Account index 2 | ... | Account index N |
+// +-----------------+-----------------+----------------+-----------------+
+// |
+// | storage slots
+// storage offset |-----------------------------------------------------+
+// v v
+// +-----------------+-----------------+-----------------+
+// | storage index 1 | storage index 2 | storage index 3 |
+// +-----------------+-----------------+-----------------+
+// | length
+// offset |-------------+
+// v v
+// +-------------+
+// | slot data 1 |
+// +-------------+
+
+// accountIndex describes the metadata belonging to an account.
+type accountIndex struct {
+ address common.Address // The address of account
+ length uint8 // The length of account data, size limited by 255
+ offset uint32 // The offset of item in account data table
+ storageOffset uint32 // The offset of storage index in storage index table
+ storageSlots uint32 // The number of mutated storage slots belonging to the account
+}
+
+// encode packs account index into byte stream.
+func (i *accountIndex) encode() []byte {
+ var buf [accountIndexSize]byte
+ copy(buf[:], i.address.Bytes())
+ buf[common.AddressLength] = i.length
+ binary.BigEndian.PutUint32(buf[common.AddressLength+1:], i.offset)
+ binary.BigEndian.PutUint32(buf[common.AddressLength+5:], i.storageOffset)
+ binary.BigEndian.PutUint32(buf[common.AddressLength+9:], i.storageSlots)
+ return buf[:]
+}
+
+// decode unpacks account index from byte stream.
+func (i *accountIndex) decode(blob []byte) {
+ i.address = common.BytesToAddress(blob[:common.AddressLength])
+ i.length = blob[common.AddressLength]
+ i.offset = binary.BigEndian.Uint32(blob[common.AddressLength+1:])
+ i.storageOffset = binary.BigEndian.Uint32(blob[common.AddressLength+5:])
+ i.storageSlots = binary.BigEndian.Uint32(blob[common.AddressLength+9:])
+}
+
+// slotIndex describes the metadata belonging to a storage slot.
+type slotIndex struct {
+ hash common.Hash // The hash of slot key
+ length uint8 // The length of storage slot, up to 32 bytes defined in protocol
+ offset uint32 // The offset of item in storage slot data table
+}
+
+// encode packs slot index into byte stream.
+func (i *slotIndex) encode() []byte {
+ var buf [slotIndexSize]byte
+ copy(buf[:common.HashLength], i.hash.Bytes())
+ buf[common.HashLength] = i.length
+ binary.BigEndian.PutUint32(buf[common.HashLength+1:], i.offset)
+ return buf[:]
+}
+
+// decode unpack slot index from the byte stream.
+func (i *slotIndex) decode(blob []byte) {
+ i.hash = common.BytesToHash(blob[:common.HashLength])
+ i.length = blob[common.HashLength]
+ i.offset = binary.BigEndian.Uint32(blob[common.HashLength+1:])
+}
+
+// meta describes the meta data of state history object.
+type meta struct {
+ version uint8 // version tag of history object
+ parent common.Hash // prev-state root before the state transition
+ root common.Hash // post-state root after the state transition
+ block uint64 // associated block number
+ incomplete []common.Address // list of address whose storage set is incomplete
+}
+
+// encode packs the meta object into byte stream.
+func (m *meta) encode() []byte {
+ buf := make([]byte, historyMetaSize+len(m.incomplete)*common.AddressLength)
+ buf[0] = m.version
+ copy(buf[1:1+common.HashLength], m.parent.Bytes())
+ copy(buf[1+common.HashLength:1+2*common.HashLength], m.root.Bytes())
+ binary.BigEndian.PutUint64(buf[1+2*common.HashLength:historyMetaSize], m.block)
+ for i, h := range m.incomplete {
+ copy(buf[i*common.AddressLength+historyMetaSize:], h.Bytes())
+ }
+ return buf[:]
+}
+
+// decode unpacks the meta object from byte stream.
+func (m *meta) decode(blob []byte) error {
+ if len(blob) < 1 {
+ return fmt.Errorf("no version tag")
+ }
+ switch blob[0] {
+ case stateHistoryVersion:
+ if len(blob) < historyMetaSize {
+ return fmt.Errorf("invalid state history meta, len: %d", len(blob))
+ }
+ if (len(blob)-historyMetaSize)%common.AddressLength != 0 {
+ return fmt.Errorf("corrupted state history meta, len: %d", len(blob))
+ }
+ m.version = blob[0]
+ m.parent = common.BytesToHash(blob[1 : 1+common.HashLength])
+ m.root = common.BytesToHash(blob[1+common.HashLength : 1+2*common.HashLength])
+ m.block = binary.BigEndian.Uint64(blob[1+2*common.HashLength : historyMetaSize])
+ for pos := historyMetaSize; pos < len(blob); {
+ m.incomplete = append(m.incomplete, common.BytesToAddress(blob[pos:pos+common.AddressLength]))
+ pos += common.AddressLength
+ }
+ return nil
+ default:
+ return fmt.Errorf("unknown version %d", blob[0])
+ }
+}
+
+// history represents a set of state changes belong to a block along with
+// the metadata including the state roots involved in the state transition.
+// State history objects in disk are linked with each other by a unique id
+// (8-bytes integer), the oldest state history object can be pruned on demand
+// in order to control the storage size.
+type history struct {
+ meta *meta // Meta data of history
+ accounts map[common.Address][]byte // Account data keyed by its address hash
+ accountList []common.Address // Sorted account hash list
+ storages map[common.Address]map[common.Hash][]byte // Storage data keyed by its address hash and slot hash
+ storageList map[common.Address][]common.Hash // Sorted slot hash list
+}
+
+// newHistory constructs the state history object with provided state change set.
+func newHistory(root common.Hash, parent common.Hash, block uint64, states *triestate.Set) *history {
+ var (
+ accountList []common.Address
+ storageList = make(map[common.Address][]common.Hash)
+ incomplete []common.Address
+ )
+ for addr := range states.Accounts {
+ accountList = append(accountList, addr)
+ }
+ slices.SortFunc(accountList, common.Address.Cmp)
+
+ for addr, slots := range states.Storages {
+ slist := make([]common.Hash, 0, len(slots))
+ for slotHash := range slots {
+ slist = append(slist, slotHash)
+ }
+ slices.SortFunc(slist, common.Hash.Cmp)
+ storageList[addr] = slist
+ }
+ for addr := range states.Incomplete {
+ incomplete = append(incomplete, addr)
+ }
+ slices.SortFunc(incomplete, common.Address.Cmp)
+
+ return &history{
+ meta: &meta{
+ version: stateHistoryVersion,
+ parent: parent,
+ root: root,
+ block: block,
+ incomplete: incomplete,
+ },
+ accounts: states.Accounts,
+ accountList: accountList,
+ storages: states.Storages,
+ storageList: storageList,
+ }
+}
+
+// encode serializes the state history and returns four byte streams represent
+// concatenated account/storage data, account/storage indexes respectively.
+func (h *history) encode() ([]byte, []byte, []byte, []byte) {
+ var (
+ slotNumber uint32 // the number of processed slots
+ accountData []byte // the buffer for concatenated account data
+ storageData []byte // the buffer for concatenated storage data
+ accountIndexes []byte // the buffer for concatenated account index
+ storageIndexes []byte // the buffer for concatenated storage index
+ )
+ for _, addr := range h.accountList {
+ accIndex := accountIndex{
+ address: addr,
+ length: uint8(len(h.accounts[addr])),
+ offset: uint32(len(accountData)),
+ }
+ slots, exist := h.storages[addr]
+ if exist {
+ // Encode storage slots in order
+ for _, slotHash := range h.storageList[addr] {
+ sIndex := slotIndex{
+ hash: slotHash,
+ length: uint8(len(slots[slotHash])),
+ offset: uint32(len(storageData)),
+ }
+ storageData = append(storageData, slots[slotHash]...)
+ storageIndexes = append(storageIndexes, sIndex.encode()...)
+ }
+ // Fill up the storage meta in account index
+ accIndex.storageOffset = slotNumber
+ accIndex.storageSlots = uint32(len(slots))
+ slotNumber += uint32(len(slots))
+ }
+ accountData = append(accountData, h.accounts[addr]...)
+ accountIndexes = append(accountIndexes, accIndex.encode()...)
+ }
+ return accountData, storageData, accountIndexes, storageIndexes
+}
+
+// decoder wraps the byte streams for decoding with extra meta fields.
+type decoder struct {
+ accountData []byte // the buffer for concatenated account data
+ storageData []byte // the buffer for concatenated storage data
+ accountIndexes []byte // the buffer for concatenated account index
+ storageIndexes []byte // the buffer for concatenated storage index
+
+ lastAccount *common.Address // the address of last resolved account
+ lastAccountRead uint32 // the read-cursor position of account data
+ lastSlotIndexRead uint32 // the read-cursor position of storage slot index
+ lastSlotDataRead uint32 // the read-cursor position of storage slot data
+}
+
+// verify validates the provided byte streams for decoding state history. A few
+// checks will be performed to quickly detect data corruption. The byte stream
+// is regarded as corrupted if:
+//
+// - account indexes buffer is empty(empty state set is invalid)
+// - account indexes/storage indexer buffer is not aligned
+//
+// note, these situations are allowed:
+//
+// - empty account data: all accounts were not present
+// - empty storage set: no slots are modified
+func (r *decoder) verify() error {
+ if len(r.accountIndexes)%accountIndexSize != 0 || len(r.accountIndexes) == 0 {
+ return fmt.Errorf("invalid account index, len: %d", len(r.accountIndexes))
+ }
+ if len(r.storageIndexes)%slotIndexSize != 0 {
+ return fmt.Errorf("invalid storage index, len: %d", len(r.storageIndexes))
+ }
+ return nil
+}
+
+// readAccount parses the account from the byte stream with specified position.
+func (r *decoder) readAccount(pos int) (accountIndex, []byte, error) {
+ // Decode account index from the index byte stream.
+ var index accountIndex
+ if (pos+1)*accountIndexSize > len(r.accountIndexes) {
+ return accountIndex{}, nil, errors.New("account data buffer is corrupted")
+ }
+ index.decode(r.accountIndexes[pos*accountIndexSize : (pos+1)*accountIndexSize])
+
+ // Perform validation before parsing account data, ensure
+ // - account is sorted in order in byte stream
+ // - account data is strictly encoded with no gap inside
+ // - account data is not out-of-slice
+ if r.lastAccount != nil { // zero address is possible
+ if bytes.Compare(r.lastAccount.Bytes(), index.address.Bytes()) >= 0 {
+ return accountIndex{}, nil, errors.New("account is not in order")
+ }
+ }
+ if index.offset != r.lastAccountRead {
+ return accountIndex{}, nil, errors.New("account data buffer is gaped")
+ }
+ last := index.offset + uint32(index.length)
+ if uint32(len(r.accountData)) < last {
+ return accountIndex{}, nil, errors.New("account data buffer is corrupted")
+ }
+ data := r.accountData[index.offset:last]
+
+ r.lastAccount = &index.address
+ r.lastAccountRead = last
+
+ return index, data, nil
+}
+
+// readStorage parses the storage slots from the byte stream with specified account.
+func (r *decoder) readStorage(accIndex accountIndex) ([]common.Hash, map[common.Hash][]byte, error) {
+ var (
+ last common.Hash
+ list []common.Hash
+ storage = make(map[common.Hash][]byte)
+ )
+ for j := 0; j < int(accIndex.storageSlots); j++ {
+ var (
+ index slotIndex
+ start = (accIndex.storageOffset + uint32(j)) * uint32(slotIndexSize)
+ end = (accIndex.storageOffset + uint32(j+1)) * uint32(slotIndexSize)
+ )
+ // Perform validation before parsing storage slot data, ensure
+ // - slot index is not out-of-slice
+ // - slot data is not out-of-slice
+ // - slot is sorted in order in byte stream
+ // - slot indexes is strictly encoded with no gap inside
+ // - slot data is strictly encoded with no gap inside
+ if start != r.lastSlotIndexRead {
+ return nil, nil, errors.New("storage index buffer is gapped")
+ }
+ if uint32(len(r.storageIndexes)) < end {
+ return nil, nil, errors.New("storage index buffer is corrupted")
+ }
+ index.decode(r.storageIndexes[start:end])
+
+ if bytes.Compare(last.Bytes(), index.hash.Bytes()) >= 0 {
+ return nil, nil, errors.New("storage slot is not in order")
+ }
+ if index.offset != r.lastSlotDataRead {
+ return nil, nil, errors.New("storage data buffer is gapped")
+ }
+ sEnd := index.offset + uint32(index.length)
+ if uint32(len(r.storageData)) < sEnd {
+ return nil, nil, errors.New("storage data buffer is corrupted")
+ }
+ storage[index.hash] = r.storageData[r.lastSlotDataRead:sEnd]
+ list = append(list, index.hash)
+
+ last = index.hash
+ r.lastSlotIndexRead = end
+ r.lastSlotDataRead = sEnd
+ }
+ return list, storage, nil
+}
+
+// decode deserializes the account and storage data from the provided byte stream.
+func (h *history) decode(accountData, storageData, accountIndexes, storageIndexes []byte) error {
+ var (
+ accounts = make(map[common.Address][]byte)
+ storages = make(map[common.Address]map[common.Hash][]byte)
+ accountList []common.Address
+ storageList = make(map[common.Address][]common.Hash)
+
+ r = &decoder{
+ accountData: accountData,
+ storageData: storageData,
+ accountIndexes: accountIndexes,
+ storageIndexes: storageIndexes,
+ }
+ )
+ if err := r.verify(); err != nil {
+ return err
+ }
+ for i := 0; i < len(accountIndexes)/accountIndexSize; i++ {
+ // Resolve account first
+ accIndex, accData, err := r.readAccount(i)
+ if err != nil {
+ return err
+ }
+ accounts[accIndex.address] = accData
+ accountList = append(accountList, accIndex.address)
+
+ // Resolve storage slots
+ slotList, slotData, err := r.readStorage(accIndex)
+ if err != nil {
+ return err
+ }
+ if len(slotList) > 0 {
+ storageList[accIndex.address] = slotList
+ storages[accIndex.address] = slotData
+ }
+ }
+ h.accounts = accounts
+ h.accountList = accountList
+ h.storages = storages
+ h.storageList = storageList
+ return nil
+}
+
+// readHistory reads and decodes the state history object by the given id.
+func readHistory(freezer *rawdb.ResettableFreezer, id uint64) (*history, error) {
+ blob := rawdb.ReadStateHistoryMeta(freezer, id)
+ if len(blob) == 0 {
+ return nil, fmt.Errorf("state history not found %d", id)
+ }
+ var m meta
+ if err := m.decode(blob); err != nil {
+ return nil, err
+ }
+ var (
+ dec = history{meta: &m}
+ accountData = rawdb.ReadStateAccountHistory(freezer, id)
+ storageData = rawdb.ReadStateStorageHistory(freezer, id)
+ accountIndexes = rawdb.ReadStateAccountIndex(freezer, id)
+ storageIndexes = rawdb.ReadStateStorageIndex(freezer, id)
+ )
+ if err := dec.decode(accountData, storageData, accountIndexes, storageIndexes); err != nil {
+ return nil, err
+ }
+ return &dec, nil
+}
+
+// writeHistory persists the state history with the provided state set.
+func writeHistory(freezer *rawdb.ResettableFreezer, dl *diffLayer) error {
+ // Short circuit if state set is not available.
+ if dl.states == nil {
+ return errors.New("state change set is not available")
+ }
+ var (
+ start = time.Now()
+ history = newHistory(dl.rootHash(), dl.parentLayer().rootHash(), dl.block, dl.states)
+ )
+ accountData, storageData, accountIndex, storageIndex := history.encode()
+ dataSize := common.StorageSize(len(accountData) + len(storageData))
+ indexSize := common.StorageSize(len(accountIndex) + len(storageIndex))
+
+ // Write history data into five freezer table respectively.
+ rawdb.WriteStateHistory(freezer, dl.stateID(), history.meta.encode(), accountIndex, storageIndex, accountData, storageData)
+
+ historyDataBytesMeter.Mark(int64(dataSize))
+ historyIndexBytesMeter.Mark(int64(indexSize))
+ historyBuildTimeMeter.UpdateSince(start)
+ log.Debug("Stored state history", "id", dl.stateID(), "block", dl.block, "data", dataSize, "index", indexSize, "elapsed", common.PrettyDuration(time.Since(start)))
+
+ return nil
+}
+
+// checkHistories retrieves a batch of meta objects with the specified range
+// and performs the callback on each item.
+func checkHistories(freezer *rawdb.ResettableFreezer, start, count uint64, check func(*meta) error) error {
+ for count > 0 {
+ number := count
+ if number > 10000 {
+ number = 10000 // split the big read into small chunks
+ }
+ blobs, err := rawdb.ReadStateHistoryMetaList(freezer, start, number)
+ if err != nil {
+ return err
+ }
+ for _, blob := range blobs {
+ var dec meta
+ if err := dec.decode(blob); err != nil {
+ return err
+ }
+ if err := check(&dec); err != nil {
+ return err
+ }
+ }
+ count -= uint64(len(blobs))
+ start += uint64(len(blobs))
+ }
+ return nil
+}
+
+// truncateFromHead removes the extra state histories from the head with the given
+// parameters. It returns the number of items removed from the head.
+func truncateFromHead(db ethdb.Batcher, freezer *rawdb.ResettableFreezer, nhead uint64) (int, error) {
+ ohead, err := freezer.Ancients()
+ if err != nil {
+ return 0, err
+ }
+ otail, err := freezer.Tail()
+ if err != nil {
+ return 0, err
+ }
+ // Ensure that the truncation target falls within the specified range.
+ if ohead < nhead || nhead < otail {
+ return 0, fmt.Errorf("out of range, tail: %d, head: %d, target: %d", otail, ohead, nhead)
+ }
+ // Short circuit if nothing to truncate.
+ if ohead == nhead {
+ return 0, nil
+ }
+ // Load the meta objects in range [nhead+1, ohead]
+ blobs, err := rawdb.ReadStateHistoryMetaList(freezer, nhead+1, ohead-nhead)
+ if err != nil {
+ return 0, err
+ }
+ batch := db.NewBatch()
+ for _, blob := range blobs {
+ var m meta
+ if err := m.decode(blob); err != nil {
+ return 0, err
+ }
+ rawdb.DeleteStateID(batch, m.root)
+ }
+ if err := batch.Write(); err != nil {
+ return 0, err
+ }
+ ohead, err = freezer.TruncateHead(nhead)
+ if err != nil {
+ return 0, err
+ }
+ return int(ohead - nhead), nil
+}
+
+// truncateFromTail removes the extra state histories from the tail with the given
+// parameters. It returns the number of items removed from the tail.
+func truncateFromTail(db ethdb.Batcher, freezer *rawdb.ResettableFreezer, ntail uint64) (int, error) {
+ ohead, err := freezer.Ancients()
+ if err != nil {
+ return 0, err
+ }
+ otail, err := freezer.Tail()
+ if err != nil {
+ return 0, err
+ }
+ // Ensure that the truncation target falls within the specified range.
+ if otail > ntail || ntail > ohead {
+ return 0, fmt.Errorf("out of range, tail: %d, head: %d, target: %d", otail, ohead, ntail)
+ }
+ // Short circuit if nothing to truncate.
+ if otail == ntail {
+ return 0, nil
+ }
+ // Load the meta objects in range [otail+1, ntail]
+ blobs, err := rawdb.ReadStateHistoryMetaList(freezer, otail+1, ntail-otail)
+ if err != nil {
+ return 0, err
+ }
+ batch := db.NewBatch()
+ for _, blob := range blobs {
+ var m meta
+ if err := m.decode(blob); err != nil {
+ return 0, err
+ }
+ rawdb.DeleteStateID(batch, m.root)
+ }
+ if err := batch.Write(); err != nil {
+ return 0, err
+ }
+ otail, err = freezer.TruncateTail(ntail)
+ if err != nil {
+ return 0, err
+ }
+ return int(ntail - otail), nil
+}
diff --git a/trie_by_cid/triedb/pathdb/history_test.go b/trie_by_cid/triedb/pathdb/history_test.go
new file mode 100644
index 0000000..70e0088
--- /dev/null
+++ b/trie_by_cid/triedb/pathdb/history_test.go
@@ -0,0 +1,334 @@
+// Copyright 2022 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see
+
+package pathdb
+
+import (
+ "bytes"
+ "fmt"
+ "reflect"
+ "testing"
+
+ "github.com/cerc-io/ipld-eth-statedb/trie_by_cid/trie/testutil"
+ "github.com/cerc-io/ipld-eth-statedb/trie_by_cid/trie/triestate"
+ "github.com/ethereum/go-ethereum/common"
+ "github.com/ethereum/go-ethereum/core/rawdb"
+ "github.com/ethereum/go-ethereum/core/types"
+ "github.com/ethereum/go-ethereum/ethdb"
+ "github.com/ethereum/go-ethereum/rlp"
+)
+
+// randomStateSet generates a random state change set.
+func randomStateSet(n int) *triestate.Set {
+ var (
+ accounts = make(map[common.Address][]byte)
+ storages = make(map[common.Address]map[common.Hash][]byte)
+ )
+ for i := 0; i < n; i++ {
+ addr := testutil.RandomAddress()
+ storages[addr] = make(map[common.Hash][]byte)
+ for j := 0; j < 3; j++ {
+ v, _ := rlp.EncodeToBytes(common.TrimLeftZeroes(testutil.RandBytes(32)))
+ storages[addr][testutil.RandomHash()] = v
+ }
+ account := generateAccount(types.EmptyRootHash)
+ accounts[addr] = types.SlimAccountRLP(account)
+ }
+ return triestate.New(accounts, storages, nil)
+}
+
+func makeHistory() *history {
+ return newHistory(testutil.RandomHash(), types.EmptyRootHash, 0, randomStateSet(3))
+}
+
+func makeHistories(n int) []*history {
+ var (
+ parent = types.EmptyRootHash
+ result []*history
+ )
+ for i := 0; i < n; i++ {
+ root := testutil.RandomHash()
+ h := newHistory(root, parent, uint64(i), randomStateSet(3))
+ parent = root
+ result = append(result, h)
+ }
+ return result
+}
+
+func TestEncodeDecodeHistory(t *testing.T) {
+ var (
+ m meta
+ dec history
+ obj = makeHistory()
+ )
+ // check if meta data can be correctly encode/decode
+ blob := obj.meta.encode()
+ if err := m.decode(blob); err != nil {
+ t.Fatalf("Failed to decode %v", err)
+ }
+ if !reflect.DeepEqual(&m, obj.meta) {
+ t.Fatal("meta is mismatched")
+ }
+
+ // check if account/storage data can be correctly encode/decode
+ accountData, storageData, accountIndexes, storageIndexes := obj.encode()
+ if err := dec.decode(accountData, storageData, accountIndexes, storageIndexes); err != nil {
+ t.Fatalf("Failed to decode, err: %v", err)
+ }
+ if !compareSet(dec.accounts, obj.accounts) {
+ t.Fatal("account data is mismatched")
+ }
+ if !compareStorages(dec.storages, obj.storages) {
+ t.Fatal("storage data is mismatched")
+ }
+ if !compareList(dec.accountList, obj.accountList) {
+ t.Fatal("account list is mismatched")
+ }
+ if !compareStorageList(dec.storageList, obj.storageList) {
+ t.Fatal("storage list is mismatched")
+ }
+}
+
+func checkHistory(t *testing.T, db ethdb.KeyValueReader, freezer *rawdb.ResettableFreezer, id uint64, root common.Hash, exist bool) {
+ blob := rawdb.ReadStateHistoryMeta(freezer, id)
+ if exist && len(blob) == 0 {
+ t.Fatalf("Failed to load trie history, %d", id)
+ }
+ if !exist && len(blob) != 0 {
+ t.Fatalf("Unexpected trie history, %d", id)
+ }
+ if exist && rawdb.ReadStateID(db, root) == nil {
+ t.Fatalf("Root->ID mapping is not found, %d", id)
+ }
+ if !exist && rawdb.ReadStateID(db, root) != nil {
+ t.Fatalf("Unexpected root->ID mapping, %d", id)
+ }
+}
+
+func checkHistoriesInRange(t *testing.T, db ethdb.KeyValueReader, freezer *rawdb.ResettableFreezer, from, to uint64, roots []common.Hash, exist bool) {
+ for i, j := from, 0; i <= to; i, j = i+1, j+1 {
+ checkHistory(t, db, freezer, i, roots[j], exist)
+ }
+}
+
+func TestTruncateHeadHistory(t *testing.T) {
+ var (
+ roots []common.Hash
+ hs = makeHistories(10)
+ db = rawdb.NewMemoryDatabase()
+ freezer, _ = openFreezer(t.TempDir(), false)
+ )
+ defer freezer.Close()
+
+ for i := 0; i < len(hs); i++ {
+ accountData, storageData, accountIndex, storageIndex := hs[i].encode()
+ rawdb.WriteStateHistory(freezer, uint64(i+1), hs[i].meta.encode(), accountIndex, storageIndex, accountData, storageData)
+ rawdb.WriteStateID(db, hs[i].meta.root, uint64(i+1))
+ roots = append(roots, hs[i].meta.root)
+ }
+ for size := len(hs); size > 0; size-- {
+ pruned, err := truncateFromHead(db, freezer, uint64(size-1))
+ if err != nil {
+ t.Fatalf("Failed to truncate from head %v", err)
+ }
+ if pruned != 1 {
+ t.Error("Unexpected pruned items", "want", 1, "got", pruned)
+ }
+ checkHistoriesInRange(t, db, freezer, uint64(size), uint64(10), roots[size-1:], false)
+ checkHistoriesInRange(t, db, freezer, uint64(1), uint64(size-1), roots[:size-1], true)
+ }
+}
+
+func TestTruncateTailHistory(t *testing.T) {
+ var (
+ roots []common.Hash
+ hs = makeHistories(10)
+ db = rawdb.NewMemoryDatabase()
+ freezer, _ = openFreezer(t.TempDir(), false)
+ )
+ defer freezer.Close()
+
+ for i := 0; i < len(hs); i++ {
+ accountData, storageData, accountIndex, storageIndex := hs[i].encode()
+ rawdb.WriteStateHistory(freezer, uint64(i+1), hs[i].meta.encode(), accountIndex, storageIndex, accountData, storageData)
+ rawdb.WriteStateID(db, hs[i].meta.root, uint64(i+1))
+ roots = append(roots, hs[i].meta.root)
+ }
+ for newTail := 1; newTail < len(hs); newTail++ {
+ pruned, _ := truncateFromTail(db, freezer, uint64(newTail))
+ if pruned != 1 {
+ t.Error("Unexpected pruned items", "want", 1, "got", pruned)
+ }
+ checkHistoriesInRange(t, db, freezer, uint64(1), uint64(newTail), roots[:newTail], false)
+ checkHistoriesInRange(t, db, freezer, uint64(newTail+1), uint64(10), roots[newTail:], true)
+ }
+}
+
+func TestTruncateTailHistories(t *testing.T) {
+ var cases = []struct {
+ limit uint64
+ expPruned int
+ maxPruned uint64
+ minUnpruned uint64
+ empty bool
+ }{
+ {
+ 1, 9, 9, 10, false,
+ },
+ {
+ 0, 10, 10, 0 /* no meaning */, true,
+ },
+ {
+ 10, 0, 0, 1, false,
+ },
+ }
+ for i, c := range cases {
+ var (
+ roots []common.Hash
+ hs = makeHistories(10)
+ db = rawdb.NewMemoryDatabase()
+ freezer, _ = openFreezer(t.TempDir()+fmt.Sprintf("%d", i), false)
+ )
+ defer freezer.Close()
+
+ for i := 0; i < len(hs); i++ {
+ accountData, storageData, accountIndex, storageIndex := hs[i].encode()
+ rawdb.WriteStateHistory(freezer, uint64(i+1), hs[i].meta.encode(), accountIndex, storageIndex, accountData, storageData)
+ rawdb.WriteStateID(db, hs[i].meta.root, uint64(i+1))
+ roots = append(roots, hs[i].meta.root)
+ }
+ pruned, _ := truncateFromTail(db, freezer, uint64(10)-c.limit)
+ if pruned != c.expPruned {
+ t.Error("Unexpected pruned items", "want", c.expPruned, "got", pruned)
+ }
+ if c.empty {
+ checkHistoriesInRange(t, db, freezer, uint64(1), uint64(10), roots, false)
+ } else {
+ tail := 10 - int(c.limit)
+ checkHistoriesInRange(t, db, freezer, uint64(1), c.maxPruned, roots[:tail], false)
+ checkHistoriesInRange(t, db, freezer, c.minUnpruned, uint64(10), roots[tail:], true)
+ }
+ }
+}
+
+func TestTruncateOutOfRange(t *testing.T) {
+ var (
+ hs = makeHistories(10)
+ db = rawdb.NewMemoryDatabase()
+ freezer, _ = openFreezer(t.TempDir(), false)
+ )
+ defer freezer.Close()
+
+ for i := 0; i < len(hs); i++ {
+ accountData, storageData, accountIndex, storageIndex := hs[i].encode()
+ rawdb.WriteStateHistory(freezer, uint64(i+1), hs[i].meta.encode(), accountIndex, storageIndex, accountData, storageData)
+ rawdb.WriteStateID(db, hs[i].meta.root, uint64(i+1))
+ }
+ truncateFromTail(db, freezer, uint64(len(hs)/2))
+
+ // Ensure of-out-range truncations are rejected correctly.
+ head, _ := freezer.Ancients()
+ tail, _ := freezer.Tail()
+
+ cases := []struct {
+ mode int
+ target uint64
+ expErr error
+ }{
+ {0, head, nil}, // nothing to delete
+ {0, head + 1, fmt.Errorf("out of range, tail: %d, head: %d, target: %d", tail, head, head+1)},
+ {0, tail - 1, fmt.Errorf("out of range, tail: %d, head: %d, target: %d", tail, head, tail-1)},
+ {1, tail, nil}, // nothing to delete
+ {1, head + 1, fmt.Errorf("out of range, tail: %d, head: %d, target: %d", tail, head, head+1)},
+ {1, tail - 1, fmt.Errorf("out of range, tail: %d, head: %d, target: %d", tail, head, tail-1)},
+ }
+ for _, c := range cases {
+ var gotErr error
+ if c.mode == 0 {
+ _, gotErr = truncateFromHead(db, freezer, c.target)
+ } else {
+ _, gotErr = truncateFromTail(db, freezer, c.target)
+ }
+ if !reflect.DeepEqual(gotErr, c.expErr) {
+ t.Errorf("Unexpected error, want: %v, got: %v", c.expErr, gotErr)
+ }
+ }
+}
+
+// openFreezer initializes the freezer instance for storing state histories.
+func openFreezer(datadir string, readOnly bool) (*rawdb.ResettableFreezer, error) {
+ return rawdb.NewStateFreezer(datadir, readOnly)
+}
+
+func compareSet[k comparable](a, b map[k][]byte) bool {
+ if len(a) != len(b) {
+ return false
+ }
+ for key, valA := range a {
+ valB, ok := b[key]
+ if !ok {
+ return false
+ }
+ if !bytes.Equal(valA, valB) {
+ return false
+ }
+ }
+ return true
+}
+
+func compareList[k comparable](a, b []k) bool {
+ if len(a) != len(b) {
+ return false
+ }
+ for i := 0; i < len(a); i++ {
+ if a[i] != b[i] {
+ return false
+ }
+ }
+ return true
+}
+
+func compareStorages(a, b map[common.Address]map[common.Hash][]byte) bool {
+ if len(a) != len(b) {
+ return false
+ }
+ for h, subA := range a {
+ subB, ok := b[h]
+ if !ok {
+ return false
+ }
+ if !compareSet(subA, subB) {
+ return false
+ }
+ }
+ return true
+}
+
+func compareStorageList(a, b map[common.Address][]common.Hash) bool {
+ if len(a) != len(b) {
+ return false
+ }
+ for h, la := range a {
+ lb, ok := b[h]
+ if !ok {
+ return false
+ }
+ if !compareList(la, lb) {
+ return false
+ }
+ }
+ return true
+}
diff --git a/trie_by_cid/triedb/pathdb/journal.go b/trie_by_cid/triedb/pathdb/journal.go
new file mode 100644
index 0000000..3e37857
--- /dev/null
+++ b/trie_by_cid/triedb/pathdb/journal.go
@@ -0,0 +1,387 @@
+// Copyright 2022 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see .
+
+package pathdb
+
+import (
+ "bytes"
+ "errors"
+ "fmt"
+ "io"
+ "time"
+
+ "github.com/cerc-io/ipld-eth-statedb/trie_by_cid/trie/trienode"
+ "github.com/cerc-io/ipld-eth-statedb/trie_by_cid/trie/triestate"
+ "github.com/ethereum/go-ethereum/common"
+ "github.com/ethereum/go-ethereum/core/rawdb"
+ "github.com/ethereum/go-ethereum/core/types"
+ "github.com/ethereum/go-ethereum/crypto"
+ "github.com/ethereum/go-ethereum/log"
+ "github.com/ethereum/go-ethereum/rlp"
+)
+
+var (
+ errMissJournal = errors.New("journal not found")
+ errMissVersion = errors.New("version not found")
+ errUnexpectedVersion = errors.New("unexpected journal version")
+ errMissDiskRoot = errors.New("disk layer root not found")
+ errUnmatchedJournal = errors.New("unmatched journal")
+)
+
+const journalVersion uint64 = 0
+
+// journalNode represents a trie node persisted in the journal.
+type journalNode struct {
+ Path []byte // Path of the node in the trie
+ Blob []byte // RLP-encoded trie node blob, nil means the node is deleted
+}
+
+// journalNodes represents a list trie nodes belong to a single account
+// or the main account trie.
+type journalNodes struct {
+ Owner common.Hash
+ Nodes []journalNode
+}
+
+// journalAccounts represents a list accounts belong to the layer.
+type journalAccounts struct {
+ Addresses []common.Address
+ Accounts [][]byte
+}
+
+// journalStorage represents a list of storage slots belong to an account.
+type journalStorage struct {
+ Incomplete bool
+ Account common.Address
+ Hashes []common.Hash
+ Slots [][]byte
+}
+
+// loadJournal tries to parse the layer journal from the disk.
+func (db *Database) loadJournal(diskRoot common.Hash) (layer, error) {
+ journal := rawdb.ReadTrieJournal(db.diskdb)
+ if len(journal) == 0 {
+ return nil, errMissJournal
+ }
+ r := rlp.NewStream(bytes.NewReader(journal), 0)
+
+ // Firstly, resolve the first element as the journal version
+ version, err := r.Uint64()
+ if err != nil {
+ return nil, errMissVersion
+ }
+ if version != journalVersion {
+ return nil, fmt.Errorf("%w want %d got %d", errUnexpectedVersion, journalVersion, version)
+ }
+ // Secondly, resolve the disk layer root, ensure it's continuous
+ // with disk layer. Note now we can ensure it's the layer journal
+ // correct version, so we expect everything can be resolved properly.
+ var root common.Hash
+ if err := r.Decode(&root); err != nil {
+ return nil, errMissDiskRoot
+ }
+ // The journal is not matched with persistent state, discard them.
+ // It can happen that geth crashes without persisting the journal.
+ if !bytes.Equal(root.Bytes(), diskRoot.Bytes()) {
+ return nil, fmt.Errorf("%w want %x got %x", errUnmatchedJournal, root, diskRoot)
+ }
+ // Load the disk layer from the journal
+ base, err := db.loadDiskLayer(r)
+ if err != nil {
+ return nil, err
+ }
+ // Load all the diff layers from the journal
+ head, err := db.loadDiffLayer(base, r)
+ if err != nil {
+ return nil, err
+ }
+ log.Debug("Loaded layer journal", "diskroot", diskRoot, "diffhead", head.rootHash())
+ return head, nil
+}
+
+// loadLayers loads a pre-existing state layer backed by a key-value store.
+func (db *Database) loadLayers() layer {
+ // Retrieve the root node of persistent state.
+ _, root := rawdb.ReadAccountTrieNode(db.diskdb, nil)
+ root = types.TrieRootHash(root)
+
+ // Load the layers by resolving the journal
+ head, err := db.loadJournal(root)
+ if err == nil {
+ return head
+ }
+ // journal is not matched(or missing) with the persistent state, discard
+ // it. Display log for discarding journal, but try to avoid showing
+ // useless information when the db is created from scratch.
+ if !(root == types.EmptyRootHash && errors.Is(err, errMissJournal)) {
+ log.Info("Failed to load journal, discard it", "err", err)
+ }
+ // Return single layer with persistent state.
+ return newDiskLayer(root, rawdb.ReadPersistentStateID(db.diskdb), db, nil, newNodeBuffer(db.bufferSize, nil, 0))
+}
+
+// loadDiskLayer reads the binary blob from the layer journal, reconstructing
+// a new disk layer on it.
+func (db *Database) loadDiskLayer(r *rlp.Stream) (layer, error) {
+ // Resolve disk layer root
+ var root common.Hash
+ if err := r.Decode(&root); err != nil {
+ return nil, fmt.Errorf("load disk root: %v", err)
+ }
+ // Resolve the state id of disk layer, it can be different
+ // with the persistent id tracked in disk, the id distance
+ // is the number of transitions aggregated in disk layer.
+ var id uint64
+ if err := r.Decode(&id); err != nil {
+ return nil, fmt.Errorf("load state id: %v", err)
+ }
+ stored := rawdb.ReadPersistentStateID(db.diskdb)
+ if stored > id {
+ return nil, fmt.Errorf("invalid state id: stored %d resolved %d", stored, id)
+ }
+ // Resolve nodes cached in node buffer
+ var encoded []journalNodes
+ if err := r.Decode(&encoded); err != nil {
+ return nil, fmt.Errorf("load disk nodes: %v", err)
+ }
+ nodes := make(map[common.Hash]map[string]*trienode.Node)
+ for _, entry := range encoded {
+ subset := make(map[string]*trienode.Node)
+ for _, n := range entry.Nodes {
+ if len(n.Blob) > 0 {
+ subset[string(n.Path)] = trienode.New(crypto.Keccak256Hash(n.Blob), n.Blob)
+ } else {
+ subset[string(n.Path)] = trienode.NewDeleted()
+ }
+ }
+ nodes[entry.Owner] = subset
+ }
+ // Calculate the internal state transitions by id difference.
+ base := newDiskLayer(root, id, db, nil, newNodeBuffer(db.bufferSize, nodes, id-stored))
+ return base, nil
+}
+
+// loadDiffLayer reads the next sections of a layer journal, reconstructing a new
+// diff and verifying that it can be linked to the requested parent.
+func (db *Database) loadDiffLayer(parent layer, r *rlp.Stream) (layer, error) {
+ // Read the next diff journal entry
+ var root common.Hash
+ if err := r.Decode(&root); err != nil {
+ // The first read may fail with EOF, marking the end of the journal
+ if err == io.EOF {
+ return parent, nil
+ }
+ return nil, fmt.Errorf("load diff root: %v", err)
+ }
+ var block uint64
+ if err := r.Decode(&block); err != nil {
+ return nil, fmt.Errorf("load block number: %v", err)
+ }
+ // Read in-memory trie nodes from journal
+ var encoded []journalNodes
+ if err := r.Decode(&encoded); err != nil {
+ return nil, fmt.Errorf("load diff nodes: %v", err)
+ }
+ nodes := make(map[common.Hash]map[string]*trienode.Node)
+ for _, entry := range encoded {
+ subset := make(map[string]*trienode.Node)
+ for _, n := range entry.Nodes {
+ if len(n.Blob) > 0 {
+ subset[string(n.Path)] = trienode.New(crypto.Keccak256Hash(n.Blob), n.Blob)
+ } else {
+ subset[string(n.Path)] = trienode.NewDeleted()
+ }
+ }
+ nodes[entry.Owner] = subset
+ }
+ // Read state changes from journal
+ var (
+ jaccounts journalAccounts
+ jstorages []journalStorage
+ accounts = make(map[common.Address][]byte)
+ storages = make(map[common.Address]map[common.Hash][]byte)
+ incomplete = make(map[common.Address]struct{})
+ )
+ if err := r.Decode(&jaccounts); err != nil {
+ return nil, fmt.Errorf("load diff accounts: %v", err)
+ }
+ for i, addr := range jaccounts.Addresses {
+ accounts[addr] = jaccounts.Accounts[i]
+ }
+ if err := r.Decode(&jstorages); err != nil {
+ return nil, fmt.Errorf("load diff storages: %v", err)
+ }
+ for _, entry := range jstorages {
+ set := make(map[common.Hash][]byte)
+ for i, h := range entry.Hashes {
+ if len(entry.Slots[i]) > 0 {
+ set[h] = entry.Slots[i]
+ } else {
+ set[h] = nil
+ }
+ }
+ if entry.Incomplete {
+ incomplete[entry.Account] = struct{}{}
+ }
+ storages[entry.Account] = set
+ }
+ return db.loadDiffLayer(newDiffLayer(parent, root, parent.stateID()+1, block, nodes, triestate.New(accounts, storages, incomplete)), r)
+}
+
+// journal implements the layer interface, marshaling the un-flushed trie nodes
+// along with layer meta data into provided byte buffer.
+func (dl *diskLayer) journal(w io.Writer) error {
+ dl.lock.RLock()
+ defer dl.lock.RUnlock()
+
+ // Ensure the layer didn't get stale
+ if dl.stale {
+ return errSnapshotStale
+ }
+ // Step one, write the disk root into the journal.
+ if err := rlp.Encode(w, dl.root); err != nil {
+ return err
+ }
+ // Step two, write the corresponding state id into the journal
+ if err := rlp.Encode(w, dl.id); err != nil {
+ return err
+ }
+ // Step three, write all unwritten nodes into the journal
+ nodes := make([]journalNodes, 0, len(dl.buffer.nodes))
+ for owner, subset := range dl.buffer.nodes {
+ entry := journalNodes{Owner: owner}
+ for path, node := range subset {
+ entry.Nodes = append(entry.Nodes, journalNode{Path: []byte(path), Blob: node.Blob})
+ }
+ nodes = append(nodes, entry)
+ }
+ if err := rlp.Encode(w, nodes); err != nil {
+ return err
+ }
+ log.Debug("Journaled pathdb disk layer", "root", dl.root, "nodes", len(dl.buffer.nodes))
+ return nil
+}
+
+// journal implements the layer interface, writing the memory layer contents
+// into a buffer to be stored in the database as the layer journal.
+func (dl *diffLayer) journal(w io.Writer) error {
+ dl.lock.RLock()
+ defer dl.lock.RUnlock()
+
+ // journal the parent first
+ if err := dl.parent.journal(w); err != nil {
+ return err
+ }
+ // Everything below was journaled, persist this layer too
+ if err := rlp.Encode(w, dl.root); err != nil {
+ return err
+ }
+ if err := rlp.Encode(w, dl.block); err != nil {
+ return err
+ }
+ // Write the accumulated trie nodes into buffer
+ nodes := make([]journalNodes, 0, len(dl.nodes))
+ for owner, subset := range dl.nodes {
+ entry := journalNodes{Owner: owner}
+ for path, node := range subset {
+ entry.Nodes = append(entry.Nodes, journalNode{Path: []byte(path), Blob: node.Blob})
+ }
+ nodes = append(nodes, entry)
+ }
+ if err := rlp.Encode(w, nodes); err != nil {
+ return err
+ }
+ // Write the accumulated state changes into buffer
+ var jacct journalAccounts
+ for addr, account := range dl.states.Accounts {
+ jacct.Addresses = append(jacct.Addresses, addr)
+ jacct.Accounts = append(jacct.Accounts, account)
+ }
+ if err := rlp.Encode(w, jacct); err != nil {
+ return err
+ }
+ storage := make([]journalStorage, 0, len(dl.states.Storages))
+ for addr, slots := range dl.states.Storages {
+ entry := journalStorage{Account: addr}
+ if _, ok := dl.states.Incomplete[addr]; ok {
+ entry.Incomplete = true
+ }
+ for slotHash, slot := range slots {
+ entry.Hashes = append(entry.Hashes, slotHash)
+ entry.Slots = append(entry.Slots, slot)
+ }
+ storage = append(storage, entry)
+ }
+ if err := rlp.Encode(w, storage); err != nil {
+ return err
+ }
+ log.Debug("Journaled pathdb diff layer", "root", dl.root, "parent", dl.parent.rootHash(), "id", dl.stateID(), "block", dl.block, "nodes", len(dl.nodes))
+ return nil
+}
+
+// Journal commits an entire diff hierarchy to disk into a single journal entry.
+// This is meant to be used during shutdown to persist the layer without
+// flattening everything down (bad for reorgs). And this function will mark the
+// database as read-only to prevent all following mutation to disk.
+func (db *Database) Journal(root common.Hash) error {
+ // Retrieve the head layer to journal from.
+ l := db.tree.get(root)
+ if l == nil {
+ return fmt.Errorf("triedb layer [%#x] missing", root)
+ }
+ disk := db.tree.bottom()
+ if l, ok := l.(*diffLayer); ok {
+ log.Info("Persisting dirty state to disk", "head", l.block, "root", root, "layers", l.id-disk.id+disk.buffer.layers)
+ } else { // disk layer only on noop runs (likely) or deep reorgs (unlikely)
+ log.Info("Persisting dirty state to disk", "root", root, "layers", disk.buffer.layers)
+ }
+ start := time.Now()
+
+ // Run the journaling
+ db.lock.Lock()
+ defer db.lock.Unlock()
+
+ // Short circuit if the database is in read only mode.
+ if db.readOnly {
+ return errDatabaseReadOnly
+ }
+ // Firstly write out the metadata of journal
+ journal := new(bytes.Buffer)
+ if err := rlp.Encode(journal, journalVersion); err != nil {
+ return err
+ }
+ // The stored state in disk might be empty, convert the
+ // root to emptyRoot in this case.
+ _, diskroot := rawdb.ReadAccountTrieNode(db.diskdb, nil)
+ diskroot = types.TrieRootHash(diskroot)
+
+ // Secondly write out the state root in disk, ensure all layers
+ // on top are continuous with disk.
+ if err := rlp.Encode(journal, diskroot); err != nil {
+ return err
+ }
+ // Finally write out the journal of each layer in reverse order.
+ if err := l.journal(journal); err != nil {
+ return err
+ }
+ // Store the journal into the database and return
+ rawdb.WriteTrieJournal(db.diskdb, journal.Bytes())
+
+ // Set the db in read only mode to reject all following mutations
+ db.readOnly = true
+ log.Info("Persisted dirty state to disk", "size", common.StorageSize(journal.Len()), "elapsed", common.PrettyDuration(time.Since(start)))
+ return nil
+}
diff --git a/trie_by_cid/triedb/pathdb/layertree.go b/trie_by_cid/triedb/pathdb/layertree.go
new file mode 100644
index 0000000..722caff
--- /dev/null
+++ b/trie_by_cid/triedb/pathdb/layertree.go
@@ -0,0 +1,214 @@
+// Copyright 2022 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see
+
+package pathdb
+
+import (
+ "errors"
+ "fmt"
+ "sync"
+
+ "github.com/cerc-io/ipld-eth-statedb/trie_by_cid/trie/trienode"
+ "github.com/cerc-io/ipld-eth-statedb/trie_by_cid/trie/triestate"
+ "github.com/ethereum/go-ethereum/common"
+ "github.com/ethereum/go-ethereum/core/types"
+)
+
+// layerTree is a group of state layers identified by the state root.
+// This structure defines a few basic operations for manipulating
+// state layers linked with each other in a tree structure. It's
+// thread-safe to use. However, callers need to ensure the thread-safety
+// of the referenced layer by themselves.
+type layerTree struct {
+ lock sync.RWMutex
+ layers map[common.Hash]layer
+}
+
+// newLayerTree constructs the layerTree with the given head layer.
+func newLayerTree(head layer) *layerTree {
+ tree := new(layerTree)
+ tree.reset(head)
+ return tree
+}
+
+// reset initializes the layerTree by the given head layer.
+// All the ancestors will be iterated out and linked in the tree.
+func (tree *layerTree) reset(head layer) {
+ tree.lock.Lock()
+ defer tree.lock.Unlock()
+
+ var layers = make(map[common.Hash]layer)
+ for head != nil {
+ layers[head.rootHash()] = head
+ head = head.parentLayer()
+ }
+ tree.layers = layers
+}
+
+// get retrieves a layer belonging to the given state root.
+func (tree *layerTree) get(root common.Hash) layer {
+ tree.lock.RLock()
+ defer tree.lock.RUnlock()
+
+ return tree.layers[types.TrieRootHash(root)]
+}
+
+// forEach iterates the stored layers inside and applies the
+// given callback on them.
+func (tree *layerTree) forEach(onLayer func(layer)) {
+ tree.lock.RLock()
+ defer tree.lock.RUnlock()
+
+ for _, layer := range tree.layers {
+ onLayer(layer)
+ }
+}
+
+// len returns the number of layers cached.
+func (tree *layerTree) len() int {
+ tree.lock.RLock()
+ defer tree.lock.RUnlock()
+
+ return len(tree.layers)
+}
+
+// add inserts a new layer into the tree if it can be linked to an existing old parent.
+func (tree *layerTree) add(root common.Hash, parentRoot common.Hash, block uint64, nodes *trienode.MergedNodeSet, states *triestate.Set) error {
+ // Reject noop updates to avoid self-loops. This is a special case that can
+ // happen for clique networks and proof-of-stake networks where empty blocks
+ // don't modify the state (0 block subsidy).
+ //
+ // Although we could silently ignore this internally, it should be the caller's
+ // responsibility to avoid even attempting to insert such a layer.
+ root, parentRoot = types.TrieRootHash(root), types.TrieRootHash(parentRoot)
+ if root == parentRoot {
+ return errors.New("layer cycle")
+ }
+ parent := tree.get(parentRoot)
+ if parent == nil {
+ return fmt.Errorf("triedb parent [%#x] layer missing", parentRoot)
+ }
+ l := parent.update(root, parent.stateID()+1, block, nodes.Flatten(), states)
+
+ tree.lock.Lock()
+ tree.layers[l.rootHash()] = l
+ tree.lock.Unlock()
+ return nil
+}
+
+// cap traverses downwards the diff tree until the number of allowed diff layers
+// are crossed. All diffs beyond the permitted number are flattened downwards.
+func (tree *layerTree) cap(root common.Hash, layers int) error {
+ // Retrieve the head layer to cap from
+ root = types.TrieRootHash(root)
+ l := tree.get(root)
+ if l == nil {
+ return fmt.Errorf("triedb layer [%#x] missing", root)
+ }
+ diff, ok := l.(*diffLayer)
+ if !ok {
+ return fmt.Errorf("triedb layer [%#x] is disk layer", root)
+ }
+ tree.lock.Lock()
+ defer tree.lock.Unlock()
+
+ // If full commit was requested, flatten the diffs and merge onto disk
+ if layers == 0 {
+ base, err := diff.persist(true)
+ if err != nil {
+ return err
+ }
+ // Replace the entire layer tree with the flat base
+ tree.layers = map[common.Hash]layer{base.rootHash(): base}
+ return nil
+ }
+ // Dive until we run out of layers or reach the persistent database
+ for i := 0; i < layers-1; i++ {
+ // If we still have diff layers below, continue down
+ if parent, ok := diff.parentLayer().(*diffLayer); ok {
+ diff = parent
+ } else {
+ // Diff stack too shallow, return without modifications
+ return nil
+ }
+ }
+ // We're out of layers, flatten anything below, stopping if it's the disk or if
+ // the memory limit is not yet exceeded.
+ switch parent := diff.parentLayer().(type) {
+ case *diskLayer:
+ return nil
+
+ case *diffLayer:
+ // Hold the lock to prevent any read operations until the new
+ // parent is linked correctly.
+ diff.lock.Lock()
+
+ base, err := parent.persist(false)
+ if err != nil {
+ diff.lock.Unlock()
+ return err
+ }
+ tree.layers[base.rootHash()] = base
+ diff.parent = base
+
+ diff.lock.Unlock()
+
+ default:
+ panic(fmt.Sprintf("unknown data layer in triedb: %T", parent))
+ }
+ // Remove any layer that is stale or links into a stale layer
+ children := make(map[common.Hash][]common.Hash)
+ for root, layer := range tree.layers {
+ if dl, ok := layer.(*diffLayer); ok {
+ parent := dl.parentLayer().rootHash()
+ children[parent] = append(children[parent], root)
+ }
+ }
+ var remove func(root common.Hash)
+ remove = func(root common.Hash) {
+ delete(tree.layers, root)
+ for _, child := range children[root] {
+ remove(child)
+ }
+ delete(children, root)
+ }
+ for root, layer := range tree.layers {
+ if dl, ok := layer.(*diskLayer); ok && dl.isStale() {
+ remove(root)
+ }
+ }
+ return nil
+}
+
+// bottom returns the bottom-most disk layer in this tree.
+func (tree *layerTree) bottom() *diskLayer {
+ tree.lock.RLock()
+ defer tree.lock.RUnlock()
+
+ if len(tree.layers) == 0 {
+ return nil // Shouldn't happen, empty tree
+ }
+ // pick a random one as the entry point
+ var current layer
+ for _, layer := range tree.layers {
+ current = layer
+ break
+ }
+ for current.parentLayer() != nil {
+ current = current.parentLayer()
+ }
+ return current.(*diskLayer)
+}
diff --git a/trie_by_cid/triedb/pathdb/metrics.go b/trie_by_cid/triedb/pathdb/metrics.go
new file mode 100644
index 0000000..9e2b1dc
--- /dev/null
+++ b/trie_by_cid/triedb/pathdb/metrics.go
@@ -0,0 +1,50 @@
+// Copyright 2022 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see
+
+package pathdb
+
+import "github.com/ethereum/go-ethereum/metrics"
+
+var (
+ cleanHitMeter = metrics.NewRegisteredMeter("pathdb/clean/hit", nil)
+ cleanMissMeter = metrics.NewRegisteredMeter("pathdb/clean/miss", nil)
+ cleanReadMeter = metrics.NewRegisteredMeter("pathdb/clean/read", nil)
+ cleanWriteMeter = metrics.NewRegisteredMeter("pathdb/clean/write", nil)
+
+ dirtyHitMeter = metrics.NewRegisteredMeter("pathdb/dirty/hit", nil)
+ dirtyMissMeter = metrics.NewRegisteredMeter("pathdb/dirty/miss", nil)
+ dirtyReadMeter = metrics.NewRegisteredMeter("pathdb/dirty/read", nil)
+ dirtyWriteMeter = metrics.NewRegisteredMeter("pathdb/dirty/write", nil)
+ dirtyNodeHitDepthHist = metrics.NewRegisteredHistogram("pathdb/dirty/depth", nil, metrics.NewExpDecaySample(1028, 0.015))
+
+ cleanFalseMeter = metrics.NewRegisteredMeter("pathdb/clean/false", nil)
+ dirtyFalseMeter = metrics.NewRegisteredMeter("pathdb/dirty/false", nil)
+ diskFalseMeter = metrics.NewRegisteredMeter("pathdb/disk/false", nil)
+
+ commitTimeTimer = metrics.NewRegisteredTimer("pathdb/commit/time", nil)
+ commitNodesMeter = metrics.NewRegisteredMeter("pathdb/commit/nodes", nil)
+ commitBytesMeter = metrics.NewRegisteredMeter("pathdb/commit/bytes", nil)
+
+ gcNodesMeter = metrics.NewRegisteredMeter("pathdb/gc/nodes", nil)
+ gcBytesMeter = metrics.NewRegisteredMeter("pathdb/gc/bytes", nil)
+
+ diffLayerBytesMeter = metrics.NewRegisteredMeter("pathdb/diff/bytes", nil)
+ diffLayerNodesMeter = metrics.NewRegisteredMeter("pathdb/diff/nodes", nil)
+
+ historyBuildTimeMeter = metrics.NewRegisteredTimer("pathdb/history/time", nil)
+ historyDataBytesMeter = metrics.NewRegisteredMeter("pathdb/history/bytes/data", nil)
+ historyIndexBytesMeter = metrics.NewRegisteredMeter("pathdb/history/bytes/index", nil)
+)
diff --git a/trie_by_cid/triedb/pathdb/nodebuffer.go b/trie_by_cid/triedb/pathdb/nodebuffer.go
new file mode 100644
index 0000000..39234e1
--- /dev/null
+++ b/trie_by_cid/triedb/pathdb/nodebuffer.go
@@ -0,0 +1,275 @@
+// Copyright 2022 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see .
+
+package pathdb
+
+import (
+ "fmt"
+ "time"
+
+ "github.com/VictoriaMetrics/fastcache"
+ "github.com/cerc-io/ipld-eth-statedb/trie_by_cid/trie/trienode"
+ "github.com/ethereum/go-ethereum/common"
+ "github.com/ethereum/go-ethereum/core/rawdb"
+ "github.com/ethereum/go-ethereum/crypto"
+ "github.com/ethereum/go-ethereum/ethdb"
+ "github.com/ethereum/go-ethereum/log"
+)
+
+// nodebuffer is a collection of modified trie nodes to aggregate the disk
+// write. The content of the nodebuffer must be checked before diving into
+// disk (since it basically is not-yet-written data).
+type nodebuffer struct {
+ layers uint64 // The number of diff layers aggregated inside
+ size uint64 // The size of aggregated writes
+ limit uint64 // The maximum memory allowance in bytes
+ nodes map[common.Hash]map[string]*trienode.Node // The dirty node set, mapped by owner and path
+}
+
+// newNodeBuffer initializes the node buffer with the provided nodes.
+func newNodeBuffer(limit int, nodes map[common.Hash]map[string]*trienode.Node, layers uint64) *nodebuffer {
+ if nodes == nil {
+ nodes = make(map[common.Hash]map[string]*trienode.Node)
+ }
+ var size uint64
+ for _, subset := range nodes {
+ for path, n := range subset {
+ size += uint64(len(n.Blob) + len(path))
+ }
+ }
+ return &nodebuffer{
+ layers: layers,
+ nodes: nodes,
+ size: size,
+ limit: uint64(limit),
+ }
+}
+
+// node retrieves the trie node with given node info.
+func (b *nodebuffer) node(owner common.Hash, path []byte, hash common.Hash) (*trienode.Node, error) {
+ subset, ok := b.nodes[owner]
+ if !ok {
+ return nil, nil
+ }
+ n, ok := subset[string(path)]
+ if !ok {
+ return nil, nil
+ }
+ if n.Hash != hash {
+ dirtyFalseMeter.Mark(1)
+ log.Error("Unexpected trie node in node buffer", "owner", owner, "path", path, "expect", hash, "got", n.Hash)
+ return nil, newUnexpectedNodeError("dirty", hash, n.Hash, owner, path, n.Blob)
+ }
+ return n, nil
+}
+
+// commit merges the dirty nodes into the nodebuffer. This operation won't take
+// the ownership of the nodes map which belongs to the bottom-most diff layer.
+// It will just hold the node references from the given map which are safe to
+// copy.
+func (b *nodebuffer) commit(nodes map[common.Hash]map[string]*trienode.Node) *nodebuffer {
+ var (
+ delta int64
+ overwrite int64
+ overwriteSize int64
+ )
+ for owner, subset := range nodes {
+ current, exist := b.nodes[owner]
+ if !exist {
+ // Allocate a new map for the subset instead of claiming it directly
+ // from the passed map to avoid potential concurrent map read/write.
+ // The nodes belong to original diff layer are still accessible even
+ // after merging, thus the ownership of nodes map should still belong
+ // to original layer and any mutation on it should be prevented.
+ current = make(map[string]*trienode.Node)
+ for path, n := range subset {
+ current[path] = n
+ delta += int64(len(n.Blob) + len(path))
+ }
+ b.nodes[owner] = current
+ continue
+ }
+ for path, n := range subset {
+ if orig, exist := current[path]; !exist {
+ delta += int64(len(n.Blob) + len(path))
+ } else {
+ delta += int64(len(n.Blob) - len(orig.Blob))
+ overwrite++
+ overwriteSize += int64(len(orig.Blob) + len(path))
+ }
+ current[path] = n
+ }
+ b.nodes[owner] = current
+ }
+ b.updateSize(delta)
+ b.layers++
+ gcNodesMeter.Mark(overwrite)
+ gcBytesMeter.Mark(overwriteSize)
+ return b
+}
+
+// revert is the reverse operation of commit. It also merges the provided nodes
+// into the nodebuffer, the difference is that the provided node set should
+// revert the changes made by the last state transition.
+func (b *nodebuffer) revert(db ethdb.KeyValueReader, nodes map[common.Hash]map[string]*trienode.Node) error {
+ // Short circuit if no embedded state transition to revert.
+ if b.layers == 0 {
+ return errStateUnrecoverable
+ }
+ b.layers--
+
+ // Reset the entire buffer if only a single transition left.
+ if b.layers == 0 {
+ b.reset()
+ return nil
+ }
+ var delta int64
+ for owner, subset := range nodes {
+ current, ok := b.nodes[owner]
+ if !ok {
+ panic(fmt.Sprintf("non-existent subset (%x)", owner))
+ }
+ for path, n := range subset {
+ orig, ok := current[path]
+ if !ok {
+ // There is a special case in MPT that one child is removed from
+ // a fullNode which only has two children, and then a new child
+ // with different position is immediately inserted into the fullNode.
+ // In this case, the clean child of the fullNode will also be
+ // marked as dirty because of node collapse and expansion.
+ //
+ // In case of database rollback, don't panic if this "clean"
+ // node occurs which is not present in buffer.
+ var nhash common.Hash
+ if owner == (common.Hash{}) {
+ _, nhash = rawdb.ReadAccountTrieNode(db, []byte(path))
+ } else {
+ _, nhash = rawdb.ReadStorageTrieNode(db, owner, []byte(path))
+ }
+ // Ignore the clean node in the case described above.
+ if nhash == n.Hash {
+ continue
+ }
+ panic(fmt.Sprintf("non-existent node (%x %v) blob: %v", owner, path, crypto.Keccak256Hash(n.Blob).Hex()))
+ }
+ current[path] = n
+ delta += int64(len(n.Blob)) - int64(len(orig.Blob))
+ }
+ }
+ b.updateSize(delta)
+ return nil
+}
+
+// updateSize updates the total cache size by the given delta.
+func (b *nodebuffer) updateSize(delta int64) {
+ size := int64(b.size) + delta
+ if size >= 0 {
+ b.size = uint64(size)
+ return
+ }
+ s := b.size
+ b.size = 0
+ log.Error("Invalid pathdb buffer size", "prev", common.StorageSize(s), "delta", common.StorageSize(delta))
+}
+
+// reset cleans up the disk cache.
+func (b *nodebuffer) reset() {
+ b.layers = 0
+ b.size = 0
+ b.nodes = make(map[common.Hash]map[string]*trienode.Node)
+}
+
+// empty returns an indicator if nodebuffer contains any state transition inside.
+func (b *nodebuffer) empty() bool {
+ return b.layers == 0
+}
+
+// setSize sets the buffer size to the provided number, and invokes a flush
+// operation if the current memory usage exceeds the new limit.
+func (b *nodebuffer) setSize(size int, db ethdb.KeyValueStore, clean *fastcache.Cache, id uint64) error {
+ b.limit = uint64(size)
+ return b.flush(db, clean, id, false)
+}
+
+// flush persists the in-memory dirty trie node into the disk if the configured
+// memory threshold is reached. Note, all data must be written atomically.
+func (b *nodebuffer) flush(db ethdb.KeyValueStore, clean *fastcache.Cache, id uint64, force bool) error {
+ if b.size <= b.limit && !force {
+ return nil
+ }
+ // Ensure the target state id is aligned with the internal counter.
+ head := rawdb.ReadPersistentStateID(db)
+ if head+b.layers != id {
+ return fmt.Errorf("buffer layers (%d) cannot be applied on top of persisted state id (%d) to reach requested state id (%d)", b.layers, head, id)
+ }
+ var (
+ start = time.Now()
+ batch = db.NewBatchWithSize(int(b.size))
+ )
+ nodes := writeNodes(batch, b.nodes, clean)
+ rawdb.WritePersistentStateID(batch, id)
+
+ // Flush all mutations in a single batch
+ size := batch.ValueSize()
+ if err := batch.Write(); err != nil {
+ return err
+ }
+ commitBytesMeter.Mark(int64(size))
+ commitNodesMeter.Mark(int64(nodes))
+ commitTimeTimer.UpdateSince(start)
+ log.Debug("Persisted pathdb nodes", "nodes", len(b.nodes), "bytes", common.StorageSize(size), "elapsed", common.PrettyDuration(time.Since(start)))
+ b.reset()
+ return nil
+}
+
+// writeNodes writes the trie nodes into the provided database batch.
+// Note this function will also inject all the newly written nodes
+// into clean cache.
+func writeNodes(batch ethdb.Batch, nodes map[common.Hash]map[string]*trienode.Node, clean *fastcache.Cache) (total int) {
+ for owner, subset := range nodes {
+ for path, n := range subset {
+ if n.IsDeleted() {
+ if owner == (common.Hash{}) {
+ rawdb.DeleteAccountTrieNode(batch, []byte(path))
+ } else {
+ rawdb.DeleteStorageTrieNode(batch, owner, []byte(path))
+ }
+ if clean != nil {
+ clean.Del(cacheKey(owner, []byte(path)))
+ }
+ } else {
+ if owner == (common.Hash{}) {
+ rawdb.WriteAccountTrieNode(batch, []byte(path), n.Blob)
+ } else {
+ rawdb.WriteStorageTrieNode(batch, owner, []byte(path), n.Blob)
+ }
+ if clean != nil {
+ clean.Set(cacheKey(owner, []byte(path)), n.Blob)
+ }
+ }
+ }
+ total += len(subset)
+ }
+ return total
+}
+
+// cacheKey constructs the unique key of clean cache.
+func cacheKey(owner common.Hash, path []byte) []byte {
+ if owner == (common.Hash{}) {
+ return path
+ }
+ return append(owner.Bytes(), path...)
+}
diff --git a/trie_by_cid/triedb/pathdb/testutils.go b/trie_by_cid/triedb/pathdb/testutils.go
new file mode 100644
index 0000000..d27b4a9
--- /dev/null
+++ b/trie_by_cid/triedb/pathdb/testutils.go
@@ -0,0 +1,157 @@
+// Copyright 2023 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see .
+
+package pathdb
+
+import (
+ "bytes"
+ "fmt"
+
+ "github.com/ethereum/go-ethereum/common"
+ "github.com/ethereum/go-ethereum/core/types"
+ "github.com/ethereum/go-ethereum/crypto"
+ "golang.org/x/exp/slices"
+
+ "github.com/cerc-io/ipld-eth-statedb/trie_by_cid/trie/trienode"
+ "github.com/cerc-io/ipld-eth-statedb/trie_by_cid/trie/triestate"
+)
+
+// testHasher is a test utility for computing root hash of a batch of state
+// elements. The hash algorithm is to sort all the elements in lexicographical
+// order, concat the key and value in turn, and perform hash calculation on
+// the concatenated bytes. Except the root hash, a nodeset will be returned
+// once Commit is called, which contains all the changes made to hasher.
+type testHasher struct {
+ owner common.Hash // owner identifier
+ root common.Hash // original root
+ dirties map[common.Hash][]byte // dirty states
+ cleans map[common.Hash][]byte // clean states
+}
+
+// newTestHasher constructs a hasher object with provided states.
+func newTestHasher(owner common.Hash, root common.Hash, cleans map[common.Hash][]byte) (*testHasher, error) {
+ if cleans == nil {
+ cleans = make(map[common.Hash][]byte)
+ }
+ if got, _ := hash(cleans); got != root {
+ return nil, fmt.Errorf("state root mismatched, want: %x, got: %x", root, got)
+ }
+ return &testHasher{
+ owner: owner,
+ root: root,
+ dirties: make(map[common.Hash][]byte),
+ cleans: cleans,
+ }, nil
+}
+
+// Get returns the value for key stored in the trie.
+func (h *testHasher) Get(key []byte) ([]byte, error) {
+ hash := common.BytesToHash(key)
+ val, ok := h.dirties[hash]
+ if ok {
+ return val, nil
+ }
+ return h.cleans[hash], nil
+}
+
+// Update associates key with value in the trie.
+func (h *testHasher) Update(key, value []byte) error {
+ h.dirties[common.BytesToHash(key)] = common.CopyBytes(value)
+ return nil
+}
+
+// Delete removes any existing value for key from the trie.
+func (h *testHasher) Delete(key []byte) error {
+ h.dirties[common.BytesToHash(key)] = nil
+ return nil
+}
+
+// Commit computes the new hash of the states and returns the set with all
+// state changes.
+func (h *testHasher) Commit(collectLeaf bool) (common.Hash, *trienode.NodeSet, error) {
+ var (
+ nodes = make(map[common.Hash][]byte)
+ set = trienode.NewNodeSet(h.owner)
+ )
+ for hash, val := range h.cleans {
+ nodes[hash] = val
+ }
+ for hash, val := range h.dirties {
+ nodes[hash] = val
+ if bytes.Equal(val, h.cleans[hash]) {
+ continue
+ }
+ if len(val) == 0 {
+ set.AddNode(hash.Bytes(), trienode.NewDeleted())
+ } else {
+ set.AddNode(hash.Bytes(), trienode.New(crypto.Keccak256Hash(val), val))
+ }
+ }
+ root, blob := hash(nodes)
+
+ // Include the dirty root node as well.
+ if root != types.EmptyRootHash && root != h.root {
+ set.AddNode(nil, trienode.New(root, blob))
+ }
+ if root == types.EmptyRootHash && h.root != types.EmptyRootHash {
+ set.AddNode(nil, trienode.NewDeleted())
+ }
+ return root, set, nil
+}
+
+// hash performs the hash computation upon the provided states.
+func hash(states map[common.Hash][]byte) (common.Hash, []byte) {
+ var hs []common.Hash
+ for hash := range states {
+ hs = append(hs, hash)
+ }
+ slices.SortFunc(hs, common.Hash.Cmp)
+
+ var input []byte
+ for _, hash := range hs {
+ if len(states[hash]) == 0 {
+ continue
+ }
+ input = append(input, hash.Bytes()...)
+ input = append(input, states[hash]...)
+ }
+ if len(input) == 0 {
+ return types.EmptyRootHash, nil
+ }
+ return crypto.Keccak256Hash(input), input
+}
+
+type hashLoader struct {
+ accounts map[common.Hash][]byte
+ storages map[common.Hash]map[common.Hash][]byte
+}
+
+func newHashLoader(accounts map[common.Hash][]byte, storages map[common.Hash]map[common.Hash][]byte) *hashLoader {
+ return &hashLoader{
+ accounts: accounts,
+ storages: storages,
+ }
+}
+
+// OpenTrie opens the main account trie.
+func (l *hashLoader) OpenTrie(root common.Hash) (triestate.Trie, error) {
+ return newTestHasher(common.Hash{}, root, l.accounts)
+}
+
+// OpenStorageTrie opens the storage trie of an account.
+func (l *hashLoader) OpenStorageTrie(stateRoot common.Hash, addrHash, root common.Hash) (triestate.Trie, error) {
+ return newTestHasher(addrHash, root, l.storages[addrHash])
+}
diff --git a/trie_by_cid/trie/preimages.go b/trie_by_cid/triedb/preimages.go
similarity index 85%
rename from trie_by_cid/trie/preimages.go
rename to trie_by_cid/triedb/preimages.go
index a6359ca..a538491 100644
--- a/trie_by_cid/trie/preimages.go
+++ b/trie_by_cid/triedb/preimages.go
@@ -14,7 +14,7 @@
// You should have received a copy of the GNU Lesser General Public License
// along with the go-ethereum library. If not, see .
-package trie
+package triedb
import (
"sync"
@@ -69,6 +69,23 @@ func (store *preimageStore) preimage(hash common.Hash) []byte {
return rawdb.ReadPreimage(store.disk, hash)
}
+// commit flushes the cached preimages into the disk.
+func (store *preimageStore) commit(force bool) error {
+ store.lock.Lock()
+ defer store.lock.Unlock()
+
+ if store.preimagesSize <= 4*1024*1024 && !force {
+ return nil
+ }
+ batch := store.disk.NewBatch()
+ rawdb.WritePreimages(batch, store.preimages)
+ if err := batch.Write(); err != nil {
+ return err
+ }
+ store.preimages, store.preimagesSize = make(map[common.Hash][]byte), 0
+ return nil
+}
+
// size returns the current storage size of accumulated preimages.
func (store *preimageStore) size() common.StorageSize {
store.lock.RLock()