Redo trie_by_cid package to be read-write

* use logrus instead of geth log
* remove benchmarks
* impl NodeIterator.ParentPath
* update go mods
This commit is contained in:
Roy Crihfield 2023-04-25 17:14:04 +08:00
parent c839739320
commit 7381b35dc6
43 changed files with 7237 additions and 707 deletions

5
go.mod
View File

@ -5,6 +5,7 @@ go 1.18
require ( require (
github.com/VictoriaMetrics/fastcache v1.6.0 github.com/VictoriaMetrics/fastcache v1.6.0
github.com/cerc-io/ipfs-ethdb/v5 v5.0.0-alpha github.com/cerc-io/ipfs-ethdb/v5 v5.0.0-alpha
github.com/davecgh/go-spew v1.1.1
github.com/ethereum/go-ethereum v1.11.5 github.com/ethereum/go-ethereum v1.11.5
github.com/hashicorp/golang-lru v0.5.5-0.20210104140557-80c98217689d github.com/hashicorp/golang-lru v0.5.5-0.20210104140557-80c98217689d
github.com/ipfs/go-cid v0.2.0 github.com/ipfs/go-cid v0.2.0
@ -21,13 +22,13 @@ require (
github.com/DataDog/zstd v1.5.2 // indirect github.com/DataDog/zstd v1.5.2 // indirect
github.com/beorn7/perks v1.0.1 // indirect github.com/beorn7/perks v1.0.1 // indirect
github.com/btcsuite/btcd/btcec/v2 v2.2.0 // indirect github.com/btcsuite/btcd/btcec/v2 v2.2.0 // indirect
github.com/btcsuite/btcd/chaincfg/chainhash v1.0.2 // indirect
github.com/cespare/xxhash/v2 v2.2.0 // indirect github.com/cespare/xxhash/v2 v2.2.0 // indirect
github.com/cockroachdb/errors v1.9.1 // indirect github.com/cockroachdb/errors v1.9.1 // indirect
github.com/cockroachdb/logtags v0.0.0-20230118201751-21c54148d20b // indirect github.com/cockroachdb/logtags v0.0.0-20230118201751-21c54148d20b // indirect
github.com/cockroachdb/pebble v0.0.0-20230209160836-829675f94811 // indirect github.com/cockroachdb/pebble v0.0.0-20230209160836-829675f94811 // indirect
github.com/cockroachdb/redact v1.1.3 // indirect github.com/cockroachdb/redact v1.1.3 // indirect
github.com/cpuguy83/go-md2man/v2 v2.0.2 // indirect github.com/cpuguy83/go-md2man/v2 v2.0.2 // indirect
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/deckarep/golang-set/v2 v2.1.0 // indirect github.com/deckarep/golang-set/v2 v2.1.0 // indirect
github.com/decred/dcrd/dcrec/secp256k1/v4 v4.0.1 // indirect github.com/decred/dcrd/dcrec/secp256k1/v4 v4.0.1 // indirect
github.com/edsrzf/mmap-go v1.0.0 // indirect github.com/edsrzf/mmap-go v1.0.0 // indirect
@ -112,4 +113,4 @@ require (
lukechampine.com/blake3 v1.1.6 // indirect lukechampine.com/blake3 v1.1.6 // indirect
) )
replace github.com/ethereum/go-ethereum v1.11.5 => github.com/cerc-io/go-ethereum v1.11.5-statediff-5.0.1-alpha replace github.com/ethereum/go-ethereum v1.11.5 => github.com/cerc-io/go-ethereum v1.11.5-statediff-5.0.4

7
go.sum
View File

@ -17,12 +17,13 @@ github.com/armon/consul-api v0.0.0-20180202201655-eb2c6b5be1b6/go.mod h1:grANhF5
github.com/aymerick/raymond v2.0.3-0.20180322193309-b565731e1464+incompatible/go.mod h1:osfaiScAUVup+UC9Nfq76eWqDhXlp+4UYaA8uhTBO6g= github.com/aymerick/raymond v2.0.3-0.20180322193309-b565731e1464+incompatible/go.mod h1:osfaiScAUVup+UC9Nfq76eWqDhXlp+4UYaA8uhTBO6g=
github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM= github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM=
github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw= github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw=
github.com/btcsuite/btcd v0.22.0-beta h1:LTDpDKUM5EeOFBPM8IXpinEcmZ6FWfNZbE3lfrfdnWo=
github.com/btcsuite/btcd/btcec/v2 v2.2.0 h1:fzn1qaOt32TuLjFlkzYSsBC35Q3KUjT1SwPxiMSCF5k= github.com/btcsuite/btcd/btcec/v2 v2.2.0 h1:fzn1qaOt32TuLjFlkzYSsBC35Q3KUjT1SwPxiMSCF5k=
github.com/btcsuite/btcd/btcec/v2 v2.2.0/go.mod h1:U7MHm051Al6XmscBQ0BoNydpOTsFAn707034b5nY8zU= github.com/btcsuite/btcd/btcec/v2 v2.2.0/go.mod h1:U7MHm051Al6XmscBQ0BoNydpOTsFAn707034b5nY8zU=
github.com/btcsuite/btcd/chaincfg/chainhash v1.0.2 h1:KdUfX2zKommPRa+PD0sWZUyXe9w277ABlgELO7H04IM=
github.com/btcsuite/btcd/chaincfg/chainhash v1.0.2/go.mod h1:7SFka0XMvUgj3hfZtydOrQY2mwhPclbT2snogU7SQQc=
github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod h1:f6KPmirojxKA12rnyqOA5BBL4O983OfeGPqjHWSTneU= github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod h1:f6KPmirojxKA12rnyqOA5BBL4O983OfeGPqjHWSTneU=
github.com/cerc-io/go-ethereum v1.11.5-statediff-5.0.1-alpha h1:x9muuG0Z2W/UAkwAq+0SXhYG9MCP6SJlGEfFNQa6JPI= github.com/cerc-io/go-ethereum v1.11.5-statediff-5.0.3-alpha h1:rhRmK/NeWMnQ07E4DuLb7WSh9FMotlXMPPaOrf8GJwM=
github.com/cerc-io/go-ethereum v1.11.5-statediff-5.0.1-alpha/go.mod h1:DIk2wFexjyzvyjuzSOtBEIAPRNZTnLXNbIHEyq1Igek= github.com/cerc-io/go-ethereum v1.11.5-statediff-5.0.3-alpha/go.mod h1:DIk2wFexjyzvyjuzSOtBEIAPRNZTnLXNbIHEyq1Igek=
github.com/cerc-io/ipfs-ethdb/v5 v5.0.0-alpha h1:I1iXTaIjbTH8ehzNXmT2waXcYBifi1yjK6FK3W3a0Pg= github.com/cerc-io/ipfs-ethdb/v5 v5.0.0-alpha h1:I1iXTaIjbTH8ehzNXmT2waXcYBifi1yjK6FK3W3a0Pg=
github.com/cerc-io/ipfs-ethdb/v5 v5.0.0-alpha/go.mod h1:EGAdV/YewEADFDDVF1k9GNwy8vNWR29Xb87sRHgMIng= github.com/cerc-io/ipfs-ethdb/v5 v5.0.0-alpha/go.mod h1:EGAdV/YewEADFDDVF1k9GNwy8vNWR29Xb87sRHgMIng=
github.com/cespare/cp v0.1.0 h1:SE+dxFebS7Iik5LK0tsi1k9ZCxEaFX4AjQmoyA+1dJk= github.com/cespare/cp v0.1.0 h1:SE+dxFebS7Iik5LK0tsi1k9ZCxEaFX4AjQmoyA+1dJk=

View File

@ -1,6 +1,10 @@
package internal package internal
import ( import (
"testing"
"time"
pgipfsethdb "github.com/cerc-io/ipfs-ethdb/v5/postgres/v0"
"github.com/ipfs/go-cid" "github.com/ipfs/go-cid"
"github.com/multiformats/go-multihash" "github.com/multiformats/go-multihash"
) )
@ -12,3 +16,12 @@ func Keccak256ToCid(codec uint64, h []byte) (cid.Cid, error) {
} }
return cid.NewCidV1(codec, buf), nil return cid.NewCidV1(codec, buf), nil
} }
// returns a cache config with unique name (groupcache names are global)
func MakeCacheConfig(t testing.TB) pgipfsethdb.CacheConfig {
return pgipfsethdb.CacheConfig{
Name: t.Name(),
Size: 3000000, // 3MB
ExpiryDuration: time.Hour,
}
}

3
trie_by_cid/doc.go Normal file
View File

@ -0,0 +1,3 @@
// This package is a near complete copy of go-ethereum/trie and go-ethereum/core/state, modified to use
// a v0 IPFS blockstore as the backing DB, i.e. DB values are indexed by CID rather than hash.
package trie_by_cid

View File

