Geth 1.13 (Deneb/Cancun) update #5

Merged
roysc merged 14 commits from update-geth-1.13 into v5 2024-05-29 10:00:13 +00:00
58 changed files with 8857 additions and 2828 deletions
Showing only changes of commit c49947243f - Show all commits

View File

@ -4,10 +4,10 @@ import (
"context" "context"
"errors" "errors"
"fmt" "fmt"
"math/big"
"github.com/VictoriaMetrics/fastcache" "github.com/VictoriaMetrics/fastcache"
lru "github.com/hashicorp/golang-lru" lru "github.com/hashicorp/golang-lru"
"github.com/holiman/uint256"
"github.com/cerc-io/plugeth-statediff/indexer/ipld" "github.com/cerc-io/plugeth-statediff/indexer/ipld"
"github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/common"
@ -99,8 +99,10 @@ func (sd *cachingDB) StateAccount(addressHash, blockHash common.Hash) (*types.St
// TODO: check expected behavior for deleted/non existing accounts // TODO: check expected behavior for deleted/non existing accounts
return nil, nil return nil, nil
} }
bal := new(big.Int) bal, err := uint256.FromDecimal(res.Balance)
bal.SetString(res.Balance, 10) if err != nil {
return nil, err
}
return &types.StateAccount{ return &types.StateAccount{
Nonce: res.Nonce, Nonce: res.Nonce,
Balance: bal, Balance: bal,

View File

@ -1,9 +1,8 @@
package state package state
import ( import (
"math/big"
"github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/common"
"github.com/holiman/uint256"
) )
// journalEntry is a modification entry in the state change journal that can be // journalEntry is a modification entry in the state change journal that can be
@ -74,19 +73,26 @@ type (
account *common.Address account *common.Address
} }
resetObjectChange struct { resetObjectChange struct {
account *common.Address
prev *stateObject prev *stateObject
prevdestruct bool prevdestruct bool
prevAccount []byte
prevStorage map[common.Hash][]byte
prevAccountOriginExist bool
prevAccountOrigin []byte
prevStorageOrigin map[common.Hash][]byte
} }
suicideChange struct { selfDestructChange struct {
account *common.Address account *common.Address
prev bool // whether account had already suicided prev bool // whether account had already self-destructed
prevbalance *big.Int prevbalance *uint256.Int
} }
// Changes to individual accounts. // Changes to individual accounts.
balanceChange struct { balanceChange struct {
account *common.Address account *common.Address
prev *big.Int prev *uint256.Int
} }
nonceChange struct { nonceChange struct {
account *common.Address account *common.Address
@ -141,21 +147,36 @@ func (ch createObjectChange) dirtied() *common.Address {
func (ch resetObjectChange) revert(s *StateDB) { func (ch resetObjectChange) revert(s *StateDB) {
s.setStateObject(ch.prev) s.setStateObject(ch.prev)
if !ch.prevdestruct {
delete(s.stateObjectsDestruct, ch.prev.address)
}
if ch.prevAccount != nil {
s.accounts[ch.prev.addrHash] = ch.prevAccount
}
if ch.prevStorage != nil {
s.storages[ch.prev.addrHash] = ch.prevStorage
}
if ch.prevAccountOriginExist {
s.accountsOrigin[ch.prev.address] = ch.prevAccountOrigin
}
if ch.prevStorageOrigin != nil {
s.storagesOrigin[ch.prev.address] = ch.prevStorageOrigin
}
} }
func (ch resetObjectChange) dirtied() *common.Address { func (ch resetObjectChange) dirtied() *common.Address {
return nil return ch.account
} }
func (ch suicideChange) revert(s *StateDB) { func (ch selfDestructChange) revert(s *StateDB) {
obj := s.getStateObject(*ch.account) obj := s.getStateObject(*ch.account)
if obj != nil { if obj != nil {
obj.suicided = ch.prev obj.selfDestructed = ch.prev
obj.setBalance(ch.prevbalance) obj.setBalance(ch.prevbalance)
} }
} }
func (ch suicideChange) dirtied() *common.Address { func (ch selfDestructChange) dirtied() *common.Address {
return ch.account return ch.account
} }

View File

@ -3,7 +3,6 @@ package state
import ( import (
"bytes" "bytes"
"fmt" "fmt"
"math/big"
"time" "time"
"github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/common"
@ -11,15 +10,7 @@ import (
"github.com/ethereum/go-ethereum/crypto" "github.com/ethereum/go-ethereum/crypto"
"github.com/ethereum/go-ethereum/metrics" "github.com/ethereum/go-ethereum/metrics"
"github.com/ethereum/go-ethereum/rlp" "github.com/ethereum/go-ethereum/rlp"
) "github.com/holiman/uint256"
var (
// emptyRoot is the known root hash of an empty trie.
// this is calculated as: emptyRoot = crypto.Keccak256(rlp.Encode([][]byte{}))
// that is, the keccak356 hash of the rlp encoding of an empty trie node (empty byte slice array)
emptyRoot = common.HexToHash("56e81f171bcc55a6ff8345e692c0f86e5b48e01b996cadc001622fb5e363b421")
// emptyCodeHash is the CodeHash for an EOA, for an account without contract code deployed
emptyCodeHash = crypto.Keccak256(nil)
) )
type Code []byte type Code []byte
@ -53,72 +44,66 @@ func (s Storage) Copy() Storage {
// First you need to obtain a state object. // First you need to obtain a state object.
// Account values can be accessed and modified through the object. // Account values can be accessed and modified through the object.
type stateObject struct { type stateObject struct {
db *StateDB
address common.Address address common.Address
addrHash common.Hash // hash of ethereum address of the account addrHash common.Hash // hash of ethereum address of the account
blockHash common.Hash // hash of the block this state object exists at or is being applied on top of blockHash common.Hash // hash of the block this state object exists at or is being applied on top of
origin *types.StateAccount // Account original data without any change applied, nil means it was not existent
data types.StateAccount data types.StateAccount
db *StateDB
// DB error.
// State objects are used by the consensus core and VM which are
// unable to deal with database-level errors. Any error that occurs
// during a database read is memoized here and will eventually be returned
// by StateDB.Commit.
dbErr error
// Write caches. // Write caches.
code Code // contract bytecode, which gets set when code is loaded code Code // contract bytecode, which gets set when code is loaded
originStorage Storage // Storage cache of original entries to dedup rewrites, reset for every transaction originStorage Storage // Storage cache of original entries to dedup rewrites
pendingStorage Storage // Storage entries that need to be flushed to disk, at the end of an entire block pendingStorage Storage // Storage entries that need to be flushed to disk, at the end of an entire block
dirtyStorage Storage // Storage entries that have been modified in the current transaction execution dirtyStorage Storage // Storage entries that have been modified in the current transaction execution, reset for every transaction
fakeStorage Storage // Fake storage which constructed by caller for debugging purpose.
// Cache flags. // Cache flags.
// When an object is marked suicided it will be delete from the trie
// during the "update" phase of the state transition.
dirtyCode bool // true if the code was updated dirtyCode bool // true if the code was updated
suicided bool
// Flag whether the account was marked as self-destructed. The self-destructed account
// is still accessible in the scope of same transaction.
selfDestructed bool
// Flag whether the account was marked as deleted. A self-destructed account
// or an account that is considered as empty will be marked as deleted at
// the end of transaction and no longer accessible anymore.
deleted bool deleted bool
// Flag whether the object was created in the current transaction
created bool
} }
// empty returns whether the account is considered empty. // empty returns whether the account is considered empty.
func (s *stateObject) empty() bool { func (s *stateObject) empty() bool {
return s.data.Nonce == 0 && s.data.Balance.Sign() == 0 && bytes.Equal(s.data.CodeHash, emptyCodeHash) return s.data.Nonce == 0 && s.data.Balance.IsZero() && bytes.Equal(s.data.CodeHash, types.EmptyCodeHash.Bytes())
} }
// newObject creates a state object. // newObject creates a state object.
func newObject(db *StateDB, address common.Address, data types.StateAccount, blockHash common.Hash) *stateObject { func newObject(db *StateDB, address common.Address, acct *types.StateAccount, blockHash common.Hash) *stateObject {
if data.Balance == nil { var (
data.Balance = new(big.Int) origin = acct
} created = acct == nil // true if the account was not existent
if data.CodeHash == nil { )
data.CodeHash = emptyCodeHash if acct == nil {
} acct = types.NewEmptyStateAccount()
if data.Root == (common.Hash{}) {
data.Root = emptyRoot
} }
return &stateObject{ return &stateObject{
db: db, db: db,
address: address, address: address,
addrHash: crypto.Keccak256Hash(address[:]), addrHash: crypto.Keccak256Hash(address[:]),
blockHash: blockHash, blockHash: blockHash,
data: data, origin: origin,
data: *acct,
originStorage: make(Storage), originStorage: make(Storage),
pendingStorage: make(Storage), pendingStorage: make(Storage),
dirtyStorage: make(Storage), dirtyStorage: make(Storage),
created: created,
} }
} }
// setError remembers the first non-nil error it is called with. func (s *stateObject) markSelfdestructed() {
func (s *stateObject) setError(err error) { s.selfDestructed = true
if s.dbErr == nil {
s.dbErr = err
}
}
func (s *stateObject) markSuicided() {
s.suicided = true
} }
func (s *stateObject) touch() { func (s *stateObject) touch() {
@ -133,46 +118,51 @@ func (s *stateObject) touch() {
} }
// GetState retrieves a value from the account storage trie. // GetState retrieves a value from the account storage trie.
func (s *stateObject) GetState(db StateDatabase, key common.Hash) common.Hash { func (s *stateObject) GetState(key common.Hash) common.Hash {
// If the fake storage is set, only lookup the state here(in the debugging mode)
if s.fakeStorage != nil {
return s.fakeStorage[key]
}
// If we have a dirty value for this state entry, return it // If we have a dirty value for this state entry, return it
value, dirty := s.dirtyStorage[key] value, dirty := s.dirtyStorage[key]
if dirty { if dirty {
return value return value
} }
// Otherwise return the entry's original value // Otherwise return the entry's original value
return s.GetCommittedState(db, key) return s.GetCommittedState(key)
} }
// GetCommittedState retrieves a value from the committed account storage trie. // GetCommittedState retrieves a value from the committed account storage trie.
func (s *stateObject) GetCommittedState(db StateDatabase, key common.Hash) common.Hash { func (s *stateObject) GetCommittedState(key common.Hash) common.Hash {
// If the fake storage is set, only lookup the state here(in the debugging mode) // If we have a pending write or clean cached, return that
if s.fakeStorage != nil { if value, pending := s.pendingStorage[key]; pending {
return s.fakeStorage[key] return value
} }
// If we have a pending write or clean cached, return that // If we have a pending write or clean cached, return that
if value, cached := s.originStorage[key]; cached { if value, cached := s.originStorage[key]; cached {
return value return value
} }
// If the object was destructed in *this* block (and potentially resurrected),
// the storage has been cleared out, and we should *not* consult the previous
// database about any storage values. The only possible alternatives are:
// 1) resurrect happened, and new slot values were set -- those should
// have been handles via pendingStorage above.
// 2) we don't have new values, and can deliver empty response back
if _, destructed := s.db.stateObjectsDestruct[s.address]; destructed {
return common.Hash{}
}
// If no live objects are available, load from database // If no live objects are available, load from database
start := time.Now() start := time.Now()
keyHash := crypto.Keccak256Hash(key[:]) keyHash := crypto.Keccak256Hash(key[:])
enc, err := db.StorageValue(s.addrHash, keyHash, s.blockHash) enc, err := s.db.db.StorageValue(s.addrHash, keyHash, s.blockHash)
if metrics.EnabledExpensive { if metrics.EnabledExpensive {
s.db.StorageReads += time.Since(start) s.db.StorageReads += time.Since(start)
} }
if err != nil { if err != nil {
s.setError(err) s.db.setError(err)
return common.Hash{} return common.Hash{}
} }
var value common.Hash var value common.Hash
if len(enc) > 0 { if len(enc) > 0 {
_, content, _, err := rlp.Split(enc) _, content, _, err := rlp.Split(enc)
if err != nil { if err != nil {
s.setError(err) s.db.setError(err)
} }
value.SetBytes(content) value.SetBytes(content)
} }
@ -181,14 +171,9 @@ func (s *stateObject) GetCommittedState(db StateDatabase, key common.Hash) commo
} }
// SetState updates a value in account storage. // SetState updates a value in account storage.
func (s *stateObject) SetState(db StateDatabase, key, value common.Hash) { func (s *stateObject) SetState(key, value common.Hash) {
// If the fake storage is set, put the temporary state update here.
if s.fakeStorage != nil {
s.fakeStorage[key] = value
return
}
// If the new value is the same as old, don't set // If the new value is the same as old, don't set
prev := s.GetState(db, key) prev := s.GetState(key)
if prev == value { if prev == value {
return return
} }
@ -201,63 +186,78 @@ func (s *stateObject) SetState(db StateDatabase, key, value common.Hash) {
s.setState(key, value) s.setState(key, value)
} }
// SetStorage replaces the entire state storage with the given one.
//
// After this function is called, all original state will be ignored and state
// lookup only happens in the fake state storage.
//
// Note this function should only be used for debugging purpose.
func (s *stateObject) SetStorage(storage map[common.Hash]common.Hash) {
// Allocate fake storage if it's nil.
if s.fakeStorage == nil {
s.fakeStorage = make(Storage)
}
for key, value := range storage {
s.fakeStorage[key] = value
}
// Don't bother journal since this function should only be used for
// debugging and the `fake` storage won't be committed to database.
}
func (s *stateObject) setState(key, value common.Hash) { func (s *stateObject) setState(key, value common.Hash) {
s.dirtyStorage[key] = value s.dirtyStorage[key] = value
} }
// finalise moves all dirty storage slots into the pending area to be hashed or
// committed later. It is invoked at the end of every transaction.
func (s *stateObject) finalise(prefetch bool) {
slotsToPrefetch := make([][]byte, 0, len(s.dirtyStorage))
for key, value := range s.dirtyStorage {
s.pendingStorage[key] = value
if value != s.originStorage[key] {
slotsToPrefetch = append(slotsToPrefetch, common.CopyBytes(key[:])) // Copy needed for closure
}
}
if len(s.dirtyStorage) > 0 {
s.dirtyStorage = make(Storage)
}
}
// AddBalance adds amount to s's balance. // AddBalance adds amount to s's balance.
// It is used to add funds to the destination account of a transfer. // It is used to add funds to the destination account of a transfer.
func (s *stateObject) AddBalance(amount *big.Int) { func (s *stateObject) AddBalance(amount *uint256.Int) {
// EIP161: We must check emptiness for the objects such that the account // EIP161: We must check emptiness for the objects such that the account
// clearing (0,0,0 objects) can take effect. // clearing (0,0,0 objects) can take effect.
if amount.Sign() == 0 { if amount.IsZero() {
if s.empty() { if s.empty() {
s.touch() s.touch()
} }
return return
} }
s.SetBalance(new(big.Int).Add(s.Balance(), amount)) s.SetBalance(new(uint256.Int).Add(s.Balance(), amount))
} }
// SubBalance removes amount from s's balance. // SubBalance removes amount from s's balance.
// It is used to remove funds from the origin account of a transfer. // It is used to remove funds from the origin account of a transfer.
func (s *stateObject) SubBalance(amount *big.Int) { func (s *stateObject) SubBalance(amount *uint256.Int) {
if amount.Sign() == 0 { if amount.IsZero() {
return return
} }
s.SetBalance(new(big.Int).Sub(s.Balance(), amount)) s.SetBalance(new(uint256.Int).Sub(s.Balance(), amount))
} }
func (s *stateObject) SetBalance(amount *big.Int) { func (s *stateObject) SetBalance(amount *uint256.Int) {
s.db.journal.append(balanceChange{ s.db.journal.append(balanceChange{
account: &s.address, account: &s.address,
prev: new(big.Int).Set(s.data.Balance), prev: new(uint256.Int).Set(s.data.Balance),
}) })
s.setBalance(amount) s.setBalance(amount)
} }
func (s *stateObject) setBalance(amount *big.Int) { func (s *stateObject) setBalance(amount *uint256.Int) {
s.data.Balance = amount s.data.Balance = amount
} }
func (s *stateObject) deepCopy(db *StateDB) *stateObject {
obj := &stateObject{
db: db,
address: s.address,
addrHash: s.addrHash,
origin: s.origin,
data: s.data,
}
obj.code = s.code
obj.dirtyStorage = s.dirtyStorage.Copy()
obj.originStorage = s.originStorage.Copy()
obj.pendingStorage = s.pendingStorage.Copy()
obj.selfDestructed = s.selfDestructed
obj.dirtyCode = s.dirtyCode
obj.deleted = s.deleted
return obj
}
// //
// Attribute accessors // Attribute accessors
// //
@ -268,16 +268,16 @@ func (s *stateObject) Address() common.Address {
} }
// Code returns the contract code associated with this object, if any. // Code returns the contract code associated with this object, if any.
func (s *stateObject) Code(db StateDatabase) []byte { func (s *stateObject) Code() []byte {
if s.code != nil { if s.code != nil {
return s.code return s.code
} }
if bytes.Equal(s.CodeHash(), emptyCodeHash) { if bytes.Equal(s.CodeHash(), types.EmptyCodeHash.Bytes()) {
return nil return nil
} }
code, err := db.ContractCode(common.BytesToHash(s.CodeHash())) code, err := s.db.db.ContractCode(common.BytesToHash(s.CodeHash()))
if err != nil { if err != nil {
s.setError(fmt.Errorf("can't load code hash %x: %v", s.CodeHash(), err)) s.db.setError(fmt.Errorf("can't load code hash %x: %v", s.CodeHash(), err))
} }
s.code = code s.code = code
return code return code
@ -286,22 +286,22 @@ func (s *stateObject) Code(db StateDatabase) []byte {
// CodeSize returns the size of the contract code associated with this object, // CodeSize returns the size of the contract code associated with this object,
// or zero if none. This method is an almost mirror of Code, but uses a cache // or zero if none. This method is an almost mirror of Code, but uses a cache
// inside the database to avoid loading codes seen recently. // inside the database to avoid loading codes seen recently.
func (s *stateObject) CodeSize(db StateDatabase) int { func (s *stateObject) CodeSize() int {
if s.code != nil { if s.code != nil {
return len(s.code) return len(s.code)
} }
if bytes.Equal(s.CodeHash(), emptyCodeHash) { if bytes.Equal(s.CodeHash(), types.EmptyCodeHash.Bytes()) {
return 0 return 0
} }
size, err := db.ContractCodeSize(common.BytesToHash(s.CodeHash())) size, err := s.db.db.ContractCodeSize(common.BytesToHash(s.CodeHash()))
if err != nil { if err != nil {
s.setError(fmt.Errorf("can't load code size %x: %v", s.CodeHash(), err)) s.db.setError(fmt.Errorf("can't load code size %x: %v", s.CodeHash(), err))
} }
return size return size
} }
func (s *stateObject) SetCode(codeHash common.Hash, code []byte) { func (s *stateObject) SetCode(codeHash common.Hash, code []byte) {
prevcode := s.Code(s.db.db) prevcode := s.Code()
s.db.journal.append(codeChange{ s.db.journal.append(codeChange{
account: &s.address, account: &s.address,
prevhash: s.CodeHash(), prevhash: s.CodeHash(),
@ -332,7 +332,7 @@ func (s *stateObject) CodeHash() []byte {
return s.data.CodeHash return s.data.CodeHash
} }
func (s *stateObject) Balance() *big.Int { func (s *stateObject) Balance() *uint256.Int {
return s.data.Balance return s.data.Balance
} }
@ -340,32 +340,6 @@ func (s *stateObject) Nonce() uint64 {
return s.data.Nonce return s.data.Nonce
} }
// Value is never called, but must be present to allow stateObject to be used func (s *stateObject) Root() common.Hash {
// as a vm.Account interface that also satisfies the vm.ContractRef return s.data.Root
// interface. Interfaces are awesome.
func (s *stateObject) Value() *big.Int {
panic("Value on stateObject should never be called")
}
// finalise moves all dirty storage slots into the pending area to be hashed or
// committed later. It is invoked at the end of every transaction.
func (s *stateObject) finalise(prefetch bool) {
for key, value := range s.dirtyStorage {
s.pendingStorage[key] = value
}
if len(s.dirtyStorage) > 0 {
s.dirtyStorage = make(Storage)
}
}
func (s *stateObject) deepCopy(db *StateDB) *stateObject {
stateObject := newObject(db, s.address, s.data, s.blockHash)
stateObject.code = s.code
stateObject.dirtyStorage = s.dirtyStorage.Copy()
stateObject.originStorage = s.originStorage.Copy()
stateObject.pendingStorage = s.pendingStorage.Copy()
stateObject.suicided = s.suicided
stateObject.dirtyCode = s.dirtyCode
stateObject.deleted = s.deleted
return stateObject
} }

View File

@ -2,7 +2,6 @@ package state
import ( import (
"fmt" "fmt"
"math/big"
"sort" "sort"
"time" "time"
@ -12,6 +11,13 @@ import (
"github.com/ethereum/go-ethereum/crypto" "github.com/ethereum/go-ethereum/crypto"
"github.com/ethereum/go-ethereum/metrics" "github.com/ethereum/go-ethereum/metrics"
"github.com/ethereum/go-ethereum/params" "github.com/ethereum/go-ethereum/params"
"github.com/holiman/uint256"
)
const (
// storageDeleteLimit denotes the highest permissible memory allocation
// employed for contract storage deletion.
storageDeleteLimit = 512 * 1024 * 1024
) )
/* /*
@ -41,20 +47,38 @@ type revision struct {
// StateDB structs within the ethereum protocol are used to store anything // StateDB structs within the ethereum protocol are used to store anything
// within the merkle trie. StateDBs take care of caching and storing // within the merkle trie. StateDBs take care of caching and storing
// nested states. It's the general query interface to retrieve: // nested states. It's the general query interface to retrieve:
//
// * Contracts // * Contracts
// * Accounts // * Accounts
//
// Once the state is committed, tries cached in stateDB (including account
// trie, storage tries) will no longer be functional. A new state instance
// must be created with new root and updated database for accessing post-
// commit states.
type StateDB struct { type StateDB struct {
db StateDatabase db Database
hasher crypto.KeccakState hasher crypto.KeccakState
// originBlockHash is the blockhash for the state we are working on top of // originBlockHash is the blockhash for the state we are working on top of
originBlockHash common.Hash originBlockHash common.Hash
// This map holds 'live' objects, which will get modified while processing a state transition. // originalRoot is the pre-state root, before any changes were made.
// It will be updated when the Commit is called.
originalRoot common.Hash
// These maps hold the state changes (including the corresponding
// original value) that occurred in this **block**.
accounts map[common.Hash][]byte // The mutated accounts in 'slim RLP' encoding
storages map[common.Hash]map[common.Hash][]byte // The mutated slots in prefix-zero trimmed rlp format
accountsOrigin map[common.Address][]byte // The original value of mutated accounts in 'slim RLP' encoding
storagesOrigin map[common.Address]map[common.Hash][]byte // The original value of mutated slots in prefix-zero trimmed rlp format
// This map holds 'live' objects, which will get modified while processing
// a state transition.
stateObjects map[common.Address]*stateObject stateObjects map[common.Address]*stateObject
stateObjectsPending map[common.Address]struct{} // State objects finalized but not yet written to the trie stateObjectsPending map[common.Address]struct{} // State objects finalized but not yet written to the trie
stateObjectsDirty map[common.Address]struct{} // State objects modified in the current execution stateObjectsDirty map[common.Address]struct{} // State objects modified in the current execution
stateObjectsDestruct map[common.Address]struct{} // State objects destructed in the block stateObjectsDestruct map[common.Address]*types.StateAccount // State objects destructed in the block along with its previous value
// DB error. // DB error.
// State objects are used by the consensus core and VM which are // State objects are used by the consensus core and VM which are
@ -66,11 +90,13 @@ type StateDB struct {
// The refund counter, also used by state transitioning. // The refund counter, also used by state transitioning.
refund uint64 refund uint64
// The tx context and all occurred logs in the scope of transaction.
thash common.Hash thash common.Hash
txIndex int txIndex int
logs map[common.Hash][]*types.Log logs map[common.Hash][]*types.Log
logSize uint logSize uint
// Preimages occurred seen by VM in the scope of block.
preimages map[common.Hash][]byte preimages map[common.Hash][]byte
// Per-transaction access list // Per-transaction access list
@ -91,14 +117,14 @@ type StateDB struct {
} }
// New creates a new StateDB on the state for the provided blockHash // New creates a new StateDB on the state for the provided blockHash
func New(blockHash common.Hash, db StateDatabase) (*StateDB, error) { func New(blockHash common.Hash, db Database) (*StateDB, error) {
sdb := &StateDB{ sdb := &StateDB{
db: db, db: db,
originBlockHash: blockHash, originBlockHash: blockHash,
stateObjects: make(map[common.Address]*stateObject), stateObjects: make(map[common.Address]*stateObject),
stateObjectsPending: make(map[common.Address]struct{}), stateObjectsPending: make(map[common.Address]struct{}),
stateObjectsDirty: make(map[common.Address]struct{}), stateObjectsDirty: make(map[common.Address]struct{}),
stateObjectsDestruct: make(map[common.Address]struct{}), stateObjectsDestruct: make(map[common.Address]*types.StateAccount),
logs: make(map[common.Hash][]*types.Log), logs: make(map[common.Hash][]*types.Log),
preimages: make(map[common.Hash][]byte), preimages: make(map[common.Hash][]byte),
journal: newJournal(), journal: newJournal(),
@ -153,7 +179,7 @@ func (s *StateDB) SubRefund(gas uint64) {
} }
// Exist reports whether the given account address exists in the state. // Exist reports whether the given account address exists in the state.
// Notably this also returns true for suicided accounts. // Notably this also returns true for self-destructed accounts.
func (s *StateDB) Exist(addr common.Address) bool { func (s *StateDB) Exist(addr common.Address) bool {
return s.getStateObject(addr) != nil return s.getStateObject(addr) != nil
} }
@ -166,14 +192,15 @@ func (s *StateDB) Empty(addr common.Address) bool {
} }
// GetBalance retrieves the balance from the given address or 0 if object not found // GetBalance retrieves the balance from the given address or 0 if object not found
func (s *StateDB) GetBalance(addr common.Address) *big.Int { func (s *StateDB) GetBalance(addr common.Address) *uint256.Int {
stateObject := s.getStateObject(addr) stateObject := s.getStateObject(addr)
if stateObject != nil { if stateObject != nil {
return stateObject.Balance() return stateObject.Balance()
} }
return common.Big0 return common.U2560
} }
// GetNonce retrieves the nonce from the given address or 0 if object not found
func (s *StateDB) GetNonce(addr common.Address) uint64 { func (s *StateDB) GetNonce(addr common.Address) uint64 {
stateObject := s.getStateObject(addr) stateObject := s.getStateObject(addr)
if stateObject != nil { if stateObject != nil {
@ -183,10 +210,25 @@ func (s *StateDB) GetNonce(addr common.Address) uint64 {
return 0 return 0
} }
// GetStorageRoot retrieves the storage root from the given address or empty
// if object not found.
func (s *StateDB) GetStorageRoot(addr common.Address) common.Hash {
stateObject := s.getStateObject(addr)
if stateObject != nil {
return stateObject.Root()
}
return common.Hash{}
}
// TxIndex returns the current transaction index set by Prepare.
func (s *StateDB) TxIndex() int {
return s.txIndex
}
func (s *StateDB) GetCode(addr common.Address) []byte { func (s *StateDB) GetCode(addr common.Address) []byte {
stateObject := s.getStateObject(addr) stateObject := s.getStateObject(addr)
if stateObject != nil { if stateObject != nil {
return stateObject.Code(s.db) return stateObject.Code()
} }
return nil return nil
} }
@ -194,24 +236,24 @@ func (s *StateDB) GetCode(addr common.Address) []byte {
func (s *StateDB) GetCodeSize(addr common.Address) int { func (s *StateDB) GetCodeSize(addr common.Address) int {
stateObject := s.getStateObject(addr) stateObject := s.getStateObject(addr)
if stateObject != nil { if stateObject != nil {
return stateObject.CodeSize(s.db) return stateObject.CodeSize()
} }
return 0 return 0
} }
func (s *StateDB) GetCodeHash(addr common.Address) common.Hash { func (s *StateDB) GetCodeHash(addr common.Address) common.Hash {
stateObject := s.getStateObject(addr) stateObject := s.getStateObject(addr)
if stateObject == nil { if stateObject != nil {
return common.Hash{}
}
return common.BytesToHash(stateObject.CodeHash()) return common.BytesToHash(stateObject.CodeHash())
} }
return common.Hash{}
}
// GetState retrieves a value from the given account's storage trie. // GetState retrieves a value from the given account's storage trie.
func (s *StateDB) GetState(addr common.Address, hash common.Hash) common.Hash { func (s *StateDB) GetState(addr common.Address, hash common.Hash) common.Hash {
stateObject := s.getStateObject(addr) stateObject := s.getStateObject(addr)
if stateObject != nil { if stateObject != nil {
return stateObject.GetState(s.db, hash) return stateObject.GetState(hash)
} }
return common.Hash{} return common.Hash{}
} }
@ -220,15 +262,20 @@ func (s *StateDB) GetState(addr common.Address, hash common.Hash) common.Hash {
func (s *StateDB) GetCommittedState(addr common.Address, hash common.Hash) common.Hash { func (s *StateDB) GetCommittedState(addr common.Address, hash common.Hash) common.Hash {
stateObject := s.getStateObject(addr) stateObject := s.getStateObject(addr)
if stateObject != nil { if stateObject != nil {
return stateObject.GetCommittedState(s.db, hash) return stateObject.GetCommittedState(hash)
} }
return common.Hash{} return common.Hash{}
} }
func (s *StateDB) HasSuicided(addr common.Address) bool { // Database retrieves the low level database supporting the lower level trie ops.
func (s *StateDB) Database() Database {
return s.db
}
func (s *StateDB) HasSelfDestructed(addr common.Address) bool {
stateObject := s.getStateObject(addr) stateObject := s.getStateObject(addr)
if stateObject != nil { if stateObject != nil {
return stateObject.suicided return stateObject.selfDestructed
} }
return false return false
} }
@ -238,7 +285,7 @@ func (s *StateDB) HasSuicided(addr common.Address) bool {
*/ */
// AddBalance adds amount to the account associated with addr. // AddBalance adds amount to the account associated with addr.
func (s *StateDB) AddBalance(addr common.Address, amount *big.Int) { func (s *StateDB) AddBalance(addr common.Address, amount *uint256.Int) {
stateObject := s.getOrNewStateObject(addr) stateObject := s.getOrNewStateObject(addr)
if stateObject != nil { if stateObject != nil {
stateObject.AddBalance(amount) stateObject.AddBalance(amount)
@ -246,14 +293,14 @@ func (s *StateDB) AddBalance(addr common.Address, amount *big.Int) {
} }
// SubBalance subtracts amount from the account associated with addr. // SubBalance subtracts amount from the account associated with addr.
func (s *StateDB) SubBalance(addr common.Address, amount *big.Int) { func (s *StateDB) SubBalance(addr common.Address, amount *uint256.Int) {
stateObject := s.getOrNewStateObject(addr) stateObject := s.getOrNewStateObject(addr)
if stateObject != nil { if stateObject != nil {
stateObject.SubBalance(amount) stateObject.SubBalance(amount)
} }
} }
func (s *StateDB) SetBalance(addr common.Address, amount *big.Int) { func (s *StateDB) SetBalance(addr common.Address, amount *uint256.Int) {
stateObject := s.getOrNewStateObject(addr) stateObject := s.getOrNewStateObject(addr)
if stateObject != nil { if stateObject != nil {
stateObject.SetBalance(amount) stateObject.SetBalance(amount)
@ -277,39 +324,59 @@ func (s *StateDB) SetCode(addr common.Address, code []byte) {
func (s *StateDB) SetState(addr common.Address, key, value common.Hash) { func (s *StateDB) SetState(addr common.Address, key, value common.Hash) {
stateObject := s.getOrNewStateObject(addr) stateObject := s.getOrNewStateObject(addr)
if stateObject != nil { if stateObject != nil {
stateObject.SetState(s.db, key, value) stateObject.SetState(key, value)
} }
} }
// SetStorage replaces the entire storage for the specified account with given // SetStorage replaces the entire storage for the specified account with given
// storage. This function should only be used for debugging. // storage. This function should only be used for debugging and the mutations
// must be discarded afterwards.
func (s *StateDB) SetStorage(addr common.Address, storage map[common.Hash]common.Hash) { func (s *StateDB) SetStorage(addr common.Address, storage map[common.Hash]common.Hash) {
s.stateObjectsDestruct[addr] = struct{}{} // SetStorage needs to wipe existing storage. We achieve this by pretending
// that the account self-destructed earlier in this block, by flagging
// it in stateObjectsDestruct. The effect of doing so is that storage lookups
// will not hit disk, since it is assumed that the disk-data is belonging
// to a previous incarnation of the object.
//
// TODO(rjl493456442) this function should only be supported by 'unwritable'
// state and all mutations made should all be discarded afterwards.
if _, ok := s.stateObjectsDestruct[addr]; !ok {
s.stateObjectsDestruct[addr] = nil
}
stateObject := s.getOrNewStateObject(addr) stateObject := s.getOrNewStateObject(addr)
if stateObject != nil { for k, v := range storage {
stateObject.SetStorage(storage) stateObject.SetState(k, v)
} }
} }
// Suicide marks the given account as suicided. // SelfDestruct marks the given account as selfdestructed.
// This clears the account balance. // This clears the account balance.
// //
// The account's state object is still available until the state is committed, // The account's state object is still available until the state is committed,
// getStateObject will return a non-nil account after Suicide. // getStateObject will return a non-nil account after SelfDestruct.
func (s *StateDB) Suicide(addr common.Address) bool { func (s *StateDB) SelfDestruct(addr common.Address) {
stateObject := s.getStateObject(addr) stateObject := s.getStateObject(addr)
if stateObject == nil { if stateObject == nil {
return false return
} }
s.journal.append(suicideChange{ s.journal.append(selfDestructChange{
account: &addr, account: &addr,
prev: stateObject.suicided, prev: stateObject.selfDestructed,
prevbalance: new(big.Int).Set(stateObject.Balance()), prevbalance: new(uint256.Int).Set(stateObject.Balance()),
}) })
stateObject.markSuicided() stateObject.markSelfdestructed()
stateObject.data.Balance = new(big.Int) stateObject.data.Balance = new(uint256.Int)
}
return true func (s *StateDB) Selfdestruct6780(addr common.Address) {
stateObject := s.getStateObject(addr)
if stateObject == nil {
return
}
if stateObject.created {
s.SelfDestruct(addr)
}
} }
// SetTransientState sets transient storage for a given account. It // SetTransientState sets transient storage for a given account. It
@ -380,7 +447,7 @@ func (s *StateDB) getDeletedStateObject(addr common.Address) *stateObject {
return nil return nil
} }
// Insert into the live set // Insert into the live set
obj := newObject(s, addr, *data, s.originBlockHash) obj := newObject(s, addr, data, s.originBlockHash)
s.setStateObject(obj) s.setStateObject(obj)
return obj return obj
} }
@ -402,19 +469,36 @@ func (s *StateDB) getOrNewStateObject(addr common.Address) *stateObject {
// the given address, it is overwritten and returned as the second return value. // the given address, it is overwritten and returned as the second return value.
func (s *StateDB) createObject(addr common.Address) (newobj, prev *stateObject) { func (s *StateDB) createObject(addr common.Address) (newobj, prev *stateObject) {
prev = s.getDeletedStateObject(addr) // Note, prev might have been deleted, we need that! prev = s.getDeletedStateObject(addr) // Note, prev might have been deleted, we need that!
newobj = newObject(s, addr, nil, s.originBlockHash)
var prevdestruct bool
if prev != nil {
_, prevdestruct = s.stateObjectsDestruct[prev.address]
if !prevdestruct {
s.stateObjectsDestruct[prev.address] = struct{}{}
}
}
newobj = newObject(s, addr, types.StateAccount{}, s.originBlockHash)
if prev == nil { if prev == nil {
s.journal.append(createObjectChange{account: &addr}) s.journal.append(createObjectChange{account: &addr})
} else { } else {
s.journal.append(resetObjectChange{prev: prev, prevdestruct: prevdestruct}) // NOTE: prevdestruct used to be set here from snapshot // The original account should be marked as destructed and all cached
// account and storage data should be cleared as well. Note, it must
// be done here, otherwise the destruction event of "original account"
// will be lost.
_, prevdestruct := s.stateObjectsDestruct[prev.address]
if !prevdestruct {
s.stateObjectsDestruct[prev.address] = prev.origin
}
// There may be some cached account/storage data already since IntermediateRoot
// will be called for each transaction before byzantium fork which will always
// cache the latest account/storage data.
prevAccount, ok := s.accountsOrigin[prev.address]
s.journal.append(resetObjectChange{
account: &addr,
prev: prev,
prevdestruct: prevdestruct,
prevAccount: s.accounts[prev.addrHash],
prevStorage: s.storages[prev.addrHash],
prevAccountOriginExist: ok,
prevAccountOrigin: prevAccount,
prevStorageOrigin: s.storagesOrigin[prev.address],
})
delete(s.accounts, prev.addrHash)
delete(s.storages, prev.addrHash)
delete(s.accountsOrigin, prev.address)
delete(s.storagesOrigin, prev.address)
} }
s.setStateObject(newobj) s.setStateObject(newobj)
if prev != nil && !prev.deleted { if prev != nil && !prev.deleted {
@ -440,14 +524,6 @@ func (s *StateDB) CreateAccount(addr common.Address) {
} }
} }
// ForEachStorage satisfies vm.StateDB but is not implemented
func (db *StateDB) ForEachStorage(addr common.Address, cb func(key, value common.Hash) bool) error {
// NOTE: as far as I can tell this method is only ever used in tests
// in that case, we can leave it unimplemented
// or if it needs to be implemented we can use iplfs-ethdb to do normal trie access
panic("ForEachStorage is not implemented")
}
// Snapshot returns an identifier for the current revision of the state. // Snapshot returns an identifier for the current revision of the state.
func (s *StateDB) Snapshot() int { func (s *StateDB) Snapshot() int {
id := s.nextRevisionId id := s.nextRevisionId
@ -477,6 +553,70 @@ func (s *StateDB) GetRefund() uint64 {
return s.refund return s.refund
} }
// Finalise finalises the state by removing the destructed objects and clears
// the journal as well as the refunds. Finalise, however, will not push any updates
// into the tries just yet. Only IntermediateRoot or Commit will do that.
func (s *StateDB) Finalise(deleteEmptyObjects bool) {
addressesToPrefetch := make([][]byte, 0, len(s.journal.dirties))
for addr := range s.journal.dirties {
obj, exist := s.stateObjects[addr]
if !exist {
// ripeMD is 'touched' at block 1714175, in tx 0x1237f737031e40bcde4a8b7e717b2d15e3ecadfe49bb1bbc71ee9deb09c6fcf2
// That tx goes out of gas, and although the notion of 'touched' does not exist there, the
// touch-event will still be recorded in the journal. Since ripeMD is a special snowflake,
// it will persist in the journal even though the journal is reverted. In this special circumstance,
// it may exist in `s.journal.dirties` but not in `s.stateObjects`.
// Thus, we can safely ignore it here
continue
}
if obj.selfDestructed || (deleteEmptyObjects && obj.empty()) {
obj.deleted = true
// We need to maintain account deletions explicitly (will remain
// set indefinitely). Note only the first occurred self-destruct
// event is tracked.
if _, ok := s.stateObjectsDestruct[obj.address]; !ok {
s.stateObjectsDestruct[obj.address] = obj.origin
}
// Note, we can't do this only at the end of a block because multiple
// transactions within the same block might self destruct and then
// resurrect an account; but the snapshotter needs both events.
delete(s.accounts, obj.addrHash) // Clear out any previously updated account data (may be recreated via a resurrect)
delete(s.storages, obj.addrHash) // Clear out any previously updated storage data (may be recreated via a resurrect)
delete(s.accountsOrigin, obj.address) // Clear out any previously updated account data (may be recreated via a resurrect)
delete(s.storagesOrigin, obj.address) // Clear out any previously updated storage data (may be recreated via a resurrect)
} else {
obj.finalise(true) // Prefetch slots in the background
}
obj.created = false
s.stateObjectsPending[addr] = struct{}{}
s.stateObjectsDirty[addr] = struct{}{}
// At this point, also ship the address off to the precacher. The precacher
// will start loading tries, and when the change is eventually committed,
// the commit-phase will be a lot faster
addressesToPrefetch = append(addressesToPrefetch, common.CopyBytes(addr[:])) // Copy needed for closure
}
// Invalidate journal because reverting across transactions is not allowed.
s.clearJournalAndRefund()
}
// SetTxContext sets the current transaction hash and index which are
// used when the EVM emits new state logs. It should be invoked before
// transaction execution.
func (s *StateDB) SetTxContext(thash common.Hash, ti int) {
s.thash = thash
s.txIndex = ti
}
func (s *StateDB) clearJournalAndRefund() {
if len(s.journal.entries) > 0 {
s.journal = newJournal()
s.refund = 0
}
s.validRevisions = s.validRevisions[:0] // Snapshots can be created without journal entries
}
// Prepare handles the preparatory steps for executing a state transition with. // Prepare handles the preparatory steps for executing a state transition with.
// This method must be invoked before state transition. // This method must be invoked before state transition.
// //
@ -553,65 +693,21 @@ func (s *StateDB) SlotInAccessList(addr common.Address, slot common.Hash) (addre
return s.accessList.Contains(addr, slot) return s.accessList.Contains(addr, slot)
} }
// Finalise finalises the state by removing the destructed objects and clears
// the journal as well as the refunds. Finalise, however, will not push any updates
// into the tries just yet. Only IntermediateRoot or Commit will do that.
func (s *StateDB) Finalise(deleteEmptyObjects bool) {
for addr := range s.journal.dirties {
obj, exist := s.stateObjects[addr]
if !exist {
// ripeMD is 'touched' at block 1714175, in tx 0x1237f737031e40bcde4a8b7e717b2d15e3ecadfe49bb1bbc71ee9deb09c6fcf2
// That tx goes out of gas, and although the notion of 'touched' does not exist there, the
// touch-event will still be recorded in the journal. Since ripeMD is a special snowflake,
// it will persist in the journal even though the journal is reverted. In this special circumstance,
// it may exist in `s.journal.dirties` but not in `s.stateObjects`.
// Thus, we can safely ignore it here
continue
}
if obj.suicided || (deleteEmptyObjects && obj.empty()) {
obj.deleted = true
// We need to maintain account deletions explicitly (will remain
// set indefinitely).
s.stateObjectsDestruct[obj.address] = struct{}{}
} else {
obj.finalise(true) // Prefetch slots in the background
}
s.stateObjectsPending[addr] = struct{}{}
s.stateObjectsDirty[addr] = struct{}{}
}
// Invalidate journal because reverting across transactions is not allowed.
s.clearJournalAndRefund()
}
// SetTxContext sets the current transaction hash and index which are
// used when the EVM emits new state logs. It should be invoked before
// transaction execution.
func (s *StateDB) SetTxContext(thash common.Hash, ti int) {
s.thash = thash
s.txIndex = ti
}
func (s *StateDB) clearJournalAndRefund() {
if len(s.journal.entries) > 0 {
s.journal = newJournal()
s.refund = 0
}
s.validRevisions = s.validRevisions[:0] // Snapshots can be created without journal entries
}
// Copy creates a deep, independent copy of the state. // Copy creates a deep, independent copy of the state.
// Snapshots of the copied state cannot be applied to the copy. // Snapshots of the copied state cannot be applied to the copy.
func (s *StateDB) Copy() *StateDB { func (s *StateDB) Copy() *StateDB {
// Copy all the basic fields, initialize the memory ones // Copy all the basic fields, initialize the memory ones
state := &StateDB{ state := &StateDB{
db: s.db, db: s.db,
originBlockHash: s.originBlockHash, originalRoot: s.originalRoot,
accounts: make(map[common.Hash][]byte),
storages: make(map[common.Hash]map[common.Hash][]byte),
accountsOrigin: make(map[common.Address][]byte),
storagesOrigin: make(map[common.Address]map[common.Hash][]byte),
stateObjects: make(map[common.Address]*stateObject, len(s.journal.dirties)), stateObjects: make(map[common.Address]*stateObject, len(s.journal.dirties)),
stateObjectsPending: make(map[common.Address]struct{}, len(s.stateObjectsPending)), stateObjectsPending: make(map[common.Address]struct{}, len(s.stateObjectsPending)),
stateObjectsDirty: make(map[common.Address]struct{}, len(s.journal.dirties)), stateObjectsDirty: make(map[common.Address]struct{}, len(s.journal.dirties)),
stateObjectsDestruct: make(map[common.Address]struct{}, len(s.stateObjectsDestruct)), stateObjectsDestruct: make(map[common.Address]*types.StateAccount, len(s.stateObjectsDestruct)),
refund: s.refund, refund: s.refund,
logs: make(map[common.Hash][]*types.Log, len(s.logs)), logs: make(map[common.Hash][]*types.Log, len(s.logs)),
logSize: s.logSize, logSize: s.logSize,
@ -651,10 +747,18 @@ func (s *StateDB) Copy() *StateDB {
} }
state.stateObjectsDirty[addr] = struct{}{} state.stateObjectsDirty[addr] = struct{}{}
} }
// Deep copy the destruction flag. // Deep copy the destruction markers.
for addr := range s.stateObjectsDestruct { for addr, value := range s.stateObjectsDestruct {
state.stateObjectsDestruct[addr] = struct{}{} state.stateObjectsDestruct[addr] = value
} }
// Deep copy the state changes made in the scope of block
// along with their original values.
state.accounts = copySet(s.accounts)
state.storages = copy2DSet(s.storages)
state.accountsOrigin = copySet(state.accountsOrigin)
state.storagesOrigin = copy2DSet(state.storagesOrigin)
// Deep copy the logs occurred in the scope of block
for hash, logs := range s.logs { for hash, logs := range s.logs {
cpy := make([]*types.Log, len(logs)) cpy := make([]*types.Log, len(logs))
for i, l := range logs { for i, l := range logs {
@ -663,6 +767,7 @@ func (s *StateDB) Copy() *StateDB {
} }
state.logs[hash] = cpy state.logs[hash] = cpy
} }
// Deep copy the preimages occurred in the scope of block
for hash, preimage := range s.preimages { for hash, preimage := range s.preimages {
state.preimages[hash] = preimage state.preimages[hash] = preimage
} }
@ -674,6 +779,26 @@ func (s *StateDB) Copy() *StateDB {
// in the middle of a transaction. // in the middle of a transaction.
state.accessList = s.accessList.Copy() state.accessList = s.accessList.Copy()
state.transientStorage = s.transientStorage.Copy() state.transientStorage = s.transientStorage.Copy()
return state return state
} }
// copySet returns a deep-copied set.
func copySet[k comparable](set map[k][]byte) map[k][]byte {
copied := make(map[k][]byte, len(set))
for key, val := range set {
copied[key] = common.CopyBytes(val)
}
return copied
}
// copy2DSet returns a two-dimensional deep-copied set.
func copy2DSet[k comparable](set map[k]map[common.Hash][]byte) map[k]map[common.Hash][]byte {
copied := make(map[k]map[common.Hash][]byte, len(set))
for addr, subset := range set {
copied[addr] = make(map[common.Hash][]byte, len(subset))
for key, val := range subset {
copied[addr][key] = common.CopyBytes(val)
}
}
return copied
}

View File

@ -5,6 +5,7 @@ import (
"math/big" "math/big"
"testing" "testing"
"github.com/holiman/uint256"
"github.com/lib/pq" "github.com/lib/pq"
"github.com/multiformats/go-multihash" "github.com/multiformats/go-multihash"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
@ -66,7 +67,7 @@ var (
Account = types.StateAccount{ Account = types.StateAccount{
Nonce: uint64(0), Nonce: uint64(0),
Balance: big.NewInt(1000), Balance: uint256.NewInt(1000),
CodeHash: AccountCodeHash.Bytes(), CodeHash: AccountCodeHash.Bytes(),
Root: common.Hash{}, Root: common.Hash{},
} }
@ -112,7 +113,7 @@ func TestPGXSuite(t *testing.T) {
database := sql.NewPGXDriverFromPool(context.Background(), pool) database := sql.NewPGXDriverFromPool(context.Background(), pool)
insertSuiteData(t, database) insertSuiteData(t, database)
db := state.NewStateDatabase(database) db := state.NewDatabase(database)
require.NoError(t, err) require.NoError(t, err)
testSuite(t, db) testSuite(t, db)
} }
@ -137,7 +138,7 @@ func TestSQLXSuite(t *testing.T) {
database := sql.NewSQLXDriverFromPool(context.Background(), pool) database := sql.NewSQLXDriverFromPool(context.Background(), pool)
insertSuiteData(t, database) insertSuiteData(t, database)
db := state.NewStateDatabase(database) db := state.NewDatabase(database)
require.NoError(t, err) require.NoError(t, err)
testSuite(t, db) testSuite(t, db)
} }
@ -226,7 +227,7 @@ func insertSuiteData(t *testing.T, database sql.Database) {
require.NoError(t, insertContractCode(database)) require.NoError(t, insertContractCode(database))
} }
func testSuite(t *testing.T, db state.StateDatabase) { func testSuite(t *testing.T, db state.Database) {
t.Run("Database", func(t *testing.T) { t.Run("Database", func(t *testing.T) {
size, err := db.ContractCodeSize(AccountCodeHash) size, err := db.ContractCodeSize(AccountCodeHash)
require.NoError(t, err) require.NoError(t, err)
@ -309,14 +310,14 @@ func testSuite(t *testing.T, db state.StateDatabase) {
newStorage := crypto.Keccak256Hash([]byte{5, 4, 3, 2, 1}) newStorage := crypto.Keccak256Hash([]byte{5, 4, 3, 2, 1})
newCode := []byte{1, 3, 3, 7} newCode := []byte{1, 3, 3, 7}
sdb.SetBalance(AccountAddress, big.NewInt(300)) sdb.SetBalance(AccountAddress, uint256.NewInt(300))
sdb.AddBalance(AccountAddress, big.NewInt(200)) sdb.AddBalance(AccountAddress, uint256.NewInt(200))
sdb.SubBalance(AccountAddress, big.NewInt(100)) sdb.SubBalance(AccountAddress, uint256.NewInt(100))
sdb.SetNonce(AccountAddress, 42) sdb.SetNonce(AccountAddress, 42)
sdb.SetState(AccountAddress, StorageSlot, newStorage) sdb.SetState(AccountAddress, StorageSlot, newStorage)
sdb.SetCode(AccountAddress, newCode) sdb.SetCode(AccountAddress, newCode)
require.Equal(t, big.NewInt(400), sdb.GetBalance(AccountAddress)) require.Equal(t, uint256.NewInt(400), sdb.GetBalance(AccountAddress))
require.Equal(t, uint64(42), sdb.GetNonce(AccountAddress)) require.Equal(t, uint64(42), sdb.GetNonce(AccountAddress))
require.Equal(t, newStorage, sdb.GetState(AccountAddress, StorageSlot)) require.Equal(t, newStorage, sdb.GetState(AccountAddress, StorageSlot))
require.Equal(t, newCode, sdb.GetCode(AccountAddress)) require.Equal(t, newCode, sdb.GetCode(AccountAddress))

1
go.mod
View File

@ -116,3 +116,4 @@ replace github.com/cerc-io/plugeth-statediff => git.vdb.to/cerc-io/plugeth-state
// dev // dev
replace github.com/cerc-io/ipfs-ethdb/v5 => git.vdb.to/cerc-io/ipfs-ethdb/v5 v5.0.1-alpha.0.20240403094152-a95b1aea6c5c replace github.com/cerc-io/ipfs-ethdb/v5 => git.vdb.to/cerc-io/ipfs-ethdb/v5 v5.0.1-alpha.0.20240403094152-a95b1aea6c5c
replace github.com/ethereum/go-ethereum => ../go-ethereum

View File

@ -5,10 +5,18 @@ import (
"time" "time"
pgipfsethdb "github.com/cerc-io/ipfs-ethdb/v5/postgres/v0" pgipfsethdb "github.com/cerc-io/ipfs-ethdb/v5/postgres/v0"
"github.com/cerc-io/plugeth-statediff/indexer/ipld"
"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/ethdb"
"github.com/ipfs/go-cid" "github.com/ipfs/go-cid"
"github.com/multiformats/go-multihash" "github.com/multiformats/go-multihash"
) )
var (
StateTrieCodec uint64 = ipld.MEthStateTrie
StorageTrieCodec uint64 = ipld.MEthStorageTrie
)
func Keccak256ToCid(codec uint64, h []byte) (cid.Cid, error) { func Keccak256ToCid(codec uint64, h []byte) (cid.Cid, error) {
buf, err := multihash.Encode(h, multihash.KECCAK_256) buf, err := multihash.Encode(h, multihash.KECCAK_256)
if err != nil { if err != nil {
@ -25,3 +33,33 @@ func MakeCacheConfig(t testing.TB) pgipfsethdb.CacheConfig {
ExpiryDuration: time.Hour, ExpiryDuration: time.Hour,
} }
} }
// ReadLegacyTrieNode retrieves the legacy trie node with the given
// associated node hash.
func ReadLegacyTrieNode(db ethdb.KeyValueReader, hash common.Hash, codec uint64) ([]byte, error) {
cid, err := Keccak256ToCid(codec, hash[:])
if err != nil {
return nil, err
}
enc, err := db.Get(cid.Bytes())
if err != nil {
return nil, err
}
return enc, nil
}
func WriteLegacyTrieNode(db ethdb.KeyValueWriter, hash common.Hash, codec uint64, data []byte) error {
cid, err := Keccak256ToCid(codec, hash[:])
if err != nil {
return err
}
return db.Put(cid.Bytes(), data)
}
func ReadCode(db ethdb.KeyValueReader, hash common.Hash) ([]byte, error) {
return ReadLegacyTrieNode(db, hash, ipld.RawBinary)
}
func WriteCode(db ethdb.KeyValueWriter, hash common.Hash, code []byte) error {
return WriteLegacyTrieNode(db, hash, ipld.RawBinary, code)
}

View File

@ -20,7 +20,7 @@ import (
"errors" "errors"
"fmt" "fmt"
"github.com/cerc-io/plugeth-statediff/indexer/ipld" "github.com/crate-crypto/go-ipa/banderwagon"
"github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/common/lru" "github.com/ethereum/go-ethereum/common/lru"
"github.com/ethereum/go-ethereum/core/types" "github.com/ethereum/go-ethereum/core/types"
@ -28,6 +28,9 @@ import (
"github.com/cerc-io/ipld-eth-statedb/internal" "github.com/cerc-io/ipld-eth-statedb/internal"
"github.com/cerc-io/ipld-eth-statedb/trie_by_cid/trie" "github.com/cerc-io/ipld-eth-statedb/trie_by_cid/trie"
"github.com/cerc-io/ipld-eth-statedb/trie_by_cid/trie/trienode"
"github.com/cerc-io/ipld-eth-statedb/trie_by_cid/trie/utils"
"github.com/cerc-io/ipld-eth-statedb/trie_by_cid/triedb"
) )
const ( const (
@ -36,6 +39,12 @@ const (
// Cache size granted for caching clean code. // Cache size granted for caching clean code.
codeCacheSize = 64 * 1024 * 1024 codeCacheSize = 64 * 1024 * 1024
// commitmentSize is the size of commitment stored in cache.
commitmentSize = banderwagon.UncompressedSize
// Cache item granted for caching commitment results.
commitmentCacheItems = 64 * 1024 * 1024 / (commitmentSize + common.AddressLength)
) )
// Database wraps access to tries and contract code. // Database wraps access to tries and contract code.
@ -44,22 +53,22 @@ type Database interface {
OpenTrie(root common.Hash) (Trie, error) OpenTrie(root common.Hash) (Trie, error)
// OpenStorageTrie opens the storage trie of an account. // OpenStorageTrie opens the storage trie of an account.
OpenStorageTrie(stateRoot common.Hash, addrHash, root common.Hash) (Trie, error) OpenStorageTrie(stateRoot, addrHash common.Hash, root common.Hash, trie Trie) (Trie, error)
// CopyTrie returns an independent copy of the given trie. // CopyTrie returns an independent copy of the given trie.
CopyTrie(Trie) Trie CopyTrie(Trie) Trie
// ContractCode retrieves a particular contract's code. // ContractCode retrieves a particular contract's code.
ContractCode(codeHash common.Hash) ([]byte, error) ContractCode(addr common.Address, codeHash common.Hash) ([]byte, error)
// ContractCodeSize retrieves a particular contracts code's size. // ContractCodeSize retrieves a particular contracts code's size.
ContractCodeSize(codeHash common.Hash) (int, error) ContractCodeSize(addr common.Address, codeHash common.Hash) (int, error)
// DiskDB returns the underlying key-value disk database. // DiskDB returns the underlying key-value disk database.
DiskDB() ethdb.KeyValueStore DiskDB() ethdb.KeyValueStore
// TrieDB retrieves the low level trie database used for data storage. // TrieDB returns the underlying trie database for managing trie nodes.
TrieDB() *trie.Database TrieDB() *triedb.Database
} }
// Trie is a Ethereum Merkle Patricia trie. // Trie is a Ethereum Merkle Patricia trie.
@ -70,40 +79,40 @@ type Trie interface {
// TODO(fjl): remove this when StateTrie is removed // TODO(fjl): remove this when StateTrie is removed
GetKey([]byte) []byte GetKey([]byte) []byte
// TryGet returns the value for key stored in the trie. The value bytes must // GetAccount abstracts an account read from the trie. It retrieves the
// not be modified by the caller. If a node was not found in the database, a
// trie.MissingNodeError is returned.
TryGet(key []byte) ([]byte, error)
// TryGetNode attempts to retrieve a trie node by compact-encoded path. It is not
// possible to use keybyte-encoding as the path might contain odd nibbles.
TryGetNode(path []byte) ([]byte, int, error)
// TryGetAccount abstracts an account read from the trie. It retrieves the
// account blob from the trie with provided account address and decodes it // account blob from the trie with provided account address and decodes it
// with associated decoding algorithm. If the specified account is not in // with associated decoding algorithm. If the specified account is not in
// the trie, nil will be returned. If the trie is corrupted(e.g. some nodes // the trie, nil will be returned. If the trie is corrupted(e.g. some nodes
// are missing or the account blob is incorrect for decoding), an error will // are missing or the account blob is incorrect for decoding), an error will
// be returned. // be returned.
TryGetAccount(address common.Address) (*types.StateAccount, error) GetAccount(address common.Address) (*types.StateAccount, error)
// TryUpdate associates key with value in the trie. If value has length zero, any // GetStorage returns the value for key stored in the trie. The value bytes
// existing value is deleted from the trie. The value bytes must not be modified // must not be modified by the caller. If a node was not found in the database,
// by the caller while they are stored in the trie. If a node was not found in the // a trie.MissingNodeError is returned.
// database, a trie.MissingNodeError is returned. GetStorage(addr common.Address, key []byte) ([]byte, error)
TryUpdate(key, value []byte) error
// TryUpdateAccount abstracts an account write to the trie. It encodes the // UpdateAccount abstracts an account write to the trie. It encodes the
// provided account object with associated algorithm and then updates it // provided account object with associated algorithm and then updates it
// in the trie with provided address. // in the trie with provided address.
TryUpdateAccount(address common.Address, account *types.StateAccount) error UpdateAccount(address common.Address, account *types.StateAccount) error
// TryDelete removes any existing value for key from the trie. If a node was not // UpdateStorage associates key with value in the trie. If value has length zero,
// found in the database, a trie.MissingNodeError is returned. // any existing value is deleted from the trie. The value bytes must not be modified
TryDelete(key []byte) error // by the caller while they are stored in the trie. If a node was not found in the
// database, a trie.MissingNodeError is returned.
UpdateStorage(addr common.Address, key, value []byte) error
// TryDeleteAccount abstracts an account deletion from the trie. // DeleteAccount abstracts an account deletion from the trie.
TryDeleteAccount(address common.Address) error DeleteAccount(address common.Address) error
// DeleteStorage removes any existing value for key from the trie. If a node
// was not found in the database, a trie.MissingNodeError is returned.
DeleteStorage(addr common.Address, key []byte) error
// UpdateContractCode abstracts code write to the trie. It is expected
// to be moved to the stateWriter interface when the latter is ready.
UpdateContractCode(address common.Address, codeHash common.Hash, code []byte) error
// Hash returns the root hash of the trie. It does not write to the database and // Hash returns the root hash of the trie. It does not write to the database and
// can be used even if the trie doesn't have one. // can be used even if the trie doesn't have one.
@ -115,11 +124,12 @@ type Trie interface {
// The returned nodeset can be nil if the trie is clean(nothing to commit). // The returned nodeset can be nil if the trie is clean(nothing to commit).
// Once the trie is committed, it's not usable anymore. A new trie must // Once the trie is committed, it's not usable anymore. A new trie must
// be created with new root and updated trie database for following usage // be created with new root and updated trie database for following usage
Commit(collectLeaf bool) (common.Hash, *trie.NodeSet) Commit(collectLeaf bool) (common.Hash, *trienode.NodeSet, error)
// NodeIterator returns an iterator that returns nodes of the trie. Iteration // NodeIterator returns an iterator that returns nodes of the trie. Iteration
// starts at the key after the given start key. // starts at the key after the given start key. And error will be returned
NodeIterator(startKey []byte) trie.NodeIterator // if fails to create node iterator.
NodeIterator(startKey []byte) (trie.NodeIterator, error)
// Prove constructs a Merkle proof for key. The result contains all encoded nodes // Prove constructs a Merkle proof for key. The result contains all encoded nodes
// on the path to the value at key. The value itself is also included in the last // on the path to the value at key. The value itself is also included in the last
@ -128,7 +138,7 @@ type Trie interface {
// If the trie does not contain a value for key, the returned proof contains all // If the trie does not contain a value for key, the returned proof contains all
// nodes of the longest existing prefix of the key (at least the root), ending // nodes of the longest existing prefix of the key (at least the root), ending
// with the node that proves the absence of the key. // with the node that proves the absence of the key.
Prove(key []byte, fromLevel uint, proofDb ethdb.KeyValueWriter) error Prove(key []byte, proofDb ethdb.KeyValueWriter) error
} }
// NewDatabase creates a backing store for state. The returned database is safe for // NewDatabase creates a backing store for state. The returned database is safe for
@ -141,17 +151,17 @@ func NewDatabase(db ethdb.Database) Database {
// NewDatabaseWithConfig creates a backing store for state. The returned database // NewDatabaseWithConfig creates a backing store for state. The returned database
// is safe for concurrent use and retains a lot of collapsed RLP trie nodes in a // is safe for concurrent use and retains a lot of collapsed RLP trie nodes in a
// large memory cache. // large memory cache.
func NewDatabaseWithConfig(db ethdb.Database, config *trie.Config) Database { func NewDatabaseWithConfig(db ethdb.Database, config *triedb.Config) Database {
return &cachingDB{ return &cachingDB{
disk: db, disk: db,
codeSizeCache: lru.NewCache[common.Hash, int](codeSizeCacheSize), codeSizeCache: lru.NewCache[common.Hash, int](codeSizeCacheSize),
codeCache: lru.NewSizeConstrainedCache[common.Hash, []byte](codeCacheSize), codeCache: lru.NewSizeConstrainedCache[common.Hash, []byte](codeCacheSize),
triedb: trie.NewDatabaseWithConfig(db, config), triedb: triedb.NewDatabase(db, config),
} }
} }
// NewDatabaseWithNodeDB creates a state database with an already initialized node database. // NewDatabaseWithNodeDB creates a state database with an already initialized node database.
func NewDatabaseWithNodeDB(db ethdb.Database, triedb *trie.Database) Database { func NewDatabaseWithNodeDB(db ethdb.Database, triedb *triedb.Database) Database {
return &cachingDB{ return &cachingDB{
disk: db, disk: db,
codeSizeCache: lru.NewCache[common.Hash, int](codeSizeCacheSize), codeSizeCache: lru.NewCache[common.Hash, int](codeSizeCacheSize),
@ -164,12 +174,15 @@ type cachingDB struct {
disk ethdb.KeyValueStore disk ethdb.KeyValueStore
codeSizeCache *lru.Cache[common.Hash, int] codeSizeCache *lru.Cache[common.Hash, int]
codeCache *lru.SizeConstrainedCache[common.Hash, []byte] codeCache *lru.SizeConstrainedCache[common.Hash, []byte]
triedb *trie.Database triedb *triedb.Database
} }
// OpenTrie opens the main account trie at a specific root hash. // OpenTrie opens the main account trie at a specific root hash.
func (db *cachingDB) OpenTrie(root common.Hash) (Trie, error) { func (db *cachingDB) OpenTrie(root common.Hash) (Trie, error) {
tr, err := trie.NewStateTrie(trie.StateTrieID(root), db.triedb, trie.StateTrieCodec) if db.triedb.IsVerkle() {
return trie.NewVerkleTrie(root, db.triedb, utils.NewPointCache(commitmentCacheItems))
}
tr, err := trie.NewStateTrie(trie.StateTrieID(root), db.triedb)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -177,8 +190,14 @@ func (db *cachingDB) OpenTrie(root common.Hash) (Trie, error) {
} }
// OpenStorageTrie opens the storage trie of an account. // OpenStorageTrie opens the storage trie of an account.
func (db *cachingDB) OpenStorageTrie(stateRoot common.Hash, addrHash, root common.Hash) (Trie, error) { func (db *cachingDB) OpenStorageTrie(stateRoot, addrHash common.Hash, root common.Hash, self Trie) (Trie, error) {
tr, err := trie.NewStateTrie(trie.StorageTrieID(stateRoot, addrHash, root), db.triedb, trie.StorageTrieCodec) // In the verkle case, there is only one tree. But the two-tree structure
// is hardcoded in the codebase. So we need to return the same trie in this
// case.
if db.triedb.IsVerkle() {
return self, nil
}
tr, err := trie.NewStateTrie(trie.StorageTrieID(stateRoot, addrHash, root), db.triedb)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -196,16 +215,12 @@ func (db *cachingDB) CopyTrie(t Trie) Trie {
} }
// ContractCode retrieves a particular contract's code. // ContractCode retrieves a particular contract's code.
func (db *cachingDB) ContractCode(codeHash common.Hash) ([]byte, error) { func (db *cachingDB) ContractCode(address common.Address, codeHash common.Hash) ([]byte, error) {
code, _ := db.codeCache.Get(codeHash) code, _ := db.codeCache.Get(codeHash)
if len(code) > 0 { if len(code) > 0 {
return code, nil return code, nil
} }
cid, err := internal.Keccak256ToCid(ipld.RawBinary, codeHash.Bytes()) code, err := internal.ReadCode(db.disk, codeHash)
if err != nil {
return nil, err
}
code, err = db.disk.Get(cid.Bytes())
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -217,12 +232,19 @@ func (db *cachingDB) ContractCode(codeHash common.Hash) ([]byte, error) {
return nil, errors.New("not found") return nil, errors.New("not found")
} }
// ContractCodeWithPrefix retrieves a particular contract's code. If the
// code can't be found in the cache, then check the existence with **new**
// db scheme.
func (db *cachingDB) ContractCodeWithPrefix(address common.Address, codeHash common.Hash) ([]byte, error) {
return db.ContractCode(address, codeHash)
}
// ContractCodeSize retrieves a particular contracts code's size. // ContractCodeSize retrieves a particular contracts code's size.
func (db *cachingDB) ContractCodeSize(codeHash common.Hash) (int, error) { func (db *cachingDB) ContractCodeSize(addr common.Address, codeHash common.Hash) (int, error) {
if cached, ok := db.codeSizeCache.Get(codeHash); ok { if cached, ok := db.codeSizeCache.Get(codeHash); ok {
return cached, nil return cached, nil
} }
code, err := db.ContractCode(codeHash) code, err := db.ContractCode(addr, codeHash)
return len(code), err return len(code), err
} }
@ -232,6 +254,6 @@ func (db *cachingDB) DiskDB() ethdb.KeyValueStore {
} }
// TrieDB retrieves any intermediate trie-node caching layer. // TrieDB retrieves any intermediate trie-node caching layer.
func (db *cachingDB) TrieDB() *trie.Database { func (db *cachingDB) TrieDB() *triedb.Database {
return db.triedb return db.triedb
} }

236
trie_by_cid/state/dump.go Normal file
View File

@ -0,0 +1,236 @@
// Copyright 2014 The go-ethereum Authors
// This file is part of the go-ethereum library.
//
// The go-ethereum library is free software: you can redistribute it and/or modify
// it under the terms of the GNU Lesser General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// The go-ethereum library is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Lesser General Public License for more details.
//
// You should have received a copy of the GNU Lesser General Public License
// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
package state
import (
"encoding/json"
"fmt"
"time"
"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/common/hexutil"
"github.com/ethereum/go-ethereum/core/types"
"github.com/ethereum/go-ethereum/log"
"github.com/ethereum/go-ethereum/rlp"
"github.com/cerc-io/ipld-eth-statedb/trie_by_cid/trie"
)
// DumpConfig is a set of options to control what portions of the state will be
// iterated and collected.
type DumpConfig struct {
SkipCode bool
SkipStorage bool
OnlyWithAddresses bool
Start []byte
Max uint64
}
// DumpCollector interface which the state trie calls during iteration
type DumpCollector interface {
// OnRoot is called with the state root
OnRoot(common.Hash)
// OnAccount is called once for each account in the trie
OnAccount(*common.Address, DumpAccount)
}
// DumpAccount represents an account in the state.
type DumpAccount struct {
Balance string `json:"balance"`
Nonce uint64 `json:"nonce"`
Root hexutil.Bytes `json:"root"`
CodeHash hexutil.Bytes `json:"codeHash"`
Code hexutil.Bytes `json:"code,omitempty"`
Storage map[common.Hash]string `json:"storage,omitempty"`
Address *common.Address `json:"address,omitempty"` // Address only present in iterative (line-by-line) mode
AddressHash hexutil.Bytes `json:"key,omitempty"` // If we don't have address, we can output the key
}
// Dump represents the full dump in a collected format, as one large map.
type Dump struct {
Root string `json:"root"`
Accounts map[string]DumpAccount `json:"accounts"`
// Next can be set to represent that this dump is only partial, and Next
// is where an iterator should be positioned in order to continue the dump.
Next []byte `json:"next,omitempty"` // nil if no more accounts
}
// OnRoot implements DumpCollector interface
func (d *Dump) OnRoot(root common.Hash) {
d.Root = fmt.Sprintf("%x", root)
}
// OnAccount implements DumpCollector interface
func (d *Dump) OnAccount(addr *common.Address, account DumpAccount) {
if addr == nil {
d.Accounts[fmt.Sprintf("pre(%s)", account.AddressHash)] = account
}
if addr != nil {
d.Accounts[(*addr).String()] = account
}
}
// iterativeDump is a DumpCollector-implementation which dumps output line-by-line iteratively.
type iterativeDump struct {
*json.Encoder
}
// OnAccount implements DumpCollector interface
func (d iterativeDump) OnAccount(addr *common.Address, account DumpAccount) {
dumpAccount := &DumpAccount{
Balance: account.Balance,
Nonce: account.Nonce,
Root: account.Root,
CodeHash: account.CodeHash,
Code: account.Code,
Storage: account.Storage,
AddressHash: account.AddressHash,
Address: addr,
}
d.Encode(dumpAccount)
}
// OnRoot implements DumpCollector interface
func (d iterativeDump) OnRoot(root common.Hash) {
d.Encode(struct {
Root common.Hash `json:"root"`
}{root})
}
// DumpToCollector iterates the state according to the given options and inserts
// the items into a collector for aggregation or serialization.
func (s *StateDB) DumpToCollector(c DumpCollector, conf *DumpConfig) (nextKey []byte) {
// Sanitize the input to allow nil configs
if conf == nil {
conf = new(DumpConfig)
}
var (
missingPreimages int
accounts uint64
start = time.Now()
logged = time.Now()
)
log.Info("Trie dumping started", "root", s.trie.Hash())
c.OnRoot(s.trie.Hash())
trieIt, err := s.trie.NodeIterator(conf.Start)
if err != nil {
log.Error("Trie dumping error", "err", err)
return nil
}
it := trie.NewIterator(trieIt)
for it.Next() {
var data types.StateAccount
if err := rlp.DecodeBytes(it.Value, &data); err != nil {
panic(err)
}
var (
account = DumpAccount{
Balance: data.Balance.String(),
Nonce: data.Nonce,
Root: data.Root[:],
CodeHash: data.CodeHash,
AddressHash: it.Key,
}
address *common.Address
addr common.Address
addrBytes = s.trie.GetKey(it.Key)
)
if addrBytes == nil {
missingPreimages++
if conf.OnlyWithAddresses {
continue
}
} else {
addr = common.BytesToAddress(addrBytes)
address = &addr
account.Address = address
}
obj := newObject(s, addr, &data)
if !conf.SkipCode {
account.Code = obj.Code()
}
if !conf.SkipStorage {
account.Storage = make(map[common.Hash]string)
tr, err := obj.getTrie()
if err != nil {
log.Error("Failed to load storage trie", "err", err)
continue
}
trieIt, err := tr.NodeIterator(nil)
if err != nil {
log.Error("Failed to create trie iterator", "err", err)
continue
}
storageIt := trie.NewIterator(trieIt)
for storageIt.Next() {
_, content, _, err := rlp.Split(storageIt.Value)
if err != nil {
log.Error("Failed to decode the value returned by iterator", "error", err)
continue
}
account.Storage[common.BytesToHash(s.trie.GetKey(storageIt.Key))] = common.Bytes2Hex(content)
}
}
c.OnAccount(address, account)
accounts++
if time.Since(logged) > 8*time.Second {
log.Info("Trie dumping in progress", "at", it.Key, "accounts", accounts,
"elapsed", common.PrettyDuration(time.Since(start)))
logged = time.Now()
}
if conf.Max > 0 && accounts >= conf.Max {
if it.Next() {
nextKey = it.Key
}
break
}
}
if missingPreimages > 0 {
log.Warn("Dump incomplete due to missing preimages", "missing", missingPreimages)
}
log.Info("Trie dumping complete", "accounts", accounts,
"elapsed", common.PrettyDuration(time.Since(start)))
return nextKey
}
// RawDump returns the state. If the processing is aborted e.g. due to options
// reaching Max, the `Next` key is set on the returned Dump.
func (s *StateDB) RawDump(opts *DumpConfig) Dump {
dump := &Dump{
Accounts: make(map[string]DumpAccount),
}
dump.Next = s.DumpToCollector(dump, opts)
return *dump
}
// Dump returns a JSON string representing the entire state as a single json-object
func (s *StateDB) Dump(opts *DumpConfig) []byte {
dump := s.RawDump(opts)
json, err := json.MarshalIndent(dump, "", " ")
if err != nil {
log.Error("Error dumping state", "err", err)
}
return json
}
// IterativeDump dumps out accounts as json-objects, delimited by linebreaks on stdout
func (s *StateDB) IterativeDump(opts *DumpConfig, output *json.Encoder) {
s.DumpToCollector(iterativeDump{output}, opts)
}

View File

@ -17,9 +17,8 @@
package state package state
import ( import (
"math/big"
"github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/common"
"github.com/holiman/uint256"
) )
// journalEntry is a modification entry in the state change journal that can be // journalEntry is a modification entry in the state change journal that can be
@ -90,19 +89,26 @@ type (
account *common.Address account *common.Address
} }
resetObjectChange struct { resetObjectChange struct {
account *common.Address
prev *stateObject prev *stateObject
prevdestruct bool prevdestruct bool
prevAccount []byte
prevStorage map[common.Hash][]byte
prevAccountOriginExist bool
prevAccountOrigin []byte
prevStorageOrigin map[common.Hash][]byte
} }
suicideChange struct { selfDestructChange struct {
account *common.Address account *common.Address
prev bool // whether account had already suicided prev bool // whether account had already self-destructed
prevbalance *big.Int prevbalance *uint256.Int
} }
// Changes to individual accounts. // Changes to individual accounts.
balanceChange struct { balanceChange struct {
account *common.Address account *common.Address
prev *big.Int prev *uint256.Int
} }
nonceChange struct { nonceChange struct {
account *common.Address account *common.Address
@ -159,21 +165,33 @@ func (ch resetObjectChange) revert(s *StateDB) {
if !ch.prevdestruct { if !ch.prevdestruct {
delete(s.stateObjectsDestruct, ch.prev.address) delete(s.stateObjectsDestruct, ch.prev.address)
} }
if ch.prevAccount != nil {
s.accounts[ch.prev.addrHash] = ch.prevAccount
}
if ch.prevStorage != nil {
s.storages[ch.prev.addrHash] = ch.prevStorage
}
if ch.prevAccountOriginExist {
s.accountsOrigin[ch.prev.address] = ch.prevAccountOrigin
}
if ch.prevStorageOrigin != nil {
s.storagesOrigin[ch.prev.address] = ch.prevStorageOrigin
}
} }
func (ch resetObjectChange) dirtied() *common.Address { func (ch resetObjectChange) dirtied() *common.Address {
return nil return ch.account
} }
func (ch suicideChange) revert(s *StateDB) { func (ch selfDestructChange) revert(s *StateDB) {
obj := s.getStateObject(*ch.account) obj := s.getStateObject(*ch.account)
if obj != nil { if obj != nil {
obj.suicided = ch.prev obj.selfDestructed = ch.prev
obj.setBalance(ch.prevbalance) obj.setBalance(ch.prevbalance)
} }
} }
func (ch suicideChange) dirtied() *common.Address { func (ch selfDestructChange) dirtied() *common.Address {
return ch.account return ch.account
} }

View File

@ -27,4 +27,11 @@ var (
storageTriesUpdatedMeter = metrics.NewRegisteredMeter("state/update/storagenodes", nil) storageTriesUpdatedMeter = metrics.NewRegisteredMeter("state/update/storagenodes", nil)
accountTrieDeletedMeter = metrics.NewRegisteredMeter("state/delete/accountnodes", nil) accountTrieDeletedMeter = metrics.NewRegisteredMeter("state/delete/accountnodes", nil)
storageTriesDeletedMeter = metrics.NewRegisteredMeter("state/delete/storagenodes", nil) storageTriesDeletedMeter = metrics.NewRegisteredMeter("state/delete/storagenodes", nil)
slotDeletionMaxCount = metrics.NewRegisteredGauge("state/delete/storage/max/slot", nil)
slotDeletionMaxSize = metrics.NewRegisteredGauge("state/delete/storage/max/size", nil)
slotDeletionTimer = metrics.NewRegisteredResettingTimer("state/delete/storage/timer", nil)
slotDeletionCount = metrics.NewRegisteredMeter("state/delete/storage/slot", nil)
slotDeletionSize = metrics.NewRegisteredMeter("state/delete/storage/size", nil)
slotDeletionSkip = metrics.NewRegisteredGauge("state/delete/storage/skip", nil)
) )

View File

@ -20,7 +20,6 @@ import (
"bytes" "bytes"
"fmt" "fmt"
"io" "io"
"math/big"
"time" "time"
"github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/common"
@ -28,7 +27,9 @@ import (
"github.com/ethereum/go-ethereum/crypto" "github.com/ethereum/go-ethereum/crypto"
"github.com/ethereum/go-ethereum/metrics" "github.com/ethereum/go-ethereum/metrics"
"github.com/ethereum/go-ethereum/rlp" "github.com/ethereum/go-ethereum/rlp"
// "github.com/cerc-io/ipld-eth-statedb/trie_by_cid/trie" "github.com/holiman/uint256"
"github.com/cerc-io/ipld-eth-statedb/trie_by_cid/trie/trienode"
) )
type Code []byte type Code []byte
@ -57,54 +58,64 @@ func (s Storage) Copy() Storage {
// stateObject represents an Ethereum account which is being modified. // stateObject represents an Ethereum account which is being modified.
// //
// The usage pattern is as follows: // The usage pattern is as follows:
// First you need to obtain a state object. // - First you need to obtain a state object.
// Account values can be accessed and modified through the object. // - Account values as well as storages can be accessed and modified through the object.
// - Finally, call commit to return the changes of storage trie and update account data.
type stateObject struct { type stateObject struct {
address common.Address
addrHash common.Hash // hash of ethereum address of the account
data types.StateAccount
db *StateDB db *StateDB
address common.Address // address of ethereum account
addrHash common.Hash // hash of ethereum address of the account
origin *types.StateAccount // Account original data without any change applied, nil means it was not existent
data types.StateAccount // Account data with all mutations applied in the scope of block
// Write caches. // Write caches.
trie Trie // storage trie, which becomes non-nil on first access trie Trie // storage trie, which becomes non-nil on first access
code Code // contract bytecode, which gets set when code is loaded code Code // contract bytecode, which gets set when code is loaded
originStorage Storage // Storage cache of original entries to dedup rewrites, reset for every transaction originStorage Storage // Storage cache of original entries to dedup rewrites
pendingStorage Storage // Storage entries that need to be flushed to disk, at the end of an entire block pendingStorage Storage // Storage entries that need to be flushed to disk, at the end of an entire block
dirtyStorage Storage // Storage entries that have been modified in the current transaction execution dirtyStorage Storage // Storage entries that have been modified in the current transaction execution, reset for every transaction
// Cache flags. // Cache flags.
// When an object is marked suicided it will be deleted from the trie
// during the "update" phase of the state transition.
dirtyCode bool // true if the code was updated dirtyCode bool // true if the code was updated
suicided bool
// Flag whether the account was marked as self-destructed. The self-destructed account
// is still accessible in the scope of same transaction.
selfDestructed bool
// Flag whether the account was marked as deleted. A self-destructed account
// or an account that is considered as empty will be marked as deleted at
// the end of transaction and no longer accessible anymore.
deleted bool deleted bool
// Flag whether the object was created in the current transaction
created bool
} }
// empty returns whether the account is considered empty. // empty returns whether the account is considered empty.
func (s *stateObject) empty() bool { func (s *stateObject) empty() bool {
return s.data.Nonce == 0 && s.data.Balance.Sign() == 0 && bytes.Equal(s.data.CodeHash, types.EmptyCodeHash.Bytes()) return s.data.Nonce == 0 && s.data.Balance.IsZero() && bytes.Equal(s.data.CodeHash, types.EmptyCodeHash.Bytes())
} }
// newObject creates a state object. // newObject creates a state object.
func newObject(db *StateDB, address common.Address, data types.StateAccount) *stateObject { func newObject(db *StateDB, address common.Address, acct *types.StateAccount) *stateObject {
if data.Balance == nil { var (
data.Balance = new(big.Int) origin = acct
} created = acct == nil // true if the account was not existent
if data.CodeHash == nil { )
data.CodeHash = types.EmptyCodeHash.Bytes() if acct == nil {
} acct = types.NewEmptyStateAccount()
if data.Root == (common.Hash{}) {
data.Root = types.EmptyRootHash
} }
return &stateObject{ return &stateObject{
db: db, db: db,
address: address, address: address,
addrHash: crypto.Keccak256Hash(address[:]), addrHash: crypto.Keccak256Hash(address[:]),
data: data, origin: origin,
data: *acct,
originStorage: make(Storage), originStorage: make(Storage),
pendingStorage: make(Storage), pendingStorage: make(Storage),
dirtyStorage: make(Storage), dirtyStorage: make(Storage),
created: created,
} }
} }
@ -113,8 +124,8 @@ func (s *stateObject) EncodeRLP(w io.Writer) error {
return rlp.Encode(w, &s.data) return rlp.Encode(w, &s.data)
} }
func (s *stateObject) markSuicided() { func (s *stateObject) markSelfdestructed() {
s.suicided = true s.selfDestructed = true
} }
func (s *stateObject) touch() { func (s *stateObject) touch() {
@ -131,17 +142,15 @@ func (s *stateObject) touch() {
// getTrie returns the associated storage trie. The trie will be opened // getTrie returns the associated storage trie. The trie will be opened
// if it's not loaded previously. An error will be returned if trie can't // if it's not loaded previously. An error will be returned if trie can't
// be loaded. // be loaded.
func (s *stateObject) getTrie(db Database) (Trie, error) { func (s *stateObject) getTrie() (Trie, error) {
if s.trie == nil { if s.trie == nil {
// Try fetching from prefetcher first // Try fetching from prefetcher first
// We don't prefetch empty tries
if s.data.Root != types.EmptyRootHash && s.db.prefetcher != nil { if s.data.Root != types.EmptyRootHash && s.db.prefetcher != nil {
// When the miner is creating the pending state, there is no // When the miner is creating the pending state, there is no prefetcher
// prefetcher
s.trie = s.db.prefetcher.trie(s.addrHash, s.data.Root) s.trie = s.db.prefetcher.trie(s.addrHash, s.data.Root)
} }
if s.trie == nil { if s.trie == nil {
tr, err := db.OpenStorageTrie(s.db.originalRoot, s.addrHash, s.data.Root) tr, err := s.db.db.OpenStorageTrie(s.db.originalRoot, s.addrHash, s.data.Root, s.db.trie)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -152,18 +161,18 @@ func (s *stateObject) getTrie(db Database) (Trie, error) {
} }
// GetState retrieves a value from the account storage trie. // GetState retrieves a value from the account storage trie.
func (s *stateObject) GetState(db Database, key common.Hash) common.Hash { func (s *stateObject) GetState(key common.Hash) common.Hash {
// If we have a dirty value for this state entry, return it // If we have a dirty value for this state entry, return it
value, dirty := s.dirtyStorage[key] value, dirty := s.dirtyStorage[key]
if dirty { if dirty {
return value return value
} }
// Otherwise return the entry's original value // Otherwise return the entry's original value
return s.GetCommittedState(db, key) return s.GetCommittedState(key)
} }
// GetCommittedState retrieves a value from the committed account storage trie. // GetCommittedState retrieves a value from the committed account storage trie.
func (s *stateObject) GetCommittedState(db Database, key common.Hash) common.Hash { func (s *stateObject) GetCommittedState(key common.Hash) common.Hash {
// If we have a pending write or clean cached, return that // If we have a pending write or clean cached, return that
if value, pending := s.pendingStorage[key]; pending { if value, pending := s.pendingStorage[key]; pending {
return value return value
@ -184,6 +193,7 @@ func (s *stateObject) GetCommittedState(db Database, key common.Hash) common.Has
var ( var (
enc []byte enc []byte
err error err error
value common.Hash
) )
if s.db.snap != nil { if s.db.snap != nil {
start := time.Now() start := time.Now()
@ -191,25 +201,6 @@ func (s *stateObject) GetCommittedState(db Database, key common.Hash) common.Has
if metrics.EnabledExpensive { if metrics.EnabledExpensive {
s.db.SnapshotStorageReads += time.Since(start) s.db.SnapshotStorageReads += time.Since(start)
} }
}
// If the snapshot is unavailable or reading from it fails, load from the database.
if s.db.snap == nil || err != nil {
start := time.Now()
tr, err := s.getTrie(db)
if err != nil {
s.db.setError(err)
return common.Hash{}
}
enc, err = tr.TryGet(key.Bytes())
if metrics.EnabledExpensive {
s.db.StorageReads += time.Since(start)
}
if err != nil {
s.db.setError(err)
return common.Hash{}
}
}
var value common.Hash
if len(enc) > 0 { if len(enc) > 0 {
_, content, _, err := rlp.Split(enc) _, content, _, err := rlp.Split(enc)
if err != nil { if err != nil {
@ -217,14 +208,33 @@ func (s *stateObject) GetCommittedState(db Database, key common.Hash) common.Has
} }
value.SetBytes(content) value.SetBytes(content)
} }
}
// If the snapshot is unavailable or reading from it fails, load from the database.
if s.db.snap == nil || err != nil {
start := time.Now()
tr, err := s.getTrie()
if err != nil {
s.db.setError(err)
return common.Hash{}
}
val, err := tr.GetStorage(s.address, key.Bytes())
if metrics.EnabledExpensive {
s.db.StorageReads += time.Since(start)
}
if err != nil {
s.db.setError(err)
return common.Hash{}
}
value.SetBytes(val)
}
s.originStorage[key] = value s.originStorage[key] = value
return value return value
} }
// SetState updates a value in account storage. // SetState updates a value in account storage.
func (s *stateObject) SetState(db Database, key, value common.Hash) { func (s *stateObject) SetState(key, value common.Hash) {
// If the new value is the same as old, don't set // If the new value is the same as old, don't set
prev := s.GetState(db, key) prev := s.GetState(key)
if prev == value { if prev == value {
return return
} }
@ -252,19 +262,24 @@ func (s *stateObject) finalise(prefetch bool) {
} }
} }
if s.db.prefetcher != nil && prefetch && len(slotsToPrefetch) > 0 && s.data.Root != types.EmptyRootHash { if s.db.prefetcher != nil && prefetch && len(slotsToPrefetch) > 0 && s.data.Root != types.EmptyRootHash {
s.db.prefetcher.prefetch(s.addrHash, s.data.Root, slotsToPrefetch) s.db.prefetcher.prefetch(s.addrHash, s.data.Root, s.address, slotsToPrefetch)
} }
if len(s.dirtyStorage) > 0 { if len(s.dirtyStorage) > 0 {
s.dirtyStorage = make(Storage) s.dirtyStorage = make(Storage)
} }
} }
// updateTrie writes cached storage modifications into the object's storage trie. // updateTrie is responsible for persisting cached storage changes into the
// It will return nil if the trie has not been loaded and no changes have been // object's storage trie. In case the storage trie is not yet loaded, this
// made. An error will be returned if the trie can't be loaded/updated correctly. // function will load the trie automatically. If any issues arise during the
func (s *stateObject) updateTrie(db Database) (Trie, error) { // loading or updating of the trie, an error will be returned. Furthermore,
// this function will return the mutated storage trie, or nil if there is no
// storage change at all.
func (s *stateObject) updateTrie() (Trie, error) {
// Make sure all dirty slots are finalized into the pending storage area // Make sure all dirty slots are finalized into the pending storage area
s.finalise(false) // Don't prefetch anymore, pull directly if need be s.finalise(false)
// Short circuit if nothing changed, don't bother with hashing anything
if len(s.pendingStorage) == 0 { if len(s.pendingStorage) == 0 {
return s.trie, nil return s.trie, nil
} }
@ -275,69 +290,84 @@ func (s *stateObject) updateTrie(db Database) (Trie, error) {
// The snapshot storage map for the object // The snapshot storage map for the object
var ( var (
storage map[common.Hash][]byte storage map[common.Hash][]byte
hasher = s.db.hasher origin map[common.Hash][]byte
) )
tr, err := s.getTrie(db) tr, err := s.getTrie()
if err != nil { if err != nil {
s.db.setError(err) s.db.setError(err)
return nil, err return nil, err
} }
// Insert all the pending updates into the trie // Insert all the pending storage updates into the trie
usedStorage := make([][]byte, 0, len(s.pendingStorage)) usedStorage := make([][]byte, 0, len(s.pendingStorage))
for key, value := range s.pendingStorage { for key, value := range s.pendingStorage {
// Skip noop changes, persist actual changes // Skip noop changes, persist actual changes
if value == s.originStorage[key] { if value == s.originStorage[key] {
continue continue
} }
prev := s.originStorage[key]
s.originStorage[key] = value s.originStorage[key] = value
var v []byte var encoded []byte // rlp-encoded value to be used by the snapshot
if (value == common.Hash{}) { if (value == common.Hash{}) {
if err := tr.TryDelete(key[:]); err != nil { if err := tr.DeleteStorage(s.address, key[:]); err != nil {
s.db.setError(err) s.db.setError(err)
return nil, err return nil, err
} }
s.db.StorageDeleted += 1 s.db.StorageDeleted += 1
} else { } else {
// Encoding []byte cannot fail, ok to ignore the error. // Encoding []byte cannot fail, ok to ignore the error.
v, _ = rlp.EncodeToBytes(common.TrimLeftZeroes(value[:])) trimmed := common.TrimLeftZeroes(value[:])
if err := tr.TryUpdate(key[:], v); err != nil { encoded, _ = rlp.EncodeToBytes(trimmed)
if err := tr.UpdateStorage(s.address, key[:], trimmed); err != nil {
s.db.setError(err) s.db.setError(err)
return nil, err return nil, err
} }
s.db.StorageUpdated += 1 s.db.StorageUpdated += 1
} }
// If state snapshotting is active, cache the data til commit // Cache the mutated storage slots until commit
if s.db.snap != nil {
if storage == nil { if storage == nil {
// Retrieve the old storage map, if available, create a new one otherwise if storage = s.db.storages[s.addrHash]; storage == nil {
if storage = s.db.snapStorage[s.addrHash]; storage == nil {
storage = make(map[common.Hash][]byte) storage = make(map[common.Hash][]byte)
s.db.snapStorage[s.addrHash] = storage s.db.storages[s.addrHash] = storage
} }
} }
storage[crypto.HashData(hasher, key[:])] = v // v will be nil if it's deleted khash := crypto.HashData(s.db.hasher, key[:])
storage[khash] = encoded // encoded will be nil if it's deleted
// Cache the original value of mutated storage slots
if origin == nil {
if origin = s.db.storagesOrigin[s.address]; origin == nil {
origin = make(map[common.Hash][]byte)
s.db.storagesOrigin[s.address] = origin
} }
}
// Track the original value of slot only if it's mutated first time
if _, ok := origin[khash]; !ok {
if prev == (common.Hash{}) {
origin[khash] = nil // nil if it was not present previously
} else {
// Encoding []byte cannot fail, ok to ignore the error.
b, _ := rlp.EncodeToBytes(common.TrimLeftZeroes(prev[:]))
origin[khash] = b
}
}
// Cache the items for preloading
usedStorage = append(usedStorage, common.CopyBytes(key[:])) // Copy needed for closure usedStorage = append(usedStorage, common.CopyBytes(key[:])) // Copy needed for closure
} }
if s.db.prefetcher != nil { if s.db.prefetcher != nil {
s.db.prefetcher.used(s.addrHash, s.data.Root, usedStorage) s.db.prefetcher.used(s.addrHash, s.data.Root, usedStorage)
} }
if len(s.pendingStorage) > 0 { s.pendingStorage = make(Storage) // reset pending map
s.pendingStorage = make(Storage)
}
return tr, nil return tr, nil
} }
// UpdateRoot sets the trie root to the current root hash of. An error // updateRoot flushes all cached storage mutations to trie, recalculating the
// will be returned if trie root hash is not computed correctly. // new storage trie root.
func (s *stateObject) updateRoot(db Database) { func (s *stateObject) updateRoot() {
tr, err := s.updateTrie(db) // Flush cached storage mutations into trie, short circuit if any error
if err != nil { // is occurred or there is not change in the trie.
return tr, err := s.updateTrie()
} if err != nil || tr == nil {
// If nothing changed, don't bother with hashing anything
if tr == nil {
return return
} }
// Track the amount of time wasted on hashing the storage trie // Track the amount of time wasted on hashing the storage trie
@ -347,54 +377,87 @@ func (s *stateObject) updateRoot(db Database) {
s.data.Root = tr.Hash() s.data.Root = tr.Hash()
} }
// commit obtains a set of dirty storage trie nodes and updates the account data.
// The returned set can be nil if nothing to commit. This function assumes all
// storage mutations have already been flushed into trie by updateRoot.
func (s *stateObject) commit() (*trienode.NodeSet, error) {
// Short circuit if trie is not even loaded, don't bother with committing anything
if s.trie == nil {
s.origin = s.data.Copy()
return nil, nil
}
// Track the amount of time wasted on committing the storage trie
if metrics.EnabledExpensive {
defer func(start time.Time) { s.db.StorageCommits += time.Since(start) }(time.Now())
}
// The trie is currently in an open state and could potentially contain
// cached mutations. Call commit to acquire a set of nodes that have been
// modified, the set can be nil if nothing to commit.
root, nodes, err := s.trie.Commit(false)
if err != nil {
return nil, err
}
s.data.Root = root
// Update original account data after commit
s.origin = s.data.Copy()
return nodes, nil
}
// AddBalance adds amount to s's balance. // AddBalance adds amount to s's balance.
// It is used to add funds to the destination account of a transfer. // It is used to add funds to the destination account of a transfer.
func (s *stateObject) AddBalance(amount *big.Int) { func (s *stateObject) AddBalance(amount *uint256.Int) {
// EIP161: We must check emptiness for the objects such that the account // EIP161: We must check emptiness for the objects such that the account
// clearing (0,0,0 objects) can take effect. // clearing (0,0,0 objects) can take effect.
if amount.Sign() == 0 { if amount.IsZero() {
if s.empty() { if s.empty() {
s.touch() s.touch()
} }
return return
} }
s.SetBalance(new(big.Int).Add(s.Balance(), amount)) s.SetBalance(new(uint256.Int).Add(s.Balance(), amount))
} }
// SubBalance removes amount from s's balance. // SubBalance removes amount from s's balance.
// It is used to remove funds from the origin account of a transfer. // It is used to remove funds from the origin account of a transfer.
func (s *stateObject) SubBalance(amount *big.Int) { func (s *stateObject) SubBalance(amount *uint256.Int) {
if amount.Sign() == 0 { if amount.IsZero() {
return return
} }
s.SetBalance(new(big.Int).Sub(s.Balance(), amount)) s.SetBalance(new(uint256.Int).Sub(s.Balance(), amount))
} }
func (s *stateObject) SetBalance(amount *big.Int) { func (s *stateObject) SetBalance(amount *uint256.Int) {
s.db.journal.append(balanceChange{ s.db.journal.append(balanceChange{
account: &s.address, account: &s.address,
prev: new(big.Int).Set(s.data.Balance), prev: new(uint256.Int).Set(s.data.Balance),
}) })
s.setBalance(amount) s.setBalance(amount)
} }
func (s *stateObject) setBalance(amount *big.Int) { func (s *stateObject) setBalance(amount *uint256.Int) {
s.data.Balance = amount s.data.Balance = amount
} }
func (s *stateObject) deepCopy(db *StateDB) *stateObject { func (s *stateObject) deepCopy(db *StateDB) *stateObject {
stateObject := newObject(db, s.address, s.data) obj := &stateObject{
if s.trie != nil { db: db,
stateObject.trie = db.db.CopyTrie(s.trie) address: s.address,
addrHash: s.addrHash,
origin: s.origin,
data: s.data,
} }
stateObject.code = s.code if s.trie != nil {
stateObject.dirtyStorage = s.dirtyStorage.Copy() obj.trie = db.db.CopyTrie(s.trie)
stateObject.originStorage = s.originStorage.Copy() }
stateObject.pendingStorage = s.pendingStorage.Copy() obj.code = s.code
stateObject.suicided = s.suicided obj.dirtyStorage = s.dirtyStorage.Copy()
stateObject.dirtyCode = s.dirtyCode obj.originStorage = s.originStorage.Copy()
stateObject.deleted = s.deleted obj.pendingStorage = s.pendingStorage.Copy()
return stateObject obj.selfDestructed = s.selfDestructed
obj.dirtyCode = s.dirtyCode
obj.deleted = s.deleted
return obj
} }
// //
@ -407,14 +470,14 @@ func (s *stateObject) Address() common.Address {
} }
// Code returns the contract code associated with this object, if any. // Code returns the contract code associated with this object, if any.
func (s *stateObject) Code(db Database) []byte { func (s *stateObject) Code() []byte {
if s.code != nil { if s.code != nil {
return s.code return s.code
} }
if bytes.Equal(s.CodeHash(), types.EmptyCodeHash.Bytes()) { if bytes.Equal(s.CodeHash(), types.EmptyCodeHash.Bytes()) {
return nil return nil
} }
code, err := db.ContractCode(common.BytesToHash(s.CodeHash())) code, err := s.db.db.ContractCode(s.address, common.BytesToHash(s.CodeHash()))
if err != nil { if err != nil {
s.db.setError(fmt.Errorf("can't load code hash %x: %v", s.CodeHash(), err)) s.db.setError(fmt.Errorf("can't load code hash %x: %v", s.CodeHash(), err))
} }
@ -425,14 +488,14 @@ func (s *stateObject) Code(db Database) []byte {
// CodeSize returns the size of the contract code associated with this object, // CodeSize returns the size of the contract code associated with this object,
// or zero if none. This method is an almost mirror of Code, but uses a cache // or zero if none. This method is an almost mirror of Code, but uses a cache
// inside the database to avoid loading codes seen recently. // inside the database to avoid loading codes seen recently.
func (s *stateObject) CodeSize(db Database) int { func (s *stateObject) CodeSize() int {
if s.code != nil { if s.code != nil {
return len(s.code) return len(s.code)
} }
if bytes.Equal(s.CodeHash(), types.EmptyCodeHash.Bytes()) { if bytes.Equal(s.CodeHash(), types.EmptyCodeHash.Bytes()) {
return 0 return 0
} }
size, err := db.ContractCodeSize(common.BytesToHash(s.CodeHash())) size, err := s.db.db.ContractCodeSize(s.address, common.BytesToHash(s.CodeHash()))
if err != nil { if err != nil {
s.db.setError(fmt.Errorf("can't load code size %x: %v", s.CodeHash(), err)) s.db.setError(fmt.Errorf("can't load code size %x: %v", s.CodeHash(), err))
} }
@ -440,7 +503,7 @@ func (s *stateObject) CodeSize(db Database) int {
} }
func (s *stateObject) SetCode(codeHash common.Hash, code []byte) { func (s *stateObject) SetCode(codeHash common.Hash, code []byte) {
prevcode := s.Code(s.db.db) prevcode := s.Code()
s.db.journal.append(codeChange{ s.db.journal.append(codeChange{
account: &s.address, account: &s.address,
prevhash: s.CodeHash(), prevhash: s.CodeHash(),
@ -471,10 +534,14 @@ func (s *stateObject) CodeHash() []byte {
return s.data.CodeHash return s.data.CodeHash
} }
func (s *stateObject) Balance() *big.Int { func (s *stateObject) Balance() *uint256.Int {
return s.data.Balance return s.data.Balance
} }
func (s *stateObject) Nonce() uint64 { func (s *stateObject) Nonce() uint64 {
return s.data.Nonce return s.data.Nonce
} }
func (s *stateObject) Root() common.Hash {
return s.data.Root
}

View File

@ -19,15 +19,15 @@ package state
import ( import (
"bytes" "bytes"
"context" "context"
"math/big"
"testing" "testing"
pgipfsethdb "github.com/cerc-io/ipfs-ethdb/v5/postgres/v0" pgipfsethdb "github.com/cerc-io/ipfs-ethdb/v5/postgres/v0"
"github.com/cerc-io/plugeth-statediff/indexer/database/sql/postgres" "github.com/cerc-io/plugeth-statediff/indexer/database/sql/postgres"
"github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/core/rawdb" "github.com/ethereum/go-ethereum/core/types"
"github.com/ethereum/go-ethereum/crypto" "github.com/ethereum/go-ethereum/crypto"
"github.com/ethereum/go-ethereum/ethdb" "github.com/ethereum/go-ethereum/ethdb"
"github.com/holiman/uint256"
"github.com/cerc-io/ipld-eth-statedb/internal" "github.com/cerc-io/ipld-eth-statedb/internal"
) )
@ -38,33 +38,38 @@ var (
teardownStatements = []string{`TRUNCATE ipld.blocks`} teardownStatements = []string{`TRUNCATE ipld.blocks`}
) )
type stateTest struct { type stateEnv struct {
db ethdb.Database db ethdb.Database
state *StateDB state *StateDB
} }
func newStateTest(t *testing.T) *stateTest { func newStateEnv(t *testing.T) *stateEnv {
db := newPgIpfsEthdb(t)
sdb, err := New(types.EmptyRootHash, NewDatabase(db), nil)
if err != nil {
t.Fatal(err)
}
return &stateEnv{db: db, state: sdb}
}
func newPgIpfsEthdb(t *testing.T) ethdb.Database {
pool, err := postgres.ConnectSQLX(testCtx, testConfig) pool, err := postgres.ConnectSQLX(testCtx, testConfig)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
db := pgipfsethdb.NewDatabase(pool, internal.MakeCacheConfig(t)) db := pgipfsethdb.NewDatabase(pool, internal.MakeCacheConfig(t))
sdb, err := New(common.Hash{}, NewDatabase(db), nil) return db
if err != nil {
t.Fatal(err)
}
return &stateTest{db: db, state: sdb}
} }
func TestNull(t *testing.T) { func TestNull(t *testing.T) {
s := newStateTest(t) s := newStateEnv(t)
address := common.HexToAddress("0x823140710bf13990e4500136726d8b55") address := common.HexToAddress("0x823140710bf13990e4500136726d8b55")
s.state.CreateAccount(address) s.state.CreateAccount(address)
//value := common.FromHex("0x823140710bf13990e4500136726d8b55") //value := common.FromHex("0x823140710bf13990e4500136726d8b55")
var value common.Hash var value common.Hash
s.state.SetState(address, common.Hash{}, value) s.state.SetState(address, common.Hash{}, value)
// s.state.Commit(false) // s.state.Commit(0, false)
if value := s.state.GetState(address, common.Hash{}); value != (common.Hash{}) { if value := s.state.GetState(address, common.Hash{}); value != (common.Hash{}) {
t.Errorf("expected empty current value, got %x", value) t.Errorf("expected empty current value, got %x", value)
@ -79,7 +84,7 @@ func TestSnapshot(t *testing.T) {
var storageaddr common.Hash var storageaddr common.Hash
data1 := common.BytesToHash([]byte{42}) data1 := common.BytesToHash([]byte{42})
data2 := common.BytesToHash([]byte{43}) data2 := common.BytesToHash([]byte{43})
s := newStateTest(t) s := newStateEnv(t)
// snapshot the genesis state // snapshot the genesis state
genesis := s.state.Snapshot() genesis := s.state.Snapshot()
@ -110,12 +115,12 @@ func TestSnapshot(t *testing.T) {
} }
func TestSnapshotEmpty(t *testing.T) { func TestSnapshotEmpty(t *testing.T) {
s := newStateTest(t) s := newStateEnv(t)
s.state.RevertToSnapshot(s.state.Snapshot()) s.state.RevertToSnapshot(s.state.Snapshot())
} }
func TestSnapshot2(t *testing.T) { func TestSnapshot2(t *testing.T) {
state, _ := New(common.Hash{}, NewDatabase(rawdb.NewMemoryDatabase()), nil) state, _ := New(types.EmptyRootHash, NewDatabase(newPgIpfsEthdb(t)), nil)
stateobjaddr0 := common.BytesToAddress([]byte("so0")) stateobjaddr0 := common.BytesToAddress([]byte("so0"))
stateobjaddr1 := common.BytesToAddress([]byte("so1")) stateobjaddr1 := common.BytesToAddress([]byte("so1"))
@ -129,22 +134,22 @@ func TestSnapshot2(t *testing.T) {
// db, trie are already non-empty values // db, trie are already non-empty values
so0 := state.getStateObject(stateobjaddr0) so0 := state.getStateObject(stateobjaddr0)
so0.SetBalance(big.NewInt(42)) so0.SetBalance(uint256.NewInt(42))
so0.SetNonce(43) so0.SetNonce(43)
so0.SetCode(crypto.Keccak256Hash([]byte{'c', 'a', 'f', 'e'}), []byte{'c', 'a', 'f', 'e'}) so0.SetCode(crypto.Keccak256Hash([]byte{'c', 'a', 'f', 'e'}), []byte{'c', 'a', 'f', 'e'})
so0.suicided = false so0.selfDestructed = false
so0.deleted = false so0.deleted = false
state.setStateObject(so0) state.setStateObject(so0)
// root, _ := state.Commit(false) // root, _ := state.Commit(0, false)
// state, _ = New(root, state.db, state.snaps) // state, _ = New(root, state.db, state.snaps)
// and one with deleted == true // and one with deleted == true
so1 := state.getStateObject(stateobjaddr1) so1 := state.getStateObject(stateobjaddr1)
so1.SetBalance(big.NewInt(52)) so1.SetBalance(uint256.NewInt(52))
so1.SetNonce(53) so1.SetNonce(53)
so1.SetCode(crypto.Keccak256Hash([]byte{'c', 'a', 'f', 'e', '2'}), []byte{'c', 'a', 'f', 'e', '2'}) so1.SetCode(crypto.Keccak256Hash([]byte{'c', 'a', 'f', 'e', '2'}), []byte{'c', 'a', 'f', 'e', '2'})
so1.suicided = true so1.selfDestructed = true
so1.deleted = true so1.deleted = true
state.setStateObject(so1) state.setStateObject(so1)
@ -158,8 +163,8 @@ func TestSnapshot2(t *testing.T) {
so0Restored := state.getStateObject(stateobjaddr0) so0Restored := state.getStateObject(stateobjaddr0)
// Update lazily-loaded values before comparing. // Update lazily-loaded values before comparing.
so0Restored.GetState(state.db, storageaddr) so0Restored.GetState(storageaddr)
so0Restored.Code(state.db) so0Restored.Code()
// non-deleted is equal (restored) // non-deleted is equal (restored)
compareStateObjects(so0Restored, so0, t) compareStateObjects(so0Restored, so0, t)

File diff suppressed because it is too large Load Diff

View File

@ -21,7 +21,6 @@ import (
"encoding/binary" "encoding/binary"
"fmt" "fmt"
"math" "math"
"math/big"
"math/rand" "math/rand"
"reflect" "reflect"
"strings" "strings"
@ -30,8 +29,11 @@ import (
"testing/quick" "testing/quick"
"github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/core/rawdb"
"github.com/ethereum/go-ethereum/core/types" "github.com/ethereum/go-ethereum/core/types"
"github.com/ethereum/go-ethereum/rlp"
"github.com/holiman/uint256"
"github.com/cerc-io/ipld-eth-statedb/trie_by_cid/trie"
) )
// TestCopy tests that copying a StateDB object indeed makes the original and // TestCopy tests that copying a StateDB object indeed makes the original and
@ -39,11 +41,13 @@ import (
// https://github.com/ethereum/go-ethereum/pull/15549. // https://github.com/ethereum/go-ethereum/pull/15549.
func TestCopy(t *testing.T) { func TestCopy(t *testing.T) {
// Create a random state test to copy and modify "independently" // Create a random state test to copy and modify "independently"
orig, _ := New(common.Hash{}, NewDatabase(rawdb.NewMemoryDatabase()), nil) db, cleanup := newPgIpfsEthdb(t)
t.Cleanup(cleanup)
orig, _ := New(types.EmptyRootHash, NewDatabase(db), nil)
for i := byte(0); i < 255; i++ { for i := byte(0); i < 255; i++ {
obj := orig.GetOrNewStateObject(common.BytesToAddress([]byte{i})) obj := orig.getOrNewStateObject(common.BytesToAddress([]byte{i}))
obj.AddBalance(big.NewInt(int64(i))) obj.AddBalance(uint256.NewInt(uint64(i)))
orig.updateStateObject(obj) orig.updateStateObject(obj)
} }
orig.Finalise(false) orig.Finalise(false)
@ -56,13 +60,13 @@ func TestCopy(t *testing.T) {
// modify all in memory // modify all in memory
for i := byte(0); i < 255; i++ { for i := byte(0); i < 255; i++ {
origObj := orig.GetOrNewStateObject(common.BytesToAddress([]byte{i})) origObj := orig.getOrNewStateObject(common.BytesToAddress([]byte{i}))
copyObj := copy.GetOrNewStateObject(common.BytesToAddress([]byte{i})) copyObj := copy.getOrNewStateObject(common.BytesToAddress([]byte{i}))
ccopyObj := ccopy.GetOrNewStateObject(common.BytesToAddress([]byte{i})) ccopyObj := ccopy.getOrNewStateObject(common.BytesToAddress([]byte{i}))
origObj.AddBalance(big.NewInt(2 * int64(i))) origObj.AddBalance(uint256.NewInt(2 * uint64(i)))
copyObj.AddBalance(big.NewInt(3 * int64(i))) copyObj.AddBalance(uint256.NewInt(3 * uint64(i)))
ccopyObj.AddBalance(big.NewInt(4 * int64(i))) ccopyObj.AddBalance(uint256.NewInt(4 * uint64(i)))
orig.updateStateObject(origObj) orig.updateStateObject(origObj)
copy.updateStateObject(copyObj) copy.updateStateObject(copyObj)
@ -84,25 +88,34 @@ func TestCopy(t *testing.T) {
// Verify that the three states have been updated independently // Verify that the three states have been updated independently
for i := byte(0); i < 255; i++ { for i := byte(0); i < 255; i++ {
origObj := orig.GetOrNewStateObject(common.BytesToAddress([]byte{i})) origObj := orig.getOrNewStateObject(common.BytesToAddress([]byte{i}))
copyObj := copy.GetOrNewStateObject(common.BytesToAddress([]byte{i})) copyObj := copy.getOrNewStateObject(common.BytesToAddress([]byte{i}))
ccopyObj := ccopy.GetOrNewStateObject(common.BytesToAddress([]byte{i})) ccopyObj := ccopy.getOrNewStateObject(common.BytesToAddress([]byte{i}))
if want := big.NewInt(3 * int64(i)); origObj.Balance().Cmp(want) != 0 { if want := uint256.NewInt(3 * uint64(i)); origObj.Balance().Cmp(want) != 0 {
t.Errorf("orig obj %d: balance mismatch: have %v, want %v", i, origObj.Balance(), want) t.Errorf("orig obj %d: balance mismatch: have %v, want %v", i, origObj.Balance(), want)
} }
if want := big.NewInt(4 * int64(i)); copyObj.Balance().Cmp(want) != 0 { if want := uint256.NewInt(4 * uint64(i)); copyObj.Balance().Cmp(want) != 0 {
t.Errorf("copy obj %d: balance mismatch: have %v, want %v", i, copyObj.Balance(), want) t.Errorf("copy obj %d: balance mismatch: have %v, want %v", i, copyObj.Balance(), want)
} }
if want := big.NewInt(5 * int64(i)); ccopyObj.Balance().Cmp(want) != 0 { if want := uint256.NewInt(5 * uint64(i)); ccopyObj.Balance().Cmp(want) != 0 {
t.Errorf("copy obj %d: balance mismatch: have %v, want %v", i, ccopyObj.Balance(), want) t.Errorf("copy obj %d: balance mismatch: have %v, want %v", i, ccopyObj.Balance(), want)
} }
} }
} }
func TestSnapshotRandom(t *testing.T) { func TestSnapshotRandom(t *testing.T) {
config := &quick.Config{MaxCount: 1000} config := &quick.Config{MaxCount: 10}
err := quick.Check((*snapshotTest).run, config) i := 0
run := func(test *snapshotTest) bool {
var res bool
t.Run(fmt.Sprintf("test-%d", i), func(t *testing.T) {
res = test.run(t)
})
i++
return res
}
err := quick.Check(run, config)
if cerr, ok := err.(*quick.CheckError); ok { if cerr, ok := err.(*quick.CheckError); ok {
test := cerr.In[0].(*snapshotTest) test := cerr.In[0].(*snapshotTest)
t.Errorf("%v:\n%s", test.err, test) t.Errorf("%v:\n%s", test.err, test)
@ -142,14 +155,14 @@ func newTestAction(addr common.Address, r *rand.Rand) testAction {
{ {
name: "SetBalance", name: "SetBalance",
fn: func(a testAction, s *StateDB) { fn: func(a testAction, s *StateDB) {
s.SetBalance(addr, big.NewInt(a.args[0])) s.SetBalance(addr, uint256.NewInt(uint64(a.args[0])))
}, },
args: make([]int64, 1), args: make([]int64, 1),
}, },
{ {
name: "AddBalance", name: "AddBalance",
fn: func(a testAction, s *StateDB) { fn: func(a testAction, s *StateDB) {
s.AddBalance(addr, big.NewInt(a.args[0])) s.AddBalance(addr, uint256.NewInt(uint64(a.args[0])))
}, },
args: make([]int64, 1), args: make([]int64, 1),
}, },
@ -187,9 +200,9 @@ func newTestAction(addr common.Address, r *rand.Rand) testAction {
}, },
}, },
{ {
name: "Suicide", name: "SelfDestruct",
fn: func(a testAction, s *StateDB) { fn: func(a testAction, s *StateDB) {
s.Suicide(addr) s.SelfDestruct(addr)
}, },
}, },
{ {
@ -296,16 +309,20 @@ func (test *snapshotTest) String() string {
return out.String() return out.String()
} }
func (test *snapshotTest) run() bool { func (test *snapshotTest) run(t *testing.T) bool {
// Run all actions and create snapshots. // Run all actions and create snapshots.
db, cleanup := newPgIpfsEthdb(t)
t.Cleanup(cleanup)
var ( var (
state, _ = New(common.Hash{}, NewDatabase(rawdb.NewMemoryDatabase()), nil) state, _ = New(types.EmptyRootHash, NewDatabase(db), nil)
snapshotRevs = make([]int, len(test.snapshots)) snapshotRevs = make([]int, len(test.snapshots))
sindex = 0 sindex = 0
checkstates = make([]*StateDB, len(test.snapshots))
) )
for i, action := range test.actions { for i, action := range test.actions {
if len(test.snapshots) > sindex && i == test.snapshots[sindex] { if len(test.snapshots) > sindex && i == test.snapshots[sindex] {
snapshotRevs[sindex] = state.Snapshot() snapshotRevs[sindex] = state.Snapshot()
checkstates[sindex] = state.Copy()
sindex++ sindex++
} }
action.fn(action, state) action.fn(action, state)
@ -313,12 +330,8 @@ func (test *snapshotTest) run() bool {
// Revert all snapshots in reverse order. Each revert must yield a state // Revert all snapshots in reverse order. Each revert must yield a state
// that is equivalent to fresh state with all actions up the snapshot applied. // that is equivalent to fresh state with all actions up the snapshot applied.
for sindex--; sindex >= 0; sindex-- { for sindex--; sindex >= 0; sindex-- {
checkstate, _ := New(common.Hash{}, state.Database(), nil)
for _, action := range test.actions[:test.snapshots[sindex]] {
action.fn(action, checkstate)
}
state.RevertToSnapshot(snapshotRevs[sindex]) state.RevertToSnapshot(snapshotRevs[sindex])
if err := test.checkEqual(state, checkstate); err != nil { if err := test.checkEqual(state, checkstates[sindex]); err != nil {
test.err = fmt.Errorf("state mismatch after revert to snapshot %d\n%v", sindex, err) test.err = fmt.Errorf("state mismatch after revert to snapshot %d\n%v", sindex, err)
return false return false
} }
@ -326,6 +339,43 @@ func (test *snapshotTest) run() bool {
return true return true
} }
func forEachStorage(s *StateDB, addr common.Address, cb func(key, value common.Hash) bool) error {
so := s.getStateObject(addr)
if so == nil {
return nil
}
tr, err := so.getTrie()
if err != nil {
return err
}
trieIt, err := tr.NodeIterator(nil)
if err != nil {
return err
}
it := trie.NewIterator(trieIt)
for it.Next() {
key := common.BytesToHash(s.trie.GetKey(it.Key))
if value, dirty := so.dirtyStorage[key]; dirty {
if !cb(key, value) {
return nil
}
continue
}
if len(it.Value) > 0 {
_, content, _, err := rlp.Split(it.Value)
if err != nil {
return err
}
if !cb(key, common.BytesToHash(content)) {
return nil
}
}
}
return nil
}
// checkEqual checks that methods of state and checkstate return the same values. // checkEqual checks that methods of state and checkstate return the same values.
func (test *snapshotTest) checkEqual(state, checkstate *StateDB) error { func (test *snapshotTest) checkEqual(state, checkstate *StateDB) error {
for _, addr := range test.addrs { for _, addr := range test.addrs {
@ -339,7 +389,7 @@ func (test *snapshotTest) checkEqual(state, checkstate *StateDB) error {
} }
// Check basic accessor methods. // Check basic accessor methods.
checkeq("Exist", state.Exist(addr), checkstate.Exist(addr)) checkeq("Exist", state.Exist(addr), checkstate.Exist(addr))
checkeq("HasSuicided", state.HasSuicided(addr), checkstate.HasSuicided(addr)) checkeq("HasSelfdestructed", state.HasSelfDestructed(addr), checkstate.HasSelfDestructed(addr))
checkeq("GetBalance", state.GetBalance(addr), checkstate.GetBalance(addr)) checkeq("GetBalance", state.GetBalance(addr), checkstate.GetBalance(addr))
checkeq("GetNonce", state.GetNonce(addr), checkstate.GetNonce(addr)) checkeq("GetNonce", state.GetNonce(addr), checkstate.GetNonce(addr))
checkeq("GetCode", state.GetCode(addr), checkstate.GetCode(addr)) checkeq("GetCode", state.GetCode(addr), checkstate.GetCode(addr))
@ -347,10 +397,10 @@ func (test *snapshotTest) checkEqual(state, checkstate *StateDB) error {
checkeq("GetCodeSize", state.GetCodeSize(addr), checkstate.GetCodeSize(addr)) checkeq("GetCodeSize", state.GetCodeSize(addr), checkstate.GetCodeSize(addr))
// Check storage. // Check storage.
if obj := state.getStateObject(addr); obj != nil { if obj := state.getStateObject(addr); obj != nil {
state.ForEachStorage(addr, func(key, value common.Hash) bool { forEachStorage(state, addr, func(key, value common.Hash) bool {
return checkeq("GetState("+key.Hex()+")", checkstate.GetState(addr, key), value) return checkeq("GetState("+key.Hex()+")", checkstate.GetState(addr, key), value)
}) })
checkstate.ForEachStorage(addr, func(key, value common.Hash) bool { forEachStorage(checkstate, addr, func(key, value common.Hash) bool {
return checkeq("GetState("+key.Hex()+")", checkstate.GetState(addr, key), value) return checkeq("GetState("+key.Hex()+")", checkstate.GetState(addr, key), value)
}) })
} }
@ -373,9 +423,11 @@ func (test *snapshotTest) checkEqual(state, checkstate *StateDB) error {
// TestCopyOfCopy tests that modified objects are carried over to the copy, and the copy of the copy. // TestCopyOfCopy tests that modified objects are carried over to the copy, and the copy of the copy.
// See https://github.com/ethereum/go-ethereum/pull/15225#issuecomment-380191512 // See https://github.com/ethereum/go-ethereum/pull/15225#issuecomment-380191512
func TestCopyOfCopy(t *testing.T) { func TestCopyOfCopy(t *testing.T) {
state, _ := New(common.Hash{}, NewDatabase(rawdb.NewMemoryDatabase()), nil) db, cleanup := newPgIpfsEthdb(t)
t.Cleanup(cleanup)
state, _ := New(types.EmptyRootHash, NewDatabase(db), nil)
addr := common.HexToAddress("aaaa") addr := common.HexToAddress("aaaa")
state.SetBalance(addr, big.NewInt(42)) state.SetBalance(addr, uint256.NewInt(42))
if got := state.Copy().GetBalance(addr).Uint64(); got != 42 { if got := state.Copy().GetBalance(addr).Uint64(); got != 42 {
t.Fatalf("1st copy fail, expected 42, got %v", got) t.Fatalf("1st copy fail, expected 42, got %v", got)
@ -394,9 +446,10 @@ func TestStateDBAccessList(t *testing.T) {
return common.HexToHash(a) return common.HexToHash(a)
} }
memDb := rawdb.NewMemoryDatabase() pgdb, cleanup := newPgIpfsEthdb(t)
db := NewDatabase(memDb) t.Cleanup(cleanup)
state, _ := New(common.Hash{}, db, nil) db := NewDatabase(pgdb)
state, _ := New(types.EmptyRootHash, db, nil)
state.accessList = newAccessList() state.accessList = newAccessList()
verifyAddrs := func(astrings ...string) { verifyAddrs := func(astrings ...string) {
@ -560,9 +613,10 @@ func TestStateDBAccessList(t *testing.T) {
} }
func TestStateDBTransientStorage(t *testing.T) { func TestStateDBTransientStorage(t *testing.T) {
memDb := rawdb.NewMemoryDatabase() pgdb, cleanup := newPgIpfsEthdb(t)
db := NewDatabase(memDb) t.Cleanup(cleanup)
state, _ := New(common.Hash{}, db, nil) db := NewDatabase(pgdb)
state, _ := New(types.EmptyRootHash, db, nil)
key := common.Hash{0x01} key := common.Hash{0x01}
value := common.Hash{0x02} value := common.Hash{0x02}

View File

@ -20,8 +20,8 @@ import (
"sync" "sync"
"github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/log"
"github.com/ethereum/go-ethereum/metrics" "github.com/ethereum/go-ethereum/metrics"
log "github.com/sirupsen/logrus"
) )
var ( var (
@ -37,7 +37,7 @@ var (
type triePrefetcher struct { type triePrefetcher struct {
db Database // Database to fetch trie nodes through db Database // Database to fetch trie nodes through
root common.Hash // Root hash of the account trie for metrics root common.Hash // Root hash of the account trie for metrics
fetches map[string]Trie // Partially or fully fetcher tries fetches map[string]Trie // Partially or fully fetched tries. Only populated for inactive copies.
fetchers map[string]*subfetcher // Subfetchers for each trie fetchers map[string]*subfetcher // Subfetchers for each trie
deliveryMissMeter metrics.Meter deliveryMissMeter metrics.Meter
@ -141,7 +141,7 @@ func (p *triePrefetcher) copy() *triePrefetcher {
} }
// prefetch schedules a batch of trie items to prefetch. // prefetch schedules a batch of trie items to prefetch.
func (p *triePrefetcher) prefetch(owner common.Hash, root common.Hash, keys [][]byte) { func (p *triePrefetcher) prefetch(owner common.Hash, root common.Hash, addr common.Address, keys [][]byte) {
// If the prefetcher is an inactive one, bail out // If the prefetcher is an inactive one, bail out
if p.fetches != nil { if p.fetches != nil {
return return
@ -150,7 +150,7 @@ func (p *triePrefetcher) prefetch(owner common.Hash, root common.Hash, keys [][]
id := p.trieID(owner, root) id := p.trieID(owner, root)
fetcher := p.fetchers[id] fetcher := p.fetchers[id]
if fetcher == nil { if fetcher == nil {
fetcher = newSubfetcher(p.db, p.root, owner, root) fetcher = newSubfetcher(p.db, p.root, owner, root, addr)
p.fetchers[id] = fetcher p.fetchers[id] = fetcher
} }
fetcher.schedule(keys) fetcher.schedule(keys)
@ -197,7 +197,10 @@ func (p *triePrefetcher) used(owner common.Hash, root common.Hash, used [][]byte
// trieID returns an unique trie identifier consists the trie owner and root hash. // trieID returns an unique trie identifier consists the trie owner and root hash.
func (p *triePrefetcher) trieID(owner common.Hash, root common.Hash) string { func (p *triePrefetcher) trieID(owner common.Hash, root common.Hash) string {
return string(append(owner.Bytes(), root.Bytes()...)) trieID := make([]byte, common.HashLength*2)
copy(trieID, owner.Bytes())
copy(trieID[common.HashLength:], root.Bytes())
return string(trieID)
} }
// subfetcher is a trie fetcher goroutine responsible for pulling entries for a // subfetcher is a trie fetcher goroutine responsible for pulling entries for a
@ -209,6 +212,7 @@ type subfetcher struct {
state common.Hash // Root hash of the state to prefetch state common.Hash // Root hash of the state to prefetch
owner common.Hash // Owner of the trie, usually account hash owner common.Hash // Owner of the trie, usually account hash
root common.Hash // Root hash of the trie to prefetch root common.Hash // Root hash of the trie to prefetch
addr common.Address // Address of the account that the trie belongs to
trie Trie // Trie being populated with nodes trie Trie // Trie being populated with nodes
tasks [][]byte // Items queued up for retrieval tasks [][]byte // Items queued up for retrieval
@ -226,12 +230,13 @@ type subfetcher struct {
// newSubfetcher creates a goroutine to prefetch state items belonging to a // newSubfetcher creates a goroutine to prefetch state items belonging to a
// particular root hash. // particular root hash.
func newSubfetcher(db Database, state common.Hash, owner common.Hash, root common.Hash) *subfetcher { func newSubfetcher(db Database, state common.Hash, owner common.Hash, root common.Hash, addr common.Address) *subfetcher {
sf := &subfetcher{ sf := &subfetcher{
db: db, db: db,
state: state, state: state,
owner: owner, owner: owner,
root: root, root: root,
addr: addr,
wake: make(chan struct{}, 1), wake: make(chan struct{}, 1),
stop: make(chan struct{}), stop: make(chan struct{}),
term: make(chan struct{}), term: make(chan struct{}),
@ -300,7 +305,9 @@ func (sf *subfetcher) loop() {
} }
sf.trie = trie sf.trie = trie
} else { } else {
trie, err := sf.db.OpenStorageTrie(sf.state, sf.owner, sf.root) // The trie argument can be nil as verkle doesn't support prefetching
// yet. TODO FIX IT(rjl493456442), otherwise code will panic here.
trie, err := sf.db.OpenStorageTrie(sf.state, sf.owner, sf.root, nil)
if err != nil { if err != nil {
log.Warn("Trie prefetcher failed opening trie", "root", sf.root, "err", err) log.Warn("Trie prefetcher failed opening trie", "root", sf.root, "err", err)
return return
@ -336,7 +343,11 @@ func (sf *subfetcher) loop() {
if _, ok := sf.seen[string(task)]; ok { if _, ok := sf.seen[string(task)]; ok {
sf.dups++ sf.dups++
} else { } else {
sf.trie.TryGet(task) if len(task) == common.AddressLength {
sf.trie.GetAccount(common.BytesToAddress(task))
} else {
sf.trie.GetStorage(sf.addr, task)
}
sf.seen[string(task)] = struct{}{} sf.seen[string(task)] = struct{}{}
} }
} }

View File

@ -22,18 +22,21 @@ import (
"time" "time"
"github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/core/rawdb" "github.com/ethereum/go-ethereum/core/types"
"github.com/holiman/uint256"
) )
func filledStateDB() *StateDB { func filledStateDB(t *testing.T) *StateDB {
state, _ := New(common.Hash{}, NewDatabase(rawdb.NewMemoryDatabase()), nil) db, cleanup := newPgIpfsEthdb(t)
t.Cleanup(cleanup)
state, _ := New(types.EmptyRootHash, NewDatabase(db), nil)
// Create an account and check if the retrieved balance is correct // Create an account and check if the retrieved balance is correct
addr := common.HexToAddress("0xaffeaffeaffeaffeaffeaffeaffeaffeaffeaffe") addr := common.HexToAddress("0xaffeaffeaffeaffeaffeaffeaffeaffeaffeaffe")
skey := common.HexToHash("aaa") skey := common.HexToHash("aaa")
sval := common.HexToHash("bbb") sval := common.HexToHash("bbb")
state.SetBalance(addr, big.NewInt(42)) // Change the account trie state.SetBalance(addr, uint256.NewInt(42)) // Change the account trie
state.SetCode(addr, []byte("hello")) // Change an external metadata state.SetCode(addr, []byte("hello")) // Change an external metadata
state.SetState(addr, skey, sval) // Change the storage trie state.SetState(addr, skey, sval) // Change the storage trie
for i := 0; i < 100; i++ { for i := 0; i < 100; i++ {
@ -44,22 +47,22 @@ func filledStateDB() *StateDB {
} }
func TestCopyAndClose(t *testing.T) { func TestCopyAndClose(t *testing.T) {
db := filledStateDB() db := filledStateDB(t)
prefetcher := newTriePrefetcher(db.db, db.originalRoot, "") prefetcher := newTriePrefetcher(db.db, db.originalRoot, "")
skey := common.HexToHash("aaa") skey := common.HexToHash("aaa")
prefetcher.prefetch(common.Hash{}, db.originalRoot, [][]byte{skey.Bytes()}) prefetcher.prefetch(common.Hash{}, db.originalRoot, common.Address{}, [][]byte{skey.Bytes()})
prefetcher.prefetch(common.Hash{}, db.originalRoot, [][]byte{skey.Bytes()}) prefetcher.prefetch(common.Hash{}, db.originalRoot, common.Address{}, [][]byte{skey.Bytes()})
time.Sleep(1 * time.Second) time.Sleep(1 * time.Second)
a := prefetcher.trie(common.Hash{}, db.originalRoot) a := prefetcher.trie(common.Hash{}, db.originalRoot)
prefetcher.prefetch(common.Hash{}, db.originalRoot, [][]byte{skey.Bytes()}) prefetcher.prefetch(common.Hash{}, db.originalRoot, common.Address{}, [][]byte{skey.Bytes()})
b := prefetcher.trie(common.Hash{}, db.originalRoot) b := prefetcher.trie(common.Hash{}, db.originalRoot)
cpy := prefetcher.copy() cpy := prefetcher.copy()
cpy.prefetch(common.Hash{}, db.originalRoot, [][]byte{skey.Bytes()}) cpy.prefetch(common.Hash{}, db.originalRoot, common.Address{}, [][]byte{skey.Bytes()})
cpy.prefetch(common.Hash{}, db.originalRoot, [][]byte{skey.Bytes()}) cpy.prefetch(common.Hash{}, db.originalRoot, common.Address{}, [][]byte{skey.Bytes()})
c := cpy.trie(common.Hash{}, db.originalRoot) c := cpy.trie(common.Hash{}, db.originalRoot)
prefetcher.close() prefetcher.close()
cpy2 := cpy.copy() cpy2 := cpy.copy()
cpy2.prefetch(common.Hash{}, db.originalRoot, [][]byte{skey.Bytes()}) cpy2.prefetch(common.Hash{}, db.originalRoot, common.Address{}, [][]byte{skey.Bytes()})
d := cpy2.trie(common.Hash{}, db.originalRoot) d := cpy2.trie(common.Hash{}, db.originalRoot)
cpy.close() cpy.close()
cpy2.close() cpy2.close()
@ -69,10 +72,10 @@ func TestCopyAndClose(t *testing.T) {
} }
func TestUseAfterClose(t *testing.T) { func TestUseAfterClose(t *testing.T) {
db := filledStateDB() db := filledStateDB(t)
prefetcher := newTriePrefetcher(db.db, db.originalRoot, "") prefetcher := newTriePrefetcher(db.db, db.originalRoot, "")
skey := common.HexToHash("aaa") skey := common.HexToHash("aaa")
prefetcher.prefetch(common.Hash{}, db.originalRoot, [][]byte{skey.Bytes()}) prefetcher.prefetch(common.Hash{}, db.originalRoot, common.Address{}, [][]byte{skey.Bytes()})
a := prefetcher.trie(common.Hash{}, db.originalRoot) a := prefetcher.trie(common.Hash{}, db.originalRoot)
prefetcher.close() prefetcher.close()
b := prefetcher.trie(common.Hash{}, db.originalRoot) b := prefetcher.trie(common.Hash{}, db.originalRoot)
@ -85,10 +88,10 @@ func TestUseAfterClose(t *testing.T) {
} }
func TestCopyClose(t *testing.T) { func TestCopyClose(t *testing.T) {
db := filledStateDB() db := filledStateDB(t)
prefetcher := newTriePrefetcher(db.db, db.originalRoot, "") prefetcher := newTriePrefetcher(db.db, db.originalRoot, "")
skey := common.HexToHash("aaa") skey := common.HexToHash("aaa")
prefetcher.prefetch(common.Hash{}, db.originalRoot, [][]byte{skey.Bytes()}) prefetcher.prefetch(common.Hash{}, db.originalRoot, common.Address{}, [][]byte{skey.Bytes()})
cpy := prefetcher.copy() cpy := prefetcher.copy()
a := prefetcher.trie(common.Hash{}, db.originalRoot) a := prefetcher.trie(common.Hash{}, db.originalRoot)
b := cpy.trie(common.Hash{}, db.originalRoot) b := cpy.trie(common.Hash{}, db.originalRoot)

View File

@ -20,26 +20,24 @@ import (
"fmt" "fmt"
"github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/common"
)
// leaf represents a trie leaf node "github.com/cerc-io/ipld-eth-statedb/trie_by_cid/trie/trienode"
type leaf struct { )
blob []byte // raw blob of leaf
parent common.Hash // the hash of parent node
}
// committer is the tool used for the trie Commit operation. The committer will // committer is the tool used for the trie Commit operation. The committer will
// capture all dirty nodes during the commit process and keep them cached in // capture all dirty nodes during the commit process and keep them cached in
// insertion order. // insertion order.
type committer struct { type committer struct {
nodes *NodeSet nodes *trienode.NodeSet
tracer *tracer
collectLeaf bool collectLeaf bool
} }
// newCommitter creates a new committer or picks one from the pool. // newCommitter creates a new committer or picks one from the pool.
func newCommitter(nodeset *NodeSet, collectLeaf bool) *committer { func newCommitter(nodeset *trienode.NodeSet, tracer *tracer, collectLeaf bool) *committer {
return &committer{ return &committer{
nodes: nodeset, nodes: nodeset,
tracer: tracer,
collectLeaf: collectLeaf, collectLeaf: collectLeaf,
} }
} }
@ -134,24 +132,15 @@ func (c *committer) store(path []byte, n node) node {
// The node is embedded in its parent, in other words, this node // The node is embedded in its parent, in other words, this node
// will not be stored in the database independently, mark it as // will not be stored in the database independently, mark it as
// deleted only if the node was existent in database before. // deleted only if the node was existent in database before.
if _, ok := c.nodes.accessList[string(path)]; ok { _, ok := c.tracer.accessList[string(path)]
c.nodes.markDeleted(path) if ok {
c.nodes.AddNode(path, trienode.NewDeleted())
} }
return n return n
} }
// We have the hash already, estimate the RLP encoding-size of the node.
// The size is used for mem tracking, does not need to be exact
var (
size = estimateSize(n)
nhash = common.BytesToHash(hash)
mnode = &memoryNode{
hash: nhash,
node: simplifyNode(n),
size: uint16(size),
}
)
// Collect the dirty node to nodeset for return. // Collect the dirty node to nodeset for return.
c.nodes.markUpdated(path, mnode) nhash := common.BytesToHash(hash)
c.nodes.AddNode(path, trienode.New(nhash, nodeToBytes(n)))
// Collect the corresponding leaf node if it's required. We don't check // Collect the corresponding leaf node if it's required. We don't check
// full node since it's impossible to store value in fullNode. The key // full node since it's impossible to store value in fullNode. The key
@ -159,38 +148,36 @@ func (c *committer) store(path []byte, n node) node {
if c.collectLeaf { if c.collectLeaf {
if sn, ok := n.(*shortNode); ok { if sn, ok := n.(*shortNode); ok {
if val, ok := sn.Val.(valueNode); ok { if val, ok := sn.Val.(valueNode); ok {
c.nodes.addLeaf(&leaf{blob: val, parent: nhash}) c.nodes.AddLeaf(nhash, val)
} }
} }
} }
return hash return hash
} }
// estimateSize estimates the size of an rlp-encoded node, without actually // MerkleResolver the children resolver in merkle-patricia-tree.
// rlp-encoding it (zero allocs). This method has been experimentally tried, and with a trie type MerkleResolver struct{}
// with 1000 leaves, the only errors above 1% are on small shortnodes, where this
// method overestimates by 2 or 3 bytes (e.g. 37 instead of 35) // ForEach implements childResolver, decodes the provided node and
func estimateSize(n node) int { // traverses the children inside.
func (resolver MerkleResolver) ForEach(node []byte, onChild func(common.Hash)) {
forGatherChildren(mustDecodeNodeUnsafe(nil, node), onChild)
}
// forGatherChildren traverses the node hierarchy and invokes the callback
// for all the hashnode children.
func forGatherChildren(n node, onChild func(hash common.Hash)) {
switch n := n.(type) { switch n := n.(type) {
case *shortNode: case *shortNode:
// A short node contains a compacted key, and a value. forGatherChildren(n.Val, onChild)
return 3 + len(n.Key) + estimateSize(n.Val)
case *fullNode: case *fullNode:
// A full node contains up to 16 hashes (some nils), and a key
s := 3
for i := 0; i < 16; i++ { for i := 0; i < 16; i++ {
if child := n.Children[i]; child != nil { forGatherChildren(n.Children[i], onChild)
s += estimateSize(child)
} else {
s++
} }
}
return s
case valueNode:
return 1 + len(n)
case hashNode: case hashNode:
return 1 + len(n) onChild(common.BytesToHash(n))
case valueNode, nil:
default: default:
panic(fmt.Sprintf("node type %T", n)) panic(fmt.Sprintf("unknown node type: %T", n))
} }
} }

View File

@ -1,440 +0,0 @@
// Copyright 2018 The go-ethereum Authors
// This file is part of the go-ethereum library.
//
// The go-ethereum library is free software: you can redistribute it and/or modify
// it under the terms of the GNU Lesser General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// The go-ethereum library is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Lesser General Public License for more details.
//
// You should have received a copy of the GNU Lesser General Public License
// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
package trie
import (
"errors"
"runtime"
"sync"
"time"
"github.com/VictoriaMetrics/fastcache"
"github.com/cerc-io/ipld-eth-statedb/internal"
"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/core/rawdb"
"github.com/ethereum/go-ethereum/core/types"
"github.com/ethereum/go-ethereum/ethdb"
"github.com/ethereum/go-ethereum/rlp"
"github.com/ethereum/go-ethereum/trie"
log "github.com/sirupsen/logrus"
)
// Database is an intermediate write layer between the trie data structures and
// the disk database. The aim is to accumulate trie writes in-memory and only
// periodically flush a couple tries to disk, garbage collecting the remainder.
//
// Note, the trie Database is **not** thread safe in its mutations, but it **is**
// thread safe in providing individual, independent node access. The rationale
// behind this split design is to provide read access to RPC handlers and sync
// servers even while the trie is executing expensive garbage collection.
type Database struct {
diskdb ethdb.Database // Persistent storage for matured trie nodes
cleans *fastcache.Cache // GC friendly memory cache of clean node RLPs
dirties map[common.Hash]*cachedNode // Data and references relationships of dirty trie nodes
oldest common.Hash // Oldest tracked node, flush-list head
newest common.Hash // Newest tracked node, flush-list tail
gctime time.Duration // Time spent on garbage collection since last commit
gcnodes uint64 // Nodes garbage collected since last commit
gcsize common.StorageSize // Data storage garbage collected since last commit
flushtime time.Duration // Time spent on data flushing since last commit
flushnodes uint64 // Nodes flushed since last commit
flushsize common.StorageSize // Data storage flushed since last commit
dirtiesSize common.StorageSize // Storage size of the dirty node cache (exc. metadata)
childrenSize common.StorageSize // Storage size of the external children tracking
preimages *preimageStore // The store for caching preimages
lock sync.RWMutex
}
// Config defines all necessary options for database.
// (re-export)
type Config = trie.Config
// NewDatabase creates a new trie database to store ephemeral trie content before
// its written out to disk or garbage collected. No read cache is created, so all
// data retrievals will hit the underlying disk database.
func NewDatabase(diskdb ethdb.Database) *Database {
return NewDatabaseWithConfig(diskdb, nil)
}
// NewDatabaseWithConfig creates a new trie database to store ephemeral trie content
// before its written out to disk or garbage collected. It also acts as a read cache
// for nodes loaded from disk.
func NewDatabaseWithConfig(diskdb ethdb.Database, config *Config) *Database {
var cleans *fastcache.Cache
if config != nil && config.Cache > 0 {
if config.Journal == "" {
cleans = fastcache.New(config.Cache * 1024 * 1024)
} else {
cleans = fastcache.LoadFromFileOrNew(config.Journal, config.Cache*1024*1024)
}
}
var preimage *preimageStore
if config != nil && config.Preimages {
preimage = newPreimageStore(diskdb)
}
db := &Database{
diskdb: diskdb,
cleans: cleans,
dirties: map[common.Hash]*cachedNode{{}: {
children: make(map[common.Hash]uint16),
}},
preimages: preimage,
}
return db
}
// insert inserts a simplified trie node into the memory database.
// All nodes inserted by this function will be reference tracked
// and in theory should only used for **trie nodes** insertion.
func (db *Database) insert(hash common.Hash, size int, node node) {
// If the node's already cached, skip
if _, ok := db.dirties[hash]; ok {
return
}
memcacheDirtyWriteMeter.Mark(int64(size))
// Create the cached entry for this node
entry := &cachedNode{
node: node,
size: uint16(size),
flushPrev: db.newest,
}
entry.forChilds(func(child common.Hash) {
if c := db.dirties[child]; c != nil {
c.parents++
}
})
db.dirties[hash] = entry
// Update the flush-list endpoints
if db.oldest == (common.Hash{}) {
db.oldest, db.newest = hash, hash
} else {
db.dirties[db.newest].flushNext, db.newest = hash, hash
}
db.dirtiesSize += common.StorageSize(common.HashLength + entry.size)
}
// Node retrieves an encoded cached trie node from memory. If it cannot be found
// cached, the method queries the persistent database for the content.
func (db *Database) Node(hash common.Hash, codec uint64) ([]byte, error) {
// It doesn't make sense to retrieve the metaroot
if hash == (common.Hash{}) {
return nil, errors.New("not found")
}
// Retrieve the node from the clean cache if available
if db.cleans != nil {
if enc := db.cleans.Get(nil, hash[:]); enc != nil {
memcacheCleanHitMeter.Mark(1)
memcacheCleanReadMeter.Mark(int64(len(enc)))
return enc, nil
}
}
// Retrieve the node from the dirty cache if available
db.lock.RLock()
dirty := db.dirties[hash]
db.lock.RUnlock()
if dirty != nil {
memcacheDirtyHitMeter.Mark(1)
memcacheDirtyReadMeter.Mark(int64(dirty.size))
return dirty.rlp(), nil
}
memcacheDirtyMissMeter.Mark(1)
// Content unavailable in memory, attempt to retrieve from disk
cid, err := internal.Keccak256ToCid(codec, hash[:])
if err != nil {
return nil, err
}
enc, err := db.diskdb.Get(cid.Bytes())
if err != nil {
return nil, err
}
if len(enc) != 0 {
if db.cleans != nil {
db.cleans.Set(hash[:], enc)
memcacheCleanMissMeter.Mark(1)
memcacheCleanWriteMeter.Mark(int64(len(enc)))
}
return enc, nil
}
return nil, errors.New("not found")
}
// Nodes retrieves the hashes of all the nodes cached within the memory database.
// This method is extremely expensive and should only be used to validate internal
// states in test code.
func (db *Database) Nodes() []common.Hash {
db.lock.RLock()
defer db.lock.RUnlock()
var hashes = make([]common.Hash, 0, len(db.dirties))
for hash := range db.dirties {
if hash != (common.Hash{}) { // Special case for "root" references/nodes
hashes = append(hashes, hash)
}
}
return hashes
}
// Reference adds a new reference from a parent node to a child node.
// This function is used to add reference between internal trie node
// and external node(e.g. storage trie root), all internal trie nodes
// are referenced together by database itself.
func (db *Database) Reference(child common.Hash, parent common.Hash) {
db.lock.Lock()
defer db.lock.Unlock()
db.reference(child, parent)
}
// reference is the private locked version of Reference.
func (db *Database) reference(child common.Hash, parent common.Hash) {
// If the node does not exist, it's a node pulled from disk, skip
node, ok := db.dirties[child]
if !ok {
return
}
// If the reference already exists, only duplicate for roots
if db.dirties[parent].children == nil {
db.dirties[parent].children = make(map[common.Hash]uint16)
db.childrenSize += cachedNodeChildrenSize
} else if _, ok = db.dirties[parent].children[child]; ok && parent != (common.Hash{}) {
return
}
node.parents++
db.dirties[parent].children[child]++
if db.dirties[parent].children[child] == 1 {
db.childrenSize += common.HashLength + 2 // uint16 counter
}
}
// Dereference removes an existing reference from a root node.
func (db *Database) Dereference(root common.Hash) {
// Sanity check to ensure that the meta-root is not removed
if root == (common.Hash{}) {
log.Error("Attempted to dereference the trie cache meta root")
return
}
db.lock.Lock()
defer db.lock.Unlock()
nodes, storage, start := len(db.dirties), db.dirtiesSize, time.Now()
db.dereference(root, common.Hash{})
db.gcnodes += uint64(nodes - len(db.dirties))
db.gcsize += storage - db.dirtiesSize
db.gctime += time.Since(start)
memcacheGCTimeTimer.Update(time.Since(start))
memcacheGCSizeMeter.Mark(int64(storage - db.dirtiesSize))
memcacheGCNodesMeter.Mark(int64(nodes - len(db.dirties)))
log.Debug("Dereferenced trie from memory database", "nodes", nodes-len(db.dirties), "size", storage-db.dirtiesSize, "time", time.Since(start),
"gcnodes", db.gcnodes, "gcsize", db.gcsize, "gctime", db.gctime, "livenodes", len(db.dirties), "livesize", db.dirtiesSize)
}
// dereference is the private locked version of Dereference.
func (db *Database) dereference(child common.Hash, parent common.Hash) {
// Dereference the parent-child
node := db.dirties[parent]
if node.children != nil && node.children[child] > 0 {
node.children[child]--
if node.children[child] == 0 {
delete(node.children, child)
db.childrenSize -= (common.HashLength + 2) // uint16 counter
}
}
// If the child does not exist, it's a previously committed node.
node, ok := db.dirties[child]
if !ok {
return
}
// If there are no more references to the child, delete it and cascade
if node.parents > 0 {
// This is a special cornercase where a node loaded from disk (i.e. not in the
// memcache any more) gets reinjected as a new node (short node split into full,
// then reverted into short), causing a cached node to have no parents. That is
// no problem in itself, but don't make maxint parents out of it.
node.parents--
}
if node.parents == 0 {
// Remove the node from the flush-list
switch child {
case db.oldest:
db.oldest = node.flushNext
db.dirties[node.flushNext].flushPrev = common.Hash{}
case db.newest:
db.newest = node.flushPrev
db.dirties[node.flushPrev].flushNext = common.Hash{}
default:
db.dirties[node.flushPrev].flushNext = node.flushNext
db.dirties[node.flushNext].flushPrev = node.flushPrev
}
// Dereference all children and delete the node
node.forChilds(func(hash common.Hash) {
db.dereference(hash, child)
})
delete(db.dirties, child)
db.dirtiesSize -= common.StorageSize(common.HashLength + int(node.size))
if node.children != nil {
db.childrenSize -= cachedNodeChildrenSize
}
}
}
// Update inserts the dirty nodes in provided nodeset into database and
// link the account trie with multiple storage tries if necessary.
func (db *Database) Update(nodes *MergedNodeSet) error {
db.lock.Lock()
defer db.lock.Unlock()
// Insert dirty nodes into the database. In the same tree, it must be
// ensured that children are inserted first, then parent so that children
// can be linked with their parent correctly.
//
// Note, the storage tries must be flushed before the account trie to
// retain the invariant that children go into the dirty cache first.
var order []common.Hash
for owner := range nodes.sets {
if owner == (common.Hash{}) {
continue
}
order = append(order, owner)
}
if _, ok := nodes.sets[common.Hash{}]; ok {
order = append(order, common.Hash{})
}
for _, owner := range order {
subset := nodes.sets[owner]
subset.forEachWithOrder(func(path string, n *memoryNode) {
if n.isDeleted() {
return // ignore deletion
}
db.insert(n.hash, int(n.size), n.node)
})
}
// Link up the account trie and storage trie if the node points
// to an account trie leaf.
if set, present := nodes.sets[common.Hash{}]; present {
for _, n := range set.leaves {
var account types.StateAccount
if err := rlp.DecodeBytes(n.blob, &account); err != nil {
return err
}
if account.Root != types.EmptyRootHash {
db.reference(account.Root, n.parent)
}
}
}
return nil
}
// Size returns the current storage size of the memory cache in front of the
// persistent database layer.
func (db *Database) Size() (common.StorageSize, common.StorageSize) {
db.lock.RLock()
defer db.lock.RUnlock()
// db.dirtiesSize only contains the useful data in the cache, but when reporting
// the total memory consumption, the maintenance metadata is also needed to be
// counted.
var metadataSize = common.StorageSize((len(db.dirties) - 1) * cachedNodeSize)
var metarootRefs = common.StorageSize(len(db.dirties[common.Hash{}].children) * (common.HashLength + 2))
var preimageSize common.StorageSize
if db.preimages != nil {
preimageSize = db.preimages.size()
}
return db.dirtiesSize + db.childrenSize + metadataSize - metarootRefs, preimageSize
}
// GetReader retrieves a node reader belonging to the given state root.
func (db *Database) GetReader(root common.Hash, codec uint64) Reader {
return &hashReader{db: db, codec: codec}
}
// hashReader is reader of hashDatabase which implements the Reader interface.
type hashReader struct {
db *Database
codec uint64
}
// Node retrieves the trie node with the given node hash.
func (reader *hashReader) Node(owner common.Hash, path []byte, hash common.Hash) (node, error) {
blob, err := reader.NodeBlob(owner, path, hash)
if err != nil {
return nil, err
}
return decodeNodeUnsafe(hash[:], blob)
}
// NodeBlob retrieves the RLP-encoded trie node blob with the given node hash.
func (reader *hashReader) NodeBlob(_ common.Hash, _ []byte, hash common.Hash) ([]byte, error) {
return reader.db.Node(hash, reader.codec)
}
// saveCache saves clean state cache to given directory path
// using specified CPU cores.
func (db *Database) saveCache(dir string, threads int) error {
if db.cleans == nil {
return nil
}
log.Info("Writing clean trie cache to disk", "path", dir, "threads", threads)
start := time.Now()
err := db.cleans.SaveToFileConcurrent(dir, threads)
if err != nil {
log.Error("Failed to persist clean trie cache", "error", err)
return err
}
log.Info("Persisted the clean trie cache", "path", dir, "elapsed", common.PrettyDuration(time.Since(start)))
return nil
}
// SaveCache atomically saves fast cache data to the given dir using all
// available CPU cores.
func (db *Database) SaveCache(dir string) error {
return db.saveCache(dir, runtime.GOMAXPROCS(0))
}
// SaveCachePeriodically atomically saves fast cache data to the given dir with
// the specified interval. All dump operation will only use a single CPU core.
func (db *Database) SaveCachePeriodically(dir string, interval time.Duration, stopCh <-chan struct{}) {
ticker := time.NewTicker(interval)
defer ticker.Stop()
for {
select {
case <-ticker.C:
db.saveCache(dir, 1)
case <-stopCh:
return
}
}
}
// Scheme returns the node scheme used in the database.
func (db *Database) Scheme() string {
return rawdb.HashScheme
}

View File

@ -1,27 +0,0 @@
package trie
import "github.com/ethereum/go-ethereum/metrics"
var (
memcacheCleanHitMeter = metrics.NewRegisteredMeter("trie/memcache/clean/hit", nil)
memcacheCleanMissMeter = metrics.NewRegisteredMeter("trie/memcache/clean/miss", nil)
memcacheCleanReadMeter = metrics.NewRegisteredMeter("trie/memcache/clean/read", nil)
memcacheCleanWriteMeter = metrics.NewRegisteredMeter("trie/memcache/clean/write", nil)
memcacheDirtyHitMeter = metrics.NewRegisteredMeter("trie/memcache/dirty/hit", nil)
memcacheDirtyMissMeter = metrics.NewRegisteredMeter("trie/memcache/dirty/miss", nil)
memcacheDirtyReadMeter = metrics.NewRegisteredMeter("trie/memcache/dirty/read", nil)
memcacheDirtyWriteMeter = metrics.NewRegisteredMeter("trie/memcache/dirty/write", nil)
memcacheFlushTimeTimer = metrics.NewRegisteredResettingTimer("trie/memcache/flush/time", nil)
memcacheFlushNodesMeter = metrics.NewRegisteredMeter("trie/memcache/flush/nodes", nil)
memcacheFlushSizeMeter = metrics.NewRegisteredMeter("trie/memcache/flush/size", nil)
memcacheGCTimeTimer = metrics.NewRegisteredResettingTimer("trie/memcache/gc/time", nil)
memcacheGCNodesMeter = metrics.NewRegisteredMeter("trie/memcache/gc/nodes", nil)
memcacheGCSizeMeter = metrics.NewRegisteredMeter("trie/memcache/gc/size", nil)
memcacheCommitTimeTimer = metrics.NewRegisteredResettingTimer("trie/memcache/commit/time", nil)
memcacheCommitNodesMeter = metrics.NewRegisteredMeter("trie/memcache/commit/nodes", nil)
memcacheCommitSizeMeter = metrics.NewRegisteredMeter("trie/memcache/commit/size", nil)
)

View File

@ -1,183 +0,0 @@
package trie
import (
"fmt"
"io"
"reflect"
"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/rlp"
)
// rawNode is a simple binary blob used to differentiate between collapsed trie
// nodes and already encoded RLP binary blobs (while at the same time store them
// in the same cache fields).
type rawNode []byte
func (n rawNode) cache() (hashNode, bool) { panic("this should never end up in a live trie") }
func (n rawNode) fstring(ind string) string { panic("this should never end up in a live trie") }
func (n rawNode) EncodeRLP(w io.Writer) error {
_, err := w.Write(n)
return err
}
// rawFullNode represents only the useful data content of a full node, with the
// caches and flags stripped out to minimize its data storage. This type honors
// the same RLP encoding as the original parent.
type rawFullNode [17]node
func (n rawFullNode) cache() (hashNode, bool) { panic("this should never end up in a live trie") }
func (n rawFullNode) fstring(ind string) string { panic("this should never end up in a live trie") }
func (n rawFullNode) EncodeRLP(w io.Writer) error {
eb := rlp.NewEncoderBuffer(w)
n.encode(eb)
return eb.Flush()
}
// rawShortNode represents only the useful data content of a short node, with the
// caches and flags stripped out to minimize its data storage. This type honors
// the same RLP encoding as the original parent.
type rawShortNode struct {
Key []byte
Val node
}
func (n rawShortNode) cache() (hashNode, bool) { panic("this should never end up in a live trie") }
func (n rawShortNode) fstring(ind string) string { panic("this should never end up in a live trie") }
// cachedNode is all the information we know about a single cached trie node
// in the memory database write layer.
type cachedNode struct {
node node // Cached collapsed trie node, or raw rlp data
size uint16 // Byte size of the useful cached data
parents uint32 // Number of live nodes referencing this one
children map[common.Hash]uint16 // External children referenced by this node
flushPrev common.Hash // Previous node in the flush-list
flushNext common.Hash // Next node in the flush-list
}
// cachedNodeSize is the raw size of a cachedNode data structure without any
// node data included. It's an approximate size, but should be a lot better
// than not counting them.
var cachedNodeSize = int(reflect.TypeOf(cachedNode{}).Size())
// cachedNodeChildrenSize is the raw size of an initialized but empty external
// reference map.
const cachedNodeChildrenSize = 48
// rlp returns the raw rlp encoded blob of the cached trie node, either directly
// from the cache, or by regenerating it from the collapsed node.
func (n *cachedNode) rlp() []byte {
if node, ok := n.node.(rawNode); ok {
return node
}
return nodeToBytes(n.node)
}
// obj returns the decoded and expanded trie node, either directly from the cache,
// or by regenerating it from the rlp encoded blob.
func (n *cachedNode) obj(hash common.Hash) node {
if node, ok := n.node.(rawNode); ok {
// The raw-blob format nodes are loaded either from the
// clean cache or the database, they are all in their own
// copy and safe to use unsafe decoder.
return mustDecodeNodeUnsafe(hash[:], node)
}
return expandNode(hash[:], n.node)
}
// forChilds invokes the callback for all the tracked children of this node,
// both the implicit ones from inside the node as well as the explicit ones
// from outside the node.
func (n *cachedNode) forChilds(onChild func(hash common.Hash)) {
for child := range n.children {
onChild(child)
}
if _, ok := n.node.(rawNode); !ok {
forGatherChildren(n.node, onChild)
}
}
// forGatherChildren traverses the node hierarchy of a collapsed storage node and
// invokes the callback for all the hashnode children.
func forGatherChildren(n node, onChild func(hash common.Hash)) {
switch n := n.(type) {
case *rawShortNode:
forGatherChildren(n.Val, onChild)
case rawFullNode:
for i := 0; i < 16; i++ {
forGatherChildren(n[i], onChild)
}
case hashNode:
onChild(common.BytesToHash(n))
case valueNode, nil, rawNode:
default:
panic(fmt.Sprintf("unknown node type: %T", n))
}
}
// simplifyNode traverses the hierarchy of an expanded memory node and discards
// all the internal caches, returning a node that only contains the raw data.
func simplifyNode(n node) node {
switch n := n.(type) {
case *shortNode:
// Short nodes discard the flags and cascade
return &rawShortNode{Key: n.Key, Val: simplifyNode(n.Val)}
case *fullNode:
// Full nodes discard the flags and cascade
node := rawFullNode(n.Children)
for i := 0; i < len(node); i++ {
if node[i] != nil {
node[i] = simplifyNode(node[i])
}
}
return node
case valueNode, hashNode, rawNode:
return n
default:
panic(fmt.Sprintf("unknown node type: %T", n))
}
}
// expandNode traverses the node hierarchy of a collapsed storage node and converts
// all fields and keys into expanded memory form.
func expandNode(hash hashNode, n node) node {
switch n := n.(type) {
case *rawShortNode:
// Short nodes need key and child expansion
return &shortNode{
Key: compactToHex(n.Key),
Val: expandNode(nil, n.Val),
flags: nodeFlag{
hash: hash,
},
}
case rawFullNode:
// Full nodes need child expansion
node := &fullNode{
flags: nodeFlag{
hash: hash,
},
}
for i := 0; i < len(node.Children); i++ {
if n[i] != nil {
node.Children[i] = expandNode(nil, n[i])
}
}
return node
case valueNode, hashNode:
return n
default:
panic(fmt.Sprintf("unknown node type: %T", n))
}
}

View File

@ -17,17 +17,137 @@
package trie package trie
import ( import (
"testing"
"github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/core/rawdb" "github.com/ethereum/go-ethereum/core/rawdb"
"github.com/ethereum/go-ethereum/core/types"
"github.com/ethereum/go-ethereum/ethdb"
"github.com/cerc-io/ipld-eth-statedb/trie_by_cid/trie/trienode"
"github.com/cerc-io/ipld-eth-statedb/trie_by_cid/triedb/database"
) )
// Tests that the trie database returns a missing trie node error if attempting // testReader implements database.Reader interface, providing function to
// to retrieve the meta root. // access trie nodes.
func TestDatabaseMetarootFetch(t *testing.T) { type testReader struct {
db := NewDatabase(rawdb.NewMemoryDatabase()) db ethdb.Database
if _, err := db.Node(common.Hash{}, StateTrieCodec); err == nil { scheme string
t.Fatalf("metaroot retrieval succeeded") nodes []*trienode.MergedNodeSet // sorted from new to old
}
// Node implements database.Reader interface, retrieving trie node with
// all available cached layers.
func (r *testReader) Node(owner common.Hash, path []byte, hash common.Hash) ([]byte, error) {
// Check the node presence with the cached layer, from latest to oldest.
for _, nodes := range r.nodes {
if _, ok := nodes.Sets[owner]; !ok {
continue
}
n, ok := nodes.Sets[owner].Nodes[string(path)]
if !ok {
continue
}
if n.IsDeleted() || n.Hash != hash {
return nil, &MissingNodeError{Owner: owner, Path: path, NodeHash: hash}
}
return n.Blob, nil
}
// Check the node presence in database.
return rawdb.ReadTrieNode(r.db, owner, path, hash, r.scheme), nil
}
// testDb implements database.Database interface, using for testing purpose.
type testDb struct {
disk ethdb.Database
root common.Hash
scheme string
nodes map[common.Hash]*trienode.MergedNodeSet
parents map[common.Hash]common.Hash
}
func newTestDatabase(diskdb ethdb.Database, scheme string) *testDb {
return &testDb{
disk: diskdb,
root: types.EmptyRootHash,
scheme: scheme,
nodes: make(map[common.Hash]*trienode.MergedNodeSet),
parents: make(map[common.Hash]common.Hash),
} }
} }
func (db *testDb) Reader(stateRoot common.Hash) (database.Reader, error) {
nodes, _ := db.dirties(stateRoot, true)
return &testReader{db: db.disk, scheme: db.scheme, nodes: nodes}, nil
}
func (db *testDb) Preimage(hash common.Hash) []byte {
return rawdb.ReadPreimage(db.disk, hash)
}
func (db *testDb) InsertPreimage(preimages map[common.Hash][]byte) {
rawdb.WritePreimages(db.disk, preimages)
}
func (db *testDb) Scheme() string { return db.scheme }
func (db *testDb) Update(root common.Hash, parent common.Hash, nodes *trienode.MergedNodeSet) error {
if root == parent {
return nil
}
if _, ok := db.nodes[root]; ok {
return nil
}
db.parents[root] = parent
db.nodes[root] = nodes
return nil
}
func (db *testDb) dirties(root common.Hash, topToBottom bool) ([]*trienode.MergedNodeSet, []common.Hash) {
var (
pending []*trienode.MergedNodeSet
roots []common.Hash
)
for {
if root == db.root {
break
}
nodes, ok := db.nodes[root]
if !ok {
break
}
if topToBottom {
pending = append(pending, nodes)
roots = append(roots, root)
} else {
pending = append([]*trienode.MergedNodeSet{nodes}, pending...)
roots = append([]common.Hash{root}, roots...)
}
root = db.parents[root]
}
return pending, roots
}
func (db *testDb) Commit(root common.Hash) error {
if root == db.root {
return nil
}
pending, roots := db.dirties(root, false)
for i, nodes := range pending {
for owner, set := range nodes.Sets {
if owner == (common.Hash{}) {
continue
}
set.ForEachWithOrder(func(path string, n *trienode.Node) {
rawdb.WriteTrieNode(db.disk, owner, []byte(path), n.Hash, n.Blob, db.scheme)
})
}
nodes.Sets[common.Hash{}].ForEachWithOrder(func(path string, n *trienode.Node) {
rawdb.WriteTrieNode(db.disk, common.Hash{}, []byte(path), n.Hash, n.Blob, db.scheme)
})
db.root = roots[i]
}
for _, root := range roots {
delete(db.nodes, root)
delete(db.parents, root)
}
return nil
}

View File

@ -34,11 +34,6 @@ package trie
// in the case of an odd number. All remaining nibbles (now an even number) fit properly // in the case of an odd number. All remaining nibbles (now an even number) fit properly
// into the remaining bytes. Compact encoding is used for nodes stored on disk. // into the remaining bytes. Compact encoding is used for nodes stored on disk.
// HexToCompact converts a hex path to the compact encoded format
func HexToCompact(hex []byte) []byte {
return hexToCompact(hex)
}
func hexToCompact(hex []byte) []byte { func hexToCompact(hex []byte) []byte {
terminator := byte(0) terminator := byte(0)
if hasTerm(hex) { if hasTerm(hex) {
@ -56,9 +51,8 @@ func hexToCompact(hex []byte) []byte {
return buf return buf
} }
// hexToCompactInPlace places the compact key in input buffer, returning the length // hexToCompactInPlace places the compact key in input buffer, returning the compacted key.
// needed for the representation func hexToCompactInPlace(hex []byte) []byte {
func hexToCompactInPlace(hex []byte) int {
var ( var (
hexLen = len(hex) // length of the hex input hexLen = len(hex) // length of the hex input
firstByte = byte(0) firstByte = byte(0)
@ -82,12 +76,7 @@ func hexToCompactInPlace(hex []byte) int {
hex[bi] = hex[ni]<<4 | hex[ni+1] hex[bi] = hex[ni]<<4 | hex[ni+1]
} }
hex[0] = firstByte hex[0] = firstByte
return binLen return hex[:binLen]
}
// CompactToHex converts a compact encoded path to hex format
func CompactToHex(compact []byte) []byte {
return compactToHex(compact)
} }
func compactToHex(compact []byte) []byte { func compactToHex(compact []byte) []byte {
@ -115,9 +104,9 @@ func keybytesToHex(str []byte) []byte {
return nibbles return nibbles
} }
// hexToKeyBytes turns hex nibbles into key bytes. // hexToKeybytes turns hex nibbles into key bytes.
// This can only be used for keys of even length. // This can only be used for keys of even length.
func hexToKeyBytes(hex []byte) []byte { func hexToKeybytes(hex []byte) []byte {
if hasTerm(hex) { if hasTerm(hex) {
hex = hex[:len(hex)-1] hex = hex[:len(hex)-1]
} }

View File

@ -72,8 +72,8 @@ func TestHexKeybytes(t *testing.T) {
if h := keybytesToHex(test.key); !bytes.Equal(h, test.hexOut) { if h := keybytesToHex(test.key); !bytes.Equal(h, test.hexOut) {
t.Errorf("keybytesToHex(%x) -> %x, want %x", test.key, h, test.hexOut) t.Errorf("keybytesToHex(%x) -> %x, want %x", test.key, h, test.hexOut)
} }
if k := hexToKeyBytes(test.hexIn); !bytes.Equal(k, test.key) { if k := hexToKeybytes(test.hexIn); !bytes.Equal(k, test.key) {
t.Errorf("hexToKeyBytes(%x) -> %x, want %x", test.hexIn, k, test.key) t.Errorf("hexToKeybytes(%x) -> %x, want %x", test.hexIn, k, test.key)
} }
} }
} }
@ -86,8 +86,7 @@ func TestHexToCompactInPlace(t *testing.T) {
} { } {
hexBytes, _ := hex.DecodeString(key) hexBytes, _ := hex.DecodeString(key)
exp := hexToCompact(hexBytes) exp := hexToCompact(hexBytes)
sz := hexToCompactInPlace(hexBytes) got := hexToCompactInPlace(hexBytes)
got := hexBytes[:sz]
if !bytes.Equal(exp, got) { if !bytes.Equal(exp, got) {
t.Fatalf("test %d: encoding err\ninp %v\ngot %x\nexp %x\n", i, key, got, exp) t.Fatalf("test %d: encoding err\ninp %v\ngot %x\nexp %x\n", i, key, got, exp)
} }
@ -102,8 +101,7 @@ func TestHexToCompactInPlaceRandom(t *testing.T) {
hexBytes := keybytesToHex(key) hexBytes := keybytesToHex(key)
hexOrig := []byte(string(hexBytes)) hexOrig := []byte(string(hexBytes))
exp := hexToCompact(hexBytes) exp := hexToCompact(hexBytes)
sz := hexToCompactInPlace(hexBytes) got := hexToCompactInPlace(hexBytes)
got := hexBytes[:sz]
if !bytes.Equal(exp, got) { if !bytes.Equal(exp, got) {
t.Fatalf("encoding err \ncpt %x\nhex %x\ngot %x\nexp %x\n", t.Fatalf("encoding err \ncpt %x\nhex %x\ngot %x\nexp %x\n",
@ -119,6 +117,13 @@ func BenchmarkHexToCompact(b *testing.B) {
} }
} }
func BenchmarkHexToCompactInPlace(b *testing.B) {
testBytes := []byte{0, 15, 1, 12, 11, 8, 16 /*term*/}
for i := 0; i < b.N; i++ {
hexToCompactInPlace(testBytes)
}
}
func BenchmarkCompactToHex(b *testing.B) { func BenchmarkCompactToHex(b *testing.B) {
testBytes := []byte{0, 15, 1, 12, 11, 8, 16 /*term*/} testBytes := []byte{0, 15, 1, 12, 11, 8, 16 /*term*/}
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
@ -136,6 +141,6 @@ func BenchmarkKeybytesToHex(b *testing.B) {
func BenchmarkHexToKeybytes(b *testing.B) { func BenchmarkHexToKeybytes(b *testing.B) {
testBytes := []byte{7, 6, 6, 5, 7, 2, 6, 2, 16} testBytes := []byte{7, 6, 6, 5, 7, 2, 6, 2, 16}
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
hexToKeyBytes(testBytes) hexToKeybytes(testBytes)
} }
} }

View File

@ -17,12 +17,18 @@
package trie package trie
import ( import (
"errors"
"fmt" "fmt"
"github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/common"
) )
// MissingNodeError is returned by the trie functions (TryGet, TryUpdate, TryDelete) // ErrCommitted is returned when a already committed trie is requested for usage.
// The potential usages can be `Get`, `Update`, `Delete`, `NodeIterator`, `Prove`
// and so on.
var ErrCommitted = errors.New("trie is already committed")
// MissingNodeError is returned by the trie functions (Get, Update, Delete)
// in the case where a trie node is not present in the local database. It contains // in the case where a trie node is not present in the local database. It contains
// information necessary for retrieving the missing node. // information necessary for retrieving the missing node.
type MissingNodeError struct { type MissingNodeError struct {

View File

@ -84,14 +84,13 @@ func (h *hasher) hash(n node, force bool) (hashed node, cached node) {
} }
return hashed, cached return hashed, cached
default: default:
// Value and hash nodes don't have children so they're left as were // Value and hash nodes don't have children, so they're left as were
return n, n return n, n
} }
} }
// hashShortNodeChildren collapses the short node. The returned collapsed node // hashShortNodeChildren collapses the short node. The returned collapsed node
// holds a live reference to the Key, and must not be modified. // holds a live reference to the Key, and must not be modified.
// The cached
func (h *hasher) hashShortNodeChildren(n *shortNode) (collapsed, cached *shortNode) { func (h *hasher) hashShortNodeChildren(n *shortNode) (collapsed, cached *shortNode) {
// Hash the short node's child, caching the newly hashed subtree // Hash the short node's child, caching the newly hashed subtree
collapsed, cached = n.copy(), n.copy() collapsed, cached = n.copy(), n.copy()
@ -153,7 +152,7 @@ func (h *hasher) shortnodeToHash(n *shortNode, force bool) node {
return h.hashData(enc) return h.hashData(enc)
} }
// shortnodeToHash is used to creates a hashNode from a set of hashNodes, (which // fullnodeToHash is used to create a hashNode from a fullNode, (which
// may contain nil values) // may contain nil values)
func (h *hasher) fullnodeToHash(n *fullNode, force bool) node { func (h *hasher) fullnodeToHash(n *fullNode, force bool) node {
n.encode(h.encbuf) n.encode(h.encbuf)
@ -203,7 +202,7 @@ func (h *hasher) proofHash(original node) (collapsed, hashed node) {
fn, _ := h.hashFullNodeChildren(n) fn, _ := h.hashFullNodeChildren(n)
return fn, h.fullnodeToHash(fn, false) return fn, h.fullnodeToHash(fn, false)
default: default:
// Value and hash nodes don't have children so they're left as were // Value and hash nodes don't have children, so they're left as were
return n, n return n, n
} }
} }

View File

@ -26,9 +26,6 @@ import (
gethtrie "github.com/ethereum/go-ethereum/trie" gethtrie "github.com/ethereum/go-ethereum/trie"
) )
// NodeIterator is an iterator to traverse the trie pre-order.
type NodeIterator = gethtrie.NodeIterator
// NodeResolver is used for looking up trie nodes before reaching into the real // NodeResolver is used for looking up trie nodes before reaching into the real
// persistent layer. This is not mandatory, rather is an optimization for cases // persistent layer. This is not mandatory, rather is an optimization for cases
// where trie nodes can be recovered from some external mechanism without reading // where trie nodes can be recovered from some external mechanism without reading
@ -75,6 +72,9 @@ func (it *Iterator) Prove() [][]byte {
return it.nodeIt.LeafProof() return it.nodeIt.LeafProof()
} }
// NodeIterator is an iterator to traverse the trie pre-order.
type NodeIterator = gethtrie.NodeIterator
// nodeIteratorState represents the iteration state at one particular node of the // nodeIteratorState represents the iteration state at one particular node of the
// trie, which can be resumed at a later invocation. // trie, which can be resumed at a later invocation.
type nodeIteratorState struct { type nodeIteratorState struct {
@ -92,6 +92,7 @@ type nodeIterator struct {
err error // Failure set in case of an internal error in the iterator err error // Failure set in case of an internal error in the iterator
resolver NodeResolver // optional node resolver for avoiding disk hits resolver NodeResolver // optional node resolver for avoiding disk hits
pool []*nodeIteratorState // local pool for iteratorstates
} }
// errIteratorEnd is stored in nodeIterator.err when iteration is done. // errIteratorEnd is stored in nodeIterator.err when iteration is done.
@ -119,6 +120,24 @@ func newNodeIterator(trie *Trie, start []byte) NodeIterator {
return it return it
} }
func (it *nodeIterator) putInPool(item *nodeIteratorState) {
if len(it.pool) < 40 {
item.node = nil
it.pool = append(it.pool, item)
}
}
func (it *nodeIterator) getFromPool() *nodeIteratorState {
idx := len(it.pool) - 1
if idx < 0 {
return new(nodeIteratorState)
}
el := it.pool[idx]
it.pool[idx] = nil
it.pool = it.pool[:idx]
return el
}
func (it *nodeIterator) AddResolver(resolver NodeResolver) { func (it *nodeIterator) AddResolver(resolver NodeResolver) {
it.resolver = resolver it.resolver = resolver
} }
@ -137,14 +156,6 @@ func (it *nodeIterator) Parent() common.Hash {
return it.stack[len(it.stack)-1].parent return it.stack[len(it.stack)-1].parent
} }
func (it *nodeIterator) ParentPath() []byte {
if len(it.stack) == 0 {
return []byte{}
}
pathlen := it.stack[len(it.stack)-1].pathlen
return it.path[:pathlen]
}
func (it *nodeIterator) Leaf() bool { func (it *nodeIterator) Leaf() bool {
return hasTerm(it.path) return hasTerm(it.path)
} }
@ -152,7 +163,7 @@ func (it *nodeIterator) Leaf() bool {
func (it *nodeIterator) LeafKey() []byte { func (it *nodeIterator) LeafKey() []byte {
if len(it.stack) > 0 { if len(it.stack) > 0 {
if _, ok := it.stack[len(it.stack)-1].node.(valueNode); ok { if _, ok := it.stack[len(it.stack)-1].node.(valueNode); ok {
return hexToKeyBytes(it.path) return hexToKeybytes(it.path)
} }
} }
panic("not at leaf") panic("not at leaf")
@ -342,7 +353,14 @@ func (it *nodeIterator) resolveHash(hash hashNode, path []byte) (node, error) {
// loaded blob will be tracked, while it's not required here since // loaded blob will be tracked, while it's not required here since
// all loaded nodes won't be linked to trie at all and track nodes // all loaded nodes won't be linked to trie at all and track nodes
// may lead to out-of-memory issue. // may lead to out-of-memory issue.
return it.trie.reader.node(path, common.BytesToHash(hash)) blob, err := it.trie.reader.node(path, common.BytesToHash(hash))
if err != nil {
return nil, err
}
// The raw-blob format nodes are loaded either from the
// clean cache or the database, they are all in their own
// copy and safe to use unsafe decoder.
return mustDecodeNodeUnsafe(hash, blob), nil
} }
func (it *nodeIterator) resolveBlob(hash hashNode, path []byte) ([]byte, error) { func (it *nodeIterator) resolveBlob(hash hashNode, path []byte) ([]byte, error) {
@ -356,7 +374,7 @@ func (it *nodeIterator) resolveBlob(hash hashNode, path []byte) ([]byte, error)
// loaded blob will be tracked, while it's not required here since // loaded blob will be tracked, while it's not required here since
// all loaded nodes won't be linked to trie at all and track nodes // all loaded nodes won't be linked to trie at all and track nodes
// may lead to out-of-memory issue. // may lead to out-of-memory issue.
return it.trie.reader.nodeBlob(path, common.BytesToHash(hash)) return it.trie.reader.node(path, common.BytesToHash(hash))
} }
func (st *nodeIteratorState) resolve(it *nodeIterator, path []byte) error { func (st *nodeIteratorState) resolve(it *nodeIterator, path []byte) error {
@ -371,8 +389,9 @@ func (st *nodeIteratorState) resolve(it *nodeIterator, path []byte) error {
return nil return nil
} }
func findChild(n *fullNode, index int, path []byte, ancestor common.Hash) (node, *nodeIteratorState, []byte, int) { func (it *nodeIterator) findChild(n *fullNode, index int, ancestor common.Hash) (node, *nodeIteratorState, []byte, int) {
var ( var (
path = it.path
child node child node
state *nodeIteratorState state *nodeIteratorState
childPath []byte childPath []byte
@ -381,13 +400,12 @@ func findChild(n *fullNode, index int, path []byte, ancestor common.Hash) (node,
if n.Children[index] != nil { if n.Children[index] != nil {
child = n.Children[index] child = n.Children[index]
hash, _ := child.cache() hash, _ := child.cache()
state = &nodeIteratorState{ state = it.getFromPool()
hash: common.BytesToHash(hash), state.hash = common.BytesToHash(hash)
node: child, state.node = child
parent: ancestor, state.parent = ancestor
index: -1, state.index = -1
pathlen: len(path), state.pathlen = len(path)
}
childPath = append(childPath, path...) childPath = append(childPath, path...)
childPath = append(childPath, byte(index)) childPath = append(childPath, byte(index))
return child, state, childPath, index return child, state, childPath, index
@ -400,7 +418,7 @@ func (it *nodeIterator) nextChild(parent *nodeIteratorState, ancestor common.Has
switch node := parent.node.(type) { switch node := parent.node.(type) {
case *fullNode: case *fullNode:
// Full node, move to the first non-nil child. // Full node, move to the first non-nil child.
if child, state, path, index := findChild(node, parent.index+1, it.path, ancestor); child != nil { if child, state, path, index := it.findChild(node, parent.index+1, ancestor); child != nil {
parent.index = index - 1 parent.index = index - 1
return state, path, true return state, path, true
} }
@ -408,13 +426,12 @@ func (it *nodeIterator) nextChild(parent *nodeIteratorState, ancestor common.Has
// Short node, return the pointer singleton child // Short node, return the pointer singleton child
if parent.index < 0 { if parent.index < 0 {
hash, _ := node.Val.cache() hash, _ := node.Val.cache()
state := &nodeIteratorState{ state := it.getFromPool()
hash: common.BytesToHash(hash), state.hash = common.BytesToHash(hash)
node: node.Val, state.node = node.Val
parent: ancestor, state.parent = ancestor
index: -1, state.index = -1
pathlen: len(it.path), state.pathlen = len(it.path)
}
path := append(it.path, node.Key...) path := append(it.path, node.Key...)
return state, path, true return state, path, true
} }
@ -428,7 +445,7 @@ func (it *nodeIterator) nextChildAt(parent *nodeIteratorState, ancestor common.H
switch n := parent.node.(type) { switch n := parent.node.(type) {
case *fullNode: case *fullNode:
// Full node, move to the first non-nil child before the desired key position // Full node, move to the first non-nil child before the desired key position
child, state, path, index := findChild(n, parent.index+1, it.path, ancestor) child, state, path, index := it.findChild(n, parent.index+1, ancestor)
if child == nil { if child == nil {
// No more children in this fullnode // No more children in this fullnode
return parent, it.path, false return parent, it.path, false
@ -440,7 +457,7 @@ func (it *nodeIterator) nextChildAt(parent *nodeIteratorState, ancestor common.H
} }
// The child is before the seek position. Try advancing // The child is before the seek position. Try advancing
for { for {
nextChild, nextState, nextPath, nextIndex := findChild(n, index+1, it.path, ancestor) nextChild, nextState, nextPath, nextIndex := it.findChild(n, index+1, ancestor)
// If we run out of children, or skipped past the target, return the // If we run out of children, or skipped past the target, return the
// previous one // previous one
if nextChild == nil || bytes.Compare(nextPath, key) >= 0 { if nextChild == nil || bytes.Compare(nextPath, key) >= 0 {
@ -454,13 +471,12 @@ func (it *nodeIterator) nextChildAt(parent *nodeIteratorState, ancestor common.H
// Short node, return the pointer singleton child // Short node, return the pointer singleton child
if parent.index < 0 { if parent.index < 0 {
hash, _ := n.Val.cache() hash, _ := n.Val.cache()
state := &nodeIteratorState{ state := it.getFromPool()
hash: common.BytesToHash(hash), state.hash = common.BytesToHash(hash)
node: n.Val, state.node = n.Val
parent: ancestor, state.parent = ancestor
index: -1, state.index = -1
pathlen: len(it.path), state.pathlen = len(it.path)
}
path := append(it.path, n.Key...) path := append(it.path, n.Key...)
return state, path, true return state, path, true
} }
@ -481,6 +497,8 @@ func (it *nodeIterator) pop() {
it.path = it.path[:last.pathlen] it.path = it.path[:last.pathlen]
it.stack[len(it.stack)-1] = nil it.stack[len(it.stack)-1] = nil
it.stack = it.stack[:len(it.stack)-1] it.stack = it.stack[:len(it.stack)-1]
// last is now unused
it.putInPool(last)
} }
func compareNodes(a, b NodeIterator) int { func compareNodes(a, b NodeIterator) int {

View File

@ -18,32 +18,82 @@ package trie
import ( import (
"bytes" "bytes"
"encoding/binary"
"fmt" "fmt"
"math/rand" "math/rand"
"testing" "testing"
"github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/core/rawdb" "github.com/ethereum/go-ethereum/core/rawdb"
"github.com/ethereum/go-ethereum/core/types"
"github.com/ethereum/go-ethereum/crypto" "github.com/ethereum/go-ethereum/crypto"
"github.com/ethereum/go-ethereum/ethdb/memorydb" "github.com/ethereum/go-ethereum/ethdb"
geth_trie "github.com/ethereum/go-ethereum/trie"
"github.com/cerc-io/ipld-eth-statedb/trie_by_cid/trie/trienode"
) )
var ( // makeTestTrie create a sample test trie to test node-wise reconstruction.
packableTestData = []kvsi{ func makeTestTrie(scheme string) (ethdb.Database, *testDb, *StateTrie, map[string][]byte) {
{"one", 1}, // Create an empty trie
{"two", 2}, db := rawdb.NewMemoryDatabase()
{"three", 3}, triedb := newTestDatabase(db, scheme)
{"four", 4}, trie, _ := NewStateTrie(TrieID(types.EmptyRootHash), triedb)
{"five", 5},
{"ten", 10}, // Fill it with some arbitrary data
content := make(map[string][]byte)
for i := byte(0); i < 255; i++ {
// Map the same data under multiple keys
key, val := common.LeftPadBytes([]byte{1, i}, 32), []byte{i}
content[string(key)] = val
trie.MustUpdate(key, val)
key, val = common.LeftPadBytes([]byte{2, i}, 32), []byte{i}
content[string(key)] = val
trie.MustUpdate(key, val)
// Add some other data to inflate the trie
for j := byte(3); j < 13; j++ {
key, val = common.LeftPadBytes([]byte{j, i}, 32), []byte{j, i}
content[string(key)] = val
trie.MustUpdate(key, val)
}
}
root, nodes, _ := trie.Commit(false)
if err := triedb.Update(root, types.EmptyRootHash, trienode.NewWithNodeSet(nodes)); err != nil {
panic(fmt.Errorf("failed to commit db %v", err))
}
if err := triedb.Commit(root); err != nil {
panic(err)
}
// Re-create the trie based on the new state
trie, _ = NewStateTrie(TrieID(root), triedb)
return db, triedb, trie, content
}
// checkTrieConsistency checks that all nodes in a trie are indeed present.
func checkTrieConsistency(db ethdb.Database, scheme string, root common.Hash, rawTrie bool) error {
ndb := newTestDatabase(db, scheme)
var it NodeIterator
if rawTrie {
trie, err := New(TrieID(root), ndb)
if err != nil {
return nil // Consider a non existent state consistent
}
it = trie.MustNodeIterator(nil)
} else {
trie, err := NewStateTrie(TrieID(root), ndb)
if err != nil {
return nil // Consider a non existent state consistent
}
it = trie.MustNodeIterator(nil)
}
for it.Next(true) {
}
return it.Error()
} }
)
func TestEmptyIterator(t *testing.T) { func TestEmptyIterator(t *testing.T) {
trie := NewEmpty(NewDatabase(rawdb.NewMemoryDatabase())) trie := NewEmpty(newTestDatabase(rawdb.NewMemoryDatabase(), rawdb.HashScheme))
iter := trie.NodeIterator(nil) iter := trie.MustNodeIterator(nil)
seen := make(map[string]struct{}) seen := make(map[string]struct{})
for iter.Next(true) { for iter.Next(true) {
@ -55,7 +105,7 @@ func TestEmptyIterator(t *testing.T) {
} }
func TestIterator(t *testing.T) { func TestIterator(t *testing.T) {
db := NewDatabase(rawdb.NewMemoryDatabase()) db := newTestDatabase(rawdb.NewMemoryDatabase(), rawdb.HashScheme)
trie := NewEmpty(db) trie := NewEmpty(db)
vals := []struct{ k, v string }{ vals := []struct{ k, v string }{
{"do", "verb"}, {"do", "verb"},
@ -69,14 +119,14 @@ func TestIterator(t *testing.T) {
all := make(map[string]string) all := make(map[string]string)
for _, val := range vals { for _, val := range vals {
all[val.k] = val.v all[val.k] = val.v
trie.Update([]byte(val.k), []byte(val.v)) trie.MustUpdate([]byte(val.k), []byte(val.v))
} }
root, nodes := trie.Commit(false) root, nodes, _ := trie.Commit(false)
db.Update(NewWithNodeSet(nodes)) db.Update(root, types.EmptyRootHash, trienode.NewWithNodeSet(nodes))
trie, _ = New(TrieID(root), db, StateTrieCodec) trie, _ = New(TrieID(root), db)
found := make(map[string]string) found := make(map[string]string)
it := NewIterator(trie.NodeIterator(nil)) it := NewIterator(trie.MustNodeIterator(nil))
for it.Next() { for it.Next() {
found[string(it.Key)] = string(it.Value) found[string(it.Key)] = string(it.Value)
} }
@ -93,20 +143,24 @@ type kv struct {
t bool t bool
} }
func (k *kv) cmp(other *kv) int {
return bytes.Compare(k.k, other.k)
}
func TestIteratorLargeData(t *testing.T) { func TestIteratorLargeData(t *testing.T) {
trie := NewEmpty(NewDatabase(rawdb.NewMemoryDatabase())) trie := NewEmpty(newTestDatabase(rawdb.NewMemoryDatabase(), rawdb.HashScheme))
vals := make(map[string]*kv) vals := make(map[string]*kv)
for i := byte(0); i < 255; i++ { for i := byte(0); i < 255; i++ {
value := &kv{common.LeftPadBytes([]byte{i}, 32), []byte{i}, false} value := &kv{common.LeftPadBytes([]byte{i}, 32), []byte{i}, false}
value2 := &kv{common.LeftPadBytes([]byte{10, i}, 32), []byte{i}, false} value2 := &kv{common.LeftPadBytes([]byte{10, i}, 32), []byte{i}, false}
trie.Update(value.k, value.v) trie.MustUpdate(value.k, value.v)
trie.Update(value2.k, value2.v) trie.MustUpdate(value2.k, value2.v)
vals[string(value.k)] = value vals[string(value.k)] = value
vals[string(value2.k)] = value2 vals[string(value2.k)] = value2
} }
it := NewIterator(trie.NodeIterator(nil)) it := NewIterator(trie.MustNodeIterator(nil))
for it.Next() { for it.Next() {
vals[string(it.Key)].t = true vals[string(it.Key)].t = true
} }
@ -126,39 +180,65 @@ func TestIteratorLargeData(t *testing.T) {
} }
} }
type iterationElement struct {
hash common.Hash
path []byte
blob []byte
}
// Tests that the node iterator indeed walks over the entire database contents. // Tests that the node iterator indeed walks over the entire database contents.
func TestNodeIteratorCoverage(t *testing.T) { func TestNodeIteratorCoverage(t *testing.T) {
db, trie, _ := makeTestTrie(t) testNodeIteratorCoverage(t, rawdb.HashScheme)
testNodeIteratorCoverage(t, rawdb.PathScheme)
}
func testNodeIteratorCoverage(t *testing.T, scheme string) {
// Create some arbitrary test trie to iterate // Create some arbitrary test trie to iterate
db, nodeDb, trie, _ := makeTestTrie(scheme)
// Gather all the node hashes found by the iterator // Gather all the node hashes found by the iterator
hashes := make(map[common.Hash]struct{}) var elements = make(map[common.Hash]iterationElement)
for it := trie.NodeIterator(nil); it.Next(true); { for it := trie.MustNodeIterator(nil); it.Next(true); {
if it.Hash() != (common.Hash{}) { if it.Hash() != (common.Hash{}) {
hashes[it.Hash()] = struct{}{} elements[it.Hash()] = iterationElement{
hash: it.Hash(),
path: common.CopyBytes(it.Path()),
blob: common.CopyBytes(it.NodeBlob()),
}
} }
} }
// Cross check the hashes and the database itself // Cross check the hashes and the database itself
for hash := range hashes { reader, err := nodeDb.Reader(trie.Hash())
if _, err := db.Node(hash, StateTrieCodec); err != nil { if err != nil {
t.Errorf("failed to retrieve reported node %x: %v", hash, err) t.Fatalf("state is not available %x", trie.Hash())
}
for _, element := range elements {
if blob, err := reader.Node(common.Hash{}, element.path, element.hash); err != nil {
t.Errorf("failed to retrieve reported node %x: %v", element.hash, err)
} else if !bytes.Equal(blob, element.blob) {
t.Errorf("node blob is different, want %v got %v", element.blob, blob)
} }
} }
for hash, obj := range db.dirties { var (
if obj != nil && hash != (common.Hash{}) { count int
if _, ok := hashes[hash]; !ok { it = db.NewIterator(nil, nil)
t.Errorf("state entry not reported %x", hash) )
}
}
}
it := db.diskdb.NewIterator(nil, nil)
for it.Next() { for it.Next() {
key := it.Key() res, _, _ := isTrieNode(nodeDb.Scheme(), it.Key(), it.Value())
if _, ok := hashes[common.BytesToHash(key)]; !ok { if !res {
t.Errorf("state entry not reported %x", key) continue
}
count += 1
if elem, ok := elements[crypto.Keccak256Hash(it.Value())]; !ok {
t.Error("state entry not reported")
} else if !bytes.Equal(it.Value(), elem.blob) {
t.Errorf("node blob is different, want %v got %v", elem.blob, it.Value())
} }
} }
it.Release() it.Release()
if count != len(elements) {
t.Errorf("state entry is mismatched %d %d", count, len(elements))
}
} }
type kvs struct{ k, v string } type kvs struct{ k, v string }
@ -187,25 +267,25 @@ var testdata2 = []kvs{
} }
func TestIteratorSeek(t *testing.T) { func TestIteratorSeek(t *testing.T) {
trie := NewEmpty(NewDatabase(rawdb.NewMemoryDatabase())) trie := NewEmpty(newTestDatabase(rawdb.NewMemoryDatabase(), rawdb.HashScheme))
for _, val := range testdata1 { for _, val := range testdata1 {
trie.Update([]byte(val.k), []byte(val.v)) trie.MustUpdate([]byte(val.k), []byte(val.v))
} }
// Seek to the middle. // Seek to the middle.
it := NewIterator(trie.NodeIterator([]byte("fab"))) it := NewIterator(trie.MustNodeIterator([]byte("fab")))
if err := checkIteratorOrder(testdata1[4:], it); err != nil { if err := checkIteratorOrder(testdata1[4:], it); err != nil {
t.Fatal(err) t.Fatal(err)
} }
// Seek to a non-existent key. // Seek to a non-existent key.
it = NewIterator(trie.NodeIterator([]byte("barc"))) it = NewIterator(trie.MustNodeIterator([]byte("barc")))
if err := checkIteratorOrder(testdata1[1:], it); err != nil { if err := checkIteratorOrder(testdata1[1:], it); err != nil {
t.Fatal(err) t.Fatal(err)
} }
// Seek beyond the end. // Seek beyond the end.
it = NewIterator(trie.NodeIterator([]byte("z"))) it = NewIterator(trie.MustNodeIterator([]byte("z")))
if err := checkIteratorOrder(nil, it); err != nil { if err := checkIteratorOrder(nil, it); err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -228,26 +308,26 @@ func checkIteratorOrder(want []kvs, it *Iterator) error {
} }
func TestDifferenceIterator(t *testing.T) { func TestDifferenceIterator(t *testing.T) {
dba := NewDatabase(rawdb.NewMemoryDatabase()) dba := newTestDatabase(rawdb.NewMemoryDatabase(), rawdb.HashScheme)
triea := NewEmpty(dba) triea := NewEmpty(dba)
for _, val := range testdata1 { for _, val := range testdata1 {
triea.Update([]byte(val.k), []byte(val.v)) triea.MustUpdate([]byte(val.k), []byte(val.v))
} }
rootA, nodesA := triea.Commit(false) rootA, nodesA, _ := triea.Commit(false)
dba.Update(NewWithNodeSet(nodesA)) dba.Update(rootA, types.EmptyRootHash, trienode.NewWithNodeSet(nodesA))
triea, _ = New(TrieID(rootA), dba, StateTrieCodec) triea, _ = New(TrieID(rootA), dba)
dbb := NewDatabase(rawdb.NewMemoryDatabase()) dbb := newTestDatabase(rawdb.NewMemoryDatabase(), rawdb.HashScheme)
trieb := NewEmpty(dbb) trieb := NewEmpty(dbb)
for _, val := range testdata2 { for _, val := range testdata2 {
trieb.Update([]byte(val.k), []byte(val.v)) trieb.MustUpdate([]byte(val.k), []byte(val.v))
} }
rootB, nodesB := trieb.Commit(false) rootB, nodesB, _ := trieb.Commit(false)
dbb.Update(NewWithNodeSet(nodesB)) dbb.Update(rootB, types.EmptyRootHash, trienode.NewWithNodeSet(nodesB))
trieb, _ = New(TrieID(rootB), dbb, StateTrieCodec) trieb, _ = New(TrieID(rootB), dbb)
found := make(map[string]string) found := make(map[string]string)
di, _ := NewDifferenceIterator(triea.NodeIterator(nil), trieb.NodeIterator(nil)) di, _ := NewDifferenceIterator(triea.MustNodeIterator(nil), trieb.MustNodeIterator(nil))
it := NewIterator(di) it := NewIterator(di)
for it.Next() { for it.Next() {
found[string(it.Key)] = string(it.Value) found[string(it.Key)] = string(it.Value)
@ -270,25 +350,25 @@ func TestDifferenceIterator(t *testing.T) {
} }
func TestUnionIterator(t *testing.T) { func TestUnionIterator(t *testing.T) {
dba := NewDatabase(rawdb.NewMemoryDatabase()) dba := newTestDatabase(rawdb.NewMemoryDatabase(), rawdb.HashScheme)
triea := NewEmpty(dba) triea := NewEmpty(dba)
for _, val := range testdata1 { for _, val := range testdata1 {
triea.Update([]byte(val.k), []byte(val.v)) triea.MustUpdate([]byte(val.k), []byte(val.v))
} }
rootA, nodesA := triea.Commit(false) rootA, nodesA, _ := triea.Commit(false)
dba.Update(NewWithNodeSet(nodesA)) dba.Update(rootA, types.EmptyRootHash, trienode.NewWithNodeSet(nodesA))
triea, _ = New(TrieID(rootA), dba, StateTrieCodec) triea, _ = New(TrieID(rootA), dba)
dbb := NewDatabase(rawdb.NewMemoryDatabase()) dbb := newTestDatabase(rawdb.NewMemoryDatabase(), rawdb.HashScheme)
trieb := NewEmpty(dbb) trieb := NewEmpty(dbb)
for _, val := range testdata2 { for _, val := range testdata2 {
trieb.Update([]byte(val.k), []byte(val.v)) trieb.MustUpdate([]byte(val.k), []byte(val.v))
} }
rootB, nodesB := trieb.Commit(false) rootB, nodesB, _ := trieb.Commit(false)
dbb.Update(NewWithNodeSet(nodesB)) dbb.Update(rootB, types.EmptyRootHash, trienode.NewWithNodeSet(nodesB))
trieb, _ = New(TrieID(rootB), dbb, StateTrieCodec) trieb, _ = New(TrieID(rootB), dbb)
di, _ := NewUnionIterator([]NodeIterator{triea.NodeIterator(nil), trieb.NodeIterator(nil)}) di, _ := NewUnionIterator([]NodeIterator{triea.MustNodeIterator(nil), trieb.MustNodeIterator(nil)})
it := NewIterator(di) it := NewIterator(di)
all := []struct{ k, v string }{ all := []struct{ k, v string }{
@ -323,86 +403,107 @@ func TestUnionIterator(t *testing.T) {
} }
func TestIteratorNoDups(t *testing.T) { func TestIteratorNoDups(t *testing.T) {
tr := NewEmpty(NewDatabase(rawdb.NewMemoryDatabase())) db := newTestDatabase(rawdb.NewMemoryDatabase(), rawdb.HashScheme)
tr := NewEmpty(db)
for _, val := range testdata1 { for _, val := range testdata1 {
tr.Update([]byte(val.k), []byte(val.v)) tr.MustUpdate([]byte(val.k), []byte(val.v))
} }
checkIteratorNoDups(t, tr.NodeIterator(nil), nil) checkIteratorNoDups(t, tr.MustNodeIterator(nil), nil)
} }
// This test checks that nodeIterator.Next can be retried after inserting missing trie nodes. // This test checks that nodeIterator.Next can be retried after inserting missing trie nodes.
func TestIteratorContinueAfterErrorMemonly(t *testing.T) { testIteratorContinueAfterError(t, true) } func TestIteratorContinueAfterError(t *testing.T) {
testIteratorContinueAfterError(t, false, rawdb.HashScheme)
func testIteratorContinueAfterError(t *testing.T, memonly bool) { testIteratorContinueAfterError(t, true, rawdb.HashScheme)
diskdb := rawdb.NewMemoryDatabase() testIteratorContinueAfterError(t, false, rawdb.PathScheme)
triedb := NewDatabase(diskdb) testIteratorContinueAfterError(t, true, rawdb.PathScheme)
tr := NewEmpty(triedb)
for _, val := range testdata1 {
tr.Update([]byte(val.k), []byte(val.v))
} }
_, nodes := tr.Commit(false)
triedb.Update(NewWithNodeSet(nodes)) func testIteratorContinueAfterError(t *testing.T, memonly bool, scheme string) {
// if !memonly { diskdb := rawdb.NewMemoryDatabase()
// triedb.Commit(tr.Hash(), false) tdb := newTestDatabase(diskdb, scheme)
// }
wantNodeCount := checkIteratorNoDups(t, tr.NodeIterator(nil), nil) tr := NewEmpty(tdb)
for _, val := range testdata1 {
tr.MustUpdate([]byte(val.k), []byte(val.v))
}
root, nodes, _ := tr.Commit(false)
tdb.Update(root, types.EmptyRootHash, trienode.NewWithNodeSet(nodes))
if !memonly {
tdb.Commit(root)
}
tr, _ = New(TrieID(root), tdb)
wantNodeCount := checkIteratorNoDups(t, tr.MustNodeIterator(nil), nil)
var ( var (
diskKeys [][]byte paths [][]byte
memKeys []common.Hash hashes []common.Hash
) )
if memonly { if memonly {
memKeys = triedb.Nodes() for path, n := range nodes.Nodes {
paths = append(paths, []byte(path))
hashes = append(hashes, n.Hash)
}
} else { } else {
it := diskdb.NewIterator(nil, nil) it := diskdb.NewIterator(nil, nil)
for it.Next() { for it.Next() {
diskKeys = append(diskKeys, it.Key()) ok, path, hash := isTrieNode(tdb.Scheme(), it.Key(), it.Value())
if !ok {
continue
}
paths = append(paths, path)
hashes = append(hashes, hash)
} }
it.Release() it.Release()
} }
for i := 0; i < 20; i++ { for i := 0; i < 20; i++ {
// Create trie that will load all nodes from DB. // Create trie that will load all nodes from DB.
tr, _ := New(TrieID(tr.Hash()), triedb, StateTrieCodec) tr, _ := New(TrieID(tr.Hash()), tdb)
// Remove a random node from the database. It can't be the root node // Remove a random node from the database. It can't be the root node
// because that one is already loaded. // because that one is already loaded.
var ( var (
rkey common.Hash
rval []byte rval []byte
robj *cachedNode rpath []byte
rhash common.Hash
) )
for { for {
if memonly { if memonly {
rkey = memKeys[rand.Intn(len(memKeys))] rpath = paths[rand.Intn(len(paths))]
} else { n := nodes.Nodes[string(rpath)]
copy(rkey[:], diskKeys[rand.Intn(len(diskKeys))]) if n == nil {
continue
} }
if rkey != tr.Hash() { rhash = n.Hash
} else {
index := rand.Intn(len(paths))
rpath = paths[index]
rhash = hashes[index]
}
if rhash != tr.Hash() {
break break
} }
} }
if memonly { if memonly {
robj = triedb.dirties[rkey] tr.reader.banned = map[string]struct{}{string(rpath): {}}
delete(triedb.dirties, rkey)
} else { } else {
rval, _ = diskdb.Get(rkey[:]) rval = rawdb.ReadTrieNode(diskdb, common.Hash{}, rpath, rhash, tdb.Scheme())
diskdb.Delete(rkey[:]) rawdb.DeleteTrieNode(diskdb, common.Hash{}, rpath, rhash, tdb.Scheme())
} }
// Iterate until the error is hit. // Iterate until the error is hit.
seen := make(map[string]bool) seen := make(map[string]bool)
it := tr.NodeIterator(nil) it := tr.MustNodeIterator(nil)
checkIteratorNoDups(t, it, seen) checkIteratorNoDups(t, it, seen)
missing, ok := it.Error().(*MissingNodeError) missing, ok := it.Error().(*MissingNodeError)
if !ok || missing.NodeHash != rkey { if !ok || missing.NodeHash != rhash {
t.Fatal("didn't hit missing node, got", it.Error()) t.Fatal("didn't hit missing node, got", it.Error())
} }
// Add the node back and continue iteration. // Add the node back and continue iteration.
if memonly { if memonly {
triedb.dirties[rkey] = robj delete(tr.reader.banned, string(rpath))
} else { } else {
diskdb.Put(rkey[:], rval) rawdb.WriteTrieNode(diskdb, common.Hash{}, rpath, rhash, rval, tdb.Scheme())
} }
checkIteratorNoDups(t, it, seen) checkIteratorNoDups(t, it, seen)
if it.Error() != nil { if it.Error() != nil {
@ -417,40 +518,49 @@ func testIteratorContinueAfterError(t *testing.T, memonly bool) {
// Similar to the test above, this one checks that failure to create nodeIterator at a // Similar to the test above, this one checks that failure to create nodeIterator at a
// certain key prefix behaves correctly when Next is called. The expectation is that Next // certain key prefix behaves correctly when Next is called. The expectation is that Next
// should retry seeking before returning true for the first time. // should retry seeking before returning true for the first time.
func TestIteratorContinueAfterSeekErrorMemonly(t *testing.T) { func TestIteratorContinueAfterSeekError(t *testing.T) {
testIteratorContinueAfterSeekError(t, true) testIteratorContinueAfterSeekError(t, false, rawdb.HashScheme)
testIteratorContinueAfterSeekError(t, true, rawdb.HashScheme)
testIteratorContinueAfterSeekError(t, false, rawdb.PathScheme)
testIteratorContinueAfterSeekError(t, true, rawdb.PathScheme)
} }
func testIteratorContinueAfterSeekError(t *testing.T, memonly bool) { func testIteratorContinueAfterSeekError(t *testing.T, memonly bool, scheme string) {
// Commit test trie to db, then remove the node containing "bars". // Commit test trie to db, then remove the node containing "bars".
var (
barNodePath []byte
barNodeHash = common.HexToHash("05041990364eb72fcb1127652ce40d8bab765f2bfe53225b1170d276cc101c2e")
)
diskdb := rawdb.NewMemoryDatabase() diskdb := rawdb.NewMemoryDatabase()
triedb := NewDatabase(diskdb) triedb := newTestDatabase(diskdb, scheme)
ctr := NewEmpty(triedb) ctr := NewEmpty(triedb)
for _, val := range testdata1 { for _, val := range testdata1 {
ctr.Update([]byte(val.k), []byte(val.v)) ctr.MustUpdate([]byte(val.k), []byte(val.v))
}
root, nodes, _ := ctr.Commit(false)
for path, n := range nodes.Nodes {
if n.Hash == barNodeHash {
barNodePath = []byte(path)
break
}
}
triedb.Update(root, types.EmptyRootHash, trienode.NewWithNodeSet(nodes))
if !memonly {
triedb.Commit(root)
} }
root, nodes := ctr.Commit(false)
triedb.Update(NewWithNodeSet(nodes))
// if !memonly {
// triedb.Commit(root, false)
// }
barNodeHash := common.HexToHash("05041990364eb72fcb1127652ce40d8bab765f2bfe53225b1170d276cc101c2e")
var ( var (
barNodeBlob []byte barNodeBlob []byte
barNodeObj *cachedNode
) )
tr, _ := New(TrieID(root), triedb)
if memonly { if memonly {
barNodeObj = triedb.dirties[barNodeHash] tr.reader.banned = map[string]struct{}{string(barNodePath): {}}
delete(triedb.dirties, barNodeHash)
} else { } else {
barNodeBlob, _ = diskdb.Get(barNodeHash[:]) barNodeBlob = rawdb.ReadTrieNode(diskdb, common.Hash{}, barNodePath, barNodeHash, triedb.Scheme())
diskdb.Delete(barNodeHash[:]) rawdb.DeleteTrieNode(diskdb, common.Hash{}, barNodePath, barNodeHash, triedb.Scheme())
} }
// Create a new iterator that seeks to "bars". Seeking can't proceed because // Create a new iterator that seeks to "bars". Seeking can't proceed because
// the node is missing. // the node is missing.
tr, _ := New(TrieID(root), triedb, StateTrieCodec) it := tr.MustNodeIterator([]byte("bars"))
it := tr.NodeIterator([]byte("bars"))
missing, ok := it.Error().(*MissingNodeError) missing, ok := it.Error().(*MissingNodeError)
if !ok { if !ok {
t.Fatal("want MissingNodeError, got", it.Error()) t.Fatal("want MissingNodeError, got", it.Error())
@ -459,9 +569,9 @@ func testIteratorContinueAfterSeekError(t *testing.T, memonly bool) {
} }
// Reinsert the missing node. // Reinsert the missing node.
if memonly { if memonly {
triedb.dirties[barNodeHash] = barNodeObj delete(tr.reader.banned, string(barNodePath))
} else { } else {
diskdb.Put(barNodeHash[:], barNodeBlob) rawdb.WriteTrieNode(diskdb, common.Hash{}, barNodePath, barNodeHash, barNodeBlob, triedb.Scheme())
} }
// Check that iteration produces the right set of values. // Check that iteration produces the right set of values.
if err := checkIteratorOrder(testdata1[2:], NewIterator(it)); err != nil { if err := checkIteratorOrder(testdata1[2:], NewIterator(it)); err != nil {
@ -482,96 +592,38 @@ func checkIteratorNoDups(t *testing.T, it NodeIterator, seen map[string]bool) in
return len(seen) return len(seen)
} }
type loggingTrieDb struct {
*Database
getCount uint64
}
// GetReader retrieves a node reader belonging to the given state root.
func (db *loggingTrieDb) GetReader(root common.Hash, codec uint64) Reader {
return &loggingNodeReader{db, codec}
}
// loggingNodeReader is a trie node reader that logs the number of get operations.
type loggingNodeReader struct {
db *loggingTrieDb
codec uint64
}
// Node retrieves the trie node with the given node hash.
func (reader *loggingNodeReader) Node(owner common.Hash, path []byte, hash common.Hash) (node, error) {
blob, err := reader.NodeBlob(owner, path, hash)
if err != nil {
return nil, err
}
return decodeNodeUnsafe(hash[:], blob)
}
// NodeBlob retrieves the RLP-encoded trie node blob with the given node hash.
func (reader *loggingNodeReader) NodeBlob(_ common.Hash, _ []byte, hash common.Hash) ([]byte, error) {
reader.db.getCount++
return reader.db.Node(hash, reader.codec)
}
func newLoggingStateTrie(id *ID, db *Database, codec uint64) (*StateTrie, *loggingTrieDb, error) {
logdb := &loggingTrieDb{Database: db}
trie, err := New(id, logdb, codec)
if err != nil {
return nil, nil, err
}
return &StateTrie{trie: *trie, preimages: db.preimages}, logdb, nil
}
// makeLargeTestTrie create a sample test trie
func makeLargeTestTrie(t testing.TB) (*Database, *StateTrie, *loggingTrieDb) {
// Create an empty trie
triedb := NewDatabase(rawdb.NewDatabase(memorydb.New()))
trie, logDb, err := newLoggingStateTrie(TrieID(common.Hash{}), triedb, StateTrieCodec)
if err != nil {
t.Fatal(err)
}
// Fill it with some arbitrary data
for i := 0; i < 10000; i++ {
key := make([]byte, 32)
val := make([]byte, 32)
binary.BigEndian.PutUint64(key, uint64(i))
binary.BigEndian.PutUint64(val, uint64(i))
key = crypto.Keccak256(key)
val = crypto.Keccak256(val)
trie.Update(key, val)
}
_, nodes := trie.Commit(false)
triedb.Update(NewWithNodeSet(nodes))
// Return the generated trie
return triedb, trie, logDb
}
// Tests that the node iterator indeed walks over the entire database contents.
func TestNodeIteratorLargeTrie(t *testing.T) {
// Create some arbitrary test trie to iterate
_, trie, logDb := makeLargeTestTrie(t)
// Do a seek operation
trie.NodeIterator(common.FromHex("0x77667766776677766778855885885885"))
// master: 24 get operations
// this pr: 5 get operations
if have, want := logDb.getCount, uint64(5); have != want {
t.Fatalf("Wrong number of lookups during seek, have %d want %d", have, want)
}
}
func TestIteratorNodeBlob(t *testing.T) { func TestIteratorNodeBlob(t *testing.T) {
edb := rawdb.NewMemoryDatabase() testIteratorNodeBlob(t, rawdb.HashScheme)
db := geth_trie.NewDatabase(edb) testIteratorNodeBlob(t, rawdb.PathScheme)
orig := geth_trie.NewEmpty(geth_trie.NewDatabase(rawdb.NewMemoryDatabase()))
if _, err := updateTrie(orig, packableTestData); err != nil {
t.Fatal(err)
} }
root := commitTrie(t, db, orig)
trie := indexTrie(t, edb, root)
found := make(map[common.Hash][]byte) func testIteratorNodeBlob(t *testing.T, scheme string) {
it := trie.NodeIterator(nil) var (
db = rawdb.NewMemoryDatabase()
triedb = newTestDatabase(db, scheme)
trie = NewEmpty(triedb)
)
vals := []struct{ k, v string }{
{"do", "verb"},
{"ether", "wookiedoo"},
{"horse", "stallion"},
{"shaman", "horse"},
{"doge", "coin"},
{"dog", "puppy"},
{"somethingveryoddindeedthis is", "myothernodedata"},
}
all := make(map[string]string)
for _, val := range vals {
all[val.k] = val.v
trie.MustUpdate([]byte(val.k), []byte(val.v))
}
root, nodes, _ := trie.Commit(false)
triedb.Update(root, types.EmptyRootHash, trienode.NewWithNodeSet(nodes))
triedb.Commit(root)
var found = make(map[common.Hash][]byte)
trie, _ = New(TrieID(root), triedb)
it := trie.MustNodeIterator(nil)
for it.Next(true) { for it.Next(true) {
if it.Hash() == (common.Hash{}) { if it.Hash() == (common.Hash{}) {
continue continue
@ -579,14 +631,18 @@ func TestIteratorNodeBlob(t *testing.T) {
found[it.Hash()] = it.NodeBlob() found[it.Hash()] = it.NodeBlob()
} }
dbIter := edb.NewIterator(nil, nil) dbIter := db.NewIterator(nil, nil)
defer dbIter.Release() defer dbIter.Release()
var count int var count int
for dbIter.Next() { for dbIter.Next() {
got, present := found[common.BytesToHash(dbIter.Key())] ok, _, _ := isTrieNode(triedb.Scheme(), dbIter.Key(), dbIter.Value())
if !ok {
continue
}
got, present := found[crypto.Keccak256Hash(dbIter.Value())]
if !present { if !present {
t.Fatalf("Miss trie node %v", dbIter.Key()) t.Fatal("Miss trie node")
} }
if !bytes.Equal(got, dbIter.Value()) { if !bytes.Equal(got, dbIter.Value()) {
t.Fatalf("Unexpected trie node want %v got %v", dbIter.Value(), got) t.Fatalf("Unexpected trie node want %v got %v", dbIter.Value(), got)
@ -594,6 +650,44 @@ func TestIteratorNodeBlob(t *testing.T) {
count += 1 count += 1
} }
if count != len(found) { if count != len(found) {
t.Fatalf("Wrong number of trie nodes found, want %d, got %d", len(found), count) t.Fatal("Find extra trie node via iterator")
}
}
// isTrieNode is a helper function which reports if the provided
// database entry belongs to a trie node or not. Note in tests
// only single layer trie is used, namely storage trie is not
// considered at all.
func isTrieNode(scheme string, key, val []byte) (bool, []byte, common.Hash) {
var (
path []byte
hash common.Hash
)
if scheme == rawdb.HashScheme {
ok := rawdb.IsLegacyTrieNode(key, val)
if !ok {
return false, nil, common.Hash{}
}
hash = common.BytesToHash(key)
} else {
ok, remain := rawdb.ResolveAccountTrieNodeKey(key)
if !ok {
return false, nil, common.Hash{}
}
path = common.CopyBytes(remain)
hash = crypto.Keccak256Hash(val)
}
return true, path, hash
}
func BenchmarkIterator(b *testing.B) {
diskDb, srcDb, tr, _ := makeTestTrie(rawdb.HashScheme)
root := tr.Hash()
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
if err := checkTrieConsistency(diskDb, srcDb.Scheme(), root, false); err != nil {
b.Fatal(err)
}
} }
} }

View File

@ -99,6 +99,19 @@ func (n valueNode) fstring(ind string) string {
return fmt.Sprintf("%x ", []byte(n)) return fmt.Sprintf("%x ", []byte(n))
} }
// rawNode is a simple binary blob used to differentiate between collapsed trie
// nodes and already encoded RLP binary blobs (while at the same time store them
// in the same cache fields).
type rawNode []byte
func (n rawNode) cache() (hashNode, bool) { panic("this should never end up in a live trie") }
func (n rawNode) fstring(ind string) string { panic("this should never end up in a live trie") }
func (n rawNode) EncodeRLP(w io.Writer) error {
_, err := w.Write(n)
return err
}
// mustDecodeNode is a wrapper of decodeNode and panic if any error is encountered. // mustDecodeNode is a wrapper of decodeNode and panic if any error is encountered.
func mustDecodeNode(hash, buf []byte) node { func mustDecodeNode(hash, buf []byte) node {
n, err := decodeNode(hash, buf) n, err := decodeNode(hash, buf)

View File

@ -59,29 +59,6 @@ func (n valueNode) encode(w rlp.EncoderBuffer) {
w.WriteBytes(n) w.WriteBytes(n)
} }
func (n rawFullNode) encode(w rlp.EncoderBuffer) {
offset := w.List()
for _, c := range n {
if c != nil {
c.encode(w)
} else {
w.Write(rlp.EmptyString)
}
}
w.ListEnd(offset)
}
func (n *rawShortNode) encode(w rlp.EncoderBuffer) {
offset := w.List()
w.WriteBytes(n.Key)
if n.Val != nil {
n.Val.encode(w)
} else {
w.Write(rlp.EmptyString)
}
w.ListEnd(offset)
}
func (n rawNode) encode(w rlp.EncoderBuffer) { func (n rawNode) encode(w rlp.EncoderBuffer) {
w.Write(n) w.Write(n)
} }

View File

@ -96,7 +96,7 @@ func TestDecodeFullNode(t *testing.T) {
// goos: darwin // goos: darwin
// goarch: arm64 // goarch: arm64
// pkg: github.com/ethereum/go-ethereum/trie // pkg: github.com/cerc-io/ipld-eth-statedb/trie_by_cid/trie
// BenchmarkEncodeShortNode // BenchmarkEncodeShortNode
// BenchmarkEncodeShortNode-8 16878850 70.81 ns/op 48 B/op 1 allocs/op // BenchmarkEncodeShortNode-8 16878850 70.81 ns/op 48 B/op 1 allocs/op
func BenchmarkEncodeShortNode(b *testing.B) { func BenchmarkEncodeShortNode(b *testing.B) {
@ -114,7 +114,7 @@ func BenchmarkEncodeShortNode(b *testing.B) {
// goos: darwin // goos: darwin
// goarch: arm64 // goarch: arm64
// pkg: github.com/ethereum/go-ethereum/trie // pkg: github.com/cerc-io/ipld-eth-statedb/trie_by_cid/trie
// BenchmarkEncodeFullNode // BenchmarkEncodeFullNode
// BenchmarkEncodeFullNode-8 4323273 284.4 ns/op 576 B/op 1 allocs/op // BenchmarkEncodeFullNode-8 4323273 284.4 ns/op 576 B/op 1 allocs/op
func BenchmarkEncodeFullNode(b *testing.B) { func BenchmarkEncodeFullNode(b *testing.B) {
@ -132,7 +132,7 @@ func BenchmarkEncodeFullNode(b *testing.B) {
// goos: darwin // goos: darwin
// goarch: arm64 // goarch: arm64
// pkg: github.com/ethereum/go-ethereum/trie // pkg: github.com/cerc-io/ipld-eth-statedb/trie_by_cid/trie
// BenchmarkDecodeShortNode // BenchmarkDecodeShortNode
// BenchmarkDecodeShortNode-8 7925638 151.0 ns/op 157 B/op 4 allocs/op // BenchmarkDecodeShortNode-8 7925638 151.0 ns/op 157 B/op 4 allocs/op
func BenchmarkDecodeShortNode(b *testing.B) { func BenchmarkDecodeShortNode(b *testing.B) {
@ -153,7 +153,7 @@ func BenchmarkDecodeShortNode(b *testing.B) {
// goos: darwin // goos: darwin
// goarch: arm64 // goarch: arm64
// pkg: github.com/ethereum/go-ethereum/trie // pkg: github.com/cerc-io/ipld-eth-statedb/trie_by_cid/trie
// BenchmarkDecodeShortNodeUnsafe // BenchmarkDecodeShortNodeUnsafe
// BenchmarkDecodeShortNodeUnsafe-8 9027476 128.6 ns/op 109 B/op 3 allocs/op // BenchmarkDecodeShortNodeUnsafe-8 9027476 128.6 ns/op 109 B/op 3 allocs/op
func BenchmarkDecodeShortNodeUnsafe(b *testing.B) { func BenchmarkDecodeShortNodeUnsafe(b *testing.B) {
@ -174,7 +174,7 @@ func BenchmarkDecodeShortNodeUnsafe(b *testing.B) {
// goos: darwin // goos: darwin
// goarch: arm64 // goarch: arm64
// pkg: github.com/ethereum/go-ethereum/trie // pkg: github.com/cerc-io/ipld-eth-statedb/trie_by_cid/trie
// BenchmarkDecodeFullNode // BenchmarkDecodeFullNode
// BenchmarkDecodeFullNode-8 1597462 761.9 ns/op 1280 B/op 18 allocs/op // BenchmarkDecodeFullNode-8 1597462 761.9 ns/op 1280 B/op 18 allocs/op
func BenchmarkDecodeFullNode(b *testing.B) { func BenchmarkDecodeFullNode(b *testing.B) {
@ -195,7 +195,7 @@ func BenchmarkDecodeFullNode(b *testing.B) {
// goos: darwin // goos: darwin
// goarch: arm64 // goarch: arm64
// pkg: github.com/ethereum/go-ethereum/trie // pkg: github.com/cerc-io/ipld-eth-statedb/trie_by_cid/trie
// BenchmarkDecodeFullNodeUnsafe // BenchmarkDecodeFullNodeUnsafe
// BenchmarkDecodeFullNodeUnsafe-8 1789070 687.1 ns/op 704 B/op 17 allocs/op // BenchmarkDecodeFullNodeUnsafe-8 1789070 687.1 ns/op 704 B/op 17 allocs/op
func BenchmarkDecodeFullNodeUnsafe(b *testing.B) { func BenchmarkDecodeFullNodeUnsafe(b *testing.B) {

View File

@ -1,218 +0,0 @@
// Copyright 2022 The go-ethereum Authors
// This file is part of the go-ethereum library.
//
// The go-ethereum library is free software: you can redistribute it and/or modify
// it under the terms of the GNU Lesser General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// The go-ethereum library is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Lesser General Public License for more details.
//
// You should have received a copy of the GNU Lesser General Public License
// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
package trie
import (
"fmt"
"reflect"
"sort"
"strings"
"github.com/ethereum/go-ethereum/common"
)
// memoryNode is all the information we know about a single cached trie node
// in the memory.
type memoryNode struct {
hash common.Hash // Node hash, computed by hashing rlp value, empty for deleted nodes
size uint16 // Byte size of the useful cached data, 0 for deleted nodes
node node // Cached collapsed trie node, or raw rlp data, nil for deleted nodes
}
// memoryNodeSize is the raw size of a memoryNode data structure without any
// node data included. It's an approximate size, but should be a lot better
// than not counting them.
// nolint:unused
var memoryNodeSize = int(reflect.TypeOf(memoryNode{}).Size())
// memorySize returns the total memory size used by this node.
// nolint:unused
func (n *memoryNode) memorySize(pathlen int) int {
return int(n.size) + memoryNodeSize + pathlen
}
// rlp returns the raw rlp encoded blob of the cached trie node, either directly
// from the cache, or by regenerating it from the collapsed node.
// nolint:unused
func (n *memoryNode) rlp() []byte {
if node, ok := n.node.(rawNode); ok {
return node
}
return nodeToBytes(n.node)
}
// obj returns the decoded and expanded trie node, either directly from the cache,
// or by regenerating it from the rlp encoded blob.
// nolint:unused
func (n *memoryNode) obj() node {
if node, ok := n.node.(rawNode); ok {
return mustDecodeNode(n.hash[:], node)
}
return expandNode(n.hash[:], n.node)
}
// isDeleted returns the indicator if the node is marked as deleted.
func (n *memoryNode) isDeleted() bool {
return n.hash == (common.Hash{})
}
// nodeWithPrev wraps the memoryNode with the previous node value.
// nolint: unused
type nodeWithPrev struct {
*memoryNode
prev []byte // RLP-encoded previous value, nil means it's non-existent
}
// unwrap returns the internal memoryNode object.
// nolint:unused
func (n *nodeWithPrev) unwrap() *memoryNode {
return n.memoryNode
}
// memorySize returns the total memory size used by this node. It overloads
// the function in memoryNode by counting the size of previous value as well.
// nolint: unused
func (n *nodeWithPrev) memorySize(pathlen int) int {
return n.memoryNode.memorySize(pathlen) + len(n.prev)
}
// NodeSet contains all dirty nodes collected during the commit operation.
// Each node is keyed by path. It's not thread-safe to use.
type NodeSet struct {
owner common.Hash // the identifier of the trie
nodes map[string]*memoryNode // the set of dirty nodes(inserted, updated, deleted)
leaves []*leaf // the list of dirty leaves
updates int // the count of updated and inserted nodes
deletes int // the count of deleted nodes
// The list of accessed nodes, which records the original node value.
// The origin value is expected to be nil for newly inserted node
// and is expected to be non-nil for other types(updated, deleted).
accessList map[string][]byte
}
// NewNodeSet initializes an empty node set to be used for tracking dirty nodes
// from a specific account or storage trie. The owner is zero for the account
// trie and the owning account address hash for storage tries.
func NewNodeSet(owner common.Hash, accessList map[string][]byte) *NodeSet {
return &NodeSet{
owner: owner,
nodes: make(map[string]*memoryNode),
accessList: accessList,
}
}
// forEachWithOrder iterates the dirty nodes with the order from bottom to top,
// right to left, nodes with the longest path will be iterated first.
func (set *NodeSet) forEachWithOrder(callback func(path string, n *memoryNode)) {
var paths sort.StringSlice
for path := range set.nodes {
paths = append(paths, path)
}
// Bottom-up, longest path first
sort.Sort(sort.Reverse(paths))
for _, path := range paths {
callback(path, set.nodes[path])
}
}
// markUpdated marks the node as dirty(newly-inserted or updated).
func (set *NodeSet) markUpdated(path []byte, node *memoryNode) {
set.nodes[string(path)] = node
set.updates += 1
}
// markDeleted marks the node as deleted.
func (set *NodeSet) markDeleted(path []byte) {
set.nodes[string(path)] = &memoryNode{}
set.deletes += 1
}
// addLeaf collects the provided leaf node into set.
func (set *NodeSet) addLeaf(node *leaf) {
set.leaves = append(set.leaves, node)
}
// Size returns the number of dirty nodes in set.
func (set *NodeSet) Size() (int, int) {
return set.updates, set.deletes
}
// Hashes returns the hashes of all updated nodes. TODO(rjl493456442) how can
// we get rid of it?
func (set *NodeSet) Hashes() []common.Hash {
var ret []common.Hash
for _, node := range set.nodes {
ret = append(ret, node.hash)
}
return ret
}
// Summary returns a string-representation of the NodeSet.
func (set *NodeSet) Summary() string {
var out = new(strings.Builder)
fmt.Fprintf(out, "nodeset owner: %v\n", set.owner)
if set.nodes != nil {
for path, n := range set.nodes {
// Deletion
if n.isDeleted() {
fmt.Fprintf(out, " [-]: %x prev: %x\n", path, set.accessList[path])
continue
}
// Insertion
origin, ok := set.accessList[path]
if !ok {
fmt.Fprintf(out, " [+]: %x -> %v\n", path, n.hash)
continue
}
// Update
fmt.Fprintf(out, " [*]: %x -> %v prev: %x\n", path, n.hash, origin)
}
}
for _, n := range set.leaves {
fmt.Fprintf(out, "[leaf]: %v\n", n)
}
return out.String()
}
// MergedNodeSet represents a merged dirty node set for a group of tries.
type MergedNodeSet struct {
sets map[common.Hash]*NodeSet
}
// NewMergedNodeSet initializes an empty merged set.
func NewMergedNodeSet() *MergedNodeSet {
return &MergedNodeSet{sets: make(map[common.Hash]*NodeSet)}
}
// NewWithNodeSet constructs a merged nodeset with the provided single set.
func NewWithNodeSet(set *NodeSet) *MergedNodeSet {
merged := NewMergedNodeSet()
merged.Merge(set)
return merged
}
// Merge merges the provided dirty nodes of a trie into the set. The assumption
// is held that no duplicated set belonging to the same trie will be merged twice.
func (set *MergedNodeSet) Merge(other *NodeSet) error {
_, present := set.sets[other.owner]
if present {
return fmt.Errorf("duplicate trie for owner %#x", other.owner)
}
set.sets[other.owner] = other
return nil
}

View File

@ -18,17 +18,14 @@ package trie
import ( import (
"bytes" "bytes"
"errors"
"fmt" "fmt"
"github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/ethdb" "github.com/ethereum/go-ethereum/ethdb"
"github.com/ethereum/go-ethereum/trie" "github.com/ethereum/go-ethereum/log"
log "github.com/sirupsen/logrus"
) )
var VerifyProof = trie.VerifyProof
var VerifyRangeProof = trie.VerifyRangeProof
// Prove constructs a merkle proof for key. The result contains all encoded nodes // Prove constructs a merkle proof for key. The result contains all encoded nodes
// on the path to the value at key. The value itself is also included in the last // on the path to the value at key. The value itself is also included in the last
// node and can be retrieved by verifying the proof. // node and can be retrieved by verifying the proof.
@ -36,7 +33,11 @@ var VerifyRangeProof = trie.VerifyRangeProof
// If the trie does not contain a value for key, the returned proof contains all // If the trie does not contain a value for key, the returned proof contains all
// nodes of the longest existing prefix of the key (at least the root node), ending // nodes of the longest existing prefix of the key (at least the root node), ending
// with the node that proves the absence of the key. // with the node that proves the absence of the key.
func (t *Trie) Prove(key []byte, fromLevel uint, proofDb ethdb.KeyValueWriter) error { func (t *Trie) Prove(key []byte, proofDb ethdb.KeyValueWriter) error {
// Short circuit if the trie is already committed and not usable.
if t.committed {
return ErrCommitted
}
// Collect all nodes on the path to key. // Collect all nodes on the path to key.
var ( var (
prefix []byte prefix []byte
@ -67,12 +68,15 @@ func (t *Trie) Prove(key []byte, fromLevel uint, proofDb ethdb.KeyValueWriter) e
// loaded blob will be tracked, while it's not required here since // loaded blob will be tracked, while it's not required here since
// all loaded nodes won't be linked to trie at all and track nodes // all loaded nodes won't be linked to trie at all and track nodes
// may lead to out-of-memory issue. // may lead to out-of-memory issue.
var err error blob, err := t.reader.node(prefix, common.BytesToHash(n))
tn, err = t.reader.node(prefix, common.BytesToHash(n))
if err != nil { if err != nil {
log.Error("Unhandled trie error in Trie.Prove", "err", err) log.Error("Unhandled trie error in Trie.Prove", "err", err)
return err return err
} }
// The raw-blob format nodes are loaded either from the
// clean cache or the database, they are all in their own
// copy and safe to use unsafe decoder.
tn = mustDecodeNodeUnsafe(n, blob)
default: default:
panic(fmt.Sprintf("%T: invalid node: %v", tn, tn)) panic(fmt.Sprintf("%T: invalid node: %v", tn, tn))
} }
@ -81,10 +85,6 @@ func (t *Trie) Prove(key []byte, fromLevel uint, proofDb ethdb.KeyValueWriter) e
defer returnHasherToPool(hasher) defer returnHasherToPool(hasher)
for i, n := range nodes { for i, n := range nodes {
if fromLevel > 0 {
fromLevel--
continue
}
var hn node var hn node
n, hn = hasher.proofHash(n) n, hn = hasher.proofHash(n)
if hash, ok := hn.(hashNode); ok || i == 0 { if hash, ok := hn.(hashNode); ok || i == 0 {
@ -107,6 +107,510 @@ func (t *Trie) Prove(key []byte, fromLevel uint, proofDb ethdb.KeyValueWriter) e
// If the trie does not contain a value for key, the returned proof contains all // If the trie does not contain a value for key, the returned proof contains all
// nodes of the longest existing prefix of the key (at least the root node), ending // nodes of the longest existing prefix of the key (at least the root node), ending
// with the node that proves the absence of the key. // with the node that proves the absence of the key.
func (t *StateTrie) Prove(key []byte, fromLevel uint, proofDb ethdb.KeyValueWriter) error { func (t *StateTrie) Prove(key []byte, proofDb ethdb.KeyValueWriter) error {
return t.trie.Prove(key, fromLevel, proofDb) return t.trie.Prove(key, proofDb)
}
// VerifyProof checks merkle proofs. The given proof must contain the value for
// key in a trie with the given root hash. VerifyProof returns an error if the
// proof contains invalid trie nodes or the wrong value.
func VerifyProof(rootHash common.Hash, key []byte, proofDb ethdb.KeyValueReader) (value []byte, err error) {
key = keybytesToHex(key)
wantHash := rootHash
for i := 0; ; i++ {
buf, _ := proofDb.Get(wantHash[:])
if buf == nil {
return nil, fmt.Errorf("proof node %d (hash %064x) missing", i, wantHash)
}
n, err := decodeNode(wantHash[:], buf)
if err != nil {
return nil, fmt.Errorf("bad proof node %d: %v", i, err)
}
keyrest, cld := get(n, key, true)
switch cld := cld.(type) {
case nil:
// The trie doesn't contain the key.
return nil, nil
case hashNode:
key = keyrest
copy(wantHash[:], cld)
case valueNode:
return cld, nil
}
}
}
// proofToPath converts a merkle proof to trie node path. The main purpose of
// this function is recovering a node path from the merkle proof stream. All
// necessary nodes will be resolved and leave the remaining as hashnode.
//
// The given edge proof is allowed to be an existent or non-existent proof.
func proofToPath(rootHash common.Hash, root node, key []byte, proofDb ethdb.KeyValueReader, allowNonExistent bool) (node, []byte, error) {
// resolveNode retrieves and resolves trie node from merkle proof stream
resolveNode := func(hash common.Hash) (node, error) {
buf, _ := proofDb.Get(hash[:])
if buf == nil {
return nil, fmt.Errorf("proof node (hash %064x) missing", hash)
}
n, err := decodeNode(hash[:], buf)
if err != nil {
return nil, fmt.Errorf("bad proof node %v", err)
}
return n, err
}
// If the root node is empty, resolve it first.
// Root node must be included in the proof.
if root == nil {
n, err := resolveNode(rootHash)
if err != nil {
return nil, nil, err
}
root = n
}
var (
err error
child, parent node
keyrest []byte
valnode []byte
)
key, parent = keybytesToHex(key), root
for {
keyrest, child = get(parent, key, false)
switch cld := child.(type) {
case nil:
// The trie doesn't contain the key. It's possible
// the proof is a non-existing proof, but at least
// we can prove all resolved nodes are correct, it's
// enough for us to prove range.
if allowNonExistent {
return root, nil, nil
}
return nil, nil, errors.New("the node is not contained in trie")
case *shortNode:
key, parent = keyrest, child // Already resolved
continue
case *fullNode:
key, parent = keyrest, child // Already resolved
continue
case hashNode:
child, err = resolveNode(common.BytesToHash(cld))
if err != nil {
return nil, nil, err
}
case valueNode:
valnode = cld
}
// Link the parent and child.
switch pnode := parent.(type) {
case *shortNode:
pnode.Val = child
case *fullNode:
pnode.Children[key[0]] = child
default:
panic(fmt.Sprintf("%T: invalid node: %v", pnode, pnode))
}
if len(valnode) > 0 {
return root, valnode, nil // The whole path is resolved
}
key, parent = keyrest, child
}
}
// unsetInternal removes all internal node references(hashnode, embedded node).
// It should be called after a trie is constructed with two edge paths. Also
// the given boundary keys must be the one used to construct the edge paths.
//
// It's the key step for range proof. All visited nodes should be marked dirty
// since the node content might be modified. Besides it can happen that some
// fullnodes only have one child which is disallowed. But if the proof is valid,
// the missing children will be filled, otherwise it will be thrown anyway.
//
// Note we have the assumption here the given boundary keys are different
// and right is larger than left.
func unsetInternal(n node, left []byte, right []byte) (bool, error) {
left, right = keybytesToHex(left), keybytesToHex(right)
// Step down to the fork point. There are two scenarios can happen:
// - the fork point is a shortnode: either the key of left proof or
// right proof doesn't match with shortnode's key.
// - the fork point is a fullnode: both two edge proofs are allowed
// to point to a non-existent key.
var (
pos = 0
parent node
// fork indicator, 0 means no fork, -1 means proof is less, 1 means proof is greater
shortForkLeft, shortForkRight int
)
findFork:
for {
switch rn := (n).(type) {
case *shortNode:
rn.flags = nodeFlag{dirty: true}
// If either the key of left proof or right proof doesn't match with
// shortnode, stop here and the forkpoint is the shortnode.
if len(left)-pos < len(rn.Key) {
shortForkLeft = bytes.Compare(left[pos:], rn.Key)
} else {
shortForkLeft = bytes.Compare(left[pos:pos+len(rn.Key)], rn.Key)
}
if len(right)-pos < len(rn.Key) {
shortForkRight = bytes.Compare(right[pos:], rn.Key)
} else {
shortForkRight = bytes.Compare(right[pos:pos+len(rn.Key)], rn.Key)
}
if shortForkLeft != 0 || shortForkRight != 0 {
break findFork
}
parent = n
n, pos = rn.Val, pos+len(rn.Key)
case *fullNode:
rn.flags = nodeFlag{dirty: true}
// If either the node pointed by left proof or right proof is nil,
// stop here and the forkpoint is the fullnode.
leftnode, rightnode := rn.Children[left[pos]], rn.Children[right[pos]]
if leftnode == nil || rightnode == nil || leftnode != rightnode {
break findFork
}
parent = n
n, pos = rn.Children[left[pos]], pos+1
default:
panic(fmt.Sprintf("%T: invalid node: %v", n, n))
}
}
switch rn := n.(type) {
case *shortNode:
// There can have these five scenarios:
// - both proofs are less than the trie path => no valid range
// - both proofs are greater than the trie path => no valid range
// - left proof is less and right proof is greater => valid range, unset the shortnode entirely
// - left proof points to the shortnode, but right proof is greater
// - right proof points to the shortnode, but left proof is less
if shortForkLeft == -1 && shortForkRight == -1 {
return false, errors.New("empty range")
}
if shortForkLeft == 1 && shortForkRight == 1 {
return false, errors.New("empty range")
}
if shortForkLeft != 0 && shortForkRight != 0 {
// The fork point is root node, unset the entire trie
if parent == nil {
return true, nil
}
parent.(*fullNode).Children[left[pos-1]] = nil
return false, nil
}
// Only one proof points to non-existent key.
if shortForkRight != 0 {
if _, ok := rn.Val.(valueNode); ok {
// The fork point is root node, unset the entire trie
if parent == nil {
return true, nil
}
parent.(*fullNode).Children[left[pos-1]] = nil
return false, nil
}
return false, unset(rn, rn.Val, left[pos:], len(rn.Key), false)
}
if shortForkLeft != 0 {
if _, ok := rn.Val.(valueNode); ok {
// The fork point is root node, unset the entire trie
if parent == nil {
return true, nil
}
parent.(*fullNode).Children[right[pos-1]] = nil
return false, nil
}
return false, unset(rn, rn.Val, right[pos:], len(rn.Key), true)
}
return false, nil
case *fullNode:
// unset all internal nodes in the forkpoint
for i := left[pos] + 1; i < right[pos]; i++ {
rn.Children[i] = nil
}
if err := unset(rn, rn.Children[left[pos]], left[pos:], 1, false); err != nil {
return false, err
}
if err := unset(rn, rn.Children[right[pos]], right[pos:], 1, true); err != nil {
return false, err
}
return false, nil
default:
panic(fmt.Sprintf("%T: invalid node: %v", n, n))
}
}
// unset removes all internal node references either the left most or right most.
// It can meet these scenarios:
//
// - The given path is existent in the trie, unset the associated nodes with the
// specific direction
// - The given path is non-existent in the trie
// - the fork point is a fullnode, the corresponding child pointed by path
// is nil, return
// - the fork point is a shortnode, the shortnode is included in the range,
// keep the entire branch and return.
// - the fork point is a shortnode, the shortnode is excluded in the range,
// unset the entire branch.
func unset(parent node, child node, key []byte, pos int, removeLeft bool) error {
switch cld := child.(type) {
case *fullNode:
if removeLeft {
for i := 0; i < int(key[pos]); i++ {
cld.Children[i] = nil
}
cld.flags = nodeFlag{dirty: true}
} else {
for i := key[pos] + 1; i < 16; i++ {
cld.Children[i] = nil
}
cld.flags = nodeFlag{dirty: true}
}
return unset(cld, cld.Children[key[pos]], key, pos+1, removeLeft)
case *shortNode:
if len(key[pos:]) < len(cld.Key) || !bytes.Equal(cld.Key, key[pos:pos+len(cld.Key)]) {
// Find the fork point, it's an non-existent branch.
if removeLeft {
if bytes.Compare(cld.Key, key[pos:]) < 0 {
// The key of fork shortnode is less than the path
// (it belongs to the range), unset the entire
// branch. The parent must be a fullnode.
fn := parent.(*fullNode)
fn.Children[key[pos-1]] = nil
}
//else {
// The key of fork shortnode is greater than the
// path(it doesn't belong to the range), keep
// it with the cached hash available.
//}
} else {
if bytes.Compare(cld.Key, key[pos:]) > 0 {
// The key of fork shortnode is greater than the
// path(it belongs to the range), unset the entries
// branch. The parent must be a fullnode.
fn := parent.(*fullNode)
fn.Children[key[pos-1]] = nil
}
//else {
// The key of fork shortnode is less than the
// path(it doesn't belong to the range), keep
// it with the cached hash available.
//}
}
return nil
}
if _, ok := cld.Val.(valueNode); ok {
fn := parent.(*fullNode)
fn.Children[key[pos-1]] = nil
return nil
}
cld.flags = nodeFlag{dirty: true}
return unset(cld, cld.Val, key, pos+len(cld.Key), removeLeft)
case nil:
// If the node is nil, then it's a child of the fork point
// fullnode(it's a non-existent branch).
return nil
default:
panic("it shouldn't happen") // hashNode, valueNode
}
}
// hasRightElement returns the indicator whether there exists more elements
// on the right side of the given path. The given path can point to an existent
// key or a non-existent one. This function has the assumption that the whole
// path should already be resolved.
func hasRightElement(node node, key []byte) bool {
pos, key := 0, keybytesToHex(key)
for node != nil {
switch rn := node.(type) {
case *fullNode:
for i := key[pos] + 1; i < 16; i++ {
if rn.Children[i] != nil {
return true
}
}
node, pos = rn.Children[key[pos]], pos+1
case *shortNode:
if len(key)-pos < len(rn.Key) || !bytes.Equal(rn.Key, key[pos:pos+len(rn.Key)]) {
return bytes.Compare(rn.Key, key[pos:]) > 0
}
node, pos = rn.Val, pos+len(rn.Key)
case valueNode:
return false // We have resolved the whole path
default:
panic(fmt.Sprintf("%T: invalid node: %v", node, node)) // hashnode
}
}
return false
}
// VerifyRangeProof checks whether the given leaf nodes and edge proof
// can prove the given trie leaves range is matched with the specific root.
// Besides, the range should be consecutive (no gap inside) and monotonic
// increasing.
//
// Note the given proof actually contains two edge proofs. Both of them can
// be non-existent proofs. For example the first proof is for a non-existent
// key 0x03, the last proof is for a non-existent key 0x10. The given batch
// leaves are [0x04, 0x05, .. 0x09]. It's still feasible to prove the given
// batch is valid.
//
// The firstKey is paired with firstProof, not necessarily the same as keys[0]
// (unless firstProof is an existent proof). Similarly, lastKey and lastProof
// are paired.
//
// Expect the normal case, this function can also be used to verify the following
// range proofs:
//
// - All elements proof. In this case the proof can be nil, but the range should
// be all the leaves in the trie.
//
// - One element proof. In this case no matter the edge proof is a non-existent
// proof or not, we can always verify the correctness of the proof.
//
// - Zero element proof. In this case a single non-existent proof is enough to prove.
// Besides, if there are still some other leaves available on the right side, then
// an error will be returned.
//
// Except returning the error to indicate the proof is valid or not, the function will
// also return a flag to indicate whether there exists more accounts/slots in the trie.
//
// Note: This method does not verify that the proof is of minimal form. If the input
// proofs are 'bloated' with neighbour leaves or random data, aside from the 'useful'
// data, then the proof will still be accepted.
func VerifyRangeProof(rootHash common.Hash, firstKey []byte, keys [][]byte, values [][]byte, proof ethdb.KeyValueReader) (bool, error) {
if len(keys) != len(values) {
return false, fmt.Errorf("inconsistent proof data, keys: %d, values: %d", len(keys), len(values))
}
// Ensure the received batch is monotonic increasing and contains no deletions
for i := 0; i < len(keys)-1; i++ {
if bytes.Compare(keys[i], keys[i+1]) >= 0 {
return false, errors.New("range is not monotonically increasing")
}
}
for _, value := range values {
if len(value) == 0 {
return false, errors.New("range contains deletion")
}
}
// Special case, there is no edge proof at all. The given range is expected
// to be the whole leaf-set in the trie.
if proof == nil {
tr := NewStackTrie(nil)
for index, key := range keys {
tr.Update(key, values[index])
}
if have, want := tr.Hash(), rootHash; have != want {
return false, fmt.Errorf("invalid proof, want hash %x, got %x", want, have)
}
return false, nil // No more elements
}
// Special case, there is a provided edge proof but zero key/value
// pairs, ensure there are no more accounts / slots in the trie.
if len(keys) == 0 {
root, val, err := proofToPath(rootHash, nil, firstKey, proof, true)
if err != nil {
return false, err
}
if val != nil || hasRightElement(root, firstKey) {
return false, errors.New("more entries available")
}
return false, nil
}
var lastKey = keys[len(keys)-1]
// Special case, there is only one element and two edge keys are same.
// In this case, we can't construct two edge paths. So handle it here.
if len(keys) == 1 && bytes.Equal(firstKey, lastKey) {
root, val, err := proofToPath(rootHash, nil, firstKey, proof, false)
if err != nil {
return false, err
}
if !bytes.Equal(firstKey, keys[0]) {
return false, errors.New("correct proof but invalid key")
}
if !bytes.Equal(val, values[0]) {
return false, errors.New("correct proof but invalid data")
}
return hasRightElement(root, firstKey), nil
}
// Ok, in all other cases, we require two edge paths available.
// First check the validity of edge keys.
if bytes.Compare(firstKey, lastKey) >= 0 {
return false, errors.New("invalid edge keys")
}
// todo(rjl493456442) different length edge keys should be supported
if len(firstKey) != len(lastKey) {
return false, errors.New("inconsistent edge keys")
}
// Convert the edge proofs to edge trie paths. Then we can
// have the same tree architecture with the original one.
// For the first edge proof, non-existent proof is allowed.
root, _, err := proofToPath(rootHash, nil, firstKey, proof, true)
if err != nil {
return false, err
}
// Pass the root node here, the second path will be merged
// with the first one. For the last edge proof, non-existent
// proof is also allowed.
root, _, err = proofToPath(rootHash, root, lastKey, proof, true)
if err != nil {
return false, err
}
// Remove all internal references. All the removed parts should
// be re-filled(or re-constructed) by the given leaves range.
empty, err := unsetInternal(root, firstKey, lastKey)
if err != nil {
return false, err
}
// Rebuild the trie with the leaf stream, the shape of trie
// should be same with the original one.
tr := &Trie{root: root, reader: newEmptyReader(), tracer: newTracer()}
if empty {
tr.root = nil
}
for index, key := range keys {
tr.Update(key, values[index])
}
if tr.Hash() != rootHash {
return false, fmt.Errorf("invalid proof, want hash %x, got %x", rootHash, tr.Hash())
}
return hasRightElement(tr.root, keys[len(keys)-1]), nil
}
// get returns the child of the given node. Return nil if the
// node with specified key doesn't exist at all.
//
// There is an additional flag `skipResolved`. If it's set then
// all resolved nodes won't be returned.
func get(tn node, key []byte, skipResolved bool) ([]byte, node) {
for {
switch n := tn.(type) {
case *shortNode:
if len(key) < len(n.Key) || !bytes.Equal(n.Key, key[:len(n.Key)]) {
return nil, nil
}
tn = n.Val
key = key[len(n.Key):]
if !skipResolved {
return key, tn
}
case *fullNode:
tn = n.Children[key[0]]
key = key[1:]
if !skipResolved {
return key, tn
}
case hashNode:
return key, n
case nil:
return key, nil
case valueNode:
return nil, n
default:
panic(fmt.Sprintf("%T: invalid node: %v", tn, tn))
}
}
} }

View File

@ -22,13 +22,13 @@ import (
"encoding/binary" "encoding/binary"
"fmt" "fmt"
mrand "math/rand" mrand "math/rand"
"sort"
"testing" "testing"
"github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/core/rawdb" "github.com/ethereum/go-ethereum/core/rawdb"
"github.com/ethereum/go-ethereum/crypto" "github.com/ethereum/go-ethereum/crypto"
"github.com/ethereum/go-ethereum/ethdb/memorydb" "github.com/ethereum/go-ethereum/ethdb/memorydb"
"golang.org/x/exp/slices"
) )
// Prng is a pseudo random number generator seeded by strong randomness. // Prng is a pseudo random number generator seeded by strong randomness.
@ -57,13 +57,13 @@ func makeProvers(trie *Trie) []func(key []byte) *memorydb.Database {
// Create a direct trie based Merkle prover // Create a direct trie based Merkle prover
provers = append(provers, func(key []byte) *memorydb.Database { provers = append(provers, func(key []byte) *memorydb.Database {
proof := memorydb.New() proof := memorydb.New()
trie.Prove(key, 0, proof) trie.Prove(key, proof)
return proof return proof
}) })
// Create a leaf iterator based Merkle prover // Create a leaf iterator based Merkle prover
provers = append(provers, func(key []byte) *memorydb.Database { provers = append(provers, func(key []byte) *memorydb.Database {
proof := memorydb.New() proof := memorydb.New()
if it := NewIterator(trie.NodeIterator(key)); it.Next() && bytes.Equal(key, it.Key) { if it := NewIterator(trie.MustNodeIterator(key)); it.Next() && bytes.Equal(key, it.Key) {
for _, p := range it.Prove() { for _, p := range it.Prove() {
proof.Put(crypto.Keccak256(p), p) proof.Put(crypto.Keccak256(p), p)
} }
@ -94,7 +94,7 @@ func TestProof(t *testing.T) {
} }
func TestOneElementProof(t *testing.T) { func TestOneElementProof(t *testing.T) {
trie := NewEmpty(NewDatabase(rawdb.NewMemoryDatabase())) trie := NewEmpty(newTestDatabase(rawdb.NewMemoryDatabase(), rawdb.HashScheme))
updateString(trie, "k", "v") updateString(trie, "k", "v")
for i, prover := range makeProvers(trie) { for i, prover := range makeProvers(trie) {
proof := prover([]byte("k")) proof := prover([]byte("k"))
@ -145,12 +145,12 @@ func TestBadProof(t *testing.T) {
// Tests that missing keys can also be proven. The test explicitly uses a single // Tests that missing keys can also be proven. The test explicitly uses a single
// entry trie and checks for missing keys both before and after the single entry. // entry trie and checks for missing keys both before and after the single entry.
func TestMissingKeyProof(t *testing.T) { func TestMissingKeyProof(t *testing.T) {
trie := NewEmpty(NewDatabase(rawdb.NewMemoryDatabase())) trie := NewEmpty(newTestDatabase(rawdb.NewMemoryDatabase(), rawdb.HashScheme))
updateString(trie, "k", "v") updateString(trie, "k", "v")
for i, key := range []string{"a", "j", "l", "z"} { for i, key := range []string{"a", "j", "l", "z"} {
proof := memorydb.New() proof := memorydb.New()
trie.Prove([]byte(key), 0, proof) trie.Prove([]byte(key), proof)
if proof.Len() != 1 { if proof.Len() != 1 {
t.Errorf("test %d: proof should have one element", i) t.Errorf("test %d: proof should have one element", i)
@ -165,30 +165,24 @@ func TestMissingKeyProof(t *testing.T) {
} }
} }
type entrySlice []*kv
func (p entrySlice) Len() int { return len(p) }
func (p entrySlice) Less(i, j int) bool { return bytes.Compare(p[i].k, p[j].k) < 0 }
func (p entrySlice) Swap(i, j int) { p[i], p[j] = p[j], p[i] }
// TestRangeProof tests normal range proof with both edge proofs // TestRangeProof tests normal range proof with both edge proofs
// as the existent proof. The test cases are generated randomly. // as the existent proof. The test cases are generated randomly.
func TestRangeProof(t *testing.T) { func TestRangeProof(t *testing.T) {
trie, vals := randomTrie(4096) trie, vals := randomTrie(4096)
var entries entrySlice var entries []*kv
for _, kv := range vals { for _, kv := range vals {
entries = append(entries, kv) entries = append(entries, kv)
} }
sort.Sort(entries) slices.SortFunc(entries, (*kv).cmp)
for i := 0; i < 500; i++ { for i := 0; i < 500; i++ {
start := mrand.Intn(len(entries)) start := mrand.Intn(len(entries))
end := mrand.Intn(len(entries)-start) + start + 1 end := mrand.Intn(len(entries)-start) + start + 1
proof := memorydb.New() proof := memorydb.New()
if err := trie.Prove(entries[start].k, 0, proof); err != nil { if err := trie.Prove(entries[start].k, proof); err != nil {
t.Fatalf("Failed to prove the first node %v", err) t.Fatalf("Failed to prove the first node %v", err)
} }
if err := trie.Prove(entries[end-1].k, 0, proof); err != nil { if err := trie.Prove(entries[end-1].k, proof); err != nil {
t.Fatalf("Failed to prove the last node %v", err) t.Fatalf("Failed to prove the last node %v", err)
} }
var keys [][]byte var keys [][]byte
@ -197,7 +191,7 @@ func TestRangeProof(t *testing.T) {
keys = append(keys, entries[i].k) keys = append(keys, entries[i].k)
vals = append(vals, entries[i].v) vals = append(vals, entries[i].v)
} }
_, err := VerifyRangeProof(trie.Hash(), keys[0], keys[len(keys)-1], keys, vals, proof) _, err := VerifyRangeProof(trie.Hash(), keys[0], keys, vals, proof)
if err != nil { if err != nil {
t.Fatalf("Case %d(%d->%d) expect no error, got %v", i, start, end-1, err) t.Fatalf("Case %d(%d->%d) expect no error, got %v", i, start, end-1, err)
} }
@ -208,11 +202,11 @@ func TestRangeProof(t *testing.T) {
// The test cases are generated randomly. // The test cases are generated randomly.
func TestRangeProofWithNonExistentProof(t *testing.T) { func TestRangeProofWithNonExistentProof(t *testing.T) {
trie, vals := randomTrie(4096) trie, vals := randomTrie(4096)
var entries entrySlice var entries []*kv
for _, kv := range vals { for _, kv := range vals {
entries = append(entries, kv) entries = append(entries, kv)
} }
sort.Sort(entries) slices.SortFunc(entries, (*kv).cmp)
for i := 0; i < 500; i++ { for i := 0; i < 500; i++ {
start := mrand.Intn(len(entries)) start := mrand.Intn(len(entries))
end := mrand.Intn(len(entries)-start) + start + 1 end := mrand.Intn(len(entries)-start) + start + 1
@ -227,19 +221,10 @@ func TestRangeProofWithNonExistentProof(t *testing.T) {
if bytes.Compare(first, entries[start].k) > 0 { if bytes.Compare(first, entries[start].k) > 0 {
continue continue
} }
// Short circuit if the increased key is same with the next key if err := trie.Prove(first, proof); err != nil {
last := increaseKey(common.CopyBytes(entries[end-1].k))
if end != len(entries) && bytes.Equal(last, entries[end].k) {
continue
}
// Short circuit if the increased key is overflow
if bytes.Compare(last, entries[end-1].k) < 0 {
continue
}
if err := trie.Prove(first, 0, proof); err != nil {
t.Fatalf("Failed to prove the first node %v", err) t.Fatalf("Failed to prove the first node %v", err)
} }
if err := trie.Prove(last, 0, proof); err != nil { if err := trie.Prove(entries[end-1].k, proof); err != nil {
t.Fatalf("Failed to prove the last node %v", err) t.Fatalf("Failed to prove the last node %v", err)
} }
var keys [][]byte var keys [][]byte
@ -248,53 +233,32 @@ func TestRangeProofWithNonExistentProof(t *testing.T) {
keys = append(keys, entries[i].k) keys = append(keys, entries[i].k)
vals = append(vals, entries[i].v) vals = append(vals, entries[i].v)
} }
_, err := VerifyRangeProof(trie.Hash(), first, last, keys, vals, proof) _, err := VerifyRangeProof(trie.Hash(), first, keys, vals, proof)
if err != nil { if err != nil {
t.Fatalf("Case %d(%d->%d) expect no error, got %v", i, start, end-1, err) t.Fatalf("Case %d(%d->%d) expect no error, got %v", i, start, end-1, err)
} }
} }
// Special case, two edge proofs for two edge key.
proof := memorydb.New()
first := common.HexToHash("0x0000000000000000000000000000000000000000000000000000000000000000").Bytes()
last := common.HexToHash("0xffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff").Bytes()
if err := trie.Prove(first, 0, proof); err != nil {
t.Fatalf("Failed to prove the first node %v", err)
}
if err := trie.Prove(last, 0, proof); err != nil {
t.Fatalf("Failed to prove the last node %v", err)
}
var k [][]byte
var v [][]byte
for i := 0; i < len(entries); i++ {
k = append(k, entries[i].k)
v = append(v, entries[i].v)
}
_, err := VerifyRangeProof(trie.Hash(), first, last, k, v, proof)
if err != nil {
t.Fatal("Failed to verify whole rang with non-existent edges")
}
} }
// TestRangeProofWithInvalidNonExistentProof tests such scenarios: // TestRangeProofWithInvalidNonExistentProof tests such scenarios:
// - There exists a gap between the first element and the left edge proof // - There exists a gap between the first element and the left edge proof
// - There exists a gap between the last element and the right edge proof
func TestRangeProofWithInvalidNonExistentProof(t *testing.T) { func TestRangeProofWithInvalidNonExistentProof(t *testing.T) {
trie, vals := randomTrie(4096) trie, vals := randomTrie(4096)
var entries entrySlice var entries []*kv
for _, kv := range vals { for _, kv := range vals {
entries = append(entries, kv) entries = append(entries, kv)
} }
sort.Sort(entries) slices.SortFunc(entries, (*kv).cmp)
// Case 1 // Case 1
start, end := 100, 200 start, end := 100, 200
first := decreaseKey(common.CopyBytes(entries[start].k)) first := decreaseKey(common.CopyBytes(entries[start].k))
proof := memorydb.New() proof := memorydb.New()
if err := trie.Prove(first, 0, proof); err != nil { if err := trie.Prove(first, proof); err != nil {
t.Fatalf("Failed to prove the first node %v", err) t.Fatalf("Failed to prove the first node %v", err)
} }
if err := trie.Prove(entries[end-1].k, 0, proof); err != nil { if err := trie.Prove(entries[end-1].k, proof); err != nil {
t.Fatalf("Failed to prove the last node %v", err) t.Fatalf("Failed to prove the last node %v", err)
} }
start = 105 // Gap created start = 105 // Gap created
@ -304,29 +268,7 @@ func TestRangeProofWithInvalidNonExistentProof(t *testing.T) {
k = append(k, entries[i].k) k = append(k, entries[i].k)
v = append(v, entries[i].v) v = append(v, entries[i].v)
} }
_, err := VerifyRangeProof(trie.Hash(), first, k[len(k)-1], k, v, proof) _, err := VerifyRangeProof(trie.Hash(), first, k, v, proof)
if err == nil {
t.Fatalf("Expected to detect the error, got nil")
}
// Case 2
start, end = 100, 200
last := increaseKey(common.CopyBytes(entries[end-1].k))
proof = memorydb.New()
if err := trie.Prove(entries[start].k, 0, proof); err != nil {
t.Fatalf("Failed to prove the first node %v", err)
}
if err := trie.Prove(last, 0, proof); err != nil {
t.Fatalf("Failed to prove the last node %v", err)
}
end = 195 // Capped slice
k = make([][]byte, 0)
v = make([][]byte, 0)
for i := start; i < end; i++ {
k = append(k, entries[i].k)
v = append(v, entries[i].v)
}
_, err = VerifyRangeProof(trie.Hash(), k[0], last, k, v, proof)
if err == nil { if err == nil {
t.Fatalf("Expected to detect the error, got nil") t.Fatalf("Expected to detect the error, got nil")
} }
@ -337,20 +279,20 @@ func TestRangeProofWithInvalidNonExistentProof(t *testing.T) {
// non-existent one. // non-existent one.
func TestOneElementRangeProof(t *testing.T) { func TestOneElementRangeProof(t *testing.T) {
trie, vals := randomTrie(4096) trie, vals := randomTrie(4096)
var entries entrySlice var entries []*kv
for _, kv := range vals { for _, kv := range vals {
entries = append(entries, kv) entries = append(entries, kv)
} }
sort.Sort(entries) slices.SortFunc(entries, (*kv).cmp)
// One element with existent edge proof, both edge proofs // One element with existent edge proof, both edge proofs
// point to the SAME key. // point to the SAME key.
start := 1000 start := 1000
proof := memorydb.New() proof := memorydb.New()
if err := trie.Prove(entries[start].k, 0, proof); err != nil { if err := trie.Prove(entries[start].k, proof); err != nil {
t.Fatalf("Failed to prove the first node %v", err) t.Fatalf("Failed to prove the first node %v", err)
} }
_, err := VerifyRangeProof(trie.Hash(), entries[start].k, entries[start].k, [][]byte{entries[start].k}, [][]byte{entries[start].v}, proof) _, err := VerifyRangeProof(trie.Hash(), entries[start].k, [][]byte{entries[start].k}, [][]byte{entries[start].v}, proof)
if err != nil { if err != nil {
t.Fatalf("Expected no error, got %v", err) t.Fatalf("Expected no error, got %v", err)
} }
@ -359,13 +301,13 @@ func TestOneElementRangeProof(t *testing.T) {
start = 1000 start = 1000
first := decreaseKey(common.CopyBytes(entries[start].k)) first := decreaseKey(common.CopyBytes(entries[start].k))
proof = memorydb.New() proof = memorydb.New()
if err := trie.Prove(first, 0, proof); err != nil { if err := trie.Prove(first, proof); err != nil {
t.Fatalf("Failed to prove the first node %v", err) t.Fatalf("Failed to prove the first node %v", err)
} }
if err := trie.Prove(entries[start].k, 0, proof); err != nil { if err := trie.Prove(entries[start].k, proof); err != nil {
t.Fatalf("Failed to prove the last node %v", err) t.Fatalf("Failed to prove the last node %v", err)
} }
_, err = VerifyRangeProof(trie.Hash(), first, entries[start].k, [][]byte{entries[start].k}, [][]byte{entries[start].v}, proof) _, err = VerifyRangeProof(trie.Hash(), first, [][]byte{entries[start].k}, [][]byte{entries[start].v}, proof)
if err != nil { if err != nil {
t.Fatalf("Expected no error, got %v", err) t.Fatalf("Expected no error, got %v", err)
} }
@ -374,13 +316,13 @@ func TestOneElementRangeProof(t *testing.T) {
start = 1000 start = 1000
last := increaseKey(common.CopyBytes(entries[start].k)) last := increaseKey(common.CopyBytes(entries[start].k))
proof = memorydb.New() proof = memorydb.New()
if err := trie.Prove(entries[start].k, 0, proof); err != nil { if err := trie.Prove(entries[start].k, proof); err != nil {
t.Fatalf("Failed to prove the first node %v", err) t.Fatalf("Failed to prove the first node %v", err)
} }
if err := trie.Prove(last, 0, proof); err != nil { if err := trie.Prove(last, proof); err != nil {
t.Fatalf("Failed to prove the last node %v", err) t.Fatalf("Failed to prove the last node %v", err)
} }
_, err = VerifyRangeProof(trie.Hash(), entries[start].k, last, [][]byte{entries[start].k}, [][]byte{entries[start].v}, proof) _, err = VerifyRangeProof(trie.Hash(), entries[start].k, [][]byte{entries[start].k}, [][]byte{entries[start].v}, proof)
if err != nil { if err != nil {
t.Fatalf("Expected no error, got %v", err) t.Fatalf("Expected no error, got %v", err)
} }
@ -389,32 +331,32 @@ func TestOneElementRangeProof(t *testing.T) {
start = 1000 start = 1000
first, last = decreaseKey(common.CopyBytes(entries[start].k)), increaseKey(common.CopyBytes(entries[start].k)) first, last = decreaseKey(common.CopyBytes(entries[start].k)), increaseKey(common.CopyBytes(entries[start].k))
proof = memorydb.New() proof = memorydb.New()
if err := trie.Prove(first, 0, proof); err != nil { if err := trie.Prove(first, proof); err != nil {
t.Fatalf("Failed to prove the first node %v", err) t.Fatalf("Failed to prove the first node %v", err)
} }
if err := trie.Prove(last, 0, proof); err != nil { if err := trie.Prove(last, proof); err != nil {
t.Fatalf("Failed to prove the last node %v", err) t.Fatalf("Failed to prove the last node %v", err)
} }
_, err = VerifyRangeProof(trie.Hash(), first, last, [][]byte{entries[start].k}, [][]byte{entries[start].v}, proof) _, err = VerifyRangeProof(trie.Hash(), first, [][]byte{entries[start].k}, [][]byte{entries[start].v}, proof)
if err != nil { if err != nil {
t.Fatalf("Expected no error, got %v", err) t.Fatalf("Expected no error, got %v", err)
} }
// Test the mini trie with only a single element. // Test the mini trie with only a single element.
tinyTrie := NewEmpty(NewDatabase(rawdb.NewMemoryDatabase())) tinyTrie := NewEmpty(newTestDatabase(rawdb.NewMemoryDatabase(), rawdb.HashScheme))
entry := &kv{randBytes(32), randBytes(20), false} entry := &kv{randBytes(32), randBytes(20), false}
tinyTrie.Update(entry.k, entry.v) tinyTrie.MustUpdate(entry.k, entry.v)
first = common.HexToHash("0x0000000000000000000000000000000000000000000000000000000000000000").Bytes() first = common.HexToHash("0x0000000000000000000000000000000000000000000000000000000000000000").Bytes()
last = entry.k last = entry.k
proof = memorydb.New() proof = memorydb.New()
if err := tinyTrie.Prove(first, 0, proof); err != nil { if err := tinyTrie.Prove(first, proof); err != nil {
t.Fatalf("Failed to prove the first node %v", err) t.Fatalf("Failed to prove the first node %v", err)
} }
if err := tinyTrie.Prove(last, 0, proof); err != nil { if err := tinyTrie.Prove(last, proof); err != nil {
t.Fatalf("Failed to prove the last node %v", err) t.Fatalf("Failed to prove the last node %v", err)
} }
_, err = VerifyRangeProof(tinyTrie.Hash(), first, last, [][]byte{entry.k}, [][]byte{entry.v}, proof) _, err = VerifyRangeProof(tinyTrie.Hash(), first, [][]byte{entry.k}, [][]byte{entry.v}, proof)
if err != nil { if err != nil {
t.Fatalf("Expected no error, got %v", err) t.Fatalf("Expected no error, got %v", err)
} }
@ -424,11 +366,11 @@ func TestOneElementRangeProof(t *testing.T) {
// The edge proofs can be nil. // The edge proofs can be nil.
func TestAllElementsProof(t *testing.T) { func TestAllElementsProof(t *testing.T) {
trie, vals := randomTrie(4096) trie, vals := randomTrie(4096)
var entries entrySlice var entries []*kv
for _, kv := range vals { for _, kv := range vals {
entries = append(entries, kv) entries = append(entries, kv)
} }
sort.Sort(entries) slices.SortFunc(entries, (*kv).cmp)
var k [][]byte var k [][]byte
var v [][]byte var v [][]byte
@ -436,20 +378,20 @@ func TestAllElementsProof(t *testing.T) {
k = append(k, entries[i].k) k = append(k, entries[i].k)
v = append(v, entries[i].v) v = append(v, entries[i].v)
} }
_, err := VerifyRangeProof(trie.Hash(), nil, nil, k, v, nil) _, err := VerifyRangeProof(trie.Hash(), nil, k, v, nil)
if err != nil { if err != nil {
t.Fatalf("Expected no error, got %v", err) t.Fatalf("Expected no error, got %v", err)
} }
// With edge proofs, it should still work. // With edge proofs, it should still work.
proof := memorydb.New() proof := memorydb.New()
if err := trie.Prove(entries[0].k, 0, proof); err != nil { if err := trie.Prove(entries[0].k, proof); err != nil {
t.Fatalf("Failed to prove the first node %v", err) t.Fatalf("Failed to prove the first node %v", err)
} }
if err := trie.Prove(entries[len(entries)-1].k, 0, proof); err != nil { if err := trie.Prove(entries[len(entries)-1].k, proof); err != nil {
t.Fatalf("Failed to prove the last node %v", err) t.Fatalf("Failed to prove the last node %v", err)
} }
_, err = VerifyRangeProof(trie.Hash(), k[0], k[len(k)-1], k, v, proof) _, err = VerifyRangeProof(trie.Hash(), k[0], k, v, proof)
if err != nil { if err != nil {
t.Fatalf("Expected no error, got %v", err) t.Fatalf("Expected no error, got %v", err)
} }
@ -457,14 +399,13 @@ func TestAllElementsProof(t *testing.T) {
// Even with non-existent edge proofs, it should still work. // Even with non-existent edge proofs, it should still work.
proof = memorydb.New() proof = memorydb.New()
first := common.HexToHash("0x0000000000000000000000000000000000000000000000000000000000000000").Bytes() first := common.HexToHash("0x0000000000000000000000000000000000000000000000000000000000000000").Bytes()
last := common.HexToHash("0xffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff").Bytes() if err := trie.Prove(first, proof); err != nil {
if err := trie.Prove(first, 0, proof); err != nil {
t.Fatalf("Failed to prove the first node %v", err) t.Fatalf("Failed to prove the first node %v", err)
} }
if err := trie.Prove(last, 0, proof); err != nil { if err := trie.Prove(entries[len(entries)-1].k, proof); err != nil {
t.Fatalf("Failed to prove the last node %v", err) t.Fatalf("Failed to prove the last node %v", err)
} }
_, err = VerifyRangeProof(trie.Hash(), first, last, k, v, proof) _, err = VerifyRangeProof(trie.Hash(), first, k, v, proof)
if err != nil { if err != nil {
t.Fatalf("Expected no error, got %v", err) t.Fatalf("Expected no error, got %v", err)
} }
@ -473,22 +414,22 @@ func TestAllElementsProof(t *testing.T) {
// TestSingleSideRangeProof tests the range starts from zero. // TestSingleSideRangeProof tests the range starts from zero.
func TestSingleSideRangeProof(t *testing.T) { func TestSingleSideRangeProof(t *testing.T) {
for i := 0; i < 64; i++ { for i := 0; i < 64; i++ {
trie := NewEmpty(NewDatabase(rawdb.NewMemoryDatabase())) trie := NewEmpty(newTestDatabase(rawdb.NewMemoryDatabase(), rawdb.HashScheme))
var entries entrySlice var entries []*kv
for i := 0; i < 4096; i++ { for i := 0; i < 4096; i++ {
value := &kv{randBytes(32), randBytes(20), false} value := &kv{randBytes(32), randBytes(20), false}
trie.Update(value.k, value.v) trie.MustUpdate(value.k, value.v)
entries = append(entries, value) entries = append(entries, value)
} }
sort.Sort(entries) slices.SortFunc(entries, (*kv).cmp)
var cases = []int{0, 1, 50, 100, 1000, 2000, len(entries) - 1} var cases = []int{0, 1, 50, 100, 1000, 2000, len(entries) - 1}
for _, pos := range cases { for _, pos := range cases {
proof := memorydb.New() proof := memorydb.New()
if err := trie.Prove(common.Hash{}.Bytes(), 0, proof); err != nil { if err := trie.Prove(common.Hash{}.Bytes(), proof); err != nil {
t.Fatalf("Failed to prove the first node %v", err) t.Fatalf("Failed to prove the first node %v", err)
} }
if err := trie.Prove(entries[pos].k, 0, proof); err != nil { if err := trie.Prove(entries[pos].k, proof); err != nil {
t.Fatalf("Failed to prove the first node %v", err) t.Fatalf("Failed to prove the first node %v", err)
} }
k := make([][]byte, 0) k := make([][]byte, 0)
@ -497,43 +438,7 @@ func TestSingleSideRangeProof(t *testing.T) {
k = append(k, entries[i].k) k = append(k, entries[i].k)
v = append(v, entries[i].v) v = append(v, entries[i].v)
} }
_, err := VerifyRangeProof(trie.Hash(), common.Hash{}.Bytes(), k[len(k)-1], k, v, proof) _, err := VerifyRangeProof(trie.Hash(), common.Hash{}.Bytes(), k, v, proof)
if err != nil {
t.Fatalf("Expected no error, got %v", err)
}
}
}
}
// TestReverseSingleSideRangeProof tests the range ends with 0xffff...fff.
func TestReverseSingleSideRangeProof(t *testing.T) {
for i := 0; i < 64; i++ {
trie := NewEmpty(NewDatabase(rawdb.NewMemoryDatabase()))
var entries entrySlice
for i := 0; i < 4096; i++ {
value := &kv{randBytes(32), randBytes(20), false}
trie.Update(value.k, value.v)
entries = append(entries, value)
}
sort.Sort(entries)
var cases = []int{0, 1, 50, 100, 1000, 2000, len(entries) - 1}
for _, pos := range cases {
proof := memorydb.New()
if err := trie.Prove(entries[pos].k, 0, proof); err != nil {
t.Fatalf("Failed to prove the first node %v", err)
}
last := common.HexToHash("0xffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff")
if err := trie.Prove(last.Bytes(), 0, proof); err != nil {
t.Fatalf("Failed to prove the last node %v", err)
}
k := make([][]byte, 0)
v := make([][]byte, 0)
for i := pos; i < len(entries); i++ {
k = append(k, entries[i].k)
v = append(v, entries[i].v)
}
_, err := VerifyRangeProof(trie.Hash(), k[0], last.Bytes(), k, v, proof)
if err != nil { if err != nil {
t.Fatalf("Expected no error, got %v", err) t.Fatalf("Expected no error, got %v", err)
} }
@ -545,20 +450,20 @@ func TestReverseSingleSideRangeProof(t *testing.T) {
// The prover is expected to detect the error. // The prover is expected to detect the error.
func TestBadRangeProof(t *testing.T) { func TestBadRangeProof(t *testing.T) {
trie, vals := randomTrie(4096) trie, vals := randomTrie(4096)
var entries entrySlice var entries []*kv
for _, kv := range vals { for _, kv := range vals {
entries = append(entries, kv) entries = append(entries, kv)
} }
sort.Sort(entries) slices.SortFunc(entries, (*kv).cmp)
for i := 0; i < 500; i++ { for i := 0; i < 500; i++ {
start := mrand.Intn(len(entries)) start := mrand.Intn(len(entries))
end := mrand.Intn(len(entries)-start) + start + 1 end := mrand.Intn(len(entries)-start) + start + 1
proof := memorydb.New() proof := memorydb.New()
if err := trie.Prove(entries[start].k, 0, proof); err != nil { if err := trie.Prove(entries[start].k, proof); err != nil {
t.Fatalf("Failed to prove the first node %v", err) t.Fatalf("Failed to prove the first node %v", err)
} }
if err := trie.Prove(entries[end-1].k, 0, proof); err != nil { if err := trie.Prove(entries[end-1].k, proof); err != nil {
t.Fatalf("Failed to prove the last node %v", err) t.Fatalf("Failed to prove the last node %v", err)
} }
var keys [][]byte var keys [][]byte
@ -567,7 +472,7 @@ func TestBadRangeProof(t *testing.T) {
keys = append(keys, entries[i].k) keys = append(keys, entries[i].k)
vals = append(vals, entries[i].v) vals = append(vals, entries[i].v)
} }
var first, last = keys[0], keys[len(keys)-1] var first = keys[0]
testcase := mrand.Intn(6) testcase := mrand.Intn(6)
var index int var index int
switch testcase { switch testcase {
@ -582,7 +487,7 @@ func TestBadRangeProof(t *testing.T) {
case 2: case 2:
// Gapped entry slice // Gapped entry slice
index = mrand.Intn(end - start) index = mrand.Intn(end - start)
if (index == 0 && start < 100) || (index == end-start-1 && end <= 100) { if (index == 0 && start < 100) || (index == end-start-1) {
continue continue
} }
keys = append(keys[:index], keys[index+1:]...) keys = append(keys[:index], keys[index+1:]...)
@ -605,7 +510,7 @@ func TestBadRangeProof(t *testing.T) {
index = mrand.Intn(end - start) index = mrand.Intn(end - start)
vals[index] = nil vals[index] = nil
} }
_, err := VerifyRangeProof(trie.Hash(), first, last, keys, vals, proof) _, err := VerifyRangeProof(trie.Hash(), first, keys, vals, proof)
if err == nil { if err == nil {
t.Fatalf("%d Case %d index %d range: (%d->%d) expect error, got nil", i, testcase, index, start, end-1) t.Fatalf("%d Case %d index %d range: (%d->%d) expect error, got nil", i, testcase, index, start, end-1)
} }
@ -615,19 +520,19 @@ func TestBadRangeProof(t *testing.T) {
// TestGappedRangeProof focuses on the small trie with embedded nodes. // TestGappedRangeProof focuses on the small trie with embedded nodes.
// If the gapped node is embedded in the trie, it should be detected too. // If the gapped node is embedded in the trie, it should be detected too.
func TestGappedRangeProof(t *testing.T) { func TestGappedRangeProof(t *testing.T) {
trie := NewEmpty(NewDatabase(rawdb.NewMemoryDatabase())) trie := NewEmpty(newTestDatabase(rawdb.NewMemoryDatabase(), rawdb.HashScheme))
var entries []*kv // Sorted entries var entries []*kv // Sorted entries
for i := byte(0); i < 10; i++ { for i := byte(0); i < 10; i++ {
value := &kv{common.LeftPadBytes([]byte{i}, 32), []byte{i}, false} value := &kv{common.LeftPadBytes([]byte{i}, 32), []byte{i}, false}
trie.Update(value.k, value.v) trie.MustUpdate(value.k, value.v)
entries = append(entries, value) entries = append(entries, value)
} }
first, last := 2, 8 first, last := 2, 8
proof := memorydb.New() proof := memorydb.New()
if err := trie.Prove(entries[first].k, 0, proof); err != nil { if err := trie.Prove(entries[first].k, proof); err != nil {
t.Fatalf("Failed to prove the first node %v", err) t.Fatalf("Failed to prove the first node %v", err)
} }
if err := trie.Prove(entries[last-1].k, 0, proof); err != nil { if err := trie.Prove(entries[last-1].k, proof); err != nil {
t.Fatalf("Failed to prove the last node %v", err) t.Fatalf("Failed to prove the last node %v", err)
} }
var keys [][]byte var keys [][]byte
@ -639,7 +544,7 @@ func TestGappedRangeProof(t *testing.T) {
keys = append(keys, entries[i].k) keys = append(keys, entries[i].k)
vals = append(vals, entries[i].v) vals = append(vals, entries[i].v)
} }
_, err := VerifyRangeProof(trie.Hash(), keys[0], keys[len(keys)-1], keys, vals, proof) _, err := VerifyRangeProof(trie.Hash(), keys[0], keys, vals, proof)
if err == nil { if err == nil {
t.Fatal("expect error, got nil") t.Fatal("expect error, got nil")
} }
@ -648,55 +553,53 @@ func TestGappedRangeProof(t *testing.T) {
// TestSameSideProofs tests the element is not in the range covered by proofs // TestSameSideProofs tests the element is not in the range covered by proofs
func TestSameSideProofs(t *testing.T) { func TestSameSideProofs(t *testing.T) {
trie, vals := randomTrie(4096) trie, vals := randomTrie(4096)
var entries entrySlice var entries []*kv
for _, kv := range vals { for _, kv := range vals {
entries = append(entries, kv) entries = append(entries, kv)
} }
sort.Sort(entries) slices.SortFunc(entries, (*kv).cmp)
pos := 1000 pos := 1000
first := decreaseKey(common.CopyBytes(entries[pos].k)) first := common.CopyBytes(entries[0].k)
first = decreaseKey(first)
last := decreaseKey(common.CopyBytes(entries[pos].k))
proof := memorydb.New() proof := memorydb.New()
if err := trie.Prove(first, 0, proof); err != nil { if err := trie.Prove(first, proof); err != nil {
t.Fatalf("Failed to prove the first node %v", err) t.Fatalf("Failed to prove the first node %v", err)
} }
if err := trie.Prove(last, 0, proof); err != nil { if err := trie.Prove(entries[2000].k, proof); err != nil {
t.Fatalf("Failed to prove the last node %v", err) t.Fatalf("Failed to prove the first node %v", err)
} }
_, err := VerifyRangeProof(trie.Hash(), first, last, [][]byte{entries[pos].k}, [][]byte{entries[pos].v}, proof) _, err := VerifyRangeProof(trie.Hash(), first, [][]byte{entries[pos].k}, [][]byte{entries[pos].v}, proof)
if err == nil { if err == nil {
t.Fatalf("Expected error, got nil") t.Fatalf("Expected error, got nil")
} }
first = increaseKey(common.CopyBytes(entries[pos].k)) first = increaseKey(common.CopyBytes(entries[pos].k))
last = increaseKey(common.CopyBytes(entries[pos].k)) last := increaseKey(common.CopyBytes(entries[pos].k))
last = increaseKey(last) last = increaseKey(last)
proof = memorydb.New() proof = memorydb.New()
if err := trie.Prove(first, 0, proof); err != nil { if err := trie.Prove(first, proof); err != nil {
t.Fatalf("Failed to prove the first node %v", err) t.Fatalf("Failed to prove the first node %v", err)
} }
if err := trie.Prove(last, 0, proof); err != nil { if err := trie.Prove(last, proof); err != nil {
t.Fatalf("Failed to prove the last node %v", err) t.Fatalf("Failed to prove the last node %v", err)
} }
_, err = VerifyRangeProof(trie.Hash(), first, last, [][]byte{entries[pos].k}, [][]byte{entries[pos].v}, proof) _, err = VerifyRangeProof(trie.Hash(), first, [][]byte{entries[pos].k}, [][]byte{entries[pos].v}, proof)
if err == nil { if err == nil {
t.Fatalf("Expected error, got nil") t.Fatalf("Expected error, got nil")
} }
} }
func TestHasRightElement(t *testing.T) { func TestHasRightElement(t *testing.T) {
trie := NewEmpty(NewDatabase(rawdb.NewMemoryDatabase())) trie := NewEmpty(newTestDatabase(rawdb.NewMemoryDatabase(), rawdb.HashScheme))
var entries entrySlice var entries []*kv
for i := 0; i < 4096; i++ { for i := 0; i < 4096; i++ {
value := &kv{randBytes(32), randBytes(20), false} value := &kv{randBytes(32), randBytes(20), false}
trie.Update(value.k, value.v) trie.MustUpdate(value.k, value.v)
entries = append(entries, value) entries = append(entries, value)
} }
sort.Sort(entries) slices.SortFunc(entries, (*kv).cmp)
var cases = []struct { var cases = []struct {
start int start int
@ -709,48 +612,37 @@ func TestHasRightElement(t *testing.T) {
{50, 100, true}, {50, 100, true},
{50, len(entries), false}, // No more element expected {50, len(entries), false}, // No more element expected
{len(entries) - 1, len(entries), false}, // Single last element with two existent proofs(point to same key) {len(entries) - 1, len(entries), false}, // Single last element with two existent proofs(point to same key)
{len(entries) - 1, -1, false}, // Single last element with non-existent right proof
{0, len(entries), false}, // The whole set with existent left proof {0, len(entries), false}, // The whole set with existent left proof
{-1, len(entries), false}, // The whole set with non-existent left proof {-1, len(entries), false}, // The whole set with non-existent left proof
{-1, -1, false}, // The whole set with non-existent left/right proof
} }
for _, c := range cases { for _, c := range cases {
var ( var (
firstKey []byte firstKey []byte
lastKey []byte
start = c.start start = c.start
end = c.end end = c.end
proof = memorydb.New() proof = memorydb.New()
) )
if c.start == -1 { if c.start == -1 {
firstKey, start = common.Hash{}.Bytes(), 0 firstKey, start = common.Hash{}.Bytes(), 0
if err := trie.Prove(firstKey, 0, proof); err != nil { if err := trie.Prove(firstKey, proof); err != nil {
t.Fatalf("Failed to prove the first node %v", err) t.Fatalf("Failed to prove the first node %v", err)
} }
} else { } else {
firstKey = entries[c.start].k firstKey = entries[c.start].k
if err := trie.Prove(entries[c.start].k, 0, proof); err != nil { if err := trie.Prove(entries[c.start].k, proof); err != nil {
t.Fatalf("Failed to prove the first node %v", err) t.Fatalf("Failed to prove the first node %v", err)
} }
} }
if c.end == -1 { if err := trie.Prove(entries[c.end-1].k, proof); err != nil {
lastKey, end = common.HexToHash("0xffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff").Bytes(), len(entries)
if err := trie.Prove(lastKey, 0, proof); err != nil {
t.Fatalf("Failed to prove the first node %v", err) t.Fatalf("Failed to prove the first node %v", err)
} }
} else {
lastKey = entries[c.end-1].k
if err := trie.Prove(entries[c.end-1].k, 0, proof); err != nil {
t.Fatalf("Failed to prove the first node %v", err)
}
}
k := make([][]byte, 0) k := make([][]byte, 0)
v := make([][]byte, 0) v := make([][]byte, 0)
for i := start; i < end; i++ { for i := start; i < end; i++ {
k = append(k, entries[i].k) k = append(k, entries[i].k)
v = append(v, entries[i].v) v = append(v, entries[i].v)
} }
hasMore, err := VerifyRangeProof(trie.Hash(), firstKey, lastKey, k, v, proof) hasMore, err := VerifyRangeProof(trie.Hash(), firstKey, k, v, proof)
if err != nil { if err != nil {
t.Fatalf("Expected no error, got %v", err) t.Fatalf("Expected no error, got %v", err)
} }
@ -764,11 +656,11 @@ func TestHasRightElement(t *testing.T) {
// The first edge proof must be a non-existent proof. // The first edge proof must be a non-existent proof.
func TestEmptyRangeProof(t *testing.T) { func TestEmptyRangeProof(t *testing.T) {
trie, vals := randomTrie(4096) trie, vals := randomTrie(4096)
var entries entrySlice var entries []*kv
for _, kv := range vals { for _, kv := range vals {
entries = append(entries, kv) entries = append(entries, kv)
} }
sort.Sort(entries) slices.SortFunc(entries, (*kv).cmp)
var cases = []struct { var cases = []struct {
pos int pos int
@ -780,10 +672,10 @@ func TestEmptyRangeProof(t *testing.T) {
for _, c := range cases { for _, c := range cases {
proof := memorydb.New() proof := memorydb.New()
first := increaseKey(common.CopyBytes(entries[c.pos].k)) first := increaseKey(common.CopyBytes(entries[c.pos].k))
if err := trie.Prove(first, 0, proof); err != nil { if err := trie.Prove(first, proof); err != nil {
t.Fatalf("Failed to prove the first node %v", err) t.Fatalf("Failed to prove the first node %v", err)
} }
_, err := VerifyRangeProof(trie.Hash(), first, nil, nil, nil, proof) _, err := VerifyRangeProof(trie.Hash(), first, nil, nil, proof)
if c.err && err == nil { if c.err && err == nil {
t.Fatalf("Expected error, got nil") t.Fatalf("Expected error, got nil")
} }
@ -799,11 +691,11 @@ func TestEmptyRangeProof(t *testing.T) {
func TestBloatedProof(t *testing.T) { func TestBloatedProof(t *testing.T) {
// Use a small trie // Use a small trie
trie, kvs := nonRandomTrie(100) trie, kvs := nonRandomTrie(100)
var entries entrySlice var entries []*kv
for _, kv := range kvs { for _, kv := range kvs {
entries = append(entries, kv) entries = append(entries, kv)
} }
sort.Sort(entries) slices.SortFunc(entries, (*kv).cmp)
var keys [][]byte var keys [][]byte
var vals [][]byte var vals [][]byte
@ -811,7 +703,7 @@ func TestBloatedProof(t *testing.T) {
// In the 'malicious' case, we add proofs for every single item // In the 'malicious' case, we add proofs for every single item
// (but only one key/value pair used as leaf) // (but only one key/value pair used as leaf)
for i, entry := range entries { for i, entry := range entries {
trie.Prove(entry.k, 0, proof) trie.Prove(entry.k, proof)
if i == 50 { if i == 50 {
keys = append(keys, entry.k) keys = append(keys, entry.k)
vals = append(vals, entry.v) vals = append(vals, entry.v)
@ -820,10 +712,10 @@ func TestBloatedProof(t *testing.T) {
// For reference, we use the same function, but _only_ prove the first // For reference, we use the same function, but _only_ prove the first
// and last element // and last element
want := memorydb.New() want := memorydb.New()
trie.Prove(keys[0], 0, want) trie.Prove(keys[0], want)
trie.Prove(keys[len(keys)-1], 0, want) trie.Prove(keys[len(keys)-1], want)
if _, err := VerifyRangeProof(trie.Hash(), keys[0], keys[len(keys)-1], keys, vals, proof); err != nil { if _, err := VerifyRangeProof(trie.Hash(), keys[0], keys, vals, proof); err != nil {
t.Fatalf("expected bloated proof to succeed, got %v", err) t.Fatalf("expected bloated proof to succeed, got %v", err)
} }
} }
@ -833,11 +725,11 @@ func TestBloatedProof(t *testing.T) {
// noop technically, but practically should be rejected. // noop technically, but practically should be rejected.
func TestEmptyValueRangeProof(t *testing.T) { func TestEmptyValueRangeProof(t *testing.T) {
trie, values := randomTrie(512) trie, values := randomTrie(512)
var entries entrySlice var entries []*kv
for _, kv := range values { for _, kv := range values {
entries = append(entries, kv) entries = append(entries, kv)
} }
sort.Sort(entries) slices.SortFunc(entries, (*kv).cmp)
// Create a new entry with a slightly modified key // Create a new entry with a slightly modified key
mid := len(entries) / 2 mid := len(entries) / 2
@ -854,10 +746,10 @@ func TestEmptyValueRangeProof(t *testing.T) {
start, end := 1, len(entries)-1 start, end := 1, len(entries)-1
proof := memorydb.New() proof := memorydb.New()
if err := trie.Prove(entries[start].k, 0, proof); err != nil { if err := trie.Prove(entries[start].k, proof); err != nil {
t.Fatalf("Failed to prove the first node %v", err) t.Fatalf("Failed to prove the first node %v", err)
} }
if err := trie.Prove(entries[end-1].k, 0, proof); err != nil { if err := trie.Prove(entries[end-1].k, proof); err != nil {
t.Fatalf("Failed to prove the last node %v", err) t.Fatalf("Failed to prove the last node %v", err)
} }
var keys [][]byte var keys [][]byte
@ -866,7 +758,7 @@ func TestEmptyValueRangeProof(t *testing.T) {
keys = append(keys, entries[i].k) keys = append(keys, entries[i].k)
vals = append(vals, entries[i].v) vals = append(vals, entries[i].v)
} }
_, err := VerifyRangeProof(trie.Hash(), keys[0], keys[len(keys)-1], keys, vals, proof) _, err := VerifyRangeProof(trie.Hash(), keys[0], keys, vals, proof)
if err == nil { if err == nil {
t.Fatalf("Expected failure on noop entry") t.Fatalf("Expected failure on noop entry")
} }
@ -877,11 +769,11 @@ func TestEmptyValueRangeProof(t *testing.T) {
// practically should be rejected. // practically should be rejected.
func TestAllElementsEmptyValueRangeProof(t *testing.T) { func TestAllElementsEmptyValueRangeProof(t *testing.T) {
trie, values := randomTrie(512) trie, values := randomTrie(512)
var entries entrySlice var entries []*kv
for _, kv := range values { for _, kv := range values {
entries = append(entries, kv) entries = append(entries, kv)
} }
sort.Sort(entries) slices.SortFunc(entries, (*kv).cmp)
// Create a new entry with a slightly modified key // Create a new entry with a slightly modified key
mid := len(entries) / 2 mid := len(entries) / 2
@ -901,7 +793,7 @@ func TestAllElementsEmptyValueRangeProof(t *testing.T) {
keys = append(keys, entries[i].k) keys = append(keys, entries[i].k)
vals = append(vals, entries[i].v) vals = append(vals, entries[i].v)
} }
_, err := VerifyRangeProof(trie.Hash(), nil, nil, keys, vals, nil) _, err := VerifyRangeProof(trie.Hash(), nil, keys, vals, nil)
if err == nil { if err == nil {
t.Fatalf("Expected failure on noop entry") t.Fatalf("Expected failure on noop entry")
} }
@ -949,7 +841,7 @@ func BenchmarkProve(b *testing.B) {
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
kv := vals[keys[i%len(keys)]] kv := vals[keys[i%len(keys)]]
proofs := memorydb.New() proofs := memorydb.New()
if trie.Prove(kv.k, 0, proofs); proofs.Len() == 0 { if trie.Prove(kv.k, proofs); proofs.Len() == 0 {
b.Fatalf("zero length proof for %x", kv.k) b.Fatalf("zero length proof for %x", kv.k)
} }
} }
@ -963,7 +855,7 @@ func BenchmarkVerifyProof(b *testing.B) {
for k := range vals { for k := range vals {
keys = append(keys, k) keys = append(keys, k)
proof := memorydb.New() proof := memorydb.New()
trie.Prove([]byte(k), 0, proof) trie.Prove([]byte(k), proof)
proofs = append(proofs, proof) proofs = append(proofs, proof)
} }
@ -983,19 +875,19 @@ func BenchmarkVerifyRangeProof5000(b *testing.B) { benchmarkVerifyRangeProof(b,
func benchmarkVerifyRangeProof(b *testing.B, size int) { func benchmarkVerifyRangeProof(b *testing.B, size int) {
trie, vals := randomTrie(8192) trie, vals := randomTrie(8192)
var entries entrySlice var entries []*kv
for _, kv := range vals { for _, kv := range vals {
entries = append(entries, kv) entries = append(entries, kv)
} }
sort.Sort(entries) slices.SortFunc(entries, (*kv).cmp)
start := 2 start := 2
end := start + size end := start + size
proof := memorydb.New() proof := memorydb.New()
if err := trie.Prove(entries[start].k, 0, proof); err != nil { if err := trie.Prove(entries[start].k, proof); err != nil {
b.Fatalf("Failed to prove the first node %v", err) b.Fatalf("Failed to prove the first node %v", err)
} }
if err := trie.Prove(entries[end-1].k, 0, proof); err != nil { if err := trie.Prove(entries[end-1].k, proof); err != nil {
b.Fatalf("Failed to prove the last node %v", err) b.Fatalf("Failed to prove the last node %v", err)
} }
var keys [][]byte var keys [][]byte
@ -1007,7 +899,7 @@ func benchmarkVerifyRangeProof(b *testing.B, size int) {
b.ResetTimer() b.ResetTimer()
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
_, err := VerifyRangeProof(trie.Hash(), keys[0], keys[len(keys)-1], keys, values, proof) _, err := VerifyRangeProof(trie.Hash(), keys[0], keys, values, proof)
if err != nil { if err != nil {
b.Fatalf("Case %d(%d->%d) expect no error, got %v", i, start, end-1, err) b.Fatalf("Case %d(%d->%d) expect no error, got %v", i, start, end-1, err)
} }
@ -1020,11 +912,11 @@ func BenchmarkVerifyRangeNoProof1000(b *testing.B) { benchmarkVerifyRangeNoProof
func benchmarkVerifyRangeNoProof(b *testing.B, size int) { func benchmarkVerifyRangeNoProof(b *testing.B, size int) {
trie, vals := randomTrie(size) trie, vals := randomTrie(size)
var entries entrySlice var entries []*kv
for _, kv := range vals { for _, kv := range vals {
entries = append(entries, kv) entries = append(entries, kv)
} }
sort.Sort(entries) slices.SortFunc(entries, (*kv).cmp)
var keys [][]byte var keys [][]byte
var values [][]byte var values [][]byte
@ -1034,7 +926,7 @@ func benchmarkVerifyRangeNoProof(b *testing.B, size int) {
} }
b.ResetTimer() b.ResetTimer()
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
_, err := VerifyRangeProof(trie.Hash(), keys[0], keys[len(keys)-1], keys, values, nil) _, err := VerifyRangeProof(trie.Hash(), keys[0], keys, values, nil)
if err != nil { if err != nil {
b.Fatalf("Expected no error, got %v", err) b.Fatalf("Expected no error, got %v", err)
} }
@ -1042,26 +934,26 @@ func benchmarkVerifyRangeNoProof(b *testing.B, size int) {
} }
func randomTrie(n int) (*Trie, map[string]*kv) { func randomTrie(n int) (*Trie, map[string]*kv) {
trie := NewEmpty(NewDatabase(rawdb.NewMemoryDatabase())) trie := NewEmpty(newTestDatabase(rawdb.NewMemoryDatabase(), rawdb.HashScheme))
vals := make(map[string]*kv) vals := make(map[string]*kv)
for i := byte(0); i < 100; i++ { for i := byte(0); i < 100; i++ {
value := &kv{common.LeftPadBytes([]byte{i}, 32), []byte{i}, false} value := &kv{common.LeftPadBytes([]byte{i}, 32), []byte{i}, false}
value2 := &kv{common.LeftPadBytes([]byte{i + 10}, 32), []byte{i}, false} value2 := &kv{common.LeftPadBytes([]byte{i + 10}, 32), []byte{i}, false}
trie.Update(value.k, value.v) trie.MustUpdate(value.k, value.v)
trie.Update(value2.k, value2.v) trie.MustUpdate(value2.k, value2.v)
vals[string(value.k)] = value vals[string(value.k)] = value
vals[string(value2.k)] = value2 vals[string(value2.k)] = value2
} }
for i := 0; i < n; i++ { for i := 0; i < n; i++ {
value := &kv{randBytes(32), randBytes(20), false} value := &kv{randBytes(32), randBytes(20), false}
trie.Update(value.k, value.v) trie.MustUpdate(value.k, value.v)
vals[string(value.k)] = value vals[string(value.k)] = value
} }
return trie, vals return trie, vals
} }
func nonRandomTrie(n int) (*Trie, map[string]*kv) { func nonRandomTrie(n int) (*Trie, map[string]*kv) {
trie := NewEmpty(NewDatabase(rawdb.NewMemoryDatabase())) trie := NewEmpty(newTestDatabase(rawdb.NewMemoryDatabase(), rawdb.HashScheme))
vals := make(map[string]*kv) vals := make(map[string]*kv)
max := uint64(0xffffffffffffffff) max := uint64(0xffffffffffffffff)
for i := uint64(0); i < uint64(n); i++ { for i := uint64(0); i < uint64(n); i++ {
@ -1071,7 +963,7 @@ func nonRandomTrie(n int) (*Trie, map[string]*kv) {
binary.LittleEndian.PutUint64(value, i-max) binary.LittleEndian.PutUint64(value, i-max)
//value := &kv{common.LeftPadBytes([]byte{i}, 32), []byte{i}, false} //value := &kv{common.LeftPadBytes([]byte{i}, 32), []byte{i}, false}
elem := &kv{key, value, false} elem := &kv{key, value, false}
trie.Update(elem.k, elem.v) trie.MustUpdate(elem.k, elem.v)
vals[string(elem.k)] = elem vals[string(elem.k)] = elem
} }
return trie, vals return trie, vals
@ -1086,22 +978,21 @@ func TestRangeProofKeysWithSharedPrefix(t *testing.T) {
common.Hex2Bytes("02"), common.Hex2Bytes("02"),
common.Hex2Bytes("03"), common.Hex2Bytes("03"),
} }
trie := NewEmpty(NewDatabase(rawdb.NewMemoryDatabase())) trie := NewEmpty(newTestDatabase(rawdb.NewMemoryDatabase(), rawdb.HashScheme))
for i, key := range keys { for i, key := range keys {
trie.Update(key, vals[i]) trie.MustUpdate(key, vals[i])
} }
root := trie.Hash() root := trie.Hash()
proof := memorydb.New() proof := memorydb.New()
start := common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000000") start := common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000000")
end := common.Hex2Bytes("ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff") if err := trie.Prove(start, proof); err != nil {
if err := trie.Prove(start, 0, proof); err != nil {
t.Fatalf("failed to prove start: %v", err) t.Fatalf("failed to prove start: %v", err)
} }
if err := trie.Prove(end, 0, proof); err != nil { if err := trie.Prove(keys[len(keys)-1], proof); err != nil {
t.Fatalf("failed to prove end: %v", err) t.Fatalf("failed to prove end: %v", err)
} }
more, err := VerifyRangeProof(root, start, end, keys, vals, proof) more, err := VerifyRangeProof(root, start, keys, vals, proof)
if err != nil { if err != nil {
t.Fatalf("failed to verify range proof: %v", err) t.Fatalf("failed to verify range proof: %v", err)
} }

View File

@ -20,9 +20,26 @@ import (
"github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/core/types" "github.com/ethereum/go-ethereum/core/types"
"github.com/ethereum/go-ethereum/rlp" "github.com/ethereum/go-ethereum/rlp"
log "github.com/sirupsen/logrus"
"github.com/cerc-io/ipld-eth-statedb/trie_by_cid/trie/trienode"
"github.com/cerc-io/ipld-eth-statedb/trie_by_cid/triedb/database"
) )
// SecureTrie is the old name of StateTrie.
// Deprecated: use StateTrie.
type SecureTrie = StateTrie
// NewSecure creates a new StateTrie.
// Deprecated: use NewStateTrie.
func NewSecure(stateRoot common.Hash, owner common.Hash, root common.Hash, db database.Database) (*SecureTrie, error) {
id := &ID{
StateRoot: stateRoot,
Owner: owner,
Root: root,
}
return NewStateTrie(id, db)
}
// StateTrie wraps a trie with key hashing. In a stateTrie trie, all // StateTrie wraps a trie with key hashing. In a stateTrie trie, all
// access operations hash the key using keccak256. This prevents // access operations hash the key using keccak256. This prevents
// calling code from creating long chains of nodes that // calling code from creating long chains of nodes that
@ -35,7 +52,7 @@ import (
// StateTrie is not safe for concurrent use. // StateTrie is not safe for concurrent use.
type StateTrie struct { type StateTrie struct {
trie Trie trie Trie
preimages *preimageStore db database.Database
hashKeyBuf [common.HashLength]byte hashKeyBuf [common.HashLength]byte
secKeyCache map[string][]byte secKeyCache map[string][]byte
secKeyCacheOwner *StateTrie // Pointer to self, replace the key cache on mismatch secKeyCacheOwner *StateTrie // Pointer to self, replace the key cache on mismatch
@ -46,41 +63,44 @@ type StateTrie struct {
// If root is the zero hash or the sha3 hash of an empty string, the // If root is the zero hash or the sha3 hash of an empty string, the
// trie is initially empty. Otherwise, New will panic if db is nil // trie is initially empty. Otherwise, New will panic if db is nil
// and returns MissingNodeError if the root node cannot be found. // and returns MissingNodeError if the root node cannot be found.
func NewStateTrie(id *ID, db *Database, codec uint64) (*StateTrie, error) { func NewStateTrie(id *ID, db database.Database) (*StateTrie, error) {
// TODO: codec can be derived based on whether Owner is the zero hash
if db == nil { if db == nil {
panic("trie.NewStateTrie called without a database") panic("trie.NewStateTrie called without a database")
} }
trie, err := New(id, db, codec) trie, err := New(id, db)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return &StateTrie{trie: *trie, preimages: db.preimages}, nil return &StateTrie{trie: *trie, db: db}, nil
} }
// Get returns the value for key stored in the trie. // MustGet returns the value for key stored in the trie.
// The value bytes must not be modified by the caller. // The value bytes must not be modified by the caller.
func (t *StateTrie) Get(key []byte) []byte { //
res, err := t.TryGet(key) // This function will omit any encountered error but just
if err != nil { // print out an error message.
log.Error("Unhandled trie error in StateTrie.Get", "err", err) func (t *StateTrie) MustGet(key []byte) []byte {
} return t.trie.MustGet(t.hashKey(key))
return res
} }
// TryGet returns the value for key stored in the trie. // GetStorage attempts to retrieve a storage slot with provided account address
// The value bytes must not be modified by the caller. // and slot key. The value bytes must not be modified by the caller.
// If the specified node is not in the trie, nil will be returned. // If the specified storage slot is not in the trie, nil will be returned.
// If a trie node is not found in the database, a MissingNodeError is returned. // If a trie node is not found in the database, a MissingNodeError is returned.
func (t *StateTrie) TryGet(key []byte) ([]byte, error) { func (t *StateTrie) GetStorage(_ common.Address, key []byte) ([]byte, error) {
return t.trie.TryGet(t.hashKey(key)) enc, err := t.trie.Get(t.hashKey(key))
if err != nil || len(enc) == 0 {
return nil, err
}
_, content, _, err := rlp.Split(enc)
return content, err
} }
// TryGetAccount attempts to retrieve an account with provided account address. // GetAccount attempts to retrieve an account with provided account address.
// If the specified account is not in the trie, nil will be returned. // If the specified account is not in the trie, nil will be returned.
// If a trie node is not found in the database, a MissingNodeError is returned. // If a trie node is not found in the database, a MissingNodeError is returned.
func (t *StateTrie) TryGetAccount(address common.Address) (*types.StateAccount, error) { func (t *StateTrie) GetAccount(address common.Address) (*types.StateAccount, error) {
res, err := t.trie.TryGet(t.hashKey(address.Bytes())) res, err := t.trie.Get(t.hashKey(address.Bytes()))
if res == nil || err != nil { if res == nil || err != nil {
return nil, err return nil, err
} }
@ -89,11 +109,11 @@ func (t *StateTrie) TryGetAccount(address common.Address) (*types.StateAccount,
return ret, err return ret, err
} }
// TryGetAccountByHash does the same thing as TryGetAccount, however // GetAccountByHash does the same thing as GetAccount, however it expects an
// it expects an account hash that is the hash of address. This constitutes an // account hash that is the hash of address. This constitutes an abstraction
// abstraction leak, since the client code needs to know the key format. // leak, since the client code needs to know the key format.
func (t *StateTrie) TryGetAccountByHash(addrHash common.Hash) (*types.StateAccount, error) { func (t *StateTrie) GetAccountByHash(addrHash common.Hash) (*types.StateAccount, error) {
res, err := t.trie.TryGet(addrHash.Bytes()) res, err := t.trie.Get(addrHash.Bytes())
if res == nil || err != nil { if res == nil || err != nil {
return nil, err return nil, err
} }
@ -102,27 +122,30 @@ func (t *StateTrie) TryGetAccountByHash(addrHash common.Hash) (*types.StateAccou
return ret, err return ret, err
} }
// TryGetNode attempts to retrieve a trie node by compact-encoded path. It is not // GetNode attempts to retrieve a trie node by compact-encoded path. It is not
// possible to use keybyte-encoding as the path might contain odd nibbles. // possible to use keybyte-encoding as the path might contain odd nibbles.
// If the specified trie node is not in the trie, nil will be returned. // If the specified trie node is not in the trie, nil will be returned.
// If a trie node is not found in the database, a MissingNodeError is returned. // If a trie node is not found in the database, a MissingNodeError is returned.
func (t *StateTrie) TryGetNode(path []byte) ([]byte, int, error) { func (t *StateTrie) GetNode(path []byte) ([]byte, int, error) {
return t.trie.TryGetNode(path) return t.trie.GetNode(path)
} }
// Update associates key with value in the trie. Subsequent calls to // MustUpdate associates key with value in the trie. Subsequent calls to
// Get will return value. If value has length zero, any existing value // Get will return value. If value has length zero, any existing value
// is deleted from the trie and calls to Get will return nil. // is deleted from the trie and calls to Get will return nil.
// //
// The value bytes must not be modified by the caller while they are // The value bytes must not be modified by the caller while they are
// stored in the trie. // stored in the trie.
func (t *StateTrie) Update(key, value []byte) { //
if err := t.TryUpdate(key, value); err != nil { // This function will omit any encountered error but just print out an
log.Error("Unhandled trie error in StateTrie.Update", "err", err) // error message.
} func (t *StateTrie) MustUpdate(key, value []byte) {
hk := t.hashKey(key)
t.trie.MustUpdate(hk, value)
t.getSecKeyCache()[string(hk)] = common.CopyBytes(key)
} }
// TryUpdate associates key with value in the trie. Subsequent calls to // UpdateStorage associates key with value in the trie. Subsequent calls to
// Get will return value. If value has length zero, any existing value // Get will return value. If value has length zero, any existing value
// is deleted from the trie and calls to Get will return nil. // is deleted from the trie and calls to Get will return nil.
// //
@ -130,9 +153,10 @@ func (t *StateTrie) Update(key, value []byte) {
// stored in the trie. // stored in the trie.
// //
// If a node is not found in the database, a MissingNodeError is returned. // If a node is not found in the database, a MissingNodeError is returned.
func (t *StateTrie) TryUpdate(key, value []byte) error { func (t *StateTrie) UpdateStorage(_ common.Address, key, value []byte) error {
hk := t.hashKey(key) hk := t.hashKey(key)
err := t.trie.TryUpdate(hk, value) v, _ := rlp.EncodeToBytes(value)
err := t.trie.Update(hk, v)
if err != nil { if err != nil {
return err return err
} }
@ -140,42 +164,46 @@ func (t *StateTrie) TryUpdate(key, value []byte) error {
return nil return nil
} }
// TryUpdateAccount account will abstract the write of an account to the // UpdateAccount will abstract the write of an account to the secure trie.
// secure trie. func (t *StateTrie) UpdateAccount(address common.Address, acc *types.StateAccount) error {
func (t *StateTrie) TryUpdateAccount(address common.Address, acc *types.StateAccount) error {
hk := t.hashKey(address.Bytes()) hk := t.hashKey(address.Bytes())
data, err := rlp.EncodeToBytes(acc) data, err := rlp.EncodeToBytes(acc)
if err != nil { if err != nil {
return err return err
} }
if err := t.trie.TryUpdate(hk, data); err != nil { if err := t.trie.Update(hk, data); err != nil {
return err return err
} }
t.getSecKeyCache()[string(hk)] = address.Bytes() t.getSecKeyCache()[string(hk)] = address.Bytes()
return nil return nil
} }
// Delete removes any existing value for key from the trie. func (t *StateTrie) UpdateContractCode(_ common.Address, _ common.Hash, _ []byte) error {
func (t *StateTrie) Delete(key []byte) { return nil
if err := t.TryDelete(key); err != nil {
log.Error("Unhandled trie error in StateTrie.Delete", "err", err)
}
} }
// TryDelete removes any existing value for key from the trie. // MustDelete removes any existing value for key from the trie. This function
// If the specified trie node is not in the trie, nothing will be changed. // will omit any encountered error but just print out an error message.
// If a node is not found in the database, a MissingNodeError is returned. func (t *StateTrie) MustDelete(key []byte) {
func (t *StateTrie) TryDelete(key []byte) error {
hk := t.hashKey(key) hk := t.hashKey(key)
delete(t.getSecKeyCache(), string(hk)) delete(t.getSecKeyCache(), string(hk))
return t.trie.TryDelete(hk) t.trie.MustDelete(hk)
} }
// TryDeleteAccount abstracts an account deletion from the trie. // DeleteStorage removes any existing storage slot from the trie.
func (t *StateTrie) TryDeleteAccount(address common.Address) error { // If the specified trie node is not in the trie, nothing will be changed.
// If a node is not found in the database, a MissingNodeError is returned.
func (t *StateTrie) DeleteStorage(_ common.Address, key []byte) error {
hk := t.hashKey(key)
delete(t.getSecKeyCache(), string(hk))
return t.trie.Delete(hk)
}
// DeleteAccount abstracts an account deletion from the trie.
func (t *StateTrie) DeleteAccount(address common.Address) error {
hk := t.hashKey(address.Bytes()) hk := t.hashKey(address.Bytes())
delete(t.getSecKeyCache(), string(hk)) delete(t.getSecKeyCache(), string(hk))
return t.trie.TryDelete(hk) return t.trie.Delete(hk)
} }
// GetKey returns the sha3 preimage of a hashed key that was // GetKey returns the sha3 preimage of a hashed key that was
@ -184,10 +212,7 @@ func (t *StateTrie) GetKey(shaKey []byte) []byte {
if key, ok := t.getSecKeyCache()[string(shaKey)]; ok { if key, ok := t.getSecKeyCache()[string(shaKey)]; ok {
return key return key
} }
if t.preimages == nil { return t.db.Preimage(common.BytesToHash(shaKey))
return nil
}
return t.preimages.preimage(common.BytesToHash(shaKey))
} }
// Commit collects all dirty nodes in the trie and replaces them with the // Commit collects all dirty nodes in the trie and replaces them with the
@ -197,16 +222,14 @@ func (t *StateTrie) GetKey(shaKey []byte) []byte {
// All cached preimages will be also flushed if preimages recording is enabled. // All cached preimages will be also flushed if preimages recording is enabled.
// Once the trie is committed, it's not usable anymore. A new trie must // Once the trie is committed, it's not usable anymore. A new trie must
// be created with new root and updated trie database for following usage // be created with new root and updated trie database for following usage
func (t *StateTrie) Commit(collectLeaf bool) (common.Hash, *NodeSet) { func (t *StateTrie) Commit(collectLeaf bool) (common.Hash, *trienode.NodeSet, error) {
// Write all the pre-images to the actual disk database // Write all the pre-images to the actual disk database
if len(t.getSecKeyCache()) > 0 { if len(t.getSecKeyCache()) > 0 {
if t.preimages != nil {
preimages := make(map[common.Hash][]byte) preimages := make(map[common.Hash][]byte)
for hk, key := range t.secKeyCache { for hk, key := range t.secKeyCache {
preimages[common.BytesToHash([]byte(hk))] = key preimages[common.BytesToHash([]byte(hk))] = key
} }
t.preimages.insertPreimage(preimages) t.db.InsertPreimage(preimages)
}
t.secKeyCache = make(map[string][]byte) t.secKeyCache = make(map[string][]byte)
} }
// Commit the trie and return its modified nodeset. // Commit the trie and return its modified nodeset.
@ -223,17 +246,23 @@ func (t *StateTrie) Hash() common.Hash {
func (t *StateTrie) Copy() *StateTrie { func (t *StateTrie) Copy() *StateTrie {
return &StateTrie{ return &StateTrie{
trie: *t.trie.Copy(), trie: *t.trie.Copy(),
preimages: t.preimages, db: t.db,
secKeyCache: t.secKeyCache, secKeyCache: t.secKeyCache,
} }
} }
// NodeIterator returns an iterator that returns nodes of the underlying trie. Iteration // NodeIterator returns an iterator that returns nodes of the underlying trie.
// starts at the key after the given start key. // Iteration starts at the key after the given start key.
func (t *StateTrie) NodeIterator(start []byte) NodeIterator { func (t *StateTrie) NodeIterator(start []byte) (NodeIterator, error) {
return t.trie.NodeIterator(start) return t.trie.NodeIterator(start)
} }
// MustNodeIterator is a wrapper of NodeIterator and will omit any encountered
// error but just print out an error message.
func (t *StateTrie) MustNodeIterator(start []byte) NodeIterator {
return t.trie.MustNodeIterator(start)
}
// hashKey returns the hash of key as an ephemeral buffer. // hashKey returns the hash of key as an ephemeral buffer.
// The caller must not hold onto the return value because it will become // The caller must not hold onto the return value because it will become
// invalid on the next call to hashKey or secKey. // invalid on the next call to hashKey or secKey.

View File

@ -25,14 +25,22 @@ import (
"github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/core/rawdb" "github.com/ethereum/go-ethereum/core/rawdb"
"github.com/ethereum/go-ethereum/core/types"
"github.com/ethereum/go-ethereum/crypto" "github.com/ethereum/go-ethereum/crypto"
"github.com/cerc-io/ipld-eth-statedb/trie_by_cid/trie/trienode"
) )
func newEmptySecure() *StateTrie {
trie, _ := NewStateTrie(TrieID(types.EmptyRootHash), newTestDatabase(rawdb.NewMemoryDatabase(), rawdb.HashScheme))
return trie
}
// makeTestStateTrie creates a large enough secure trie for testing. // makeTestStateTrie creates a large enough secure trie for testing.
func makeTestStateTrie() (*Database, *StateTrie, map[string][]byte) { func makeTestStateTrie() (*testDb, *StateTrie, map[string][]byte) {
// Create an empty trie // Create an empty trie
triedb := NewDatabase(rawdb.NewMemoryDatabase()) triedb := newTestDatabase(rawdb.NewMemoryDatabase(), rawdb.HashScheme)
trie, _ := NewStateTrie(TrieID(common.Hash{}), triedb, StateTrieCodec) trie, _ := NewStateTrie(TrieID(types.EmptyRootHash), triedb)
// Fill it with some arbitrary data // Fill it with some arbitrary data
content := make(map[string][]byte) content := make(map[string][]byte)
@ -40,33 +48,30 @@ func makeTestStateTrie() (*Database, *StateTrie, map[string][]byte) {
// Map the same data under multiple keys // Map the same data under multiple keys
key, val := common.LeftPadBytes([]byte{1, i}, 32), []byte{i} key, val := common.LeftPadBytes([]byte{1, i}, 32), []byte{i}
content[string(key)] = val content[string(key)] = val
trie.Update(key, val) trie.MustUpdate(key, val)
key, val = common.LeftPadBytes([]byte{2, i}, 32), []byte{i} key, val = common.LeftPadBytes([]byte{2, i}, 32), []byte{i}
content[string(key)] = val content[string(key)] = val
trie.Update(key, val) trie.MustUpdate(key, val)
// Add some other data to inflate the trie // Add some other data to inflate the trie
for j := byte(3); j < 13; j++ { for j := byte(3); j < 13; j++ {
key, val = common.LeftPadBytes([]byte{j, i}, 32), []byte{j, i} key, val = common.LeftPadBytes([]byte{j, i}, 32), []byte{j, i}
content[string(key)] = val content[string(key)] = val
trie.Update(key, val) trie.MustUpdate(key, val)
} }
} }
root, nodes := trie.Commit(false) root, nodes, _ := trie.Commit(false)
if err := triedb.Update(NewWithNodeSet(nodes)); err != nil { if err := triedb.Update(root, types.EmptyRootHash, trienode.NewWithNodeSet(nodes)); err != nil {
panic(fmt.Errorf("failed to commit db %v", err)) panic(fmt.Errorf("failed to commit db %v", err))
} }
// Re-create the trie based on the new state // Re-create the trie based on the new state
trie, _ = NewStateTrie(TrieID(root), triedb, StateTrieCodec) trie, _ = NewStateTrie(TrieID(root), triedb)
return triedb, trie, content return triedb, trie, content
} }
func TestSecureDelete(t *testing.T) { func TestSecureDelete(t *testing.T) {
trie, err := NewStateTrie(TrieID(common.Hash{}), NewDatabase(rawdb.NewMemoryDatabase()), StateTrieCodec) trie := newEmptySecure()
if err != nil {
t.Fatal(err)
}
vals := []struct{ k, v string }{ vals := []struct{ k, v string }{
{"do", "verb"}, {"do", "verb"},
{"ether", "wookiedoo"}, {"ether", "wookiedoo"},
@ -79,9 +84,9 @@ func TestSecureDelete(t *testing.T) {
} }
for _, val := range vals { for _, val := range vals {
if val.v != "" { if val.v != "" {
trie.Update([]byte(val.k), []byte(val.v)) trie.MustUpdate([]byte(val.k), []byte(val.v))
} else { } else {
trie.Delete([]byte(val.k)) trie.MustDelete([]byte(val.k))
} }
} }
hash := trie.Hash() hash := trie.Hash()
@ -92,17 +97,14 @@ func TestSecureDelete(t *testing.T) {
} }
func TestSecureGetKey(t *testing.T) { func TestSecureGetKey(t *testing.T) {
trie, err := NewStateTrie(TrieID(common.Hash{}), NewDatabase(rawdb.NewMemoryDatabase()), StateTrieCodec) trie := newEmptySecure()
if err != nil { trie.MustUpdate([]byte("foo"), []byte("bar"))
t.Fatal(err)
}
trie.Update([]byte("foo"), []byte("bar"))
key := []byte("foo") key := []byte("foo")
value := []byte("bar") value := []byte("bar")
seckey := crypto.Keccak256(key) seckey := crypto.Keccak256(key)
if !bytes.Equal(trie.Get(key), value) { if !bytes.Equal(trie.MustGet(key), value) {
t.Errorf("Get did not return bar") t.Errorf("Get did not return bar")
} }
if k := trie.GetKey(seckey); !bytes.Equal(k, key) { if k := trie.GetKey(seckey); !bytes.Equal(k, key) {
@ -129,15 +131,15 @@ func TestStateTrieConcurrency(t *testing.T) {
for j := byte(0); j < 255; j++ { for j := byte(0); j < 255; j++ {
// Map the same data under multiple keys // Map the same data under multiple keys
key, val := common.LeftPadBytes([]byte{byte(index), 1, j}, 32), []byte{j} key, val := common.LeftPadBytes([]byte{byte(index), 1, j}, 32), []byte{j}
tries[index].Update(key, val) tries[index].MustUpdate(key, val)
key, val = common.LeftPadBytes([]byte{byte(index), 2, j}, 32), []byte{j} key, val = common.LeftPadBytes([]byte{byte(index), 2, j}, 32), []byte{j}
tries[index].Update(key, val) tries[index].MustUpdate(key, val)
// Add some other data to inflate the trie // Add some other data to inflate the trie
for k := byte(3); k < 13; k++ { for k := byte(3); k < 13; k++ {
key, val = common.LeftPadBytes([]byte{byte(index), k, j}, 32), []byte{k, j} key, val = common.LeftPadBytes([]byte{byte(index), k, j}, 32), []byte{k, j}
tries[index].Update(key, val) tries[index].MustUpdate(key, val)
} }
} }
tries[index].Commit(false) tries[index].Commit(false)

View File

@ -16,7 +16,9 @@
package trie package trie
import "github.com/ethereum/go-ethereum/common" import (
"github.com/ethereum/go-ethereum/common"
)
// tracer tracks the changes of trie nodes. During the trie operations, // tracer tracks the changes of trie nodes. During the trie operations,
// some nodes can be deleted from the trie, while these deleted nodes // some nodes can be deleted from the trie, while these deleted nodes
@ -111,15 +113,18 @@ func (t *tracer) copy() *tracer {
} }
} }
// markDeletions puts all tracked deletions into the provided nodeset. // deletedNodes returns a list of node paths which are deleted from the trie.
func (t *tracer) markDeletions(set *NodeSet) { func (t *tracer) deletedNodes() []string {
var paths []string
for path := range t.deletes { for path := range t.deletes {
// It's possible a few deleted nodes were embedded // It's possible a few deleted nodes were embedded
// in their parent before, the deletions can be no // in their parent before, the deletions can be no
// effect by deleting nothing, filter them out. // effect by deleting nothing, filter them out.
if _, ok := set.accessList[path]; !ok { _, ok := t.accessList[path]
if !ok {
continue continue
} }
set.markDeleted([]byte(path)) paths = append(paths, path)
} }
return paths
} }

View File

@ -1,3 +1,19 @@
// Copyright 2014 The go-ethereum Authors
// This file is part of the go-ethereum library.
//
// The go-ethereum library is free software: you can redistribute it and/or modify
// it under the terms of the GNU Lesser General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// The go-ethereum library is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Lesser General Public License for more details.
//
// You should have received a copy of the GNU Lesser General Public License
// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
// Package trie implements Merkle Patricia Tries. // Package trie implements Merkle Patricia Tries.
package trie package trie
@ -8,14 +24,10 @@ import (
"github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/core/types" "github.com/ethereum/go-ethereum/core/types"
log "github.com/sirupsen/logrus" "github.com/ethereum/go-ethereum/log"
"github.com/cerc-io/plugeth-statediff/indexer/ipld" "github.com/cerc-io/ipld-eth-statedb/trie_by_cid/trie/trienode"
) "github.com/cerc-io/ipld-eth-statedb/trie_by_cid/triedb/database"
var (
StateTrieCodec uint64 = ipld.MEthStateTrie
StorageTrieCodec uint64 = ipld.MEthStorageTrie
) )
// Trie is a Merkle Patricia Trie. Use New to create a trie that sits on // Trie is a Merkle Patricia Trie. Use New to create a trie that sits on
@ -29,6 +41,10 @@ type Trie struct {
root node root node
owner common.Hash owner common.Hash
// Flag whether the commit operation is already performed. If so the
// trie is not usable(latest states is invisible).
committed bool
// Keep track of the number leaves which have been inserted since the last // Keep track of the number leaves which have been inserted since the last
// hashing operation. This number will not directly map to the number of // hashing operation. This number will not directly map to the number of
// actually unhashed nodes. // actually unhashed nodes.
@ -52,22 +68,21 @@ func (t *Trie) Copy() *Trie {
return &Trie{ return &Trie{
root: t.root, root: t.root,
owner: t.owner, owner: t.owner,
committed: t.committed,
unhashed: t.unhashed, unhashed: t.unhashed,
reader: t.reader, reader: t.reader,
tracer: t.tracer.copy(), tracer: t.tracer.copy(),
} }
} }
// New creates a trie instance with the provided trie id and the read-only // New creates the trie instance with provided trie id and the read-only
// database. The state specified by trie id must be available, otherwise // database. The state specified by trie id must be available, otherwise
// an error will be returned. The trie root specified by trie id can be // an error will be returned. The trie root specified by trie id can be
// zero hash or the sha3 hash of an empty string, then trie is initially // zero hash or the sha3 hash of an empty string, then trie is initially
// empty, otherwise, the root node must be present in database or returns // empty, otherwise, the root node must be present in database or returns
// a MissingNodeError if not. // a MissingNodeError if not.
// The passed codec specifies whether to read state or storage nodes from the func New(id *ID, db database.Database) (*Trie, error) {
// trie. reader, err := newTrieReader(id.StateRoot, id.Owner, db)
func New(id *ID, db NodeReader, codec uint64) (*Trie, error) {
reader, err := newTrieReader(id.StateRoot, id.Owner, db, codec)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -87,42 +102,59 @@ func New(id *ID, db NodeReader, codec uint64) (*Trie, error) {
} }
// NewEmpty is a shortcut to create empty tree. It's mostly used in tests. // NewEmpty is a shortcut to create empty tree. It's mostly used in tests.
func NewEmpty(db *Database) *Trie { func NewEmpty(db database.Database) *Trie {
tr, err := New(TrieID(common.Hash{}), db, StateTrieCodec) tr, _ := New(TrieID(types.EmptyRootHash), db)
if err != nil {
panic(err)
}
return tr return tr
} }
// MustNodeIterator is a wrapper of NodeIterator and will omit any encountered
// error but just print out an error message.
func (t *Trie) MustNodeIterator(start []byte) NodeIterator {
it, err := t.NodeIterator(start)
if err != nil {
log.Error("Unhandled trie error in Trie.NodeIterator", "err", err)
}
return it
}
// NodeIterator returns an iterator that returns nodes of the trie. Iteration starts at // NodeIterator returns an iterator that returns nodes of the trie. Iteration starts at
// the key after the given start key. // the key after the given start key.
func (t *Trie) NodeIterator(start []byte) NodeIterator { func (t *Trie) NodeIterator(start []byte) (NodeIterator, error) {
return newNodeIterator(t, start) // Short circuit if the trie is already committed and not usable.
if t.committed {
return nil, ErrCommitted
}
return newNodeIterator(t, start), nil
} }
// Get returns the value for key stored in the trie. // MustGet is a wrapper of Get and will omit any encountered error but just
// The value bytes must not be modified by the caller. // print out an error message.
func (t *Trie) Get(key []byte) []byte { func (t *Trie) MustGet(key []byte) []byte {
res, err := t.TryGet(key) res, err := t.Get(key)
if err != nil { if err != nil {
log.Error("Unhandled trie error in Trie.Get", "err", err) log.Error("Unhandled trie error in Trie.Get", "err", err)
} }
return res return res
} }
// TryGet returns the value for key stored in the trie. // Get returns the value for key stored in the trie.
// The value bytes must not be modified by the caller. // The value bytes must not be modified by the caller.
// If a node was not found in the database, a MissingNodeError is returned. //
func (t *Trie) TryGet(key []byte) ([]byte, error) { // If the requested node is not present in trie, no error will be returned.
value, newroot, didResolve, err := t.tryGet(t.root, keybytesToHex(key), 0) // If the trie is corrupted, a MissingNodeError is returned.
func (t *Trie) Get(key []byte) ([]byte, error) {
// Short circuit if the trie is already committed and not usable.
if t.committed {
return nil, ErrCommitted
}
value, newroot, didResolve, err := t.get(t.root, keybytesToHex(key), 0)
if err == nil && didResolve { if err == nil && didResolve {
t.root = newroot t.root = newroot
} }
return value, err return value, err
} }
func (t *Trie) tryGet(origNode node, key []byte, pos int) (value []byte, newnode node, didResolve bool, err error) { func (t *Trie) get(origNode node, key []byte, pos int) (value []byte, newnode node, didResolve bool, err error) {
switch n := (origNode).(type) { switch n := (origNode).(type) {
case nil: case nil:
return nil, nil, false, nil return nil, nil, false, nil
@ -133,14 +165,14 @@ func (t *Trie) tryGet(origNode node, key []byte, pos int) (value []byte, newnode
// key not found in trie // key not found in trie
return nil, n, false, nil return nil, n, false, nil
} }
value, newnode, didResolve, err = t.tryGet(n.Val, key, pos+len(n.Key)) value, newnode, didResolve, err = t.get(n.Val, key, pos+len(n.Key))
if err == nil && didResolve { if err == nil && didResolve {
n = n.copy() n = n.copy()
n.Val = newnode n.Val = newnode
} }
return value, n, didResolve, err return value, n, didResolve, err
case *fullNode: case *fullNode:
value, newnode, didResolve, err = t.tryGet(n.Children[key[pos]], key, pos+1) value, newnode, didResolve, err = t.get(n.Children[key[pos]], key, pos+1)
if err == nil && didResolve { if err == nil && didResolve {
n = n.copy() n = n.copy()
n.Children[key[pos]] = newnode n.Children[key[pos]] = newnode
@ -151,17 +183,34 @@ func (t *Trie) tryGet(origNode node, key []byte, pos int) (value []byte, newnode
if err != nil { if err != nil {
return nil, n, true, err return nil, n, true, err
} }
value, newnode, _, err := t.tryGet(child, key, pos) value, newnode, _, err := t.get(child, key, pos)
return value, newnode, true, err return value, newnode, true, err
default: default:
panic(fmt.Sprintf("%T: invalid node: %v", origNode, origNode)) panic(fmt.Sprintf("%T: invalid node: %v", origNode, origNode))
} }
} }
// TryGetNode attempts to retrieve a trie node by compact-encoded path. It is not // MustGetNode is a wrapper of GetNode and will omit any encountered error but
// possible to use keybyte-encoding as the path might contain odd nibbles. // just print out an error message.
func (t *Trie) TryGetNode(path []byte) ([]byte, int, error) { func (t *Trie) MustGetNode(path []byte) ([]byte, int) {
item, newroot, resolved, err := t.tryGetNode(t.root, compactToHex(path), 0) item, resolved, err := t.GetNode(path)
if err != nil {
log.Error("Unhandled trie error in Trie.GetNode", "err", err)
}
return item, resolved
}
// GetNode retrieves a trie node by compact-encoded path. It is not possible
// to use keybyte-encoding as the path might contain odd nibbles.
//
// If the requested node is not present in trie, no error will be returned.
// If the trie is corrupted, a MissingNodeError is returned.
func (t *Trie) GetNode(path []byte) ([]byte, int, error) {
// Short circuit if the trie is already committed and not usable.
if t.committed {
return nil, 0, ErrCommitted
}
item, newroot, resolved, err := t.getNode(t.root, compactToHex(path), 0)
if err != nil { if err != nil {
return nil, resolved, err return nil, resolved, err
} }
@ -171,10 +220,10 @@ func (t *Trie) TryGetNode(path []byte) ([]byte, int, error) {
if item == nil { if item == nil {
return nil, resolved, nil return nil, resolved, nil
} }
return item, resolved, err return item, resolved, nil
} }
func (t *Trie) tryGetNode(origNode node, path []byte, pos int) (item []byte, newnode node, resolved int, err error) { func (t *Trie) getNode(origNode node, path []byte, pos int) (item []byte, newnode node, resolved int, err error) {
// If non-existent path requested, abort // If non-existent path requested, abort
if origNode == nil { if origNode == nil {
return nil, nil, 0, nil return nil, nil, 0, nil
@ -193,7 +242,7 @@ func (t *Trie) tryGetNode(origNode node, path []byte, pos int) (item []byte, new
if hash == nil { if hash == nil {
return nil, origNode, 0, errors.New("non-consensus node") return nil, origNode, 0, errors.New("non-consensus node")
} }
blob, err := t.reader.nodeBlob(path, common.BytesToHash(hash)) blob, err := t.reader.node(path, common.BytesToHash(hash))
return blob, origNode, 1, err return blob, origNode, 1, err
} }
// Path still needs to be traversed, descend into children // Path still needs to be traversed, descend into children
@ -207,7 +256,7 @@ func (t *Trie) tryGetNode(origNode node, path []byte, pos int) (item []byte, new
// Path branches off from short node // Path branches off from short node
return nil, n, 0, nil return nil, n, 0, nil
} }
item, newnode, resolved, err = t.tryGetNode(n.Val, path, pos+len(n.Key)) item, newnode, resolved, err = t.getNode(n.Val, path, pos+len(n.Key))
if err == nil && resolved > 0 { if err == nil && resolved > 0 {
n = n.copy() n = n.copy()
n.Val = newnode n.Val = newnode
@ -215,7 +264,7 @@ func (t *Trie) tryGetNode(origNode node, path []byte, pos int) (item []byte, new
return item, n, resolved, err return item, n, resolved, err
case *fullNode: case *fullNode:
item, newnode, resolved, err = t.tryGetNode(n.Children[path[pos]], path, pos+1) item, newnode, resolved, err = t.getNode(n.Children[path[pos]], path, pos+1)
if err == nil && resolved > 0 { if err == nil && resolved > 0 {
n = n.copy() n = n.copy()
n.Children[path[pos]] = newnode n.Children[path[pos]] = newnode
@ -227,7 +276,7 @@ func (t *Trie) tryGetNode(origNode node, path []byte, pos int) (item []byte, new
if err != nil { if err != nil {
return nil, n, 1, err return nil, n, 1, err
} }
item, newnode, resolved, err := t.tryGetNode(child, path, pos) item, newnode, resolved, err := t.getNode(child, path, pos)
return item, newnode, resolved + 1, err return item, newnode, resolved + 1, err
default: default:
@ -235,33 +284,32 @@ func (t *Trie) tryGetNode(origNode node, path []byte, pos int) (item []byte, new
} }
} }
// MustUpdate is a wrapper of Update and will omit any encountered error but
// just print out an error message.
func (t *Trie) MustUpdate(key, value []byte) {
if err := t.Update(key, value); err != nil {
log.Error("Unhandled trie error in Trie.Update", "err", err)
}
}
// Update associates key with value in the trie. Subsequent calls to // Update associates key with value in the trie. Subsequent calls to
// Get will return value. If value has length zero, any existing value // Get will return value. If value has length zero, any existing value
// is deleted from the trie and calls to Get will return nil. // is deleted from the trie and calls to Get will return nil.
// //
// The value bytes must not be modified by the caller while they are // The value bytes must not be modified by the caller while they are
// stored in the trie. // stored in the trie.
func (t *Trie) Update(key, value []byte) { //
if err := t.TryUpdate(key, value); err != nil { // If the requested node is not present in trie, no error will be returned.
log.Error("Unhandled trie error in Trie.Update", "err", err) // If the trie is corrupted, a MissingNodeError is returned.
func (t *Trie) Update(key, value []byte) error {
// Short circuit if the trie is already committed and not usable.
if t.committed {
return ErrCommitted
} }
return t.update(key, value)
} }
// TryUpdate associates key with value in the trie. Subsequent calls to func (t *Trie) update(key, value []byte) error {
// Get will return value. If value has length zero, any existing value
// is deleted from the trie and calls to Get will return nil.
//
// The value bytes must not be modified by the caller while they are
// stored in the trie.
//
// If a node was not found in the database, a MissingNodeError is returned.
func (t *Trie) TryUpdate(key, value []byte) error {
return t.tryUpdate(key, value)
}
// tryUpdate expects an RLP-encoded value and performs the core function
// for TryUpdate and TryUpdateAccount.
func (t *Trie) tryUpdate(key, value []byte) error {
t.unhashed++ t.unhashed++
k := keybytesToHex(key) k := keybytesToHex(key)
if len(value) != 0 { if len(value) != 0 {
@ -359,16 +407,23 @@ func (t *Trie) insert(n node, prefix, key []byte, value node) (bool, node, error
} }
} }
// Delete removes any existing value for key from the trie. // MustDelete is a wrapper of Delete and will omit any encountered error but
func (t *Trie) Delete(key []byte) { // just print out an error message.
if err := t.TryDelete(key); err != nil { func (t *Trie) MustDelete(key []byte) {
if err := t.Delete(key); err != nil {
log.Error("Unhandled trie error in Trie.Delete", "err", err) log.Error("Unhandled trie error in Trie.Delete", "err", err)
} }
} }
// TryDelete removes any existing value for key from the trie. // Delete removes any existing value for key from the trie.
// If a node was not found in the database, a MissingNodeError is returned. //
func (t *Trie) TryDelete(key []byte) error { // If the requested node is not present in trie, no error will be returned.
// If the trie is corrupted, a MissingNodeError is returned.
func (t *Trie) Delete(key []byte) error {
// Short circuit if the trie is already committed and not usable.
if t.committed {
return ErrCommitted
}
t.unhashed++ t.unhashed++
k := keybytesToHex(key) k := keybytesToHex(key)
_, n, err := t.delete(t.root, nil, k) _, n, err := t.delete(t.root, nil, k)
@ -532,7 +587,7 @@ func (t *Trie) resolve(n node, prefix []byte) (node, error) {
// node's original value. The rlp-encoded blob is preferred to be loaded from // node's original value. The rlp-encoded blob is preferred to be loaded from
// database because it's easy to decode node while complex to encode node to blob. // database because it's easy to decode node while complex to encode node to blob.
func (t *Trie) resolveAndTrack(n hashNode, prefix []byte) (node, error) { func (t *Trie) resolveAndTrack(n hashNode, prefix []byte) (node, error) {
blob, err := t.reader.nodeBlob(prefix, common.BytesToHash(n)) blob, err := t.reader.node(prefix, common.BytesToHash(n))
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -554,17 +609,25 @@ func (t *Trie) Hash() common.Hash {
// The returned nodeset can be nil if the trie is clean (nothing to commit). // The returned nodeset can be nil if the trie is clean (nothing to commit).
// Once the trie is committed, it's not usable anymore. A new trie must // Once the trie is committed, it's not usable anymore. A new trie must
// be created with new root and updated trie database for following usage // be created with new root and updated trie database for following usage
func (t *Trie) Commit(collectLeaf bool) (common.Hash, *NodeSet) { func (t *Trie) Commit(collectLeaf bool) (common.Hash, *trienode.NodeSet, error) {
defer t.tracer.reset() defer t.tracer.reset()
defer func() {
nodes := NewNodeSet(t.owner, t.tracer.accessList) t.committed = true
t.tracer.markDeletions(nodes) }()
// Trie is empty and can be classified into two types of situations: // Trie is empty and can be classified into two types of situations:
// - The trie was empty and no update happens // (a) The trie was empty and no update happens => return nil
// - The trie was non-empty and all nodes are dropped // (b) The trie was non-empty and all nodes are dropped => return
// the node set includes all deleted nodes
if t.root == nil { if t.root == nil {
return types.EmptyRootHash, nodes paths := t.tracer.deletedNodes()
if len(paths) == 0 {
return types.EmptyRootHash, nil, nil // case (a)
}
nodes := trienode.NewNodeSet(t.owner)
for _, path := range paths {
nodes.AddNode([]byte(path), trienode.NewDeleted())
}
return types.EmptyRootHash, nodes, nil // case (b)
} }
// Derive the hash for all dirty nodes first. We hold the assumption // Derive the hash for all dirty nodes first. We hold the assumption
// in the following procedure that all nodes are hashed. // in the following procedure that all nodes are hashed.
@ -576,10 +639,14 @@ func (t *Trie) Commit(collectLeaf bool) (common.Hash, *NodeSet) {
// Replace the root node with the origin hash in order to // Replace the root node with the origin hash in order to
// ensure all resolved nodes are dropped after the commit. // ensure all resolved nodes are dropped after the commit.
t.root = hashedNode t.root = hashedNode
return rootHash, nil return rootHash, nil, nil
} }
t.root = newCommitter(nodes, collectLeaf).Commit(t.root) nodes := trienode.NewNodeSet(t.owner)
return rootHash, nodes for _, path := range t.tracer.deletedNodes() {
nodes.AddNode([]byte(path), trienode.NewDeleted())
}
t.root = newCommitter(nodes, t.tracer, collectLeaf).Commit(t.root)
return rootHash, nodes, nil
} }
// hashRoot calculates the root hash of the given trie // hashRoot calculates the root hash of the given trie
@ -603,4 +670,5 @@ func (t *Trie) Reset() {
t.owner = common.Hash{} t.owner = common.Hash{}
t.unhashed = 0 t.unhashed = 0
t.tracer.reset() t.tracer.reset()
t.committed = false
} }

View File

@ -17,44 +17,33 @@
package trie package trie
import ( import (
"fmt"
"github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/core/types"
"github.com/ethereum/go-ethereum/log"
"github.com/cerc-io/ipld-eth-statedb/trie_by_cid/trie/triestate"
"github.com/cerc-io/ipld-eth-statedb/trie_by_cid/triedb/database"
) )
// Reader wraps the Node and NodeBlob method of a backing trie store.
type Reader interface {
// Node retrieves the trie node with the provided trie identifier, hexary
// node path and the corresponding node hash.
// No error will be returned if the node is not found.
Node(owner common.Hash, path []byte, hash common.Hash) (node, error)
// NodeBlob retrieves the RLP-encoded trie node blob with the provided trie
// identifier, hexary node path and the corresponding node hash.
// No error will be returned if the node is not found.
NodeBlob(owner common.Hash, path []byte, hash common.Hash) ([]byte, error)
}
// NodeReader wraps all the necessary functions for accessing trie node.
type NodeReader interface {
// GetReader returns a reader for accessing all trie nodes with provided
// state root. Nil is returned in case the state is not available.
GetReader(root common.Hash, codec uint64) Reader
}
// trieReader is a wrapper of the underlying node reader. It's not safe // trieReader is a wrapper of the underlying node reader. It's not safe
// for concurrent usage. // for concurrent usage.
type trieReader struct { type trieReader struct {
owner common.Hash owner common.Hash
reader Reader reader database.Reader
banned map[string]struct{} // Marker to prevent node from being accessed, for tests banned map[string]struct{} // Marker to prevent node from being accessed, for tests
} }
// newTrieReader initializes the trie reader with the given node reader. // newTrieReader initializes the trie reader with the given node reader.
func newTrieReader(stateRoot, owner common.Hash, db NodeReader, codec uint64) (*trieReader, error) { func newTrieReader(stateRoot, owner common.Hash, db database.Database) (*trieReader, error) {
reader := db.GetReader(stateRoot, codec) if stateRoot == (common.Hash{}) || stateRoot == types.EmptyRootHash {
if reader == nil { if stateRoot == (common.Hash{}) {
return nil, fmt.Errorf("state not found #%x", stateRoot) log.Error("Zero state root hash!")
}
return &trieReader{owner: owner}, nil
}
reader, err := db.Reader(stateRoot)
if err != nil {
return nil, &MissingNodeError{Owner: owner, NodeHash: stateRoot, err: err}
} }
return &trieReader{owner: owner, reader: reader}, nil return &trieReader{owner: owner, reader: reader}, nil
} }
@ -65,30 +54,10 @@ func newEmptyReader() *trieReader {
return &trieReader{} return &trieReader{}
} }
// node retrieves the trie node with the provided trie node information.
// An MissingNodeError will be returned in case the node is not found or
// any error is encountered.
func (r *trieReader) node(path []byte, hash common.Hash) (node, error) {
// Perform the logics in tests for preventing trie node access.
if r.banned != nil {
if _, ok := r.banned[string(path)]; ok {
return nil, &MissingNodeError{Owner: r.owner, NodeHash: hash, Path: path}
}
}
if r.reader == nil {
return nil, &MissingNodeError{Owner: r.owner, NodeHash: hash, Path: path}
}
node, err := r.reader.Node(r.owner, path, hash)
if err != nil || node == nil {
return nil, &MissingNodeError{Owner: r.owner, NodeHash: hash, Path: path, err: err}
}
return node, nil
}
// node retrieves the rlp-encoded trie node with the provided trie node // node retrieves the rlp-encoded trie node with the provided trie node
// information. An MissingNodeError will be returned in case the node is // information. An MissingNodeError will be returned in case the node is
// not found or any error is encountered. // not found or any error is encountered.
func (r *trieReader) nodeBlob(path []byte, hash common.Hash) ([]byte, error) { func (r *trieReader) node(path []byte, hash common.Hash) ([]byte, error) {
// Perform the logics in tests for preventing trie node access. // Perform the logics in tests for preventing trie node access.
if r.banned != nil { if r.banned != nil {
if _, ok := r.banned[string(path)]; ok { if _, ok := r.banned[string(path)]; ok {
@ -98,9 +67,29 @@ func (r *trieReader) nodeBlob(path []byte, hash common.Hash) ([]byte, error) {
if r.reader == nil { if r.reader == nil {
return nil, &MissingNodeError{Owner: r.owner, NodeHash: hash, Path: path} return nil, &MissingNodeError{Owner: r.owner, NodeHash: hash, Path: path}
} }
blob, err := r.reader.NodeBlob(r.owner, path, hash) blob, err := r.reader.Node(r.owner, path, hash)
if err != nil || len(blob) == 0 { if err != nil || len(blob) == 0 {
return nil, &MissingNodeError{Owner: r.owner, NodeHash: hash, Path: path, err: err} return nil, &MissingNodeError{Owner: r.owner, NodeHash: hash, Path: path, err: err}
} }
return blob, nil return blob, nil
} }
// MerkleLoader implements triestate.TrieLoader for constructing tries.
type MerkleLoader struct {
db database.Database
}
// NewMerkleLoader creates the merkle trie loader.
func NewMerkleLoader(db database.Database) *MerkleLoader {
return &MerkleLoader{db: db}
}
// OpenTrie opens the main account trie.
func (l *MerkleLoader) OpenTrie(root common.Hash) (triestate.Trie, error) {
return New(TrieID(root), l.db)
}
// OpenStorageTrie opens the storage trie of an account.
func (l *MerkleLoader) OpenStorageTrie(stateRoot common.Hash, addrHash, root common.Hash) (triestate.Trie, error) {
return New(StorageTrieID(stateRoot, addrHash, root), l.db)
}

File diff suppressed because it is too large Load Diff

View File

@ -1,241 +0,0 @@
package trie
import (
"bytes"
"context"
"fmt"
"math/big"
"math/rand"
"testing"
"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/core/rawdb"
gethstate "github.com/ethereum/go-ethereum/core/state"
"github.com/ethereum/go-ethereum/core/types"
"github.com/ethereum/go-ethereum/ethdb"
"github.com/ethereum/go-ethereum/rlp"
gethtrie "github.com/ethereum/go-ethereum/trie"
"github.com/jmoiron/sqlx"
pgipfsethdb "github.com/cerc-io/ipfs-ethdb/v5/postgres/v0"
"github.com/cerc-io/plugeth-statediff/indexer/database/sql/postgres"
"github.com/cerc-io/plugeth-statediff/test_helpers"
"github.com/cerc-io/ipld-eth-statedb/internal"
"github.com/cerc-io/ipld-eth-statedb/trie_by_cid/helper"
)
var (
dbConfig, _ = postgres.TestConfig.WithEnv()
trieConfig = Config{Cache: 256}
)
type kvi struct {
k []byte
v int64
}
type kvMap map[string]*kvi
type kvsi struct {
k string
v int64
}
// NewAccountTrie is a shortcut to create a trie using the StateTrieCodec (ie. IPLD MEthStateTrie codec).
func NewAccountTrie(id *ID, db NodeReader) (*Trie, error) {
return New(id, db, StateTrieCodec)
}
// makeTestTrie create a sample test trie to test node-wise reconstruction.
func makeTestTrie(t testing.TB) (*Database, *StateTrie, map[string][]byte) {
// Create an empty trie
triedb := NewDatabase(rawdb.NewMemoryDatabase())
trie, err := NewStateTrie(TrieID(common.Hash{}), triedb, StateTrieCodec)
if err != nil {
t.Fatal(err)
}
// Fill it with some arbitrary data
content := make(map[string][]byte)
for i := byte(0); i < 255; i++ {
// Map the same data under multiple keys
key, val := common.LeftPadBytes([]byte{1, i}, 32), []byte{i}
content[string(key)] = val
trie.Update(key, val)
key, val = common.LeftPadBytes([]byte{2, i}, 32), []byte{i}
content[string(key)] = val
trie.Update(key, val)
// Add some other data to inflate the trie
for j := byte(3); j < 13; j++ {
key, val = common.LeftPadBytes([]byte{j, i}, 32), []byte{j, i}
content[string(key)] = val
trie.Update(key, val)
}
}
root, nodes := trie.Commit(false)
if err := triedb.Update(NewWithNodeSet(nodes)); err != nil {
panic(fmt.Errorf("failed to commit db %v", err))
}
// Re-create the trie based on the new state
trie, err = NewStateTrie(TrieID(root), triedb, StateTrieCodec)
if err != nil {
t.Fatal(err)
}
return triedb, trie, content
}
func forHashedNodes(tr *Trie) map[string][]byte {
var (
it = tr.NodeIterator(nil)
nodes = make(map[string][]byte)
)
for it.Next(true) {
if it.Hash() == (common.Hash{}) {
continue
}
nodes[string(it.Path())] = common.CopyBytes(it.NodeBlob())
}
return nodes
}
func diffTries(trieA, trieB *Trie) (map[string][]byte, map[string][]byte, map[string][]byte) {
var (
nodesA = forHashedNodes(trieA)
nodesB = forHashedNodes(trieB)
inA = make(map[string][]byte) // hashed nodes in trie a but not b
inB = make(map[string][]byte) // hashed nodes in trie b but not a
both = make(map[string][]byte) // hashed nodes in both tries but different value
)
for path, blobA := range nodesA {
if blobB, ok := nodesB[path]; ok {
if bytes.Equal(blobA, blobB) {
continue
}
both[path] = blobA
continue
}
inA[path] = blobA
}
for path, blobB := range nodesB {
if _, ok := nodesA[path]; ok {
continue
}
inB[path] = blobB
}
return inA, inB, both
}
func packValue(val int64) []byte {
acct := &types.StateAccount{
Balance: big.NewInt(val),
CodeHash: test_helpers.NullCodeHash.Bytes(),
Root: test_helpers.EmptyContractRoot,
}
acct_rlp, err := rlp.EncodeToBytes(acct)
if err != nil {
panic(err)
}
return acct_rlp
}
func updateTrie(tr *gethtrie.Trie, vals []kvsi) (kvMap, error) {
all := kvMap{}
for _, val := range vals {
all[string(val.k)] = &kvi{[]byte(val.k), val.v}
tr.Update([]byte(val.k), packValue(val.v))
}
return all, nil
}
func commitTrie(t testing.TB, db *gethtrie.Database, tr *gethtrie.Trie) common.Hash {
t.Helper()
root, nodes := tr.Commit(false)
if err := db.Update(gethtrie.NewWithNodeSet(nodes)); err != nil {
t.Fatal(err)
}
if err := db.Commit(root, false); err != nil {
t.Fatal(err)
}
return root
}
func makePgIpfsEthDB(t testing.TB) ethdb.Database {
pg_db, err := postgres.ConnectSQLX(context.Background(), dbConfig)
if err != nil {
t.Fatal(err)
}
t.Cleanup(func() {
if err := TearDownDB(pg_db); err != nil {
t.Fatal(err)
}
})
return pgipfsethdb.NewDatabase(pg_db, internal.MakeCacheConfig(t))
}
// commit a LevelDB state trie, index to IPLD and return new trie
func indexTrie(t testing.TB, edb ethdb.Database, root common.Hash) *Trie {
t.Helper()
dbConfig.Driver = postgres.PGX
err := helper.IndexStateDiff(dbConfig, gethstate.NewDatabase(edb), common.Hash{}, root)
if err != nil {
t.Fatal(err)
}
ipfs_db := makePgIpfsEthDB(t)
tr, err := New(TrieID(root), NewDatabase(ipfs_db), StateTrieCodec)
if err != nil {
t.Fatal(err)
}
return tr
}
// generates a random Geth LevelDB trie of n key-value pairs and corresponding value map
func randomGethTrie(n int, db *gethtrie.Database) (*gethtrie.Trie, kvMap) {
trie := gethtrie.NewEmpty(db)
var vals []*kvi
for i := byte(0); i < 100; i++ {
e := &kvi{common.LeftPadBytes([]byte{i}, 32), int64(i)}
e2 := &kvi{common.LeftPadBytes([]byte{i + 10}, 32), int64(i)}
vals = append(vals, e, e2)
}
for i := 0; i < n; i++ {
k := randBytes(32)
v := rand.Int63()
vals = append(vals, &kvi{k, v})
}
all := kvMap{}
for _, val := range vals {
all[string(val.k)] = &kvi{[]byte(val.k), val.v}
trie.Update([]byte(val.k), packValue(val.v))
}
return trie, all
}
// TearDownDB is used to tear down the watcher dbs after tests
func TearDownDB(db *sqlx.DB) error {
tx, err := db.Beginx()
if err != nil {
return err
}
statements := []string{
`DELETE FROM nodes`,
`DELETE FROM ipld.blocks`,
`DELETE FROM eth.header_cids`,
`DELETE FROM eth.uncle_cids`,
`DELETE FROM eth.transaction_cids`,
`DELETE FROM eth.receipt_cids`,
`DELETE FROM eth.state_cids`,
`DELETE FROM eth.storage_cids`,
`DELETE FROM eth.log_cids`,
`DELETE FROM eth_meta.watched_addresses`,
}
for _, stm := range statements {
if _, err = tx.Exec(stm); err != nil {
return fmt.Errorf("error executing `%s`: %w", stm, err)
}
}
return tx.Commit()
}

View File

@ -0,0 +1,339 @@
// Copyright 2022 The go-ethereum Authors
// This file is part of the go-ethereum library.
//
// The go-ethereum library is free software: you can redistribute it and/or modify
// it under the terms of the GNU Lesser General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// The go-ethereum library is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Lesser General Public License for more details.
//
// You should have received a copy of the GNU Lesser General Public License
// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
package triedb
import (
"errors"
"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/ethdb"
"github.com/ethereum/go-ethereum/log"
"github.com/cerc-io/ipld-eth-statedb/trie_by_cid/trie"
"github.com/cerc-io/ipld-eth-statedb/trie_by_cid/trie/trienode"
"github.com/cerc-io/ipld-eth-statedb/trie_by_cid/trie/triestate"
"github.com/cerc-io/ipld-eth-statedb/trie_by_cid/triedb/database"
"github.com/cerc-io/ipld-eth-statedb/trie_by_cid/triedb/hashdb"
"github.com/cerc-io/ipld-eth-statedb/trie_by_cid/triedb/pathdb"
)
// Config defines all necessary options for database.
type Config struct {
Preimages bool // Flag whether the preimage of node key is recorded
IsVerkle bool // Flag whether the db is holding a verkle tree
HashDB *hashdb.Config // Configs for hash-based scheme
PathDB *pathdb.Config // Configs for experimental path-based scheme
}
// HashDefaults represents a config for using hash-based scheme with
// default settings.
var HashDefaults = &Config{
Preimages: false,
HashDB: hashdb.Defaults,
}
// backend defines the methods needed to access/update trie nodes in different
// state scheme.
type backend interface {
// Scheme returns the identifier of used storage scheme.
Scheme() string
// Initialized returns an indicator if the state data is already initialized
// according to the state scheme.
Initialized(genesisRoot common.Hash) bool
// Size returns the current storage size of the diff layers on top of the
// disk layer and the storage size of the nodes cached in the disk layer.
//
// For hash scheme, there is no differentiation between diff layer nodes
// and dirty disk layer nodes, so both are merged into the second return.
Size() (common.StorageSize, common.StorageSize)
// Update performs a state transition by committing dirty nodes contained
// in the given set in order to update state from the specified parent to
// the specified root.
//
// The passed in maps(nodes, states) will be retained to avoid copying
// everything. Therefore, these maps must not be changed afterwards.
Update(root common.Hash, parent common.Hash, block uint64, nodes *trienode.MergedNodeSet, states *triestate.Set) error
// Commit writes all relevant trie nodes belonging to the specified state
// to disk. Report specifies whether logs will be displayed in info level.
Commit(root common.Hash, report bool) error
// Close closes the trie database backend and releases all held resources.
Close() error
}
// Database is the wrapper of the underlying backend which is shared by different
// types of node backend as an entrypoint. It's responsible for all interactions
// relevant with trie nodes and node preimages.
type Database struct {
config *Config // Configuration for trie database
diskdb ethdb.Database // Persistent database to store the snapshot
preimages *preimageStore // The store for caching preimages
backend backend // The backend for managing trie nodes
}
// NewDatabase initializes the trie database with default settings, note
// the legacy hash-based scheme is used by default.
func NewDatabase(diskdb ethdb.Database, config *Config) *Database {
// Sanitize the config and use the default one if it's not specified.
if config == nil {
config = HashDefaults
}
var preimages *preimageStore
if config.Preimages {
preimages = newPreimageStore(diskdb)
}
db := &Database{
config: config,
diskdb: diskdb,
preimages: preimages,
}
if config.HashDB != nil && config.PathDB != nil {
log.Crit("Both 'hash' and 'path' mode are configured")
}
if config.PathDB != nil {
db.backend = pathdb.New(diskdb, config.PathDB)
} else {
var resolver hashdb.ChildResolver
if config.IsVerkle {
// TODO define verkle resolver
log.Crit("Verkle node resolver is not defined")
} else {
resolver = trie.MerkleResolver{}
}
db.backend = hashdb.New(diskdb, config.HashDB, resolver)
}
return db
}
// Reader returns a reader for accessing all trie nodes with provided state root.
// An error will be returned if the requested state is not available.
func (db *Database) Reader(blockRoot common.Hash) (database.Reader, error) {
switch b := db.backend.(type) {
case *hashdb.Database:
return b.Reader(blockRoot)
case *pathdb.Database:
return b.Reader(blockRoot)
}
return nil, errors.New("unknown backend")
}
// Update performs a state transition by committing dirty nodes contained in the
// given set in order to update state from the specified parent to the specified
// root. The held pre-images accumulated up to this point will be flushed in case
// the size exceeds the threshold.
//
// The passed in maps(nodes, states) will be retained to avoid copying everything.
// Therefore, these maps must not be changed afterwards.
func (db *Database) Update(root common.Hash, parent common.Hash, block uint64, nodes *trienode.MergedNodeSet, states *triestate.Set) error {
if db.preimages != nil {
db.preimages.commit(false)
}
return db.backend.Update(root, parent, block, nodes, states)
}
// Commit iterates over all the children of a particular node, writes them out
// to disk. As a side effect, all pre-images accumulated up to this point are
// also written.
func (db *Database) Commit(root common.Hash, report bool) error {
if db.preimages != nil {
db.preimages.commit(true)
}
return db.backend.Commit(root, report)
}
// Size returns the storage size of diff layer nodes above the persistent disk
// layer, the dirty nodes buffered within the disk layer, and the size of cached
// preimages.
func (db *Database) Size() (common.StorageSize, common.StorageSize, common.StorageSize) {
var (
diffs, nodes common.StorageSize
preimages common.StorageSize
)
diffs, nodes = db.backend.Size()
if db.preimages != nil {
preimages = db.preimages.size()
}
return diffs, nodes, preimages
}
// Initialized returns an indicator if the state data is already initialized
// according to the state scheme.
func (db *Database) Initialized(genesisRoot common.Hash) bool {
return db.backend.Initialized(genesisRoot)
}
// Scheme returns the node scheme used in the database.
func (db *Database) Scheme() string {
return db.backend.Scheme()
}
// Close flushes the dangling preimages to disk and closes the trie database.
// It is meant to be called when closing the blockchain object, so that all
// resources held can be released correctly.
func (db *Database) Close() error {
db.WritePreimages()
return db.backend.Close()
}
// WritePreimages flushes all accumulated preimages to disk forcibly.
func (db *Database) WritePreimages() {
if db.preimages != nil {
// db.preimages.commit(true)
}
}
// Preimage retrieves a cached trie node pre-image from preimage store.
func (db *Database) Preimage(hash common.Hash) []byte {
if db.preimages == nil {
return nil
}
return db.preimages.preimage(hash)
}
// InsertPreimage writes pre-images of trie node to the preimage store.
func (db *Database) InsertPreimage(preimages map[common.Hash][]byte) {
if db.preimages == nil {
return
}
db.preimages.insertPreimage(preimages)
}
// Cap iteratively flushes old but still referenced trie nodes until the total
// memory usage goes below the given threshold. The held pre-images accumulated
// up to this point will be flushed in case the size exceeds the threshold.
//
// It's only supported by hash-based database and will return an error for others.
func (db *Database) Cap(limit common.StorageSize) error {
hdb, ok := db.backend.(*hashdb.Database)
if !ok {
return errors.New("not supported")
}
if db.preimages != nil {
// db.preimages.commit(false)
}
return hdb.Cap(limit)
}
// Reference adds a new reference from a parent node to a child node. This function
// is used to add reference between internal trie node and external node(e.g. storage
// trie root), all internal trie nodes are referenced together by database itself.
//
// It's only supported by hash-based database and will return an error for others.
func (db *Database) Reference(root common.Hash, parent common.Hash) error {
hdb, ok := db.backend.(*hashdb.Database)
if !ok {
return errors.New("not supported")
}
hdb.Reference(root, parent)
return nil
}
// Dereference removes an existing reference from a root node. It's only
// supported by hash-based database and will return an error for others.
func (db *Database) Dereference(root common.Hash) error {
hdb, ok := db.backend.(*hashdb.Database)
if !ok {
return errors.New("not supported")
}
hdb.Dereference(root)
return nil
}
// Recover rollbacks the database to a specified historical point. The state is
// supported as the rollback destination only if it's canonical state and the
// corresponding trie histories are existent. It's only supported by path-based
// database and will return an error for others.
func (db *Database) Recover(target common.Hash) error {
pdb, ok := db.backend.(*pathdb.Database)
if !ok {
return errors.New("not supported")
}
var loader triestate.TrieLoader
if db.config.IsVerkle {
// TODO define verkle loader
log.Crit("Verkle loader is not defined")
} else {
loader = trie.NewMerkleLoader(db)
}
return pdb.Recover(target, loader)
}
// Recoverable returns the indicator if the specified state is enabled to be
// recovered. It's only supported by path-based database and will return an
// error for others.
func (db *Database) Recoverable(root common.Hash) (bool, error) {
pdb, ok := db.backend.(*pathdb.Database)
if !ok {
return false, errors.New("not supported")
}
return pdb.Recoverable(root), nil
}
// Disable deactivates the database and invalidates all available state layers
// as stale to prevent access to the persistent state, which is in the syncing
// stage.
//
// It's only supported by path-based database and will return an error for others.
func (db *Database) Disable() error {
pdb, ok := db.backend.(*pathdb.Database)
if !ok {
return errors.New("not supported")
}
return pdb.Disable()
}
// Enable activates database and resets the state tree with the provided persistent
// state root once the state sync is finished.
func (db *Database) Enable(root common.Hash) error {
pdb, ok := db.backend.(*pathdb.Database)
if !ok {
return errors.New("not supported")
}
return pdb.Enable(root)
}
// Journal commits an entire diff hierarchy to disk into a single journal entry.
// This is meant to be used during shutdown to persist the snapshot without
// flattening everything down (bad for reorgs). It's only supported by path-based
// database and will return an error for others.
func (db *Database) Journal(root common.Hash) error {
pdb, ok := db.backend.(*pathdb.Database)
if !ok {
return errors.New("not supported")
}
return pdb.Journal(root)
}
// SetBufferSize sets the node buffer size to the provided value(in bytes).
// It's only supported by path-based database and will return an error for
// others.
func (db *Database) SetBufferSize(size int) error {
pdb, ok := db.backend.(*pathdb.Database)
if !ok {
return errors.New("not supported")
}
return pdb.SetBufferSize(size)
}
// IsVerkle returns the indicator if the database is holding a verkle tree.
func (db *Database) IsVerkle() bool {
return db.config.IsVerkle
}

View File

@ -0,0 +1,48 @@
// Copyright 2024 The go-ethereum Authors
// This file is part of the go-ethereum library.
//
// The go-ethereum library is free software: you can redistribute it and/or modify
// it under the terms of the GNU Lesser General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// The go-ethereum library is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Lesser General Public License for more details.
//
// You should have received a copy of the GNU Lesser General Public License
// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
package database
import (
"github.com/ethereum/go-ethereum/common"
)
// Reader wraps the Node method of a backing trie reader.
type Reader interface {
// Node retrieves the trie node blob with the provided trie identifier,
// node path and the corresponding node hash. No error will be returned
// if the node is not found.
Node(owner common.Hash, path []byte, hash common.Hash) ([]byte, error)
}
// PreimageStore wraps the methods of a backing store for reading and writing
// trie node preimages.
type PreimageStore interface {
// Preimage retrieves the preimage of the specified hash.
Preimage(hash common.Hash) []byte
// InsertPreimage commits a set of preimages along with their hashes.
InsertPreimage(preimages map[common.Hash][]byte)
}
// Database wraps the methods of a backing trie store.
type Database interface {
PreimageStore
// Reader returns a node reader associated with the specific state.
// An error will be returned if the specified state is not available.
Reader(stateRoot common.Hash) (Reader, error)
}

View File

@ -0,0 +1,665 @@
// Copyright 2018 The go-ethereum Authors
// This file is part of the go-ethereum library.
//
// The go-ethereum library is free software: you can redistribute it and/or modify
// it under the terms of the GNU Lesser General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// The go-ethereum library is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Lesser General Public License for more details.
//
// You should have received a copy of the GNU Lesser General Public License
// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
package hashdb
import (
"errors"
"fmt"
"reflect"
"sync"
"time"
"github.com/VictoriaMetrics/fastcache"
"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/core/rawdb"
"github.com/ethereum/go-ethereum/core/types"
"github.com/ethereum/go-ethereum/ethdb"
"github.com/ethereum/go-ethereum/log"
"github.com/ethereum/go-ethereum/metrics"
"github.com/ethereum/go-ethereum/rlp"
"github.com/cerc-io/ipld-eth-statedb/internal"
"github.com/cerc-io/ipld-eth-statedb/trie_by_cid/trie/trienode"
"github.com/cerc-io/ipld-eth-statedb/trie_by_cid/trie/triestate"
)
var (
memcacheCleanHitMeter = metrics.NewRegisteredMeter("hashdb/memcache/clean/hit", nil)
memcacheCleanMissMeter = metrics.NewRegisteredMeter("hashdb/memcache/clean/miss", nil)
memcacheCleanReadMeter = metrics.NewRegisteredMeter("hashdb/memcache/clean/read", nil)
memcacheCleanWriteMeter = metrics.NewRegisteredMeter("hashdb/memcache/clean/write", nil)
memcacheDirtyHitMeter = metrics.NewRegisteredMeter("hashdb/memcache/dirty/hit", nil)
memcacheDirtyMissMeter = metrics.NewRegisteredMeter("hashdb/memcache/dirty/miss", nil)
memcacheDirtyReadMeter = metrics.NewRegisteredMeter("hashdb/memcache/dirty/read", nil)
memcacheDirtyWriteMeter = metrics.NewRegisteredMeter("hashdb/memcache/dirty/write", nil)
memcacheFlushTimeTimer = metrics.NewRegisteredResettingTimer("hashdb/memcache/flush/time", nil)
memcacheFlushNodesMeter = metrics.NewRegisteredMeter("hashdb/memcache/flush/nodes", nil)
memcacheFlushBytesMeter = metrics.NewRegisteredMeter("hashdb/memcache/flush/bytes", nil)
memcacheGCTimeTimer = metrics.NewRegisteredResettingTimer("hashdb/memcache/gc/time", nil)
memcacheGCNodesMeter = metrics.NewRegisteredMeter("hashdb/memcache/gc/nodes", nil)
memcacheGCBytesMeter = metrics.NewRegisteredMeter("hashdb/memcache/gc/bytes", nil)
memcacheCommitTimeTimer = metrics.NewRegisteredResettingTimer("hashdb/memcache/commit/time", nil)
memcacheCommitNodesMeter = metrics.NewRegisteredMeter("hashdb/memcache/commit/nodes", nil)
memcacheCommitBytesMeter = metrics.NewRegisteredMeter("hashdb/memcache/commit/bytes", nil)
)
// ChildResolver defines the required method to decode the provided
// trie node and iterate the children on top.
type ChildResolver interface {
ForEach(node []byte, onChild func(common.Hash))
}
// Config contains the settings for database.
type Config struct {
CleanCacheSize int // Maximum memory allowance (in bytes) for caching clean nodes
}
// Defaults is the default setting for database if it's not specified.
// Notably, clean cache is disabled explicitly,
var Defaults = &Config{
// Explicitly set clean cache size to 0 to avoid creating fastcache,
// otherwise database must be closed when it's no longer needed to
// prevent memory leak.
CleanCacheSize: 0,
}
// Database is an intermediate write layer between the trie data structures and
// the disk database. The aim is to accumulate trie writes in-memory and only
// periodically flush a couple tries to disk, garbage collecting the remainder.
type Database struct {
diskdb ethdb.Database // Persistent storage for matured trie nodes
resolver ChildResolver // The handler to resolve children of nodes
cleans *fastcache.Cache // GC friendly memory cache of clean node RLPs
dirties map[common.Hash]*cachedNode // Data and references relationships of dirty trie nodes
oldest common.Hash // Oldest tracked node, flush-list head
newest common.Hash // Newest tracked node, flush-list tail
gctime time.Duration // Time spent on garbage collection since last commit
gcnodes uint64 // Nodes garbage collected since last commit
gcsize common.StorageSize // Data storage garbage collected since last commit
flushtime time.Duration // Time spent on data flushing since last commit
flushnodes uint64 // Nodes flushed since last commit
flushsize common.StorageSize // Data storage flushed since last commit
dirtiesSize common.StorageSize // Storage size of the dirty node cache (exc. metadata)
childrenSize common.StorageSize // Storage size of the external children tracking
lock sync.RWMutex
}
// cachedNode is all the information we know about a single cached trie node
// in the memory database write layer.
type cachedNode struct {
node []byte // Encoded node blob, immutable
parents uint32 // Number of live nodes referencing this one
external map[common.Hash]struct{} // The set of external children
flushPrev common.Hash // Previous node in the flush-list
flushNext common.Hash // Next node in the flush-list
}
// cachedNodeSize is the raw size of a cachedNode data structure without any
// node data included. It's an approximate size, but should be a lot better
// than not counting them.
var cachedNodeSize = int(reflect.TypeOf(cachedNode{}).Size())
// forChildren invokes the callback for all the tracked children of this node,
// both the implicit ones from inside the node as well as the explicit ones
// from outside the node.
func (n *cachedNode) forChildren(resolver ChildResolver, onChild func(hash common.Hash)) {
for child := range n.external {
onChild(child)
}
resolver.ForEach(n.node, onChild)
}
// New initializes the hash-based node database.
func New(diskdb ethdb.Database, config *Config, resolver ChildResolver) *Database {
if config == nil {
config = Defaults
}
var cleans *fastcache.Cache
if config.CleanCacheSize > 0 {
cleans = fastcache.New(config.CleanCacheSize)
}
return &Database{
diskdb: diskdb,
resolver: resolver,
cleans: cleans,
dirties: make(map[common.Hash]*cachedNode),
}
}
// insert inserts a trie node into the memory database. All nodes inserted by
// this function will be reference tracked. This function assumes the lock is
// already held.
func (db *Database) insert(hash common.Hash, node []byte) {
// If the node's already cached, skip
if _, ok := db.dirties[hash]; ok {
return
}
memcacheDirtyWriteMeter.Mark(int64(len(node)))
// Create the cached entry for this node
entry := &cachedNode{
node: node,
flushPrev: db.newest,
}
entry.forChildren(db.resolver, func(child common.Hash) {
if c := db.dirties[child]; c != nil {
c.parents++
}
})
db.dirties[hash] = entry
// Update the flush-list endpoints
if db.oldest == (common.Hash{}) {
db.oldest, db.newest = hash, hash
} else {
db.dirties[db.newest].flushNext, db.newest = hash, hash
}
db.dirtiesSize += common.StorageSize(common.HashLength + len(node))
}
// node retrieves an encoded cached trie node from memory. If it cannot be found
// cached, the method queries the persistent database for the content.
func (db *Database) node(hash common.Hash, codec uint64) ([]byte, error) {
// It doesn't make sense to retrieve the metaroot
if hash == (common.Hash{}) {
return nil, errors.New("not found")
}
// Retrieve the node from the clean cache if available
if db.cleans != nil {
if enc := db.cleans.Get(nil, hash[:]); enc != nil {
memcacheCleanHitMeter.Mark(1)
memcacheCleanReadMeter.Mark(int64(len(enc)))
return enc, nil
}
}
// Retrieve the node from the dirty cache if available.
db.lock.RLock()
dirty := db.dirties[hash]
db.lock.RUnlock()
// Return the cached node if it's found in the dirty set.
// The dirty.node field is immutable and safe to read it
// even without lock guard.
if dirty != nil {
memcacheDirtyHitMeter.Mark(1)
memcacheDirtyReadMeter.Mark(int64(len(dirty.node)))
return dirty.node, nil
}
memcacheDirtyMissMeter.Mark(1)
// Content unavailable in memory, attempt to retrieve from disk
cid, err := internal.Keccak256ToCid(codec, hash[:])
if err != nil {
return nil, err
}
enc, err := db.diskdb.Get(cid.Bytes())
if err != nil {
return nil, err
}
if len(enc) != 0 {
if db.cleans != nil {
db.cleans.Set(hash[:], enc)
memcacheCleanMissMeter.Mark(1)
memcacheCleanWriteMeter.Mark(int64(len(enc)))
}
return enc, nil
}
return nil, errors.New("not found")
}
// Reference adds a new reference from a parent node to a child node.
// This function is used to add reference between internal trie node
// and external node(e.g. storage trie root), all internal trie nodes
// are referenced together by database itself.
func (db *Database) Reference(child common.Hash, parent common.Hash) {
db.lock.Lock()
defer db.lock.Unlock()
db.reference(child, parent)
}
// reference is the private locked version of Reference.
func (db *Database) reference(child common.Hash, parent common.Hash) {
// If the node does not exist, it's a node pulled from disk, skip
node, ok := db.dirties[child]
if !ok {
return
}
// The reference is for state root, increase the reference counter.
if parent == (common.Hash{}) {
node.parents += 1
return
}
// The reference is for external storage trie, don't duplicate if
// the reference is already existent.
if db.dirties[parent].external == nil {
db.dirties[parent].external = make(map[common.Hash]struct{})
}
if _, ok := db.dirties[parent].external[child]; ok {
return
}
node.parents++
db.dirties[parent].external[child] = struct{}{}
db.childrenSize += common.HashLength
}
// Dereference removes an existing reference from a root node.
func (db *Database) Dereference(root common.Hash) {
// Sanity check to ensure that the meta-root is not removed
if root == (common.Hash{}) {
log.Error("Attempted to dereference the trie cache meta root")
return
}
db.lock.Lock()
defer db.lock.Unlock()
nodes, storage, start := len(db.dirties), db.dirtiesSize, time.Now()
db.dereference(root)
db.gcnodes += uint64(nodes - len(db.dirties))
db.gcsize += storage - db.dirtiesSize
db.gctime += time.Since(start)
memcacheGCTimeTimer.Update(time.Since(start))
memcacheGCBytesMeter.Mark(int64(storage - db.dirtiesSize))
memcacheGCNodesMeter.Mark(int64(nodes - len(db.dirties)))
log.Debug("Dereferenced trie from memory database", "nodes", nodes-len(db.dirties), "size", storage-db.dirtiesSize, "time", time.Since(start),
"gcnodes", db.gcnodes, "gcsize", db.gcsize, "gctime", db.gctime, "livenodes", len(db.dirties), "livesize", db.dirtiesSize)
}
// dereference is the private locked version of Dereference.
func (db *Database) dereference(hash common.Hash) {
// If the node does not exist, it's a previously committed node.
node, ok := db.dirties[hash]
if !ok {
return
}
// If there are no more references to the node, delete it and cascade
if node.parents > 0 {
// This is a special cornercase where a node loaded from disk (i.e. not in the
// memcache any more) gets reinjected as a new node (short node split into full,
// then reverted into short), causing a cached node to have no parents. That is
// no problem in itself, but don't make maxint parents out of it.
node.parents--
}
if node.parents == 0 {
// Remove the node from the flush-list
switch hash {
case db.oldest:
db.oldest = node.flushNext
if node.flushNext != (common.Hash{}) {
db.dirties[node.flushNext].flushPrev = common.Hash{}
}
case db.newest:
db.newest = node.flushPrev
if node.flushPrev != (common.Hash{}) {
db.dirties[node.flushPrev].flushNext = common.Hash{}
}
default:
db.dirties[node.flushPrev].flushNext = node.flushNext
db.dirties[node.flushNext].flushPrev = node.flushPrev
}
// Dereference all children and delete the node
node.forChildren(db.resolver, func(child common.Hash) {
db.dereference(child)
})
delete(db.dirties, hash)
db.dirtiesSize -= common.StorageSize(common.HashLength + len(node.node))
if node.external != nil {
db.childrenSize -= common.StorageSize(len(node.external) * common.HashLength)
}
}
}
// Cap iteratively flushes old but still referenced trie nodes until the total
// memory usage goes below the given threshold.
func (db *Database) Cap(limit common.StorageSize) error {
db.lock.Lock()
defer db.lock.Unlock()
// Create a database batch to flush persistent data out. It is important that
// outside code doesn't see an inconsistent state (referenced data removed from
// memory cache during commit but not yet in persistent storage). This is ensured
// by only uncaching existing data when the database write finalizes.
batch := db.diskdb.NewBatch()
nodes, storage, start := len(db.dirties), db.dirtiesSize, time.Now()
// db.dirtiesSize only contains the useful data in the cache, but when reporting
// the total memory consumption, the maintenance metadata is also needed to be
// counted.
size := db.dirtiesSize + common.StorageSize(len(db.dirties)*cachedNodeSize)
size += db.childrenSize
// Keep committing nodes from the flush-list until we're below allowance
oldest := db.oldest
for size > limit && oldest != (common.Hash{}) {
// Fetch the oldest referenced node and push into the batch
node := db.dirties[oldest]
rawdb.WriteLegacyTrieNode(batch, oldest, node.node)
// If we exceeded the ideal batch size, commit and reset
if batch.ValueSize() >= ethdb.IdealBatchSize {
if err := batch.Write(); err != nil {
log.Error("Failed to write flush list to disk", "err", err)
return err
}
batch.Reset()
}
// Iterate to the next flush item, or abort if the size cap was achieved. Size
// is the total size, including the useful cached data (hash -> blob), the
// cache item metadata, as well as external children mappings.
size -= common.StorageSize(common.HashLength + len(node.node) + cachedNodeSize)
if node.external != nil {
size -= common.StorageSize(len(node.external) * common.HashLength)
}
oldest = node.flushNext
}
// Flush out any remainder data from the last batch
if err := batch.Write(); err != nil {
log.Error("Failed to write flush list to disk", "err", err)
return err
}
// Write successful, clear out the flushed data
for db.oldest != oldest {
node := db.dirties[db.oldest]
delete(db.dirties, db.oldest)
db.oldest = node.flushNext
db.dirtiesSize -= common.StorageSize(common.HashLength + len(node.node))
if node.external != nil {
db.childrenSize -= common.StorageSize(len(node.external) * common.HashLength)
}
}
if db.oldest != (common.Hash{}) {
db.dirties[db.oldest].flushPrev = common.Hash{}
}
db.flushnodes += uint64(nodes - len(db.dirties))
db.flushsize += storage - db.dirtiesSize
db.flushtime += time.Since(start)
memcacheFlushTimeTimer.Update(time.Since(start))
memcacheFlushBytesMeter.Mark(int64(storage - db.dirtiesSize))
memcacheFlushNodesMeter.Mark(int64(nodes - len(db.dirties)))
log.Debug("Persisted nodes from memory database", "nodes", nodes-len(db.dirties), "size", storage-db.dirtiesSize, "time", time.Since(start),
"flushnodes", db.flushnodes, "flushsize", db.flushsize, "flushtime", db.flushtime, "livenodes", len(db.dirties), "livesize", db.dirtiesSize)
return nil
}
// Commit iterates over all the children of a particular node, writes them out
// to disk, forcefully tearing down all references in both directions. As a side
// effect, all pre-images accumulated up to this point are also written.
func (db *Database) Commit(node common.Hash, report bool) error {
db.lock.Lock()
defer db.lock.Unlock()
// Create a database batch to flush persistent data out. It is important that
// outside code doesn't see an inconsistent state (referenced data removed from
// memory cache during commit but not yet in persistent storage). This is ensured
// by only uncaching existing data when the database write finalizes.
start := time.Now()
batch := db.diskdb.NewBatch()
// Move the trie itself into the batch, flushing if enough data is accumulated
nodes, storage := len(db.dirties), db.dirtiesSize
uncacher := &cleaner{db}
if err := db.commit(node, batch, uncacher); err != nil {
log.Error("Failed to commit trie from trie database", "err", err)
return err
}
// Trie mostly committed to disk, flush any batch leftovers
if err := batch.Write(); err != nil {
log.Error("Failed to write trie to disk", "err", err)
return err
}
// Uncache any leftovers in the last batch
if err := batch.Replay(uncacher); err != nil {
return err
}
batch.Reset()
// Reset the storage counters and bumped metrics
memcacheCommitTimeTimer.Update(time.Since(start))
memcacheCommitBytesMeter.Mark(int64(storage - db.dirtiesSize))
memcacheCommitNodesMeter.Mark(int64(nodes - len(db.dirties)))
logger := log.Info
if !report {
logger = log.Debug
}
logger("Persisted trie from memory database", "nodes", nodes-len(db.dirties)+int(db.flushnodes), "size", storage-db.dirtiesSize+db.flushsize, "time", time.Since(start)+db.flushtime,
"gcnodes", db.gcnodes, "gcsize", db.gcsize, "gctime", db.gctime, "livenodes", len(db.dirties), "livesize", db.dirtiesSize)
// Reset the garbage collection statistics
db.gcnodes, db.gcsize, db.gctime = 0, 0, 0
db.flushnodes, db.flushsize, db.flushtime = 0, 0, 0
return nil
}
// commit is the private locked version of Commit.
func (db *Database) commit(hash common.Hash, batch ethdb.Batch, uncacher *cleaner) error {
// If the node does not exist, it's a previously committed node
node, ok := db.dirties[hash]
if !ok {
return nil
}
var err error
// Dereference all children and delete the node
node.forChildren(db.resolver, func(child common.Hash) {
if err == nil {
err = db.commit(child, batch, uncacher)
}
})
if err != nil {
return err
}
// If we've reached an optimal batch size, commit and start over
rawdb.WriteLegacyTrieNode(batch, hash, node.node)
if batch.ValueSize() >= ethdb.IdealBatchSize {
if err := batch.Write(); err != nil {
return err
}
err := batch.Replay(uncacher)
if err != nil {
return err
}
batch.Reset()
}
return nil
}
// cleaner is a database batch replayer that takes a batch of write operations
// and cleans up the trie database from anything written to disk.
type cleaner struct {
db *Database
}
// Put reacts to database writes and implements dirty data uncaching. This is the
// post-processing step of a commit operation where the already persisted trie is
// removed from the dirty cache and moved into the clean cache. The reason behind
// the two-phase commit is to ensure data availability while moving from memory
// to disk.
func (c *cleaner) Put(key []byte, rlp []byte) error {
hash := common.BytesToHash(key)
// If the node does not exist, we're done on this path
node, ok := c.db.dirties[hash]
if !ok {
return nil
}
// Node still exists, remove it from the flush-list
switch hash {
case c.db.oldest:
c.db.oldest = node.flushNext
if node.flushNext != (common.Hash{}) {
c.db.dirties[node.flushNext].flushPrev = common.Hash{}
}
case c.db.newest:
c.db.newest = node.flushPrev
if node.flushPrev != (common.Hash{}) {
c.db.dirties[node.flushPrev].flushNext = common.Hash{}
}
default:
c.db.dirties[node.flushPrev].flushNext = node.flushNext
c.db.dirties[node.flushNext].flushPrev = node.flushPrev
}
// Remove the node from the dirty cache
delete(c.db.dirties, hash)
c.db.dirtiesSize -= common.StorageSize(common.HashLength + len(node.node))
if node.external != nil {
c.db.childrenSize -= common.StorageSize(len(node.external) * common.HashLength)
}
// Move the flushed node into the clean cache to prevent insta-reloads
if c.db.cleans != nil {
c.db.cleans.Set(hash[:], rlp)
memcacheCleanWriteMeter.Mark(int64(len(rlp)))
}
return nil
}
func (c *cleaner) Delete(key []byte) error {
panic("not implemented")
}
// Initialized returns an indicator if state data is already initialized
// in hash-based scheme by checking the presence of genesis state.
func (db *Database) Initialized(genesisRoot common.Hash) bool {
return rawdb.HasLegacyTrieNode(db.diskdb, genesisRoot)
}
// Update inserts the dirty nodes in provided nodeset into database and link the
// account trie with multiple storage tries if necessary.
func (db *Database) Update(root common.Hash, parent common.Hash, block uint64, nodes *trienode.MergedNodeSet, states *triestate.Set) error {
// Ensure the parent state is present and signal a warning if not.
if parent != types.EmptyRootHash {
if blob, _ := db.node(parent, internal.StateTrieCodec); len(blob) == 0 {
log.Error("parent state is not present")
}
}
db.lock.Lock()
defer db.lock.Unlock()
// Insert dirty nodes into the database. In the same tree, it must be
// ensured that children are inserted first, then parent so that children
// can be linked with their parent correctly.
//
// Note, the storage tries must be flushed before the account trie to
// retain the invariant that children go into the dirty cache first.
var order []common.Hash
for owner := range nodes.Sets {
if owner == (common.Hash{}) {
continue
}
order = append(order, owner)
}
if _, ok := nodes.Sets[common.Hash{}]; ok {
order = append(order, common.Hash{})
}
for _, owner := range order {
subset := nodes.Sets[owner]
subset.ForEachWithOrder(func(path string, n *trienode.Node) {
if n.IsDeleted() {
return // ignore deletion
}
db.insert(n.Hash, n.Blob)
})
}
// Link up the account trie and storage trie if the node points
// to an account trie leaf.
if set, present := nodes.Sets[common.Hash{}]; present {
for _, n := range set.Leaves {
var account types.StateAccount
if err := rlp.DecodeBytes(n.Blob, &account); err != nil {
return err
}
if account.Root != types.EmptyRootHash {
db.reference(account.Root, n.Parent)
}
}
}
return nil
}
// Size returns the current storage size of the memory cache in front of the
// persistent database layer.
//
// The first return will always be 0, representing the memory stored in unbounded
// diff layers above the dirty cache. This is only available in pathdb.
func (db *Database) Size() (common.StorageSize, common.StorageSize) {
db.lock.RLock()
defer db.lock.RUnlock()
// db.dirtiesSize only contains the useful data in the cache, but when reporting
// the total memory consumption, the maintenance metadata is also needed to be
// counted.
var metadataSize = common.StorageSize(len(db.dirties) * cachedNodeSize)
return 0, db.dirtiesSize + db.childrenSize + metadataSize
}
// Close closes the trie database and releases all held resources.
func (db *Database) Close() error {
if db.cleans != nil {
db.cleans.Reset()
db.cleans = nil
}
return nil
}
// Scheme returns the node scheme used in the database.
func (db *Database) Scheme() string {
return rawdb.HashScheme
}
// Reader retrieves a node reader belonging to the given state root.
// An error will be returned if the requested state is not available.
func (db *Database) Reader(root common.Hash) (*reader, error) {
if _, err := db.node(root, internal.StateTrieCodec); err != nil {
return nil, fmt.Errorf("state %#x is not available, %v", root, err)
}
return &reader{db: db}, nil
}
// reader is a state reader of Database which implements the Reader interface.
type reader struct {
db *Database
}
// Node retrieves the trie node with the given node hash. No error will be
// returned if the node is not found.
func (reader *reader) Node(owner common.Hash, path []byte, hash common.Hash) ([]byte, error) {
// this is an account node iff the owner hash is zero
codec := internal.StateTrieCodec
if owner != (common.Hash{}) {
codec = internal.StorageTrieCodec
}
blob, _ := reader.db.node(hash, codec)
return blob, nil
}

View File

@ -0,0 +1,485 @@
// Copyright 2022 The go-ethereum Authors
// This file is part of the go-ethereum library.
//
// The go-ethereum library is free software: you can redistribute it and/or modify
// it under the terms of the GNU Lesser General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// The go-ethereum library is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Lesser General Public License for more details.
//
// You should have received a copy of the GNU Lesser General Public License
// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
package pathdb
import (
"errors"
"fmt"
"io"
"sync"
"time"
"github.com/cerc-io/ipld-eth-statedb/trie_by_cid/trie/trienode"
"github.com/cerc-io/ipld-eth-statedb/trie_by_cid/trie/triestate"
"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/core/rawdb"
"github.com/ethereum/go-ethereum/core/types"
"github.com/ethereum/go-ethereum/ethdb"
"github.com/ethereum/go-ethereum/log"
"github.com/ethereum/go-ethereum/params"
)
const (
// maxDiffLayers is the maximum diff layers allowed in the layer tree.
maxDiffLayers = 128
// defaultCleanSize is the default memory allowance of clean cache.
defaultCleanSize = 16 * 1024 * 1024
// maxBufferSize is the maximum memory allowance of node buffer.
// Too large nodebuffer will cause the system to pause for a long
// time when write happens. Also, the largest batch that pebble can
// support is 4GB, node will panic if batch size exceeds this limit.
maxBufferSize = 256 * 1024 * 1024
// DefaultBufferSize is the default memory allowance of node buffer
// that aggregates the writes from above until it's flushed into the
// disk. It's meant to be used once the initial sync is finished.
// Do not increase the buffer size arbitrarily, otherwise the system
// pause time will increase when the database writes happen.
DefaultBufferSize = 64 * 1024 * 1024
)
// layer is the interface implemented by all state layers which includes some
// public methods and some additional methods for internal usage.
type layer interface {
// Node retrieves the trie node with the node info. An error will be returned
// if the read operation exits abnormally. For example, if the layer is already
// stale, or the associated state is regarded as corrupted. Notably, no error
// will be returned if the requested node is not found in database.
Node(owner common.Hash, path []byte, hash common.Hash) ([]byte, error)
// rootHash returns the root hash for which this layer was made.
rootHash() common.Hash
// stateID returns the associated state id of layer.
stateID() uint64
// parentLayer returns the subsequent layer of it, or nil if the disk was reached.
parentLayer() layer
// update creates a new layer on top of the existing layer diff tree with
// the provided dirty trie nodes along with the state change set.
//
// Note, the maps are retained by the method to avoid copying everything.
update(root common.Hash, id uint64, block uint64, nodes map[common.Hash]map[string]*trienode.Node, states *triestate.Set) *diffLayer
// journal commits an entire diff hierarchy to disk into a single journal entry.
// This is meant to be used during shutdown to persist the layer without
// flattening everything down (bad for reorgs).
journal(w io.Writer) error
}
// Config contains the settings for database.
type Config struct {
StateHistory uint64 // Number of recent blocks to maintain state history for
CleanCacheSize int // Maximum memory allowance (in bytes) for caching clean nodes
DirtyCacheSize int // Maximum memory allowance (in bytes) for caching dirty nodes
ReadOnly bool // Flag whether the database is opened in read only mode.
}
// sanitize checks the provided user configurations and changes anything that's
// unreasonable or unworkable.
func (c *Config) sanitize() *Config {
conf := *c
if conf.DirtyCacheSize > maxBufferSize {
log.Warn("Sanitizing invalid node buffer size", "provided", common.StorageSize(conf.DirtyCacheSize), "updated", common.StorageSize(maxBufferSize))
conf.DirtyCacheSize = maxBufferSize
}
return &conf
}
// Defaults contains default settings for Ethereum mainnet.
var Defaults = &Config{
StateHistory: params.FullImmutabilityThreshold,
CleanCacheSize: defaultCleanSize,
DirtyCacheSize: DefaultBufferSize,
}
// ReadOnly is the config in order to open database in read only mode.
var ReadOnly = &Config{ReadOnly: true}
// Database is a multiple-layered structure for maintaining in-memory trie nodes.
// It consists of one persistent base layer backed by a key-value store, on top
// of which arbitrarily many in-memory diff layers are stacked. The memory diffs
// can form a tree with branching, but the disk layer is singleton and common to
// all. If a reorg goes deeper than the disk layer, a batch of reverse diffs can
// be applied to rollback. The deepest reorg that can be handled depends on the
// amount of state histories tracked in the disk.
//
// At most one readable and writable database can be opened at the same time in
// the whole system which ensures that only one database writer can operate disk
// state. Unexpected open operations can cause the system to panic.
type Database struct {
// readOnly is the flag whether the mutation is allowed to be applied.
// It will be set automatically when the database is journaled during
// the shutdown to reject all following unexpected mutations.
readOnly bool // Flag if database is opened in read only mode
waitSync bool // Flag if database is deactivated due to initial state sync
bufferSize int // Memory allowance (in bytes) for caching dirty nodes
config *Config // Configuration for database
diskdb ethdb.Database // Persistent storage for matured trie nodes
tree *layerTree // The group for all known layers
freezer *rawdb.ResettableFreezer // Freezer for storing trie histories, nil possible in tests
lock sync.RWMutex // Lock to prevent mutations from happening at the same time
}
// New attempts to load an already existing layer from a persistent key-value
// store (with a number of memory layers from a journal). If the journal is not
// matched with the base persistent layer, all the recorded diff layers are discarded.
func New(diskdb ethdb.Database, config *Config) *Database {
if config == nil {
config = Defaults
}
config = config.sanitize()
db := &Database{
readOnly: config.ReadOnly,
bufferSize: config.DirtyCacheSize,
config: config,
diskdb: diskdb,
}
// Construct the layer tree by resolving the in-disk singleton state
// and in-memory layer journal.
db.tree = newLayerTree(db.loadLayers())
// Open the freezer for state history if the passed database contains an
// ancient store. Otherwise, all the relevant functionalities are disabled.
//
// Because the freezer can only be opened once at the same time, this
// mechanism also ensures that at most one **non-readOnly** database
// is opened at the same time to prevent accidental mutation.
if ancient, err := diskdb.AncientDatadir(); err == nil && ancient != "" && !db.readOnly {
freezer, err := rawdb.NewStateFreezer(ancient, false)
if err != nil {
log.Crit("Failed to open state history freezer", "err", err)
}
db.freezer = freezer
diskLayerID := db.tree.bottom().stateID()
if diskLayerID == 0 {
// Reset the entire state histories in case the trie database is
// not initialized yet, as these state histories are not expected.
frozen, err := db.freezer.Ancients()
if err != nil {
log.Crit("Failed to retrieve head of state history", "err", err)
}
if frozen != 0 {
err := db.freezer.Reset()
if err != nil {
log.Crit("Failed to reset state histories", "err", err)
}
log.Info("Truncated extraneous state history")
}
} else {
// Truncate the extra state histories above in freezer in case
// it's not aligned with the disk layer.
pruned, err := truncateFromHead(db.diskdb, freezer, diskLayerID)
if err != nil {
log.Crit("Failed to truncate extra state histories", "err", err)
}
if pruned != 0 {
log.Warn("Truncated extra state histories", "number", pruned)
}
}
}
// Disable database in case node is still in the initial state sync stage.
if rawdb.ReadSnapSyncStatusFlag(diskdb) == rawdb.StateSyncRunning && !db.readOnly {
if err := db.Disable(); err != nil {
log.Crit("Failed to disable database", "err", err) // impossible to happen
}
}
log.Warn("Path-based state scheme is an experimental feature")
return db
}
// Reader retrieves a layer belonging to the given state root.
func (db *Database) Reader(root common.Hash) (layer, error) {
l := db.tree.get(root)
if l == nil {
return nil, fmt.Errorf("state %#x is not available", root)
}
return l, nil
}
// Update adds a new layer into the tree, if that can be linked to an existing
// old parent. It is disallowed to insert a disk layer (the origin of all). Apart
// from that this function will flatten the extra diff layers at bottom into disk
// to only keep 128 diff layers in memory by default.
//
// The passed in maps(nodes, states) will be retained to avoid copying everything.
// Therefore, these maps must not be changed afterwards.
func (db *Database) Update(root common.Hash, parentRoot common.Hash, block uint64, nodes *trienode.MergedNodeSet, states *triestate.Set) error {
// Hold the lock to prevent concurrent mutations.
db.lock.Lock()
defer db.lock.Unlock()
// Short circuit if the mutation is not allowed.
if err := db.modifyAllowed(); err != nil {
return err
}
if err := db.tree.add(root, parentRoot, block, nodes, states); err != nil {
return err
}
// Keep 128 diff layers in the memory, persistent layer is 129th.
// - head layer is paired with HEAD state
// - head-1 layer is paired with HEAD-1 state
// - head-127 layer(bottom-most diff layer) is paired with HEAD-127 state
// - head-128 layer(disk layer) is paired with HEAD-128 state
return db.tree.cap(root, maxDiffLayers)
}
// Commit traverses downwards the layer tree from a specified layer with the
// provided state root and all the layers below are flattened downwards. It
// can be used alone and mostly for test purposes.
func (db *Database) Commit(root common.Hash, report bool) error {
// Hold the lock to prevent concurrent mutations.
db.lock.Lock()
defer db.lock.Unlock()
// Short circuit if the mutation is not allowed.
if err := db.modifyAllowed(); err != nil {
return err
}
return db.tree.cap(root, 0)
}
// Disable deactivates the database and invalidates all available state layers
// as stale to prevent access to the persistent state, which is in the syncing
// stage.
func (db *Database) Disable() error {
db.lock.Lock()
defer db.lock.Unlock()
// Short circuit if the database is in read only mode.
if db.readOnly {
return errDatabaseReadOnly
}
// Prevent duplicated disable operation.
if db.waitSync {
log.Error("Reject duplicated disable operation")
return nil
}
db.waitSync = true
// Mark the disk layer as stale to prevent access to persistent state.
db.tree.bottom().markStale()
// Write the initial sync flag to persist it across restarts.
rawdb.WriteSnapSyncStatusFlag(db.diskdb, rawdb.StateSyncRunning)
log.Info("Disabled trie database due to state sync")
return nil
}
// Enable activates database and resets the state tree with the provided persistent
// state root once the state sync is finished.
func (db *Database) Enable(root common.Hash) error {
db.lock.Lock()
defer db.lock.Unlock()
// Short circuit if the database is in read only mode.
if db.readOnly {
return errDatabaseReadOnly
}
// Ensure the provided state root matches the stored one.
root = types.TrieRootHash(root)
_, stored := rawdb.ReadAccountTrieNode(db.diskdb, nil)
if stored != root {
return fmt.Errorf("state root mismatch: stored %x, synced %x", stored, root)
}
// Drop the stale state journal in persistent database and
// reset the persistent state id back to zero.
batch := db.diskdb.NewBatch()
rawdb.DeleteTrieJournal(batch)
rawdb.WritePersistentStateID(batch, 0)
if err := batch.Write(); err != nil {
return err
}
// Clean up all state histories in freezer. Theoretically
// all root->id mappings should be removed as well. Since
// mappings can be huge and might take a while to clear
// them, just leave them in disk and wait for overwriting.
if db.freezer != nil {
if err := db.freezer.Reset(); err != nil {
return err
}
}
// Re-construct a new disk layer backed by persistent state
// with **empty clean cache and node buffer**.
db.tree.reset(newDiskLayer(root, 0, db, nil, newNodeBuffer(db.bufferSize, nil, 0)))
// Re-enable the database as the final step.
db.waitSync = false
rawdb.WriteSnapSyncStatusFlag(db.diskdb, rawdb.StateSyncFinished)
log.Info("Rebuilt trie database", "root", root)
return nil
}
// Recover rollbacks the database to a specified historical point.
// The state is supported as the rollback destination only if it's
// canonical state and the corresponding trie histories are existent.
func (db *Database) Recover(root common.Hash, loader triestate.TrieLoader) error {
db.lock.Lock()
defer db.lock.Unlock()
// Short circuit if rollback operation is not supported.
if err := db.modifyAllowed(); err != nil {
return err
}
if db.freezer == nil {
return errors.New("state rollback is non-supported")
}
// Short circuit if the target state is not recoverable.
root = types.TrieRootHash(root)
if !db.Recoverable(root) {
return errStateUnrecoverable
}
// Apply the state histories upon the disk layer in order.
var (
start = time.Now()
dl = db.tree.bottom()
)
for dl.rootHash() != root {
h, err := readHistory(db.freezer, dl.stateID())
if err != nil {
return err
}
dl, err = dl.revert(h, loader)
if err != nil {
return err
}
// reset layer with newly created disk layer. It must be
// done after each revert operation, otherwise the new
// disk layer won't be accessible from outside.
db.tree.reset(dl)
}
rawdb.DeleteTrieJournal(db.diskdb)
_, err := truncateFromHead(db.diskdb, db.freezer, dl.stateID())
if err != nil {
return err
}
log.Debug("Recovered state", "root", root, "elapsed", common.PrettyDuration(time.Since(start)))
return nil
}
// Recoverable returns the indicator if the specified state is recoverable.
func (db *Database) Recoverable(root common.Hash) bool {
// Ensure the requested state is a known state.
root = types.TrieRootHash(root)
id := rawdb.ReadStateID(db.diskdb, root)
if id == nil {
return false
}
// Recoverable state must below the disk layer. The recoverable
// state only refers the state that is currently not available,
// but can be restored by applying state history.
dl := db.tree.bottom()
if *id >= dl.stateID() {
return false
}
// Ensure the requested state is a canonical state and all state
// histories in range [id+1, disklayer.ID] are present and complete.
parent := root
return checkHistories(db.freezer, *id+1, dl.stateID()-*id, func(m *meta) error {
if m.parent != parent {
return errors.New("unexpected state history")
}
if len(m.incomplete) > 0 {
return errors.New("incomplete state history")
}
parent = m.root
return nil
}) == nil
}
// Close closes the trie database and the held freezer.
func (db *Database) Close() error {
db.lock.Lock()
defer db.lock.Unlock()
// Set the database to read-only mode to prevent all
// following mutations.
db.readOnly = true
// Release the memory held by clean cache.
db.tree.bottom().resetCache()
// Close the attached state history freezer.
if db.freezer == nil {
return nil
}
return db.freezer.Close()
}
// Size returns the current storage size of the memory cache in front of the
// persistent database layer.
func (db *Database) Size() (diffs common.StorageSize, nodes common.StorageSize) {
db.tree.forEach(func(layer layer) {
if diff, ok := layer.(*diffLayer); ok {
diffs += common.StorageSize(diff.memory)
}
if disk, ok := layer.(*diskLayer); ok {
nodes += disk.size()
}
})
return diffs, nodes
}
// Initialized returns an indicator if the state data is already
// initialized in path-based scheme.
func (db *Database) Initialized(genesisRoot common.Hash) bool {
var inited bool
db.tree.forEach(func(layer layer) {
if layer.rootHash() != types.EmptyRootHash {
inited = true
}
})
if !inited {
inited = rawdb.ReadSnapSyncStatusFlag(db.diskdb) != rawdb.StateSyncUnknown
}
return inited
}
// SetBufferSize sets the node buffer size to the provided value(in bytes).
func (db *Database) SetBufferSize(size int) error {
db.lock.Lock()
defer db.lock.Unlock()
if size > maxBufferSize {
log.Info("Capped node buffer size", "provided", common.StorageSize(size), "adjusted", common.StorageSize(maxBufferSize))
size = maxBufferSize
}
db.bufferSize = size
return db.tree.bottom().setBufferSize(db.bufferSize)
}
// Scheme returns the node scheme used in the database.
func (db *Database) Scheme() string {
return rawdb.PathScheme
}
// modifyAllowed returns the indicator if mutation is allowed. This function
// assumes the db.lock is already held.
func (db *Database) modifyAllowed() error {
if db.readOnly {
return errDatabaseReadOnly
}
if db.waitSync {
return errDatabaseWaitSync
}
return nil
}

View File

@ -0,0 +1,608 @@
// Copyright 2022 The go-ethereum Authors
// This file is part of the go-ethereum library.
//
// The go-ethereum library is free software: you can redistribute it and/or modify
// it under the terms of the GNU Lesser General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// The go-ethereum library is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Lesser General Public License for more details.
//
// You should have received a copy of the GNU Lesser General Public License
// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
package pathdb
import (
"bytes"
"errors"
"fmt"
"math/rand"
"testing"
"github.com/cerc-io/ipld-eth-statedb/trie_by_cid/trie/testutil"
"github.com/cerc-io/ipld-eth-statedb/trie_by_cid/trie/trienode"
"github.com/cerc-io/ipld-eth-statedb/trie_by_cid/trie/triestate"
"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/core/rawdb"
"github.com/ethereum/go-ethereum/core/types"
"github.com/ethereum/go-ethereum/crypto"
"github.com/ethereum/go-ethereum/rlp"
"github.com/holiman/uint256"
)
func updateTrie(addrHash common.Hash, root common.Hash, dirties, cleans map[common.Hash][]byte) (common.Hash, *trienode.NodeSet) {
h, err := newTestHasher(addrHash, root, cleans)
if err != nil {
panic(fmt.Errorf("failed to create hasher, err: %w", err))
}
for key, val := range dirties {
if len(val) == 0 {
h.Delete(key.Bytes())
} else {
h.Update(key.Bytes(), val)
}
}
root, nodes, _ := h.Commit(false)
return root, nodes
}
func generateAccount(storageRoot common.Hash) types.StateAccount {
return types.StateAccount{
Nonce: uint64(rand.Intn(100)),
Balance: uint256.NewInt(rand.Uint64()),
CodeHash: testutil.RandBytes(32),
Root: storageRoot,
}
}
const (
createAccountOp int = iota
modifyAccountOp
deleteAccountOp
opLen
)
type genctx struct {
accounts map[common.Hash][]byte
storages map[common.Hash]map[common.Hash][]byte
accountOrigin map[common.Address][]byte
storageOrigin map[common.Address]map[common.Hash][]byte
nodes *trienode.MergedNodeSet
}
func newCtx() *genctx {
return &genctx{
accounts: make(map[common.Hash][]byte),
storages: make(map[common.Hash]map[common.Hash][]byte),
accountOrigin: make(map[common.Address][]byte),
storageOrigin: make(map[common.Address]map[common.Hash][]byte),
nodes: trienode.NewMergedNodeSet(),
}
}
type tester struct {
db *Database
roots []common.Hash
preimages map[common.Hash]common.Address
accounts map[common.Hash][]byte
storages map[common.Hash]map[common.Hash][]byte
// state snapshots
snapAccounts map[common.Hash]map[common.Hash][]byte
snapStorages map[common.Hash]map[common.Hash]map[common.Hash][]byte
}
func newTester(t *testing.T, historyLimit uint64) *tester {
var (
disk, _ = rawdb.NewDatabaseWithFreezer(rawdb.NewMemoryDatabase(), t.TempDir(), "", false)
db = New(disk, &Config{
StateHistory: historyLimit,
CleanCacheSize: 256 * 1024,
DirtyCacheSize: 256 * 1024,
})
obj = &tester{
db: db,
preimages: make(map[common.Hash]common.Address),
accounts: make(map[common.Hash][]byte),
storages: make(map[common.Hash]map[common.Hash][]byte),
snapAccounts: make(map[common.Hash]map[common.Hash][]byte),
snapStorages: make(map[common.Hash]map[common.Hash]map[common.Hash][]byte),
}
)
for i := 0; i < 2*128; i++ {
var parent = types.EmptyRootHash
if len(obj.roots) != 0 {
parent = obj.roots[len(obj.roots)-1]
}
root, nodes, states := obj.generate(parent)
if err := db.Update(root, parent, uint64(i), nodes, states); err != nil {
panic(fmt.Errorf("failed to update state changes, err: %w", err))
}
obj.roots = append(obj.roots, root)
}
return obj
}
func (t *tester) release() {
t.db.Close()
t.db.diskdb.Close()
}
func (t *tester) randAccount() (common.Address, []byte) {
for addrHash, account := range t.accounts {
return t.preimages[addrHash], account
}
return common.Address{}, nil
}
func (t *tester) generateStorage(ctx *genctx, addr common.Address) common.Hash {
var (
addrHash = crypto.Keccak256Hash(addr.Bytes())
storage = make(map[common.Hash][]byte)
origin = make(map[common.Hash][]byte)
)
for i := 0; i < 10; i++ {
v, _ := rlp.EncodeToBytes(common.TrimLeftZeroes(testutil.RandBytes(32)))
hash := testutil.RandomHash()
storage[hash] = v
origin[hash] = nil
}
root, set := updateTrie(addrHash, types.EmptyRootHash, storage, nil)
ctx.storages[addrHash] = storage
ctx.storageOrigin[addr] = origin
ctx.nodes.Merge(set)
return root
}
func (t *tester) mutateStorage(ctx *genctx, addr common.Address, root common.Hash) common.Hash {
var (
addrHash = crypto.Keccak256Hash(addr.Bytes())
storage = make(map[common.Hash][]byte)
origin = make(map[common.Hash][]byte)
)
for hash, val := range t.storages[addrHash] {
origin[hash] = val
storage[hash] = nil
if len(origin) == 3 {
break
}
}
for i := 0; i < 3; i++ {
v, _ := rlp.EncodeToBytes(common.TrimLeftZeroes(testutil.RandBytes(32)))
hash := testutil.RandomHash()
storage[hash] = v
origin[hash] = nil
}
root, set := updateTrie(crypto.Keccak256Hash(addr.Bytes()), root, storage, t.storages[addrHash])
ctx.storages[addrHash] = storage
ctx.storageOrigin[addr] = origin
ctx.nodes.Merge(set)
return root
}
func (t *tester) clearStorage(ctx *genctx, addr common.Address, root common.Hash) common.Hash {
var (
addrHash = crypto.Keccak256Hash(addr.Bytes())
storage = make(map[common.Hash][]byte)
origin = make(map[common.Hash][]byte)
)
for hash, val := range t.storages[addrHash] {
origin[hash] = val
storage[hash] = nil
}
root, set := updateTrie(addrHash, root, storage, t.storages[addrHash])
if root != types.EmptyRootHash {
panic("failed to clear storage trie")
}
ctx.storages[addrHash] = storage
ctx.storageOrigin[addr] = origin
ctx.nodes.Merge(set)
return root
}
func (t *tester) generate(parent common.Hash) (common.Hash, *trienode.MergedNodeSet, *triestate.Set) {
var (
ctx = newCtx()
dirties = make(map[common.Hash]struct{})
)
for i := 0; i < 20; i++ {
switch rand.Intn(opLen) {
case createAccountOp:
// account creation
addr := testutil.RandomAddress()
addrHash := crypto.Keccak256Hash(addr.Bytes())
if _, ok := t.accounts[addrHash]; ok {
continue
}
if _, ok := dirties[addrHash]; ok {
continue
}
dirties[addrHash] = struct{}{}
root := t.generateStorage(ctx, addr)
ctx.accounts[addrHash] = types.SlimAccountRLP(generateAccount(root))
ctx.accountOrigin[addr] = nil
t.preimages[addrHash] = addr
case modifyAccountOp:
// account mutation
addr, account := t.randAccount()
if addr == (common.Address{}) {
continue
}
addrHash := crypto.Keccak256Hash(addr.Bytes())
if _, ok := dirties[addrHash]; ok {
continue
}
dirties[addrHash] = struct{}{}
acct, _ := types.FullAccount(account)
stRoot := t.mutateStorage(ctx, addr, acct.Root)
newAccount := types.SlimAccountRLP(generateAccount(stRoot))
ctx.accounts[addrHash] = newAccount
ctx.accountOrigin[addr] = account
case deleteAccountOp:
// account deletion
addr, account := t.randAccount()
if addr == (common.Address{}) {
continue
}
addrHash := crypto.Keccak256Hash(addr.Bytes())
if _, ok := dirties[addrHash]; ok {
continue
}
dirties[addrHash] = struct{}{}
acct, _ := types.FullAccount(account)
if acct.Root != types.EmptyRootHash {
t.clearStorage(ctx, addr, acct.Root)
}
ctx.accounts[addrHash] = nil
ctx.accountOrigin[addr] = account
}
}
root, set := updateTrie(common.Hash{}, parent, ctx.accounts, t.accounts)
ctx.nodes.Merge(set)
// Save state snapshot before commit
t.snapAccounts[parent] = copyAccounts(t.accounts)
t.snapStorages[parent] = copyStorages(t.storages)
// Commit all changes to live state set
for addrHash, account := range ctx.accounts {
if len(account) == 0 {
delete(t.accounts, addrHash)
} else {
t.accounts[addrHash] = account
}
}
for addrHash, slots := range ctx.storages {
if _, ok := t.storages[addrHash]; !ok {
t.storages[addrHash] = make(map[common.Hash][]byte)
}
for sHash, slot := range slots {
if len(slot) == 0 {
delete(t.storages[addrHash], sHash)
} else {
t.storages[addrHash][sHash] = slot
}
}
}
return root, ctx.nodes, triestate.New(ctx.accountOrigin, ctx.storageOrigin, nil)
}
// lastRoot returns the latest root hash, or empty if nothing is cached.
func (t *tester) lastHash() common.Hash {
if len(t.roots) == 0 {
return common.Hash{}
}
return t.roots[len(t.roots)-1]
}
func (t *tester) verifyState(root common.Hash) error {
reader, err := t.db.Reader(root)
if err != nil {
return err
}
_, err = reader.Node(common.Hash{}, nil, root)
if err != nil {
return errors.New("root node is not available")
}
for addrHash, account := range t.snapAccounts[root] {
blob, err := reader.Node(common.Hash{}, addrHash.Bytes(), crypto.Keccak256Hash(account))
if err != nil || !bytes.Equal(blob, account) {
return fmt.Errorf("account is mismatched: %w", err)
}
}
for addrHash, slots := range t.snapStorages[root] {
for hash, slot := range slots {
blob, err := reader.Node(addrHash, hash.Bytes(), crypto.Keccak256Hash(slot))
if err != nil || !bytes.Equal(blob, slot) {
return fmt.Errorf("slot is mismatched: %w", err)
}
}
}
return nil
}
func (t *tester) verifyHistory() error {
bottom := t.bottomIndex()
for i, root := range t.roots {
// The state history related to the state above disk layer should not exist.
if i > bottom {
_, err := readHistory(t.db.freezer, uint64(i+1))
if err == nil {
return errors.New("unexpected state history")
}
continue
}
// The state history related to the state below or equal to the disk layer
// should exist.
obj, err := readHistory(t.db.freezer, uint64(i+1))
if err != nil {
return err
}
parent := types.EmptyRootHash
if i != 0 {
parent = t.roots[i-1]
}
if obj.meta.parent != parent {
return fmt.Errorf("unexpected parent, want: %x, got: %x", parent, obj.meta.parent)
}
if obj.meta.root != root {
return fmt.Errorf("unexpected root, want: %x, got: %x", root, obj.meta.root)
}
}
return nil
}
// bottomIndex returns the index of current disk layer.
func (t *tester) bottomIndex() int {
bottom := t.db.tree.bottom()
for i := 0; i < len(t.roots); i++ {
if t.roots[i] == bottom.rootHash() {
return i
}
}
return -1
}
func TestDatabaseRollback(t *testing.T) {
// Verify state histories
tester := newTester(t, 0)
defer tester.release()
if err := tester.verifyHistory(); err != nil {
t.Fatalf("Invalid state history, err: %v", err)
}
// Revert database from top to bottom
for i := tester.bottomIndex(); i >= 0; i-- {
root := tester.roots[i]
parent := types.EmptyRootHash
if i > 0 {
parent = tester.roots[i-1]
}
loader := newHashLoader(tester.snapAccounts[root], tester.snapStorages[root])
if err := tester.db.Recover(parent, loader); err != nil {
t.Fatalf("Failed to revert db, err: %v", err)
}
tester.verifyState(parent)
}
if tester.db.tree.len() != 1 {
t.Fatal("Only disk layer is expected")
}
}
func TestDatabaseRecoverable(t *testing.T) {
var (
tester = newTester(t, 0)
index = tester.bottomIndex()
)
defer tester.release()
var cases = []struct {
root common.Hash
expect bool
}{
// Unknown state should be unrecoverable
{common.Hash{0x1}, false},
// Initial state should be recoverable
{types.EmptyRootHash, true},
// Initial state should be recoverable
{common.Hash{}, true},
// Layers below current disk layer are recoverable
{tester.roots[index-1], true},
// Disklayer itself is not recoverable, since it's
// available for accessing.
{tester.roots[index], false},
// Layers above current disk layer are not recoverable
// since they are available for accessing.
{tester.roots[index+1], false},
}
for i, c := range cases {
result := tester.db.Recoverable(c.root)
if result != c.expect {
t.Fatalf("case: %d, unexpected result, want %t, got %t", i, c.expect, result)
}
}
}
func TestDisable(t *testing.T) {
tester := newTester(t, 0)
defer tester.release()
_, stored := rawdb.ReadAccountTrieNode(tester.db.diskdb, nil)
if err := tester.db.Disable(); err != nil {
t.Fatal("Failed to deactivate database")
}
if err := tester.db.Enable(types.EmptyRootHash); err == nil {
t.Fatalf("Invalid activation should be rejected")
}
if err := tester.db.Enable(stored); err != nil {
t.Fatal("Failed to activate database")
}
// Ensure journal is deleted from disk
if blob := rawdb.ReadTrieJournal(tester.db.diskdb); len(blob) != 0 {
t.Fatal("Failed to clean journal")
}
// Ensure all trie histories are removed
n, err := tester.db.freezer.Ancients()
if err != nil {
t.Fatal("Failed to clean state history")
}
if n != 0 {
t.Fatal("Failed to clean state history")
}
// Verify layer tree structure, single disk layer is expected
if tester.db.tree.len() != 1 {
t.Fatalf("Extra layer kept %d", tester.db.tree.len())
}
if tester.db.tree.bottom().rootHash() != stored {
t.Fatalf("Root hash is not matched exp %x got %x", stored, tester.db.tree.bottom().rootHash())
}
}
func TestCommit(t *testing.T) {
tester := newTester(t, 0)
defer tester.release()
if err := tester.db.Commit(tester.lastHash(), false); err != nil {
t.Fatalf("Failed to cap database, err: %v", err)
}
// Verify layer tree structure, single disk layer is expected
if tester.db.tree.len() != 1 {
t.Fatal("Layer tree structure is invalid")
}
if tester.db.tree.bottom().rootHash() != tester.lastHash() {
t.Fatal("Layer tree structure is invalid")
}
// Verify states
if err := tester.verifyState(tester.lastHash()); err != nil {
t.Fatalf("State is invalid, err: %v", err)
}
// Verify state histories
if err := tester.verifyHistory(); err != nil {
t.Fatalf("State history is invalid, err: %v", err)
}
}
func TestJournal(t *testing.T) {
tester := newTester(t, 0)
defer tester.release()
if err := tester.db.Journal(tester.lastHash()); err != nil {
t.Errorf("Failed to journal, err: %v", err)
}
tester.db.Close()
tester.db = New(tester.db.diskdb, nil)
// Verify states including disk layer and all diff on top.
for i := 0; i < len(tester.roots); i++ {
if i >= tester.bottomIndex() {
if err := tester.verifyState(tester.roots[i]); err != nil {
t.Fatalf("Invalid state, err: %v", err)
}
continue
}
if err := tester.verifyState(tester.roots[i]); err == nil {
t.Fatal("Unexpected state")
}
}
}
func TestCorruptedJournal(t *testing.T) {
tester := newTester(t, 0)
defer tester.release()
if err := tester.db.Journal(tester.lastHash()); err != nil {
t.Errorf("Failed to journal, err: %v", err)
}
tester.db.Close()
_, root := rawdb.ReadAccountTrieNode(tester.db.diskdb, nil)
// Mutate the journal in disk, it should be regarded as invalid
blob := rawdb.ReadTrieJournal(tester.db.diskdb)
blob[0] = 1
rawdb.WriteTrieJournal(tester.db.diskdb, blob)
// Verify states, all not-yet-written states should be discarded
tester.db = New(tester.db.diskdb, nil)
for i := 0; i < len(tester.roots); i++ {
if tester.roots[i] == root {
if err := tester.verifyState(root); err != nil {
t.Fatalf("Disk state is corrupted, err: %v", err)
}
continue
}
if err := tester.verifyState(tester.roots[i]); err == nil {
t.Fatal("Unexpected state")
}
}
}
// TestTailTruncateHistory function is designed to test a specific edge case where,
// when history objects are removed from the end, it should trigger a state flush
// if the ID of the new tail object is even higher than the persisted state ID.
//
// For example, let's say the ID of the persistent state is 10, and the current
// history objects range from ID(5) to ID(15). As we accumulate six more objects,
// the history will expand to cover ID(11) to ID(21). ID(11) then becomes the
// oldest history object, and its ID is even higher than the stored state.
//
// In this scenario, it is mandatory to update the persistent state before
// truncating the tail histories. This ensures that the ID of the persistent state
// always falls within the range of [oldest-history-id, latest-history-id].
func TestTailTruncateHistory(t *testing.T) {
tester := newTester(t, 10)
defer tester.release()
tester.db.Close()
tester.db = New(tester.db.diskdb, &Config{StateHistory: 10})
head, err := tester.db.freezer.Ancients()
if err != nil {
t.Fatalf("Failed to obtain freezer head")
}
stored := rawdb.ReadPersistentStateID(tester.db.diskdb)
if head != stored {
t.Fatalf("Failed to truncate excess history object above, stored: %d, head: %d", stored, head)
}
}
// copyAccounts returns a deep-copied account set of the provided one.
func copyAccounts(set map[common.Hash][]byte) map[common.Hash][]byte {
copied := make(map[common.Hash][]byte, len(set))
for key, val := range set {
copied[key] = common.CopyBytes(val)
}
return copied
}
// copyStorages returns a deep-copied storage set of the provided one.
func copyStorages(set map[common.Hash]map[common.Hash][]byte) map[common.Hash]map[common.Hash][]byte {
copied := make(map[common.Hash]map[common.Hash][]byte, len(set))
for addrHash, subset := range set {
copied[addrHash] = make(map[common.Hash][]byte, len(subset))
for key, val := range subset {
copied[addrHash][key] = common.CopyBytes(val)
}
}
return copied
}

View File

@ -0,0 +1,174 @@
// Copyright 2022 The go-ethereum Authors
// This file is part of the go-ethereum library.
//
// The go-ethereum library is free software: you can redistribute it and/or modify
// it under the terms of the GNU Lesser General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// The go-ethereum library is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Lesser General Public License for more details.
//
// You should have received a copy of the GNU Lesser General Public License
// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
package pathdb
import (
"fmt"
"sync"
"github.com/cerc-io/ipld-eth-statedb/trie_by_cid/trie/trienode"
"github.com/cerc-io/ipld-eth-statedb/trie_by_cid/trie/triestate"
"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/log"
)
// diffLayer represents a collection of modifications made to the in-memory tries
// along with associated state changes after running a block on top.
//
// The goal of a diff layer is to act as a journal, tracking recent modifications
// made to the state, that have not yet graduated into a semi-immutable state.
type diffLayer struct {
// Immutables
root common.Hash // Root hash to which this layer diff belongs to
id uint64 // Corresponding state id
block uint64 // Associated block number
nodes map[common.Hash]map[string]*trienode.Node // Cached trie nodes indexed by owner and path
states *triestate.Set // Associated state change set for building history
memory uint64 // Approximate guess as to how much memory we use
parent layer // Parent layer modified by this one, never nil, **can be changed**
lock sync.RWMutex // Lock used to protect parent
}
// newDiffLayer creates a new diff layer on top of an existing layer.
func newDiffLayer(parent layer, root common.Hash, id uint64, block uint64, nodes map[common.Hash]map[string]*trienode.Node, states *triestate.Set) *diffLayer {
var (
size int64
count int
)
dl := &diffLayer{
root: root,
id: id,
block: block,
nodes: nodes,
states: states,
parent: parent,
}
for _, subset := range nodes {
for path, n := range subset {
dl.memory += uint64(n.Size() + len(path))
size += int64(len(n.Blob) + len(path))
}
count += len(subset)
}
if states != nil {
dl.memory += uint64(states.Size())
}
dirtyWriteMeter.Mark(size)
diffLayerNodesMeter.Mark(int64(count))
diffLayerBytesMeter.Mark(int64(dl.memory))
log.Debug("Created new diff layer", "id", id, "block", block, "nodes", count, "size", common.StorageSize(dl.memory))
return dl
}
// rootHash implements the layer interface, returning the root hash of
// corresponding state.
func (dl *diffLayer) rootHash() common.Hash {
return dl.root
}
// stateID implements the layer interface, returning the state id of the layer.
func (dl *diffLayer) stateID() uint64 {
return dl.id
}
// parentLayer implements the layer interface, returning the subsequent
// layer of the diff layer.
func (dl *diffLayer) parentLayer() layer {
dl.lock.RLock()
defer dl.lock.RUnlock()
return dl.parent
}
// node retrieves the node with provided node information. It's the internal
// version of Node function with additional accessed layer tracked. No error
// will be returned if node is not found.
func (dl *diffLayer) node(owner common.Hash, path []byte, hash common.Hash, depth int) ([]byte, error) {
// Hold the lock, ensure the parent won't be changed during the
// state accessing.
dl.lock.RLock()
defer dl.lock.RUnlock()
// If the trie node is known locally, return it
subset, ok := dl.nodes[owner]
if ok {
n, ok := subset[string(path)]
if ok {
// If the trie node is not hash matched, or marked as removed,
// bubble up an error here. It shouldn't happen at all.
if n.Hash != hash {
dirtyFalseMeter.Mark(1)
log.Error("Unexpected trie node in diff layer", "owner", owner, "path", path, "expect", hash, "got", n.Hash)
return nil, newUnexpectedNodeError("diff", hash, n.Hash, owner, path, n.Blob)
}
dirtyHitMeter.Mark(1)
dirtyNodeHitDepthHist.Update(int64(depth))
dirtyReadMeter.Mark(int64(len(n.Blob)))
return n.Blob, nil
}
}
// Trie node unknown to this layer, resolve from parent
if diff, ok := dl.parent.(*diffLayer); ok {
return diff.node(owner, path, hash, depth+1)
}
// Failed to resolve through diff layers, fallback to disk layer
return dl.parent.Node(owner, path, hash)
}
// Node implements the layer interface, retrieving the trie node blob with the
// provided node information. No error will be returned if the node is not found.
func (dl *diffLayer) Node(owner common.Hash, path []byte, hash common.Hash) ([]byte, error) {
return dl.node(owner, path, hash, 0)
}
// update implements the layer interface, creating a new layer on top of the
// existing layer tree with the specified data items.
func (dl *diffLayer) update(root common.Hash, id uint64, block uint64, nodes map[common.Hash]map[string]*trienode.Node, states *triestate.Set) *diffLayer {
return newDiffLayer(dl, root, id, block, nodes, states)
}
// persist flushes the diff layer and all its parent layers to disk layer.
func (dl *diffLayer) persist(force bool) (layer, error) {
if parent, ok := dl.parentLayer().(*diffLayer); ok {
// Hold the lock to prevent any read operation until the new
// parent is linked correctly.
dl.lock.Lock()
// The merging of diff layers starts at the bottom-most layer,
// therefore we recurse down here, flattening on the way up
// (diffToDisk).
result, err := parent.persist(force)
if err != nil {
dl.lock.Unlock()
return nil, err
}
dl.parent = result
dl.lock.Unlock()
}
return diffToDisk(dl, force)
}
// diffToDisk merges a bottom-most diff into the persistent disk layer underneath
// it. The method will panic if called onto a non-bottom-most diff layer.
func diffToDisk(layer *diffLayer, force bool) (layer, error) {
disk, ok := layer.parentLayer().(*diskLayer)
if !ok {
panic(fmt.Sprintf("unknown layer type: %T", layer.parentLayer()))
}
return disk.commit(layer, force)
}

View File

@ -0,0 +1,170 @@
// Copyright 2019 The go-ethereum Authors
// This file is part of the go-ethereum library.
//
// The go-ethereum library is free software: you can redistribute it and/or modify
// it under the terms of the GNU Lesser General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// The go-ethereum library is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Lesser General Public License for more details.
//
// You should have received a copy of the GNU Lesser General Public License
// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
package pathdb
import (
"bytes"
"testing"
"github.com/cerc-io/ipld-eth-statedb/trie_by_cid/trie/testutil"
"github.com/cerc-io/ipld-eth-statedb/trie_by_cid/trie/trienode"
"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/core/rawdb"
)
func emptyLayer() *diskLayer {
return &diskLayer{
db: New(rawdb.NewMemoryDatabase(), nil),
buffer: newNodeBuffer(DefaultBufferSize, nil, 0),
}
}
// goos: darwin
// goarch: arm64
// pkg: github.com/cerc-io/ipld-eth-statedb/trie_by_cid/trie
// BenchmarkSearch128Layers
// BenchmarkSearch128Layers-8 243826 4755 ns/op
func BenchmarkSearch128Layers(b *testing.B) { benchmarkSearch(b, 0, 128) }
// goos: darwin
// goarch: arm64
// pkg: github.com/cerc-io/ipld-eth-statedb/trie_by_cid/trie
// BenchmarkSearch512Layers
// BenchmarkSearch512Layers-8 49686 24256 ns/op
func BenchmarkSearch512Layers(b *testing.B) { benchmarkSearch(b, 0, 512) }
// goos: darwin
// goarch: arm64
// pkg: github.com/cerc-io/ipld-eth-statedb/trie_by_cid/trie
// BenchmarkSearch1Layer
// BenchmarkSearch1Layer-8 14062725 88.40 ns/op
func BenchmarkSearch1Layer(b *testing.B) { benchmarkSearch(b, 127, 128) }
func benchmarkSearch(b *testing.B, depth int, total int) {
var (
npath []byte
nhash common.Hash
nblob []byte
)
// First, we set up 128 diff layers, with 3K items each
fill := func(parent layer, index int) *diffLayer {
nodes := make(map[common.Hash]map[string]*trienode.Node)
nodes[common.Hash{}] = make(map[string]*trienode.Node)
for i := 0; i < 3000; i++ {
var (
path = testutil.RandBytes(32)
node = testutil.RandomNode()
)
nodes[common.Hash{}][string(path)] = trienode.New(node.Hash, node.Blob)
if npath == nil && depth == index {
npath = common.CopyBytes(path)
nblob = common.CopyBytes(node.Blob)
nhash = node.Hash
}
}
return newDiffLayer(parent, common.Hash{}, 0, 0, nodes, nil)
}
var layer layer
layer = emptyLayer()
for i := 0; i < total; i++ {
layer = fill(layer, i)
}
b.ResetTimer()
var (
have []byte
err error
)
for i := 0; i < b.N; i++ {
have, err = layer.Node(common.Hash{}, npath, nhash)
if err != nil {
b.Fatal(err)
}
}
if !bytes.Equal(have, nblob) {
b.Fatalf("have %x want %x", have, nblob)
}
}
// goos: darwin
// goarch: arm64
// pkg: github.com/cerc-io/ipld-eth-statedb/trie_by_cid/trie
// BenchmarkPersist
// BenchmarkPersist-8 10 111252975 ns/op
func BenchmarkPersist(b *testing.B) {
// First, we set up 128 diff layers, with 3K items each
fill := func(parent layer) *diffLayer {
nodes := make(map[common.Hash]map[string]*trienode.Node)
nodes[common.Hash{}] = make(map[string]*trienode.Node)
for i := 0; i < 3000; i++ {
var (
path = testutil.RandBytes(32)
node = testutil.RandomNode()
)
nodes[common.Hash{}][string(path)] = trienode.New(node.Hash, node.Blob)
}
return newDiffLayer(parent, common.Hash{}, 0, 0, nodes, nil)
}
for i := 0; i < b.N; i++ {
b.StopTimer()
var layer layer
layer = emptyLayer()
for i := 1; i < 128; i++ {
layer = fill(layer)
}
b.StartTimer()
dl, ok := layer.(*diffLayer)
if !ok {
break
}
dl.persist(false)
}
}
// BenchmarkJournal benchmarks the performance for journaling the layers.
//
// BenchmarkJournal
// BenchmarkJournal-8 10 110969279 ns/op
func BenchmarkJournal(b *testing.B) {
b.SkipNow()
// First, we set up 128 diff layers, with 3K items each
fill := func(parent layer) *diffLayer {
nodes := make(map[common.Hash]map[string]*trienode.Node)
nodes[common.Hash{}] = make(map[string]*trienode.Node)
for i := 0; i < 3000; i++ {
var (
path = testutil.RandBytes(32)
node = testutil.RandomNode()
)
nodes[common.Hash{}][string(path)] = trienode.New(node.Hash, node.Blob)
}
// TODO(rjl493456442) a non-nil state set is expected.
return newDiffLayer(parent, common.Hash{}, 0, 0, nodes, nil)
}
var layer layer
layer = emptyLayer()
for i := 0; i < 128; i++ {
layer = fill(layer)
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
layer.journal(new(bytes.Buffer))
}
}

View File

@ -0,0 +1,338 @@
// Copyright 2022 The go-ethereum Authors
// This file is part of the go-ethereum library.
//
// The go-ethereum library is free software: you can redistribute it and/or modify
// it under the terms of the GNU Lesser General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// The go-ethereum library is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Lesser General Public License for more details.
//
// You should have received a copy of the GNU Lesser General Public License
// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
package pathdb
import (
"errors"
"fmt"
"sync"
"github.com/VictoriaMetrics/fastcache"
"github.com/cerc-io/ipld-eth-statedb/trie_by_cid/trie/trienode"
"github.com/cerc-io/ipld-eth-statedb/trie_by_cid/trie/triestate"
"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/core/rawdb"
"github.com/ethereum/go-ethereum/crypto"
"github.com/ethereum/go-ethereum/log"
"golang.org/x/crypto/sha3"
)
// diskLayer is a low level persistent layer built on top of a key-value store.
type diskLayer struct {
root common.Hash // Immutable, root hash to which this layer was made for
id uint64 // Immutable, corresponding state id
db *Database // Path-based trie database
cleans *fastcache.Cache // GC friendly memory cache of clean node RLPs
buffer *nodebuffer // Node buffer to aggregate writes
stale bool // Signals that the layer became stale (state progressed)
lock sync.RWMutex // Lock used to protect stale flag
}
// newDiskLayer creates a new disk layer based on the passing arguments.
func newDiskLayer(root common.Hash, id uint64, db *Database, cleans *fastcache.Cache, buffer *nodebuffer) *diskLayer {
// Initialize a clean cache if the memory allowance is not zero
// or reuse the provided cache if it is not nil (inherited from
// the original disk layer).
if cleans == nil && db.config.CleanCacheSize != 0 {
cleans = fastcache.New(db.config.CleanCacheSize)
}
return &diskLayer{
root: root,
id: id,
db: db,
cleans: cleans,
buffer: buffer,
}
}
// root implements the layer interface, returning root hash of corresponding state.
func (dl *diskLayer) rootHash() common.Hash {
return dl.root
}
// stateID implements the layer interface, returning the state id of disk layer.
func (dl *diskLayer) stateID() uint64 {
return dl.id
}
// parent implements the layer interface, returning nil as there's no layer
// below the disk.
func (dl *diskLayer) parentLayer() layer {
return nil
}
// isStale return whether this layer has become stale (was flattened across) or if
// it's still live.
func (dl *diskLayer) isStale() bool {
dl.lock.RLock()
defer dl.lock.RUnlock()
return dl.stale
}
// markStale sets the stale flag as true.
func (dl *diskLayer) markStale() {
dl.lock.Lock()
defer dl.lock.Unlock()
if dl.stale {
panic("triedb disk layer is stale") // we've committed into the same base from two children, boom
}
dl.stale = true
}
// Node implements the layer interface, retrieving the trie node with the
// provided node info. No error will be returned if the node is not found.
func (dl *diskLayer) Node(owner common.Hash, path []byte, hash common.Hash) ([]byte, error) {
dl.lock.RLock()
defer dl.lock.RUnlock()
if dl.stale {
return nil, errSnapshotStale
}
// Try to retrieve the trie node from the not-yet-written
// node buffer first. Note the buffer is lock free since
// it's impossible to mutate the buffer before tagging the
// layer as stale.
n, err := dl.buffer.node(owner, path, hash)
if err != nil {
return nil, err
}
if n != nil {
dirtyHitMeter.Mark(1)
dirtyReadMeter.Mark(int64(len(n.Blob)))
return n.Blob, nil
}
dirtyMissMeter.Mark(1)
// Try to retrieve the trie node from the clean memory cache
key := cacheKey(owner, path)
if dl.cleans != nil {
if blob := dl.cleans.Get(nil, key); len(blob) > 0 {
h := newHasher()
defer h.release()
got := h.hash(blob)
if got == hash {
cleanHitMeter.Mark(1)
cleanReadMeter.Mark(int64(len(blob)))
return blob, nil
}
cleanFalseMeter.Mark(1)
log.Error("Unexpected trie node in clean cache", "owner", owner, "path", path, "expect", hash, "got", got)
}
cleanMissMeter.Mark(1)
}
// Try to retrieve the trie node from the disk.
var (
nBlob []byte
nHash common.Hash
)
if owner == (common.Hash{}) {
nBlob, nHash = rawdb.ReadAccountTrieNode(dl.db.diskdb, path)
} else {
nBlob, nHash = rawdb.ReadStorageTrieNode(dl.db.diskdb, owner, path)
}
if nHash != hash {
diskFalseMeter.Mark(1)
log.Error("Unexpected trie node in disk", "owner", owner, "path", path, "expect", hash, "got", nHash)
return nil, newUnexpectedNodeError("disk", hash, nHash, owner, path, nBlob)
}
if dl.cleans != nil && len(nBlob) > 0 {
dl.cleans.Set(key, nBlob)
cleanWriteMeter.Mark(int64(len(nBlob)))
}
return nBlob, nil
}
// update implements the layer interface, returning a new diff layer on top
// with the given state set.
func (dl *diskLayer) update(root common.Hash, id uint64, block uint64, nodes map[common.Hash]map[string]*trienode.Node, states *triestate.Set) *diffLayer {
return newDiffLayer(dl, root, id, block, nodes, states)
}
// commit merges the given bottom-most diff layer into the node buffer
// and returns a newly constructed disk layer. Note the current disk
// layer must be tagged as stale first to prevent re-access.
func (dl *diskLayer) commit(bottom *diffLayer, force bool) (*diskLayer, error) {
dl.lock.Lock()
defer dl.lock.Unlock()
// Construct and store the state history first. If crash happens after storing
// the state history but without flushing the corresponding states(journal),
// the stored state history will be truncated from head in the next restart.
var (
overflow bool
oldest uint64
)
if dl.db.freezer != nil {
err := writeHistory(dl.db.freezer, bottom)
if err != nil {
return nil, err
}
// Determine if the persisted history object has exceeded the configured
// limitation, set the overflow as true if so.
tail, err := dl.db.freezer.Tail()
if err != nil {
return nil, err
}
limit := dl.db.config.StateHistory
if limit != 0 && bottom.stateID()-tail > limit {
overflow = true
oldest = bottom.stateID() - limit + 1 // track the id of history **after truncation**
}
}
// Mark the diskLayer as stale before applying any mutations on top.
dl.stale = true
// Store the root->id lookup afterwards. All stored lookups are identified
// by the **unique** state root. It's impossible that in the same chain
// blocks are not adjacent but have the same root.
if dl.id == 0 {
rawdb.WriteStateID(dl.db.diskdb, dl.root, 0)
}
rawdb.WriteStateID(dl.db.diskdb, bottom.rootHash(), bottom.stateID())
// Construct a new disk layer by merging the nodes from the provided diff
// layer, and flush the content in disk layer if there are too many nodes
// cached. The clean cache is inherited from the original disk layer.
ndl := newDiskLayer(bottom.root, bottom.stateID(), dl.db, dl.cleans, dl.buffer.commit(bottom.nodes))
// In a unique scenario where the ID of the oldest history object (after tail
// truncation) surpasses the persisted state ID, we take the necessary action
// of forcibly committing the cached dirty nodes to ensure that the persisted
// state ID remains higher.
if !force && rawdb.ReadPersistentStateID(dl.db.diskdb) < oldest {
force = true
}
if err := ndl.buffer.flush(ndl.db.diskdb, ndl.cleans, ndl.id, force); err != nil {
return nil, err
}
// To remove outdated history objects from the end, we set the 'tail' parameter
// to 'oldest-1' due to the offset between the freezer index and the history ID.
if overflow {
pruned, err := truncateFromTail(ndl.db.diskdb, ndl.db.freezer, oldest-1)
if err != nil {
return nil, err
}
log.Debug("Pruned state history", "items", pruned, "tailid", oldest)
}
return ndl, nil
}
// revert applies the given state history and return a reverted disk layer.
func (dl *diskLayer) revert(h *history, loader triestate.TrieLoader) (*diskLayer, error) {
if h.meta.root != dl.rootHash() {
return nil, errUnexpectedHistory
}
// Reject if the provided state history is incomplete. It's due to
// a large construct SELF-DESTRUCT which can't be handled because
// of memory limitation.
if len(h.meta.incomplete) > 0 {
return nil, errors.New("incomplete state history")
}
if dl.id == 0 {
return nil, fmt.Errorf("%w: zero state id", errStateUnrecoverable)
}
// Apply the reverse state changes upon the current state. This must
// be done before holding the lock in order to access state in "this"
// layer.
nodes, err := triestate.Apply(h.meta.parent, h.meta.root, h.accounts, h.storages, loader)
if err != nil {
return nil, err
}
// Mark the diskLayer as stale before applying any mutations on top.
dl.lock.Lock()
defer dl.lock.Unlock()
dl.stale = true
// State change may be applied to node buffer, or the persistent
// state, depends on if node buffer is empty or not. If the node
// buffer is not empty, it means that the state transition that
// needs to be reverted is not yet flushed and cached in node
// buffer, otherwise, manipulate persistent state directly.
if !dl.buffer.empty() {
err := dl.buffer.revert(dl.db.diskdb, nodes)
if err != nil {
return nil, err
}
} else {
batch := dl.db.diskdb.NewBatch()
writeNodes(batch, nodes, dl.cleans)
rawdb.WritePersistentStateID(batch, dl.id-1)
if err := batch.Write(); err != nil {
log.Crit("Failed to write states", "err", err)
}
}
return newDiskLayer(h.meta.parent, dl.id-1, dl.db, dl.cleans, dl.buffer), nil
}
// setBufferSize sets the node buffer size to the provided value.
func (dl *diskLayer) setBufferSize(size int) error {
dl.lock.RLock()
defer dl.lock.RUnlock()
if dl.stale {
return errSnapshotStale
}
return dl.buffer.setSize(size, dl.db.diskdb, dl.cleans, dl.id)
}
// size returns the approximate size of cached nodes in the disk layer.
func (dl *diskLayer) size() common.StorageSize {
dl.lock.RLock()
defer dl.lock.RUnlock()
if dl.stale {
return 0
}
return common.StorageSize(dl.buffer.size)
}
// resetCache releases the memory held by clean cache to prevent memory leak.
func (dl *diskLayer) resetCache() {
dl.lock.RLock()
defer dl.lock.RUnlock()
// Stale disk layer loses the ownership of clean cache.
if dl.stale {
return
}
if dl.cleans != nil {
dl.cleans.Reset()
}
}
// hasher is used to compute the sha256 hash of the provided data.
type hasher struct{ sha crypto.KeccakState }
var hasherPool = sync.Pool{
New: func() interface{} { return &hasher{sha: sha3.NewLegacyKeccak256().(crypto.KeccakState)} },
}
func newHasher() *hasher {
return hasherPool.Get().(*hasher)
}
func (h *hasher) hash(data []byte) common.Hash {
return crypto.HashData(h.sha, data)
}
func (h *hasher) release() {
hasherPool.Put(h)
}

View File

@ -0,0 +1,60 @@
// Copyright 2023 The go-ethereum Authors
// This file is part of the go-ethereum library.
//
// The go-ethereum library is free software: you can redistribute it and/or modify
// it under the terms of the GNU Lesser General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// The go-ethereum library is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Lesser General Public License for more details.
//
// You should have received a copy of the GNU Lesser General Public License
// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>
package pathdb
import (
"errors"
"fmt"
"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/common/hexutil"
)
var (
// errDatabaseReadOnly is returned if the database is opened in read only mode
// to prevent any mutation.
errDatabaseReadOnly = errors.New("read only")
// errDatabaseWaitSync is returned if the initial state sync is not completed
// yet and database is disabled to prevent accessing state.
errDatabaseWaitSync = errors.New("waiting for sync")
// errSnapshotStale is returned from data accessors if the underlying layer
// layer had been invalidated due to the chain progressing forward far enough
// to not maintain the layer's original state.
errSnapshotStale = errors.New("layer stale")
// errUnexpectedHistory is returned if an unmatched state history is applied
// to the database for state rollback.
errUnexpectedHistory = errors.New("unexpected state history")
// errStateUnrecoverable is returned if state is required to be reverted to
// a destination without associated state history available.
errStateUnrecoverable = errors.New("state is unrecoverable")
// errUnexpectedNode is returned if the requested node with specified path is
// not hash matched with expectation.
errUnexpectedNode = errors.New("unexpected node")
)
func newUnexpectedNodeError(loc string, expHash common.Hash, gotHash common.Hash, owner common.Hash, path []byte, blob []byte) error {
blobHex := "nil"
if len(blob) > 0 {
blobHex = hexutil.Encode(blob)
}
return fmt.Errorf("%w, loc: %s, node: (%x %v), %x!=%x, blob: %s", errUnexpectedNode, loc, owner, path, expHash, gotHash, blobHex)
}

View File

@ -0,0 +1,649 @@
// Copyright 2022 The go-ethereum Authors
// This file is part of the go-ethereum library.
//
// The go-ethereum library is free software: you can redistribute it and/or modify
// it under the terms of the GNU Lesser General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// The go-ethereum library is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Lesser General Public License for more details.
//
// You should have received a copy of the GNU Lesser General Public License
// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>
package pathdb
import (
"bytes"
"encoding/binary"
"errors"
"fmt"
"time"
"github.com/cerc-io/ipld-eth-statedb/trie_by_cid/trie/triestate"
"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/core/rawdb"
"github.com/ethereum/go-ethereum/ethdb"
"github.com/ethereum/go-ethereum/log"
"golang.org/x/exp/slices"
)
// State history records the state changes involved in executing a block. The
// state can be reverted to the previous version by applying the associated
// history object (state reverse diff). State history objects are kept to
// guarantee that the system can perform state rollbacks in case of deep reorg.
//
// Each state transition will generate a state history object. Note that not
// every block has a corresponding state history object. If a block performs
// no state changes whatsoever, no state is created for it. Each state history
// will have a sequentially increasing number acting as its unique identifier.
//
// The state history is written to disk (ancient store) when the corresponding
// diff layer is merged into the disk layer. At the same time, system can prune
// the oldest histories according to config.
//
// Disk State
// ^
// |
// +------------+ +---------+ +---------+ +---------+
// | Init State |---->| State 1 |---->| ... |---->| State n |
// +------------+ +---------+ +---------+ +---------+
//
// +-----------+ +------+ +-----------+
// | History 1 |----> | ... |---->| History n |
// +-----------+ +------+ +-----------+
//
// # Rollback
//
// If the system wants to roll back to a previous state n, it needs to ensure
// all history objects from n+1 up to the current disk layer are existent. The
// history objects are applied to the state in reverse order, starting from the
// current disk layer.
const (
accountIndexSize = common.AddressLength + 13 // The length of encoded account index
slotIndexSize = common.HashLength + 5 // The length of encoded slot index
historyMetaSize = 9 + 2*common.HashLength // The length of fixed size part of meta object
stateHistoryVersion = uint8(0) // initial version of state history structure.
)
// Each state history entry is consisted of five elements:
//
// # metadata
// This object contains a few meta fields, such as the associated state root,
// block number, version tag and so on. This object may contain an extra
// accountHash list which means the storage changes belong to these accounts
// are not complete due to large contract destruction. The incomplete history
// can not be used for rollback and serving archive state request.
//
// # account index
// This object contains some index information of account. For example, offset
// and length indicate the location of the data belonging to the account. Besides,
// storageOffset and storageSlots indicate the storage modification location
// belonging to the account.
//
// The size of each account index is *fixed*, and all indexes are sorted
// lexicographically. Thus binary search can be performed to quickly locate a
// specific account.
//
// # account data
// Account data is a concatenated byte stream composed of all account data.
// The account data can be solved by the offset and length info indicated
// by corresponding account index.
//
// fixed size
// ^ ^
// / \
// +-----------------+-----------------+----------------+-----------------+
// | Account index 1 | Account index 2 | ... | Account index N |
// +-----------------+-----------------+----------------+-----------------+
// |
// | length
// offset |----------------+
// v v
// +----------------+----------------+----------------+----------------+
// | Account data 1 | Account data 2 | ... | Account data N |
// +----------------+----------------+----------------+----------------+
//
// # storage index
// This object is similar with account index. It's also fixed size and contains
// the location info of storage slot data.
//
// # storage data
// Storage data is a concatenated byte stream composed of all storage slot data.
// The storage slot data can be solved by the location info indicated by
// corresponding account index and storage slot index.
//
// fixed size
// ^ ^
// / \
// +-----------------+-----------------+----------------+-----------------+
// | Account index 1 | Account index 2 | ... | Account index N |
// +-----------------+-----------------+----------------+-----------------+
// |
// | storage slots
// storage offset |-----------------------------------------------------+
// v v
// +-----------------+-----------------+-----------------+
// | storage index 1 | storage index 2 | storage index 3 |
// +-----------------+-----------------+-----------------+
// | length
// offset |-------------+
// v v
// +-------------+
// | slot data 1 |
// +-------------+
// accountIndex describes the metadata belonging to an account.
type accountIndex struct {
address common.Address // The address of account
length uint8 // The length of account data, size limited by 255
offset uint32 // The offset of item in account data table
storageOffset uint32 // The offset of storage index in storage index table
storageSlots uint32 // The number of mutated storage slots belonging to the account
}
// encode packs account index into byte stream.
func (i *accountIndex) encode() []byte {
var buf [accountIndexSize]byte
copy(buf[:], i.address.Bytes())
buf[common.AddressLength] = i.length
binary.BigEndian.PutUint32(buf[common.AddressLength+1:], i.offset)
binary.BigEndian.PutUint32(buf[common.AddressLength+5:], i.storageOffset)
binary.BigEndian.PutUint32(buf[common.AddressLength+9:], i.storageSlots)
return buf[:]
}
// decode unpacks account index from byte stream.
func (i *accountIndex) decode(blob []byte) {
i.address = common.BytesToAddress(blob[:common.AddressLength])
i.length = blob[common.AddressLength]
i.offset = binary.BigEndian.Uint32(blob[common.AddressLength+1:])
i.storageOffset = binary.BigEndian.Uint32(blob[common.AddressLength+5:])
i.storageSlots = binary.BigEndian.Uint32(blob[common.AddressLength+9:])
}
// slotIndex describes the metadata belonging to a storage slot.
type slotIndex struct {
hash common.Hash // The hash of slot key
length uint8 // The length of storage slot, up to 32 bytes defined in protocol
offset uint32 // The offset of item in storage slot data table
}
// encode packs slot index into byte stream.
func (i *slotIndex) encode() []byte {
var buf [slotIndexSize]byte
copy(buf[:common.HashLength], i.hash.Bytes())
buf[common.HashLength] = i.length
binary.BigEndian.PutUint32(buf[common.HashLength+1:], i.offset)
return buf[:]
}
// decode unpack slot index from the byte stream.
func (i *slotIndex) decode(blob []byte) {
i.hash = common.BytesToHash(blob[:common.HashLength])
i.length = blob[common.HashLength]
i.offset = binary.BigEndian.Uint32(blob[common.HashLength+1:])
}
// meta describes the meta data of state history object.
type meta struct {
version uint8 // version tag of history object
parent common.Hash // prev-state root before the state transition
root common.Hash // post-state root after the state transition
block uint64 // associated block number
incomplete []common.Address // list of address whose storage set is incomplete
}
// encode packs the meta object into byte stream.
func (m *meta) encode() []byte {
buf := make([]byte, historyMetaSize+len(m.incomplete)*common.AddressLength)
buf[0] = m.version
copy(buf[1:1+common.HashLength], m.parent.Bytes())
copy(buf[1+common.HashLength:1+2*common.HashLength], m.root.Bytes())
binary.BigEndian.PutUint64(buf[1+2*common.HashLength:historyMetaSize], m.block)
for i, h := range m.incomplete {
copy(buf[i*common.AddressLength+historyMetaSize:], h.Bytes())
}
return buf[:]
}
// decode unpacks the meta object from byte stream.
func (m *meta) decode(blob []byte) error {
if len(blob) < 1 {
return fmt.Errorf("no version tag")
}
switch blob[0] {
case stateHistoryVersion:
if len(blob) < historyMetaSize {
return fmt.Errorf("invalid state history meta, len: %d", len(blob))
}
if (len(blob)-historyMetaSize)%common.AddressLength != 0 {
return fmt.Errorf("corrupted state history meta, len: %d", len(blob))
}
m.version = blob[0]
m.parent = common.BytesToHash(blob[1 : 1+common.HashLength])
m.root = common.BytesToHash(blob[1+common.HashLength : 1+2*common.HashLength])
m.block = binary.BigEndian.Uint64(blob[1+2*common.HashLength : historyMetaSize])
for pos := historyMetaSize; pos < len(blob); {
m.incomplete = append(m.incomplete, common.BytesToAddress(blob[pos:pos+common.AddressLength]))
pos += common.AddressLength
}
return nil
default:
return fmt.Errorf("unknown version %d", blob[0])
}
}
// history represents a set of state changes belong to a block along with
// the metadata including the state roots involved in the state transition.
// State history objects in disk are linked with each other by a unique id
// (8-bytes integer), the oldest state history object can be pruned on demand
// in order to control the storage size.
type history struct {
meta *meta // Meta data of history
accounts map[common.Address][]byte // Account data keyed by its address hash
accountList []common.Address // Sorted account hash list
storages map[common.Address]map[common.Hash][]byte // Storage data keyed by its address hash and slot hash
storageList map[common.Address][]common.Hash // Sorted slot hash list
}
// newHistory constructs the state history object with provided state change set.
func newHistory(root common.Hash, parent common.Hash, block uint64, states *triestate.Set) *history {
var (
accountList []common.Address
storageList = make(map[common.Address][]common.Hash)
incomplete []common.Address
)
for addr := range states.Accounts {
accountList = append(accountList, addr)
}
slices.SortFunc(accountList, common.Address.Cmp)
for addr, slots := range states.Storages {
slist := make([]common.Hash, 0, len(slots))
for slotHash := range slots {
slist = append(slist, slotHash)
}
slices.SortFunc(slist, common.Hash.Cmp)
storageList[addr] = slist
}
for addr := range states.Incomplete {
incomplete = append(incomplete, addr)
}
slices.SortFunc(incomplete, common.Address.Cmp)
return &history{
meta: &meta{
version: stateHistoryVersion,
parent: parent,
root: root,
block: block,
incomplete: incomplete,
},
accounts: states.Accounts,
accountList: accountList,
storages: states.Storages,
storageList: storageList,
}
}
// encode serializes the state history and returns four byte streams represent
// concatenated account/storage data, account/storage indexes respectively.
func (h *history) encode() ([]byte, []byte, []byte, []byte) {
var (
slotNumber uint32 // the number of processed slots
accountData []byte // the buffer for concatenated account data
storageData []byte // the buffer for concatenated storage data
accountIndexes []byte // the buffer for concatenated account index
storageIndexes []byte // the buffer for concatenated storage index
)
for _, addr := range h.accountList {
accIndex := accountIndex{
address: addr,
length: uint8(len(h.accounts[addr])),
offset: uint32(len(accountData)),
}
slots, exist := h.storages[addr]
if exist {
// Encode storage slots in order
for _, slotHash := range h.storageList[addr] {
sIndex := slotIndex{
hash: slotHash,
length: uint8(len(slots[slotHash])),
offset: uint32(len(storageData)),
}
storageData = append(storageData, slots[slotHash]...)
storageIndexes = append(storageIndexes, sIndex.encode()...)
}
// Fill up the storage meta in account index
accIndex.storageOffset = slotNumber
accIndex.storageSlots = uint32(len(slots))
slotNumber += uint32(len(slots))
}
accountData = append(accountData, h.accounts[addr]...)
accountIndexes = append(accountIndexes, accIndex.encode()...)
}
return accountData, storageData, accountIndexes, storageIndexes
}
// decoder wraps the byte streams for decoding with extra meta fields.
type decoder struct {
accountData []byte // the buffer for concatenated account data
storageData []byte // the buffer for concatenated storage data
accountIndexes []byte // the buffer for concatenated account index
storageIndexes []byte // the buffer for concatenated storage index
lastAccount *common.Address // the address of last resolved account
lastAccountRead uint32 // the read-cursor position of account data
lastSlotIndexRead uint32 // the read-cursor position of storage slot index
lastSlotDataRead uint32 // the read-cursor position of storage slot data
}
// verify validates the provided byte streams for decoding state history. A few
// checks will be performed to quickly detect data corruption. The byte stream
// is regarded as corrupted if:
//
// - account indexes buffer is empty(empty state set is invalid)
// - account indexes/storage indexer buffer is not aligned
//
// note, these situations are allowed:
//
// - empty account data: all accounts were not present
// - empty storage set: no slots are modified
func (r *decoder) verify() error {
if len(r.accountIndexes)%accountIndexSize != 0 || len(r.accountIndexes) == 0 {
return fmt.Errorf("invalid account index, len: %d", len(r.accountIndexes))
}
if len(r.storageIndexes)%slotIndexSize != 0 {
return fmt.Errorf("invalid storage index, len: %d", len(r.storageIndexes))
}
return nil
}
// readAccount parses the account from the byte stream with specified position.
func (r *decoder) readAccount(pos int) (accountIndex, []byte, error) {
// Decode account index from the index byte stream.
var index accountIndex
if (pos+1)*accountIndexSize > len(r.accountIndexes) {
return accountIndex{}, nil, errors.New("account data buffer is corrupted")
}
index.decode(r.accountIndexes[pos*accountIndexSize : (pos+1)*accountIndexSize])
// Perform validation before parsing account data, ensure
// - account is sorted in order in byte stream
// - account data is strictly encoded with no gap inside
// - account data is not out-of-slice
if r.lastAccount != nil { // zero address is possible
if bytes.Compare(r.lastAccount.Bytes(), index.address.Bytes()) >= 0 {
return accountIndex{}, nil, errors.New("account is not in order")
}
}
if index.offset != r.lastAccountRead {
return accountIndex{}, nil, errors.New("account data buffer is gaped")
}
last := index.offset + uint32(index.length)
if uint32(len(r.accountData)) < last {
return accountIndex{}, nil, errors.New("account data buffer is corrupted")
}
data := r.accountData[index.offset:last]
r.lastAccount = &index.address
r.lastAccountRead = last
return index, data, nil
}
// readStorage parses the storage slots from the byte stream with specified account.
func (r *decoder) readStorage(accIndex accountIndex) ([]common.Hash, map[common.Hash][]byte, error) {
var (
last common.Hash
list []common.Hash
storage = make(map[common.Hash][]byte)
)
for j := 0; j < int(accIndex.storageSlots); j++ {
var (
index slotIndex
start = (accIndex.storageOffset + uint32(j)) * uint32(slotIndexSize)
end = (accIndex.storageOffset + uint32(j+1)) * uint32(slotIndexSize)
)
// Perform validation before parsing storage slot data, ensure
// - slot index is not out-of-slice
// - slot data is not out-of-slice
// - slot is sorted in order in byte stream
// - slot indexes is strictly encoded with no gap inside
// - slot data is strictly encoded with no gap inside
if start != r.lastSlotIndexRead {
return nil, nil, errors.New("storage index buffer is gapped")
}
if uint32(len(r.storageIndexes)) < end {
return nil, nil, errors.New("storage index buffer is corrupted")
}
index.decode(r.storageIndexes[start:end])
if bytes.Compare(last.Bytes(), index.hash.Bytes()) >= 0 {
return nil, nil, errors.New("storage slot is not in order")
}
if index.offset != r.lastSlotDataRead {
return nil, nil, errors.New("storage data buffer is gapped")
}
sEnd := index.offset + uint32(index.length)
if uint32(len(r.storageData)) < sEnd {
return nil, nil, errors.New("storage data buffer is corrupted")
}
storage[index.hash] = r.storageData[r.lastSlotDataRead:sEnd]
list = append(list, index.hash)
last = index.hash
r.lastSlotIndexRead = end
r.lastSlotDataRead = sEnd
}
return list, storage, nil
}
// decode deserializes the account and storage data from the provided byte stream.
func (h *history) decode(accountData, storageData, accountIndexes, storageIndexes []byte) error {
var (
accounts = make(map[common.Address][]byte)
storages = make(map[common.Address]map[common.Hash][]byte)
accountList []common.Address
storageList = make(map[common.Address][]common.Hash)
r = &decoder{
accountData: accountData,
storageData: storageData,
accountIndexes: accountIndexes,
storageIndexes: storageIndexes,
}
)
if err := r.verify(); err != nil {
return err
}
for i := 0; i < len(accountIndexes)/accountIndexSize; i++ {
// Resolve account first
accIndex, accData, err := r.readAccount(i)
if err != nil {
return err
}
accounts[accIndex.address] = accData
accountList = append(accountList, accIndex.address)
// Resolve storage slots
slotList, slotData, err := r.readStorage(accIndex)
if err != nil {
return err
}
if len(slotList) > 0 {
storageList[accIndex.address] = slotList
storages[accIndex.address] = slotData
}
}
h.accounts = accounts
h.accountList = accountList
h.storages = storages
h.storageList = storageList
return nil
}
// readHistory reads and decodes the state history object by the given id.
func readHistory(freezer *rawdb.ResettableFreezer, id uint64) (*history, error) {
blob := rawdb.ReadStateHistoryMeta(freezer, id)
if len(blob) == 0 {
return nil, fmt.Errorf("state history not found %d", id)
}
var m meta
if err := m.decode(blob); err != nil {
return nil, err
}
var (
dec = history{meta: &m}
accountData = rawdb.ReadStateAccountHistory(freezer, id)
storageData = rawdb.ReadStateStorageHistory(freezer, id)
accountIndexes = rawdb.ReadStateAccountIndex(freezer, id)
storageIndexes = rawdb.ReadStateStorageIndex(freezer, id)
)
if err := dec.decode(accountData, storageData, accountIndexes, storageIndexes); err != nil {
return nil, err
}
return &dec, nil
}
// writeHistory persists the state history with the provided state set.
func writeHistory(freezer *rawdb.ResettableFreezer, dl *diffLayer) error {
// Short circuit if state set is not available.
if dl.states == nil {
return errors.New("state change set is not available")
}
var (
start = time.Now()
history = newHistory(dl.rootHash(), dl.parentLayer().rootHash(), dl.block, dl.states)
)
accountData, storageData, accountIndex, storageIndex := history.encode()
dataSize := common.StorageSize(len(accountData) + len(storageData))
indexSize := common.StorageSize(len(accountIndex) + len(storageIndex))
// Write history data into five freezer table respectively.
rawdb.WriteStateHistory(freezer, dl.stateID(), history.meta.encode(), accountIndex, storageIndex, accountData, storageData)
historyDataBytesMeter.Mark(int64(dataSize))
historyIndexBytesMeter.Mark(int64(indexSize))
historyBuildTimeMeter.UpdateSince(start)
log.Debug("Stored state history", "id", dl.stateID(), "block", dl.block, "data", dataSize, "index", indexSize, "elapsed", common.PrettyDuration(time.Since(start)))
return nil
}
// checkHistories retrieves a batch of meta objects with the specified range
// and performs the callback on each item.
func checkHistories(freezer *rawdb.ResettableFreezer, start, count uint64, check func(*meta) error) error {
for count > 0 {
number := count
if number > 10000 {
number = 10000 // split the big read into small chunks
}
blobs, err := rawdb.ReadStateHistoryMetaList(freezer, start, number)
if err != nil {
return err
}
for _, blob := range blobs {
var dec meta
if err := dec.decode(blob); err != nil {
return err
}
if err := check(&dec); err != nil {
return err
}
}
count -= uint64(len(blobs))
start += uint64(len(blobs))
}
return nil
}
// truncateFromHead removes the extra state histories from the head with the given
// parameters. It returns the number of items removed from the head.
func truncateFromHead(db ethdb.Batcher, freezer *rawdb.ResettableFreezer, nhead uint64) (int, error) {
ohead, err := freezer.Ancients()
if err != nil {
return 0, err
}
otail, err := freezer.Tail()
if err != nil {
return 0, err
}
// Ensure that the truncation target falls within the specified range.
if ohead < nhead || nhead < otail {
return 0, fmt.Errorf("out of range, tail: %d, head: %d, target: %d", otail, ohead, nhead)
}
// Short circuit if nothing to truncate.
if ohead == nhead {
return 0, nil
}
// Load the meta objects in range [nhead+1, ohead]
blobs, err := rawdb.ReadStateHistoryMetaList(freezer, nhead+1, ohead-nhead)
if err != nil {
return 0, err
}
batch := db.NewBatch()
for _, blob := range blobs {
var m meta
if err := m.decode(blob); err != nil {
return 0, err
}
rawdb.DeleteStateID(batch, m.root)
}
if err := batch.Write(); err != nil {
return 0, err
}
ohead, err = freezer.TruncateHead(nhead)
if err != nil {
return 0, err
}
return int(ohead - nhead), nil
}
// truncateFromTail removes the extra state histories from the tail with the given
// parameters. It returns the number of items removed from the tail.
func truncateFromTail(db ethdb.Batcher, freezer *rawdb.ResettableFreezer, ntail uint64) (int, error) {
ohead, err := freezer.Ancients()
if err != nil {
return 0, err
}
otail, err := freezer.Tail()
if err != nil {
return 0, err
}
// Ensure that the truncation target falls within the specified range.
if otail > ntail || ntail > ohead {
return 0, fmt.Errorf("out of range, tail: %d, head: %d, target: %d", otail, ohead, ntail)
}
// Short circuit if nothing to truncate.
if otail == ntail {
return 0, nil
}
// Load the meta objects in range [otail+1, ntail]
blobs, err := rawdb.ReadStateHistoryMetaList(freezer, otail+1, ntail-otail)
if err != nil {
return 0, err
}
batch := db.NewBatch()
for _, blob := range blobs {
var m meta
if err := m.decode(blob); err != nil {
return 0, err
}
rawdb.DeleteStateID(batch, m.root)
}
if err := batch.Write(); err != nil {
return 0, err
}
otail, err = freezer.TruncateTail(ntail)
if err != nil {
return 0, err
}
return int(ntail - otail), nil
}

View File

@ -0,0 +1,334 @@
// Copyright 2022 The go-ethereum Authors
// This file is part of the go-ethereum library.
//
// The go-ethereum library is free software: you can redistribute it and/or modify
// it under the terms of the GNU Lesser General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// The go-ethereum library is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Lesser General Public License for more details.
//
// You should have received a copy of the GNU Lesser General Public License
// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>
package pathdb
import (
"bytes"
"fmt"
"reflect"
"testing"
"github.com/cerc-io/ipld-eth-statedb/trie_by_cid/trie/testutil"
"github.com/cerc-io/ipld-eth-statedb/trie_by_cid/trie/triestate"
"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/core/rawdb"
"github.com/ethereum/go-ethereum/core/types"
"github.com/ethereum/go-ethereum/ethdb"
"github.com/ethereum/go-ethereum/rlp"
)
// randomStateSet generates a random state change set.
func randomStateSet(n int) *triestate.Set {
var (
accounts = make(map[common.Address][]byte)
storages = make(map[common.Address]map[common.Hash][]byte)
)
for i := 0; i < n; i++ {
addr := testutil.RandomAddress()
storages[addr] = make(map[common.Hash][]byte)
for j := 0; j < 3; j++ {
v, _ := rlp.EncodeToBytes(common.TrimLeftZeroes(testutil.RandBytes(32)))
storages[addr][testutil.RandomHash()] = v
}
account := generateAccount(types.EmptyRootHash)
accounts[addr] = types.SlimAccountRLP(account)
}
return triestate.New(accounts, storages, nil)
}
func makeHistory() *history {
return newHistory(testutil.RandomHash(), types.EmptyRootHash, 0, randomStateSet(3))
}
func makeHistories(n int) []*history {
var (
parent = types.EmptyRootHash
result []*history
)
for i := 0; i < n; i++ {
root := testutil.RandomHash()
h := newHistory(root, parent, uint64(i), randomStateSet(3))
parent = root
result = append(result, h)
}
return result
}
func TestEncodeDecodeHistory(t *testing.T) {
var (
m meta
dec history
obj = makeHistory()
)
// check if meta data can be correctly encode/decode
blob := obj.meta.encode()
if err := m.decode(blob); err != nil {
t.Fatalf("Failed to decode %v", err)
}
if !reflect.DeepEqual(&m, obj.meta) {
t.Fatal("meta is mismatched")
}
// check if account/storage data can be correctly encode/decode
accountData, storageData, accountIndexes, storageIndexes := obj.encode()
if err := dec.decode(accountData, storageData, accountIndexes, storageIndexes); err != nil {
t.Fatalf("Failed to decode, err: %v", err)
}
if !compareSet(dec.accounts, obj.accounts) {
t.Fatal("account data is mismatched")
}
if !compareStorages(dec.storages, obj.storages) {
t.Fatal("storage data is mismatched")
}
if !compareList(dec.accountList, obj.accountList) {
t.Fatal("account list is mismatched")
}
if !compareStorageList(dec.storageList, obj.storageList) {
t.Fatal("storage list is mismatched")
}
}
func checkHistory(t *testing.T, db ethdb.KeyValueReader, freezer *rawdb.ResettableFreezer, id uint64, root common.Hash, exist bool) {
blob := rawdb.ReadStateHistoryMeta(freezer, id)
if exist && len(blob) == 0 {
t.Fatalf("Failed to load trie history, %d", id)
}
if !exist && len(blob) != 0 {
t.Fatalf("Unexpected trie history, %d", id)
}
if exist && rawdb.ReadStateID(db, root) == nil {
t.Fatalf("Root->ID mapping is not found, %d", id)
}
if !exist && rawdb.ReadStateID(db, root) != nil {
t.Fatalf("Unexpected root->ID mapping, %d", id)
}
}
func checkHistoriesInRange(t *testing.T, db ethdb.KeyValueReader, freezer *rawdb.ResettableFreezer, from, to uint64, roots []common.Hash, exist bool) {
for i, j := from, 0; i <= to; i, j = i+1, j+1 {
checkHistory(t, db, freezer, i, roots[j], exist)
}
}
func TestTruncateHeadHistory(t *testing.T) {
var (
roots []common.Hash
hs = makeHistories(10)
db = rawdb.NewMemoryDatabase()
freezer, _ = openFreezer(t.TempDir(), false)
)
defer freezer.Close()
for i := 0; i < len(hs); i++ {
accountData, storageData, accountIndex, storageIndex := hs[i].encode()
rawdb.WriteStateHistory(freezer, uint64(i+1), hs[i].meta.encode(), accountIndex, storageIndex, accountData, storageData)
rawdb.WriteStateID(db, hs[i].meta.root, uint64(i+1))
roots = append(roots, hs[i].meta.root)
}
for size := len(hs); size > 0; size-- {
pruned, err := truncateFromHead(db, freezer, uint64(size-1))
if err != nil {
t.Fatalf("Failed to truncate from head %v", err)
}
if pruned != 1 {
t.Error("Unexpected pruned items", "want", 1, "got", pruned)
}
checkHistoriesInRange(t, db, freezer, uint64(size), uint64(10), roots[size-1:], false)
checkHistoriesInRange(t, db, freezer, uint64(1), uint64(size-1), roots[:size-1], true)
}
}
func TestTruncateTailHistory(t *testing.T) {
var (
roots []common.Hash
hs = makeHistories(10)
db = rawdb.NewMemoryDatabase()
freezer, _ = openFreezer(t.TempDir(), false)
)
defer freezer.Close()
for i := 0; i < len(hs); i++ {
accountData, storageData, accountIndex, storageIndex := hs[i].encode()
rawdb.WriteStateHistory(freezer, uint64(i+1), hs[i].meta.encode(), accountIndex, storageIndex, accountData, storageData)
rawdb.WriteStateID(db, hs[i].meta.root, uint64(i+1))
roots = append(roots, hs[i].meta.root)
}
for newTail := 1; newTail < len(hs); newTail++ {
pruned, _ := truncateFromTail(db, freezer, uint64(newTail))
if pruned != 1 {
t.Error("Unexpected pruned items", "want", 1, "got", pruned)
}
checkHistoriesInRange(t, db, freezer, uint64(1), uint64(newTail), roots[:newTail], false)
checkHistoriesInRange(t, db, freezer, uint64(newTail+1), uint64(10), roots[newTail:], true)
}
}
func TestTruncateTailHistories(t *testing.T) {
var cases = []struct {
limit uint64
expPruned int
maxPruned uint64
minUnpruned uint64
empty bool
}{
{
1, 9, 9, 10, false,
},
{
0, 10, 10, 0 /* no meaning */, true,
},
{
10, 0, 0, 1, false,
},
}
for i, c := range cases {
var (
roots []common.Hash
hs = makeHistories(10)
db = rawdb.NewMemoryDatabase()
freezer, _ = openFreezer(t.TempDir()+fmt.Sprintf("%d", i), false)
)
defer freezer.Close()
for i := 0; i < len(hs); i++ {
accountData, storageData, accountIndex, storageIndex := hs[i].encode()
rawdb.WriteStateHistory(freezer, uint64(i+1), hs[i].meta.encode(), accountIndex, storageIndex, accountData, storageData)
rawdb.WriteStateID(db, hs[i].meta.root, uint64(i+1))
roots = append(roots, hs[i].meta.root)
}
pruned, _ := truncateFromTail(db, freezer, uint64(10)-c.limit)
if pruned != c.expPruned {
t.Error("Unexpected pruned items", "want", c.expPruned, "got", pruned)
}
if c.empty {
checkHistoriesInRange(t, db, freezer, uint64(1), uint64(10), roots, false)
} else {
tail := 10 - int(c.limit)
checkHistoriesInRange(t, db, freezer, uint64(1), c.maxPruned, roots[:tail], false)
checkHistoriesInRange(t, db, freezer, c.minUnpruned, uint64(10), roots[tail:], true)
}
}
}
func TestTruncateOutOfRange(t *testing.T) {
var (
hs = makeHistories(10)
db = rawdb.NewMemoryDatabase()
freezer, _ = openFreezer(t.TempDir(), false)
)
defer freezer.Close()
for i := 0; i < len(hs); i++ {
accountData, storageData, accountIndex, storageIndex := hs[i].encode()
rawdb.WriteStateHistory(freezer, uint64(i+1), hs[i].meta.encode(), accountIndex, storageIndex, accountData, storageData)
rawdb.WriteStateID(db, hs[i].meta.root, uint64(i+1))
}
truncateFromTail(db, freezer, uint64(len(hs)/2))
// Ensure of-out-range truncations are rejected correctly.
head, _ := freezer.Ancients()
tail, _ := freezer.Tail()
cases := []struct {
mode int
target uint64
expErr error
}{
{0, head, nil}, // nothing to delete
{0, head + 1, fmt.Errorf("out of range, tail: %d, head: %d, target: %d", tail, head, head+1)},
{0, tail - 1, fmt.Errorf("out of range, tail: %d, head: %d, target: %d", tail, head, tail-1)},
{1, tail, nil}, // nothing to delete
{1, head + 1, fmt.Errorf("out of range, tail: %d, head: %d, target: %d", tail, head, head+1)},
{1, tail - 1, fmt.Errorf("out of range, tail: %d, head: %d, target: %d", tail, head, tail-1)},
}
for _, c := range cases {
var gotErr error
if c.mode == 0 {
_, gotErr = truncateFromHead(db, freezer, c.target)
} else {
_, gotErr = truncateFromTail(db, freezer, c.target)
}
if !reflect.DeepEqual(gotErr, c.expErr) {
t.Errorf("Unexpected error, want: %v, got: %v", c.expErr, gotErr)
}
}
}
// openFreezer initializes the freezer instance for storing state histories.
func openFreezer(datadir string, readOnly bool) (*rawdb.ResettableFreezer, error) {
return rawdb.NewStateFreezer(datadir, readOnly)
}
func compareSet[k comparable](a, b map[k][]byte) bool {
if len(a) != len(b) {
return false
}
for key, valA := range a {
valB, ok := b[key]
if !ok {
return false
}
if !bytes.Equal(valA, valB) {
return false
}
}
return true
}
func compareList[k comparable](a, b []k) bool {
if len(a) != len(b) {
return false
}
for i := 0; i < len(a); i++ {
if a[i] != b[i] {
return false
}
}
return true
}
func compareStorages(a, b map[common.Address]map[common.Hash][]byte) bool {
if len(a) != len(b) {
return false
}
for h, subA := range a {
subB, ok := b[h]
if !ok {
return false
}
if !compareSet(subA, subB) {
return false
}
}
return true
}
func compareStorageList(a, b map[common.Address][]common.Hash) bool {
if len(a) != len(b) {
return false
}
for h, la := range a {
lb, ok := b[h]
if !ok {
return false
}
if !compareList(la, lb) {
return false
}
}
return true
}

View File

@ -0,0 +1,387 @@
// Copyright 2022 The go-ethereum Authors
// This file is part of the go-ethereum library.
//
// The go-ethereum library is free software: you can redistribute it and/or modify
// it under the terms of the GNU Lesser General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// The go-ethereum library is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Lesser General Public License for more details.
//
// You should have received a copy of the GNU Lesser General Public License
// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
package pathdb
import (
"bytes"
"errors"
"fmt"
"io"
"time"
"github.com/cerc-io/ipld-eth-statedb/trie_by_cid/trie/trienode"
"github.com/cerc-io/ipld-eth-statedb/trie_by_cid/trie/triestate"
"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/core/rawdb"
"github.com/ethereum/go-ethereum/core/types"
"github.com/ethereum/go-ethereum/crypto"
"github.com/ethereum/go-ethereum/log"
"github.com/ethereum/go-ethereum/rlp"
)
var (
errMissJournal = errors.New("journal not found")
errMissVersion = errors.New("version not found")
errUnexpectedVersion = errors.New("unexpected journal version")
errMissDiskRoot = errors.New("disk layer root not found")
errUnmatchedJournal = errors.New("unmatched journal")
)
const journalVersion uint64 = 0
// journalNode represents a trie node persisted in the journal.
type journalNode struct {
Path []byte // Path of the node in the trie
Blob []byte // RLP-encoded trie node blob, nil means the node is deleted
}
// journalNodes represents a list trie nodes belong to a single account
// or the main account trie.
type journalNodes struct {
Owner common.Hash
Nodes []journalNode
}
// journalAccounts represents a list accounts belong to the layer.
type journalAccounts struct {
Addresses []common.Address
Accounts [][]byte
}
// journalStorage represents a list of storage slots belong to an account.
type journalStorage struct {
Incomplete bool
Account common.Address
Hashes []common.Hash
Slots [][]byte
}
// loadJournal tries to parse the layer journal from the disk.
func (db *Database) loadJournal(diskRoot common.Hash) (layer, error) {
journal := rawdb.ReadTrieJournal(db.diskdb)
if len(journal) == 0 {
return nil, errMissJournal
}
r := rlp.NewStream(bytes.NewReader(journal), 0)
// Firstly, resolve the first element as the journal version
version, err := r.Uint64()
if err != nil {
return nil, errMissVersion
}
if version != journalVersion {
return nil, fmt.Errorf("%w want %d got %d", errUnexpectedVersion, journalVersion, version)
}
// Secondly, resolve the disk layer root, ensure it's continuous
// with disk layer. Note now we can ensure it's the layer journal
// correct version, so we expect everything can be resolved properly.
var root common.Hash
if err := r.Decode(&root); err != nil {
return nil, errMissDiskRoot
}
// The journal is not matched with persistent state, discard them.
// It can happen that geth crashes without persisting the journal.
if !bytes.Equal(root.Bytes(), diskRoot.Bytes()) {
return nil, fmt.Errorf("%w want %x got %x", errUnmatchedJournal, root, diskRoot)
}
// Load the disk layer from the journal
base, err := db.loadDiskLayer(r)
if err != nil {
return nil, err
}
// Load all the diff layers from the journal
head, err := db.loadDiffLayer(base, r)
if err != nil {
return nil, err
}
log.Debug("Loaded layer journal", "diskroot", diskRoot, "diffhead", head.rootHash())
return head, nil
}
// loadLayers loads a pre-existing state layer backed by a key-value store.
func (db *Database) loadLayers() layer {
// Retrieve the root node of persistent state.
_, root := rawdb.ReadAccountTrieNode(db.diskdb, nil)
root = types.TrieRootHash(root)
// Load the layers by resolving the journal
head, err := db.loadJournal(root)
if err == nil {
return head
}
// journal is not matched(or missing) with the persistent state, discard
// it. Display log for discarding journal, but try to avoid showing
// useless information when the db is created from scratch.
if !(root == types.EmptyRootHash && errors.Is(err, errMissJournal)) {
log.Info("Failed to load journal, discard it", "err", err)
}
// Return single layer with persistent state.
return newDiskLayer(root, rawdb.ReadPersistentStateID(db.diskdb), db, nil, newNodeBuffer(db.bufferSize, nil, 0))
}
// loadDiskLayer reads the binary blob from the layer journal, reconstructing
// a new disk layer on it.
func (db *Database) loadDiskLayer(r *rlp.Stream) (layer, error) {
// Resolve disk layer root
var root common.Hash
if err := r.Decode(&root); err != nil {
return nil, fmt.Errorf("load disk root: %v", err)
}
// Resolve the state id of disk layer, it can be different
// with the persistent id tracked in disk, the id distance
// is the number of transitions aggregated in disk layer.
var id uint64
if err := r.Decode(&id); err != nil {
return nil, fmt.Errorf("load state id: %v", err)
}
stored := rawdb.ReadPersistentStateID(db.diskdb)
if stored > id {
return nil, fmt.Errorf("invalid state id: stored %d resolved %d", stored, id)
}
// Resolve nodes cached in node buffer
var encoded []journalNodes
if err := r.Decode(&encoded); err != nil {
return nil, fmt.Errorf("load disk nodes: %v", err)
}
nodes := make(map[common.Hash]map[string]*trienode.Node)
for _, entry := range encoded {
subset := make(map[string]*trienode.Node)
for _, n := range entry.Nodes {
if len(n.Blob) > 0 {
subset[string(n.Path)] = trienode.New(crypto.Keccak256Hash(n.Blob), n.Blob)
} else {
subset[string(n.Path)] = trienode.NewDeleted()
}
}
nodes[entry.Owner] = subset
}
// Calculate the internal state transitions by id difference.
base := newDiskLayer(root, id, db, nil, newNodeBuffer(db.bufferSize, nodes, id-stored))
return base, nil
}
// loadDiffLayer reads the next sections of a layer journal, reconstructing a new
// diff and verifying that it can be linked to the requested parent.
func (db *Database) loadDiffLayer(parent layer, r *rlp.Stream) (layer, error) {
// Read the next diff journal entry
var root common.Hash
if err := r.Decode(&root); err != nil {
// The first read may fail with EOF, marking the end of the journal
if err == io.EOF {
return parent, nil
}
return nil, fmt.Errorf("load diff root: %v", err)
}
var block uint64
if err := r.Decode(&block); err != nil {
return nil, fmt.Errorf("load block number: %v", err)
}
// Read in-memory trie nodes from journal
var encoded []journalNodes
if err := r.Decode(&encoded); err != nil {
return nil, fmt.Errorf("load diff nodes: %v", err)
}
nodes := make(map[common.Hash]map[string]*trienode.Node)
for _, entry := range encoded {
subset := make(map[string]*trienode.Node)
for _, n := range entry.Nodes {
if len(n.Blob) > 0 {
subset[string(n.Path)] = trienode.New(crypto.Keccak256Hash(n.Blob), n.Blob)
} else {
subset[string(n.Path)] = trienode.NewDeleted()
}
}
nodes[entry.Owner] = subset
}
// Read state changes from journal
var (
jaccounts journalAccounts
jstorages []journalStorage
accounts = make(map[common.Address][]byte)
storages = make(map[common.Address]map[common.Hash][]byte)
incomplete = make(map[common.Address]struct{})
)
if err := r.Decode(&jaccounts); err != nil {
return nil, fmt.Errorf("load diff accounts: %v", err)
}
for i, addr := range jaccounts.Addresses {
accounts[addr] = jaccounts.Accounts[i]
}
if err := r.Decode(&jstorages); err != nil {
return nil, fmt.Errorf("load diff storages: %v", err)
}
for _, entry := range jstorages {
set := make(map[common.Hash][]byte)
for i, h := range entry.Hashes {
if len(entry.Slots[i]) > 0 {
set[h] = entry.Slots[i]
} else {
set[h] = nil
}
}
if entry.Incomplete {
incomplete[entry.Account] = struct{}{}
}
storages[entry.Account] = set
}
return db.loadDiffLayer(newDiffLayer(parent, root, parent.stateID()+1, block, nodes, triestate.New(accounts, storages, incomplete)), r)
}
// journal implements the layer interface, marshaling the un-flushed trie nodes
// along with layer meta data into provided byte buffer.
func (dl *diskLayer) journal(w io.Writer) error {
dl.lock.RLock()
defer dl.lock.RUnlock()
// Ensure the layer didn't get stale
if dl.stale {
return errSnapshotStale
}
// Step one, write the disk root into the journal.
if err := rlp.Encode(w, dl.root); err != nil {
return err
}
// Step two, write the corresponding state id into the journal
if err := rlp.Encode(w, dl.id); err != nil {
return err
}
// Step three, write all unwritten nodes into the journal
nodes := make([]journalNodes, 0, len(dl.buffer.nodes))
for owner, subset := range dl.buffer.nodes {
entry := journalNodes{Owner: owner}
for path, node := range subset {
entry.Nodes = append(entry.Nodes, journalNode{Path: []byte(path), Blob: node.Blob})
}
nodes = append(nodes, entry)
}
if err := rlp.Encode(w, nodes); err != nil {
return err
}
log.Debug("Journaled pathdb disk layer", "root", dl.root, "nodes", len(dl.buffer.nodes))
return nil
}
// journal implements the layer interface, writing the memory layer contents
// into a buffer to be stored in the database as the layer journal.
func (dl *diffLayer) journal(w io.Writer) error {
dl.lock.RLock()
defer dl.lock.RUnlock()
// journal the parent first
if err := dl.parent.journal(w); err != nil {
return err
}
// Everything below was journaled, persist this layer too
if err := rlp.Encode(w, dl.root); err != nil {
return err
}
if err := rlp.Encode(w, dl.block); err != nil {
return err
}
// Write the accumulated trie nodes into buffer
nodes := make([]journalNodes, 0, len(dl.nodes))
for owner, subset := range dl.nodes {
entry := journalNodes{Owner: owner}
for path, node := range subset {
entry.Nodes = append(entry.Nodes, journalNode{Path: []byte(path), Blob: node.Blob})
}
nodes = append(nodes, entry)
}
if err := rlp.Encode(w, nodes); err != nil {
return err
}
// Write the accumulated state changes into buffer
var jacct journalAccounts
for addr, account := range dl.states.Accounts {
jacct.Addresses = append(jacct.Addresses, addr)
jacct.Accounts = append(jacct.Accounts, account)
}
if err := rlp.Encode(w, jacct); err != nil {
return err
}
storage := make([]journalStorage, 0, len(dl.states.Storages))
for addr, slots := range dl.states.Storages {
entry := journalStorage{Account: addr}
if _, ok := dl.states.Incomplete[addr]; ok {
entry.Incomplete = true
}
for slotHash, slot := range slots {
entry.Hashes = append(entry.Hashes, slotHash)
entry.Slots = append(entry.Slots, slot)
}
storage = append(storage, entry)
}
if err := rlp.Encode(w, storage); err != nil {
return err
}
log.Debug("Journaled pathdb diff layer", "root", dl.root, "parent", dl.parent.rootHash(), "id", dl.stateID(), "block", dl.block, "nodes", len(dl.nodes))
return nil
}
// Journal commits an entire diff hierarchy to disk into a single journal entry.
// This is meant to be used during shutdown to persist the layer without
// flattening everything down (bad for reorgs). And this function will mark the
// database as read-only to prevent all following mutation to disk.
func (db *Database) Journal(root common.Hash) error {
// Retrieve the head layer to journal from.
l := db.tree.get(root)
if l == nil {
return fmt.Errorf("triedb layer [%#x] missing", root)
}
disk := db.tree.bottom()
if l, ok := l.(*diffLayer); ok {
log.Info("Persisting dirty state to disk", "head", l.block, "root", root, "layers", l.id-disk.id+disk.buffer.layers)
} else { // disk layer only on noop runs (likely) or deep reorgs (unlikely)
log.Info("Persisting dirty state to disk", "root", root, "layers", disk.buffer.layers)
}
start := time.Now()
// Run the journaling
db.lock.Lock()
defer db.lock.Unlock()
// Short circuit if the database is in read only mode.
if db.readOnly {
return errDatabaseReadOnly
}
// Firstly write out the metadata of journal
journal := new(bytes.Buffer)
if err := rlp.Encode(journal, journalVersion); err != nil {
return err
}
// The stored state in disk might be empty, convert the
// root to emptyRoot in this case.
_, diskroot := rawdb.ReadAccountTrieNode(db.diskdb, nil)
diskroot = types.TrieRootHash(diskroot)
// Secondly write out the state root in disk, ensure all layers
// on top are continuous with disk.
if err := rlp.Encode(journal, diskroot); err != nil {
return err
}
// Finally write out the journal of each layer in reverse order.
if err := l.journal(journal); err != nil {
return err
}
// Store the journal into the database and return
rawdb.WriteTrieJournal(db.diskdb, journal.Bytes())
// Set the db in read only mode to reject all following mutations
db.readOnly = true
log.Info("Persisted dirty state to disk", "size", common.StorageSize(journal.Len()), "elapsed", common.PrettyDuration(time.Since(start)))
return nil
}

View File

@ -0,0 +1,214 @@
// Copyright 2022 The go-ethereum Authors
// This file is part of the go-ethereum library.
//
// The go-ethereum library is free software: you can redistribute it and/or modify
// it under the terms of the GNU Lesser General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// The go-ethereum library is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Lesser General Public License for more details.
//
// You should have received a copy of the GNU Lesser General Public License
// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>
package pathdb
import (
"errors"
"fmt"
"sync"
"github.com/cerc-io/ipld-eth-statedb/trie_by_cid/trie/trienode"
"github.com/cerc-io/ipld-eth-statedb/trie_by_cid/trie/triestate"
"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/core/types"
)
// layerTree is a group of state layers identified by the state root.
// This structure defines a few basic operations for manipulating
// state layers linked with each other in a tree structure. It's
// thread-safe to use. However, callers need to ensure the thread-safety
// of the referenced layer by themselves.
type layerTree struct {
lock sync.RWMutex
layers map[common.Hash]layer
}
// newLayerTree constructs the layerTree with the given head layer.
func newLayerTree(head layer) *layerTree {
tree := new(layerTree)
tree.reset(head)
return tree
}
// reset initializes the layerTree by the given head layer.
// All the ancestors will be iterated out and linked in the tree.
func (tree *layerTree) reset(head layer) {
tree.lock.Lock()
defer tree.lock.Unlock()
var layers = make(map[common.Hash]layer)
for head != nil {
layers[head.rootHash()] = head
head = head.parentLayer()
}
tree.layers = layers
}
// get retrieves a layer belonging to the given state root.
func (tree *layerTree) get(root common.Hash) layer {
tree.lock.RLock()
defer tree.lock.RUnlock()
return tree.layers[types.TrieRootHash(root)]
}
// forEach iterates the stored layers inside and applies the
// given callback on them.
func (tree *layerTree) forEach(onLayer func(layer)) {
tree.lock.RLock()
defer tree.lock.RUnlock()
for _, layer := range tree.layers {
onLayer(layer)
}
}
// len returns the number of layers cached.
func (tree *layerTree) len() int {
tree.lock.RLock()
defer tree.lock.RUnlock()
return len(tree.layers)
}
// add inserts a new layer into the tree if it can be linked to an existing old parent.
func (tree *layerTree) add(root common.Hash, parentRoot common.Hash, block uint64, nodes *trienode.MergedNodeSet, states *triestate.Set) error {
// Reject noop updates to avoid self-loops. This is a special case that can
// happen for clique networks and proof-of-stake networks where empty blocks
// don't modify the state (0 block subsidy).
//
// Although we could silently ignore this internally, it should be the caller's
// responsibility to avoid even attempting to insert such a layer.
root, parentRoot = types.TrieRootHash(root), types.TrieRootHash(parentRoot)
if root == parentRoot {
return errors.New("layer cycle")
}
parent := tree.get(parentRoot)
if parent == nil {
return fmt.Errorf("triedb parent [%#x] layer missing", parentRoot)
}
l := parent.update(root, parent.stateID()+1, block, nodes.Flatten(), states)
tree.lock.Lock()
tree.layers[l.rootHash()] = l
tree.lock.Unlock()
return nil
}
// cap traverses downwards the diff tree until the number of allowed diff layers
// are crossed. All diffs beyond the permitted number are flattened downwards.
func (tree *layerTree) cap(root common.Hash, layers int) error {
// Retrieve the head layer to cap from
root = types.TrieRootHash(root)
l := tree.get(root)
if l == nil {
return fmt.Errorf("triedb layer [%#x] missing", root)
}
diff, ok := l.(*diffLayer)
if !ok {
return fmt.Errorf("triedb layer [%#x] is disk layer", root)
}
tree.lock.Lock()
defer tree.lock.Unlock()
// If full commit was requested, flatten the diffs and merge onto disk
if layers == 0 {
base, err := diff.persist(true)
if err != nil {
return err
}
// Replace the entire layer tree with the flat base
tree.layers = map[common.Hash]layer{base.rootHash(): base}
return nil
}
// Dive until we run out of layers or reach the persistent database
for i := 0; i < layers-1; i++ {
// If we still have diff layers below, continue down
if parent, ok := diff.parentLayer().(*diffLayer); ok {
diff = parent
} else {
// Diff stack too shallow, return without modifications
return nil
}
}
// We're out of layers, flatten anything below, stopping if it's the disk or if
// the memory limit is not yet exceeded.
switch parent := diff.parentLayer().(type) {
case *diskLayer:
return nil
case *diffLayer:
// Hold the lock to prevent any read operations until the new
// parent is linked correctly.
diff.lock.Lock()
base, err := parent.persist(false)
if err != nil {
diff.lock.Unlock()
return err
}
tree.layers[base.rootHash()] = base
diff.parent = base
diff.lock.Unlock()
default:
panic(fmt.Sprintf("unknown data layer in triedb: %T", parent))
}
// Remove any layer that is stale or links into a stale layer
children := make(map[common.Hash][]common.Hash)
for root, layer := range tree.layers {
if dl, ok := layer.(*diffLayer); ok {
parent := dl.parentLayer().rootHash()
children[parent] = append(children[parent], root)
}
}
var remove func(root common.Hash)
remove = func(root common.Hash) {
delete(tree.layers, root)
for _, child := range children[root] {
remove(child)
}
delete(children, root)
}
for root, layer := range tree.layers {
if dl, ok := layer.(*diskLayer); ok && dl.isStale() {
remove(root)
}
}
return nil
}
// bottom returns the bottom-most disk layer in this tree.
func (tree *layerTree) bottom() *diskLayer {
tree.lock.RLock()
defer tree.lock.RUnlock()
if len(tree.layers) == 0 {
return nil // Shouldn't happen, empty tree
}
// pick a random one as the entry point
var current layer
for _, layer := range tree.layers {
current = layer
break
}
for current.parentLayer() != nil {
current = current.parentLayer()
}
return current.(*diskLayer)
}

View File

@ -0,0 +1,50 @@
// Copyright 2022 The go-ethereum Authors
// This file is part of the go-ethereum library.
//
// The go-ethereum library is free software: you can redistribute it and/or modify
// it under the terms of the GNU Lesser General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// The go-ethereum library is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Lesser General Public License for more details.
//
// You should have received a copy of the GNU Lesser General Public License
// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>
package pathdb
import "github.com/ethereum/go-ethereum/metrics"
var (
cleanHitMeter = metrics.NewRegisteredMeter("pathdb/clean/hit", nil)
cleanMissMeter = metrics.NewRegisteredMeter("pathdb/clean/miss", nil)
cleanReadMeter = metrics.NewRegisteredMeter("pathdb/clean/read", nil)
cleanWriteMeter = metrics.NewRegisteredMeter("pathdb/clean/write", nil)
dirtyHitMeter = metrics.NewRegisteredMeter("pathdb/dirty/hit", nil)
dirtyMissMeter = metrics.NewRegisteredMeter("pathdb/dirty/miss", nil)
dirtyReadMeter = metrics.NewRegisteredMeter("pathdb/dirty/read", nil)
dirtyWriteMeter = metrics.NewRegisteredMeter("pathdb/dirty/write", nil)
dirtyNodeHitDepthHist = metrics.NewRegisteredHistogram("pathdb/dirty/depth", nil, metrics.NewExpDecaySample(1028, 0.015))
cleanFalseMeter = metrics.NewRegisteredMeter("pathdb/clean/false", nil)
dirtyFalseMeter = metrics.NewRegisteredMeter("pathdb/dirty/false", nil)
diskFalseMeter = metrics.NewRegisteredMeter("pathdb/disk/false", nil)
commitTimeTimer = metrics.NewRegisteredTimer("pathdb/commit/time", nil)
commitNodesMeter = metrics.NewRegisteredMeter("pathdb/commit/nodes", nil)
commitBytesMeter = metrics.NewRegisteredMeter("pathdb/commit/bytes", nil)
gcNodesMeter = metrics.NewRegisteredMeter("pathdb/gc/nodes", nil)
gcBytesMeter = metrics.NewRegisteredMeter("pathdb/gc/bytes", nil)
diffLayerBytesMeter = metrics.NewRegisteredMeter("pathdb/diff/bytes", nil)
diffLayerNodesMeter = metrics.NewRegisteredMeter("pathdb/diff/nodes", nil)
historyBuildTimeMeter = metrics.NewRegisteredTimer("pathdb/history/time", nil)
historyDataBytesMeter = metrics.NewRegisteredMeter("pathdb/history/bytes/data", nil)
historyIndexBytesMeter = metrics.NewRegisteredMeter("pathdb/history/bytes/index", nil)
)

View File

@ -0,0 +1,275 @@
// Copyright 2022 The go-ethereum Authors
// This file is part of the go-ethereum library.
//
// The go-ethereum library is free software: you can redistribute it and/or modify
// it under the terms of the GNU Lesser General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// The go-ethereum library is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Lesser General Public License for more details.
//
// You should have received a copy of the GNU Lesser General Public License
// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
package pathdb
import (
"fmt"
"time"
"github.com/VictoriaMetrics/fastcache"
"github.com/cerc-io/ipld-eth-statedb/trie_by_cid/trie/trienode"
"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/core/rawdb"
"github.com/ethereum/go-ethereum/crypto"
"github.com/ethereum/go-ethereum/ethdb"
"github.com/ethereum/go-ethereum/log"
)
// nodebuffer is a collection of modified trie nodes to aggregate the disk
// write. The content of the nodebuffer must be checked before diving into
// disk (since it basically is not-yet-written data).
type nodebuffer struct {
layers uint64 // The number of diff layers aggregated inside
size uint64 // The size of aggregated writes
limit uint64 // The maximum memory allowance in bytes
nodes map[common.Hash]map[string]*trienode.Node // The dirty node set, mapped by owner and path
}
// newNodeBuffer initializes the node buffer with the provided nodes.
func newNodeBuffer(limit int, nodes map[common.Hash]map[string]*trienode.Node, layers uint64) *nodebuffer {
if nodes == nil {
nodes = make(map[common.Hash]map[string]*trienode.Node)
}
var size uint64
for _, subset := range nodes {
for path, n := range subset {
size += uint64(len(n.Blob) + len(path))
}
}
return &nodebuffer{
layers: layers,
nodes: nodes,
size: size,
limit: uint64(limit),
}
}
// node retrieves the trie node with given node info.
func (b *nodebuffer) node(owner common.Hash, path []byte, hash common.Hash) (*trienode.Node, error) {
subset, ok := b.nodes[owner]
if !ok {
return nil, nil
}
n, ok := subset[string(path)]
if !ok {
return nil, nil
}
if n.Hash != hash {
dirtyFalseMeter.Mark(1)
log.Error("Unexpected trie node in node buffer", "owner", owner, "path", path, "expect", hash, "got", n.Hash)
return nil, newUnexpectedNodeError("dirty", hash, n.Hash, owner, path, n.Blob)
}
return n, nil
}
// commit merges the dirty nodes into the nodebuffer. This operation won't take
// the ownership of the nodes map which belongs to the bottom-most diff layer.
// It will just hold the node references from the given map which are safe to
// copy.
func (b *nodebuffer) commit(nodes map[common.Hash]map[string]*trienode.Node) *nodebuffer {
var (
delta int64
overwrite int64
overwriteSize int64
)
for owner, subset := range nodes {
current, exist := b.nodes[owner]
if !exist {
// Allocate a new map for the subset instead of claiming it directly
// from the passed map to avoid potential concurrent map read/write.
// The nodes belong to original diff layer are still accessible even
// after merging, thus the ownership of nodes map should still belong
// to original layer and any mutation on it should be prevented.
current = make(map[string]*trienode.Node)
for path, n := range subset {
current[path] = n
delta += int64(len(n.Blob) + len(path))
}
b.nodes[owner] = current
continue
}
for path, n := range subset {
if orig, exist := current[path]; !exist {
delta += int64(len(n.Blob) + len(path))
} else {
delta += int64(len(n.Blob) - len(orig.Blob))
overwrite++
overwriteSize += int64(len(orig.Blob) + len(path))
}
current[path] = n
}
b.nodes[owner] = current
}
b.updateSize(delta)
b.layers++
gcNodesMeter.Mark(overwrite)
gcBytesMeter.Mark(overwriteSize)
return b
}
// revert is the reverse operation of commit. It also merges the provided nodes
// into the nodebuffer, the difference is that the provided node set should
// revert the changes made by the last state transition.
func (b *nodebuffer) revert(db ethdb.KeyValueReader, nodes map[common.Hash]map[string]*trienode.Node) error {
// Short circuit if no embedded state transition to revert.
if b.layers == 0 {
return errStateUnrecoverable
}
b.layers--
// Reset the entire buffer if only a single transition left.
if b.layers == 0 {
b.reset()
return nil
}
var delta int64
for owner, subset := range nodes {
current, ok := b.nodes[owner]
if !ok {
panic(fmt.Sprintf("non-existent subset (%x)", owner))
}
for path, n := range subset {
orig, ok := current[path]
if !ok {
// There is a special case in MPT that one child is removed from
// a fullNode which only has two children, and then a new child
// with different position is immediately inserted into the fullNode.
// In this case, the clean child of the fullNode will also be
// marked as dirty because of node collapse and expansion.
//
// In case of database rollback, don't panic if this "clean"
// node occurs which is not present in buffer.
var nhash common.Hash
if owner == (common.Hash{}) {
_, nhash = rawdb.ReadAccountTrieNode(db, []byte(path))
} else {
_, nhash = rawdb.ReadStorageTrieNode(db, owner, []byte(path))
}
// Ignore the clean node in the case described above.
if nhash == n.Hash {
continue
}
panic(fmt.Sprintf("non-existent node (%x %v) blob: %v", owner, path, crypto.Keccak256Hash(n.Blob).Hex()))
}
current[path] = n
delta += int64(len(n.Blob)) - int64(len(orig.Blob))
}
}
b.updateSize(delta)
return nil
}
// updateSize updates the total cache size by the given delta.
func (b *nodebuffer) updateSize(delta int64) {
size := int64(b.size) + delta
if size >= 0 {
b.size = uint64(size)
return
}
s := b.size
b.size = 0
log.Error("Invalid pathdb buffer size", "prev", common.StorageSize(s), "delta", common.StorageSize(delta))
}
// reset cleans up the disk cache.
func (b *nodebuffer) reset() {
b.layers = 0
b.size = 0
b.nodes = make(map[common.Hash]map[string]*trienode.Node)
}
// empty returns an indicator if nodebuffer contains any state transition inside.
func (b *nodebuffer) empty() bool {
return b.layers == 0
}
// setSize sets the buffer size to the provided number, and invokes a flush
// operation if the current memory usage exceeds the new limit.
func (b *nodebuffer) setSize(size int, db ethdb.KeyValueStore, clean *fastcache.Cache, id uint64) error {
b.limit = uint64(size)
return b.flush(db, clean, id, false)
}
// flush persists the in-memory dirty trie node into the disk if the configured
// memory threshold is reached. Note, all data must be written atomically.
func (b *nodebuffer) flush(db ethdb.KeyValueStore, clean *fastcache.Cache, id uint64, force bool) error {
if b.size <= b.limit && !force {
return nil
}
// Ensure the target state id is aligned with the internal counter.
head := rawdb.ReadPersistentStateID(db)
if head+b.layers != id {
return fmt.Errorf("buffer layers (%d) cannot be applied on top of persisted state id (%d) to reach requested state id (%d)", b.layers, head, id)
}
var (
start = time.Now()
batch = db.NewBatchWithSize(int(b.size))
)
nodes := writeNodes(batch, b.nodes, clean)
rawdb.WritePersistentStateID(batch, id)
// Flush all mutations in a single batch
size := batch.ValueSize()
if err := batch.Write(); err != nil {
return err
}
commitBytesMeter.Mark(int64(size))
commitNodesMeter.Mark(int64(nodes))
commitTimeTimer.UpdateSince(start)
log.Debug("Persisted pathdb nodes", "nodes", len(b.nodes), "bytes", common.StorageSize(size), "elapsed", common.PrettyDuration(time.Since(start)))
b.reset()
return nil
}
// writeNodes writes the trie nodes into the provided database batch.
// Note this function will also inject all the newly written nodes
// into clean cache.
func writeNodes(batch ethdb.Batch, nodes map[common.Hash]map[string]*trienode.Node, clean *fastcache.Cache) (total int) {
for owner, subset := range nodes {
for path, n := range subset {
if n.IsDeleted() {
if owner == (common.Hash{}) {
rawdb.DeleteAccountTrieNode(batch, []byte(path))
} else {
rawdb.DeleteStorageTrieNode(batch, owner, []byte(path))
}
if clean != nil {
clean.Del(cacheKey(owner, []byte(path)))
}
} else {
if owner == (common.Hash{}) {
rawdb.WriteAccountTrieNode(batch, []byte(path), n.Blob)
} else {
rawdb.WriteStorageTrieNode(batch, owner, []byte(path), n.Blob)
}
if clean != nil {
clean.Set(cacheKey(owner, []byte(path)), n.Blob)
}
}
}
total += len(subset)
}
return total
}
// cacheKey constructs the unique key of clean cache.
func cacheKey(owner common.Hash, path []byte) []byte {
if owner == (common.Hash{}) {
return path
}
return append(owner.Bytes(), path...)
}

View File

@ -0,0 +1,157 @@
// Copyright 2023 The go-ethereum Authors
// This file is part of the go-ethereum library.
//
// The go-ethereum library is free software: you can redistribute it and/or modify
// it under the terms of the GNU Lesser General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// The go-ethereum library is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Lesser General Public License for more details.
//
// You should have received a copy of the GNU Lesser General Public License
// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
package pathdb
import (
"bytes"
"fmt"
"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/core/types"
"github.com/ethereum/go-ethereum/crypto"
"golang.org/x/exp/slices"
"github.com/cerc-io/ipld-eth-statedb/trie_by_cid/trie/trienode"
"github.com/cerc-io/ipld-eth-statedb/trie_by_cid/trie/triestate"
)
// testHasher is a test utility for computing root hash of a batch of state
// elements. The hash algorithm is to sort all the elements in lexicographical
// order, concat the key and value in turn, and perform hash calculation on
// the concatenated bytes. Except the root hash, a nodeset will be returned
// once Commit is called, which contains all the changes made to hasher.
type testHasher struct {
owner common.Hash // owner identifier
root common.Hash // original root
dirties map[common.Hash][]byte // dirty states
cleans map[common.Hash][]byte // clean states
}
// newTestHasher constructs a hasher object with provided states.
func newTestHasher(owner common.Hash, root common.Hash, cleans map[common.Hash][]byte) (*testHasher, error) {
if cleans == nil {
cleans = make(map[common.Hash][]byte)
}
if got, _ := hash(cleans); got != root {
return nil, fmt.Errorf("state root mismatched, want: %x, got: %x", root, got)
}
return &testHasher{
owner: owner,
root: root,
dirties: make(map[common.Hash][]byte),
cleans: cleans,
}, nil
}
// Get returns the value for key stored in the trie.
func (h *testHasher) Get(key []byte) ([]byte, error) {
hash := common.BytesToHash(key)
val, ok := h.dirties[hash]
if ok {
return val, nil
}
return h.cleans[hash], nil
}
// Update associates key with value in the trie.
func (h *testHasher) Update(key, value []byte) error {
h.dirties[common.BytesToHash(key)] = common.CopyBytes(value)
return nil
}
// Delete removes any existing value for key from the trie.
func (h *testHasher) Delete(key []byte) error {
h.dirties[common.BytesToHash(key)] = nil
return nil
}
// Commit computes the new hash of the states and returns the set with all
// state changes.
func (h *testHasher) Commit(collectLeaf bool) (common.Hash, *trienode.NodeSet, error) {
var (
nodes = make(map[common.Hash][]byte)
set = trienode.NewNodeSet(h.owner)
)
for hash, val := range h.cleans {
nodes[hash] = val
}
for hash, val := range h.dirties {
nodes[hash] = val
if bytes.Equal(val, h.cleans[hash]) {
continue
}
if len(val) == 0 {
set.AddNode(hash.Bytes(), trienode.NewDeleted())
} else {
set.AddNode(hash.Bytes(), trienode.New(crypto.Keccak256Hash(val), val))
}
}
root, blob := hash(nodes)
// Include the dirty root node as well.
if root != types.EmptyRootHash && root != h.root {
set.AddNode(nil, trienode.New(root, blob))
}
if root == types.EmptyRootHash && h.root != types.EmptyRootHash {
set.AddNode(nil, trienode.NewDeleted())
}
return root, set, nil
}
// hash performs the hash computation upon the provided states.
func hash(states map[common.Hash][]byte) (common.Hash, []byte) {
var hs []common.Hash
for hash := range states {
hs = append(hs, hash)
}
slices.SortFunc(hs, common.Hash.Cmp)
var input []byte
for _, hash := range hs {
if len(states[hash]) == 0 {
continue
}
input = append(input, hash.Bytes()...)
input = append(input, states[hash]...)
}
if len(input) == 0 {
return types.EmptyRootHash, nil
}
return crypto.Keccak256Hash(input), input
}
type hashLoader struct {
accounts map[common.Hash][]byte
storages map[common.Hash]map[common.Hash][]byte
}
func newHashLoader(accounts map[common.Hash][]byte, storages map[common.Hash]map[common.Hash][]byte) *hashLoader {
return &hashLoader{
accounts: accounts,
storages: storages,
}
}
// OpenTrie opens the main account trie.
func (l *hashLoader) OpenTrie(root common.Hash) (triestate.Trie, error) {
return newTestHasher(common.Hash{}, root, l.accounts)
}
// OpenStorageTrie opens the storage trie of an account.
func (l *hashLoader) OpenStorageTrie(stateRoot common.Hash, addrHash, root common.Hash) (triestate.Trie, error) {
return newTestHasher(addrHash, root, l.storages[addrHash])
}

View File

@ -14,7 +14,7 @@
// You should have received a copy of the GNU Lesser General Public License // You should have received a copy of the GNU Lesser General Public License
// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>. // along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
package trie package triedb
import ( import (
"sync" "sync"
@ -69,6 +69,23 @@ func (store *preimageStore) preimage(hash common.Hash) []byte {
return rawdb.ReadPreimage(store.disk, hash) return rawdb.ReadPreimage(store.disk, hash)
} }
// commit flushes the cached preimages into the disk.
func (store *preimageStore) commit(force bool) error {
store.lock.Lock()
defer store.lock.Unlock()
if store.preimagesSize <= 4*1024*1024 && !force {
return nil
}
batch := store.disk.NewBatch()
rawdb.WritePreimages(batch, store.preimages)
if err := batch.Write(); err != nil {
return err
}
store.preimages, store.preimagesSize = make(map[common.Hash][]byte), 0
return nil
}
// size returns the current storage size of accumulated preimages. // size returns the current storage size of accumulated preimages.
func (store *preimageStore) size() common.StorageSize { func (store *preimageStore) size() common.StorageSize {
store.lock.RLock() store.lock.RLock()