@ -20,7 +20,10 @@ var (
mockTD = big.NewInt(1) mockTD = big.NewInt(1)
) )
func IndexChain(dbConfig postgres.Config, stateCache state.Database, rootA, rootB common.Hash) error { // IndexStateDiff indexes a single statediff.
// - uses TestChainConfig
// - block hash/number are left as zero
func IndexStateDiff(dbConfig postgres.Config, stateCache state.Database, rootA, rootB common.Hash) error {
_, indexer, err := indexer.NewStateDiffIndexer( _, indexer, err := indexer.NewStateDiffIndexer(
context.Background(), ChainConfig, node.Info{}, dbConfig) context.Background(), ChainConfig, node.Info{}, dbConfig)
if err != nil { if err != nil {
@ -28,11 +31,10 @@ func IndexChain(dbConfig postgres.Config, stateCache state.Database, rootA, root
} }
defer indexer.Close() // fixme: hangs when using PGX driver defer indexer.Close() // fixme: hangs when using PGX driver
// generating statediff payload for block, and transform the data into Postgres
builder := statediff.NewBuilder(stateCache) builder := statediff.NewBuilder(stateCache)
block := types.NewBlock(&types.Header{Root: rootB}, nil, nil, nil, NewHasher()) block := types.NewBlock(&types.Header{Root: rootB}, nil, nil, nil, NewHasher())
// todo: use dummy block hashes to just produce trie structure for testing // uses zero block hash/number, we only need the trie structure here
args := statediff.Args{ args := statediff.Args{
OldStateRoot: rootA, OldStateRoot: rootA,
NewStateRoot: rootB, NewStateRoot: rootB,
@ -45,12 +47,7 @@ func IndexChain(dbConfig postgres.Config, stateCache state.Database, rootA, root
if err != nil { if err != nil {
return err return err
} }
// for _, node := range diff.Nodes { // we don't need to index diff.Nodes since we are just interested in the trie
// err := indexer.PushStateNode(tx, node, block.Hash().String())
// if err != nil {
// return err
// }
// }
for _, ipld := range diff.IPLDs { for _, ipld := range diff.IPLDs {
if err := indexer.PushIPLD(tx, ipld); err != nil { if err := indexer.PushIPLD(tx, ipld); err != nil {
return err return err

View File

@ -0,0 +1,136 @@
// Copyright 2020 The go-ethereum Authors
// This file is part of the go-ethereum library.
//
// The go-ethereum library is free software: you can redistribute it and/or modify
// it under the terms of the GNU Lesser General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// The go-ethereum library is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Lesser General Public License for more details.
//
// You should have received a copy of the GNU Lesser General Public License
// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
package state
import (
"github.com/ethereum/go-ethereum/common"
)
type accessList struct {
addresses map[common.Address]int
slots []map[common.Hash]struct{}
}
// ContainsAddress returns true if the address is in the access list.
func (al *accessList) ContainsAddress(address common.Address) bool {
_, ok := al.addresses[address]
return ok
}
// Contains checks if a slot within an account is present in the access list, returning
// separate flags for the presence of the account and the slot respectively.
func (al *accessList) Contains(address common.Address, slot common.Hash) (addressPresent bool, slotPresent bool) {
idx, ok := al.addresses[address]
if !ok {
// no such address (and hence zero slots)
return false, false
}
if idx == -1 {
// address yes, but no slots
return true, false
}
_, slotPresent = al.slots[idx][slot]
return true, slotPresent
}
// newAccessList creates a new accessList.
func newAccessList() *accessList {
return &accessList{
addresses: make(map[common.Address]int),
}
}
// Copy creates an independent copy of an accessList.
func (a *accessList) Copy() *accessList {
cp := newAccessList()
for k, v := range a.addresses {
cp.addresses[k] = v
}
cp.slots = make([]map[common.Hash]struct{}, len(a.slots))
for i, slotMap := range a.slots {
newSlotmap := make(map[common.Hash]struct{}, len(slotMap))
for k := range slotMap {
newSlotmap[k] = struct{}{}
}
cp.slots[i] = newSlotmap
}
return cp
}
// AddAddress adds an address to the access list, and returns 'true' if the operation
// caused a change (addr was not previously in the list).
func (al *accessList) AddAddress(address common.Address) bool {
if _, present := al.addresses[address]; present {
return false
}
al.addresses[address] = -1
return true
}
// AddSlot adds the specified (addr, slot) combo to the access list.
// Return values are:
// - address added
// - slot added
// For any 'true' value returned, a corresponding journal entry must be made.
func (al *accessList) AddSlot(address common.Address, slot common.Hash) (addrChange bool, slotChange bool) {
idx, addrPresent := al.addresses[address]
if !addrPresent || idx == -1 {
// Address not present, or addr present but no slots there
al.addresses[address] = len(al.slots)
slotmap := map[common.Hash]struct{}{slot: {}}
al.slots = append(al.slots, slotmap)
return !addrPresent, true
}
// There is already an (address,slot) mapping
slotmap := al.slots[idx]
if _, ok := slotmap[slot]; !ok {
slotmap[slot] = struct{}{}
// Journal add slot change
return false, true
}
// No changes required
return false, false
}
// DeleteSlot removes an (address, slot)-tuple from the access list.
// This operation needs to be performed in the same order as the addition happened.
// This method is meant to be used by the journal, which maintains ordering of
// operations.
func (al *accessList) DeleteSlot(address common.Address, slot common.Hash) {
idx, addrOk := al.addresses[address]
// There are two ways this can fail
if !addrOk {
panic("reverting slot change, address not present in list")
}
slotmap := al.slots[idx]
delete(slotmap, slot)
// If that was the last (first) slot, remove it
// Since additions and rollbacks are always performed in order,
// we can delete the item without worrying about screwing up later indices
if len(slotmap) == 0 {
al.slots = al.slots[:idx]
al.addresses[address] = -1
}
}
// DeleteAddress removes an address from the access list. This operation
// needs to be performed in the same order as the addition happened.
// This method is meant to be used by the journal, which maintains ordering of
// operations.
func (al *accessList) DeleteAddress(address common.Address) {
delete(al.addresses, address)
}

View File

@ -1,14 +1,30 @@
// Copyright 2017 The go-ethereum Authors
// This file is part of the go-ethereum library.
//
// The go-ethereum library is free software: you can redistribute it and/or modify
// it under the terms of the GNU Lesser General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// The go-ethereum library is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Lesser General Public License for more details.
//
// You should have received a copy of the GNU Lesser General Public License
// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
package state package state
import ( import (
"errors" "errors"
"fmt"
"github.com/VictoriaMetrics/fastcache"
"github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/common/lru"
"github.com/ethereum/go-ethereum/core/types" "github.com/ethereum/go-ethereum/core/types"
"github.com/ethereum/go-ethereum/ethdb" "github.com/ethereum/go-ethereum/ethdb"
"github.com/ethereum/go-ethereum/statediff/indexer/ipld" "github.com/ethereum/go-ethereum/statediff/indexer/ipld"
lru "github.com/hashicorp/golang-lru"
"github.com/cerc-io/ipld-eth-statedb/internal" "github.com/cerc-io/ipld-eth-statedb/internal"
"github.com/cerc-io/ipld-eth-statedb/trie_by_cid/trie" "github.com/cerc-io/ipld-eth-statedb/trie_by_cid/trie"
@ -28,7 +44,10 @@ type Database interface {
OpenTrie(root common.Hash) (Trie, error) OpenTrie(root common.Hash) (Trie, error)
// OpenStorageTrie opens the storage trie of an account. // OpenStorageTrie opens the storage trie of an account.
OpenStorageTrie(addrHash, root common.Hash) (Trie, error) OpenStorageTrie(stateRoot common.Hash, addrHash, root common.Hash) (Trie, error)
// CopyTrie returns an independent copy of the given trie.
CopyTrie(Trie) Trie
// ContractCode retrieves a particular contract's code. // ContractCode retrieves a particular contract's code.
ContractCode(codeHash common.Hash) ([]byte, error) ContractCode(codeHash common.Hash) ([]byte, error)
@ -36,17 +55,79 @@ type Database interface {
// ContractCodeSize retrieves a particular contracts code's size. // ContractCodeSize retrieves a particular contracts code's size.
ContractCodeSize(codeHash common.Hash) (int, error) ContractCodeSize(codeHash common.Hash) (int, error)
// DiskDB returns the underlying key-value disk database.
DiskDB() ethdb.KeyValueStore
// TrieDB retrieves the low level trie database used for data storage. // TrieDB retrieves the low level trie database used for data storage.
TrieDB() *trie.Database TrieDB() *trie.Database
} }
// Trie is a Ethereum Merkle Patricia trie. // Trie is a Ethereum Merkle Patricia trie.
type Trie interface { type Trie interface {
// GetKey returns the sha3 preimage of a hashed key that was previously used
// to store a value.
//
// TODO(fjl): remove this when StateTrie is removed
GetKey([]byte) []byte
// TryGet returns the value for key stored in the trie. The value bytes must
// not be modified by the caller. If a node was not found in the database, a
// trie.MissingNodeError is returned.
TryGet(key []byte) ([]byte, error) TryGet(key []byte) ([]byte, error)
// TryGetNode attempts to retrieve a trie node by compact-encoded path. It is not
// possible to use keybyte-encoding as the path might contain odd nibbles.
TryGetNode(path []byte) ([]byte, int, error) TryGetNode(path []byte) ([]byte, int, error)
TryGetAccount(key []byte) (*types.StateAccount, error)
// TryGetAccount abstracts an account read from the trie. It retrieves the
// account blob from the trie with provided account address and decodes it
// with associated decoding algorithm. If the specified account is not in
// the trie, nil will be returned. If the trie is corrupted(e.g. some nodes
// are missing or the account blob is incorrect for decoding), an error will
// be returned.
TryGetAccount(address common.Address) (*types.StateAccount, error)
// TryUpdate associates key with value in the trie. If value has length zero, any
// existing value is deleted from the trie. The value bytes must not be modified
// by the caller while they are stored in the trie. If a node was not found in the
// database, a trie.MissingNodeError is returned.
TryUpdate(key, value []byte) error
// TryUpdateAccount abstracts an account write to the trie. It encodes the
// provided account object with associated algorithm and then updates it
// in the trie with provided address.
TryUpdateAccount(address common.Address, account *types.StateAccount) error
// TryDelete removes any existing value for key from the trie. If a node was not
// found in the database, a trie.MissingNodeError is returned.
TryDelete(key []byte) error
// TryDeleteAccount abstracts an account deletion from the trie.
TryDeleteAccount(address common.Address) error
// Hash returns the root hash of the trie. It does not write to the database and
// can be used even if the trie doesn't have one.
Hash() common.Hash Hash() common.Hash
// Commit collects all dirty nodes in the trie and replace them with the
// corresponding node hash. All collected nodes(including dirty leaves if
// collectLeaf is true) will be encapsulated into a nodeset for return.
// The returned nodeset can be nil if the trie is clean(nothing to commit).
// Once the trie is committed, it's not usable anymore. A new trie must
// be created with new root and updated trie database for following usage
Commit(collectLeaf bool) (common.Hash, *trie.NodeSet)
// NodeIterator returns an iterator that returns nodes of the trie. Iteration
// starts at the key after the given start key.
NodeIterator(startKey []byte) trie.NodeIterator NodeIterator(startKey []byte) trie.NodeIterator
// Prove constructs a Merkle proof for key. The result contains all encoded nodes
// on the path to the value at key. The value itself is also included in the last
// node and can be retrieved by verifying the proof.
//
// If the trie does not contain a value for key, the returned proof contains all
// nodes of the longest existing prefix of the key (at least the root), ending
// with the node that proves the absence of the key.
Prove(key []byte, fromLevel uint, proofDb ethdb.KeyValueWriter) error Prove(key []byte, fromLevel uint, proofDb ethdb.KeyValueWriter) error
} }
@ -61,23 +142,34 @@ func NewDatabase(db ethdb.Database) Database {
// is safe for concurrent use and retains a lot of collapsed RLP trie nodes in a // is safe for concurrent use and retains a lot of collapsed RLP trie nodes in a
// large memory cache. // large memory cache.
func NewDatabaseWithConfig(db ethdb.Database, config *trie.Config) Database { func NewDatabaseWithConfig(db ethdb.Database, config *trie.Config) Database {
csc, _ := lru.New(codeSizeCacheSize)
return &cachingDB{ return &cachingDB{
db: trie.NewDatabaseWithConfig(db, config), disk: db,
codeSizeCache: csc, codeSizeCache: lru.NewCache[common.Hash, int](codeSizeCacheSize),
codeCache: fastcache.New(codeCacheSize), codeCache: lru.NewSizeConstrainedCache[common.Hash, []byte](codeCacheSize),
triedb: trie.NewDatabaseWithConfig(db, config),
}
}
// NewDatabaseWithNodeDB creates a state database with an already initialized node database.
func NewDatabaseWithNodeDB(db ethdb.Database, triedb *trie.Database) Database {
return &cachingDB{
disk: db,
codeSizeCache: lru.NewCache[common.Hash, int](codeSizeCacheSize),
codeCache: lru.NewSizeConstrainedCache[common.Hash, []byte](codeCacheSize),
triedb: triedb,
} }
} }
type cachingDB struct { type cachingDB struct {
db *trie.Database disk ethdb.KeyValueStore
codeSizeCache *lru.Cache codeSizeCache *lru.Cache[common.Hash, int]
codeCache *fastcache.Cache codeCache *lru.SizeConstrainedCache[common.Hash, []byte]
triedb *trie.Database
} }
// OpenTrie opens the main account trie at a specific root hash. // OpenTrie opens the main account trie at a specific root hash.
func (db *cachingDB) OpenTrie(root common.Hash) (Trie, error) { func (db *cachingDB) OpenTrie(root common.Hash) (Trie, error) {
tr, err := trie.NewStateTrie(common.Hash{}, root, db.db) tr, err := trie.NewStateTrie(trie.StateTrieID(root), db.triedb, trie.StateTrieCodec)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -85,29 +177,40 @@ func (db *cachingDB) OpenTrie(root common.Hash) (Trie, error) {
} }
// OpenStorageTrie opens the storage trie of an account. // OpenStorageTrie opens the storage trie of an account.
func (db *cachingDB) OpenStorageTrie(addrHash, root common.Hash) (Trie, error) { func (db *cachingDB) OpenStorageTrie(stateRoot common.Hash, addrHash, root common.Hash) (Trie, error) {
tr, err := trie.NewStorageTrie(addrHash, root, db.db) tr, err := trie.NewStateTrie(trie.StorageTrieID(stateRoot, addrHash, root), db.triedb, trie.StorageTrieCodec)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return tr, nil return tr, nil
} }
// CopyTrie returns an independent copy of the given trie.
func (db *cachingDB) CopyTrie(t Trie) Trie {
switch t := t.(type) {
case *trie.StateTrie:
return t.Copy()
default:
panic(fmt.Errorf("unknown trie type %T", t))
}
}
// ContractCode retrieves a particular contract's code. // ContractCode retrieves a particular contract's code.
func (db *cachingDB) ContractCode(codeHash common.Hash) ([]byte, error) { func (db *cachingDB) ContractCode(codeHash common.Hash) ([]byte, error) {
if code := db.codeCache.Get(nil, codeHash.Bytes()); len(code) > 0 { code, _ := db.codeCache.Get(codeHash)
if len(code) > 0 {
return code, nil return code, nil
} }
codeCID, err := internal.Keccak256ToCid(ipld.RawBinary, codeHash.Bytes()) cid, err := internal.Keccak256ToCid(ipld.RawBinary, codeHash.Bytes())
if err != nil { if err != nil {
return nil, err return nil, err
} }
code, err := db.db.DiskDB().Get(codeCID.Bytes()) code, err = db.disk.Get(cid.Bytes())
if err != nil { if err != nil {
return nil, err return nil, err
} }
if len(code) > 0 { if len(code) > 0 {
db.codeCache.Set(codeHash.Bytes(), code) db.codeCache.Add(codeHash, code)
db.codeSizeCache.Add(codeHash, len(code)) db.codeSizeCache.Add(codeHash, len(code))
return code, nil return code, nil
} }
@ -117,13 +220,18 @@ func (db *cachingDB) ContractCode(codeHash common.Hash) ([]byte, error) {
// ContractCodeSize retrieves a particular contracts code's size. // ContractCodeSize retrieves a particular contracts code's size.
func (db *cachingDB) ContractCodeSize(codeHash common.Hash) (int, error) { func (db *cachingDB) ContractCodeSize(codeHash common.Hash) (int, error) {
if cached, ok := db.codeSizeCache.Get(codeHash); ok { if cached, ok := db.codeSizeCache.Get(codeHash); ok {
return cached.(int), nil return cached, nil
} }
code, err := db.ContractCode(codeHash) code, err := db.ContractCode(codeHash)
return len(code), err return len(code), err
} }
// DiskDB returns the underlying key-value disk database.
func (db *cachingDB) DiskDB() ethdb.KeyValueStore {
return db.disk
}
// TrieDB retrieves any intermediate trie-node caching layer. // TrieDB retrieves any intermediate trie-node caching layer.
func (db *cachingDB) TrieDB() *trie.Database { func (db *cachingDB) TrieDB() *trie.Database {
return db.db return db.triedb
} }

View File

@ -0,0 +1,282 @@
// Copyright 2016 The go-ethereum Authors
// This file is part of the go-ethereum library.
//
// The go-ethereum library is free software: you can redistribute it and/or modify
// it under the terms of the GNU Lesser General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// The go-ethereum library is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Lesser General Public License for more details.
//
// You should have received a copy of the GNU Lesser General Public License
// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
package state
import (
"math/big"
"github.com/ethereum/go-ethereum/common"
)
// journalEntry is a modification entry in the state change journal that can be
// reverted on demand.
type journalEntry interface {
// revert undoes the changes introduced by this journal entry.
revert(*StateDB)
// dirtied returns the Ethereum address modified by this journal entry.
dirtied() *common.Address
}
// journal contains the list of state modifications applied since the last state
// commit. These are tracked to be able to be reverted in the case of an execution
// exception or request for reversal.
type journal struct {
entries []journalEntry // Current changes tracked by the journal
dirties map[common.Address]int // Dirty accounts and the number of changes
}
// newJournal creates a new initialized journal.
func newJournal() *journal {
return &journal{
dirties: make(map[common.Address]int),
}
}
// append inserts a new modification entry to the end of the change journal.
func (j *journal) append(entry journalEntry) {
j.entries = append(j.entries, entry)
if addr := entry.dirtied(); addr != nil {
j.dirties[*addr]++
}
}
// revert undoes a batch of journalled modifications along with any reverted
// dirty handling too.
func (j *journal) revert(statedb *StateDB, snapshot int) {
for i := len(j.entries) - 1; i >= snapshot; i-- {
// Undo the changes made by the operation
j.entries[i].revert(statedb)
// Drop any dirty tracking induced by the change
if addr := j.entries[i].dirtied(); addr != nil {
if j.dirties[*addr]--; j.dirties[*addr] == 0 {
delete(j.dirties, *addr)
}
}
}
j.entries = j.entries[:snapshot]
}
// dirty explicitly sets an address to dirty, even if the change entries would
// otherwise suggest it as clean. This method is an ugly hack to handle the RIPEMD
// precompile consensus exception.
func (j *journal) dirty(addr common.Address) {
j.dirties[addr]++
}
// length returns the current number of entries in the journal.
func (j *journal) length() int {
return len(j.entries)
}
type (
// Changes to the account trie.
createObjectChange struct {
account *common.Address
}
resetObjectChange struct {
prev *stateObject
prevdestruct bool
}
suicideChange struct {
account *common.Address
prev bool // whether account had already suicided
prevbalance *big.Int
}
// Changes to individual accounts.
balanceChange struct {
account *common.Address
prev *big.Int
}
nonceChange struct {
account *common.Address
prev uint64
}
storageChange struct {
account *common.Address
key, prevalue common.Hash
}
codeChange struct {
account *common.Address
prevcode, prevhash []byte
}
// Changes to other state values.
refundChange struct {
prev uint64
}
addLogChange struct {
txhash common.Hash
}
addPreimageChange struct {
hash common.Hash
}
touchChange struct {
account *common.Address
}
// Changes to the access list
accessListAddAccountChange struct {
address *common.Address
}
accessListAddSlotChange struct {
address *common.Address
slot *common.Hash
}
transientStorageChange struct {
account *common.Address
key, prevalue common.Hash
}
)
func (ch createObjectChange) revert(s *StateDB) {
delete(s.stateObjects, *ch.account)
delete(s.stateObjectsDirty, *ch.account)
}
func (ch createObjectChange) dirtied() *common.Address {
return ch.account
}
func (ch resetObjectChange) revert(s *StateDB) {
s.setStateObject(ch.prev)
if !ch.prevdestruct {
delete(s.stateObjectsDestruct, ch.prev.address)
}
}
func (ch resetObjectChange) dirtied() *common.Address {
return nil
}
func (ch suicideChange) revert(s *StateDB) {
obj := s.getStateObject(*ch.account)
if obj != nil {
obj.suicided = ch.prev
obj.setBalance(ch.prevbalance)
}
}
func (ch suicideChange) dirtied() *common.Address {
return ch.account
}
var ripemd = common.HexToAddress("0000000000000000000000000000000000000003")
func (ch touchChange) revert(s *StateDB) {
}
func (ch touchChange) dirtied() *common.Address {
return ch.account
}
func (ch balanceChange) revert(s *StateDB) {
s.getStateObject(*ch.account).setBalance(ch.prev)
}
func (ch balanceChange) dirtied() *common.Address {
return ch.account
}
func (ch nonceChange) revert(s *StateDB) {
s.getStateObject(*ch.account).setNonce(ch.prev)
}
func (ch nonceChange) dirtied() *common.Address {
return ch.account
}
func (ch codeChange) revert(s *StateDB) {
s.getStateObject(*ch.account).setCode(common.BytesToHash(ch.prevhash), ch.prevcode)
}
func (ch codeChange) dirtied() *common.Address {
return ch.account
}
func (ch storageChange) revert(s *StateDB) {
s.getStateObject(*ch.account).setState(ch.key, ch.prevalue)
}
func (ch storageChange) dirtied() *common.Address {
return ch.account
}
func (ch transientStorageChange) revert(s *StateDB) {
s.setTransientState(*ch.account, ch.key, ch.prevalue)
}
func (ch transientStorageChange) dirtied() *common.Address {
return nil
}
func (ch refundChange) revert(s *StateDB) {
s.refund = ch.prev
}
func (ch refundChange) dirtied() *common.Address {
return nil
}
func (ch addLogChange) revert(s *StateDB) {
logs := s.logs[ch.txhash]
if len(logs) == 1 {
delete(s.logs, ch.txhash)
} else {
s.logs[ch.txhash] = logs[:len(logs)-1]
}
s.logSize--
}
func (ch addLogChange) dirtied() *common.Address {
return nil
}
func (ch addPreimageChange) revert(s *StateDB) {
delete(s.preimages, ch.hash)
}
func (ch addPreimageChange) dirtied() *common.Address {
return nil
}
func (ch accessListAddAccountChange) revert(s *StateDB) {
/*
One important invariant here, is that whenever a (addr, slot) is added, if the
addr is not already present, the add causes two journal entries:
- one for the address,
- one for the (address,slot)
Therefore, when unrolling the change, we can always blindly delete the
(addr) at this point, since no storage adds can remain when come upon
a single (addr) change.
*/
s.accessList.DeleteAddress(*ch.address)
}
func (ch accessListAddAccountChange) dirtied() *common.Address {
return nil
}
func (ch accessListAddSlotChange) revert(s *StateDB) {
s.accessList.DeleteSlot(*ch.address, *ch.slot)
}
func (ch accessListAddSlotChange) dirtied() *common.Address {
return nil
}

View File

@ -0,0 +1,30 @@
// Copyright 2021 The go-ethereum Authors
// This file is part of the go-ethereum library.
//
// The go-ethereum library is free software: you can redistribute it and/or modify
// it under the terms of the GNU Lesser General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// The go-ethereum library is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Lesser General Public License for more details.
//
// You should have received a copy of the GNU Lesser General Public License
// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
package state
import "github.com/ethereum/go-ethereum/metrics"
var (
accountUpdatedMeter = metrics.NewRegisteredMeter("state/update/account", nil)
storageUpdatedMeter = metrics.NewRegisteredMeter("state/update/storage", nil)
accountDeletedMeter = metrics.NewRegisteredMeter("state/delete/account", nil)
storageDeletedMeter = metrics.NewRegisteredMeter("state/delete/storage", nil)
accountTrieUpdatedMeter = metrics.NewRegisteredMeter("state/update/accountnodes", nil)
storageTriesUpdatedMeter = metrics.NewRegisteredMeter("state/update/storagenodes", nil)
accountTrieDeletedMeter = metrics.NewRegisteredMeter("state/delete/accountnodes", nil)
storageTriesDeletedMeter = metrics.NewRegisteredMeter("state/delete/storagenodes", nil)
)

View File

@ -1,8 +1,25 @@
// Copyright 2014 The go-ethereum Authors
// This file is part of the go-ethereum library.
//
// The go-ethereum library is free software: you can redistribute it and/or modify
// it under the terms of the GNU Lesser General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// The go-ethereum library is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Lesser General Public License for more details.
//
// You should have received a copy of the GNU Lesser General Public License
// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
package state package state
import ( import (
"bytes" "bytes"
"fmt" "fmt"
"io"
"math/big" "math/big"
"time" "time"
@ -11,15 +28,7 @@ import (
"github.com/ethereum/go-ethereum/crypto" "github.com/ethereum/go-ethereum/crypto"
"github.com/ethereum/go-ethereum/metrics" "github.com/ethereum/go-ethereum/metrics"
"github.com/ethereum/go-ethereum/rlp" "github.com/ethereum/go-ethereum/rlp"
) // "github.com/cerc-io/ipld-eth-statedb/trie_by_cid/trie"
var (
// emptyRoot is the known root hash of an empty trie.
// this is calculated as: emptyRoot = crypto.Keccak256(rlp.Encode([][]byte{}))
// that is, the keccak356 hash of the rlp encoding of an empty trie node (empty byte slice array)
emptyRoot = common.HexToHash("56e81f171bcc55a6ff8345e692c0f86e5b48e01b996cadc001622fb5e363b421")
// emptyCodeHash is the CodeHash for an EOA, for an account without contract code deployed
emptyCodeHash = crypto.Keccak256(nil)
) )
type Code []byte type Code []byte
@ -34,7 +43,6 @@ func (s Storage) String() (str string) {
for key, value := range s { for key, value := range s {
str += fmt.Sprintf("%X : %X\n", key, value) str += fmt.Sprintf("%X : %X\n", key, value)
} }
return return
} }
@ -43,32 +51,39 @@ func (s Storage) Copy() Storage {
for key, value := range s { for key, value := range s {
cpy[key] = value cpy[key] = value
} }
return cpy return cpy
} }
// stateObject represents an Ethereum account which is being accessed. // stateObject represents an Ethereum account which is being modified.
// //
// The usage pattern is as follows: // The usage pattern is as follows:
// First you need to obtain a state object. // First you need to obtain a state object.
// Account values can be accessed through the object. // Account values can be accessed and modified through the object.
type stateObject struct { type stateObject struct {
address common.Address address common.Address
addrHash common.Hash // hash of ethereum address of the account addrHash common.Hash // hash of ethereum address of the account
data types.StateAccount data types.StateAccount
db *StateDB db *StateDB
// Caches. // Write caches.
trie Trie // storage trie, which becomes non-nil on first access trie Trie // storage trie, which becomes non-nil on first access
code Code // contract bytecode, which gets set when code is loaded code Code // contract bytecode, which gets set when code is loaded
originStorage Storage // Storage cache of original entries to dedup rewrites, reset for every transaction originStorage Storage // Storage cache of original entries to dedup rewrites, reset for every transaction
fakeStorage Storage // Fake storage which constructed by caller for debugging purpose. pendingStorage Storage // Storage entries that need to be flushed to disk, at the end of an entire block
dirtyStorage Storage // Storage entries that have been modified in the current transaction execution
// Cache flags.
// When an object is marked suicided it will be deleted from the trie
// during the "update" phase of the state transition.
dirtyCode bool // true if the code was updated
suicided bool
deleted bool
} }
// empty returns whether the account is considered empty. // empty returns whether the account is considered empty.
func (s *stateObject) empty() bool { func (s *stateObject) empty() bool {
return s.data.Nonce == 0 && s.data.Balance.Sign() == 0 && bytes.Equal(s.data.CodeHash, emptyCodeHash) return s.data.Nonce == 0 && s.data.Balance.Sign() == 0 && bytes.Equal(s.data.CodeHash, types.EmptyCodeHash.Bytes())
} }
// newObject creates a state object. // newObject creates a state object.
@ -77,71 +92,128 @@ func newObject(db *StateDB, address common.Address, data types.StateAccount) *st
data.Balance = new(big.Int) data.Balance = new(big.Int)
} }
if data.CodeHash == nil { if data.CodeHash == nil {
data.CodeHash = emptyCodeHash data.CodeHash = types.EmptyCodeHash.Bytes()
} }
if data.Root == (common.Hash{}) { if data.Root == (common.Hash{}) {
data.Root = emptyRoot data.Root = types.EmptyRootHash
} }
return &stateObject{ return &stateObject{
db: db, db: db,
address: address, address: address,
addrHash: crypto.Keccak256Hash(address[:]), addrHash: crypto.Keccak256Hash(address[:]),
data: data, data: data,
originStorage: make(Storage), originStorage: make(Storage),
pendingStorage: make(Storage),
dirtyStorage: make(Storage),
} }
} }
// setError remembers the first non-nil error it is called with. // EncodeRLP implements rlp.Encoder.
func (s *stateObject) setError(err error) { func (s *stateObject) EncodeRLP(w io.Writer) error {
s.db.setError(err) return rlp.Encode(w, &s.data)
} }
func (s *stateObject) getTrie(db Database) Trie { func (s *stateObject) markSuicided() {
s.suicided = true
}
func (s *stateObject) touch() {
s.db.journal.append(touchChange{
account: &s.address,
})
if s.address == ripemd {
// Explicitly put it in the dirty-cache, which is otherwise generated from
// flattened journals.
s.db.journal.dirty(s.address)
}
}
// getTrie returns the associated storage trie. The trie will be opened
// if it's not loaded previously. An error will be returned if trie can't
// be loaded.
func (s *stateObject) getTrie(db Database) (Trie, error) {
if s.trie == nil { if s.trie == nil {
// // Try fetching from prefetcher first // Try fetching from prefetcher first
// // We don't prefetch empty tries // We don't prefetch empty tries
// if s.data.Root != emptyRoot && s.db.prefetcher != nil { if s.data.Root != types.EmptyRootHash && s.db.prefetcher != nil {
// // When the miner is creating the pending state, there is no // When the miner is creating the pending state, there is no
// // prefetcher // prefetcher
// s.trie = s.db.prefetcher.trie(s.addrHash, s.data.Root) s.trie = s.db.prefetcher.trie(s.addrHash, s.data.Root)
// } }
if s.trie == nil { if s.trie == nil {
var err error tr, err := db.OpenStorageTrie(s.db.originalRoot, s.addrHash, s.data.Root)
s.trie, err = db.OpenStorageTrie(s.addrHash, s.data.Root)
if err != nil { if err != nil {
s.trie, _ = db.OpenStorageTrie(s.addrHash, common.Hash{}) return nil, err
s.setError(fmt.Errorf("can't create storage trie: %w", err))
} }
s.trie = tr
} }
} }
return s.trie return s.trie, nil
}
// GetState retrieves a value from the account storage trie.
func (s *stateObject) GetState(db Database, key common.Hash) common.Hash {
// If we have a dirty value for this state entry, return it
value, dirty := s.dirtyStorage[key]
if dirty {
return value
}
// Otherwise return the entry's original value
return s.GetCommittedState(db, key)
} }
// GetCommittedState retrieves a value from the committed account storage trie. // GetCommittedState retrieves a value from the committed account storage trie.
func (s *stateObject) GetState(db Database, key common.Hash) common.Hash { func (s *stateObject) GetCommittedState(db Database, key common.Hash) common.Hash {
// If the fake storage is set, only lookup the state here(in the debugging mode) // If we have a pending write or clean cached, return that
if s.fakeStorage != nil { if value, pending := s.pendingStorage[key]; pending {
return s.fakeStorage[key] return value
} }
// If we have a cached value, return that
if value, cached := s.originStorage[key]; cached { if value, cached := s.originStorage[key]; cached {
return value return value
} }
// If no live objects are available, load from the database. // If the object was destructed in *this* block (and potentially resurrected),
start := time.Now() // the storage has been cleared out, and we should *not* consult the previous
enc, err := s.getTrie(db).TryGet(key.Bytes()) // database about any storage values. The only possible alternatives are:
if metrics.EnabledExpensive { // 1) resurrect happened, and new slot values were set -- those should
s.db.StorageReads += time.Since(start) // have been handles via pendingStorage above.
} // 2) we don't have new values, and can deliver empty response back
if err != nil { if _, destructed := s.db.stateObjectsDestruct[s.address]; destructed {
s.setError(err)
return common.Hash{} return common.Hash{}
} }
// If no live objects are available, attempt to use snapshots
var (
enc []byte
err error
)
if s.db.snap != nil {
start := time.Now()
enc, err = s.db.snap.Storage(s.addrHash, crypto.Keccak256Hash(key.Bytes()))
if metrics.EnabledExpensive {
s.db.SnapshotStorageReads += time.Since(start)
}
}
// If the snapshot is unavailable or reading from it fails, load from the database.
if s.db.snap == nil || err != nil {
start := time.Now()
tr, err := s.getTrie(db)
if err != nil {
s.db.setError(err)
return common.Hash{}
}
enc, err = tr.TryGet(key.Bytes())
if metrics.EnabledExpensive {
s.db.StorageReads += time.Since(start)
}
if err != nil {
s.db.setError(err)
return common.Hash{}
}
}
var value common.Hash var value common.Hash
if len(enc) > 0 { if len(enc) > 0 {
_, content, _, err := rlp.Split(enc) _, content, _, err := rlp.Split(enc)
if err != nil { if err != nil {
s.setError(err) s.db.setError(err)
} }
value.SetBytes(content) value.SetBytes(content)
} }
@ -149,6 +221,182 @@ func (s *stateObject) GetState(db Database, key common.Hash) common.Hash {
return value return value
} }
// SetState updates a value in account storage.
func (s *stateObject) SetState(db Database, key, value common.Hash) {
// If the new value is the same as old, don't set
prev := s.GetState(db, key)
if prev == value {
return
}
// New value is different, update and journal the change
s.db.journal.append(storageChange{
account: &s.address,
key: key,
prevalue: prev,
})
s.setState(key, value)
}
func (s *stateObject) setState(key, value common.Hash) {
s.dirtyStorage[key] = value
}
// finalise moves all dirty storage slots into the pending area to be hashed or
// committed later. It is invoked at the end of every transaction.
func (s *stateObject) finalise(prefetch bool) {
slotsToPrefetch := make([][]byte, 0, len(s.dirtyStorage))
for key, value := range s.dirtyStorage {
s.pendingStorage[key] = value
if value != s.originStorage[key] {
slotsToPrefetch = append(slotsToPrefetch, common.CopyBytes(key[:])) // Copy needed for closure
}
}
if s.db.prefetcher != nil && prefetch && len(slotsToPrefetch) > 0 && s.data.Root != types.EmptyRootHash {
s.db.prefetcher.prefetch(s.addrHash, s.data.Root, slotsToPrefetch)
}
if len(s.dirtyStorage) > 0 {
s.dirtyStorage = make(Storage)
}
}
// updateTrie writes cached storage modifications into the object's storage trie.
// It will return nil if the trie has not been loaded and no changes have been
// made. An error will be returned if the trie can't be loaded/updated correctly.
func (s *stateObject) updateTrie(db Database) (Trie, error) {
// Make sure all dirty slots are finalized into the pending storage area
s.finalise(false) // Don't prefetch anymore, pull directly if need be
if len(s.pendingStorage) == 0 {
return s.trie, nil
}
// Track the amount of time wasted on updating the storage trie
if metrics.EnabledExpensive {
defer func(start time.Time) { s.db.StorageUpdates += time.Since(start) }(time.Now())
}
// The snapshot storage map for the object
var (
storage map[common.Hash][]byte
hasher = s.db.hasher
)
tr, err := s.getTrie(db)
if err != nil {
s.db.setError(err)
return nil, err
}
// Insert all the pending updates into the trie
usedStorage := make([][]byte, 0, len(s.pendingStorage))
for key, value := range s.pendingStorage {
// Skip noop changes, persist actual changes
if value == s.originStorage[key] {
continue
}
s.originStorage[key] = value
var v []byte
if (value == common.Hash{}) {
if err := tr.TryDelete(key[:]); err != nil {
s.db.setError(err)
return nil, err
}
s.db.StorageDeleted += 1
} else {
// Encoding []byte cannot fail, ok to ignore the error.
v, _ = rlp.EncodeToBytes(common.TrimLeftZeroes(value[:]))
if err := tr.TryUpdate(key[:], v); err != nil {
s.db.setError(err)
return nil, err
}
s.db.StorageUpdated += 1
}
// If state snapshotting is active, cache the data til commit
if s.db.snap != nil {
if storage == nil {
// Retrieve the old storage map, if available, create a new one otherwise
if storage = s.db.snapStorage[s.addrHash]; storage == nil {
storage = make(map[common.Hash][]byte)
s.db.snapStorage[s.addrHash] = storage
}
}
storage[crypto.HashData(hasher, key[:])] = v // v will be nil if it's deleted
}
usedStorage = append(usedStorage, common.CopyBytes(key[:])) // Copy needed for closure
}
if s.db.prefetcher != nil {
s.db.prefetcher.used(s.addrHash, s.data.Root, usedStorage)
}
if len(s.pendingStorage) > 0 {
s.pendingStorage = make(Storage)
}
return tr, nil
}
// UpdateRoot sets the trie root to the current root hash of. An error
// will be returned if trie root hash is not computed correctly.
func (s *stateObject) updateRoot(db Database) {
tr, err := s.updateTrie(db)
if err != nil {
return
}
// If nothing changed, don't bother with hashing anything
if tr == nil {
return
}
// Track the amount of time wasted on hashing the storage trie
if metrics.EnabledExpensive {
defer func(start time.Time) { s.db.StorageHashes += time.Since(start) }(time.Now())
}
s.data.Root = tr.Hash()
}
// AddBalance adds amount to s's balance.
// It is used to add funds to the destination account of a transfer.
func (s *stateObject) AddBalance(amount *big.Int) {
// EIP161: We must check emptiness for the objects such that the account
// clearing (0,0,0 objects) can take effect.
if amount.Sign() == 0 {
if s.empty() {
s.touch()
}
return
}
s.SetBalance(new(big.Int).Add(s.Balance(), amount))
}
// SubBalance removes amount from s's balance.
// It is used to remove funds from the origin account of a transfer.
func (s *stateObject) SubBalance(amount *big.Int) {
if amount.Sign() == 0 {
return
}
s.SetBalance(new(big.Int).Sub(s.Balance(), amount))
}
func (s *stateObject) SetBalance(amount *big.Int) {
s.db.journal.append(balanceChange{
account: &s.address,
prev: new(big.Int).Set(s.data.Balance),
})
s.setBalance(amount)
}
func (s *stateObject) setBalance(amount *big.Int) {
s.data.Balance = amount
}
func (s *stateObject) deepCopy(db *StateDB) *stateObject {
stateObject := newObject(db, s.address, s.data)
if s.trie != nil {
stateObject.trie = db.db.CopyTrie(s.trie)
}
stateObject.code = s.code
stateObject.dirtyStorage = s.dirtyStorage.Copy()
stateObject.originStorage = s.originStorage.Copy()
stateObject.pendingStorage = s.pendingStorage.Copy()
stateObject.suicided = s.suicided
stateObject.dirtyCode = s.dirtyCode
stateObject.deleted = s.deleted
return stateObject
}
// //
// Attribute accessors // Attribute accessors
// //
@ -163,12 +411,12 @@ func (s *stateObject) Code(db Database) []byte {
if s.code != nil { if s.code != nil {
return s.code return s.code
} }
if bytes.Equal(s.CodeHash(), emptyCodeHash) { if bytes.Equal(s.CodeHash(), types.EmptyCodeHash.Bytes()) {
return nil return nil
} }
code, err := db.ContractCode(common.BytesToHash(s.CodeHash())) code, err := db.ContractCode(common.BytesToHash(s.CodeHash()))
if err != nil { if err != nil {
s.setError(fmt.Errorf("can't load code hash %x: %v", s.CodeHash(), err)) s.db.setError(fmt.Errorf("can't load code hash %x: %v", s.CodeHash(), err))
} }
s.code = code s.code = code
return code return code
@ -181,16 +429,44 @@ func (s *stateObject) CodeSize(db Database) int {
if s.code != nil { if s.code != nil {
return len(s.code) return len(s.code)
} }
if bytes.Equal(s.CodeHash(), emptyCodeHash) { if bytes.Equal(s.CodeHash(), types.EmptyCodeHash.Bytes()) {
return 0 return 0
} }
size, err := db.ContractCodeSize(common.BytesToHash(s.CodeHash())) size, err := db.ContractCodeSize(common.BytesToHash(s.CodeHash()))
if err != nil { if err != nil {
s.setError(fmt.Errorf("can't load code size %x: %v", s.CodeHash(), err)) s.db.setError(fmt.Errorf("can't load code size %x: %v", s.CodeHash(), err))
} }
return size return size
} }
func (s *stateObject) SetCode(codeHash common.Hash, code []byte) {
prevcode := s.Code(s.db.db)
s.db.journal.append(codeChange{
account: &s.address,
prevhash: s.CodeHash(),
prevcode: prevcode,
})
s.setCode(codeHash, code)
}
func (s *stateObject) setCode(codeHash common.Hash, code []byte) {
s.code = code
s.data.CodeHash = codeHash[:]
s.dirtyCode = true
}
func (s *stateObject) SetNonce(nonce uint64) {
s.db.journal.append(nonceChange{
account: &s.address,
prev: s.data.Nonce,
})
s.setNonce(nonce)
}
func (s *stateObject) setNonce(nonce uint64) {
s.data.Nonce = nonce
}
func (s *stateObject) CodeHash() []byte { func (s *stateObject) CodeHash() []byte {
return s.data.CodeHash return s.data.CodeHash
} }
@ -202,10 +478,3 @@ func (s *stateObject) Balance() *big.Int {
func (s *stateObject) Nonce() uint64 { func (s *stateObject) Nonce() uint64 {
return s.data.Nonce return s.data.Nonce
} }
// Never called, but must be present to allow stateObject to be used
// as a vm.Account interface that also satisfies the vm.ContractRef
// interface. Interfaces are awesome.
func (s *stateObject) Value() *big.Int {
panic("Value on stateObject should never be called")
}

View File

@ -0,0 +1,46 @@
// Copyright 2019 The go-ethereum Authors
// This file is part of the go-ethereum library.
//
// The go-ethereum library is free software: you can redistribute it and/or modify
// it under the terms of the GNU Lesser General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// The go-ethereum library is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Lesser General Public License for more details.
//
// You should have received a copy of the GNU Lesser General Public License
// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
package state
import (
"bytes"
"testing"
"github.com/ethereum/go-ethereum/common"
)
func BenchmarkCutOriginal(b *testing.B) {
value := common.HexToHash("0x01")
for i := 0; i < b.N; i++ {
bytes.TrimLeft(value[:], "\x00")
}
}
func BenchmarkCutsetterFn(b *testing.B) {
value := common.HexToHash("0x01")
cutSetFn := func(r rune) bool { return r == 0 }
for i := 0; i < b.N; i++ {
bytes.TrimLeftFunc(value[:], cutSetFn)
}
}
func BenchmarkCutCustomTrim(b *testing.B) {
value := common.HexToHash("0x01")
for i := 0; i < b.N; i++ {
common.TrimLeftZeroes(value[:])
}
}

View File

@ -0,0 +1,219 @@
// Copyright 2014 The go-ethereum Authors
// This file is part of the go-ethereum library.
//
// The go-ethereum library is free software: you can redistribute it and/or modify
// it under the terms of the GNU Lesser General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// The go-ethereum library is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Lesser General Public License for more details.
//
// You should have received a copy of the GNU Lesser General Public License
// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
package state
import (
"bytes"
"context"
"math/big"
"testing"
pgipfsethdb "github.com/cerc-io/ipfs-ethdb/v5/postgres/v0"
"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/core/rawdb"
"github.com/ethereum/go-ethereum/crypto"
"github.com/ethereum/go-ethereum/ethdb"
"github.com/ethereum/go-ethereum/statediff/indexer/database/sql/postgres"
"github.com/cerc-io/ipld-eth-statedb/internal"
)
var (
testCtx = context.Background()
testConfig, _ = postgres.DefaultConfig.WithEnv()
teardownStatements = []string{`TRUNCATE ipld.blocks`}
)
type stateTest struct {
db ethdb.Database
state *StateDB
}
func newStateTest(t *testing.T) *stateTest {
pool, err := postgres.ConnectSQLX(testCtx, testConfig)
if err != nil {
t.Fatal(err)
}
db := pgipfsethdb.NewDatabase(pool, internal.MakeCacheConfig(t))
sdb, err := New(common.Hash{}, NewDatabase(db), nil)
if err != nil {
t.Fatal(err)
}
return &stateTest{db: db, state: sdb}
}
func TestNull(t *testing.T) {
s := newStateTest(t)
address := common.HexToAddress("0x823140710bf13990e4500136726d8b55")
s.state.CreateAccount(address)
//value := common.FromHex("0x823140710bf13990e4500136726d8b55")
var value common.Hash
s.state.SetState(address, common.Hash{}, value)
// s.state.Commit(false)
if value := s.state.GetState(address, common.Hash{}); value != (common.Hash{}) {
t.Errorf("expected empty current value, got %x", value)
}
if value := s.state.GetCommittedState(address, common.Hash{}); value != (common.Hash{}) {
t.Errorf("expected empty committed value, got %x", value)
}
}
func TestSnapshot(t *testing.T) {
stateobjaddr := common.BytesToAddress([]byte("aa"))
var storageaddr common.Hash
data1 := common.BytesToHash([]byte{42})
data2 := common.BytesToHash([]byte{43})
s := newStateTest(t)
// snapshot the genesis state
genesis := s.state.Snapshot()
// set initial state object value
s.state.SetState(stateobjaddr, storageaddr, data1)
snapshot := s.state.Snapshot()
// set a new state object value, revert it and ensure correct content
s.state.SetState(stateobjaddr, storageaddr, data2)
s.state.RevertToSnapshot(snapshot)
if v := s.state.GetState(stateobjaddr, storageaddr); v != data1 {
t.Errorf("wrong storage value %v, want %v", v, data1)
}
if v := s.state.GetCommittedState(stateobjaddr, storageaddr); v != (common.Hash{}) {
t.Errorf("wrong committed storage value %v, want %v", v, common.Hash{})
}
// revert up to the genesis state and ensure correct content
s.state.RevertToSnapshot(genesis)
if v := s.state.GetState(stateobjaddr, storageaddr); v != (common.Hash{}) {
t.Errorf("wrong storage value %v, want %v", v, common.Hash{})
}
if v := s.state.GetCommittedState(stateobjaddr, storageaddr); v != (common.Hash{}) {
t.Errorf("wrong committed storage value %v, want %v", v, common.Hash{})
}
}
func TestSnapshotEmpty(t *testing.T) {
s := newStateTest(t)
s.state.RevertToSnapshot(s.state.Snapshot())
}
func TestSnapshot2(t *testing.T) {
state, _ := New(common.Hash{}, NewDatabase(rawdb.NewMemoryDatabase()), nil)
stateobjaddr0 := common.BytesToAddress([]byte("so0"))
stateobjaddr1 := common.BytesToAddress([]byte("so1"))
var storageaddr common.Hash
data0 := common.BytesToHash([]byte{17})
data1 := common.BytesToHash([]byte{18})
state.SetState(stateobjaddr0, storageaddr, data0)
state.SetState(stateobjaddr1, storageaddr, data1)
// db, trie are already non-empty values
so0 := state.getStateObject(stateobjaddr0)
so0.SetBalance(big.NewInt(42))
so0.SetNonce(43)
so0.SetCode(crypto.Keccak256Hash([]byte{'c', 'a', 'f', 'e'}), []byte{'c', 'a', 'f', 'e'})
so0.suicided = false
so0.deleted = false
state.setStateObject(so0)
// root, _ := state.Commit(false)
// state, _ = New(root, state.db, state.snaps)
// and one with deleted == true
so1 := state.getStateObject(stateobjaddr1)
so1.SetBalance(big.NewInt(52))
so1.SetNonce(53)
so1.SetCode(crypto.Keccak256Hash([]byte{'c', 'a', 'f', 'e', '2'}), []byte{'c', 'a', 'f', 'e', '2'})
so1.suicided = true
so1.deleted = true
state.setStateObject(so1)
so1 = state.getStateObject(stateobjaddr1)
if so1 != nil {
t.Fatalf("deleted object not nil when getting")
}
snapshot := state.Snapshot()
state.RevertToSnapshot(snapshot)
so0Restored := state.getStateObject(stateobjaddr0)
// Update lazily-loaded values before comparing.
so0Restored.GetState(state.db, storageaddr)
so0Restored.Code(state.db)
// non-deleted is equal (restored)
compareStateObjects(so0Restored, so0, t)
// deleted should be nil, both before and after restore of state copy
so1Restored := state.getStateObject(stateobjaddr1)
if so1Restored != nil {
t.Fatalf("deleted object not nil after restoring snapshot: %+v", so1Restored)
}
}
func compareStateObjects(so0, so1 *stateObject, t *testing.T) {
if so0.Address() != so1.Address() {
t.Fatalf("Address mismatch: have %v, want %v", so0.address, so1.address)
}
if so0.Balance().Cmp(so1.Balance()) != 0 {
t.Fatalf("Balance mismatch: have %v, want %v", so0.Balance(), so1.Balance())
}
if so0.Nonce() != so1.Nonce() {
t.Fatalf("Nonce mismatch: have %v, want %v", so0.Nonce(), so1.Nonce())
}
if so0.data.Root != so1.data.Root {
t.Errorf("Root mismatch: have %x, want %x", so0.data.Root[:], so1.data.Root[:])
}
if !bytes.Equal(so0.CodeHash(), so1.CodeHash()) {
t.Fatalf("CodeHash mismatch: have %v, want %v", so0.CodeHash(), so1.CodeHash())
}
if !bytes.Equal(so0.code, so1.code) {
t.Fatalf("Code mismatch: have %v, want %v", so0.code, so1.code)
}
if len(so1.dirtyStorage) != len(so0.dirtyStorage) {
t.Errorf("Dirty storage size mismatch: have %d, want %d", len(so1.dirtyStorage), len(so0.dirtyStorage))
}
for k, v := range so1.dirtyStorage {
if so0.dirtyStorage[k] != v {
t.Errorf("Dirty storage key %x mismatch: have %v, want %v", k, so0.dirtyStorage[k], v)
}
}
for k, v := range so0.dirtyStorage {
if so1.dirtyStorage[k] != v {
t.Errorf("Dirty storage key %x mismatch: have %v, want none.", k, v)
}
}
if len(so1.originStorage) != len(so0.originStorage) {
t.Errorf("Origin storage size mismatch: have %d, want %d", len(so1.originStorage), len(so0.originStorage))
}
for k, v := range so1.originStorage {
if so0.originStorage[k] != v {
t.Errorf("Origin storage key %x mismatch: have %v, want %v", k, so0.originStorage[k], v)
}
}
for k, v := range so0.originStorage {
if so1.originStorage[k] != v {
t.Errorf("Origin storage key %x mismatch: have %v, want none.", k, v)
}
}
}

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,594 @@
// Copyright 2016 The go-ethereum Authors
// This file is part of the go-ethereum library.
//
// The go-ethereum library is free software: you can redistribute it and/or modify
// it under the terms of the GNU Lesser General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// The go-ethereum library is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Lesser General Public License for more details.
//
// You should have received a copy of the GNU Lesser General Public License
// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
package state
import (
"bytes"
"encoding/binary"
"fmt"
"math"
"math/big"
"math/rand"
"reflect"
"strings"
"sync"
"testing"
"testing/quick"
"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/core/rawdb"
"github.com/ethereum/go-ethereum/core/types"
)
// TestCopy tests that copying a StateDB object indeed makes the original and
// the copy independent of each other. This test is a regression test against
// https://github.com/ethereum/go-ethereum/pull/15549.
func TestCopy(t *testing.T) {
// Create a random state test to copy and modify "independently"
orig, _ := New(common.Hash{}, NewDatabase(rawdb.NewMemoryDatabase()), nil)
for i := byte(0); i < 255; i++ {
obj := orig.GetOrNewStateObject(common.BytesToAddress([]byte{i}))
obj.AddBalance(big.NewInt(int64(i)))
orig.updateStateObject(obj)
}
orig.Finalise(false)
// Copy the state
copy := orig.Copy()
// Copy the copy state
ccopy := copy.Copy()
// modify all in memory
for i := byte(0); i < 255; i++ {
origObj := orig.GetOrNewStateObject(common.BytesToAddress([]byte{i}))
copyObj := copy.GetOrNewStateObject(common.BytesToAddress([]byte{i}))
ccopyObj := ccopy.GetOrNewStateObject(common.BytesToAddress([]byte{i}))
origObj.AddBalance(big.NewInt(2 * int64(i)))
copyObj.AddBalance(big.NewInt(3 * int64(i)))
ccopyObj.AddBalance(big.NewInt(4 * int64(i)))
orig.updateStateObject(origObj)
copy.updateStateObject(copyObj)
ccopy.updateStateObject(copyObj)
}
// Finalise the changes on all concurrently
finalise := func(wg *sync.WaitGroup, db *StateDB) {
defer wg.Done()
db.Finalise(true)
}
var wg sync.WaitGroup
wg.Add(3)
go finalise(&wg, orig)
go finalise(&wg, copy)
go finalise(&wg, ccopy)
wg.Wait()
// Verify that the three states have been updated independently
for i := byte(0); i < 255; i++ {
origObj := orig.GetOrNewStateObject(common.BytesToAddress([]byte{i}))
copyObj := copy.GetOrNewStateObject(common.BytesToAddress([]byte{i}))
ccopyObj := ccopy.GetOrNewStateObject(common.BytesToAddress([]byte{i}))
if want := big.NewInt(3 * int64(i)); origObj.Balance().Cmp(want) != 0 {
t.Errorf("orig obj %d: balance mismatch: have %v, want %v", i, origObj.Balance(), want)
}
if want := big.NewInt(4 * int64(i)); copyObj.Balance().Cmp(want) != 0 {
t.Errorf("copy obj %d: balance mismatch: have %v, want %v", i, copyObj.Balance(), want)
}
if want := big.NewInt(5 * int64(i)); ccopyObj.Balance().Cmp(want) != 0 {
t.Errorf("copy obj %d: balance mismatch: have %v, want %v", i, ccopyObj.Balance(), want)
}
}
}
func TestSnapshotRandom(t *testing.T) {
config := &quick.Config{MaxCount: 1000}
err := quick.Check((*snapshotTest).run, config)
if cerr, ok := err.(*quick.CheckError); ok {
test := cerr.In[0].(*snapshotTest)
t.Errorf("%v:\n%s", test.err, test)
} else if err != nil {
t.Error(err)
}
}
// A snapshotTest checks that reverting StateDB snapshots properly undoes all changes
// captured by the snapshot. Instances of this test with pseudorandom content are created
// by Generate.
//
// The test works as follows:
//
// A new state is created and all actions are applied to it. Several snapshots are taken
// in between actions. The test then reverts each snapshot. For each snapshot the actions
// leading up to it are replayed on a fresh, empty state. The behaviour of all public
// accessor methods on the reverted state must match the return value of the equivalent
// methods on the replayed state.
type snapshotTest struct {
addrs []common.Address // all account addresses
actions []testAction // modifications to the state
snapshots []int // actions indexes at which snapshot is taken
err error // failure details are reported through this field
}
type testAction struct {
name string
fn func(testAction, *StateDB)
args []int64
noAddr bool
}
// newTestAction creates a random action that changes state.
func newTestAction(addr common.Address, r *rand.Rand) testAction {
actions := []testAction{
{
name: "SetBalance",
fn: func(a testAction, s *StateDB) {
s.SetBalance(addr, big.NewInt(a.args[0]))
},
args: make([]int64, 1),
},
{
name: "AddBalance",
fn: func(a testAction, s *StateDB) {
s.AddBalance(addr, big.NewInt(a.args[0]))
},
args: make([]int64, 1),
},
{
name: "SetNonce",
fn: func(a testAction, s *StateDB) {
s.SetNonce(addr, uint64(a.args[0]))
},
args: make([]int64, 1),
},
{
name: "SetState",
fn: func(a testAction, s *StateDB) {
var key, val common.Hash
binary.BigEndian.PutUint16(key[:], uint16(a.args[0]))
binary.BigEndian.PutUint16(val[:], uint16(a.args[1]))
s.SetState(addr, key, val)
},
args: make([]int64, 2),
},
{
name: "SetCode",
fn: func(a testAction, s *StateDB) {
code := make([]byte, 16)
binary.BigEndian.PutUint64(code, uint64(a.args[0]))
binary.BigEndian.PutUint64(code[8:], uint64(a.args[1]))
s.SetCode(addr, code)
},
args: make([]int64, 2),
},
{
name: "CreateAccount",
fn: func(a testAction, s *StateDB) {
s.CreateAccount(addr)
},
},
{
name: "Suicide",
fn: func(a testAction, s *StateDB) {
s.Suicide(addr)
},
},
{
name: "AddRefund",
fn: func(a testAction, s *StateDB) {
s.AddRefund(uint64(a.args[0]))
},
args: make([]int64, 1),
noAddr: true,
},
{
name: "AddLog",
fn: func(a testAction, s *StateDB) {
data := make([]byte, 2)
binary.BigEndian.PutUint16(data, uint16(a.args[0]))
s.AddLog(&types.Log{Address: addr, Data: data})
},
args: make([]int64, 1),
},
{
name: "AddPreimage",
fn: func(a testAction, s *StateDB) {
preimage := []byte{1}
hash := common.BytesToHash(preimage)
s.AddPreimage(hash, preimage)
},
args: make([]int64, 1),
},
{
name: "AddAddressToAccessList",
fn: func(a testAction, s *StateDB) {
s.AddAddressToAccessList(addr)
},
},
{
name: "AddSlotToAccessList",
fn: func(a testAction, s *StateDB) {
s.AddSlotToAccessList(addr,
common.Hash{byte(a.args[0])})
},
args: make([]int64, 1),
},
{
name: "SetTransientState",
fn: func(a testAction, s *StateDB) {
var key, val common.Hash
binary.BigEndian.PutUint16(key[:], uint16(a.args[0]))
binary.BigEndian.PutUint16(val[:], uint16(a.args[1]))
s.SetTransientState(addr, key, val)
},
args: make([]int64, 2),
},
}
action := actions[r.Intn(len(actions))]
var nameargs []string
if !action.noAddr {
nameargs = append(nameargs, addr.Hex())
}
for i := range action.args {
action.args[i] = rand.Int63n(100)
nameargs = append(nameargs, fmt.Sprint(action.args[i]))
}
action.name += strings.Join(nameargs, ", ")
return action
}
// Generate returns a new snapshot test of the given size. All randomness is
// derived from r.
func (*snapshotTest) Generate(r *rand.Rand, size int) reflect.Value {
// Generate random actions.
addrs := make([]common.Address, 50)
for i := range addrs {
addrs[i][0] = byte(i)
}
actions := make([]testAction, size)
for i := range actions {
addr := addrs[r.Intn(len(addrs))]
actions[i] = newTestAction(addr, r)
}
// Generate snapshot indexes.
nsnapshots := int(math.Sqrt(float64(size)))
if size > 0 && nsnapshots == 0 {
nsnapshots = 1
}
snapshots := make([]int, nsnapshots)
snaplen := len(actions) / nsnapshots
for i := range snapshots {
// Try to place the snapshots some number of actions apart from each other.
snapshots[i] = (i * snaplen) + r.Intn(snaplen)
}
return reflect.ValueOf(&snapshotTest{addrs, actions, snapshots, nil})
}
func (test *snapshotTest) String() string {
out := new(bytes.Buffer)
sindex := 0
for i, action := range test.actions {
if len(test.snapshots) > sindex && i == test.snapshots[sindex] {
fmt.Fprintf(out, "---- snapshot %d ----\n", sindex)
sindex++
}
fmt.Fprintf(out, "%4d: %s\n", i, action.name)
}
return out.String()
}
func (test *snapshotTest) run() bool {
// Run all actions and create snapshots.
var (
state, _ = New(common.Hash{}, NewDatabase(rawdb.NewMemoryDatabase()), nil)
snapshotRevs = make([]int, len(test.snapshots))
sindex = 0
)
for i, action := range test.actions {
if len(test.snapshots) > sindex && i == test.snapshots[sindex] {
snapshotRevs[sindex] = state.Snapshot()
sindex++
}
action.fn(action, state)
}
// Revert all snapshots in reverse order. Each revert must yield a state
// that is equivalent to fresh state with all actions up the snapshot applied.
for sindex--; sindex >= 0; sindex-- {
checkstate, _ := New(common.Hash{}, state.Database(), nil)
for _, action := range test.actions[:test.snapshots[sindex]] {
action.fn(action, checkstate)
}
state.RevertToSnapshot(snapshotRevs[sindex])
if err := test.checkEqual(state, checkstate); err != nil {
test.err = fmt.Errorf("state mismatch after revert to snapshot %d\n%v", sindex, err)
return false
}
}
return true
}
// checkEqual checks that methods of state and checkstate return the same values.
func (test *snapshotTest) checkEqual(state, checkstate *StateDB) error {
for _, addr := range test.addrs {
var err error
checkeq := func(op string, a, b interface{}) bool {
if err == nil && !reflect.DeepEqual(a, b) {
err = fmt.Errorf("got %s(%s) == %v, want %v", op, addr.Hex(), a, b)
return false
}
return true
}
// Check basic accessor methods.
checkeq("Exist", state.Exist(addr), checkstate.Exist(addr))
checkeq("HasSuicided", state.HasSuicided(addr), checkstate.HasSuicided(addr))
checkeq("GetBalance", state.GetBalance(addr), checkstate.GetBalance(addr))
checkeq("GetNonce", state.GetNonce(addr), checkstate.GetNonce(addr))
checkeq("GetCode", state.GetCode(addr), checkstate.GetCode(addr))
checkeq("GetCodeHash", state.GetCodeHash(addr), checkstate.GetCodeHash(addr))
checkeq("GetCodeSize", state.GetCodeSize(addr), checkstate.GetCodeSize(addr))
// Check storage.
if obj := state.getStateObject(addr); obj != nil {
state.ForEachStorage(addr, func(key, value common.Hash) bool {
return checkeq("GetState("+key.Hex()+")", checkstate.GetState(addr, key), value)
})
checkstate.ForEachStorage(addr, func(key, value common.Hash) bool {
return checkeq("GetState("+key.Hex()+")", checkstate.GetState(addr, key), value)
})
}
if err != nil {
return err
}
}
if state.GetRefund() != checkstate.GetRefund() {
return fmt.Errorf("got GetRefund() == %d, want GetRefund() == %d",
state.GetRefund(), checkstate.GetRefund())
}
if !reflect.DeepEqual(state.GetLogs(common.Hash{}, 0, common.Hash{}), checkstate.GetLogs(common.Hash{}, 0, common.Hash{})) {
return fmt.Errorf("got GetLogs(common.Hash{}) == %v, want GetLogs(common.Hash{}) == %v",
state.GetLogs(common.Hash{}, 0, common.Hash{}), checkstate.GetLogs(common.Hash{}, 0, common.Hash{}))
}
return nil
}
// TestCopyOfCopy tests that modified objects are carried over to the copy, and the copy of the copy.
// See https://github.com/ethereum/go-ethereum/pull/15225#issuecomment-380191512
func TestCopyOfCopy(t *testing.T) {
state, _ := New(common.Hash{}, NewDatabase(rawdb.NewMemoryDatabase()), nil)
addr := common.HexToAddress("aaaa")
state.SetBalance(addr, big.NewInt(42))
if got := state.Copy().GetBalance(addr).Uint64(); got != 42 {
t.Fatalf("1st copy fail, expected 42, got %v", got)
}
if got := state.Copy().Copy().GetBalance(addr).Uint64(); got != 42 {
t.Fatalf("2nd copy fail, expected 42, got %v", got)
}
}
func TestStateDBAccessList(t *testing.T) {
// Some helpers
addr := func(a string) common.Address {
return common.HexToAddress(a)
}
slot := func(a string) common.Hash {
return common.HexToHash(a)
}
memDb := rawdb.NewMemoryDatabase()
db := NewDatabase(memDb)
state, _ := New(common.Hash{}, db, nil)
state.accessList = newAccessList()
verifyAddrs := func(astrings ...string) {
t.Helper()
// convert to common.Address form
var addresses []common.Address
var addressMap = make(map[common.Address]struct{})
for _, astring := range astrings {
address := addr(astring)
addresses = append(addresses, address)
addressMap[address] = struct{}{}
}
// Check that the given addresses are in the access list
for _, address := range addresses {
if !state.AddressInAccessList(address) {
t.Fatalf("expected %x to be in access list", address)
}
}
// Check that only the expected addresses are present in the access list
for address := range state.accessList.addresses {
if _, exist := addressMap[address]; !exist {
t.Fatalf("extra address %x in access list", address)
}
}
}
verifySlots := func(addrString string, slotStrings ...string) {
if !state.AddressInAccessList(addr(addrString)) {
t.Fatalf("scope missing address/slots %v", addrString)
}
var address = addr(addrString)
// convert to common.Hash form
var slots []common.Hash
var slotMap = make(map[common.Hash]struct{})
for _, slotString := range slotStrings {
s := slot(slotString)
slots = append(slots, s)
slotMap[s] = struct{}{}
}
// Check that the expected items are in the access list
for i, s := range slots {
if _, slotPresent := state.SlotInAccessList(address, s); !slotPresent {
t.Fatalf("input %d: scope missing slot %v (address %v)", i, s, addrString)
}
}
// Check that no extra elements are in the access list
index := state.accessList.addresses[address]
if index >= 0 {
stateSlots := state.accessList.slots[index]
for s := range stateSlots {
if _, slotPresent := slotMap[s]; !slotPresent {
t.Fatalf("scope has extra slot %v (address %v)", s, addrString)
}
}
}
}
state.AddAddressToAccessList(addr("aa")) // 1
state.AddSlotToAccessList(addr("bb"), slot("01")) // 2,3
state.AddSlotToAccessList(addr("bb"), slot("02")) // 4
verifyAddrs("aa", "bb")
verifySlots("bb", "01", "02")
// Make a copy
stateCopy1 := state.Copy()
if exp, got := 4, state.journal.length(); exp != got {
t.Fatalf("journal length mismatch: have %d, want %d", got, exp)
}
// same again, should cause no journal entries
state.AddSlotToAccessList(addr("bb"), slot("01"))
state.AddSlotToAccessList(addr("bb"), slot("02"))
state.AddAddressToAccessList(addr("aa"))
if exp, got := 4, state.journal.length(); exp != got {
t.Fatalf("journal length mismatch: have %d, want %d", got, exp)
}
// some new ones
state.AddSlotToAccessList(addr("bb"), slot("03")) // 5
state.AddSlotToAccessList(addr("aa"), slot("01")) // 6
state.AddSlotToAccessList(addr("cc"), slot("01")) // 7,8
state.AddAddressToAccessList(addr("cc"))
if exp, got := 8, state.journal.length(); exp != got {
t.Fatalf("journal length mismatch: have %d, want %d", got, exp)
}
verifyAddrs("aa", "bb", "cc")
verifySlots("aa", "01")
verifySlots("bb", "01", "02", "03")
verifySlots("cc", "01")
// now start rolling back changes
state.journal.revert(state, 7)
if _, ok := state.SlotInAccessList(addr("cc"), slot("01")); ok {
t.Fatalf("slot present, expected missing")
}
verifyAddrs("aa", "bb", "cc")
verifySlots("aa", "01")
verifySlots("bb", "01", "02", "03")
state.journal.revert(state, 6)
if state.AddressInAccessList(addr("cc")) {
t.Fatalf("addr present, expected missing")
}
verifyAddrs("aa", "bb")
verifySlots("aa", "01")
verifySlots("bb", "01", "02", "03")
state.journal.revert(state, 5)
if _, ok := state.SlotInAccessList(addr("aa"), slot("01")); ok {
t.Fatalf("slot present, expected missing")
}
verifyAddrs("aa", "bb")
verifySlots("bb", "01", "02", "03")
state.journal.revert(state, 4)
if _, ok := state.SlotInAccessList(addr("bb"), slot("03")); ok {
t.Fatalf("slot present, expected missing")
}
verifyAddrs("aa", "bb")
verifySlots("bb", "01", "02")
state.journal.revert(state, 3)
if _, ok := state.SlotInAccessList(addr("bb"), slot("02")); ok {
t.Fatalf("slot present, expected missing")
}
verifyAddrs("aa", "bb")
verifySlots("bb", "01")
state.journal.revert(state, 2)
if _, ok := state.SlotInAccessList(addr("bb"), slot("01")); ok {
t.Fatalf("slot present, expected missing")
}
verifyAddrs("aa", "bb")
state.journal.revert(state, 1)
if state.AddressInAccessList(addr("bb")) {
t.Fatalf("addr present, expected missing")
}
verifyAddrs("aa")
state.journal.revert(state, 0)
if state.AddressInAccessList(addr("aa")) {
t.Fatalf("addr present, expected missing")
}
if got, exp := len(state.accessList.addresses), 0; got != exp {
t.Fatalf("expected empty, got %d", got)
}
if got, exp := len(state.accessList.slots), 0; got != exp {
t.Fatalf("expected empty, got %d", got)
}
// Check the copy
// Make a copy
state = stateCopy1
verifyAddrs("aa", "bb")
verifySlots("bb", "01", "02")
if got, exp := len(state.accessList.addresses), 2; got != exp {
t.Fatalf("expected empty, got %d", got)
}
if got, exp := len(state.accessList.slots), 1; got != exp {
t.Fatalf("expected empty, got %d", got)
}
}
func TestStateDBTransientStorage(t *testing.T) {
memDb := rawdb.NewMemoryDatabase()
db := NewDatabase(memDb)
state, _ := New(common.Hash{}, db, nil)
key := common.Hash{0x01}
value := common.Hash{0x02}
addr := common.Address{}
state.SetTransientState(addr, key, value)
if exp, got := 1, state.journal.length(); exp != got {
t.Fatalf("journal length mismatch: have %d, want %d", got, exp)
}
// the retrieved value should equal what was set
if got := state.GetTransientState(addr, key); got != value {
t.Fatalf("transient storage mismatch: have %x, want %x", got, value)
}
// revert the transient state being set and then check that the
// value is now the empty hash
state.journal.revert(state, 0)
if got, exp := state.GetTransientState(addr, key), (common.Hash{}); exp != got {
t.Fatalf("transient storage mismatch: have %x, want %x", got, exp)
}
// set transient state and then copy the statedb and ensure that
// the transient state is copied
state.SetTransientState(addr, key, value)
cpy := state.Copy()
if got := cpy.GetTransientState(addr, key); got != value {
t.Fatalf("transient storage mismatch: have %x, want %x", got, value)
}
}

View File

@ -0,0 +1,55 @@
// Copyright 2022 The go-ethereum Authors
// This file is part of the go-ethereum library.
//
// The go-ethereum library is free software: you can redistribute it and/or modify
// it under the terms of the GNU Lesser General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// The go-ethereum library is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Lesser General Public License for more details.
//
// You should have received a copy of the GNU Lesser General Public License
// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
package state
import (
"github.com/ethereum/go-ethereum/common"
)
// transientStorage is a representation of EIP-1153 "Transient Storage".
type transientStorage map[common.Address]Storage
// newTransientStorage creates a new instance of a transientStorage.
func newTransientStorage() transientStorage {
return make(transientStorage)
}
// Set sets the transient-storage `value` for `key` at the given `addr`.
func (t transientStorage) Set(addr common.Address, key, value common.Hash) {
if _, ok := t[addr]; !ok {
t[addr] = make(Storage)
}
t[addr][key] = value
}
// Get gets the transient storage for `key` at the given `addr`.
func (t transientStorage) Get(addr common.Address, key common.Hash) common.Hash {
val, ok := t[addr]
if !ok {
return common.Hash{}
}
return val[key]
}
// Copy does a deep copy of the transientStorage
func (t transientStorage) Copy() transientStorage {
storage := make(transientStorage)
for key, value := range t {
storage[key] = value.Copy()
}
return storage
}

View File

@ -0,0 +1,354 @@
// Copyright 2020 The go-ethereum Authors
// This file is part of the go-ethereum library.
//
// The go-ethereum library is free software: you can redistribute it and/or modify
// it under the terms of the GNU Lesser General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// The go-ethereum library is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Lesser General Public License for more details.
//
// You should have received a copy of the GNU Lesser General Public License
// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
package state
import (
"sync"
"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/metrics"
log "github.com/sirupsen/logrus"
)
var (
// triePrefetchMetricsPrefix is the prefix under which to publish the metrics.
triePrefetchMetricsPrefix = "trie/prefetch/"
)
// triePrefetcher is an active prefetcher, which receives accounts or storage
// items and does trie-loading of them. The goal is to get as much useful content
// into the caches as possible.
//
// Note, the prefetcher's API is not thread safe.
type triePrefetcher struct {
db Database // Database to fetch trie nodes through
root common.Hash // Root hash of the account trie for metrics
fetches map[string]Trie // Partially or fully fetcher tries
fetchers map[string]*subfetcher // Subfetchers for each trie
deliveryMissMeter metrics.Meter
accountLoadMeter metrics.Meter
accountDupMeter metrics.Meter
accountSkipMeter metrics.Meter
accountWasteMeter metrics.Meter
storageLoadMeter metrics.Meter
storageDupMeter metrics.Meter
storageSkipMeter metrics.Meter
storageWasteMeter metrics.Meter
}
func newTriePrefetcher(db Database, root common.Hash, namespace string) *triePrefetcher {
prefix := triePrefetchMetricsPrefix + namespace
p := &triePrefetcher{
db: db,
root: root,
fetchers: make(map[string]*subfetcher), // Active prefetchers use the fetchers map
deliveryMissMeter: metrics.GetOrRegisterMeter(prefix+"/deliverymiss", nil),
accountLoadMeter: metrics.GetOrRegisterMeter(prefix+"/account/load", nil),
accountDupMeter: metrics.GetOrRegisterMeter(prefix+"/account/dup", nil),
accountSkipMeter: metrics.GetOrRegisterMeter(prefix+"/account/skip", nil),
accountWasteMeter: metrics.GetOrRegisterMeter(prefix+"/account/waste", nil),
storageLoadMeter: metrics.GetOrRegisterMeter(prefix+"/storage/load", nil),
storageDupMeter: metrics.GetOrRegisterMeter(prefix+"/storage/dup", nil),
storageSkipMeter: metrics.GetOrRegisterMeter(prefix+"/storage/skip", nil),
storageWasteMeter: metrics.GetOrRegisterMeter(prefix+"/storage/waste", nil),
}
return p
}
// close iterates over all the subfetchers, aborts any that were left spinning
// and reports the stats to the metrics subsystem.
func (p *triePrefetcher) close() {
for _, fetcher := range p.fetchers {
fetcher.abort() // safe to do multiple times
if metrics.Enabled {
if fetcher.root == p.root {
p.accountLoadMeter.Mark(int64(len(fetcher.seen)))
p.accountDupMeter.Mark(int64(fetcher.dups))
p.accountSkipMeter.Mark(int64(len(fetcher.tasks)))
for _, key := range fetcher.used {
delete(fetcher.seen, string(key))
}
p.accountWasteMeter.Mark(int64(len(fetcher.seen)))
} else {
p.storageLoadMeter.Mark(int64(len(fetcher.seen)))
p.storageDupMeter.Mark(int64(fetcher.dups))
p.storageSkipMeter.Mark(int64(len(fetcher.tasks)))
for _, key := range fetcher.used {
delete(fetcher.seen, string(key))
}
p.storageWasteMeter.Mark(int64(len(fetcher.seen)))
}
}
}
// Clear out all fetchers (will crash on a second call, deliberate)
p.fetchers = nil
}
// copy creates a deep-but-inactive copy of the trie prefetcher. Any trie data
// already loaded will be copied over, but no goroutines will be started. This
// is mostly used in the miner which creates a copy of it's actively mutated
// state to be sealed while it may further mutate the state.
func (p *triePrefetcher) copy() *triePrefetcher {
copy := &triePrefetcher{
db: p.db,
root: p.root,
fetches: make(map[string]Trie), // Active prefetchers use the fetches map
deliveryMissMeter: p.deliveryMissMeter,
accountLoadMeter: p.accountLoadMeter,
accountDupMeter: p.accountDupMeter,
accountSkipMeter: p.accountSkipMeter,
accountWasteMeter: p.accountWasteMeter,
storageLoadMeter: p.storageLoadMeter,
storageDupMeter: p.storageDupMeter,
storageSkipMeter: p.storageSkipMeter,
storageWasteMeter: p.storageWasteMeter,
}
// If the prefetcher is already a copy, duplicate the data
if p.fetches != nil {
for root, fetch := range p.fetches {
if fetch == nil {
continue
}
copy.fetches[root] = p.db.CopyTrie(fetch)
}
return copy
}
// Otherwise we're copying an active fetcher, retrieve the current states
for id, fetcher := range p.fetchers {
copy.fetches[id] = fetcher.peek()
}
return copy
}
// prefetch schedules a batch of trie items to prefetch.
func (p *triePrefetcher) prefetch(owner common.Hash, root common.Hash, keys [][]byte) {
// If the prefetcher is an inactive one, bail out
if p.fetches != nil {
return
}
// Active fetcher, schedule the retrievals
id := p.trieID(owner, root)
fetcher := p.fetchers[id]
if fetcher == nil {
fetcher = newSubfetcher(p.db, p.root, owner, root)
p.fetchers[id] = fetcher
}
fetcher.schedule(keys)
}
// trie returns the trie matching the root hash, or nil if the prefetcher doesn't
// have it.
func (p *triePrefetcher) trie(owner common.Hash, root common.Hash) Trie {
// If the prefetcher is inactive, return from existing deep copies
id := p.trieID(owner, root)
if p.fetches != nil {
trie := p.fetches[id]
if trie == nil {
p.deliveryMissMeter.Mark(1)
return nil
}
return p.db.CopyTrie(trie)
}
// Otherwise the prefetcher is active, bail if no trie was prefetched for this root
fetcher := p.fetchers[id]
if fetcher == nil {
p.deliveryMissMeter.Mark(1)
return nil
}
// Interrupt the prefetcher if it's by any chance still running and return
// a copy of any pre-loaded trie.
fetcher.abort() // safe to do multiple times
trie := fetcher.peek()
if trie == nil {
p.deliveryMissMeter.Mark(1)
return nil
}
return trie
}
// used marks a batch of state items used to allow creating statistics as to
// how useful or wasteful the prefetcher is.
func (p *triePrefetcher) used(owner common.Hash, root common.Hash, used [][]byte) {
if fetcher := p.fetchers[p.trieID(owner, root)]; fetcher != nil {
fetcher.used = used
}
}
// trieID returns an unique trie identifier consists the trie owner and root hash.
func (p *triePrefetcher) trieID(owner common.Hash, root common.Hash) string {
return string(append(owner.Bytes(), root.Bytes()...))
}
// subfetcher is a trie fetcher goroutine responsible for pulling entries for a
// single trie. It is spawned when a new root is encountered and lives until the
// main prefetcher is paused and either all requested items are processed or if
// the trie being worked on is retrieved from the prefetcher.
type subfetcher struct {
db Database // Database to load trie nodes through
state common.Hash // Root hash of the state to prefetch
owner common.Hash // Owner of the trie, usually account hash
root common.Hash // Root hash of the trie to prefetch
trie Trie // Trie being populated with nodes
tasks [][]byte // Items queued up for retrieval
lock sync.Mutex // Lock protecting the task queue
wake chan struct{} // Wake channel if a new task is scheduled
stop chan struct{} // Channel to interrupt processing
term chan struct{} // Channel to signal interruption
copy chan chan Trie // Channel to request a copy of the current trie
seen map[string]struct{} // Tracks the entries already loaded
dups int // Number of duplicate preload tasks
used [][]byte // Tracks the entries used in the end
}
// newSubfetcher creates a goroutine to prefetch state items belonging to a
// particular root hash.
func newSubfetcher(db Database, state common.Hash, owner common.Hash, root common.Hash) *subfetcher {
sf := &subfetcher{
db: db,
state: state,
owner: owner,
root: root,
wake: make(chan struct{}, 1),
stop: make(chan struct{}),
term: make(chan struct{}),
copy: make(chan chan Trie),
seen: make(map[string]struct{}),
}
go sf.loop()
return sf
}
// schedule adds a batch of trie keys to the queue to prefetch.
func (sf *subfetcher) schedule(keys [][]byte) {
// Append the tasks to the current queue
sf.lock.Lock()
sf.tasks = append(sf.tasks, keys...)
sf.lock.Unlock()
// Notify the prefetcher, it's fine if it's already terminated
select {
case sf.wake <- struct{}{}:
default:
}
}
// peek tries to retrieve a deep copy of the fetcher's trie in whatever form it
// is currently.
func (sf *subfetcher) peek() Trie {
ch := make(chan Trie)
select {
case sf.copy <- ch:
// Subfetcher still alive, return copy from it
return <-ch
case <-sf.term:
// Subfetcher already terminated, return a copy directly
if sf.trie == nil {
return nil
}
return sf.db.CopyTrie(sf.trie)
}
}
// abort interrupts the subfetcher immediately. It is safe to call abort multiple
// times but it is not thread safe.
func (sf *subfetcher) abort() {
select {
case <-sf.stop:
default:
close(sf.stop)
}
<-sf.term
}
// loop waits for new tasks to be scheduled and keeps loading them until it runs
// out of tasks or its underlying trie is retrieved for committing.
func (sf *subfetcher) loop() {
// No matter how the loop stops, signal anyone waiting that it's terminated
defer close(sf.term)
// Start by opening the trie and stop processing if it fails
if sf.owner == (common.Hash{}) {
trie, err := sf.db.OpenTrie(sf.root)
if err != nil {
log.Warn("Trie prefetcher failed opening trie", "root", sf.root, "err", err)
return
}
sf.trie = trie
} else {
trie, err := sf.db.OpenStorageTrie(sf.state, sf.owner, sf.root)
if err != nil {
log.Warn("Trie prefetcher failed opening trie", "root", sf.root, "err", err)
return
}
sf.trie = trie
}
// Trie opened successfully, keep prefetching items
for {
select {
case <-sf.wake:
// Subfetcher was woken up, retrieve any tasks to avoid spinning the lock
sf.lock.Lock()
tasks := sf.tasks
sf.tasks = nil
sf.lock.Unlock()
// Prefetch any tasks until the loop is interrupted
for i, task := range tasks {
select {
case <-sf.stop:
// If termination is requested, add any leftover back and return
sf.lock.Lock()
sf.tasks = append(sf.tasks, tasks[i:]...)
sf.lock.Unlock()
return
case ch := <-sf.copy:
// Somebody wants a copy of the current trie, grant them
ch <- sf.db.CopyTrie(sf.trie)
default:
// No termination request yet, prefetch the next entry
if _, ok := sf.seen[string(task)]; ok {
sf.dups++
} else {
sf.trie.TryGet(task)
sf.seen[string(task)] = struct{}{}
}
}
}
case ch := <-sf.copy:
// Somebody wants a copy of the current trie, grant them
ch <- sf.db.CopyTrie(sf.trie)
case <-sf.stop:
// Termination is requested, abort and leave remaining tasks
return
}
}
}

View File

@ -0,0 +1,110 @@
// Copyright 2021 The go-ethereum Authors
// This file is part of the go-ethereum library.
//
// The go-ethereum library is free software: you can redistribute it and/or modify
// it under the terms of the GNU Lesser General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// The go-ethereum library is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Lesser General Public License for more details.
//
// You should have received a copy of the GNU Lesser General Public License
// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
package state
import (
"math/big"
"testing"
"time"
"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/core/rawdb"
)
func filledStateDB() *StateDB {
state, _ := New(common.Hash{}, NewDatabase(rawdb.NewMemoryDatabase()), nil)
// Create an account and check if the retrieved balance is correct
addr := common.HexToAddress("0xaffeaffeaffeaffeaffeaffeaffeaffeaffeaffe")
skey := common.HexToHash("aaa")
sval := common.HexToHash("bbb")
state.SetBalance(addr, big.NewInt(42)) // Change the account trie
state.SetCode(addr, []byte("hello")) // Change an external metadata
state.SetState(addr, skey, sval) // Change the storage trie
for i := 0; i < 100; i++ {
sk := common.BigToHash(big.NewInt(int64(i)))
state.SetState(addr, sk, sk) // Change the storage trie
}
return state
}
func TestCopyAndClose(t *testing.T) {
db := filledStateDB()
prefetcher := newTriePrefetcher(db.db, db.originalRoot, "")
skey := common.HexToHash("aaa")
prefetcher.prefetch(common.Hash{}, db.originalRoot, [][]byte{skey.Bytes()})
prefetcher.prefetch(common.Hash{}, db.originalRoot, [][]byte{skey.Bytes()})
time.Sleep(1 * time.Second)
a := prefetcher.trie(common.Hash{}, db.originalRoot)
prefetcher.prefetch(common.Hash{}, db.originalRoot, [][]byte{skey.Bytes()})
b := prefetcher.trie(common.Hash{}, db.originalRoot)
cpy := prefetcher.copy()
cpy.prefetch(common.Hash{}, db.originalRoot, [][]byte{skey.Bytes()})
cpy.prefetch(common.Hash{}, db.originalRoot, [][]byte{skey.Bytes()})
c := cpy.trie(common.Hash{}, db.originalRoot)
prefetcher.close()
cpy2 := cpy.copy()
cpy2.prefetch(common.Hash{}, db.originalRoot, [][]byte{skey.Bytes()})
d := cpy2.trie(common.Hash{}, db.originalRoot)
cpy.close()
cpy2.close()
if a.Hash() != b.Hash() || a.Hash() != c.Hash() || a.Hash() != d.Hash() {
t.Fatalf("Invalid trie, hashes should be equal: %v %v %v %v", a.Hash(), b.Hash(), c.Hash(), d.Hash())
}
}
func TestUseAfterClose(t *testing.T) {
db := filledStateDB()
prefetcher := newTriePrefetcher(db.db, db.originalRoot, "")
skey := common.HexToHash("aaa")
prefetcher.prefetch(common.Hash{}, db.originalRoot, [][]byte{skey.Bytes()})
a := prefetcher.trie(common.Hash{}, db.originalRoot)
prefetcher.close()
b := prefetcher.trie(common.Hash{}, db.originalRoot)
if a == nil {
t.Fatal("Prefetching before close should not return nil")
}
if b != nil {
t.Fatal("Trie after close should return nil")
}
}
func TestCopyClose(t *testing.T) {
db := filledStateDB()
prefetcher := newTriePrefetcher(db.db, db.originalRoot, "")
skey := common.HexToHash("aaa")
prefetcher.prefetch(common.Hash{}, db.originalRoot, [][]byte{skey.Bytes()})
cpy := prefetcher.copy()
a := prefetcher.trie(common.Hash{}, db.originalRoot)
b := cpy.trie(common.Hash{}, db.originalRoot)
prefetcher.close()
c := prefetcher.trie(common.Hash{}, db.originalRoot)
d := cpy.trie(common.Hash{}, db.originalRoot)
if a == nil {
t.Fatal("Prefetching before close should not return nil")
}
if b == nil {
t.Fatal("Copy trie should return nil")
}
if c != nil {
t.Fatal("Trie after close should return nil")
}
if d == nil {
t.Fatal("Copy trie should not return nil")
}
}

View File

@ -0,0 +1,196 @@
// Copyright 2020 The go-ethereum Authors
// This file is part of the go-ethereum library.
//
// The go-ethereum library is free software: you can redistribute it and/or modify
// it under the terms of the GNU Lesser General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// The go-ethereum library is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Lesser General Public License for more details.
//
// You should have received a copy of the GNU Lesser General Public License
// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
package trie
import (
"fmt"
"github.com/ethereum/go-ethereum/common"
)
// leaf represents a trie leaf node
type leaf struct {
blob []byte // raw blob of leaf
parent common.Hash // the hash of parent node
}
// committer is the tool used for the trie Commit operation. The committer will
// capture all dirty nodes during the commit process and keep them cached in
// insertion order.
type committer struct {
nodes *NodeSet
collectLeaf bool
}
// newCommitter creates a new committer or picks one from the pool.
func newCommitter(nodeset *NodeSet, collectLeaf bool) *committer {
return &committer{
nodes: nodeset,
collectLeaf: collectLeaf,
}
}
// Commit collapses a node down into a hash node.
func (c *committer) Commit(n node) hashNode {
return c.commit(nil, n).(hashNode)
}
// commit collapses a node down into a hash node and returns it.
func (c *committer) commit(path []byte, n node) node {
// if this path is clean, use available cached data
hash, dirty := n.cache()
if hash != nil && !dirty {
return hash
}
// Commit children, then parent, and remove the dirty flag.
switch cn := n.(type) {
case *shortNode:
// Commit child
collapsed := cn.copy()
// If the child is fullNode, recursively commit,
// otherwise it can only be hashNode or valueNode.
if _, ok := cn.Val.(*fullNode); ok {
collapsed.Val = c.commit(append(path, cn.Key...), cn.Val)
}
// The key needs to be copied, since we're adding it to the
// modified nodeset.
collapsed.Key = hexToCompact(cn.Key)
hashedNode := c.store(path, collapsed)
if hn, ok := hashedNode.(hashNode); ok {
return hn
}
return collapsed
case *fullNode:
hashedKids := c.commitChildren(path, cn)
collapsed := cn.copy()
collapsed.Children = hashedKids
hashedNode := c.store(path, collapsed)
if hn, ok := hashedNode.(hashNode); ok {
return hn
}
return collapsed
case hashNode:
return cn
default:
// nil, valuenode shouldn't be committed
panic(fmt.Sprintf("%T: invalid node: %v", n, n))
}
}
// commitChildren commits the children of the given fullnode
func (c *committer) commitChildren(path []byte, n *fullNode) [17]node {
var children [17]node
for i := 0; i < 16; i++ {
child := n.Children[i]
if child == nil {
continue
}
// If it's the hashed child, save the hash value directly.
// Note: it's impossible that the child in range [0, 15]
// is a valueNode.
if hn, ok := child.(hashNode); ok {
children[i] = hn
continue
}
// Commit the child recursively and store the "hashed" value.
// Note the returned node can be some embedded nodes, so it's
// possible the type is not hashNode.
children[i] = c.commit(append(path, byte(i)), child)
}
// For the 17th child, it's possible the type is valuenode.
if n.Children[16] != nil {
children[16] = n.Children[16]
}
return children
}
// store hashes the node n and adds it to the modified nodeset. If leaf collection
// is enabled, leaf nodes will be tracked in the modified nodeset as well.
func (c *committer) store(path []byte, n node) node {
// Larger nodes are replaced by their hash and stored in the database.
var hash, _ = n.cache()
// This was not generated - must be a small node stored in the parent.
// In theory, we should check if the node is leaf here (embedded node
// usually is leaf node). But small value (less than 32bytes) is not
// our target (leaves in account trie only).
if hash == nil {
// The node is embedded in its parent, in other words, this node
// will not be stored in the database independently, mark it as
// deleted only if the node was existent in database before.
if _, ok := c.nodes.accessList[string(path)]; ok {
c.nodes.markDeleted(path)
}
return n
}
// We have the hash already, estimate the RLP encoding-size of the node.
// The size is used for mem tracking, does not need to be exact
var (
size = estimateSize(n)
nhash = common.BytesToHash(hash)
mnode = &memoryNode{
hash: nhash,
node: simplifyNode(n),
size: uint16(size),
}
)
// Collect the dirty node to nodeset for return.
c.nodes.markUpdated(path, mnode)
// Collect the corresponding leaf node if it's required. We don't check
// full node since it's impossible to store value in fullNode. The key
// length of leaves should be exactly same.
if c.collectLeaf {
if sn, ok := n.(*shortNode); ok {
if val, ok := sn.Val.(valueNode); ok {
c.nodes.addLeaf(&leaf{blob: val, parent: nhash})
}
}
}
return hash
}
// estimateSize estimates the size of an rlp-encoded node, without actually
// rlp-encoding it (zero allocs). This method has been experimentally tried, and with a trie
// with 1000 leaves, the only errors above 1% are on small shortnodes, where this
// method overestimates by 2 or 3 bytes (e.g. 37 instead of 35)
func estimateSize(n node) int {
switch n := n.(type) {
case *shortNode:
// A short node contains a compacted key, and a value.
return 3 + len(n.Key) + estimateSize(n.Val)
case *fullNode:
// A full node contains up to 16 hashes (some nils), and a key
s := 3
for i := 0; i < 16; i++ {
if child := n.Children[i]; child != nil {
s += estimateSize(child)
} else {
s++
}
}
return s
case valueNode:
return 1 + len(n)
case hashNode:
return 1 + len(n)
default:
panic(fmt.Sprintf("node type %T", n))
}
}

View File

@ -18,25 +18,50 @@ package trie
import ( import (
"errors" "errors"
"runtime"
"sync"
"time"
"github.com/VictoriaMetrics/fastcache" "github.com/VictoriaMetrics/fastcache"
"github.com/cerc-io/ipld-eth-statedb/internal"
"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/core/rawdb"
"github.com/ethereum/go-ethereum/core/types"
"github.com/ethereum/go-ethereum/ethdb" "github.com/ethereum/go-ethereum/ethdb"
"github.com/ethereum/go-ethereum/rlp"
"github.com/ethereum/go-ethereum/trie" "github.com/ethereum/go-ethereum/trie"
"github.com/ipfs/go-cid" log "github.com/sirupsen/logrus"
) )
type CidKey = cid.Cid // Database is an intermediate write layer between the trie data structures and
// the disk database. The aim is to accumulate trie writes in-memory and only
func isEmpty(key CidKey) bool { // periodically flush a couple tries to disk, garbage collecting the remainder.
return len(key.KeyString()) == 0 //
} // Note, the trie Database is **not** thread safe in its mutations, but it **is**
// thread safe in providing individual, independent node access. The rationale
// Database is an intermediate read-only layer between the trie data structures and // behind this split design is to provide read access to RPC handlers and sync
// the disk database. This trie Database is thread safe in providing individual, // servers even while the trie is executing expensive garbage collection.
// independent node access.
type Database struct { type Database struct {
diskdb ethdb.KeyValueStore // Persistent storage for matured trie nodes diskdb ethdb.Database // Persistent storage for matured trie nodes
cleans *fastcache.Cache // GC friendly memory cache of clean node RLPs
cleans *fastcache.Cache // GC friendly memory cache of clean node RLPs
dirties map[common.Hash]*cachedNode // Data and references relationships of dirty trie nodes
oldest common.Hash // Oldest tracked node, flush-list head
newest common.Hash // Newest tracked node, flush-list tail
gctime time.Duration // Time spent on garbage collection since last commit
gcnodes uint64 // Nodes garbage collected since last commit
gcsize common.StorageSize // Data storage garbage collected since last commit
flushtime time.Duration // Time spent on data flushing since last commit
flushnodes uint64 // Nodes flushed since last commit
flushsize common.StorageSize // Data storage flushed since last commit
dirtiesSize common.StorageSize // Storage size of the dirty node cache (exc. metadata)
childrenSize common.StorageSize // Storage size of the external children tracking
preimages *preimageStore // The store for caching preimages
lock sync.RWMutex
} }
// Config defines all necessary options for database. // Config defines all necessary options for database.
@ -46,14 +71,14 @@ type Config = trie.Config
// NewDatabase creates a new trie database to store ephemeral trie content before // NewDatabase creates a new trie database to store ephemeral trie content before
// its written out to disk or garbage collected. No read cache is created, so all // its written out to disk or garbage collected. No read cache is created, so all
// data retrievals will hit the underlying disk database. // data retrievals will hit the underlying disk database.
func NewDatabase(diskdb ethdb.KeyValueStore) *Database { func NewDatabase(diskdb ethdb.Database) *Database {
return NewDatabaseWithConfig(diskdb, nil) return NewDatabaseWithConfig(diskdb, nil)
} }
// NewDatabaseWithConfig creates a new trie database to store ephemeral trie content // NewDatabaseWithConfig creates a new trie database to store ephemeral trie content
// before it's written out to disk or garbage collected. It also acts as a read cache // before its written out to disk or garbage collected. It also acts as a read cache
// for nodes loaded from disk. // for nodes loaded from disk.
func NewDatabaseWithConfig(diskdb ethdb.KeyValueStore, config *Config) *Database { func NewDatabaseWithConfig(diskdb ethdb.Database, config *Config) *Database {
var cleans *fastcache.Cache var cleans *fastcache.Cache
if config != nil && config.Cache > 0 { if config != nil && config.Cache > 0 {
if config.Journal == "" { if config.Journal == "" {
@ -62,44 +87,354 @@ func NewDatabaseWithConfig(diskdb ethdb.KeyValueStore, config *Config) *Database
cleans = fastcache.LoadFromFileOrNew(config.Journal, config.Cache*1024*1024) cleans = fastcache.LoadFromFileOrNew(config.Journal, config.Cache*1024*1024)
} }
} }
var preimage *preimageStore
if config != nil && config.Preimages {
preimage = newPreimageStore(diskdb)
}
db := &Database{ db := &Database{
diskdb: diskdb, diskdb: diskdb,
cleans: cleans, cleans: cleans,
dirties: map[common.Hash]*cachedNode{{}: {
children: make(map[common.Hash]uint16),
}},
preimages: preimage,
} }
return db return db
} }
// DiskDB retrieves the persistent storage backing the trie database. // insert inserts a simplified trie node into the memory database.
func (db *Database) DiskDB() ethdb.KeyValueStore { // All nodes inserted by this function will be reference tracked
return db.diskdb // and in theory should only used for **trie nodes** insertion.
func (db *Database) insert(hash common.Hash, size int, node node) {
// If the node's already cached, skip
if _, ok := db.dirties[hash]; ok {
return
}
memcacheDirtyWriteMeter.Mark(int64(size))
// Create the cached entry for this node
entry := &cachedNode{
node: node,
size: uint16(size),
flushPrev: db.newest,
}
entry.forChilds(func(child common.Hash) {
if c := db.dirties[child]; c != nil {
c.parents++
}
})
db.dirties[hash] = entry
// Update the flush-list endpoints
if db.oldest == (common.Hash{}) {
db.oldest, db.newest = hash, hash
} else {
db.dirties[db.newest].flushNext, db.newest = hash, hash
}
db.dirtiesSize += common.StorageSize(common.HashLength + entry.size)
} }
// Node retrieves an encoded trie node by CID. If it cannot be found // Node retrieves an encoded cached trie node from memory. If it cannot be found
// cached in memory, it queries the persistent database. // cached, the method queries the persistent database for the content.
func (db *Database) Node(key CidKey) ([]byte, error) { func (db *Database) Node(hash common.Hash, codec uint64) ([]byte, error) {
// It doesn't make sense to retrieve the metaroot // It doesn't make sense to retrieve the metaroot
if isEmpty(key) { if hash == (common.Hash{}) {
return nil, errors.New("not found") return nil, errors.New("not found")
} }
cidbytes := key.Bytes()
// Retrieve the node from the clean cache if available // Retrieve the node from the clean cache if available
if db.cleans != nil { if db.cleans != nil {
if enc := db.cleans.Get(nil, cidbytes); enc != nil { if enc := db.cleans.Get(nil, hash[:]); enc != nil {
memcacheCleanHitMeter.Mark(1)
memcacheCleanReadMeter.Mark(int64(len(enc)))
return enc, nil return enc, nil
} }
} }
// Retrieve the node from the dirty cache if available
db.lock.RLock()
dirty := db.dirties[hash]
db.lock.RUnlock()
if dirty != nil {
memcacheDirtyHitMeter.Mark(1)
memcacheDirtyReadMeter.Mark(int64(dirty.size))
return dirty.rlp(), nil
}
memcacheDirtyMissMeter.Mark(1)
// Content unavailable in memory, attempt to retrieve from disk // Content unavailable in memory, attempt to retrieve from disk
enc, err := db.diskdb.Get(cidbytes) cid, err := internal.Keccak256ToCid(codec, hash[:])
if err != nil {
return nil, err
}
enc, err := db.diskdb.Get(cid.Bytes())
if err != nil { if err != nil {
return nil, err return nil, err
} }
if len(enc) != 0 { if len(enc) != 0 {
if db.cleans != nil { if db.cleans != nil {
db.cleans.Set(cidbytes, enc) db.cleans.Set(hash[:], enc)
memcacheCleanMissMeter.Mark(1)
memcacheCleanWriteMeter.Mark(int64(len(enc)))
} }
return enc, nil return enc, nil
} }
return nil, errors.New("not found") return nil, errors.New("not found")
} }
// Nodes retrieves the hashes of all the nodes cached within the memory database.
// This method is extremely expensive and should only be used to validate internal
// states in test code.
func (db *Database) Nodes() []common.Hash {
db.lock.RLock()
defer db.lock.RUnlock()
var hashes = make([]common.Hash, 0, len(db.dirties))
for hash := range db.dirties {
if hash != (common.Hash{}) { // Special case for "root" references/nodes
hashes = append(hashes, hash)
}
}
return hashes
}
// Reference adds a new reference from a parent node to a child node.
// This function is used to add reference between internal trie node
// and external node(e.g. storage trie root), all internal trie nodes
// are referenced together by database itself.
func (db *Database) Reference(child common.Hash, parent common.Hash) {
db.lock.Lock()
defer db.lock.Unlock()
db.reference(child, parent)
}
// reference is the private locked version of Reference.
func (db *Database) reference(child common.Hash, parent common.Hash) {
// If the node does not exist, it's a node pulled from disk, skip
node, ok := db.dirties[child]
if !ok {
return
}
// If the reference already exists, only duplicate for roots
if db.dirties[parent].children == nil {
db.dirties[parent].children = make(map[common.Hash]uint16)
db.childrenSize += cachedNodeChildrenSize
} else if _, ok = db.dirties[parent].children[child]; ok && parent != (common.Hash{}) {
return
}
node.parents++
db.dirties[parent].children[child]++
if db.dirties[parent].children[child] == 1 {
db.childrenSize += common.HashLength + 2 // uint16 counter
}
}
// Dereference removes an existing reference from a root node.
func (db *Database) Dereference(root common.Hash) {
// Sanity check to ensure that the meta-root is not removed
if root == (common.Hash{}) {
log.Error("Attempted to dereference the trie cache meta root")
return
}
db.lock.Lock()
defer db.lock.Unlock()
nodes, storage, start := len(db.dirties), db.dirtiesSize, time.Now()
db.dereference(root, common.Hash{})
db.gcnodes += uint64(nodes - len(db.dirties))
db.gcsize += storage - db.dirtiesSize
db.gctime += time.Since(start)
memcacheGCTimeTimer.Update(time.Since(start))
memcacheGCSizeMeter.Mark(int64(storage - db.dirtiesSize))
memcacheGCNodesMeter.Mark(int64(nodes - len(db.dirties)))
log.Debug("Dereferenced trie from memory database", "nodes", nodes-len(db.dirties), "size", storage-db.dirtiesSize, "time", time.Since(start),
"gcnodes", db.gcnodes, "gcsize", db.gcsize, "gctime", db.gctime, "livenodes", len(db.dirties), "livesize", db.dirtiesSize)
}
// dereference is the private locked version of Dereference.
func (db *Database) dereference(child common.Hash, parent common.Hash) {
// Dereference the parent-child
node := db.dirties[parent]
if node.children != nil && node.children[child] > 0 {
node.children[child]--
if node.children[child] == 0 {
delete(node.children, child)
db.childrenSize -= (common.HashLength + 2) // uint16 counter
}
}
// If the child does not exist, it's a previously committed node.
node, ok := db.dirties[child]
if !ok {
return
}
// If there are no more references to the child, delete it and cascade
if node.parents > 0 {
// This is a special cornercase where a node loaded from disk (i.e. not in the
// memcache any more) gets reinjected as a new node (short node split into full,
// then reverted into short), causing a cached node to have no parents. That is
// no problem in itself, but don't make maxint parents out of it.
node.parents--
}
if node.parents == 0 {
// Remove the node from the flush-list
switch child {
case db.oldest:
db.oldest = node.flushNext
db.dirties[node.flushNext].flushPrev = common.Hash{}
case db.newest:
db.newest = node.flushPrev
db.dirties[node.flushPrev].flushNext = common.Hash{}
default:
db.dirties[node.flushPrev].flushNext = node.flushNext
db.dirties[node.flushNext].flushPrev = node.flushPrev
}
// Dereference all children and delete the node
node.forChilds(func(hash common.Hash) {
db.dereference(hash, child)
})
delete(db.dirties, child)
db.dirtiesSize -= common.StorageSize(common.HashLength + int(node.size))
if node.children != nil {
db.childrenSize -= cachedNodeChildrenSize
}
}
}
// Update inserts the dirty nodes in provided nodeset into database and
// link the account trie with multiple storage tries if necessary.
func (db *Database) Update(nodes *MergedNodeSet) error {
db.lock.Lock()
defer db.lock.Unlock()
// Insert dirty nodes into the database. In the same tree, it must be
// ensured that children are inserted first, then parent so that children
// can be linked with their parent correctly.
//
// Note, the storage tries must be flushed before the account trie to
// retain the invariant that children go into the dirty cache first.
var order []common.Hash
for owner := range nodes.sets {
if owner == (common.Hash{}) {
continue
}
order = append(order, owner)
}
if _, ok := nodes.sets[common.Hash{}]; ok {
order = append(order, common.Hash{})
}
for _, owner := range order {
subset := nodes.sets[owner]
subset.forEachWithOrder(func(path string, n *memoryNode) {
if n.isDeleted() {
return // ignore deletion
}
db.insert(n.hash, int(n.size), n.node)
})
}
// Link up the account trie and storage trie if the node points
// to an account trie leaf.
if set, present := nodes.sets[common.Hash{}]; present {
for _, n := range set.leaves {
var account types.StateAccount
if err := rlp.DecodeBytes(n.blob, &account); err != nil {
return err
}
if account.Root != types.EmptyRootHash {
db.reference(account.Root, n.parent)
}
}
}
return nil
}
// Size returns the current storage size of the memory cache in front of the
// persistent database layer.
func (db *Database) Size() (common.StorageSize, common.StorageSize) {
db.lock.RLock()
defer db.lock.RUnlock()
// db.dirtiesSize only contains the useful data in the cache, but when reporting
// the total memory consumption, the maintenance metadata is also needed to be
// counted.
var metadataSize = common.StorageSize((len(db.dirties) - 1) * cachedNodeSize)
var metarootRefs = common.StorageSize(len(db.dirties[common.Hash{}].children) * (common.HashLength + 2))
var preimageSize common.StorageSize
if db.preimages != nil {
preimageSize = db.preimages.size()
}
return db.dirtiesSize + db.childrenSize + metadataSize - metarootRefs, preimageSize
}
// GetReader retrieves a node reader belonging to the given state root.
func (db *Database) GetReader(root common.Hash, codec uint64) Reader {
return &hashReader{db: db, codec: codec}
}
// hashReader is reader of hashDatabase which implements the Reader interface.
type hashReader struct {
db *Database
codec uint64
}
// Node retrieves the trie node with the given node hash.
func (reader *hashReader) Node(owner common.Hash, path []byte, hash common.Hash) (node, error) {
blob, err := reader.NodeBlob(owner, path, hash)
if err != nil {
return nil, err
}
return decodeNodeUnsafe(hash[:], blob)
}
// NodeBlob retrieves the RLP-encoded trie node blob with the given node hash.
func (reader *hashReader) NodeBlob(_ common.Hash, _ []byte, hash common.Hash) ([]byte, error) {
return reader.db.Node(hash, reader.codec)
}
// saveCache saves clean state cache to given directory path
// using specified CPU cores.
func (db *Database) saveCache(dir string, threads int) error {
if db.cleans == nil {
return nil
}
log.Info("Writing clean trie cache to disk", "path", dir, "threads", threads)
start := time.Now()
err := db.cleans.SaveToFileConcurrent(dir, threads)
if err != nil {
log.Error("Failed to persist clean trie cache", "error", err)
return err
}
log.Info("Persisted the clean trie cache", "path", dir, "elapsed", common.PrettyDuration(time.Since(start)))
return nil
}
// SaveCache atomically saves fast cache data to the given dir using all
// available CPU cores.
func (db *Database) SaveCache(dir string) error {
return db.saveCache(dir, runtime.GOMAXPROCS(0))
}
// SaveCachePeriodically atomically saves fast cache data to the given dir with
// the specified interval. All dump operation will only use a single CPU core.
func (db *Database) SaveCachePeriodically(dir string, interval time.Duration, stopCh <-chan struct{}) {
ticker := time.NewTicker(interval)
defer ticker.Stop()
for {
select {
case <-ticker.C:
db.saveCache(dir, 1)
case <-stopCh:
return
}
}
}
// Scheme returns the node scheme used in the database.
func (db *Database) Scheme() string {
return rawdb.HashScheme
}

View File

@ -0,0 +1,27 @@
package trie
import "github.com/ethereum/go-ethereum/metrics"
var (
memcacheCleanHitMeter = metrics.NewRegisteredMeter("trie/memcache/clean/hit", nil)
memcacheCleanMissMeter = metrics.NewRegisteredMeter("trie/memcache/clean/miss", nil)
memcacheCleanReadMeter = metrics.NewRegisteredMeter("trie/memcache/clean/read", nil)
memcacheCleanWriteMeter = metrics.NewRegisteredMeter("trie/memcache/clean/write", nil)
memcacheDirtyHitMeter = metrics.NewRegisteredMeter("trie/memcache/dirty/hit", nil)
memcacheDirtyMissMeter = metrics.NewRegisteredMeter("trie/memcache/dirty/miss", nil)
memcacheDirtyReadMeter = metrics.NewRegisteredMeter("trie/memcache/dirty/read", nil)
memcacheDirtyWriteMeter = metrics.NewRegisteredMeter("trie/memcache/dirty/write", nil)
memcacheFlushTimeTimer = metrics.NewRegisteredResettingTimer("trie/memcache/flush/time", nil)
memcacheFlushNodesMeter = metrics.NewRegisteredMeter("trie/memcache/flush/nodes", nil)
memcacheFlushSizeMeter = metrics.NewRegisteredMeter("trie/memcache/flush/size", nil)
memcacheGCTimeTimer = metrics.NewRegisteredResettingTimer("trie/memcache/gc/time", nil)
memcacheGCNodesMeter = metrics.NewRegisteredMeter("trie/memcache/gc/nodes", nil)
memcacheGCSizeMeter = metrics.NewRegisteredMeter("trie/memcache/gc/size", nil)
memcacheCommitTimeTimer = metrics.NewRegisteredResettingTimer("trie/memcache/commit/time", nil)
memcacheCommitNodesMeter = metrics.NewRegisteredMeter("trie/memcache/commit/nodes", nil)
memcacheCommitSizeMeter = metrics.NewRegisteredMeter("trie/memcache/commit/size", nil)
)

View File

@ -0,0 +1,183 @@
package trie
import (
"fmt"
"io"
"reflect"
"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/rlp"
)
// rawNode is a simple binary blob used to differentiate between collapsed trie
// nodes and already encoded RLP binary blobs (while at the same time store them
// in the same cache fields).
type rawNode []byte
func (n rawNode) cache() (hashNode, bool) { panic("this should never end up in a live trie") }
func (n rawNode) fstring(ind string) string { panic("this should never end up in a live trie") }
func (n rawNode) EncodeRLP(w io.Writer) error {
_, err := w.Write(n)
return err
}
// rawFullNode represents only the useful data content of a full node, with the
// caches and flags stripped out to minimize its data storage. This type honors
// the same RLP encoding as the original parent.
type rawFullNode [17]node
func (n rawFullNode) cache() (hashNode, bool) { panic("this should never end up in a live trie") }
func (n rawFullNode) fstring(ind string) string { panic("this should never end up in a live trie") }
func (n rawFullNode) EncodeRLP(w io.Writer) error {
eb := rlp.NewEncoderBuffer(w)
n.encode(eb)
return eb.Flush()
}
// rawShortNode represents only the useful data content of a short node, with the
// caches and flags stripped out to minimize its data storage. This type honors
// the same RLP encoding as the original parent.
type rawShortNode struct {
Key []byte
Val node
}
func (n rawShortNode) cache() (hashNode, bool) { panic("this should never end up in a live trie") }
func (n rawShortNode) fstring(ind string) string { panic("this should never end up in a live trie") }
// cachedNode is all the information we know about a single cached trie node
// in the memory database write layer.
type cachedNode struct {
node node // Cached collapsed trie node, or raw rlp data
size uint16 // Byte size of the useful cached data
parents uint32 // Number of live nodes referencing this one
children map[common.Hash]uint16 // External children referenced by this node
flushPrev common.Hash // Previous node in the flush-list
flushNext common.Hash // Next node in the flush-list
}
// cachedNodeSize is the raw size of a cachedNode data structure without any
// node data included. It's an approximate size, but should be a lot better
// than not counting them.
var cachedNodeSize = int(reflect.TypeOf(cachedNode{}).Size())
// cachedNodeChildrenSize is the raw size of an initialized but empty external
// reference map.
const cachedNodeChildrenSize = 48
// rlp returns the raw rlp encoded blob of the cached trie node, either directly
// from the cache, or by regenerating it from the collapsed node.
func (n *cachedNode) rlp() []byte {
if node, ok := n.node.(rawNode); ok {
return node
}
return nodeToBytes(n.node)
}
// obj returns the decoded and expanded trie node, either directly from the cache,
// or by regenerating it from the rlp encoded blob.
func (n *cachedNode) obj(hash common.Hash) node {
if node, ok := n.node.(rawNode); ok {
// The raw-blob format nodes are loaded either from the
// clean cache or the database, they are all in their own
// copy and safe to use unsafe decoder.
return mustDecodeNodeUnsafe(hash[:], node)
}
return expandNode(hash[:], n.node)
}
// forChilds invokes the callback for all the tracked children of this node,
// both the implicit ones from inside the node as well as the explicit ones
// from outside the node.
func (n *cachedNode) forChilds(onChild func(hash common.Hash)) {
for child := range n.children {
onChild(child)
}
if _, ok := n.node.(rawNode); !ok {
forGatherChildren(n.node, onChild)
}
}
// forGatherChildren traverses the node hierarchy of a collapsed storage node and
// invokes the callback for all the hashnode children.
func forGatherChildren(n node, onChild func(hash common.Hash)) {
switch n := n.(type) {
case *rawShortNode:
forGatherChildren(n.Val, onChild)
case rawFullNode:
for i := 0; i < 16; i++ {
forGatherChildren(n[i], onChild)
}
case hashNode:
onChild(common.BytesToHash(n))
case valueNode, nil, rawNode:
default:
panic(fmt.Sprintf("unknown node type: %T", n))
}
}
// simplifyNode traverses the hierarchy of an expanded memory node and discards
// all the internal caches, returning a node that only contains the raw data.
func simplifyNode(n node) node {
switch n := n.(type) {
case *shortNode:
// Short nodes discard the flags and cascade
return &rawShortNode{Key: n.Key, Val: simplifyNode(n.Val)}
case *fullNode:
// Full nodes discard the flags and cascade
node := rawFullNode(n.Children)
for i := 0; i < len(node); i++ {
if node[i] != nil {
node[i] = simplifyNode(node[i])
}
}
return node
case valueNode, hashNode, rawNode:
return n
default:
panic(fmt.Sprintf("unknown node type: %T", n))
}
}
// expandNode traverses the node hierarchy of a collapsed storage node and converts
// all fields and keys into expanded memory form.
func expandNode(hash hashNode, n node) node {
switch n := n.(type) {
case *rawShortNode:
// Short nodes need key and child expansion
return &shortNode{
Key: compactToHex(n.Key),
Val: expandNode(nil, n.Val),
flags: nodeFlag{
hash: hash,
},
}
case rawFullNode:
// Full nodes need child expansion
node := &fullNode{
flags: nodeFlag{
hash: hash,
},
}
for i := 0; i < len(node.Children); i++ {
if n[i] != nil {
node.Children[i] = expandNode(nil, n[i])
}
}
return node
case valueNode, hashNode:
return n
default:
panic(fmt.Sprintf("unknown node type: %T", n))
}
}

View File

@ -14,21 +14,20 @@
// You should have received a copy of the GNU Lesser General Public License // You should have received a copy of the GNU Lesser General Public License
// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>. // along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
package trie_test package trie
import ( import (
"testing" "testing"
"github.com/ethereum/go-ethereum/ethdb/memorydb" "github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/core/rawdb"
"github.com/cerc-io/ipld-eth-statedb/trie_by_cid/trie"
) )
// Tests that the trie database returns a missing trie node error if attempting // Tests that the trie database returns a missing trie node error if attempting
// to retrieve the meta root. // to retrieve the meta root.
func TestDatabaseMetarootFetch(t *testing.T) { func TestDatabaseMetarootFetch(t *testing.T) {
db := trie.NewDatabase(memorydb.New()) db := NewDatabase(rawdb.NewMemoryDatabase())
if _, err := db.Node(trie.CidKey{}); err == nil { if _, err := db.Node(common.Hash{}, StateTrieCodec); err == nil {
t.Fatalf("metaroot retrieval succeeded") t.Fatalf("metaroot retrieval succeeded")
} }
} }

View File

@ -16,8 +16,81 @@
package trie package trie
// Trie keys are dealt with in three distinct encodings:
//
// KEYBYTES encoding contains the actual key and nothing else. This encoding is the
// input to most API functions.
//
// HEX encoding contains one byte for each nibble of the key and an optional trailing
// 'terminator' byte of value 0x10 which indicates whether or not the node at the key
// contains a value. Hex key encoding is used for nodes loaded in memory because it's
// convenient to access.
//
// COMPACT encoding is defined by the Ethereum Yellow Paper (it's called "hex prefix
// encoding" there) and contains the bytes of the key and a flag. The high nibble of the
// first byte contains the flag; the lowest bit encoding the oddness of the length and
// the second-lowest encoding whether the node at the key is a value node. The low nibble
// of the first byte is zero in the case of an even number of nibbles and the first nibble
// in the case of an odd number. All remaining nibbles (now an even number) fit properly
// into the remaining bytes. Compact encoding is used for nodes stored on disk.
// HexToCompact converts a hex path to the compact encoded format
func HexToCompact(hex []byte) []byte {
return hexToCompact(hex)
}
func hexToCompact(hex []byte) []byte {
terminator := byte(0)
if hasTerm(hex) {
terminator = 1
hex = hex[:len(hex)-1]
}
buf := make([]byte, len(hex)/2+1)
buf[0] = terminator << 5 // the flag byte
if len(hex)&1 == 1 {
buf[0] |= 1 << 4 // odd flag
buf[0] |= hex[0] // first nibble is contained in the first byte
hex = hex[1:]
}
decodeNibbles(hex, buf[1:])
return buf
}
// hexToCompactInPlace places the compact key in input buffer, returning the length
// needed for the representation
func hexToCompactInPlace(hex []byte) int {
var (
hexLen = len(hex) // length of the hex input
firstByte = byte(0)
)
// Check if we have a terminator there
if hexLen > 0 && hex[hexLen-1] == 16 {
firstByte = 1 << 5
hexLen-- // last part was the terminator, ignore that
}
var (
binLen = hexLen/2 + 1
ni = 0 // index in hex
bi = 1 // index in bin (compact)
)
if hexLen&1 == 1 {
firstByte |= 1 << 4 // odd flag
firstByte |= hex[0] // first nibble is contained in the first byte
ni++
}
for ; ni < hexLen; bi, ni = bi+1, ni+2 {
hex[bi] = hex[ni]<<4 | hex[ni+1]
}
hex[0] = firstByte
return binLen
}
// CompactToHex converts a compact encoded path to hex format // CompactToHex converts a compact encoded path to hex format
func CompactToHex(compact []byte) []byte { func CompactToHex(compact []byte) []byte {
return compactToHex(compact)
}
func compactToHex(compact []byte) []byte {
if len(compact) == 0 { if len(compact) == 0 {
return compact return compact
} }
@ -62,6 +135,20 @@ func decodeNibbles(nibbles []byte, bytes []byte) {
} }
} }
// prefixLen returns the length of the common prefix of a and b.
func prefixLen(a, b []byte) int {
var i, length = 0, len(a)
if len(b) < length {
length = len(b)
}
for ; i < length; i++ {
if a[i] != b[i] {
break
}
}
return i
}
// hasTerm returns whether a hex key has the terminator flag. // hasTerm returns whether a hex key has the terminator flag.
func hasTerm(s []byte) bool { func hasTerm(s []byte) bool {
return len(s) > 0 && s[len(s)-1] == 16 return len(s) > 0 && s[len(s)-1] == 16

View File

@ -0,0 +1,141 @@
// Copyright 2014 The go-ethereum Authors
// This file is part of the go-ethereum library.
//
// The go-ethereum library is free software: you can redistribute it and/or modify
// it under the terms of the GNU Lesser General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// The go-ethereum library is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Lesser General Public License for more details.
//
// You should have received a copy of the GNU Lesser General Public License
// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
package trie
import (
"bytes"
crand "crypto/rand"
"encoding/hex"
"math/rand"
"testing"
)
func TestHexCompact(t *testing.T) {
tests := []struct{ hex, compact []byte }{
// empty keys, with and without terminator.
{hex: []byte{}, compact: []byte{0x00}},
{hex: []byte{16}, compact: []byte{0x20}},
// odd length, no terminator
{hex: []byte{1, 2, 3, 4, 5}, compact: []byte{0x11, 0x23, 0x45}},
// even length, no terminator
{hex: []byte{0, 1, 2, 3, 4, 5}, compact: []byte{0x00, 0x01, 0x23, 0x45}},
// odd length, terminator
{hex: []byte{15, 1, 12, 11, 8, 16 /*term*/}, compact: []byte{0x3f, 0x1c, 0xb8}},
// even length, terminator
{hex: []byte{0, 15, 1, 12, 11, 8, 16 /*term*/}, compact: []byte{0x20, 0x0f, 0x1c, 0xb8}},
}
for _, test := range tests {
if c := hexToCompact(test.hex); !bytes.Equal(c, test.compact) {
t.Errorf("hexToCompact(%x) -> %x, want %x", test.hex, c, test.compact)
}
if h := compactToHex(test.compact); !bytes.Equal(h, test.hex) {
t.Errorf("compactToHex(%x) -> %x, want %x", test.compact, h, test.hex)
}
}
}
func TestHexKeybytes(t *testing.T) {
tests := []struct{ key, hexIn, hexOut []byte }{
{key: []byte{}, hexIn: []byte{16}, hexOut: []byte{16}},
{key: []byte{}, hexIn: []byte{}, hexOut: []byte{16}},
{
key: []byte{0x12, 0x34, 0x56},
hexIn: []byte{1, 2, 3, 4, 5, 6, 16},
hexOut: []byte{1, 2, 3, 4, 5, 6, 16},
},
{
key: []byte{0x12, 0x34, 0x5},
hexIn: []byte{1, 2, 3, 4, 0, 5, 16},
hexOut: []byte{1, 2, 3, 4, 0, 5, 16},
},
{
key: []byte{0x12, 0x34, 0x56},
hexIn: []byte{1, 2, 3, 4, 5, 6},
hexOut: []byte{1, 2, 3, 4, 5, 6, 16},
},
}
for _, test := range tests {
if h := keybytesToHex(test.key); !bytes.Equal(h, test.hexOut) {
t.Errorf("keybytesToHex(%x) -> %x, want %x", test.key, h, test.hexOut)
}
if k := hexToKeyBytes(test.hexIn); !bytes.Equal(k, test.key) {
t.Errorf("hexToKeyBytes(%x) -> %x, want %x", test.hexIn, k, test.key)
}
}
}
func TestHexToCompactInPlace(t *testing.T) {
for i, key := range []string{
"00",
"060a040c0f000a090b040803010801010900080d090a0a0d0903000b10",
"10",
} {
hexBytes, _ := hex.DecodeString(key)
exp := hexToCompact(hexBytes)
sz := hexToCompactInPlace(hexBytes)
got := hexBytes[:sz]
if !bytes.Equal(exp, got) {
t.Fatalf("test %d: encoding err\ninp %v\ngot %x\nexp %x\n", i, key, got, exp)
}
}
}
func TestHexToCompactInPlaceRandom(t *testing.T) {
for i := 0; i < 10000; i++ {
l := rand.Intn(128)
key := make([]byte, l)
crand.Read(key)
hexBytes := keybytesToHex(key)
hexOrig := []byte(string(hexBytes))
exp := hexToCompact(hexBytes)
sz := hexToCompactInPlace(hexBytes)
got := hexBytes[:sz]
if !bytes.Equal(exp, got) {
t.Fatalf("encoding err \ncpt %x\nhex %x\ngot %x\nexp %x\n",
key, hexOrig, got, exp)
}
}
}
func BenchmarkHexToCompact(b *testing.B) {
testBytes := []byte{0, 15, 1, 12, 11, 8, 16 /*term*/}
for i := 0; i < b.N; i++ {
hexToCompact(testBytes)
}
}
func BenchmarkCompactToHex(b *testing.B) {
testBytes := []byte{0, 15, 1, 12, 11, 8, 16 /*term*/}
for i := 0; i < b.N; i++ {
compactToHex(testBytes)
}
}
func BenchmarkKeybytesToHex(b *testing.B) {
testBytes := []byte{7, 6, 6, 5, 7, 2, 6, 2, 16}
for i := 0; i < b.N; i++ {
keybytesToHex(testBytes)
}
}
func BenchmarkHexToKeybytes(b *testing.B) {
testBytes := []byte{7, 6, 6, 5, 7, 2, 6, 2, 16}
for i := 0; i < b.N; i++ {
hexToKeyBytes(testBytes)
}
}

View File

@ -27,7 +27,7 @@ import (
// information necessary for retrieving the missing node. // information necessary for retrieving the missing node.
type MissingNodeError struct { type MissingNodeError struct {
Owner common.Hash // owner of the trie if it's 2-layered trie Owner common.Hash // owner of the trie if it's 2-layered trie
NodeHash []byte // hash of the missing node NodeHash common.Hash // hash of the missing node
Path []byte // hex-encoded path to the missing node Path []byte // hex-encoded path to the missing node
err error // concrete error for missing trie node err error // concrete error for missing trie node
} }

View File

@ -21,7 +21,6 @@ import (
"github.com/ethereum/go-ethereum/crypto" "github.com/ethereum/go-ethereum/crypto"
"github.com/ethereum/go-ethereum/rlp" "github.com/ethereum/go-ethereum/rlp"
"github.com/ethereum/go-ethereum/trie"
"golang.org/x/crypto/sha3" "golang.org/x/crypto/sha3"
) )
@ -99,7 +98,7 @@ func (h *hasher) hashShortNodeChildren(n *shortNode) (collapsed, cached *shortNo
// Previously, we did copy this one. We don't seem to need to actually // Previously, we did copy this one. We don't seem to need to actually
// do that, since we don't overwrite/reuse keys // do that, since we don't overwrite/reuse keys
//cached.Key = common.CopyBytes(n.Key) //cached.Key = common.CopyBytes(n.Key)
collapsed.Key = trie.HexToCompact(n.Key) collapsed.Key = hexToCompact(n.Key)
// Unless the child is a valuenode or hashnode, hash it // Unless the child is a valuenode or hashnode, hash it
switch n.Val.(type) { switch n.Val.(type) {
case *fullNode, *shortNode: case *fullNode, *shortNode:
@ -171,8 +170,8 @@ func (h *hasher) fullnodeToHash(n *fullNode, force bool) node {
// //
// All node encoding must be done like this: // All node encoding must be done like this:
// //
// node.encode(h.encbuf) // node.encode(h.encbuf)
// enc := h.encodedBytes() // enc := h.encodedBytes()
// //
// This convention exists because node.encode can only be inlined/escape-analyzed when // This convention exists because node.encode can only be inlined/escape-analyzed when
// called on a concrete receiver type. // called on a concrete receiver type.

View File

@ -18,14 +18,26 @@ package trie
import ( import (
"bytes" "bytes"
"container/heap"
"errors" "errors"
"time"
"github.com/ethereum/go-ethereum/statediff/indexer/database/metrics"
"github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/trie" "github.com/ethereum/go-ethereum/core/types"
gethtrie "github.com/ethereum/go-ethereum/trie"
) )
// NodeIterator is a re-export of the go-ethereum interface // NodeIterator is an iterator to traverse the trie pre-order.
type NodeIterator = trie.NodeIterator type NodeIterator = gethtrie.NodeIterator
// NodeResolver is used for looking up trie nodes before reaching into the real
// persistent layer. This is not mandatory, rather is an optimization for cases
// where trie nodes can be recovered from some external mechanism without reading
// from disk. In those cases, this resolver allows short circuiting accesses and
// returning them from memory.
type NodeResolver = gethtrie.NodeResolver
// Iterator is a key-value trie iterator that traverses a Trie. // Iterator is a key-value trie iterator that traverses a Trie.
type Iterator struct { type Iterator struct {
@ -82,7 +94,7 @@ type nodeIterator struct {
path []byte // Path to the current node path []byte // Path to the current node
err error // Failure set in case of an internal error in the iterator err error // Failure set in case of an internal error in the iterator
resolver trie.NodeResolver // Optional intermediate resolver above the disk layer resolver NodeResolver // optional node resolver for avoiding disk hits
} }
// errIteratorEnd is stored in nodeIterator.err when iteration is done. // errIteratorEnd is stored in nodeIterator.err when iteration is done.
@ -99,7 +111,7 @@ func (e seekError) Error() string {
} }
func newNodeIterator(trie *Trie, start []byte) NodeIterator { func newNodeIterator(trie *Trie, start []byte) NodeIterator {
if trie.Hash() == emptyRoot { if trie.Hash() == types.EmptyRootHash {
return &nodeIterator{ return &nodeIterator{
trie: trie, trie: trie,
err: errIteratorEnd, err: errIteratorEnd,
@ -110,7 +122,7 @@ func newNodeIterator(trie *Trie, start []byte) NodeIterator {
return it return it
} }
func (it *nodeIterator) AddResolver(resolver trie.NodeResolver) { func (it *nodeIterator) AddResolver(resolver NodeResolver) {
it.resolver = resolver it.resolver = resolver
} }
@ -128,6 +140,14 @@ func (it *nodeIterator) Parent() common.Hash {
return it.stack[len(it.stack)-1].parent return it.stack[len(it.stack)-1].parent
} }
func (it *nodeIterator) ParentPath() []byte {
if len(it.stack) == 0 {
return []byte{}
}
pathlen := it.stack[len(it.stack)-1].pathlen
return it.path[:pathlen]
}
func (it *nodeIterator) Leaf() bool { func (it *nodeIterator) Leaf() bool {
return hasTerm(it.path) return hasTerm(it.path)
} }
@ -241,7 +261,7 @@ func (it *nodeIterator) seek(prefix []byte) error {
func (it *nodeIterator) init() (*nodeIteratorState, error) { func (it *nodeIterator) init() (*nodeIteratorState, error) {
root := it.trie.Hash() root := it.trie.Hash()
state := &nodeIteratorState{node: it.trie.root, index: -1} state := &nodeIteratorState{node: it.trie.root, index: -1}
if root != emptyRoot { if root != types.EmptyRootHash {
state.hash = root state.hash = root
} }
return state, state.resolve(it, nil) return state, state.resolve(it, nil)
@ -320,7 +340,12 @@ func (it *nodeIterator) resolveHash(hash hashNode, path []byte) (node, error) {
} }
} }
} }
return it.trie.resolveHash(hash, path) // Retrieve the specified node from the underlying node reader.
// it.trie.resolveAndTrack is not used since in that function the
// loaded blob will be tracked, while it's not required here since
// all loaded nodes won't be linked to trie at all and track nodes
// may lead to out-of-memory issue.
return it.trie.reader.node(path, common.BytesToHash(hash))
} }
func (it *nodeIterator) resolveBlob(hash hashNode, path []byte) ([]byte, error) { func (it *nodeIterator) resolveBlob(hash hashNode, path []byte) ([]byte, error) {
@ -329,7 +354,12 @@ func (it *nodeIterator) resolveBlob(hash hashNode, path []byte) ([]byte, error)
return blob, nil return blob, nil
} }
} }
return it.trie.resolveBlob(hash, path) // Retrieve the specified node from the underlying node reader.
// it.trie.resolveAndTrack is not used since in that function the
// loaded blob will be tracked, while it's not required here since
// all loaded nodes won't be linked to trie at all and track nodes
// may lead to out-of-memory issue.
return it.trie.reader.nodeBlob(path, common.BytesToHash(hash))
} }
func (st *nodeIteratorState) resolve(it *nodeIterator, path []byte) error { func (st *nodeIteratorState) resolve(it *nodeIterator, path []byte) error {
@ -455,3 +485,248 @@ func (it *nodeIterator) pop() {
it.stack[len(it.stack)-1] = nil it.stack[len(it.stack)-1] = nil
it.stack = it.stack[:len(it.stack)-1] it.stack = it.stack[:len(it.stack)-1]
} }
func compareNodes(a, b NodeIterator) int {
if cmp := bytes.Compare(a.Path(), b.Path()); cmp != 0 {
return cmp
}
if a.Leaf() && !b.Leaf() {
return -1
} else if b.Leaf() && !a.Leaf() {
return 1
}
if cmp := bytes.Compare(a.Hash().Bytes(), b.Hash().Bytes()); cmp != 0 {
return cmp
}
if a.Leaf() && b.Leaf() {
return bytes.Compare(a.LeafBlob(), b.LeafBlob())
}
return 0
}
type differenceIterator struct {
a, b NodeIterator // Nodes returned are those in b - a.
eof bool // Indicates a has run out of elements
count int // Number of nodes scanned on either trie
}
// NewDifferenceIterator constructs a NodeIterator that iterates over elements in b that
// are not in a. Returns the iterator, and a pointer to an integer recording the number
// of nodes seen.
func NewDifferenceIterator(a, b NodeIterator) (NodeIterator, *int) {
a.Next(true)
it := &differenceIterator{
a: a,
b: b,
}
return it, &it.count
}
func (it *differenceIterator) Hash() common.Hash {
return it.b.Hash()
}
func (it *differenceIterator) Parent() common.Hash {
return it.b.Parent()
}
func (it *differenceIterator) ParentPath() []byte {
return it.b.ParentPath()
}
func (it *differenceIterator) Leaf() bool {
return it.b.Leaf()
}
func (it *differenceIterator) LeafKey() []byte {
return it.b.LeafKey()
}
func (it *differenceIterator) LeafBlob() []byte {
return it.b.LeafBlob()
}
func (it *differenceIterator) LeafProof() [][]byte {
return it.b.LeafProof()
}
func (it *differenceIterator) Path() []byte {
return it.b.Path()
}
func (it *differenceIterator) NodeBlob() []byte {
return it.b.NodeBlob()
}
func (it *differenceIterator) AddResolver(resolver NodeResolver) {
panic("not implemented")
}
func (it *differenceIterator) Next(bool) bool {
defer metrics.UpdateDuration(time.Now(), metrics.IndexerMetrics.DifferenceIteratorNextTimer)
// Invariants:
// - We always advance at least one element in b.
// - At the start of this function, a's path is lexically greater than b's.
if !it.b.Next(true) {
return false
}
it.count++
if it.eof {
// a has reached eof, so we just return all elements from b
return true
}
for {
switch compareNodes(it.a, it.b) {
case -1:
// b jumped past a; advance a
if !it.a.Next(true) {
it.eof = true
return true
}
it.count++
case 1:
// b is before a
return true
case 0:
// a and b are identical; skip this whole subtree if the nodes have hashes
hasHash := it.a.Hash() == common.Hash{}
if !it.b.Next(hasHash) {
return false
}
it.count++
if !it.a.Next(hasHash) {
it.eof = true
return true
}
it.count++
}
}
}
func (it *differenceIterator) Error() error {
if err := it.a.Error(); err != nil {
return err
}
return it.b.Error()
}
type nodeIteratorHeap []NodeIterator
func (h nodeIteratorHeap) Len() int { return len(h) }
func (h nodeIteratorHeap) Less(i, j int) bool { return compareNodes(h[i], h[j]) < 0 }
func (h nodeIteratorHeap) Swap(i, j int) { h[i], h[j] = h[j], h[i] }
func (h *nodeIteratorHeap) Push(x interface{}) { *h = append(*h, x.(NodeIterator)) }
func (h *nodeIteratorHeap) Pop() interface{} {
n := len(*h)
x := (*h)[n-1]
*h = (*h)[0 : n-1]
return x
}
type unionIterator struct {
items *nodeIteratorHeap // Nodes returned are the union of the ones in these iterators
count int // Number of nodes scanned across all tries
}
// NewUnionIterator constructs a NodeIterator that iterates over elements in the union
// of the provided NodeIterators. Returns the iterator, and a pointer to an integer
// recording the number of nodes visited.
func NewUnionIterator(iters []NodeIterator) (NodeIterator, *int) {
h := make(nodeIteratorHeap, len(iters))
copy(h, iters)
heap.Init(&h)
ui := &unionIterator{items: &h}
return ui, &ui.count
}
func (it *unionIterator) Hash() common.Hash {
return (*it.items)[0].Hash()
}
func (it *unionIterator) Parent() common.Hash {
return (*it.items)[0].Parent()
}
func (it *unionIterator) ParentPath() []byte {
return (*it.items)[0].ParentPath()
}
func (it *unionIterator) Leaf() bool {
return (*it.items)[0].Leaf()
}
func (it *unionIterator) LeafKey() []byte {
return (*it.items)[0].LeafKey()
}
func (it *unionIterator) LeafBlob() []byte {
return (*it.items)[0].LeafBlob()
}
func (it *unionIterator) LeafProof() [][]byte {
return (*it.items)[0].LeafProof()
}
func (it *unionIterator) Path() []byte {
return (*it.items)[0].Path()
}
func (it *unionIterator) NodeBlob() []byte {
return (*it.items)[0].NodeBlob()
}
func (it *unionIterator) AddResolver(resolver NodeResolver) {
panic("not implemented")
}
// Next returns the next node in the union of tries being iterated over.
//
// It does this by maintaining a heap of iterators, sorted by the iteration
// order of their next elements, with one entry for each source trie. Each
// time Next() is called, it takes the least element from the heap to return,
// advancing any other iterators that also point to that same element. These
// iterators are called with descend=false, since we know that any nodes under
// these nodes will also be duplicates, found in the currently selected iterator.
// Whenever an iterator is advanced, it is pushed back into the heap if it still
// has elements remaining.
//
// In the case that descend=false - eg, we're asked to ignore all subnodes of the
// current node - we also advance any iterators in the heap that have the current
// path as a prefix.
func (it *unionIterator) Next(descend bool) bool {
if len(*it.items) == 0 {
return false
}
// Get the next key from the union
least := heap.Pop(it.items).(NodeIterator)
// Skip over other nodes as long as they're identical, or, if we're not descending, as
// long as they have the same prefix as the current node.
for len(*it.items) > 0 && ((!descend && bytes.HasPrefix((*it.items)[0].Path(), least.Path())) || compareNodes(least, (*it.items)[0]) == 0) {
skipped := heap.Pop(it.items).(NodeIterator)
// Skip the whole subtree if the nodes have hashes; otherwise just skip this node
if skipped.Next(skipped.Hash() == common.Hash{}) {
it.count++
// If there are more elements, push the iterator back on the heap
heap.Push(it.items, skipped)
}
}
if least.Next(descend) {
it.count++
heap.Push(it.items, least)
}
return len(*it.items) > 0
}
func (it *unionIterator) Error() error {
for i := 0; i < len(*it.items); i++ {
if err := (*it.items)[i].Error(); err != nil {
return err
}
}
return nil
}

View File

@ -14,28 +14,24 @@
// You should have received a copy of the GNU Lesser General Public License // You should have received a copy of the GNU Lesser General Public License
// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>. // along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
package trie_test package trie
import ( import (
"bytes" "bytes"
"context" "encoding/binary"
"fmt" "fmt"
"math/rand"
"testing" "testing"
"github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/core/rawdb" "github.com/ethereum/go-ethereum/core/rawdb"
"github.com/ethereum/go-ethereum/statediff/indexer/database/sql/postgres" "github.com/ethereum/go-ethereum/crypto"
"github.com/ethereum/go-ethereum/ethdb/memorydb"
geth_trie "github.com/ethereum/go-ethereum/trie" geth_trie "github.com/ethereum/go-ethereum/trie"
"github.com/cerc-io/ipld-eth-statedb/trie_by_cid/trie"
) )
var ( var (
dbConfig, _ = postgres.DefaultConfig.WithEnv() packableTestData = []kvsi{
trieConfig = trie.Config{Cache: 256}
ctx = context.Background()
testdata0 = []kvs{
{"one", 1}, {"one", 1},
{"two", 2}, {"two", 2},
{"three", 3}, {"three", 3},
@ -43,20 +39,10 @@ var (
{"five", 5}, {"five", 5},
{"ten", 10}, {"ten", 10},
} }
testdata1 = []kvs{
{"barb", 0},
{"bard", 1},
{"bars", 2},
{"bar", 3},
{"fab", 4},
{"food", 5},
{"foos", 6},
{"foo", 7},
}
) )
func TestEmptyIterator(t *testing.T) { func TestEmptyIterator(t *testing.T) {
trie := trie.NewEmpty(trie.NewDatabase(rawdb.NewMemoryDatabase())) trie := NewEmpty(NewDatabase(rawdb.NewMemoryDatabase()))
iter := trie.NodeIterator(nil) iter := trie.NodeIterator(nil)
seen := make(map[string]struct{}) seen := make(map[string]struct{})
@ -69,63 +55,163 @@ func TestEmptyIterator(t *testing.T) {
} }
func TestIterator(t *testing.T) { func TestIterator(t *testing.T) {
edb := rawdb.NewMemoryDatabase() db := NewDatabase(rawdb.NewMemoryDatabase())
db := geth_trie.NewDatabase(edb) trie := NewEmpty(db)
origtrie := geth_trie.NewEmpty(db) vals := []struct{ k, v string }{
all, err := updateTrie(origtrie, testdata0) {"do", "verb"},
if err != nil { {"ether", "wookiedoo"},
t.Fatal(err) {"horse", "stallion"},
{"shaman", "horse"},
{"doge", "coin"},
{"dog", "puppy"},
{"somethingveryoddindeedthis is", "myothernodedata"},
} }
// commit and index data all := make(map[string]string)
root := commitTrie(t, db, origtrie) for _, val := range vals {
tr := indexTrie(t, edb, root) all[val.k] = val.v
trie.Update([]byte(val.k), []byte(val.v))
}
root, nodes := trie.Commit(false)
db.Update(NewWithNodeSet(nodes))
found := make(map[string]int64) trie, _ = New(TrieID(root), db, StateTrieCodec)
it := trie.NewIterator(tr.NodeIterator(nil)) found := make(map[string]string)
it := NewIterator(trie.NodeIterator(nil))
for it.Next() { for it.Next() {
found[string(it.Key)] = unpackValue(it.Value) found[string(it.Key)] = string(it.Value)
} }
if len(found) != len(all) { for k, v := range all {
t.Errorf("number of iterated values do not match: want %d, found %d", len(all), len(found)) if found[k] != v {
} t.Errorf("iterator value mismatch for %s: got %q want %q", k, found[k], v)
for k, kv := range all {
if found[k] != kv.v {
t.Errorf("iterator value mismatch for %s: got %q want %q", k, found[k], kv.v)
} }
} }
} }
func TestIteratorSeek(t *testing.T) { type kv struct {
edb := rawdb.NewMemoryDatabase() k, v []byte
db := geth_trie.NewDatabase(edb) t bool
orig := geth_trie.NewEmpty(geth_trie.NewDatabase(rawdb.NewMemoryDatabase())) }
if _, err := updateTrie(orig, testdata1); err != nil {
t.Fatal(err) func TestIteratorLargeData(t *testing.T) {
trie := NewEmpty(NewDatabase(rawdb.NewMemoryDatabase()))
vals := make(map[string]*kv)
for i := byte(0); i < 255; i++ {
value := &kv{common.LeftPadBytes([]byte{i}, 32), []byte{i}, false}
value2 := &kv{common.LeftPadBytes([]byte{10, i}, 32), []byte{i}, false}
trie.Update(value.k, value.v)
trie.Update(value2.k, value2.v)
vals[string(value.k)] = value
vals[string(value2.k)] = value2
}
it := NewIterator(trie.NodeIterator(nil))
for it.Next() {
vals[string(it.Key)].t = true
}
var untouched []*kv
for _, value := range vals {
if !value.t {
untouched = append(untouched, value)
}
}
if len(untouched) > 0 {
t.Errorf("Missed %d nodes", len(untouched))
for _, value := range untouched {
t.Error(value)
}
}
}
// Tests that the node iterator indeed walks over the entire database contents.
func TestNodeIteratorCoverage(t *testing.T) {
db, trie, _ := makeTestTrie(t)
// Create some arbitrary test trie to iterate
// Gather all the node hashes found by the iterator
hashes := make(map[common.Hash]struct{})
for it := trie.NodeIterator(nil); it.Next(true); {
if it.Hash() != (common.Hash{}) {
hashes[it.Hash()] = struct{}{}
}
}
// Cross check the hashes and the database itself
for hash := range hashes {
if _, err := db.Node(hash, StateTrieCodec); err != nil {
t.Errorf("failed to retrieve reported node %x: %v", hash, err)
}
}
for hash, obj := range db.dirties {
if obj != nil && hash != (common.Hash{}) {
if _, ok := hashes[hash]; !ok {
t.Errorf("state entry not reported %x", hash)
}
}
}
it := db.diskdb.NewIterator(nil, nil)
for it.Next() {
key := it.Key()
if _, ok := hashes[common.BytesToHash(key)]; !ok {
t.Errorf("state entry not reported %x", key)
}
}
it.Release()
}
type kvs struct{ k, v string }
var testdata1 = []kvs{
{"barb", "ba"},
{"bard", "bc"},
{"bars", "bb"},
{"bar", "b"},
{"fab", "z"},
{"food", "ab"},
{"foos", "aa"},
{"foo", "a"},
}
var testdata2 = []kvs{
{"aardvark", "c"},
{"bar", "b"},
{"barb", "bd"},
{"bars", "be"},
{"fab", "z"},
{"foo", "a"},
{"foos", "aa"},
{"food", "ab"},
{"jars", "d"},
}
func TestIteratorSeek(t *testing.T) {
trie := NewEmpty(NewDatabase(rawdb.NewMemoryDatabase()))
for _, val := range testdata1 {
trie.Update([]byte(val.k), []byte(val.v))
} }
root := commitTrie(t, db, orig)
tr := indexTrie(t, edb, root)
// Seek to the middle. // Seek to the middle.
it := trie.NewIterator(tr.NodeIterator([]byte("fab"))) it := NewIterator(trie.NodeIterator([]byte("fab")))
if err := checkIteratorOrder(testdata1[4:], it); err != nil { if err := checkIteratorOrder(testdata1[4:], it); err != nil {
t.Fatal(err) t.Fatal(err)
} }
// Seek to a non-existent key. // Seek to a non-existent key.
it = trie.NewIterator(tr.NodeIterator([]byte("barc"))) it = NewIterator(trie.NodeIterator([]byte("barc")))
if err := checkIteratorOrder(testdata1[1:], it); err != nil { if err := checkIteratorOrder(testdata1[1:], it); err != nil {
t.Fatal(err) t.Fatal(err)
} }
// Seek beyond the end. // Seek beyond the end.
it = trie.NewIterator(tr.NodeIterator([]byte("z"))) it = NewIterator(trie.NodeIterator([]byte("z")))
if err := checkIteratorOrder(nil, it); err != nil { if err := checkIteratorOrder(nil, it); err != nil {
t.Fatal(err) t.Fatal(err)
} }
} }
func checkIteratorOrder(want []kvs, it *trie.Iterator) error { func checkIteratorOrder(want []kvs, it *Iterator) error {
for it.Next() { for it.Next() {
if len(want) == 0 { if len(want) == 0 {
return fmt.Errorf("didn't expect any more values, got key %q", it.Key) return fmt.Errorf("didn't expect any more values, got key %q", it.Key)
@ -141,11 +227,366 @@ func checkIteratorOrder(want []kvs, it *trie.Iterator) error {
return nil return nil
} }
func TestDifferenceIterator(t *testing.T) {
dba := NewDatabase(rawdb.NewMemoryDatabase())
triea := NewEmpty(dba)
for _, val := range testdata1 {
triea.Update([]byte(val.k), []byte(val.v))
}
rootA, nodesA := triea.Commit(false)
dba.Update(NewWithNodeSet(nodesA))
triea, _ = New(TrieID(rootA), dba, StateTrieCodec)
dbb := NewDatabase(rawdb.NewMemoryDatabase())
trieb := NewEmpty(dbb)
for _, val := range testdata2 {
trieb.Update([]byte(val.k), []byte(val.v))
}
rootB, nodesB := trieb.Commit(false)
dbb.Update(NewWithNodeSet(nodesB))
trieb, _ = New(TrieID(rootB), dbb, StateTrieCodec)
found := make(map[string]string)
di, _ := NewDifferenceIterator(triea.NodeIterator(nil), trieb.NodeIterator(nil))
it := NewIterator(di)
for it.Next() {
found[string(it.Key)] = string(it.Value)
}
all := []struct{ k, v string }{
{"aardvark", "c"},
{"barb", "bd"},
{"bars", "be"},
{"jars", "d"},
}
for _, item := range all {
if found[item.k] != item.v {
t.Errorf("iterator value mismatch for %s: got %v want %v", item.k, found[item.k], item.v)
}
}
if len(found) != len(all) {
t.Errorf("iterator count mismatch: got %d values, want %d", len(found), len(all))
}
}
func TestUnionIterator(t *testing.T) {
dba := NewDatabase(rawdb.NewMemoryDatabase())
triea := NewEmpty(dba)
for _, val := range testdata1 {
triea.Update([]byte(val.k), []byte(val.v))
}
rootA, nodesA := triea.Commit(false)
dba.Update(NewWithNodeSet(nodesA))
triea, _ = New(TrieID(rootA), dba, StateTrieCodec)
dbb := NewDatabase(rawdb.NewMemoryDatabase())
trieb := NewEmpty(dbb)
for _, val := range testdata2 {
trieb.Update([]byte(val.k), []byte(val.v))
}
rootB, nodesB := trieb.Commit(false)
dbb.Update(NewWithNodeSet(nodesB))
trieb, _ = New(TrieID(rootB), dbb, StateTrieCodec)
di, _ := NewUnionIterator([]NodeIterator{triea.NodeIterator(nil), trieb.NodeIterator(nil)})
it := NewIterator(di)
all := []struct{ k, v string }{
{"aardvark", "c"},
{"barb", "ba"},
{"barb", "bd"},
{"bard", "bc"},
{"bars", "bb"},
{"bars", "be"},
{"bar", "b"},
{"fab", "z"},
{"food", "ab"},
{"foos", "aa"},
{"foo", "a"},
{"jars", "d"},
}
for i, kv := range all {
if !it.Next() {
t.Errorf("Iterator ends prematurely at element %d", i)
}
if kv.k != string(it.Key) {
t.Errorf("iterator value mismatch for element %d: got key %s want %s", i, it.Key, kv.k)
}
if kv.v != string(it.Value) {
t.Errorf("iterator value mismatch for element %d: got value %s want %s", i, it.Value, kv.v)
}
}
if it.Next() {
t.Errorf("Iterator returned extra values.")
}
}
func TestIteratorNoDups(t *testing.T) {
tr := NewEmpty(NewDatabase(rawdb.NewMemoryDatabase()))
for _, val := range testdata1 {
tr.Update([]byte(val.k), []byte(val.v))
}
checkIteratorNoDups(t, tr.NodeIterator(nil), nil)
}
// This test checks that nodeIterator.Next can be retried after inserting missing trie nodes.
func TestIteratorContinueAfterErrorMemonly(t *testing.T) { testIteratorContinueAfterError(t, true) }
func testIteratorContinueAfterError(t *testing.T, memonly bool) {
diskdb := rawdb.NewMemoryDatabase()
triedb := NewDatabase(diskdb)
tr := NewEmpty(triedb)
for _, val := range testdata1 {
tr.Update([]byte(val.k), []byte(val.v))
}
_, nodes := tr.Commit(false)
triedb.Update(NewWithNodeSet(nodes))
// if !memonly {
// triedb.Commit(tr.Hash(), false)
// }
wantNodeCount := checkIteratorNoDups(t, tr.NodeIterator(nil), nil)
var (
diskKeys [][]byte
memKeys []common.Hash
)
if memonly {
memKeys = triedb.Nodes()
} else {
it := diskdb.NewIterator(nil, nil)
for it.Next() {
diskKeys = append(diskKeys, it.Key())
}
it.Release()
}
for i := 0; i < 20; i++ {
// Create trie that will load all nodes from DB.
tr, _ := New(TrieID(tr.Hash()), triedb, StateTrieCodec)
// Remove a random node from the database. It can't be the root node
// because that one is already loaded.
var (
rkey common.Hash
rval []byte
robj *cachedNode
)
for {
if memonly {
rkey = memKeys[rand.Intn(len(memKeys))]
} else {
copy(rkey[:], diskKeys[rand.Intn(len(diskKeys))])
}
if rkey != tr.Hash() {
break
}
}
if memonly {
robj = triedb.dirties[rkey]
delete(triedb.dirties, rkey)
} else {
rval, _ = diskdb.Get(rkey[:])
diskdb.Delete(rkey[:])
}
// Iterate until the error is hit.
seen := make(map[string]bool)
it := tr.NodeIterator(nil)
checkIteratorNoDups(t, it, seen)
missing, ok := it.Error().(*MissingNodeError)
if !ok || missing.NodeHash != rkey {
t.Fatal("didn't hit missing node, got", it.Error())
}
// Add the node back and continue iteration.
if memonly {
triedb.dirties[rkey] = robj
} else {
diskdb.Put(rkey[:], rval)
}
checkIteratorNoDups(t, it, seen)
if it.Error() != nil {
t.Fatal("unexpected error", it.Error())
}
if len(seen) != wantNodeCount {
t.Fatal("wrong node iteration count, got", len(seen), "want", wantNodeCount)
}
}
}
// Similar to the test above, this one checks that failure to create nodeIterator at a
// certain key prefix behaves correctly when Next is called. The expectation is that Next
// should retry seeking before returning true for the first time.
func TestIteratorContinueAfterSeekErrorMemonly(t *testing.T) {
testIteratorContinueAfterSeekError(t, true)
}
func testIteratorContinueAfterSeekError(t *testing.T, memonly bool) {
// Commit test trie to db, then remove the node containing "bars".
diskdb := rawdb.NewMemoryDatabase()
triedb := NewDatabase(diskdb)
ctr := NewEmpty(triedb)
for _, val := range testdata1 {
ctr.Update([]byte(val.k), []byte(val.v))
}
root, nodes := ctr.Commit(false)
triedb.Update(NewWithNodeSet(nodes))
// if !memonly {
// triedb.Commit(root, false)
// }
barNodeHash := common.HexToHash("05041990364eb72fcb1127652ce40d8bab765f2bfe53225b1170d276cc101c2e")
var (
barNodeBlob []byte
barNodeObj *cachedNode
)
if memonly {
barNodeObj = triedb.dirties[barNodeHash]
delete(triedb.dirties, barNodeHash)
} else {
barNodeBlob, _ = diskdb.Get(barNodeHash[:])
diskdb.Delete(barNodeHash[:])
}
// Create a new iterator that seeks to "bars". Seeking can't proceed because
// the node is missing.
tr, _ := New(TrieID(root), triedb, StateTrieCodec)
it := tr.NodeIterator([]byte("bars"))
missing, ok := it.Error().(*MissingNodeError)
if !ok {
t.Fatal("want MissingNodeError, got", it.Error())
} else if missing.NodeHash != barNodeHash {
t.Fatal("wrong node missing")
}
// Reinsert the missing node.
if memonly {
triedb.dirties[barNodeHash] = barNodeObj
} else {
diskdb.Put(barNodeHash[:], barNodeBlob)
}
// Check that iteration produces the right set of values.
if err := checkIteratorOrder(testdata1[2:], NewIterator(it)); err != nil {
t.Fatal(err)
}
}
func checkIteratorNoDups(t *testing.T, it NodeIterator, seen map[string]bool) int {
if seen == nil {
seen = make(map[string]bool)
}
for it.Next(true) {
if seen[string(it.Path())] {
t.Fatalf("iterator visited node path %x twice", it.Path())
}
seen[string(it.Path())] = true
}
return len(seen)
}
type loggingTrieDb struct {
*Database
getCount uint64
}
// GetReader retrieves a node reader belonging to the given state root.
func (db *loggingTrieDb) GetReader(root common.Hash, codec uint64) Reader {
return &loggingNodeReader{db, codec}
}
// hashReader is reader of hashDatabase which implements the Reader interface.
type loggingNodeReader struct {
db *loggingTrieDb
codec uint64
}
// Node retrieves the trie node with the given node hash.
func (reader *loggingNodeReader) Node(owner common.Hash, path []byte, hash common.Hash) (node, error) {
blob, err := reader.NodeBlob(owner, path, hash)
if err != nil {
return nil, err
}
return decodeNodeUnsafe(hash[:], blob)
}
// NodeBlob retrieves the RLP-encoded trie node blob with the given node hash.
func (reader *loggingNodeReader) NodeBlob(_ common.Hash, _ []byte, hash common.Hash) ([]byte, error) {
reader.db.getCount++
return reader.db.Node(hash, reader.codec)
}
func newLoggingStateTrie(id *ID, db *Database, codec uint64) (*StateTrie, *loggingTrieDb, error) {
logdb := &loggingTrieDb{Database: db}
trie, err := New(id, logdb, codec)
if err != nil {
return nil, nil, err
}
return &StateTrie{trie: *trie, preimages: db.preimages}, logdb, nil
}
// makeLargeTestTrie create a sample test trie
func makeLargeTestTrie(t testing.TB) (*Database, *StateTrie, *loggingTrieDb) {
// Create an empty trie
triedb := NewDatabase(rawdb.NewDatabase(memorydb.New()))
trie, logDb, err := newLoggingStateTrie(TrieID(common.Hash{}), triedb, StateTrieCodec)
if err != nil {
t.Fatal(err)
}
// Fill it with some arbitrary data
for i := 0; i < 10000; i++ {
key := make([]byte, 32)
val := make([]byte, 32)
binary.BigEndian.PutUint64(key, uint64(i))
binary.BigEndian.PutUint64(val, uint64(i))
key = crypto.Keccak256(key)
val = crypto.Keccak256(val)
trie.Update(key, val)
}
_, nodes := trie.Commit(false)
triedb.Update(NewWithNodeSet(nodes))
// Return the generated trie
return triedb, trie, logDb
}
// Tests that the node iterator indeed walks over the entire database contents.
func TestNodeIteratorLargeTrie(t *testing.T) {
// Create some arbitrary test trie to iterate
_, trie, logDb := makeLargeTestTrie(t)
// Do a seek operation
trie.NodeIterator(common.FromHex("0x77667766776677766778855885885885"))
// master: 24 get operations
// this pr: 5 get operations
if have, want := logDb.getCount, uint64(5); have != want {
t.Fatalf("Wrong number of lookups during seek, have %d want %d", have, want)
}
}
func TestIteratorNodeBlob(t *testing.T) { func TestIteratorNodeBlob(t *testing.T) {
// var (
// db = rawdb.NewMemoryDatabase()
// triedb = NewDatabase(db)
// trie = NewEmpty(triedb)
// )
// vals := []struct{ k, v string }{
// {"do", "verb"},
// {"ether", "wookiedoo"},
// {"horse", "stallion"},
// {"shaman", "horse"},
// {"doge", "coin"},
// {"dog", "puppy"},
// {"somethingveryoddindeedthis is", "myothernodedata"},
// }
// all := make(map[string]string)
// for _, val := range vals {
// all[val.k] = val.v
// trie.Update([]byte(val.k), []byte(val.v))
// }
// _, nodes := trie.Commit(false)
// triedb.Update(NewWithNodeSet(nodes))
edb := rawdb.NewMemoryDatabase() edb := rawdb.NewMemoryDatabase()
db := geth_trie.NewDatabase(edb) db := geth_trie.NewDatabase(edb)
orig := geth_trie.NewEmpty(geth_trie.NewDatabase(rawdb.NewMemoryDatabase())) orig := geth_trie.NewEmpty(geth_trie.NewDatabase(rawdb.NewMemoryDatabase()))
if _, err := updateTrie(orig, testdata1); err != nil { if _, err := updateTrie(orig, packableTestData); err != nil {
t.Fatal(err) t.Fatal(err)
} }
root := commitTrie(t, db, orig) root := commitTrie(t, db, orig)
@ -167,7 +608,7 @@ func TestIteratorNodeBlob(t *testing.T) {
for dbIter.Next() { for dbIter.Next() {
got, present := found[common.BytesToHash(dbIter.Key())] got, present := found[common.BytesToHash(dbIter.Key())]
if !present { if !present {
t.Fatalf("Missing trie node %v", dbIter.Key()) t.Fatalf("Miss trie node %v", dbIter.Key())
} }
if !bytes.Equal(got, dbIter.Value()) { if !bytes.Equal(got, dbIter.Value()) {
t.Fatalf("Unexpected trie node want %v got %v", dbIter.Value(), got) t.Fatalf("Unexpected trie node want %v got %v", dbIter.Value(), got)
@ -175,6 +616,6 @@ func TestIteratorNodeBlob(t *testing.T) {
count += 1 count += 1
} }
if count != len(found) { if count != len(found) {
t.Fatal("Find extra trie node via iterator") t.Fatalf("Wrong number of trie nodes found, want %d, got %d", len(found), count)
} }
} }

View File

@ -23,7 +23,6 @@ import (
"github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/rlp" "github.com/ethereum/go-ethereum/rlp"
"github.com/ethereum/go-ethereum/trie"
) )
var indices = []string{"0", "1", "2", "3", "4", "5", "6", "7", "8", "9", "a", "b", "c", "d", "e", "f", "[17]"} var indices = []string{"0", "1", "2", "3", "4", "5", "6", "7", "8", "9", "a", "b", "c", "d", "e", "f", "[17]"}
@ -157,7 +156,7 @@ func decodeShort(hash, elems []byte) (node, error) {
return nil, err return nil, err
} }
flag := nodeFlag{hash: hash} flag := nodeFlag{hash: hash}
key := trie.CompactToHex(kbuf) key := compactToHex(kbuf)
if hasTerm(key) { if hasTerm(key) {
// value node // value node
val, _, err := rlp.SplitString(rest) val, _, err := rlp.SplitString(rest)

View File

@ -58,3 +58,30 @@ func (n hashNode) encode(w rlp.EncoderBuffer) {
func (n valueNode) encode(w rlp.EncoderBuffer) { func (n valueNode) encode(w rlp.EncoderBuffer) {
w.WriteBytes(n) w.WriteBytes(n)
} }
func (n rawFullNode) encode(w rlp.EncoderBuffer) {
offset := w.List()
for _, c := range n {
if c != nil {
c.encode(w)
} else {
w.Write(rlp.EmptyString)
}
}
w.ListEnd(offset)
}
func (n *rawShortNode) encode(w rlp.EncoderBuffer) {
offset := w.List()
w.WriteBytes(n.Key)
if n.Val != nil {
n.Val.encode(w)
} else {
w.Write(rlp.EmptyString)
}
w.ListEnd(offset)
}
func (n rawNode) encode(w rlp.EncoderBuffer) {
w.Write(n)
}

View File

@ -20,6 +20,7 @@ import (
"bytes" "bytes"
"testing" "testing"
"github.com/ethereum/go-ethereum/crypto"
"github.com/ethereum/go-ethereum/rlp" "github.com/ethereum/go-ethereum/rlp"
) )
@ -92,3 +93,123 @@ func TestDecodeFullNode(t *testing.T) {
t.Fatalf("decode full node err: %v", err) t.Fatalf("decode full node err: %v", err)
} }
} }
// goos: darwin
// goarch: arm64
// pkg: github.com/ethereum/go-ethereum/trie
// BenchmarkEncodeShortNode
// BenchmarkEncodeShortNode-8 16878850 70.81 ns/op 48 B/op 1 allocs/op
func BenchmarkEncodeShortNode(b *testing.B) {
node := &shortNode{
Key: []byte{0x1, 0x2},
Val: hashNode(randBytes(32)),
}
b.ResetTimer()
b.ReportAllocs()
for i := 0; i < b.N; i++ {
nodeToBytes(node)
}
}
// goos: darwin
// goarch: arm64
// pkg: github.com/ethereum/go-ethereum/trie
// BenchmarkEncodeFullNode
// BenchmarkEncodeFullNode-8 4323273 284.4 ns/op 576 B/op 1 allocs/op
func BenchmarkEncodeFullNode(b *testing.B) {
node := &fullNode{}
for i := 0; i < 16; i++ {
node.Children[i] = hashNode(randBytes(32))
}
b.ResetTimer()
b.ReportAllocs()
for i := 0; i < b.N; i++ {
nodeToBytes(node)
}
}
// goos: darwin
// goarch: arm64
// pkg: github.com/ethereum/go-ethereum/trie
// BenchmarkDecodeShortNode
// BenchmarkDecodeShortNode-8 7925638 151.0 ns/op 157 B/op 4 allocs/op
func BenchmarkDecodeShortNode(b *testing.B) {
node := &shortNode{
Key: []byte{0x1, 0x2},
Val: hashNode(randBytes(32)),
}
blob := nodeToBytes(node)
hash := crypto.Keccak256(blob)
b.ResetTimer()
b.ReportAllocs()
for i := 0; i < b.N; i++ {
mustDecodeNode(hash, blob)
}
}
// goos: darwin
// goarch: arm64
// pkg: github.com/ethereum/go-ethereum/trie
// BenchmarkDecodeShortNodeUnsafe
// BenchmarkDecodeShortNodeUnsafe-8 9027476 128.6 ns/op 109 B/op 3 allocs/op
func BenchmarkDecodeShortNodeUnsafe(b *testing.B) {
node := &shortNode{
Key: []byte{0x1, 0x2},
Val: hashNode(randBytes(32)),
}
blob := nodeToBytes(node)
hash := crypto.Keccak256(blob)
b.ResetTimer()
b.ReportAllocs()
for i := 0; i < b.N; i++ {
mustDecodeNodeUnsafe(hash, blob)
}
}
// goos: darwin
// goarch: arm64
// pkg: github.com/ethereum/go-ethereum/trie
// BenchmarkDecodeFullNode
// BenchmarkDecodeFullNode-8 1597462 761.9 ns/op 1280 B/op 18 allocs/op
func BenchmarkDecodeFullNode(b *testing.B) {
node := &fullNode{}
for i := 0; i < 16; i++ {
node.Children[i] = hashNode(randBytes(32))
}
blob := nodeToBytes(node)
hash := crypto.Keccak256(blob)
b.ResetTimer()
b.ReportAllocs()
for i := 0; i < b.N; i++ {
mustDecodeNode(hash, blob)
}
}
// goos: darwin
// goarch: arm64
// pkg: github.com/ethereum/go-ethereum/trie
// BenchmarkDecodeFullNodeUnsafe
// BenchmarkDecodeFullNodeUnsafe-8 1789070 687.1 ns/op 704 B/op 17 allocs/op
func BenchmarkDecodeFullNodeUnsafe(b *testing.B) {
node := &fullNode{}
for i := 0; i < 16; i++ {
node.Children[i] = hashNode(randBytes(32))
}
blob := nodeToBytes(node)
hash := crypto.Keccak256(blob)
b.ResetTimer()
b.ReportAllocs()
for i := 0; i < b.N; i++ {
mustDecodeNodeUnsafe(hash, blob)
}
}

218
trie_by_cid/trie/nodeset.go Normal file
View File

@ -0,0 +1,218 @@
// Copyright 2022 The go-ethereum Authors
// This file is part of the go-ethereum library.
//
// The go-ethereum library is free software: you can redistribute it and/or modify
// it under the terms of the GNU Lesser General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// The go-ethereum library is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Lesser General Public License for more details.
//
// You should have received a copy of the GNU Lesser General Public License
// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
package trie
import (
"fmt"
"reflect"
"sort"
"strings"
"github.com/ethereum/go-ethereum/common"
)
// memoryNode is all the information we know about a single cached trie node
// in the memory.
type memoryNode struct {
hash common.Hash // Node hash, computed by hashing rlp value, empty for deleted nodes
size uint16 // Byte size of the useful cached data, 0 for deleted nodes
node node // Cached collapsed trie node, or raw rlp data, nil for deleted nodes
}
// memoryNodeSize is the raw size of a memoryNode data structure without any
// node data included. It's an approximate size, but should be a lot better
// than not counting them.
// nolint:unused
var memoryNodeSize = int(reflect.TypeOf(memoryNode{}).Size())
// memorySize returns the total memory size used by this node.
// nolint:unused
func (n *memoryNode) memorySize(pathlen int) int {
return int(n.size) + memoryNodeSize + pathlen
}
// rlp returns the raw rlp encoded blob of the cached trie node, either directly
// from the cache, or by regenerating it from the collapsed node.
// nolint:unused
func (n *memoryNode) rlp() []byte {
if node, ok := n.node.(rawNode); ok {
return node
}
return nodeToBytes(n.node)
}
// obj returns the decoded and expanded trie node, either directly from the cache,
// or by regenerating it from the rlp encoded blob.
// nolint:unused
func (n *memoryNode) obj() node {
if node, ok := n.node.(rawNode); ok {
return mustDecodeNode(n.hash[:], node)
}
return expandNode(n.hash[:], n.node)
}
// isDeleted returns the indicator if the node is marked as deleted.
func (n *memoryNode) isDeleted() bool {
return n.hash == (common.Hash{})
}
// nodeWithPrev wraps the memoryNode with the previous node value.
// nolint: unused
type nodeWithPrev struct {
*memoryNode
prev []byte // RLP-encoded previous value, nil means it's non-existent
}
// unwrap returns the internal memoryNode object.
// nolint:unused
func (n *nodeWithPrev) unwrap() *memoryNode {
return n.memoryNode
}
// memorySize returns the total memory size used by this node. It overloads
// the function in memoryNode by counting the size of previous value as well.
// nolint: unused
func (n *nodeWithPrev) memorySize(pathlen int) int {
return n.memoryNode.memorySize(pathlen) + len(n.prev)
}
// NodeSet contains all dirty nodes collected during the commit operation.
// Each node is keyed by path. It's not thread-safe to use.
type NodeSet struct {
owner common.Hash // the identifier of the trie
nodes map[string]*memoryNode // the set of dirty nodes(inserted, updated, deleted)
leaves []*leaf // the list of dirty leaves
updates int // the count of updated and inserted nodes
deletes int // the count of deleted nodes
// The list of accessed nodes, which records the original node value.
// The origin value is expected to be nil for newly inserted node
// and is expected to be non-nil for other types(updated, deleted).
accessList map[string][]byte
}
// NewNodeSet initializes an empty node set to be used for tracking dirty nodes
// from a specific account or storage trie. The owner is zero for the account
// trie and the owning account address hash for storage tries.
func NewNodeSet(owner common.Hash, accessList map[string][]byte) *NodeSet {
return &NodeSet{
owner: owner,
nodes: make(map[string]*memoryNode),
accessList: accessList,
}
}
// forEachWithOrder iterates the dirty nodes with the order from bottom to top,
// right to left, nodes with the longest path will be iterated first.
func (set *NodeSet) forEachWithOrder(callback func(path string, n *memoryNode)) {
var paths sort.StringSlice
for path := range set.nodes {
paths = append(paths, path)
}
// Bottom-up, longest path first
sort.Sort(sort.Reverse(paths))
for _, path := range paths {
callback(path, set.nodes[path])
}
}
// markUpdated marks the node as dirty(newly-inserted or updated).
func (set *NodeSet) markUpdated(path []byte, node *memoryNode) {
set.nodes[string(path)] = node
set.updates += 1
}
// markDeleted marks the node as deleted.
func (set *NodeSet) markDeleted(path []byte) {
set.nodes[string(path)] = &memoryNode{}
set.deletes += 1
}
// addLeaf collects the provided leaf node into set.
func (set *NodeSet) addLeaf(node *leaf) {
set.leaves = append(set.leaves, node)
}
// Size returns the number of dirty nodes in set.
func (set *NodeSet) Size() (int, int) {
return set.updates, set.deletes
}
// Hashes returns the hashes of all updated nodes. TODO(rjl493456442) how can
// we get rid of it?
func (set *NodeSet) Hashes() []common.Hash {
var ret []common.Hash
for _, node := range set.nodes {
ret = append(ret, node.hash)
}
return ret
}
// Summary returns a string-representation of the NodeSet.
func (set *NodeSet) Summary() string {
var out = new(strings.Builder)
fmt.Fprintf(out, "nodeset owner: %v\n", set.owner)
if set.nodes != nil {
for path, n := range set.nodes {
// Deletion
if n.isDeleted() {
fmt.Fprintf(out, " [-]: %x prev: %x\n", path, set.accessList[path])
continue
}
// Insertion
origin, ok := set.accessList[path]
if !ok {
fmt.Fprintf(out, " [+]: %x -> %v\n", path, n.hash)
continue
}
// Update
fmt.Fprintf(out, " [*]: %x -> %v prev: %x\n", path, n.hash, origin)
}
}
for _, n := range set.leaves {
fmt.Fprintf(out, "[leaf]: %v\n", n)
}
return out.String()
}
// MergedNodeSet represents a merged dirty node set for a group of tries.
type MergedNodeSet struct {
sets map[common.Hash]*NodeSet
}
// NewMergedNodeSet initializes an empty merged set.
func NewMergedNodeSet() *MergedNodeSet {
return &MergedNodeSet{sets: make(map[common.Hash]*NodeSet)}
}
// NewWithNodeSet constructs a merged nodeset with the provided single set.
func NewWithNodeSet(set *NodeSet) *MergedNodeSet {
merged := NewMergedNodeSet()
merged.Merge(set)
return merged
}
// Merge merges the provided dirty nodes of a trie into the set. The assumption
// is held that no duplicated set belonging to the same trie will be merged twice.
func (set *MergedNodeSet) Merge(other *NodeSet) error {
_, present := set.sets[other.owner]
if present {
return fmt.Errorf("duplicate trie for owner %#x", other.owner)
}
set.sets[other.owner] = other
return nil
}

View File

@ -0,0 +1,78 @@
// Copyright 2022 The go-ethereum Authors
// This file is part of the go-ethereum library.
//
// The go-ethereum library is free software: you can redistribute it and/or modify
// it under the terms of the GNU Lesser General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// The go-ethereum library is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Lesser General Public License for more details.
//
// You should have received a copy of the GNU Lesser General Public License
// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
package trie
import (
"sync"
"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/core/rawdb"
"github.com/ethereum/go-ethereum/ethdb"
)
// preimageStore is the store for caching preimages of node key.
type preimageStore struct {
lock sync.RWMutex
disk ethdb.KeyValueStore
preimages map[common.Hash][]byte // Preimages of nodes from the secure trie
preimagesSize common.StorageSize // Storage size of the preimages cache
}
// newPreimageStore initializes the store for caching preimages.
func newPreimageStore(disk ethdb.KeyValueStore) *preimageStore {
return &preimageStore{
disk: disk,
preimages: make(map[common.Hash][]byte),
}
}
// insertPreimage writes a new trie node pre-image to the memory database if it's
// yet unknown. The method will NOT make a copy of the slice, only use if the
// preimage will NOT be changed later on.
func (store *preimageStore) insertPreimage(preimages map[common.Hash][]byte) {
store.lock.Lock()
defer store.lock.Unlock()
for hash, preimage := range preimages {
if _, ok := store.preimages[hash]; ok {
continue
}
store.preimages[hash] = preimage
store.preimagesSize += common.StorageSize(common.HashLength + len(preimage))
}
}
// preimage retrieves a cached trie node pre-image from memory. If it cannot be
// found cached, the method queries the persistent database for the content.
func (store *preimageStore) preimage(hash common.Hash) []byte {
store.lock.RLock()
preimage := store.preimages[hash]
store.lock.RUnlock()
if preimage != nil {
return preimage
}
return rawdb.ReadPreimage(store.disk, hash)
}
// size returns the current storage size of accumulated preimages.
func (store *preimageStore) size() common.StorageSize {
store.lock.RLock()
defer store.lock.RUnlock()
return store.preimagesSize
}

View File

@ -20,9 +20,10 @@ import (
"bytes" "bytes"
"fmt" "fmt"
"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/ethdb" "github.com/ethereum/go-ethereum/ethdb"
"github.com/ethereum/go-ethereum/log"
"github.com/ethereum/go-ethereum/trie" "github.com/ethereum/go-ethereum/trie"
log "github.com/sirupsen/logrus"
) )
var VerifyProof = trie.VerifyProof var VerifyProof = trie.VerifyProof
@ -61,10 +62,15 @@ func (t *Trie) Prove(key []byte, fromLevel uint, proofDb ethdb.KeyValueWriter) e
key = key[1:] key = key[1:]
nodes = append(nodes, n) nodes = append(nodes, n)
case hashNode: case hashNode:
// Retrieve the specified node from the underlying node reader.
// trie.resolveAndTrack is not used since in that function the
// loaded blob will be tracked, while it's not required here since
// all loaded nodes won't be linked to trie at all and track nodes
// may lead to out-of-memory issue.
var err error var err error
tn, err = t.resolveHash(n, prefix) tn, err = t.reader.node(prefix, common.BytesToHash(n))
if err != nil { if err != nil {
log.Error(fmt.Sprintf("Unhandled trie error: %v", err)) log.Error("Unhandled trie error in Trie.Prove", "err", err)
return err return err
} }
default: default:

View File

@ -14,29 +14,39 @@
// You should have received a copy of the GNU Lesser General Public License // You should have received a copy of the GNU Lesser General Public License
// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>. // along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
package trie_test package trie
import ( import (
"bytes" "bytes"
crand "crypto/rand"
"encoding/binary"
"fmt"
mrand "math/rand" mrand "math/rand"
"sort" "sort"
"testing" "testing"
"time"
"github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/core/rawdb" "github.com/ethereum/go-ethereum/core/rawdb"
"github.com/ethereum/go-ethereum/crypto" "github.com/ethereum/go-ethereum/crypto"
"github.com/ethereum/go-ethereum/ethdb/memorydb" "github.com/ethereum/go-ethereum/ethdb/memorydb"
geth_trie "github.com/ethereum/go-ethereum/trie"
. "github.com/cerc-io/ipld-eth-statedb/trie_by_cid/trie"
) )
// tweakable trie size // Prng is a pseudo random number generator seeded by strong randomness.
var scaleFactor = 512 // The randomness is printed on startup in order to make failures reproducible.
var prng = initRnd()
func init() { func initRnd() *mrand.Rand {
mrand.Seed(time.Now().UnixNano()) var seed [8]byte
crand.Read(seed[:])
rnd := mrand.New(mrand.NewSource(int64(binary.LittleEndian.Uint64(seed[:]))))
fmt.Printf("Seed: %x\n", seed)
return rnd
}
func randBytes(n int) []byte {
r := make([]byte, n)
prng.Read(r)
return r
} }
// makeProvers creates Merkle trie provers based on different implementations to // makeProvers creates Merkle trie provers based on different implementations to
@ -64,7 +74,7 @@ func makeProvers(trie *Trie) []func(key []byte) *memorydb.Database {
} }
func TestProof(t *testing.T) { func TestProof(t *testing.T) {
trie, vals := randomTrie(t, scaleFactor) trie, vals := randomTrie(500)
root := trie.Hash() root := trie.Hash()
for i, prover := range makeProvers(trie) { for i, prover := range makeProvers(trie) {
for _, kv := range vals { for _, kv := range vals {
@ -72,11 +82,11 @@ func TestProof(t *testing.T) {
if proof == nil { if proof == nil {
t.Fatalf("prover %d: missing key %x while constructing proof", i, kv.k) t.Fatalf("prover %d: missing key %x while constructing proof", i, kv.k)
} }
val, err := geth_trie.VerifyProof(root, kv.k, proof) val, err := VerifyProof(root, kv.k, proof)
if err != nil { if err != nil {
t.Fatalf("prover %d: failed to verify proof for key %x: %v\nraw proof: %x", i, kv.k, err, proof) t.Fatalf("prover %d: failed to verify proof for key %x: %v\nraw proof: %x", i, kv.k, err, proof)
} }
if kv.v != unpackValue(val) { if !bytes.Equal(val, kv.v) {
t.Fatalf("prover %d: verified value mismatch for key %x: have %x, want %x", i, kv.k, val, kv.v) t.Fatalf("prover %d: verified value mismatch for key %x: have %x, want %x", i, kv.k, val, kv.v)
} }
} }
@ -84,13 +94,8 @@ func TestProof(t *testing.T) {
} }
func TestOneElementProof(t *testing.T) { func TestOneElementProof(t *testing.T) {
edb := rawdb.NewMemoryDatabase() trie := NewEmpty(NewDatabase(rawdb.NewMemoryDatabase()))
db := geth_trie.NewDatabase(edb) updateString(trie, "k", "v")
orig := geth_trie.NewEmpty(db)
orig.Update([]byte("k"), packValue(42))
root := commitTrie(t, db, orig)
trie := indexTrie(t, edb, root)
for i, prover := range makeProvers(trie) { for i, prover := range makeProvers(trie) {
proof := prover([]byte("k")) proof := prover([]byte("k"))
if proof == nil { if proof == nil {
@ -99,18 +104,18 @@ func TestOneElementProof(t *testing.T) {
if proof.Len() != 1 { if proof.Len() != 1 {
t.Errorf("prover %d: proof should have one element", i) t.Errorf("prover %d: proof should have one element", i)
} }
val, err := geth_trie.VerifyProof(trie.Hash(), []byte("k"), proof) val, err := VerifyProof(trie.Hash(), []byte("k"), proof)
if err != nil { if err != nil {
t.Fatalf("prover %d: failed to verify proof: %v\nraw proof: %x", i, err, proof) t.Fatalf("prover %d: failed to verify proof: %v\nraw proof: %x", i, err, proof)
} }
if 42 != unpackValue(val) { if !bytes.Equal(val, []byte("v")) {
t.Fatalf("prover %d: verified value mismatch: have %x, want 'k'", i, val) t.Fatalf("prover %d: verified value mismatch: have %x, want 'k'", i, val)
} }
} }
} }
func TestBadProof(t *testing.T) { func TestBadProof(t *testing.T) {
trie, vals := randomTrie(t, 2*scaleFactor) trie, vals := randomTrie(800)
root := trie.Hash() root := trie.Hash()
for i, prover := range makeProvers(trie) { for i, prover := range makeProvers(trie) {
for _, kv := range vals { for _, kv := range vals {
@ -140,12 +145,8 @@ func TestBadProof(t *testing.T) {
// Tests that missing keys can also be proven. The test explicitly uses a single // Tests that missing keys can also be proven. The test explicitly uses a single
// entry trie and checks for missing keys both before and after the single entry. // entry trie and checks for missing keys both before and after the single entry.
func TestMissingKeyProof(t *testing.T) { func TestMissingKeyProof(t *testing.T) {
edb := rawdb.NewMemoryDatabase() trie := NewEmpty(NewDatabase(rawdb.NewMemoryDatabase()))
db := geth_trie.NewDatabase(edb) updateString(trie, "k", "v")
orig := geth_trie.NewEmpty(db)
orig.Update([]byte("k"), packValue(42))
root := commitTrie(t, db, orig)
trie := indexTrie(t, edb, root)
for i, key := range []string{"a", "j", "l", "z"} { for i, key := range []string{"a", "j", "l", "z"} {
proof := memorydb.New() proof := memorydb.New()
@ -164,15 +165,7 @@ func TestMissingKeyProof(t *testing.T) {
} }
} }
type entry struct { type entrySlice []*kv
k, v []byte
}
func packEntry(kv *kv) *entry {
return &entry{kv.k, packValue(kv.v)}
}
type entrySlice []*entry
func (p entrySlice) Len() int { return len(p) } func (p entrySlice) Len() int { return len(p) }
func (p entrySlice) Less(i, j int) bool { return bytes.Compare(p[i].k, p[j].k) < 0 } func (p entrySlice) Less(i, j int) bool { return bytes.Compare(p[i].k, p[j].k) < 0 }
@ -181,10 +174,10 @@ func (p entrySlice) Swap(i, j int) { p[i], p[j] = p[j], p[i] }
// TestRangeProof tests normal range proof with both edge proofs // TestRangeProof tests normal range proof with both edge proofs
// as the existent proof. The test cases are generated randomly. // as the existent proof. The test cases are generated randomly.
func TestRangeProof(t *testing.T) { func TestRangeProof(t *testing.T) {
trie, vals := randomTrie(t, 8*scaleFactor) trie, vals := randomTrie(4096)
var entries entrySlice var entries entrySlice
for _, kv := range vals { for _, kv := range vals {
entries = append(entries, packEntry(kv)) entries = append(entries, kv)
} }
sort.Sort(entries) sort.Sort(entries)
for i := 0; i < 500; i++ { for i := 0; i < 500; i++ {
@ -214,10 +207,10 @@ func TestRangeProof(t *testing.T) {
// TestRangeProof tests normal range proof with two non-existent proofs. // TestRangeProof tests normal range proof with two non-existent proofs.
// The test cases are generated randomly. // The test cases are generated randomly.
func TestRangeProofWithNonExistentProof(t *testing.T) { func TestRangeProofWithNonExistentProof(t *testing.T) {
trie, vals := randomTrie(t, 8*scaleFactor) trie, vals := randomTrie(4096)
var entries entrySlice var entries entrySlice
for _, kv := range vals { for _, kv := range vals {
entries = append(entries, packEntry(kv)) entries = append(entries, kv)
} }
sort.Sort(entries) sort.Sort(entries)
for i := 0; i < 500; i++ { for i := 0; i < 500; i++ {
@ -286,10 +279,10 @@ func TestRangeProofWithNonExistentProof(t *testing.T) {
// - There exists a gap between the first element and the left edge proof // - There exists a gap between the first element and the left edge proof
// - There exists a gap between the last element and the right edge proof // - There exists a gap between the last element and the right edge proof
func TestRangeProofWithInvalidNonExistentProof(t *testing.T) { func TestRangeProofWithInvalidNonExistentProof(t *testing.T) {
trie, vals := randomTrie(t, 8*scaleFactor) trie, vals := randomTrie(4096)
var entries entrySlice var entries entrySlice
for _, kv := range vals { for _, kv := range vals {
entries = append(entries, packEntry(kv)) entries = append(entries, kv)
} }
sort.Sort(entries) sort.Sort(entries)
@ -343,10 +336,10 @@ func TestRangeProofWithInvalidNonExistentProof(t *testing.T) {
// element. The first edge proof can be existent one or // element. The first edge proof can be existent one or
// non-existent one. // non-existent one.
func TestOneElementRangeProof(t *testing.T) { func TestOneElementRangeProof(t *testing.T) {
trie, vals := randomTrie(t, 8*scaleFactor) trie, vals := randomTrie(4096)
var entries entrySlice var entries entrySlice
for _, kv := range vals { for _, kv := range vals {
entries = append(entries, packEntry(kv)) entries = append(entries, kv)
} }
sort.Sort(entries) sort.Sort(entries)
@ -408,38 +401,32 @@ func TestOneElementRangeProof(t *testing.T) {
} }
// Test the mini trie with only a single element. // Test the mini trie with only a single element.
t.Run("single element", func(t *testing.T) { tinyTrie := NewEmpty(NewDatabase(rawdb.NewMemoryDatabase()))
edb := rawdb.NewMemoryDatabase() entry := &kv{randBytes(32), randBytes(20), false}
db := geth_trie.NewDatabase(edb) tinyTrie.Update(entry.k, entry.v)
orig := geth_trie.NewEmpty(db)
entry := &entry{randBytes(32), packValue(mrand.Int63())}
orig.Update(entry.k, entry.v)
root := commitTrie(t, db, orig)
tinyTrie := indexTrie(t, edb, root)
first = common.HexToHash("0x0000000000000000000000000000000000000000000000000000000000000000").Bytes() first = common.HexToHash("0x0000000000000000000000000000000000000000000000000000000000000000").Bytes()
last = entry.k last = entry.k
proof = memorydb.New() proof = memorydb.New()
if err := tinyTrie.Prove(first, 0, proof); err != nil { if err := tinyTrie.Prove(first, 0, proof); err != nil {
t.Fatalf("Failed to prove the first node %v", err) t.Fatalf("Failed to prove the first node %v", err)
} }
if err := tinyTrie.Prove(last, 0, proof); err != nil { if err := tinyTrie.Prove(last, 0, proof); err != nil {
t.Fatalf("Failed to prove the last node %v", err) t.Fatalf("Failed to prove the last node %v", err)
} }
_, err = VerifyRangeProof(tinyTrie.Hash(), first, last, [][]byte{entry.k}, [][]byte{entry.v}, proof) _, err = VerifyRangeProof(tinyTrie.Hash(), first, last, [][]byte{entry.k}, [][]byte{entry.v}, proof)
if err != nil { if err != nil {
t.Fatalf("Expected no error, got %v", err) t.Fatalf("Expected no error, got %v", err)
} }
})
} }
// TestAllElementsProof tests the range proof with all elements. // TestAllElementsProof tests the range proof with all elements.
// The edge proofs can be nil. // The edge proofs can be nil.
func TestAllElementsProof(t *testing.T) { func TestAllElementsProof(t *testing.T) {
trie, vals := randomTrie(t, 8*scaleFactor) trie, vals := randomTrie(4096)
var entries entrySlice var entries entrySlice
for _, kv := range vals { for _, kv := range vals {
entries = append(entries, packEntry(kv)) entries = append(entries, kv)
} }
sort.Sort(entries) sort.Sort(entries)
@ -483,86 +470,73 @@ func TestAllElementsProof(t *testing.T) {
} }
} }
func positionCases(size int) []int {
var cases []int
for _, pos := range []int{0, 1, 50, 100, 1000, 2000, size - 1} {
if pos >= size {
break
}
cases = append(cases, pos)
}
return cases
}
// TestSingleSideRangeProof tests the range starts from zero. // TestSingleSideRangeProof tests the range starts from zero.
func TestSingleSideRangeProof(t *testing.T) { func TestSingleSideRangeProof(t *testing.T) {
edb := rawdb.NewMemoryDatabase() for i := 0; i < 64; i++ {
db := geth_trie.NewDatabase(edb) trie := NewEmpty(NewDatabase(rawdb.NewMemoryDatabase()))
orig := geth_trie.NewEmpty(db) var entries entrySlice
var entries entrySlice for i := 0; i < 4096; i++ {
for i := 0; i < 8*scaleFactor; i++ { value := &kv{randBytes(32), randBytes(20), false}
value := &entry{randBytes(32), packValue(mrand.Int63())} trie.Update(value.k, value.v)
orig.Update(value.k, value.v) entries = append(entries, value)
entries = append(entries, value) }
} sort.Sort(entries)
root := commitTrie(t, db, orig)
trie := indexTrie(t, edb, root)
sort.Sort(entries)
for _, pos := range positionCases(len(entries)) { var cases = []int{0, 1, 50, 100, 1000, 2000, len(entries) - 1}
proof := memorydb.New() for _, pos := range cases {
if err := trie.Prove(common.Hash{}.Bytes(), 0, proof); err != nil { proof := memorydb.New()
t.Fatalf("Failed to prove the first node %v", err) if err := trie.Prove(common.Hash{}.Bytes(), 0, proof); err != nil {
} t.Fatalf("Failed to prove the first node %v", err)
if err := trie.Prove(entries[pos].k, 0, proof); err != nil { }
t.Fatalf("Failed to prove the first node %v", err) if err := trie.Prove(entries[pos].k, 0, proof); err != nil {
} t.Fatalf("Failed to prove the first node %v", err)
k := make([][]byte, 0) }
v := make([][]byte, 0) k := make([][]byte, 0)
for i := 0; i <= pos; i++ { v := make([][]byte, 0)
k = append(k, entries[i].k) for i := 0; i <= pos; i++ {
v = append(v, entries[i].v) k = append(k, entries[i].k)
} v = append(v, entries[i].v)
_, err := VerifyRangeProof(trie.Hash(), common.Hash{}.Bytes(), k[len(k)-1], k, v, proof) }
if err != nil { _, err := VerifyRangeProof(trie.Hash(), common.Hash{}.Bytes(), k[len(k)-1], k, v, proof)
t.Fatalf("Expected no error, got %v", err) if err != nil {
t.Fatalf("Expected no error, got %v", err)
}
} }
} }
} }
// TestReverseSingleSideRangeProof tests the range ends with 0xffff...fff. // TestReverseSingleSideRangeProof tests the range ends with 0xffff...fff.
func TestReverseSingleSideRangeProof(t *testing.T) { func TestReverseSingleSideRangeProof(t *testing.T) {
edb := rawdb.NewMemoryDatabase() for i := 0; i < 64; i++ {
db := geth_trie.NewDatabase(edb) trie := NewEmpty(NewDatabase(rawdb.NewMemoryDatabase()))
orig := geth_trie.NewEmpty(db) var entries entrySlice
var entries entrySlice for i := 0; i < 4096; i++ {
for i := 0; i < 8*scaleFactor; i++ { value := &kv{randBytes(32), randBytes(20), false}
value := &entry{randBytes(32), packValue(mrand.Int63())} trie.Update(value.k, value.v)
orig.Update(value.k, value.v) entries = append(entries, value)
entries = append(entries, value) }
} sort.Sort(entries)
root := commitTrie(t, db, orig)
trie := indexTrie(t, edb, root)
sort.Sort(entries)
for _, pos := range positionCases(len(entries)) { var cases = []int{0, 1, 50, 100, 1000, 2000, len(entries) - 1}
proof := memorydb.New() for _, pos := range cases {
if err := trie.Prove(entries[pos].k, 0, proof); err != nil { proof := memorydb.New()
t.Fatalf("Failed to prove the first node %v", err) if err := trie.Prove(entries[pos].k, 0, proof); err != nil {
} t.Fatalf("Failed to prove the first node %v", err)
last := common.HexToHash("0xffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff") }
if err := trie.Prove(last.Bytes(), 0, proof); err != nil { last := common.HexToHash("0xffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff")
t.Fatalf("Failed to prove the last node %v", err) if err := trie.Prove(last.Bytes(), 0, proof); err != nil {
} t.Fatalf("Failed to prove the last node %v", err)
k := make([][]byte, 0) }
v := make([][]byte, 0) k := make([][]byte, 0)
for i := pos; i < len(entries); i++ { v := make([][]byte, 0)
k = append(k, entries[i].k) for i := pos; i < len(entries); i++ {
v = append(v, entries[i].v) k = append(k, entries[i].k)
} v = append(v, entries[i].v)
_, err := VerifyRangeProof(trie.Hash(), k[0], last.Bytes(), k, v, proof) }
if err != nil { _, err := VerifyRangeProof(trie.Hash(), k[0], last.Bytes(), k, v, proof)
t.Fatalf("Expected no error, got %v", err) if err != nil {
t.Fatalf("Expected no error, got %v", err)
}
} }
} }
} }
@ -570,10 +544,10 @@ func TestReverseSingleSideRangeProof(t *testing.T) {
// TestBadRangeProof tests a few cases which the proof is wrong. // TestBadRangeProof tests a few cases which the proof is wrong.
// The prover is expected to detect the error. // The prover is expected to detect the error.
func TestBadRangeProof(t *testing.T) { func TestBadRangeProof(t *testing.T) {
trie, vals := randomTrie(t, 8*scaleFactor) trie, vals := randomTrie(4096)
var entries entrySlice var entries entrySlice
for _, kv := range vals { for _, kv := range vals {
entries = append(entries, packEntry(kv)) entries = append(entries, kv)
} }
sort.Sort(entries) sort.Sort(entries)
@ -641,17 +615,13 @@ func TestBadRangeProof(t *testing.T) {
// TestGappedRangeProof focuses on the small trie with embedded nodes. // TestGappedRangeProof focuses on the small trie with embedded nodes.
// If the gapped node is embedded in the trie, it should be detected too. // If the gapped node is embedded in the trie, it should be detected too.
func TestGappedRangeProof(t *testing.T) { func TestGappedRangeProof(t *testing.T) {
edb := rawdb.NewMemoryDatabase() trie := NewEmpty(NewDatabase(rawdb.NewMemoryDatabase()))
db := geth_trie.NewDatabase(edb) var entries []*kv // Sorted entries
orig := geth_trie.NewEmpty(db)
var entries entrySlice
for i := byte(0); i < 10; i++ { for i := byte(0); i < 10; i++ {
value := &entry{common.LeftPadBytes([]byte{i}, 32), packValue(int64(i))} value := &kv{common.LeftPadBytes([]byte{i}, 32), []byte{i}, false}
orig.Update(value.k, value.v) trie.Update(value.k, value.v)
entries = append(entries, value) entries = append(entries, value)
} }
root := commitTrie(t, db, orig)
trie := indexTrie(t, edb, root)
first, last := 2, 8 first, last := 2, 8
proof := memorydb.New() proof := memorydb.New()
if err := trie.Prove(entries[first].k, 0, proof); err != nil { if err := trie.Prove(entries[first].k, 0, proof); err != nil {
@ -677,10 +647,10 @@ func TestGappedRangeProof(t *testing.T) {
// TestSameSideProofs tests the element is not in the range covered by proofs // TestSameSideProofs tests the element is not in the range covered by proofs
func TestSameSideProofs(t *testing.T) { func TestSameSideProofs(t *testing.T) {
trie, vals := randomTrie(t, 8*scaleFactor) trie, vals := randomTrie(4096)
var entries entrySlice var entries entrySlice
for _, kv := range vals { for _, kv := range vals {
entries = append(entries, packEntry(kv)) entries = append(entries, kv)
} }
sort.Sort(entries) sort.Sort(entries)
@ -719,17 +689,13 @@ func TestSameSideProofs(t *testing.T) {
} }
func TestHasRightElement(t *testing.T) { func TestHasRightElement(t *testing.T) {
edb := rawdb.NewMemoryDatabase() trie := NewEmpty(NewDatabase(rawdb.NewMemoryDatabase()))
db := geth_trie.NewDatabase(edb)
orig := geth_trie.NewEmpty(db)
var entries entrySlice var entries entrySlice
for i := 0; i < 8*scaleFactor; i++ { for i := 0; i < 4096; i++ {
value := &entry{randBytes(32), packValue(int64(i))} value := &kv{randBytes(32), randBytes(20), false}
orig.Update(value.k, value.v) trie.Update(value.k, value.v)
entries = append(entries, value) entries = append(entries, value)
} }
root := commitTrie(t, db, orig)
trie := indexTrie(t, edb, root)
sort.Sort(entries) sort.Sort(entries)
var cases = []struct { var cases = []struct {
@ -797,10 +763,10 @@ func TestHasRightElement(t *testing.T) {
// TestEmptyRangeProof tests the range proof with "no" element. // TestEmptyRangeProof tests the range proof with "no" element.
// The first edge proof must be a non-existent proof. // The first edge proof must be a non-existent proof.
func TestEmptyRangeProof(t *testing.T) { func TestEmptyRangeProof(t *testing.T) {
trie, vals := randomTrie(t, 8*scaleFactor) trie, vals := randomTrie(4096)
var entries entrySlice var entries entrySlice
for _, kv := range vals { for _, kv := range vals {
entries = append(entries, packEntry(kv)) entries = append(entries, kv)
} }
sort.Sort(entries) sort.Sort(entries)
@ -827,14 +793,49 @@ func TestEmptyRangeProof(t *testing.T) {
} }
} }
// TestBloatedProof tests a malicious proof, where the proof is more or less the
// whole trie. Previously we didn't accept such packets, but the new APIs do, so
// lets leave this test as a bit weird, but present.
func TestBloatedProof(t *testing.T) {
// Use a small trie
trie, kvs := nonRandomTrie(100)
var entries entrySlice
for _, kv := range kvs {
entries = append(entries, kv)
}
sort.Sort(entries)
var keys [][]byte
var vals [][]byte
proof := memorydb.New()
// In the 'malicious' case, we add proofs for every single item
// (but only one key/value pair used as leaf)
for i, entry := range entries {
trie.Prove(entry.k, 0, proof)
if i == 50 {
keys = append(keys, entry.k)
vals = append(vals, entry.v)
}
}
// For reference, we use the same function, but _only_ prove the first
// and last element
want := memorydb.New()
trie.Prove(keys[0], 0, want)
trie.Prove(keys[len(keys)-1], 0, want)
if _, err := VerifyRangeProof(trie.Hash(), keys[0], keys[len(keys)-1], keys, vals, proof); err != nil {
t.Fatalf("expected bloated proof to succeed, got %v", err)
}
}
// TestEmptyValueRangeProof tests normal range proof with both edge proofs // TestEmptyValueRangeProof tests normal range proof with both edge proofs
// as the existent proof, but with an extra empty value included, which is a // as the existent proof, but with an extra empty value included, which is a
// noop technically, but practically should be rejected. // noop technically, but practically should be rejected.
func TestEmptyValueRangeProof(t *testing.T) { func TestEmptyValueRangeProof(t *testing.T) {
trie, values := randomTrie(t, scaleFactor) trie, values := randomTrie(512)
var entries entrySlice var entries entrySlice
for _, kv := range values { for _, kv := range values {
entries = append(entries, packEntry(kv)) entries = append(entries, kv)
} }
sort.Sort(entries) sort.Sort(entries)
@ -847,8 +848,8 @@ func TestEmptyValueRangeProof(t *testing.T) {
break break
} }
} }
noop := &entry{key, []byte{}} noop := &kv{key, []byte{}, false}
entries = append(append(append(entrySlice{}, entries[:mid]...), noop), entries[mid:]...) entries = append(append(append([]*kv{}, entries[:mid]...), noop), entries[mid:]...)
start, end := 1, len(entries)-1 start, end := 1, len(entries)-1
@ -875,10 +876,10 @@ func TestEmptyValueRangeProof(t *testing.T) {
// but with an extra empty value included, which is a noop technically, but // but with an extra empty value included, which is a noop technically, but
// practically should be rejected. // practically should be rejected.
func TestAllElementsEmptyValueRangeProof(t *testing.T) { func TestAllElementsEmptyValueRangeProof(t *testing.T) {
trie, values := randomTrie(t, scaleFactor) trie, values := randomTrie(512)
var entries entrySlice var entries entrySlice
for _, kv := range values { for _, kv := range values {
entries = append(entries, packEntry(kv)) entries = append(entries, kv)
} }
sort.Sort(entries) sort.Sort(entries)
@ -891,8 +892,8 @@ func TestAllElementsEmptyValueRangeProof(t *testing.T) {
break break
} }
} }
noop := &entry{key, nil} noop := &kv{key, []byte{}, false}
entries = append(append(append(entrySlice{}, entries[:mid]...), noop), entries[mid:]...) entries = append(append(append([]*kv{}, entries[:mid]...), noop), entries[mid:]...)
var keys [][]byte var keys [][]byte
var vals [][]byte var vals [][]byte
@ -938,7 +939,7 @@ func decreaseKey(key []byte) []byte {
} }
func BenchmarkProve(b *testing.B) { func BenchmarkProve(b *testing.B) {
trie, vals := randomTrie(b, 100) trie, vals := randomTrie(100)
var keys []string var keys []string
for k := range vals { for k := range vals {
keys = append(keys, k) keys = append(keys, k)
@ -955,7 +956,7 @@ func BenchmarkProve(b *testing.B) {
} }
func BenchmarkVerifyProof(b *testing.B) { func BenchmarkVerifyProof(b *testing.B) {
trie, vals := randomTrie(b, 100) trie, vals := randomTrie(100)
root := trie.Hash() root := trie.Hash()
var keys []string var keys []string
var proofs []*memorydb.Database var proofs []*memorydb.Database
@ -981,10 +982,10 @@ func BenchmarkVerifyRangeProof1000(b *testing.B) { benchmarkVerifyRangeProof(b,
func BenchmarkVerifyRangeProof5000(b *testing.B) { benchmarkVerifyRangeProof(b, 5000) } func BenchmarkVerifyRangeProof5000(b *testing.B) { benchmarkVerifyRangeProof(b, 5000) }
func benchmarkVerifyRangeProof(b *testing.B, size int) { func benchmarkVerifyRangeProof(b *testing.B, size int) {
trie, vals := randomTrie(b, 8192) trie, vals := randomTrie(8192)
var entries entrySlice var entries entrySlice
for _, kv := range vals { for _, kv := range vals {
entries = append(entries, packEntry(kv)) entries = append(entries, kv)
} }
sort.Sort(entries) sort.Sort(entries)
@ -1018,10 +1019,10 @@ func BenchmarkVerifyRangeNoProof500(b *testing.B) { benchmarkVerifyRangeNoProof
func BenchmarkVerifyRangeNoProof1000(b *testing.B) { benchmarkVerifyRangeNoProof(b, 1000) } func BenchmarkVerifyRangeNoProof1000(b *testing.B) { benchmarkVerifyRangeNoProof(b, 1000) }
func benchmarkVerifyRangeNoProof(b *testing.B, size int) { func benchmarkVerifyRangeNoProof(b *testing.B, size int) {
trie, vals := randomTrie(b, size) trie, vals := randomTrie(size)
var entries entrySlice var entries entrySlice
for _, kv := range vals { for _, kv := range vals {
entries = append(entries, packEntry(kv)) entries = append(entries, kv)
} }
sort.Sort(entries) sort.Sort(entries)
@ -1040,24 +1041,56 @@ func benchmarkVerifyRangeNoProof(b *testing.B, size int) {
} }
} }
func randomTrie(n int) (*Trie, map[string]*kv) {
trie := NewEmpty(NewDatabase(rawdb.NewMemoryDatabase()))
vals := make(map[string]*kv)
for i := byte(0); i < 100; i++ {
value := &kv{common.LeftPadBytes([]byte{i}, 32), []byte{i}, false}
value2 := &kv{common.LeftPadBytes([]byte{i + 10}, 32), []byte{i}, false}
trie.Update(value.k, value.v)
trie.Update(value2.k, value2.v)
vals[string(value.k)] = value
vals[string(value2.k)] = value2
}
for i := 0; i < n; i++ {
value := &kv{randBytes(32), randBytes(20), false}
trie.Update(value.k, value.v)
vals[string(value.k)] = value
}
return trie, vals
}
func nonRandomTrie(n int) (*Trie, map[string]*kv) {
trie := NewEmpty(NewDatabase(rawdb.NewMemoryDatabase()))
vals := make(map[string]*kv)
max := uint64(0xffffffffffffffff)
for i := uint64(0); i < uint64(n); i++ {
value := make([]byte, 32)
key := make([]byte, 32)
binary.LittleEndian.PutUint64(key, i)
binary.LittleEndian.PutUint64(value, i-max)
//value := &kv{common.LeftPadBytes([]byte{i}, 32), []byte{i}, false}
elem := &kv{key, value, false}
trie.Update(elem.k, elem.v)
vals[string(elem.k)] = elem
}
return trie, vals
}
func TestRangeProofKeysWithSharedPrefix(t *testing.T) { func TestRangeProofKeysWithSharedPrefix(t *testing.T) {
keys := [][]byte{ keys := [][]byte{
common.Hex2Bytes("aa10000000000000000000000000000000000000000000000000000000000000"), common.Hex2Bytes("aa10000000000000000000000000000000000000000000000000000000000000"),
common.Hex2Bytes("aa20000000000000000000000000000000000000000000000000000000000000"), common.Hex2Bytes("aa20000000000000000000000000000000000000000000000000000000000000"),
} }
vals := [][]byte{ vals := [][]byte{
packValue(2), common.Hex2Bytes("02"),
packValue(3), common.Hex2Bytes("03"),
} }
edb := rawdb.NewMemoryDatabase() trie := NewEmpty(NewDatabase(rawdb.NewMemoryDatabase()))
db := geth_trie.NewDatabase(edb)
orig := geth_trie.NewEmpty(db)
for i, key := range keys { for i, key := range keys {
orig.Update(key, vals[i]) trie.Update(key, vals[i])
} }
root := commitTrie(t, db, orig) root := trie.Hash()
trie := indexTrie(t, edb, root)
proof := memorydb.New() proof := memorydb.New()
start := common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000000") start := common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000000")
end := common.Hex2Bytes("ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff") end := common.Hex2Bytes("ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff")

View File

@ -17,82 +17,88 @@
package trie package trie
import ( import (
"fmt"
"github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/core/types" "github.com/ethereum/go-ethereum/core/types"
"github.com/ethereum/go-ethereum/rlp" "github.com/ethereum/go-ethereum/rlp"
"github.com/ethereum/go-ethereum/statediff/indexer/ipld" log "github.com/sirupsen/logrus"
) )
// StateTrie wraps a trie with key hashing. In a secure trie, all // StateTrie wraps a trie with key hashing. In a stateTrie trie, all
// access operations hash the key using keccak256. This prevents // access operations hash the key using keccak256. This prevents
// calling code from creating long chains of nodes that // calling code from creating long chains of nodes that
// increase the access time. // increase the access time.
// //
// Contrary to a regular trie, a StateTrie can only be created with // Contrary to a regular trie, a StateTrie can only be created with
// New and must have an attached database. // New and must have an attached database. The database also stores
// the preimage of each key if preimage recording is enabled.
// //
// StateTrie is not safe for concurrent use. // StateTrie is not safe for concurrent use.
type StateTrie struct { type StateTrie struct {
trie Trie trie Trie
hashKeyBuf [common.HashLength]byte preimages *preimageStore
hashKeyBuf [common.HashLength]byte
secKeyCache map[string][]byte
secKeyCacheOwner *StateTrie // Pointer to self, replace the key cache on mismatch
} }
// NewStateTrie creates a trie with an existing root node from a backing database // NewStateTrie creates a trie with an existing root node from a backing database.
// and optional intermediate in-memory node pool.
// //
// If root is the zero hash or the sha3 hash of an empty string, the // If root is the zero hash or the sha3 hash of an empty string, the
// trie is initially empty. Otherwise, New will panic if db is nil // trie is initially empty. Otherwise, New will panic if db is nil
// and returns MissingNodeError if the root node cannot be found. // and returns MissingNodeError if the root node cannot be found.
// func NewStateTrie(id *ID, db *Database, codec uint64) (*StateTrie, error) {
// Accessing the trie loads nodes from the database or node pool on demand.
// Loaded nodes are kept around until their 'cache generation' expires.
// A new cache generation is created by each call to Commit.
// cachelimit sets the number of past cache generations to keep.
//
// Retrieves IPLD blocks by CID encoded as "eth-state-trie"
func NewStateTrie(owner common.Hash, root common.Hash, db *Database) (*StateTrie, error) {
return newStateTrie(owner, root, db, ipld.MEthStateTrie)
}
// NewStorageTrie is identical to NewStateTrie, but retrieves IPLD blocks encoded
// as "eth-storage-trie"
func NewStorageTrie(owner common.Hash, root common.Hash, db *Database) (*StateTrie, error) {
return newStateTrie(owner, root, db, ipld.MEthStorageTrie)
}
func newStateTrie(owner common.Hash, root common.Hash, db *Database, codec uint64) (*StateTrie, error) {
if db == nil { if db == nil {
panic("NewStateTrie called without a database") panic("trie.NewStateTrie called without a database")
} }
trie, err := New(owner, root, db, codec) trie, err := New(id, db, codec)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return &StateTrie{trie: *trie}, nil return &StateTrie{trie: *trie, preimages: db.preimages}, nil
}
// Get returns the value for key stored in the trie.
// The value bytes must not be modified by the caller.
func (t *StateTrie) Get(key []byte) []byte {
res, err := t.TryGet(key)
if err != nil {
log.Error("Unhandled trie error in StateTrie.Get", "err", err)
}
return res
} }
// TryGet returns the value for key stored in the trie. // TryGet returns the value for key stored in the trie.
// The value bytes must not be modified by the caller. // The value bytes must not be modified by the caller.
// If a node was not found in the database, a MissingNodeError is returned. // If the specified node is not in the trie, nil will be returned.
// If a trie node is not found in the database, a MissingNodeError is returned.
func (t *StateTrie) TryGet(key []byte) ([]byte, error) { func (t *StateTrie) TryGet(key []byte) ([]byte, error) {
return t.trie.TryGet(t.hashKey(key)) return t.trie.TryGet(t.hashKey(key))
} }
func (t *StateTrie) TryGetAccount(key []byte) (*types.StateAccount, error) { // TryGetAccount attempts to retrieve an account with provided account address.
var ret types.StateAccount // If the specified account is not in the trie, nil will be returned.
res, err := t.TryGet(key) // If a trie node is not found in the database, a MissingNodeError is returned.
if err != nil { func (t *StateTrie) TryGetAccount(address common.Address) (*types.StateAccount, error) {
// log.Error(fmt.Sprintf("Unhandled trie error: %v", err)) res, err := t.trie.TryGet(t.hashKey(address.Bytes()))
panic(fmt.Sprintf("Unhandled trie error: %v", err)) if res == nil || err != nil {
return &ret, err return nil, err
} }
if res == nil { ret := new(types.StateAccount)
return nil, nil err = rlp.DecodeBytes(res, ret)
return ret, err
}
// TryGetAccountByHash does the same thing as TryGetAccount, however
// it expects an account hash that is the hash of address. This constitutes an
// abstraction leak, since the client code needs to know the key format.
func (t *StateTrie) TryGetAccountByHash(addrHash common.Hash) (*types.StateAccount, error) {
res, err := t.trie.TryGet(addrHash.Bytes())
if res == nil || err != nil {
return nil, err
} }
err = rlp.DecodeBytes(res, &ret) ret := new(types.StateAccount)
return &ret, err err = rlp.DecodeBytes(res, ret)
return ret, err
} }
// TryGetNode attempts to retrieve a trie node by compact-encoded path. It is not // TryGetNode attempts to retrieve a trie node by compact-encoded path. It is not
@ -103,12 +109,124 @@ func (t *StateTrie) TryGetNode(path []byte) ([]byte, int, error) {
return t.trie.TryGetNode(path) return t.trie.TryGetNode(path)
} }
// Update associates key with value in the trie. Subsequent calls to
// Get will return value. If value has length zero, any existing value
// is deleted from the trie and calls to Get will return nil.
//
// The value bytes must not be modified by the caller while they are
// stored in the trie.
func (t *StateTrie) Update(key, value []byte) {
if err := t.TryUpdate(key, value); err != nil {
log.Error("Unhandled trie error in StateTrie.Update", "err", err)
}
}
// TryUpdate associates key with value in the trie. Subsequent calls to
// Get will return value. If value has length zero, any existing value
// is deleted from the trie and calls to Get will return nil.
//
// The value bytes must not be modified by the caller while they are
// stored in the trie.
//
// If a node is not found in the database, a MissingNodeError is returned.
func (t *StateTrie) TryUpdate(key, value []byte) error {
hk := t.hashKey(key)
err := t.trie.TryUpdate(hk, value)
if err != nil {
return err
}
t.getSecKeyCache()[string(hk)] = common.CopyBytes(key)
return nil
}
// TryUpdateAccount account will abstract the write of an account to the
// secure trie.
func (t *StateTrie) TryUpdateAccount(address common.Address, acc *types.StateAccount) error {
hk := t.hashKey(address.Bytes())
data, err := rlp.EncodeToBytes(acc)
if err != nil {
return err
}
if err := t.trie.TryUpdate(hk, data); err != nil {
return err
}
t.getSecKeyCache()[string(hk)] = address.Bytes()
return nil
}
// Delete removes any existing value for key from the trie.
func (t *StateTrie) Delete(key []byte) {
if err := t.TryDelete(key); err != nil {
log.Error("Unhandled trie error in StateTrie.Delete", "err", err)
}
}
// TryDelete removes any existing value for key from the trie.
// If the specified trie node is not in the trie, nothing will be changed.
// If a node is not found in the database, a MissingNodeError is returned.
func (t *StateTrie) TryDelete(key []byte) error {
hk := t.hashKey(key)
delete(t.getSecKeyCache(), string(hk))
return t.trie.TryDelete(hk)
}
// TryDeleteAccount abstracts an account deletion from the trie.
func (t *StateTrie) TryDeleteAccount(address common.Address) error {
hk := t.hashKey(address.Bytes())
delete(t.getSecKeyCache(), string(hk))
return t.trie.TryDelete(hk)
}
// GetKey returns the sha3 preimage of a hashed key that was
// previously used to store a value.
func (t *StateTrie) GetKey(shaKey []byte) []byte {
if key, ok := t.getSecKeyCache()[string(shaKey)]; ok {
return key
}
if t.preimages == nil {
return nil
}
return t.preimages.preimage(common.BytesToHash(shaKey))
}
// Commit collects all dirty nodes in the trie and replaces them with the
// corresponding node hash. All collected nodes (including dirty leaves if
// collectLeaf is true) will be encapsulated into a nodeset for return.
// The returned nodeset can be nil if the trie is clean (nothing to commit).
// All cached preimages will be also flushed if preimages recording is enabled.
// Once the trie is committed, it's not usable anymore. A new trie must
// be created with new root and updated trie database for following usage
func (t *StateTrie) Commit(collectLeaf bool) (common.Hash, *NodeSet) {
// Write all the pre-images to the actual disk database
if len(t.getSecKeyCache()) > 0 {
if t.preimages != nil {
preimages := make(map[common.Hash][]byte)
for hk, key := range t.secKeyCache {
preimages[common.BytesToHash([]byte(hk))] = key
}
t.preimages.insertPreimage(preimages)
}
t.secKeyCache = make(map[string][]byte)
}
// Commit the trie and return its modified nodeset.
return t.trie.Commit(collectLeaf)
}
// Hash returns the root hash of StateTrie. It does not write to the // Hash returns the root hash of StateTrie. It does not write to the
// database and can be used even if the trie doesn't have one. // database and can be used even if the trie doesn't have one.
func (t *StateTrie) Hash() common.Hash { func (t *StateTrie) Hash() common.Hash {
return t.trie.Hash() return t.trie.Hash()
} }
// Copy returns a copy of StateTrie.
func (t *StateTrie) Copy() *StateTrie {
return &StateTrie{
trie: *t.trie.Copy(),
preimages: t.preimages,
secKeyCache: t.secKeyCache,
}
}
// NodeIterator returns an iterator that returns nodes of the underlying trie. Iteration // NodeIterator returns an iterator that returns nodes of the underlying trie. Iteration
// starts at the key after the given start key. // starts at the key after the given start key.
func (t *StateTrie) NodeIterator(start []byte) NodeIterator { func (t *StateTrie) NodeIterator(start []byte) NodeIterator {
@ -126,3 +244,14 @@ func (t *StateTrie) hashKey(key []byte) []byte {
returnHasherToPool(h) returnHasherToPool(h)
return t.hashKeyBuf[:] return t.hashKeyBuf[:]
} }
// getSecKeyCache returns the current secure key cache, creating a new one if
// ownership changed (i.e. the current secure trie is a copy of another owning
// the actual cache).
func (t *StateTrie) getSecKeyCache() map[string][]byte {
if t != t.secKeyCacheOwner {
t.secKeyCacheOwner = t
t.secKeyCache = make(map[string][]byte)
}
return t.secKeyCache
}

View File

@ -0,0 +1,148 @@
// Copyright 2015 The go-ethereum Authors
// This file is part of the go-ethereum library.
//
// The go-ethereum library is free software: you can redistribute it and/or modify
// it under the terms of the GNU Lesser General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// The go-ethereum library is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Lesser General Public License for more details.
//
// You should have received a copy of the GNU Lesser General Public License
// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
package trie
import (
"bytes"
"fmt"
"runtime"
"sync"
"testing"
"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/core/rawdb"
"github.com/ethereum/go-ethereum/crypto"
)
// makeTestStateTrie creates a large enough secure trie for testing.
func makeTestStateTrie() (*Database, *StateTrie, map[string][]byte) {
// Create an empty trie
triedb := NewDatabase(rawdb.NewMemoryDatabase())
trie, _ := NewStateTrie(TrieID(common.Hash{}), triedb, StateTrieCodec)
// Fill it with some arbitrary data
content := make(map[string][]byte)
for i := byte(0); i < 255; i++ {
// Map the same data under multiple keys
key, val := common.LeftPadBytes([]byte{1, i}, 32), []byte{i}
content[string(key)] = val
trie.Update(key, val)
key, val = common.LeftPadBytes([]byte{2, i}, 32), []byte{i}
content[string(key)] = val
trie.Update(key, val)
// Add some other data to inflate the trie
for j := byte(3); j < 13; j++ {
key, val = common.LeftPadBytes([]byte{j, i}, 32), []byte{j, i}
content[string(key)] = val
trie.Update(key, val)
}
}
root, nodes := trie.Commit(false)
if err := triedb.Update(NewWithNodeSet(nodes)); err != nil {
panic(fmt.Errorf("failed to commit db %v", err))
}
// Re-create the trie based on the new state
trie, _ = NewStateTrie(TrieID(root), triedb, StateTrieCodec)
return triedb, trie, content
}
func TestSecureDelete(t *testing.T) {
trie, err := NewStateTrie(TrieID(common.Hash{}), NewDatabase(rawdb.NewMemoryDatabase()), StateTrieCodec)
if err != nil {
t.Fatal(err)
}
vals := []struct{ k, v string }{
{"do", "verb"},
{"ether", "wookiedoo"},
{"horse", "stallion"},
{"shaman", "horse"},
{"doge", "coin"},
{"ether", ""},
{"dog", "puppy"},
{"shaman", ""},
}
for _, val := range vals {
if val.v != "" {
trie.Update([]byte(val.k), []byte(val.v))
} else {
trie.Delete([]byte(val.k))
}
}
hash := trie.Hash()
exp := common.HexToHash("29b235a58c3c25ab83010c327d5932bcf05324b7d6b1185e650798034783ca9d")
if hash != exp {
t.Errorf("expected %x got %x", exp, hash)
}
}
func TestSecureGetKey(t *testing.T) {
trie, err := NewStateTrie(TrieID(common.Hash{}), NewDatabase(rawdb.NewMemoryDatabase()), StateTrieCodec)
if err != nil {
t.Fatal(err)
}
trie.Update([]byte("foo"), []byte("bar"))
key := []byte("foo")
value := []byte("bar")
seckey := crypto.Keccak256(key)
if !bytes.Equal(trie.Get(key), value) {
t.Errorf("Get did not return bar")
}
if k := trie.GetKey(seckey); !bytes.Equal(k, key) {
t.Errorf("GetKey returned %q, want %q", k, key)
}
}
func TestStateTrieConcurrency(t *testing.T) {
// Create an initial trie and copy if for concurrent access
_, trie, _ := makeTestStateTrie()
threads := runtime.NumCPU()
tries := make([]*StateTrie, threads)
for i := 0; i < threads; i++ {
tries[i] = trie.Copy()
}
// Start a batch of goroutines interacting with the trie
pend := new(sync.WaitGroup)
pend.Add(threads)
for i := 0; i < threads; i++ {
go func(index int) {
defer pend.Done()
for j := byte(0); j < 255; j++ {
// Map the same data under multiple keys
key, val := common.LeftPadBytes([]byte{byte(index), 1, j}, 32), []byte{j}
tries[index].Update(key, val)
key, val = common.LeftPadBytes([]byte{byte(index), 2, j}, 32), []byte{j}
tries[index].Update(key, val)
// Add some other data to inflate the trie
for k := byte(3); k < 13; k++ {
key, val = common.LeftPadBytes([]byte{byte(index), k, j}, 32), []byte{k, j}
tries[index].Update(key, val)
}
}
tries[index].Commit(false)
}(i)
}
// Wait for all threads to finish
pend.Wait()
}

125
trie_by_cid/trie/tracer.go Normal file
View File

@ -0,0 +1,125 @@
// Copyright 2022 The go-ethereum Authors
// This file is part of the go-ethereum library.
//
// The go-ethereum library is free software: you can redistribute it and/or modify
// it under the terms of the GNU Lesser General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// The go-ethereum library is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Lesser General Public License for more details.
//
// You should have received a copy of the GNU Lesser General Public License
// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
package trie
import "github.com/ethereum/go-ethereum/common"
// tracer tracks the changes of trie nodes. During the trie operations,
// some nodes can be deleted from the trie, while these deleted nodes
// won't be captured by trie.Hasher or trie.Committer. Thus, these deleted
// nodes won't be removed from the disk at all. Tracer is an auxiliary tool
// used to track all insert and delete operations of trie and capture all
// deleted nodes eventually.
//
// The changed nodes can be mainly divided into two categories: the leaf
// node and intermediate node. The former is inserted/deleted by callers
// while the latter is inserted/deleted in order to follow the rule of trie.
// This tool can track all of them no matter the node is embedded in its
// parent or not, but valueNode is never tracked.
//
// Besides, it's also used for recording the original value of the nodes
// when they are resolved from the disk. The pre-value of the nodes will
// be used to construct trie history in the future.
//
// Note tracer is not thread-safe, callers should be responsible for handling
// the concurrency issues by themselves.
type tracer struct {
inserts map[string]struct{}
deletes map[string]struct{}
accessList map[string][]byte
}
// newTracer initializes the tracer for capturing trie changes.
func newTracer() *tracer {
return &tracer{
inserts: make(map[string]struct{}),
deletes: make(map[string]struct{}),
accessList: make(map[string][]byte),
}
}
// onRead tracks the newly loaded trie node and caches the rlp-encoded
// blob internally. Don't change the value outside of function since
// it's not deep-copied.
func (t *tracer) onRead(path []byte, val []byte) {
t.accessList[string(path)] = val
}
// onInsert tracks the newly inserted trie node. If it's already
// in the deletion set (resurrected node), then just wipe it from
// the deletion set as it's "untouched".
func (t *tracer) onInsert(path []byte) {
if _, present := t.deletes[string(path)]; present {
delete(t.deletes, string(path))
return
}
t.inserts[string(path)] = struct{}{}
}
// onDelete tracks the newly deleted trie node. If it's already
// in the addition set, then just wipe it from the addition set
// as it's untouched.
func (t *tracer) onDelete(path []byte) {
if _, present := t.inserts[string(path)]; present {
delete(t.inserts, string(path))
return
}
t.deletes[string(path)] = struct{}{}
}
// reset clears the content tracked by tracer.
func (t *tracer) reset() {
t.inserts = make(map[string]struct{})
t.deletes = make(map[string]struct{})
t.accessList = make(map[string][]byte)
}
// copy returns a deep copied tracer instance.
func (t *tracer) copy() *tracer {
var (
inserts = make(map[string]struct{})
deletes = make(map[string]struct{})
accessList = make(map[string][]byte)
)
for path := range t.inserts {
inserts[path] = struct{}{}
}
for path := range t.deletes {
deletes[path] = struct{}{}
}
for path, blob := range t.accessList {
accessList[path] = common.CopyBytes(blob)
}
return &tracer{
inserts: inserts,
deletes: deletes,
accessList: accessList,
}
}
// markDeletions puts all tracked deletions into the provided nodeset.
func (t *tracer) markDeletions(set *NodeSet) {
for path := range t.deletes {
// It's possible a few deleted nodes were embedded
// in their parent before, the deletions can be no
// effect by deleting nothing, filter them out.
if _, ok := set.accessList[path]; !ok {
continue
}
set.markDeleted([]byte(path))
}
}

View File

@ -7,18 +7,15 @@ import (
"fmt" "fmt"
"github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/crypto" "github.com/ethereum/go-ethereum/core/types"
log "github.com/sirupsen/logrus"
util "github.com/cerc-io/ipld-eth-statedb/internal"
"github.com/ethereum/go-ethereum/statediff/indexer/ipld" "github.com/ethereum/go-ethereum/statediff/indexer/ipld"
) )
var ( var (
// emptyRoot is the known root hash of an empty trie. StateTrieCodec uint64 = ipld.MEthStateTrie
emptyRoot = common.HexToHash("56e81f171bcc55a6ff8345e692c0f86e5b48e01b996cadc001622fb5e363b421") StorageTrieCodec uint64 = ipld.MEthStorageTrie
// emptyState is the known hash of an empty state trie entry.
emptyState = crypto.Keccak256Hash(nil)
) )
// Trie is a Merkle Patricia Trie. Use New to create a trie that sits on // Trie is a Merkle Patricia Trie. Use New to create a trie that sits on
@ -37,29 +34,48 @@ type Trie struct {
// actually unhashed nodes. // actually unhashed nodes.
unhashed int unhashed int
// db is the handler trie can retrieve nodes from. It's // reader is the handler trie can retrieve nodes from.
// only for reading purpose and not available for writing. reader *trieReader
db *Database
// Multihash codec for key encoding // tracer is the tool to track the trie changes.
codec uint64 // It will be reset after each commit operation.
tracer *tracer
} }
// New creates a trie with an existing root node from db and an assigned // newFlag returns the cache flag value for a newly created node.
// owner for storage proximity. func (t *Trie) newFlag() nodeFlag {
// return nodeFlag{dirty: true}
// If root is the zero hash or the sha3 hash of an empty string, the }
// trie is initially empty and does not require a database. Otherwise,
// New will panic if db is nil and returns a MissingNodeError if root does // Copy returns a copy of Trie.
// not exist in the database. Accessing the trie loads nodes from db on demand. func (t *Trie) Copy() *Trie {
func New(owner common.Hash, root common.Hash, db *Database, codec uint64) (*Trie, error) { return &Trie{
trie := &Trie{ root: t.root,
owner: owner, owner: t.owner,
db: db, unhashed: t.unhashed,
codec: codec, reader: t.reader,
tracer: t.tracer.copy(),
} }
if root != (common.Hash{}) && root != emptyRoot { }
rootnode, err := trie.resolveHash(root[:], nil)
// New creates a trie instance with the provided trie id and the read-only
// database. The state specified by trie id must be available, otherwise
// an error will be returned. The trie root specified by trie id can be
// zero hash or the sha3 hash of an empty string, then trie is initially
// empty, otherwise, the root node must be present in database or returns
// a MissingNodeError if not.
func New(id *ID, db NodeReader, codec uint64) (*Trie, error) {
reader, err := newTrieReader(id.StateRoot, id.Owner, db, codec)
if err != nil {
return nil, err
}
trie := &Trie{
owner: id.Owner,
reader: reader,
tracer: newTracer(),
}
if id.Root != (common.Hash{}) && id.Root != types.EmptyRootHash {
rootnode, err := trie.resolveAndTrack(id.Root[:], nil)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -70,7 +86,10 @@ func New(owner common.Hash, root common.Hash, db *Database, codec uint64) (*Trie
// NewEmpty is a shortcut to create empty tree. It's mostly used in tests. // NewEmpty is a shortcut to create empty tree. It's mostly used in tests.
func NewEmpty(db *Database) *Trie { func NewEmpty(db *Database) *Trie {
tr, _ := New(common.Hash{}, common.Hash{}, db, ipld.MEthStateTrie) tr, err := New(TrieID(common.Hash{}), db, StateTrieCodec)
if err != nil {
panic(err)
}
return tr return tr
} }
@ -80,6 +99,16 @@ func (t *Trie) NodeIterator(start []byte) NodeIterator {
return newNodeIterator(t, start) return newNodeIterator(t, start)
} }
// Get returns the value for key stored in the trie.
// The value bytes must not be modified by the caller.
func (t *Trie) Get(key []byte) []byte {
res, err := t.TryGet(key)
if err != nil {
log.Error("Unhandled trie error in Trie.Get", "err", err)
}
return res
}
// TryGet returns the value for key stored in the trie. // TryGet returns the value for key stored in the trie.
// The value bytes must not be modified by the caller. // The value bytes must not be modified by the caller.
// If a node was not found in the database, a MissingNodeError is returned. // If a node was not found in the database, a MissingNodeError is returned.
@ -116,7 +145,7 @@ func (t *Trie) tryGet(origNode node, key []byte, pos int) (value []byte, newnode
} }
return value, n, didResolve, err return value, n, didResolve, err
case hashNode: case hashNode:
child, err := t.resolveHash(n, key[:pos]) child, err := t.resolveAndTrack(n, key[:pos])
if err != nil { if err != nil {
return nil, n, true, err return nil, n, true, err
} }
@ -130,7 +159,7 @@ func (t *Trie) tryGet(origNode node, key []byte, pos int) (value []byte, newnode
// TryGetNode attempts to retrieve a trie node by compact-encoded path. It is not // TryGetNode attempts to retrieve a trie node by compact-encoded path. It is not
// possible to use keybyte-encoding as the path might contain odd nibbles. // possible to use keybyte-encoding as the path might contain odd nibbles.
func (t *Trie) TryGetNode(path []byte) ([]byte, int, error) { func (t *Trie) TryGetNode(path []byte) ([]byte, int, error) {
item, newroot, resolved, err := t.tryGetNode(t.root, CompactToHex(path), 0) item, newroot, resolved, err := t.tryGetNode(t.root, compactToHex(path), 0)
if err != nil { if err != nil {
return nil, resolved, err return nil, resolved, err
} }
@ -162,11 +191,7 @@ func (t *Trie) tryGetNode(origNode node, path []byte, pos int) (item []byte, new
if hash == nil { if hash == nil {
return nil, origNode, 0, errors.New("non-consensus node") return nil, origNode, 0, errors.New("non-consensus node")
} }
cid, err := util.Keccak256ToCid(t.codec, hash) blob, err := t.reader.nodeBlob(path, common.BytesToHash(hash))
if err != nil {
return nil, origNode, 0, err
}
blob, err := t.db.Node(cid)
return blob, origNode, 1, err return blob, origNode, 1, err
} }
// Path still needs to be traversed, descend into children // Path still needs to be traversed, descend into children
@ -196,7 +221,7 @@ func (t *Trie) tryGetNode(origNode node, path []byte, pos int) (item []byte, new
return item, n, resolved, err return item, n, resolved, err
case hashNode: case hashNode:
child, err := t.resolveHash(n, path[:pos]) child, err := t.resolveAndTrack(n, path[:pos])
if err != nil { if err != nil {
return nil, n, 1, err return nil, n, 1, err
} }
@ -208,42 +233,309 @@ func (t *Trie) tryGetNode(origNode node, path []byte, pos int) (item []byte, new
} }
} }
// resolveHash loads node from the underlying database with the provided // Update associates key with value in the trie. Subsequent calls to
// node hash and path prefix. // Get will return value. If value has length zero, any existing value
func (t *Trie) resolveHash(n hashNode, prefix []byte) (node, error) { // is deleted from the trie and calls to Get will return nil.
cid, err := util.Keccak256ToCid(t.codec, n) //
if err != nil { // The value bytes must not be modified by the caller while they are
return nil, err // stored in the trie.
func (t *Trie) Update(key, value []byte) {
if err := t.TryUpdate(key, value); err != nil {
log.Error("Unhandled trie error in Trie.Update", "err", err)
} }
enc, err := t.db.Node(cid)
if err != nil {
return nil, &MissingNodeError{Owner: t.owner, NodeHash: n, Path: prefix, err: err}
}
node, err := decodeNodeUnsafe(n, enc)
if err != nil {
return nil, err
}
if node != nil {
return node, nil
}
return nil, &MissingNodeError{Owner: t.owner, NodeHash: n, Path: prefix}
} }
// resolveHash loads rlp-encoded node blob from the underlying database // TryUpdate associates key with value in the trie. Subsequent calls to
// with the provided node hash and path prefix. // Get will return value. If value has length zero, any existing value
func (t *Trie) resolveBlob(n hashNode, prefix []byte) ([]byte, error) { // is deleted from the trie and calls to Get will return nil.
cid, err := util.Keccak256ToCid(t.codec, n) //
// The value bytes must not be modified by the caller while they are
// stored in the trie.
//
// If a node was not found in the database, a MissingNodeError is returned.
func (t *Trie) TryUpdate(key, value []byte) error {
return t.tryUpdate(key, value)
}
// tryUpdate expects an RLP-encoded value and performs the core function
// for TryUpdate and TryUpdateAccount.
func (t *Trie) tryUpdate(key, value []byte) error {
t.unhashed++
k := keybytesToHex(key)
if len(value) != 0 {
_, n, err := t.insert(t.root, nil, k, valueNode(value))
if err != nil {
return err
}
t.root = n
} else {
_, n, err := t.delete(t.root, nil, k)
if err != nil {
return err
}
t.root = n
}
return nil
}
func (t *Trie) insert(n node, prefix, key []byte, value node) (bool, node, error) {
if len(key) == 0 {
if v, ok := n.(valueNode); ok {
return !bytes.Equal(v, value.(valueNode)), value, nil
}
return true, value, nil
}
switch n := n.(type) {
case *shortNode:
matchlen := prefixLen(key, n.Key)
// If the whole key matches, keep this short node as is
// and only update the value.
if matchlen == len(n.Key) {
dirty, nn, err := t.insert(n.Val, append(prefix, key[:matchlen]...), key[matchlen:], value)
if !dirty || err != nil {
return false, n, err
}
return true, &shortNode{n.Key, nn, t.newFlag()}, nil
}
// Otherwise branch out at the index where they differ.
branch := &fullNode{flags: t.newFlag()}
var err error
_, branch.Children[n.Key[matchlen]], err = t.insert(nil, append(prefix, n.Key[:matchlen+1]...), n.Key[matchlen+1:], n.Val)
if err != nil {
return false, nil, err
}
_, branch.Children[key[matchlen]], err = t.insert(nil, append(prefix, key[:matchlen+1]...), key[matchlen+1:], value)
if err != nil {
return false, nil, err
}
// Replace this shortNode with the branch if it occurs at index 0.
if matchlen == 0 {
return true, branch, nil
}
// New branch node is created as a child of the original short node.
// Track the newly inserted node in the tracer. The node identifier
// passed is the path from the root node.
t.tracer.onInsert(append(prefix, key[:matchlen]...))
// Replace it with a short node leading up to the branch.
return true, &shortNode{key[:matchlen], branch, t.newFlag()}, nil
case *fullNode:
dirty, nn, err := t.insert(n.Children[key[0]], append(prefix, key[0]), key[1:], value)
if !dirty || err != nil {
return false, n, err
}
n = n.copy()
n.flags = t.newFlag()
n.Children[key[0]] = nn
return true, n, nil
case nil:
// New short node is created and track it in the tracer. The node identifier
// passed is the path from the root node. Note the valueNode won't be tracked
// since it's always embedded in its parent.
t.tracer.onInsert(prefix)
return true, &shortNode{key, value, t.newFlag()}, nil
case hashNode:
// We've hit a part of the trie that isn't loaded yet. Load
// the node and insert into it. This leaves all child nodes on
// the path to the value in the trie.
rn, err := t.resolveAndTrack(n, prefix)
if err != nil {
return false, nil, err
}
dirty, nn, err := t.insert(rn, prefix, key, value)
if !dirty || err != nil {
return false, rn, err
}
return true, nn, nil
default:
panic(fmt.Sprintf("%T: invalid node: %v", n, n))
}
}
// Delete removes any existing value for key from the trie.
func (t *Trie) Delete(key []byte) {
if err := t.TryDelete(key); err != nil {
log.Error("Unhandled trie error in Trie.Delete", "err", err)
}
}
// TryDelete removes any existing value for key from the trie.
// If a node was not found in the database, a MissingNodeError is returned.
func (t *Trie) TryDelete(key []byte) error {
t.unhashed++
k := keybytesToHex(key)
_, n, err := t.delete(t.root, nil, k)
if err != nil {
return err
}
t.root = n
return nil
}
// delete returns the new root of the trie with key deleted.
// It reduces the trie to minimal form by simplifying
// nodes on the way up after deleting recursively.
func (t *Trie) delete(n node, prefix, key []byte) (bool, node, error) {
switch n := n.(type) {
case *shortNode:
matchlen := prefixLen(key, n.Key)
if matchlen < len(n.Key) {
return false, n, nil // don't replace n on mismatch
}
if matchlen == len(key) {
// The matched short node is deleted entirely and track
// it in the deletion set. The same the valueNode doesn't
// need to be tracked at all since it's always embedded.
t.tracer.onDelete(prefix)
return true, nil, nil // remove n entirely for whole matches
}
// The key is longer than n.Key. Remove the remaining suffix
// from the subtrie. Child can never be nil here since the
// subtrie must contain at least two other values with keys
// longer than n.Key.
dirty, child, err := t.delete(n.Val, append(prefix, key[:len(n.Key)]...), key[len(n.Key):])
if !dirty || err != nil {
return false, n, err
}
switch child := child.(type) {
case *shortNode:
// The child shortNode is merged into its parent, track
// is deleted as well.
t.tracer.onDelete(append(prefix, n.Key...))
// Deleting from the subtrie reduced it to another
// short node. Merge the nodes to avoid creating a
// shortNode{..., shortNode{...}}. Use concat (which
// always creates a new slice) instead of append to
// avoid modifying n.Key since it might be shared with
// other nodes.
return true, &shortNode{concat(n.Key, child.Key...), child.Val, t.newFlag()}, nil
default:
return true, &shortNode{n.Key, child, t.newFlag()}, nil
}
case *fullNode:
dirty, nn, err := t.delete(n.Children[key[0]], append(prefix, key[0]), key[1:])
if !dirty || err != nil {
return false, n, err
}
n = n.copy()
n.flags = t.newFlag()
n.Children[key[0]] = nn
// Because n is a full node, it must've contained at least two children
// before the delete operation. If the new child value is non-nil, n still
// has at least two children after the deletion, and cannot be reduced to
// a short node.
if nn != nil {
return true, n, nil
}
// Reduction:
// Check how many non-nil entries are left after deleting and
// reduce the full node to a short node if only one entry is
// left. Since n must've contained at least two children
// before deletion (otherwise it would not be a full node) n
// can never be reduced to nil.
//
// When the loop is done, pos contains the index of the single
// value that is left in n or -2 if n contains at least two
// values.
pos := -1
for i, cld := range &n.Children {
if cld != nil {
if pos == -1 {
pos = i
} else {
pos = -2
break
}
}
}
if pos >= 0 {
if pos != 16 {
// If the remaining entry is a short node, it replaces
// n and its key gets the missing nibble tacked to the
// front. This avoids creating an invalid
// shortNode{..., shortNode{...}}. Since the entry
// might not be loaded yet, resolve it just for this
// check.
cnode, err := t.resolve(n.Children[pos], append(prefix, byte(pos)))
if err != nil {
return false, nil, err
}
if cnode, ok := cnode.(*shortNode); ok {
// Replace the entire full node with the short node.
// Mark the original short node as deleted since the
// value is embedded into the parent now.
t.tracer.onDelete(append(prefix, byte(pos)))
k := append([]byte{byte(pos)}, cnode.Key...)
return true, &shortNode{k, cnode.Val, t.newFlag()}, nil
}
}
// Otherwise, n is replaced by a one-nibble short node
// containing the child.
return true, &shortNode{[]byte{byte(pos)}, n.Children[pos], t.newFlag()}, nil
}
// n still contains at least two values and cannot be reduced.
return true, n, nil
case valueNode:
return true, nil, nil
case nil:
return false, nil, nil
case hashNode:
// We've hit a part of the trie that isn't loaded yet. Load
// the node and delete from it. This leaves all child nodes on
// the path to the value in the trie.
rn, err := t.resolveAndTrack(n, prefix)
if err != nil {
return false, nil, err
}
dirty, nn, err := t.delete(rn, prefix, key)
if !dirty || err != nil {
return false, rn, err
}
return true, nn, nil
default:
panic(fmt.Sprintf("%T: invalid node: %v (%v)", n, n, key))
}
}
func concat(s1 []byte, s2 ...byte) []byte {
r := make([]byte, len(s1)+len(s2))
copy(r, s1)
copy(r[len(s1):], s2)
return r
}
func (t *Trie) resolve(n node, prefix []byte) (node, error) {
if n, ok := n.(hashNode); ok {
return t.resolveAndTrack(n, prefix)
}
return n, nil
}
// resolveAndTrack loads node from the underlying store with the given node hash
// and path prefix and also tracks the loaded node blob in tracer treated as the
// node's original value. The rlp-encoded blob is preferred to be loaded from
// database because it's easy to decode node while complex to encode node to blob.
func (t *Trie) resolveAndTrack(n hashNode, prefix []byte) (node, error) {
blob, err := t.reader.nodeBlob(prefix, common.BytesToHash(n))
if err != nil { if err != nil {
return nil, err return nil, err
} }
blob, err := t.db.Node(cid) t.tracer.onRead(prefix, blob)
if err != nil { return mustDecodeNode(n, blob), nil
return nil, err
}
if len(blob) != 0 {
return blob, nil
}
return nil, &MissingNodeError{Owner: t.owner, NodeHash: n, Path: prefix}
} }
// Hash returns the root hash of the trie. It does not write to the // Hash returns the root hash of the trie. It does not write to the
@ -254,15 +546,59 @@ func (t *Trie) Hash() common.Hash {
return common.BytesToHash(hash.(hashNode)) return common.BytesToHash(hash.(hashNode))
} }
// Commit collects all dirty nodes in the trie and replaces them with the
// corresponding node hash. All collected nodes (including dirty leaves if
// collectLeaf is true) will be encapsulated into a nodeset for return.
// The returned nodeset can be nil if the trie is clean (nothing to commit).
// Once the trie is committed, it's not usable anymore. A new trie must
// be created with new root and updated trie database for following usage
func (t *Trie) Commit(collectLeaf bool) (common.Hash, *NodeSet) {
defer t.tracer.reset()
nodes := NewNodeSet(t.owner, t.tracer.accessList)
t.tracer.markDeletions(nodes)
// Trie is empty and can be classified into two types of situations:
// - The trie was empty and no update happens
// - The trie was non-empty and all nodes are dropped
if t.root == nil {
return types.EmptyRootHash, nodes
}
// Derive the hash for all dirty nodes first. We hold the assumption
// in the following procedure that all nodes are hashed.
rootHash := t.Hash()
// Do a quick check if we really need to commit. This can happen e.g.
// if we load a trie for reading storage values, but don't write to it.
if hashedNode, dirty := t.root.cache(); !dirty {
// Replace the root node with the origin hash in order to
// ensure all resolved nodes are dropped after the commit.
t.root = hashedNode
return rootHash, nil
}
t.root = newCommitter(nodes, collectLeaf).Commit(t.root)
return rootHash, nodes
}
// hashRoot calculates the root hash of the given trie // hashRoot calculates the root hash of the given trie
func (t *Trie) hashRoot() (node, node) { func (t *Trie) hashRoot() (node, node) {
if t.root == nil { if t.root == nil {
return hashNode(emptyRoot.Bytes()), nil return hashNode(types.EmptyRootHash.Bytes()), nil
} }
// If the number of changes is below 100, we let one thread handle it // If the number of changes is below 100, we let one thread handle it
h := newHasher(t.unhashed >= 100) h := newHasher(t.unhashed >= 100)
defer returnHasherToPool(h) defer func() {
returnHasherToPool(h)
t.unhashed = 0
}()
hashed, cached := h.hash(t.root, true) hashed, cached := h.hash(t.root, true)
t.unhashed = 0
return hashed, cached return hashed, cached
} }
// Reset drops the referenced root node and cleans all internal state.
func (t *Trie) Reset() {
t.root = nil
t.owner = common.Hash{}
t.unhashed = 0
t.tracer.reset()
}

View File

@ -0,0 +1,55 @@
// Copyright 2022 The go-ethereum Authors
// This file is part of the go-ethereum library.
//
// The go-ethereum library is free software: you can redistribute it and/or modify
// it under the terms of the GNU Lesser General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// The go-ethereum library is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Lesser General Public License for more details.
//
// You should have received a copy of the GNU Lesser General Public License
// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>
package trie
import "github.com/ethereum/go-ethereum/common"
// ID is the identifier for uniquely identifying a trie.
type ID struct {
StateRoot common.Hash // The root of the corresponding state(block.root)
Owner common.Hash // The contract address hash which the trie belongs to
Root common.Hash // The root hash of trie
}
// StateTrieID constructs an identifier for state trie with the provided state root.
func StateTrieID(root common.Hash) *ID {
return &ID{
StateRoot: root,
Owner: common.Hash{},
Root: root,
}
}
// StorageTrieID constructs an identifier for storage trie which belongs to a certain
// state and contract specified by the stateRoot and owner.
func StorageTrieID(stateRoot common.Hash, owner common.Hash, root common.Hash) *ID {
return &ID{
StateRoot: stateRoot,
Owner: owner,
Root: root,
}
}
// TrieID constructs an identifier for a standard trie(not a second-layer trie)
// with provided root. It's mostly used in tests and some other tries like CHT trie.
func TrieID(root common.Hash) *ID {
return &ID{
StateRoot: root,
Owner: common.Hash{},
Root: root,
}
}

View File

@ -0,0 +1,106 @@
// Copyright 2022 The go-ethereum Authors
// This file is part of the go-ethereum library.
//
// The go-ethereum library is free software: you can redistribute it and/or modify
// it under the terms of the GNU Lesser General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// The go-ethereum library is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Lesser General Public License for more details.
//
// You should have received a copy of the GNU Lesser General Public License
// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
package trie
import (
"fmt"
"github.com/ethereum/go-ethereum/common"
)
// Reader wraps the Node and NodeBlob method of a backing trie store.
type Reader interface {
// Node retrieves the trie node with the provided trie identifier, hexary
// node path and the corresponding node hash.
// No error will be returned if the node is not found.
Node(owner common.Hash, path []byte, hash common.Hash) (node, error)
// NodeBlob retrieves the RLP-encoded trie node blob with the provided trie
// identifier, hexary node path and the corresponding node hash.
// No error will be returned if the node is not found.
NodeBlob(owner common.Hash, path []byte, hash common.Hash) ([]byte, error)
}
// NodeReader wraps all the necessary functions for accessing trie node.
type NodeReader interface {
// GetReader returns a reader for accessing all trie nodes with provided
// state root. Nil is returned in case the state is not available.
GetReader(root common.Hash, codec uint64) Reader
}
// trieReader is a wrapper of the underlying node reader. It's not safe
// for concurrent usage.
type trieReader struct {
owner common.Hash
reader Reader
banned map[string]struct{} // Marker to prevent node from being accessed, for tests
}
// newTrieReader initializes the trie reader with the given node reader.
func newTrieReader(stateRoot, owner common.Hash, db NodeReader, codec uint64) (*trieReader, error) {
reader := db.GetReader(stateRoot, codec)
if reader == nil {
return nil, fmt.Errorf("state not found #%x", stateRoot)
}
return &trieReader{owner: owner, reader: reader}, nil
}
// newEmptyReader initializes the pure in-memory reader. All read operations
// should be forbidden and returns the MissingNodeError.
func newEmptyReader() *trieReader {
return &trieReader{}
}
// node retrieves the trie node with the provided trie node information.
// An MissingNodeError will be returned in case the node is not found or
// any error is encountered.
func (r *trieReader) node(path []byte, hash common.Hash) (node, error) {
// Perform the logics in tests for preventing trie node access.
if r.banned != nil {
if _, ok := r.banned[string(path)]; ok {
return nil, &MissingNodeError{Owner: r.owner, NodeHash: hash, Path: path}
}
}
if r.reader == nil {
return nil, &MissingNodeError{Owner: r.owner, NodeHash: hash, Path: path}
}
node, err := r.reader.Node(r.owner, path, hash)
if err != nil || node == nil {
return nil, &MissingNodeError{Owner: r.owner, NodeHash: hash, Path: path, err: err}
}
return node, nil
}
// node retrieves the rlp-encoded trie node with the provided trie node
// information. An MissingNodeError will be returned in case the node is
// not found or any error is encountered.
func (r *trieReader) nodeBlob(path []byte, hash common.Hash) ([]byte, error) {
// Perform the logics in tests for preventing trie node access.
if r.banned != nil {
if _, ok := r.banned[string(path)]; ok {
return nil, &MissingNodeError{Owner: r.owner, NodeHash: hash, Path: path}
}
}
if r.reader == nil {
return nil, &MissingNodeError{Owner: r.owner, NodeHash: hash, Path: path}
}
blob, err := r.reader.NodeBlob(r.owner, path, hash)
if err != nil || len(blob) == 0 {
return nil, &MissingNodeError{Owner: r.owner, NodeHash: hash, Path: path, err: err}
}
return blob, nil
}

View File

@ -14,26 +14,34 @@
// You should have received a copy of the GNU Lesser General Public License // You should have received a copy of the GNU Lesser General Public License
// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>. // along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
package trie_test package trie
import ( import (
"bytes"
"encoding/binary"
"errors"
"fmt" "fmt"
"math/big" "math/big"
"math/rand" "math/rand"
"reflect"
"testing" "testing"
"testing/quick"
"github.com/davecgh/go-spew/spew"
"github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/core/rawdb" "github.com/ethereum/go-ethereum/core/rawdb"
"github.com/ethereum/go-ethereum/core/types" "github.com/ethereum/go-ethereum/core/types"
"github.com/ethereum/go-ethereum/crypto" "github.com/ethereum/go-ethereum/crypto"
"github.com/ethereum/go-ethereum/rlp" "github.com/ethereum/go-ethereum/rlp"
geth_trie "github.com/ethereum/go-ethereum/trie"
"github.com/cerc-io/ipld-eth-statedb/trie_by_cid/trie"
) )
func TestTrieEmpty(t *testing.T) { func init() {
trie := trie.NewEmpty(trie.NewDatabase(rawdb.NewMemoryDatabase())) spew.Config.Indent = " "
spew.Config.DisableMethods = false
}
func TestEmptyTrie(t *testing.T) {
trie := NewEmpty(NewDatabase(rawdb.NewMemoryDatabase()))
res := trie.Hash() res := trie.Hash()
exp := types.EmptyRootHash exp := types.EmptyRootHash
if res != exp { if res != exp {
@ -41,42 +49,543 @@ func TestTrieEmpty(t *testing.T) {
} }
} }
func TestTrieMissingRoot(t *testing.T) { func TestNull(t *testing.T) {
trie := NewEmpty(NewDatabase(rawdb.NewMemoryDatabase()))
key := make([]byte, 32)
value := []byte("test")
trie.Update(key, value)
if !bytes.Equal(trie.Get(key), value) {
t.Fatal("wrong value")
}
}
func TestMissingRoot(t *testing.T) {
root := common.HexToHash("0beec7b5ea3f0fdbc95d0dd47f3c5bc275da8a33") root := common.HexToHash("0beec7b5ea3f0fdbc95d0dd47f3c5bc275da8a33")
tr, err := newStateTrie(root, trie.NewDatabase(rawdb.NewMemoryDatabase())) trie, err := NewAccountTrie(TrieID(root), NewDatabase(rawdb.NewMemoryDatabase()))
if tr != nil { if trie != nil {
t.Error("New returned non-nil trie for invalid root") t.Error("New returned non-nil trie for invalid root")
} }
if _, ok := err.(*trie.MissingNodeError); !ok { if _, ok := err.(*MissingNodeError); !ok {
t.Errorf("New returned wrong error: %v", err) t.Errorf("New returned wrong error: %v", err)
} }
} }
func TestTrieBasic(t *testing.T) { func TestMissingNodeMemonly(t *testing.T) { testMissingNode(t, true) }
edb := rawdb.NewMemoryDatabase()
db := geth_trie.NewDatabase(edb) func testMissingNode(t *testing.T, memonly bool) {
origtrie := geth_trie.NewEmpty(db) diskdb := rawdb.NewMemoryDatabase()
origtrie.Update([]byte("foo"), packValue(842)) triedb := NewDatabase(diskdb)
expected := commitTrie(t, db, origtrie)
tr := indexTrie(t, edb, expected) trie := NewEmpty(triedb)
got := tr.Hash() updateString(trie, "120000", "qwerqwerqwerqwerqwerqwerqwerqwer")
if expected != got { updateString(trie, "123456", "asdfasdfasdfasdfasdfasdfasdfasdf")
t.Errorf("got %x expected %x", got, expected) root, nodes := trie.Commit(false)
triedb.Update(NewWithNodeSet(nodes))
trie, _ = NewAccountTrie(TrieID(root), triedb)
_, err := trie.TryGet([]byte("120000"))
if err != nil {
t.Errorf("Unexpected error: %v", err)
}
trie, _ = NewAccountTrie(TrieID(root), triedb)
_, err = trie.TryGet([]byte("120099"))
if err != nil {
t.Errorf("Unexpected error: %v", err)
}
trie, _ = NewAccountTrie(TrieID(root), triedb)
_, err = trie.TryGet([]byte("123456"))
if err != nil {
t.Errorf("Unexpected error: %v", err)
}
trie, _ = NewAccountTrie(TrieID(root), triedb)
err = trie.TryUpdate([]byte("120099"), []byte("zxcvzxcvzxcvzxcvzxcvzxcvzxcvzxcv"))
if err != nil {
t.Errorf("Unexpected error: %v", err)
}
trie, _ = NewAccountTrie(TrieID(root), triedb)
err = trie.TryDelete([]byte("123456"))
if err != nil {
t.Errorf("Unexpected error: %v", err)
}
hash := common.HexToHash("0xe1d943cc8f061a0c0b98162830b970395ac9315654824bf21b73b891365262f9")
if memonly {
delete(triedb.dirties, hash)
} else {
diskdb.Delete(hash[:])
}
trie, _ = NewAccountTrie(TrieID(root), triedb)
_, err = trie.TryGet([]byte("120000"))
if _, ok := err.(*MissingNodeError); !ok {
t.Errorf("Wrong error: %v", err)
}
trie, _ = NewAccountTrie(TrieID(root), triedb)
_, err = trie.TryGet([]byte("120099"))
if _, ok := err.(*MissingNodeError); !ok {
t.Errorf("Wrong error: %v", err)
}
trie, _ = NewAccountTrie(TrieID(root), triedb)
_, err = trie.TryGet([]byte("123456"))
if err != nil {
t.Errorf("Unexpected error: %v", err)
}
trie, _ = NewAccountTrie(TrieID(root), triedb)
err = trie.TryUpdate([]byte("120099"), []byte("zxcv"))
if _, ok := err.(*MissingNodeError); !ok {
t.Errorf("Wrong error: %v", err)
}
trie, _ = NewAccountTrie(TrieID(root), triedb)
err = trie.TryDelete([]byte("123456"))
if _, ok := err.(*MissingNodeError); !ok {
t.Errorf("Wrong error: %v", err)
} }
checkValue(t, tr, []byte("foo"))
} }
func TestTrieTiny(t *testing.T) { func TestInsert(t *testing.T) {
trie := NewEmpty(NewDatabase(rawdb.NewMemoryDatabase()))
updateString(trie, "doe", "reindeer")
updateString(trie, "dog", "puppy")
updateString(trie, "dogglesworth", "cat")
exp := common.HexToHash("8aad789dff2f538bca5d8ea56e8abe10f4c7ba3a5dea95fea4cd6e7c3a1168d3")
root := trie.Hash()
if root != exp {
t.Errorf("case 1: exp %x got %x", exp, root)
}
trie = NewEmpty(NewDatabase(rawdb.NewMemoryDatabase()))
updateString(trie, "A", "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa")
exp = common.HexToHash("d23786fb4a010da3ce639d66d5e904a11dbc02746d1ce25029e53290cabf28ab")
root, _ = trie.Commit(false)
if root != exp {
t.Errorf("case 2: exp %x got %x", exp, root)
}
}
func TestGet(t *testing.T) {
db := NewDatabase(rawdb.NewMemoryDatabase())
trie := NewEmpty(db)
updateString(trie, "doe", "reindeer")
updateString(trie, "dog", "puppy")
updateString(trie, "dogglesworth", "cat")
for i := 0; i < 2; i++ {
res := getString(trie, "dog")
if !bytes.Equal(res, []byte("puppy")) {
t.Errorf("expected puppy got %x", res)
}
unknown := getString(trie, "unknown")
if unknown != nil {
t.Errorf("expected nil got %x", unknown)
}
if i == 1 {
return
}
root, nodes := trie.Commit(false)
db.Update(NewWithNodeSet(nodes))
trie, _ = NewAccountTrie(TrieID(root), db)
}
}
func TestDelete(t *testing.T) {
trie := NewEmpty(NewDatabase(rawdb.NewMemoryDatabase()))
vals := []struct{ k, v string }{
{"do", "verb"},
{"ether", "wookiedoo"},
{"horse", "stallion"},
{"shaman", "horse"},
{"doge", "coin"},
{"ether", ""},
{"dog", "puppy"},
{"shaman", ""},
}
for _, val := range vals {
if val.v != "" {
updateString(trie, val.k, val.v)
} else {
deleteString(trie, val.k)
}
}
hash := trie.Hash()
exp := common.HexToHash("5991bb8c6514148a29db676a14ac506cd2cd5775ace63c30a4fe457715e9ac84")
if hash != exp {
t.Errorf("expected %x got %x", exp, hash)
}
}
func TestEmptyValues(t *testing.T) {
trie := NewEmpty(NewDatabase(rawdb.NewMemoryDatabase()))
vals := []struct{ k, v string }{
{"do", "verb"},
{"ether", "wookiedoo"},
{"horse", "stallion"},
{"shaman", "horse"},
{"doge", "coin"},
{"ether", ""},
{"dog", "puppy"},
{"shaman", ""},
}
for _, val := range vals {
updateString(trie, val.k, val.v)
}
hash := trie.Hash()
exp := common.HexToHash("5991bb8c6514148a29db676a14ac506cd2cd5775ace63c30a4fe457715e9ac84")
if hash != exp {
t.Errorf("expected %x got %x", exp, hash)
}
}
func TestReplication(t *testing.T) {
triedb := NewDatabase(rawdb.NewMemoryDatabase())
trie := NewEmpty(triedb)
vals := []struct{ k, v string }{
{"do", "verb"},
{"ether", "wookiedoo"},
{"horse", "stallion"},
{"shaman", "horse"},
{"doge", "coin"},
{"dog", "puppy"},
{"somethingveryoddindeedthis is", "myothernodedata"},
}
for _, val := range vals {
updateString(trie, val.k, val.v)
}
exp, nodes := trie.Commit(false)
triedb.Update(NewWithNodeSet(nodes))
// create a new trie on top of the database and check that lookups work.
trie2, err := NewAccountTrie(TrieID(exp), triedb)
if err != nil {
t.Fatalf("can't recreate trie at %x: %v", exp, err)
}
for _, kv := range vals {
if string(getString(trie2, kv.k)) != kv.v {
t.Errorf("trie2 doesn't have %q => %q", kv.k, kv.v)
}
}
hash, nodes := trie2.Commit(false)
if hash != exp {
t.Errorf("root failure. expected %x got %x", exp, hash)
}
// recreate the trie after commit
if nodes != nil {
triedb.Update(NewWithNodeSet(nodes))
}
trie2, err = NewAccountTrie(TrieID(hash), triedb)
if err != nil {
t.Fatalf("can't recreate trie at %x: %v", exp, err)
}
// perform some insertions on the new trie.
vals2 := []struct{ k, v string }{
{"do", "verb"},
{"ether", "wookiedoo"},
{"horse", "stallion"},
// {"shaman", "horse"},
// {"doge", "coin"},
// {"ether", ""},
// {"dog", "puppy"},
// {"somethingveryoddindeedthis is", "myothernodedata"},
// {"shaman", ""},
}
for _, val := range vals2 {
updateString(trie2, val.k, val.v)
}
if hash := trie2.Hash(); hash != exp {
t.Errorf("root failure. expected %x got %x", exp, hash)
}
}
func TestLargeValue(t *testing.T) {
trie := NewEmpty(NewDatabase(rawdb.NewMemoryDatabase()))
trie.Update([]byte("key1"), []byte{99, 99, 99, 99})
trie.Update([]byte("key2"), bytes.Repeat([]byte{1}, 32))
trie.Hash()
}
// TestRandomCases tests some cases that were found via random fuzzing
func TestRandomCases(t *testing.T) {
var rt = []randTestStep{
{op: 6, key: common.Hex2Bytes(""), value: common.Hex2Bytes("")}, // step 0
{op: 6, key: common.Hex2Bytes(""), value: common.Hex2Bytes("")}, // step 1
{op: 0, key: common.Hex2Bytes("d51b182b95d677e5f1c82508c0228de96b73092d78ce78b2230cd948674f66fd1483bd"), value: common.Hex2Bytes("0000000000000002")}, // step 2
{op: 2, key: common.Hex2Bytes("c2a38512b83107d665c65235b0250002882ac2022eb00711552354832c5f1d030d0e408e"), value: common.Hex2Bytes("")}, // step 3
{op: 3, key: common.Hex2Bytes(""), value: common.Hex2Bytes("")}, // step 4
{op: 3, key: common.Hex2Bytes(""), value: common.Hex2Bytes("")}, // step 5
{op: 6, key: common.Hex2Bytes(""), value: common.Hex2Bytes("")}, // step 6
{op: 3, key: common.Hex2Bytes(""), value: common.Hex2Bytes("")}, // step 7
{op: 0, key: common.Hex2Bytes("c2a38512b83107d665c65235b0250002882ac2022eb00711552354832c5f1d030d0e408e"), value: common.Hex2Bytes("0000000000000008")}, // step 8
{op: 0, key: common.Hex2Bytes("d51b182b95d677e5f1c82508c0228de96b73092d78ce78b2230cd948674f66fd1483bd"), value: common.Hex2Bytes("0000000000000009")}, // step 9
{op: 2, key: common.Hex2Bytes("fd"), value: common.Hex2Bytes("")}, // step 10
{op: 6, key: common.Hex2Bytes(""), value: common.Hex2Bytes("")}, // step 11
{op: 6, key: common.Hex2Bytes(""), value: common.Hex2Bytes("")}, // step 12
{op: 0, key: common.Hex2Bytes("fd"), value: common.Hex2Bytes("000000000000000d")}, // step 13
{op: 6, key: common.Hex2Bytes(""), value: common.Hex2Bytes("")}, // step 14
{op: 1, key: common.Hex2Bytes("c2a38512b83107d665c65235b0250002882ac2022eb00711552354832c5f1d030d0e408e"), value: common.Hex2Bytes("")}, // step 15
{op: 3, key: common.Hex2Bytes(""), value: common.Hex2Bytes("")}, // step 16
{op: 0, key: common.Hex2Bytes("c2a38512b83107d665c65235b0250002882ac2022eb00711552354832c5f1d030d0e408e"), value: common.Hex2Bytes("0000000000000011")}, // step 17
{op: 5, key: common.Hex2Bytes(""), value: common.Hex2Bytes("")}, // step 18
{op: 3, key: common.Hex2Bytes(""), value: common.Hex2Bytes("")}, // step 19
{op: 0, key: common.Hex2Bytes("d51b182b95d677e5f1c82508c0228de96b73092d78ce78b2230cd948674f66fd1483bd"), value: common.Hex2Bytes("0000000000000014")}, // step 20
{op: 0, key: common.Hex2Bytes("d51b182b95d677e5f1c82508c0228de96b73092d78ce78b2230cd948674f66fd1483bd"), value: common.Hex2Bytes("0000000000000015")}, // step 21
{op: 0, key: common.Hex2Bytes("c2a38512b83107d665c65235b0250002882ac2022eb00711552354832c5f1d030d0e408e"), value: common.Hex2Bytes("0000000000000016")}, // step 22
{op: 5, key: common.Hex2Bytes(""), value: common.Hex2Bytes("")}, // step 23
{op: 1, key: common.Hex2Bytes("980c393656413a15c8da01978ed9f89feb80b502f58f2d640e3a2f5f7a99a7018f1b573befd92053ac6f78fca4a87268"), value: common.Hex2Bytes("")}, // step 24
{op: 1, key: common.Hex2Bytes("fd"), value: common.Hex2Bytes("")}, // step 25
}
runRandTest(rt)
}
// randTest performs random trie operations.
// Instances of this test are created by Generate.
type randTest []randTestStep
type randTestStep struct {
op int
key []byte // for opUpdate, opDelete, opGet
value []byte // for opUpdate
err error // for debugging
}
const (
opUpdate = iota
opDelete
opGet
opHash
opCommit
opItercheckhash
opNodeDiff
opProve
opMax // boundary value, not an actual op
)
func (randTest) Generate(r *rand.Rand, size int) reflect.Value {
var allKeys [][]byte
genKey := func() []byte {
if len(allKeys) < 2 || r.Intn(100) < 10 {
// new key
key := make([]byte, r.Intn(50))
r.Read(key)
allKeys = append(allKeys, key)
return key
}
// use existing key
return allKeys[r.Intn(len(allKeys))]
}
var steps randTest
for i := 0; i < size; i++ {
step := randTestStep{op: r.Intn(opMax)}
switch step.op {
case opUpdate:
step.key = genKey()
step.value = make([]byte, 8)
binary.BigEndian.PutUint64(step.value, uint64(i))
case opGet, opDelete, opProve:
step.key = genKey()
}
steps = append(steps, step)
}
return reflect.ValueOf(steps)
}
func verifyAccessList(old *Trie, new *Trie, set *NodeSet) error {
deletes, inserts, updates := diffTries(old, new)
// Check insertion set
for path := range inserts {
n, ok := set.nodes[path]
if !ok || n.isDeleted() {
return errors.New("expect new node")
}
_, ok = set.accessList[path]
if ok {
return errors.New("unexpected origin value")
}
}
// Check deletion set
for path, blob := range deletes {
n, ok := set.nodes[path]
if !ok || !n.isDeleted() {
return errors.New("expect deleted node")
}
v, ok := set.accessList[path]
if !ok {
return errors.New("expect origin value")
}
if !bytes.Equal(v, blob) {
return errors.New("invalid origin value")
}
}
// Check update set
for path, blob := range updates {
n, ok := set.nodes[path]
if !ok || n.isDeleted() {
return errors.New("expect updated node")
}
v, ok := set.accessList[path]
if !ok {
return errors.New("expect origin value")
}
if !bytes.Equal(v, blob) {
return errors.New("invalid origin value")
}
}
return nil
}
func runRandTest(rt randTest) bool {
var (
triedb = NewDatabase(rawdb.NewMemoryDatabase())
tr = NewEmpty(triedb)
values = make(map[string]string) // tracks content of the trie
origTrie = NewEmpty(triedb)
)
for i, step := range rt {
// fmt.Printf("{op: %d, key: common.Hex2Bytes(\"%x\"), value: common.Hex2Bytes(\"%x\")}, // step %d\n",
// step.op, step.key, step.value, i)
switch step.op {
case opUpdate:
tr.Update(step.key, step.value)
values[string(step.key)] = string(step.value)
case opDelete:
tr.Delete(step.key)
delete(values, string(step.key))
case opGet:
v := tr.Get(step.key)
want := values[string(step.key)]
if string(v) != want {
rt[i].err = fmt.Errorf("mismatch for key %#x, got %#x want %#x", step.key, v, want)
}
case opProve:
hash := tr.Hash()
if hash == types.EmptyRootHash {
continue
}
proofDb := rawdb.NewMemoryDatabase()
err := tr.Prove(step.key, 0, proofDb)
if err != nil {
rt[i].err = fmt.Errorf("failed for proving key %#x, %v", step.key, err)
}
_, err = VerifyProof(hash, step.key, proofDb)
if err != nil {
rt[i].err = fmt.Errorf("failed for verifying key %#x, %v", step.key, err)
}
case opHash:
tr.Hash()
case opCommit:
root, nodes := tr.Commit(true)
if nodes != nil {
triedb.Update(NewWithNodeSet(nodes))
}
newtr, err := NewAccountTrie(TrieID(root), triedb)
if err != nil {
rt[i].err = err
return false
}
if nodes != nil {
if err := verifyAccessList(origTrie, newtr, nodes); err != nil {
rt[i].err = err
return false
}
}
tr = newtr
origTrie = tr.Copy()
case opItercheckhash:
checktr := NewEmpty(triedb)
it := NewIterator(tr.NodeIterator(nil))
for it.Next() {
checktr.Update(it.Key, it.Value)
}
if tr.Hash() != checktr.Hash() {
rt[i].err = fmt.Errorf("hash mismatch in opItercheckhash")
}
case opNodeDiff:
var (
origIter = origTrie.NodeIterator(nil)
curIter = tr.NodeIterator(nil)
origSeen = make(map[string]struct{})
curSeen = make(map[string]struct{})
)
for origIter.Next(true) {
if origIter.Leaf() {
continue
}
origSeen[string(origIter.Path())] = struct{}{}
}
for curIter.Next(true) {
if curIter.Leaf() {
continue
}
curSeen[string(curIter.Path())] = struct{}{}
}
var (
insertExp = make(map[string]struct{})
deleteExp = make(map[string]struct{})
)
for path := range curSeen {
_, present := origSeen[path]
if !present {
insertExp[path] = struct{}{}
}
}
for path := range origSeen {
_, present := curSeen[path]
if !present {
deleteExp[path] = struct{}{}
}
}
if len(insertExp) != len(tr.tracer.inserts) {
rt[i].err = fmt.Errorf("insert set mismatch")
}
if len(deleteExp) != len(tr.tracer.deletes) {
rt[i].err = fmt.Errorf("delete set mismatch")
}
for insert := range tr.tracer.inserts {
if _, present := insertExp[insert]; !present {
rt[i].err = fmt.Errorf("missing inserted node")
}
}
for del := range tr.tracer.deletes {
if _, present := deleteExp[del]; !present {
rt[i].err = fmt.Errorf("missing deleted node")
}
}
}
// Abort the test on error.
if rt[i].err != nil {
return false
}
}
return true
}
func TestRandom(t *testing.T) {
if err := quick.Check(runRandTest, nil); err != nil {
if cerr, ok := err.(*quick.CheckError); ok {
t.Fatalf("random test iteration %d failed: %s", cerr.Count, spew.Sdump(cerr.In))
}
t.Fatal(err)
}
}
func TestTinyTrie(t *testing.T) {
// Create a realistic account trie to hash // Create a realistic account trie to hash
_, accounts := makeAccounts(5) _, accounts := makeAccounts(5)
edb := rawdb.NewMemoryDatabase() trie := NewEmpty(NewDatabase(rawdb.NewMemoryDatabase()))
db := geth_trie.NewDatabase(edb)
origtrie := geth_trie.NewEmpty(db)
type testCase struct { type testCase struct {
key, account []byte key, account []byte
root common.Hash root common.Hash
} }
cases := []testCase{ cases := []testCase{
{ {
common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000001337"), common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000001337"),
@ -92,48 +601,48 @@ func TestTrieTiny(t *testing.T) {
common.HexToHash("0608c1d1dc3905fa22204c7a0e43644831c3b6d3def0f274be623a948197e64a"), common.HexToHash("0608c1d1dc3905fa22204c7a0e43644831c3b6d3def0f274be623a948197e64a"),
}, },
} }
for i, tc := range cases { for i, c := range cases {
t.Run(fmt.Sprintf("case %d", i), func(t *testing.T) { trie.Update(c.key, c.account)
origtrie.Update(tc.key, tc.account) root := trie.Hash()
trie := indexTrie(t, edb, commitTrie(t, db, origtrie)) if root != c.root {
if exp, root := tc.root, trie.Hash(); exp != root { t.Errorf("case %d: got %x, exp %x", i, root, c.root)
t.Errorf("got %x, exp %x", root, exp) }
} }
checkValue(t, trie, tc.key) checktr := NewEmpty(NewDatabase(rawdb.NewMemoryDatabase()))
}) it := NewIterator(trie.NodeIterator(nil))
for it.Next() {
checktr.Update(it.Key, it.Value)
}
if troot, itroot := trie.Hash(), checktr.Hash(); troot != itroot {
t.Fatalf("hash mismatch in opItercheckhash, trie: %x, check: %x", troot, itroot)
} }
} }
func TestTrieMedium(t *testing.T) { func TestCommitAfterHash(t *testing.T) {
// Create a realistic account trie to hash // Create a realistic account trie to hash
addresses, accounts := makeAccounts(1000) addresses, accounts := makeAccounts(1000)
edb := rawdb.NewMemoryDatabase() trie := NewEmpty(NewDatabase(rawdb.NewMemoryDatabase()))
db := geth_trie.NewDatabase(edb)
origtrie := geth_trie.NewEmpty(db)
var keys [][]byte
for i := 0; i < len(addresses); i++ { for i := 0; i < len(addresses); i++ {
key := crypto.Keccak256(addresses[i][:]) trie.Update(crypto.Keccak256(addresses[i][:]), accounts[i])
if i%50 == 0 {
keys = append(keys, key)
}
origtrie.Update(key, accounts[i])
} }
tr := indexTrie(t, edb, commitTrie(t, db, origtrie)) // Insert the accounts into the trie and hash it
trie.Hash()
root := tr.Hash() trie.Commit(false)
root := trie.Hash()
exp := common.HexToHash("72f9d3f3fe1e1dd7b8936442e7642aef76371472d94319900790053c493f3fe6") exp := common.HexToHash("72f9d3f3fe1e1dd7b8936442e7642aef76371472d94319900790053c493f3fe6")
if exp != root { if exp != root {
t.Errorf("got %x, exp %x", root, exp) t.Errorf("got %x, exp %x", root, exp)
} }
root, _ = trie.Commit(false)
for _, key := range keys { if exp != root {
checkValue(t, tr, key) t.Errorf("got %x, exp %x", root, exp)
} }
} }
// Make deterministically random accounts
func makeAccounts(size int) (addresses [][20]byte, accounts [][]byte) { func makeAccounts(size int) (addresses [][20]byte, accounts [][]byte) {
// Make the random benchmark deterministic
random := rand.New(rand.NewSource(0)) random := rand.New(rand.NewSource(0))
// Create a realistic account trie to hash
addresses = make([][20]byte, size) addresses = make([][20]byte, size)
for i := 0; i < len(addresses); i++ { for i := 0; i < len(addresses); i++ {
data := make([]byte, 20) data := make([]byte, 20)
@ -149,25 +658,40 @@ func makeAccounts(size int) (addresses [][20]byte, accounts [][]byte) {
) )
// The big.Rand function is not deterministic with regards to 64 vs 32 bit systems, // The big.Rand function is not deterministic with regards to 64 vs 32 bit systems,
// and will consume different amount of data from the rand source. // and will consume different amount of data from the rand source.
// balance = new(big.Int).Rand(random, new(big.Int).Exp(common.Big2, common.Big256, nil)) //balance = new(big.Int).Rand(random, new(big.Int).Exp(common.Big2, common.Big256, nil))
// Therefore, we instead just read via byte buffer // Therefore, we instead just read via byte buffer
numBytes := random.Uint32() % 33 // [0, 32] bytes numBytes := random.Uint32() % 33 // [0, 32] bytes
balanceBytes := make([]byte, numBytes) balanceBytes := make([]byte, numBytes)
random.Read(balanceBytes) random.Read(balanceBytes)
balance := new(big.Int).SetBytes(balanceBytes) balance := new(big.Int).SetBytes(balanceBytes)
acct := &types.StateAccount{Nonce: nonce, Balance: balance, Root: root, CodeHash: code} data, _ := rlp.EncodeToBytes(&types.StateAccount{Nonce: nonce, Balance: balance, Root: root, CodeHash: code})
data, _ := rlp.EncodeToBytes(acct)
accounts[i] = data accounts[i] = data
} }
return addresses, accounts return addresses, accounts
} }
func checkValue(t *testing.T, tr *trie.Trie, key []byte) { func getString(trie *Trie, k string) []byte {
val, err := tr.TryGet(key) return trie.Get([]byte(k))
if err != nil { }
t.Fatalf("error getting node: %s", err)
} func updateString(trie *Trie, k, v string) {
if len(val) == 0 { trie.Update([]byte(k), []byte(v))
t.Errorf("failed to get value for %x", key) }
func deleteString(trie *Trie, k string) {
trie.Delete([]byte(k))
}
func TestDecodeNode(t *testing.T) {
t.Parallel()
var (
hash = make([]byte, 20)
elems = make([]byte, 20)
)
for i := 0; i < 5000000; i++ {
prng.Read(hash)
prng.Read(elems)
decodeNode(hash, elems)
} }
} }

View File

@ -1,43 +1,133 @@
package trie_test package trie
import ( import (
"bytes"
"context"
"fmt" "fmt"
"math/big" "math/big"
"math/rand" "math/rand"
"testing" "testing"
"time"
"github.com/jmoiron/sqlx"
"github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/core/rawdb" "github.com/ethereum/go-ethereum/core/rawdb"
geth_state "github.com/ethereum/go-ethereum/core/state" gethstate "github.com/ethereum/go-ethereum/core/state"
"github.com/ethereum/go-ethereum/core/types" "github.com/ethereum/go-ethereum/core/types"
"github.com/ethereum/go-ethereum/ethdb" "github.com/ethereum/go-ethereum/ethdb"
"github.com/ethereum/go-ethereum/rlp" "github.com/ethereum/go-ethereum/rlp"
geth_trie "github.com/ethereum/go-ethereum/trie" gethtrie "github.com/ethereum/go-ethereum/trie"
"github.com/jmoiron/sqlx"
pgipfsethdb "github.com/cerc-io/ipfs-ethdb/v5/postgres/v0" pgipfsethdb "github.com/cerc-io/ipfs-ethdb/v5/postgres/v0"
"github.com/cerc-io/ipld-eth-statedb/trie_by_cid/helper"
"github.com/cerc-io/ipld-eth-statedb/trie_by_cid/state"
"github.com/cerc-io/ipld-eth-statedb/trie_by_cid/trie"
"github.com/ethereum/go-ethereum/statediff/indexer/database/sql/postgres" "github.com/ethereum/go-ethereum/statediff/indexer/database/sql/postgres"
"github.com/ethereum/go-ethereum/statediff/indexer/ipld"
"github.com/ethereum/go-ethereum/statediff/test_helpers" "github.com/ethereum/go-ethereum/statediff/test_helpers"
"github.com/cerc-io/ipld-eth-statedb/internal"
"github.com/cerc-io/ipld-eth-statedb/trie_by_cid/helper"
) )
type kv struct { var (
dbConfig, _ = postgres.DefaultConfig.WithEnv()
trieConfig = Config{Cache: 256}
)
type kvi struct {
k []byte k []byte
v int64 v int64
} }
type kvMap map[string]*kv type kvMap map[string]*kvi
type kvs struct { type kvsi struct {
k string k string
v int64 v int64
} }
// NewAccountTrie is a shortcut to create a trie using the StateTrieCodec (ie. IPLD MEthStateTrie codec).
func NewAccountTrie(id *ID, db NodeReader) (*Trie, error) {
return New(id, db, StateTrieCodec)
}
// makeTestTrie create a sample test trie to test node-wise reconstruction.
func makeTestTrie(t testing.TB) (*Database, *StateTrie, map[string][]byte) {
// Create an empty trie
triedb := NewDatabase(rawdb.NewMemoryDatabase())
trie, err := NewStateTrie(TrieID(common.Hash{}), triedb, StateTrieCodec)
if err != nil {
t.Fatal(err)
}
// Fill it with some arbitrary data
content := make(map[string][]byte)
for i := byte(0); i < 255; i++ {
// Map the same data under multiple keys
key, val := common.LeftPadBytes([]byte{1, i}, 32), []byte{i}
content[string(key)] = val
trie.Update(key, val)
key, val = common.LeftPadBytes([]byte{2, i}, 32), []byte{i}
content[string(key)] = val
trie.Update(key, val)
// Add some other data to inflate the trie
for j := byte(3); j < 13; j++ {
key, val = common.LeftPadBytes([]byte{j, i}, 32), []byte{j, i}
content[string(key)] = val
trie.Update(key, val)
}
}
root, nodes := trie.Commit(false)
if err := triedb.Update(NewWithNodeSet(nodes)); err != nil {
panic(fmt.Errorf("failed to commit db %v", err))
}
// Re-create the trie based on the new state
trie, err = NewStateTrie(TrieID(root), triedb, StateTrieCodec)
if err != nil {
t.Fatal(err)
}
return triedb, trie, content
}
func forHashedNodes(tr *Trie) map[string][]byte {
var (
it = tr.NodeIterator(nil)
nodes = make(map[string][]byte)
)
for it.Next(true) {
if it.Hash() == (common.Hash{}) {
continue
}
nodes[string(it.Path())] = common.CopyBytes(it.NodeBlob())
}
return nodes
}
func diffTries(trieA, trieB *Trie) (map[string][]byte, map[string][]byte, map[string][]byte) {
var (
nodesA = forHashedNodes(trieA)
nodesB = forHashedNodes(trieB)
inA = make(map[string][]byte) // hashed nodes in trie a but not b
inB = make(map[string][]byte) // hashed nodes in trie b but not a
both = make(map[string][]byte) // hashed nodes in both tries but different value
)
for path, blobA := range nodesA {
if blobB, ok := nodesB[path]; ok {
if bytes.Equal(blobA, blobB) {
continue
}
both[path] = blobA
continue
}
inA[path] = blobA
}
for path, blobB := range nodesB {
if _, ok := nodesA[path]; ok {
continue
}
inB[path] = blobB
}
return inA, inB, both
}
func packValue(val int64) []byte { func packValue(val int64) []byte {
acct := &types.StateAccount{ acct := &types.StateAccount{
Balance: big.NewInt(val), Balance: big.NewInt(val),
@ -51,27 +141,19 @@ func packValue(val int64) []byte {
return acct_rlp return acct_rlp
} }
func unpackValue(val []byte) int64 { func updateTrie(tr *gethtrie.Trie, vals []kvsi) (kvMap, error) {
var acct types.StateAccount
if err := rlp.DecodeBytes(val, &acct); err != nil {
panic(err)
}
return acct.Balance.Int64()
}
func updateTrie(tr *geth_trie.Trie, vals []kvs) (kvMap, error) {
all := kvMap{} all := kvMap{}
for _, val := range vals { for _, val := range vals {
all[string(val.k)] = &kv{[]byte(val.k), val.v} all[string(val.k)] = &kvi{[]byte(val.k), val.v}
tr.Update([]byte(val.k), packValue(val.v)) tr.Update([]byte(val.k), packValue(val.v))
} }
return all, nil return all, nil
} }
func commitTrie(t testing.TB, db *geth_trie.Database, tr *geth_trie.Trie) common.Hash { func commitTrie(t testing.TB, db *gethtrie.Database, tr *gethtrie.Trie) common.Hash {
t.Helper() t.Helper()
root, nodes := tr.Commit(false) root, nodes := tr.Commit(false)
if err := db.Update(geth_trie.NewWithNodeSet(nodes)); err != nil { if err := db.Update(gethtrie.NewWithNodeSet(nodes)); err != nil {
t.Fatal(err) t.Fatal(err)
} }
if err := db.Commit(root, false); err != nil { if err := db.Commit(root, false); err != nil {
@ -80,16 +162,8 @@ func commitTrie(t testing.TB, db *geth_trie.Database, tr *geth_trie.Trie) common
return root return root
} }
// commit a LevelDB state trie, index to IPLD and return new trie func makePgIpfsEthDB(t testing.TB) ethdb.Database {
func indexTrie(t testing.TB, edb ethdb.Database, root common.Hash) *trie.Trie { pg_db, err := postgres.ConnectSQLX(context.Background(), dbConfig)
t.Helper()
dbConfig.Driver = postgres.PGX
err := helper.IndexChain(dbConfig, geth_state.NewDatabase(edb), common.Hash{}, root)
if err != nil {
t.Fatal(err)
}
pg_db, err := postgres.ConnectSQLX(ctx, dbConfig)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -98,62 +172,48 @@ func indexTrie(t testing.TB, edb ethdb.Database, root common.Hash) *trie.Trie {
t.Fatal(err) t.Fatal(err)
} }
}) })
return pgipfsethdb.NewDatabase(pg_db, internal.MakeCacheConfig(t))
}
ipfs_db := pgipfsethdb.NewDatabase(pg_db, makeCacheConfig(t)) // commit a LevelDB state trie, index to IPLD and return new trie
sdb_db := state.NewDatabase(ipfs_db) func indexTrie(t testing.TB, edb ethdb.Database, root common.Hash) *Trie {
tr, err := newStateTrie(root, sdb_db.TrieDB()) t.Helper()
dbConfig.Driver = postgres.PGX
err := helper.IndexStateDiff(dbConfig, gethstate.NewDatabase(edb), common.Hash{}, root)
if err != nil {
t.Fatal(err)
}
ipfs_db := makePgIpfsEthDB(t)
tr, err := New(TrieID(root), NewDatabase(ipfs_db), StateTrieCodec)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
return tr return tr
} }
func newStateTrie(root common.Hash, db *trie.Database) (*trie.Trie, error) {
tr, err := trie.New(common.Hash{}, root, db, ipld.MEthStateTrie)
if err != nil {
return nil, err
}
return tr, nil
}
// generates a random Geth LevelDB trie of n key-value pairs and corresponding value map // generates a random Geth LevelDB trie of n key-value pairs and corresponding value map
func randomGethTrie(n int, db *geth_trie.Database) (*geth_trie.Trie, kvMap) { func randomGethTrie(n int, db *gethtrie.Database) (*gethtrie.Trie, kvMap) {
trie := geth_trie.NewEmpty(db) trie := gethtrie.NewEmpty(db)
var vals []*kv var vals []*kvi
for i := byte(0); i < 100; i++ { for i := byte(0); i < 100; i++ {
e := &kv{common.LeftPadBytes([]byte{i}, 32), int64(i)} e := &kvi{common.LeftPadBytes([]byte{i}, 32), int64(i)}
e2 := &kv{common.LeftPadBytes([]byte{i + 10}, 32), int64(i)} e2 := &kvi{common.LeftPadBytes([]byte{i + 10}, 32), int64(i)}
vals = append(vals, e, e2) vals = append(vals, e, e2)
} }
for i := 0; i < n; i++ { for i := 0; i < n; i++ {
k := randBytes(32) k := randBytes(32)
v := rand.Int63() v := rand.Int63()
vals = append(vals, &kv{k, v}) vals = append(vals, &kvi{k, v})
} }
all := kvMap{} all := kvMap{}
for _, val := range vals { for _, val := range vals {
all[string(val.k)] = &kv{[]byte(val.k), val.v} all[string(val.k)] = &kvi{[]byte(val.k), val.v}
trie.Update([]byte(val.k), packValue(val.v)) trie.Update([]byte(val.k), packValue(val.v))
} }
return trie, all return trie, all
} }
// generates a random IPLD-indexed trie
func randomTrie(t testing.TB, n int) (*trie.Trie, kvMap) {
edb := rawdb.NewMemoryDatabase()
db := geth_trie.NewDatabase(edb)
orig, vals := randomGethTrie(n, db)
root := commitTrie(t, db, orig)
trie := indexTrie(t, edb, root)
return trie, vals
}
func randBytes(n int) []byte {
r := make([]byte, n)
rand.Read(r)
return r
}
// TearDownDB is used to tear down the watcher dbs after tests // TearDownDB is used to tear down the watcher dbs after tests
func TearDownDB(db *sqlx.DB) error { func TearDownDB(db *sqlx.DB) error {
tx, err := db.Beginx() tx, err := db.Beginx()
@ -179,12 +239,3 @@ func TearDownDB(db *sqlx.DB) error {
} }
return tx.Commit() return tx.Commit()
} }
// returns a cache config with unique name (groupcache names are global)
func makeCacheConfig(t testing.TB) pgipfsethdb.CacheConfig {
return pgipfsethdb.CacheConfig{
Name: t.Name(),
Size: 3000000, // 3MB
ExpiryDuration: time.Hour,
}
}